dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,438 @@
1
+ # Default
2
+ import os
3
+ import yaml
4
+ import enum
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from importlib import resources
9
+ from typing import Optional, Union
10
+ from pandas.tseries.frequencies import to_offset
11
+
12
+ # Custom
13
+ from .configuration_tinytimemixer import TinyTimeMixerConfig
14
+ from .modeling_tinytimemixer import TinyTimeMixerForPrediction
15
+ from .consts import DEFAULT_FREQUENCY_MAPPING, TTM_LOW_RESOLUTION_MODELS_MAX_CONTEXT
16
+
17
+ # Hugging face
18
+ from transformers import PreTrainedModel
19
+
20
+ # PyTorch
21
+ import torch
22
+
23
+ ##TODO fix
24
+ TTM_CONF = {'ibm-granite-models': {'512-96-r1': {'release': 'r1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r1', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '1024-96-r1': {'release': 'r1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r1', 'revision': '1024_96_v1', 'context_length': 1024, 'prediction_length': 96}, '512-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '512-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-192-r2', 'context_length': 512, 'prediction_length': 192}, '512-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-336-r2', 'context_length': 512, 'prediction_length': 336}, '512-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-720-r2', 'context_length': 512, 'prediction_length': 720}, '1024-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-96-r2', 'context_length': 1024, 'prediction_length': 96}, '1024-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-192-r2', 'context_length': 1024, 'prediction_length': 192}, '1024-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-336-r2', 'context_length': 1024, 'prediction_length': 336}, '1024-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-720-r2', 'context_length': 1024, 'prediction_length': 720}, '1536-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-96-r2', 'context_length': 1536, 'prediction_length': 96}, '1536-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-192-r2', 'context_length': 1536, 'prediction_length': 192}, '1536-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-336-r2', 'context_length': 1536, 'prediction_length': 336}, '1536-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-720-r2', 'context_length': 1536, 'prediction_length': 720}, '52-16-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '52-16-ft-r2.1', 'context_length': 52, 'prediction_length': 16}, '52-16-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '52-16-ft-l1-r2.1', 'context_length': 52, 'prediction_length': 16}, '90-30-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '90-30-ft-r2.1', 'context_length': 90, 'prediction_length': 30}, '90-30-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '90-30-ft-l1-r2.1', 'context_length': 90, 'prediction_length': 30}, '180-60-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '180-60-ft-l1-r2.1', 'context_length': 180, 'prediction_length': 60}, '360-60-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '360-60-ft-l1-r2.1', 'context_length': 360, 'prediction_length': 60}, '512-48-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-48-ft-r2.1', 'context_length': 512, 'prediction_length': 48}, '512-48-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-48-ft-l1-r2.1', 'context_length': 512, 'prediction_length': 48}, '512-96-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-96-ft-r2.1', 'context_length': 512, 'prediction_length': 96}, '512-96-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-96-ft-l1-r2.1', 'context_length': 512, 'prediction_length': 96}}, 'research-use-models': {'512-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '512-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-192-ft-r2', 'context_length': 512, 'prediction_length': 192}, '512-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-336-ft-r2', 'context_length': 512, 'prediction_length': 336}, '512-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-720-ft-r2', 'context_length': 512, 'prediction_length': 720}, '1024-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-96-ft-r2', 'context_length': 1024, 'prediction_length': 96}, '1024-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-192-ft-r2', 'context_length': 1024, 'prediction_length': 192}, '1024-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-336-ft-r2', 'context_length': 1024, 'prediction_length': 336}, '1024-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-720-ft-r2', 'context_length': 1024, 'prediction_length': 720}, '1536-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-96-ft-r2', 'context_length': 1536, 'prediction_length': 96}, '1536-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-192-ft-r2', 'context_length': 1536, 'prediction_length': 192}, '1536-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-336-ft-r2', 'context_length': 1536, 'prediction_length': 336}, '1536-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-720-ft-r2', 'context_length': 1536, 'prediction_length': 720}}}
25
+
26
+
27
+ class ForceReturn(enum.Enum):
28
+ """`Enum` for the `force_return` parameter in the `get_model` function.
29
+
30
+ "zeropad" = Returns a pre-trained TTM that has a context length higher than the input context length, hence,
31
+ the user must apply zero-padding to use the returned model.
32
+ "rolling" = Returns a pre-trained TTM that has a prediction length lower than the requested prediction length,
33
+ hence, the user must apply rolling technique to use the returned model to forecast to the desired length.
34
+ The `RecursivePredictor` class can be utilized in this scenario.
35
+ "random_init_small" = Returns a randomly initialized small TTM which must be trained before performing inference.
36
+ "random_init_medium" = Returns a randomly initialized medium TTM which must be trained before performing inference.
37
+ "random_init_large" = Returns a randomly initialized large TTM which must be trained before performing inference.
38
+
39
+ """
40
+
41
+ ZEROPAD = "zeropad"
42
+ ROLLING = "rolling"
43
+ RANDOM_INIT_SMALL = "random_init_small"
44
+ RANDOM_INIT_MEDIUM = "random_init_medium"
45
+ RANDOM_INIT_LARGE = "random_init_large"
46
+
47
+ class ModelSize(enum.Enum):
48
+ """`Enum` for the `size` parameter in the `get_random_ttm` function."""
49
+
50
+ SMALL = "small"
51
+ MEDIUM = "medium"
52
+ LARGE = "large"
53
+
54
+
55
+ def check_ttm_model_path(model_path):
56
+ if (
57
+ "ibm/TTM" in model_path
58
+ or "ibm-granite/granite-timeseries-ttm-r1" in model_path
59
+ or "ibm-granite/granite-timeseries-ttm-v1" in model_path
60
+ or "ibm-granite/granite-timeseries-ttm-1m" in model_path
61
+ ):
62
+ return 1
63
+ elif "ibm-granite/granite-timeseries-ttm-r2" in model_path:
64
+ return 2
65
+ elif "ibm-research/ttm-research-r2" in model_path:
66
+ return 3
67
+ else:
68
+ return 0
69
+
70
+ def count_parameters(model: torch.nn.Module) -> int:
71
+ """Count trainable parameters in a model
72
+
73
+ Args:
74
+ model (torch.nn.Module): The model.
75
+
76
+ Returns:
77
+ int: Number of parameters requiring gradients.
78
+ """
79
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
80
+
81
+ def get_random_ttm(
82
+ context_length: int, prediction_length: int, size: str = ModelSize.SMALL.value, **kwargs
83
+ ) -> PreTrainedModel:
84
+ """Get a TTM with random weights.
85
+
86
+ Args:
87
+ context_length (int): Context length or history.
88
+ prediction_length (int): Prediction length or forecast horizon.
89
+ size (str, optional): Size of the desired TTM (small/medium/large). Defaults to "small".
90
+
91
+ Raises:
92
+ ValueError: If wrong size is provided.
93
+ ValueError: Context length should be at least 4 if `size=small`,
94
+ or at least 16 if `size=medium`,
95
+ or at least 32 if `size=large`.
96
+
97
+ Returns:
98
+ PreTrainedModel: TTM model with randomly initialized weights.
99
+ """
100
+ if ModelSize.SMALL.value in size.lower():
101
+ cl_lower_bound = 4
102
+ apl = 0
103
+ elif ModelSize.MEDIUM.value in size.lower():
104
+ cl_lower_bound = 16
105
+ apl = 3
106
+ elif ModelSize.LARGE.value in size.lower():
107
+ cl_lower_bound = 32
108
+ apl = 5
109
+ else:
110
+ raise ValueError("Wrong size. Should be either of these [small/medium/large].")
111
+ if context_length < cl_lower_bound:
112
+ raise ValueError(f"Context length should be at least {cl_lower_bound} if `size={size}`.")
113
+
114
+ cl = context_length if context_length % 2 == 0 else context_length - 1
115
+
116
+ pl = 2
117
+ while cl % pl == 0 and cl / pl >= 8:
118
+ pl = pl * 2
119
+
120
+ if ModelSize.SMALL.value in size.lower():
121
+ d_model = 2 * pl
122
+ num_layers = 3
123
+ elif ModelSize.MEDIUM.value in size.lower():
124
+ d_model = 16 * 2**apl
125
+ num_layers = 3
126
+ elif ModelSize.LARGE.value in size.lower():
127
+ d_model = 16 * 2**apl
128
+ num_layers = 5
129
+ else:
130
+ raise ValueError("Wrong size. Should be either of these [small/medium/large].")
131
+
132
+ ttm_config = TinyTimeMixerConfig(
133
+ context_length=cl,
134
+ prediction_length=prediction_length,
135
+ patch_length=pl,
136
+ patch_stride=pl,
137
+ d_model=d_model,
138
+ num_layers=num_layers,
139
+ decoder_num_layers=2,
140
+ decoder_d_model=d_model,
141
+ adaptive_patching_levels=apl,
142
+ dropout=0.2,
143
+ **kwargs,
144
+ )
145
+ model = TinyTimeMixerForPrediction(config=ttm_config)
146
+
147
+ return model
148
+
149
+ def get_frequency_token(token_name: str):
150
+ token = DEFAULT_FREQUENCY_MAPPING.get(token_name, None)
151
+ if token is not None:
152
+ return torch.tensor(token, dtype=torch.int)
153
+
154
+ # try to map as a frequency string
155
+ try:
156
+ token_name_offs = to_offset(token_name).freqstr
157
+ token = DEFAULT_FREQUENCY_MAPPING.get(token_name_offs, None)
158
+ if token is not None:
159
+ return torch.tensor(token, dtype=torch.int)
160
+ except ValueError:
161
+ # lastly try to map the timedelta to a frequency string
162
+ token_name_td = pd._libs.tslibs.timedeltas.Timedelta(token_name)
163
+ token_name_offs = to_offset(token_name_td).freqstr
164
+ token = DEFAULT_FREQUENCY_MAPPING.get(token_name_offs, None)
165
+ if token is not None:
166
+ return torch.tensor(token, dtype=torch.int)
167
+
168
+ token = DEFAULT_FREQUENCY_MAPPING["oov"]
169
+
170
+ return torch.tensor(token, dtype=torch.int)
171
+
172
+
173
+ class RMSELoss(torch.nn.Module):
174
+ def __init__(self):
175
+ super().__init__()
176
+ self.mse = torch.nn.MSELoss()
177
+
178
+ def forward(self,yhat,y):
179
+ return torch.sqrt(self.mse(yhat,y))
180
+
181
+ def get_model(
182
+ model_path: str,
183
+ model_name: str = "ttm",
184
+ context_length: Optional[int] = None,
185
+ prediction_length: Optional[int] = None,
186
+ freq_prefix_tuning: bool = False,
187
+ freq: Optional[str] = None,
188
+ prefer_l1_loss: bool = False,
189
+ prefer_longer_context: bool = True,
190
+ force_return: Optional[str] = None,
191
+ return_model_key: bool = False,
192
+ **kwargs,
193
+ ) -> Union[str, PreTrainedModel]:
194
+ """TTM Model card offers a suite of models with varying `context_length` and `prediction_length` combinations.
195
+ This wrapper automatically selects the right model based on the given input `context_length` and
196
+ `prediction_length` abstracting away the internal complexity.
197
+
198
+ Args:
199
+ model_path (str): HuggingFace model card path or local model path (Ex. ibm-granite/granite-timeseries-ttm-r2)
200
+ model_name (str, optional): Model name to use. Current allowed values: [ttm]. Defaults to "ttm".
201
+ context_length (int, optional): Input Context length or history. Defaults to None.
202
+ prediction_length (int, optional): Length of the forecast horizon. Defaults to None.
203
+ freq_prefix_tuning (bool, optional): If true, it will prefer TTM models that are trained with frequency prefix
204
+ tuning configuration. Defaults to None.
205
+ freq (str, optional): Resolution or frequency of the data. Defaults to None. Allowed values are as per the
206
+ `tsfm_public.toolkit.time_series_preprocessor.DEFAULT_FREQUENCY_MAPPING`.
207
+ See this for details: https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/time_series_preprocessor.py.
208
+ prefer_l1_loss (bool, optional): If True, it will prefer choosing models that were trained with L1 loss or
209
+ mean absolute error loss. Defaults to False.
210
+ prefer_longer_context (bool, optional): If True, it will prefer selecting model with longer context/history
211
+ Defaults to True.
212
+ force_return (str, optional): This is used to force the get_model() to return a TTM model even when the provided
213
+ configurations don't match with the existing TTMs. It gets the closest TTM possible. Allowed values are
214
+ ["zeropad"/"rolling"/"random_init_small"/"random_init_medium"/"random_init_large"/`None`].
215
+ "zeropad" = Returns a pre-trained TTM that has a context length higher than the input context length, hence,
216
+ the user must apply zero-padding to use the returned model.
217
+ "rolling" = Returns a pre-trained TTM that has a prediction length lower than the requested prediction length,
218
+ hence, the user must apply rolling technique to use the returned model to forecast to the desired length.
219
+ The `RecursivePredictor` class can be utilized in this scenario.
220
+ "random_init_small" = Returns a randomly initialized small TTM which must be trained before performing inference.
221
+ "random_init_medium" = Returns a randomly initialized medium TTM which must be trained before performing inference.
222
+ "random_init_large" = Returns a randomly initialized large TTM which must be trained before performing inference.
223
+ `None` = `force_return` is disable. Raises an error if no suitable model is found.
224
+ Defaults to None.
225
+ return_model_key (bool, optional): If True, only the TTM model name will be returned, instead of the actual model.
226
+ This does not downlaod the model, and only returns the name of the suitable model. Defaults to False.
227
+
228
+ Returns:
229
+ Union[str, PreTrainedModel]: Returns the Model, or the model name.
230
+ """
231
+ if model_name.lower() == "ttm":
232
+ model_path_type = check_ttm_model_path(model_path)
233
+ prediction_filter_length = None
234
+ ttm_model_revision = None
235
+ if model_path_type != 0:
236
+ if context_length is None or prediction_length is None:
237
+ raise ValueError(
238
+ "Provide `context_length` and `prediction_length` when `model_path` is a hugginface model path."
239
+ )
240
+
241
+ # Get freq
242
+ R = DEFAULT_FREQUENCY_MAPPING.get(freq, 0)
243
+
244
+ # Get list of all TTM models
245
+ '''
246
+ config_dir = resources.files("tsfm_public.resources.model_paths_config")
247
+ with open(os.path.join(config_dir, "ttm.yaml"), "r") as file:
248
+ model_revisions = yaml.safe_load(file)
249
+ '''
250
+ model_revisions = TTM_CONF ##TODO fix this
251
+ if model_path_type == 1 or model_path_type == 2:
252
+ available_models = model_revisions["ibm-granite-models"]
253
+ filtered_models = {}
254
+ if model_path_type == 1:
255
+ for k in available_models.keys():
256
+ if available_models[k]["release"].startswith("r1"):
257
+ filtered_models[k] = available_models[k]
258
+ if model_path_type == 2:
259
+ for k in available_models.keys():
260
+ if available_models[k]["release"].startswith("r2"):
261
+ filtered_models[k] = available_models[k]
262
+ available_models = filtered_models
263
+ else:
264
+ available_models = model_revisions["research-use-models"]
265
+
266
+ # Calculate shortest TTM context length, will be needed later
267
+ available_model_keys = list(available_models.keys())
268
+ available_ttm_context_lengths = [available_models[m]["context_length"] for m in available_model_keys]
269
+ shortest_ttm_context_length = min(available_ttm_context_lengths)
270
+
271
+ # Step 1: Filter models based on freq (R)
272
+ if model_path_type == 1 or model_path_type == 2:
273
+ # Only, r2.1 models are suitable for Daily or longer freq
274
+ if R >= 8:
275
+ models = [m for m in available_models.keys() if "r2.1" in available_models[m]["release"]]
276
+ else:
277
+ models = list(available_models.keys())
278
+ else:
279
+ models = list(available_models.keys())
280
+
281
+ # Step 2: Filter models by context length constraint
282
+ # Choose all models which have lower context length than
283
+ # the input available length
284
+ selected_models_ = []
285
+ if context_length < shortest_ttm_context_length:
286
+ if force_return is None:
287
+ raise ValueError(
288
+ "Requested context length is less than the "
289
+ f"shortest context length for TTMs: {shortest_ttm_context_length}. "
290
+ "Set `force_return=zeropad` to get a TTM with longer context."
291
+ )
292
+ elif force_return == ForceReturn.ZEROPAD.value: # force_return.startswith("zero"):
293
+ # Keep all models. Zero-padding must be done outside.
294
+ selected_models_ = models
295
+ else:
296
+ lowest_context_length = np.inf
297
+ shortest_context_models = []
298
+ for m in models:
299
+ if available_models[m]["context_length"] <= context_length:
300
+ selected_models_.append(m)
301
+ if available_models[m]["context_length"] <= lowest_context_length:
302
+ lowest_context_length = available_models[m]["context_length"]
303
+ shortest_context_models.append(m)
304
+
305
+ if len(selected_models_) == 0:
306
+ if force_return is None:
307
+ raise ValueError(
308
+ "Could not find a TTM with `context_length` shorter "
309
+ f"than the requested context length = {context_length}. "
310
+ "Set `force_return=zeropad` to get a TTM with longer context."
311
+ )
312
+ elif force_return == ForceReturn.ZEROPAD.value: # force_return.startswith("zero"):
313
+ selected_models_ = shortest_context_models
314
+ models = selected_models_
315
+
316
+ # Step 3: Apply L1 and FT preferences only when context_length <= 512
317
+ if len(models) > 0:
318
+ if prefer_longer_context:
319
+ reference_context = min(
320
+ context_length, max([available_models[m]["context_length"] for m in models])
321
+ )
322
+ else:
323
+ reference_context = min([available_models[m]["context_length"] for m in models])
324
+ if reference_context <= TTM_LOW_RESOLUTION_MODELS_MAX_CONTEXT:
325
+ # Step 3a: Filter based on L1 preference
326
+ if prefer_l1_loss:
327
+ l1_models = [m for m in models if "-l1-" in m]
328
+ if l1_models:
329
+ models = l1_models
330
+
331
+ # Step 3b: Filter based on frequency tuning indicator preference
332
+ if freq_prefix_tuning:
333
+ ft_models = [m for m in models if "-ft-" in m]
334
+ if ft_models:
335
+ models = ft_models
336
+
337
+ # Step 4: Sort models by context length (descending if prefer_longer_context else ascending)
338
+ # Step 5: Sub-sort for each context length by forecast length in ascending order
339
+ if len(models) > 0:
340
+ sign = -1 if prefer_longer_context else 1
341
+ models = sorted(
342
+ models,
343
+ key=lambda m: (
344
+ sign * int(available_models[m]["context_length"]),
345
+ int(available_models[m]["prediction_length"]),
346
+ ),
347
+ )
348
+
349
+ # Step 6: Remove models whose forecast length is less than input forecast length
350
+ # Because, this needs recursion which has to be handled outside this get_model() utility
351
+ if len(models) > 0:
352
+ selected_models_ = []
353
+ highest_prediction_length = -np.inf
354
+ highest_prediction_model = None
355
+ for m in models:
356
+ if int(available_models[m]["prediction_length"]) >= prediction_length:
357
+ selected_models_.append(m)
358
+ if available_models[m]["prediction_length"] > highest_prediction_length:
359
+ highest_prediction_length = available_models[m]["prediction_length"]
360
+ highest_prediction_model = m
361
+ if len(selected_models_) == 0:
362
+ if force_return is None:
363
+ raise ValueError(
364
+ "Could not find a TTM with `prediction_length` higher "
365
+ f"than the requested prediction length = {prediction_length}. "
366
+ "Set `force_return=rolling` to get a TTM with shorted prediction "
367
+ "length. Rolling must be done outside."
368
+ )
369
+ elif force_return == ForceReturn.ROLLING.value: # force_return.startswith("roll"):
370
+ selected_models_.append(highest_prediction_model)
371
+ models = selected_models_
372
+
373
+ # Step 7: Do not allow unknow frequency
374
+ if freq_prefix_tuning and (freq is not None) and (freq not in DEFAULT_FREQUENCY_MAPPING.keys()):
375
+ models = []
376
+
377
+ # Step 8: Return the first available model or a dummy model if none found
378
+ if len(models) == 0:
379
+ if force_return is None:
380
+ raise ValueError(
381
+ "No suitable pre-trained TTM was found! Set `force_return` to "
382
+ "random_init_small/random_init_medium/random_init_large "
383
+ "to get a randomly initialized TTM of size small/medium/large "
384
+ "respectively."
385
+ )
386
+ elif force_return in [
387
+ ForceReturn.RANDOM_INIT_SMALL.value,
388
+ ForceReturn.RANDOM_INIT_MEDIUM.value,
389
+ ForceReturn.RANDOM_INIT_LARGE.value,
390
+ ]: # "sma" in force_return.lower() or "med" in force_return.lower() or "lar" in force_return.lower():
391
+ model = get_random_ttm(context_length, prediction_length, size=force_return)
392
+ if return_model_key:
393
+ model_key = force_return.split("_")[-1]
394
+ return f"TTM({model_key})"
395
+ else:
396
+ return model
397
+ else:
398
+ raise ValueError(
399
+ "Could not find a suitable TTM for the given "
400
+ f"context_length = {context_length}, and "
401
+ f"prediction_length = {prediction_length}. "
402
+ "Check the model card for more information. "
403
+ "set `force_return` properly (see the docstrings) "
404
+ "if you want to get a randomly initialized TTM."
405
+ )
406
+ else:
407
+ model_key = models[0]
408
+
409
+ # selected_context_length = available_models[model_key]["context_length"]
410
+ selected_prediction_length = available_models[model_key]["prediction_length"]
411
+ if selected_prediction_length > prediction_length:
412
+ prediction_filter_length = prediction_length
413
+
414
+ # if selected_prediction_length < prediction_length:
415
+ # LOGGER.warning(
416
+ # "Selected `prediction_length` is shorter than the requested "
417
+ # "length since no suitable model could be found. You can use "
418
+ # " `RecursivePredictor` for forecast to the desired length."
419
+ # )
420
+
421
+ ttm_model_revision = available_models[model_key]["revision"]
422
+
423
+ else:
424
+ prediction_filter_length = prediction_length
425
+
426
+ if return_model_key:
427
+ return model_key
428
+ # Load model
429
+ model = TinyTimeMixerForPrediction.from_pretrained(
430
+ model_path,
431
+ revision=ttm_model_revision,
432
+ prediction_filter_length=prediction_filter_length,
433
+ **kwargs,
434
+ )
435
+ else:
436
+ raise ValueError("Currently supported values for `model_name` = 'ttm'.")
437
+
438
+ return model