dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# Default
|
|
2
|
+
import os
|
|
3
|
+
import yaml
|
|
4
|
+
import enum
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from importlib import resources
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
from pandas.tseries.frequencies import to_offset
|
|
11
|
+
|
|
12
|
+
# Custom
|
|
13
|
+
from .configuration_tinytimemixer import TinyTimeMixerConfig
|
|
14
|
+
from .modeling_tinytimemixer import TinyTimeMixerForPrediction
|
|
15
|
+
from .consts import DEFAULT_FREQUENCY_MAPPING, TTM_LOW_RESOLUTION_MODELS_MAX_CONTEXT
|
|
16
|
+
|
|
17
|
+
# Hugging face
|
|
18
|
+
from transformers import PreTrainedModel
|
|
19
|
+
|
|
20
|
+
# PyTorch
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
##TODO fix
|
|
24
|
+
TTM_CONF = {'ibm-granite-models': {'512-96-r1': {'release': 'r1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r1', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '1024-96-r1': {'release': 'r1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r1', 'revision': '1024_96_v1', 'context_length': 1024, 'prediction_length': 96}, '512-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '512-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-192-r2', 'context_length': 512, 'prediction_length': 192}, '512-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-336-r2', 'context_length': 512, 'prediction_length': 336}, '512-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-720-r2', 'context_length': 512, 'prediction_length': 720}, '1024-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-96-r2', 'context_length': 1024, 'prediction_length': 96}, '1024-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-192-r2', 'context_length': 1024, 'prediction_length': 192}, '1024-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-336-r2', 'context_length': 1024, 'prediction_length': 336}, '1024-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1024-720-r2', 'context_length': 1024, 'prediction_length': 720}, '1536-96-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-96-r2', 'context_length': 1536, 'prediction_length': 96}, '1536-192-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-192-r2', 'context_length': 1536, 'prediction_length': 192}, '1536-336-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-336-r2', 'context_length': 1536, 'prediction_length': 336}, '1536-720-r2': {'release': 'r2', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '1536-720-r2', 'context_length': 1536, 'prediction_length': 720}, '52-16-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '52-16-ft-r2.1', 'context_length': 52, 'prediction_length': 16}, '52-16-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '52-16-ft-l1-r2.1', 'context_length': 52, 'prediction_length': 16}, '90-30-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '90-30-ft-r2.1', 'context_length': 90, 'prediction_length': 30}, '90-30-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '90-30-ft-l1-r2.1', 'context_length': 90, 'prediction_length': 30}, '180-60-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '180-60-ft-l1-r2.1', 'context_length': 180, 'prediction_length': 60}, '360-60-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '360-60-ft-l1-r2.1', 'context_length': 360, 'prediction_length': 60}, '512-48-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-48-ft-r2.1', 'context_length': 512, 'prediction_length': 48}, '512-48-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-48-ft-l1-r2.1', 'context_length': 512, 'prediction_length': 48}, '512-96-ft-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-96-ft-r2.1', 'context_length': 512, 'prediction_length': 96}, '512-96-ft-l1-r2.1': {'release': 'r2.1', 'model_card': 'ibm-granite/granite-timeseries-ttm-r2', 'revision': '512-96-ft-l1-r2.1', 'context_length': 512, 'prediction_length': 96}}, 'research-use-models': {'512-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': 'main', 'context_length': 512, 'prediction_length': 96}, '512-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-192-ft-r2', 'context_length': 512, 'prediction_length': 192}, '512-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-336-ft-r2', 'context_length': 512, 'prediction_length': 336}, '512-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '512-720-ft-r2', 'context_length': 512, 'prediction_length': 720}, '1024-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-96-ft-r2', 'context_length': 1024, 'prediction_length': 96}, '1024-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-192-ft-r2', 'context_length': 1024, 'prediction_length': 192}, '1024-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-336-ft-r2', 'context_length': 1024, 'prediction_length': 336}, '1024-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1024-720-ft-r2', 'context_length': 1024, 'prediction_length': 720}, '1536-96-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-96-ft-r2', 'context_length': 1536, 'prediction_length': 96}, '1536-192-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-192-ft-r2', 'context_length': 1536, 'prediction_length': 192}, '1536-336-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-336-ft-r2', 'context_length': 1536, 'prediction_length': 336}, '1536-720-ft-r2': {'release': 'r2', 'model_card': 'ibm-research/ttm-research-r2', 'revision': '1536-720-ft-r2', 'context_length': 1536, 'prediction_length': 720}}}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ForceReturn(enum.Enum):
|
|
28
|
+
"""`Enum` for the `force_return` parameter in the `get_model` function.
|
|
29
|
+
|
|
30
|
+
"zeropad" = Returns a pre-trained TTM that has a context length higher than the input context length, hence,
|
|
31
|
+
the user must apply zero-padding to use the returned model.
|
|
32
|
+
"rolling" = Returns a pre-trained TTM that has a prediction length lower than the requested prediction length,
|
|
33
|
+
hence, the user must apply rolling technique to use the returned model to forecast to the desired length.
|
|
34
|
+
The `RecursivePredictor` class can be utilized in this scenario.
|
|
35
|
+
"random_init_small" = Returns a randomly initialized small TTM which must be trained before performing inference.
|
|
36
|
+
"random_init_medium" = Returns a randomly initialized medium TTM which must be trained before performing inference.
|
|
37
|
+
"random_init_large" = Returns a randomly initialized large TTM which must be trained before performing inference.
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
ZEROPAD = "zeropad"
|
|
42
|
+
ROLLING = "rolling"
|
|
43
|
+
RANDOM_INIT_SMALL = "random_init_small"
|
|
44
|
+
RANDOM_INIT_MEDIUM = "random_init_medium"
|
|
45
|
+
RANDOM_INIT_LARGE = "random_init_large"
|
|
46
|
+
|
|
47
|
+
class ModelSize(enum.Enum):
|
|
48
|
+
"""`Enum` for the `size` parameter in the `get_random_ttm` function."""
|
|
49
|
+
|
|
50
|
+
SMALL = "small"
|
|
51
|
+
MEDIUM = "medium"
|
|
52
|
+
LARGE = "large"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def check_ttm_model_path(model_path):
|
|
56
|
+
if (
|
|
57
|
+
"ibm/TTM" in model_path
|
|
58
|
+
or "ibm-granite/granite-timeseries-ttm-r1" in model_path
|
|
59
|
+
or "ibm-granite/granite-timeseries-ttm-v1" in model_path
|
|
60
|
+
or "ibm-granite/granite-timeseries-ttm-1m" in model_path
|
|
61
|
+
):
|
|
62
|
+
return 1
|
|
63
|
+
elif "ibm-granite/granite-timeseries-ttm-r2" in model_path:
|
|
64
|
+
return 2
|
|
65
|
+
elif "ibm-research/ttm-research-r2" in model_path:
|
|
66
|
+
return 3
|
|
67
|
+
else:
|
|
68
|
+
return 0
|
|
69
|
+
|
|
70
|
+
def count_parameters(model: torch.nn.Module) -> int:
|
|
71
|
+
"""Count trainable parameters in a model
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model (torch.nn.Module): The model.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
int: Number of parameters requiring gradients.
|
|
78
|
+
"""
|
|
79
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
80
|
+
|
|
81
|
+
def get_random_ttm(
|
|
82
|
+
context_length: int, prediction_length: int, size: str = ModelSize.SMALL.value, **kwargs
|
|
83
|
+
) -> PreTrainedModel:
|
|
84
|
+
"""Get a TTM with random weights.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
context_length (int): Context length or history.
|
|
88
|
+
prediction_length (int): Prediction length or forecast horizon.
|
|
89
|
+
size (str, optional): Size of the desired TTM (small/medium/large). Defaults to "small".
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If wrong size is provided.
|
|
93
|
+
ValueError: Context length should be at least 4 if `size=small`,
|
|
94
|
+
or at least 16 if `size=medium`,
|
|
95
|
+
or at least 32 if `size=large`.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
PreTrainedModel: TTM model with randomly initialized weights.
|
|
99
|
+
"""
|
|
100
|
+
if ModelSize.SMALL.value in size.lower():
|
|
101
|
+
cl_lower_bound = 4
|
|
102
|
+
apl = 0
|
|
103
|
+
elif ModelSize.MEDIUM.value in size.lower():
|
|
104
|
+
cl_lower_bound = 16
|
|
105
|
+
apl = 3
|
|
106
|
+
elif ModelSize.LARGE.value in size.lower():
|
|
107
|
+
cl_lower_bound = 32
|
|
108
|
+
apl = 5
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError("Wrong size. Should be either of these [small/medium/large].")
|
|
111
|
+
if context_length < cl_lower_bound:
|
|
112
|
+
raise ValueError(f"Context length should be at least {cl_lower_bound} if `size={size}`.")
|
|
113
|
+
|
|
114
|
+
cl = context_length if context_length % 2 == 0 else context_length - 1
|
|
115
|
+
|
|
116
|
+
pl = 2
|
|
117
|
+
while cl % pl == 0 and cl / pl >= 8:
|
|
118
|
+
pl = pl * 2
|
|
119
|
+
|
|
120
|
+
if ModelSize.SMALL.value in size.lower():
|
|
121
|
+
d_model = 2 * pl
|
|
122
|
+
num_layers = 3
|
|
123
|
+
elif ModelSize.MEDIUM.value in size.lower():
|
|
124
|
+
d_model = 16 * 2**apl
|
|
125
|
+
num_layers = 3
|
|
126
|
+
elif ModelSize.LARGE.value in size.lower():
|
|
127
|
+
d_model = 16 * 2**apl
|
|
128
|
+
num_layers = 5
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Wrong size. Should be either of these [small/medium/large].")
|
|
131
|
+
|
|
132
|
+
ttm_config = TinyTimeMixerConfig(
|
|
133
|
+
context_length=cl,
|
|
134
|
+
prediction_length=prediction_length,
|
|
135
|
+
patch_length=pl,
|
|
136
|
+
patch_stride=pl,
|
|
137
|
+
d_model=d_model,
|
|
138
|
+
num_layers=num_layers,
|
|
139
|
+
decoder_num_layers=2,
|
|
140
|
+
decoder_d_model=d_model,
|
|
141
|
+
adaptive_patching_levels=apl,
|
|
142
|
+
dropout=0.2,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
model = TinyTimeMixerForPrediction(config=ttm_config)
|
|
146
|
+
|
|
147
|
+
return model
|
|
148
|
+
|
|
149
|
+
def get_frequency_token(token_name: str):
|
|
150
|
+
token = DEFAULT_FREQUENCY_MAPPING.get(token_name, None)
|
|
151
|
+
if token is not None:
|
|
152
|
+
return torch.tensor(token, dtype=torch.int)
|
|
153
|
+
|
|
154
|
+
# try to map as a frequency string
|
|
155
|
+
try:
|
|
156
|
+
token_name_offs = to_offset(token_name).freqstr
|
|
157
|
+
token = DEFAULT_FREQUENCY_MAPPING.get(token_name_offs, None)
|
|
158
|
+
if token is not None:
|
|
159
|
+
return torch.tensor(token, dtype=torch.int)
|
|
160
|
+
except ValueError:
|
|
161
|
+
# lastly try to map the timedelta to a frequency string
|
|
162
|
+
token_name_td = pd._libs.tslibs.timedeltas.Timedelta(token_name)
|
|
163
|
+
token_name_offs = to_offset(token_name_td).freqstr
|
|
164
|
+
token = DEFAULT_FREQUENCY_MAPPING.get(token_name_offs, None)
|
|
165
|
+
if token is not None:
|
|
166
|
+
return torch.tensor(token, dtype=torch.int)
|
|
167
|
+
|
|
168
|
+
token = DEFAULT_FREQUENCY_MAPPING["oov"]
|
|
169
|
+
|
|
170
|
+
return torch.tensor(token, dtype=torch.int)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class RMSELoss(torch.nn.Module):
|
|
174
|
+
def __init__(self):
|
|
175
|
+
super().__init__()
|
|
176
|
+
self.mse = torch.nn.MSELoss()
|
|
177
|
+
|
|
178
|
+
def forward(self,yhat,y):
|
|
179
|
+
return torch.sqrt(self.mse(yhat,y))
|
|
180
|
+
|
|
181
|
+
def get_model(
|
|
182
|
+
model_path: str,
|
|
183
|
+
model_name: str = "ttm",
|
|
184
|
+
context_length: Optional[int] = None,
|
|
185
|
+
prediction_length: Optional[int] = None,
|
|
186
|
+
freq_prefix_tuning: bool = False,
|
|
187
|
+
freq: Optional[str] = None,
|
|
188
|
+
prefer_l1_loss: bool = False,
|
|
189
|
+
prefer_longer_context: bool = True,
|
|
190
|
+
force_return: Optional[str] = None,
|
|
191
|
+
return_model_key: bool = False,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> Union[str, PreTrainedModel]:
|
|
194
|
+
"""TTM Model card offers a suite of models with varying `context_length` and `prediction_length` combinations.
|
|
195
|
+
This wrapper automatically selects the right model based on the given input `context_length` and
|
|
196
|
+
`prediction_length` abstracting away the internal complexity.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
model_path (str): HuggingFace model card path or local model path (Ex. ibm-granite/granite-timeseries-ttm-r2)
|
|
200
|
+
model_name (str, optional): Model name to use. Current allowed values: [ttm]. Defaults to "ttm".
|
|
201
|
+
context_length (int, optional): Input Context length or history. Defaults to None.
|
|
202
|
+
prediction_length (int, optional): Length of the forecast horizon. Defaults to None.
|
|
203
|
+
freq_prefix_tuning (bool, optional): If true, it will prefer TTM models that are trained with frequency prefix
|
|
204
|
+
tuning configuration. Defaults to None.
|
|
205
|
+
freq (str, optional): Resolution or frequency of the data. Defaults to None. Allowed values are as per the
|
|
206
|
+
`tsfm_public.toolkit.time_series_preprocessor.DEFAULT_FREQUENCY_MAPPING`.
|
|
207
|
+
See this for details: https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/time_series_preprocessor.py.
|
|
208
|
+
prefer_l1_loss (bool, optional): If True, it will prefer choosing models that were trained with L1 loss or
|
|
209
|
+
mean absolute error loss. Defaults to False.
|
|
210
|
+
prefer_longer_context (bool, optional): If True, it will prefer selecting model with longer context/history
|
|
211
|
+
Defaults to True.
|
|
212
|
+
force_return (str, optional): This is used to force the get_model() to return a TTM model even when the provided
|
|
213
|
+
configurations don't match with the existing TTMs. It gets the closest TTM possible. Allowed values are
|
|
214
|
+
["zeropad"/"rolling"/"random_init_small"/"random_init_medium"/"random_init_large"/`None`].
|
|
215
|
+
"zeropad" = Returns a pre-trained TTM that has a context length higher than the input context length, hence,
|
|
216
|
+
the user must apply zero-padding to use the returned model.
|
|
217
|
+
"rolling" = Returns a pre-trained TTM that has a prediction length lower than the requested prediction length,
|
|
218
|
+
hence, the user must apply rolling technique to use the returned model to forecast to the desired length.
|
|
219
|
+
The `RecursivePredictor` class can be utilized in this scenario.
|
|
220
|
+
"random_init_small" = Returns a randomly initialized small TTM which must be trained before performing inference.
|
|
221
|
+
"random_init_medium" = Returns a randomly initialized medium TTM which must be trained before performing inference.
|
|
222
|
+
"random_init_large" = Returns a randomly initialized large TTM which must be trained before performing inference.
|
|
223
|
+
`None` = `force_return` is disable. Raises an error if no suitable model is found.
|
|
224
|
+
Defaults to None.
|
|
225
|
+
return_model_key (bool, optional): If True, only the TTM model name will be returned, instead of the actual model.
|
|
226
|
+
This does not downlaod the model, and only returns the name of the suitable model. Defaults to False.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Union[str, PreTrainedModel]: Returns the Model, or the model name.
|
|
230
|
+
"""
|
|
231
|
+
if model_name.lower() == "ttm":
|
|
232
|
+
model_path_type = check_ttm_model_path(model_path)
|
|
233
|
+
prediction_filter_length = None
|
|
234
|
+
ttm_model_revision = None
|
|
235
|
+
if model_path_type != 0:
|
|
236
|
+
if context_length is None or prediction_length is None:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Provide `context_length` and `prediction_length` when `model_path` is a hugginface model path."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Get freq
|
|
242
|
+
R = DEFAULT_FREQUENCY_MAPPING.get(freq, 0)
|
|
243
|
+
|
|
244
|
+
# Get list of all TTM models
|
|
245
|
+
'''
|
|
246
|
+
config_dir = resources.files("tsfm_public.resources.model_paths_config")
|
|
247
|
+
with open(os.path.join(config_dir, "ttm.yaml"), "r") as file:
|
|
248
|
+
model_revisions = yaml.safe_load(file)
|
|
249
|
+
'''
|
|
250
|
+
model_revisions = TTM_CONF ##TODO fix this
|
|
251
|
+
if model_path_type == 1 or model_path_type == 2:
|
|
252
|
+
available_models = model_revisions["ibm-granite-models"]
|
|
253
|
+
filtered_models = {}
|
|
254
|
+
if model_path_type == 1:
|
|
255
|
+
for k in available_models.keys():
|
|
256
|
+
if available_models[k]["release"].startswith("r1"):
|
|
257
|
+
filtered_models[k] = available_models[k]
|
|
258
|
+
if model_path_type == 2:
|
|
259
|
+
for k in available_models.keys():
|
|
260
|
+
if available_models[k]["release"].startswith("r2"):
|
|
261
|
+
filtered_models[k] = available_models[k]
|
|
262
|
+
available_models = filtered_models
|
|
263
|
+
else:
|
|
264
|
+
available_models = model_revisions["research-use-models"]
|
|
265
|
+
|
|
266
|
+
# Calculate shortest TTM context length, will be needed later
|
|
267
|
+
available_model_keys = list(available_models.keys())
|
|
268
|
+
available_ttm_context_lengths = [available_models[m]["context_length"] for m in available_model_keys]
|
|
269
|
+
shortest_ttm_context_length = min(available_ttm_context_lengths)
|
|
270
|
+
|
|
271
|
+
# Step 1: Filter models based on freq (R)
|
|
272
|
+
if model_path_type == 1 or model_path_type == 2:
|
|
273
|
+
# Only, r2.1 models are suitable for Daily or longer freq
|
|
274
|
+
if R >= 8:
|
|
275
|
+
models = [m for m in available_models.keys() if "r2.1" in available_models[m]["release"]]
|
|
276
|
+
else:
|
|
277
|
+
models = list(available_models.keys())
|
|
278
|
+
else:
|
|
279
|
+
models = list(available_models.keys())
|
|
280
|
+
|
|
281
|
+
# Step 2: Filter models by context length constraint
|
|
282
|
+
# Choose all models which have lower context length than
|
|
283
|
+
# the input available length
|
|
284
|
+
selected_models_ = []
|
|
285
|
+
if context_length < shortest_ttm_context_length:
|
|
286
|
+
if force_return is None:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"Requested context length is less than the "
|
|
289
|
+
f"shortest context length for TTMs: {shortest_ttm_context_length}. "
|
|
290
|
+
"Set `force_return=zeropad` to get a TTM with longer context."
|
|
291
|
+
)
|
|
292
|
+
elif force_return == ForceReturn.ZEROPAD.value: # force_return.startswith("zero"):
|
|
293
|
+
# Keep all models. Zero-padding must be done outside.
|
|
294
|
+
selected_models_ = models
|
|
295
|
+
else:
|
|
296
|
+
lowest_context_length = np.inf
|
|
297
|
+
shortest_context_models = []
|
|
298
|
+
for m in models:
|
|
299
|
+
if available_models[m]["context_length"] <= context_length:
|
|
300
|
+
selected_models_.append(m)
|
|
301
|
+
if available_models[m]["context_length"] <= lowest_context_length:
|
|
302
|
+
lowest_context_length = available_models[m]["context_length"]
|
|
303
|
+
shortest_context_models.append(m)
|
|
304
|
+
|
|
305
|
+
if len(selected_models_) == 0:
|
|
306
|
+
if force_return is None:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
"Could not find a TTM with `context_length` shorter "
|
|
309
|
+
f"than the requested context length = {context_length}. "
|
|
310
|
+
"Set `force_return=zeropad` to get a TTM with longer context."
|
|
311
|
+
)
|
|
312
|
+
elif force_return == ForceReturn.ZEROPAD.value: # force_return.startswith("zero"):
|
|
313
|
+
selected_models_ = shortest_context_models
|
|
314
|
+
models = selected_models_
|
|
315
|
+
|
|
316
|
+
# Step 3: Apply L1 and FT preferences only when context_length <= 512
|
|
317
|
+
if len(models) > 0:
|
|
318
|
+
if prefer_longer_context:
|
|
319
|
+
reference_context = min(
|
|
320
|
+
context_length, max([available_models[m]["context_length"] for m in models])
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
reference_context = min([available_models[m]["context_length"] for m in models])
|
|
324
|
+
if reference_context <= TTM_LOW_RESOLUTION_MODELS_MAX_CONTEXT:
|
|
325
|
+
# Step 3a: Filter based on L1 preference
|
|
326
|
+
if prefer_l1_loss:
|
|
327
|
+
l1_models = [m for m in models if "-l1-" in m]
|
|
328
|
+
if l1_models:
|
|
329
|
+
models = l1_models
|
|
330
|
+
|
|
331
|
+
# Step 3b: Filter based on frequency tuning indicator preference
|
|
332
|
+
if freq_prefix_tuning:
|
|
333
|
+
ft_models = [m for m in models if "-ft-" in m]
|
|
334
|
+
if ft_models:
|
|
335
|
+
models = ft_models
|
|
336
|
+
|
|
337
|
+
# Step 4: Sort models by context length (descending if prefer_longer_context else ascending)
|
|
338
|
+
# Step 5: Sub-sort for each context length by forecast length in ascending order
|
|
339
|
+
if len(models) > 0:
|
|
340
|
+
sign = -1 if prefer_longer_context else 1
|
|
341
|
+
models = sorted(
|
|
342
|
+
models,
|
|
343
|
+
key=lambda m: (
|
|
344
|
+
sign * int(available_models[m]["context_length"]),
|
|
345
|
+
int(available_models[m]["prediction_length"]),
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Step 6: Remove models whose forecast length is less than input forecast length
|
|
350
|
+
# Because, this needs recursion which has to be handled outside this get_model() utility
|
|
351
|
+
if len(models) > 0:
|
|
352
|
+
selected_models_ = []
|
|
353
|
+
highest_prediction_length = -np.inf
|
|
354
|
+
highest_prediction_model = None
|
|
355
|
+
for m in models:
|
|
356
|
+
if int(available_models[m]["prediction_length"]) >= prediction_length:
|
|
357
|
+
selected_models_.append(m)
|
|
358
|
+
if available_models[m]["prediction_length"] > highest_prediction_length:
|
|
359
|
+
highest_prediction_length = available_models[m]["prediction_length"]
|
|
360
|
+
highest_prediction_model = m
|
|
361
|
+
if len(selected_models_) == 0:
|
|
362
|
+
if force_return is None:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"Could not find a TTM with `prediction_length` higher "
|
|
365
|
+
f"than the requested prediction length = {prediction_length}. "
|
|
366
|
+
"Set `force_return=rolling` to get a TTM with shorted prediction "
|
|
367
|
+
"length. Rolling must be done outside."
|
|
368
|
+
)
|
|
369
|
+
elif force_return == ForceReturn.ROLLING.value: # force_return.startswith("roll"):
|
|
370
|
+
selected_models_.append(highest_prediction_model)
|
|
371
|
+
models = selected_models_
|
|
372
|
+
|
|
373
|
+
# Step 7: Do not allow unknow frequency
|
|
374
|
+
if freq_prefix_tuning and (freq is not None) and (freq not in DEFAULT_FREQUENCY_MAPPING.keys()):
|
|
375
|
+
models = []
|
|
376
|
+
|
|
377
|
+
# Step 8: Return the first available model or a dummy model if none found
|
|
378
|
+
if len(models) == 0:
|
|
379
|
+
if force_return is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"No suitable pre-trained TTM was found! Set `force_return` to "
|
|
382
|
+
"random_init_small/random_init_medium/random_init_large "
|
|
383
|
+
"to get a randomly initialized TTM of size small/medium/large "
|
|
384
|
+
"respectively."
|
|
385
|
+
)
|
|
386
|
+
elif force_return in [
|
|
387
|
+
ForceReturn.RANDOM_INIT_SMALL.value,
|
|
388
|
+
ForceReturn.RANDOM_INIT_MEDIUM.value,
|
|
389
|
+
ForceReturn.RANDOM_INIT_LARGE.value,
|
|
390
|
+
]: # "sma" in force_return.lower() or "med" in force_return.lower() or "lar" in force_return.lower():
|
|
391
|
+
model = get_random_ttm(context_length, prediction_length, size=force_return)
|
|
392
|
+
if return_model_key:
|
|
393
|
+
model_key = force_return.split("_")[-1]
|
|
394
|
+
return f"TTM({model_key})"
|
|
395
|
+
else:
|
|
396
|
+
return model
|
|
397
|
+
else:
|
|
398
|
+
raise ValueError(
|
|
399
|
+
"Could not find a suitable TTM for the given "
|
|
400
|
+
f"context_length = {context_length}, and "
|
|
401
|
+
f"prediction_length = {prediction_length}. "
|
|
402
|
+
"Check the model card for more information. "
|
|
403
|
+
"set `force_return` properly (see the docstrings) "
|
|
404
|
+
"if you want to get a randomly initialized TTM."
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
model_key = models[0]
|
|
408
|
+
|
|
409
|
+
# selected_context_length = available_models[model_key]["context_length"]
|
|
410
|
+
selected_prediction_length = available_models[model_key]["prediction_length"]
|
|
411
|
+
if selected_prediction_length > prediction_length:
|
|
412
|
+
prediction_filter_length = prediction_length
|
|
413
|
+
|
|
414
|
+
# if selected_prediction_length < prediction_length:
|
|
415
|
+
# LOGGER.warning(
|
|
416
|
+
# "Selected `prediction_length` is shorter than the requested "
|
|
417
|
+
# "length since no suitable model could be found. You can use "
|
|
418
|
+
# " `RecursivePredictor` for forecast to the desired length."
|
|
419
|
+
# )
|
|
420
|
+
|
|
421
|
+
ttm_model_revision = available_models[model_key]["revision"]
|
|
422
|
+
|
|
423
|
+
else:
|
|
424
|
+
prediction_filter_length = prediction_length
|
|
425
|
+
|
|
426
|
+
if return_model_key:
|
|
427
|
+
return model_key
|
|
428
|
+
# Load model
|
|
429
|
+
model = TinyTimeMixerForPrediction.from_pretrained(
|
|
430
|
+
model_path,
|
|
431
|
+
revision=ttm_model_revision,
|
|
432
|
+
prediction_filter_length=prediction_filter_length,
|
|
433
|
+
**kwargs,
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
raise ValueError("Currently supported values for `model_name` = 'ttm'.")
|
|
437
|
+
|
|
438
|
+
return model
|