tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
timesfm/flax/util.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Flax utility functions for TimesFM layers."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import functools
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
import jaxtyping
|
|
22
|
+
|
|
23
|
+
Float = jaxtyping.Float
|
|
24
|
+
Array = jaxtyping.Array
|
|
25
|
+
Bool = jaxtyping.Bool
|
|
26
|
+
Integer = jaxtyping.Integer
|
|
27
|
+
|
|
28
|
+
_TOLERANCE = 1e-6
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@jax.tree_util.register_dataclass
|
|
32
|
+
@dataclasses.dataclass(frozen=False)
|
|
33
|
+
class DecodeCache:
|
|
34
|
+
"""Cache for decoding."""
|
|
35
|
+
|
|
36
|
+
next_index: Integer[Array, "b"]
|
|
37
|
+
num_masked: Integer[Array, "b"]
|
|
38
|
+
key: Float[Array, "b n h d"]
|
|
39
|
+
value: Float[Array, "b n h d"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@jax.jit
|
|
43
|
+
def update_running_stats(
|
|
44
|
+
n: Float[Array, "b"],
|
|
45
|
+
mu: Float[Array, "b"],
|
|
46
|
+
sigma: Float[Array, "b"],
|
|
47
|
+
x: Float[Array, "b p"],
|
|
48
|
+
mask: Bool[Array, "b p"],
|
|
49
|
+
) -> tuple[
|
|
50
|
+
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
|
|
51
|
+
tuple[Float[Array, "b"], Float[Array, "b"], Float[Array, "b"]],
|
|
52
|
+
]:
|
|
53
|
+
"""Updates the running stats."""
|
|
54
|
+
is_legit = jnp.logical_not(mask)
|
|
55
|
+
inc_n = jnp.sum(is_legit.astype(jnp.float32), axis=-1, keepdims=False)
|
|
56
|
+
inc_mu = jnp.where(
|
|
57
|
+
inc_n == 0, 0.0, jnp.mean(x, axis=-1, keepdims=False, where=is_legit)
|
|
58
|
+
)
|
|
59
|
+
inc_sigma = jnp.where(
|
|
60
|
+
inc_n == 0, 0.0, jnp.std(x, axis=-1, keepdims=False, where=is_legit)
|
|
61
|
+
)
|
|
62
|
+
new_n = n + inc_n
|
|
63
|
+
new_mu = jnp.where(new_n == 0, 0.0, (n * mu + inc_mu * inc_n) / new_n)
|
|
64
|
+
new_sigma = jnp.sqrt(
|
|
65
|
+
jnp.where(
|
|
66
|
+
new_n == 0,
|
|
67
|
+
0.0,
|
|
68
|
+
(
|
|
69
|
+
n * sigma * sigma
|
|
70
|
+
+ inc_n * inc_sigma * inc_sigma
|
|
71
|
+
+ n * (mu - new_mu) * (mu - new_mu)
|
|
72
|
+
+ inc_n * (inc_mu - new_mu) * (inc_mu - new_mu)
|
|
73
|
+
)
|
|
74
|
+
/ new_n,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
return (w := (new_n, new_mu, new_sigma), w)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def scan_along_axis(f, init, xs, axis: int, **kwargs):
|
|
81
|
+
"""Scans along an axis."""
|
|
82
|
+
moved_xs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), xs)
|
|
83
|
+
carry, moved_ys = jax.lax.scan(f, init, moved_xs, **kwargs)
|
|
84
|
+
return (
|
|
85
|
+
carry,
|
|
86
|
+
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), moved_ys),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@functools.partial(jax.jit, static_argnames=("reverse",))
|
|
91
|
+
def revin(
|
|
92
|
+
x: Float[Array, "b ..."],
|
|
93
|
+
mu: Float[Array, "b ..."],
|
|
94
|
+
sigma: Float[Array, "b ..."],
|
|
95
|
+
reverse: bool = False,
|
|
96
|
+
):
|
|
97
|
+
"""Reversible per-instance normalization."""
|
|
98
|
+
if len(mu.shape) == len(x.shape) - 1:
|
|
99
|
+
mu = mu[..., None]
|
|
100
|
+
sigma = sigma[..., None]
|
|
101
|
+
elif len(mu.shape) == len(x.shape) - 2:
|
|
102
|
+
mu = mu[..., None, None]
|
|
103
|
+
sigma = sigma[..., None, None]
|
|
104
|
+
if reverse:
|
|
105
|
+
return x * sigma + mu
|
|
106
|
+
else:
|
|
107
|
+
return (x - mu) / jnp.where(sigma < _TOLERANCE, 1.0, sigma)
|
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""TimesFM 2p5 base implementation."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
from typing import Any, Callable, Sequence
|
|
19
|
+
|
|
20
|
+
import collections
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from .. import configs
|
|
24
|
+
|
|
25
|
+
ResidualBlockConfig = configs.ResidualBlockConfig
|
|
26
|
+
StackedTransformersConfig = configs.StackedTransformersConfig
|
|
27
|
+
TransformerConfig = configs.TransformerConfig
|
|
28
|
+
ForecastConfig = configs.ForecastConfig
|
|
29
|
+
Category = int | str
|
|
30
|
+
XRegMode = str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def strip_leading_nans(arr):
|
|
34
|
+
"""Removes contiguous NaN values from the beginning of a NumPy array.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
arr: The input NumPy array.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A new NumPy array with leading NaN values removed.
|
|
41
|
+
If the array is all NaNs or empty, returns an empty array.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
isnan = np.isnan(arr)
|
|
45
|
+
first_valid_index = np.argmax(~isnan)
|
|
46
|
+
return arr[first_valid_index:]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def linear_interpolation(arr):
|
|
50
|
+
"""Performs linear interpolation to fill NaN values in a 1D numpy array.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
arr: The 1D numpy array containing NaN values.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A new numpy array with NaN values filled using linear interpolation,
|
|
57
|
+
or the original array if no NaNs are present.
|
|
58
|
+
Returns None if the input is not a 1D array.
|
|
59
|
+
Returns the original array if there are no NaN values.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
nans = np.isnan(arr)
|
|
63
|
+
if not np.any(nans): # Check if there are any NaNs
|
|
64
|
+
return arr
|
|
65
|
+
|
|
66
|
+
def x(z):
|
|
67
|
+
return z.nonzero()[0]
|
|
68
|
+
|
|
69
|
+
nans_indices = x(nans)
|
|
70
|
+
non_nans_indices = x(~nans)
|
|
71
|
+
non_nans_values = arr[~nans]
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
|
|
75
|
+
except ValueError:
|
|
76
|
+
if non_nans_values:
|
|
77
|
+
mu = np.nanmean(arr)
|
|
78
|
+
else:
|
|
79
|
+
mu = 0.0
|
|
80
|
+
arr = np.where(np.isfinite(arr), arr, mu)
|
|
81
|
+
return arr
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclasses.dataclass(frozen=True)
|
|
85
|
+
class TimesFM_2p5_200M_Definition:
|
|
86
|
+
"""Framework-agnostic config of TimesFM 2.5."""
|
|
87
|
+
|
|
88
|
+
context_limit = 16384
|
|
89
|
+
input_patch_len: int = 32
|
|
90
|
+
output_patch_len: int = 128
|
|
91
|
+
output_quantile_len: int = 1024
|
|
92
|
+
quantiles: list[float] = dataclasses.field(
|
|
93
|
+
default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
|
94
|
+
)
|
|
95
|
+
decode_index: int = 5
|
|
96
|
+
tokenizer: ResidualBlockConfig = ResidualBlockConfig(
|
|
97
|
+
input_dims=64,
|
|
98
|
+
hidden_dims=1280,
|
|
99
|
+
output_dims=1280,
|
|
100
|
+
use_bias=True,
|
|
101
|
+
activation="swish",
|
|
102
|
+
)
|
|
103
|
+
stacked_transformers: StackedTransformersConfig = StackedTransformersConfig(
|
|
104
|
+
num_layers=20,
|
|
105
|
+
transformer=TransformerConfig(
|
|
106
|
+
model_dims=1280,
|
|
107
|
+
hidden_dims=1280,
|
|
108
|
+
num_heads=16,
|
|
109
|
+
attention_norm="rms",
|
|
110
|
+
feedforward_norm="rms",
|
|
111
|
+
qk_norm="rms",
|
|
112
|
+
use_bias=False,
|
|
113
|
+
use_rotary_position_embeddings=True,
|
|
114
|
+
ff_activation="swish",
|
|
115
|
+
fuse_qkv=True,
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
|
|
119
|
+
input_dims=1280,
|
|
120
|
+
hidden_dims=1280,
|
|
121
|
+
output_dims=1280,
|
|
122
|
+
use_bias=False,
|
|
123
|
+
activation="swish",
|
|
124
|
+
)
|
|
125
|
+
output_projection_quantiles: ResidualBlockConfig = ResidualBlockConfig(
|
|
126
|
+
input_dims=1280,
|
|
127
|
+
hidden_dims=1280,
|
|
128
|
+
output_dims=10240,
|
|
129
|
+
use_bias=False,
|
|
130
|
+
activation="swish",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class TimesFM_2p5:
|
|
135
|
+
"""Abstract base class for TimesFM models.
|
|
136
|
+
|
|
137
|
+
Attributes:
|
|
138
|
+
forecast_config: Configuration for forecasting flags.
|
|
139
|
+
compiled_decode: Compiled decode function.
|
|
140
|
+
global_batch_size: Global batch size.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
forecast_config: ForecastConfig | None = None
|
|
144
|
+
compiled_decode: Callable[..., Any] | None = None
|
|
145
|
+
global_batch_size: int = 0
|
|
146
|
+
|
|
147
|
+
def load_checkpoint(self, path: str):
|
|
148
|
+
"""Loads a TimesFM model from a checkpoint."""
|
|
149
|
+
raise NotImplementedError()
|
|
150
|
+
|
|
151
|
+
def compile(self, forecast_config: ForecastConfig | None = None):
|
|
152
|
+
"""Compiles the TimesFM model for fast decoding."""
|
|
153
|
+
raise NotImplementedError()
|
|
154
|
+
|
|
155
|
+
def forecast(
|
|
156
|
+
self, horizon: int, inputs: list[np.ndarray]
|
|
157
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
158
|
+
"""Forecasts the time series."""
|
|
159
|
+
if self.compiled_decode is None:
|
|
160
|
+
raise RuntimeError("Model is not compiled. Please call compile() first.")
|
|
161
|
+
|
|
162
|
+
assert self.global_batch_size > 0
|
|
163
|
+
assert self.forecast_config is not None
|
|
164
|
+
|
|
165
|
+
context = self.forecast_config.max_context
|
|
166
|
+
num_inputs = len(inputs)
|
|
167
|
+
if (w := num_inputs % self.global_batch_size) != 0:
|
|
168
|
+
inputs += [np.array([0.0] * 3)] * (self.global_batch_size - w)
|
|
169
|
+
|
|
170
|
+
output_points = []
|
|
171
|
+
output_quantiles = []
|
|
172
|
+
values = []
|
|
173
|
+
masks = []
|
|
174
|
+
idx = 0
|
|
175
|
+
for each_input in inputs:
|
|
176
|
+
value = linear_interpolation(strip_leading_nans(np.array(each_input)))
|
|
177
|
+
if (w := len(value)) >= context:
|
|
178
|
+
value = value[-context:]
|
|
179
|
+
mask = np.zeros_like(value, dtype=bool)
|
|
180
|
+
else:
|
|
181
|
+
mask = np.array([True] * (context - w) + [False] * w)
|
|
182
|
+
value = np.pad(value, (context - w, 0), "constant", constant_values=0.0)
|
|
183
|
+
values.append(value)
|
|
184
|
+
masks.append(mask)
|
|
185
|
+
idx += 1
|
|
186
|
+
if idx == self.global_batch_size:
|
|
187
|
+
idx = 0
|
|
188
|
+
point_forecast, quantile_forecast = self.compiled_decode(horizon, values, masks)
|
|
189
|
+
output_points.append(point_forecast)
|
|
190
|
+
output_quantiles.append(quantile_forecast)
|
|
191
|
+
values = []
|
|
192
|
+
masks = []
|
|
193
|
+
|
|
194
|
+
output_points = np.concatenate(output_points, axis=0)
|
|
195
|
+
output_quantiles = np.concatenate(output_quantiles, axis=0)
|
|
196
|
+
return output_points[:num_inputs], output_quantiles[:num_inputs]
|
|
197
|
+
|
|
198
|
+
def forecast_with_covariates(
|
|
199
|
+
self,
|
|
200
|
+
inputs: list[Sequence[float]],
|
|
201
|
+
dynamic_numerical_covariates: dict[str, Sequence[Sequence[float]]] | None = None,
|
|
202
|
+
dynamic_categorical_covariates: (
|
|
203
|
+
dict[str, Sequence[Sequence[Category]]] | None
|
|
204
|
+
) = None,
|
|
205
|
+
static_numerical_covariates: dict[str, Sequence[float]] | None = None,
|
|
206
|
+
static_categorical_covariates: dict[str, Sequence[Category]] | None = None,
|
|
207
|
+
xreg_mode: XRegMode = "xreg + timesfm",
|
|
208
|
+
normalize_xreg_target_per_input: bool = True,
|
|
209
|
+
ridge: float = 0.0,
|
|
210
|
+
max_rows_per_col: int = 0,
|
|
211
|
+
force_on_cpu: bool = False,
|
|
212
|
+
):
|
|
213
|
+
"""Forecasts on a list of time series with covariates.
|
|
214
|
+
|
|
215
|
+
To optimize inference speed, avoid string valued categorical covariates.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
inputs: A list of time series forecast contexts. Each context time series
|
|
219
|
+
should be in a format convertible to JTensor by `jnp.array`.
|
|
220
|
+
dynamic_numerical_covariates: A dict of dynamic numerical covariates.
|
|
221
|
+
dynamic_categorical_covariates: A dict of dynamic categorical covariates.
|
|
222
|
+
static_numerical_covariates: A dict of static numerical covariates.
|
|
223
|
+
static_categorical_covariates: A dict of static categorical covariates.
|
|
224
|
+
xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm"
|
|
225
|
+
fits a model on the residuals of the TimesFM forecast. "timesfm + xreg"
|
|
226
|
+
fits a model on the targets then forecasts on the residuals via TimesFM.
|
|
227
|
+
normalize_xreg_target_per_input: whether to normalize the xreg target per
|
|
228
|
+
input in the given batch.
|
|
229
|
+
ridge: ridge penalty for the linear model.
|
|
230
|
+
max_rows_per_col: max number of rows per column for the linear model.
|
|
231
|
+
force_on_cpu: whether to force running on cpu for the linear model.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
A tuple of two lists. The first is the outputs of the model. The second is
|
|
235
|
+
the outputs of the xreg.
|
|
236
|
+
"""
|
|
237
|
+
if self.forecast_config is None:
|
|
238
|
+
raise ValueError("Model is not compiled. Please call compile() first.")
|
|
239
|
+
elif not self.forecast_config.return_backcast:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"For XReg, `return_backcast` must be set to True in the forecast config. Please recompile the model."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
from ..utils import xreg_lib
|
|
245
|
+
|
|
246
|
+
# Verify and bookkeep covariates.
|
|
247
|
+
if not (
|
|
248
|
+
dynamic_numerical_covariates
|
|
249
|
+
or dynamic_categorical_covariates
|
|
250
|
+
or static_numerical_covariates
|
|
251
|
+
or static_categorical_covariates
|
|
252
|
+
):
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"At least one of dynamic_numerical_covariates,"
|
|
255
|
+
" dynamic_categorical_covariates, static_numerical_covariates,"
|
|
256
|
+
" static_categorical_covariates must be set."
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Track the lengths of (1) each input, (2) the part that can be used in the
|
|
260
|
+
# linear model, and (3) the horizon.
|
|
261
|
+
input_lens, train_lens, test_lens = [], [], []
|
|
262
|
+
|
|
263
|
+
for i, input_ts in enumerate(inputs):
|
|
264
|
+
input_len = len(input_ts)
|
|
265
|
+
input_lens.append(input_len)
|
|
266
|
+
|
|
267
|
+
if xreg_mode == "timesfm + xreg":
|
|
268
|
+
# For fitting residuals, no TimesFM forecast on the first patch.
|
|
269
|
+
train_lens.append(max(0, input_len - self.model.p))
|
|
270
|
+
elif xreg_mode == "xreg + timesfm":
|
|
271
|
+
train_lens.append(input_len)
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(f"Unsupported mode: {xreg_mode}")
|
|
274
|
+
|
|
275
|
+
if dynamic_numerical_covariates:
|
|
276
|
+
test_lens.append(
|
|
277
|
+
len(list(dynamic_numerical_covariates.values())[0][i]) - input_len
|
|
278
|
+
)
|
|
279
|
+
elif dynamic_categorical_covariates:
|
|
280
|
+
test_lens.append(
|
|
281
|
+
len(list(dynamic_categorical_covariates.values())[0][i]) - input_len
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
test_lens.append(self.forecast_config.max_horizon)
|
|
285
|
+
|
|
286
|
+
if test_lens[-1] > self.forecast_config.max_horizon:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"Forecast horizon length inferred from the dynamic covaraites is longer than the"
|
|
289
|
+
f"max_horizon defined in the forecast config: {test_lens[-1]} > {self.forecast_config.max_horizon=}."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Prepare the covariates into train and test.
|
|
293
|
+
train_dynamic_numerical_covariates = collections.defaultdict(list)
|
|
294
|
+
test_dynamic_numerical_covariates = collections.defaultdict(list)
|
|
295
|
+
train_dynamic_categorical_covariates = collections.defaultdict(list)
|
|
296
|
+
test_dynamic_categorical_covariates = collections.defaultdict(list)
|
|
297
|
+
for covariates, train_covariates, test_covariates in (
|
|
298
|
+
(
|
|
299
|
+
dynamic_numerical_covariates,
|
|
300
|
+
train_dynamic_numerical_covariates,
|
|
301
|
+
test_dynamic_numerical_covariates,
|
|
302
|
+
),
|
|
303
|
+
(
|
|
304
|
+
dynamic_categorical_covariates,
|
|
305
|
+
train_dynamic_categorical_covariates,
|
|
306
|
+
test_dynamic_categorical_covariates,
|
|
307
|
+
),
|
|
308
|
+
):
|
|
309
|
+
if not covariates:
|
|
310
|
+
continue
|
|
311
|
+
for covariate_name, covariate_values in covariates.items():
|
|
312
|
+
for input_len, train_len, covariate_value in zip(
|
|
313
|
+
input_lens, train_lens, covariate_values
|
|
314
|
+
):
|
|
315
|
+
train_covariates[covariate_name].append(
|
|
316
|
+
covariate_value[(input_len - train_len) : input_len]
|
|
317
|
+
)
|
|
318
|
+
test_covariates[covariate_name].append(covariate_value[input_len:])
|
|
319
|
+
|
|
320
|
+
# Fit models.
|
|
321
|
+
if xreg_mode == "timesfm + xreg":
|
|
322
|
+
# Forecast via TimesFM then fit a model on the residuals.
|
|
323
|
+
point_outputs, quantile_outputs = self.forecast(
|
|
324
|
+
horizon=self.forecast_config.max_horizon, inputs=inputs
|
|
325
|
+
)
|
|
326
|
+
targets = [
|
|
327
|
+
(
|
|
328
|
+
np.array(input_ts)[-train_len:]
|
|
329
|
+
- point_output[: -self.forecast_config.max_horizon][-train_len:]
|
|
330
|
+
)
|
|
331
|
+
for input_ts, point_output, train_len in zip(inputs, point_outputs, train_lens)
|
|
332
|
+
]
|
|
333
|
+
per_instance_stats = None
|
|
334
|
+
if normalize_xreg_target_per_input:
|
|
335
|
+
targets, per_instance_stats = xreg_lib.normalize(targets)
|
|
336
|
+
xregs = xreg_lib.BatchedInContextXRegLinear(
|
|
337
|
+
targets=targets,
|
|
338
|
+
train_lens=train_lens,
|
|
339
|
+
test_lens=test_lens,
|
|
340
|
+
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
|
|
341
|
+
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
|
|
342
|
+
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
|
|
343
|
+
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
|
|
344
|
+
static_numerical_covariates=static_numerical_covariates,
|
|
345
|
+
static_categorical_covariates=static_categorical_covariates,
|
|
346
|
+
).fit(
|
|
347
|
+
ridge=ridge,
|
|
348
|
+
one_hot_encoder_drop=None if ridge > 0 else "first",
|
|
349
|
+
max_rows_per_col=max_rows_per_col,
|
|
350
|
+
force_on_cpu=force_on_cpu,
|
|
351
|
+
debug_info=False,
|
|
352
|
+
assert_covariates=True,
|
|
353
|
+
assert_covariate_shapes=True,
|
|
354
|
+
)
|
|
355
|
+
if normalize_xreg_target_per_input:
|
|
356
|
+
xregs = xreg_lib.renormalize(xregs, per_instance_stats)
|
|
357
|
+
xregs = np.array(xregs)
|
|
358
|
+
new_point_outputs = [
|
|
359
|
+
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
|
|
360
|
+
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
|
|
361
|
+
]
|
|
362
|
+
new_quantile_outputs = [
|
|
363
|
+
(
|
|
364
|
+
quantile_output[-self.forecast_config.max_horizon :][:test_len]
|
|
365
|
+
+ xreg[..., None]
|
|
366
|
+
)
|
|
367
|
+
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
|
|
368
|
+
]
|
|
369
|
+
|
|
370
|
+
else:
|
|
371
|
+
# Fit a model on the targets then forecast on the residuals via TimesFM.
|
|
372
|
+
targets = [
|
|
373
|
+
np.array(input_ts)[-train_len:]
|
|
374
|
+
for input_ts, train_len in zip(inputs, train_lens)
|
|
375
|
+
]
|
|
376
|
+
per_instance_stats = None
|
|
377
|
+
if normalize_xreg_target_per_input:
|
|
378
|
+
targets, per_instance_stats = xreg_lib.normalize(targets)
|
|
379
|
+
xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
|
|
380
|
+
targets=targets,
|
|
381
|
+
train_lens=train_lens,
|
|
382
|
+
test_lens=test_lens,
|
|
383
|
+
train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
|
|
384
|
+
test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
|
|
385
|
+
train_dynamic_categorical_covariates=train_dynamic_categorical_covariates,
|
|
386
|
+
test_dynamic_categorical_covariates=test_dynamic_categorical_covariates,
|
|
387
|
+
static_numerical_covariates=static_numerical_covariates,
|
|
388
|
+
static_categorical_covariates=static_categorical_covariates,
|
|
389
|
+
).fit(
|
|
390
|
+
ridge=ridge,
|
|
391
|
+
one_hot_encoder_drop=None if ridge > 0 else "first",
|
|
392
|
+
max_rows_per_col=max_rows_per_col,
|
|
393
|
+
force_on_cpu=force_on_cpu,
|
|
394
|
+
debug_info=True,
|
|
395
|
+
assert_covariates=True,
|
|
396
|
+
assert_covariate_shapes=True,
|
|
397
|
+
)
|
|
398
|
+
point_outputs, quantile_outputs = self.forecast(
|
|
399
|
+
horizon=self.forecast_config.max_horizon,
|
|
400
|
+
inputs=[
|
|
401
|
+
target - xreg_on_context
|
|
402
|
+
for target, xreg_on_context in zip(targets, xregs_on_context)
|
|
403
|
+
],
|
|
404
|
+
)
|
|
405
|
+
new_point_outputs = [
|
|
406
|
+
(point_output[-self.forecast_config.max_horizon :][:test_len] + xreg)
|
|
407
|
+
for point_output, test_len, xreg in zip(point_outputs, test_lens, xregs)
|
|
408
|
+
]
|
|
409
|
+
new_quantile_outputs = [
|
|
410
|
+
(
|
|
411
|
+
quantile_output[-self.forecast_config.max_horizon :][:test_len]
|
|
412
|
+
+ xreg[..., None]
|
|
413
|
+
)
|
|
414
|
+
for quantile_output, test_len, xreg in zip(quantile_outputs, test_lens, xregs)
|
|
415
|
+
]
|
|
416
|
+
if normalize_xreg_target_per_input:
|
|
417
|
+
new_point_outputs = xreg_lib.renormalize(new_point_outputs, per_instance_stats)
|
|
418
|
+
new_quantile_outputs = xreg_lib.renormalize(
|
|
419
|
+
new_quantile_outputs, per_instance_stats
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
return new_point_outputs, new_quantile_outputs
|