tsagentkit-timesfm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
timesfm/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TimesFM API."""
16
+
17
+ from .configs import ForecastConfig
18
+
19
+ try:
20
+ from .timesfm_2p5 import timesfm_2p5_torch
21
+ TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch
22
+ except ImportError:
23
+ pass
24
+
25
+ try:
26
+ from .timesfm_2p5 import timesfm_2p5_flax
27
+ TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax
28
+ except ImportError:
29
+ pass
timesfm/configs.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Abstract configs for TimesFM layers."""
16
+
17
+ import dataclasses
18
+ from typing import Literal
19
+
20
+
21
+ @dataclasses.dataclass(frozen=True)
22
+ class ForecastConfig:
23
+ """Options for forecasting.
24
+
25
+ Attributes:
26
+ max_context: The maximum context length. This is used by the complied decode
27
+ function at inference time during batched inference. Any input time series
28
+ with length less than max_context will be padded with zeros, and with
29
+ length greater than max_context will be truncated.
30
+ max_horizon: The maximum horizon length. This is used by the complied decode
31
+ function at inference time during batched inference. The compiled cached
32
+ decoding function will by default forecast till max_horizon.
33
+ normalize_inputs: Whether to normalize the inputs. This is useful when the
34
+ raw inputs are of extremely large or small magnitudes which may result in
35
+ numerical issues.
36
+ window_size: The window size for decomposed forecasting.
37
+ TODO(siriuz42):implement it.
38
+ per_core_batch_size: The batch size per core. Used at inference time during
39
+ batched inference when multiple GPU / TPU devices are used.
40
+ use_continuous_quantile_head: Whether to use a separate continuous quantile
41
+ head to avoid quantile collapsing.
42
+ force_flip_invariance: Whether to force flip invariance. TimesFM guarantees
43
+ that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag
44
+ extends it to a < 0 as well.
45
+ infer_is_positive: Whether to guarantee nonnegativity of the output if the
46
+ input is nonnegative.
47
+ fix_quantile_crossing: Whether to fix quantile crossing.
48
+ return_backcast: Whether to return backcast.
49
+ """
50
+
51
+ max_context: int = 0
52
+ max_horizon: int = 0
53
+ normalize_inputs: bool = False
54
+ window_size: int = 0
55
+ per_core_batch_size: int = 1
56
+ use_continuous_quantile_head: bool = False
57
+ force_flip_invariance: bool = True
58
+ infer_is_positive: bool = True
59
+ fix_quantile_crossing: bool = False
60
+ return_backcast: bool = False
61
+
62
+
63
+ @dataclasses.dataclass(frozen=True)
64
+ class ResidualBlockConfig:
65
+ """Framework-agnostic config for a residual block."""
66
+
67
+ input_dims: int
68
+ hidden_dims: int
69
+ output_dims: int
70
+ use_bias: bool
71
+ activation: Literal["relu", "swish", "none"]
72
+
73
+
74
+ @dataclasses.dataclass(frozen=True)
75
+ class RandomFourierFeaturesConfig:
76
+ """Framework-agnostic config for random fourier features."""
77
+
78
+ input_dims: int
79
+ output_dims: int
80
+ projection_stddev: float
81
+ use_bias: bool
82
+
83
+
84
+ @dataclasses.dataclass(frozen=True)
85
+ class TransformerConfig:
86
+ """Framework-agnostic config for a transformer."""
87
+
88
+ model_dims: int
89
+ hidden_dims: int
90
+ num_heads: int
91
+ attention_norm: Literal["rms"]
92
+ feedforward_norm: Literal["rms"]
93
+ qk_norm: Literal["rms", "none"]
94
+ use_bias: bool
95
+ use_rotary_position_embeddings: bool
96
+ ff_activation: Literal["relu", "swish", "none"]
97
+ fuse_qkv: bool
98
+
99
+
100
+ @dataclasses.dataclass(frozen=True)
101
+ class StackedTransformersConfig:
102
+ """Framework-agnostic config for a stacked transformers."""
103
+
104
+ num_layers: int
105
+ transformer: TransformerConfig
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
timesfm/flax/dense.py ADDED
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dense layers for TimesFM."""
16
+
17
+ from flax import nnx
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import jaxtyping
21
+
22
+ from .. import configs
23
+
24
+ Array = jaxtyping.Array
25
+ Bool = jaxtyping.Bool
26
+ Float = jaxtyping.Float
27
+ Integer = jaxtyping.Integer
28
+ Num = jaxtyping.Num
29
+
30
+ ResidualBlockConfig = configs.ResidualBlockConfig
31
+ RandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig
32
+
33
+
34
+ class ResidualBlock(nnx.Module):
35
+ """Residual block with two linear layers and a linear residual connection."""
36
+
37
+ def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
38
+ self.config = config
39
+ self.hidden_layer = nnx.Linear(
40
+ in_features=config.input_dims,
41
+ out_features=config.hidden_dims,
42
+ use_bias=config.use_bias,
43
+ rngs=rngs,
44
+ )
45
+ self.output_layer = nnx.Linear(
46
+ in_features=config.hidden_dims,
47
+ out_features=config.output_dims,
48
+ use_bias=config.use_bias,
49
+ rngs=rngs,
50
+ )
51
+ self.residual_layer = nnx.Linear(
52
+ in_features=config.input_dims,
53
+ out_features=config.output_dims,
54
+ use_bias=config.use_bias,
55
+ rngs=rngs,
56
+ )
57
+ if config.activation == "relu":
58
+ self.activation = jax.nn.relu
59
+ elif config.activation == "swish":
60
+ self.activation = jax.nn.swish
61
+ elif config.activation == "none":
62
+ self.activation = lambda x: x
63
+ else:
64
+ raise ValueError(f"Activation: {config.activation} not supported.")
65
+
66
+ def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
67
+ return self.output_layer(
68
+ self.activation(self.hidden_layer(x))
69
+ ) + self.residual_layer(x)
70
+
71
+
72
+ class RandomFourierFeatures(nnx.Module):
73
+ """Random Fourier features layer."""
74
+
75
+ __data__ = ("phrase_shifts",)
76
+
77
+ def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):
78
+ self.config = config
79
+
80
+ if config.output_dims % 4 != 0:
81
+ raise ValueError(
82
+ f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
83
+ )
84
+ num_projected_features = config.output_dims // 4
85
+
86
+ self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))
87
+ self.projection_layer = nnx.Linear(
88
+ in_features=config.input_dims,
89
+ out_features=num_projected_features,
90
+ use_bias=config.use_bias,
91
+ rngs=rngs,
92
+ )
93
+ self.residual_layer = nnx.Linear(
94
+ in_features=config.input_dims,
95
+ out_features=config.output_dims,
96
+ use_bias=config.use_bias,
97
+ rngs=rngs,
98
+ )
99
+
100
+ def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
101
+ projected = self.projection_layer(x)
102
+ cos_features = jnp.cos(projected)
103
+ sin_features = jnp.sin(projected)
104
+ sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))
105
+ sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))
106
+ fourier_features = jnp.concatenate(
107
+ [cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1
108
+ )
109
+ residual = self.residual_layer(x)
110
+ return fourier_features + residual
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Normalization layers for TimesFM."""
16
+
17
+ from flax import nnx
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import jaxtyping
21
+
22
+ Array = jaxtyping.Array
23
+ Bool = jaxtyping.Bool
24
+ Float = jaxtyping.Float
25
+ Integer = jaxtyping.Integer
26
+ Num = jaxtyping.Num
27
+
28
+
29
+ class RMSNorm(nnx.Module):
30
+ """RMS normalization."""
31
+
32
+ __data__ = ("scale",)
33
+
34
+ def __init__(
35
+ self,
36
+ num_features: int,
37
+ *,
38
+ epsilon: float = 1e-6,
39
+ rngs=nnx.Rngs(42),
40
+ ):
41
+ del rngs
42
+ self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))
43
+ self.num_features = num_features
44
+ self.epsilon = epsilon
45
+
46
+ def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
47
+ var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)
48
+ normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)
49
+ normed_inputs *= self.scale
50
+ return normed_inputs
51
+
52
+
53
+ class LayerNorm(nnx.Module):
54
+ """Layer normalization replica of LayerNorm."""
55
+
56
+ __data__ = ("scale", "bias")
57
+
58
+ def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):
59
+ del rngs
60
+ self.scale = nnx.Param(jnp.ones(shape=(num_features,)))
61
+ self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))
62
+ self.num_features = num_features
63
+ self.epsilon = epsilon
64
+
65
+ def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
66
+ mean = jnp.mean(inputs, axis=-1, keepdims=True)
67
+ var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
68
+ normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
69
+ normed_inputs *= self.scale
70
+ normed_inputs += self.bias
71
+ return normed_inputs
@@ -0,0 +1,356 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transformer layers for TimesFM."""
16
+
17
+ import functools
18
+ from typing import Callable
19
+
20
+ from flax import nnx
21
+ from flax.nnx.nn import linear
22
+ import jax
23
+ from jax import lax
24
+ import jax.numpy as jnp
25
+ import jaxtyping
26
+
27
+ from .. import configs
28
+ from . import normalization, util
29
+
30
+ Array = jaxtyping.Array
31
+ Bool = jaxtyping.Bool
32
+ Float = jaxtyping.Float
33
+ Integer = jaxtyping.Integer
34
+ Num = jaxtyping.Num
35
+ LayerNorm = normalization.LayerNorm
36
+ RMSNorm = normalization.RMSNorm
37
+ LinearGeneral = linear.LinearGeneral
38
+ TransformerConfig = configs.TransformerConfig
39
+ DecodeCache = util.DecodeCache
40
+
41
+
42
+ @functools.partial(
43
+ jax.jit,
44
+ static_argnames=("query_length", "kv_length"),
45
+ )
46
+ def make_attn_mask(
47
+ query_length: int,
48
+ num_all_masked_kv: Integer[Array, "b"],
49
+ query_index_offset: Integer[Array, "b"] | None = None,
50
+ kv_length: int = 0,
51
+ ) -> Bool[Array, "b 1 q n"]:
52
+ """Makes attention mask."""
53
+
54
+ if kv_length == 0:
55
+ kv_length = query_length
56
+
57
+ q_index = jnp.arange(query_length)[None, None, :, None]
58
+ if query_index_offset is not None:
59
+ q_index += query_index_offset[:, None, None, None]
60
+ kv_index = jnp.arange(kv_length)[None, None, None, :]
61
+ return jnp.logical_and(
62
+ q_index >= kv_index,
63
+ kv_index >= num_all_masked_kv[:, None, None, None],
64
+ )
65
+
66
+
67
+ class RotaryPositionalEmbedding(nnx.Module):
68
+ """Rotary positional embedding."""
69
+
70
+ def __init__(
71
+ self,
72
+ embedding_dims: int,
73
+ min_timescale: int = 1,
74
+ max_timescale: int = 10000,
75
+ ):
76
+ self.embedding_dims = embedding_dims
77
+ self.min_timescale = min_timescale
78
+ self.max_timescale = max_timescale
79
+
80
+ def __call__(
81
+ self,
82
+ inputs: Float[Array, "b ... d"],
83
+ position: Array | None = None,
84
+ ):
85
+ """Generates a JTensor of sinusoids with different frequencies."""
86
+ if self.embedding_dims != inputs.shape[-1]:
87
+ raise ValueError(
88
+ "The embedding dims of the rotary position embedding"
89
+ "must match the hidden dimension of the inputs."
90
+ )
91
+ half_embedding_dim = self.embedding_dims // 2
92
+ fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
93
+ timescale = (
94
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
95
+ )
96
+ if position is None:
97
+ seq_length = inputs.shape[1]
98
+ position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]
99
+ if len(inputs.shape) == 4:
100
+ position = position[..., None, None]
101
+ timescale = timescale[None, None, None, :]
102
+ elif len(inputs.shape) == 3:
103
+ position = position[..., None]
104
+ timescale = timescale[None, None, :]
105
+ else:
106
+ raise ValueError("Inputs must be of rank 3 or 4.")
107
+ sinusoid_inp = position / timescale
108
+ sin = jnp.sin(sinusoid_inp)
109
+ cos = jnp.cos(sinusoid_inp)
110
+ first_half, second_half = jnp.split(inputs, 2, axis=-1)
111
+ first_part = first_half * cos - second_half * sin
112
+ second_part = second_half * cos + first_half * sin
113
+ first_part = first_part.astype(None)
114
+ second_part = second_part.astype(None)
115
+ return jnp.concatenate([first_part, second_part], axis=-1)
116
+
117
+
118
+ class PerDimScale(nnx.Module):
119
+ """Per-dimension scaling."""
120
+
121
+ __data__ = ("per_dim_scale",)
122
+
123
+ def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
124
+ del rngs
125
+ self.num_dims = num_dims
126
+ self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))
127
+
128
+ def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
129
+ return x * (
130
+ 1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)
131
+ )
132
+
133
+
134
+ class MultiHeadAttention(nnx.Module):
135
+ """Multi-head attention."""
136
+
137
+ def __init__(
138
+ self,
139
+ num_heads: int,
140
+ in_features: int,
141
+ *,
142
+ use_per_dim_scale: bool = True,
143
+ use_rotary_position_embeddings: bool = True,
144
+ use_bias: bool = False,
145
+ deterministic: bool | None = None,
146
+ attention_fn: Callable[..., Array] = nnx.dot_product_attention,
147
+ qk_norm: str = "rms",
148
+ rngs=nnx.Rngs(42),
149
+ ):
150
+ self.num_heads = num_heads
151
+ self.in_features = in_features
152
+ self.qkv_features = in_features
153
+ self.out_features = in_features
154
+ self.in_kv_features = in_features
155
+ self.deterministic = deterministic
156
+ self.use_bias = use_bias
157
+ self.attention_fn = attention_fn
158
+ self.qk_norm = qk_norm
159
+
160
+ if self.qkv_features % self.num_heads != 0:
161
+ raise ValueError(
162
+ f"Memory dimension ({self.qkv_features}) must be divisible by "
163
+ f"'num_heads' heads ({self.num_heads})."
164
+ )
165
+ self.head_dim = self.qkv_features // self.num_heads
166
+
167
+ linear_general = functools.partial(
168
+ LinearGeneral,
169
+ out_features=(self.num_heads, self.head_dim),
170
+ use_bias=self.use_bias,
171
+ )
172
+ # project inputs_q to multi-headed q/k/v
173
+ # dimensions are then [batch..., length, n_heads, n_features_per_head]
174
+ self.query = linear_general(self.in_features, rngs=rngs)
175
+ self.key = linear_general(self.in_kv_features, rngs=rngs)
176
+ self.value = linear_general(self.in_kv_features, rngs=rngs)
177
+
178
+ if self.qk_norm == "rms":
179
+ self.query_ln = RMSNorm(self.head_dim)
180
+ self.key_ln = RMSNorm(self.head_dim)
181
+ else:
182
+ self.query_ln = None
183
+ self.key_ln = None
184
+
185
+ self.out = LinearGeneral(
186
+ in_features=(self.num_heads, self.head_dim),
187
+ out_features=self.out_features,
188
+ axis=(-2, -1),
189
+ use_bias=self.use_bias,
190
+ rngs=rngs,
191
+ )
192
+
193
+ self.use_per_dim_scale = use_per_dim_scale
194
+ self.use_rotary_position_embeddings = use_rotary_position_embeddings
195
+ if self.use_rotary_position_embeddings:
196
+ self.rotary_position_embedding = RotaryPositionalEmbedding(
197
+ embedding_dims=self.head_dim,
198
+ )
199
+ else:
200
+ self.rotary_position_embedding = None
201
+
202
+ if use_per_dim_scale:
203
+ self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)
204
+ else:
205
+ self.per_dim_scale = None
206
+
207
+ def __call__(
208
+ self,
209
+ inputs_q: Array,
210
+ *,
211
+ decode_cache: DecodeCache | None = None,
212
+ patch_mask: Array | None = None,
213
+ deterministic: bool | None = None,
214
+ sow_weights: bool = False,
215
+ ) -> tuple[Float[Array, "b ... o"], DecodeCache | None]:
216
+ """Applies multi-head dot product attention on the input data."""
217
+ _, n_patches, input_in_features = inputs_q.shape
218
+ if input_in_features != self.in_features:
219
+ raise ValueError(
220
+ f"Incompatible input dimension, got {input_in_features} "
221
+ f"but module expects {self.in_features}."
222
+ )
223
+ if patch_mask is None:
224
+ patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)
225
+
226
+ # For query: rope -> ln -> per_dim_scale
227
+ query = self.query(inputs_q)
228
+ key = self.key(inputs_q)
229
+ value = self.value(inputs_q)
230
+
231
+ if decode_cache is None:
232
+ num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
233
+ next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)
234
+ else:
235
+ num_masked = (
236
+ jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
237
+ + decode_cache.num_masked
238
+ )
239
+ next_index = decode_cache.next_index
240
+
241
+ if self.use_rotary_position_embeddings:
242
+ position = (
243
+ jnp.arange(n_patches, dtype=jnp.int32)[None, :]
244
+ + next_index[:, None]
245
+ - num_masked[:, None]
246
+ )
247
+ query = self.rotary_position_embedding(query, position)
248
+ key = self.rotary_position_embedding(key, position)
249
+ if self.query_ln is not None:
250
+ query = self.query_ln(query)
251
+ if self.key_ln is not None:
252
+ key = self.key_ln(key)
253
+ if self.use_per_dim_scale:
254
+ query = self.per_dim_scale(query)
255
+
256
+ if decode_cache is not None:
257
+ # Cached decoding.
258
+ _, decode_cache_size, _, _ = decode_cache.value.shape
259
+ zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))
260
+ start_indices = (zero, next_index[0], zero, zero)
261
+ key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)
262
+ value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)
263
+ decode_cache.key = key
264
+ decode_cache.value = value
265
+ decode_cache.next_index = next_index + n_patches
266
+ decode_cache.num_masked = num_masked
267
+ attn_mask = make_attn_mask(
268
+ query_length=n_patches,
269
+ num_all_masked_kv=num_masked,
270
+ query_index_offset=next_index,
271
+ kv_length=decode_cache_size,
272
+ )
273
+ else:
274
+ # Training
275
+ attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
276
+
277
+ # apply attention
278
+ x = self.attention_fn(
279
+ query * jnp.sqrt(self.head_dim),
280
+ key,
281
+ value,
282
+ mask=attn_mask,
283
+ deterministic=deterministic,
284
+ module=self if sow_weights else None,
285
+ )
286
+ # back to the original inputs dimensions
287
+ out = self.out(x)
288
+ return out, decode_cache
289
+
290
+
291
+ class Transformer(nnx.Module):
292
+ """Classic Transformer used in TimesFM."""
293
+
294
+ def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
295
+ self.config = config
296
+
297
+ if config.attention_norm == "rms":
298
+ self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
299
+ self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
300
+ else:
301
+ raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
302
+
303
+ self.attn = MultiHeadAttention(
304
+ num_heads=config.num_heads,
305
+ in_features=config.model_dims,
306
+ use_per_dim_scale=True,
307
+ use_rotary_position_embeddings=config.use_rotary_position_embeddings,
308
+ qk_norm=config.qk_norm,
309
+ rngs=rngs,
310
+ )
311
+
312
+ if config.feedforward_norm == "rms":
313
+ self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
314
+ self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
315
+ else:
316
+ raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
317
+ self.ff0 = nnx.Linear(
318
+ in_features=config.model_dims,
319
+ out_features=config.hidden_dims,
320
+ use_bias=config.use_bias,
321
+ rngs=rngs,
322
+ )
323
+ self.ff1 = nnx.Linear(
324
+ in_features=config.hidden_dims,
325
+ out_features=config.model_dims,
326
+ use_bias=config.use_bias,
327
+ rngs=rngs,
328
+ )
329
+ if config.ff_activation == "relu":
330
+ self.activation = jax.nn.relu
331
+ elif config.ff_activation == "swish":
332
+ self.activation = jax.nn.swish
333
+ elif config.ff_activation == "none":
334
+ self.activation = lambda x: x
335
+ else:
336
+ raise ValueError(f"Activation: {config.ff_activation} not supported.")
337
+
338
+ def __call__(
339
+ self,
340
+ input_embeddings: Float[Array, "b n d"],
341
+ patch_mask: Bool[Array, "b n"],
342
+ decode_cache: DecodeCache | None = None,
343
+ ) -> tuple[Float[Array, "b n d"], DecodeCache | None]:
344
+ attn_output, decode_cache = self.attn(
345
+ inputs_q=self.pre_attn_ln(input_embeddings),
346
+ decode_cache=decode_cache,
347
+ patch_mask=patch_mask,
348
+ sow_weights=False,
349
+ deterministic=True,
350
+ )
351
+ attn_output = self.post_attn_ln(attn_output) + input_embeddings
352
+ output_embeddings = (
353
+ self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
354
+ + attn_output
355
+ )
356
+ return output_embeddings, decode_cache