tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,472 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TimesFM models."""
15
+
16
+ import dataclasses
17
+ import logging
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ from typing import Dict, Optional, Sequence, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from huggingface_hub import ModelHubMixin, hf_hub_download
26
+ from safetensors.torch import load_file, save_file
27
+ from torch import nn
28
+
29
+ from .. import configs
30
+ from ..torch import dense, transformer, util
31
+ from . import timesfm_2p5_base
32
+
33
+ revin = util.revin
34
+
35
+
36
+ class TimesFM_2p5_200M_torch_module(nn.Module):
37
+ """TimesFM 2.5 with 200M parameters."""
38
+
39
+ config = timesfm_2p5_base.TimesFM_2p5_200M_Definition()
40
+
41
+ def __init__(self):
42
+ super().__init__()
43
+
44
+ # Names constants.
45
+ self.p = self.config.input_patch_len # 32
46
+ self.o = self.config.output_patch_len # 128
47
+ self.os = self.config.output_quantile_len # 1024
48
+ self.m = self.o // self.p # 4
49
+ self.x = self.config.stacked_transformers.num_layers # 20
50
+ self.h = self.config.stacked_transformers.transformer.num_heads # 16
51
+ self.md = self.config.stacked_transformers.transformer.model_dims # 1280
52
+ self.hd = self.md // self.h # 80
53
+ self.q = len(self.config.quantiles) + 1 # 10
54
+ self.aridx = self.config.decode_index # 5
55
+
56
+ # Layers.
57
+ self.tokenizer = dense.ResidualBlock(self.config.tokenizer)
58
+ self.stacked_xf = nn.ModuleList(
59
+ [
60
+ transformer.Transformer(self.config.stacked_transformers.transformer)
61
+ for _ in range(self.x)
62
+ ]
63
+ )
64
+ self.output_projection_point = dense.ResidualBlock(
65
+ self.config.output_projection_point
66
+ )
67
+ self.output_projection_quantiles = dense.ResidualBlock(
68
+ self.config.output_projection_quantiles
69
+ )
70
+
71
+ # Device.
72
+ if torch.cuda.is_available():
73
+ self.device = torch.device("cuda:0")
74
+ self.device_count = torch.cuda.device_count()
75
+ else:
76
+ self.device = torch.device("cpu")
77
+ self.device_count = 1
78
+
79
+ def load_checkpoint(self, path: str, **kwargs):
80
+ """Loads a PyTorch TimesFM model from a checkpoint."""
81
+ tensors = load_file(path)
82
+ self.load_state_dict(tensors, strict=True)
83
+ self.to(self.device)
84
+ torch_compile = True
85
+ if "torch_compile" in kwargs:
86
+ torch_compile = kwargs["torch_compile"]
87
+ if torch_compile:
88
+ print("Compiling model...")
89
+ self = torch.compile(self)
90
+
91
+ self.eval()
92
+
93
+ def forward(
94
+ self,
95
+ inputs: torch.Tensor,
96
+ masks: torch.Tensor,
97
+ decode_caches: list[util.DecodeCache] | None = None,
98
+ ):
99
+ tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
100
+ input_embeddings = self.tokenizer(tokenizer_inputs)
101
+
102
+ if decode_caches is None:
103
+ decode_caches = [None] * self.x
104
+
105
+ output_embeddings = input_embeddings
106
+ new_decode_caches = []
107
+ for i, layer in enumerate(self.stacked_xf):
108
+ output_embeddings, new_cache = layer(
109
+ output_embeddings, masks[..., -1], decode_caches[i]
110
+ )
111
+ new_decode_caches.append(new_cache)
112
+ output_ts = self.output_projection_point(output_embeddings)
113
+ output_quantile_spread = self.output_projection_quantiles(output_embeddings)
114
+
115
+ return (
116
+ input_embeddings,
117
+ output_embeddings,
118
+ output_ts,
119
+ output_quantile_spread,
120
+ ), new_decode_caches
121
+
122
+ def decode(self, horizon: int, inputs, masks):
123
+ """Decodes the time series."""
124
+
125
+ with torch.no_grad():
126
+ batch_size, context = inputs.shape[0], inputs.shape[1]
127
+ num_decode_steps = (horizon - 1) // self.o
128
+ num_input_patches = context // self.p
129
+ decode_cache_size = num_input_patches + num_decode_steps * self.m
130
+
131
+ # Prefill
132
+ patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
133
+ patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
134
+
135
+ # running stats
136
+ n = torch.zeros(batch_size, device=inputs.device)
137
+ mu = torch.zeros(batch_size, device=inputs.device)
138
+ sigma = torch.zeros(batch_size, device=inputs.device)
139
+ patch_mu = []
140
+ patch_sigma = []
141
+ for i in range(num_input_patches):
142
+ (n, mu, sigma), _ = util.update_running_stats(
143
+ n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
144
+ )
145
+ patch_mu.append(mu)
146
+ patch_sigma.append(sigma)
147
+ last_n, last_mu, last_sigma = n, mu, sigma
148
+ context_mu = torch.stack(patch_mu, dim=1)
149
+ context_sigma = torch.stack(patch_sigma, dim=1)
150
+
151
+ decode_caches = [
152
+ util.DecodeCache(
153
+ next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
154
+ num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
155
+ key=torch.zeros(
156
+ batch_size,
157
+ decode_cache_size,
158
+ self.h,
159
+ self.hd,
160
+ device=inputs.device,
161
+ ),
162
+ value=torch.zeros(
163
+ batch_size,
164
+ decode_cache_size,
165
+ self.h,
166
+ self.hd,
167
+ device=inputs.device,
168
+ ),
169
+ )
170
+ for _ in range(self.x)
171
+ ]
172
+
173
+ normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
174
+ normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
175
+ (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(
176
+ normed_inputs, patched_masks, decode_caches
177
+ )
178
+ renormed_outputs = torch.reshape(
179
+ revin(normed_outputs, context_mu, context_sigma, reverse=True),
180
+ (batch_size, -1, self.o, self.q),
181
+ )
182
+ renormed_quantile_spread = torch.reshape(
183
+ revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
184
+ (batch_size, -1, self.os, self.q),
185
+ )[:, -1, ...]
186
+
187
+ # Autogressive decode
188
+ ar_outputs = []
189
+ last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
190
+
191
+ for _ in range(num_decode_steps):
192
+ new_patched_input = torch.reshape(
193
+ last_renormed_output, (batch_size, self.m, self.p)
194
+ )
195
+ new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
196
+
197
+ n, mu, sigma = last_n, last_mu, last_sigma
198
+ new_mus, new_sigmas = [], []
199
+ for i in range(self.m):
200
+ (n, mu, sigma), _ = util.update_running_stats(
201
+ n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
202
+ )
203
+ new_mus.append(mu)
204
+ new_sigmas.append(sigma)
205
+ last_n, last_mu, last_sigma = n, mu, sigma
206
+ new_mu = torch.stack(new_mus, dim=1)
207
+ new_sigma = torch.stack(new_sigmas, dim=1)
208
+
209
+ new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
210
+ (_, _, new_normed_output, _), decode_caches = self(
211
+ new_normed_input, new_mask, decode_caches
212
+ )
213
+
214
+ new_renormed_output = torch.reshape(
215
+ revin(new_normed_output, new_mu, new_sigma, reverse=True),
216
+ (batch_size, self.m, self.o, self.q),
217
+ )
218
+ ar_outputs.append(new_renormed_output[:, -1, ...])
219
+ last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
220
+
221
+ if num_decode_steps > 0:
222
+ ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
223
+ else:
224
+ ar_renormed_outputs = None
225
+
226
+ return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
227
+
228
+ def forecast_naive(
229
+ self, horizon: int, inputs: Sequence[np.ndarray]
230
+ ) -> list[np.ndarray]:
231
+ """Forecasts the time series.
232
+
233
+ This is a naive implementation for debugging purposes. No forecasting
234
+ flags are used here. Forecasting quality can be subpar.
235
+
236
+ Args:
237
+ horizon: The number of time points to forecast.
238
+ inputs: A sequence of numpy arrays, each representing a time series to
239
+ query forecast for.
240
+
241
+ Returns:
242
+ A list of numpy arrays of forecasts.
243
+ """
244
+ outputs = []
245
+ for each_input in inputs:
246
+ input_t = torch.tensor(each_input, dtype=torch.float32)
247
+ mask = torch.zeros_like(input_t, dtype=torch.bool)
248
+ len_front_mask = self.p - (len(each_input) % self.p)
249
+ if len_front_mask < self.p:
250
+ input_t = torch.cat(
251
+ [torch.zeros(len_front_mask, dtype=torch.float32), input_t], dim=0
252
+ )
253
+ mask = torch.cat([torch.ones(len_front_mask, dtype=torch.bool), mask], dim=0)
254
+ input_t = input_t[None, ...]
255
+ mask = mask[None, ...]
256
+ t_pf, _, t_ar = self.decode(horizon, input_t, mask)
257
+ to_concat = [t_pf[:, -1, ...]]
258
+ if t_ar is not None:
259
+ to_concat.append(t_ar.reshape(1, -1, self.q))
260
+ torch_forecast = torch.cat(to_concat, dim=1)[..., :horizon]
261
+ torch_forecast = torch_forecast.squeeze(0)
262
+ outputs.append(torch_forecast.detach().cpu().numpy())
263
+ return outputs
264
+
265
+
266
+ class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, ModelHubMixin):
267
+ """PyTorch implementation of TimesFM 2.5 with 200M parameters."""
268
+
269
+ model: nn.Module = TimesFM_2p5_200M_torch_module()
270
+
271
+ @classmethod
272
+ def _from_pretrained(
273
+ cls,
274
+ *,
275
+ model_id: str = "google/timesfm-2.5-200m-pytorch",
276
+ revision: Optional[str],
277
+ cache_dir: Optional[Union[str, Path]],
278
+ force_download: bool = True,
279
+ proxies: Optional[Dict] = None,
280
+ resume_download: Optional[bool] = None,
281
+ local_files_only: bool,
282
+ token: Optional[Union[str, bool]],
283
+ **model_kwargs,
284
+ ):
285
+ """
286
+ Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
287
+ Face Hub. This method is the backend for the `from_pretrained` class
288
+ method provided by `ModelHubMixin`.
289
+ """
290
+ # Create an instance of the model wrapper class.
291
+ instance = cls(**model_kwargs)
292
+ # Download the config file for hf tracking.
293
+ _ = hf_hub_download(
294
+ repo_id="google/timesfm-2.5-200m-pytorch",
295
+ filename="config.json",
296
+ force_download=True,
297
+ )
298
+ print("Downloaded.")
299
+
300
+ # Determine the path to the model weights.
301
+ model_file_path = ""
302
+ if os.path.isdir(model_id):
303
+ logging.info("Loading checkpoint from local directory: %s", model_id)
304
+ model_file_path = os.path.join(model_id, "model.safetensors")
305
+ if not os.path.exists(model_file_path):
306
+ raise FileNotFoundError(f"model.safetensors not found in directory {model_id}")
307
+ else:
308
+ logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
309
+ model_file_path = hf_hub_download(
310
+ repo_id=model_id,
311
+ filename="model.safetensors",
312
+ revision=revision,
313
+ cache_dir=cache_dir,
314
+ force_download=force_download,
315
+ proxies=proxies,
316
+ resume_download=resume_download,
317
+ token=token,
318
+ local_files_only=local_files_only,
319
+ )
320
+
321
+ logging.info("Loading checkpoint from: %s", model_file_path)
322
+ # Load the weights into the model.
323
+ instance.model.load_checkpoint(model_file_path, **model_kwargs)
324
+ return instance
325
+
326
+ def _save_pretrained(self, save_directory: Union[str, Path]):
327
+ """
328
+ Saves the model's state dictionary to a safetensors file. This method
329
+ is called by the `save_pretrained` method from `ModelHubMixin`.
330
+ """
331
+ if not os.path.exists(save_directory):
332
+ os.makedirs(save_directory)
333
+
334
+ weights_path = os.path.join(save_directory, "model.safetensors")
335
+ save_file(self.model.state_dict(), weights_path)
336
+
337
+ def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:
338
+ """Attempts to compile the model for fast decoding.
339
+
340
+ See configs.ForecastConfig for more details on the supported flags.
341
+
342
+ Args:
343
+ forecast_config: Configuration for forecasting flags.
344
+ **kwargs: Additional keyword arguments to pass to model.compile().
345
+ """
346
+ self.global_batch_size = (
347
+ forecast_config.per_core_batch_size * self.model.device_count
348
+ )
349
+
350
+ # Shortcut.
351
+ fc = forecast_config
352
+
353
+ if fc.max_context % self.model.p != 0:
354
+ logging.info(
355
+ "When compiling, max context needs to be multiple of the patch size"
356
+ " %d. Using max context = %d instead.",
357
+ self.model.p,
358
+ new_context := math.ceil(fc.max_context / self.model.p) * self.model.p,
359
+ )
360
+ fc = dataclasses.replace(fc, max_context=new_context)
361
+ if fc.max_horizon % self.model.o != 0:
362
+ logging.info(
363
+ "When compiling, max horizon needs to be multiple of the output patch"
364
+ " size %d. Using max horizon = %d instead.",
365
+ self.model.o,
366
+ new_horizon := math.ceil(fc.max_horizon / self.model.o) * self.model.o,
367
+ )
368
+ fc = dataclasses.replace(fc, max_horizon=new_horizon)
369
+ if fc.max_context + fc.max_horizon > self.model.config.context_limit:
370
+ raise ValueError(
371
+ "Context + horizon must be less than the context limit."
372
+ f" {fc.max_context} + {fc.max_horizon} >"
373
+ f" {self.model.config.context_limit}."
374
+ )
375
+ if fc.use_continuous_quantile_head and (fc.max_horizon > self.model.os):
376
+ raise ValueError(
377
+ f"Continuous quantile head is not supported for horizons > {self.model.os}."
378
+ )
379
+ self.forecast_config = fc
380
+
381
+ def _compiled_decode(horizon, inputs, masks):
382
+ if horizon > fc.max_horizon:
383
+ raise ValueError(
384
+ f"Horizon must be less than the max horizon. {horizon} > {fc.max_horizon}."
385
+ )
386
+
387
+ inputs = (
388
+ torch.from_numpy(np.array(inputs)).to(self.model.device).to(torch.float32)
389
+ )
390
+ masks = torch.from_numpy(np.array(masks)).to(self.model.device).to(torch.bool)
391
+ batch_size = inputs.shape[0]
392
+
393
+ if fc.infer_is_positive:
394
+ is_positive = torch.all(inputs >= 0, dim=-1, keepdim=True)
395
+ else:
396
+ is_positive = None
397
+
398
+ if fc.normalize_inputs:
399
+ mu = torch.mean(inputs, dim=-1, keepdim=True)
400
+ sigma = torch.std(inputs, dim=-1, keepdim=True)
401
+ inputs = revin(inputs, mu, sigma, reverse=False)
402
+ else:
403
+ mu, sigma = None, None
404
+
405
+ pf_outputs, quantile_spreads, ar_outputs = self.model.decode(
406
+ forecast_config.max_horizon, inputs, masks
407
+ )
408
+ to_cat = [pf_outputs[:, -1, ...]]
409
+ if ar_outputs is not None:
410
+ to_cat.append(ar_outputs.reshape(batch_size, -1, self.model.q))
411
+ full_forecast = torch.cat(to_cat, dim=1)
412
+
413
+ def flip_quantile_fn(x):
414
+ return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)
415
+
416
+ if fc.force_flip_invariance:
417
+ flipped_pf_outputs, flipped_quantile_spreads, flipped_ar_outputs = (
418
+ self.model.decode(forecast_config.max_horizon, -inputs, masks)
419
+ )
420
+ flipped_quantile_spreads = flip_quantile_fn(flipped_quantile_spreads)
421
+ flipped_pf_outputs = flip_quantile_fn(flipped_pf_outputs)
422
+ to_cat = [flipped_pf_outputs[:, -1, ...]]
423
+ if flipped_ar_outputs is not None:
424
+ to_cat.append(flipped_ar_outputs.reshape(batch_size, -1, self.model.q))
425
+ flipped_full_forecast = torch.cat(to_cat, dim=1)
426
+ quantile_spreads = (quantile_spreads - flipped_quantile_spreads) / 2
427
+ pf_outputs = (pf_outputs - flipped_pf_outputs) / 2
428
+ full_forecast = (full_forecast - flipped_full_forecast) / 2
429
+
430
+ if fc.use_continuous_quantile_head:
431
+ for quantile_index in [1, 2, 3, 4, 6, 7, 8, 9]:
432
+ full_forecast[:, :, quantile_index] = (
433
+ quantile_spreads[:, : fc.max_horizon, quantile_index]
434
+ - quantile_spreads[:, : fc.max_horizon, 5]
435
+ + full_forecast[:, : fc.max_horizon, 5]
436
+ )
437
+ full_forecast = full_forecast[:, :horizon, :]
438
+
439
+ if fc.return_backcast:
440
+ full_backcast = pf_outputs[:, :-1, : self.model.p, :].reshape(
441
+ batch_size, -1, self.model.q
442
+ )
443
+ full_forecast = torch.cat([full_backcast, full_forecast], dim=1)
444
+
445
+ if fc.fix_quantile_crossing:
446
+ for i in [4, 3, 2, 1]:
447
+ full_forecast[:, :, i] = torch.where(
448
+ full_forecast[:, :, i] < full_forecast[:, :, i + 1],
449
+ full_forecast[:, :, i],
450
+ full_forecast[:, :, i + 1],
451
+ )
452
+ for i in [6, 7, 8, 9]:
453
+ full_forecast[:, :, i] = torch.where(
454
+ full_forecast[:, :, i] > full_forecast[:, :, i - 1],
455
+ full_forecast[:, :, i],
456
+ full_forecast[:, :, i - 1],
457
+ )
458
+
459
+ if fc.normalize_inputs:
460
+ full_forecast = revin(full_forecast, mu, sigma, reverse=True)
461
+
462
+ if is_positive is not None:
463
+ full_forecast = torch.where(
464
+ is_positive[..., None],
465
+ torch.maximum(full_forecast, torch.zeros_like(full_forecast)),
466
+ full_forecast,
467
+ )
468
+
469
+ full_forecast = full_forecast.detach().cpu().numpy()
470
+ return full_forecast[..., 5], full_forecast
471
+
472
+ self.compiled_decode = _compiled_decode
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
timesfm/torch/dense.py ADDED
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dense layers for TimesFM."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from .. import configs
21
+
22
+
23
+ class ResidualBlock(nn.Module):
24
+ """Residual block with two linear layers and a linear residual connection."""
25
+
26
+ def __init__(self, config: configs.ResidualBlockConfig):
27
+ super().__init__()
28
+ self.config = config
29
+ self.hidden_layer = nn.Linear(
30
+ in_features=config.input_dims,
31
+ out_features=config.hidden_dims,
32
+ bias=config.use_bias,
33
+ )
34
+ self.output_layer = nn.Linear(
35
+ in_features=config.hidden_dims,
36
+ out_features=config.output_dims,
37
+ bias=config.use_bias,
38
+ )
39
+ self.residual_layer = nn.Linear(
40
+ in_features=config.input_dims,
41
+ out_features=config.output_dims,
42
+ bias=config.use_bias,
43
+ )
44
+ if config.activation == "relu":
45
+ self.activation = nn.ReLU()
46
+ elif config.activation == "swish":
47
+ self.activation = nn.SiLU()
48
+ elif config.activation == "none":
49
+ self.activation = nn.Identity()
50
+ else:
51
+ raise ValueError(f"Activation: {config.activation} not supported.")
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return self.output_layer(
55
+ self.activation(self.hidden_layer(x))
56
+ ) + self.residual_layer(x)
57
+
58
+
59
+ class RandomFourierFeatures(nn.Module):
60
+ """Random Fourier features layer."""
61
+
62
+ def __init__(self, config: configs.RandomFourierFeaturesConfig):
63
+ super().__init__()
64
+ self.config = config
65
+
66
+ if config.output_dims % 4 != 0:
67
+ raise ValueError(
68
+ f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
69
+ )
70
+ num_projected_features = config.output_dims // 4
71
+
72
+ self.phase_shifts = nn.Parameter(torch.zeros(2, num_projected_features))
73
+ self.projection_layer = nn.Linear(
74
+ in_features=config.input_dims,
75
+ out_features=num_projected_features,
76
+ bias=config.use_bias,
77
+ )
78
+ self.residual_layer = nn.Linear(
79
+ in_features=config.input_dims,
80
+ out_features=config.output_dims,
81
+ bias=config.use_bias,
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ projected = self.projection_layer(x)
86
+ cos_features = torch.cos(projected)
87
+ sin_features = torch.sin(projected)
88
+ sq_wave_1 = torch.sign(torch.sin(projected + self.phase_shifts[0, :]))
89
+ sq_wave_2 = torch.sign(torch.sin(projected + self.phase_shifts[1, :]))
90
+ fourier_features = torch.cat(
91
+ [cos_features, sin_features, sq_wave_1, sq_wave_2], dim=-1
92
+ )
93
+ residual = self.residual_layer(x)
94
+ return fourier_features + residual
@@ -0,0 +1,39 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Normalization layers for TimesFM."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class RMSNorm(nn.Module):
22
+ """RMS normalization."""
23
+
24
+ def __init__(
25
+ self,
26
+ num_features: int,
27
+ *,
28
+ epsilon: float = 1e-6,
29
+ ):
30
+ super().__init__()
31
+ self.scale = nn.Parameter(torch.zeros(num_features))
32
+ self.num_features = num_features
33
+ self.epsilon = epsilon
34
+
35
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
36
+ var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
37
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
38
+ normed_inputs = normed_inputs * self.scale
39
+ return normed_inputs