dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2099 @@
|
|
|
1
|
+
# Copyright contributors to the TSFM project
|
|
2
|
+
#
|
|
3
|
+
# This code is based on layers and components from the PatchTSMixer model in the HuggingFace Transformers
|
|
4
|
+
# Library: https://github.com/huggingface/transformers/blob/main/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
|
|
5
|
+
"""PyTorch TinyTimeMixer model."""
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import math
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
15
|
+
from transformers.time_series_utils import (
|
|
16
|
+
NegativeBinomialOutput,
|
|
17
|
+
NormalOutput,
|
|
18
|
+
StudentTOutput,
|
|
19
|
+
)
|
|
20
|
+
from transformers.utils import ModelOutput
|
|
21
|
+
|
|
22
|
+
from .configuration_tinytimemixer import TinyTimeMixerConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PinballLoss(nn.Module):
|
|
27
|
+
def __init__(self, quantile: float):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the Pinball Loss for multidimensional tensors.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
quantile (float): The desired quantile (e.g., 0.5 for median, 0.9 for 90th percentile).
|
|
33
|
+
"""
|
|
34
|
+
super(PinballLoss, self).__init__()
|
|
35
|
+
self.quantile = quantile
|
|
36
|
+
|
|
37
|
+
def forward(self, predictions, targets):
|
|
38
|
+
"""
|
|
39
|
+
Compute the Pinball Loss for shape [b, seq_len, channels].
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
predictions (torch.Tensor): Predicted values, shape [b, seq_len, channels].
|
|
43
|
+
targets (torch.Tensor): Ground truth values, shape [b, seq_len, channels].
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
torch.Tensor: The mean pinball loss over all dimensions.
|
|
47
|
+
"""
|
|
48
|
+
errors = targets - predictions
|
|
49
|
+
|
|
50
|
+
loss = torch.max(self.quantile * errors, (self.quantile - 1) * errors)
|
|
51
|
+
|
|
52
|
+
return loss.mean()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TinyTimeMixerGatedAttention(nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
Module that applies gated attention to input data.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
in_size (`int`): The input size.
|
|
61
|
+
out_size (`int`): The output size.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, in_size: int, out_size: int):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.attn_layer = nn.Linear(in_size, out_size)
|
|
67
|
+
self.attn_softmax = nn.Softmax(dim=-1)
|
|
68
|
+
|
|
69
|
+
def forward(self, inputs):
|
|
70
|
+
attn_weight = self.attn_softmax(self.attn_layer(inputs))
|
|
71
|
+
inputs = inputs * attn_weight
|
|
72
|
+
return inputs
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TinyTimeMixerCategoricalEmbeddingLayer(nn.Module):
|
|
76
|
+
""" """
|
|
77
|
+
|
|
78
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.categorical_vocab_size_list = config.categorical_vocab_size_list
|
|
81
|
+
self.embedding_layers = nn.ModuleList(
|
|
82
|
+
[nn.Embedding(vocab, config.d_model) for vocab in self.categorical_vocab_size_list]
|
|
83
|
+
)
|
|
84
|
+
self.number_of_categorical_variables = len(self.categorical_vocab_size_list)
|
|
85
|
+
self.num_patches = config.num_patches
|
|
86
|
+
|
|
87
|
+
def forward(self, static_categorical_values: torch.Tensor):
|
|
88
|
+
"""
|
|
89
|
+
Parameters:
|
|
90
|
+
static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`):
|
|
91
|
+
Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
|
|
92
|
+
TinyTimeMixerConfig param `categorical_vocab_size_list`
|
|
93
|
+
Returns:
|
|
94
|
+
`torch.Tensor` of shape `(batch_size, number_of_categorical_variables, num_patches, d_model)`
|
|
95
|
+
"""
|
|
96
|
+
# static_categorical_values [bs x number_of_categorical_variables]
|
|
97
|
+
embedded_tensors = []
|
|
98
|
+
|
|
99
|
+
for i in range(self.number_of_categorical_variables):
|
|
100
|
+
embedded_tensor = self.embedding_layers[i](static_categorical_values[:, i].long())
|
|
101
|
+
embedded_tensors.append(embedded_tensor)
|
|
102
|
+
|
|
103
|
+
output_tensor = torch.stack(embedded_tensors, dim=1) # bs x number_of_categorical_variables x d_model
|
|
104
|
+
|
|
105
|
+
output_tensor = output_tensor.unsqueeze(2).repeat(
|
|
106
|
+
1, 1, self.num_patches, 1
|
|
107
|
+
) # bs x number_of_categorical_variables x num_patches x d_model
|
|
108
|
+
|
|
109
|
+
return output_tensor
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class TinyTimeMixerBatchNorm(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Compute batch normalization over the sequence length (time) dimension.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
118
|
+
super().__init__()
|
|
119
|
+
self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
|
|
120
|
+
|
|
121
|
+
def forward(self, inputs: torch.Tensor):
|
|
122
|
+
"""
|
|
123
|
+
Parameters:
|
|
124
|
+
inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
|
|
125
|
+
input for Batch norm calculation
|
|
126
|
+
Returns:
|
|
127
|
+
`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
|
|
128
|
+
"""
|
|
129
|
+
output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
|
|
130
|
+
output = self.batchnorm(output)
|
|
131
|
+
return output.transpose(1, 2)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class TinyTimeMixerPositionalEncoding(nn.Module):
|
|
135
|
+
"""
|
|
136
|
+
Class for positional encoding
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
140
|
+
super().__init__()
|
|
141
|
+
# positional encoding: [num_patches x d_model]
|
|
142
|
+
if config.use_positional_encoding:
|
|
143
|
+
self.position_enc = self._init_pe(config)
|
|
144
|
+
else:
|
|
145
|
+
self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _init_pe(config: TinyTimeMixerConfig) -> nn.Parameter:
|
|
149
|
+
# Positional encoding
|
|
150
|
+
if config.positional_encoding_type == "random":
|
|
151
|
+
position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
|
|
152
|
+
elif config.positional_encoding_type == "sincos":
|
|
153
|
+
position_enc = torch.zeros(config.num_patches, config.d_model)
|
|
154
|
+
position = torch.arange(0, config.num_patches).unsqueeze(1)
|
|
155
|
+
div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
|
|
156
|
+
position_enc[:, 0::2] = torch.sin(position * div_term)
|
|
157
|
+
position_enc[:, 1::2] = torch.cos(position * div_term)
|
|
158
|
+
position_enc = position_enc - position_enc.mean()
|
|
159
|
+
position_enc = position_enc / (position_enc.std() * 10)
|
|
160
|
+
position_enc = nn.Parameter(position_enc, requires_grad=False)
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
|
|
164
|
+
)
|
|
165
|
+
return position_enc
|
|
166
|
+
|
|
167
|
+
def forward(self, patch_input: torch.Tensor):
|
|
168
|
+
# hidden_state: [bs x num_channels x num_patches x d_model]
|
|
169
|
+
hidden_state = patch_input + self.position_enc
|
|
170
|
+
return hidden_state
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class TinyTimeMixerNormLayer(nn.Module):
|
|
174
|
+
"""Normalization block
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
178
|
+
Configuration.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
182
|
+
super().__init__()
|
|
183
|
+
|
|
184
|
+
self.norm_mlp = config.norm_mlp
|
|
185
|
+
|
|
186
|
+
if "batch" in config.norm_mlp.lower():
|
|
187
|
+
self.norm = TinyTimeMixerBatchNorm(config)
|
|
188
|
+
else:
|
|
189
|
+
self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
|
|
190
|
+
|
|
191
|
+
def forward(self, inputs: torch.Tensor):
|
|
192
|
+
"""
|
|
193
|
+
Args:
|
|
194
|
+
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
|
|
195
|
+
Input to the normalization layer.
|
|
196
|
+
Returns:
|
|
197
|
+
`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
|
|
198
|
+
"""
|
|
199
|
+
if "batch" in self.norm_mlp.lower():
|
|
200
|
+
# reshape the data
|
|
201
|
+
inputs_reshaped = torch.reshape(
|
|
202
|
+
inputs,
|
|
203
|
+
(
|
|
204
|
+
inputs.shape[0] * inputs.shape[1],
|
|
205
|
+
inputs.shape[2],
|
|
206
|
+
inputs.shape[3],
|
|
207
|
+
),
|
|
208
|
+
) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
|
|
209
|
+
|
|
210
|
+
# inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
|
|
211
|
+
inputs_reshaped = self.norm(inputs_reshaped)
|
|
212
|
+
|
|
213
|
+
# put back data to the original shape
|
|
214
|
+
inputs = torch.reshape(inputs_reshaped, inputs.shape)
|
|
215
|
+
|
|
216
|
+
else:
|
|
217
|
+
inputs = self.norm(inputs)
|
|
218
|
+
|
|
219
|
+
return inputs
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class TinyTimeMixerMLP(nn.Module):
|
|
223
|
+
def __init__(self, in_features, out_features, config):
|
|
224
|
+
super().__init__()
|
|
225
|
+
num_hidden = in_features * config.expansion_factor
|
|
226
|
+
self.fc1 = nn.Linear(in_features, num_hidden)
|
|
227
|
+
self.dropout1 = nn.Dropout(config.dropout)
|
|
228
|
+
self.fc2 = nn.Linear(num_hidden, out_features)
|
|
229
|
+
self.dropout2 = nn.Dropout(config.dropout)
|
|
230
|
+
|
|
231
|
+
def forward(self, inputs: torch.Tensor):
|
|
232
|
+
"""
|
|
233
|
+
Args:
|
|
234
|
+
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
|
|
235
|
+
Input to the MLP layer.
|
|
236
|
+
Returns:
|
|
237
|
+
`torch.Tensor` of the same shape as `inputs`
|
|
238
|
+
"""
|
|
239
|
+
inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
|
|
240
|
+
inputs = self.fc2(inputs)
|
|
241
|
+
inputs = self.dropout2(inputs)
|
|
242
|
+
return inputs
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class TinyTimeMixerChannelFeatureMixerBlock(nn.Module):
|
|
246
|
+
"""This module mixes the features in the channel dimension.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
250
|
+
Configuration.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
254
|
+
super().__init__()
|
|
255
|
+
|
|
256
|
+
self.norm = TinyTimeMixerNormLayer(config)
|
|
257
|
+
self.gated_attn = config.gated_attn
|
|
258
|
+
self.mlp = TinyTimeMixerMLP(
|
|
259
|
+
in_features=config.num_input_channels,
|
|
260
|
+
out_features=config.num_input_channels,
|
|
261
|
+
config=config,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if config.gated_attn:
|
|
265
|
+
self.gating_block = TinyTimeMixerGatedAttention(
|
|
266
|
+
in_size=config.num_input_channels, out_size=config.num_input_channels
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def forward(self, inputs: torch.Tensor):
|
|
270
|
+
"""
|
|
271
|
+
Args:
|
|
272
|
+
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
|
|
273
|
+
input to the MLP layer
|
|
274
|
+
Returns:
|
|
275
|
+
`torch.Tensor` of the same shape as `inputs`
|
|
276
|
+
"""
|
|
277
|
+
residual = inputs
|
|
278
|
+
inputs = self.norm(inputs)
|
|
279
|
+
|
|
280
|
+
inputs = inputs.permute(0, 3, 2, 1)
|
|
281
|
+
|
|
282
|
+
if self.gated_attn:
|
|
283
|
+
inputs = self.gating_block(inputs)
|
|
284
|
+
|
|
285
|
+
inputs = self.mlp(inputs)
|
|
286
|
+
|
|
287
|
+
inputs = inputs.permute(0, 3, 2, 1)
|
|
288
|
+
|
|
289
|
+
out = inputs + residual
|
|
290
|
+
return out
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class TinyTimeMixerAttention(nn.Module):
|
|
294
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
295
|
+
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
embed_dim: int,
|
|
299
|
+
num_heads: int,
|
|
300
|
+
dropout: float = 0.0,
|
|
301
|
+
is_decoder: bool = False,
|
|
302
|
+
bias: bool = True,
|
|
303
|
+
is_causal: bool = False,
|
|
304
|
+
config: Optional[TinyTimeMixerConfig] = None,
|
|
305
|
+
):
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.embed_dim = embed_dim
|
|
308
|
+
self.num_heads = num_heads
|
|
309
|
+
self.dropout = dropout
|
|
310
|
+
self.head_dim = embed_dim // num_heads
|
|
311
|
+
self.config = config
|
|
312
|
+
|
|
313
|
+
if (self.head_dim * num_heads) != self.embed_dim:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
316
|
+
f" and `num_heads`: {num_heads})."
|
|
317
|
+
)
|
|
318
|
+
self.scaling = self.head_dim**-0.5
|
|
319
|
+
self.is_decoder = is_decoder
|
|
320
|
+
self.is_causal = is_causal
|
|
321
|
+
|
|
322
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
323
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
324
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
325
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
326
|
+
|
|
327
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
328
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
329
|
+
|
|
330
|
+
def forward(
|
|
331
|
+
self,
|
|
332
|
+
hidden_states: torch.Tensor,
|
|
333
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
334
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
335
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
336
|
+
layer_head_mask: Optional[torch.Tensor] = None,
|
|
337
|
+
output_attentions: bool = False,
|
|
338
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
339
|
+
"""Input shape: Batch x Time x Channel"""
|
|
340
|
+
|
|
341
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
342
|
+
# for the decoder
|
|
343
|
+
is_cross_attention = key_value_states is not None
|
|
344
|
+
|
|
345
|
+
bsz, tgt_len, _ = hidden_states.size()
|
|
346
|
+
|
|
347
|
+
# get query proj
|
|
348
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
|
349
|
+
# get key, value proj
|
|
350
|
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
|
351
|
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
|
352
|
+
# the provided `key_value_states` to support prefix tuning
|
|
353
|
+
if (
|
|
354
|
+
is_cross_attention
|
|
355
|
+
and past_key_value is not None
|
|
356
|
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
|
357
|
+
):
|
|
358
|
+
# reuse k,v, cross_attentions
|
|
359
|
+
key_states = past_key_value[0]
|
|
360
|
+
value_states = past_key_value[1]
|
|
361
|
+
elif is_cross_attention:
|
|
362
|
+
# cross_attentions
|
|
363
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
364
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
365
|
+
elif past_key_value is not None:
|
|
366
|
+
# reuse k, v, self_attention
|
|
367
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
368
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
369
|
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
370
|
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
371
|
+
else:
|
|
372
|
+
# self_attention
|
|
373
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
374
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
375
|
+
|
|
376
|
+
if self.is_decoder:
|
|
377
|
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
378
|
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
379
|
+
# key/value_states (first "if" case)
|
|
380
|
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
381
|
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
382
|
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
383
|
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
384
|
+
past_key_value = (key_states, value_states)
|
|
385
|
+
|
|
386
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
387
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
388
|
+
key_states = key_states.reshape(*proj_shape)
|
|
389
|
+
value_states = value_states.reshape(*proj_shape)
|
|
390
|
+
|
|
391
|
+
src_len = key_states.size(1)
|
|
392
|
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
393
|
+
|
|
394
|
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
395
|
+
raise ValueError(
|
|
396
|
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
397
|
+
f" {attn_weights.size()}"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
if attention_mask is not None:
|
|
401
|
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
404
|
+
)
|
|
405
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
|
406
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
407
|
+
|
|
408
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
409
|
+
|
|
410
|
+
if layer_head_mask is not None:
|
|
411
|
+
if layer_head_mask.size() != (self.num_heads,):
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
414
|
+
f" {layer_head_mask.size()}"
|
|
415
|
+
)
|
|
416
|
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
417
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
418
|
+
|
|
419
|
+
if output_attentions:
|
|
420
|
+
# this operation is a bit awkward, but it's required to
|
|
421
|
+
# make sure that attn_weights keeps its gradient.
|
|
422
|
+
# In order to do so, attn_weights have to be reshaped
|
|
423
|
+
# twice and have to be reused in the following
|
|
424
|
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
425
|
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
|
426
|
+
else:
|
|
427
|
+
attn_weights_reshaped = None
|
|
428
|
+
|
|
429
|
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
430
|
+
|
|
431
|
+
attn_output = torch.bmm(attn_probs, value_states)
|
|
432
|
+
|
|
433
|
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
|
436
|
+
f" {attn_output.size()}"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
440
|
+
attn_output = attn_output.transpose(1, 2)
|
|
441
|
+
|
|
442
|
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
443
|
+
# partitioned across GPUs when using tensor-parallelism.
|
|
444
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
445
|
+
|
|
446
|
+
attn_output = self.out_proj(attn_output)
|
|
447
|
+
|
|
448
|
+
return attn_output, attn_weights_reshaped, past_key_value
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class PatchMixerBlock(nn.Module):
|
|
452
|
+
"""This module mixes the patch dimension.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
456
|
+
Configuration.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
460
|
+
super().__init__()
|
|
461
|
+
|
|
462
|
+
self.norm = TinyTimeMixerNormLayer(config)
|
|
463
|
+
|
|
464
|
+
self.self_attn = config.self_attn
|
|
465
|
+
self.gated_attn = config.gated_attn
|
|
466
|
+
|
|
467
|
+
self.mlp = TinyTimeMixerMLP(
|
|
468
|
+
in_features=config.num_patches,
|
|
469
|
+
out_features=config.num_patches,
|
|
470
|
+
config=config,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if config.gated_attn:
|
|
474
|
+
self.gating_block = TinyTimeMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
|
|
475
|
+
|
|
476
|
+
if config.self_attn:
|
|
477
|
+
self.self_attn_layer = TinyTimeMixerAttention(
|
|
478
|
+
embed_dim=config.d_model,
|
|
479
|
+
num_heads=config.self_attn_heads,
|
|
480
|
+
dropout=config.dropout,
|
|
481
|
+
)
|
|
482
|
+
self.norm_attn = TinyTimeMixerNormLayer(config)
|
|
483
|
+
|
|
484
|
+
def forward(self, hidden_state):
|
|
485
|
+
"""
|
|
486
|
+
Args:
|
|
487
|
+
hidden_state (`torch.Tensor`): Input tensor.
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
`torch.Tensor`: Transformed tensor.
|
|
491
|
+
"""
|
|
492
|
+
residual = hidden_state
|
|
493
|
+
|
|
494
|
+
hidden_state = self.norm(hidden_state)
|
|
495
|
+
|
|
496
|
+
if self.self_attn:
|
|
497
|
+
batch_size, n_vars, num_patches, d_model = hidden_state.shape
|
|
498
|
+
hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
|
|
499
|
+
|
|
500
|
+
x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
|
|
501
|
+
x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
|
|
502
|
+
|
|
503
|
+
# Transpose so that num_patches is the last dimension
|
|
504
|
+
hidden_state = hidden_state.transpose(2, 3)
|
|
505
|
+
hidden_state = self.mlp(hidden_state)
|
|
506
|
+
|
|
507
|
+
if self.gated_attn:
|
|
508
|
+
hidden_state = self.gating_block(hidden_state)
|
|
509
|
+
|
|
510
|
+
# Transpose back
|
|
511
|
+
hidden_state = hidden_state.transpose(2, 3)
|
|
512
|
+
|
|
513
|
+
if self.self_attn:
|
|
514
|
+
hidden_state = self.norm_attn(hidden_state + x_attn)
|
|
515
|
+
|
|
516
|
+
out = hidden_state + residual
|
|
517
|
+
return out
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
class FeatureMixerBlock(nn.Module):
|
|
521
|
+
"""This module mixes the hidden feature dimension.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
525
|
+
Configuration.
|
|
526
|
+
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
530
|
+
super().__init__()
|
|
531
|
+
|
|
532
|
+
self.norm = TinyTimeMixerNormLayer(config)
|
|
533
|
+
|
|
534
|
+
self.gated_attn = config.gated_attn
|
|
535
|
+
|
|
536
|
+
self.mlp = TinyTimeMixerMLP(
|
|
537
|
+
in_features=config.d_model,
|
|
538
|
+
out_features=config.d_model,
|
|
539
|
+
config=config,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
if config.gated_attn:
|
|
543
|
+
self.gating_block = TinyTimeMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
|
|
544
|
+
|
|
545
|
+
def forward(self, hidden: torch.Tensor):
|
|
546
|
+
"""
|
|
547
|
+
Args:
|
|
548
|
+
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
|
|
549
|
+
Input tensor to the layer.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
`torch.Tensor`: Transformed tensor.
|
|
553
|
+
"""
|
|
554
|
+
residual = hidden
|
|
555
|
+
hidden = self.norm(hidden)
|
|
556
|
+
hidden = self.mlp(hidden)
|
|
557
|
+
|
|
558
|
+
if self.gated_attn:
|
|
559
|
+
hidden = self.gating_block(hidden)
|
|
560
|
+
|
|
561
|
+
out = hidden + residual
|
|
562
|
+
return out
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class ForecastChannelHeadMixer(nn.Module):
|
|
566
|
+
"""ForecastChannelMixer Module to reconcile forecasts across channels with exogenous support.
|
|
567
|
+
|
|
568
|
+
When channel_context_length is positive this mode creates a patch for every multi-variate forecast point with its surronding context
|
|
569
|
+
it then flattens it and applies MLP to it.
|
|
570
|
+
By this way, every forecast point learn from its pre and post surrounding context in a channel mixed way.
|
|
571
|
+
Residual is added to ensure noise reduction with initial forecasts.
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
575
|
+
super().__init__()
|
|
576
|
+
|
|
577
|
+
self.fcm_context_length = config.fcm_context_length
|
|
578
|
+
self.scl = 2 * self.fcm_context_length + 1
|
|
579
|
+
|
|
580
|
+
if config.prediction_channel_indices is not None:
|
|
581
|
+
self.prediction_channel_count = len(config.prediction_channel_indices)
|
|
582
|
+
else:
|
|
583
|
+
self.prediction_channel_count = config.num_input_channels
|
|
584
|
+
|
|
585
|
+
if config.exogenous_channel_indices is not None:
|
|
586
|
+
self.exogenous_channel_count = len(config.exogenous_channel_indices)
|
|
587
|
+
else:
|
|
588
|
+
self.exogenous_channel_count = 0
|
|
589
|
+
|
|
590
|
+
self.total_channel_count = self.prediction_channel_count + self.exogenous_channel_count
|
|
591
|
+
|
|
592
|
+
self.fcm_use_mixer = config.fcm_use_mixer
|
|
593
|
+
|
|
594
|
+
self.exogenous_channel_indices = config.exogenous_channel_indices
|
|
595
|
+
self.prediction_channel_indices = config.prediction_channel_indices
|
|
596
|
+
scl_features = self.scl
|
|
597
|
+
|
|
598
|
+
if self.fcm_use_mixer:
|
|
599
|
+
# model mixer considering channel dim as patch dim for lag computation
|
|
600
|
+
temp_config = copy.deepcopy(config)
|
|
601
|
+
temp_config.num_patches = self.total_channel_count
|
|
602
|
+
temp_config.patch_length = self.scl
|
|
603
|
+
temp_config.num_input_channels = config.prediction_length
|
|
604
|
+
temp_config.d_model = self.scl * 2
|
|
605
|
+
temp_config.patch_stride = 1
|
|
606
|
+
temp_config.num_layers = config.fcm_mix_layers
|
|
607
|
+
temp_config.dropout = config.head_dropout
|
|
608
|
+
temp_config.mode = "common_channel"
|
|
609
|
+
temp_config.gated_attn = config.fcm_gated_attn
|
|
610
|
+
temp_config.adaptive_patching_levels = 0
|
|
611
|
+
self.exog_mixer = TinyTimeMixerBlock(temp_config)
|
|
612
|
+
scl_features = self.scl * 2
|
|
613
|
+
self.fcm_embedding = nn.Linear(temp_config.patch_length, temp_config.d_model)
|
|
614
|
+
|
|
615
|
+
self.mlp = nn.Linear(
|
|
616
|
+
self.total_channel_count * (scl_features),
|
|
617
|
+
self.prediction_channel_count,
|
|
618
|
+
)
|
|
619
|
+
if config.fcm_gated_attn:
|
|
620
|
+
self.fcm_gating_block = TinyTimeMixerGatedAttention(
|
|
621
|
+
in_size=self.total_channel_count * (scl_features),
|
|
622
|
+
out_size=self.total_channel_count * (scl_features),
|
|
623
|
+
)
|
|
624
|
+
if self.fcm_context_length > 0:
|
|
625
|
+
patch_config = copy.deepcopy(config)
|
|
626
|
+
patch_config.context_length = config.prediction_length + (2 * config.fcm_context_length)
|
|
627
|
+
patch_config.masked_context_length = None
|
|
628
|
+
patch_config.patch_length = self.scl
|
|
629
|
+
patch_config.patch_stride = 1
|
|
630
|
+
self.fcm_patch_block = TinyTimeMixerPatchify(patch_config)
|
|
631
|
+
|
|
632
|
+
self.fcm_gated_attn = config.fcm_gated_attn
|
|
633
|
+
self.prediction_length = config.prediction_length
|
|
634
|
+
self.fcm_prepend_past = config.fcm_prepend_past
|
|
635
|
+
|
|
636
|
+
self.fcm_prepend_past_offset = (
|
|
637
|
+
config.fcm_prepend_past_offset
|
|
638
|
+
) # Number of items to skip in the context window from the end
|
|
639
|
+
|
|
640
|
+
if self.fcm_prepend_past_offset is None:
|
|
641
|
+
self.fcm_prepend_slicing_indices = slice(-self.fcm_context_length, None)
|
|
642
|
+
else:
|
|
643
|
+
self.fcm_prepend_slicing_indices = slice(
|
|
644
|
+
-(self.fcm_prepend_past_offset + self.fcm_context_length),
|
|
645
|
+
-self.fcm_prepend_past_offset,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def forward(
|
|
649
|
+
self,
|
|
650
|
+
base_forecasts: torch.Tensor,
|
|
651
|
+
past_values: Optional[torch.Tensor],
|
|
652
|
+
future_values: Optional[torch.Tensor] = None,
|
|
653
|
+
):
|
|
654
|
+
"""
|
|
655
|
+
Args:
|
|
656
|
+
base_forecasts (`torch.Tensor` of shape `(batch_size, prediction length, forecast_channels)`):
|
|
657
|
+
Base Forecasts to reconcile
|
|
658
|
+
|
|
659
|
+
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
660
|
+
Context values of the time series. For a forecasting task, this denotes the history/past time series values.
|
|
661
|
+
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
|
|
662
|
+
greater than 1.
|
|
663
|
+
|
|
664
|
+
future_values (`torch.Tensor` of shape `(batch_size, prediction length, input_channels)`, *optional*, Defaults to None):
|
|
665
|
+
Actual groundtruths of the forecasts. Pass dummy values (say 0) for forecast channels, if groundtruth is unknown.
|
|
666
|
+
Pass the correct values for Exogenous channels where the forecast values are known.
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
`torch.Tensor`: Updated forecasts of shape `(batch_size, prediction length, forecast_channels)`
|
|
670
|
+
"""
|
|
671
|
+
# base_forecasts.shape == (batch_size x forecast_len x n_vars)
|
|
672
|
+
|
|
673
|
+
if self.prediction_channel_indices is not None:
|
|
674
|
+
past_prepend_values = past_values[
|
|
675
|
+
:, self.fcm_prepend_slicing_indices, self.prediction_channel_indices
|
|
676
|
+
] # bs x context_len x forecast_channels
|
|
677
|
+
else:
|
|
678
|
+
past_prepend_values = past_values[
|
|
679
|
+
:, self.fcm_prepend_slicing_indices, :
|
|
680
|
+
] # bs x fcm_context_len x forecast_channels
|
|
681
|
+
|
|
682
|
+
if self.exogenous_channel_count > 0 and future_values is None:
|
|
683
|
+
raise ValueError("future_values cannot be none when we have exogenous channels.")
|
|
684
|
+
|
|
685
|
+
if self.exogenous_channel_count > 0:
|
|
686
|
+
exog_values = future_values[..., self.exogenous_channel_indices] # bs x prediction len x exog_channels
|
|
687
|
+
past_exog_values = past_values[
|
|
688
|
+
:, self.fcm_prepend_slicing_indices, self.exogenous_channel_indices
|
|
689
|
+
] # bs x context_len x exog_channels
|
|
690
|
+
|
|
691
|
+
past_prepend_values = torch.cat(
|
|
692
|
+
(past_prepend_values, past_exog_values), dim=-1
|
|
693
|
+
) # bs x fcm_context_len x (forecast_channels+exog_channels)
|
|
694
|
+
|
|
695
|
+
else:
|
|
696
|
+
exog_values = None
|
|
697
|
+
|
|
698
|
+
residual = base_forecasts
|
|
699
|
+
|
|
700
|
+
if exog_values is not None:
|
|
701
|
+
base_forecasts = torch.cat(
|
|
702
|
+
(base_forecasts, exog_values), dim=-1
|
|
703
|
+
) # x.shape == (batch_size x forecast_len x (forecast_channels+exog_channels))
|
|
704
|
+
|
|
705
|
+
if self.fcm_context_length > 0:
|
|
706
|
+
# this mode creates a patch for every multi-variate forecast point with its surronding context
|
|
707
|
+
# it then flattens it and applies MLP to it.
|
|
708
|
+
# By this way, every forecast point learn from its pre and post surrounding context in a channel mixed way.
|
|
709
|
+
# Residual is added to ensure noise reduction with initial forecasts.
|
|
710
|
+
|
|
711
|
+
# prefill and postfill zeros to enable patching for every forecast point with surrounding context
|
|
712
|
+
|
|
713
|
+
dummy = torch.zeros(
|
|
714
|
+
base_forecasts.shape[0],
|
|
715
|
+
self.fcm_context_length,
|
|
716
|
+
base_forecasts.shape[2],
|
|
717
|
+
device=base_forecasts.device,
|
|
718
|
+
) # bs x fcm_context_length x n_vars
|
|
719
|
+
|
|
720
|
+
if self.fcm_prepend_past:
|
|
721
|
+
# add prefill and postfill
|
|
722
|
+
extend_forecasts = torch.concat(
|
|
723
|
+
(past_prepend_values, base_forecasts, dummy), dim=1
|
|
724
|
+
) # bs x forecast_len + 2*fcm_context_length x n_vars
|
|
725
|
+
else:
|
|
726
|
+
# add prefill and postfill
|
|
727
|
+
extend_forecasts = torch.concat(
|
|
728
|
+
(dummy, base_forecasts, dummy), dim=1
|
|
729
|
+
) # bs x forecast_len + 2*fcm_context_length x n_vars
|
|
730
|
+
|
|
731
|
+
# create patch
|
|
732
|
+
extend_forecasts = self.fcm_patch_block(extend_forecasts) # xb: [bs x n_vars x forecast_len x scl]
|
|
733
|
+
|
|
734
|
+
extend_forecasts = extend_forecasts.transpose(1, 2) # [bs x forecast_len x n_vars x scl]
|
|
735
|
+
|
|
736
|
+
if extend_forecasts.shape[1] != self.prediction_length:
|
|
737
|
+
raise ValueError("out_patches should match to forecast length")
|
|
738
|
+
|
|
739
|
+
if self.fcm_use_mixer:
|
|
740
|
+
extend_forecasts = self.fcm_embedding(extend_forecasts)
|
|
741
|
+
extend_forecasts, _ = self.exog_mixer(extend_forecasts)
|
|
742
|
+
|
|
743
|
+
extend_forecasts = extend_forecasts.flatten(start_dim=2) # xb: [bs x forecast_len x n_vars * scl]
|
|
744
|
+
|
|
745
|
+
if self.fcm_gated_attn:
|
|
746
|
+
extend_forecasts = self.fcm_gating_block(extend_forecasts) # xb: [bs x forecast_len x n_vars * scl]
|
|
747
|
+
|
|
748
|
+
extend_forecasts = self.mlp(extend_forecasts) # xb: [bs x forecast_len x n_vars]
|
|
749
|
+
|
|
750
|
+
else:
|
|
751
|
+
if self.fcm_gated_attn:
|
|
752
|
+
extend_forecasts = self.fcm_gating_block(base_forecasts)
|
|
753
|
+
|
|
754
|
+
extend_forecasts = self.mlp(extend_forecasts)
|
|
755
|
+
|
|
756
|
+
new_forecast = extend_forecasts + residual
|
|
757
|
+
|
|
758
|
+
return new_forecast
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
class TinyTimeMixerLayer(nn.Module):
|
|
762
|
+
"""
|
|
763
|
+
The `TinyTimeMixer` layer that does all three kinds of mixing.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
767
|
+
Configuration.
|
|
768
|
+
|
|
769
|
+
"""
|
|
770
|
+
|
|
771
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
772
|
+
super().__init__()
|
|
773
|
+
|
|
774
|
+
if config.num_patches > 1:
|
|
775
|
+
self.patch_mixer = PatchMixerBlock(config=config)
|
|
776
|
+
|
|
777
|
+
self.feature_mixer = FeatureMixerBlock(config=config)
|
|
778
|
+
|
|
779
|
+
self.mode = config.mode
|
|
780
|
+
self.num_patches = config.num_patches
|
|
781
|
+
if config.mode == "mix_channel":
|
|
782
|
+
self.channel_feature_mixer = TinyTimeMixerChannelFeatureMixerBlock(config=config)
|
|
783
|
+
|
|
784
|
+
def forward(self, hidden: torch.Tensor):
|
|
785
|
+
"""
|
|
786
|
+
Args:
|
|
787
|
+
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
|
|
788
|
+
Input tensor to the layer.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
`torch.Tensor`: Transformed tensor.
|
|
792
|
+
"""
|
|
793
|
+
if self.mode == "mix_channel":
|
|
794
|
+
hidden = self.channel_feature_mixer(hidden)
|
|
795
|
+
|
|
796
|
+
if self.num_patches > 1:
|
|
797
|
+
hidden = self.patch_mixer(hidden)
|
|
798
|
+
hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
|
|
799
|
+
return hidden
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
class TinyTimeMixerAdaptivePatchingBlock(nn.Module):
|
|
803
|
+
"""
|
|
804
|
+
The `TinyTimeMixer` layer that does all three kinds of mixing.
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
808
|
+
Configuration.
|
|
809
|
+
|
|
810
|
+
"""
|
|
811
|
+
|
|
812
|
+
def __init__(self, config: TinyTimeMixerConfig, adapt_patch_level: int):
|
|
813
|
+
super().__init__()
|
|
814
|
+
temp_config = copy.deepcopy(config)
|
|
815
|
+
self.adapt_patch_level = adapt_patch_level
|
|
816
|
+
adaptive_patch_factor = 2**adapt_patch_level
|
|
817
|
+
self.adaptive_patch_factor = adaptive_patch_factor
|
|
818
|
+
|
|
819
|
+
if config.d_model // self.adaptive_patch_factor <= 4:
|
|
820
|
+
# do not allow reduction beyond d_model less than 4
|
|
821
|
+
# logger.warning(
|
|
822
|
+
# "Disabling adaptive patching at level %s. Either increase d_model or reduce adaptive_patching_levels"
|
|
823
|
+
# % (adapt_patch_level)
|
|
824
|
+
# )
|
|
825
|
+
self.adaptive_patch_factor = 1
|
|
826
|
+
|
|
827
|
+
if config.d_model % self.adaptive_patch_factor != 0:
|
|
828
|
+
raise ValueError("d_model should be divisible by 2^i, where i varies from 0 to adaptive_patching_levels.")
|
|
829
|
+
temp_config.num_patches = temp_config.num_patches * self.adaptive_patch_factor
|
|
830
|
+
temp_config.d_model = temp_config.d_model // self.adaptive_patch_factor
|
|
831
|
+
|
|
832
|
+
self.mixer_layers = nn.ModuleList([TinyTimeMixerLayer(temp_config) for i in range(temp_config.num_layers)])
|
|
833
|
+
|
|
834
|
+
def forward(self, hidden: torch.Tensor):
|
|
835
|
+
"""
|
|
836
|
+
Args:
|
|
837
|
+
hidden (`torch.Tensor` of shape `(batch_size x nvars x num_patch x d_model)`):
|
|
838
|
+
Input tensor to the layer.
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
`torch.Tensor`: Transformed tensor.
|
|
842
|
+
"""
|
|
843
|
+
all_hidden_states = []
|
|
844
|
+
all_hidden_states.append(hidden)
|
|
845
|
+
hidden = torch.reshape(
|
|
846
|
+
hidden,
|
|
847
|
+
(
|
|
848
|
+
hidden.shape[0],
|
|
849
|
+
hidden.shape[1],
|
|
850
|
+
hidden.shape[2] * self.adaptive_patch_factor,
|
|
851
|
+
hidden.shape[3] // self.adaptive_patch_factor,
|
|
852
|
+
),
|
|
853
|
+
)
|
|
854
|
+
all_hidden_states.append(hidden)
|
|
855
|
+
|
|
856
|
+
for mod in self.mixer_layers:
|
|
857
|
+
hidden = mod(hidden)
|
|
858
|
+
all_hidden_states.append(hidden)
|
|
859
|
+
|
|
860
|
+
hidden = torch.reshape(
|
|
861
|
+
hidden,
|
|
862
|
+
(
|
|
863
|
+
hidden.shape[0],
|
|
864
|
+
hidden.shape[1],
|
|
865
|
+
hidden.shape[2] // self.adaptive_patch_factor,
|
|
866
|
+
hidden.shape[3] * self.adaptive_patch_factor,
|
|
867
|
+
),
|
|
868
|
+
)
|
|
869
|
+
all_hidden_states.append(hidden)
|
|
870
|
+
|
|
871
|
+
return hidden, all_hidden_states
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
class TinyTimeMixerBlock(nn.Module):
|
|
875
|
+
"""The main computing framework of the `TinyTimeMixer` model.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
879
|
+
Configuration.
|
|
880
|
+
"""
|
|
881
|
+
|
|
882
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
883
|
+
super().__init__()
|
|
884
|
+
|
|
885
|
+
num_layers = config.num_layers
|
|
886
|
+
|
|
887
|
+
self.adaptive_patching_levels = config.adaptive_patching_levels
|
|
888
|
+
|
|
889
|
+
if self.adaptive_patching_levels > 0:
|
|
890
|
+
self.mixers = nn.ModuleList(
|
|
891
|
+
[
|
|
892
|
+
TinyTimeMixerAdaptivePatchingBlock(config=config, adapt_patch_level=i)
|
|
893
|
+
for i in reversed(range(config.adaptive_patching_levels))
|
|
894
|
+
]
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
else:
|
|
898
|
+
self.mixers = nn.ModuleList([TinyTimeMixerLayer(config=config) for _ in range(num_layers)])
|
|
899
|
+
|
|
900
|
+
def forward(self, hidden_state, output_hidden_states: bool = False):
|
|
901
|
+
"""
|
|
902
|
+
Args:
|
|
903
|
+
hidden_state (`torch.Tensor`): The input tensor.
|
|
904
|
+
output_hidden_states (`bool`, *optional*, defaults to False.):
|
|
905
|
+
Whether to output the hidden states as well.
|
|
906
|
+
|
|
907
|
+
Returns:
|
|
908
|
+
`torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
|
|
909
|
+
`True`.
|
|
910
|
+
"""
|
|
911
|
+
all_hidden_states = []
|
|
912
|
+
|
|
913
|
+
embedding = hidden_state
|
|
914
|
+
|
|
915
|
+
for mod in self.mixers:
|
|
916
|
+
if self.adaptive_patching_levels > 0:
|
|
917
|
+
embedding, hidden_states = mod(embedding)
|
|
918
|
+
all_hidden_states.extend(hidden_states)
|
|
919
|
+
else:
|
|
920
|
+
embedding = mod(embedding)
|
|
921
|
+
if output_hidden_states:
|
|
922
|
+
all_hidden_states.append(embedding)
|
|
923
|
+
|
|
924
|
+
if output_hidden_states:
|
|
925
|
+
return embedding, all_hidden_states
|
|
926
|
+
else:
|
|
927
|
+
return embedding, None
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
class TinyTimeMixerDecoder(nn.Module):
|
|
931
|
+
"""Decoder for tiny time mixer
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
935
|
+
Configuration.
|
|
936
|
+
"""
|
|
937
|
+
|
|
938
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
939
|
+
super().__init__()
|
|
940
|
+
|
|
941
|
+
if config.d_model != config.decoder_d_model:
|
|
942
|
+
self.adapter = nn.Linear(config.d_model, config.decoder_d_model)
|
|
943
|
+
else:
|
|
944
|
+
self.adapter = None
|
|
945
|
+
|
|
946
|
+
self.decoder_raw_residual = config.decoder_raw_residual
|
|
947
|
+
self.num_input_channels = config.num_input_channels
|
|
948
|
+
|
|
949
|
+
if config.decoder_raw_residual:
|
|
950
|
+
self.decoder_raw_embedding = nn.Linear(config.patch_length, config.decoder_d_model)
|
|
951
|
+
# nn.init.zeros_(self.decoder_raw_embedding.weight)
|
|
952
|
+
# nn.init.zeros_(self.decoder_raw_embedding.bias)
|
|
953
|
+
|
|
954
|
+
decoder_config = copy.deepcopy(config)
|
|
955
|
+
decoder_config.num_layers = config.decoder_num_layers
|
|
956
|
+
decoder_config.d_model = config.decoder_d_model
|
|
957
|
+
decoder_config.dropout = config.head_dropout
|
|
958
|
+
decoder_config.adaptive_patching_levels = config.decoder_adaptive_patching_levels
|
|
959
|
+
decoder_config.mode = config.decoder_mode
|
|
960
|
+
|
|
961
|
+
if config.categorical_vocab_size_list is not None:
|
|
962
|
+
if config.decoder_mode == "common_channel":
|
|
963
|
+
# logger.warning("Setting decoder_mode to mix_channel as static categorical variables is available")
|
|
964
|
+
# config.decoder_mode = "mix_channel"
|
|
965
|
+
raise ValueError("set decoder_mode to mix_channel when using static categorical variables")
|
|
966
|
+
|
|
967
|
+
decoder_config.num_input_channels += len(config.categorical_vocab_size_list)
|
|
968
|
+
self.decoder_cat_embedding_layer = TinyTimeMixerCategoricalEmbeddingLayer(decoder_config)
|
|
969
|
+
else:
|
|
970
|
+
self.decoder_cat_embedding_layer = None
|
|
971
|
+
|
|
972
|
+
self.decoder_block = TinyTimeMixerBlock(decoder_config)
|
|
973
|
+
|
|
974
|
+
self.resolution_prefix_tuning = config.resolution_prefix_tuning
|
|
975
|
+
|
|
976
|
+
def forward(
|
|
977
|
+
self,
|
|
978
|
+
hidden_state,
|
|
979
|
+
patch_input,
|
|
980
|
+
output_hidden_states: bool = False,
|
|
981
|
+
static_categorical_values: Optional[torch.Tensor] = None,
|
|
982
|
+
):
|
|
983
|
+
"""
|
|
984
|
+
Args:
|
|
985
|
+
hidden_state (`torch.Tensor` of shape `(batch_size x nvars x num_patch x d_model)`): The input tensor from backbone.
|
|
986
|
+
output_hidden_states (`bool`, *optional*, defaults to False.):
|
|
987
|
+
Whether to output the hidden states as well.
|
|
988
|
+
|
|
989
|
+
static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
|
|
990
|
+
Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
|
|
991
|
+
TinyTimeMixerConfig param `categorical_vocab_size_list`
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
`torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
|
|
995
|
+
`True`.
|
|
996
|
+
"""
|
|
997
|
+
if output_hidden_states:
|
|
998
|
+
decoder_hidden_states = []
|
|
999
|
+
else:
|
|
1000
|
+
decoder_hidden_states = None
|
|
1001
|
+
|
|
1002
|
+
decoder_input = hidden_state
|
|
1003
|
+
|
|
1004
|
+
if self.adapter is not None:
|
|
1005
|
+
decoder_input = self.adapter(
|
|
1006
|
+
hidden_state
|
|
1007
|
+
) # model_output: [batch_size x nvars x num_patch x decoder_d_model]
|
|
1008
|
+
if output_hidden_states:
|
|
1009
|
+
decoder_hidden_states.append(decoder_input)
|
|
1010
|
+
|
|
1011
|
+
if self.decoder_raw_residual:
|
|
1012
|
+
if self.resolution_prefix_tuning:
|
|
1013
|
+
if patch_input.shape[-2] == decoder_input.shape[-2] - 1:
|
|
1014
|
+
temp_shape = list(patch_input.shape)
|
|
1015
|
+
temp_shape[-2] = 1
|
|
1016
|
+
temp_zeros = torch.zeros(*temp_shape).to(patch_input.device)
|
|
1017
|
+
patch_input = torch.cat([temp_zeros, patch_input], dim=-2)
|
|
1018
|
+
|
|
1019
|
+
decoder_input = decoder_input + self.decoder_raw_embedding(
|
|
1020
|
+
patch_input
|
|
1021
|
+
) # [batch_size x nvars x num_patch x decoder_d_model]
|
|
1022
|
+
if output_hidden_states:
|
|
1023
|
+
decoder_hidden_states.append(decoder_input)
|
|
1024
|
+
|
|
1025
|
+
if self.decoder_cat_embedding_layer is not None:
|
|
1026
|
+
if static_categorical_values is None:
|
|
1027
|
+
raise ValueError("Missing static_categorical_values tensor in forward call")
|
|
1028
|
+
cat_embeddings = self.decoder_cat_embedding_layer(
|
|
1029
|
+
static_categorical_values
|
|
1030
|
+
) # bs x n_cat x n_patches x d_model
|
|
1031
|
+
|
|
1032
|
+
decoder_input = torch.concat(
|
|
1033
|
+
(decoder_input, cat_embeddings), dim=1
|
|
1034
|
+
) # bs x nvars+n_cat x n_patches x d_model
|
|
1035
|
+
|
|
1036
|
+
decoder_output, hidden_states = self.decoder_block(
|
|
1037
|
+
hidden_state=decoder_input, output_hidden_states=output_hidden_states
|
|
1038
|
+
) # bs x nvars+n_cat x n_patches x d_model
|
|
1039
|
+
|
|
1040
|
+
if output_hidden_states:
|
|
1041
|
+
decoder_hidden_states.extend(hidden_states)
|
|
1042
|
+
|
|
1043
|
+
if self.decoder_cat_embedding_layer is not None:
|
|
1044
|
+
decoder_output = decoder_output[:, : self.num_input_channels, :, :] # bs x nvars x n_patches x d_model
|
|
1045
|
+
if output_hidden_states:
|
|
1046
|
+
decoder_hidden_states.append(decoder_output)
|
|
1047
|
+
|
|
1048
|
+
return decoder_output, decoder_hidden_states
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
class TinyTimeMixerForPredictionHead(nn.Module):
|
|
1052
|
+
"""Prediction Head for Forecasting
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
config (`TinyTimeMixerConfig`, *required*): Configuration.
|
|
1056
|
+
"""
|
|
1057
|
+
|
|
1058
|
+
def __init__(self, config: TinyTimeMixerConfig, distribution_output=None):
|
|
1059
|
+
super().__init__()
|
|
1060
|
+
|
|
1061
|
+
self.prediction_channel_indices = config.prediction_channel_indices
|
|
1062
|
+
|
|
1063
|
+
if self.prediction_channel_indices is not None:
|
|
1064
|
+
self.prediction_channel_indices.sort()
|
|
1065
|
+
|
|
1066
|
+
self.prediction_filter_length = config.prediction_filter_length
|
|
1067
|
+
|
|
1068
|
+
self.dropout_layer = nn.Dropout(config.head_dropout)
|
|
1069
|
+
self.enable_forecast_channel_mixing = config.enable_forecast_channel_mixing
|
|
1070
|
+
if config.use_decoder:
|
|
1071
|
+
head_d_model = config.decoder_d_model
|
|
1072
|
+
else:
|
|
1073
|
+
head_d_model = config.d_model
|
|
1074
|
+
|
|
1075
|
+
if distribution_output is None:
|
|
1076
|
+
self.base_forecast_block = nn.Linear((config.num_patches * head_d_model), config.prediction_length)
|
|
1077
|
+
else:
|
|
1078
|
+
self.base_forecast_block = distribution_output.get_parameter_projection(config.num_patches * head_d_model)
|
|
1079
|
+
|
|
1080
|
+
self.flatten = nn.Flatten(start_dim=-2)
|
|
1081
|
+
|
|
1082
|
+
if self.enable_forecast_channel_mixing:
|
|
1083
|
+
temp_config = copy.deepcopy(config)
|
|
1084
|
+
if self.prediction_filter_length is not None:
|
|
1085
|
+
temp_config.prediction_length = self.prediction_filter_length
|
|
1086
|
+
|
|
1087
|
+
self.fcm_block = ForecastChannelHeadMixer(config=temp_config)
|
|
1088
|
+
|
|
1089
|
+
def forward(self, hidden_features, past_values, future_values=None):
|
|
1090
|
+
"""
|
|
1091
|
+
|
|
1092
|
+
Args:
|
|
1093
|
+
hidden_features `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
|
|
1094
|
+
features.
|
|
1095
|
+
|
|
1096
|
+
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
1097
|
+
Context values of the time series. For a forecasting task, this denotes the history/past time series values.
|
|
1098
|
+
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
|
|
1099
|
+
greater than 1.
|
|
1100
|
+
|
|
1101
|
+
future_values (`torch.Tensor` of shape `(batch_size, prediction length, input_channels)`, *optional*, Defaults to None):
|
|
1102
|
+
Actual groundtruths of the forecasts. Pass dummy values (say 0) for forecast channels, if groundtruth is unknown.
|
|
1103
|
+
Pass the correct values for Exogenous channels where the forecast values are known.
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
Returns:
|
|
1107
|
+
`torch.Tensor` of shape `(batch_size, prediction_length, forecast_channels)`.
|
|
1108
|
+
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
|
|
1112
|
+
hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
|
|
1113
|
+
forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
|
|
1114
|
+
if isinstance(forecast, tuple):
|
|
1115
|
+
forecast = tuple(z.transpose(-1, -2) for z in forecast)
|
|
1116
|
+
else:
|
|
1117
|
+
forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
|
|
1118
|
+
|
|
1119
|
+
if self.prediction_channel_indices is not None:
|
|
1120
|
+
if isinstance(forecast, tuple):
|
|
1121
|
+
forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
|
|
1122
|
+
else:
|
|
1123
|
+
forecast = forecast[
|
|
1124
|
+
..., self.prediction_channel_indices
|
|
1125
|
+
] # [batch_size x prediction_length x prediction_n_vars]
|
|
1126
|
+
|
|
1127
|
+
if self.prediction_filter_length is not None:
|
|
1128
|
+
if isinstance(forecast, tuple):
|
|
1129
|
+
forecast = tuple(z[:, : self.prediction_filter_length, :] for z in forecast)
|
|
1130
|
+
else:
|
|
1131
|
+
forecast = forecast[
|
|
1132
|
+
:, : self.prediction_filter_length, :
|
|
1133
|
+
] # [batch_size x prediction_filter_length x prediction_n_vars]
|
|
1134
|
+
|
|
1135
|
+
if (
|
|
1136
|
+
self.prediction_filter_length is not None
|
|
1137
|
+
and future_values is not None
|
|
1138
|
+
and future_values.shape[1] != self.prediction_filter_length
|
|
1139
|
+
):
|
|
1140
|
+
future_values = future_values[
|
|
1141
|
+
:, : self.prediction_filter_length, :
|
|
1142
|
+
] # [batch_size x prediction_filter_length x n_vars]
|
|
1143
|
+
|
|
1144
|
+
if self.enable_forecast_channel_mixing:
|
|
1145
|
+
if isinstance(forecast, tuple):
|
|
1146
|
+
raise ValueError("Forecast channel mixing is not enabled for distribution head")
|
|
1147
|
+
else:
|
|
1148
|
+
forecast = self.fcm_block(forecast, past_values=past_values, future_values=future_values)
|
|
1149
|
+
# [batch_size x prediction_length x prediction_n_vars]
|
|
1150
|
+
|
|
1151
|
+
return forecast
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
class TinyTimeMixerPreTrainedModel(PreTrainedModel):
|
|
1155
|
+
# Weight initialization
|
|
1156
|
+
config_class = TinyTimeMixerConfig
|
|
1157
|
+
base_model_prefix = "model"
|
|
1158
|
+
main_input_name = "past_values"
|
|
1159
|
+
supports_gradient_checkpointing = False
|
|
1160
|
+
|
|
1161
|
+
def _init_weights(self, module):
|
|
1162
|
+
"""Initialize weights"""
|
|
1163
|
+
|
|
1164
|
+
if isinstance(module, TinyTimeMixerPositionalEncoding):
|
|
1165
|
+
# initialize positional encoding
|
|
1166
|
+
if self.config.positional_encoding_type == "random":
|
|
1167
|
+
nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
|
|
1168
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
1169
|
+
module.bias.data.zero_()
|
|
1170
|
+
module.weight.data.fill_(1.0)
|
|
1171
|
+
elif isinstance(module, TinyTimeMixerBatchNorm):
|
|
1172
|
+
module.batchnorm.bias.data.zero_()
|
|
1173
|
+
module.batchnorm.weight.data.fill_(1.0)
|
|
1174
|
+
elif isinstance(module, nn.Linear):
|
|
1175
|
+
# print(f"Initializing Linear layers with method: {self.config.init_linear}")
|
|
1176
|
+
if self.config.init_linear == "normal":
|
|
1177
|
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
|
1178
|
+
if module.bias is not None:
|
|
1179
|
+
module.bias.data.zero_()
|
|
1180
|
+
elif self.config.init_linear == "uniform":
|
|
1181
|
+
nn.init.uniform_(module.weight)
|
|
1182
|
+
if module.bias is not None:
|
|
1183
|
+
module.bias.data.zero_()
|
|
1184
|
+
elif self.config.init_linear == "xavier_uniform":
|
|
1185
|
+
nn.init.xavier_uniform_(module.weight)
|
|
1186
|
+
if module.bias is not None:
|
|
1187
|
+
module.bias.data.zero_()
|
|
1188
|
+
else:
|
|
1189
|
+
module.reset_parameters()
|
|
1190
|
+
elif isinstance(module, nn.Embedding):
|
|
1191
|
+
# print(f"Initializing Embedding layers with method: {self.config.init_embed}")
|
|
1192
|
+
if self.config.init_embed == "normal":
|
|
1193
|
+
nn.init.normal_(module.weight)
|
|
1194
|
+
elif self.config.init_embed == "uniform":
|
|
1195
|
+
nn.init.uniform_(module.weight)
|
|
1196
|
+
elif self.config.init_embed == "xavier_uniform":
|
|
1197
|
+
nn.init.xavier_uniform_(module.weight)
|
|
1198
|
+
else:
|
|
1199
|
+
module.reset_parameters()
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
class TinyTimeMixerPatchify(nn.Module):
|
|
1203
|
+
"""
|
|
1204
|
+
A class to patchify the time series sequence into different patches
|
|
1205
|
+
|
|
1206
|
+
Returns:
|
|
1207
|
+
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
|
|
1208
|
+
"""
|
|
1209
|
+
|
|
1210
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1211
|
+
super().__init__()
|
|
1212
|
+
|
|
1213
|
+
self.sequence_length = (
|
|
1214
|
+
config.masked_context_length if config.masked_context_length is not None else config.context_length
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
self.patch_length = config.patch_length
|
|
1218
|
+
self.patch_stride = config.patch_stride
|
|
1219
|
+
|
|
1220
|
+
if self.sequence_length <= self.patch_length:
|
|
1221
|
+
raise ValueError(
|
|
1222
|
+
f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
# get the number of patches
|
|
1226
|
+
self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
|
|
1227
|
+
new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
|
|
1228
|
+
self.sequence_start = self.sequence_length - new_sequence_length
|
|
1229
|
+
|
|
1230
|
+
def forward(self, past_values: torch.Tensor):
|
|
1231
|
+
"""
|
|
1232
|
+
Parameters:
|
|
1233
|
+
past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
|
|
1234
|
+
Input for patchification
|
|
1235
|
+
|
|
1236
|
+
Returns:
|
|
1237
|
+
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
|
|
1238
|
+
"""
|
|
1239
|
+
sequence_length = past_values.shape[-2]
|
|
1240
|
+
if sequence_length != self.sequence_length:
|
|
1241
|
+
raise ValueError(
|
|
1242
|
+
f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
|
|
1243
|
+
)
|
|
1244
|
+
# output: [bs x new_sequence_length x num_channels]
|
|
1245
|
+
output = past_values[:, self.sequence_start :, :]
|
|
1246
|
+
# output: [bs x num_patches x num_input_channels x patch_length]
|
|
1247
|
+
output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
|
|
1248
|
+
# output: [bs x num_input_channels x num_patches x patch_length]
|
|
1249
|
+
output = output.transpose(-2, -3).contiguous()
|
|
1250
|
+
return output
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
class TinyTimeMixerStdScaler(nn.Module):
|
|
1254
|
+
"""
|
|
1255
|
+
Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
|
|
1256
|
+
subtracting from the mean and dividing by the standard deviation.
|
|
1257
|
+
"""
|
|
1258
|
+
|
|
1259
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1260
|
+
super().__init__()
|
|
1261
|
+
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
|
|
1262
|
+
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
|
|
1263
|
+
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
|
|
1264
|
+
|
|
1265
|
+
def forward(
|
|
1266
|
+
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
1267
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1268
|
+
"""
|
|
1269
|
+
Parameters:
|
|
1270
|
+
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1271
|
+
input for Batch norm calculation
|
|
1272
|
+
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1273
|
+
Calculating the scale on the observed indicator.
|
|
1274
|
+
Returns:
|
|
1275
|
+
tuple of `torch.Tensor` of shapes
|
|
1276
|
+
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
|
|
1277
|
+
`(batch_size, 1, num_input_channels)`)
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
|
|
1281
|
+
denominator = denominator.clamp_min(torch.tensor(1, device=denominator.device))
|
|
1282
|
+
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
|
|
1283
|
+
|
|
1284
|
+
variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
|
|
1285
|
+
scale = torch.sqrt(variance + self.minimum_scale)
|
|
1286
|
+
return (data - loc) / scale, loc, scale
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
class TinyTimeMixerMeanScaler(nn.Module):
|
|
1290
|
+
"""
|
|
1291
|
+
Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
|
|
1292
|
+
accordingly.
|
|
1293
|
+
"""
|
|
1294
|
+
|
|
1295
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1296
|
+
super().__init__()
|
|
1297
|
+
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
|
|
1298
|
+
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
|
|
1299
|
+
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
|
|
1300
|
+
self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
|
|
1301
|
+
|
|
1302
|
+
def forward(
|
|
1303
|
+
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
1304
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1305
|
+
"""
|
|
1306
|
+
Parameters:
|
|
1307
|
+
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1308
|
+
input for Batch norm calculation
|
|
1309
|
+
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1310
|
+
Calculating the scale on the observed indicator.
|
|
1311
|
+
Returns:
|
|
1312
|
+
tuple of `torch.Tensor` of shapes
|
|
1313
|
+
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
|
|
1314
|
+
`(batch_size, 1, num_input_channels)`)
|
|
1315
|
+
"""
|
|
1316
|
+
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
|
|
1317
|
+
num_observed = observed_indicator.sum(self.dim, keepdim=True)
|
|
1318
|
+
|
|
1319
|
+
scale = ts_sum / torch.clamp(num_observed, min=1)
|
|
1320
|
+
|
|
1321
|
+
# If `default_scale` is provided, we use it, otherwise we use the scale
|
|
1322
|
+
# of the batch.
|
|
1323
|
+
if self.default_scale is None:
|
|
1324
|
+
batch_sum = ts_sum.sum(dim=0)
|
|
1325
|
+
batch_observations = torch.clamp(num_observed.sum(0), min=1)
|
|
1326
|
+
default_scale = torch.squeeze(batch_sum / batch_observations)
|
|
1327
|
+
else:
|
|
1328
|
+
default_scale = self.default_scale * torch.ones_like(scale)
|
|
1329
|
+
|
|
1330
|
+
# apply default scale where there are no observations
|
|
1331
|
+
scale = torch.where(num_observed > 0, scale, default_scale)
|
|
1332
|
+
|
|
1333
|
+
# ensure the scale is at least `self.minimum_scale`
|
|
1334
|
+
scale = torch.clamp(scale, min=self.minimum_scale)
|
|
1335
|
+
scaled_data = data / scale
|
|
1336
|
+
|
|
1337
|
+
if not self.keepdim:
|
|
1338
|
+
scale = scale.squeeze(dim=self.dim)
|
|
1339
|
+
|
|
1340
|
+
return scaled_data, torch.zeros_like(scale), scale
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
class TinyTimeMixerNOPScaler(nn.Module):
|
|
1344
|
+
"""
|
|
1345
|
+
Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
|
|
1346
|
+
"""
|
|
1347
|
+
|
|
1348
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1349
|
+
super().__init__()
|
|
1350
|
+
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
|
|
1351
|
+
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
|
|
1352
|
+
|
|
1353
|
+
def forward(
|
|
1354
|
+
self, data: torch.Tensor, observed_indicator: torch.Tensor = None
|
|
1355
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1356
|
+
"""
|
|
1357
|
+
Parameters:
|
|
1358
|
+
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1359
|
+
input for Batch norm calculation
|
|
1360
|
+
Returns:
|
|
1361
|
+
tuple of `torch.Tensor` of shapes
|
|
1362
|
+
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
|
|
1363
|
+
`(batch_size, 1, num_input_channels)`)
|
|
1364
|
+
"""
|
|
1365
|
+
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
|
|
1366
|
+
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
|
|
1367
|
+
return data, loc, scale
|
|
1368
|
+
|
|
1369
|
+
|
|
1370
|
+
@dataclass
|
|
1371
|
+
class TinyTimeMixerEncoderOutput(ModelOutput):
|
|
1372
|
+
"""
|
|
1373
|
+
Base class for `TinyTimeMixerEncoderOutput`, with potential hidden states.
|
|
1374
|
+
|
|
1375
|
+
Args:
|
|
1376
|
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
|
|
1377
|
+
Hidden-state at the output of the last layer of the model.
|
|
1378
|
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
|
1379
|
+
Hidden-states of the model at the output of each layer.
|
|
1380
|
+
"""
|
|
1381
|
+
|
|
1382
|
+
last_hidden_state: torch.FloatTensor = None
|
|
1383
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
1384
|
+
|
|
1385
|
+
|
|
1386
|
+
class TinyTimeMixerEncoder(TinyTimeMixerPreTrainedModel):
|
|
1387
|
+
"""
|
|
1388
|
+
Encoder for TinyTimeMixer which inputs patched time-series and outputs patched embeddings.
|
|
1389
|
+
|
|
1390
|
+
Args:
|
|
1391
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
1392
|
+
Configuration.
|
|
1393
|
+
"""
|
|
1394
|
+
|
|
1395
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1396
|
+
if config.init_processing is False:
|
|
1397
|
+
config.check_and_init_preprocessing()
|
|
1398
|
+
|
|
1399
|
+
super().__init__(config)
|
|
1400
|
+
|
|
1401
|
+
self.use_return_dict = config.use_return_dict
|
|
1402
|
+
|
|
1403
|
+
self.patcher = nn.Linear(config.patch_length, config.d_model)
|
|
1404
|
+
if config.use_positional_encoding:
|
|
1405
|
+
self.positional_encoder = TinyTimeMixerPositionalEncoding(config=config)
|
|
1406
|
+
else:
|
|
1407
|
+
self.positional_encoder = None
|
|
1408
|
+
self.mlp_mixer_encoder = TinyTimeMixerBlock(config=config)
|
|
1409
|
+
|
|
1410
|
+
if config.resolution_prefix_tuning:
|
|
1411
|
+
mid_dim = (config.patch_length + config.d_model) // 2
|
|
1412
|
+
|
|
1413
|
+
self.freq_mod = nn.Sequential(
|
|
1414
|
+
nn.Embedding(config.frequency_token_vocab_size, config.patch_length),
|
|
1415
|
+
nn.Linear(config.patch_length, mid_dim),
|
|
1416
|
+
nn.GELU(),
|
|
1417
|
+
nn.Linear(mid_dim, config.d_model),
|
|
1418
|
+
)
|
|
1419
|
+
self.resolution_prefix_tuning = config.resolution_prefix_tuning
|
|
1420
|
+
self.d_model = config.d_model
|
|
1421
|
+
|
|
1422
|
+
# # Initialize weights and apply final processing
|
|
1423
|
+
# if config.post_init:
|
|
1424
|
+
# self.post_init()
|
|
1425
|
+
|
|
1426
|
+
def forward(
|
|
1427
|
+
self,
|
|
1428
|
+
past_values: torch.Tensor,
|
|
1429
|
+
output_hidden_states: Optional[bool] = False,
|
|
1430
|
+
return_dict: Optional[bool] = None,
|
|
1431
|
+
freq_token: Optional[torch.Tensor] = None,
|
|
1432
|
+
) -> Union[Tuple, TinyTimeMixerEncoderOutput]:
|
|
1433
|
+
r"""
|
|
1434
|
+
Args:
|
|
1435
|
+
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
|
|
1436
|
+
Context values of the time series.
|
|
1437
|
+
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
|
|
1438
|
+
it is greater than 1.
|
|
1439
|
+
|
|
1440
|
+
output_hidden_states (`bool`, *optional*):
|
|
1441
|
+
Whether or not to return the hidden states of all layers.
|
|
1442
|
+
|
|
1443
|
+
return_dict (`bool`, *optional*):
|
|
1444
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
1445
|
+
|
|
1446
|
+
Returns:
|
|
1447
|
+
`torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
|
|
1448
|
+
"""
|
|
1449
|
+
|
|
1450
|
+
return_dict = return_dict if return_dict is not None else self.use_return_dict
|
|
1451
|
+
|
|
1452
|
+
# flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
|
|
1453
|
+
patches = self.patcher(past_values)
|
|
1454
|
+
|
|
1455
|
+
if self.resolution_prefix_tuning:
|
|
1456
|
+
if freq_token is not None:
|
|
1457
|
+
freq_embedding = self.freq_mod(freq_token.long()) # bs x d_model
|
|
1458
|
+
|
|
1459
|
+
freq_embedding = freq_embedding.view(patches.shape[0], 1, 1, self.d_model)
|
|
1460
|
+
freq_embedding = freq_embedding.expand(
|
|
1461
|
+
patches.shape[0],
|
|
1462
|
+
patches.shape[1],
|
|
1463
|
+
1,
|
|
1464
|
+
self.d_model,
|
|
1465
|
+
) # bs x channels x 1 x num_features
|
|
1466
|
+
|
|
1467
|
+
patches = torch.cat((freq_embedding, patches), dim=-2) # bs x channels x num_patch+1 x num_features
|
|
1468
|
+
|
|
1469
|
+
else:
|
|
1470
|
+
raise Exception("Expecting freq_token in forward")
|
|
1471
|
+
|
|
1472
|
+
# add positional encoder
|
|
1473
|
+
if self.positional_encoder is not None:
|
|
1474
|
+
patches = self.positional_encoder(patches)
|
|
1475
|
+
|
|
1476
|
+
last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
|
|
1477
|
+
|
|
1478
|
+
if not return_dict:
|
|
1479
|
+
return tuple(
|
|
1480
|
+
v
|
|
1481
|
+
for v in [
|
|
1482
|
+
last_hidden_state,
|
|
1483
|
+
hidden_states,
|
|
1484
|
+
]
|
|
1485
|
+
)
|
|
1486
|
+
|
|
1487
|
+
return TinyTimeMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
@dataclass
|
|
1491
|
+
class TinyTimeMixerModelOutput(ModelOutput):
|
|
1492
|
+
"""
|
|
1493
|
+
Base class for model's outputs, with potential hidden states.
|
|
1494
|
+
|
|
1495
|
+
Args:
|
|
1496
|
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
|
|
1497
|
+
Hidden-state at the output of the last layer of the model.
|
|
1498
|
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
|
1499
|
+
Hidden-states of the model at the output of each layer.
|
|
1500
|
+
patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
|
|
1501
|
+
Patched input data to the model.
|
|
1502
|
+
loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
|
|
1503
|
+
Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
|
|
1504
|
+
enabled.
|
|
1505
|
+
scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
|
|
1506
|
+
Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
|
|
1507
|
+
enabled.
|
|
1508
|
+
"""
|
|
1509
|
+
|
|
1510
|
+
last_hidden_state: torch.FloatTensor = None
|
|
1511
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
1512
|
+
patch_input: torch.FloatTensor = None
|
|
1513
|
+
loc: Optional[torch.FloatTensor] = None
|
|
1514
|
+
scale: Optional[torch.FloatTensor] = None
|
|
1515
|
+
|
|
1516
|
+
|
|
1517
|
+
class TinyTimeMixerModel(TinyTimeMixerPreTrainedModel):
|
|
1518
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1519
|
+
if config.init_processing is False:
|
|
1520
|
+
config.check_and_init_preprocessing()
|
|
1521
|
+
|
|
1522
|
+
super().__init__(config)
|
|
1523
|
+
|
|
1524
|
+
self.use_return_dict = config.use_return_dict
|
|
1525
|
+
self.encoder = TinyTimeMixerEncoder(config)
|
|
1526
|
+
self.patching = TinyTimeMixerPatchify(config)
|
|
1527
|
+
|
|
1528
|
+
if config.scaling == "mean":
|
|
1529
|
+
self.scaler = TinyTimeMixerMeanScaler(config)
|
|
1530
|
+
elif config.scaling == "std" or config.scaling is True:
|
|
1531
|
+
self.scaler = TinyTimeMixerStdScaler(config)
|
|
1532
|
+
else:
|
|
1533
|
+
self.scaler = TinyTimeMixerNOPScaler(config)
|
|
1534
|
+
|
|
1535
|
+
self.d_model = config.d_model
|
|
1536
|
+
|
|
1537
|
+
# # Initialize weights and apply final processing
|
|
1538
|
+
# if config.post_init:
|
|
1539
|
+
# self.post_init()
|
|
1540
|
+
|
|
1541
|
+
def forward(
|
|
1542
|
+
self,
|
|
1543
|
+
past_values: torch.Tensor,
|
|
1544
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
1545
|
+
output_hidden_states: Optional[bool] = False,
|
|
1546
|
+
return_dict: Optional[bool] = None,
|
|
1547
|
+
freq_token: Optional[torch.Tensor] = None,
|
|
1548
|
+
) -> TinyTimeMixerModelOutput:
|
|
1549
|
+
r"""
|
|
1550
|
+
past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
|
|
1551
|
+
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
|
|
1552
|
+
in `[0, 1]` or `[False, True]`:
|
|
1553
|
+
- 1 or True for values that are **observed**,
|
|
1554
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
1555
|
+
|
|
1556
|
+
Returns:
|
|
1557
|
+
|
|
1558
|
+
"""
|
|
1559
|
+
return_dict = return_dict if return_dict is not None else self.use_return_dict
|
|
1560
|
+
|
|
1561
|
+
if past_observed_mask is None:
|
|
1562
|
+
past_observed_mask = torch.ones_like(past_values)
|
|
1563
|
+
scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask)
|
|
1564
|
+
|
|
1565
|
+
patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
|
|
1566
|
+
|
|
1567
|
+
enc_input = patched_x
|
|
1568
|
+
|
|
1569
|
+
encoder_output = self.encoder(
|
|
1570
|
+
enc_input,
|
|
1571
|
+
output_hidden_states=output_hidden_states,
|
|
1572
|
+
return_dict=return_dict,
|
|
1573
|
+
freq_token=freq_token,
|
|
1574
|
+
)
|
|
1575
|
+
|
|
1576
|
+
if isinstance(encoder_output, tuple):
|
|
1577
|
+
encoder_output = TinyTimeMixerEncoderOutput(*encoder_output)
|
|
1578
|
+
|
|
1579
|
+
if not return_dict:
|
|
1580
|
+
return tuple(
|
|
1581
|
+
v
|
|
1582
|
+
for v in [
|
|
1583
|
+
encoder_output.last_hidden_state,
|
|
1584
|
+
encoder_output.hidden_states,
|
|
1585
|
+
patched_x,
|
|
1586
|
+
loc,
|
|
1587
|
+
scale,
|
|
1588
|
+
]
|
|
1589
|
+
)
|
|
1590
|
+
|
|
1591
|
+
return TinyTimeMixerModelOutput(
|
|
1592
|
+
last_hidden_state=encoder_output.last_hidden_state,
|
|
1593
|
+
hidden_states=encoder_output.hidden_states,
|
|
1594
|
+
patch_input=patched_x,
|
|
1595
|
+
loc=loc,
|
|
1596
|
+
scale=scale,
|
|
1597
|
+
)
|
|
1598
|
+
|
|
1599
|
+
|
|
1600
|
+
@dataclass
|
|
1601
|
+
class TinyTimeMixerForPredictionOutput(ModelOutput):
|
|
1602
|
+
"""
|
|
1603
|
+
Output type of [`TinyTimeMixerForPredictionOutput`].
|
|
1604
|
+
|
|
1605
|
+
Args:
|
|
1606
|
+
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
|
|
1607
|
+
Prediction output from the forecast head.
|
|
1608
|
+
backbone_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
|
|
1609
|
+
Backbone embeddings before passing through the decoder
|
|
1610
|
+
decoder_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
|
|
1611
|
+
Decoder embeddings before passing through the head.
|
|
1612
|
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
|
1613
|
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
1614
|
+
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
|
|
1615
|
+
Total loss.
|
|
1616
|
+
loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
|
|
1617
|
+
Input mean
|
|
1618
|
+
scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
|
|
1619
|
+
Input std dev
|
|
1620
|
+
|
|
1621
|
+
"""
|
|
1622
|
+
|
|
1623
|
+
loss: Optional[torch.FloatTensor] = None
|
|
1624
|
+
prediction_outputs: torch.FloatTensor = None
|
|
1625
|
+
backbone_hidden_state: torch.FloatTensor = None
|
|
1626
|
+
decoder_hidden_state: torch.FloatTensor = None
|
|
1627
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
1628
|
+
loc: torch.FloatTensor = None
|
|
1629
|
+
scale: torch.FloatTensor = None
|
|
1630
|
+
|
|
1631
|
+
|
|
1632
|
+
@dataclass
|
|
1633
|
+
class SampleTinyTimeMixerPredictionOutput(ModelOutput):
|
|
1634
|
+
"""
|
|
1635
|
+
Base class for time series model's predictions outputs that contains the sampled values from the chosen
|
|
1636
|
+
distribution.
|
|
1637
|
+
|
|
1638
|
+
Args:
|
|
1639
|
+
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
|
|
1640
|
+
Sampled values from the chosen distribution.
|
|
1641
|
+
"""
|
|
1642
|
+
|
|
1643
|
+
sequences: torch.FloatTensor = None
|
|
1644
|
+
|
|
1645
|
+
|
|
1646
|
+
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
|
|
1647
|
+
"""
|
|
1648
|
+
Computes the negative log likelihood loss from input distribution with respect to target.
|
|
1649
|
+
"""
|
|
1650
|
+
return -input.log_prob(target)
|
|
1651
|
+
|
|
1652
|
+
|
|
1653
|
+
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
|
|
1654
|
+
"""
|
|
1655
|
+
Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
|
|
1656
|
+
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
|
|
1657
|
+
|
|
1658
|
+
Args:
|
|
1659
|
+
input_tensor (`torch.FloatTensor`):
|
|
1660
|
+
Input tensor, of which the average must be computed.
|
|
1661
|
+
weights (`torch.FloatTensor`, *optional*):
|
|
1662
|
+
Weights tensor, of the same shape as `input_tensor`.
|
|
1663
|
+
dim (`int`, *optional*):
|
|
1664
|
+
The dim along which to average `input_tensor`.
|
|
1665
|
+
|
|
1666
|
+
Returns:
|
|
1667
|
+
`torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
|
|
1668
|
+
"""
|
|
1669
|
+
if weights is not None:
|
|
1670
|
+
weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
|
|
1671
|
+
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
|
|
1672
|
+
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
|
|
1673
|
+
else:
|
|
1674
|
+
return input_tensor.mean(dim=dim)
|
|
1675
|
+
|
|
1676
|
+
|
|
1677
|
+
class TinyTimeMixerForPrediction(TinyTimeMixerPreTrainedModel):
|
|
1678
|
+
r"""
|
|
1679
|
+
`TinyTimeMixer` for forecasting application.
|
|
1680
|
+
|
|
1681
|
+
Args:
|
|
1682
|
+
config (`TinyTimeMixerConfig`, *required*):
|
|
1683
|
+
Configuration.
|
|
1684
|
+
|
|
1685
|
+
Returns:
|
|
1686
|
+
`None`.
|
|
1687
|
+
"""
|
|
1688
|
+
|
|
1689
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1690
|
+
config.check_and_init_preprocessing()
|
|
1691
|
+
super().__init__(config)
|
|
1692
|
+
|
|
1693
|
+
self.config = config
|
|
1694
|
+
|
|
1695
|
+
self.loss = config.loss
|
|
1696
|
+
|
|
1697
|
+
self.use_return_dict = config.use_return_dict
|
|
1698
|
+
|
|
1699
|
+
self.prediction_channel_indices = config.prediction_channel_indices
|
|
1700
|
+
self.num_parallel_samples = config.num_parallel_samples
|
|
1701
|
+
|
|
1702
|
+
self.num_input_channels = config.num_input_channels
|
|
1703
|
+
|
|
1704
|
+
self.prediction_filter_length = config.prediction_filter_length
|
|
1705
|
+
|
|
1706
|
+
if config.loss in ["mse", "mae", "pinball", "huber"] or config.loss is None:
|
|
1707
|
+
self.distribution_output = None
|
|
1708
|
+
elif config.loss == "nll":
|
|
1709
|
+
if self.prediction_filter_length is None:
|
|
1710
|
+
dim = config.prediction_length
|
|
1711
|
+
else:
|
|
1712
|
+
dim = config.prediction_filter_length
|
|
1713
|
+
|
|
1714
|
+
distribution_output_map = {
|
|
1715
|
+
"student_t": StudentTOutput,
|
|
1716
|
+
"normal": NormalOutput,
|
|
1717
|
+
"negative_binomial": NegativeBinomialOutput,
|
|
1718
|
+
}
|
|
1719
|
+
output_class = distribution_output_map.get(config.distribution_output, None)
|
|
1720
|
+
if output_class is not None:
|
|
1721
|
+
self.distribution_output = output_class(dim=dim)
|
|
1722
|
+
else:
|
|
1723
|
+
raise ValueError(f"Unknown distribution output {config.distribution_output}")
|
|
1724
|
+
|
|
1725
|
+
self.backbone = TinyTimeMixerModel(config)
|
|
1726
|
+
|
|
1727
|
+
self.use_decoder = config.use_decoder
|
|
1728
|
+
|
|
1729
|
+
if config.use_decoder:
|
|
1730
|
+
self.decoder = TinyTimeMixerDecoder(config)
|
|
1731
|
+
|
|
1732
|
+
self.head = TinyTimeMixerForPredictionHead(
|
|
1733
|
+
config=config,
|
|
1734
|
+
distribution_output=self.distribution_output,
|
|
1735
|
+
)
|
|
1736
|
+
|
|
1737
|
+
# Initialize weights and apply final processing
|
|
1738
|
+
if config.post_init:
|
|
1739
|
+
self.post_init()
|
|
1740
|
+
|
|
1741
|
+
def forward(
|
|
1742
|
+
self,
|
|
1743
|
+
past_values: torch.Tensor,
|
|
1744
|
+
future_values: Optional[torch.Tensor] = None,
|
|
1745
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
1746
|
+
future_observed_mask: Optional[torch.Tensor] = None,
|
|
1747
|
+
output_hidden_states: Optional[bool] = False,
|
|
1748
|
+
return_loss: bool = True,
|
|
1749
|
+
return_dict: Optional[bool] = None,
|
|
1750
|
+
freq_token: Optional[torch.Tensor] = None,
|
|
1751
|
+
static_categorical_values: Optional[torch.Tensor] = None,
|
|
1752
|
+
metadata: Optional[torch.Tensor] = None,
|
|
1753
|
+
) -> TinyTimeMixerForPredictionOutput:
|
|
1754
|
+
r"""
|
|
1755
|
+
past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
|
|
1756
|
+
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
|
|
1757
|
+
in `[0, 1]` or `[False, True]`:
|
|
1758
|
+
- 1 or True for values that are **observed**,
|
|
1759
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
1760
|
+
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
|
|
1761
|
+
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
|
|
1762
|
+
values of the time series, that serve as labels for the model. The `future_values` is what the
|
|
1763
|
+
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
|
|
1764
|
+
required for a pretraining task.
|
|
1765
|
+
|
|
1766
|
+
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
|
|
1767
|
+
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
|
|
1768
|
+
pass the target data with all channels, as channel Filtering for both prediction and target will be
|
|
1769
|
+
manually applied before the loss computation.
|
|
1770
|
+
future_observed_mask (`torch.Tensor` of shape `(batch_size, prediction_length, num_targets)`, *optional*):
|
|
1771
|
+
Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
|
|
1772
|
+
in `[0, 1]` or `[False, True]`:
|
|
1773
|
+
- 1 or True for values that are **observed**,
|
|
1774
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
1775
|
+
return_loss (`bool`, *optional*):
|
|
1776
|
+
Whether to return the loss in the `forward` call.
|
|
1777
|
+
static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
|
|
1778
|
+
Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
|
|
1779
|
+
TinyTimeMixerConfig param `categorical_vocab_size_list`
|
|
1780
|
+
metadata (`torch.Tensor`, *optional*): A tensor containing metadata. Currently unused in TinyTimeMixer, but used
|
|
1781
|
+
to support custom trainers. Defaults to None.
|
|
1782
|
+
|
|
1783
|
+
Returns:
|
|
1784
|
+
|
|
1785
|
+
"""
|
|
1786
|
+
if past_values.dim() != 3:
|
|
1787
|
+
raise ValueError(
|
|
1788
|
+
"`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`."
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
sequence_length = (
|
|
1792
|
+
self.config.masked_context_length
|
|
1793
|
+
if self.config.masked_context_length is not None
|
|
1794
|
+
else self.config.context_length
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
if past_values.shape[1] > sequence_length:
|
|
1798
|
+
past_values = past_values[:, -sequence_length:, :]
|
|
1799
|
+
elif past_values.shape[1] < sequence_length:
|
|
1800
|
+
raise ValueError("Context length in `past_values` is shorter that TTM context_length.")
|
|
1801
|
+
|
|
1802
|
+
# if self.loss == "mse":
|
|
1803
|
+
# loss = nn.MSELoss(reduction="mean")
|
|
1804
|
+
# elif self.loss == "mae":
|
|
1805
|
+
# loss = nn.L1Loss(reduction="mean")
|
|
1806
|
+
# elif self.loss == "pinball":
|
|
1807
|
+
# loss = PinballLoss(quantile=self.config.quantile)
|
|
1808
|
+
# elif self.loss == "huber":
|
|
1809
|
+
# loss = nn.HuberLoss(delta=self.config.huber_delta)
|
|
1810
|
+
# elif self.loss == "nll":
|
|
1811
|
+
# raise Exception(
|
|
1812
|
+
# "NLL loss and Distribution heads are currently not allowed. Use mse or mae as loss functions."
|
|
1813
|
+
# )
|
|
1814
|
+
# loss = nll
|
|
1815
|
+
# elif self.loss is None:
|
|
1816
|
+
# loss = None
|
|
1817
|
+
# else:
|
|
1818
|
+
# raise ValueError("Invalid loss function: Allowed values: mse, mae and nll")
|
|
1819
|
+
|
|
1820
|
+
loss = nn.MSELoss(reduction="mean")
|
|
1821
|
+
|
|
1822
|
+
return_dict = return_dict if return_dict is not None else self.use_return_dict
|
|
1823
|
+
|
|
1824
|
+
# past_values: tensor [batch_size x context_length x num_input_channels]
|
|
1825
|
+
model_output = self.backbone(
|
|
1826
|
+
past_values,
|
|
1827
|
+
past_observed_mask=past_observed_mask,
|
|
1828
|
+
output_hidden_states=output_hidden_states,
|
|
1829
|
+
return_dict=return_dict,
|
|
1830
|
+
freq_token=freq_token,
|
|
1831
|
+
) # model_output: [batch_size x nvars x num_patch x d_model]
|
|
1832
|
+
|
|
1833
|
+
if isinstance(model_output, tuple):
|
|
1834
|
+
model_output = TinyTimeMixerModelOutput(*model_output)
|
|
1835
|
+
|
|
1836
|
+
decoder_input = model_output.last_hidden_state
|
|
1837
|
+
hidden_states = model_output.hidden_states
|
|
1838
|
+
|
|
1839
|
+
if self.use_decoder:
|
|
1840
|
+
decoder_output, decoder_hidden_states = self.decoder(
|
|
1841
|
+
hidden_state=decoder_input,
|
|
1842
|
+
patch_input=model_output.patch_input,
|
|
1843
|
+
output_hidden_states=output_hidden_states,
|
|
1844
|
+
static_categorical_values=static_categorical_values,
|
|
1845
|
+
) # [batch_size x nvars x num_patch x decoder_d_model]
|
|
1846
|
+
|
|
1847
|
+
if decoder_hidden_states:
|
|
1848
|
+
hidden_states.extend(decoder_hidden_states)
|
|
1849
|
+
|
|
1850
|
+
else:
|
|
1851
|
+
decoder_output = decoder_input
|
|
1852
|
+
|
|
1853
|
+
# tensor [batch_size x prediction_length x num_input_channels]
|
|
1854
|
+
|
|
1855
|
+
# head should take future mask
|
|
1856
|
+
y_hat = self.head(decoder_output, past_values=past_values, future_values=future_values)
|
|
1857
|
+
|
|
1858
|
+
if (
|
|
1859
|
+
self.prediction_filter_length is not None
|
|
1860
|
+
and future_values is not None
|
|
1861
|
+
and future_values.shape[1] != self.prediction_filter_length
|
|
1862
|
+
):
|
|
1863
|
+
future_values = future_values[:, : self.prediction_filter_length, :]
|
|
1864
|
+
|
|
1865
|
+
if future_observed_mask is not None:
|
|
1866
|
+
future_observed_mask = future_observed_mask[:, : self.prediction_filter_length, :]
|
|
1867
|
+
|
|
1868
|
+
if (
|
|
1869
|
+
self.prediction_channel_indices is not None
|
|
1870
|
+
and future_values is not None
|
|
1871
|
+
and future_values.shape[2] != len(self.prediction_channel_indices)
|
|
1872
|
+
and future_values.shape[2] == self.num_input_channels
|
|
1873
|
+
):
|
|
1874
|
+
future_values = future_values[..., self.prediction_channel_indices]
|
|
1875
|
+
|
|
1876
|
+
if future_observed_mask is not None:
|
|
1877
|
+
future_observed_mask = future_observed_mask[..., self.prediction_channel_indices]
|
|
1878
|
+
|
|
1879
|
+
if self.prediction_channel_indices is not None:
|
|
1880
|
+
loc = model_output.loc[..., self.prediction_channel_indices]
|
|
1881
|
+
scale = model_output.scale[..., self.prediction_channel_indices]
|
|
1882
|
+
else:
|
|
1883
|
+
loc = model_output.loc
|
|
1884
|
+
scale = model_output.scale
|
|
1885
|
+
|
|
1886
|
+
y_hat = y_hat * scale + loc
|
|
1887
|
+
# loc/scale: batch_size x 1 x prediction_channel_indices or num_targets
|
|
1888
|
+
|
|
1889
|
+
# loss_val = None
|
|
1890
|
+
|
|
1891
|
+
# if future_observed_mask is not None:
|
|
1892
|
+
# fut_mask_bool = future_observed_mask.type(torch.bool)
|
|
1893
|
+
|
|
1894
|
+
# if self.distribution_output:
|
|
1895
|
+
# distribution = self.distribution_output.distribution(y_hat, loc=loc, scale=scale)
|
|
1896
|
+
# if future_values is not None and return_loss is True and loss is not None:
|
|
1897
|
+
# if future_observed_mask is not None and (~fut_mask_bool).any():
|
|
1898
|
+
# if (~fut_mask_bool).all():
|
|
1899
|
+
# # no valid observed values
|
|
1900
|
+
# print(future_observed_mask)
|
|
1901
|
+
# raise ValueError("Loss computation failed due to too many missing values")
|
|
1902
|
+
# loss_val = loss(distribution, future_values)
|
|
1903
|
+
# # select only values of loss where entire timepoint is observed
|
|
1904
|
+
# loss_val = loss_val[fut_mask_bool.all(dim=-1)]
|
|
1905
|
+
# else:
|
|
1906
|
+
# loss_val = loss(distribution, future_values)
|
|
1907
|
+
# loss_val = weighted_average(loss_val)
|
|
1908
|
+
# else:
|
|
1909
|
+
# y_hat = y_hat * scale + loc
|
|
1910
|
+
# if future_values is not None and return_loss is True and loss is not None:
|
|
1911
|
+
# if future_observed_mask is not None:
|
|
1912
|
+
# loss_val = loss(y_hat[fut_mask_bool], future_values[fut_mask_bool])
|
|
1913
|
+
# else:
|
|
1914
|
+
# # avoiding mask operations for performance benefits on normal scenarios.
|
|
1915
|
+
# loss_val = loss(y_hat, future_values)
|
|
1916
|
+
|
|
1917
|
+
return y_hat
|
|
1918
|
+
|
|
1919
|
+
# if not return_dict:
|
|
1920
|
+
# return tuple(
|
|
1921
|
+
# v
|
|
1922
|
+
# for v in [
|
|
1923
|
+
# loss_val,
|
|
1924
|
+
# y_hat,
|
|
1925
|
+
# model_output.last_hidden_state,
|
|
1926
|
+
# decoder_output,
|
|
1927
|
+
# hidden_states,
|
|
1928
|
+
# loc,
|
|
1929
|
+
# scale,
|
|
1930
|
+
# ]
|
|
1931
|
+
# )
|
|
1932
|
+
|
|
1933
|
+
# return TinyTimeMixerForPredictionOutput(
|
|
1934
|
+
# loss=loss_val,
|
|
1935
|
+
# prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
|
|
1936
|
+
# backbone_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
|
|
1937
|
+
# decoder_hidden_state=decoder_output, # x: [batch_size x nvars x num_patch x decoder_d_model]
|
|
1938
|
+
# hidden_states=hidden_states,
|
|
1939
|
+
# loc=loc,
|
|
1940
|
+
# scale=scale,
|
|
1941
|
+
# )
|
|
1942
|
+
|
|
1943
|
+
def generate(
|
|
1944
|
+
self,
|
|
1945
|
+
past_values: torch.Tensor,
|
|
1946
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
1947
|
+
) -> SampleTinyTimeMixerPredictionOutput:
|
|
1948
|
+
"""
|
|
1949
|
+
Generate sequences of sample predictions from a model with a probability distribution head.
|
|
1950
|
+
|
|
1951
|
+
Args:
|
|
1952
|
+
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
|
|
1953
|
+
Past values of the time series that serves as context in order to predict the future.
|
|
1954
|
+
|
|
1955
|
+
past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
|
|
1956
|
+
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
|
|
1957
|
+
in `[0, 1]` or `[False, True]`:
|
|
1958
|
+
- 1 or True for values that are **observed**,
|
|
1959
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
1960
|
+
|
|
1961
|
+
Return:
|
|
1962
|
+
[`SampleTinyTimeMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
|
|
1963
|
+
number of samples, prediction_length, num_input_channels)`.
|
|
1964
|
+
"""
|
|
1965
|
+
# get number of samples
|
|
1966
|
+
num_parallel_samples = self.num_parallel_samples
|
|
1967
|
+
|
|
1968
|
+
# get model output
|
|
1969
|
+
outputs = self(
|
|
1970
|
+
past_values=past_values,
|
|
1971
|
+
future_values=None,
|
|
1972
|
+
past_observed_mask=past_observed_mask,
|
|
1973
|
+
output_hidden_states=False,
|
|
1974
|
+
)
|
|
1975
|
+
|
|
1976
|
+
# get distribution
|
|
1977
|
+
|
|
1978
|
+
distribution = self.distribution_output.distribution(
|
|
1979
|
+
outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
# get samples: list of [batch_size x prediction_length x num_channels]
|
|
1983
|
+
samples = [distribution.sample() for _ in range(num_parallel_samples)]
|
|
1984
|
+
|
|
1985
|
+
# stack tensors
|
|
1986
|
+
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
|
|
1987
|
+
return SampleTinyTimeMixerPredictionOutput(sequences=samples)
|
|
1988
|
+
|
|
1989
|
+
|
|
1990
|
+
class TinyTimeMixerForMaskedPrediction(TinyTimeMixerForPrediction):
|
|
1991
|
+
def __init__(self, config: TinyTimeMixerConfig):
|
|
1992
|
+
if config.prediction_filter_length is not None:
|
|
1993
|
+
append_length = config.prediction_filter_length
|
|
1994
|
+
else:
|
|
1995
|
+
append_length = config.prediction_length
|
|
1996
|
+
|
|
1997
|
+
self.append_length = append_length
|
|
1998
|
+
config.masked_context_length = config.context_length + append_length
|
|
1999
|
+
config.fcm_prepend_past_offset = append_length
|
|
2000
|
+
|
|
2001
|
+
if config.exogenous_channel_indices is not None:
|
|
2002
|
+
self.non_exog_channels = list(
|
|
2003
|
+
set(range(config.num_input_channels)) - set(config.exogenous_channel_indices)
|
|
2004
|
+
)
|
|
2005
|
+
else:
|
|
2006
|
+
self.non_exog_channels = list(range(config.num_input_channels))
|
|
2007
|
+
|
|
2008
|
+
super().__init__(config)
|
|
2009
|
+
|
|
2010
|
+
def forward(
|
|
2011
|
+
self,
|
|
2012
|
+
past_values: torch.Tensor,
|
|
2013
|
+
future_values: Optional[torch.Tensor] = None,
|
|
2014
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
2015
|
+
future_observed_mask: Optional[torch.Tensor] = None,
|
|
2016
|
+
output_hidden_states: Optional[bool] = False,
|
|
2017
|
+
return_loss: bool = True,
|
|
2018
|
+
return_dict: Optional[bool] = None,
|
|
2019
|
+
freq_token: Optional[torch.Tensor] = None,
|
|
2020
|
+
static_categorical_values: Optional[torch.Tensor] = None,
|
|
2021
|
+
metadata: Optional[torch.Tensor] = None,
|
|
2022
|
+
) -> TinyTimeMixerForPredictionOutput:
|
|
2023
|
+
r"""
|
|
2024
|
+
past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
|
|
2025
|
+
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
|
|
2026
|
+
in `[0, 1]` or `[False, True]`:
|
|
2027
|
+
- 1 or True for values that are **observed**,
|
|
2028
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
2029
|
+
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
|
|
2030
|
+
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
|
|
2031
|
+
values of the time series, that serve as labels for the model. The `future_values` is what the
|
|
2032
|
+
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
|
|
2033
|
+
required for a pretraining task.
|
|
2034
|
+
|
|
2035
|
+
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
|
|
2036
|
+
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
|
|
2037
|
+
pass the target data with all channels, as channel Filtering for both prediction and target will be
|
|
2038
|
+
manually applied before the loss computation.
|
|
2039
|
+
future_observed_mask (`torch.Tensor` of shape `(batch_size, prediction_length, num_targets)`, *optional*):
|
|
2040
|
+
Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
|
|
2041
|
+
in `[0, 1]` or `[False, True]`:
|
|
2042
|
+
- 1 or True for values that are **observed**,
|
|
2043
|
+
- 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
|
|
2044
|
+
return_loss (`bool`, *optional*):
|
|
2045
|
+
Whether to return the loss in the `forward` call.
|
|
2046
|
+
static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
|
|
2047
|
+
Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
|
|
2048
|
+
TinyTimeMixerConfig param `categorical_vocab_size_list`
|
|
2049
|
+
metadata (`torch.Tensor`, *optional*): A tensor containing metadata. Currently unused in TinyTimeMixer, but used
|
|
2050
|
+
to support custom trainers. Defaults to None.
|
|
2051
|
+
|
|
2052
|
+
Returns:
|
|
2053
|
+
|
|
2054
|
+
"""
|
|
2055
|
+
if future_values is not None:
|
|
2056
|
+
future_values_masked = future_values.clone()
|
|
2057
|
+
else:
|
|
2058
|
+
future_values_masked = torch.zeros(past_values.shape[0], self.append_length, past_values.shape[2])
|
|
2059
|
+
|
|
2060
|
+
if (
|
|
2061
|
+
self.config.prediction_filter_length is not None
|
|
2062
|
+
and future_values_masked is not None
|
|
2063
|
+
and future_values_masked.shape[1] != self.config.prediction_filter_length
|
|
2064
|
+
):
|
|
2065
|
+
future_values_masked = future_values_masked[:, : self.config.prediction_filter_length, :]
|
|
2066
|
+
|
|
2067
|
+
if self.config.exogenous_channel_indices is not None:
|
|
2068
|
+
future_values_masked[:, :, self.non_exog_channels] = self.config.mask_value
|
|
2069
|
+
else:
|
|
2070
|
+
future_values_masked.fill_(self.config.mask_value)
|
|
2071
|
+
past_values = torch.cat((past_values, future_values_masked), dim=-2) # xb: [bs x seq_len+ fl x n_vars]
|
|
2072
|
+
|
|
2073
|
+
if past_observed_mask is None:
|
|
2074
|
+
past_observed_mask = torch.ones_like(past_values)
|
|
2075
|
+
|
|
2076
|
+
# if there is already a past mask - update with it
|
|
2077
|
+
# index 1 refers to the seq len
|
|
2078
|
+
|
|
2079
|
+
if past_observed_mask.shape[1] < past_values.shape[1]:
|
|
2080
|
+
temp_mask = torch.ones_like(past_values)
|
|
2081
|
+
temp_mask[:, : past_observed_mask.shape[1], :] = past_observed_mask
|
|
2082
|
+
past_observed_mask = temp_mask
|
|
2083
|
+
|
|
2084
|
+
# past_observed_mask[:, -self.config.prediction_length :, :] = 0
|
|
2085
|
+
past_observed_mask[:, -self.config.prediction_length :, self.non_exog_channels] = 0
|
|
2086
|
+
# [bs x seq_len+ fl x n_vars]
|
|
2087
|
+
|
|
2088
|
+
return super().forward(
|
|
2089
|
+
past_values=past_values,
|
|
2090
|
+
future_values=future_values,
|
|
2091
|
+
past_observed_mask=past_observed_mask,
|
|
2092
|
+
future_observed_mask=future_observed_mask,
|
|
2093
|
+
output_hidden_states=output_hidden_states,
|
|
2094
|
+
return_loss=return_loss,
|
|
2095
|
+
return_dict=return_dict,
|
|
2096
|
+
freq_token=freq_token,
|
|
2097
|
+
static_categorical_values=static_categorical_values,
|
|
2098
|
+
metadata=metadata,
|
|
2099
|
+
)
|