autogluon.timeseries 1.4.1b20250926__py3-none-any.whl → 1.4.1b20250927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. autogluon/timeseries/models/__init__.py +2 -0
  2. autogluon/timeseries/models/toto/__init__.py +3 -0
  3. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  4. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  5. autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
  6. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  7. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  8. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  9. autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
  10. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
  11. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  12. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  13. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  14. autogluon/timeseries/models/toto/dataloader.py +108 -0
  15. autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
  16. autogluon/timeseries/models/toto/model.py +234 -0
  17. autogluon/timeseries/version.py +1 -1
  18. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/METADATA +10 -5
  19. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/RECORD +26 -11
  20. /autogluon.timeseries-1.4.1b20250926-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20250927-py3.9-nspkg.pth +0 -0
  21. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/LICENSE +0 -0
  22. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/NOTICE +0 -0
  23. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/WHEEL +0 -0
  24. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/namespace_packages.txt +0 -0
  25. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/top_level.txt +0 -0
  26. {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/zip-safe +0 -0
@@ -0,0 +1,423 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Union, cast
8
+
9
+ import numpy as np
10
+ import torch
11
+ from einops import rearrange, repeat
12
+ from gluonts.torch.distributions import AffineTransformed
13
+ from torch.distributions import Distribution
14
+
15
+ from .backbone import TotoBackbone
16
+ from .dataset import (
17
+ MaskedTimeseries,
18
+ pad_array,
19
+ pad_id_mask,
20
+ replace_extreme_values,
21
+ )
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class Forecast:
26
+ mean: torch.Tensor
27
+ samples: Optional[torch.Tensor]
28
+
29
+ def quantile(self, q: Union[float, torch.Tensor]) -> torch.Tensor:
30
+ """
31
+ Compute the quantile of the forecast samples.
32
+ """
33
+ assert self.samples is not None, "samples must be provided to compute quantiles"
34
+ assert isinstance(q, float) or isinstance(q, torch.Tensor), "q must be a float or a tensor"
35
+ if isinstance(q, float):
36
+ q = torch.tensor(q, device=self.samples.device, dtype=self.samples.dtype)
37
+ return self.samples.quantile(q, dim=-1)
38
+
39
+ @property
40
+ def median(self) -> torch.Tensor:
41
+ """
42
+ The median of the forecast samples.
43
+ """
44
+ return self.quantile(0.5)
45
+
46
+ @property
47
+ def std(self) -> torch.Tensor:
48
+ """
49
+ Compute the standard deviation of the forecast samples.
50
+ """
51
+ assert self.samples is not None, "samples must be provided to compute standard deviation"
52
+ return self.samples.std(dim=-1)
53
+
54
+
55
+ class TotoForecaster:
56
+ """
57
+ A forecaster class for the Toto model that handles autoregressive decoding for time series forecasting.
58
+
59
+ This class wraps a TotoBackbone model and provides methods to generate forecasts for time series data.
60
+ The forecasting process uses an autoregressive decoding algorithm:
61
+
62
+ 1. The model first processes the entire input context (historical data)
63
+ 2. For each future time step:
64
+ - The model generates a distribution over possible values
65
+ - Either the mean or random samples are drawn from this distribution
66
+ - The generated value(s) are appended to the input sequence
67
+ - The process repeats with this extended sequence
68
+
69
+ When generating multiple samples (num_samples > 1), the model creates separate trajectories for each sample:
70
+ - Each trajectory starts with the same historical context
71
+ - As sampling progresses, each trajectory evolves independently
72
+ - This results in num_samples different possible future paths
73
+ - Samples can be processed in batches (samples_per_batch) to manage memory usage
74
+
75
+ The forecaster efficiently reuses computation from the context processing phase using a key-value cache,
76
+ which stores intermediate transformer attention states to avoid redundant computation.
77
+
78
+ The forecaster handles data preprocessing, including padding to match the model's patch size,
79
+ and postprocessing to format the outputs as a Forecast object containing means and optional samples.
80
+ """
81
+
82
+ model: TotoBackbone
83
+
84
+ def __init__(self, model: TotoBackbone):
85
+ self.model = model
86
+
87
+ def forecast(
88
+ self,
89
+ inputs: MaskedTimeseries,
90
+ prediction_length: int,
91
+ num_samples: Optional[int] = None,
92
+ samples_per_batch: int = 10,
93
+ use_kv_cache: bool = True,
94
+ ) -> Forecast:
95
+ """
96
+ Generate a forecast for a batch of time series. This method works autoregressively,
97
+ i.e. it feeds the model's predictions back into itself. The decoding process is as follows:
98
+
99
+ 1. The model first processes the entire input context (historical data)
100
+ 2. For each future time step:
101
+ - The model generates a distribution over possible values
102
+ - Either the mean or random samples are drawn from this distribution
103
+ - The generated value(s) are appended to the input sequence
104
+ - The process repeats with this extended sequence
105
+
106
+ There are two modes of operation:
107
+ 1. num_samples is None: generate a single mean prediction
108
+ 2. num_samples is not None: generate num_samples random samples
109
+
110
+ When num_samples is not None, the model creates num_samples separate trajectories for each sample:
111
+ - Each trajectory starts with the same historical context
112
+ - As sampling progresses, each trajectory evolves independently
113
+ - This results in num_samples different possible future paths
114
+ - Samples can be processed in batches (samples_per_batch) to manage memory usage
115
+
116
+ When using samples_per_batch, this batch size compounds with the optional batch dimension of the input.
117
+ For example, if you have a batch of 10 time series, and you set samples_per_batch to 10,
118
+ the effective batch size is 100. For the best performance, set samples_per_batch
119
+ as high as possible, subject to memory constraints.
120
+
121
+ Args:
122
+ inputs: A MaskedTimeseries object containing the input time series.
123
+ prediction_length: The number of future time steps to predict.
124
+ num_samples:
125
+ The number of samples to generate.
126
+ If None, a single mean prediction is generated. However,
127
+ the mean point forecast tends to be less accurate than the
128
+ median or mean of the samples (provided enough samples are generated).
129
+ It's recommended to use at least 128 samples for reliable forecasts.
130
+ samples_per_batch:
131
+ The number of samples to generate per batch.
132
+ In most cases, this should be as high as possible, subject to memory constraints.
133
+ When the inputs have a batch dimension, the effective batch size is samples_per_batch * batch_size.
134
+ use_kv_cache:
135
+ Whether to use a key-value cache for the model. In most cases, this should be True,
136
+ as it significantly speeds up inference.
137
+ """
138
+ if len(inputs.series.shape) == 2:
139
+ # unbatched input, variates x time_steps
140
+ batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs]))
141
+ else:
142
+ # input is already batched
143
+ batch = inputs
144
+
145
+ # pad the input to the nearest multiple of the patch size
146
+ series = pad_array(batch.series, self.model.patch_embed.stride)
147
+ padding_mask = pad_array(batch.padding_mask, self.model.patch_embed.stride)
148
+ id_mask = batch.id_mask
149
+ if id_mask is not None:
150
+ id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride)
151
+ timestamp_seconds = pad_array(batch.timestamp_seconds, self.model.patch_embed.stride)
152
+ time_interval_seconds: torch.Tensor = torch.as_tensor(
153
+ batch.time_interval_seconds, device=series.device, dtype=torch.int
154
+ )
155
+
156
+ if num_samples is not None:
157
+ samples = self.generate_samples(
158
+ inputs=series,
159
+ prediction_length=prediction_length,
160
+ num_samples=num_samples,
161
+ timestamp_seconds=timestamp_seconds,
162
+ time_interval_seconds=time_interval_seconds,
163
+ input_padding_mask=padding_mask,
164
+ id_mask=id_mask,
165
+ sampling_batch_size=samples_per_batch,
166
+ use_kv_cache=use_kv_cache,
167
+ )
168
+ mean = samples.mean(dim=-1)
169
+ else:
170
+ mean = self.generate_mean(
171
+ inputs=series,
172
+ prediction_length=prediction_length,
173
+ timestamp_seconds=timestamp_seconds,
174
+ time_interval_seconds=time_interval_seconds,
175
+ input_padding_mask=padding_mask,
176
+ id_mask=id_mask,
177
+ use_kv_cache=use_kv_cache,
178
+ )
179
+ samples = None
180
+
181
+ return Forecast(mean=mean, samples=samples)
182
+
183
+ @torch.no_grad()
184
+ def generate_mean(
185
+ self,
186
+ inputs: torch.Tensor,
187
+ prediction_length: int,
188
+ timestamp_seconds: torch.Tensor,
189
+ time_interval_seconds: torch.Tensor,
190
+ input_padding_mask: Optional[torch.Tensor] = None,
191
+ id_mask: Optional[torch.Tensor] = None,
192
+ use_kv_cache: bool = False,
193
+ ) -> torch.Tensor:
194
+ """
195
+ Generate a point prediction by taking the mean of the output distribution at each step.
196
+ This method works autoregressively, i.e. it feeds the model's predictions back into itself
197
+ to generate the next prediction.
198
+ """
199
+ if input_padding_mask is None:
200
+ input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device)
201
+ if id_mask is None:
202
+ id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device)
203
+
204
+ ## round up the prediction length to the nearest multiple of the patch size
205
+ patch_size = self.model.patch_embed.stride
206
+ rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size)
207
+ start_index = inputs.shape[-1]
208
+ end_index = start_index + prediction_length
209
+
210
+ # TODO: maybe pass in future masks, rather than making assumptions here?
211
+ dummy_padding = torch.ones(
212
+ (input_padding_mask.shape[0], input_padding_mask.shape[1], patch_size),
213
+ device=inputs.device,
214
+ dtype=torch.bool,
215
+ )
216
+ dummy_id_mask = repeat(
217
+ id_mask[:, :, -1:],
218
+ "batch variates 1 -> batch variates patch_size",
219
+ patch_size=patch_size,
220
+ )
221
+ if use_kv_cache:
222
+ kv_cache = self.model.allocate_kv_cache(
223
+ batch_size=inputs.shape[0],
224
+ num_variates=inputs.shape[1],
225
+ max_time_steps=inputs.shape[2] + rounded_steps,
226
+ device=inputs.device,
227
+ dtype=inputs.dtype,
228
+ )
229
+ else:
230
+ kv_cache = None
231
+
232
+ scaling_prefix_length = inputs.shape[-1]
233
+
234
+ for _ in range(rounded_steps // patch_size):
235
+ base_distr, loc, scale = self.model(
236
+ inputs=inputs,
237
+ input_padding_mask=input_padding_mask,
238
+ id_mask=id_mask,
239
+ kv_cache=kv_cache,
240
+ scaling_prefix_length=scaling_prefix_length,
241
+ )
242
+ distr = self.create_affine_transformed(base_distr, loc, scale)
243
+
244
+ # We remove extreme values that can occur early in training
245
+ # and cause validation metrics to be NaN
246
+ samples = replace_extreme_values(distr.mean[:, :, -patch_size:])
247
+
248
+ inputs = torch.cat([inputs, samples], dim=-1)
249
+ id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1)
250
+ input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1)
251
+ for _ in range(patch_size):
252
+ next_timestamp: torch.Tensor = timestamp_seconds[:, :, -1] + time_interval_seconds
253
+ timestamp_seconds = torch.cat([timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1)
254
+
255
+ return inputs.detach()[:, :, start_index:end_index]
256
+
257
+ @torch.no_grad()
258
+ def generate_samples(
259
+ self,
260
+ inputs: torch.Tensor,
261
+ prediction_length: int,
262
+ num_samples: int,
263
+ timestamp_seconds: torch.Tensor,
264
+ time_interval_seconds: torch.Tensor,
265
+ input_padding_mask: Optional[torch.Tensor] = None,
266
+ id_mask: Optional[torch.Tensor] = None,
267
+ sampling_batch_size: int = 10,
268
+ use_kv_cache: bool = False,
269
+ ) -> torch.Tensor:
270
+ """
271
+ Generate samples from the output distribution.
272
+ This method works autorregressively, i.e. it feeds the model's predictions back into itself.
273
+ It works by creating num_samples chains. Each chain is a separate sequence of predictions.
274
+ At each time step, for each chain we take a single sample from the output distribution and append
275
+ it to the end of the sequence.
276
+ """
277
+ if input_padding_mask is None:
278
+ input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device)
279
+ if id_mask is None:
280
+ id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device)
281
+
282
+ assert num_samples % sampling_batch_size == 0, "num_samples must be divisible by sampling_batch_size"
283
+ num_batches = num_samples // sampling_batch_size
284
+
285
+ # round up the prediction length to the nearest multiple of the patch size
286
+ patch_size = self.model.patch_embed.patch_size
287
+ rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size)
288
+ start_index = inputs.shape[-1]
289
+ end_index = start_index + prediction_length
290
+
291
+ dummy_padding = torch.ones(
292
+ (
293
+ input_padding_mask.shape[0] * sampling_batch_size,
294
+ input_padding_mask.shape[1],
295
+ patch_size,
296
+ ),
297
+ dtype=torch.bool,
298
+ device=inputs.device,
299
+ )
300
+ dummy_id_mask = repeat(
301
+ id_mask[:, :, -1:],
302
+ "batch variates 1 -> (sampling_batch_size batch) variates patch_size",
303
+ sampling_batch_size=sampling_batch_size,
304
+ patch_size=patch_size,
305
+ )
306
+ inputs = repeat(
307
+ inputs,
308
+ "batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
309
+ sampling_batch_size=sampling_batch_size,
310
+ )
311
+ input_padding_mask = repeat(
312
+ input_padding_mask,
313
+ "batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
314
+ sampling_batch_size=sampling_batch_size,
315
+ )
316
+ id_mask = repeat(
317
+ id_mask,
318
+ "batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
319
+ sampling_batch_size=sampling_batch_size,
320
+ )
321
+ timestamp_seconds = repeat(
322
+ timestamp_seconds,
323
+ "batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
324
+ sampling_batch_size=sampling_batch_size,
325
+ )
326
+ time_interval_seconds = repeat(
327
+ time_interval_seconds,
328
+ "batch variates -> (sampling_batch_size batch) variates",
329
+ sampling_batch_size=sampling_batch_size,
330
+ )
331
+
332
+ all_samples = []
333
+ if use_kv_cache:
334
+ kv_cache = self.model.allocate_kv_cache(
335
+ batch_size=inputs.shape[0],
336
+ num_variates=inputs.shape[1],
337
+ max_time_steps=inputs.shape[2] + rounded_steps,
338
+ device=inputs.device,
339
+ dtype=inputs.dtype,
340
+ )
341
+ else:
342
+ kv_cache = None
343
+
344
+ scaling_prefix_length = inputs.shape[-1]
345
+
346
+ for _ in range(num_batches):
347
+ batch_inputs = torch.clone(inputs)
348
+ batch_input_padding_mask = torch.clone(input_padding_mask)
349
+ batch_id_mask = torch.clone(id_mask)
350
+ batch_timestamp_seconds = torch.clone(timestamp_seconds)
351
+
352
+ for _ in range(rounded_steps // patch_size):
353
+ base_distr, loc, scale = self.model(
354
+ inputs=batch_inputs,
355
+ input_padding_mask=batch_input_padding_mask,
356
+ id_mask=batch_id_mask,
357
+ kv_cache=kv_cache,
358
+ scaling_prefix_length=scaling_prefix_length,
359
+ )
360
+ distr = self.create_affine_transformed(base_distr, loc, scale)
361
+
362
+ sample = distr.sample()
363
+ assert sample is not None
364
+
365
+ # We remove extreme values that can occur early in training
366
+ # and cause validation metrics to be NaN
367
+ samples = replace_extreme_values(sample[:, :, -patch_size:])
368
+ batch_inputs = torch.cat([batch_inputs, samples], dim=-1)
369
+ batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1)
370
+ batch_input_padding_mask = torch.cat([batch_input_padding_mask, dummy_padding], dim=-1)
371
+ for _ in range(patch_size):
372
+ next_timestamp = batch_timestamp_seconds[:, :, -1] + time_interval_seconds
373
+ batch_timestamp_seconds = torch.cat(
374
+ [batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1
375
+ )
376
+ all_samples.append(batch_inputs)
377
+ if kv_cache is not None:
378
+ kv_cache.reset()
379
+
380
+ outputs = torch.cat(all_samples, dim=0)
381
+ unfolded_outputs = rearrange(
382
+ outputs,
383
+ "(samples batch) variates seq_len -> batch variates seq_len samples",
384
+ samples=num_samples,
385
+ ).detach()
386
+
387
+ trimmed_predictions = unfolded_outputs[:, :, start_index:end_index, :]
388
+ return trimmed_predictions
389
+
390
+ @staticmethod
391
+ def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor) -> Distribution:
392
+ """
393
+ Creates an AffineTransformed distribution with correctly matched shapes.
394
+
395
+ Handles three cases:
396
+ 1. When loc/scale are per-timestep (from CausalStdMeanScaler)
397
+ 2. When base_distr only contains the distribution for the latest patch
398
+ while loc/scale contain values for the entire sequence
399
+ 3. When loc/scale have a single time step (from StdMeanScaler/StdMinScaler)
400
+ and need to be broadcast to match a multi-step base distribution
401
+
402
+ Args:
403
+ base_distr: The base distribution to transform
404
+ loc: Location parameter
405
+ scale: Scale parameter
406
+
407
+ Returns:
408
+ An AffineTransformed distribution with properly handled shapes
409
+ """
410
+ # Get the shape of the base distribution
411
+ # We'll use this to match the time dimension of loc/scale
412
+ base_shape = base_distr.mean.shape
413
+
414
+ base_time_dim = base_shape[-1] # Time dimension of base distribution
415
+ loc_time_dim = loc.shape[-1] # Time dimension of loc
416
+
417
+ if loc_time_dim == 1:
418
+ # Case 1: If loc/scale have time dimension 1 (standard scalers), PyTorch broadcasting will handle it
419
+ return AffineTransformed(base_distr, loc=loc, scale=scale)
420
+
421
+ # Case 2: If loc/scale have time dimension > 1 (causal scaler with history)
422
+ # We need to extract only the suffix that matches the base distribution
423
+ return AffineTransformed(base_distr, loc=loc[:, :, -base_time_dim:], scale=scale[:, :, -base_time_dim:])
@@ -0,0 +1,108 @@
1
+ import functools
2
+ import time
3
+ from typing import Any, Callable, Iterator, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from autogluon.core.utils.exceptions import TimeLimitExceeded
9
+ from autogluon.timeseries import TimeSeriesDataFrame
10
+
11
+ from ._internal.dataset import MaskedTimeseries, freq_to_seconds
12
+
13
+
14
+ class TotoInferenceDataset(torch.utils.data.Dataset):
15
+ def __init__(
16
+ self,
17
+ target_df: TimeSeriesDataFrame,
18
+ max_context_length: int,
19
+ target_column: str = "target",
20
+ ):
21
+ assert max_context_length > 0
22
+ self.max_context_length = max_context_length
23
+ self.target_array = target_df[target_column].to_numpy(dtype=np.float32)
24
+
25
+ # store pointer to start:end of each time series
26
+ self.indptr = target_df.get_indptr()
27
+
28
+ self.freq = target_df.freq
29
+
30
+ def __len__(self):
31
+ return len(self.indptr) - 1 # noqa
32
+
33
+ def __getitem__(self, idx) -> np.ndarray:
34
+ start_idx = self.indptr[idx]
35
+ end_idx = self.indptr[idx + 1]
36
+
37
+ if end_idx - start_idx > self.max_context_length:
38
+ start_idx = end_idx - self.max_context_length
39
+
40
+ return self.target_array[start_idx:end_idx]
41
+
42
+
43
+ class TotoDataLoader:
44
+ def __init__(
45
+ self,
46
+ dataset: TotoInferenceDataset,
47
+ freq: Optional[str] = None,
48
+ batch_size: int = 1,
49
+ time_limit: Optional[Union[int, float]] = None,
50
+ device: Any = None,
51
+ ):
52
+ self.device = torch.device(device)
53
+ self.batch_loader = torch.utils.data.DataLoader(
54
+ dataset=dataset,
55
+ batch_size=batch_size,
56
+ collate_fn=functools.partial(self._collate, device=self.device),
57
+ )
58
+ self.on_batch = self._get_timeout_callback(time_limit) if time_limit is not None else (lambda *a, **k: None)
59
+
60
+ self.freq: str = freq or dataset.freq or "h"
61
+
62
+ @staticmethod
63
+ def _get_timeout_callback(seconds: Optional[float]) -> Callable:
64
+ start_time = time.monotonic()
65
+
66
+ def callback() -> None:
67
+ if seconds is not None and time.monotonic() - start_time > seconds:
68
+ raise TimeLimitExceeded
69
+
70
+ return callback
71
+
72
+ @staticmethod
73
+ def _collate(time_series: list[np.ndarray], device: Any) -> torch.Tensor:
74
+ return torch.nn.utils.rnn.pad_sequence(
75
+ sequences=[torch.tensor(c, device=device, dtype=torch.float32) for c in time_series],
76
+ batch_first=True,
77
+ padding_value=torch.nan,
78
+ padding_side="left",
79
+ )
80
+
81
+ def __iter__(self) -> Iterator[MaskedTimeseries]:
82
+ for batch in self.batch_loader:
83
+ time_series = batch.unsqueeze(1).to(self.device).to(torch.float32)
84
+ nan_mask = torch.isnan(time_series)
85
+ time_series[nan_mask] = 0.0 # pad with zeros instead of nan
86
+
87
+ current_batch_size, _, context_length = time_series.shape
88
+
89
+ id_mask = torch.arange(current_batch_size, dtype=torch.int, device=self.device)[:, None, None].repeat(
90
+ 1, 1, context_length
91
+ )
92
+
93
+ time_interval_seconds = torch.full(
94
+ (current_batch_size, 1),
95
+ fill_value=freq_to_seconds(self.freq),
96
+ device=self.device,
97
+ dtype=torch.int,
98
+ )
99
+
100
+ yield MaskedTimeseries(
101
+ time_series,
102
+ padding_mask=~nan_mask,
103
+ id_mask=id_mask,
104
+ timestamp_seconds=torch.zeros_like(time_series, dtype=torch.int),
105
+ time_interval_seconds=time_interval_seconds,
106
+ )
107
+
108
+ self.on_batch()
@@ -0,0 +1,119 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+ from ._internal.backbone import TotoBackbone
7
+
8
+
9
+ class TotoConfig(PretrainedConfig):
10
+ model_type = "toto"
11
+
12
+ def __init__(
13
+ self,
14
+ dropout: float = 0.0,
15
+ embed_dim: int = 768,
16
+ num_heads: int = 12,
17
+ num_layers: int = 12,
18
+ output_distribution_classes: Optional[list[str]] = None,
19
+ output_distribution_kwargs: Optional[dict] = None,
20
+ patch_size: int = 64,
21
+ scale_factor_exponent: float = 10.0,
22
+ spacewise_every_n_layers: int = 12,
23
+ spacewise_first: bool = False,
24
+ stabilize_with_global: bool = True,
25
+ stride: int = 64,
26
+ transformers_version: str = "4.49.0",
27
+ use_memory_efficient_attention: bool = False,
28
+ **kwargs,
29
+ ):
30
+ self.dropout = dropout
31
+ self.embed_dim = embed_dim
32
+ self.num_heads = num_heads
33
+ self.num_layers = num_layers
34
+ self.output_distribution_classes = output_distribution_classes or ["MixtureOfStudentTsOutput"]
35
+ self.output_distribution_kwargs = output_distribution_kwargs or {"k_components": 24}
36
+ self.patch_size = patch_size
37
+ self.scale_factor_exponent = scale_factor_exponent
38
+ self.spacewise_every_n_layers = spacewise_every_n_layers
39
+ self.spacewise_first = spacewise_first
40
+ self.stabilize_with_global = stabilize_with_global
41
+ self.stride = stride
42
+ self.transformers_version = transformers_version
43
+ self.use_memory_efficient_attention = use_memory_efficient_attention
44
+
45
+ super().__init__(**kwargs)
46
+
47
+
48
+ class TotoPretrainedModel(PreTrainedModel):
49
+ config_class = TotoConfig
50
+ base_model_prefix = "model" # optional, used for weight naming conventions
51
+
52
+ def __init__(self, config: TotoConfig):
53
+ super().__init__(config)
54
+ self.model = TotoBackbone(
55
+ patch_size=config.patch_size,
56
+ stride=config.stride,
57
+ embed_dim=config.embed_dim,
58
+ num_layers=config.num_layers,
59
+ num_heads=config.num_heads,
60
+ mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 3072),
61
+ dropout=config.dropout,
62
+ spacewise_every_n_layers=config.spacewise_every_n_layers,
63
+ scaler_cls=getattr(config, "scaler_cls", "model.scaler.CausalPatchStdMeanScaler"),
64
+ output_distribution_classes=config.output_distribution_classes,
65
+ spacewise_first=config.spacewise_first,
66
+ output_distribution_kwargs=config.output_distribution_kwargs,
67
+ use_memory_efficient_attention=False,
68
+ stabilize_with_global=config.stabilize_with_global,
69
+ scale_factor_exponent=config.scale_factor_exponent,
70
+ **getattr(config, "extra_kwargs", {}),
71
+ )
72
+ self._register_load_state_dict_pre_hook(self._remap_state_dict_keys_hook)
73
+ self.post_init()
74
+
75
+ def _remap_state_dict_keys_hook(
76
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
77
+ ):
78
+ remap = {
79
+ "mlp.0.w12.weight": "mlp.0.weight",
80
+ "mlp.0.w12.bias": "mlp.0.bias",
81
+ "mlp.0.w3.weight": "mlp.2.weight",
82
+ "mlp.0.w3.bias": "mlp.2.bias",
83
+ }
84
+
85
+ keys_to_remap = []
86
+ for key in list(state_dict.keys()):
87
+ for old, new in remap.items():
88
+ if old in key:
89
+ new_key = key.replace(old, new)
90
+ keys_to_remap.append((key, new_key))
91
+ break
92
+
93
+ for old_key, new_key in keys_to_remap:
94
+ state_dict[new_key] = state_dict.pop(old_key)
95
+
96
+ @classmethod
97
+ def from_pretrained(cls, model_name_or_path, config=None, torch_dtype=None, device_map=None, **kwargs):
98
+ transformers_logger = logging.getLogger("transformers.modeling_utils")
99
+ original_level = transformers_logger.level
100
+
101
+ try:
102
+ # Here we suppress transformers logger's "some weights were not initialized" error since the
103
+ # remapping hook is only called after the initial model loading.
104
+ transformers_logger.setLevel(logging.ERROR)
105
+
106
+ # Transformers follows a different load path that does not call load_state_dict hooks when
107
+ # loading with explicit device maps. Here, we first load the model with no device maps and
108
+ # move it.
109
+ model = super().from_pretrained(model_name_or_path, config=config, torch_dtype=torch_dtype, **kwargs)
110
+ if device_map is not None:
111
+ model = model.to(device_map)
112
+
113
+ finally:
114
+ transformers_logger.setLevel(original_level)
115
+
116
+ return model
117
+
118
+ def forward(self, *args, **kwargs):
119
+ return self.model(*args, **kwargs)