tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,520 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Helper functions for in-context covariates and regression."""
15
+
16
+ import itertools
17
+ import math
18
+ from typing import Any, Iterable, Literal, Mapping, Sequence
19
+
20
+ try:
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from sklearn import preprocessing
25
+ except ImportError:
26
+ raise ImportError(
27
+ "Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?"
28
+ )
29
+
30
+ Category = int | str
31
+
32
+ _TOL = 1e-6
33
+ XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
34
+
35
+
36
+ def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
37
+ return np.array(list(itertools.chain.from_iterable(nested)))
38
+
39
+
40
+ def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
41
+ return np.array(
42
+ list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))
43
+ )
44
+
45
+
46
+ def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
47
+ if x.ndim == 1:
48
+ (i,) = x.shape
49
+ di = 2 ** math.ceil(math.log2(i)) - i
50
+ return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
51
+ elif x.ndim == 2:
52
+ i, j = x.shape
53
+ di = 2 ** math.ceil(math.log2(i)) - i
54
+ dj = 2 ** math.ceil(math.log2(j)) - j
55
+ return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
56
+ else:
57
+ raise ValueError(f"Unsupported array shape: {x.shape}")
58
+
59
+
60
+ # Per time series normalization: forward.
61
+ def normalize(batch):
62
+ stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]
63
+ new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
64
+ return new_batch, stats
65
+
66
+
67
+ # Per time series normalization: inverse.
68
+ def renormalize(batch, stats):
69
+ return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
70
+
71
+
72
+ class BatchedInContextXRegBase:
73
+ """Helper class for in-context regression covariate formatting.
74
+
75
+ Attributes:
76
+ targets: List of targets (responses) of the in-context regression.
77
+ train_lens: List of lengths of each target vector from the context.
78
+ test_lens: List of lengths of each forecast horizon.
79
+ train_dynamic_numerical_covariates: Dict of covariate names mapping to the
80
+ dynamic numerical covariates of each forecast task on the context. Their
81
+ lengths should match the corresponding lengths in `train_lens`.
82
+ train_dynamic_categorical_covariates: Dict of covariate names mapping to the
83
+ dynamic categorical covariates of each forecast task on the context. Their
84
+ lengths should match the corresponding lengths in `train_lens`.
85
+ test_dynamic_numerical_covariates: Dict of covariate names mapping to the
86
+ dynamic numerical covariates of each forecast task on the horizon. Their
87
+ lengths should match the corresponding lengths in `test_lens`.
88
+ test_dynamic_categorical_covariates: Dict of covariate names mapping to the
89
+ dynamic categorical covariates of each forecast task on the horizon. Their
90
+ lengths should match the corresponding lengths in `test_lens`.
91
+ static_numerical_covariates: Dict of covariate names mapping to the static
92
+ numerical covariates of each forecast task.
93
+ static_categorical_covariates: Dict of covariate names mapping to the static
94
+ categorical covariates of each forecast task.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ targets: Sequence[Sequence[float]],
100
+ train_lens: Sequence[int],
101
+ test_lens: Sequence[int],
102
+ train_dynamic_numerical_covariates: (
103
+ Mapping[str, Sequence[Sequence[float]]] | None
104
+ ) = None,
105
+ train_dynamic_categorical_covariates: (
106
+ Mapping[str, Sequence[Sequence[Category]]] | None
107
+ ) = None,
108
+ test_dynamic_numerical_covariates: (
109
+ Mapping[str, Sequence[Sequence[float]]] | None
110
+ ) = None,
111
+ test_dynamic_categorical_covariates: (
112
+ Mapping[str, Sequence[Sequence[Category]]] | None
113
+ ) = None,
114
+ static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
115
+ static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,
116
+ ) -> None:
117
+ """Initializes with the exogenous covariate inputs.
118
+
119
+ Here we use model fitting language to refer to the context as 'train' and
120
+ the horizon as 'test'. We assume batched inputs. To properly format the
121
+ request:
122
+
123
+ - `train_lens` represents the contexts in the batch. Targets and all train
124
+ dynamic covariates should have the same lengths as the corresponding
125
+ elements
126
+ in `train_lens`. Notice each `train_len` can be different from the exact
127
+ length of the corresponding context depending on how much of the context is
128
+ used for fitting the in-context model.
129
+ - `test_lens` represents the horizon lengths in the batch. All tesdt
130
+ dynamic
131
+ covariates should have the same lengths as the corresponding elements in
132
+ `test_lens`.
133
+ - Static covariates should be one for each input.
134
+ - For train and test dynamic covariates, they should have the same
135
+ covariate
136
+ names.
137
+
138
+ Pass an empty dict {} for a covariate type if it is not present.
139
+
140
+ Example:
141
+ Here is a set of valid inputs whose schema can be used for reference.
142
+ ```
143
+ targets = [
144
+ [0.0, 0.1, 0.2],
145
+ [0.0, 0.1, 0.2, 0.3],
146
+ ] # Two inputs in this batch.
147
+ train_lens = [3, 4]
148
+ test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
149
+ train_dynamic_numerical_covariates = {
150
+ "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
151
+ "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
152
+ } # Each train dynamic covariate has 3 and 4 elements respectively.
153
+ test_dynamic_numerical_covariates = {
154
+ "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
155
+ "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
156
+ } # Each test dynamic covariate has 2 and 5 elements respectively.
157
+ train_dynamic_categorical_covariates = {
158
+ "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
159
+ "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
160
+ "bad"]],
161
+ }
162
+ test_dynamic_categorical_covariates = {
163
+ "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
164
+ "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
165
+ }
166
+ static_numerical_covariates = {
167
+ "cov_1_sn": [0.0, 3.0],
168
+ "cov_2_sn": [2.0, 1.0],
169
+ "cov_3_sn": [1.0, 2.0],
170
+ } # Each static covariate has 1 element for each input.
171
+ static_categorical_covariates = {
172
+ "cov_1_sc": ["apple", "orange"],
173
+ "cov_2_sc": [2, 3],
174
+ }
175
+ ```
176
+
177
+ Args:
178
+ targets: List of targets (responses) of the in-context regression.
179
+ train_lens: List of lengths of each target vector from the context.
180
+ test_lens: List of lengths of each forecast horizon.
181
+ train_dynamic_numerical_covariates: Dict of covariate names mapping to the
182
+ dynamic numerical covariates of each forecast task on the context. Their
183
+ lengths should match the corresponding lengths in `train_lens`.
184
+ train_dynamic_categorical_covariates: Dict of covariate names mapping to
185
+ the dynamic categorical covariates of each forecast task on the context.
186
+ Their lengths should match the corresponding lengths in `train_lens`.
187
+ test_dynamic_numerical_covariates: Dict of covariate names mapping to the
188
+ dynamic numerical covariates of each forecast task on the horizon. Their
189
+ lengths should match the corresponding lengths in `test_lens`.
190
+ test_dynamic_categorical_covariates: Dict of covariate names mapping to
191
+ the dynamic categorical covariates of each forecast task on the horizon.
192
+ Their lengths should match the corresponding lengths in `test_lens`.
193
+ static_numerical_covariates: Dict of covariate names mapping to the static
194
+ numerical covariates of each forecast task.
195
+ static_categorical_covariates: Dict of covariate names mapping to the
196
+ static categorical covariates of each forecast task.
197
+ """
198
+ self.targets = targets
199
+ self.train_lens = train_lens
200
+ self.test_lens = test_lens
201
+ self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}
202
+ self.train_dynamic_categorical_covariates = (
203
+ train_dynamic_categorical_covariates or {}
204
+ )
205
+ self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}
206
+ self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}
207
+ self.static_numerical_covariates = static_numerical_covariates or {}
208
+ self.static_categorical_covariates = static_categorical_covariates or {}
209
+
210
+ def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
211
+ """Verifies the validity of the covariate inputs."""
212
+
213
+ # Check presence.
214
+ if (
215
+ self.train_dynamic_numerical_covariates
216
+ and not self.test_dynamic_numerical_covariates
217
+ ) or (
218
+ not self.train_dynamic_numerical_covariates
219
+ and self.test_dynamic_numerical_covariates
220
+ ):
221
+ raise ValueError(
222
+ "train_dynamic_numerical_covariates and"
223
+ " test_dynamic_numerical_covariates must be both present or both"
224
+ " absent."
225
+ )
226
+
227
+ if (
228
+ self.train_dynamic_categorical_covariates
229
+ and not self.test_dynamic_categorical_covariates
230
+ ) or (
231
+ not self.train_dynamic_categorical_covariates
232
+ and self.test_dynamic_categorical_covariates
233
+ ):
234
+ raise ValueError(
235
+ "train_dynamic_categorical_covariates and"
236
+ " test_dynamic_categorical_covariates must be both present or both"
237
+ " absent."
238
+ )
239
+
240
+ # Check keys.
241
+ for dict_a, dict_b, dict_a_name, dict_b_name in (
242
+ (
243
+ self.train_dynamic_numerical_covariates,
244
+ self.test_dynamic_numerical_covariates,
245
+ "train_dynamic_numerical_covariates",
246
+ "test_dynamic_numerical_covariates",
247
+ ),
248
+ (
249
+ self.train_dynamic_categorical_covariates,
250
+ self.test_dynamic_categorical_covariates,
251
+ "train_dynamic_categorical_covariates",
252
+ "test_dynamic_categorical_covariates",
253
+ ),
254
+ ):
255
+ if w := set(dict_a.keys()) - set(dict_b.keys()):
256
+ raise ValueError(f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
257
+ if w := set(dict_b.keys()) - set(dict_a.keys()):
258
+ raise ValueError(f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
259
+
260
+ # Check shapes.
261
+ if assert_covariate_shapes:
262
+ if len(self.targets) != len(self.train_lens):
263
+ raise ValueError(
264
+ "targets and train_lens must have the same number of elements."
265
+ )
266
+
267
+ if len(self.train_lens) != len(self.test_lens):
268
+ raise ValueError(
269
+ "train_lens and test_lens must have the same number of elements."
270
+ )
271
+
272
+ for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):
273
+ if len(target) != train_len:
274
+ raise ValueError(
275
+ f"targets[{i}] has length {len(target)} != expected {train_len}."
276
+ )
277
+
278
+ for key, values in self.static_numerical_covariates.items():
279
+ if len(values) != len(self.train_lens):
280
+ raise ValueError(
281
+ f"static_numerical_covariates has key {key} with number of"
282
+ f" examples {len(values)} != expected {len(self.train_lens)}."
283
+ )
284
+
285
+ for key, values in self.static_categorical_covariates.items():
286
+ if len(values) != len(self.train_lens):
287
+ raise ValueError(
288
+ f"static_categorical_covariates has key {key} with number of"
289
+ f" examples {len(values)} != expected {len(self.train_lens)}."
290
+ )
291
+
292
+ for lens, dict_cov, dict_cov_name in (
293
+ (
294
+ self.train_lens,
295
+ self.train_dynamic_numerical_covariates,
296
+ "train_dynamic_numerical_covariates",
297
+ ),
298
+ (
299
+ self.train_lens,
300
+ self.train_dynamic_categorical_covariates,
301
+ "train_dynamic_categorical_covariates",
302
+ ),
303
+ (
304
+ self.test_lens,
305
+ self.test_dynamic_numerical_covariates,
306
+ "test_dynamic_numerical_covariates",
307
+ ),
308
+ (
309
+ self.test_lens,
310
+ self.test_dynamic_categorical_covariates,
311
+ "test_dynamic_categorical_covariates",
312
+ ),
313
+ ):
314
+ for key, cov_values in dict_cov.items():
315
+ if len(cov_values) != len(lens):
316
+ raise ValueError(
317
+ f"{dict_cov_name} has key {key} with number of examples"
318
+ f" {len(cov_values)} != expected {len(lens)}."
319
+ )
320
+ for i, cov_value in enumerate(cov_values):
321
+ if len(cov_value) != lens[i]:
322
+ raise ValueError(
323
+ f"{dict_cov_name} has key {key} with its {i}-th example"
324
+ f" length {len(cov_value)} != expected {lens[i]}."
325
+ )
326
+
327
+ def create_covariate_matrix(
328
+ self,
329
+ one_hot_encoder_drop: str | None = "first",
330
+ use_intercept: bool = True,
331
+ assert_covariates: bool = False,
332
+ assert_covariate_shapes: bool = False,
333
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
334
+ """Creates target vector and covariate matrices for in context regression.
335
+
336
+ Here we use model fitting language to refer to the context as 'train' and
337
+ the horizon as 'test'.
338
+
339
+ Args:
340
+ one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
341
+ use_intercept: Whether to prepare an intercept (all 1) column in the
342
+ matrices.
343
+ assert_covariates: Whether to assert the validity of the covariate inputs.
344
+ assert_covariate_shapes: Whether to assert the shapes of the covariate
345
+ inputs when `assert_covariates` is True.
346
+
347
+ Returns:
348
+ A tuple of the target vector, the covariate matrix for the context, and
349
+ the covariate matrix for the horizon.
350
+ """
351
+ if assert_covariates:
352
+ self._assert_covariates(assert_covariate_shapes)
353
+
354
+ x_train, x_test = [], []
355
+
356
+ # Numerical features.
357
+ for name in sorted(self.train_dynamic_numerical_covariates):
358
+ x_train.append(
359
+ _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]
360
+ )
361
+ x_test.append(
362
+ _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]
363
+ )
364
+
365
+ for covs in self.static_numerical_covariates.values():
366
+ x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
367
+ x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
368
+
369
+ if x_train:
370
+ x_train = np.concatenate(x_train, axis=1)
371
+ x_test = np.concatenate(x_test, axis=1)
372
+
373
+ # Normalize for robustness.
374
+ x_mean = np.mean(x_train, axis=0, keepdims=True)
375
+ x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)
376
+ x_train = [(x_train - x_mean) / x_std]
377
+ x_test = [(x_test - x_mean) / x_std]
378
+
379
+ # Categorical features. Encode one by one.
380
+ one_hot_encoder = preprocessing.OneHotEncoder(
381
+ drop=one_hot_encoder_drop,
382
+ sparse_output=False,
383
+ handle_unknown="ignore",
384
+ )
385
+ for name in sorted(self.train_dynamic_categorical_covariates.keys()):
386
+ ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[
387
+ :, np.newaxis
388
+ ]
389
+ ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
390
+ x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
391
+ x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
392
+
393
+ for covs in self.static_categorical_covariates.values():
394
+ ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
395
+ x_train.append(_repeat(ohe, self.train_lens))
396
+ x_test.append(_repeat(ohe, self.test_lens))
397
+
398
+ x_train = np.concatenate(x_train, axis=1)
399
+ x_test = np.concatenate(x_test, axis=1)
400
+
401
+ if use_intercept:
402
+ x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
403
+ x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
404
+
405
+ return _unnest(self.targets), x_train, x_test
406
+
407
+ def fit(self) -> Any:
408
+ raise NotImplementedError("Fit is not implemented.")
409
+
410
+
411
+ class BatchedInContextXRegLinear(BatchedInContextXRegBase):
412
+ """Linear in-context regression model."""
413
+
414
+ def fit(
415
+ self,
416
+ ridge: float = 0.0,
417
+ one_hot_encoder_drop: str | None = "first",
418
+ use_intercept: bool = True,
419
+ force_on_cpu: bool = False,
420
+ max_rows_per_col: int = 0,
421
+ max_rows_per_col_sample_seed: int = 42,
422
+ debug_info: bool = False,
423
+ assert_covariates: bool = False,
424
+ assert_covariate_shapes: bool = False,
425
+ ) -> (
426
+ list[np.ndarray]
427
+ | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]
428
+ ):
429
+ """Fits a linear model for in-context regression.
430
+
431
+ Args:
432
+ ridge: A non-negative value for specifying the ridge regression penalty.
433
+ If 0 is provided, fallback to ordinary least squares. Note this penalty
434
+ is added to the normalized covariate matrix.
435
+ one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
436
+ use_intercept: Whether to prepare an intercept (all 1) column in the
437
+ matrices.
438
+ force_on_cpu: Whether to force execution on cpu for accelerator machines.
439
+ max_rows_per_col: How many rows to subsample per column. 0 for no
440
+ subsampling. This is for speeding up model fitting.
441
+ max_rows_per_col_sample_seed: The seed for the subsampling if needed by
442
+ `max_rows_per_col`.
443
+ debug_info: Whether to return debug info.
444
+ assert_covariates: Whether to assert the validity of the covariate inputs.
445
+ assert_covariate_shapes: Whether to assert the shapes of the covariate
446
+ inputs when `assert_covariates` is True.
447
+
448
+ Returns:
449
+ If `debug_info` is False:
450
+ The linear fits on the horizon.
451
+ If `debug_info` is True:
452
+ A tuple of:
453
+ - the linear fits on the horizon,
454
+ - the linear fits on the context,
455
+ - the flattened target vector,
456
+ - the covariate matrix for the context, and
457
+ - the covariate matrix for the horizon.
458
+ """
459
+ flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
460
+ one_hot_encoder_drop=one_hot_encoder_drop,
461
+ use_intercept=use_intercept,
462
+ assert_covariates=assert_covariates,
463
+ assert_covariate_shapes=assert_covariate_shapes,
464
+ )
465
+
466
+ x_train = x_train_raw.copy()
467
+ if max_rows_per_col:
468
+ nrows, ncols = x_train.shape
469
+ if nrows > (w := ncols * max_rows_per_col):
470
+ subsample = jax.random.choice(
471
+ jax.random.PRNGKey(max_rows_per_col_sample_seed),
472
+ nrows,
473
+ (w,),
474
+ replace=False,
475
+ )
476
+ x_train = x_train[subsample]
477
+ flat_targets = flat_targets[subsample]
478
+
479
+ device = jax.devices("cpu")[0] if force_on_cpu else None
480
+ # Runs jitted version of the solvers which are quicker at the cost of
481
+ # running jitting during the first time calling. Re-jitting happens whenever
482
+ # new (padded) shapes are encountered.
483
+ # Ocassionally it helps with the speed and the accuracy if we force single
484
+ # thread execution on cpu for accelerator machines:
485
+ # 1. Avoid moving data to accelarator memory.
486
+ # 2. Avoid precision loss if any.
487
+ with jax.default_device(device):
488
+ x_train_raw = _to_padded_jax_array(x_train_raw)
489
+ x_train = _to_padded_jax_array(x_train)
490
+ flat_targets = _to_padded_jax_array(flat_targets)
491
+ x_test = _to_padded_jax_array(x_test)
492
+ beta_hat = (
493
+ jnp.linalg.pinv(
494
+ x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
495
+ hermitian=True,
496
+ )
497
+ @ x_train.T
498
+ @ flat_targets
499
+ )
500
+ y_hat = x_test @ beta_hat
501
+ y_hat_context = x_train_raw @ beta_hat if debug_info else None
502
+
503
+ outputs = []
504
+ outputs_context = []
505
+
506
+ # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
507
+ train_index, test_index = 0, 0
508
+ for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):
509
+ outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))
510
+ if debug_info:
511
+ outputs_context.append(
512
+ np.array(y_hat_context[train_index : (train_index + train_index_delta)])
513
+ )
514
+ train_index += train_index_delta
515
+ test_index += test_index_delta
516
+
517
+ if debug_info:
518
+ return outputs, outputs_context, flat_targets, x_train, x_test
519
+ else:
520
+ return outputs
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: tsagentkit-timesfm
3
+ Version: 1.0.0
4
+ Summary: TimesFM fork distribution for tsagentkit.
5
+ Author-email: Rajat Sen <senrajat@google.com>, Yichen Zhou <yichenzhou@google.com>, Abhimanyu Das <abhidas@google.com>, Petros Mol <pmol@google.com>, Michael Chertushkin <chertushkinmichael@gmail.com>
6
+ Maintainer-email: cuizhengliang <czlbou2012@gmail.com>
7
+ License: Apache-2.0
8
+ Project-URL: Homepage, https://github.com/LeonEthan/timesfm
9
+ Project-URL: Repository, https://github.com/LeonEthan/timesfm
10
+ Project-URL: Issues, https://github.com/LeonEthan/timesfm/issues
11
+ Project-URL: Upstream, https://github.com/google-research/timesfm
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy>=1.26.4
16
+ Requires-Dist: huggingface_hub[cli]>=0.23.0
17
+ Requires-Dist: safetensors>=0.5.3
18
+ Provides-Extra: torch
19
+ Requires-Dist: torch>=2.0.0; extra == "torch"
20
+ Provides-Extra: flax
21
+ Requires-Dist: flax; extra == "flax"
22
+ Requires-Dist: optax; extra == "flax"
23
+ Requires-Dist: einshape; extra == "flax"
24
+ Requires-Dist: orbax-checkpoint; extra == "flax"
25
+ Requires-Dist: jaxtyping; extra == "flax"
26
+ Requires-Dist: jax[cuda]; extra == "flax"
27
+ Provides-Extra: xreg
28
+ Requires-Dist: jax[cuda]; extra == "xreg"
29
+ Requires-Dist: scikit-learn; extra == "xreg"
30
+ Dynamic: license-file
31
+
32
+ # TimesFM
33
+
34
+ TimesFM (Time Series Foundation Model) is a pretrained time-series foundation
35
+ model developed by Google Research for time-series forecasting.
36
+
37
+ * Paper:
38
+ [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),
39
+ ICML 2024.
40
+ * All checkpoints:
41
+ [TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).
42
+ * [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).
43
+ * [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):
44
+ an official Google product.
45
+
46
+ This open version is not an officially supported Google product.
47
+
48
+ ## Fork Notice (tsagentkit)
49
+
50
+ This repository is a fork of `google-research/timesfm`, published to PyPI as
51
+ `tsagentkit-timesfm` to support the `tsagentkit` project without a GitHub
52
+ dependency. The Python import name remains `timesfm`.
53
+
54
+ **Latest Model Version:** TimesFM 2.5
55
+
56
+ **Archived Model Versions:**
57
+
58
+ - 1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip
59
+ install timesfm==1.3.0` to install an older version of this package to load
60
+ them.
61
+
62
+ ## Update - Oct. 29, 2025
63
+
64
+ Added back the covariate support through XReg for TimesFM 2.5.
65
+
66
+
67
+ ## Update - Sept. 15, 2025
68
+
69
+ TimesFM 2.5 is out!
70
+
71
+ Comparing to TimesFM 2.0, this new 2.5 model:
72
+
73
+ - uses 200M parameters, down from 500M.
74
+ - supports up to 16k context length, up from 2048.
75
+ - supports continuous quantile forecast up to 1k horizon via an optional 30M
76
+ quantile head.
77
+ - gets rid of the `frequency` indicator.
78
+ - has a couple of new forecasting flags.
79
+
80
+ Along with the model upgrade we have also upgraded the inference API. This repo
81
+ will be under construction over the next few weeks to
82
+
83
+ 1. add support for an upcoming Flax version of the model (faster inference).
84
+ 2. add back covariate support.
85
+ 3. populate more docstrings, docs and notebook.
86
+
87
+ ### Install
88
+
89
+ 1. Install from PyPI:
90
+ ```shell
91
+ pip install "tsagentkit-timesfm[torch]"
92
+ # Or with flax
93
+ pip install "tsagentkit-timesfm[flax]"
94
+ # Or XReg support
95
+ pip install "tsagentkit-timesfm[xreg]"
96
+ ```
97
+
98
+ 2. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators
99
+ (CPU, GPU, TPU or Apple Silicon).:
100
+
101
+ - [Install PyTorch](https://pytorch.org/get-started/locally/).
102
+ - [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)
103
+ for Flax.
104
+
105
+ ### From Source (optional)
106
+
107
+ ```shell
108
+ git clone https://github.com/LeonEthan/timesfm.git
109
+ cd timesfm
110
+
111
+ uv venv
112
+ source .venv/bin/activate
113
+
114
+ uv pip install -e .[torch]
115
+ # Or with flax
116
+ uv pip install -e .[flax]
117
+ # Or XReg support
118
+ uv pip install -e .[xreg]
119
+ ```
120
+
121
+ ### Code Example
122
+
123
+ ```python
124
+ import torch
125
+ import numpy as np
126
+ import timesfm
127
+
128
+ torch.set_float32_matmul_precision("high")
129
+
130
+ model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")
131
+
132
+ model.compile(
133
+ timesfm.ForecastConfig(
134
+ max_context=1024,
135
+ max_horizon=256,
136
+ normalize_inputs=True,
137
+ use_continuous_quantile_head=True,
138
+ force_flip_invariance=True,
139
+ infer_is_positive=True,
140
+ fix_quantile_crossing=True,
141
+ )
142
+ )
143
+ point_forecast, quantile_forecast = model.forecast(
144
+ horizon=12,
145
+ inputs=[
146
+ np.linspace(0, 1, 100),
147
+ np.sin(np.linspace(0, 20, 67)),
148
+ ], # Two dummy inputs
149
+ )
150
+ point_forecast.shape # (2, 12)
151
+ quantile_forecast.shape # (2, 12, 10): mean, then 10th to 90th quantiles.
152
+ ```
@@ -0,0 +1,21 @@
1
+ timesfm/__init__.py,sha256=B2pgyjc3hgIwQQqs5BuJh9tZy3cv1ox39bhyp_JIJBs,919
2
+ timesfm/configs.py,sha256=hhAsa7LEZgEct-gZqP6xl7HBKWR5gHNgHguE7JPjJxk,3617
3
+ timesfm/flax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
4
+ timesfm/flax/dense.py,sha256=vkcDueNfcXgR1HL61vCOWzwHCKH9TreiQANJzMK2_Pg,3522
5
+ timesfm/flax/normalization.py,sha256=KGeCNAftZVXtE_5_pfBecBNjJzDMp6JOCl2WoPdBvEM,2168
6
+ timesfm/flax/transformer.py,sha256=K9qO31l5bwUf9zl-myl8HDltAkC-O4CoSWg8up7vIm4,11506
7
+ timesfm/flax/util.py,sha256=2SFLO6xHT84whuNumVIwZTXXPag-XKrKK2hbj74grmM,3027
8
+ timesfm/timesfm_2p5/timesfm_2p5_base.py,sha256=hR-BIcNki2DVJ5k4B3FczT4C0bowpcJu381XR9FLwfw,15174
9
+ timesfm/timesfm_2p5/timesfm_2p5_flax.py,sha256=ts42ros5VP-IdldYqWtWqG9ckNm30Z0dEaXuJ_r9UAo,19375
10
+ timesfm/timesfm_2p5/timesfm_2p5_torch.py,sha256=MPwE6LL_6OOjKJ-qSINO4zGQB_Vho7LfUYjDeoAvboI,16943
11
+ timesfm/torch/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
12
+ timesfm/torch/dense.py,sha256=KIlJhThOqj3aexuILVu4Qm2LJ0stmgHLFMRlHAWE7B0,3109
13
+ timesfm/torch/normalization.py,sha256=w3tWr9GYrr9e9x7gTvJEpjMY2i5KBJfsmzGQiPbMUVQ,1204
14
+ timesfm/torch/transformer.py,sha256=dEoOHkHBJiFQmY-8Yhu35LsvucebD-RucN6nRFEZYrs,11720
15
+ timesfm/torch/util.py,sha256=xwCoJFnhFaHGyXlwWObjJyqnnocFoR1I2DSJtndm6Rc,2654
16
+ timesfm/utils/xreg_lib.py,sha256=RgmzwRqhdE2zZcih9wf3Kkm2iUi8XWseDr2IZ3Kj9NI,20721
17
+ tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
18
+ tsagentkit_timesfm-1.0.0.dist-info/METADATA,sha256=_hzD2eY53jd9XY2CmlVlhqumxRLmp-Phu_C4nhNo1ks,4823
19
+ tsagentkit_timesfm-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
20
+ tsagentkit_timesfm-1.0.0.dist-info/top_level.txt,sha256=5CtB5agsTzKTMKI3M87NsaJk7gh51fFKxmwRkdfeLdg,8
21
+ tsagentkit_timesfm-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+