tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""TimesFM models."""
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import logging
|
|
18
|
+
import math
|
|
19
|
+
import os
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Dict, Optional, Sequence, Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import torch
|
|
25
|
+
from huggingface_hub import ModelHubMixin, hf_hub_download
|
|
26
|
+
from safetensors.torch import load_file, save_file
|
|
27
|
+
from torch import nn
|
|
28
|
+
|
|
29
|
+
from .. import configs
|
|
30
|
+
from ..torch import dense, transformer, util
|
|
31
|
+
from . import timesfm_2p5_base
|
|
32
|
+
|
|
33
|
+
revin = util.revin
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TimesFM_2p5_200M_torch_module(nn.Module):
|
|
37
|
+
"""TimesFM 2.5 with 200M parameters."""
|
|
38
|
+
|
|
39
|
+
config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
super().__init__()
|
|
43
|
+
|
|
44
|
+
# Names constants.
|
|
45
|
+
self.p = self.config.input_patch_len # 32
|
|
46
|
+
self.o = self.config.output_patch_len # 128
|
|
47
|
+
self.os = self.config.output_quantile_len # 1024
|
|
48
|
+
self.m = self.o // self.p # 4
|
|
49
|
+
self.x = self.config.stacked_transformers.num_layers # 20
|
|
50
|
+
self.h = self.config.stacked_transformers.transformer.num_heads # 16
|
|
51
|
+
self.md = self.config.stacked_transformers.transformer.model_dims # 1280
|
|
52
|
+
self.hd = self.md // self.h # 80
|
|
53
|
+
self.q = len(self.config.quantiles) + 1 # 10
|
|
54
|
+
self.aridx = self.config.decode_index # 5
|
|
55
|
+
|
|
56
|
+
# Layers.
|
|
57
|
+
self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
|
|
58
|
+
self.stacked_xf = nn.ModuleList(
|
|
59
|
+
[
|
|
60
|
+
transformer.Transformer(self.config.stacked_transformers.transformer)
|
|
61
|
+
for _ in range(self.x)
|
|
62
|
+
]
|
|
63
|
+
)
|
|
64
|
+
self.output_projection_point = dense.ResidualBlock(
|
|
65
|
+
self.config.output_projection_point
|
|
66
|
+
)
|
|
67
|
+
self.output_projection_quantiles = dense.ResidualBlock(
|
|
68
|
+
self.config.output_projection_quantiles
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Device.
|
|
72
|
+
if torch.cuda.is_available():
|
|
73
|
+
self.device = torch.device("cuda:0")
|
|
74
|
+
self.device_count = torch.cuda.device_count()
|
|
75
|
+
else:
|
|
76
|
+
self.device = torch.device("cpu")
|
|
77
|
+
self.device_count = 1
|
|
78
|
+
|
|
79
|
+
def load_checkpoint(self, path: str, **kwargs):
|
|
80
|
+
"""Loads a PyTorch TimesFM model from a checkpoint."""
|
|
81
|
+
tensors = load_file(path)
|
|
82
|
+
self.load_state_dict(tensors, strict=True)
|
|
83
|
+
self.to(self.device)
|
|
84
|
+
torch_compile = True
|
|
85
|
+
if "torch_compile" in kwargs:
|
|
86
|
+
torch_compile = kwargs["torch_compile"]
|
|
87
|
+
if torch_compile:
|
|
88
|
+
print("Compiling model...")
|
|
89
|
+
self = torch.compile(self)
|
|
90
|
+
|
|
91
|
+
self.eval()
|
|
92
|
+
|
|
93
|
+
def forward(
|
|
94
|
+
self,
|
|
95
|
+
inputs: torch.Tensor,
|
|
96
|
+
masks: torch.Tensor,
|
|
97
|
+
decode_caches: list[util.DecodeCache] | None = None,
|
|
98
|
+
):
|
|
99
|
+
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
|
|
100
|
+
input_embeddings = self.tokenizer(tokenizer_inputs)
|
|
101
|
+
|
|
102
|
+
if decode_caches is None:
|
|
103
|
+
decode_caches = [None] * self.x
|
|
104
|
+
|
|
105
|
+
output_embeddings = input_embeddings
|
|
106
|
+
new_decode_caches = []
|
|
107
|
+
for i, layer in enumerate(self.stacked_xf):
|
|
108
|
+
output_embeddings, new_cache = layer(
|
|
109
|
+
output_embeddings, masks[..., -1], decode_caches[i]
|
|
110
|
+
)
|
|
111
|
+
new_decode_caches.append(new_cache)
|
|
112
|
+
output_ts = self.output_projection_point(output_embeddings)
|
|
113
|
+
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
|
|
114
|
+
|
|
115
|
+
return (
|
|
116
|
+
input_embeddings,
|
|
117
|
+
output_embeddings,
|
|
118
|
+
output_ts,
|
|
119
|
+
output_quantile_spread,
|
|
120
|
+
), new_decode_caches
|
|
121
|
+
|
|
122
|
+
def decode(self, horizon: int, inputs, masks):
|
|
123
|
+
"""Decodes the time series."""
|
|
124
|
+
|
|
125
|
+
with torch.no_grad():
|
|
126
|
+
batch_size, context = inputs.shape[0], inputs.shape[1]
|
|
127
|
+
num_decode_steps = (horizon - 1) // self.o
|
|
128
|
+
num_input_patches = context // self.p
|
|
129
|
+
decode_cache_size = num_input_patches + num_decode_steps * self.m
|
|
130
|
+
|
|
131
|
+
# Prefill
|
|
132
|
+
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
|
|
133
|
+
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
|
|
134
|
+
|
|
135
|
+
# running stats
|
|
136
|
+
n = torch.zeros(batch_size, device=inputs.device)
|
|
137
|
+
mu = torch.zeros(batch_size, device=inputs.device)
|
|
138
|
+
sigma = torch.zeros(batch_size, device=inputs.device)
|
|
139
|
+
patch_mu = []
|
|
140
|
+
patch_sigma = []
|
|
141
|
+
for i in range(num_input_patches):
|
|
142
|
+
(n, mu, sigma), _ = util.update_running_stats(
|
|
143
|
+
n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
|
|
144
|
+
)
|
|
145
|
+
patch_mu.append(mu)
|
|
146
|
+
patch_sigma.append(sigma)
|
|
147
|
+
last_n, last_mu, last_sigma = n, mu, sigma
|
|
148
|
+
context_mu = torch.stack(patch_mu, dim=1)
|
|
149
|
+
context_sigma = torch.stack(patch_sigma, dim=1)
|
|
150
|
+
|
|
151
|
+
decode_caches = [
|
|
152
|
+
util.DecodeCache(
|
|
153
|
+
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
|
154
|
+
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
|
155
|
+
key=torch.zeros(
|
|
156
|
+
batch_size,
|
|
157
|
+
decode_cache_size,
|
|
158
|
+
self.h,
|
|
159
|
+
self.hd,
|
|
160
|
+
device=inputs.device,
|
|
161
|
+
),
|
|
162
|
+
value=torch.zeros(
|
|
163
|
+
batch_size,
|
|
164
|
+
decode_cache_size,
|
|
165
|
+
self.h,
|
|
166
|
+
self.hd,
|
|
167
|
+
device=inputs.device,
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
for _ in range(self.x)
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
|
|
174
|
+
normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
|
|
175
|
+
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(
|
|
176
|
+
normed_inputs, patched_masks, decode_caches
|
|
177
|
+
)
|
|
178
|
+
renormed_outputs = torch.reshape(
|
|
179
|
+
revin(normed_outputs, context_mu, context_sigma, reverse=True),
|
|
180
|
+
(batch_size, -1, self.o, self.q),
|
|
181
|
+
)
|
|
182
|
+
renormed_quantile_spread = torch.reshape(
|
|
183
|
+
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
|
|
184
|
+
(batch_size, -1, self.os, self.q),
|
|
185
|
+
)[:, -1, ...]
|
|
186
|
+
|
|
187
|
+
# Autogressive decode
|
|
188
|
+
ar_outputs = []
|
|
189
|
+
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
|
|
190
|
+
|
|
191
|
+
for _ in range(num_decode_steps):
|
|
192
|
+
new_patched_input = torch.reshape(
|
|
193
|
+
last_renormed_output, (batch_size, self.m, self.p)
|
|
194
|
+
)
|
|
195
|
+
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
|
|
196
|
+
|
|
197
|
+
n, mu, sigma = last_n, last_mu, last_sigma
|
|
198
|
+
new_mus, new_sigmas = [], []
|
|
199
|
+
for i in range(self.m):
|
|
200
|
+
(n, mu, sigma), _ = util.update_running_stats(
|
|
201
|
+
n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
|
|
202
|
+
)
|
|
203
|
+
new_mus.append(mu)
|
|
204
|
+
new_sigmas.append(sigma)
|
|
205
|
+
last_n, last_mu, last_sigma = n, mu, sigma
|
|
206
|
+
new_mu = torch.stack(new_mus, dim=1)
|
|
207
|
+
new_sigma = torch.stack(new_sigmas, dim=1)
|
|
208
|
+
|
|
209
|
+
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
|
|
210
|
+
(_, _, new_normed_output, _), decode_caches = self(
|
|
211
|
+
new_normed_input, new_mask, decode_caches
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
new_renormed_output = torch.reshape(
|
|
215
|
+
revin(new_normed_output, new_mu, new_sigma, reverse=True),
|
|
216
|
+
(batch_size, self.m, self.o, self.q),
|
|
217
|
+
)
|
|
218
|
+
ar_outputs.append(new_renormed_output[:, -1, ...])
|
|
219
|
+
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
|
|
220
|
+
|
|
221
|
+
if num_decode_steps > 0:
|
|
222
|
+
ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
|
|
223
|
+
else:
|
|
224
|
+
ar_renormed_outputs = None
|
|
225
|
+
|
|
226
|
+
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
|
|
227
|
+
|
|
228
|
+
def forecast_naive(
|
|
229
|
+
self, horizon: int, inputs: Sequence[np.ndarray]
|
|
230
|
+
) -> list[np.ndarray]:
|
|
231
|
+
"""Forecasts the time series.
|
|
232
|
+
|
|
233
|
+
This is a naive implementation for debugging purposes. No forecasting
|
|
234
|
+
flags are used here. Forecasting quality can be subpar.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
horizon: The number of time points to forecast.
|
|
238
|
+
inputs: A sequence of numpy arrays, each representing a time series to
|
|
239
|
+
query forecast for.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
A list of numpy arrays of forecasts.
|
|
243
|
+
"""
|
|
244
|
+
outputs = []
|
|
245
|
+
for each_input in inputs:
|
|
246
|
+
input_t = torch.tensor(each_input, dtype=torch.float32)
|
|
247
|
+
mask = torch.zeros_like(input_t, dtype=torch.bool)
|
|
248
|
+
len_front_mask = self.p - (len(each_input) % self.p)
|
|
249
|
+
if len_front_mask < self.p:
|
|
250
|
+
input_t = torch.cat(
|
|
251
|
+
[torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0
|
|
252
|
+
)
|
|
253
|
+
mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)
|
|
254
|
+
input_t = input_t[None, ...]
|
|
255
|
+
mask = mask[None, ...]
|
|
256
|
+
t_pf, _, t_ar = self.decode(horizon, input_t, mask)
|
|
257
|
+
to_concat = [t_pf[:, -1, ...]]
|
|
258
|
+
if t_ar is not None:
|
|
259
|
+
to_concat.append(t_ar.reshape(1, -1, self.q))
|
|
260
|
+
torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]
|
|
261
|
+
torch_forecast = torch_forecast.squeeze(0)
|
|
262
|
+
outputs.append(torch_forecast.detach().cpu().numpy())
|
|
263
|
+
return outputs
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, ModelHubMixin):
|
|
267
|
+
"""PyTorch implementation of TimesFM 2.5 with 200M parameters."""
|
|
268
|
+
|
|
269
|
+
model: nn.Module = TimesFM_2p5_200M_torch_module()
|
|
270
|
+
|
|
271
|
+
@classmethod
|
|
272
|
+
def _from_pretrained(
|
|
273
|
+
cls,
|
|
274
|
+
*,
|
|
275
|
+
model_id: str = "google/timesfm-2.5-200m-pytorch",
|
|
276
|
+
revision: Optional[str],
|
|
277
|
+
cache_dir: Optional[Union[str, Path]],
|
|
278
|
+
force_download: bool = True,
|
|
279
|
+
proxies: Optional[Dict] = None,
|
|
280
|
+
resume_download: Optional[bool] = None,
|
|
281
|
+
local_files_only: bool,
|
|
282
|
+
token: Optional[Union[str, bool]],
|
|
283
|
+
**model_kwargs,
|
|
284
|
+
):
|
|
285
|
+
"""
|
|
286
|
+
Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
|
|
287
|
+
Face Hub. This method is the backend for the `from_pretrained` class
|
|
288
|
+
method provided by `ModelHubMixin`.
|
|
289
|
+
"""
|
|
290
|
+
# Create an instance of the model wrapper class.
|
|
291
|
+
instance = cls(**model_kwargs)
|
|
292
|
+
# Download the config file for hf tracking.
|
|
293
|
+
_ = hf_hub_download(
|
|
294
|
+
repo_id="google/timesfm-2.5-200m-pytorch",
|
|
295
|
+
filename="config.json",
|
|
296
|
+
force_download=True,
|
|
297
|
+
)
|
|
298
|
+
print("Downloaded.")
|
|
299
|
+
|
|
300
|
+
# Determine the path to the model weights.
|
|
301
|
+
model_file_path = ""
|
|
302
|
+
if os.path.isdir(model_id):
|
|
303
|
+
logging.info("Loading checkpoint from local directory: %s", model_id)
|
|
304
|
+
model_file_path = os.path.join(model_id, "model.safetensors")
|
|
305
|
+
if not os.path.exists(model_file_path):
|
|
306
|
+
raise FileNotFoundError(f"model.safetensors not found in directory {model_id}")
|
|
307
|
+
else:
|
|
308
|
+
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
|
|
309
|
+
model_file_path = hf_hub_download(
|
|
310
|
+
repo_id=model_id,
|
|
311
|
+
filename="model.safetensors",
|
|
312
|
+
revision=revision,
|
|
313
|
+
cache_dir=cache_dir,
|
|
314
|
+
force_download=force_download,
|
|
315
|
+
proxies=proxies,
|
|
316
|
+
resume_download=resume_download,
|
|
317
|
+
token=token,
|
|
318
|
+
local_files_only=local_files_only,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
logging.info("Loading checkpoint from: %s", model_file_path)
|
|
322
|
+
# Load the weights into the model.
|
|
323
|
+
instance.model.load_checkpoint(model_file_path, **model_kwargs)
|
|
324
|
+
return instance
|
|
325
|
+
|
|
326
|
+
def _save_pretrained(self, save_directory: Union[str, Path]):
|
|
327
|
+
"""
|
|
328
|
+
Saves the model's state dictionary to a safetensors file. This method
|
|
329
|
+
is called by the `save_pretrained` method from `ModelHubMixin`.
|
|
330
|
+
"""
|
|
331
|
+
if not os.path.exists(save_directory):
|
|
332
|
+
os.makedirs(save_directory)
|
|
333
|
+
|
|
334
|
+
weights_path = os.path.join(save_directory, "model.safetensors")
|
|
335
|
+
save_file(self.model.state_dict(), weights_path)
|
|
336
|
+
|
|
337
|
+
def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:
|
|
338
|
+
"""Attempts to compile the model for fast decoding.
|
|
339
|
+
|
|
340
|
+
See configs.ForecastConfig for more details on the supported flags.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
forecast_config: Configuration for forecasting flags.
|
|
344
|
+
**kwargs: Additional keyword arguments to pass to model.compile().
|
|
345
|
+
"""
|
|
346
|
+
self.global_batch_size = (
|
|
347
|
+
forecast_config.per_core_batch_size * self.model.device_count
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Shortcut.
|
|
351
|
+
fc = forecast_config
|
|
352
|
+
|
|
353
|
+
if fc.max_context % self.model.p != 0:
|
|
354
|
+
logging.info(
|
|
355
|
+
"When compiling, max context needs to be multiple of the patch size"
|
|
356
|
+
" %d. Using max context = %d instead.",
|
|
357
|
+
self.model.p,
|
|
358
|
+
new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
|
|
359
|
+
)
|
|
360
|
+
fc = dataclasses.replace(fc, max_context=new_context)
|
|
361
|
+
if fc.max_horizon % self.model.o != 0:
|
|
362
|
+
logging.info(
|
|
363
|
+
"When compiling, max horizon needs to be multiple of the output patch"
|
|
364
|
+
" size %d. Using max horizon = %d instead.",
|
|
365
|
+
self.model.o,
|
|
366
|
+
new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
|
|
367
|
+
)
|
|
368
|
+
fc = dataclasses.replace(fc, max_horizon=new_horizon)
|
|
369
|
+
if fc.max_context + fc.max_horizon > self.model.config.context_limit:
|
|
370
|
+
raise ValueError(
|
|
371
|
+
"Context + horizon must be less than the context limit."
|
|
372
|
+
f" {fc.max_context} + {fc.max_horizon} >"
|
|
373
|
+
f" {self.model.config.context_limit}."
|
|
374
|
+
)
|
|
375
|
+
if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
|
|
376
|
+
raise ValueError(
|
|
377
|
+
f"Continuous quantile head is not supported for horizons > {self.model.os}."
|
|
378
|
+
)
|
|
379
|
+
self.forecast_config = fc
|
|
380
|
+
|
|
381
|
+
def _compiled_decode(horizon, inputs, masks):
|
|
382
|
+
if horizon > fc.max_horizon:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
inputs = (
|
|
388
|
+
torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)
|
|
389
|
+
)
|
|
390
|
+
masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)
|
|
391
|
+
batch_size = inputs.shape[0]
|
|
392
|
+
|
|
393
|
+
if fc.infer_is_positive:
|
|
394
|
+
is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)
|
|
395
|
+
else:
|
|
396
|
+
is_positive = None
|
|
397
|
+
|
|
398
|
+
if fc.normalize_inputs:
|
|
399
|
+
mu = torch.mean(inputs, dim=-1, keepdim=True)
|
|
400
|
+
sigma = torch.std(inputs, dim=-1, keepdim=True)
|
|
401
|
+
inputs = revin(inputs, mu, sigma, reverse=False)
|
|
402
|
+
else:
|
|
403
|
+
mu, sigma = None, None
|
|
404
|
+
|
|
405
|
+
pf_outputs, quantile_spreads, ar_outputs = self.model.decode(
|
|
406
|
+
forecast_config.max_horizon, inputs, masks
|
|
407
|
+
)
|
|
408
|
+
to_cat = [pf_outputs[:, -1, ...]]
|
|
409
|
+
if ar_outputs is not None:
|
|
410
|
+
to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))
|
|
411
|
+
full_forecast = torch.cat(to_cat, dim=1)
|
|
412
|
+
|
|
413
|
+
def flip_quantile_fn(x):
|
|
414
|
+
return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)
|
|
415
|
+
|
|
416
|
+
if fc.force_flip_invariance:
|
|
417
|
+
flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
|
|
418
|
+
self.model.decode(forecast_config.max_horizon, -inputs, masks)
|
|
419
|
+
)
|
|
420
|
+
flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)
|
|
421
|
+
flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)
|
|
422
|
+
to_cat = [flipped_pf_outputs[:, -1, ...]]
|
|
423
|
+
if flipped_ar_outputs is not None:
|
|
424
|
+
to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))
|
|
425
|
+
flipped_full_forecast = torch.cat(to_cat, dim=1)
|
|
426
|
+
quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
|
|
427
|
+
pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
|
|
428
|
+
full_forecast = (full_forecast - flipped_full_forecast) / 2
|
|
429
|
+
|
|
430
|
+
if fc.use_continuous_quantile_head:
|
|
431
|
+
for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:
|
|
432
|
+
full_forecast[:, :, quantile_index] = (
|
|
433
|
+
quantile_spreads[:, : fc.max_horizon, quantile_index]
|
|
434
|
+
- quantile_spreads[:, : fc.max_horizon, 5]
|
|
435
|
+
+ full_forecast[:, : fc.max_horizon, 5]
|
|
436
|
+
)
|
|
437
|
+
full_forecast = full_forecast[:, :horizon, :]
|
|
438
|
+
|
|
439
|
+
if fc.return_backcast:
|
|
440
|
+
full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(
|
|
441
|
+
batch_size, -1, self.model.q
|
|
442
|
+
)
|
|
443
|
+
full_forecast = torch.cat([full_backcast, full_forecast], dim=1)
|
|
444
|
+
|
|
445
|
+
if fc.fix_quantile_crossing:
|
|
446
|
+
for i in [4, 3, 2, 1]:
|
|
447
|
+
full_forecast[:, :, i] = torch.where(
|
|
448
|
+
full_forecast[:, :, i] < full_forecast[:, :, i + 1],
|
|
449
|
+
full_forecast[:, :, i],
|
|
450
|
+
full_forecast[:, :, i + 1],
|
|
451
|
+
)
|
|
452
|
+
for i in [6, 7, 8, 9]:
|
|
453
|
+
full_forecast[:, :, i] = torch.where(
|
|
454
|
+
full_forecast[:, :, i] > full_forecast[:, :, i - 1],
|
|
455
|
+
full_forecast[:, :, i],
|
|
456
|
+
full_forecast[:, :, i - 1],
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
if fc.normalize_inputs:
|
|
460
|
+
full_forecast = revin(full_forecast, mu, sigma, reverse=True)
|
|
461
|
+
|
|
462
|
+
if is_positive is not None:
|
|
463
|
+
full_forecast = torch.where(
|
|
464
|
+
is_positive[..., None],
|
|
465
|
+
torch.maximum(full_forecast, torch.zeros_like(full_forecast)),
|
|
466
|
+
full_forecast,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
full_forecast = full_forecast.detach().cpu().numpy()
|
|
470
|
+
return full_forecast[..., 5], full_forecast
|
|
471
|
+
|
|
472
|
+
self.compiled_decode = _compiled_decode
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
timesfm/torch/dense.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Dense layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from .. import configs
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ResidualBlock(nn.Module):
|
|
24
|
+
"""Residual block with two linear layers and a linear residual connection."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: configs.ResidualBlockConfig):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.config = config
|
|
29
|
+
self.hidden_layer = nn.Linear(
|
|
30
|
+
in_features=config.input_dims,
|
|
31
|
+
out_features=config.hidden_dims,
|
|
32
|
+
bias=config.use_bias,
|
|
33
|
+
)
|
|
34
|
+
self.output_layer = nn.Linear(
|
|
35
|
+
in_features=config.hidden_dims,
|
|
36
|
+
out_features=config.output_dims,
|
|
37
|
+
bias=config.use_bias,
|
|
38
|
+
)
|
|
39
|
+
self.residual_layer = nn.Linear(
|
|
40
|
+
in_features=config.input_dims,
|
|
41
|
+
out_features=config.output_dims,
|
|
42
|
+
bias=config.use_bias,
|
|
43
|
+
)
|
|
44
|
+
if config.activation == "relu":
|
|
45
|
+
self.activation = nn.ReLU()
|
|
46
|
+
elif config.activation == "swish":
|
|
47
|
+
self.activation = nn.SiLU()
|
|
48
|
+
elif config.activation == "none":
|
|
49
|
+
self.activation = nn.Identity()
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"Activation: {config.activation} not supported.")
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
return self.output_layer(
|
|
55
|
+
self.activation(self.hidden_layer(x))
|
|
56
|
+
) + self.residual_layer(x)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RandomFourierFeatures(nn.Module):
|
|
60
|
+
"""Random Fourier features layer."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: configs.RandomFourierFeaturesConfig):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.config = config
|
|
65
|
+
|
|
66
|
+
if config.output_dims % 4 != 0:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
|
|
69
|
+
)
|
|
70
|
+
num_projected_features = config.output_dims // 4
|
|
71
|
+
|
|
72
|
+
self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))
|
|
73
|
+
self.projection_layer = nn.Linear(
|
|
74
|
+
in_features=config.input_dims,
|
|
75
|
+
out_features=num_projected_features,
|
|
76
|
+
bias=config.use_bias,
|
|
77
|
+
)
|
|
78
|
+
self.residual_layer = nn.Linear(
|
|
79
|
+
in_features=config.input_dims,
|
|
80
|
+
out_features=config.output_dims,
|
|
81
|
+
bias=config.use_bias,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
projected = self.projection_layer(x)
|
|
86
|
+
cos_features = torch.cos(projected)
|
|
87
|
+
sin_features = torch.sin(projected)
|
|
88
|
+
sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))
|
|
89
|
+
sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))
|
|
90
|
+
fourier_features = torch.cat(
|
|
91
|
+
[cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1
|
|
92
|
+
)
|
|
93
|
+
residual = self.residual_layer(x)
|
|
94
|
+
return fourier_features + residual
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Normalization layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RMSNorm(nn.Module):
|
|
22
|
+
"""RMS normalization."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
num_features: int,
|
|
27
|
+
*,
|
|
28
|
+
epsilon: float = 1e-6,
|
|
29
|
+
):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.scale = nn.Parameter(torch.zeros(num_features))
|
|
32
|
+
self.num_features = num_features
|
|
33
|
+
self.epsilon = epsilon
|
|
34
|
+
|
|
35
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
36
|
+
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
|
|
37
|
+
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
|
|
38
|
+
normed_inputs = normed_inputs * self.scale
|
|
39
|
+
return normed_inputs
|