tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Helper functions for in-context covariates and regression."""
|
|
15
|
+
|
|
16
|
+
import itertools
|
|
17
|
+
import math
|
|
18
|
+
from typing import Any, Iterable, Literal, Mapping, Sequence
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
import numpy as np
|
|
24
|
+
from sklearn import preprocessing
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"Failed to load the XReg module. Did you forget to install `timesfm[xreg]`?"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
Category = int | str
|
|
31
|
+
|
|
32
|
+
_TOL = 1e-6
|
|
33
|
+
XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
|
|
37
|
+
return np.array(list(itertools.chain.from_iterable(nested)))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
|
|
41
|
+
return np.array(
|
|
42
|
+
list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts)))
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
|
|
47
|
+
if x.ndim == 1:
|
|
48
|
+
(i,) = x.shape
|
|
49
|
+
di = 2 ** math.ceil(math.log2(i)) - i
|
|
50
|
+
return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
|
|
51
|
+
elif x.ndim == 2:
|
|
52
|
+
i, j = x.shape
|
|
53
|
+
di = 2 ** math.ceil(math.log2(i)) - i
|
|
54
|
+
dj = 2 ** math.ceil(math.log2(j)) - j
|
|
55
|
+
return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(f"Unsupported array shape: {x.shape}")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Per time series normalization: forward.
|
|
61
|
+
def normalize(batch):
|
|
62
|
+
stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch]
|
|
63
|
+
new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
|
|
64
|
+
return new_batch, stats
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Per time series normalization: inverse.
|
|
68
|
+
def renormalize(batch, stats):
|
|
69
|
+
return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BatchedInContextXRegBase:
|
|
73
|
+
"""Helper class for in-context regression covariate formatting.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
targets: List of targets (responses) of the in-context regression.
|
|
77
|
+
train_lens: List of lengths of each target vector from the context.
|
|
78
|
+
test_lens: List of lengths of each forecast horizon.
|
|
79
|
+
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
|
|
80
|
+
dynamic numerical covariates of each forecast task on the context. Their
|
|
81
|
+
lengths should match the corresponding lengths in `train_lens`.
|
|
82
|
+
train_dynamic_categorical_covariates: Dict of covariate names mapping to the
|
|
83
|
+
dynamic categorical covariates of each forecast task on the context. Their
|
|
84
|
+
lengths should match the corresponding lengths in `train_lens`.
|
|
85
|
+
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
|
|
86
|
+
dynamic numerical covariates of each forecast task on the horizon. Their
|
|
87
|
+
lengths should match the corresponding lengths in `test_lens`.
|
|
88
|
+
test_dynamic_categorical_covariates: Dict of covariate names mapping to the
|
|
89
|
+
dynamic categorical covariates of each forecast task on the horizon. Their
|
|
90
|
+
lengths should match the corresponding lengths in `test_lens`.
|
|
91
|
+
static_numerical_covariates: Dict of covariate names mapping to the static
|
|
92
|
+
numerical covariates of each forecast task.
|
|
93
|
+
static_categorical_covariates: Dict of covariate names mapping to the static
|
|
94
|
+
categorical covariates of each forecast task.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
targets: Sequence[Sequence[float]],
|
|
100
|
+
train_lens: Sequence[int],
|
|
101
|
+
test_lens: Sequence[int],
|
|
102
|
+
train_dynamic_numerical_covariates: (
|
|
103
|
+
Mapping[str, Sequence[Sequence[float]]] | None
|
|
104
|
+
) = None,
|
|
105
|
+
train_dynamic_categorical_covariates: (
|
|
106
|
+
Mapping[str, Sequence[Sequence[Category]]] | None
|
|
107
|
+
) = None,
|
|
108
|
+
test_dynamic_numerical_covariates: (
|
|
109
|
+
Mapping[str, Sequence[Sequence[float]]] | None
|
|
110
|
+
) = None,
|
|
111
|
+
test_dynamic_categorical_covariates: (
|
|
112
|
+
Mapping[str, Sequence[Sequence[Category]]] | None
|
|
113
|
+
) = None,
|
|
114
|
+
static_numerical_covariates: Mapping[str, Sequence[float]] | None = None,
|
|
115
|
+
static_categorical_covariates: (Mapping[str, Sequence[Category]] | None) = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Initializes with the exogenous covariate inputs.
|
|
118
|
+
|
|
119
|
+
Here we use model fitting language to refer to the context as 'train' and
|
|
120
|
+
the horizon as 'test'. We assume batched inputs. To properly format the
|
|
121
|
+
request:
|
|
122
|
+
|
|
123
|
+
- `train_lens` represents the contexts in the batch. Targets and all train
|
|
124
|
+
dynamic covariates should have the same lengths as the corresponding
|
|
125
|
+
elements
|
|
126
|
+
in `train_lens`. Notice each `train_len` can be different from the exact
|
|
127
|
+
length of the corresponding context depending on how much of the context is
|
|
128
|
+
used for fitting the in-context model.
|
|
129
|
+
- `test_lens` represents the horizon lengths in the batch. All tesdt
|
|
130
|
+
dynamic
|
|
131
|
+
covariates should have the same lengths as the corresponding elements in
|
|
132
|
+
`test_lens`.
|
|
133
|
+
- Static covariates should be one for each input.
|
|
134
|
+
- For train and test dynamic covariates, they should have the same
|
|
135
|
+
covariate
|
|
136
|
+
names.
|
|
137
|
+
|
|
138
|
+
Pass an empty dict {} for a covariate type if it is not present.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
Here is a set of valid inputs whose schema can be used for reference.
|
|
142
|
+
```
|
|
143
|
+
targets = [
|
|
144
|
+
[0.0, 0.1, 0.2],
|
|
145
|
+
[0.0, 0.1, 0.2, 0.3],
|
|
146
|
+
] # Two inputs in this batch.
|
|
147
|
+
train_lens = [3, 4]
|
|
148
|
+
test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
|
|
149
|
+
train_dynamic_numerical_covariates = {
|
|
150
|
+
"cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
|
|
151
|
+
"cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
|
|
152
|
+
} # Each train dynamic covariate has 3 and 4 elements respectively.
|
|
153
|
+
test_dynamic_numerical_covariates = {
|
|
154
|
+
"cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
|
|
155
|
+
"cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
|
|
156
|
+
} # Each test dynamic covariate has 2 and 5 elements respectively.
|
|
157
|
+
train_dynamic_categorical_covariates = {
|
|
158
|
+
"cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
|
|
159
|
+
"cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
|
|
160
|
+
"bad"]],
|
|
161
|
+
}
|
|
162
|
+
test_dynamic_categorical_covariates = {
|
|
163
|
+
"cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
|
|
164
|
+
"cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
|
|
165
|
+
}
|
|
166
|
+
static_numerical_covariates = {
|
|
167
|
+
"cov_1_sn": [0.0, 3.0],
|
|
168
|
+
"cov_2_sn": [2.0, 1.0],
|
|
169
|
+
"cov_3_sn": [1.0, 2.0],
|
|
170
|
+
} # Each static covariate has 1 element for each input.
|
|
171
|
+
static_categorical_covariates = {
|
|
172
|
+
"cov_1_sc": ["apple", "orange"],
|
|
173
|
+
"cov_2_sc": [2, 3],
|
|
174
|
+
}
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
targets: List of targets (responses) of the in-context regression.
|
|
179
|
+
train_lens: List of lengths of each target vector from the context.
|
|
180
|
+
test_lens: List of lengths of each forecast horizon.
|
|
181
|
+
train_dynamic_numerical_covariates: Dict of covariate names mapping to the
|
|
182
|
+
dynamic numerical covariates of each forecast task on the context. Their
|
|
183
|
+
lengths should match the corresponding lengths in `train_lens`.
|
|
184
|
+
train_dynamic_categorical_covariates: Dict of covariate names mapping to
|
|
185
|
+
the dynamic categorical covariates of each forecast task on the context.
|
|
186
|
+
Their lengths should match the corresponding lengths in `train_lens`.
|
|
187
|
+
test_dynamic_numerical_covariates: Dict of covariate names mapping to the
|
|
188
|
+
dynamic numerical covariates of each forecast task on the horizon. Their
|
|
189
|
+
lengths should match the corresponding lengths in `test_lens`.
|
|
190
|
+
test_dynamic_categorical_covariates: Dict of covariate names mapping to
|
|
191
|
+
the dynamic categorical covariates of each forecast task on the horizon.
|
|
192
|
+
Their lengths should match the corresponding lengths in `test_lens`.
|
|
193
|
+
static_numerical_covariates: Dict of covariate names mapping to the static
|
|
194
|
+
numerical covariates of each forecast task.
|
|
195
|
+
static_categorical_covariates: Dict of covariate names mapping to the
|
|
196
|
+
static categorical covariates of each forecast task.
|
|
197
|
+
"""
|
|
198
|
+
self.targets = targets
|
|
199
|
+
self.train_lens = train_lens
|
|
200
|
+
self.test_lens = test_lens
|
|
201
|
+
self.train_dynamic_numerical_covariates = train_dynamic_numerical_covariates or {}
|
|
202
|
+
self.train_dynamic_categorical_covariates = (
|
|
203
|
+
train_dynamic_categorical_covariates or {}
|
|
204
|
+
)
|
|
205
|
+
self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {}
|
|
206
|
+
self.test_dynamic_categorical_covariates = test_dynamic_categorical_covariates or {}
|
|
207
|
+
self.static_numerical_covariates = static_numerical_covariates or {}
|
|
208
|
+
self.static_categorical_covariates = static_categorical_covariates or {}
|
|
209
|
+
|
|
210
|
+
def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
|
|
211
|
+
"""Verifies the validity of the covariate inputs."""
|
|
212
|
+
|
|
213
|
+
# Check presence.
|
|
214
|
+
if (
|
|
215
|
+
self.train_dynamic_numerical_covariates
|
|
216
|
+
and not self.test_dynamic_numerical_covariates
|
|
217
|
+
) or (
|
|
218
|
+
not self.train_dynamic_numerical_covariates
|
|
219
|
+
and self.test_dynamic_numerical_covariates
|
|
220
|
+
):
|
|
221
|
+
raise ValueError(
|
|
222
|
+
"train_dynamic_numerical_covariates and"
|
|
223
|
+
" test_dynamic_numerical_covariates must be both present or both"
|
|
224
|
+
" absent."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if (
|
|
228
|
+
self.train_dynamic_categorical_covariates
|
|
229
|
+
and not self.test_dynamic_categorical_covariates
|
|
230
|
+
) or (
|
|
231
|
+
not self.train_dynamic_categorical_covariates
|
|
232
|
+
and self.test_dynamic_categorical_covariates
|
|
233
|
+
):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
"train_dynamic_categorical_covariates and"
|
|
236
|
+
" test_dynamic_categorical_covariates must be both present or both"
|
|
237
|
+
" absent."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Check keys.
|
|
241
|
+
for dict_a, dict_b, dict_a_name, dict_b_name in (
|
|
242
|
+
(
|
|
243
|
+
self.train_dynamic_numerical_covariates,
|
|
244
|
+
self.test_dynamic_numerical_covariates,
|
|
245
|
+
"train_dynamic_numerical_covariates",
|
|
246
|
+
"test_dynamic_numerical_covariates",
|
|
247
|
+
),
|
|
248
|
+
(
|
|
249
|
+
self.train_dynamic_categorical_covariates,
|
|
250
|
+
self.test_dynamic_categorical_covariates,
|
|
251
|
+
"train_dynamic_categorical_covariates",
|
|
252
|
+
"test_dynamic_categorical_covariates",
|
|
253
|
+
),
|
|
254
|
+
):
|
|
255
|
+
if w := set(dict_a.keys()) - set(dict_b.keys()):
|
|
256
|
+
raise ValueError(f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
|
|
257
|
+
if w := set(dict_b.keys()) - set(dict_a.keys()):
|
|
258
|
+
raise ValueError(f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
|
|
259
|
+
|
|
260
|
+
# Check shapes.
|
|
261
|
+
if assert_covariate_shapes:
|
|
262
|
+
if len(self.targets) != len(self.train_lens):
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"targets and train_lens must have the same number of elements."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if len(self.train_lens) != len(self.test_lens):
|
|
268
|
+
raise ValueError(
|
|
269
|
+
"train_lens and test_lens must have the same number of elements."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)):
|
|
273
|
+
if len(target) != train_len:
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f"targets[{i}] has length {len(target)} != expected {train_len}."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
for key, values in self.static_numerical_covariates.items():
|
|
279
|
+
if len(values) != len(self.train_lens):
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"static_numerical_covariates has key {key} with number of"
|
|
282
|
+
f" examples {len(values)} != expected {len(self.train_lens)}."
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
for key, values in self.static_categorical_covariates.items():
|
|
286
|
+
if len(values) != len(self.train_lens):
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"static_categorical_covariates has key {key} with number of"
|
|
289
|
+
f" examples {len(values)} != expected {len(self.train_lens)}."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
for lens, dict_cov, dict_cov_name in (
|
|
293
|
+
(
|
|
294
|
+
self.train_lens,
|
|
295
|
+
self.train_dynamic_numerical_covariates,
|
|
296
|
+
"train_dynamic_numerical_covariates",
|
|
297
|
+
),
|
|
298
|
+
(
|
|
299
|
+
self.train_lens,
|
|
300
|
+
self.train_dynamic_categorical_covariates,
|
|
301
|
+
"train_dynamic_categorical_covariates",
|
|
302
|
+
),
|
|
303
|
+
(
|
|
304
|
+
self.test_lens,
|
|
305
|
+
self.test_dynamic_numerical_covariates,
|
|
306
|
+
"test_dynamic_numerical_covariates",
|
|
307
|
+
),
|
|
308
|
+
(
|
|
309
|
+
self.test_lens,
|
|
310
|
+
self.test_dynamic_categorical_covariates,
|
|
311
|
+
"test_dynamic_categorical_covariates",
|
|
312
|
+
),
|
|
313
|
+
):
|
|
314
|
+
for key, cov_values in dict_cov.items():
|
|
315
|
+
if len(cov_values) != len(lens):
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"{dict_cov_name} has key {key} with number of examples"
|
|
318
|
+
f" {len(cov_values)} != expected {len(lens)}."
|
|
319
|
+
)
|
|
320
|
+
for i, cov_value in enumerate(cov_values):
|
|
321
|
+
if len(cov_value) != lens[i]:
|
|
322
|
+
raise ValueError(
|
|
323
|
+
f"{dict_cov_name} has key {key} with its {i}-th example"
|
|
324
|
+
f" length {len(cov_value)} != expected {lens[i]}."
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def create_covariate_matrix(
|
|
328
|
+
self,
|
|
329
|
+
one_hot_encoder_drop: str | None = "first",
|
|
330
|
+
use_intercept: bool = True,
|
|
331
|
+
assert_covariates: bool = False,
|
|
332
|
+
assert_covariate_shapes: bool = False,
|
|
333
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
334
|
+
"""Creates target vector and covariate matrices for in context regression.
|
|
335
|
+
|
|
336
|
+
Here we use model fitting language to refer to the context as 'train' and
|
|
337
|
+
the horizon as 'test'.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
|
|
341
|
+
use_intercept: Whether to prepare an intercept (all 1) column in the
|
|
342
|
+
matrices.
|
|
343
|
+
assert_covariates: Whether to assert the validity of the covariate inputs.
|
|
344
|
+
assert_covariate_shapes: Whether to assert the shapes of the covariate
|
|
345
|
+
inputs when `assert_covariates` is True.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
A tuple of the target vector, the covariate matrix for the context, and
|
|
349
|
+
the covariate matrix for the horizon.
|
|
350
|
+
"""
|
|
351
|
+
if assert_covariates:
|
|
352
|
+
self._assert_covariates(assert_covariate_shapes)
|
|
353
|
+
|
|
354
|
+
x_train, x_test = [], []
|
|
355
|
+
|
|
356
|
+
# Numerical features.
|
|
357
|
+
for name in sorted(self.train_dynamic_numerical_covariates):
|
|
358
|
+
x_train.append(
|
|
359
|
+
_unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]
|
|
360
|
+
)
|
|
361
|
+
x_test.append(
|
|
362
|
+
_unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
for covs in self.static_numerical_covariates.values():
|
|
366
|
+
x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
|
|
367
|
+
x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
|
|
368
|
+
|
|
369
|
+
if x_train:
|
|
370
|
+
x_train = np.concatenate(x_train, axis=1)
|
|
371
|
+
x_test = np.concatenate(x_test, axis=1)
|
|
372
|
+
|
|
373
|
+
# Normalize for robustness.
|
|
374
|
+
x_mean = np.mean(x_train, axis=0, keepdims=True)
|
|
375
|
+
x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0)
|
|
376
|
+
x_train = [(x_train - x_mean) / x_std]
|
|
377
|
+
x_test = [(x_test - x_mean) / x_std]
|
|
378
|
+
|
|
379
|
+
# Categorical features. Encode one by one.
|
|
380
|
+
one_hot_encoder = preprocessing.OneHotEncoder(
|
|
381
|
+
drop=one_hot_encoder_drop,
|
|
382
|
+
sparse_output=False,
|
|
383
|
+
handle_unknown="ignore",
|
|
384
|
+
)
|
|
385
|
+
for name in sorted(self.train_dynamic_categorical_covariates.keys()):
|
|
386
|
+
ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[
|
|
387
|
+
:, np.newaxis
|
|
388
|
+
]
|
|
389
|
+
ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
|
|
390
|
+
x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
|
|
391
|
+
x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
|
|
392
|
+
|
|
393
|
+
for covs in self.static_categorical_covariates.values():
|
|
394
|
+
ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
|
|
395
|
+
x_train.append(_repeat(ohe, self.train_lens))
|
|
396
|
+
x_test.append(_repeat(ohe, self.test_lens))
|
|
397
|
+
|
|
398
|
+
x_train = np.concatenate(x_train, axis=1)
|
|
399
|
+
x_test = np.concatenate(x_test, axis=1)
|
|
400
|
+
|
|
401
|
+
if use_intercept:
|
|
402
|
+
x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
|
|
403
|
+
x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
|
|
404
|
+
|
|
405
|
+
return _unnest(self.targets), x_train, x_test
|
|
406
|
+
|
|
407
|
+
def fit(self) -> Any:
|
|
408
|
+
raise NotImplementedError("Fit is not implemented.")
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class BatchedInContextXRegLinear(BatchedInContextXRegBase):
|
|
412
|
+
"""Linear in-context regression model."""
|
|
413
|
+
|
|
414
|
+
def fit(
|
|
415
|
+
self,
|
|
416
|
+
ridge: float = 0.0,
|
|
417
|
+
one_hot_encoder_drop: str | None = "first",
|
|
418
|
+
use_intercept: bool = True,
|
|
419
|
+
force_on_cpu: bool = False,
|
|
420
|
+
max_rows_per_col: int = 0,
|
|
421
|
+
max_rows_per_col_sample_seed: int = 42,
|
|
422
|
+
debug_info: bool = False,
|
|
423
|
+
assert_covariates: bool = False,
|
|
424
|
+
assert_covariate_shapes: bool = False,
|
|
425
|
+
) -> (
|
|
426
|
+
list[np.ndarray]
|
|
427
|
+
| tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array]
|
|
428
|
+
):
|
|
429
|
+
"""Fits a linear model for in-context regression.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
ridge: A non-negative value for specifying the ridge regression penalty.
|
|
433
|
+
If 0 is provided, fallback to ordinary least squares. Note this penalty
|
|
434
|
+
is added to the normalized covariate matrix.
|
|
435
|
+
one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
|
|
436
|
+
use_intercept: Whether to prepare an intercept (all 1) column in the
|
|
437
|
+
matrices.
|
|
438
|
+
force_on_cpu: Whether to force execution on cpu for accelerator machines.
|
|
439
|
+
max_rows_per_col: How many rows to subsample per column. 0 for no
|
|
440
|
+
subsampling. This is for speeding up model fitting.
|
|
441
|
+
max_rows_per_col_sample_seed: The seed for the subsampling if needed by
|
|
442
|
+
`max_rows_per_col`.
|
|
443
|
+
debug_info: Whether to return debug info.
|
|
444
|
+
assert_covariates: Whether to assert the validity of the covariate inputs.
|
|
445
|
+
assert_covariate_shapes: Whether to assert the shapes of the covariate
|
|
446
|
+
inputs when `assert_covariates` is True.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
If `debug_info` is False:
|
|
450
|
+
The linear fits on the horizon.
|
|
451
|
+
If `debug_info` is True:
|
|
452
|
+
A tuple of:
|
|
453
|
+
- the linear fits on the horizon,
|
|
454
|
+
- the linear fits on the context,
|
|
455
|
+
- the flattened target vector,
|
|
456
|
+
- the covariate matrix for the context, and
|
|
457
|
+
- the covariate matrix for the horizon.
|
|
458
|
+
"""
|
|
459
|
+
flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
|
|
460
|
+
one_hot_encoder_drop=one_hot_encoder_drop,
|
|
461
|
+
use_intercept=use_intercept,
|
|
462
|
+
assert_covariates=assert_covariates,
|
|
463
|
+
assert_covariate_shapes=assert_covariate_shapes,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
x_train = x_train_raw.copy()
|
|
467
|
+
if max_rows_per_col:
|
|
468
|
+
nrows, ncols = x_train.shape
|
|
469
|
+
if nrows > (w := ncols * max_rows_per_col):
|
|
470
|
+
subsample = jax.random.choice(
|
|
471
|
+
jax.random.PRNGKey(max_rows_per_col_sample_seed),
|
|
472
|
+
nrows,
|
|
473
|
+
(w,),
|
|
474
|
+
replace=False,
|
|
475
|
+
)
|
|
476
|
+
x_train = x_train[subsample]
|
|
477
|
+
flat_targets = flat_targets[subsample]
|
|
478
|
+
|
|
479
|
+
device = jax.devices("cpu")[0] if force_on_cpu else None
|
|
480
|
+
# Runs jitted version of the solvers which are quicker at the cost of
|
|
481
|
+
# running jitting during the first time calling. Re-jitting happens whenever
|
|
482
|
+
# new (padded) shapes are encountered.
|
|
483
|
+
# Ocassionally it helps with the speed and the accuracy if we force single
|
|
484
|
+
# thread execution on cpu for accelerator machines:
|
|
485
|
+
# 1. Avoid moving data to accelarator memory.
|
|
486
|
+
# 2. Avoid precision loss if any.
|
|
487
|
+
with jax.default_device(device):
|
|
488
|
+
x_train_raw = _to_padded_jax_array(x_train_raw)
|
|
489
|
+
x_train = _to_padded_jax_array(x_train)
|
|
490
|
+
flat_targets = _to_padded_jax_array(flat_targets)
|
|
491
|
+
x_test = _to_padded_jax_array(x_test)
|
|
492
|
+
beta_hat = (
|
|
493
|
+
jnp.linalg.pinv(
|
|
494
|
+
x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
|
|
495
|
+
hermitian=True,
|
|
496
|
+
)
|
|
497
|
+
@ x_train.T
|
|
498
|
+
@ flat_targets
|
|
499
|
+
)
|
|
500
|
+
y_hat = x_test @ beta_hat
|
|
501
|
+
y_hat_context = x_train_raw @ beta_hat if debug_info else None
|
|
502
|
+
|
|
503
|
+
outputs = []
|
|
504
|
+
outputs_context = []
|
|
505
|
+
|
|
506
|
+
# Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
|
|
507
|
+
train_index, test_index = 0, 0
|
|
508
|
+
for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens):
|
|
509
|
+
outputs.append(np.array(y_hat[test_index : (test_index + test_index_delta)]))
|
|
510
|
+
if debug_info:
|
|
511
|
+
outputs_context.append(
|
|
512
|
+
np.array(y_hat_context[train_index : (train_index + train_index_delta)])
|
|
513
|
+
)
|
|
514
|
+
train_index += train_index_delta
|
|
515
|
+
test_index += test_index_delta
|
|
516
|
+
|
|
517
|
+
if debug_info:
|
|
518
|
+
return outputs, outputs_context, flat_targets, x_train, x_test
|
|
519
|
+
else:
|
|
520
|
+
return outputs
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tsagentkit-timesfm
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: TimesFM fork distribution for tsagentkit.
|
|
5
|
+
Author-email: Rajat Sen <senrajat@google.com>, Yichen Zhou <yichenzhou@google.com>, Abhimanyu Das <abhidas@google.com>, Petros Mol <pmol@google.com>, Michael Chertushkin <chertushkinmichael@gmail.com>
|
|
6
|
+
Maintainer-email: cuizhengliang <czlbou2012@gmail.com>
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Project-URL: Homepage, https://github.com/LeonEthan/timesfm
|
|
9
|
+
Project-URL: Repository, https://github.com/LeonEthan/timesfm
|
|
10
|
+
Project-URL: Issues, https://github.com/LeonEthan/timesfm/issues
|
|
11
|
+
Project-URL: Upstream, https://github.com/google-research/timesfm
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: numpy>=1.26.4
|
|
16
|
+
Requires-Dist: huggingface_hub[cli]>=0.23.0
|
|
17
|
+
Requires-Dist: safetensors>=0.5.3
|
|
18
|
+
Provides-Extra: torch
|
|
19
|
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
|
20
|
+
Provides-Extra: flax
|
|
21
|
+
Requires-Dist: flax; extra == "flax"
|
|
22
|
+
Requires-Dist: optax; extra == "flax"
|
|
23
|
+
Requires-Dist: einshape; extra == "flax"
|
|
24
|
+
Requires-Dist: orbax-checkpoint; extra == "flax"
|
|
25
|
+
Requires-Dist: jaxtyping; extra == "flax"
|
|
26
|
+
Requires-Dist: jax[cuda]; extra == "flax"
|
|
27
|
+
Provides-Extra: xreg
|
|
28
|
+
Requires-Dist: jax[cuda]; extra == "xreg"
|
|
29
|
+
Requires-Dist: scikit-learn; extra == "xreg"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# TimesFM
|
|
33
|
+
|
|
34
|
+
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation
|
|
35
|
+
model developed by Google Research for time-series forecasting.
|
|
36
|
+
|
|
37
|
+
* Paper:
|
|
38
|
+
[A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688),
|
|
39
|
+
ICML 2024.
|
|
40
|
+
* All checkpoints:
|
|
41
|
+
[TimesFM Hugging Face Collection](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6).
|
|
42
|
+
* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/).
|
|
43
|
+
* [TimesFM in BigQuery](https://cloud.google.com/bigquery/docs/timesfm-model):
|
|
44
|
+
an official Google product.
|
|
45
|
+
|
|
46
|
+
This open version is not an officially supported Google product.
|
|
47
|
+
|
|
48
|
+
## Fork Notice (tsagentkit)
|
|
49
|
+
|
|
50
|
+
This repository is a fork of `google-research/timesfm`, published to PyPI as
|
|
51
|
+
`tsagentkit-timesfm` to support the `tsagentkit` project without a GitHub
|
|
52
|
+
dependency. The Python import name remains `timesfm`.
|
|
53
|
+
|
|
54
|
+
**Latest Model Version:** TimesFM 2.5
|
|
55
|
+
|
|
56
|
+
**Archived Model Versions:**
|
|
57
|
+
|
|
58
|
+
- 1.0 and 2.0: relevant code archived in the sub directory `v1`. You can `pip
|
|
59
|
+
install timesfm==1.3.0` to install an older version of this package to load
|
|
60
|
+
them.
|
|
61
|
+
|
|
62
|
+
## Update - Oct. 29, 2025
|
|
63
|
+
|
|
64
|
+
Added back the covariate support through XReg for TimesFM 2.5.
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
## Update - Sept. 15, 2025
|
|
68
|
+
|
|
69
|
+
TimesFM 2.5 is out!
|
|
70
|
+
|
|
71
|
+
Comparing to TimesFM 2.0, this new 2.5 model:
|
|
72
|
+
|
|
73
|
+
- uses 200M parameters, down from 500M.
|
|
74
|
+
- supports up to 16k context length, up from 2048.
|
|
75
|
+
- supports continuous quantile forecast up to 1k horizon via an optional 30M
|
|
76
|
+
quantile head.
|
|
77
|
+
- gets rid of the `frequency` indicator.
|
|
78
|
+
- has a couple of new forecasting flags.
|
|
79
|
+
|
|
80
|
+
Along with the model upgrade we have also upgraded the inference API. This repo
|
|
81
|
+
will be under construction over the next few weeks to
|
|
82
|
+
|
|
83
|
+
1. add support for an upcoming Flax version of the model (faster inference).
|
|
84
|
+
2. add back covariate support.
|
|
85
|
+
3. populate more docstrings, docs and notebook.
|
|
86
|
+
|
|
87
|
+
### Install
|
|
88
|
+
|
|
89
|
+
1. Install from PyPI:
|
|
90
|
+
```shell
|
|
91
|
+
pip install "tsagentkit-timesfm[torch]"
|
|
92
|
+
# Or with flax
|
|
93
|
+
pip install "tsagentkit-timesfm[flax]"
|
|
94
|
+
# Or XReg support
|
|
95
|
+
pip install "tsagentkit-timesfm[xreg]"
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
2. [Optional] Install your preferred `torch` / `jax` backend based on your OS and accelerators
|
|
99
|
+
(CPU, GPU, TPU or Apple Silicon).:
|
|
100
|
+
|
|
101
|
+
- [Install PyTorch](https://pytorch.org/get-started/locally/).
|
|
102
|
+
- [Install Jax](https://docs.jax.dev/en/latest/installation.html#installation)
|
|
103
|
+
for Flax.
|
|
104
|
+
|
|
105
|
+
### From Source (optional)
|
|
106
|
+
|
|
107
|
+
```shell
|
|
108
|
+
git clone https://github.com/LeonEthan/timesfm.git
|
|
109
|
+
cd timesfm
|
|
110
|
+
|
|
111
|
+
uv venv
|
|
112
|
+
source .venv/bin/activate
|
|
113
|
+
|
|
114
|
+
uv pip install -e .[torch]
|
|
115
|
+
# Or with flax
|
|
116
|
+
uv pip install -e .[flax]
|
|
117
|
+
# Or XReg support
|
|
118
|
+
uv pip install -e .[xreg]
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### Code Example
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
import torch
|
|
125
|
+
import numpy as np
|
|
126
|
+
import timesfm
|
|
127
|
+
|
|
128
|
+
torch.set_float32_matmul_precision("high")
|
|
129
|
+
|
|
130
|
+
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")
|
|
131
|
+
|
|
132
|
+
model.compile(
|
|
133
|
+
timesfm.ForecastConfig(
|
|
134
|
+
max_context=1024,
|
|
135
|
+
max_horizon=256,
|
|
136
|
+
normalize_inputs=True,
|
|
137
|
+
use_continuous_quantile_head=True,
|
|
138
|
+
force_flip_invariance=True,
|
|
139
|
+
infer_is_positive=True,
|
|
140
|
+
fix_quantile_crossing=True,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
point_forecast, quantile_forecast = model.forecast(
|
|
144
|
+
horizon=12,
|
|
145
|
+
inputs=[
|
|
146
|
+
np.linspace(0, 1, 100),
|
|
147
|
+
np.sin(np.linspace(0, 20, 67)),
|
|
148
|
+
], # Two dummy inputs
|
|
149
|
+
)
|
|
150
|
+
point_forecast.shape # (2, 12)
|
|
151
|
+
quantile_forecast.shape # (2, 12, 10): mean, then 10th to 90th quantiles.
|
|
152
|
+
```
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
timesfm/__init__.py,sha256=B2pgyjc3hgIwQQqs5BuJh9tZy3cv1ox39bhyp_JIJBs,919
|
|
2
|
+
timesfm/configs.py,sha256=hhAsa7LEZgEct-gZqP6xl7HBKWR5gHNgHguE7JPjJxk,3617
|
|
3
|
+
timesfm/flax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
4
|
+
timesfm/flax/dense.py,sha256=vkcDueNfcXgR1HL61vCOWzwHCKH9TreiQANJzMK2_Pg,3522
|
|
5
|
+
timesfm/flax/normalization.py,sha256=KGeCNAftZVXtE_5_pfBecBNjJzDMp6JOCl2WoPdBvEM,2168
|
|
6
|
+
timesfm/flax/transformer.py,sha256=K9qO31l5bwUf9zl-myl8HDltAkC-O4CoSWg8up7vIm4,11506
|
|
7
|
+
timesfm/flax/util.py,sha256=2SFLO6xHT84whuNumVIwZTXXPag-XKrKK2hbj74grmM,3027
|
|
8
|
+
timesfm/timesfm_2p5/timesfm_2p5_base.py,sha256=hR-BIcNki2DVJ5k4B3FczT4C0bowpcJu381XR9FLwfw,15174
|
|
9
|
+
timesfm/timesfm_2p5/timesfm_2p5_flax.py,sha256=ts42ros5VP-IdldYqWtWqG9ckNm30Z0dEaXuJ_r9UAo,19375
|
|
10
|
+
timesfm/timesfm_2p5/timesfm_2p5_torch.py,sha256=MPwE6LL_6OOjKJ-qSINO4zGQB_Vho7LfUYjDeoAvboI,16943
|
|
11
|
+
timesfm/torch/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
12
|
+
timesfm/torch/dense.py,sha256=KIlJhThOqj3aexuILVu4Qm2LJ0stmgHLFMRlHAWE7B0,3109
|
|
13
|
+
timesfm/torch/normalization.py,sha256=w3tWr9GYrr9e9x7gTvJEpjMY2i5KBJfsmzGQiPbMUVQ,1204
|
|
14
|
+
timesfm/torch/transformer.py,sha256=dEoOHkHBJiFQmY-8Yhu35LsvucebD-RucN6nRFEZYrs,11720
|
|
15
|
+
timesfm/torch/util.py,sha256=xwCoJFnhFaHGyXlwWObjJyqnnocFoR1I2DSJtndm6Rc,2654
|
|
16
|
+
timesfm/utils/xreg_lib.py,sha256=RgmzwRqhdE2zZcih9wf3Kkm2iUi8XWseDr2IZ3Kj9NI,20721
|
|
17
|
+
tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
18
|
+
tsagentkit_timesfm-1.0.0.dist-info/METADATA,sha256=_hzD2eY53jd9XY2CmlVlhqumxRLmp-Phu_C4nhNo1ks,4823
|
|
19
|
+
tsagentkit_timesfm-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
20
|
+
tsagentkit_timesfm-1.0.0.dist-info/top_level.txt,sha256=5CtB5agsTzKTMKI3M87NsaJk7gh51fFKxmwRkdfeLdg,8
|
|
21
|
+
tsagentkit_timesfm-1.0.0.dist-info/RECORD,,
|