tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,602 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""TimesFM models in Flax."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import functools
|
|
19
|
+
import gc
|
|
20
|
+
import logging
|
|
21
|
+
import math
|
|
22
|
+
import os
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any, Callable, Dict
|
|
25
|
+
|
|
26
|
+
import einshape
|
|
27
|
+
from flax import nnx
|
|
28
|
+
import huggingface_hub
|
|
29
|
+
import jax
|
|
30
|
+
import jax.numpy as jnp
|
|
31
|
+
import jaxtyping
|
|
32
|
+
import numpy as np
|
|
33
|
+
import orbax.checkpoint as ocp
|
|
34
|
+
|
|
35
|
+
from .. import configs
|
|
36
|
+
from ..flax import dense, transformer, util
|
|
37
|
+
from . import timesfm_2p5_base
|
|
38
|
+
|
|
39
|
+
jax_einshape = einshape.jax_einshape
|
|
40
|
+
scan = util.scan_along_axis
|
|
41
|
+
revin = util.revin
|
|
42
|
+
|
|
43
|
+
Float = jaxtyping.Float
|
|
44
|
+
Bool = jaxtyping.Bool
|
|
45
|
+
Array = jaxtyping.Array
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def try_gc():
|
|
49
|
+
for d in jax.local_devices():
|
|
50
|
+
stats = d.memory_stats()
|
|
51
|
+
if stats is None:
|
|
52
|
+
return
|
|
53
|
+
if stats["bytes_in_use"] / stats["bytes_limit"] > 0.75:
|
|
54
|
+
gc.collect()
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@nnx.vmap(in_axes=(None, 0), out_axes=0)
|
|
59
|
+
def _create_stacked_transformers(
|
|
60
|
+
config: configs.StackedTransformersConfig, key: jax.Array
|
|
61
|
+
):
|
|
62
|
+
return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _scan_along_axis(f, init, xs, axis: int, **kwargs):
|
|
66
|
+
"""Scans along an axis."""
|
|
67
|
+
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
|
|
68
|
+
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
|
|
69
|
+
return (
|
|
70
|
+
carry,
|
|
71
|
+
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))
|
|
76
|
+
def _apply_stacked_transformers(
|
|
77
|
+
model: transformer.Transformer,
|
|
78
|
+
x: Float[Array, "b n d"],
|
|
79
|
+
m: Float[Array, "b n"],
|
|
80
|
+
decode_cache: util.DecodeCache | None = None,
|
|
81
|
+
) -> Float[Array, "b n d"]:
|
|
82
|
+
return model(x, m, decode_cache=decode_cache)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class TimesFM_2p5_200M_flax_module(nnx.Module): # pylint: disable=invalid-name
|
|
86
|
+
"""TimesFM 2.5 with 200M parameters."""
|
|
87
|
+
|
|
88
|
+
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
|
|
89
|
+
decode_index: int = 5
|
|
90
|
+
compiled_decode: Callable[..., Any] | None = None
|
|
91
|
+
backend: str = ""
|
|
92
|
+
context: int = 0
|
|
93
|
+
horizon: int = 0
|
|
94
|
+
per_core_batch_size: int = 0
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.backend = jax.devices()[0].platform
|
|
99
|
+
self.num_devices = len(jax.devices(self.backend))
|
|
100
|
+
|
|
101
|
+
# Names constants.
|
|
102
|
+
self.p = self.config.input_patch_len # 32
|
|
103
|
+
self.o = self.config.output_patch_len # 128
|
|
104
|
+
self.os = self.config.output_quantile_len # 1024
|
|
105
|
+
self.m = self.o // self.p # 4
|
|
106
|
+
self.x = self.config.stacked_transformers.num_layers # 20
|
|
107
|
+
self.h = self.config.stacked_transformers.transformer.num_heads # 16
|
|
108
|
+
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
|
|
109
|
+
self.hd = self.md // self.h # 80
|
|
110
|
+
self.q = len(self.config.quantiles) + 1 # 10
|
|
111
|
+
self.aridx = self.config.decode_index # 5
|
|
112
|
+
|
|
113
|
+
# Layers.
|
|
114
|
+
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
|
|
115
|
+
self.stacked_xf = _create_stacked_transformers(
|
|
116
|
+
self.config.stacked_transformers,
|
|
117
|
+
jax.random.split(jax.random.key(42), self.x),
|
|
118
|
+
)
|
|
119
|
+
self.output_projection_point = dense.ResidualBlock(
|
|
120
|
+
self.config.output_projection_point
|
|
121
|
+
)
|
|
122
|
+
self.output_projection_quantiles = dense.ResidualBlock(
|
|
123
|
+
self.config.output_projection_quantiles
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def __call__(
|
|
127
|
+
self,
|
|
128
|
+
inputs: Float[Array, "b n p"],
|
|
129
|
+
masks: Bool[Array, "b n p"],
|
|
130
|
+
decode_cache: util.DecodeCache | None = None,
|
|
131
|
+
):
|
|
132
|
+
tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)
|
|
133
|
+
input_embeddings = self.tokenizer(tokenizer_inputs)
|
|
134
|
+
if decode_cache is None:
|
|
135
|
+
decode_cache = [None] * self.x
|
|
136
|
+
output_embeddings, decode_cache = _apply_stacked_transformers(
|
|
137
|
+
self.stacked_xf, input_embeddings, masks[..., -1], decode_cache
|
|
138
|
+
)
|
|
139
|
+
output_ts = self.output_projection_point(output_embeddings)
|
|
140
|
+
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
|
|
141
|
+
return (
|
|
142
|
+
input_embeddings,
|
|
143
|
+
output_embeddings,
|
|
144
|
+
output_ts,
|
|
145
|
+
output_quantile_spread,
|
|
146
|
+
), decode_cache
|
|
147
|
+
|
|
148
|
+
@nnx.jit(static_argnames=("horizon",))
|
|
149
|
+
def decode(self, horizon: int, inputs, masks):
|
|
150
|
+
batch_size, context = inputs.shape[0], inputs.shape[1]
|
|
151
|
+
num_decode_steps = (horizon - 1) // self.o
|
|
152
|
+
num_input_patches = context // self.p
|
|
153
|
+
decode_cache_size = num_input_patches + num_decode_steps * self.m
|
|
154
|
+
|
|
155
|
+
# Prefill
|
|
156
|
+
patched_inputs = jax_einshape("b(np)->bnp", inputs, b=batch_size, p=self.p)
|
|
157
|
+
patched_masks = jax_einshape("b(np)->bnp", masks, b=batch_size, p=self.p)
|
|
158
|
+
(last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(
|
|
159
|
+
lambda carry, xs: util.update_running_stats(*carry, *xs),
|
|
160
|
+
init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),
|
|
161
|
+
xs=(patched_inputs, patched_masks),
|
|
162
|
+
axis=1,
|
|
163
|
+
)
|
|
164
|
+
decode_cache = util.DecodeCache(
|
|
165
|
+
next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
|
|
166
|
+
num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
|
|
167
|
+
key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
|
|
168
|
+
value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
|
|
169
|
+
)
|
|
170
|
+
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
|
|
171
|
+
normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)
|
|
172
|
+
(_, _, normed_outputs, normed_quantile_spread), decode_cache = self(
|
|
173
|
+
normed_inputs, patched_masks, decode_cache
|
|
174
|
+
)
|
|
175
|
+
renormed_outputs = jax_einshape(
|
|
176
|
+
"bn(oq)->bnoq",
|
|
177
|
+
revin(normed_outputs, context_mu, context_sigma, reverse=True),
|
|
178
|
+
o=self.o,
|
|
179
|
+
q=self.q,
|
|
180
|
+
)
|
|
181
|
+
renormed_quantile_spread = jax_einshape(
|
|
182
|
+
"bn(oq)->bnoq",
|
|
183
|
+
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
|
|
184
|
+
o=self.os,
|
|
185
|
+
q=self.q,
|
|
186
|
+
)[:, -1, ...]
|
|
187
|
+
|
|
188
|
+
# Autogressive decode
|
|
189
|
+
@nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))
|
|
190
|
+
def _ar_decode(module, carry, unused_iter):
|
|
191
|
+
last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry
|
|
192
|
+
new_patched_input = jax_einshape(
|
|
193
|
+
"b(mp)->bmp", last_renormed_output, m=module.m, p=module.p
|
|
194
|
+
)
|
|
195
|
+
new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)
|
|
196
|
+
carry_stats, (_, new_mu, new_sigma) = scan(
|
|
197
|
+
lambda carry, xs: util.update_running_stats(*carry, *xs),
|
|
198
|
+
init=(last_n, last_mu, last_sigma),
|
|
199
|
+
xs=(new_patched_input, new_mask),
|
|
200
|
+
axis=1,
|
|
201
|
+
)
|
|
202
|
+
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
|
|
203
|
+
(_, _, new_normed_output, _), decode_cache = module(
|
|
204
|
+
new_normed_input, new_mask, decode_cache
|
|
205
|
+
)
|
|
206
|
+
new_renormed_output = jax_einshape(
|
|
207
|
+
"bm(oq)->bmoq",
|
|
208
|
+
revin(new_normed_output, new_mu, new_sigma, reverse=True),
|
|
209
|
+
o=module.o,
|
|
210
|
+
q=module.q,
|
|
211
|
+
)[..., -1, :, :]
|
|
212
|
+
|
|
213
|
+
return (
|
|
214
|
+
(
|
|
215
|
+
new_renormed_output[..., module.decode_index],
|
|
216
|
+
carry_stats,
|
|
217
|
+
decode_cache,
|
|
218
|
+
),
|
|
219
|
+
new_renormed_output,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if num_decode_steps > 0:
|
|
223
|
+
_, ar_renormed_outputs = _ar_decode(
|
|
224
|
+
self,
|
|
225
|
+
(
|
|
226
|
+
renormed_outputs[..., -1, :, self.decode_index],
|
|
227
|
+
(last_n, last_mu, last_sigma),
|
|
228
|
+
decode_cache,
|
|
229
|
+
),
|
|
230
|
+
jnp.arange(num_decode_steps),
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
ar_renormed_outputs = None
|
|
234
|
+
|
|
235
|
+
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
|
|
236
|
+
|
|
237
|
+
def compile(
|
|
238
|
+
self,
|
|
239
|
+
context: int,
|
|
240
|
+
horizon: int,
|
|
241
|
+
per_core_batch_size: int = 1,
|
|
242
|
+
):
|
|
243
|
+
if context % self.p != 0:
|
|
244
|
+
logging.info(
|
|
245
|
+
"When compiling, context needs to be multiple of the patch size %d."
|
|
246
|
+
" Modifying context to %d.",
|
|
247
|
+
self.p,
|
|
248
|
+
context := math.ceil(context / self.p) * self.p,
|
|
249
|
+
)
|
|
250
|
+
if horizon % self.o != 0:
|
|
251
|
+
logging.info(
|
|
252
|
+
"When compiling, horizon needs to be multiple of the output patch"
|
|
253
|
+
" size %d. Modifying horizon to %d.",
|
|
254
|
+
self.o,
|
|
255
|
+
horizon := math.ceil(horizon / self.o) * self.o,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self.context = context
|
|
259
|
+
self.horizon = horizon
|
|
260
|
+
self.per_core_batch_size = per_core_batch_size
|
|
261
|
+
|
|
262
|
+
@nnx.pmap(
|
|
263
|
+
in_axes=(None, None, 0, 0),
|
|
264
|
+
out_axes=(0, 0, 0),
|
|
265
|
+
devices=jax.devices(self.backend),
|
|
266
|
+
axis_size=self.num_devices,
|
|
267
|
+
static_broadcasted_argnums=(1,),
|
|
268
|
+
axis_name="global_batch",
|
|
269
|
+
)
|
|
270
|
+
def compiled_decode_kernel(model, horizon, inputs, masks):
|
|
271
|
+
return model.decode(horizon, inputs, masks)
|
|
272
|
+
|
|
273
|
+
self.compiled_decode = functools.partial(compiled_decode_kernel, self)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _flip_quantile_fn(x):
|
|
277
|
+
return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@functools.partial(
|
|
281
|
+
jax.jit,
|
|
282
|
+
donate_argnums=(0, 1, 2),
|
|
283
|
+
)
|
|
284
|
+
def _force_flip_invariance_fn(
|
|
285
|
+
flipped_pf_outputs,
|
|
286
|
+
flipped_quantile_spreads,
|
|
287
|
+
flipped_ar_outputs,
|
|
288
|
+
):
|
|
289
|
+
"""Forces flip invariance."""
|
|
290
|
+
flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)
|
|
291
|
+
flipped_pf_outputs = jax_einshape("tb...->(tb)...", flipped_pf_outputs)
|
|
292
|
+
flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)
|
|
293
|
+
flipped_quantile_spreads = jax_einshape("tb...->(tb)...", flipped_quantile_spreads)
|
|
294
|
+
to_concat = [flipped_pf_outputs[:, -1, ...]]
|
|
295
|
+
if flipped_ar_outputs is not None:
|
|
296
|
+
flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)
|
|
297
|
+
flipped_ar_outputs = jax_einshape("tbno...->(tb)(no)...", flipped_ar_outputs)
|
|
298
|
+
to_concat.append(flipped_ar_outputs)
|
|
299
|
+
flipped_full_forecast = jnp.concatenate(to_concat, axis=1)
|
|
300
|
+
|
|
301
|
+
return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@functools.partial(
|
|
305
|
+
jax.jit,
|
|
306
|
+
static_argnames=("max_horizon",),
|
|
307
|
+
donate_argnums=(0,),
|
|
308
|
+
)
|
|
309
|
+
def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):
|
|
310
|
+
"""Uses continuous quantile head."""
|
|
311
|
+
to_stack = [full_forecast[..., :max_horizon, 0]]
|
|
312
|
+
for quantile_index in [1, 2, 3, 4]:
|
|
313
|
+
to_stack.append(
|
|
314
|
+
quantile_spreads[:, :max_horizon, quantile_index]
|
|
315
|
+
- quantile_spreads[:, :max_horizon, 5]
|
|
316
|
+
+ full_forecast[:, :max_horizon, 5]
|
|
317
|
+
)
|
|
318
|
+
to_stack.append(full_forecast[..., :max_horizon, 5])
|
|
319
|
+
for quantile_index in [6, 7, 8, 9]:
|
|
320
|
+
to_stack.append(
|
|
321
|
+
quantile_spreads[:, :max_horizon, quantile_index]
|
|
322
|
+
- quantile_spreads[:, :max_horizon, 5]
|
|
323
|
+
+ full_forecast[:, :max_horizon, 5]
|
|
324
|
+
)
|
|
325
|
+
return jnp.stack(to_stack, axis=-1)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@functools.partial(jax.jit, donate_argnums=(0,))
|
|
329
|
+
def _fix_quantile_crossing_fn(full_forecast):
|
|
330
|
+
"""Fixes quantile crossing."""
|
|
331
|
+
lower_quantiles = _scan_along_axis(
|
|
332
|
+
lambda carry, x: (w := jnp.minimum(carry, x), w),
|
|
333
|
+
init=full_forecast[..., 5],
|
|
334
|
+
xs=full_forecast[..., 1:5],
|
|
335
|
+
axis=-1,
|
|
336
|
+
reverse=True,
|
|
337
|
+
)[1]
|
|
338
|
+
upper_quantiles = _scan_along_axis(
|
|
339
|
+
lambda carry, x: (w := jnp.maximum(carry, x), w),
|
|
340
|
+
init=full_forecast[..., 5],
|
|
341
|
+
xs=full_forecast[..., 6:10],
|
|
342
|
+
axis=-1,
|
|
343
|
+
reverse=False,
|
|
344
|
+
)[1]
|
|
345
|
+
return jnp.concatenate(
|
|
346
|
+
[
|
|
347
|
+
full_forecast[..., :1],
|
|
348
|
+
lower_quantiles,
|
|
349
|
+
full_forecast[..., 5:6],
|
|
350
|
+
upper_quantiles,
|
|
351
|
+
],
|
|
352
|
+
axis=-1,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@functools.partial(jax.jit, static_argnames=("fc",), donate_argnums=(1, 2))
|
|
357
|
+
def _before_model_decode(fc, inputs, masks):
|
|
358
|
+
"""All Jax steps before model decode call."""
|
|
359
|
+
if fc.infer_is_positive:
|
|
360
|
+
is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)
|
|
361
|
+
else:
|
|
362
|
+
is_positive = None
|
|
363
|
+
|
|
364
|
+
if fc.normalize_inputs:
|
|
365
|
+
mu = jnp.mean(inputs, axis=-1, keepdims=True)
|
|
366
|
+
sigma = jnp.std(inputs, axis=-1, keepdims=True)
|
|
367
|
+
inputs = revin(inputs, mu, sigma, reverse=False)
|
|
368
|
+
else:
|
|
369
|
+
mu, sigma = None, None
|
|
370
|
+
|
|
371
|
+
inputs = jax_einshape("(tb)...->tb...", inputs, b=fc.per_core_batch_size)
|
|
372
|
+
masks = jax_einshape("(tb)...->tb...", masks, b=fc.per_core_batch_size)
|
|
373
|
+
|
|
374
|
+
return inputs, masks, is_positive, mu, sigma
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@functools.partial(
|
|
378
|
+
jax.jit,
|
|
379
|
+
static_argnames=(
|
|
380
|
+
"fc",
|
|
381
|
+
"p",
|
|
382
|
+
),
|
|
383
|
+
donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),
|
|
384
|
+
)
|
|
385
|
+
def _after_model_decode(
|
|
386
|
+
fc,
|
|
387
|
+
pf_outputs,
|
|
388
|
+
quantile_spreads,
|
|
389
|
+
ar_outputs,
|
|
390
|
+
flipped_pf_outputs,
|
|
391
|
+
flipped_quantile_spreads,
|
|
392
|
+
flipped_ar_outputs,
|
|
393
|
+
is_positive,
|
|
394
|
+
mu,
|
|
395
|
+
sigma,
|
|
396
|
+
p,
|
|
397
|
+
):
|
|
398
|
+
"""All Jax steps after model decode call."""
|
|
399
|
+
# t: num_devices, b: per_core_batch_size
|
|
400
|
+
pf_outputs = jax_einshape("tb...->(tb)...", pf_outputs)
|
|
401
|
+
quantile_spreads = jax_einshape("tb...->(tb)...", quantile_spreads)
|
|
402
|
+
to_concat = [pf_outputs[:, -1, ...]]
|
|
403
|
+
if ar_outputs is not None:
|
|
404
|
+
ar_outputs = jax_einshape("tbno...->(tb)(no)...", ar_outputs)
|
|
405
|
+
to_concat.append(ar_outputs)
|
|
406
|
+
full_forecast = jnp.concatenate(to_concat, axis=1)
|
|
407
|
+
|
|
408
|
+
if fc.force_flip_invariance:
|
|
409
|
+
(
|
|
410
|
+
flipped_quantile_spreads,
|
|
411
|
+
flipped_pf_outputs,
|
|
412
|
+
flipped_full_forecast,
|
|
413
|
+
) = _force_flip_invariance_fn(
|
|
414
|
+
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs
|
|
415
|
+
)
|
|
416
|
+
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
|
|
417
|
+
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
|
|
418
|
+
full_forecast = (full_forecast - flipped_full_forecast) / 2
|
|
419
|
+
|
|
420
|
+
if fc.use_continuous_quantile_head:
|
|
421
|
+
full_forecast = _use_continuous_quantile_head_fn(
|
|
422
|
+
full_forecast, quantile_spreads, fc.max_horizon
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
if fc.return_backcast:
|
|
426
|
+
full_backcast = jax_einshape("...npq->...(np)q", pf_outputs[:, :-1, :p, :])
|
|
427
|
+
full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)
|
|
428
|
+
|
|
429
|
+
if fc.fix_quantile_crossing:
|
|
430
|
+
full_forecast = _fix_quantile_crossing_fn(full_forecast)
|
|
431
|
+
|
|
432
|
+
if fc.normalize_inputs:
|
|
433
|
+
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
|
|
434
|
+
|
|
435
|
+
if is_positive is not None:
|
|
436
|
+
full_forecast = jnp.where(
|
|
437
|
+
is_positive[..., None],
|
|
438
|
+
jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),
|
|
439
|
+
full_forecast,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return full_forecast
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
|
|
446
|
+
"""Flax implementation of TimesFM 2.5 with 200M parameters."""
|
|
447
|
+
|
|
448
|
+
model: nnx.Module = TimesFM_2p5_200M_flax_module()
|
|
449
|
+
|
|
450
|
+
@classmethod
|
|
451
|
+
def from_pretrained(
|
|
452
|
+
cls,
|
|
453
|
+
model_id: str = "google/timesfm-2.5-200m-flax",
|
|
454
|
+
*,
|
|
455
|
+
revision: str | None = None,
|
|
456
|
+
cache_dir: str | Path | None = None,
|
|
457
|
+
force_download: bool = False,
|
|
458
|
+
proxies: Dict | None = None,
|
|
459
|
+
resume_download: bool | None = None,
|
|
460
|
+
local_files_only: bool | None = None,
|
|
461
|
+
token: str | None = None,
|
|
462
|
+
**model_kwargs,
|
|
463
|
+
):
|
|
464
|
+
"""Loads a Flax TimesFM model."""
|
|
465
|
+
|
|
466
|
+
# Create an instance of the model wrapper class.
|
|
467
|
+
instance = cls(**model_kwargs)
|
|
468
|
+
|
|
469
|
+
# Determine the path to the model weights.
|
|
470
|
+
model_file_path = ""
|
|
471
|
+
if os.path.isdir(model_id):
|
|
472
|
+
logging.info("Loading checkpoint from local directory: %s", model_id)
|
|
473
|
+
model_file_path = model_id
|
|
474
|
+
else:
|
|
475
|
+
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
|
|
476
|
+
model_file_path = huggingface_hub.snapshot_download(
|
|
477
|
+
repo_id=model_id,
|
|
478
|
+
revision=revision,
|
|
479
|
+
cache_dir=cache_dir,
|
|
480
|
+
force_download=force_download,
|
|
481
|
+
proxies=proxies,
|
|
482
|
+
resume_download=resume_download,
|
|
483
|
+
token=token,
|
|
484
|
+
local_files_only=local_files_only,
|
|
485
|
+
)
|
|
486
|
+
logging.info("Loading checkpoint from: %s", model_file_path)
|
|
487
|
+
|
|
488
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
489
|
+
graph, state = nnx.split(instance.model)
|
|
490
|
+
state = checkpointer.restore(model_file_path, state)
|
|
491
|
+
instance.model = nnx.merge(graph, state)
|
|
492
|
+
return instance
|
|
493
|
+
|
|
494
|
+
def compile(
|
|
495
|
+
self,
|
|
496
|
+
forecast_config: configs.ForecastConfig,
|
|
497
|
+
dryrun: bool = True,
|
|
498
|
+
**kwargs
|
|
499
|
+
):
|
|
500
|
+
# Acrobym used during validation.
|
|
501
|
+
print("Compiling model...")
|
|
502
|
+
|
|
503
|
+
fc = forecast_config
|
|
504
|
+
if fc.max_context % self.model.p != 0:
|
|
505
|
+
logging.info(
|
|
506
|
+
"When compiling, max context needs to be multiple of the patch size"
|
|
507
|
+
" %d. Using max context = %d instead.",
|
|
508
|
+
self.model.p,
|
|
509
|
+
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
|
|
510
|
+
)
|
|
511
|
+
fc = dataclasses.replace(fc, max_context=new_context)
|
|
512
|
+
if fc.max_horizon % self.model.o != 0:
|
|
513
|
+
logging.info(
|
|
514
|
+
"When compiling, max horizon needs to be multiple of the output patch"
|
|
515
|
+
" size %d. Using max horizon = %d instead.",
|
|
516
|
+
self.model.o,
|
|
517
|
+
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
|
|
518
|
+
)
|
|
519
|
+
fc = dataclasses.replace(fc, max_horizon=new_horizon)
|
|
520
|
+
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
|
|
521
|
+
raise ValueError(
|
|
522
|
+
"Context + horizon must be less than the context limit."
|
|
523
|
+
f" {fc.max_context} + {fc.max_horizon} >"
|
|
524
|
+
f" {self.model.config.context_limit}."
|
|
525
|
+
)
|
|
526
|
+
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
|
|
527
|
+
raise ValueError(
|
|
528
|
+
f"Continuous quantile head is not supported for horizons > {self.model.os}."
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
self.forecast_config = fc
|
|
532
|
+
self.model.compile(
|
|
533
|
+
context=self.forecast_config.max_context,
|
|
534
|
+
horizon=self.forecast_config.max_horizon,
|
|
535
|
+
per_core_batch_size=fc.per_core_batch_size,
|
|
536
|
+
)
|
|
537
|
+
self.per_core_batch_size = self.forecast_config.per_core_batch_size
|
|
538
|
+
self.num_devices = self.model.num_devices
|
|
539
|
+
self.global_batch_size = (
|
|
540
|
+
self.forecast_config.per_core_batch_size * self.model.num_devices
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
def compiled_decode_kernel(fc, horizon, inputs, masks):
|
|
544
|
+
inputs = jnp.array(inputs, dtype=jnp.float32)
|
|
545
|
+
masks = jnp.array(masks, dtype=jnp.bool)
|
|
546
|
+
if horizon > fc.max_horizon:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
|
|
549
|
+
)
|
|
550
|
+
to_trim = fc.max_horizon - horizon
|
|
551
|
+
|
|
552
|
+
inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)
|
|
553
|
+
|
|
554
|
+
pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(
|
|
555
|
+
fc.max_horizon, inputs, masks
|
|
556
|
+
)
|
|
557
|
+
if fc.force_flip_invariance:
|
|
558
|
+
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
|
|
559
|
+
self.model.compiled_decode(fc.max_horizon, -inputs, masks)
|
|
560
|
+
)
|
|
561
|
+
else:
|
|
562
|
+
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
|
|
563
|
+
None,
|
|
564
|
+
None,
|
|
565
|
+
None,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
full_forecast = _after_model_decode(
|
|
569
|
+
fc,
|
|
570
|
+
pf_outputs,
|
|
571
|
+
quantile_spreads,
|
|
572
|
+
ar_outputs,
|
|
573
|
+
flipped_pf_outputs,
|
|
574
|
+
flipped_quantile_spreads,
|
|
575
|
+
flipped_ar_outputs,
|
|
576
|
+
is_positive,
|
|
577
|
+
mu,
|
|
578
|
+
sigma,
|
|
579
|
+
self.model.p,
|
|
580
|
+
)
|
|
581
|
+
full_forecast_np = np.array(full_forecast)
|
|
582
|
+
del full_forecast
|
|
583
|
+
try_gc()
|
|
584
|
+
if to_trim > 0:
|
|
585
|
+
full_forecast_np = full_forecast_np[..., :-to_trim, :]
|
|
586
|
+
return full_forecast_np[..., 5], full_forecast_np
|
|
587
|
+
|
|
588
|
+
self.compiled_decode = functools.partial(
|
|
589
|
+
compiled_decode_kernel, self.forecast_config
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
if dryrun:
|
|
593
|
+
_ = self.compiled_decode(
|
|
594
|
+
self.forecast_config.max_horizon,
|
|
595
|
+
jnp.zeros(
|
|
596
|
+
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32
|
|
597
|
+
),
|
|
598
|
+
jnp.zeros(
|
|
599
|
+
(self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool
|
|
600
|
+
),
|
|
601
|
+
)
|
|
602
|
+
print("Compiling done.")
|