dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2099 @@
1
+ # Copyright contributors to the TSFM project
2
+ #
3
+ # This code is based on layers and components from the PatchTSMixer model in the HuggingFace Transformers
4
+ # Library: https://github.com/huggingface/transformers/blob/main/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
5
+ """PyTorch TinyTimeMixer model."""
6
+
7
+ import copy
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Tuple, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.time_series_utils import (
16
+ NegativeBinomialOutput,
17
+ NormalOutput,
18
+ StudentTOutput,
19
+ )
20
+ from transformers.utils import ModelOutput
21
+
22
+ from .configuration_tinytimemixer import TinyTimeMixerConfig
23
+
24
+
25
+
26
+ class PinballLoss(nn.Module):
27
+ def __init__(self, quantile: float):
28
+ """
29
+ Initialize the Pinball Loss for multidimensional tensors.
30
+
31
+ Args:
32
+ quantile (float): The desired quantile (e.g., 0.5 for median, 0.9 for 90th percentile).
33
+ """
34
+ super(PinballLoss, self).__init__()
35
+ self.quantile = quantile
36
+
37
+ def forward(self, predictions, targets):
38
+ """
39
+ Compute the Pinball Loss for shape [b, seq_len, channels].
40
+
41
+ Args:
42
+ predictions (torch.Tensor): Predicted values, shape [b, seq_len, channels].
43
+ targets (torch.Tensor): Ground truth values, shape [b, seq_len, channels].
44
+
45
+ Returns:
46
+ torch.Tensor: The mean pinball loss over all dimensions.
47
+ """
48
+ errors = targets - predictions
49
+
50
+ loss = torch.max(self.quantile * errors, (self.quantile - 1) * errors)
51
+
52
+ return loss.mean()
53
+
54
+
55
+ class TinyTimeMixerGatedAttention(nn.Module):
56
+ """
57
+ Module that applies gated attention to input data.
58
+
59
+ Args:
60
+ in_size (`int`): The input size.
61
+ out_size (`int`): The output size.
62
+ """
63
+
64
+ def __init__(self, in_size: int, out_size: int):
65
+ super().__init__()
66
+ self.attn_layer = nn.Linear(in_size, out_size)
67
+ self.attn_softmax = nn.Softmax(dim=-1)
68
+
69
+ def forward(self, inputs):
70
+ attn_weight = self.attn_softmax(self.attn_layer(inputs))
71
+ inputs = inputs * attn_weight
72
+ return inputs
73
+
74
+
75
+ class TinyTimeMixerCategoricalEmbeddingLayer(nn.Module):
76
+ """ """
77
+
78
+ def __init__(self, config: TinyTimeMixerConfig):
79
+ super().__init__()
80
+ self.categorical_vocab_size_list = config.categorical_vocab_size_list
81
+ self.embedding_layers = nn.ModuleList(
82
+ [nn.Embedding(vocab, config.d_model) for vocab in self.categorical_vocab_size_list]
83
+ )
84
+ self.number_of_categorical_variables = len(self.categorical_vocab_size_list)
85
+ self.num_patches = config.num_patches
86
+
87
+ def forward(self, static_categorical_values: torch.Tensor):
88
+ """
89
+ Parameters:
90
+ static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`):
91
+ Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
92
+ TinyTimeMixerConfig param `categorical_vocab_size_list`
93
+ Returns:
94
+ `torch.Tensor` of shape `(batch_size, number_of_categorical_variables, num_patches, d_model)`
95
+ """
96
+ # static_categorical_values [bs x number_of_categorical_variables]
97
+ embedded_tensors = []
98
+
99
+ for i in range(self.number_of_categorical_variables):
100
+ embedded_tensor = self.embedding_layers[i](static_categorical_values[:, i].long())
101
+ embedded_tensors.append(embedded_tensor)
102
+
103
+ output_tensor = torch.stack(embedded_tensors, dim=1) # bs x number_of_categorical_variables x d_model
104
+
105
+ output_tensor = output_tensor.unsqueeze(2).repeat(
106
+ 1, 1, self.num_patches, 1
107
+ ) # bs x number_of_categorical_variables x num_patches x d_model
108
+
109
+ return output_tensor
110
+
111
+
112
+ class TinyTimeMixerBatchNorm(nn.Module):
113
+ """
114
+ Compute batch normalization over the sequence length (time) dimension.
115
+ """
116
+
117
+ def __init__(self, config: TinyTimeMixerConfig):
118
+ super().__init__()
119
+ self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
120
+
121
+ def forward(self, inputs: torch.Tensor):
122
+ """
123
+ Parameters:
124
+ inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
125
+ input for Batch norm calculation
126
+ Returns:
127
+ `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
128
+ """
129
+ output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
130
+ output = self.batchnorm(output)
131
+ return output.transpose(1, 2)
132
+
133
+
134
+ class TinyTimeMixerPositionalEncoding(nn.Module):
135
+ """
136
+ Class for positional encoding
137
+ """
138
+
139
+ def __init__(self, config: TinyTimeMixerConfig):
140
+ super().__init__()
141
+ # positional encoding: [num_patches x d_model]
142
+ if config.use_positional_encoding:
143
+ self.position_enc = self._init_pe(config)
144
+ else:
145
+ self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
146
+
147
+ @staticmethod
148
+ def _init_pe(config: TinyTimeMixerConfig) -> nn.Parameter:
149
+ # Positional encoding
150
+ if config.positional_encoding_type == "random":
151
+ position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
152
+ elif config.positional_encoding_type == "sincos":
153
+ position_enc = torch.zeros(config.num_patches, config.d_model)
154
+ position = torch.arange(0, config.num_patches).unsqueeze(1)
155
+ div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
156
+ position_enc[:, 0::2] = torch.sin(position * div_term)
157
+ position_enc[:, 1::2] = torch.cos(position * div_term)
158
+ position_enc = position_enc - position_enc.mean()
159
+ position_enc = position_enc / (position_enc.std() * 10)
160
+ position_enc = nn.Parameter(position_enc, requires_grad=False)
161
+ else:
162
+ raise ValueError(
163
+ f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
164
+ )
165
+ return position_enc
166
+
167
+ def forward(self, patch_input: torch.Tensor):
168
+ # hidden_state: [bs x num_channels x num_patches x d_model]
169
+ hidden_state = patch_input + self.position_enc
170
+ return hidden_state
171
+
172
+
173
+ class TinyTimeMixerNormLayer(nn.Module):
174
+ """Normalization block
175
+
176
+ Args:
177
+ config (`TinyTimeMixerConfig`, *required*):
178
+ Configuration.
179
+ """
180
+
181
+ def __init__(self, config: TinyTimeMixerConfig):
182
+ super().__init__()
183
+
184
+ self.norm_mlp = config.norm_mlp
185
+
186
+ if "batch" in config.norm_mlp.lower():
187
+ self.norm = TinyTimeMixerBatchNorm(config)
188
+ else:
189
+ self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
190
+
191
+ def forward(self, inputs: torch.Tensor):
192
+ """
193
+ Args:
194
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
195
+ Input to the normalization layer.
196
+ Returns:
197
+ `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
198
+ """
199
+ if "batch" in self.norm_mlp.lower():
200
+ # reshape the data
201
+ inputs_reshaped = torch.reshape(
202
+ inputs,
203
+ (
204
+ inputs.shape[0] * inputs.shape[1],
205
+ inputs.shape[2],
206
+ inputs.shape[3],
207
+ ),
208
+ ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
209
+
210
+ # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
211
+ inputs_reshaped = self.norm(inputs_reshaped)
212
+
213
+ # put back data to the original shape
214
+ inputs = torch.reshape(inputs_reshaped, inputs.shape)
215
+
216
+ else:
217
+ inputs = self.norm(inputs)
218
+
219
+ return inputs
220
+
221
+
222
+ class TinyTimeMixerMLP(nn.Module):
223
+ def __init__(self, in_features, out_features, config):
224
+ super().__init__()
225
+ num_hidden = in_features * config.expansion_factor
226
+ self.fc1 = nn.Linear(in_features, num_hidden)
227
+ self.dropout1 = nn.Dropout(config.dropout)
228
+ self.fc2 = nn.Linear(num_hidden, out_features)
229
+ self.dropout2 = nn.Dropout(config.dropout)
230
+
231
+ def forward(self, inputs: torch.Tensor):
232
+ """
233
+ Args:
234
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
235
+ Input to the MLP layer.
236
+ Returns:
237
+ `torch.Tensor` of the same shape as `inputs`
238
+ """
239
+ inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
240
+ inputs = self.fc2(inputs)
241
+ inputs = self.dropout2(inputs)
242
+ return inputs
243
+
244
+
245
+ class TinyTimeMixerChannelFeatureMixerBlock(nn.Module):
246
+ """This module mixes the features in the channel dimension.
247
+
248
+ Args:
249
+ config (`TinyTimeMixerConfig`, *required*):
250
+ Configuration.
251
+ """
252
+
253
+ def __init__(self, config: TinyTimeMixerConfig):
254
+ super().__init__()
255
+
256
+ self.norm = TinyTimeMixerNormLayer(config)
257
+ self.gated_attn = config.gated_attn
258
+ self.mlp = TinyTimeMixerMLP(
259
+ in_features=config.num_input_channels,
260
+ out_features=config.num_input_channels,
261
+ config=config,
262
+ )
263
+
264
+ if config.gated_attn:
265
+ self.gating_block = TinyTimeMixerGatedAttention(
266
+ in_size=config.num_input_channels, out_size=config.num_input_channels
267
+ )
268
+
269
+ def forward(self, inputs: torch.Tensor):
270
+ """
271
+ Args:
272
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
273
+ input to the MLP layer
274
+ Returns:
275
+ `torch.Tensor` of the same shape as `inputs`
276
+ """
277
+ residual = inputs
278
+ inputs = self.norm(inputs)
279
+
280
+ inputs = inputs.permute(0, 3, 2, 1)
281
+
282
+ if self.gated_attn:
283
+ inputs = self.gating_block(inputs)
284
+
285
+ inputs = self.mlp(inputs)
286
+
287
+ inputs = inputs.permute(0, 3, 2, 1)
288
+
289
+ out = inputs + residual
290
+ return out
291
+
292
+
293
+ class TinyTimeMixerAttention(nn.Module):
294
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
295
+
296
+ def __init__(
297
+ self,
298
+ embed_dim: int,
299
+ num_heads: int,
300
+ dropout: float = 0.0,
301
+ is_decoder: bool = False,
302
+ bias: bool = True,
303
+ is_causal: bool = False,
304
+ config: Optional[TinyTimeMixerConfig] = None,
305
+ ):
306
+ super().__init__()
307
+ self.embed_dim = embed_dim
308
+ self.num_heads = num_heads
309
+ self.dropout = dropout
310
+ self.head_dim = embed_dim // num_heads
311
+ self.config = config
312
+
313
+ if (self.head_dim * num_heads) != self.embed_dim:
314
+ raise ValueError(
315
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
316
+ f" and `num_heads`: {num_heads})."
317
+ )
318
+ self.scaling = self.head_dim**-0.5
319
+ self.is_decoder = is_decoder
320
+ self.is_causal = is_causal
321
+
322
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
323
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
324
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
325
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
326
+
327
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
328
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ key_value_states: Optional[torch.Tensor] = None,
334
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ layer_head_mask: Optional[torch.Tensor] = None,
337
+ output_attentions: bool = False,
338
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
339
+ """Input shape: Batch x Time x Channel"""
340
+
341
+ # if key_value_states are provided this layer is used as a cross-attention layer
342
+ # for the decoder
343
+ is_cross_attention = key_value_states is not None
344
+
345
+ bsz, tgt_len, _ = hidden_states.size()
346
+
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states) * self.scaling
349
+ # get key, value proj
350
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
351
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
352
+ # the provided `key_value_states` to support prefix tuning
353
+ if (
354
+ is_cross_attention
355
+ and past_key_value is not None
356
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
357
+ ):
358
+ # reuse k,v, cross_attentions
359
+ key_states = past_key_value[0]
360
+ value_states = past_key_value[1]
361
+ elif is_cross_attention:
362
+ # cross_attentions
363
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
364
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
365
+ elif past_key_value is not None:
366
+ # reuse k, v, self_attention
367
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
368
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
369
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
370
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
371
+ else:
372
+ # self_attention
373
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
374
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
375
+
376
+ if self.is_decoder:
377
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
378
+ # Further calls to cross_attention layer can then reuse all cross-attention
379
+ # key/value_states (first "if" case)
380
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
381
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
382
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
383
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
384
+ past_key_value = (key_states, value_states)
385
+
386
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
387
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
388
+ key_states = key_states.reshape(*proj_shape)
389
+ value_states = value_states.reshape(*proj_shape)
390
+
391
+ src_len = key_states.size(1)
392
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
393
+
394
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
395
+ raise ValueError(
396
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
397
+ f" {attn_weights.size()}"
398
+ )
399
+
400
+ if attention_mask is not None:
401
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
402
+ raise ValueError(
403
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
404
+ )
405
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
406
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
407
+
408
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
409
+
410
+ if layer_head_mask is not None:
411
+ if layer_head_mask.size() != (self.num_heads,):
412
+ raise ValueError(
413
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
414
+ f" {layer_head_mask.size()}"
415
+ )
416
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
417
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
418
+
419
+ if output_attentions:
420
+ # this operation is a bit awkward, but it's required to
421
+ # make sure that attn_weights keeps its gradient.
422
+ # In order to do so, attn_weights have to be reshaped
423
+ # twice and have to be reused in the following
424
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
425
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
426
+ else:
427
+ attn_weights_reshaped = None
428
+
429
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
430
+
431
+ attn_output = torch.bmm(attn_probs, value_states)
432
+
433
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
434
+ raise ValueError(
435
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
436
+ f" {attn_output.size()}"
437
+ )
438
+
439
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
440
+ attn_output = attn_output.transpose(1, 2)
441
+
442
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
443
+ # partitioned across GPUs when using tensor-parallelism.
444
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
445
+
446
+ attn_output = self.out_proj(attn_output)
447
+
448
+ return attn_output, attn_weights_reshaped, past_key_value
449
+
450
+
451
+ class PatchMixerBlock(nn.Module):
452
+ """This module mixes the patch dimension.
453
+
454
+ Args:
455
+ config (`TinyTimeMixerConfig`, *required*):
456
+ Configuration.
457
+ """
458
+
459
+ def __init__(self, config: TinyTimeMixerConfig):
460
+ super().__init__()
461
+
462
+ self.norm = TinyTimeMixerNormLayer(config)
463
+
464
+ self.self_attn = config.self_attn
465
+ self.gated_attn = config.gated_attn
466
+
467
+ self.mlp = TinyTimeMixerMLP(
468
+ in_features=config.num_patches,
469
+ out_features=config.num_patches,
470
+ config=config,
471
+ )
472
+
473
+ if config.gated_attn:
474
+ self.gating_block = TinyTimeMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
475
+
476
+ if config.self_attn:
477
+ self.self_attn_layer = TinyTimeMixerAttention(
478
+ embed_dim=config.d_model,
479
+ num_heads=config.self_attn_heads,
480
+ dropout=config.dropout,
481
+ )
482
+ self.norm_attn = TinyTimeMixerNormLayer(config)
483
+
484
+ def forward(self, hidden_state):
485
+ """
486
+ Args:
487
+ hidden_state (`torch.Tensor`): Input tensor.
488
+
489
+ Returns:
490
+ `torch.Tensor`: Transformed tensor.
491
+ """
492
+ residual = hidden_state
493
+
494
+ hidden_state = self.norm(hidden_state)
495
+
496
+ if self.self_attn:
497
+ batch_size, n_vars, num_patches, d_model = hidden_state.shape
498
+ hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
499
+
500
+ x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
501
+ x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
502
+
503
+ # Transpose so that num_patches is the last dimension
504
+ hidden_state = hidden_state.transpose(2, 3)
505
+ hidden_state = self.mlp(hidden_state)
506
+
507
+ if self.gated_attn:
508
+ hidden_state = self.gating_block(hidden_state)
509
+
510
+ # Transpose back
511
+ hidden_state = hidden_state.transpose(2, 3)
512
+
513
+ if self.self_attn:
514
+ hidden_state = self.norm_attn(hidden_state + x_attn)
515
+
516
+ out = hidden_state + residual
517
+ return out
518
+
519
+
520
+ class FeatureMixerBlock(nn.Module):
521
+ """This module mixes the hidden feature dimension.
522
+
523
+ Args:
524
+ config (`TinyTimeMixerConfig`, *required*):
525
+ Configuration.
526
+
527
+ """
528
+
529
+ def __init__(self, config: TinyTimeMixerConfig):
530
+ super().__init__()
531
+
532
+ self.norm = TinyTimeMixerNormLayer(config)
533
+
534
+ self.gated_attn = config.gated_attn
535
+
536
+ self.mlp = TinyTimeMixerMLP(
537
+ in_features=config.d_model,
538
+ out_features=config.d_model,
539
+ config=config,
540
+ )
541
+
542
+ if config.gated_attn:
543
+ self.gating_block = TinyTimeMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
544
+
545
+ def forward(self, hidden: torch.Tensor):
546
+ """
547
+ Args:
548
+ hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
549
+ Input tensor to the layer.
550
+
551
+ Returns:
552
+ `torch.Tensor`: Transformed tensor.
553
+ """
554
+ residual = hidden
555
+ hidden = self.norm(hidden)
556
+ hidden = self.mlp(hidden)
557
+
558
+ if self.gated_attn:
559
+ hidden = self.gating_block(hidden)
560
+
561
+ out = hidden + residual
562
+ return out
563
+
564
+
565
+ class ForecastChannelHeadMixer(nn.Module):
566
+ """ForecastChannelMixer Module to reconcile forecasts across channels with exogenous support.
567
+
568
+ When channel_context_length is positive this mode creates a patch for every multi-variate forecast point with its surronding context
569
+ it then flattens it and applies MLP to it.
570
+ By this way, every forecast point learn from its pre and post surrounding context in a channel mixed way.
571
+ Residual is added to ensure noise reduction with initial forecasts.
572
+ """
573
+
574
+ def __init__(self, config: TinyTimeMixerConfig):
575
+ super().__init__()
576
+
577
+ self.fcm_context_length = config.fcm_context_length
578
+ self.scl = 2 * self.fcm_context_length + 1
579
+
580
+ if config.prediction_channel_indices is not None:
581
+ self.prediction_channel_count = len(config.prediction_channel_indices)
582
+ else:
583
+ self.prediction_channel_count = config.num_input_channels
584
+
585
+ if config.exogenous_channel_indices is not None:
586
+ self.exogenous_channel_count = len(config.exogenous_channel_indices)
587
+ else:
588
+ self.exogenous_channel_count = 0
589
+
590
+ self.total_channel_count = self.prediction_channel_count + self.exogenous_channel_count
591
+
592
+ self.fcm_use_mixer = config.fcm_use_mixer
593
+
594
+ self.exogenous_channel_indices = config.exogenous_channel_indices
595
+ self.prediction_channel_indices = config.prediction_channel_indices
596
+ scl_features = self.scl
597
+
598
+ if self.fcm_use_mixer:
599
+ # model mixer considering channel dim as patch dim for lag computation
600
+ temp_config = copy.deepcopy(config)
601
+ temp_config.num_patches = self.total_channel_count
602
+ temp_config.patch_length = self.scl
603
+ temp_config.num_input_channels = config.prediction_length
604
+ temp_config.d_model = self.scl * 2
605
+ temp_config.patch_stride = 1
606
+ temp_config.num_layers = config.fcm_mix_layers
607
+ temp_config.dropout = config.head_dropout
608
+ temp_config.mode = "common_channel"
609
+ temp_config.gated_attn = config.fcm_gated_attn
610
+ temp_config.adaptive_patching_levels = 0
611
+ self.exog_mixer = TinyTimeMixerBlock(temp_config)
612
+ scl_features = self.scl * 2
613
+ self.fcm_embedding = nn.Linear(temp_config.patch_length, temp_config.d_model)
614
+
615
+ self.mlp = nn.Linear(
616
+ self.total_channel_count * (scl_features),
617
+ self.prediction_channel_count,
618
+ )
619
+ if config.fcm_gated_attn:
620
+ self.fcm_gating_block = TinyTimeMixerGatedAttention(
621
+ in_size=self.total_channel_count * (scl_features),
622
+ out_size=self.total_channel_count * (scl_features),
623
+ )
624
+ if self.fcm_context_length > 0:
625
+ patch_config = copy.deepcopy(config)
626
+ patch_config.context_length = config.prediction_length + (2 * config.fcm_context_length)
627
+ patch_config.masked_context_length = None
628
+ patch_config.patch_length = self.scl
629
+ patch_config.patch_stride = 1
630
+ self.fcm_patch_block = TinyTimeMixerPatchify(patch_config)
631
+
632
+ self.fcm_gated_attn = config.fcm_gated_attn
633
+ self.prediction_length = config.prediction_length
634
+ self.fcm_prepend_past = config.fcm_prepend_past
635
+
636
+ self.fcm_prepend_past_offset = (
637
+ config.fcm_prepend_past_offset
638
+ ) # Number of items to skip in the context window from the end
639
+
640
+ if self.fcm_prepend_past_offset is None:
641
+ self.fcm_prepend_slicing_indices = slice(-self.fcm_context_length, None)
642
+ else:
643
+ self.fcm_prepend_slicing_indices = slice(
644
+ -(self.fcm_prepend_past_offset + self.fcm_context_length),
645
+ -self.fcm_prepend_past_offset,
646
+ )
647
+
648
+ def forward(
649
+ self,
650
+ base_forecasts: torch.Tensor,
651
+ past_values: Optional[torch.Tensor],
652
+ future_values: Optional[torch.Tensor] = None,
653
+ ):
654
+ """
655
+ Args:
656
+ base_forecasts (`torch.Tensor` of shape `(batch_size, prediction length, forecast_channels)`):
657
+ Base Forecasts to reconcile
658
+
659
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
660
+ Context values of the time series. For a forecasting task, this denotes the history/past time series values.
661
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
662
+ greater than 1.
663
+
664
+ future_values (`torch.Tensor` of shape `(batch_size, prediction length, input_channels)`, *optional*, Defaults to None):
665
+ Actual groundtruths of the forecasts. Pass dummy values (say 0) for forecast channels, if groundtruth is unknown.
666
+ Pass the correct values for Exogenous channels where the forecast values are known.
667
+
668
+ Returns:
669
+ `torch.Tensor`: Updated forecasts of shape `(batch_size, prediction length, forecast_channels)`
670
+ """
671
+ # base_forecasts.shape == (batch_size x forecast_len x n_vars)
672
+
673
+ if self.prediction_channel_indices is not None:
674
+ past_prepend_values = past_values[
675
+ :, self.fcm_prepend_slicing_indices, self.prediction_channel_indices
676
+ ] # bs x context_len x forecast_channels
677
+ else:
678
+ past_prepend_values = past_values[
679
+ :, self.fcm_prepend_slicing_indices, :
680
+ ] # bs x fcm_context_len x forecast_channels
681
+
682
+ if self.exogenous_channel_count > 0 and future_values is None:
683
+ raise ValueError("future_values cannot be none when we have exogenous channels.")
684
+
685
+ if self.exogenous_channel_count > 0:
686
+ exog_values = future_values[..., self.exogenous_channel_indices] # bs x prediction len x exog_channels
687
+ past_exog_values = past_values[
688
+ :, self.fcm_prepend_slicing_indices, self.exogenous_channel_indices
689
+ ] # bs x context_len x exog_channels
690
+
691
+ past_prepend_values = torch.cat(
692
+ (past_prepend_values, past_exog_values), dim=-1
693
+ ) # bs x fcm_context_len x (forecast_channels+exog_channels)
694
+
695
+ else:
696
+ exog_values = None
697
+
698
+ residual = base_forecasts
699
+
700
+ if exog_values is not None:
701
+ base_forecasts = torch.cat(
702
+ (base_forecasts, exog_values), dim=-1
703
+ ) # x.shape == (batch_size x forecast_len x (forecast_channels+exog_channels))
704
+
705
+ if self.fcm_context_length > 0:
706
+ # this mode creates a patch for every multi-variate forecast point with its surronding context
707
+ # it then flattens it and applies MLP to it.
708
+ # By this way, every forecast point learn from its pre and post surrounding context in a channel mixed way.
709
+ # Residual is added to ensure noise reduction with initial forecasts.
710
+
711
+ # prefill and postfill zeros to enable patching for every forecast point with surrounding context
712
+
713
+ dummy = torch.zeros(
714
+ base_forecasts.shape[0],
715
+ self.fcm_context_length,
716
+ base_forecasts.shape[2],
717
+ device=base_forecasts.device,
718
+ ) # bs x fcm_context_length x n_vars
719
+
720
+ if self.fcm_prepend_past:
721
+ # add prefill and postfill
722
+ extend_forecasts = torch.concat(
723
+ (past_prepend_values, base_forecasts, dummy), dim=1
724
+ ) # bs x forecast_len + 2*fcm_context_length x n_vars
725
+ else:
726
+ # add prefill and postfill
727
+ extend_forecasts = torch.concat(
728
+ (dummy, base_forecasts, dummy), dim=1
729
+ ) # bs x forecast_len + 2*fcm_context_length x n_vars
730
+
731
+ # create patch
732
+ extend_forecasts = self.fcm_patch_block(extend_forecasts) # xb: [bs x n_vars x forecast_len x scl]
733
+
734
+ extend_forecasts = extend_forecasts.transpose(1, 2) # [bs x forecast_len x n_vars x scl]
735
+
736
+ if extend_forecasts.shape[1] != self.prediction_length:
737
+ raise ValueError("out_patches should match to forecast length")
738
+
739
+ if self.fcm_use_mixer:
740
+ extend_forecasts = self.fcm_embedding(extend_forecasts)
741
+ extend_forecasts, _ = self.exog_mixer(extend_forecasts)
742
+
743
+ extend_forecasts = extend_forecasts.flatten(start_dim=2) # xb: [bs x forecast_len x n_vars * scl]
744
+
745
+ if self.fcm_gated_attn:
746
+ extend_forecasts = self.fcm_gating_block(extend_forecasts) # xb: [bs x forecast_len x n_vars * scl]
747
+
748
+ extend_forecasts = self.mlp(extend_forecasts) # xb: [bs x forecast_len x n_vars]
749
+
750
+ else:
751
+ if self.fcm_gated_attn:
752
+ extend_forecasts = self.fcm_gating_block(base_forecasts)
753
+
754
+ extend_forecasts = self.mlp(extend_forecasts)
755
+
756
+ new_forecast = extend_forecasts + residual
757
+
758
+ return new_forecast
759
+
760
+
761
+ class TinyTimeMixerLayer(nn.Module):
762
+ """
763
+ The `TinyTimeMixer` layer that does all three kinds of mixing.
764
+
765
+ Args:
766
+ config (`TinyTimeMixerConfig`, *required*):
767
+ Configuration.
768
+
769
+ """
770
+
771
+ def __init__(self, config: TinyTimeMixerConfig):
772
+ super().__init__()
773
+
774
+ if config.num_patches > 1:
775
+ self.patch_mixer = PatchMixerBlock(config=config)
776
+
777
+ self.feature_mixer = FeatureMixerBlock(config=config)
778
+
779
+ self.mode = config.mode
780
+ self.num_patches = config.num_patches
781
+ if config.mode == "mix_channel":
782
+ self.channel_feature_mixer = TinyTimeMixerChannelFeatureMixerBlock(config=config)
783
+
784
+ def forward(self, hidden: torch.Tensor):
785
+ """
786
+ Args:
787
+ hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
788
+ Input tensor to the layer.
789
+
790
+ Returns:
791
+ `torch.Tensor`: Transformed tensor.
792
+ """
793
+ if self.mode == "mix_channel":
794
+ hidden = self.channel_feature_mixer(hidden)
795
+
796
+ if self.num_patches > 1:
797
+ hidden = self.patch_mixer(hidden)
798
+ hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
799
+ return hidden
800
+
801
+
802
+ class TinyTimeMixerAdaptivePatchingBlock(nn.Module):
803
+ """
804
+ The `TinyTimeMixer` layer that does all three kinds of mixing.
805
+
806
+ Args:
807
+ config (`TinyTimeMixerConfig`, *required*):
808
+ Configuration.
809
+
810
+ """
811
+
812
+ def __init__(self, config: TinyTimeMixerConfig, adapt_patch_level: int):
813
+ super().__init__()
814
+ temp_config = copy.deepcopy(config)
815
+ self.adapt_patch_level = adapt_patch_level
816
+ adaptive_patch_factor = 2**adapt_patch_level
817
+ self.adaptive_patch_factor = adaptive_patch_factor
818
+
819
+ if config.d_model // self.adaptive_patch_factor <= 4:
820
+ # do not allow reduction beyond d_model less than 4
821
+ # logger.warning(
822
+ # "Disabling adaptive patching at level %s. Either increase d_model or reduce adaptive_patching_levels"
823
+ # % (adapt_patch_level)
824
+ # )
825
+ self.adaptive_patch_factor = 1
826
+
827
+ if config.d_model % self.adaptive_patch_factor != 0:
828
+ raise ValueError("d_model should be divisible by 2^i, where i varies from 0 to adaptive_patching_levels.")
829
+ temp_config.num_patches = temp_config.num_patches * self.adaptive_patch_factor
830
+ temp_config.d_model = temp_config.d_model // self.adaptive_patch_factor
831
+
832
+ self.mixer_layers = nn.ModuleList([TinyTimeMixerLayer(temp_config) for i in range(temp_config.num_layers)])
833
+
834
+ def forward(self, hidden: torch.Tensor):
835
+ """
836
+ Args:
837
+ hidden (`torch.Tensor` of shape `(batch_size x nvars x num_patch x d_model)`):
838
+ Input tensor to the layer.
839
+
840
+ Returns:
841
+ `torch.Tensor`: Transformed tensor.
842
+ """
843
+ all_hidden_states = []
844
+ all_hidden_states.append(hidden)
845
+ hidden = torch.reshape(
846
+ hidden,
847
+ (
848
+ hidden.shape[0],
849
+ hidden.shape[1],
850
+ hidden.shape[2] * self.adaptive_patch_factor,
851
+ hidden.shape[3] // self.adaptive_patch_factor,
852
+ ),
853
+ )
854
+ all_hidden_states.append(hidden)
855
+
856
+ for mod in self.mixer_layers:
857
+ hidden = mod(hidden)
858
+ all_hidden_states.append(hidden)
859
+
860
+ hidden = torch.reshape(
861
+ hidden,
862
+ (
863
+ hidden.shape[0],
864
+ hidden.shape[1],
865
+ hidden.shape[2] // self.adaptive_patch_factor,
866
+ hidden.shape[3] * self.adaptive_patch_factor,
867
+ ),
868
+ )
869
+ all_hidden_states.append(hidden)
870
+
871
+ return hidden, all_hidden_states
872
+
873
+
874
+ class TinyTimeMixerBlock(nn.Module):
875
+ """The main computing framework of the `TinyTimeMixer` model.
876
+
877
+ Args:
878
+ config (`TinyTimeMixerConfig`, *required*):
879
+ Configuration.
880
+ """
881
+
882
+ def __init__(self, config: TinyTimeMixerConfig):
883
+ super().__init__()
884
+
885
+ num_layers = config.num_layers
886
+
887
+ self.adaptive_patching_levels = config.adaptive_patching_levels
888
+
889
+ if self.adaptive_patching_levels > 0:
890
+ self.mixers = nn.ModuleList(
891
+ [
892
+ TinyTimeMixerAdaptivePatchingBlock(config=config, adapt_patch_level=i)
893
+ for i in reversed(range(config.adaptive_patching_levels))
894
+ ]
895
+ )
896
+
897
+ else:
898
+ self.mixers = nn.ModuleList([TinyTimeMixerLayer(config=config) for _ in range(num_layers)])
899
+
900
+ def forward(self, hidden_state, output_hidden_states: bool = False):
901
+ """
902
+ Args:
903
+ hidden_state (`torch.Tensor`): The input tensor.
904
+ output_hidden_states (`bool`, *optional*, defaults to False.):
905
+ Whether to output the hidden states as well.
906
+
907
+ Returns:
908
+ `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
909
+ `True`.
910
+ """
911
+ all_hidden_states = []
912
+
913
+ embedding = hidden_state
914
+
915
+ for mod in self.mixers:
916
+ if self.adaptive_patching_levels > 0:
917
+ embedding, hidden_states = mod(embedding)
918
+ all_hidden_states.extend(hidden_states)
919
+ else:
920
+ embedding = mod(embedding)
921
+ if output_hidden_states:
922
+ all_hidden_states.append(embedding)
923
+
924
+ if output_hidden_states:
925
+ return embedding, all_hidden_states
926
+ else:
927
+ return embedding, None
928
+
929
+
930
+ class TinyTimeMixerDecoder(nn.Module):
931
+ """Decoder for tiny time mixer
932
+
933
+ Args:
934
+ config (`TinyTimeMixerConfig`, *required*):
935
+ Configuration.
936
+ """
937
+
938
+ def __init__(self, config: TinyTimeMixerConfig):
939
+ super().__init__()
940
+
941
+ if config.d_model != config.decoder_d_model:
942
+ self.adapter = nn.Linear(config.d_model, config.decoder_d_model)
943
+ else:
944
+ self.adapter = None
945
+
946
+ self.decoder_raw_residual = config.decoder_raw_residual
947
+ self.num_input_channels = config.num_input_channels
948
+
949
+ if config.decoder_raw_residual:
950
+ self.decoder_raw_embedding = nn.Linear(config.patch_length, config.decoder_d_model)
951
+ # nn.init.zeros_(self.decoder_raw_embedding.weight)
952
+ # nn.init.zeros_(self.decoder_raw_embedding.bias)
953
+
954
+ decoder_config = copy.deepcopy(config)
955
+ decoder_config.num_layers = config.decoder_num_layers
956
+ decoder_config.d_model = config.decoder_d_model
957
+ decoder_config.dropout = config.head_dropout
958
+ decoder_config.adaptive_patching_levels = config.decoder_adaptive_patching_levels
959
+ decoder_config.mode = config.decoder_mode
960
+
961
+ if config.categorical_vocab_size_list is not None:
962
+ if config.decoder_mode == "common_channel":
963
+ # logger.warning("Setting decoder_mode to mix_channel as static categorical variables is available")
964
+ # config.decoder_mode = "mix_channel"
965
+ raise ValueError("set decoder_mode to mix_channel when using static categorical variables")
966
+
967
+ decoder_config.num_input_channels += len(config.categorical_vocab_size_list)
968
+ self.decoder_cat_embedding_layer = TinyTimeMixerCategoricalEmbeddingLayer(decoder_config)
969
+ else:
970
+ self.decoder_cat_embedding_layer = None
971
+
972
+ self.decoder_block = TinyTimeMixerBlock(decoder_config)
973
+
974
+ self.resolution_prefix_tuning = config.resolution_prefix_tuning
975
+
976
+ def forward(
977
+ self,
978
+ hidden_state,
979
+ patch_input,
980
+ output_hidden_states: bool = False,
981
+ static_categorical_values: Optional[torch.Tensor] = None,
982
+ ):
983
+ """
984
+ Args:
985
+ hidden_state (`torch.Tensor` of shape `(batch_size x nvars x num_patch x d_model)`): The input tensor from backbone.
986
+ output_hidden_states (`bool`, *optional*, defaults to False.):
987
+ Whether to output the hidden states as well.
988
+
989
+ static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
990
+ Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
991
+ TinyTimeMixerConfig param `categorical_vocab_size_list`
992
+
993
+ Returns:
994
+ `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
995
+ `True`.
996
+ """
997
+ if output_hidden_states:
998
+ decoder_hidden_states = []
999
+ else:
1000
+ decoder_hidden_states = None
1001
+
1002
+ decoder_input = hidden_state
1003
+
1004
+ if self.adapter is not None:
1005
+ decoder_input = self.adapter(
1006
+ hidden_state
1007
+ ) # model_output: [batch_size x nvars x num_patch x decoder_d_model]
1008
+ if output_hidden_states:
1009
+ decoder_hidden_states.append(decoder_input)
1010
+
1011
+ if self.decoder_raw_residual:
1012
+ if self.resolution_prefix_tuning:
1013
+ if patch_input.shape[-2] == decoder_input.shape[-2] - 1:
1014
+ temp_shape = list(patch_input.shape)
1015
+ temp_shape[-2] = 1
1016
+ temp_zeros = torch.zeros(*temp_shape).to(patch_input.device)
1017
+ patch_input = torch.cat([temp_zeros, patch_input], dim=-2)
1018
+
1019
+ decoder_input = decoder_input + self.decoder_raw_embedding(
1020
+ patch_input
1021
+ ) # [batch_size x nvars x num_patch x decoder_d_model]
1022
+ if output_hidden_states:
1023
+ decoder_hidden_states.append(decoder_input)
1024
+
1025
+ if self.decoder_cat_embedding_layer is not None:
1026
+ if static_categorical_values is None:
1027
+ raise ValueError("Missing static_categorical_values tensor in forward call")
1028
+ cat_embeddings = self.decoder_cat_embedding_layer(
1029
+ static_categorical_values
1030
+ ) # bs x n_cat x n_patches x d_model
1031
+
1032
+ decoder_input = torch.concat(
1033
+ (decoder_input, cat_embeddings), dim=1
1034
+ ) # bs x nvars+n_cat x n_patches x d_model
1035
+
1036
+ decoder_output, hidden_states = self.decoder_block(
1037
+ hidden_state=decoder_input, output_hidden_states=output_hidden_states
1038
+ ) # bs x nvars+n_cat x n_patches x d_model
1039
+
1040
+ if output_hidden_states:
1041
+ decoder_hidden_states.extend(hidden_states)
1042
+
1043
+ if self.decoder_cat_embedding_layer is not None:
1044
+ decoder_output = decoder_output[:, : self.num_input_channels, :, :] # bs x nvars x n_patches x d_model
1045
+ if output_hidden_states:
1046
+ decoder_hidden_states.append(decoder_output)
1047
+
1048
+ return decoder_output, decoder_hidden_states
1049
+
1050
+
1051
+ class TinyTimeMixerForPredictionHead(nn.Module):
1052
+ """Prediction Head for Forecasting
1053
+
1054
+ Args:
1055
+ config (`TinyTimeMixerConfig`, *required*): Configuration.
1056
+ """
1057
+
1058
+ def __init__(self, config: TinyTimeMixerConfig, distribution_output=None):
1059
+ super().__init__()
1060
+
1061
+ self.prediction_channel_indices = config.prediction_channel_indices
1062
+
1063
+ if self.prediction_channel_indices is not None:
1064
+ self.prediction_channel_indices.sort()
1065
+
1066
+ self.prediction_filter_length = config.prediction_filter_length
1067
+
1068
+ self.dropout_layer = nn.Dropout(config.head_dropout)
1069
+ self.enable_forecast_channel_mixing = config.enable_forecast_channel_mixing
1070
+ if config.use_decoder:
1071
+ head_d_model = config.decoder_d_model
1072
+ else:
1073
+ head_d_model = config.d_model
1074
+
1075
+ if distribution_output is None:
1076
+ self.base_forecast_block = nn.Linear((config.num_patches * head_d_model), config.prediction_length)
1077
+ else:
1078
+ self.base_forecast_block = distribution_output.get_parameter_projection(config.num_patches * head_d_model)
1079
+
1080
+ self.flatten = nn.Flatten(start_dim=-2)
1081
+
1082
+ if self.enable_forecast_channel_mixing:
1083
+ temp_config = copy.deepcopy(config)
1084
+ if self.prediction_filter_length is not None:
1085
+ temp_config.prediction_length = self.prediction_filter_length
1086
+
1087
+ self.fcm_block = ForecastChannelHeadMixer(config=temp_config)
1088
+
1089
+ def forward(self, hidden_features, past_values, future_values=None):
1090
+ """
1091
+
1092
+ Args:
1093
+ hidden_features `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
1094
+ features.
1095
+
1096
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
1097
+ Context values of the time series. For a forecasting task, this denotes the history/past time series values.
1098
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
1099
+ greater than 1.
1100
+
1101
+ future_values (`torch.Tensor` of shape `(batch_size, prediction length, input_channels)`, *optional*, Defaults to None):
1102
+ Actual groundtruths of the forecasts. Pass dummy values (say 0) for forecast channels, if groundtruth is unknown.
1103
+ Pass the correct values for Exogenous channels where the forecast values are known.
1104
+
1105
+
1106
+ Returns:
1107
+ `torch.Tensor` of shape `(batch_size, prediction_length, forecast_channels)`.
1108
+
1109
+ """
1110
+
1111
+ hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
1112
+ hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
1113
+ forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
1114
+ if isinstance(forecast, tuple):
1115
+ forecast = tuple(z.transpose(-1, -2) for z in forecast)
1116
+ else:
1117
+ forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
1118
+
1119
+ if self.prediction_channel_indices is not None:
1120
+ if isinstance(forecast, tuple):
1121
+ forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
1122
+ else:
1123
+ forecast = forecast[
1124
+ ..., self.prediction_channel_indices
1125
+ ] # [batch_size x prediction_length x prediction_n_vars]
1126
+
1127
+ if self.prediction_filter_length is not None:
1128
+ if isinstance(forecast, tuple):
1129
+ forecast = tuple(z[:, : self.prediction_filter_length, :] for z in forecast)
1130
+ else:
1131
+ forecast = forecast[
1132
+ :, : self.prediction_filter_length, :
1133
+ ] # [batch_size x prediction_filter_length x prediction_n_vars]
1134
+
1135
+ if (
1136
+ self.prediction_filter_length is not None
1137
+ and future_values is not None
1138
+ and future_values.shape[1] != self.prediction_filter_length
1139
+ ):
1140
+ future_values = future_values[
1141
+ :, : self.prediction_filter_length, :
1142
+ ] # [batch_size x prediction_filter_length x n_vars]
1143
+
1144
+ if self.enable_forecast_channel_mixing:
1145
+ if isinstance(forecast, tuple):
1146
+ raise ValueError("Forecast channel mixing is not enabled for distribution head")
1147
+ else:
1148
+ forecast = self.fcm_block(forecast, past_values=past_values, future_values=future_values)
1149
+ # [batch_size x prediction_length x prediction_n_vars]
1150
+
1151
+ return forecast
1152
+
1153
+
1154
+ class TinyTimeMixerPreTrainedModel(PreTrainedModel):
1155
+ # Weight initialization
1156
+ config_class = TinyTimeMixerConfig
1157
+ base_model_prefix = "model"
1158
+ main_input_name = "past_values"
1159
+ supports_gradient_checkpointing = False
1160
+
1161
+ def _init_weights(self, module):
1162
+ """Initialize weights"""
1163
+
1164
+ if isinstance(module, TinyTimeMixerPositionalEncoding):
1165
+ # initialize positional encoding
1166
+ if self.config.positional_encoding_type == "random":
1167
+ nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
1168
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
1169
+ module.bias.data.zero_()
1170
+ module.weight.data.fill_(1.0)
1171
+ elif isinstance(module, TinyTimeMixerBatchNorm):
1172
+ module.batchnorm.bias.data.zero_()
1173
+ module.batchnorm.weight.data.fill_(1.0)
1174
+ elif isinstance(module, nn.Linear):
1175
+ # print(f"Initializing Linear layers with method: {self.config.init_linear}")
1176
+ if self.config.init_linear == "normal":
1177
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
1178
+ if module.bias is not None:
1179
+ module.bias.data.zero_()
1180
+ elif self.config.init_linear == "uniform":
1181
+ nn.init.uniform_(module.weight)
1182
+ if module.bias is not None:
1183
+ module.bias.data.zero_()
1184
+ elif self.config.init_linear == "xavier_uniform":
1185
+ nn.init.xavier_uniform_(module.weight)
1186
+ if module.bias is not None:
1187
+ module.bias.data.zero_()
1188
+ else:
1189
+ module.reset_parameters()
1190
+ elif isinstance(module, nn.Embedding):
1191
+ # print(f"Initializing Embedding layers with method: {self.config.init_embed}")
1192
+ if self.config.init_embed == "normal":
1193
+ nn.init.normal_(module.weight)
1194
+ elif self.config.init_embed == "uniform":
1195
+ nn.init.uniform_(module.weight)
1196
+ elif self.config.init_embed == "xavier_uniform":
1197
+ nn.init.xavier_uniform_(module.weight)
1198
+ else:
1199
+ module.reset_parameters()
1200
+
1201
+
1202
+ class TinyTimeMixerPatchify(nn.Module):
1203
+ """
1204
+ A class to patchify the time series sequence into different patches
1205
+
1206
+ Returns:
1207
+ `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
1208
+ """
1209
+
1210
+ def __init__(self, config: TinyTimeMixerConfig):
1211
+ super().__init__()
1212
+
1213
+ self.sequence_length = (
1214
+ config.masked_context_length if config.masked_context_length is not None else config.context_length
1215
+ )
1216
+
1217
+ self.patch_length = config.patch_length
1218
+ self.patch_stride = config.patch_stride
1219
+
1220
+ if self.sequence_length <= self.patch_length:
1221
+ raise ValueError(
1222
+ f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
1223
+ )
1224
+
1225
+ # get the number of patches
1226
+ self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
1227
+ new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
1228
+ self.sequence_start = self.sequence_length - new_sequence_length
1229
+
1230
+ def forward(self, past_values: torch.Tensor):
1231
+ """
1232
+ Parameters:
1233
+ past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
1234
+ Input for patchification
1235
+
1236
+ Returns:
1237
+ `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
1238
+ """
1239
+ sequence_length = past_values.shape[-2]
1240
+ if sequence_length != self.sequence_length:
1241
+ raise ValueError(
1242
+ f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
1243
+ )
1244
+ # output: [bs x new_sequence_length x num_channels]
1245
+ output = past_values[:, self.sequence_start :, :]
1246
+ # output: [bs x num_patches x num_input_channels x patch_length]
1247
+ output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
1248
+ # output: [bs x num_input_channels x num_patches x patch_length]
1249
+ output = output.transpose(-2, -3).contiguous()
1250
+ return output
1251
+
1252
+
1253
+ class TinyTimeMixerStdScaler(nn.Module):
1254
+ """
1255
+ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
1256
+ subtracting from the mean and dividing by the standard deviation.
1257
+ """
1258
+
1259
+ def __init__(self, config: TinyTimeMixerConfig):
1260
+ super().__init__()
1261
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
1262
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
1263
+ self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
1264
+
1265
+ def forward(
1266
+ self, data: torch.Tensor, observed_indicator: torch.Tensor
1267
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1268
+ """
1269
+ Parameters:
1270
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1271
+ input for Batch norm calculation
1272
+ observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1273
+ Calculating the scale on the observed indicator.
1274
+ Returns:
1275
+ tuple of `torch.Tensor` of shapes
1276
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
1277
+ `(batch_size, 1, num_input_channels)`)
1278
+ """
1279
+
1280
+ denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
1281
+ denominator = denominator.clamp_min(torch.tensor(1, device=denominator.device))
1282
+ loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
1283
+
1284
+ variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
1285
+ scale = torch.sqrt(variance + self.minimum_scale)
1286
+ return (data - loc) / scale, loc, scale
1287
+
1288
+
1289
+ class TinyTimeMixerMeanScaler(nn.Module):
1290
+ """
1291
+ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
1292
+ accordingly.
1293
+ """
1294
+
1295
+ def __init__(self, config: TinyTimeMixerConfig):
1296
+ super().__init__()
1297
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
1298
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
1299
+ self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
1300
+ self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
1301
+
1302
+ def forward(
1303
+ self, data: torch.Tensor, observed_indicator: torch.Tensor
1304
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1305
+ """
1306
+ Parameters:
1307
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1308
+ input for Batch norm calculation
1309
+ observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1310
+ Calculating the scale on the observed indicator.
1311
+ Returns:
1312
+ tuple of `torch.Tensor` of shapes
1313
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
1314
+ `(batch_size, 1, num_input_channels)`)
1315
+ """
1316
+ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
1317
+ num_observed = observed_indicator.sum(self.dim, keepdim=True)
1318
+
1319
+ scale = ts_sum / torch.clamp(num_observed, min=1)
1320
+
1321
+ # If `default_scale` is provided, we use it, otherwise we use the scale
1322
+ # of the batch.
1323
+ if self.default_scale is None:
1324
+ batch_sum = ts_sum.sum(dim=0)
1325
+ batch_observations = torch.clamp(num_observed.sum(0), min=1)
1326
+ default_scale = torch.squeeze(batch_sum / batch_observations)
1327
+ else:
1328
+ default_scale = self.default_scale * torch.ones_like(scale)
1329
+
1330
+ # apply default scale where there are no observations
1331
+ scale = torch.where(num_observed > 0, scale, default_scale)
1332
+
1333
+ # ensure the scale is at least `self.minimum_scale`
1334
+ scale = torch.clamp(scale, min=self.minimum_scale)
1335
+ scaled_data = data / scale
1336
+
1337
+ if not self.keepdim:
1338
+ scale = scale.squeeze(dim=self.dim)
1339
+
1340
+ return scaled_data, torch.zeros_like(scale), scale
1341
+
1342
+
1343
+ class TinyTimeMixerNOPScaler(nn.Module):
1344
+ """
1345
+ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
1346
+ """
1347
+
1348
+ def __init__(self, config: TinyTimeMixerConfig):
1349
+ super().__init__()
1350
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
1351
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
1352
+
1353
+ def forward(
1354
+ self, data: torch.Tensor, observed_indicator: torch.Tensor = None
1355
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1356
+ """
1357
+ Parameters:
1358
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1359
+ input for Batch norm calculation
1360
+ Returns:
1361
+ tuple of `torch.Tensor` of shapes
1362
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
1363
+ `(batch_size, 1, num_input_channels)`)
1364
+ """
1365
+ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
1366
+ loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
1367
+ return data, loc, scale
1368
+
1369
+
1370
+ @dataclass
1371
+ class TinyTimeMixerEncoderOutput(ModelOutput):
1372
+ """
1373
+ Base class for `TinyTimeMixerEncoderOutput`, with potential hidden states.
1374
+
1375
+ Args:
1376
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
1377
+ Hidden-state at the output of the last layer of the model.
1378
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
1379
+ Hidden-states of the model at the output of each layer.
1380
+ """
1381
+
1382
+ last_hidden_state: torch.FloatTensor = None
1383
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1384
+
1385
+
1386
+ class TinyTimeMixerEncoder(TinyTimeMixerPreTrainedModel):
1387
+ """
1388
+ Encoder for TinyTimeMixer which inputs patched time-series and outputs patched embeddings.
1389
+
1390
+ Args:
1391
+ config (`TinyTimeMixerConfig`, *required*):
1392
+ Configuration.
1393
+ """
1394
+
1395
+ def __init__(self, config: TinyTimeMixerConfig):
1396
+ if config.init_processing is False:
1397
+ config.check_and_init_preprocessing()
1398
+
1399
+ super().__init__(config)
1400
+
1401
+ self.use_return_dict = config.use_return_dict
1402
+
1403
+ self.patcher = nn.Linear(config.patch_length, config.d_model)
1404
+ if config.use_positional_encoding:
1405
+ self.positional_encoder = TinyTimeMixerPositionalEncoding(config=config)
1406
+ else:
1407
+ self.positional_encoder = None
1408
+ self.mlp_mixer_encoder = TinyTimeMixerBlock(config=config)
1409
+
1410
+ if config.resolution_prefix_tuning:
1411
+ mid_dim = (config.patch_length + config.d_model) // 2
1412
+
1413
+ self.freq_mod = nn.Sequential(
1414
+ nn.Embedding(config.frequency_token_vocab_size, config.patch_length),
1415
+ nn.Linear(config.patch_length, mid_dim),
1416
+ nn.GELU(),
1417
+ nn.Linear(mid_dim, config.d_model),
1418
+ )
1419
+ self.resolution_prefix_tuning = config.resolution_prefix_tuning
1420
+ self.d_model = config.d_model
1421
+
1422
+ # # Initialize weights and apply final processing
1423
+ # if config.post_init:
1424
+ # self.post_init()
1425
+
1426
+ def forward(
1427
+ self,
1428
+ past_values: torch.Tensor,
1429
+ output_hidden_states: Optional[bool] = False,
1430
+ return_dict: Optional[bool] = None,
1431
+ freq_token: Optional[torch.Tensor] = None,
1432
+ ) -> Union[Tuple, TinyTimeMixerEncoderOutput]:
1433
+ r"""
1434
+ Args:
1435
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
1436
+ Context values of the time series.
1437
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
1438
+ it is greater than 1.
1439
+
1440
+ output_hidden_states (`bool`, *optional*):
1441
+ Whether or not to return the hidden states of all layers.
1442
+
1443
+ return_dict (`bool`, *optional*):
1444
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1445
+
1446
+ Returns:
1447
+ `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
1448
+ """
1449
+
1450
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
1451
+
1452
+ # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
1453
+ patches = self.patcher(past_values)
1454
+
1455
+ if self.resolution_prefix_tuning:
1456
+ if freq_token is not None:
1457
+ freq_embedding = self.freq_mod(freq_token.long()) # bs x d_model
1458
+
1459
+ freq_embedding = freq_embedding.view(patches.shape[0], 1, 1, self.d_model)
1460
+ freq_embedding = freq_embedding.expand(
1461
+ patches.shape[0],
1462
+ patches.shape[1],
1463
+ 1,
1464
+ self.d_model,
1465
+ ) # bs x channels x 1 x num_features
1466
+
1467
+ patches = torch.cat((freq_embedding, patches), dim=-2) # bs x channels x num_patch+1 x num_features
1468
+
1469
+ else:
1470
+ raise Exception("Expecting freq_token in forward")
1471
+
1472
+ # add positional encoder
1473
+ if self.positional_encoder is not None:
1474
+ patches = self.positional_encoder(patches)
1475
+
1476
+ last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
1477
+
1478
+ if not return_dict:
1479
+ return tuple(
1480
+ v
1481
+ for v in [
1482
+ last_hidden_state,
1483
+ hidden_states,
1484
+ ]
1485
+ )
1486
+
1487
+ return TinyTimeMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
1488
+
1489
+
1490
+ @dataclass
1491
+ class TinyTimeMixerModelOutput(ModelOutput):
1492
+ """
1493
+ Base class for model's outputs, with potential hidden states.
1494
+
1495
+ Args:
1496
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
1497
+ Hidden-state at the output of the last layer of the model.
1498
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
1499
+ Hidden-states of the model at the output of each layer.
1500
+ patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
1501
+ Patched input data to the model.
1502
+ loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
1503
+ Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
1504
+ enabled.
1505
+ scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
1506
+ Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
1507
+ enabled.
1508
+ """
1509
+
1510
+ last_hidden_state: torch.FloatTensor = None
1511
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1512
+ patch_input: torch.FloatTensor = None
1513
+ loc: Optional[torch.FloatTensor] = None
1514
+ scale: Optional[torch.FloatTensor] = None
1515
+
1516
+
1517
+ class TinyTimeMixerModel(TinyTimeMixerPreTrainedModel):
1518
+ def __init__(self, config: TinyTimeMixerConfig):
1519
+ if config.init_processing is False:
1520
+ config.check_and_init_preprocessing()
1521
+
1522
+ super().__init__(config)
1523
+
1524
+ self.use_return_dict = config.use_return_dict
1525
+ self.encoder = TinyTimeMixerEncoder(config)
1526
+ self.patching = TinyTimeMixerPatchify(config)
1527
+
1528
+ if config.scaling == "mean":
1529
+ self.scaler = TinyTimeMixerMeanScaler(config)
1530
+ elif config.scaling == "std" or config.scaling is True:
1531
+ self.scaler = TinyTimeMixerStdScaler(config)
1532
+ else:
1533
+ self.scaler = TinyTimeMixerNOPScaler(config)
1534
+
1535
+ self.d_model = config.d_model
1536
+
1537
+ # # Initialize weights and apply final processing
1538
+ # if config.post_init:
1539
+ # self.post_init()
1540
+
1541
+ def forward(
1542
+ self,
1543
+ past_values: torch.Tensor,
1544
+ past_observed_mask: Optional[torch.Tensor] = None,
1545
+ output_hidden_states: Optional[bool] = False,
1546
+ return_dict: Optional[bool] = None,
1547
+ freq_token: Optional[torch.Tensor] = None,
1548
+ ) -> TinyTimeMixerModelOutput:
1549
+ r"""
1550
+ past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
1551
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
1552
+ in `[0, 1]` or `[False, True]`:
1553
+ - 1 or True for values that are **observed**,
1554
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
1555
+
1556
+ Returns:
1557
+
1558
+ """
1559
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
1560
+
1561
+ if past_observed_mask is None:
1562
+ past_observed_mask = torch.ones_like(past_values)
1563
+ scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask)
1564
+
1565
+ patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
1566
+
1567
+ enc_input = patched_x
1568
+
1569
+ encoder_output = self.encoder(
1570
+ enc_input,
1571
+ output_hidden_states=output_hidden_states,
1572
+ return_dict=return_dict,
1573
+ freq_token=freq_token,
1574
+ )
1575
+
1576
+ if isinstance(encoder_output, tuple):
1577
+ encoder_output = TinyTimeMixerEncoderOutput(*encoder_output)
1578
+
1579
+ if not return_dict:
1580
+ return tuple(
1581
+ v
1582
+ for v in [
1583
+ encoder_output.last_hidden_state,
1584
+ encoder_output.hidden_states,
1585
+ patched_x,
1586
+ loc,
1587
+ scale,
1588
+ ]
1589
+ )
1590
+
1591
+ return TinyTimeMixerModelOutput(
1592
+ last_hidden_state=encoder_output.last_hidden_state,
1593
+ hidden_states=encoder_output.hidden_states,
1594
+ patch_input=patched_x,
1595
+ loc=loc,
1596
+ scale=scale,
1597
+ )
1598
+
1599
+
1600
+ @dataclass
1601
+ class TinyTimeMixerForPredictionOutput(ModelOutput):
1602
+ """
1603
+ Output type of [`TinyTimeMixerForPredictionOutput`].
1604
+
1605
+ Args:
1606
+ prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
1607
+ Prediction output from the forecast head.
1608
+ backbone_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
1609
+ Backbone embeddings before passing through the decoder
1610
+ decoder_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
1611
+ Decoder embeddings before passing through the head.
1612
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
1613
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1614
+ loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
1615
+ Total loss.
1616
+ loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
1617
+ Input mean
1618
+ scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
1619
+ Input std dev
1620
+
1621
+ """
1622
+
1623
+ loss: Optional[torch.FloatTensor] = None
1624
+ prediction_outputs: torch.FloatTensor = None
1625
+ backbone_hidden_state: torch.FloatTensor = None
1626
+ decoder_hidden_state: torch.FloatTensor = None
1627
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1628
+ loc: torch.FloatTensor = None
1629
+ scale: torch.FloatTensor = None
1630
+
1631
+
1632
+ @dataclass
1633
+ class SampleTinyTimeMixerPredictionOutput(ModelOutput):
1634
+ """
1635
+ Base class for time series model's predictions outputs that contains the sampled values from the chosen
1636
+ distribution.
1637
+
1638
+ Args:
1639
+ sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
1640
+ Sampled values from the chosen distribution.
1641
+ """
1642
+
1643
+ sequences: torch.FloatTensor = None
1644
+
1645
+
1646
+ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
1647
+ """
1648
+ Computes the negative log likelihood loss from input distribution with respect to target.
1649
+ """
1650
+ return -input.log_prob(target)
1651
+
1652
+
1653
+ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
1654
+ """
1655
+ Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
1656
+ meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
1657
+
1658
+ Args:
1659
+ input_tensor (`torch.FloatTensor`):
1660
+ Input tensor, of which the average must be computed.
1661
+ weights (`torch.FloatTensor`, *optional*):
1662
+ Weights tensor, of the same shape as `input_tensor`.
1663
+ dim (`int`, *optional*):
1664
+ The dim along which to average `input_tensor`.
1665
+
1666
+ Returns:
1667
+ `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
1668
+ """
1669
+ if weights is not None:
1670
+ weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
1671
+ sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
1672
+ return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
1673
+ else:
1674
+ return input_tensor.mean(dim=dim)
1675
+
1676
+
1677
+ class TinyTimeMixerForPrediction(TinyTimeMixerPreTrainedModel):
1678
+ r"""
1679
+ `TinyTimeMixer` for forecasting application.
1680
+
1681
+ Args:
1682
+ config (`TinyTimeMixerConfig`, *required*):
1683
+ Configuration.
1684
+
1685
+ Returns:
1686
+ `None`.
1687
+ """
1688
+
1689
+ def __init__(self, config: TinyTimeMixerConfig):
1690
+ config.check_and_init_preprocessing()
1691
+ super().__init__(config)
1692
+
1693
+ self.config = config
1694
+
1695
+ self.loss = config.loss
1696
+
1697
+ self.use_return_dict = config.use_return_dict
1698
+
1699
+ self.prediction_channel_indices = config.prediction_channel_indices
1700
+ self.num_parallel_samples = config.num_parallel_samples
1701
+
1702
+ self.num_input_channels = config.num_input_channels
1703
+
1704
+ self.prediction_filter_length = config.prediction_filter_length
1705
+
1706
+ if config.loss in ["mse", "mae", "pinball", "huber"] or config.loss is None:
1707
+ self.distribution_output = None
1708
+ elif config.loss == "nll":
1709
+ if self.prediction_filter_length is None:
1710
+ dim = config.prediction_length
1711
+ else:
1712
+ dim = config.prediction_filter_length
1713
+
1714
+ distribution_output_map = {
1715
+ "student_t": StudentTOutput,
1716
+ "normal": NormalOutput,
1717
+ "negative_binomial": NegativeBinomialOutput,
1718
+ }
1719
+ output_class = distribution_output_map.get(config.distribution_output, None)
1720
+ if output_class is not None:
1721
+ self.distribution_output = output_class(dim=dim)
1722
+ else:
1723
+ raise ValueError(f"Unknown distribution output {config.distribution_output}")
1724
+
1725
+ self.backbone = TinyTimeMixerModel(config)
1726
+
1727
+ self.use_decoder = config.use_decoder
1728
+
1729
+ if config.use_decoder:
1730
+ self.decoder = TinyTimeMixerDecoder(config)
1731
+
1732
+ self.head = TinyTimeMixerForPredictionHead(
1733
+ config=config,
1734
+ distribution_output=self.distribution_output,
1735
+ )
1736
+
1737
+ # Initialize weights and apply final processing
1738
+ if config.post_init:
1739
+ self.post_init()
1740
+
1741
+ def forward(
1742
+ self,
1743
+ past_values: torch.Tensor,
1744
+ future_values: Optional[torch.Tensor] = None,
1745
+ past_observed_mask: Optional[torch.Tensor] = None,
1746
+ future_observed_mask: Optional[torch.Tensor] = None,
1747
+ output_hidden_states: Optional[bool] = False,
1748
+ return_loss: bool = True,
1749
+ return_dict: Optional[bool] = None,
1750
+ freq_token: Optional[torch.Tensor] = None,
1751
+ static_categorical_values: Optional[torch.Tensor] = None,
1752
+ metadata: Optional[torch.Tensor] = None,
1753
+ ) -> TinyTimeMixerForPredictionOutput:
1754
+ r"""
1755
+ past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
1756
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
1757
+ in `[0, 1]` or `[False, True]`:
1758
+ - 1 or True for values that are **observed**,
1759
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
1760
+ future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
1761
+ `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
1762
+ values of the time series, that serve as labels for the model. The `future_values` is what the
1763
+ Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
1764
+ required for a pretraining task.
1765
+
1766
+ For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
1767
+ to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
1768
+ pass the target data with all channels, as channel Filtering for both prediction and target will be
1769
+ manually applied before the loss computation.
1770
+ future_observed_mask (`torch.Tensor` of shape `(batch_size, prediction_length, num_targets)`, *optional*):
1771
+ Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
1772
+ in `[0, 1]` or `[False, True]`:
1773
+ - 1 or True for values that are **observed**,
1774
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
1775
+ return_loss (`bool`, *optional*):
1776
+ Whether to return the loss in the `forward` call.
1777
+ static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
1778
+ Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
1779
+ TinyTimeMixerConfig param `categorical_vocab_size_list`
1780
+ metadata (`torch.Tensor`, *optional*): A tensor containing metadata. Currently unused in TinyTimeMixer, but used
1781
+ to support custom trainers. Defaults to None.
1782
+
1783
+ Returns:
1784
+
1785
+ """
1786
+ if past_values.dim() != 3:
1787
+ raise ValueError(
1788
+ "`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`."
1789
+ )
1790
+
1791
+ sequence_length = (
1792
+ self.config.masked_context_length
1793
+ if self.config.masked_context_length is not None
1794
+ else self.config.context_length
1795
+ )
1796
+
1797
+ if past_values.shape[1] > sequence_length:
1798
+ past_values = past_values[:, -sequence_length:, :]
1799
+ elif past_values.shape[1] < sequence_length:
1800
+ raise ValueError("Context length in `past_values` is shorter that TTM context_length.")
1801
+
1802
+ # if self.loss == "mse":
1803
+ # loss = nn.MSELoss(reduction="mean")
1804
+ # elif self.loss == "mae":
1805
+ # loss = nn.L1Loss(reduction="mean")
1806
+ # elif self.loss == "pinball":
1807
+ # loss = PinballLoss(quantile=self.config.quantile)
1808
+ # elif self.loss == "huber":
1809
+ # loss = nn.HuberLoss(delta=self.config.huber_delta)
1810
+ # elif self.loss == "nll":
1811
+ # raise Exception(
1812
+ # "NLL loss and Distribution heads are currently not allowed. Use mse or mae as loss functions."
1813
+ # )
1814
+ # loss = nll
1815
+ # elif self.loss is None:
1816
+ # loss = None
1817
+ # else:
1818
+ # raise ValueError("Invalid loss function: Allowed values: mse, mae and nll")
1819
+
1820
+ loss = nn.MSELoss(reduction="mean")
1821
+
1822
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
1823
+
1824
+ # past_values: tensor [batch_size x context_length x num_input_channels]
1825
+ model_output = self.backbone(
1826
+ past_values,
1827
+ past_observed_mask=past_observed_mask,
1828
+ output_hidden_states=output_hidden_states,
1829
+ return_dict=return_dict,
1830
+ freq_token=freq_token,
1831
+ ) # model_output: [batch_size x nvars x num_patch x d_model]
1832
+
1833
+ if isinstance(model_output, tuple):
1834
+ model_output = TinyTimeMixerModelOutput(*model_output)
1835
+
1836
+ decoder_input = model_output.last_hidden_state
1837
+ hidden_states = model_output.hidden_states
1838
+
1839
+ if self.use_decoder:
1840
+ decoder_output, decoder_hidden_states = self.decoder(
1841
+ hidden_state=decoder_input,
1842
+ patch_input=model_output.patch_input,
1843
+ output_hidden_states=output_hidden_states,
1844
+ static_categorical_values=static_categorical_values,
1845
+ ) # [batch_size x nvars x num_patch x decoder_d_model]
1846
+
1847
+ if decoder_hidden_states:
1848
+ hidden_states.extend(decoder_hidden_states)
1849
+
1850
+ else:
1851
+ decoder_output = decoder_input
1852
+
1853
+ # tensor [batch_size x prediction_length x num_input_channels]
1854
+
1855
+ # head should take future mask
1856
+ y_hat = self.head(decoder_output, past_values=past_values, future_values=future_values)
1857
+
1858
+ if (
1859
+ self.prediction_filter_length is not None
1860
+ and future_values is not None
1861
+ and future_values.shape[1] != self.prediction_filter_length
1862
+ ):
1863
+ future_values = future_values[:, : self.prediction_filter_length, :]
1864
+
1865
+ if future_observed_mask is not None:
1866
+ future_observed_mask = future_observed_mask[:, : self.prediction_filter_length, :]
1867
+
1868
+ if (
1869
+ self.prediction_channel_indices is not None
1870
+ and future_values is not None
1871
+ and future_values.shape[2] != len(self.prediction_channel_indices)
1872
+ and future_values.shape[2] == self.num_input_channels
1873
+ ):
1874
+ future_values = future_values[..., self.prediction_channel_indices]
1875
+
1876
+ if future_observed_mask is not None:
1877
+ future_observed_mask = future_observed_mask[..., self.prediction_channel_indices]
1878
+
1879
+ if self.prediction_channel_indices is not None:
1880
+ loc = model_output.loc[..., self.prediction_channel_indices]
1881
+ scale = model_output.scale[..., self.prediction_channel_indices]
1882
+ else:
1883
+ loc = model_output.loc
1884
+ scale = model_output.scale
1885
+
1886
+ y_hat = y_hat * scale + loc
1887
+ # loc/scale: batch_size x 1 x prediction_channel_indices or num_targets
1888
+
1889
+ # loss_val = None
1890
+
1891
+ # if future_observed_mask is not None:
1892
+ # fut_mask_bool = future_observed_mask.type(torch.bool)
1893
+
1894
+ # if self.distribution_output:
1895
+ # distribution = self.distribution_output.distribution(y_hat, loc=loc, scale=scale)
1896
+ # if future_values is not None and return_loss is True and loss is not None:
1897
+ # if future_observed_mask is not None and (~fut_mask_bool).any():
1898
+ # if (~fut_mask_bool).all():
1899
+ # # no valid observed values
1900
+ # print(future_observed_mask)
1901
+ # raise ValueError("Loss computation failed due to too many missing values")
1902
+ # loss_val = loss(distribution, future_values)
1903
+ # # select only values of loss where entire timepoint is observed
1904
+ # loss_val = loss_val[fut_mask_bool.all(dim=-1)]
1905
+ # else:
1906
+ # loss_val = loss(distribution, future_values)
1907
+ # loss_val = weighted_average(loss_val)
1908
+ # else:
1909
+ # y_hat = y_hat * scale + loc
1910
+ # if future_values is not None and return_loss is True and loss is not None:
1911
+ # if future_observed_mask is not None:
1912
+ # loss_val = loss(y_hat[fut_mask_bool], future_values[fut_mask_bool])
1913
+ # else:
1914
+ # # avoiding mask operations for performance benefits on normal scenarios.
1915
+ # loss_val = loss(y_hat, future_values)
1916
+
1917
+ return y_hat
1918
+
1919
+ # if not return_dict:
1920
+ # return tuple(
1921
+ # v
1922
+ # for v in [
1923
+ # loss_val,
1924
+ # y_hat,
1925
+ # model_output.last_hidden_state,
1926
+ # decoder_output,
1927
+ # hidden_states,
1928
+ # loc,
1929
+ # scale,
1930
+ # ]
1931
+ # )
1932
+
1933
+ # return TinyTimeMixerForPredictionOutput(
1934
+ # loss=loss_val,
1935
+ # prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
1936
+ # backbone_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
1937
+ # decoder_hidden_state=decoder_output, # x: [batch_size x nvars x num_patch x decoder_d_model]
1938
+ # hidden_states=hidden_states,
1939
+ # loc=loc,
1940
+ # scale=scale,
1941
+ # )
1942
+
1943
+ def generate(
1944
+ self,
1945
+ past_values: torch.Tensor,
1946
+ past_observed_mask: Optional[torch.Tensor] = None,
1947
+ ) -> SampleTinyTimeMixerPredictionOutput:
1948
+ """
1949
+ Generate sequences of sample predictions from a model with a probability distribution head.
1950
+
1951
+ Args:
1952
+ past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
1953
+ Past values of the time series that serves as context in order to predict the future.
1954
+
1955
+ past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
1956
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
1957
+ in `[0, 1]` or `[False, True]`:
1958
+ - 1 or True for values that are **observed**,
1959
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
1960
+
1961
+ Return:
1962
+ [`SampleTinyTimeMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
1963
+ number of samples, prediction_length, num_input_channels)`.
1964
+ """
1965
+ # get number of samples
1966
+ num_parallel_samples = self.num_parallel_samples
1967
+
1968
+ # get model output
1969
+ outputs = self(
1970
+ past_values=past_values,
1971
+ future_values=None,
1972
+ past_observed_mask=past_observed_mask,
1973
+ output_hidden_states=False,
1974
+ )
1975
+
1976
+ # get distribution
1977
+
1978
+ distribution = self.distribution_output.distribution(
1979
+ outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
1980
+ )
1981
+
1982
+ # get samples: list of [batch_size x prediction_length x num_channels]
1983
+ samples = [distribution.sample() for _ in range(num_parallel_samples)]
1984
+
1985
+ # stack tensors
1986
+ samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
1987
+ return SampleTinyTimeMixerPredictionOutput(sequences=samples)
1988
+
1989
+
1990
+ class TinyTimeMixerForMaskedPrediction(TinyTimeMixerForPrediction):
1991
+ def __init__(self, config: TinyTimeMixerConfig):
1992
+ if config.prediction_filter_length is not None:
1993
+ append_length = config.prediction_filter_length
1994
+ else:
1995
+ append_length = config.prediction_length
1996
+
1997
+ self.append_length = append_length
1998
+ config.masked_context_length = config.context_length + append_length
1999
+ config.fcm_prepend_past_offset = append_length
2000
+
2001
+ if config.exogenous_channel_indices is not None:
2002
+ self.non_exog_channels = list(
2003
+ set(range(config.num_input_channels)) - set(config.exogenous_channel_indices)
2004
+ )
2005
+ else:
2006
+ self.non_exog_channels = list(range(config.num_input_channels))
2007
+
2008
+ super().__init__(config)
2009
+
2010
+ def forward(
2011
+ self,
2012
+ past_values: torch.Tensor,
2013
+ future_values: Optional[torch.Tensor] = None,
2014
+ past_observed_mask: Optional[torch.Tensor] = None,
2015
+ future_observed_mask: Optional[torch.Tensor] = None,
2016
+ output_hidden_states: Optional[bool] = False,
2017
+ return_loss: bool = True,
2018
+ return_dict: Optional[bool] = None,
2019
+ freq_token: Optional[torch.Tensor] = None,
2020
+ static_categorical_values: Optional[torch.Tensor] = None,
2021
+ metadata: Optional[torch.Tensor] = None,
2022
+ ) -> TinyTimeMixerForPredictionOutput:
2023
+ r"""
2024
+ past_observed_mask (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
2025
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
2026
+ in `[0, 1]` or `[False, True]`:
2027
+ - 1 or True for values that are **observed**,
2028
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
2029
+ future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
2030
+ `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
2031
+ values of the time series, that serve as labels for the model. The `future_values` is what the
2032
+ Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
2033
+ required for a pretraining task.
2034
+
2035
+ For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
2036
+ to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
2037
+ pass the target data with all channels, as channel Filtering for both prediction and target will be
2038
+ manually applied before the loss computation.
2039
+ future_observed_mask (`torch.Tensor` of shape `(batch_size, prediction_length, num_targets)`, *optional*):
2040
+ Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
2041
+ in `[0, 1]` or `[False, True]`:
2042
+ - 1 or True for values that are **observed**,
2043
+ - 0 or False for values that are **missing** (i.e. NaNs that were replaced by zeros).
2044
+ return_loss (`bool`, *optional*):
2045
+ Whether to return the loss in the `forward` call.
2046
+ static_categorical_values (`torch.FloatTensor` of shape `(batch_size, number_of_categorical_variables)`, *optional*):
2047
+ Tokenized categorical values can be passed here. Ensure to pass in the same order as the vocab size list used in the
2048
+ TinyTimeMixerConfig param `categorical_vocab_size_list`
2049
+ metadata (`torch.Tensor`, *optional*): A tensor containing metadata. Currently unused in TinyTimeMixer, but used
2050
+ to support custom trainers. Defaults to None.
2051
+
2052
+ Returns:
2053
+
2054
+ """
2055
+ if future_values is not None:
2056
+ future_values_masked = future_values.clone()
2057
+ else:
2058
+ future_values_masked = torch.zeros(past_values.shape[0], self.append_length, past_values.shape[2])
2059
+
2060
+ if (
2061
+ self.config.prediction_filter_length is not None
2062
+ and future_values_masked is not None
2063
+ and future_values_masked.shape[1] != self.config.prediction_filter_length
2064
+ ):
2065
+ future_values_masked = future_values_masked[:, : self.config.prediction_filter_length, :]
2066
+
2067
+ if self.config.exogenous_channel_indices is not None:
2068
+ future_values_masked[:, :, self.non_exog_channels] = self.config.mask_value
2069
+ else:
2070
+ future_values_masked.fill_(self.config.mask_value)
2071
+ past_values = torch.cat((past_values, future_values_masked), dim=-2) # xb: [bs x seq_len+ fl x n_vars]
2072
+
2073
+ if past_observed_mask is None:
2074
+ past_observed_mask = torch.ones_like(past_values)
2075
+
2076
+ # if there is already a past mask - update with it
2077
+ # index 1 refers to the seq len
2078
+
2079
+ if past_observed_mask.shape[1] < past_values.shape[1]:
2080
+ temp_mask = torch.ones_like(past_values)
2081
+ temp_mask[:, : past_observed_mask.shape[1], :] = past_observed_mask
2082
+ past_observed_mask = temp_mask
2083
+
2084
+ # past_observed_mask[:, -self.config.prediction_length :, :] = 0
2085
+ past_observed_mask[:, -self.config.prediction_length :, self.non_exog_channels] = 0
2086
+ # [bs x seq_len+ fl x n_vars]
2087
+
2088
+ return super().forward(
2089
+ past_values=past_values,
2090
+ future_values=future_values,
2091
+ past_observed_mask=past_observed_mask,
2092
+ future_observed_mask=future_observed_mask,
2093
+ output_hidden_states=output_hidden_states,
2094
+ return_loss=return_loss,
2095
+ return_dict=return_dict,
2096
+ freq_token=freq_token,
2097
+ static_categorical_values=static_categorical_values,
2098
+ metadata=metadata,
2099
+ )