From 25913ff3e83ad6b6e3abb6ca010796284e590b23 Mon Sep 17 00:00:00 2001 From: mesarcik Date: Mon, 1 Jun 2026 10:04:42 +0000 Subject: [PATCH 1/6] closed->open --- sync_repos.sh | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100755 sync_repos.sh diff --git a/sync_repos.sh b/sync_repos.sh new file mode 100755 index 0000000..195370c --- /dev/null +++ b/sync_repos.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +usage() { + cat < + +Copies files changed on (vs main) from the closed-source +repo to the open-source repo, preserving directory structure. + +Arguments: + closed_path Path to the closed-source repo + closed_branch Branch in the closed-source repo to sync from + open_path Path to the open-source repo + +Example: + $0 ~/code/closed-source feature-x ~/code/open-source +EOF +} + +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + usage + exit 0 +fi + +if [[ $# -ne 3 ]]; then + usage + exit 1 +fi + +CLOSED_PATH="$1" +CLOSED_BRANCH="$2" +OPEN_PATH="$3" + +cd "$CLOSED_PATH" +git checkout "$CLOSED_BRANCH" +cp --parents $(git diff main --name-only) "$OPEN_PATH" + From 44cdbf5e695ad5aa39a10c069def5f9166d99c13 Mon Sep 17 00:00:00 2001 From: mesarcik Date: Mon, 1 Jun 2026 10:05:12 +0000 Subject: [PATCH 2/6] added missing weather.py --- src/s4casting/model/weather.py | 86 ++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 src/s4casting/model/weather.py diff --git a/src/s4casting/model/weather.py b/src/s4casting/model/weather.py new file mode 100644 index 0000000..ff238e6 --- /dev/null +++ b/src/s4casting/model/weather.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Contributors to the s4casting project +# +# SPDX-License-Identifier: MPL-2.0 + +import torch +from torch import nn + + +class WeatherAuxTask(nn.Module): + """Self-supervised auxiliary task that learns to forecast masked weather features. + + During training, a random trailing fraction of each sequence's weather features + is zeroed out and the model learns to reconstruct those values from the latent + representation. During eval (self.training=False) the mask fraction collapses to + zero so no features are hidden and compute_loss returns a zero scalar. + + Only valid when patch_size=1 (no temporal compression), so the encoder output + is per-timestep and the forecaster can map directly to (B, T, C). + + When n_weather_features=0 both methods are no-ops. + """ + + def __init__(self, latent_dim: int, n_weather_features: int): + """Initialize WeatherAuxTask. + + Args: + latent_dim: Size of the latent representation produced by the patch encoder. + n_weather_features: Number of auxiliary weather channels (features after the + first target feature). Set to 0 to disable. + """ + super().__init__() + self.n_weather_features = n_weather_features + # max(1, ...) avoids nn.Linear(latent_dim, 0) when disabled; the layer is + # never called when n_weather_features=0. + self.weather_forecaster = nn.Linear(latent_dim, max(1, n_weather_features)) + + def prepare(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Randomly mask weather features and return targets for the auxiliary loss. + + Args: + x: Normalised input tensor of shape (B, T, F). + + Returns: + x_masked: x with a random trailing fraction of weather channels zeroed. + weather_gt: Ground-truth values in the masked region, zero elsewhere. Shape (B, T, C). + weather_mask: Binary float mask (1=unmasked, 0=masked). Shape (B, T, C). + """ + B, T = x.shape[:2] + C = self.n_weather_features + + if C == 0 or not self.training: + empty = x.new_empty(B, T, 0) + return x, empty, empty + + fractions = torch.rand(B, device=x.device) + n_masked = (fractions * T).long() + time_idx = torch.arange(T, device=x.device).view(1, T) + weather_mask = (time_idx < (T - n_masked).unsqueeze(1)).float().unsqueeze(-1).expand(B, T, C) + + weather_gt = x[..., 1 : 1 + C].clone() * (1 - weather_mask) + x = x.clone() + x[..., 1 : 1 + C] = x[..., 1 : 1 + C] * weather_mask + return x, weather_gt, weather_mask + + def compute_loss( + self, + x_enc: torch.Tensor, + weather_gt: torch.Tensor, + weather_mask: torch.Tensor, + ) -> torch.Tensor: + """Compute MSE over the masked weather region. + + Args: + x_enc: Per-timestep encoder representation of shape (B, T, E). + weather_gt: Ground-truth weather values in the masked region. Shape (B, T, C). + weather_mask: Binary mask (1=unmasked, 0=masked). Shape (B, T, C). + + Returns: + Scalar MSE loss over the masked region; zero when n_weather_features=0. + """ + if self.n_weather_features == 0: + return x_enc.new_zeros(()) + + weather_forecast = self.weather_forecaster(x_enc) # (B, T, C) + n_masked = (1 - weather_mask).sum().clamp(min=1) + return ((weather_forecast - weather_gt) ** 2 * (1 - weather_mask)).sum() / n_masked From e9f9b9ff1309d7d606cda0f7fa49e1b3eb5847e3 Mon Sep 17 00:00:00 2001 From: mesarcik Date: Mon, 1 Jun 2026 10:08:55 +0000 Subject: [PATCH 3/6] closed->open --- configs/cuda.toml | 1 + notebooks/00_data_preparation.ipynb | 3 +- notebooks/01_train_model.ipynb | 3 +- notebooks/02_inference.ipynb | 3 +- notebooks/03_evaluation.ipynb | 3 +- .../04_train_with_weather_time_location.ipynb | 3 +- notebooks/05_load_forecasting.ipynb | 85 ++++--- scripts/train.py | 2 +- src/s4casting/core/batcher.py | 91 ++++--- src/s4casting/core/benchmarker.py | 32 ++- src/s4casting/core/config.py | 35 ++- src/s4casting/core/evaluator.py | 21 +- src/s4casting/core/functional.py | 11 +- src/s4casting/core/loss.py | 51 +++- src/s4casting/core/tasks.py | 120 ++++++--- src/s4casting/core/trainer.py | 16 +- src/s4casting/data/utils.py | 227 ++++++++++-------- src/s4casting/eval/evaluator_heads.py | 21 +- src/s4casting/eval/metrics.py | 6 +- src/s4casting/factories/model_container.py | 26 +- src/s4casting/factories/trainer.py | 8 +- src/s4casting/model/mamba.py | 14 +- src/s4casting/model/mambacpu.py | 9 +- src/s4casting/model/ss.py | 35 ++- sync_repos.sh | 4 + tests/utils.py | 2 +- 26 files changed, 545 insertions(+), 287 deletions(-) diff --git a/configs/cuda.toml b/configs/cuda.toml index 02bc90b..559272f 100644 --- a/configs/cuda.toml +++ b/configs/cuda.toml @@ -33,6 +33,7 @@ latent_dim = 1024 [model.loss] loss = "nll" sigma_regularisation_factor = 1 +components = {primary = 1.0, weather = 1.0} [model.ssm] kernel = "s6" diff --git a/notebooks/00_data_preparation.ipynb b/notebooks/00_data_preparation.ipynb index 5ce52df..d7fe3eb 100644 --- a/notebooks/00_data_preparation.ipynb +++ b/notebooks/00_data_preparation.ipynb @@ -7,8 +7,7 @@ "source": [ "# S4Casting: Data Preparation\n", "\n", - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n", + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n", "This notebook contains scripts and utilities for preparing and formatting datasets for use with the S4Casting framework. Proper data formatting is crucial for effective model training and evaluation.\n", "\n", "## Overview\n", diff --git a/notebooks/01_train_model.ipynb b/notebooks/01_train_model.ipynb index 1668b99..e058724 100644 --- a/notebooks/01_train_model.ipynb +++ b/notebooks/01_train_model.ipynb @@ -7,8 +7,7 @@ "source": [ "# S4Casting: Training Pipeline\n", "\n", - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n", + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n", "This notebook demonstrates the **training pipeline**, as in *scripts/train.py*.\n", "It walks through loading the configuration, setting up the model, optimizer, trainer, and running the training loop.\n", "\n", diff --git a/notebooks/02_inference.ipynb b/notebooks/02_inference.ipynb index 1391e5a..a8e844e 100644 --- a/notebooks/02_inference.ipynb +++ b/notebooks/02_inference.ipynb @@ -7,8 +7,7 @@ "source": [ "# Inference Pipeline\n", "\n", - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n", + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n", "This notebook demonstrates how to perform **inference** using the previously trained S4Casting model. \n", "The goal is to take recent time-series data (e.g., energy measurements) and generate short- or medium-term forecasts.\n", "\n", diff --git a/notebooks/03_evaluation.ipynb b/notebooks/03_evaluation.ipynb index 880b0fb..7e81996 100644 --- a/notebooks/03_evaluation.ipynb +++ b/notebooks/03_evaluation.ipynb @@ -7,8 +7,7 @@ "source": [ "# Evaluation and Benchmarking of Model Forecasts\n", "\n", - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n", + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n", "This notebook demonstrates how to evaluate and visualize the results of a previously trained **S4Casting** model. \n", "It covers:\n", "1. Loading the saved inference output from disk. \n", diff --git a/notebooks/04_train_with_weather_time_location.ipynb b/notebooks/04_train_with_weather_time_location.ipynb index 48f06e4..752f2b3 100644 --- a/notebooks/04_train_with_weather_time_location.ipynb +++ b/notebooks/04_train_with_weather_time_location.ipynb @@ -7,8 +7,7 @@ "source": [ "# Training with extra features: weather, time, and location\n", "\n", - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n", + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n", "This notebook demonstrates how to train a model with additional contextual features:\n", "\n", "* Weather data aligned to your time series locations \n", diff --git a/notebooks/05_load_forecasting.ipynb b/notebooks/05_load_forecasting.ipynb index 8d1e915..b13dac3 100644 --- a/notebooks/05_load_forecasting.ipynb +++ b/notebooks/05_load_forecasting.ipynb @@ -5,14 +5,13 @@ "id": "0", "metadata": {}, "source": [ - "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.", - "\n" + "This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "0", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -40,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +92,7 @@ }, { "cell_type": "markdown", - "id": "2", + "id": "3", "metadata": {}, "source": [ "## s4casting with historical load measurements\n", @@ -113,7 +112,7 @@ }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ "## Build model\n", @@ -124,7 +123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -175,7 +174,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ "In the image below, you see how this model could be illustrated in a children's story book. \n", @@ -199,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -208,7 +207,7 @@ }, { "cell_type": "markdown", - "id": "7", + "id": "8", "metadata": {}, "source": [ "The entire model architecture can be broken down into four components - \n", @@ -227,7 +226,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -238,7 +237,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -253,7 +252,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ "## Build dataset\n", @@ -267,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +293,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "13", "metadata": {}, "source": [ "X and Y have the same ```B * L * F``` shape, and so do the masks output by ```PredictionTaskDataset```. \n", @@ -305,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -365,7 +364,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "16", "metadata": {}, "source": [ "## Example model prediction" @@ -374,7 +373,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -400,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "18", "metadata": {}, "source": [ "As this model is probablistic, instead of predicting a point value, the model predicts the value in different quantiles. \n", @@ -419,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -433,7 +432,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -461,7 +460,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -501,7 +500,7 @@ }, { "cell_type": "markdown", - "id": "21", + "id": "22", "metadata": {}, "source": [ "The green lines are the models prediction of the 50th quantile. This is from our untrained model with randomly initialized weights, so of course these predictions don't make no sense. We're going to change that!" @@ -509,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "23", "metadata": {}, "source": [ "## Let's train the model\n", @@ -531,7 +530,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -594,7 +593,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "25", "metadata": {}, "source": [ "# How does the model do after 1000 steps?\n", @@ -605,7 +604,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -657,7 +656,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -697,7 +696,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "28", "metadata": {}, "source": [ "The model that has been trained for 1000 iterations has learned the broad/rough measurement patterns. \n", @@ -716,7 +715,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -789,7 +788,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -823,7 +822,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "31", "metadata": {}, "source": [ "In the close up on prediction, we see that the median prediction tracks the ground truth reasonably well (for most of the samples at least, for certain signals the model just fails because it has not been trained long enough)." @@ -831,7 +830,7 @@ }, { "cell_type": "markdown", - "id": "31", + "id": "32", "metadata": {}, "source": [ "## What about a model that has been trained for 100_000 steps?\n", @@ -843,7 +842,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -900,7 +899,7 @@ }, { "cell_type": "markdown", - "id": "33", + "id": "34", "metadata": {}, "source": [ "Now let's load in the same sample that we plotted above to see how much our model has improved." @@ -909,7 +908,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -947,7 +946,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +1022,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "37", "metadata": {}, "source": [ "It depends on which sample you're plotting, for most signals, you should see improvements after training for longer. However, for solar/wind related signals, the improvements will be minimal because the model has not been trained to associate the load with weather data.\n", @@ -1034,7 +1033,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -1086,7 +1085,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "39", "metadata": {}, "outputs": [], "source": [ @@ -1162,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "39", + "id": "40", "metadata": {}, "source": [ "You should see that the model is better able to follow the higher frequency changes (the 'spikiness'/volatility of targets) than with the model that's only trained for 1000 steps. At 10k steps, the model better captures the timing, shape and amplitude of peaks. Though for signals which don't have a clear daily cycles, the model performs less well. This is something we can improve by making the context bigger (e.g. instead of 7 days, perhaps 30 days) - effectively allowing the model to see further back. \n", @@ -1174,7 +1173,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "41", "metadata": {}, "source": [ "## Conclusion: s4casting\n", diff --git a/scripts/train.py b/scripts/train.py index 9aa03b6..5db66d3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -51,7 +51,7 @@ def train(config: Configuration): optimizer = fc.provide_optimizer(config.optimizer, model_container.raw_model.parameters()) scheduler = fc.provide_scheduler(config.scheduler, optimizer) - trainer = fc.provide_trainer(config=config.training, io=config.io, optimizer=config.optimizer, machine=machine) + trainer = fc.provide_trainer(config=config.training, optimizer=config.optimizer, machine=machine) checkpointer = fc.provide_checkpointer(config.io, trainer.hooks) evaluator_head = fc.provide_evaluator_head(config.model, trainer.hooks) evaluator = fc.provide_evaluation(trainer.hooks, evaluator_head) diff --git a/src/s4casting/core/batcher.py b/src/s4casting/core/batcher.py index b19e602..87f01c7 100644 --- a/src/s4casting/core/batcher.py +++ b/src/s4casting/core/batcher.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: MPL-2.0 from collections import defaultdict -from itertools import product import numpy as np import torch from numpy.typing import NDArray -from torch.utils.data import ConcatDataset +from torch.utils.data import ConcatDataset, RandomSampler +from torch.utils.data.distributed import DistributedSampler from s4casting.core.config import ( BenchmarkingConfiguration, @@ -35,7 +35,7 @@ intervals_for_year, substract, ) -from s4casting.data.utils import ConcatDatasetSampler, collate_single_interval +from s4casting.data.utils import build_valid_context_sampling_pairs, collate_single_interval class Batcher: @@ -70,12 +70,16 @@ def __init__( io_config, model_config, self.datasets_per_source ) - context_sample_rates = list( - product( - model_config.context_window, - model_config.input_sample_intervals_minutes, - ) + valid_context_windows = build_valid_context_sampling_pairs( + context_days=model_config.context_window, + sample_intervals_minutes=model_config.input_sample_intervals_minutes, + min_points=32, + max_context_len=model_config.ssm.mixer_size if model_config.ssm is not None else None, + interval_context_limits=io_config.interval_context_limits, ) + context_sample_rates = [ + (r["context_days"], r["sample_interval_minutes"]) for r in valid_context_windows["valid_pairs"] + ] # check if we have specified local benchmarks and remove them from trianing set if so local_bench = bench_config.benchmarks.get("LocalBenchmark") @@ -98,12 +102,11 @@ def __init__( # Fill gaps for each combination of context window and sample rate non_benchmark_intervals = [] for context_days, sample_rate in context_sample_rates: - sample_rate_factor = sample_rate / model_config.base_sample_interval_minutes context_window_minutes = context_days * 24 * 60 _non_benchmark_intervals = fill_gaps( non_benchmark_intervals_temp, - io_config.gap_skip_hours, - context_days * sample_rate_factor, + int(24 * context_days * (io_config.gap_skip_perc / 100)), + context_days, io_config.context_window_valid_ratio, ) _non_benchmark_intervals = add_duration(_non_benchmark_intervals, -context_window_minutes * 60) @@ -126,6 +129,8 @@ def __init__( self.train = self._get_task_dataset( train_config.task, ConcatDataset(self.train), + valid_context_windows["recommended_max_context_samples"], + train_config.max_retries, model_config.alignment, model_config.base_sample_interval_minutes, model_config.predict_width, @@ -134,6 +139,8 @@ def __init__( self.validation = self._get_task_dataset( train_config.task, ConcatDataset(self.validation), + valid_context_windows["recommended_max_context_samples"], + train_config.max_retries, model_config.alignment, model_config.base_sample_interval_minutes, model_config.predict_width, @@ -279,6 +286,8 @@ def create_data_loaders( ) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """Create data loaders for training and validation. + Note: we offset the seed by iteration, this should only be triggered when resuming training. + Args: train_config (TrainingConfiguration): Training configuration. machine (Machine): Machine configuration. @@ -287,36 +296,46 @@ def create_data_loaders( Returns: tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: Training and validation data loaders. """ - # Create a sampler to properly distribute the data across multiple GPUs - ddp_kwargs = {"num_replicas": machine.world_size, "rank": machine.ddp.global_rank} if machine.ddp else {} - - train_sampler = ConcatDatasetSampler( - self.train_ds_lengths, - train_config.batch_size, - drop_last=True, - seed=run_config.seed, - **ddp_kwargs, - ) + if machine.ddp: + train_sampler = DistributedSampler( + self.train, + num_replicas=machine.world_size, + rank=machine.ddp.global_rank, + shuffle=True, + seed=run_config.seed + train_config.iteration, + drop_last=True, + ) + validation_sampler = DistributedSampler( + self.validation, + num_replicas=machine.world_size, + rank=machine.ddp.global_rank, + shuffle=True, + seed=run_config.seed, + drop_last=False, + ) + else: + train_sampler = RandomSampler( + self.train, generator=torch.Generator().manual_seed(run_config.seed + train_config.iteration) + ) + validation_sampler = RandomSampler( + self.validation, generator=torch.Generator().manual_seed(run_config.seed + 1028) + ) train_loader = torch.utils.data.DataLoader( self.train, - batch_sampler=train_sampler, + batch_size=train_config.batch_size, + sampler=train_sampler, + drop_last=True, collate_fn=collate_single_interval, - num_workers=4, + num_workers=8, persistent_workers=True, pin_memory=True, ) - validation_sampler = ConcatDatasetSampler( - self.validation_ds_lengths, - train_config.batch_size, - drop_last=False, - seed=run_config.seed, - ) - validation_loader = torch.utils.data.DataLoader( self.validation, - batch_sampler=validation_sampler, + batch_size=train_config.batch_size, + sampler=validation_sampler, collate_fn=collate_single_interval, num_workers=2, persistent_workers=True, @@ -329,6 +348,8 @@ def create_data_loaders( def _get_task_dataset( task_name: str, dataset: torch.utils.data.Dataset, + max_context_samples: int, + max_retries: int, alignment: int, sample_rate: int, predict_width: int | tuple[float, float] = 2, @@ -338,6 +359,8 @@ def _get_task_dataset( Args: task_name (str): The name of the task ("prediction", "masking", or "randomprediction"). dataset: The dataset to wrap in the task dataset. + max_context_samples: Maximum number for context_samples to zero pad to. + max_retries (int) : Maximum number of retries for rejection sampling. alignment (int): data alignment. sample_rate (int): base sample rate of dataset. predict_width (int or tuple[float, float]): prediction window if using prediction task. @@ -350,6 +373,8 @@ def _get_task_dataset( raise ValueError("For a prediction task prediction_width must be an int.") return PredictionTaskDataset( dataset, + max_context_samples, + max_retries, 0, (predict_width * 24 * 60) // sample_rate, # type: ignore ) @@ -357,6 +382,8 @@ def _get_task_dataset( if task_name == "masking": return RandomMaskingTaskDataset( dataset, + max_context_samples, + max_retries, alignment // sample_rate, ) # type: ignore @@ -365,6 +392,8 @@ def _get_task_dataset( raise ValueError("Task 'randomprediction' requires predict width of list.") return VariablePredictionTaskDataset( dataset, + max_context_samples, + max_retries, predict_dim=0, min_predict_width_perc=predict_width[0], max_predict_width_perc=predict_width[1], diff --git a/src/s4casting/core/benchmarker.py b/src/s4casting/core/benchmarker.py index 448e545..d249ca9 100644 --- a/src/s4casting/core/benchmarker.py +++ b/src/s4casting/core/benchmarker.py @@ -87,13 +87,19 @@ def sample_dataset(self, location: str, context: Context): context.batcher.benchmark.datas, # type: ignore[attr-defined] self.bench_config.input_sample_interval_minutes * 60, ) - predict_window = ( + predict_window_samples = ( self.bench_config.predict_window_days * 24 * 60 ) // self.bench_config.input_sample_interval_minutes + + context_window_samples = ( + self.bench_config.context_window_days * 24 * 60 + ) // self.bench_config.input_sample_interval_minutes location_dataset = PredictionTaskDataset( dataset, + context_window_samples, + 0, self.bench_config.predict_dim, # type: ignore[attr-defined] - predict_window, + predict_window_samples, ) times = get_timestamps( @@ -138,10 +144,12 @@ def benchmark(self, context: Context, iteration) -> None: context.benchmark_location = location (X, xm, Y, ym), times, sign, sample_config = self.sample_dataset(location, context) output_interval = ( - self.bench_config.output_sample_interval_minutes + torch.ones_like(sample_config.sample_interval_minutes) + * self.bench_config.output_sample_interval_minutes if self.bench_config.output_sample_interval_minutes is not None else sample_config.sample_interval_minutes ) + prediction, loss = run_in_batches( benchmarking_model, context.configuration.training.batch_size, @@ -162,8 +170,22 @@ def benchmark(self, context: Context, iteration) -> None: times=times, sign=sign, report_type="benchmark", - sample_config=sample_config, - output_interval=output_interval, + predict_window_days=int( + ( + sample_config.predict_window_samples[0].item() + * sample_config.sample_interval_minutes[0].item() + ) + // (24 * 60) + ), + context_window_days=int( + ( + sample_config.context_window_samples[0].item() + * sample_config.sample_interval_minutes[0].item() + ) + // (24 * 60) + ), + input_interval=sample_config.sample_interval_minutes[0].item(), + output_interval=output_interval[0].item(), n_day_ahead=self.bench_config.n_day_ahead, ) context.model_container.model.train(mode=True) diff --git a/src/s4casting/core/config.py b/src/s4casting/core/config.py index f6ae414..9c464ff 100644 --- a/src/s4casting/core/config.py +++ b/src/s4casting/core/config.py @@ -116,6 +116,14 @@ class LossConfiguration(BaseModel): 0.0, description="Prevents sigma from growing too big", ) + components: dict[str, float] = Field( + default_factory=lambda: {"primary": 1.0}, + description=( + "Named loss components and their scalar weights. " + "Keys must match terms accumulated in the model forward pass. " + "Example: {primary = 1.0, weather = 0.5}" + ), + ) class ChronosConfiguration(BaseModel): @@ -152,6 +160,10 @@ class SSMConfiguration(BaseModel): 4, description="Number of stacked layers in the SSM/GRU block.", ) + mixer_size: PositiveInt | None = Field( + None, + description="Whether to do time domain mixing, limits maximum model context width.", + ) class TransformerConfiguration(BaseModel): @@ -335,11 +347,24 @@ class IOConfiguration(BaseModel): features: dict[str, DatasetConfiguration] = Field(..., description="Dictionary of (feature) datasets.") output: str = Field(..., description="Location to save outputs.") load_checkpoint: str | None = Field(None, description="Path to load checkpoint from.") - iteration: int = Field(0, description="Current iteration number.") - gap_skip_hours: int = Field(1, description="Number of hours to skip for gaps.") + gap_skip_perc: PositiveInt = Field( + 5, ge=0, le=100, description="Percentage of context window of hours to skip for gaps." + ) context_window_valid_ratio: float = Field(0.8, description="Valid ratio for input window.") hash_datasets: bool = Field(False, description="Whether to hash the datasets to be logged") to_memory: bool = Field(False, description="Whether to move the memmap'd data to CPU memory") + interval_context_limits: dict[int, dict[str, int]] = Field( + default={ + 5: {"min_days": 7, "max_days": 10}, + 10: {"min_days": 7, "max_days": 14}, + 15: {"min_days": 11, "max_days": 32}, + 30: {"min_days": 16, "max_days": 64}, + 60: {"min_days": 16, "max_days": 64}, + 1440: {"min_days": 32, "max_days": 364}, + 10080: {"min_days": 64, "max_days": 364}, + }, + description="Pairs of valid interval and min max days in the context windows", + ) class StefBeamBenchmark(BaseModel): @@ -419,6 +444,11 @@ class TrainingConfiguration(BaseModel): description="Training task : 'prediction' (fixed window), " "'masking' (random masking), or 'randomprediction' (random prediction window percentage).", ) + max_retries: NonNegativeInt = Field( + 10, + description="Max retries for rejection sampling, invalid due to changing prediction/context windows." + "Default 0 means no rejection sampling.", + ) gradient_accumulation_steps: int = Field(2, description="Number of gradient accumulation steps.") batch_size: int = Field(32, description="Batch size for training.") @@ -429,6 +459,7 @@ class TrainingConfiguration(BaseModel): n_samples_per_epoch: PositiveInt | None = Field( default=None, init=False, description="Populated at runtime with the number of samples per epoch." ) + iteration: int = Field(1, description="Current iteration number. Set to 0 for zero-shot performance.") def get_data_range_spans(dataset_dict: DatasetConfiguration) -> tuple[datetime, datetime]: diff --git a/src/s4casting/core/evaluator.py b/src/s4casting/core/evaluator.py index 08cf1e9..f479fb5 100644 --- a/src/s4casting/core/evaluator.py +++ b/src/s4casting/core/evaluator.py @@ -64,8 +64,8 @@ def evaluate(self, context: Context, iteration: int) -> None: prediction, loss = evaluation_model(X, Xm, sample_config.sample_interval_minutes, output_interval, Y, Ym) context.validation_loss = loss.item() - context.input_validation_sample_rate = sample_config.sample_interval_minutes - context.output_validation_sample_rate = output_interval + context.input_validation_sample_rate = sample_config.sample_interval_minutes[0].item() + context.output_validation_sample_rate = output_interval[0].item() self.head_evaluator.report( context=context, @@ -77,9 +77,20 @@ def evaluate(self, context: Context, iteration: int) -> None: loss=context.validation_loss, iteration=iteration, report_type="evaluation", - sample_config=sample_config, - output_interval=output_interval, - n_day_ahead=sample_config.predict_window_days[0].item(), + output_interval=output_interval[0].item(), + n_day_ahead=int( + (sample_config.predict_window_samples[0].item() * sample_config.sample_interval_minutes[0].item()) + // (24 * 60) + ), + predict_window_days=int( + (sample_config.predict_window_samples[0].item() * sample_config.sample_interval_minutes[0].item()) + // (24 * 60) + ), + context_window_days=int( + (sample_config.context_window_samples[0].item() * sample_config.sample_interval_minutes[0].item()) + // (24 * 60) + ), + input_interval=sample_config.sample_interval_minutes[0].item(), ) context.model_container.model.train(mode=True) diff --git a/src/s4casting/core/functional.py b/src/s4casting/core/functional.py index 9edaf70..c301379 100644 --- a/src/s4casting/core/functional.py +++ b/src/s4casting/core/functional.py @@ -132,14 +132,16 @@ def resample(data: torch.Tensor, patch_size, maxpool=True) -> torch.Tensor: def select_rate( - input_rate: int, + input_rate: torch.Tensor, output_sample_intervals_minutes: list[int], -) -> int: + transcoding: bool = False, +) -> torch.Tensor: """Randomly choose an output sample interval that is greater than or equal to the given input sample interval. Args: input_rate (int): input_sample rate for batch. output_sample_intervals_minutes(list[int]): Possible output sample rate. + transcoding (bool): Determines if input and output rates can be different. Returns: int : selected sample rate. @@ -147,6 +149,11 @@ def select_rate( Raises: ValueError: If no valid output sample interval exists. """ + if not transcoding: + return input_rate + + raise ValueError("Transcoding currently unsupported.") + valid_rates = [rate for rate in output_sample_intervals_minutes if rate >= input_rate] if not valid_rates: diff --git a/src/s4casting/core/loss.py b/src/s4casting/core/loss.py index 3818b35..b6b2662 100644 --- a/src/s4casting/core/loss.py +++ b/src/s4casting/core/loss.py @@ -10,6 +10,34 @@ from torch import nn +class CompositeLoss(nn.Module): + """Combines named scalar loss terms with per-component weights. + + Weights are specified in config under ``model.loss.components``. Any key + present in the ``losses`` dict passed to ``forward`` that has no entry in + ``weights`` is silently ignored (opt-in model). Any configured key that is + absent from ``losses`` at call time is also skipped gracefully, which covers + cases where a loss term is only active for certain inputs (e.g. weather loss + requires multi-feature data). + + Adding a new loss term requires: + 1. One line in the model forward: ``_losses["new_term"] = value`` + 2. One config entry: ``components = {primary = 1.0, new_term = 0.5}`` + """ + + def __init__(self, weights: dict[str, float]): + """Init doctring for linting.""" + super().__init__() + self.weights = weights + + def forward(self, losses: dict[str, torch.Tensor]) -> torch.Tensor: + """Return the weighted sum of all loss terms present in both dicts.""" + total = sum(self.weights[k] * v for k, v in losses.items() if k in self.weights and self.weights[k] != 0.0) + if isinstance(total, int): + raise ValueError(f"No active loss terms found. Configured: {self.weights}, received: {list(losses)}") + return total + + class SoftClipLoss(nn.Module): """Soft clipping loss function.""" @@ -61,8 +89,8 @@ def forward( self, out: torch.Tensor, target: torch.Tensor, - input_interval: int, - output_interval: int, + input_interval: torch.Tensor, + output_interval: torch.Tensor, mask: torch.Tensor | None = None, ): """Calculates the nll loss assuming a out is a tensor containing GMM parameters. @@ -74,8 +102,8 @@ def forward( (batch_size, seq_len, n_out_features, or , or 1). target (torch.Tensor): Tensor with one dimension less than the GMM parameters. Shape: (batch_size, seq_len, n_out_features). - input_interval (int): Input sample rate of eval step. - output_interval (int): Output sample rate of eval step. + input_interval (torch.Tensor): Input sample rate of eval step. + output_interval (torch.Tensor): Output sample rate of eval step. mask (torch.Tensor, optional): Tensor determining on which values to calculate the loss. 1 = calculate, 0 = ignore. Shape: same as `target`. If no mask is supplied, the loss will be calculated on the entire signal. @@ -95,17 +123,22 @@ def forward( # Note: clone is necessary for gradient flow logpi, sigma, mu = (t.unsqueeze(2).clone() for t in out.unbind(dim=-1)) # -> (B, T, 1, 1, D) + # Check that the ratio between input and output intervals are the same, such that the reshape is consistent + ratio = output_interval // input_interval + if not all(ratio[0] == r for r in ratio): + raise ValueError("Mixed ratios of sample rates in model forward") + # reshape such that a number of samples fit inside each distribution target = rearrange( target * mask, "b (t s) f -> b t s f", - s=output_interval // input_interval, + s=ratio[0].item(), f=target.shape[-1], ).unsqueeze(-1) mask = rearrange( mask, "b (t s) f -> b t s f", - s=output_interval // input_interval, + s=ratio[0].item(), f=mask.shape[-1], ) @@ -144,8 +177,8 @@ def forward( self, out: torch.Tensor, target: torch.Tensor, - input_interval: int, # noqa - output_interval: int, # noqa + input_interval: torch.Tensor, # noqa + output_interval: torch.Tensor, # noqa mask: torch.Tensor | None = None, ): """Calculates the pinball loss. @@ -154,7 +187,7 @@ def forward( out (torch.Tensor): Model output of shape (batch_size, seq_len, n_out_features, ,). target (torch.Tensor): Target tensor of shape (batch_size, seq_len, n_out_features). mask (torch.Tensor, optional): Mask tensor determining on which values to calculate the loss. - input_interval (int): Input sample rate of eval step. (Not used) + input_interval (torch.Tensor): Input sample rate of eval step. (Not used) output_interval (int): Output sample rate of eval step. (Not used) Returns: diff --git a/src/s4casting/core/tasks.py b/src/s4casting/core/tasks.py index 0f764ff..c091c62 100644 --- a/src/s4casting/core/tasks.py +++ b/src/s4casting/core/tasks.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MPL-2.0 +import warnings from collections import namedtuple import torch @@ -14,13 +15,17 @@ class TaskDataset: """Dataset wrapper that provides input and output masks for each sample.""" - def __init__(self, dataset): + def __init__(self, dataset, max_context_samples, max_retries): """Initialize the TaskDataset. Args: dataset: The underlying dataset to wrap. + max_context_samples: maximum context witdth. + max_retries: Number of retries at a diffent index if data is not valid. """ self.dataset = dataset + self.max_context_samples = max_context_samples + self.max_retries = max(1, max_retries) self.predict_window_samples = None self.predict_dim = None @@ -32,12 +37,11 @@ def __len__(self): """ return len(self.dataset) - def get_masks(self, sample, _sample_interval): + def get_masks(self, sample): """Get the input and output masks for a given sample. Args: sample: The sample for which to get the masks. - _sample_interval: The sample sample_interval. Returns: tuple: A tuple containing the input mask and output mask. @@ -83,57 +87,104 @@ def valid_predict_window( return not scaled_offset > offset_threshold + def zero_pad( + self, X: torch.Tensor, xm: torch.Tensor, ym: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Zero pad context to maximum width. + + Note: only works if max context samples is different then X.shape[0]. + + Args: + X (torch.Tensor): Sample to be padded. + xm (torch.Tensor): Sample context mask to be padded. + ym (torch.Tensor): Sample prediction mask to be padded. + + Returns: + X (torch.Tensor): Padded sample. + xm (torch.Tensor): Padded context mask. + ym (torch.Tensor): Padded prediction mask. + + """ + _diff = self.max_context_samples - X.shape[0] + + if _diff == 0: + return (X, xm, ym) + if _diff < 0: + raise ValueError("Context width greater than pad width") + + zeros = torch.zeros([_diff, X.shape[1]], device=X.device) + X = torch.cat([zeros, X], dim=0) + xm = torch.cat([zeros, xm], dim=0) + ym = torch.cat([zeros, ym], dim=0) + return (X, xm, ym) + def __getitem__(self, idx): """Get the task sample at the specified index. + Note that this function does rejection sampling if max_retries > 0. + Args: idx: The index of the sample to retrieve. Returns: TaskSample: A named tuple containing the input data, input mask, output data, and output mask. + SampleConfig: A named tuple of the configuration Parameters for a givem sample. """ - X, sample_config = self.dataset[idx] - xm, ym = self.get_masks(X, sample_config.sample_interval_minutes) - # A hack so that prediction window is accessible downstream - sample_config.predict_window_days = (self.predict_window_samples * sample_config.sample_interval_minutes) // ( - 24 * 60 - ) - - # Note we only do this according to prediction window split - # But could theoretically be done with random masking as well. - if isinstance(self, (PredictionTaskDataset, VariablePredictionTaskDataset)) and not self.valid_predict_window( - X - ): - # i.e. if no valid then mask the whole sample - ym = torch.zeros_like(ym) - xm = torch.zeros_like(ym) - - # set prediction window - - return TaskSample(torch.nan_to_num(X) * xm, xm, torch.nan_to_num(X.detach().clone()) * ym, ym), sample_config + for attempt in range(self.max_retries): + X, sample_config = self.dataset[idx] + xm, ym = self.get_masks(X) + sample_config.predict_window_samples = self.predict_window_samples + sample_config.context_window_samples = self.max_context_samples + + if xm.sum() < 10 or ym.sum() < 10: + idx = torch.randint(len(self.dataset), (1,)).item() + continue + + if isinstance( + self, (PredictionTaskDataset, VariablePredictionTaskDataset) + ) and not self.valid_predict_window(X): + if attempt == self.max_retries - 1: + warnings.warn( + f"Could not find a valid sample after {self.max_retries} attempts, " + f"returning zero-masked sample at idx={idx}." + f"\n{sample_config}" + ) + break + idx = torch.randint(len(self.dataset), (1,)).item() + continue + + break # Valid sample found, exit loop + X, xm, ym = self.zero_pad(X, xm, ym) + return TaskSample( + torch.nan_to_num(X) * xm, + xm, + torch.nan_to_num(X.detach().clone()) * ym, + ym, + ), sample_config class PredictionTaskDataset(TaskDataset): """Dataset wrapper for prediction tasks.""" - def __init__(self, dataset, predict_dim, predict_window_samples): + def __init__(self, dataset, max_context_samples, max_retries, predict_dim, predict_window_samples): """Initialize the PredictionTaskDataset. Args: dataset: The underlying dataset to wrap. + max_context_samples: The max number of samples in a sample. + max_retries: Number of retries at a different index if data is not valid. predict_dim: The dimension to predict. predict_window_samples: The window size for prediction. """ - super().__init__(dataset) + super().__init__(dataset, max_context_samples, max_retries) self.predict_window_samples = predict_window_samples self.predict_dim = predict_dim - def get_masks(self, sample, _sample_interval): + def get_masks(self, sample): """Get the input and output masks for prediction tasks. Args: sample: The sample for which to get the masks. - _sample_interval: The sample sample_interval. Returns: tuple: A tuple containing the input mask and output mask. @@ -158,18 +209,22 @@ class VariablePredictionTaskDataset(TaskDataset): of the sample length each time a sample is retrieved. """ - def __init__(self, dataset, predict_dim, min_predict_width_perc, max_predict_width_perc): + def __init__( + self, dataset, max_context_samples, max_retries, predict_dim, min_predict_width_perc, max_predict_width_perc + ): """Initialize the VariablePredictionTaskDataset. Args: dataset: The underlying dataset to wrap. + max_context_samples: maximum context witdth. + max_retries: Number of retries at a different index if data is not valid. predict_dim: The dimension to predict. min_predict_width_perc: Minimum prediction window as a percentage of sample length (0.0 to 1.0). max_predict_width_perc: Maximum prediction window as a percentage of sample length (0.0 to 1.0). """ - super().__init__(dataset) + super().__init__(dataset, max_context_samples, max_retries) self.predict_dim = predict_dim self.min_predict_width_perc = min_predict_width_perc self.max_predict_width_perc = max_predict_width_perc @@ -211,25 +266,26 @@ def get_masks(self, sample, sample_interval): class RandomMaskingTaskDataset(TaskDataset): """Dataset wrapper that applies random masking to samples.""" - def __init__(self, dataset, min_mask_size, mask_fraction=0.3): + def __init__(self, dataset, max_context_samples, max_retries, min_mask_size, mask_fraction=0.3): """Initialize the RandomMaskingTaskDataset. Args: dataset: The dataset whose samples are to be masked. + max_context_samples: maximum context witdth. + max_retries: Number of retries at a different index if data is not valid. min_mask_size: The min_mask_size in samples. Should be a multiple of the model's `patch_size`. mask_fraction: The fraction of mask samples. """ - super().__init__(dataset) + super().__init__(dataset, max_context_samples, max_retries) self.min_mask_size = min_mask_size self.mask_fraction = mask_fraction - def get_masks(self, sample, _sample_interval): + def get_masks(self, sample): """Get the input and output masks with random masking. Args: sample: The sample for which to get the masks. - _sample_interval: The sample sample_interval. Returns: tuple: A tuple containing the input mask and output mask. diff --git a/src/s4casting/core/trainer.py b/src/s4casting/core/trainer.py index 0fd55f8..54a182a 100644 --- a/src/s4casting/core/trainer.py +++ b/src/s4casting/core/trainer.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist -from s4casting.core.config import IOConfiguration, OptimizerConfiguration, TrainingConfiguration +from s4casting.core.config import OptimizerConfiguration, TrainingConfiguration from s4casting.core.context import Context from s4casting.core.functional import select_rate from s4casting.core.hooks import TrainingHooks @@ -20,7 +20,6 @@ class Trainer: def __init__( self, config: TrainingConfiguration, - io: IOConfiguration, optimizer: OptimizerConfiguration, machine: Machine, gradient_accumulation_steps: int, @@ -29,15 +28,13 @@ def __init__( Args: config (TrainingConfiguration): Training configuration. - io (IOConfiguration): IO configuration. optimizer (OptimizerConfiguration): Optimizer configuration. machine (Machine): Machine information. gradient_accumulation_steps (int): Number of gradient accumulation steps. """ self._config = config - self._io = io self._optimizer = optimizer - self._iteration = self._io.iteration + self._iteration = self._config.iteration self._scores = {} self._gradient_accumulation_steps = gradient_accumulation_steps self._evaluation_interval = config.evaluation_interval @@ -81,7 +78,7 @@ def train(self, context: Context) -> None: """ self.hooks.start.call(context) - # Baseline benchmark at iteration 0 so wandb has a pre-training reference point. + # Baseline benchmark at iteration 0 for zero-shot performance. if self._iteration == 0 and self._main_process: self.hooks.benchmark.call(context, 0) self.hooks.stef_beam.call(context, 0) @@ -125,14 +122,13 @@ def _train_step(self, context: Context, X_all, sample_config) -> None: X, Xm, Y, Ym = ( x[micro_step * B : (micro_step + 1) * B].float().to(context.machine.torch_device) for x in X_all ) + input_interval = sample_config.sample_interval_minutes[micro_step * B : (micro_step + 1) * B] - output_interval = select_rate( - sample_config.sample_interval_minutes, context.configuration.model.output_sample_intervals_minutes - ) # ty: ignore[possibly-missing-attribute] + output_interval = select_rate(input_interval, context.configuration.model.output_sample_intervals_minutes) # ty: ignore[possibly-missing-attribute] _, loss = context.model_container.model( X, Xm, - sample_config.sample_interval_minutes, + input_interval, output_interval, Y, Ym, diff --git a/src/s4casting/data/utils.py b/src/s4casting/data/utils.py index c835121..8f8a7f1 100644 --- a/src/s4casting/data/utils.py +++ b/src/s4casting/data/utils.py @@ -2,23 +2,143 @@ # # SPDX-License-Identifier: MPL-2.0 +import warnings from dataclasses import dataclass -import numpy as np import torch from torch.utils.data import default_collate +def build_valid_context_sampling_pairs( + context_days=(7, 10, 14, 16, 32, 64, 364), + sample_intervals_minutes=(5, 10, 15, 30, 60, 1440, 10080), + min_points=32, + max_context_len=None, + interval_context_limits={ + 5: {"min_days": 7, "max_days": 10}, + 10: {"min_days": 7, "max_days": 14}, + 15: {"min_days": 11, "max_days": 32}, + 30: {"min_days": 16, "max_days": 64}, + 60: {"min_days": 16, "max_days": 64}, + 1440: {"min_days": 32, "max_days": 364}, + 10080: {"min_days": 64, "max_days": 364}, + }, +): + """Build valid (context_days, sample_interval_minutes) pairs for a fixed padded model input. + + All sample intervals are specified in minutes: + 5 = 5 minutes + 10 = 10 minutes + 60 = 1 hour + 1440 = 1 day + 10080 = 1 week + + A pair is valid if: + - the sample count is an integer + - sample count >= min_points + - sample count <= max_context_samples (if max_context_samples is provided) + - it satisfies optional manual pruning rules in interval_context_limits + + Args: + context_days: iterable of allowed context widths in days + sample_intervals_minutes: iterable of allowed sample intervals in minutes + min_points: minimum raw sample count required + max_context_len: fixed padded input length; if None, inferred from valid pairs + interval_context_limits: optional dict of the form + { + 5: {"min_days": 7, "max_days": 10}, + 10: {"min_days": 7, "max_days": 14}, + 15: {"min_days": 14, "max_days": 32}, + 30: {"min_days": 16, "max_days": 64}, + 60: {"min_days": 16, "max_days": 64}, + 1440: {"min_days": 32, "max_days": 364}, + 10080: {"min_days": 364, "max_days": 364}, + } + + Returns: + dict with: + - "valid_pairs": list of dicts + - "recommended_max_context_samples": smallest exact padded length covering all valid pairs + """ + valid = [] + + for days in context_days: + context_minutes = days * 1440 + + for interval_minutes in sample_intervals_minutes: + # Optional policy pruning + if interval_context_limits and interval_minutes in interval_context_limits: + limits = interval_context_limits[interval_minutes] + if days < limits["min_days"] or days > limits["max_days"]: + warnings.warn( + f"Skipping ({days}d, {interval_minutes}min): context window outside allowed range " + f"[{limits['min_days']}d, {limits['max_days']}d] for {interval_minutes}-minute interval.", + stacklevel=2, + ) + continue + + points = context_minutes / interval_minutes + + # Must align exactly + if abs(points - round(points)) > 1e-9: + warnings.warn( + f"Skipping ({days}d, {interval_minutes}min): {days * 1440} minutes is not exactly " + f"divisible by {interval_minutes}-minute interval (would give {points:.4f} points).", + stacklevel=2, + ) + continue + + points = round(points) + + if points < min_points: + warnings.warn( + f"Skipping ({days}d, {interval_minutes}min): {points} points is below the minimum of {min_points}.", + stacklevel=2, + ) + continue + + if max_context_len is not None and points > max_context_len: + warnings.warn( + f"Skipping ({days}d, {interval_minutes}min): {points} points exceeds max_context_len " + f"of {max_context_len}.", + stacklevel=2, + ) + continue + + valid.append({ + "context_days": days, + "sample_interval_minutes": interval_minutes, + "points": points, + }) + + if not valid: + raise ValueError("No valid (context_days, sample_interval_minutes) pairs found.") + + recommended_max_context_samples = max(row["points"] for row in valid) + + final_max_context_samples = max_context_len if max_context_len is not None else recommended_max_context_samples + + # Add padding metadata + for row in valid: + row["pad_left"] = final_max_context_samples - row["points"] + row["padded_length"] = final_max_context_samples + + return { + "valid_pairs": valid, + "recommended_max_context_samples": recommended_max_context_samples, + } + + @dataclass class SampleConfig: """Class for keeping track of parameters when sampling from dataset.""" location: int # TODO map this back to a string? sample_interval_minutes: int - context_window_days: int + context_window_samples: int start_timestamp: float n_features: int - predict_window_days: float | None = None + predict_window_samples: float | None = None # TODO: add feature names? @@ -27,11 +147,11 @@ class SampleConfigBatch: """Class for keeping track of parameters when collating configs from dataset.""" location: torch.Tensor # [B] long - sample_interval_minutes: int - context_window_days: int + sample_interval_minutes: torch.Tensor # [B] float + context_window_samples: torch.Tensor # [B] long start_timestamp: torch.Tensor # [B] float n_features: int - predict_window_days: torch.Tensor + predict_window_samples: torch.Tensor def collate_sample_configs(configs: list[SampleConfig]) -> SampleConfigBatch: @@ -43,18 +163,10 @@ def collate_sample_configs(configs: list[SampleConfig]) -> SampleConfigBatch: # Always collate required numeric fields into tensors location = torch.tensor([c.location for c in configs], dtype=torch.long) sample_interval_minutes = torch.tensor([c.sample_interval_minutes for c in configs], dtype=torch.long) - context_window_days = torch.tensor([c.context_window_days for c in configs], dtype=torch.long) + context_window_samples = torch.tensor([c.context_window_samples for c in configs], dtype=torch.long) start_timestamp = torch.tensor([c.start_timestamp for c in configs], dtype=torch.float64) n_features = torch.tensor([c.n_features for c in configs], dtype=torch.long) - predict_window_days = torch.tensor([c.predict_window_days for c in configs], dtype=torch.float32) - - if not torch.all(sample_interval_minutes == sample_interval_minutes[0]): - unique = torch.unique(sample_interval_minutes).tolist() - raise ValueError(f"Mixed sample_interval in batch: {unique}") - - if not torch.all(context_window_days == context_window_days[0]): - unique = torch.unique(context_window_days).tolist() - raise ValueError(f"Mixed context_window_days in batch: {unique}") + predict_window_samples = torch.tensor([c.predict_window_samples for c in configs], dtype=torch.float32) if not torch.all(n_features == n_features[0]): unique = torch.unique(n_features).tolist() @@ -62,11 +174,11 @@ def collate_sample_configs(configs: list[SampleConfig]) -> SampleConfigBatch: return SampleConfigBatch( location=location, - sample_interval_minutes=sample_interval_minutes[0].item(), - context_window_days=context_window_days[0].item(), + sample_interval_minutes=sample_interval_minutes, + context_window_samples=context_window_samples, start_timestamp=start_timestamp, n_features=n_features[0].item(), - predict_window_days=predict_window_days, + predict_window_samples=predict_window_samples, ) @@ -83,80 +195,3 @@ def collate_single_interval(batch): configs = collate_sample_configs(configs) return task_batch, configs - - -class ConcatDatasetSampler(torch.utils.data.Sampler): - """Batch sampler for a torch.utils.data.ConcatDataset. - - It that guarantees each mini-batch is drawn from exactly one underlying dataset. - """ - - def __init__( - self, - train_ds_lengths, - batch_size, - drop_last=True, - num_replicas=1, - rank=0, - seed=0, - ): - """Init fn for sampler. - - Parameters - ---------- - train_ds_lengths (tuple) : Tuple of lengths of each dataset. - batch_size (int) : Number of samples per batch. - drop_last (bool) : Drop samples that dont fit in batch. - num_replicas (int): Worldsize i.e. number of GPUs. - rank (int): Which gpu is being used. - seed (int): batcher seed. - """ - self.num_replicas = num_replicas - self.rank = rank - self.seed = seed - self.epoch = 0 - - self.batch_samplers = [] - base = 0 - for L in train_ds_lengths: - self.batch_samplers.append( - list( - torch.utils.data.BatchSampler( - torch.utils.data.SubsetRandomSampler(range(base, base + L)), - batch_size=batch_size, - drop_last=drop_last, - ) - ) - ) - base += L - - self.cumsum = np.cumsum([len(bs) for bs in self.batch_samplers]) - self.total_batches = int(self.cumsum[-1]) if len(self.cumsum) else 0 - - def set_epoch(self, epoch): - """Set the epoch to change the shuffle order for ddp.""" - self.epoch = epoch - - def __iter__(self): - """Yield batches of indices such that each batch is drawn from a single underlying dataset.""" - if self.total_batches == 0: - return - - g = torch.Generator().manual_seed(self.seed + self.epoch) - order = torch.randperm(self.total_batches, generator=g).tolist() - - if self.num_replicas > 1: - total_size = (self.total_batches // self.num_replicas) * self.num_replicas - order = order[:total_size] - order = order[self.rank : total_size : self.num_replicas] - - for idx in order: - i = int(np.searchsorted(self.cumsum, idx, "right")) - prev = int(self.cumsum[i - 1]) if i else 0 - yield self.batch_samplers[i][idx - prev] - - def __len__(self): - """Return the nominal number of batches produced by the sampler.""" - if self.num_replicas > 1: - return self.total_batches // self.num_replicas - return self.total_batches diff --git a/src/s4casting/eval/evaluator_heads.py b/src/s4casting/eval/evaluator_heads.py index 2c7c3ab..7d8e2f8 100644 --- a/src/s4casting/eval/evaluator_heads.py +++ b/src/s4casting/eval/evaluator_heads.py @@ -9,7 +9,6 @@ from s4casting.core.distributions import gmm_to_quantiles from s4casting.core.hooks import CommonHooks, TrainingHooks from s4casting.data.dataset.interface import get_ordered_feature_names -from s4casting.data.utils import SampleConfigBatch from s4casting.eval.metrics import Metrics from s4casting.visualisation import plot_quantiles @@ -105,7 +104,9 @@ def report( loss: float, iteration: int, report_type: str, - sample_config: SampleConfigBatch, + context_window_days: int, + predict_window_days: int, + input_interval: int, output_interval: int, n_day_ahead: int, location: str | None = None, @@ -129,7 +130,9 @@ def report( report_type (str): Type of report (e.g., "benchmark", "evaluation", "inference"). output_interval (int): Output sample rate of eval step. n_day_ahead (int): Days ahead for the forecast (different from prediction width). - sample_config (SampleConfigBatch): Sample configuration. + context_window_days (int): Context window in days. + predict_window_days (int): Prediction window in days. + input_interval (int): Input sample interval. """ if self.head_type == "gmm": logpi, sigma, mu = (x for x in prediction.unbind(dim=-1)) # (B, T, G, 3) @@ -146,9 +149,9 @@ def report( Y=Y, Ym=Ym, quantiles=prediction, - input_window_days=sample_config.context_window_days - sample_config.predict_window_days[0].item(), + input_window_days=context_window_days - predict_window_days, n_day_ahead=n_day_ahead, - input_interval=sample_config.sample_interval_minutes, + input_interval=input_interval, output_interval=output_interval, ) @@ -161,8 +164,8 @@ def report( ) metrics = Metrics( output_sample_interval_minutes=output_interval, - prediction_window_days=int(sample_config.predict_window_days[0].item()), - input_sample_interval_minutes=sample_config.sample_interval_minutes, + predict_window_days=predict_window_days, + input_sample_interval_minutes=input_interval, climits=climits, # type: ignore[arg-type] quantiles=quantiles, Y=Y, @@ -180,10 +183,10 @@ def report( Y, Ym, times, - sample_config.sample_interval_minutes, + input_interval, output_interval, report_type, - "short" if sample_config.sample_interval_minutes == output_interval else "medium", + "short" if input_interval == output_interval else "medium", feature_names=get_ordered_feature_names(context.configuration), ) diff --git a/src/s4casting/eval/metrics.py b/src/s4casting/eval/metrics.py index 27b7fe6..3f5f941 100644 --- a/src/s4casting/eval/metrics.py +++ b/src/s4casting/eval/metrics.py @@ -21,7 +21,7 @@ class Metrics: def __init__( self, Y: torch.Tensor, - prediction_window_days: int, + predict_window_days: int, model_config: ModelConfiguration, metrics_config: MetricsConfiguration, input_sample_interval_minutes: int, @@ -42,7 +42,7 @@ def __init__( quantile_values (torch.Tensor): quantile values. Y (torch.Tensor): Ground truth tensor. Y (int): prediction window size. - prediction_window_days (int): prediction window size. + predict_window_days (int): prediction window size. model_config (ModelConfiguration): Model configuration for setting up config metrics_config (MetricsConfiguration): Metrics configuration for setting up config sign (str): Specifies whether to calculate metrics for "LDN", "ODN", or "BOTH". @@ -78,7 +78,7 @@ def __init__( self.ldn_index = np.argwhere( np.array(self.model_config.output_head.quantile_values) == mape_ldn_quantile_value )[0][0] - self.sufficient_days = prediction_window_days >= model_config.days_per_month + self.sufficient_days = predict_window_days >= model_config.days_per_month # check whether input and outputs rates are the same # as CRPS, and peak metrics require that they are diff --git a/src/s4casting/factories/model_container.py b/src/s4casting/factories/model_container.py index 8ea8766..a69f621 100644 --- a/src/s4casting/factories/model_container.py +++ b/src/s4casting/factories/model_container.py @@ -3,13 +3,14 @@ # SPDX-License-Identifier: MPL-2.0 # type: ignore +import warnings from copy import deepcopy from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP from s4casting.core.config import DTYPE_MAP, IOConfiguration, ModelConfiguration -from s4casting.core.loss import SoftClipLoss, SubsetNLLLoss, SubsetPinballLoss +from s4casting.core.loss import CompositeLoss, SoftClipLoss, SubsetNLLLoss, SubsetPinballLoss from s4casting.core.machine import Machine from s4casting.core.model_container import ModelContainer from s4casting.model._encoders import PatchDecoder, PatchEncoder, SeperateLocTime, SSEncoder @@ -87,32 +88,41 @@ def provide_model_container(config: ModelConfiguration, io_config: IOConfigurati len(x.subset_features) if x.subset_features else x.n_features for x in {k.split("_")[0]: v for k, v in io_config.features.items()}.values() ) + has_time = any(v.loader == "time" for v in io_config.features.values()) + n_data_features = n_features - 3 * has_time + if config.patch_encoder.patch_size != 1 and n_data_features > config.n_out_features: + warnings.warn( + f"Weather auxiliary loss is disabled because patch_size=" + f"{config.patch_encoder.patch_size} > 1. Set patch_size=1 to enable it.", + UserWarning, + stacklevel=2, + ) + n_weather_features = max(0, n_data_features - config.n_out_features) if config.patch_encoder.patch_size == 1 else 0 latent_dim = config.transformer.latent_dim if config.model == "transformer" else config.latent_dim if config.patch_encoder.arch == "linear": if config.model == "transformer": patch_encoder = PatchEncoder( latent_dim, - (n_features - 3 * any(v.loader == "time" for v in io_config.features.values())) - * 2, # to accept target + mask + n_data_features * 2, # to accept target + mask config.patch_encoder.patch_size, ) else: patch_encoder = PatchEncoder( latent_dim, - n_features - 3 * any(v.loader == "time" for v in io_config.features.values()), + n_data_features, config.patch_encoder.patch_size, ) elif config.patch_encoder.arch == "ss": patch_encoder = SSEncoder( latent_dim, - n_features - 3 * any(v.loader == "time" for v in io_config.features.values()), + n_data_features, n_layers=config.patch_encoder.n_layers, patch_size=config.patch_encoder.patch_size, ) - if any(v.loader == "time" for v in io_config.features.values()): + if has_time: patch_encoder = SeperateLocTime(patch_encoder) patch_decoder = PatchDecoder( @@ -137,6 +147,7 @@ def provide_model_container(config: ModelConfiguration, io_config: IOConfigurati output_head = QuantileHead(latent_dim, config.n_out_features, config.output_head.quantile_values) loss_fn = _build_loss_fn(config) + composite_loss = CompositeLoss(config.loss.components) assert config.loss.loss in ["nll", "mse", "pinball"], f"Loss function {config.loss.loss} not implemented" @@ -154,14 +165,17 @@ def loss_fn(*args, **kwargs): n_layer=config.ssm.n_layers, kernel=config.ssm.kernel, backend="keops" if machine.torch_device_kind == "cuda" else "naive", + mixer_size=config.ssm.mixer_size, patch_size=config.patch_encoder.patch_size, norm_clamp=config.norm_clamp, norm_eps=config.norm_eps, loss_fn=loss_fn, + composite_loss=composite_loss, output_head=output_head, patch_encoder=patch_encoder, patch_decoder=patch_decoder, base_sample_interval_minutes=config.base_sample_interval_minutes, + n_weather_features=n_weather_features, ) elif config.model == "transformer": diff --git a/src/s4casting/factories/trainer.py b/src/s4casting/factories/trainer.py index c25890b..d060c5c 100644 --- a/src/s4casting/factories/trainer.py +++ b/src/s4casting/factories/trainer.py @@ -2,19 +2,16 @@ # # SPDX-License-Identifier: MPL-2.0 -from s4casting.core.config import IOConfiguration, OptimizerConfiguration, TrainingConfiguration +from s4casting.core.config import OptimizerConfiguration, TrainingConfiguration from s4casting.core.machine import Machine from s4casting.core.trainer import Trainer -def provide_trainer( - config: TrainingConfiguration, io: IOConfiguration, optimizer: OptimizerConfiguration, machine: Machine -) -> Trainer: +def provide_trainer(config: TrainingConfiguration, optimizer: OptimizerConfiguration, machine: Machine) -> Trainer: """Provide a Trainer instance. Args: config (TrainingConfiguration): Training configuration. - io (IOConfiguration): IO configuration. optimizer (OptimizerConfiguration): Optimizer configuration. machine (Machine): Machine information. @@ -27,7 +24,6 @@ def provide_trainer( ) return Trainer( config, - io, optimizer, machine, gradient_accumulation_steps=config.gradient_accumulation_steps // machine.world_size, diff --git a/src/s4casting/model/mamba.py b/src/s4casting/model/mamba.py index c3b68e0..f45f26f 100644 --- a/src/s4casting/model/mamba.py +++ b/src/s4casting/model/mamba.py @@ -57,7 +57,7 @@ def __init__( self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path + self.use_fast_path = use_fast_path and causal_conv1d_fn is not None self.layer_idx = layer_idx self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) @@ -142,7 +142,7 @@ def forward(self, hidden_states, inference_params=None, *, rate=1.0): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat if ( - self.use_fast_path and causal_conv1d_fn is not None and inference_params is None + self.use_fast_path and causal_conv1d_fn is not None and inference_params is None and isinstance(rate, float) ): # Doesn't support outputting the states out = mamba_inner_fn( xz, @@ -184,8 +184,14 @@ def forward(self, hidden_states, inference_params=None, *, rate=1.0): x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() - dt = dt * rate - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + + if isinstance(rate, torch.Tensor): + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) # move rearrange before rate scaling + dt = dt * rate.to(dt.device).view(-1, 1, 1) # (B,) broadcast over d_inner and L + elif isinstance(rate, float): + dt = dt * rate + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] diff --git a/src/s4casting/model/mambacpu.py b/src/s4casting/model/mambacpu.py index f048b4d..e6dfecb 100644 --- a/src/s4casting/model/mambacpu.py +++ b/src/s4casting/model/mambacpu.py @@ -135,8 +135,13 @@ def forward(self, hidden_states, inference_params=None, *, rate=1.0): x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() - dt = dt * rate - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + + if isinstance(rate, torch.Tensor): + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + dt = dt * rate.view(-1, 1, 1) + elif isinstance(rate, float): + dt = dt * rate + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] diff --git a/src/s4casting/model/ss.py b/src/s4casting/model/ss.py index 1d92143..7f77c4b 100644 --- a/src/s4casting/model/ss.py +++ b/src/s4casting/model/ss.py @@ -5,11 +5,12 @@ import torch from torch import nn -from s4casting.core.loss import SubsetNLLLoss +from s4casting.core.loss import CompositeLoss, SubsetNLLLoss from s4casting.model._blocks import GruBlock, S4Block, S6Block, SequenceResidualBlock from s4casting.model._encoders import PatchDecoder, PatchEncoder, SeperateLocTime from s4casting.model._heads import GMMHead from s4casting.model._norm import denorm, norm, norm_target +from s4casting.model.weather import WeatherAuxTask class SSModel(nn.Module): @@ -22,15 +23,18 @@ def __init__( n_layer=4, kernel="s4", backend="keops", + mixer_size=None, patch_size=1, n_out_features=1, norm_clamp=10.0, norm_eps=1e-5, loss_fn: nn.Module = SubsetNLLLoss(1, "masked"), + composite_loss: nn.Module = CompositeLoss({"primary": 1.0}), output_head: nn.Module = GMMHead(256, 2, 1), patch_encoder: nn.Module = PatchEncoder(256, 5, 8), patch_decoder: nn.Module = PatchDecoder(256, 256, 8, [15], [15]), base_sample_interval_minutes: int = 15, + n_weather_features: int = 0, ): """Initialize the SSModel. @@ -39,20 +43,25 @@ def __init__( n_layer (int): Number of S4/GRU layers. kernel (str): Kernel type ("s4", "s6", or "gru"). backend (str): Backend to use for the kernel. + mixer_size (int|None): Time domain mixer size. patch_size (int): Size of the patches for patch encoding/decoding. n_out_features (int): Number of output features. norm_clamp (float): Clamping value for normalization. norm_eps (float): Epsilon for norm. loss_fn (nn.Module): Loss function module. + composite_loss (nn.Module): Combines named loss terms with per-component weights. output_head (nn.Module): Output head module. patch_encoder (nn.Module): Patch encoder module. patch_decoder (nn.Module): Patch decoder module. base_sample_interval_minutes (int): Base sample interval. + n_weather_features (int): Number of auxiliary weather channels for the masking + loss. Set to 0 to disable. """ super().__init__() self.patch_size = patch_size self.kernel = kernel self.loss_fn = loss_fn + self.composite_loss = composite_loss self.norm_clamp = norm_clamp self.norm_eps = norm_eps self.n_out_features = n_out_features @@ -61,10 +70,12 @@ def __init__( self.output_head = output_head self.latent_dim = latent_dim self.base_sample_interval_minutes = base_sample_interval_minutes + self.weather_aux = WeatherAuxTask(latent_dim, n_weather_features) Kernel = {"s4": S4Block, "s6": S6Block, "gru": GruBlock}[kernel] self.ss_layers = nn.ModuleList([ - SequenceResidualBlock(self.latent_dim, Kernel, backend=backend) for _ in range(n_layer) + SequenceResidualBlock(self.latent_dim, Kernel, backend=backend, mixer_size=mixer_size) + for _ in range(n_layer) ]) def forward(self, x, xm, input_interval, output_interval, y=None, ym=None): @@ -75,8 +86,8 @@ def forward(self, x, xm, input_interval, output_interval, y=None, ym=None): xm (torch.Tensor): Mask tensor of shape (batch_size, seq_len, n_features). y (torch.Tensor | None): Target tensor of shape (batch_size, seq_len, n_features). ym (torch.Tensor | None): Target mask tensor of shape (batch_size, seq_len, n_features). - input_interval (int): Used for multi-rate training of state space models. - output_interval (int): Used for multi-rate training of state space models. + input_interval (torch.Tensor | float): Used for multi-rate training of state space models. + output_interval (torch.Tensor | float): Used for multi-rate training of state space models. Returns: torch.Tensor: Output tensor of shape (batch_size, seq_len, n_out_features). @@ -93,9 +104,9 @@ def forward(self, x, xm, input_interval, output_interval, y=None, ym=None): clamp=self.norm_clamp, dims=torch.arange(x.shape[-1] - 3 * isinstance(self.patch_encoder, SeperateLocTime)), ) - x = self.patch_encoder( - x, input_interval / self.base_sample_interval_minutes, output_interval // input_interval - ) # B T F -> B T/P E + x, weather_gt, weather_mask = self.weather_aux.prepare(x) + x = self.patch_encoder(x, input_interval, output_interval) # B T F -> B T/P E + x_enc = x for layer in self.ss_layers: x = layer(x, output_interval / self.base_sample_interval_minutes) @@ -107,9 +118,13 @@ def forward(self, x, xm, input_interval, output_interval, y=None, ym=None): y = norm_target( mean_in=mean_in[..., : self.n_out_features], std_in=std_in[..., : self.n_out_features], y=y, ym=ym ) - loss = self.loss_fn( - out=x, target=y, input_interval=input_interval, output_interval=output_interval, mask=ym - ) + _losses: dict[str, torch.Tensor] = { + "primary": self.loss_fn( + out=x, target=y, input_interval=input_interval, output_interval=output_interval, mask=ym + ), + "weather": self.weather_aux.compute_loss(x_enc, weather_gt, weather_mask), + } + loss = self.composite_loss(_losses) x = denorm( mean_in=mean_in[..., : self.n_out_features].unsqueeze(-1), diff --git a/sync_repos.sh b/sync_repos.sh index 195370c..7fcc9a6 100755 --- a/sync_repos.sh +++ b/sync_repos.sh @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: Contributors to the s4casting project +# +# SPDX-License-Identifier: MPL-2.0 + #!/bin/bash usage() { diff --git a/tests/utils.py b/tests/utils.py index 8a4273d..21b7e89 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -160,7 +160,7 @@ def create_tmp_model_checkpoint( model_container = fc.provide_model_container(config.model, config.io, machine) optimizer = fc.provide_optimizer(config.optimizer, model_container.raw_model.parameters()) scheduler = fc.provide_scheduler(config.scheduler, optimizer) - trainer = fc.provide_trainer(config=config.training, io=config.io, optimizer=config.optimizer, machine=machine) + trainer = fc.provide_trainer(config=config.training, optimizer=config.optimizer, machine=machine) checkpointer = fc.provide_checkpointer(config.io, trainer.hooks) evaluator_head = fc.provide_evaluator_head(config.model, trainer.hooks) evaluator = fc.provide_evaluation(trainer.hooks, evaluator_head) From 44eae430a431e0871ab889176ce0aba0e23df375 Mon Sep 17 00:00:00 2001 From: mesarcik Date: Mon, 1 Jun 2026 11:38:44 +0000 Subject: [PATCH 4/6] added _blocks --- src/s4casting/model/_blocks.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/s4casting/model/_blocks.py b/src/s4casting/model/_blocks.py index 74bdfe9..f55a281 100644 --- a/src/s4casting/model/_blocks.py +++ b/src/s4casting/model/_blocks.py @@ -231,21 +231,26 @@ def forward(self, x, prev_hidden=None): class SequenceResidualBlock(nn.Module): """Sequence Residual Block with LayerNorm and S4/GRU layer.""" - def __init__(self, d_input, kernel, backend="keops"): + def __init__(self, d_input, kernel, backend="keops", mixer_size=None): """Initialize the SequenceResidualBlock. Args: d_input (int): Dimension of the input. kernel (nn.Module): Kernel layer (S4Block or GruBlock). backend (str): Backend to use for the kernel. + mixer_size (int | None): maximum context width size to mix along. """ super().__init__() + self.mixer_size = mixer_size self.layer = kernel(d_input, backend) self.norm = torch.nn.LayerNorm((d_input,)) + self.mixer = nn.Identity() if mixer_size is None else nn.Linear(mixer_size, mixer_size) def forward(self, x, rate=1, **kwargs): """Forward pass of the SequenceResidualBlock. + Note mixer is applied in time if mixer time is not None. + Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_input). rate: Scaling factor for the kernel. @@ -257,4 +262,5 @@ def forward(self, x, rate=1, **kwargs): y = x y = self.norm(rearrange(y, "b ... d -> b (...) d")).view(y.shape) y, _new_state = self.layer(y, rate=rate, **kwargs) + y = self.mixer(y.swapaxes(1, 2)).swapaxes(1, 2) return x + y From b42a300a0f52c4b0f7e203e5745724e8ae3cfb6f Mon Sep 17 00:00:00 2001 From: mesarcik Date: Mon, 1 Jun 2026 12:34:02 +0000 Subject: [PATCH 5/6] _encoders added --- src/s4casting/model/_encoders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/s4casting/model/_encoders.py b/src/s4casting/model/_encoders.py index 999aeae..72f2ef1 100644 --- a/src/s4casting/model/_encoders.py +++ b/src/s4casting/model/_encoders.py @@ -36,11 +36,11 @@ def __init__( SequenceResidualBlock(latent_dim, _kernel, backend=backend) for _ in range(n_layers) ]) - def forward(self, x, sample_rate_conversion_factor, patch_size): + def forward(self, x, input_interval, output_interval): x = self.expand(x) for layer in self.ss_layers: - x = layer(x, sample_rate_conversion_factor) - return x[:, patch_size - 1 :: patch_size, :] + x = layer(x, input_interval / output_interval) + return x class PatchEncoder(nn.Module): From f2eef0447ed0fb99fd306e89798ae7fc10d6969b Mon Sep 17 00:00:00 2001 From: mesarcik Date: Tue, 2 Jun 2026 09:28:35 +0000 Subject: [PATCH 6/6] minor patch --- src/s4casting/model/weather.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/s4casting/model/weather.py b/src/s4casting/model/weather.py index ff238e6..39a8a8d 100644 --- a/src/s4casting/model/weather.py +++ b/src/s4casting/model/weather.py @@ -78,7 +78,7 @@ def compute_loss( Returns: Scalar MSE loss over the masked region; zero when n_weather_features=0. """ - if self.n_weather_features == 0: + if weather_gt.shape[2] == 0: return x_enc.new_zeros(()) weather_forecast = self.weather_forecaster(x_enc) # (B, T, C)