tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,370 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transformer layers for TimesFM."""
16
+
17
+ import math
18
+ from typing import Callable
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+ from .. import configs
25
+ from . import normalization, util
26
+
27
+ LayerNorm = nn.LayerNorm
28
+ RMSNorm = normalization.RMSNorm
29
+ DecodeCache = util.DecodeCache
30
+
31
+
32
+ def make_attn_mask(
33
+ query_length: int,
34
+ num_all_masked_kv: torch.Tensor,
35
+ query_index_offset: torch.Tensor | None = None,
36
+ kv_length: int = 0,
37
+ ) -> torch.Tensor:
38
+ """Makes attention mask."""
39
+ if kv_length == 0:
40
+ kv_length = query_length
41
+
42
+ q_index = torch.arange(query_length, device=num_all_masked_kv.device)[
43
+ None, None, :, None
44
+ ]
45
+ if query_index_offset is not None:
46
+ q_index = q_index + query_index_offset[:, None, None, None]
47
+ kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[
48
+ None, None, None, :
49
+ ]
50
+ return torch.logical_and(
51
+ q_index >= kv_index,
52
+ kv_index >= num_all_masked_kv[:, None, None, None],
53
+ )
54
+
55
+
56
+ class RotaryPositionalEmbedding(nn.Module):
57
+ """Rotary positional embedding."""
58
+
59
+ def __init__(
60
+ self,
61
+ embedding_dims: int,
62
+ min_timescale: float = 1.0,
63
+ max_timescale: float = 10000.0,
64
+ ):
65
+ super().__init__()
66
+ self.embedding_dims = embedding_dims
67
+ self.min_timescale = min_timescale
68
+ self.max_timescale = max_timescale
69
+
70
+ def forward(
71
+ self,
72
+ inputs: torch.Tensor,
73
+ position: torch.Tensor | None = None,
74
+ ):
75
+ """Generates a JTensor of sinusoids with different frequencies."""
76
+ if self.embedding_dims != inputs.shape[-1]:
77
+ raise ValueError(
78
+ "The embedding dims of the rotary position embedding"
79
+ "must match the hidden dimension of the inputs."
80
+ )
81
+ half_embedding_dim = self.embedding_dims // 2
82
+ fraction = (
83
+ 2
84
+ * torch.arange(0, half_embedding_dim, device=inputs.device)
85
+ / self.embedding_dims
86
+ )
87
+ timescale = (
88
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
89
+ ).to(inputs.device)
90
+ if position is None:
91
+ seq_length = inputs.shape[1]
92
+ position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[
93
+ None, :
94
+ ]
95
+
96
+ if len(inputs.shape) == 4:
97
+ position = position[..., None, None]
98
+ timescale = timescale[None, None, None, :]
99
+ elif len(inputs.shape) == 3:
100
+ position = position[..., None]
101
+ timescale = timescale[None, None, :]
102
+ else:
103
+ raise ValueError("Inputs must be of rank 3 or 4.")
104
+
105
+ sinusoid_inp = position / timescale
106
+ sin = torch.sin(sinusoid_inp)
107
+ cos = torch.cos(sinusoid_inp)
108
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
109
+ first_part = first_half * cos - second_half * sin
110
+ second_part = second_half * cos + first_half * sin
111
+ return torch.cat([first_part, second_part], dim=-1)
112
+
113
+
114
+ def _dot_product_attention(
115
+ query,
116
+ key,
117
+ value,
118
+ mask=None,
119
+ ):
120
+ """Computes dot-product attention given query, key, and value."""
121
+ attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key)
122
+ if mask is not None:
123
+ attn_weights = torch.where(
124
+ mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2
125
+ )
126
+
127
+ attn_weights = F.softmax(attn_weights, dim=-1)
128
+
129
+ return torch.einsum("...hqk,...khd->...qhd", attn_weights, value)
130
+
131
+
132
+ def _torch_dot_product_attention(query, key, value, mask=None):
133
+ """
134
+ Performs the exact same (unscaled) attention as the above function,
135
+ but using the fast and fused F.scaled_dot_product_attention kernel.
136
+ """
137
+
138
+ # 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)
139
+ query = query.permute(0, 2, 1, 3)
140
+ key = key.permute(0, 2, 1, 3)
141
+ value = value.permute(0, 2, 1, 3)
142
+
143
+ # 2. Call the fused attention kernel
144
+ # - Pass the mask to `attn_mask`.
145
+ # - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.
146
+ output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)
147
+
148
+ # 3. Permute the output back to the original (B, L, H, D) layout
149
+ output = output.permute(0, 2, 1, 3)
150
+
151
+ return output
152
+
153
+
154
+ class PerDimScale(nn.Module):
155
+ """Per-dimension scaling."""
156
+
157
+ def __init__(self, num_dims: int):
158
+ super().__init__()
159
+ self.num_dims = num_dims
160
+ self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ scale_factor = (
164
+ 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
165
+ )
166
+ return x * scale_factor
167
+
168
+
169
+ class MultiHeadAttention(nn.Module):
170
+ """Multi-head attention."""
171
+
172
+ def __init__(
173
+ self,
174
+ num_heads: int,
175
+ in_features: int,
176
+ *,
177
+ use_per_dim_scale: bool = True,
178
+ use_rotary_position_embeddings: bool = True,
179
+ use_bias: bool = False,
180
+ attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,
181
+ qk_norm: str = "rms",
182
+ fuse_qkv: bool = False,
183
+ ):
184
+ super().__init__()
185
+ self.num_heads = num_heads
186
+ self.in_features = in_features
187
+ self.head_dim = in_features // num_heads
188
+ self.use_bias = use_bias
189
+ self.attention_fn = attention_fn
190
+ self.qk_norm = qk_norm
191
+ self.fuse_qkv = fuse_qkv
192
+
193
+ if self.in_features % self.num_heads != 0:
194
+ raise ValueError(
195
+ f"Memory dimension ({self.in_features}) must be divisible by "
196
+ f"'num_heads' heads ({self.num_heads})."
197
+ )
198
+
199
+ if self.fuse_qkv:
200
+ self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)
201
+ else:
202
+ self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)
203
+ self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)
204
+ self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)
205
+ self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)
206
+
207
+ if self.qk_norm == "rms":
208
+ self.query_ln = RMSNorm(self.head_dim)
209
+ self.key_ln = RMSNorm(self.head_dim)
210
+ else:
211
+ self.query_ln = nn.Identity()
212
+ self.key_ln = nn.Identity()
213
+
214
+ self.use_rotary_position_embeddings = use_rotary_position_embeddings
215
+ if self.use_rotary_position_embeddings:
216
+ self.rotary_position_embedding = RotaryPositionalEmbedding(
217
+ embedding_dims=self.head_dim,
218
+ )
219
+
220
+ self.use_per_dim_scale = use_per_dim_scale
221
+ if use_per_dim_scale:
222
+ self.per_dim_scale = PerDimScale(num_dims=self.head_dim)
223
+
224
+ def forward(
225
+ self,
226
+ inputs_q: torch.Tensor,
227
+ *,
228
+ decode_cache: DecodeCache | None = None,
229
+ patch_mask: torch.Tensor | None = None,
230
+ ) -> tuple[torch.Tensor, DecodeCache | None]:
231
+ b, n_patches, _ = inputs_q.shape
232
+ if patch_mask is None:
233
+ patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)
234
+
235
+ if self.fuse_qkv:
236
+ qkv = self.qkv_proj(inputs_q)
237
+ query, key, value = torch.chunk(qkv, 3, dim=-1)
238
+ query = query.view(b, n_patches, self.num_heads, self.head_dim)
239
+ key = key.view(b, n_patches, self.num_heads, self.head_dim)
240
+ value = value.view(b, n_patches, self.num_heads, self.head_dim)
241
+ else:
242
+ query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
243
+ key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
244
+ value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
245
+
246
+ if decode_cache is None:
247
+ num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
248
+ next_index = torch.zeros_like(num_masked, dtype=torch.int32)
249
+ else:
250
+ num_masked = (
251
+ torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
252
+ )
253
+ next_index = decode_cache.next_index.clone()
254
+
255
+ if self.use_rotary_position_embeddings:
256
+ position = (
257
+ torch.arange(n_patches, device=inputs_q.device)[None, :]
258
+ + next_index[:, None]
259
+ - num_masked[:, None]
260
+ )
261
+ query = self.rotary_position_embedding(query, position)
262
+ key = self.rotary_position_embedding(key, position)
263
+
264
+ query = self.query_ln(query)
265
+ key = self.key_ln(key)
266
+
267
+ if self.use_per_dim_scale:
268
+ query = self.per_dim_scale(query)
269
+
270
+ if decode_cache is not None:
271
+ _, decode_cache_size, _, _ = decode_cache.value.shape
272
+
273
+ start = decode_cache.next_index[0]
274
+ end = start + n_patches
275
+
276
+ # Perform a single, vectorized slice assignment for the entire batch.
277
+ # This is vastly more efficient than a Python for-loop.
278
+
279
+ decode_cache.key[:, start:end] = key
280
+ decode_cache.value[:, start:end] = value
281
+
282
+ key = decode_cache.key
283
+ value = decode_cache.value
284
+ decode_cache.next_index += n_patches
285
+ decode_cache.num_masked = num_masked
286
+ attn_mask = make_attn_mask(
287
+ query_length=n_patches,
288
+ num_all_masked_kv=num_masked,
289
+ query_index_offset=next_index,
290
+ kv_length=decode_cache_size,
291
+ )
292
+ else:
293
+ attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
294
+
295
+ x = self.attention_fn(
296
+ query,
297
+ key,
298
+ value,
299
+ mask=attn_mask,
300
+ )
301
+
302
+ x = x.reshape(b, n_patches, self.in_features)
303
+ out = self.out(x)
304
+ return out, decode_cache
305
+
306
+
307
+ class Transformer(nn.Module):
308
+ """Classic Transformer used in TimesFM."""
309
+
310
+ def __init__(self, config: configs.TransformerConfig):
311
+ super().__init__()
312
+ self.config = config
313
+
314
+ if config.attention_norm == "rms":
315
+ self.pre_attn_ln = RMSNorm(num_features=config.model_dims)
316
+ self.post_attn_ln = RMSNorm(num_features=config.model_dims)
317
+ else:
318
+ raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
319
+
320
+ self.attn = MultiHeadAttention(
321
+ num_heads=config.num_heads,
322
+ in_features=config.model_dims,
323
+ use_per_dim_scale=True,
324
+ use_rotary_position_embeddings=config.use_rotary_position_embeddings,
325
+ qk_norm=config.qk_norm,
326
+ fuse_qkv=config.fuse_qkv,
327
+ )
328
+
329
+ if config.feedforward_norm == "rms":
330
+ self.pre_ff_ln = RMSNorm(num_features=config.model_dims)
331
+ self.post_ff_ln = RMSNorm(num_features=config.model_dims)
332
+ else:
333
+ raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
334
+
335
+ self.ff0 = nn.Linear(
336
+ in_features=config.model_dims,
337
+ out_features=config.hidden_dims,
338
+ bias=config.use_bias,
339
+ )
340
+ self.ff1 = nn.Linear(
341
+ in_features=config.hidden_dims,
342
+ out_features=config.model_dims,
343
+ bias=config.use_bias,
344
+ )
345
+ if config.ff_activation == "relu":
346
+ self.activation = nn.ReLU()
347
+ elif config.ff_activation == "swish":
348
+ self.activation = nn.SiLU()
349
+ elif config.ff_activation == "none":
350
+ self.activation = nn.Identity()
351
+ else:
352
+ raise ValueError(f"Activation: {config.ff_activation} not supported.")
353
+
354
+ def forward(
355
+ self,
356
+ input_embeddings: torch.Tensor,
357
+ patch_mask: torch.Tensor,
358
+ decode_cache: DecodeCache | None = None,
359
+ ) -> tuple[torch.Tensor, DecodeCache | None]:
360
+ attn_output, decode_cache = self.attn(
361
+ inputs_q=self.pre_attn_ln(input_embeddings),
362
+ decode_cache=decode_cache,
363
+ patch_mask=patch_mask,
364
+ )
365
+ attn_output = self.post_attn_ln(attn_output) + input_embeddings
366
+ output_embeddings = (
367
+ self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
368
+ + attn_output
369
+ )
370
+ return output_embeddings, decode_cache
timesfm/torch/util.py ADDED
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """PyTorch utility functions for TimesFM layers."""
16
+
17
+ import dataclasses
18
+ import torch
19
+
20
+ _TOLERANCE = 1e-6
21
+
22
+
23
+ @dataclasses.dataclass(frozen=False)
24
+ class DecodeCache:
25
+ """Cache for decoding."""
26
+
27
+ next_index: torch.Tensor
28
+ num_masked: torch.Tensor
29
+ key: torch.Tensor
30
+ value: torch.Tensor
31
+
32
+
33
+ def update_running_stats(
34
+ n: torch.Tensor,
35
+ mu: torch.Tensor,
36
+ sigma: torch.Tensor,
37
+ x: torch.Tensor,
38
+ mask: torch.Tensor,
39
+ ) -> tuple[
40
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor],
41
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor],
42
+ ]:
43
+ """Updates the running stats."""
44
+ is_legit = torch.logical_not(mask)
45
+ inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
46
+
47
+ inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
48
+ inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
49
+ inc_mu = inc_mu_numerator / inc_n_safe
50
+ inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
51
+
52
+ inc_var_numerator = torch.sum(
53
+ ((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1
54
+ )
55
+ inc_var = inc_var_numerator / inc_n_safe
56
+ inc_var = torch.where(inc_n == 0, 0.0, inc_var)
57
+ inc_sigma = torch.sqrt(inc_var)
58
+
59
+ new_n = n + inc_n
60
+ new_n_safe = torch.where(new_n == 0, 1.0, new_n)
61
+
62
+ new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
63
+ new_mu = torch.where(new_n == 0, 0.0, new_mu)
64
+
65
+ term1 = n * sigma.pow(2)
66
+ term2 = inc_n * inc_sigma.pow(2)
67
+ term3 = n * (mu - new_mu).pow(2)
68
+ term4 = inc_n * (inc_mu - new_mu).pow(2)
69
+
70
+ new_var = (term1 + term2 + term3 + term4) / new_n_safe
71
+ new_var = torch.where(new_n == 0, 0.0, new_var)
72
+ new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
73
+
74
+ return (w := (new_n, new_mu, new_sigma), w)
75
+
76
+
77
+ def revin(
78
+ x: torch.Tensor,
79
+ mu: torch.Tensor,
80
+ sigma: torch.Tensor,
81
+ reverse: bool = False,
82
+ ):
83
+ """Reversible instance normalization."""
84
+ if len(mu.shape) == len(x.shape) - 1:
85
+ mu = mu[..., None]
86
+ sigma = sigma[..., None]
87
+ elif len(mu.shape) == len(x.shape) - 2:
88
+ mu = mu[..., None, None]
89
+ sigma = sigma[..., None, None]
90
+
91
+ if reverse:
92
+ return x * sigma + mu
93
+ else:
94
+ return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)