tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,602 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TimesFM models in Flax."""
16
+
17
+ import dataclasses
18
+ import functools
19
+ import gc
20
+ import logging
21
+ import math
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Any, Callable, Dict
25
+
26
+ import einshape
27
+ from flax import nnx
28
+ import huggingface_hub
29
+ import jax
30
+ import jax.numpy as jnp
31
+ import jaxtyping
32
+ import numpy as np
33
+ import orbax.checkpoint as ocp
34
+
35
+ from .. import configs
36
+ from ..flax import dense, transformer, util
37
+ from . import timesfm_2p5_base
38
+
39
+ jax_einshape = einshape.jax_einshape
40
+ scan = util.scan_along_axis
41
+ revin = util.revin
42
+
43
+ Float = jaxtyping.Float
44
+ Bool = jaxtyping.Bool
45
+ Array = jaxtyping.Array
46
+
47
+
48
+ def try_gc():
49
+ for d in jax.local_devices():
50
+ stats = d.memory_stats()
51
+ if stats is None:
52
+ return
53
+ if stats["bytes_in_use"] / stats["bytes_limit"] > 0.75:
54
+ gc.collect()
55
+ break
56
+
57
+
58
+ @nnx.vmap(in_axes=(None, 0), out_axes=0)
59
+ def _create_stacked_transformers(
60
+ config: configs.StackedTransformersConfig, key: jax.Array
61
+ ):
62
+ return transformer.Transformer(config.transformer, rngs=nnx.Rngs(key))
63
+
64
+
65
+ def _scan_along_axis(f, init, xs, axis: int, **kwargs):
66
+ """Scans along an axis."""
67
+ moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
68
+ carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
69
+ return (
70
+ carry,
71
+ jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
72
+ )
73
+
74
+
75
+ @nnx.scan(in_axes=(0, nnx.Carry, None, 0), out_axes=(nnx.Carry, 0))
76
+ def _apply_stacked_transformers(
77
+ model: transformer.Transformer,
78
+ x: Float[Array, "b n d"],
79
+ m: Float[Array, "b n"],
80
+ decode_cache: util.DecodeCache | None = None,
81
+ ) -> Float[Array, "b n d"]:
82
+ return model(x, m, decode_cache=decode_cache)
83
+
84
+
85
+ class TimesFM_2p5_200M_flax_module(nnx.Module): # pylint: disable=invalid-name
86
+ """TimesFM 2.5 with 200M parameters."""
87
+
88
+ config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
89
+ decode_index: int = 5
90
+ compiled_decode: Callable[..., Any] | None = None
91
+ backend: str = ""
92
+ context: int = 0
93
+ horizon: int = 0
94
+ per_core_batch_size: int = 0
95
+
96
+ def __init__(self):
97
+ super().__init__()
98
+ self.backend = jax.devices()[0].platform
99
+ self.num_devices = len(jax.devices(self.backend))
100
+
101
+ # Names constants.
102
+ self.p = self.config.input_patch_len # 32
103
+ self.o = self.config.output_patch_len # 128
104
+ self.os = self.config.output_quantile_len # 1024
105
+ self.m = self.o // self.p # 4
106
+ self.x = self.config.stacked_transformers.num_layers # 20
107
+ self.h = self.config.stacked_transformers.transformer.num_heads # 16
108
+ self.md = self.config.stacked_transformers.transformer.model_dims # 1280
109
+ self.hd = self.md // self.h # 80
110
+ self.q = len(self.config.quantiles) + 1 # 10
111
+ self.aridx = self.config.decode_index # 5
112
+
113
+ # Layers.
114
+ self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
115
+ self.stacked_xf = _create_stacked_transformers(
116
+ self.config.stacked_transformers,
117
+ jax.random.split(jax.random.key(42), self.x),
118
+ )
119
+ self.output_projection_point = dense.ResidualBlock(
120
+ self.config.output_projection_point
121
+ )
122
+ self.output_projection_quantiles = dense.ResidualBlock(
123
+ self.config.output_projection_quantiles
124
+ )
125
+
126
+ def __call__(
127
+ self,
128
+ inputs: Float[Array, "b n p"],
129
+ masks: Bool[Array, "b n p"],
130
+ decode_cache: util.DecodeCache | None = None,
131
+ ):
132
+ tokenizer_inputs = jnp.concatenate([inputs, masks.astype(inputs.dtype)], axis=-1)
133
+ input_embeddings = self.tokenizer(tokenizer_inputs)
134
+ if decode_cache is None:
135
+ decode_cache = [None] * self.x
136
+ output_embeddings, decode_cache = _apply_stacked_transformers(
137
+ self.stacked_xf, input_embeddings, masks[..., -1], decode_cache
138
+ )
139
+ output_ts = self.output_projection_point(output_embeddings)
140
+ output_quantile_spread = self.output_projection_quantiles(output_embeddings)
141
+ return (
142
+ input_embeddings,
143
+ output_embeddings,
144
+ output_ts,
145
+ output_quantile_spread,
146
+ ), decode_cache
147
+
148
+ @nnx.jit(static_argnames=("horizon",))
149
+ def decode(self, horizon: int, inputs, masks):
150
+ batch_size, context = inputs.shape[0], inputs.shape[1]
151
+ num_decode_steps = (horizon - 1) // self.o
152
+ num_input_patches = context // self.p
153
+ decode_cache_size = num_input_patches + num_decode_steps * self.m
154
+
155
+ # Prefill
156
+ patched_inputs = jax_einshape("b(np)->bnp", inputs, b=batch_size, p=self.p)
157
+ patched_masks = jax_einshape("b(np)->bnp", masks, b=batch_size, p=self.p)
158
+ (last_n, last_mu, last_sigma), (_, context_mu, context_sigma) = scan(
159
+ lambda carry, xs: util.update_running_stats(*carry, *xs),
160
+ init=(zero := jnp.zeros(shape=(batch_size)), zero, zero),
161
+ xs=(patched_inputs, patched_masks),
162
+ axis=1,
163
+ )
164
+ decode_cache = util.DecodeCache(
165
+ next_index=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
166
+ num_masked=jnp.zeros(shape=(self.x, batch_size), dtype=jnp.int32),
167
+ key=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
168
+ value=jnp.zeros(shape=(self.x, batch_size, decode_cache_size, self.h, self.hd)),
169
+ )
170
+ normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
171
+ normed_inputs = jnp.where(patched_masks, 0.0, normed_inputs)
172
+ (_, _, normed_outputs, normed_quantile_spread), decode_cache = self(
173
+ normed_inputs, patched_masks, decode_cache
174
+ )
175
+ renormed_outputs = jax_einshape(
176
+ "bn(oq)->bnoq",
177
+ revin(normed_outputs, context_mu, context_sigma, reverse=True),
178
+ o=self.o,
179
+ q=self.q,
180
+ )
181
+ renormed_quantile_spread = jax_einshape(
182
+ "bn(oq)->bnoq",
183
+ revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
184
+ o=self.os,
185
+ q=self.q,
186
+ )[:, -1, ...]
187
+
188
+ # Autogressive decode
189
+ @nnx.scan(in_axes=(None, nnx.Carry, 0), out_axes=(nnx.Carry, 1))
190
+ def _ar_decode(module, carry, unused_iter):
191
+ last_renormed_output, (last_n, last_mu, last_sigma), decode_cache = carry
192
+ new_patched_input = jax_einshape(
193
+ "b(mp)->bmp", last_renormed_output, m=module.m, p=module.p
194
+ )
195
+ new_mask = jnp.zeros_like(new_patched_input, dtype=jnp.bool)
196
+ carry_stats, (_, new_mu, new_sigma) = scan(
197
+ lambda carry, xs: util.update_running_stats(*carry, *xs),
198
+ init=(last_n, last_mu, last_sigma),
199
+ xs=(new_patched_input, new_mask),
200
+ axis=1,
201
+ )
202
+ new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
203
+ (_, _, new_normed_output, _), decode_cache = module(
204
+ new_normed_input, new_mask, decode_cache
205
+ )
206
+ new_renormed_output = jax_einshape(
207
+ "bm(oq)->bmoq",
208
+ revin(new_normed_output, new_mu, new_sigma, reverse=True),
209
+ o=module.o,
210
+ q=module.q,
211
+ )[..., -1, :, :]
212
+
213
+ return (
214
+ (
215
+ new_renormed_output[..., module.decode_index],
216
+ carry_stats,
217
+ decode_cache,
218
+ ),
219
+ new_renormed_output,
220
+ )
221
+
222
+ if num_decode_steps > 0:
223
+ _, ar_renormed_outputs = _ar_decode(
224
+ self,
225
+ (
226
+ renormed_outputs[..., -1, :, self.decode_index],
227
+ (last_n, last_mu, last_sigma),
228
+ decode_cache,
229
+ ),
230
+ jnp.arange(num_decode_steps),
231
+ )
232
+ else:
233
+ ar_renormed_outputs = None
234
+
235
+ return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
236
+
237
+ def compile(
238
+ self,
239
+ context: int,
240
+ horizon: int,
241
+ per_core_batch_size: int = 1,
242
+ ):
243
+ if context % self.p != 0:
244
+ logging.info(
245
+ "When compiling, context needs to be multiple of the patch size %d."
246
+ " Modifying context to %d.",
247
+ self.p,
248
+ context := math.ceil(context / self.p) * self.p,
249
+ )
250
+ if horizon % self.o != 0:
251
+ logging.info(
252
+ "When compiling, horizon needs to be multiple of the output patch"
253
+ " size %d. Modifying horizon to %d.",
254
+ self.o,
255
+ horizon := math.ceil(horizon / self.o) * self.o,
256
+ )
257
+
258
+ self.context = context
259
+ self.horizon = horizon
260
+ self.per_core_batch_size = per_core_batch_size
261
+
262
+ @nnx.pmap(
263
+ in_axes=(None, None, 0, 0),
264
+ out_axes=(0, 0, 0),
265
+ devices=jax.devices(self.backend),
266
+ axis_size=self.num_devices,
267
+ static_broadcasted_argnums=(1,),
268
+ axis_name="global_batch",
269
+ )
270
+ def compiled_decode_kernel(model, horizon, inputs, masks):
271
+ return model.decode(horizon, inputs, masks)
272
+
273
+ self.compiled_decode = functools.partial(compiled_decode_kernel, self)
274
+
275
+
276
+ def _flip_quantile_fn(x):
277
+ return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
278
+
279
+
280
+ @functools.partial(
281
+ jax.jit,
282
+ donate_argnums=(0, 1, 2),
283
+ )
284
+ def _force_flip_invariance_fn(
285
+ flipped_pf_outputs,
286
+ flipped_quantile_spreads,
287
+ flipped_ar_outputs,
288
+ ):
289
+ """Forces flip invariance."""
290
+ flipped_pf_outputs = _flip_quantile_fn(flipped_pf_outputs)
291
+ flipped_pf_outputs = jax_einshape("tb...->(tb)...", flipped_pf_outputs)
292
+ flipped_quantile_spreads = _flip_quantile_fn(flipped_quantile_spreads)
293
+ flipped_quantile_spreads = jax_einshape("tb...->(tb)...", flipped_quantile_spreads)
294
+ to_concat = [flipped_pf_outputs[:, -1, ...]]
295
+ if flipped_ar_outputs is not None:
296
+ flipped_ar_outputs = _flip_quantile_fn(flipped_ar_outputs)
297
+ flipped_ar_outputs = jax_einshape("tbno...->(tb)(no)...", flipped_ar_outputs)
298
+ to_concat.append(flipped_ar_outputs)
299
+ flipped_full_forecast = jnp.concatenate(to_concat, axis=1)
300
+
301
+ return flipped_quantile_spreads, flipped_pf_outputs, flipped_full_forecast
302
+
303
+
304
+ @functools.partial(
305
+ jax.jit,
306
+ static_argnames=("max_horizon",),
307
+ donate_argnums=(0,),
308
+ )
309
+ def _use_continuous_quantile_head_fn(full_forecast, quantile_spreads, max_horizon):
310
+ """Uses continuous quantile head."""
311
+ to_stack = [full_forecast[..., :max_horizon, 0]]
312
+ for quantile_index in [1, 2, 3, 4]:
313
+ to_stack.append(
314
+ quantile_spreads[:, :max_horizon, quantile_index]
315
+ - quantile_spreads[:, :max_horizon, 5]
316
+ + full_forecast[:, :max_horizon, 5]
317
+ )
318
+ to_stack.append(full_forecast[..., :max_horizon, 5])
319
+ for quantile_index in [6, 7, 8, 9]:
320
+ to_stack.append(
321
+ quantile_spreads[:, :max_horizon, quantile_index]
322
+ - quantile_spreads[:, :max_horizon, 5]
323
+ + full_forecast[:, :max_horizon, 5]
324
+ )
325
+ return jnp.stack(to_stack, axis=-1)
326
+
327
+
328
+ @functools.partial(jax.jit, donate_argnums=(0,))
329
+ def _fix_quantile_crossing_fn(full_forecast):
330
+ """Fixes quantile crossing."""
331
+ lower_quantiles = _scan_along_axis(
332
+ lambda carry, x: (w := jnp.minimum(carry, x), w),
333
+ init=full_forecast[..., 5],
334
+ xs=full_forecast[..., 1:5],
335
+ axis=-1,
336
+ reverse=True,
337
+ )[1]
338
+ upper_quantiles = _scan_along_axis(
339
+ lambda carry, x: (w := jnp.maximum(carry, x), w),
340
+ init=full_forecast[..., 5],
341
+ xs=full_forecast[..., 6:10],
342
+ axis=-1,
343
+ reverse=False,
344
+ )[1]
345
+ return jnp.concatenate(
346
+ [
347
+ full_forecast[..., :1],
348
+ lower_quantiles,
349
+ full_forecast[..., 5:6],
350
+ upper_quantiles,
351
+ ],
352
+ axis=-1,
353
+ )
354
+
355
+
356
+ @functools.partial(jax.jit, static_argnames=("fc",), donate_argnums=(1, 2))
357
+ def _before_model_decode(fc, inputs, masks):
358
+ """All Jax steps before model decode call."""
359
+ if fc.infer_is_positive:
360
+ is_positive = jnp.all(inputs >= 0, axis=-1, keepdims=True)
361
+ else:
362
+ is_positive = None
363
+
364
+ if fc.normalize_inputs:
365
+ mu = jnp.mean(inputs, axis=-1, keepdims=True)
366
+ sigma = jnp.std(inputs, axis=-1, keepdims=True)
367
+ inputs = revin(inputs, mu, sigma, reverse=False)
368
+ else:
369
+ mu, sigma = None, None
370
+
371
+ inputs = jax_einshape("(tb)...->tb...", inputs, b=fc.per_core_batch_size)
372
+ masks = jax_einshape("(tb)...->tb...", masks, b=fc.per_core_batch_size)
373
+
374
+ return inputs, masks, is_positive, mu, sigma
375
+
376
+
377
+ @functools.partial(
378
+ jax.jit,
379
+ static_argnames=(
380
+ "fc",
381
+ "p",
382
+ ),
383
+ donate_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9),
384
+ )
385
+ def _after_model_decode(
386
+ fc,
387
+ pf_outputs,
388
+ quantile_spreads,
389
+ ar_outputs,
390
+ flipped_pf_outputs,
391
+ flipped_quantile_spreads,
392
+ flipped_ar_outputs,
393
+ is_positive,
394
+ mu,
395
+ sigma,
396
+ p,
397
+ ):
398
+ """All Jax steps after model decode call."""
399
+ # t: num_devices, b: per_core_batch_size
400
+ pf_outputs = jax_einshape("tb...->(tb)...", pf_outputs)
401
+ quantile_spreads = jax_einshape("tb...->(tb)...", quantile_spreads)
402
+ to_concat = [pf_outputs[:, -1, ...]]
403
+ if ar_outputs is not None:
404
+ ar_outputs = jax_einshape("tbno...->(tb)(no)...", ar_outputs)
405
+ to_concat.append(ar_outputs)
406
+ full_forecast = jnp.concatenate(to_concat, axis=1)
407
+
408
+ if fc.force_flip_invariance:
409
+ (
410
+ flipped_quantile_spreads,
411
+ flipped_pf_outputs,
412
+ flipped_full_forecast,
413
+ ) = _force_flip_invariance_fn(
414
+ flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs
415
+ )
416
+ quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
417
+ pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
418
+ full_forecast = (full_forecast - flipped_full_forecast) / 2
419
+
420
+ if fc.use_continuous_quantile_head:
421
+ full_forecast = _use_continuous_quantile_head_fn(
422
+ full_forecast, quantile_spreads, fc.max_horizon
423
+ )
424
+
425
+ if fc.return_backcast:
426
+ full_backcast = jax_einshape("...npq->...(np)q", pf_outputs[:, :-1, :p, :])
427
+ full_forecast = jnp.concatenate([full_backcast, full_forecast], axis=1)
428
+
429
+ if fc.fix_quantile_crossing:
430
+ full_forecast = _fix_quantile_crossing_fn(full_forecast)
431
+
432
+ if fc.normalize_inputs:
433
+ full_forecast = revin(full_forecast, mu, sigma, reverse=True)
434
+
435
+ if is_positive is not None:
436
+ full_forecast = jnp.where(
437
+ is_positive[..., None],
438
+ jnp.maximum(full_forecast, jnp.zeros_like(full_forecast)),
439
+ full_forecast,
440
+ )
441
+
442
+ return full_forecast
443
+
444
+
445
+ class TimesFM_2p5_200M_flax(timesfm_2p5_base.TimesFM_2p5):
446
+ """Flax implementation of TimesFM 2.5 with 200M parameters."""
447
+
448
+ model: nnx.Module = TimesFM_2p5_200M_flax_module()
449
+
450
+ @classmethod
451
+ def from_pretrained(
452
+ cls,
453
+ model_id: str = "google/timesfm-2.5-200m-flax",
454
+ *,
455
+ revision: str | None = None,
456
+ cache_dir: str | Path | None = None,
457
+ force_download: bool = False,
458
+ proxies: Dict | None = None,
459
+ resume_download: bool | None = None,
460
+ local_files_only: bool | None = None,
461
+ token: str | None = None,
462
+ **model_kwargs,
463
+ ):
464
+ """Loads a Flax TimesFM model."""
465
+
466
+ # Create an instance of the model wrapper class.
467
+ instance = cls(**model_kwargs)
468
+
469
+ # Determine the path to the model weights.
470
+ model_file_path = ""
471
+ if os.path.isdir(model_id):
472
+ logging.info("Loading checkpoint from local directory: %s", model_id)
473
+ model_file_path = model_id
474
+ else:
475
+ logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
476
+ model_file_path = huggingface_hub.snapshot_download(
477
+ repo_id=model_id,
478
+ revision=revision,
479
+ cache_dir=cache_dir,
480
+ force_download=force_download,
481
+ proxies=proxies,
482
+ resume_download=resume_download,
483
+ token=token,
484
+ local_files_only=local_files_only,
485
+ )
486
+ logging.info("Loading checkpoint from: %s", model_file_path)
487
+
488
+ checkpointer = ocp.StandardCheckpointer()
489
+ graph, state = nnx.split(instance.model)
490
+ state = checkpointer.restore(model_file_path, state)
491
+ instance.model = nnx.merge(graph, state)
492
+ return instance
493
+
494
+ def compile(
495
+ self,
496
+ forecast_config: configs.ForecastConfig,
497
+ dryrun: bool = True,
498
+ **kwargs
499
+ ):
500
+ # Acrobym used during validation.
501
+ print("Compiling model...")
502
+
503
+ fc = forecast_config
504
+ if fc.max_context % self.model.p != 0:
505
+ logging.info(
506
+ "When compiling, max context needs to be multiple of the patch size"
507
+ " %d. Using max context = %d instead.",
508
+ self.model.p,
509
+ new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
510
+ )
511
+ fc = dataclasses.replace(fc, max_context=new_context)
512
+ if fc.max_horizon % self.model.o != 0:
513
+ logging.info(
514
+ "When compiling, max horizon needs to be multiple of the output patch"
515
+ " size %d. Using max horizon = %d instead.",
516
+ self.model.o,
517
+ new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
518
+ )
519
+ fc = dataclasses.replace(fc, max_horizon=new_horizon)
520
+ if fc.max_context + fc.max_horizon > self.model.config.context_limit:
521
+ raise ValueError(
522
+ "Context + horizon must be less than the context limit."
523
+ f" {fc.max_context} + {fc.max_horizon} >"
524
+ f" {self.model.config.context_limit}."
525
+ )
526
+ if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
527
+ raise ValueError(
528
+ f"Continuous quantile head is not supported for horizons > {self.model.os}."
529
+ )
530
+
531
+ self.forecast_config = fc
532
+ self.model.compile(
533
+ context=self.forecast_config.max_context,
534
+ horizon=self.forecast_config.max_horizon,
535
+ per_core_batch_size=fc.per_core_batch_size,
536
+ )
537
+ self.per_core_batch_size = self.forecast_config.per_core_batch_size
538
+ self.num_devices = self.model.num_devices
539
+ self.global_batch_size = (
540
+ self.forecast_config.per_core_batch_size * self.model.num_devices
541
+ )
542
+
543
+ def compiled_decode_kernel(fc, horizon, inputs, masks):
544
+ inputs = jnp.array(inputs, dtype=jnp.float32)
545
+ masks = jnp.array(masks, dtype=jnp.bool)
546
+ if horizon > fc.max_horizon:
547
+ raise ValueError(
548
+ f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
549
+ )
550
+ to_trim = fc.max_horizon - horizon
551
+
552
+ inputs, masks, is_positive, mu, sigma = _before_model_decode(fc, inputs, masks)
553
+
554
+ pf_outputs, quantile_spreads, ar_outputs = self.model.compiled_decode(
555
+ fc.max_horizon, inputs, masks
556
+ )
557
+ if fc.force_flip_invariance:
558
+ flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
559
+ self.model.compiled_decode(fc.max_horizon, -inputs, masks)
560
+ )
561
+ else:
562
+ flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
563
+ None,
564
+ None,
565
+ None,
566
+ )
567
+
568
+ full_forecast = _after_model_decode(
569
+ fc,
570
+ pf_outputs,
571
+ quantile_spreads,
572
+ ar_outputs,
573
+ flipped_pf_outputs,
574
+ flipped_quantile_spreads,
575
+ flipped_ar_outputs,
576
+ is_positive,
577
+ mu,
578
+ sigma,
579
+ self.model.p,
580
+ )
581
+ full_forecast_np = np.array(full_forecast)
582
+ del full_forecast
583
+ try_gc()
584
+ if to_trim > 0:
585
+ full_forecast_np = full_forecast_np[..., :-to_trim, :]
586
+ return full_forecast_np[..., 5], full_forecast_np
587
+
588
+ self.compiled_decode = functools.partial(
589
+ compiled_decode_kernel, self.forecast_config
590
+ )
591
+
592
+ if dryrun:
593
+ _ = self.compiled_decode(
594
+ self.forecast_config.max_horizon,
595
+ jnp.zeros(
596
+ (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.float32
597
+ ),
598
+ jnp.zeros(
599
+ (self.global_batch_size, self.forecast_config.max_context), dtype=jnp.bool
600
+ ),
601
+ )
602
+ print("Compiling done.")