autogluon.timeseries 1.4.1b20250926__py3-none-any.whl → 1.4.1b20250927__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/timeseries/models/__init__.py +2 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
- autogluon/timeseries/models/toto/model.py +234 -0
- autogluon/timeseries/version.py +1 -1
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/METADATA +10 -5
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/RECORD +26 -11
- /autogluon.timeseries-1.4.1b20250926-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20250927-py3.9-nspkg.pth +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/LICENSE +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/NOTICE +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/WHEEL +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250926.dist-info → autogluon.timeseries-1.4.1b20250927.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional, Union, cast
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from einops import rearrange, repeat
|
|
12
|
+
from gluonts.torch.distributions import AffineTransformed
|
|
13
|
+
from torch.distributions import Distribution
|
|
14
|
+
|
|
15
|
+
from .backbone import TotoBackbone
|
|
16
|
+
from .dataset import (
|
|
17
|
+
MaskedTimeseries,
|
|
18
|
+
pad_array,
|
|
19
|
+
pad_id_mask,
|
|
20
|
+
replace_extreme_values,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class Forecast:
|
|
26
|
+
mean: torch.Tensor
|
|
27
|
+
samples: Optional[torch.Tensor]
|
|
28
|
+
|
|
29
|
+
def quantile(self, q: Union[float, torch.Tensor]) -> torch.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
Compute the quantile of the forecast samples.
|
|
32
|
+
"""
|
|
33
|
+
assert self.samples is not None, "samples must be provided to compute quantiles"
|
|
34
|
+
assert isinstance(q, float) or isinstance(q, torch.Tensor), "q must be a float or a tensor"
|
|
35
|
+
if isinstance(q, float):
|
|
36
|
+
q = torch.tensor(q, device=self.samples.device, dtype=self.samples.dtype)
|
|
37
|
+
return self.samples.quantile(q, dim=-1)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def median(self) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
The median of the forecast samples.
|
|
43
|
+
"""
|
|
44
|
+
return self.quantile(0.5)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def std(self) -> torch.Tensor:
|
|
48
|
+
"""
|
|
49
|
+
Compute the standard deviation of the forecast samples.
|
|
50
|
+
"""
|
|
51
|
+
assert self.samples is not None, "samples must be provided to compute standard deviation"
|
|
52
|
+
return self.samples.std(dim=-1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TotoForecaster:
|
|
56
|
+
"""
|
|
57
|
+
A forecaster class for the Toto model that handles autoregressive decoding for time series forecasting.
|
|
58
|
+
|
|
59
|
+
This class wraps a TotoBackbone model and provides methods to generate forecasts for time series data.
|
|
60
|
+
The forecasting process uses an autoregressive decoding algorithm:
|
|
61
|
+
|
|
62
|
+
1. The model first processes the entire input context (historical data)
|
|
63
|
+
2. For each future time step:
|
|
64
|
+
- The model generates a distribution over possible values
|
|
65
|
+
- Either the mean or random samples are drawn from this distribution
|
|
66
|
+
- The generated value(s) are appended to the input sequence
|
|
67
|
+
- The process repeats with this extended sequence
|
|
68
|
+
|
|
69
|
+
When generating multiple samples (num_samples > 1), the model creates separate trajectories for each sample:
|
|
70
|
+
- Each trajectory starts with the same historical context
|
|
71
|
+
- As sampling progresses, each trajectory evolves independently
|
|
72
|
+
- This results in num_samples different possible future paths
|
|
73
|
+
- Samples can be processed in batches (samples_per_batch) to manage memory usage
|
|
74
|
+
|
|
75
|
+
The forecaster efficiently reuses computation from the context processing phase using a key-value cache,
|
|
76
|
+
which stores intermediate transformer attention states to avoid redundant computation.
|
|
77
|
+
|
|
78
|
+
The forecaster handles data preprocessing, including padding to match the model's patch size,
|
|
79
|
+
and postprocessing to format the outputs as a Forecast object containing means and optional samples.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
model: TotoBackbone
|
|
83
|
+
|
|
84
|
+
def __init__(self, model: TotoBackbone):
|
|
85
|
+
self.model = model
|
|
86
|
+
|
|
87
|
+
def forecast(
|
|
88
|
+
self,
|
|
89
|
+
inputs: MaskedTimeseries,
|
|
90
|
+
prediction_length: int,
|
|
91
|
+
num_samples: Optional[int] = None,
|
|
92
|
+
samples_per_batch: int = 10,
|
|
93
|
+
use_kv_cache: bool = True,
|
|
94
|
+
) -> Forecast:
|
|
95
|
+
"""
|
|
96
|
+
Generate a forecast for a batch of time series. This method works autoregressively,
|
|
97
|
+
i.e. it feeds the model's predictions back into itself. The decoding process is as follows:
|
|
98
|
+
|
|
99
|
+
1. The model first processes the entire input context (historical data)
|
|
100
|
+
2. For each future time step:
|
|
101
|
+
- The model generates a distribution over possible values
|
|
102
|
+
- Either the mean or random samples are drawn from this distribution
|
|
103
|
+
- The generated value(s) are appended to the input sequence
|
|
104
|
+
- The process repeats with this extended sequence
|
|
105
|
+
|
|
106
|
+
There are two modes of operation:
|
|
107
|
+
1. num_samples is None: generate a single mean prediction
|
|
108
|
+
2. num_samples is not None: generate num_samples random samples
|
|
109
|
+
|
|
110
|
+
When num_samples is not None, the model creates num_samples separate trajectories for each sample:
|
|
111
|
+
- Each trajectory starts with the same historical context
|
|
112
|
+
- As sampling progresses, each trajectory evolves independently
|
|
113
|
+
- This results in num_samples different possible future paths
|
|
114
|
+
- Samples can be processed in batches (samples_per_batch) to manage memory usage
|
|
115
|
+
|
|
116
|
+
When using samples_per_batch, this batch size compounds with the optional batch dimension of the input.
|
|
117
|
+
For example, if you have a batch of 10 time series, and you set samples_per_batch to 10,
|
|
118
|
+
the effective batch size is 100. For the best performance, set samples_per_batch
|
|
119
|
+
as high as possible, subject to memory constraints.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
inputs: A MaskedTimeseries object containing the input time series.
|
|
123
|
+
prediction_length: The number of future time steps to predict.
|
|
124
|
+
num_samples:
|
|
125
|
+
The number of samples to generate.
|
|
126
|
+
If None, a single mean prediction is generated. However,
|
|
127
|
+
the mean point forecast tends to be less accurate than the
|
|
128
|
+
median or mean of the samples (provided enough samples are generated).
|
|
129
|
+
It's recommended to use at least 128 samples for reliable forecasts.
|
|
130
|
+
samples_per_batch:
|
|
131
|
+
The number of samples to generate per batch.
|
|
132
|
+
In most cases, this should be as high as possible, subject to memory constraints.
|
|
133
|
+
When the inputs have a batch dimension, the effective batch size is samples_per_batch * batch_size.
|
|
134
|
+
use_kv_cache:
|
|
135
|
+
Whether to use a key-value cache for the model. In most cases, this should be True,
|
|
136
|
+
as it significantly speeds up inference.
|
|
137
|
+
"""
|
|
138
|
+
if len(inputs.series.shape) == 2:
|
|
139
|
+
# unbatched input, variates x time_steps
|
|
140
|
+
batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs]))
|
|
141
|
+
else:
|
|
142
|
+
# input is already batched
|
|
143
|
+
batch = inputs
|
|
144
|
+
|
|
145
|
+
# pad the input to the nearest multiple of the patch size
|
|
146
|
+
series = pad_array(batch.series, self.model.patch_embed.stride)
|
|
147
|
+
padding_mask = pad_array(batch.padding_mask, self.model.patch_embed.stride)
|
|
148
|
+
id_mask = batch.id_mask
|
|
149
|
+
if id_mask is not None:
|
|
150
|
+
id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride)
|
|
151
|
+
timestamp_seconds = pad_array(batch.timestamp_seconds, self.model.patch_embed.stride)
|
|
152
|
+
time_interval_seconds: torch.Tensor = torch.as_tensor(
|
|
153
|
+
batch.time_interval_seconds, device=series.device, dtype=torch.int
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if num_samples is not None:
|
|
157
|
+
samples = self.generate_samples(
|
|
158
|
+
inputs=series,
|
|
159
|
+
prediction_length=prediction_length,
|
|
160
|
+
num_samples=num_samples,
|
|
161
|
+
timestamp_seconds=timestamp_seconds,
|
|
162
|
+
time_interval_seconds=time_interval_seconds,
|
|
163
|
+
input_padding_mask=padding_mask,
|
|
164
|
+
id_mask=id_mask,
|
|
165
|
+
sampling_batch_size=samples_per_batch,
|
|
166
|
+
use_kv_cache=use_kv_cache,
|
|
167
|
+
)
|
|
168
|
+
mean = samples.mean(dim=-1)
|
|
169
|
+
else:
|
|
170
|
+
mean = self.generate_mean(
|
|
171
|
+
inputs=series,
|
|
172
|
+
prediction_length=prediction_length,
|
|
173
|
+
timestamp_seconds=timestamp_seconds,
|
|
174
|
+
time_interval_seconds=time_interval_seconds,
|
|
175
|
+
input_padding_mask=padding_mask,
|
|
176
|
+
id_mask=id_mask,
|
|
177
|
+
use_kv_cache=use_kv_cache,
|
|
178
|
+
)
|
|
179
|
+
samples = None
|
|
180
|
+
|
|
181
|
+
return Forecast(mean=mean, samples=samples)
|
|
182
|
+
|
|
183
|
+
@torch.no_grad()
|
|
184
|
+
def generate_mean(
|
|
185
|
+
self,
|
|
186
|
+
inputs: torch.Tensor,
|
|
187
|
+
prediction_length: int,
|
|
188
|
+
timestamp_seconds: torch.Tensor,
|
|
189
|
+
time_interval_seconds: torch.Tensor,
|
|
190
|
+
input_padding_mask: Optional[torch.Tensor] = None,
|
|
191
|
+
id_mask: Optional[torch.Tensor] = None,
|
|
192
|
+
use_kv_cache: bool = False,
|
|
193
|
+
) -> torch.Tensor:
|
|
194
|
+
"""
|
|
195
|
+
Generate a point prediction by taking the mean of the output distribution at each step.
|
|
196
|
+
This method works autoregressively, i.e. it feeds the model's predictions back into itself
|
|
197
|
+
to generate the next prediction.
|
|
198
|
+
"""
|
|
199
|
+
if input_padding_mask is None:
|
|
200
|
+
input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device)
|
|
201
|
+
if id_mask is None:
|
|
202
|
+
id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device)
|
|
203
|
+
|
|
204
|
+
## round up the prediction length to the nearest multiple of the patch size
|
|
205
|
+
patch_size = self.model.patch_embed.stride
|
|
206
|
+
rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size)
|
|
207
|
+
start_index = inputs.shape[-1]
|
|
208
|
+
end_index = start_index + prediction_length
|
|
209
|
+
|
|
210
|
+
# TODO: maybe pass in future masks, rather than making assumptions here?
|
|
211
|
+
dummy_padding = torch.ones(
|
|
212
|
+
(input_padding_mask.shape[0], input_padding_mask.shape[1], patch_size),
|
|
213
|
+
device=inputs.device,
|
|
214
|
+
dtype=torch.bool,
|
|
215
|
+
)
|
|
216
|
+
dummy_id_mask = repeat(
|
|
217
|
+
id_mask[:, :, -1:],
|
|
218
|
+
"batch variates 1 -> batch variates patch_size",
|
|
219
|
+
patch_size=patch_size,
|
|
220
|
+
)
|
|
221
|
+
if use_kv_cache:
|
|
222
|
+
kv_cache = self.model.allocate_kv_cache(
|
|
223
|
+
batch_size=inputs.shape[0],
|
|
224
|
+
num_variates=inputs.shape[1],
|
|
225
|
+
max_time_steps=inputs.shape[2] + rounded_steps,
|
|
226
|
+
device=inputs.device,
|
|
227
|
+
dtype=inputs.dtype,
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
kv_cache = None
|
|
231
|
+
|
|
232
|
+
scaling_prefix_length = inputs.shape[-1]
|
|
233
|
+
|
|
234
|
+
for _ in range(rounded_steps // patch_size):
|
|
235
|
+
base_distr, loc, scale = self.model(
|
|
236
|
+
inputs=inputs,
|
|
237
|
+
input_padding_mask=input_padding_mask,
|
|
238
|
+
id_mask=id_mask,
|
|
239
|
+
kv_cache=kv_cache,
|
|
240
|
+
scaling_prefix_length=scaling_prefix_length,
|
|
241
|
+
)
|
|
242
|
+
distr = self.create_affine_transformed(base_distr, loc, scale)
|
|
243
|
+
|
|
244
|
+
# We remove extreme values that can occur early in training
|
|
245
|
+
# and cause validation metrics to be NaN
|
|
246
|
+
samples = replace_extreme_values(distr.mean[:, :, -patch_size:])
|
|
247
|
+
|
|
248
|
+
inputs = torch.cat([inputs, samples], dim=-1)
|
|
249
|
+
id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1)
|
|
250
|
+
input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1)
|
|
251
|
+
for _ in range(patch_size):
|
|
252
|
+
next_timestamp: torch.Tensor = timestamp_seconds[:, :, -1] + time_interval_seconds
|
|
253
|
+
timestamp_seconds = torch.cat([timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1)
|
|
254
|
+
|
|
255
|
+
return inputs.detach()[:, :, start_index:end_index]
|
|
256
|
+
|
|
257
|
+
@torch.no_grad()
|
|
258
|
+
def generate_samples(
|
|
259
|
+
self,
|
|
260
|
+
inputs: torch.Tensor,
|
|
261
|
+
prediction_length: int,
|
|
262
|
+
num_samples: int,
|
|
263
|
+
timestamp_seconds: torch.Tensor,
|
|
264
|
+
time_interval_seconds: torch.Tensor,
|
|
265
|
+
input_padding_mask: Optional[torch.Tensor] = None,
|
|
266
|
+
id_mask: Optional[torch.Tensor] = None,
|
|
267
|
+
sampling_batch_size: int = 10,
|
|
268
|
+
use_kv_cache: bool = False,
|
|
269
|
+
) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Generate samples from the output distribution.
|
|
272
|
+
This method works autorregressively, i.e. it feeds the model's predictions back into itself.
|
|
273
|
+
It works by creating num_samples chains. Each chain is a separate sequence of predictions.
|
|
274
|
+
At each time step, for each chain we take a single sample from the output distribution and append
|
|
275
|
+
it to the end of the sequence.
|
|
276
|
+
"""
|
|
277
|
+
if input_padding_mask is None:
|
|
278
|
+
input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device)
|
|
279
|
+
if id_mask is None:
|
|
280
|
+
id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device)
|
|
281
|
+
|
|
282
|
+
assert num_samples % sampling_batch_size == 0, "num_samples must be divisible by sampling_batch_size"
|
|
283
|
+
num_batches = num_samples // sampling_batch_size
|
|
284
|
+
|
|
285
|
+
# round up the prediction length to the nearest multiple of the patch size
|
|
286
|
+
patch_size = self.model.patch_embed.patch_size
|
|
287
|
+
rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size)
|
|
288
|
+
start_index = inputs.shape[-1]
|
|
289
|
+
end_index = start_index + prediction_length
|
|
290
|
+
|
|
291
|
+
dummy_padding = torch.ones(
|
|
292
|
+
(
|
|
293
|
+
input_padding_mask.shape[0] * sampling_batch_size,
|
|
294
|
+
input_padding_mask.shape[1],
|
|
295
|
+
patch_size,
|
|
296
|
+
),
|
|
297
|
+
dtype=torch.bool,
|
|
298
|
+
device=inputs.device,
|
|
299
|
+
)
|
|
300
|
+
dummy_id_mask = repeat(
|
|
301
|
+
id_mask[:, :, -1:],
|
|
302
|
+
"batch variates 1 -> (sampling_batch_size batch) variates patch_size",
|
|
303
|
+
sampling_batch_size=sampling_batch_size,
|
|
304
|
+
patch_size=patch_size,
|
|
305
|
+
)
|
|
306
|
+
inputs = repeat(
|
|
307
|
+
inputs,
|
|
308
|
+
"batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
|
|
309
|
+
sampling_batch_size=sampling_batch_size,
|
|
310
|
+
)
|
|
311
|
+
input_padding_mask = repeat(
|
|
312
|
+
input_padding_mask,
|
|
313
|
+
"batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
|
|
314
|
+
sampling_batch_size=sampling_batch_size,
|
|
315
|
+
)
|
|
316
|
+
id_mask = repeat(
|
|
317
|
+
id_mask,
|
|
318
|
+
"batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
|
|
319
|
+
sampling_batch_size=sampling_batch_size,
|
|
320
|
+
)
|
|
321
|
+
timestamp_seconds = repeat(
|
|
322
|
+
timestamp_seconds,
|
|
323
|
+
"batch variates seq_len -> (sampling_batch_size batch) variates seq_len",
|
|
324
|
+
sampling_batch_size=sampling_batch_size,
|
|
325
|
+
)
|
|
326
|
+
time_interval_seconds = repeat(
|
|
327
|
+
time_interval_seconds,
|
|
328
|
+
"batch variates -> (sampling_batch_size batch) variates",
|
|
329
|
+
sampling_batch_size=sampling_batch_size,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
all_samples = []
|
|
333
|
+
if use_kv_cache:
|
|
334
|
+
kv_cache = self.model.allocate_kv_cache(
|
|
335
|
+
batch_size=inputs.shape[0],
|
|
336
|
+
num_variates=inputs.shape[1],
|
|
337
|
+
max_time_steps=inputs.shape[2] + rounded_steps,
|
|
338
|
+
device=inputs.device,
|
|
339
|
+
dtype=inputs.dtype,
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
kv_cache = None
|
|
343
|
+
|
|
344
|
+
scaling_prefix_length = inputs.shape[-1]
|
|
345
|
+
|
|
346
|
+
for _ in range(num_batches):
|
|
347
|
+
batch_inputs = torch.clone(inputs)
|
|
348
|
+
batch_input_padding_mask = torch.clone(input_padding_mask)
|
|
349
|
+
batch_id_mask = torch.clone(id_mask)
|
|
350
|
+
batch_timestamp_seconds = torch.clone(timestamp_seconds)
|
|
351
|
+
|
|
352
|
+
for _ in range(rounded_steps // patch_size):
|
|
353
|
+
base_distr, loc, scale = self.model(
|
|
354
|
+
inputs=batch_inputs,
|
|
355
|
+
input_padding_mask=batch_input_padding_mask,
|
|
356
|
+
id_mask=batch_id_mask,
|
|
357
|
+
kv_cache=kv_cache,
|
|
358
|
+
scaling_prefix_length=scaling_prefix_length,
|
|
359
|
+
)
|
|
360
|
+
distr = self.create_affine_transformed(base_distr, loc, scale)
|
|
361
|
+
|
|
362
|
+
sample = distr.sample()
|
|
363
|
+
assert sample is not None
|
|
364
|
+
|
|
365
|
+
# We remove extreme values that can occur early in training
|
|
366
|
+
# and cause validation metrics to be NaN
|
|
367
|
+
samples = replace_extreme_values(sample[:, :, -patch_size:])
|
|
368
|
+
batch_inputs = torch.cat([batch_inputs, samples], dim=-1)
|
|
369
|
+
batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1)
|
|
370
|
+
batch_input_padding_mask = torch.cat([batch_input_padding_mask, dummy_padding], dim=-1)
|
|
371
|
+
for _ in range(patch_size):
|
|
372
|
+
next_timestamp = batch_timestamp_seconds[:, :, -1] + time_interval_seconds
|
|
373
|
+
batch_timestamp_seconds = torch.cat(
|
|
374
|
+
[batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1
|
|
375
|
+
)
|
|
376
|
+
all_samples.append(batch_inputs)
|
|
377
|
+
if kv_cache is not None:
|
|
378
|
+
kv_cache.reset()
|
|
379
|
+
|
|
380
|
+
outputs = torch.cat(all_samples, dim=0)
|
|
381
|
+
unfolded_outputs = rearrange(
|
|
382
|
+
outputs,
|
|
383
|
+
"(samples batch) variates seq_len -> batch variates seq_len samples",
|
|
384
|
+
samples=num_samples,
|
|
385
|
+
).detach()
|
|
386
|
+
|
|
387
|
+
trimmed_predictions = unfolded_outputs[:, :, start_index:end_index, :]
|
|
388
|
+
return trimmed_predictions
|
|
389
|
+
|
|
390
|
+
@staticmethod
|
|
391
|
+
def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor) -> Distribution:
|
|
392
|
+
"""
|
|
393
|
+
Creates an AffineTransformed distribution with correctly matched shapes.
|
|
394
|
+
|
|
395
|
+
Handles three cases:
|
|
396
|
+
1. When loc/scale are per-timestep (from CausalStdMeanScaler)
|
|
397
|
+
2. When base_distr only contains the distribution for the latest patch
|
|
398
|
+
while loc/scale contain values for the entire sequence
|
|
399
|
+
3. When loc/scale have a single time step (from StdMeanScaler/StdMinScaler)
|
|
400
|
+
and need to be broadcast to match a multi-step base distribution
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
base_distr: The base distribution to transform
|
|
404
|
+
loc: Location parameter
|
|
405
|
+
scale: Scale parameter
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
An AffineTransformed distribution with properly handled shapes
|
|
409
|
+
"""
|
|
410
|
+
# Get the shape of the base distribution
|
|
411
|
+
# We'll use this to match the time dimension of loc/scale
|
|
412
|
+
base_shape = base_distr.mean.shape
|
|
413
|
+
|
|
414
|
+
base_time_dim = base_shape[-1] # Time dimension of base distribution
|
|
415
|
+
loc_time_dim = loc.shape[-1] # Time dimension of loc
|
|
416
|
+
|
|
417
|
+
if loc_time_dim == 1:
|
|
418
|
+
# Case 1: If loc/scale have time dimension 1 (standard scalers), PyTorch broadcasting will handle it
|
|
419
|
+
return AffineTransformed(base_distr, loc=loc, scale=scale)
|
|
420
|
+
|
|
421
|
+
# Case 2: If loc/scale have time dimension > 1 (causal scaler with history)
|
|
422
|
+
# We need to extract only the suffix that matches the base distribution
|
|
423
|
+
return AffineTransformed(base_distr, loc=loc[:, :, -base_time_dim:], scale=scale[:, :, -base_time_dim:])
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Callable, Iterator, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
|
9
|
+
from autogluon.timeseries import TimeSeriesDataFrame
|
|
10
|
+
|
|
11
|
+
from ._internal.dataset import MaskedTimeseries, freq_to_seconds
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TotoInferenceDataset(torch.utils.data.Dataset):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
target_df: TimeSeriesDataFrame,
|
|
18
|
+
max_context_length: int,
|
|
19
|
+
target_column: str = "target",
|
|
20
|
+
):
|
|
21
|
+
assert max_context_length > 0
|
|
22
|
+
self.max_context_length = max_context_length
|
|
23
|
+
self.target_array = target_df[target_column].to_numpy(dtype=np.float32)
|
|
24
|
+
|
|
25
|
+
# store pointer to start:end of each time series
|
|
26
|
+
self.indptr = target_df.get_indptr()
|
|
27
|
+
|
|
28
|
+
self.freq = target_df.freq
|
|
29
|
+
|
|
30
|
+
def __len__(self):
|
|
31
|
+
return len(self.indptr) - 1 # noqa
|
|
32
|
+
|
|
33
|
+
def __getitem__(self, idx) -> np.ndarray:
|
|
34
|
+
start_idx = self.indptr[idx]
|
|
35
|
+
end_idx = self.indptr[idx + 1]
|
|
36
|
+
|
|
37
|
+
if end_idx - start_idx > self.max_context_length:
|
|
38
|
+
start_idx = end_idx - self.max_context_length
|
|
39
|
+
|
|
40
|
+
return self.target_array[start_idx:end_idx]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TotoDataLoader:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
dataset: TotoInferenceDataset,
|
|
47
|
+
freq: Optional[str] = None,
|
|
48
|
+
batch_size: int = 1,
|
|
49
|
+
time_limit: Optional[Union[int, float]] = None,
|
|
50
|
+
device: Any = None,
|
|
51
|
+
):
|
|
52
|
+
self.device = torch.device(device)
|
|
53
|
+
self.batch_loader = torch.utils.data.DataLoader(
|
|
54
|
+
dataset=dataset,
|
|
55
|
+
batch_size=batch_size,
|
|
56
|
+
collate_fn=functools.partial(self._collate, device=self.device),
|
|
57
|
+
)
|
|
58
|
+
self.on_batch = self._get_timeout_callback(time_limit) if time_limit is not None else (lambda *a, **k: None)
|
|
59
|
+
|
|
60
|
+
self.freq: str = freq or dataset.freq or "h"
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _get_timeout_callback(seconds: Optional[float]) -> Callable:
|
|
64
|
+
start_time = time.monotonic()
|
|
65
|
+
|
|
66
|
+
def callback() -> None:
|
|
67
|
+
if seconds is not None and time.monotonic() - start_time > seconds:
|
|
68
|
+
raise TimeLimitExceeded
|
|
69
|
+
|
|
70
|
+
return callback
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _collate(time_series: list[np.ndarray], device: Any) -> torch.Tensor:
|
|
74
|
+
return torch.nn.utils.rnn.pad_sequence(
|
|
75
|
+
sequences=[torch.tensor(c, device=device, dtype=torch.float32) for c in time_series],
|
|
76
|
+
batch_first=True,
|
|
77
|
+
padding_value=torch.nan,
|
|
78
|
+
padding_side="left",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def __iter__(self) -> Iterator[MaskedTimeseries]:
|
|
82
|
+
for batch in self.batch_loader:
|
|
83
|
+
time_series = batch.unsqueeze(1).to(self.device).to(torch.float32)
|
|
84
|
+
nan_mask = torch.isnan(time_series)
|
|
85
|
+
time_series[nan_mask] = 0.0 # pad with zeros instead of nan
|
|
86
|
+
|
|
87
|
+
current_batch_size, _, context_length = time_series.shape
|
|
88
|
+
|
|
89
|
+
id_mask = torch.arange(current_batch_size, dtype=torch.int, device=self.device)[:, None, None].repeat(
|
|
90
|
+
1, 1, context_length
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
time_interval_seconds = torch.full(
|
|
94
|
+
(current_batch_size, 1),
|
|
95
|
+
fill_value=freq_to_seconds(self.freq),
|
|
96
|
+
device=self.device,
|
|
97
|
+
dtype=torch.int,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
yield MaskedTimeseries(
|
|
101
|
+
time_series,
|
|
102
|
+
padding_mask=~nan_mask,
|
|
103
|
+
id_mask=id_mask,
|
|
104
|
+
timestamp_seconds=torch.zeros_like(time_series, dtype=torch.int),
|
|
105
|
+
time_interval_seconds=time_interval_seconds,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.on_batch()
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
5
|
+
|
|
6
|
+
from ._internal.backbone import TotoBackbone
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TotoConfig(PretrainedConfig):
|
|
10
|
+
model_type = "toto"
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
dropout: float = 0.0,
|
|
15
|
+
embed_dim: int = 768,
|
|
16
|
+
num_heads: int = 12,
|
|
17
|
+
num_layers: int = 12,
|
|
18
|
+
output_distribution_classes: Optional[list[str]] = None,
|
|
19
|
+
output_distribution_kwargs: Optional[dict] = None,
|
|
20
|
+
patch_size: int = 64,
|
|
21
|
+
scale_factor_exponent: float = 10.0,
|
|
22
|
+
spacewise_every_n_layers: int = 12,
|
|
23
|
+
spacewise_first: bool = False,
|
|
24
|
+
stabilize_with_global: bool = True,
|
|
25
|
+
stride: int = 64,
|
|
26
|
+
transformers_version: str = "4.49.0",
|
|
27
|
+
use_memory_efficient_attention: bool = False,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
self.dropout = dropout
|
|
31
|
+
self.embed_dim = embed_dim
|
|
32
|
+
self.num_heads = num_heads
|
|
33
|
+
self.num_layers = num_layers
|
|
34
|
+
self.output_distribution_classes = output_distribution_classes or ["MixtureOfStudentTsOutput"]
|
|
35
|
+
self.output_distribution_kwargs = output_distribution_kwargs or {"k_components": 24}
|
|
36
|
+
self.patch_size = patch_size
|
|
37
|
+
self.scale_factor_exponent = scale_factor_exponent
|
|
38
|
+
self.spacewise_every_n_layers = spacewise_every_n_layers
|
|
39
|
+
self.spacewise_first = spacewise_first
|
|
40
|
+
self.stabilize_with_global = stabilize_with_global
|
|
41
|
+
self.stride = stride
|
|
42
|
+
self.transformers_version = transformers_version
|
|
43
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
44
|
+
|
|
45
|
+
super().__init__(**kwargs)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class TotoPretrainedModel(PreTrainedModel):
|
|
49
|
+
config_class = TotoConfig
|
|
50
|
+
base_model_prefix = "model" # optional, used for weight naming conventions
|
|
51
|
+
|
|
52
|
+
def __init__(self, config: TotoConfig):
|
|
53
|
+
super().__init__(config)
|
|
54
|
+
self.model = TotoBackbone(
|
|
55
|
+
patch_size=config.patch_size,
|
|
56
|
+
stride=config.stride,
|
|
57
|
+
embed_dim=config.embed_dim,
|
|
58
|
+
num_layers=config.num_layers,
|
|
59
|
+
num_heads=config.num_heads,
|
|
60
|
+
mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 3072),
|
|
61
|
+
dropout=config.dropout,
|
|
62
|
+
spacewise_every_n_layers=config.spacewise_every_n_layers,
|
|
63
|
+
scaler_cls=getattr(config, "scaler_cls", "model.scaler.CausalPatchStdMeanScaler"),
|
|
64
|
+
output_distribution_classes=config.output_distribution_classes,
|
|
65
|
+
spacewise_first=config.spacewise_first,
|
|
66
|
+
output_distribution_kwargs=config.output_distribution_kwargs,
|
|
67
|
+
use_memory_efficient_attention=False,
|
|
68
|
+
stabilize_with_global=config.stabilize_with_global,
|
|
69
|
+
scale_factor_exponent=config.scale_factor_exponent,
|
|
70
|
+
**getattr(config, "extra_kwargs", {}),
|
|
71
|
+
)
|
|
72
|
+
self._register_load_state_dict_pre_hook(self._remap_state_dict_keys_hook)
|
|
73
|
+
self.post_init()
|
|
74
|
+
|
|
75
|
+
def _remap_state_dict_keys_hook(
|
|
76
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
77
|
+
):
|
|
78
|
+
remap = {
|
|
79
|
+
"mlp.0.w12.weight": "mlp.0.weight",
|
|
80
|
+
"mlp.0.w12.bias": "mlp.0.bias",
|
|
81
|
+
"mlp.0.w3.weight": "mlp.2.weight",
|
|
82
|
+
"mlp.0.w3.bias": "mlp.2.bias",
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
keys_to_remap = []
|
|
86
|
+
for key in list(state_dict.keys()):
|
|
87
|
+
for old, new in remap.items():
|
|
88
|
+
if old in key:
|
|
89
|
+
new_key = key.replace(old, new)
|
|
90
|
+
keys_to_remap.append((key, new_key))
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
for old_key, new_key in keys_to_remap:
|
|
94
|
+
state_dict[new_key] = state_dict.pop(old_key)
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_pretrained(cls, model_name_or_path, config=None, torch_dtype=None, device_map=None, **kwargs):
|
|
98
|
+
transformers_logger = logging.getLogger("transformers.modeling_utils")
|
|
99
|
+
original_level = transformers_logger.level
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
# Here we suppress transformers logger's "some weights were not initialized" error since the
|
|
103
|
+
# remapping hook is only called after the initial model loading.
|
|
104
|
+
transformers_logger.setLevel(logging.ERROR)
|
|
105
|
+
|
|
106
|
+
# Transformers follows a different load path that does not call load_state_dict hooks when
|
|
107
|
+
# loading with explicit device maps. Here, we first load the model with no device maps and
|
|
108
|
+
# move it.
|
|
109
|
+
model = super().from_pretrained(model_name_or_path, config=config, torch_dtype=torch_dtype, **kwargs)
|
|
110
|
+
if device_map is not None:
|
|
111
|
+
model = model.to(device_map)
|
|
112
|
+
|
|
113
|
+
finally:
|
|
114
|
+
transformers_logger.setLevel(original_level)
|
|
115
|
+
|
|
116
|
+
return model
|
|
117
|
+
|
|
118
|
+
def forward(self, *args, **kwargs):
|
|
119
|
+
return self.model(*args, **kwargs)
|