tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
timesfm/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""TimesFM API."""
|
|
16
|
+
|
|
17
|
+
from .configs import ForecastConfig
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from .timesfm_2p5 import timesfm_2p5_torch
|
|
21
|
+
TimesFM_2p5_200M_torch = timesfm_2p5_torch.TimesFM_2p5_200M_torch
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from .timesfm_2p5 import timesfm_2p5_flax
|
|
27
|
+
TimesFM_2p5_200M_flax = timesfm_2p5_flax.TimesFM_2p5_200M_flax
|
|
28
|
+
except ImportError:
|
|
29
|
+
pass
|
timesfm/configs.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Abstract configs for TimesFM layers."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
from typing import Literal
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass(frozen=True)
|
|
22
|
+
class ForecastConfig:
|
|
23
|
+
"""Options for forecasting.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
max_context: The maximum context length. This is used by the complied decode
|
|
27
|
+
function at inference time during batched inference. Any input time series
|
|
28
|
+
with length less than max_context will be padded with zeros, and with
|
|
29
|
+
length greater than max_context will be truncated.
|
|
30
|
+
max_horizon: The maximum horizon length. This is used by the complied decode
|
|
31
|
+
function at inference time during batched inference. The compiled cached
|
|
32
|
+
decoding function will by default forecast till max_horizon.
|
|
33
|
+
normalize_inputs: Whether to normalize the inputs. This is useful when the
|
|
34
|
+
raw inputs are of extremely large or small magnitudes which may result in
|
|
35
|
+
numerical issues.
|
|
36
|
+
window_size: The window size for decomposed forecasting.
|
|
37
|
+
TODO(siriuz42):implement it.
|
|
38
|
+
per_core_batch_size: The batch size per core. Used at inference time during
|
|
39
|
+
batched inference when multiple GPU / TPU devices are used.
|
|
40
|
+
use_continuous_quantile_head: Whether to use a separate continuous quantile
|
|
41
|
+
head to avoid quantile collapsing.
|
|
42
|
+
force_flip_invariance: Whether to force flip invariance. TimesFM guarantees
|
|
43
|
+
that TimesFM(aX + b) = a * TimesFM(x) + b for a >= 0 by default. This flag
|
|
44
|
+
extends it to a < 0 as well.
|
|
45
|
+
infer_is_positive: Whether to guarantee nonnegativity of the output if the
|
|
46
|
+
input is nonnegative.
|
|
47
|
+
fix_quantile_crossing: Whether to fix quantile crossing.
|
|
48
|
+
return_backcast: Whether to return backcast.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
max_context: int = 0
|
|
52
|
+
max_horizon: int = 0
|
|
53
|
+
normalize_inputs: bool = False
|
|
54
|
+
window_size: int = 0
|
|
55
|
+
per_core_batch_size: int = 1
|
|
56
|
+
use_continuous_quantile_head: bool = False
|
|
57
|
+
force_flip_invariance: bool = True
|
|
58
|
+
infer_is_positive: bool = True
|
|
59
|
+
fix_quantile_crossing: bool = False
|
|
60
|
+
return_backcast: bool = False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclasses.dataclass(frozen=True)
|
|
64
|
+
class ResidualBlockConfig:
|
|
65
|
+
"""Framework-agnostic config for a residual block."""
|
|
66
|
+
|
|
67
|
+
input_dims: int
|
|
68
|
+
hidden_dims: int
|
|
69
|
+
output_dims: int
|
|
70
|
+
use_bias: bool
|
|
71
|
+
activation: Literal["relu", "swish", "none"]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass(frozen=True)
|
|
75
|
+
class RandomFourierFeaturesConfig:
|
|
76
|
+
"""Framework-agnostic config for random fourier features."""
|
|
77
|
+
|
|
78
|
+
input_dims: int
|
|
79
|
+
output_dims: int
|
|
80
|
+
projection_stddev: float
|
|
81
|
+
use_bias: bool
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclasses.dataclass(frozen=True)
|
|
85
|
+
class TransformerConfig:
|
|
86
|
+
"""Framework-agnostic config for a transformer."""
|
|
87
|
+
|
|
88
|
+
model_dims: int
|
|
89
|
+
hidden_dims: int
|
|
90
|
+
num_heads: int
|
|
91
|
+
attention_norm: Literal["rms"]
|
|
92
|
+
feedforward_norm: Literal["rms"]
|
|
93
|
+
qk_norm: Literal["rms", "none"]
|
|
94
|
+
use_bias: bool
|
|
95
|
+
use_rotary_position_embeddings: bool
|
|
96
|
+
ff_activation: Literal["relu", "swish", "none"]
|
|
97
|
+
fuse_qkv: bool
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclasses.dataclass(frozen=True)
|
|
101
|
+
class StackedTransformersConfig:
|
|
102
|
+
"""Framework-agnostic config for a stacked transformers."""
|
|
103
|
+
|
|
104
|
+
num_layers: int
|
|
105
|
+
transformer: TransformerConfig
|
timesfm/flax/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
timesfm/flax/dense.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Dense layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
from flax import nnx
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import jaxtyping
|
|
21
|
+
|
|
22
|
+
from .. import configs
|
|
23
|
+
|
|
24
|
+
Array = jaxtyping.Array
|
|
25
|
+
Bool = jaxtyping.Bool
|
|
26
|
+
Float = jaxtyping.Float
|
|
27
|
+
Integer = jaxtyping.Integer
|
|
28
|
+
Num = jaxtyping.Num
|
|
29
|
+
|
|
30
|
+
ResidualBlockConfig = configs.ResidualBlockConfig
|
|
31
|
+
RandomFourierFeaturesConfig = configs.RandomFourierFeaturesConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ResidualBlock(nnx.Module):
|
|
35
|
+
"""Residual block with two linear layers and a linear residual connection."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: ResidualBlockConfig, *, rngs=nnx.Rngs(42)):
|
|
38
|
+
self.config = config
|
|
39
|
+
self.hidden_layer = nnx.Linear(
|
|
40
|
+
in_features=config.input_dims,
|
|
41
|
+
out_features=config.hidden_dims,
|
|
42
|
+
use_bias=config.use_bias,
|
|
43
|
+
rngs=rngs,
|
|
44
|
+
)
|
|
45
|
+
self.output_layer = nnx.Linear(
|
|
46
|
+
in_features=config.hidden_dims,
|
|
47
|
+
out_features=config.output_dims,
|
|
48
|
+
use_bias=config.use_bias,
|
|
49
|
+
rngs=rngs,
|
|
50
|
+
)
|
|
51
|
+
self.residual_layer = nnx.Linear(
|
|
52
|
+
in_features=config.input_dims,
|
|
53
|
+
out_features=config.output_dims,
|
|
54
|
+
use_bias=config.use_bias,
|
|
55
|
+
rngs=rngs,
|
|
56
|
+
)
|
|
57
|
+
if config.activation == "relu":
|
|
58
|
+
self.activation = jax.nn.relu
|
|
59
|
+
elif config.activation == "swish":
|
|
60
|
+
self.activation = jax.nn.swish
|
|
61
|
+
elif config.activation == "none":
|
|
62
|
+
self.activation = lambda x: x
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Activation: {config.activation} not supported.")
|
|
65
|
+
|
|
66
|
+
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
|
|
67
|
+
return self.output_layer(
|
|
68
|
+
self.activation(self.hidden_layer(x))
|
|
69
|
+
) + self.residual_layer(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class RandomFourierFeatures(nnx.Module):
|
|
73
|
+
"""Random Fourier features layer."""
|
|
74
|
+
|
|
75
|
+
__data__ = ("phrase_shifts",)
|
|
76
|
+
|
|
77
|
+
def __init__(self, config: RandomFourierFeaturesConfig, *, rngs=nnx.Rngs(42)):
|
|
78
|
+
self.config = config
|
|
79
|
+
|
|
80
|
+
if config.output_dims % 4 != 0:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Output dims must be a multiple of 4: {config.output_dims} % 4 != 0."
|
|
83
|
+
)
|
|
84
|
+
num_projected_features = config.output_dims // 4
|
|
85
|
+
|
|
86
|
+
self.phase_shifts = nnx.Param(jnp.zeros(shape=(2, num_projected_features)))
|
|
87
|
+
self.projection_layer = nnx.Linear(
|
|
88
|
+
in_features=config.input_dims,
|
|
89
|
+
out_features=num_projected_features,
|
|
90
|
+
use_bias=config.use_bias,
|
|
91
|
+
rngs=rngs,
|
|
92
|
+
)
|
|
93
|
+
self.residual_layer = nnx.Linear(
|
|
94
|
+
in_features=config.input_dims,
|
|
95
|
+
out_features=config.output_dims,
|
|
96
|
+
use_bias=config.use_bias,
|
|
97
|
+
rngs=rngs,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def __call__(self, x: Float[Array, "b ... i"]) -> Float[Array, "b ... o"]:
|
|
101
|
+
projected = self.projection_layer(x)
|
|
102
|
+
cos_features = jnp.cos(projected)
|
|
103
|
+
sin_features = jnp.sin(projected)
|
|
104
|
+
sq_wave_1 = jnp.sign(jnp.sin(projected + self.phase_shifts[0, :]))
|
|
105
|
+
sq_wave_2 = jnp.sign(jnp.sin(projected + self.phase_shifts[1, :]))
|
|
106
|
+
fourier_features = jnp.concatenate(
|
|
107
|
+
[cos_features, sin_features, sq_wave_1, sq_wave_2], axis=-1
|
|
108
|
+
)
|
|
109
|
+
residual = self.residual_layer(x)
|
|
110
|
+
return fourier_features + residual
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Normalization layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
from flax import nnx
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import jaxtyping
|
|
21
|
+
|
|
22
|
+
Array = jaxtyping.Array
|
|
23
|
+
Bool = jaxtyping.Bool
|
|
24
|
+
Float = jaxtyping.Float
|
|
25
|
+
Integer = jaxtyping.Integer
|
|
26
|
+
Num = jaxtyping.Num
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RMSNorm(nnx.Module):
|
|
30
|
+
"""RMS normalization."""
|
|
31
|
+
|
|
32
|
+
__data__ = ("scale",)
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
num_features: int,
|
|
37
|
+
*,
|
|
38
|
+
epsilon: float = 1e-6,
|
|
39
|
+
rngs=nnx.Rngs(42),
|
|
40
|
+
):
|
|
41
|
+
del rngs
|
|
42
|
+
self.scale = nnx.Param(jnp.zeros(shape=(num_features,)))
|
|
43
|
+
self.num_features = num_features
|
|
44
|
+
self.epsilon = epsilon
|
|
45
|
+
|
|
46
|
+
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
|
|
47
|
+
var = jnp.mean(jnp.square(inputs), axis=-1, keepdims=True)
|
|
48
|
+
normed_inputs = inputs * jax.lax.rsqrt(var + self.epsilon)
|
|
49
|
+
normed_inputs *= self.scale
|
|
50
|
+
return normed_inputs
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LayerNorm(nnx.Module):
|
|
54
|
+
"""Layer normalization replica of LayerNorm."""
|
|
55
|
+
|
|
56
|
+
__data__ = ("scale", "bias")
|
|
57
|
+
|
|
58
|
+
def __init__(self, num_features: int, *, epsilon: float = 1e-6, rngs=nnx.Rngs(42)):
|
|
59
|
+
del rngs
|
|
60
|
+
self.scale = nnx.Param(jnp.ones(shape=(num_features,)))
|
|
61
|
+
self.bias = nnx.Param(jnp.zeros(shape=(num_features,)))
|
|
62
|
+
self.num_features = num_features
|
|
63
|
+
self.epsilon = epsilon
|
|
64
|
+
|
|
65
|
+
def __call__(self, inputs: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
|
|
66
|
+
mean = jnp.mean(inputs, axis=-1, keepdims=True)
|
|
67
|
+
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
|
|
68
|
+
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
|
|
69
|
+
normed_inputs *= self.scale
|
|
70
|
+
normed_inputs += self.bias
|
|
71
|
+
return normed_inputs
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Transformer layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from flax.nnx.nn import linear
|
|
22
|
+
import jax
|
|
23
|
+
from jax import lax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
import jaxtyping
|
|
26
|
+
|
|
27
|
+
from .. import configs
|
|
28
|
+
from . import normalization, util
|
|
29
|
+
|
|
30
|
+
Array = jaxtyping.Array
|
|
31
|
+
Bool = jaxtyping.Bool
|
|
32
|
+
Float = jaxtyping.Float
|
|
33
|
+
Integer = jaxtyping.Integer
|
|
34
|
+
Num = jaxtyping.Num
|
|
35
|
+
LayerNorm = normalization.LayerNorm
|
|
36
|
+
RMSNorm = normalization.RMSNorm
|
|
37
|
+
LinearGeneral = linear.LinearGeneral
|
|
38
|
+
TransformerConfig = configs.TransformerConfig
|
|
39
|
+
DecodeCache = util.DecodeCache
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@functools.partial(
|
|
43
|
+
jax.jit,
|
|
44
|
+
static_argnames=("query_length", "kv_length"),
|
|
45
|
+
)
|
|
46
|
+
def make_attn_mask(
|
|
47
|
+
query_length: int,
|
|
48
|
+
num_all_masked_kv: Integer[Array, "b"],
|
|
49
|
+
query_index_offset: Integer[Array, "b"] | None = None,
|
|
50
|
+
kv_length: int = 0,
|
|
51
|
+
) -> Bool[Array, "b 1 q n"]:
|
|
52
|
+
"""Makes attention mask."""
|
|
53
|
+
|
|
54
|
+
if kv_length == 0:
|
|
55
|
+
kv_length = query_length
|
|
56
|
+
|
|
57
|
+
q_index = jnp.arange(query_length)[None, None, :, None]
|
|
58
|
+
if query_index_offset is not None:
|
|
59
|
+
q_index += query_index_offset[:, None, None, None]
|
|
60
|
+
kv_index = jnp.arange(kv_length)[None, None, None, :]
|
|
61
|
+
return jnp.logical_and(
|
|
62
|
+
q_index >= kv_index,
|
|
63
|
+
kv_index >= num_all_masked_kv[:, None, None, None],
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class RotaryPositionalEmbedding(nnx.Module):
|
|
68
|
+
"""Rotary positional embedding."""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
embedding_dims: int,
|
|
73
|
+
min_timescale: int = 1,
|
|
74
|
+
max_timescale: int = 10000,
|
|
75
|
+
):
|
|
76
|
+
self.embedding_dims = embedding_dims
|
|
77
|
+
self.min_timescale = min_timescale
|
|
78
|
+
self.max_timescale = max_timescale
|
|
79
|
+
|
|
80
|
+
def __call__(
|
|
81
|
+
self,
|
|
82
|
+
inputs: Float[Array, "b ... d"],
|
|
83
|
+
position: Array | None = None,
|
|
84
|
+
):
|
|
85
|
+
"""Generates a JTensor of sinusoids with different frequencies."""
|
|
86
|
+
if self.embedding_dims != inputs.shape[-1]:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"The embedding dims of the rotary position embedding"
|
|
89
|
+
"must match the hidden dimension of the inputs."
|
|
90
|
+
)
|
|
91
|
+
half_embedding_dim = self.embedding_dims // 2
|
|
92
|
+
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
|
|
93
|
+
timescale = (
|
|
94
|
+
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
|
|
95
|
+
)
|
|
96
|
+
if position is None:
|
|
97
|
+
seq_length = inputs.shape[1]
|
|
98
|
+
position = jnp.arange(seq_length, dtype=jnp.float32)[None, :]
|
|
99
|
+
if len(inputs.shape) == 4:
|
|
100
|
+
position = position[..., None, None]
|
|
101
|
+
timescale = timescale[None, None, None, :]
|
|
102
|
+
elif len(inputs.shape) == 3:
|
|
103
|
+
position = position[..., None]
|
|
104
|
+
timescale = timescale[None, None, :]
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("Inputs must be of rank 3 or 4.")
|
|
107
|
+
sinusoid_inp = position / timescale
|
|
108
|
+
sin = jnp.sin(sinusoid_inp)
|
|
109
|
+
cos = jnp.cos(sinusoid_inp)
|
|
110
|
+
first_half, second_half = jnp.split(inputs, 2, axis=-1)
|
|
111
|
+
first_part = first_half * cos - second_half * sin
|
|
112
|
+
second_part = second_half * cos + first_half * sin
|
|
113
|
+
first_part = first_part.astype(None)
|
|
114
|
+
second_part = second_part.astype(None)
|
|
115
|
+
return jnp.concatenate([first_part, second_part], axis=-1)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class PerDimScale(nnx.Module):
|
|
119
|
+
"""Per-dimension scaling."""
|
|
120
|
+
|
|
121
|
+
__data__ = ("per_dim_scale",)
|
|
122
|
+
|
|
123
|
+
def __init__(self, num_dims: int, *, rngs=nnx.Rngs(42)):
|
|
124
|
+
del rngs
|
|
125
|
+
self.num_dims = num_dims
|
|
126
|
+
self.per_dim_scale = nnx.Param(jnp.zeros(shape=(num_dims,)))
|
|
127
|
+
|
|
128
|
+
def __call__(self, x: Float[Array, "b ... d"]) -> Float[Array, "b ... d"]:
|
|
129
|
+
return x * (
|
|
130
|
+
1.442695041 / jnp.sqrt(self.num_dims) * jax.nn.softplus(self.per_dim_scale)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class MultiHeadAttention(nnx.Module):
|
|
135
|
+
"""Multi-head attention."""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
num_heads: int,
|
|
140
|
+
in_features: int,
|
|
141
|
+
*,
|
|
142
|
+
use_per_dim_scale: bool = True,
|
|
143
|
+
use_rotary_position_embeddings: bool = True,
|
|
144
|
+
use_bias: bool = False,
|
|
145
|
+
deterministic: bool | None = None,
|
|
146
|
+
attention_fn: Callable[..., Array] = nnx.dot_product_attention,
|
|
147
|
+
qk_norm: str = "rms",
|
|
148
|
+
rngs=nnx.Rngs(42),
|
|
149
|
+
):
|
|
150
|
+
self.num_heads = num_heads
|
|
151
|
+
self.in_features = in_features
|
|
152
|
+
self.qkv_features = in_features
|
|
153
|
+
self.out_features = in_features
|
|
154
|
+
self.in_kv_features = in_features
|
|
155
|
+
self.deterministic = deterministic
|
|
156
|
+
self.use_bias = use_bias
|
|
157
|
+
self.attention_fn = attention_fn
|
|
158
|
+
self.qk_norm = qk_norm
|
|
159
|
+
|
|
160
|
+
if self.qkv_features % self.num_heads != 0:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Memory dimension ({self.qkv_features}) must be divisible by "
|
|
163
|
+
f"'num_heads' heads ({self.num_heads})."
|
|
164
|
+
)
|
|
165
|
+
self.head_dim = self.qkv_features // self.num_heads
|
|
166
|
+
|
|
167
|
+
linear_general = functools.partial(
|
|
168
|
+
LinearGeneral,
|
|
169
|
+
out_features=(self.num_heads, self.head_dim),
|
|
170
|
+
use_bias=self.use_bias,
|
|
171
|
+
)
|
|
172
|
+
# project inputs_q to multi-headed q/k/v
|
|
173
|
+
# dimensions are then [batch..., length, n_heads, n_features_per_head]
|
|
174
|
+
self.query = linear_general(self.in_features, rngs=rngs)
|
|
175
|
+
self.key = linear_general(self.in_kv_features, rngs=rngs)
|
|
176
|
+
self.value = linear_general(self.in_kv_features, rngs=rngs)
|
|
177
|
+
|
|
178
|
+
if self.qk_norm == "rms":
|
|
179
|
+
self.query_ln = RMSNorm(self.head_dim)
|
|
180
|
+
self.key_ln = RMSNorm(self.head_dim)
|
|
181
|
+
else:
|
|
182
|
+
self.query_ln = None
|
|
183
|
+
self.key_ln = None
|
|
184
|
+
|
|
185
|
+
self.out = LinearGeneral(
|
|
186
|
+
in_features=(self.num_heads, self.head_dim),
|
|
187
|
+
out_features=self.out_features,
|
|
188
|
+
axis=(-2, -1),
|
|
189
|
+
use_bias=self.use_bias,
|
|
190
|
+
rngs=rngs,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.use_per_dim_scale = use_per_dim_scale
|
|
194
|
+
self.use_rotary_position_embeddings = use_rotary_position_embeddings
|
|
195
|
+
if self.use_rotary_position_embeddings:
|
|
196
|
+
self.rotary_position_embedding = RotaryPositionalEmbedding(
|
|
197
|
+
embedding_dims=self.head_dim,
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
self.rotary_position_embedding = None
|
|
201
|
+
|
|
202
|
+
if use_per_dim_scale:
|
|
203
|
+
self.per_dim_scale = PerDimScale(num_dims=self.head_dim, rngs=rngs)
|
|
204
|
+
else:
|
|
205
|
+
self.per_dim_scale = None
|
|
206
|
+
|
|
207
|
+
def __call__(
|
|
208
|
+
self,
|
|
209
|
+
inputs_q: Array,
|
|
210
|
+
*,
|
|
211
|
+
decode_cache: DecodeCache | None = None,
|
|
212
|
+
patch_mask: Array | None = None,
|
|
213
|
+
deterministic: bool | None = None,
|
|
214
|
+
sow_weights: bool = False,
|
|
215
|
+
) -> tuple[Float[Array, "b ... o"], DecodeCache | None]:
|
|
216
|
+
"""Applies multi-head dot product attention on the input data."""
|
|
217
|
+
_, n_patches, input_in_features = inputs_q.shape
|
|
218
|
+
if input_in_features != self.in_features:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Incompatible input dimension, got {input_in_features} "
|
|
221
|
+
f"but module expects {self.in_features}."
|
|
222
|
+
)
|
|
223
|
+
if patch_mask is None:
|
|
224
|
+
patch_mask = jnp.zeros_like(inputs_q.shape[:-1], dtype=jnp.bool)
|
|
225
|
+
|
|
226
|
+
# For query: rope -> ln -> per_dim_scale
|
|
227
|
+
query = self.query(inputs_q)
|
|
228
|
+
key = self.key(inputs_q)
|
|
229
|
+
value = self.value(inputs_q)
|
|
230
|
+
|
|
231
|
+
if decode_cache is None:
|
|
232
|
+
num_masked = jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
|
|
233
|
+
next_index = jnp.zeros_like(num_masked, dtype=jnp.int32)
|
|
234
|
+
else:
|
|
235
|
+
num_masked = (
|
|
236
|
+
jnp.sum(patch_mask.astype(jnp.int32), axis=-1, keepdims=False)
|
|
237
|
+
+ decode_cache.num_masked
|
|
238
|
+
)
|
|
239
|
+
next_index = decode_cache.next_index
|
|
240
|
+
|
|
241
|
+
if self.use_rotary_position_embeddings:
|
|
242
|
+
position = (
|
|
243
|
+
jnp.arange(n_patches, dtype=jnp.int32)[None, :]
|
|
244
|
+
+ next_index[:, None]
|
|
245
|
+
- num_masked[:, None]
|
|
246
|
+
)
|
|
247
|
+
query = self.rotary_position_embedding(query, position)
|
|
248
|
+
key = self.rotary_position_embedding(key, position)
|
|
249
|
+
if self.query_ln is not None:
|
|
250
|
+
query = self.query_ln(query)
|
|
251
|
+
if self.key_ln is not None:
|
|
252
|
+
key = self.key_ln(key)
|
|
253
|
+
if self.use_per_dim_scale:
|
|
254
|
+
query = self.per_dim_scale(query)
|
|
255
|
+
|
|
256
|
+
if decode_cache is not None:
|
|
257
|
+
# Cached decoding.
|
|
258
|
+
_, decode_cache_size, _, _ = decode_cache.value.shape
|
|
259
|
+
zero = jnp.array(0, dtype=lax.dtype(next_index.dtype))
|
|
260
|
+
start_indices = (zero, next_index[0], zero, zero)
|
|
261
|
+
key = lax.dynamic_update_slice(decode_cache.key, key, start_indices)
|
|
262
|
+
value = lax.dynamic_update_slice(decode_cache.value, value, start_indices)
|
|
263
|
+
decode_cache.key = key
|
|
264
|
+
decode_cache.value = value
|
|
265
|
+
decode_cache.next_index = next_index + n_patches
|
|
266
|
+
decode_cache.num_masked = num_masked
|
|
267
|
+
attn_mask = make_attn_mask(
|
|
268
|
+
query_length=n_patches,
|
|
269
|
+
num_all_masked_kv=num_masked,
|
|
270
|
+
query_index_offset=next_index,
|
|
271
|
+
kv_length=decode_cache_size,
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
# Training
|
|
275
|
+
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
|
|
276
|
+
|
|
277
|
+
# apply attention
|
|
278
|
+
x = self.attention_fn(
|
|
279
|
+
query * jnp.sqrt(self.head_dim),
|
|
280
|
+
key,
|
|
281
|
+
value,
|
|
282
|
+
mask=attn_mask,
|
|
283
|
+
deterministic=deterministic,
|
|
284
|
+
module=self if sow_weights else None,
|
|
285
|
+
)
|
|
286
|
+
# back to the original inputs dimensions
|
|
287
|
+
out = self.out(x)
|
|
288
|
+
return out, decode_cache
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class Transformer(nnx.Module):
|
|
292
|
+
"""Classic Transformer used in TimesFM."""
|
|
293
|
+
|
|
294
|
+
def __init__(self, config: TransformerConfig, *, rngs=nnx.Rngs(42)):
|
|
295
|
+
self.config = config
|
|
296
|
+
|
|
297
|
+
if config.attention_norm == "rms":
|
|
298
|
+
self.pre_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
|
|
299
|
+
self.post_attn_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
|
|
302
|
+
|
|
303
|
+
self.attn = MultiHeadAttention(
|
|
304
|
+
num_heads=config.num_heads,
|
|
305
|
+
in_features=config.model_dims,
|
|
306
|
+
use_per_dim_scale=True,
|
|
307
|
+
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
|
|
308
|
+
qk_norm=config.qk_norm,
|
|
309
|
+
rngs=rngs,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if config.feedforward_norm == "rms":
|
|
313
|
+
self.pre_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
|
|
314
|
+
self.post_ff_ln = RMSNorm(num_features=config.model_dims, rngs=rngs)
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
|
|
317
|
+
self.ff0 = nnx.Linear(
|
|
318
|
+
in_features=config.model_dims,
|
|
319
|
+
out_features=config.hidden_dims,
|
|
320
|
+
use_bias=config.use_bias,
|
|
321
|
+
rngs=rngs,
|
|
322
|
+
)
|
|
323
|
+
self.ff1 = nnx.Linear(
|
|
324
|
+
in_features=config.hidden_dims,
|
|
325
|
+
out_features=config.model_dims,
|
|
326
|
+
use_bias=config.use_bias,
|
|
327
|
+
rngs=rngs,
|
|
328
|
+
)
|
|
329
|
+
if config.ff_activation == "relu":
|
|
330
|
+
self.activation = jax.nn.relu
|
|
331
|
+
elif config.ff_activation == "swish":
|
|
332
|
+
self.activation = jax.nn.swish
|
|
333
|
+
elif config.ff_activation == "none":
|
|
334
|
+
self.activation = lambda x: x
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(f"Activation: {config.ff_activation} not supported.")
|
|
337
|
+
|
|
338
|
+
def __call__(
|
|
339
|
+
self,
|
|
340
|
+
input_embeddings: Float[Array, "b n d"],
|
|
341
|
+
patch_mask: Bool[Array, "b n"],
|
|
342
|
+
decode_cache: DecodeCache | None = None,
|
|
343
|
+
) -> tuple[Float[Array, "b n d"], DecodeCache | None]:
|
|
344
|
+
attn_output, decode_cache = self.attn(
|
|
345
|
+
inputs_q=self.pre_attn_ln(input_embeddings),
|
|
346
|
+
decode_cache=decode_cache,
|
|
347
|
+
patch_mask=patch_mask,
|
|
348
|
+
sow_weights=False,
|
|
349
|
+
deterministic=True,
|
|
350
|
+
)
|
|
351
|
+
attn_output = self.post_attn_ln(attn_output) + input_embeddings
|
|
352
|
+
output_embeddings = (
|
|
353
|
+
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
|
|
354
|
+
+ attn_output
|
|
355
|
+
)
|
|
356
|
+
return output_embeddings, decode_cache
|