tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
timesfm/flax/util.py ADDED
@@ -0,0 +1,107 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Flax utility functions for TimesFM layers."""
16
+
17
+ import dataclasses
18
+ import functools
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import jaxtyping
22
+
23
+ Float = jaxtyping.Float
24
+ Array = jaxtyping.Array
25
+ Bool = jaxtyping.Bool
26
+ Integer = jaxtyping.Integer
27
+
28
+ _TOLERANCE = 1e-6
29
+
30
+
31
+ @jax.tree_util.register_dataclass
32
+ @dataclasses.dataclass(frozen=False)
33
+ class DecodeCache:
34
+ """Cache for decoding."""
35
+
36
+ next_index: Integer[Array, "b"]
37
+ num_masked: Integer[Array, "b"]
38
+ key: Float[Array, "b n h d"]
39
+ value: Float[Array, "b n h d"]
40
+
41
+
42
+ @jax.jit
43
+ def update_running_stats(
44
+ n: Float[Array, "b"],
45
+ mu: Float[Array, "b"],
46
+ sigma: Float[Array, "b"],
47
+ x: Float[Array, "b p"],
48
+ mask: Bool[Array, "b p"],
49
+ ) -> tuple[
50
+ tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
51
+ tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
52
+ ]:
53
+ """Updates the running stats."""
54
+ is_legit = jnp.logical_not(mask)
55
+ inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)
56
+ inc_mu = jnp.where(
57
+ inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)
58
+ )
59
+ inc_sigma = jnp.where(
60
+ inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)
61
+ )
62
+ new_n = n + inc_n
63
+ new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)
64
+ new_sigma = jnp.sqrt(
65
+ jnp.where(
66
+ new_n == 0,
67
+ 0.0,
68
+ (
69
+ n * sigma * sigma
70
+ + inc_n * inc_sigma * inc_sigma
71
+ + n * (mu - new_mu) * (mu - new_mu)
72
+ + inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)
73
+ )
74
+ / new_n,
75
+ )
76
+ )
77
+ return (w := (new_n, new_mu, new_sigma), w)
78
+
79
+
80
+ def scan_along_axis(f, init, xs, axis: int, **kwargs):
81
+ """Scans along an axis."""
82
+ moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
83
+ carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
84
+ return (
85
+ carry,
86
+ jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
87
+ )
88
+
89
+
90
+ @functools.partial(jax.jit, static_argnames=("reverse",))
91
+ def revin(
92
+ x: Float[Array, "b ..."],
93
+ mu: Float[Array, "b ..."],
94
+ sigma: Float[Array, "b ..."],
95
+ reverse: bool = False,
96
+ ):
97
+ """Reversible per-instance normalization."""
98
+ if len(mu.shape) == len(x.shape) - 1:
99
+ mu = mu[..., None]
100
+ sigma = sigma[..., None]
101
+ elif len(mu.shape) == len(x.shape) - 2:
102
+ mu = mu[..., None, None]
103
+ sigma = sigma[..., None, None]
104
+ if reverse:
105
+ return x * sigma + mu
106
+ else:
107
+ return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)
@@ -0,0 +1,422 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TimesFM 2p5 base implementation."""
16
+
17
+ import dataclasses
18
+ from typing import Any, Callable, Sequence
19
+
20
+ import collections
21
+ import numpy as np
22
+
23
+ from .. import configs
24
+
25
+ ResidualBlockConfig = configs.ResidualBlockConfig
26
+ StackedTransformersConfig = configs.StackedTransformersConfig
27
+ TransformerConfig = configs.TransformerConfig
28
+ ForecastConfig = configs.ForecastConfig
29
+ Category = int | str
30
+ XRegMode = str
31
+
32
+
33
+ def strip_leading_nans(arr):
34
+ """Removes contiguous NaN values from the beginning of a NumPy array.
35
+
36
+ Args:
37
+ arr: The input NumPy array.
38
+
39
+ Returns:
40
+ A new NumPy array with leading NaN values removed.
41
+ If the array is all NaNs or empty, returns an empty array.
42
+ """
43
+
44
+ isnan = np.isnan(arr)
45
+ first_valid_index = np.argmax(~isnan)
46
+ return arr[first_valid_index:]
47
+
48
+
49
+ def linear_interpolation(arr):
50
+ """Performs linear interpolation to fill NaN values in a 1D numpy array.
51
+
52
+ Args:
53
+ arr: The 1D numpy array containing NaN values.
54
+
55
+ Returns:
56
+ A new numpy array with NaN values filled using linear interpolation,
57
+ or the original array if no NaNs are present.
58
+ Returns None if the input is not a 1D array.
59
+ Returns the original array if there are no NaN values.
60
+ """
61
+
62
+ nans = np.isnan(arr)
63
+ if not np.any(nans): # Check if there are any NaNs
64
+ return arr
65
+
66
+ def x(z):
67
+ return z.nonzero()[0]
68
+
69
+ nans_indices = x(nans)
70
+ non_nans_indices = x(~nans)
71
+ non_nans_values = arr[~nans]
72
+
73
+ try:
74
+ arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
75
+ except ValueError:
76
+ if non_nans_values:
77
+ mu = np.nanmean(arr)
78
+ else:
79
+ mu = 0.0
80
+ arr = np.where(np.isfinite(arr), arr, mu)
81
+ return arr
82
+
83
+
84
+ @dataclasses.dataclass(frozen=True)
85
+ class TimesFM_2p5_200M_Definition:
86
+ """Framework-agnostic config of TimesFM 2.5."""
87
+
88
+ context_limit = 16384
89
+ input_patch_len: int = 32
90
+ output_patch_len: int = 128
91
+ output_quantile_len: int = 1024
92
+ quantiles: list[float] = dataclasses.field(
93
+ default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
94
+ )
95
+ decode_index: int = 5
96
+ tokenizer: ResidualBlockConfig = ResidualBlockConfig(
97
+ input_dims=64,
98
+ hidden_dims=1280,
99
+ output_dims=1280,
100
+ use_bias=True,
101
+ activation="swish",
102
+ )
103
+ stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(
104
+ num_layers=20,
105
+ transformer=TransformerConfig(
106
+ model_dims=1280,
107
+ hidden_dims=1280,
108
+ num_heads=16,
109
+ attention_norm="rms",
110
+ feedforward_norm="rms",
111
+ qk_norm="rms",
112
+ use_bias=False,
113
+ use_rotary_position_embeddings=True,
114
+ ff_activation="swish",
115
+ fuse_qkv=True,
116
+ ),
117
+ )
118
+ output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
119
+ input_dims=1280,
120
+ hidden_dims=1280,
121
+ output_dims=1280,
122
+ use_bias=False,
123
+ activation="swish",
124
+ )
125
+ output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(
126
+ input_dims=1280,
127
+ hidden_dims=1280,
128
+ output_dims=10240,
129
+ use_bias=False,
130
+ activation="swish",
131
+ )
132
+
133
+
134
+ class TimesFM_2p5:
135
+ """Abstract base class for TimesFM models.
136
+
137
+ Attributes:
138
+ forecast_config: Configuration for forecasting flags.
139
+ compiled_decode: Compiled decode function.
140
+ global_batch_size: Global batch size.
141
+ """
142
+
143
+ forecast_config: ForecastConfig | None = None
144
+ compiled_decode: Callable[..., Any] | None = None
145
+ global_batch_size: int = 0
146
+
147
+ def load_checkpoint(self, path: str):
148
+ """Loads a TimesFM model from a checkpoint."""
149
+ raise NotImplementedError()
150
+
151
+ def compile(self, forecast_config: ForecastConfig | None = None):
152
+ """Compiles the TimesFM model for fast decoding."""
153
+ raise NotImplementedError()
154
+
155
+ def forecast(
156
+ self, horizon: int, inputs: list[np.ndarray]
157
+ ) -> tuple[np.ndarray, np.ndarray]:
158
+ """Forecasts the time series."""
159
+ if self.compiled_decode is None:
160
+ raise RuntimeError("Model is not compiled. Please call compile() first.")
161
+
162
+ assert self.global_batch_size > 0
163
+ assert self.forecast_config is not None
164
+
165
+ context = self.forecast_config.max_context
166
+ num_inputs = len(inputs)
167
+ if (w := num_inputs % self.global_batch_size) != 0:
168
+ inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)
169
+
170
+ output_points = []
171
+ output_quantiles = []
172
+ values = []
173
+ masks = []
174
+ idx = 0
175
+ for each_input in inputs:
176
+ value = linear_interpolation(strip_leading_nans(np.array(each_input)))
177
+ if (w := len(value)) >= context:
178
+ value = value[-context:]
179
+ mask = np.zeros_like(value, dtype=bool)
180
+ else:
181
+ mask = np.array([True] * (context - w) + [False] * w)
182
+ value = np.pad(value, (context - w, 0), "constant", constant_values=0.0)
183
+ values.append(value)
184
+ masks.append(mask)
185
+ idx += 1
186
+ if idx == self.global_batch_size:
187
+ idx = 0
188
+ point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)
189
+ output_points.append(point_forecast)
190
+ output_quantiles.append(quantile_forecast)
191
+ values = []
192
+ masks = []
193
+
194
+ output_points = np.concatenate(output_points, axis=0)
195
+ output_quantiles = np.concatenate(output_quantiles, axis=0)
196
+ return output_points[:num_inputs], output_quantiles[:num_inputs]
197
+
198
+ def forecast_with_covariates(
199
+ self,
200
+ inputs: list[Sequence[float]],
201
+ dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,
202
+ dynamic_categorical_covariates: (
203
+ dict[str, Sequence[Sequence[Category]]] | None
204
+ ) = None,
205
+ static_numerical_covariates: dict[str, Sequence[float]] | None = None,
206
+ static_categorical_covariates: dict[str, Sequence[Category]] | None = None,
207
+ xreg_mode: XRegMode = "xreg + timesfm",
208
+ normalize_xreg_target_per_input: bool = True,
209
+ ridge: float = 0.0,
210
+ max_rows_per_col: int = 0,
211
+ force_on_cpu: bool = False,
212
+ ):
213
+ """Forecasts on a list of time series with covariates.
214
+
215
+ To optimize inference speed, avoid string valued categorical covariates.
216
+
217
+ Args:
218
+ inputs: A list of time series forecast contexts. Each context time series
219
+ should be in a format convertible to JTensor by `jnp.array`.
220
+ dynamic_numerical_covariates: A dict of dynamic numerical covariates.
221
+ dynamic_categorical_covariates: A dict of dynamic categorical covariates.
222
+ static_numerical_covariates: A dict of static numerical covariates.
223
+ static_categorical_covariates: A dict of static categorical covariates.
224
+ xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm"
225
+ fits a model on the residuals of the TimesFM forecast. "timesfm + xreg"
226
+ fits a model on the targets then forecasts on the residuals via TimesFM.
227
+ normalize_xreg_target_per_input: whether to normalize the xreg target per
228
+ input in the given batch.
229
+ ridge: ridge penalty for the linear model.
230
+ max_rows_per_col: max number of rows per column for the linear model.
231
+ force_on_cpu: whether to force running on cpu for the linear model.
232
+
233
+ Returns:
234
+ A tuple of two lists. The first is the outputs of the model. The second is
235
+ the outputs of the xreg.
236
+ """
237
+ if self.forecast_config is None:
238
+ raise ValueError("Model is not compiled. Please call compile() first.")
239
+ elif not self.forecast_config.return_backcast:
240
+ raise ValueError(
241
+ "For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model."
242
+ )
243
+
244
+ from ..utils import xreg_lib
245
+
246
+ # Verify and bookkeep covariates.
247
+ if not (
248
+ dynamic_numerical_covariates
249
+ or dynamic_categorical_covariates
250
+ or static_numerical_covariates
251
+ or static_categorical_covariates
252
+ ):
253
+ raise ValueError(
254
+ "At least one of dynamic_numerical_covariates,"
255
+ " dynamic_categorical_covariates, static_numerical_covariates,"
256
+ " static_categorical_covariates must be set."
257
+ )
258
+
259
+ # Track the lengths of (1) each input, (2) the part that can be used in the
260
+ # linear model, and (3) the horizon.
261
+ input_lens, train_lens, test_lens = [], [], []
262
+
263
+ for i, input_ts in enumerate(inputs):
264
+ input_len = len(input_ts)
265
+ input_lens.append(input_len)
266
+
267
+ if xreg_mode == "timesfm + xreg":
268
+ # For fitting residuals, no TimesFM forecast on the first patch.
269
+ train_lens.append(max(0, input_len - self.model.p))
270
+ elif xreg_mode == "xreg + timesfm":
271
+ train_lens.append(input_len)
272
+ else:
273
+ raise ValueError(f"Unsupported mode: {xreg_mode}")
274
+
275
+ if dynamic_numerical_covariates:
276
+ test_lens.append(
277
+ len(list(dynamic_numerical_covariates.values())[0][i]) - input_len
278
+ )
279
+ elif dynamic_categorical_covariates:
280
+ test_lens.append(
281
+ len(list(dynamic_categorical_covariates.values())[0][i]) - input_len
282
+ )
283
+ else:
284
+ test_lens.append(self.forecast_config.max_horizon)
285
+
286
+ if test_lens[-1] > self.forecast_config.max_horizon:
287
+ raise ValueError(
288
+ "Forecast horizon length inferred from the dynamic covaraites is longer than the"
289
+ f"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}."
290
+ )
291
+
292
+ # Prepare the covariates into train and test.
293
+ train_dynamic_numerical_covariates = collections.defaultdict(list)
294
+ test_dynamic_numerical_covariates = collections.defaultdict(list)
295
+ train_dynamic_categorical_covariates = collections.defaultdict(list)
296
+ test_dynamic_categorical_covariates = collections.defaultdict(list)
297
+ for covariates, train_covariates, test_covariates in (
298
+ (
299
+ dynamic_numerical_covariates,
300
+ train_dynamic_numerical_covariates,
301
+ test_dynamic_numerical_covariates,
302
+ ),
303
+ (
304
+ dynamic_categorical_covariates,
305
+ train_dynamic_categorical_covariates,
306
+ test_dynamic_categorical_covariates,
307
+ ),
308
+ ):
309
+ if not covariates:
310
+ continue
311
+ for covariate_name, covariate_values in covariates.items():
312
+ for input_len, train_len, covariate_value in zip(
313
+ input_lens, train_lens, covariate_values
314
+ ):
315
+ train_covariates[covariate_name].append(
316
+ covariate_value[(input_len - train_len) : input_len]
317
+ )
318
+ test_covariates[covariate_name].append(covariate_value[input_len:])
319
+
320
+ # Fit models.
321
+ if xreg_mode == "timesfm + xreg":
322
+ # Forecast via TimesFM then fit a model on the residuals.
323
+ point_outputs, quantile_outputs = self.forecast(
324
+ horizon=self.forecast_config.max_horizon, inputs=inputs
325
+ )
326
+ targets = [
327
+ (
328
+ np.array(input_ts)[-train_len:]
329
+ - point_output[: -self.forecast_config.max_horizon][-train_len:]
330
+ )
331
+ for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)
332
+ ]
333
+ per_instance_stats = None
334
+ if normalize_xreg_target_per_input:
335
+ targets, per_instance_stats = xreg_lib.normalize(targets)
336
+ xregs = xreg_lib.BatchedInContextXRegLinear(
337
+ targets=targets,
338
+ train_lens=train_lens,
339
+ test_lens=test_lens,
340
+ train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
341
+ test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
342
+ train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
343
+ test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
344
+ static_numerical_covariates=static_numerical_covariates,
345
+ static_categorical_covariates=static_categorical_covariates,
346
+ ).fit(
347
+ ridge=ridge,
348
+ one_hot_encoder_drop=None if ridge > 0 else "first",
349
+ max_rows_per_col=max_rows_per_col,
350
+ force_on_cpu=force_on_cpu,
351
+ debug_info=False,
352
+ assert_covariates=True,
353
+ assert_covariate_shapes=True,
354
+ )
355
+ if normalize_xreg_target_per_input:
356
+ xregs = xreg_lib.renormalize(xregs, per_instance_stats)
357
+ xregs = np.array(xregs)
358
+ new_point_outputs = [
359
+ (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
360
+ for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
361
+ ]
362
+ new_quantile_outputs = [
363
+ (
364
+ quantile_output[-self.forecast_config.max_horizon :][:test_len]
365
+ + xreg[..., None]
366
+ )
367
+ for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
368
+ ]
369
+
370
+ else:
371
+ # Fit a model on the targets then forecast on the residuals via TimesFM.
372
+ targets = [
373
+ np.array(input_ts)[-train_len:]
374
+ for input_ts, train_len in zip(inputs, train_lens)
375
+ ]
376
+ per_instance_stats = None
377
+ if normalize_xreg_target_per_input:
378
+ targets, per_instance_stats = xreg_lib.normalize(targets)
379
+ xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
380
+ targets=targets,
381
+ train_lens=train_lens,
382
+ test_lens=test_lens,
383
+ train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
384
+ test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
385
+ train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
386
+ test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
387
+ static_numerical_covariates=static_numerical_covariates,
388
+ static_categorical_covariates=static_categorical_covariates,
389
+ ).fit(
390
+ ridge=ridge,
391
+ one_hot_encoder_drop=None if ridge > 0 else "first",
392
+ max_rows_per_col=max_rows_per_col,
393
+ force_on_cpu=force_on_cpu,
394
+ debug_info=True,
395
+ assert_covariates=True,
396
+ assert_covariate_shapes=True,
397
+ )
398
+ point_outputs, quantile_outputs = self.forecast(
399
+ horizon=self.forecast_config.max_horizon,
400
+ inputs=[
401
+ target - xreg_on_context
402
+ for target, xreg_on_context in zip(targets, xregs_on_context)
403
+ ],
404
+ )
405
+ new_point_outputs = [
406
+ (point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
407
+ for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
408
+ ]
409
+ new_quantile_outputs = [
410
+ (
411
+ quantile_output[-self.forecast_config.max_horizon :][:test_len]
412
+ + xreg[..., None]
413
+ )
414
+ for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
415
+ ]
416
+ if normalize_xreg_target_per_input:
417
+ new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)
418
+ new_quantile_outputs = xreg_lib.renormalize(
419
+ new_quantile_outputs, per_instance_stats
420
+ )
421
+
422
+ return new_point_outputs, new_quantile_outputs