tirex-mirror 2025.10.2__tar.gz → 2025.10.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.10.2/src/tirex_mirror.egg-info → tirex_mirror-2025.10.7}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/pyproject.toml +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/forecast.py +163 -5
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/gluon.py +2 -2
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/hf_data.py +2 -2
- tirex_mirror-2025.10.7/src/tirex/api_adapter/standard_adapter.py +90 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/base.py +26 -5
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/cell.py +66 -77
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/tirex.py +17 -0
- tirex_mirror-2025.10.7/src/tirex/util.py +617 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/SOURCES.txt +2 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_forecast.py +17 -6
- tirex_mirror-2025.10.7/tests/test_standard_adapter.py +166 -0
- tirex_mirror-2025.10.7/tests/test_util_freq.py +112 -0
- tirex_mirror-2025.10.2/src/tirex/api_adapter/standard_adapter.py +0 -67
- tirex_mirror-2025.10.2/src/tirex/util.py +0 -13
- tirex_mirror-2025.10.2/tests/test_standard_adapter.py +0 -183
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/LICENSE +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/MANIFEST.in +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/README.md +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/setup.cfg +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/patcher.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/layer.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_chronos_zs.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_slstm_torch_vs_cuda.py +0 -0
|
@@ -2,15 +2,22 @@
|
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from
|
|
5
|
+
from functools import partial
|
|
6
|
+
from math import ceil
|
|
7
|
+
from typing import Literal, Optional
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
10
|
|
|
11
|
+
from tirex.util import frequency_resample
|
|
12
|
+
|
|
9
13
|
from .standard_adapter import ContextType, get_batches
|
|
10
14
|
|
|
11
15
|
DEF_TARGET_COLUMN = "target"
|
|
12
16
|
DEF_META_COLUMNS = ("start", "item_id")
|
|
13
17
|
|
|
18
|
+
# Allowed resampling strategies (extend as new strategies are implemented)
|
|
19
|
+
RESAMPLE_STRATEGIES: list[str] = ["frequency"]
|
|
20
|
+
|
|
14
21
|
|
|
15
22
|
def _format_output(
|
|
16
23
|
quantiles: torch.Tensor,
|
|
@@ -33,6 +40,27 @@ def _format_output(
|
|
|
33
40
|
raise ValueError(f"Invalid output type: {output_type}")
|
|
34
41
|
|
|
35
42
|
|
|
43
|
+
def _pad_time_series_batch(
|
|
44
|
+
batch_series: list[torch.Tensor],
|
|
45
|
+
max_length: int,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
if not batch_series:
|
|
48
|
+
return torch.empty((0, max_length))
|
|
49
|
+
|
|
50
|
+
first = batch_series[0]
|
|
51
|
+
dtype = first.dtype if first.is_floating_point() else torch.float32
|
|
52
|
+
device = first.device
|
|
53
|
+
|
|
54
|
+
padded = torch.full((len(batch_series), max_length), float("nan"), dtype=dtype, device=device)
|
|
55
|
+
|
|
56
|
+
for idx, series in enumerate(batch_series):
|
|
57
|
+
series = series.to(padded.dtype)
|
|
58
|
+
series_len = series.shape[0]
|
|
59
|
+
padded[idx, max_length - series_len :] = series
|
|
60
|
+
|
|
61
|
+
return padded
|
|
62
|
+
|
|
63
|
+
|
|
36
64
|
def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
|
|
37
65
|
for batch_ctx, batch_meta in batches:
|
|
38
66
|
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
|
|
@@ -45,7 +73,105 @@ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwar
|
|
|
45
73
|
)
|
|
46
74
|
|
|
47
75
|
|
|
48
|
-
def
|
|
76
|
+
def _call_fc_with_padding(fc_func, batch_series: list[torch.Tensor], **predict_kwargs):
|
|
77
|
+
if not batch_series:
|
|
78
|
+
raise ValueError("Received empty batch for forecasting")
|
|
79
|
+
|
|
80
|
+
max_len = max(series.shape[0] for series in batch_series)
|
|
81
|
+
padded_ts = _pad_time_series_batch(batch_series, max_len)
|
|
82
|
+
|
|
83
|
+
return fc_func(padded_ts, **predict_kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _resample_fc_func_wrapper(
|
|
87
|
+
fc_func,
|
|
88
|
+
batch,
|
|
89
|
+
resample_strategy: str,
|
|
90
|
+
max_context: int = 2016,
|
|
91
|
+
**predict_kwargs,
|
|
92
|
+
):
|
|
93
|
+
# downsample the time series based on the dominant frequencies, if enabled
|
|
94
|
+
max_period = (max_context // 1000) * 500
|
|
95
|
+
prediction_length = predict_kwargs.get("prediction_length", 100)
|
|
96
|
+
batch_resampled_ts: list[torch.Tensor] = []
|
|
97
|
+
fc_resample_fns = []
|
|
98
|
+
scaling_factors = []
|
|
99
|
+
|
|
100
|
+
# select the function doing the resampling
|
|
101
|
+
ctx_resample_fn = lambda x: (x, 1.0, (lambda y: y))
|
|
102
|
+
match resample_strategy:
|
|
103
|
+
case "frequency":
|
|
104
|
+
ctx_resample_fn = frequency_resample
|
|
105
|
+
case _:
|
|
106
|
+
raise RuntimeError("This shouldn't happen.")
|
|
107
|
+
|
|
108
|
+
for series in batch:
|
|
109
|
+
resampled_ts, _sample_factor, fc_resample_fn = ctx_resample_fn(
|
|
110
|
+
series,
|
|
111
|
+
prediction_length=prediction_length,
|
|
112
|
+
max_period=max_period,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
batch_resampled_ts.append(resampled_ts)
|
|
116
|
+
fc_resample_fns.append(fc_resample_fn)
|
|
117
|
+
scaling_factors.append(_sample_factor)
|
|
118
|
+
|
|
119
|
+
# Compute per-item required horizons (in downsampled domain)
|
|
120
|
+
per_item_pred_lens = [int(ceil(prediction_length * sf)) for sf in scaling_factors]
|
|
121
|
+
max_pred_len = max(per_item_pred_lens) if per_item_pred_lens else int(prediction_length)
|
|
122
|
+
predict_kwargs.update(prediction_length=max_pred_len)
|
|
123
|
+
|
|
124
|
+
max_ts_length = max(ts.shape[0] for ts in batch_resampled_ts)
|
|
125
|
+
padded_ts = _pad_time_series_batch(batch_resampled_ts, max_ts_length)
|
|
126
|
+
print(f"Average sample batch factor: {sum(scaling_factors) / len(scaling_factors)}")
|
|
127
|
+
|
|
128
|
+
# generate prediction
|
|
129
|
+
fc_quantiles, fc_mean = fc_func(padded_ts, **predict_kwargs)
|
|
130
|
+
|
|
131
|
+
batch_prediction_q = []
|
|
132
|
+
batch_prediction_m = []
|
|
133
|
+
for el_q, el_m, fc_resample_fn, item_pred_len in zip(fc_quantiles, fc_mean, fc_resample_fns, per_item_pred_lens):
|
|
134
|
+
# truncate the forecasts to their individual sample factor adjusted prediction lengths
|
|
135
|
+
el_q = el_q[:item_pred_len, ...] # [T, Q]
|
|
136
|
+
el_m = el_m[:item_pred_len] # [T]
|
|
137
|
+
|
|
138
|
+
# upsample prediction
|
|
139
|
+
quantiles = fc_resample_fn(el_q.squeeze(0).transpose(0, 1)).transpose(0, 1) # [T, Q]
|
|
140
|
+
mean = fc_resample_fn(el_m.squeeze(0))
|
|
141
|
+
|
|
142
|
+
quantiles = quantiles[:prediction_length, ...]
|
|
143
|
+
mean = mean[:prediction_length]
|
|
144
|
+
|
|
145
|
+
batch_prediction_q.append(quantiles)
|
|
146
|
+
batch_prediction_m.append(mean)
|
|
147
|
+
|
|
148
|
+
return torch.stack(batch_prediction_q, dim=0), torch.stack(batch_prediction_m, dim=0)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _gen_forecast(
|
|
152
|
+
fc_func,
|
|
153
|
+
batches,
|
|
154
|
+
output_type,
|
|
155
|
+
quantile_levels,
|
|
156
|
+
yield_per_batch,
|
|
157
|
+
resample_strategy: str | None = None,
|
|
158
|
+
max_context: int = 2016,
|
|
159
|
+
**predict_kwargs,
|
|
160
|
+
):
|
|
161
|
+
base_fc_func = fc_func
|
|
162
|
+
|
|
163
|
+
if resample_strategy is not None:
|
|
164
|
+
if resample_strategy not in RESAMPLE_STRATEGIES:
|
|
165
|
+
raise ValueError(f"Invalid resample strategy: {resample_strategy}. Allowed: {RESAMPLE_STRATEGIES}")
|
|
166
|
+
fc_func = partial(
|
|
167
|
+
_resample_fc_func_wrapper,
|
|
168
|
+
base_fc_func,
|
|
169
|
+
resample_strategy=resample_strategy,
|
|
170
|
+
max_context=max_context,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
fc_func = partial(_call_fc_with_padding, base_fc_func)
|
|
174
|
+
|
|
49
175
|
if yield_per_batch:
|
|
50
176
|
return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
|
|
51
177
|
|
|
@@ -92,6 +218,9 @@ def _common_forecast_doc():
|
|
|
92
218
|
forecasts batch by batch as they are computed.
|
|
93
219
|
Defaults to `False`.
|
|
94
220
|
|
|
221
|
+
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
|
|
222
|
+
If `None`, no resampling is applied. Currently only "frequency" is supported.
|
|
223
|
+
|
|
95
224
|
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
|
96
225
|
prediction mechanism of the pre-trained model. Refer to the model's
|
|
97
226
|
internal prediction method documentation for available options.
|
|
@@ -113,6 +242,11 @@ class ForecastModel(ABC):
|
|
|
113
242
|
def _forecast_quantiles(self, batch, **predict_kwargs):
|
|
114
243
|
pass
|
|
115
244
|
|
|
245
|
+
@property
|
|
246
|
+
def max_context_length(self) -> int:
|
|
247
|
+
# retrieve the max_context attribute of the model configuration if present
|
|
248
|
+
return getattr(getattr(self, "config", None), "max_context", 2016)
|
|
249
|
+
|
|
116
250
|
def forecast(
|
|
117
251
|
self,
|
|
118
252
|
context: ContextType,
|
|
@@ -120,6 +254,7 @@ class ForecastModel(ABC):
|
|
|
120
254
|
batch_size: int = 512,
|
|
121
255
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
122
256
|
yield_per_batch: bool = False,
|
|
257
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
123
258
|
**predict_kwargs,
|
|
124
259
|
):
|
|
125
260
|
f"""
|
|
@@ -134,7 +269,14 @@ class ForecastModel(ABC):
|
|
|
134
269
|
assert batch_size >= 1, "Batch size must be >= 1"
|
|
135
270
|
batches = get_batches(context, batch_size)
|
|
136
271
|
return _gen_forecast(
|
|
137
|
-
self._forecast_quantiles,
|
|
272
|
+
self._forecast_quantiles,
|
|
273
|
+
batches,
|
|
274
|
+
output_type,
|
|
275
|
+
quantile_levels,
|
|
276
|
+
yield_per_batch,
|
|
277
|
+
resample_strategy=resample_strategy,
|
|
278
|
+
max_context=self.max_context_length,
|
|
279
|
+
**predict_kwargs,
|
|
138
280
|
)
|
|
139
281
|
|
|
140
282
|
def forecast_gluon(
|
|
@@ -144,6 +286,7 @@ class ForecastModel(ABC):
|
|
|
144
286
|
batch_size: int = 512,
|
|
145
287
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
146
288
|
yield_per_batch: bool = False,
|
|
289
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
147
290
|
data_kwargs: dict = {},
|
|
148
291
|
**predict_kwargs,
|
|
149
292
|
):
|
|
@@ -165,7 +308,14 @@ class ForecastModel(ABC):
|
|
|
165
308
|
|
|
166
309
|
batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
|
|
167
310
|
return _gen_forecast(
|
|
168
|
-
self._forecast_quantiles,
|
|
311
|
+
self._forecast_quantiles,
|
|
312
|
+
batches,
|
|
313
|
+
output_type,
|
|
314
|
+
quantile_levels,
|
|
315
|
+
yield_per_batch,
|
|
316
|
+
resample_strategy=resample_strategy,
|
|
317
|
+
max_context=self.max_context_length,
|
|
318
|
+
**predict_kwargs,
|
|
169
319
|
)
|
|
170
320
|
|
|
171
321
|
def forecast_hfdata(
|
|
@@ -175,6 +325,7 @@ class ForecastModel(ABC):
|
|
|
175
325
|
batch_size: int = 512,
|
|
176
326
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
177
327
|
yield_per_batch: bool = False,
|
|
328
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
178
329
|
data_kwargs: dict = {},
|
|
179
330
|
**predict_kwargs,
|
|
180
331
|
):
|
|
@@ -198,5 +349,12 @@ class ForecastModel(ABC):
|
|
|
198
349
|
|
|
199
350
|
batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
|
|
200
351
|
return _gen_forecast(
|
|
201
|
-
self._forecast_quantiles,
|
|
352
|
+
self._forecast_quantiles,
|
|
353
|
+
batches,
|
|
354
|
+
output_type,
|
|
355
|
+
quantile_levels,
|
|
356
|
+
yield_per_batch,
|
|
357
|
+
resample_strategy=resample_strategy,
|
|
358
|
+
max_context=self.max_context_length,
|
|
359
|
+
**predict_kwargs,
|
|
202
360
|
)
|
|
@@ -7,7 +7,7 @@ from gluonts.dataset.common import Dataset
|
|
|
7
7
|
from gluonts.dataset.field_names import FieldName
|
|
8
8
|
from gluonts.model.forecast import QuantileForecast
|
|
9
9
|
|
|
10
|
-
from .standard_adapter import
|
|
10
|
+
from .standard_adapter import _batch_iterable
|
|
11
11
|
|
|
12
12
|
DEF_TARGET_COLUMN = FieldName.TARGET # target
|
|
13
13
|
DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
|
|
@@ -27,7 +27,7 @@ def _get_gluon_ts_map(**gluon_kwargs):
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
|
|
30
|
-
return
|
|
30
|
+
return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
import datasets
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .standard_adapter import
|
|
7
|
+
from .standard_adapter import _batch_iterable
|
|
8
8
|
|
|
9
9
|
DEF_TARGET_COLUMN = "target"
|
|
10
10
|
|
|
@@ -35,4 +35,4 @@ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
|
|
|
35
35
|
|
|
36
36
|
def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
|
|
37
37
|
dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
|
|
38
|
-
return
|
|
38
|
+
return _batch_iterable(map(map_func, dataset), batch_size)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import itertools
|
|
5
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
ContextType = Union[
|
|
12
|
+
torch.Tensor,
|
|
13
|
+
np.ndarray,
|
|
14
|
+
list[torch.Tensor],
|
|
15
|
+
list[np.ndarray],
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _ensure_1d_tensor(sample) -> torch.Tensor:
|
|
20
|
+
if isinstance(sample, torch.Tensor):
|
|
21
|
+
tensor = sample
|
|
22
|
+
else:
|
|
23
|
+
tensor = torch.as_tensor(sample)
|
|
24
|
+
|
|
25
|
+
if tensor.ndim > 1:
|
|
26
|
+
tensor = tensor.squeeze()
|
|
27
|
+
|
|
28
|
+
assert tensor.ndim == 1, "Each sample must be one-dimensional"
|
|
29
|
+
return tensor
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _batched_slice(
|
|
33
|
+
full_batch,
|
|
34
|
+
full_meta: list[dict] | None,
|
|
35
|
+
batch_size: int,
|
|
36
|
+
) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
|
|
37
|
+
total = len(full_batch)
|
|
38
|
+
for start in range(0, total, batch_size):
|
|
39
|
+
batch = full_batch[start : start + batch_size]
|
|
40
|
+
meta = full_meta[start : start + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]
|
|
41
|
+
|
|
42
|
+
batch_series = []
|
|
43
|
+
for idx in range(len(batch)):
|
|
44
|
+
sample = batch[idx]
|
|
45
|
+
tensor = _ensure_1d_tensor(sample)
|
|
46
|
+
batch_series.append(tensor)
|
|
47
|
+
|
|
48
|
+
yield batch_series, meta
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _batched(iterable: Iterable, n: int):
|
|
52
|
+
it = iter(iterable)
|
|
53
|
+
while batch := tuple(itertools.islice(it, n)):
|
|
54
|
+
yield batch
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _batch_iterable(
|
|
58
|
+
iterable: Iterable[tuple[torch.Tensor, dict | None]],
|
|
59
|
+
batch_size: int,
|
|
60
|
+
) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
|
|
61
|
+
for batch in _batched(iterable, batch_size):
|
|
62
|
+
series_list: list[torch.Tensor] = []
|
|
63
|
+
meta_list: list[dict] = []
|
|
64
|
+
|
|
65
|
+
for sample, meta in batch:
|
|
66
|
+
tensor = _ensure_1d_tensor(sample)
|
|
67
|
+
assert len(tensor) > 0, "Each sample needs to have a length > 0"
|
|
68
|
+
series_list.append(tensor)
|
|
69
|
+
meta_list.append(meta if meta is not None else {})
|
|
70
|
+
|
|
71
|
+
yield series_list, meta_list
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_batches(context: ContextType, batch_size: int):
|
|
75
|
+
batches = None
|
|
76
|
+
if isinstance(context, torch.Tensor):
|
|
77
|
+
if context.ndim == 1:
|
|
78
|
+
context = context.unsqueeze(0)
|
|
79
|
+
assert context.ndim == 2
|
|
80
|
+
batches = _batched_slice(context, None, batch_size)
|
|
81
|
+
elif isinstance(context, np.ndarray):
|
|
82
|
+
if context.ndim == 1:
|
|
83
|
+
context = np.expand_dims(context, axis=0)
|
|
84
|
+
assert context.ndim == 2
|
|
85
|
+
batches = _batched_slice(context, None, batch_size)
|
|
86
|
+
elif isinstance(context, (list, Iterable)):
|
|
87
|
+
batches = _batch_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
|
|
88
|
+
if batches is None:
|
|
89
|
+
raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
|
|
90
|
+
return batches
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (c) NXAI GmbH.
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
import os
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from typing import Literal, TypeVar
|
|
@@ -8,7 +9,10 @@ from typing import Literal, TypeVar
|
|
|
8
9
|
import torch
|
|
9
10
|
from huggingface_hub import hf_hub_download
|
|
10
11
|
|
|
12
|
+
from tirex.models.slstm.cell import sLSTMCellTorch
|
|
13
|
+
|
|
11
14
|
T = TypeVar("T", bound="PretrainedModel")
|
|
15
|
+
VERSION_DELIMITER = "-"
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
def skip_cuda():
|
|
@@ -29,6 +33,17 @@ def parse_hf_repo_id(path):
|
|
|
29
33
|
return "/".join(parts[0:2])
|
|
30
34
|
|
|
31
35
|
|
|
36
|
+
def parse_model_string(model_string):
|
|
37
|
+
if VERSION_DELIMITER in model_string:
|
|
38
|
+
parts = model_string.split(VERSION_DELIMITER)
|
|
39
|
+
model_id, version = parts[0], parts[0]
|
|
40
|
+
else:
|
|
41
|
+
model_id = model_string
|
|
42
|
+
version = None
|
|
43
|
+
|
|
44
|
+
return model_id, version
|
|
45
|
+
|
|
46
|
+
|
|
32
47
|
class PretrainedModel(ABC):
|
|
33
48
|
REGISTRY: dict[str, "PretrainedModel"] = {}
|
|
34
49
|
|
|
@@ -38,7 +53,7 @@ class PretrainedModel(ABC):
|
|
|
38
53
|
|
|
39
54
|
@classmethod
|
|
40
55
|
def from_pretrained(
|
|
41
|
-
cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
|
|
56
|
+
cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
|
|
42
57
|
) -> T:
|
|
43
58
|
if hf_kwargs is None:
|
|
44
59
|
hf_kwargs = {}
|
|
@@ -58,9 +73,10 @@ class PretrainedModel(ABC):
|
|
|
58
73
|
model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
|
|
59
74
|
model.on_load_checkpoint(checkpoint)
|
|
60
75
|
model.load_state_dict(checkpoint["state_dict"])
|
|
76
|
+
model = model.to(device)
|
|
61
77
|
|
|
62
|
-
if backend == "
|
|
63
|
-
|
|
78
|
+
if compile and backend == "torch":
|
|
79
|
+
sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
|
|
64
80
|
return model
|
|
65
81
|
|
|
66
82
|
@classmethod
|
|
@@ -76,6 +92,7 @@ def load_model(
|
|
|
76
92
|
path: str,
|
|
77
93
|
device: str | None = None,
|
|
78
94
|
backend: Literal["torch", "cuda"] | None = None,
|
|
95
|
+
compile: bool = False,
|
|
79
96
|
hf_kwargs=None,
|
|
80
97
|
ckp_kwargs=None,
|
|
81
98
|
) -> PretrainedModel:
|
|
@@ -85,6 +102,7 @@ def load_model(
|
|
|
85
102
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
86
103
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
87
104
|
backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
|
|
105
|
+
compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
|
|
88
106
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
89
107
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
90
108
|
|
|
@@ -99,11 +117,14 @@ def load_model(
|
|
|
99
117
|
backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
|
|
100
118
|
|
|
101
119
|
try:
|
|
102
|
-
_,
|
|
120
|
+
_, model_string = parse_hf_repo_id(path).split("/")
|
|
121
|
+
model_id, version = parse_model_string(model_string)
|
|
103
122
|
except:
|
|
104
123
|
raise ValueError(f"Invalid model path {path}")
|
|
105
124
|
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
|
|
106
125
|
if model_cls is None:
|
|
107
126
|
raise ValueError(f"Invalid model id {model_id}")
|
|
108
127
|
|
|
109
|
-
return model_cls.from_pretrained(
|
|
128
|
+
return model_cls.from_pretrained(
|
|
129
|
+
path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
|
|
130
|
+
)
|
|
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
|
|
|
43
43
|
state = self._get_state(input, state)
|
|
44
44
|
|
|
45
45
|
if self.backend == "torch":
|
|
46
|
-
|
|
46
|
+
output, state = self._impl_torch(input, state)
|
|
47
47
|
elif self.backend == "cuda":
|
|
48
|
-
|
|
48
|
+
output, state = self._impl_cuda(input, state)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
output = self._permute_output(all_states[0][1:])
|
|
52
|
-
return output.to(input.dtype), state.to(input.dtype)
|
|
50
|
+
return self._permute_output(output).to(input.dtype), state.to(input.dtype)
|
|
53
51
|
|
|
54
52
|
def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
55
53
|
input = input.to(dtype=torch.bfloat16)
|
|
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
|
|
|
64
62
|
.reshape(-1)
|
|
65
63
|
)
|
|
66
64
|
|
|
67
|
-
return slstm_forward(input, state, recurrent_kernel, bias)
|
|
65
|
+
return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
|
|
68
66
|
|
|
69
67
|
def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
70
68
|
if input.device.type != "cuda":
|
|
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
|
|
|
88
86
|
|
|
89
87
|
input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
|
|
90
88
|
|
|
91
|
-
|
|
89
|
+
all_states = self.func.apply(
|
|
92
90
|
False,
|
|
93
91
|
input.contiguous(),
|
|
94
92
|
state.contiguous(),
|
|
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
|
|
|
96
94
|
self._bias_.contiguous(),
|
|
97
95
|
)
|
|
98
96
|
|
|
97
|
+
state = all_states[:, -1]
|
|
98
|
+
output = all_states[0][1:]
|
|
99
|
+
return output, state
|
|
100
|
+
|
|
99
101
|
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
102
|
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
101
103
|
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
|
|
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
|
|
|
119
121
|
return output.permute(1, 2, 0, 3)
|
|
120
122
|
|
|
121
123
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
states
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
.
|
|
151
|
-
.
|
|
152
|
-
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
mnew = iraw
|
|
180
|
-
else:
|
|
181
|
-
mnew = torch.max(iraw, logfplusm) # eq 15
|
|
182
|
-
ogate = torch.sigmoid(oraw) # eq 14
|
|
183
|
-
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
184
|
-
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
185
|
-
zgate = torch.tanh(zraw) # eq 11
|
|
186
|
-
cnew = fgate * c + igate * zgate # eq 8
|
|
187
|
-
nnew = fgate * n + igate # eq 9
|
|
188
|
-
hnew = ogate * cnew / nnew # eq 10
|
|
189
|
-
|
|
190
|
-
# y (4, B, H), state (4, B, H)
|
|
191
|
-
return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
124
|
+
class sLSTMCellTorch:
|
|
125
|
+
@staticmethod
|
|
126
|
+
def slstm_forward(
|
|
127
|
+
x: torch.Tensor, # [S, B, G*I]
|
|
128
|
+
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
129
|
+
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
130
|
+
b: torch.Tensor, # [T*H]
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
132
|
+
num_gates = 4
|
|
133
|
+
num_heads = R.shape[0]
|
|
134
|
+
S, B, _ = x.shape
|
|
135
|
+
H = R.shape[1] * num_heads
|
|
136
|
+
assert states.shape == (num_gates, B, H)
|
|
137
|
+
|
|
138
|
+
states = states.to(R.dtype).unbind(dim=0)
|
|
139
|
+
output = []
|
|
140
|
+
for i in range(S):
|
|
141
|
+
Ry = (
|
|
142
|
+
states[0]
|
|
143
|
+
.reshape(B, num_heads, 1, -1)
|
|
144
|
+
.matmul(R.unsqueeze(0))
|
|
145
|
+
.reshape(B, num_heads, num_gates, -1)
|
|
146
|
+
.transpose(1, 2)
|
|
147
|
+
.reshape(B, -1)
|
|
148
|
+
)
|
|
149
|
+
states = sLSTMCellTorch.slstm_forward_pointwise(
|
|
150
|
+
x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
|
|
151
|
+
)
|
|
152
|
+
states = [s.to(dtype=R.dtype) for s in states]
|
|
153
|
+
output.append(states[0])
|
|
154
|
+
|
|
155
|
+
return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def slstm_forward_pointwise(
|
|
159
|
+
Wx: torch.Tensor, # dim [B, 4*H]
|
|
160
|
+
Ry: torch.Tensor, # dim [B, 4*H]
|
|
161
|
+
b: torch.Tensor, # dim [1, 4*H]
|
|
162
|
+
states: torch.Tensor, # dim 4 x [B, H]
|
|
163
|
+
) -> list[torch.Tensor]:
|
|
164
|
+
y, c, n, m = states
|
|
165
|
+
|
|
166
|
+
raw = Wx + Ry + b
|
|
167
|
+
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
168
|
+
|
|
169
|
+
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
|
|
170
|
+
logfplusm = m + F.logsigmoid(fraw) # eq 15
|
|
171
|
+
mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
|
|
172
|
+
ogate = torch.sigmoid(oraw) # eq 14
|
|
173
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
174
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
175
|
+
zgate = torch.tanh(zraw) # eq 11
|
|
176
|
+
cnew = fgate * c + igate * zgate # eq 8
|
|
177
|
+
nnew = fgate * n + igate # eq 9
|
|
178
|
+
hnew = ogate * cnew / nnew # eq 10
|
|
179
|
+
|
|
180
|
+
return [hnew, cnew, nnew, mnew] # 4 x (B, H)
|