tsagentkit-timesfm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timesfm/__init__.py +29 -0
- timesfm/configs.py +105 -0
- timesfm/flax/__init__.py +13 -0
- timesfm/flax/dense.py +110 -0
- timesfm/flax/normalization.py +71 -0
- timesfm/flax/transformer.py +356 -0
- timesfm/flax/util.py +107 -0
- timesfm/timesfm_2p5/timesfm_2p5_base.py +422 -0
- timesfm/timesfm_2p5/timesfm_2p5_flax.py +602 -0
- timesfm/timesfm_2p5/timesfm_2p5_torch.py +472 -0
- timesfm/torch/__init__.py +13 -0
- timesfm/torch/dense.py +94 -0
- timesfm/torch/normalization.py +39 -0
- timesfm/torch/transformer.py +370 -0
- timesfm/torch/util.py +94 -0
- timesfm/utils/xreg_lib.py +520 -0
- tsagentkit_timesfm-1.0.0.dist-info/METADATA +152 -0
- tsagentkit_timesfm-1.0.0.dist-info/RECORD +21 -0
- tsagentkit_timesfm-1.0.0.dist-info/WHEEL +5 -0
- tsagentkit_timesfm-1.0.0.dist-info/licenses/LICENSE +202 -0
- tsagentkit_timesfm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Transformer layers for TimesFM."""
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
from torch import nn
|
|
23
|
+
|
|
24
|
+
from .. import configs
|
|
25
|
+
from . import normalization, util
|
|
26
|
+
|
|
27
|
+
LayerNorm = nn.LayerNorm
|
|
28
|
+
RMSNorm = normalization.RMSNorm
|
|
29
|
+
DecodeCache = util.DecodeCache
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def make_attn_mask(
|
|
33
|
+
query_length: int,
|
|
34
|
+
num_all_masked_kv: torch.Tensor,
|
|
35
|
+
query_index_offset: torch.Tensor | None = None,
|
|
36
|
+
kv_length: int = 0,
|
|
37
|
+
) -> torch.Tensor:
|
|
38
|
+
"""Makes attention mask."""
|
|
39
|
+
if kv_length == 0:
|
|
40
|
+
kv_length = query_length
|
|
41
|
+
|
|
42
|
+
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[
|
|
43
|
+
None, None, :, None
|
|
44
|
+
]
|
|
45
|
+
if query_index_offset is not None:
|
|
46
|
+
q_index = q_index + query_index_offset[:, None, None, None]
|
|
47
|
+
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[
|
|
48
|
+
None, None, None, :
|
|
49
|
+
]
|
|
50
|
+
return torch.logical_and(
|
|
51
|
+
q_index >= kv_index,
|
|
52
|
+
kv_index >= num_all_masked_kv[:, None, None, None],
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class RotaryPositionalEmbedding(nn.Module):
|
|
57
|
+
"""Rotary positional embedding."""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
embedding_dims: int,
|
|
62
|
+
min_timescale: float = 1.0,
|
|
63
|
+
max_timescale: float = 10000.0,
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.embedding_dims = embedding_dims
|
|
67
|
+
self.min_timescale = min_timescale
|
|
68
|
+
self.max_timescale = max_timescale
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self,
|
|
72
|
+
inputs: torch.Tensor,
|
|
73
|
+
position: torch.Tensor | None = None,
|
|
74
|
+
):
|
|
75
|
+
"""Generates a JTensor of sinusoids with different frequencies."""
|
|
76
|
+
if self.embedding_dims != inputs.shape[-1]:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"The embedding dims of the rotary position embedding"
|
|
79
|
+
"must match the hidden dimension of the inputs."
|
|
80
|
+
)
|
|
81
|
+
half_embedding_dim = self.embedding_dims // 2
|
|
82
|
+
fraction = (
|
|
83
|
+
2
|
|
84
|
+
* torch.arange(0, half_embedding_dim, device=inputs.device)
|
|
85
|
+
/ self.embedding_dims
|
|
86
|
+
)
|
|
87
|
+
timescale = (
|
|
88
|
+
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
|
|
89
|
+
).to(inputs.device)
|
|
90
|
+
if position is None:
|
|
91
|
+
seq_length = inputs.shape[1]
|
|
92
|
+
position = torch.arange(seq_length, dtype=torch.float32, device=inputs.device)[
|
|
93
|
+
None, :
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
if len(inputs.shape) == 4:
|
|
97
|
+
position = position[..., None, None]
|
|
98
|
+
timescale = timescale[None, None, None, :]
|
|
99
|
+
elif len(inputs.shape) == 3:
|
|
100
|
+
position = position[..., None]
|
|
101
|
+
timescale = timescale[None, None, :]
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError("Inputs must be of rank 3 or 4.")
|
|
104
|
+
|
|
105
|
+
sinusoid_inp = position / timescale
|
|
106
|
+
sin = torch.sin(sinusoid_inp)
|
|
107
|
+
cos = torch.cos(sinusoid_inp)
|
|
108
|
+
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
|
|
109
|
+
first_part = first_half * cos - second_half * sin
|
|
110
|
+
second_part = second_half * cos + first_half * sin
|
|
111
|
+
return torch.cat([first_part, second_part], dim=-1)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _dot_product_attention(
|
|
115
|
+
query,
|
|
116
|
+
key,
|
|
117
|
+
value,
|
|
118
|
+
mask=None,
|
|
119
|
+
):
|
|
120
|
+
"""Computes dot-product attention given query, key, and value."""
|
|
121
|
+
attn_weights = torch.einsum("...qhd,...khd->...hqk", query, key)
|
|
122
|
+
if mask is not None:
|
|
123
|
+
attn_weights = torch.where(
|
|
124
|
+
mask, attn_weights, -torch.finfo(attn_weights.dtype).max / 2
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
128
|
+
|
|
129
|
+
return torch.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _torch_dot_product_attention(query, key, value, mask=None):
|
|
133
|
+
"""
|
|
134
|
+
Performs the exact same (unscaled) attention as the above function,
|
|
135
|
+
but using the fast and fused F.scaled_dot_product_attention kernel.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
# 1. Permute inputs from (B, L, H, D) to the expected (B, H, L, D)
|
|
139
|
+
query = query.permute(0, 2, 1, 3)
|
|
140
|
+
key = key.permute(0, 2, 1, 3)
|
|
141
|
+
value = value.permute(0, 2, 1, 3)
|
|
142
|
+
|
|
143
|
+
# 2. Call the fused attention kernel
|
|
144
|
+
# - Pass the mask to `attn_mask`.
|
|
145
|
+
# - Set `scale=1.0` to disable the default 1/sqrt(d_k) scaling.
|
|
146
|
+
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, scale=1.0)
|
|
147
|
+
|
|
148
|
+
# 3. Permute the output back to the original (B, L, H, D) layout
|
|
149
|
+
output = output.permute(0, 2, 1, 3)
|
|
150
|
+
|
|
151
|
+
return output
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class PerDimScale(nn.Module):
|
|
155
|
+
"""Per-dimension scaling."""
|
|
156
|
+
|
|
157
|
+
def __init__(self, num_dims: int):
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.num_dims = num_dims
|
|
160
|
+
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
|
|
161
|
+
|
|
162
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
163
|
+
scale_factor = (
|
|
164
|
+
1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
|
|
165
|
+
)
|
|
166
|
+
return x * scale_factor
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class MultiHeadAttention(nn.Module):
|
|
170
|
+
"""Multi-head attention."""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
num_heads: int,
|
|
175
|
+
in_features: int,
|
|
176
|
+
*,
|
|
177
|
+
use_per_dim_scale: bool = True,
|
|
178
|
+
use_rotary_position_embeddings: bool = True,
|
|
179
|
+
use_bias: bool = False,
|
|
180
|
+
attention_fn: Callable[..., torch.Tensor] = _torch_dot_product_attention,
|
|
181
|
+
qk_norm: str = "rms",
|
|
182
|
+
fuse_qkv: bool = False,
|
|
183
|
+
):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.num_heads = num_heads
|
|
186
|
+
self.in_features = in_features
|
|
187
|
+
self.head_dim = in_features // num_heads
|
|
188
|
+
self.use_bias = use_bias
|
|
189
|
+
self.attention_fn = attention_fn
|
|
190
|
+
self.qk_norm = qk_norm
|
|
191
|
+
self.fuse_qkv = fuse_qkv
|
|
192
|
+
|
|
193
|
+
if self.in_features % self.num_heads != 0:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Memory dimension ({self.in_features}) must be divisible by "
|
|
196
|
+
f"'num_heads' heads ({self.num_heads})."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if self.fuse_qkv:
|
|
200
|
+
self.qkv_proj = nn.Linear(self.in_features, 3 * self.in_features, bias=use_bias)
|
|
201
|
+
else:
|
|
202
|
+
self.query = nn.Linear(self.in_features, self.in_features, bias=use_bias)
|
|
203
|
+
self.key = nn.Linear(self.in_features, self.in_features, bias=use_bias)
|
|
204
|
+
self.value = nn.Linear(self.in_features, self.in_features, bias=use_bias)
|
|
205
|
+
self.out = nn.Linear(self.in_features, self.in_features, bias=use_bias)
|
|
206
|
+
|
|
207
|
+
if self.qk_norm == "rms":
|
|
208
|
+
self.query_ln = RMSNorm(self.head_dim)
|
|
209
|
+
self.key_ln = RMSNorm(self.head_dim)
|
|
210
|
+
else:
|
|
211
|
+
self.query_ln = nn.Identity()
|
|
212
|
+
self.key_ln = nn.Identity()
|
|
213
|
+
|
|
214
|
+
self.use_rotary_position_embeddings = use_rotary_position_embeddings
|
|
215
|
+
if self.use_rotary_position_embeddings:
|
|
216
|
+
self.rotary_position_embedding = RotaryPositionalEmbedding(
|
|
217
|
+
embedding_dims=self.head_dim,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
self.use_per_dim_scale = use_per_dim_scale
|
|
221
|
+
if use_per_dim_scale:
|
|
222
|
+
self.per_dim_scale = PerDimScale(num_dims=self.head_dim)
|
|
223
|
+
|
|
224
|
+
def forward(
|
|
225
|
+
self,
|
|
226
|
+
inputs_q: torch.Tensor,
|
|
227
|
+
*,
|
|
228
|
+
decode_cache: DecodeCache | None = None,
|
|
229
|
+
patch_mask: torch.Tensor | None = None,
|
|
230
|
+
) -> tuple[torch.Tensor, DecodeCache | None]:
|
|
231
|
+
b, n_patches, _ = inputs_q.shape
|
|
232
|
+
if patch_mask is None:
|
|
233
|
+
patch_mask = torch.zeros(b, n_patches, dtype=torch.bool, device=inputs_q.device)
|
|
234
|
+
|
|
235
|
+
if self.fuse_qkv:
|
|
236
|
+
qkv = self.qkv_proj(inputs_q)
|
|
237
|
+
query, key, value = torch.chunk(qkv, 3, dim=-1)
|
|
238
|
+
query = query.view(b, n_patches, self.num_heads, self.head_dim)
|
|
239
|
+
key = key.view(b, n_patches, self.num_heads, self.head_dim)
|
|
240
|
+
value = value.view(b, n_patches, self.num_heads, self.head_dim)
|
|
241
|
+
else:
|
|
242
|
+
query = self.query(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
|
|
243
|
+
key = self.key(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
|
|
244
|
+
value = self.value(inputs_q).view(b, n_patches, self.num_heads, self.head_dim)
|
|
245
|
+
|
|
246
|
+
if decode_cache is None:
|
|
247
|
+
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
|
|
248
|
+
next_index = torch.zeros_like(num_masked, dtype=torch.int32)
|
|
249
|
+
else:
|
|
250
|
+
num_masked = (
|
|
251
|
+
torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
|
|
252
|
+
)
|
|
253
|
+
next_index = decode_cache.next_index.clone()
|
|
254
|
+
|
|
255
|
+
if self.use_rotary_position_embeddings:
|
|
256
|
+
position = (
|
|
257
|
+
torch.arange(n_patches, device=inputs_q.device)[None, :]
|
|
258
|
+
+ next_index[:, None]
|
|
259
|
+
- num_masked[:, None]
|
|
260
|
+
)
|
|
261
|
+
query = self.rotary_position_embedding(query, position)
|
|
262
|
+
key = self.rotary_position_embedding(key, position)
|
|
263
|
+
|
|
264
|
+
query = self.query_ln(query)
|
|
265
|
+
key = self.key_ln(key)
|
|
266
|
+
|
|
267
|
+
if self.use_per_dim_scale:
|
|
268
|
+
query = self.per_dim_scale(query)
|
|
269
|
+
|
|
270
|
+
if decode_cache is not None:
|
|
271
|
+
_, decode_cache_size, _, _ = decode_cache.value.shape
|
|
272
|
+
|
|
273
|
+
start = decode_cache.next_index[0]
|
|
274
|
+
end = start + n_patches
|
|
275
|
+
|
|
276
|
+
# Perform a single, vectorized slice assignment for the entire batch.
|
|
277
|
+
# This is vastly more efficient than a Python for-loop.
|
|
278
|
+
|
|
279
|
+
decode_cache.key[:, start:end] = key
|
|
280
|
+
decode_cache.value[:, start:end] = value
|
|
281
|
+
|
|
282
|
+
key = decode_cache.key
|
|
283
|
+
value = decode_cache.value
|
|
284
|
+
decode_cache.next_index += n_patches
|
|
285
|
+
decode_cache.num_masked = num_masked
|
|
286
|
+
attn_mask = make_attn_mask(
|
|
287
|
+
query_length=n_patches,
|
|
288
|
+
num_all_masked_kv=num_masked,
|
|
289
|
+
query_index_offset=next_index,
|
|
290
|
+
kv_length=decode_cache_size,
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
attn_mask = make_attn_mask(query_length=n_patches, num_all_masked_kv=num_masked)
|
|
294
|
+
|
|
295
|
+
x = self.attention_fn(
|
|
296
|
+
query,
|
|
297
|
+
key,
|
|
298
|
+
value,
|
|
299
|
+
mask=attn_mask,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
x = x.reshape(b, n_patches, self.in_features)
|
|
303
|
+
out = self.out(x)
|
|
304
|
+
return out, decode_cache
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class Transformer(nn.Module):
|
|
308
|
+
"""Classic Transformer used in TimesFM."""
|
|
309
|
+
|
|
310
|
+
def __init__(self, config: configs.TransformerConfig):
|
|
311
|
+
super().__init__()
|
|
312
|
+
self.config = config
|
|
313
|
+
|
|
314
|
+
if config.attention_norm == "rms":
|
|
315
|
+
self.pre_attn_ln = RMSNorm(num_features=config.model_dims)
|
|
316
|
+
self.post_attn_ln = RMSNorm(num_features=config.model_dims)
|
|
317
|
+
else:
|
|
318
|
+
raise ValueError(f"Layer norm: {config.attention_norm} not supported.")
|
|
319
|
+
|
|
320
|
+
self.attn = MultiHeadAttention(
|
|
321
|
+
num_heads=config.num_heads,
|
|
322
|
+
in_features=config.model_dims,
|
|
323
|
+
use_per_dim_scale=True,
|
|
324
|
+
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
|
|
325
|
+
qk_norm=config.qk_norm,
|
|
326
|
+
fuse_qkv=config.fuse_qkv,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if config.feedforward_norm == "rms":
|
|
330
|
+
self.pre_ff_ln = RMSNorm(num_features=config.model_dims)
|
|
331
|
+
self.post_ff_ln = RMSNorm(num_features=config.model_dims)
|
|
332
|
+
else:
|
|
333
|
+
raise ValueError(f"Layer norm: {config.feedforward_norm} not supported.")
|
|
334
|
+
|
|
335
|
+
self.ff0 = nn.Linear(
|
|
336
|
+
in_features=config.model_dims,
|
|
337
|
+
out_features=config.hidden_dims,
|
|
338
|
+
bias=config.use_bias,
|
|
339
|
+
)
|
|
340
|
+
self.ff1 = nn.Linear(
|
|
341
|
+
in_features=config.hidden_dims,
|
|
342
|
+
out_features=config.model_dims,
|
|
343
|
+
bias=config.use_bias,
|
|
344
|
+
)
|
|
345
|
+
if config.ff_activation == "relu":
|
|
346
|
+
self.activation = nn.ReLU()
|
|
347
|
+
elif config.ff_activation == "swish":
|
|
348
|
+
self.activation = nn.SiLU()
|
|
349
|
+
elif config.ff_activation == "none":
|
|
350
|
+
self.activation = nn.Identity()
|
|
351
|
+
else:
|
|
352
|
+
raise ValueError(f"Activation: {config.ff_activation} not supported.")
|
|
353
|
+
|
|
354
|
+
def forward(
|
|
355
|
+
self,
|
|
356
|
+
input_embeddings: torch.Tensor,
|
|
357
|
+
patch_mask: torch.Tensor,
|
|
358
|
+
decode_cache: DecodeCache | None = None,
|
|
359
|
+
) -> tuple[torch.Tensor, DecodeCache | None]:
|
|
360
|
+
attn_output, decode_cache = self.attn(
|
|
361
|
+
inputs_q=self.pre_attn_ln(input_embeddings),
|
|
362
|
+
decode_cache=decode_cache,
|
|
363
|
+
patch_mask=patch_mask,
|
|
364
|
+
)
|
|
365
|
+
attn_output = self.post_attn_ln(attn_output) + input_embeddings
|
|
366
|
+
output_embeddings = (
|
|
367
|
+
self.post_ff_ln(self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))))
|
|
368
|
+
+ attn_output
|
|
369
|
+
)
|
|
370
|
+
return output_embeddings, decode_cache
|
timesfm/torch/util.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""PyTorch utility functions for TimesFM layers."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
_TOLERANCE = 1e-6
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclasses.dataclass(frozen=False)
|
|
24
|
+
class DecodeCache:
|
|
25
|
+
"""Cache for decoding."""
|
|
26
|
+
|
|
27
|
+
next_index: torch.Tensor
|
|
28
|
+
num_masked: torch.Tensor
|
|
29
|
+
key: torch.Tensor
|
|
30
|
+
value: torch.Tensor
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def update_running_stats(
|
|
34
|
+
n: torch.Tensor,
|
|
35
|
+
mu: torch.Tensor,
|
|
36
|
+
sigma: torch.Tensor,
|
|
37
|
+
x: torch.Tensor,
|
|
38
|
+
mask: torch.Tensor,
|
|
39
|
+
) -> tuple[
|
|
40
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
41
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
42
|
+
]:
|
|
43
|
+
"""Updates the running stats."""
|
|
44
|
+
is_legit = torch.logical_not(mask)
|
|
45
|
+
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
|
|
46
|
+
|
|
47
|
+
inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
|
|
48
|
+
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
|
|
49
|
+
inc_mu = inc_mu_numerator / inc_n_safe
|
|
50
|
+
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
|
|
51
|
+
|
|
52
|
+
inc_var_numerator = torch.sum(
|
|
53
|
+
((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1
|
|
54
|
+
)
|
|
55
|
+
inc_var = inc_var_numerator / inc_n_safe
|
|
56
|
+
inc_var = torch.where(inc_n == 0, 0.0, inc_var)
|
|
57
|
+
inc_sigma = torch.sqrt(inc_var)
|
|
58
|
+
|
|
59
|
+
new_n = n + inc_n
|
|
60
|
+
new_n_safe = torch.where(new_n == 0, 1.0, new_n)
|
|
61
|
+
|
|
62
|
+
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
|
|
63
|
+
new_mu = torch.where(new_n == 0, 0.0, new_mu)
|
|
64
|
+
|
|
65
|
+
term1 = n * sigma.pow(2)
|
|
66
|
+
term2 = inc_n * inc_sigma.pow(2)
|
|
67
|
+
term3 = n * (mu - new_mu).pow(2)
|
|
68
|
+
term4 = inc_n * (inc_mu - new_mu).pow(2)
|
|
69
|
+
|
|
70
|
+
new_var = (term1 + term2 + term3 + term4) / new_n_safe
|
|
71
|
+
new_var = torch.where(new_n == 0, 0.0, new_var)
|
|
72
|
+
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
|
|
73
|
+
|
|
74
|
+
return (w := (new_n, new_mu, new_sigma), w)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def revin(
|
|
78
|
+
x: torch.Tensor,
|
|
79
|
+
mu: torch.Tensor,
|
|
80
|
+
sigma: torch.Tensor,
|
|
81
|
+
reverse: bool = False,
|
|
82
|
+
):
|
|
83
|
+
"""Reversible instance normalization."""
|
|
84
|
+
if len(mu.shape) == len(x.shape) - 1:
|
|
85
|
+
mu = mu[..., None]
|
|
86
|
+
sigma = sigma[..., None]
|
|
87
|
+
elif len(mu.shape) == len(x.shape) - 2:
|
|
88
|
+
mu = mu[..., None, None]
|
|
89
|
+
sigma = sigma[..., None, None]
|
|
90
|
+
|
|
91
|
+
if reverse:
|
|
92
|
+
return x * sigma + mu
|
|
93
|
+
else:
|
|
94
|
+
return (x - mu) / torch.where(sigma < _TOLERANCE, 1.0, sigma)
|