tirex-mirror 2025.10.2__tar.gz → 2025.10.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {tirex_mirror-2025.10.2/src/tirex_mirror.egg-info → tirex_mirror-2025.10.7}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/forecast.py +163 -5
  4. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/gluon.py +2 -2
  5. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/hf_data.py +2 -2
  6. tirex_mirror-2025.10.7/src/tirex/api_adapter/standard_adapter.py +90 -0
  7. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/base.py +26 -5
  8. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/cell.py +66 -77
  9. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/tirex.py +17 -0
  10. tirex_mirror-2025.10.7/src/tirex/util.py +617 -0
  11. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
  12. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/SOURCES.txt +2 -1
  13. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_forecast.py +17 -6
  14. tirex_mirror-2025.10.7/tests/test_standard_adapter.py +166 -0
  15. tirex_mirror-2025.10.7/tests/test_util_freq.py +112 -0
  16. tirex_mirror-2025.10.2/src/tirex/api_adapter/standard_adapter.py +0 -67
  17. tirex_mirror-2025.10.2/src/tirex/util.py +0 -13
  18. tirex_mirror-2025.10.2/tests/test_standard_adapter.py +0 -183
  19. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/LICENSE +0 -0
  20. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/LICENSE_MIRROR.txt +0 -0
  21. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/MANIFEST.in +0 -0
  22. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/NOTICE.txt +0 -0
  23. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/README.md +0 -0
  24. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/setup.cfg +0 -0
  25. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/__init__.py +0 -0
  26. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/__init__.py +0 -0
  27. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/__init__.py +0 -0
  28. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/patcher.py +0 -0
  29. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/block.py +0 -0
  30. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/layer.py +0 -0
  31. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  32. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/requires.txt +0 -0
  33. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  34. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_chronos_zs.py +0 -0
  35. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_forecast_adapter.py +0 -0
  36. {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.7}/tests/test_slstm_torch_vs_cuda.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.2
3
+ Version: 2025.10.7
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.02"
3
+ version = "2025.10.07"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -2,15 +2,22 @@
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Literal
5
+ from functools import partial
6
+ from math import ceil
7
+ from typing import Literal, Optional
6
8
 
7
9
  import torch
8
10
 
11
+ from tirex.util import frequency_resample
12
+
9
13
  from .standard_adapter import ContextType, get_batches
10
14
 
11
15
  DEF_TARGET_COLUMN = "target"
12
16
  DEF_META_COLUMNS = ("start", "item_id")
13
17
 
18
+ # Allowed resampling strategies (extend as new strategies are implemented)
19
+ RESAMPLE_STRATEGIES: list[str] = ["frequency"]
20
+
14
21
 
15
22
  def _format_output(
16
23
  quantiles: torch.Tensor,
@@ -33,6 +40,27 @@ def _format_output(
33
40
  raise ValueError(f"Invalid output type: {output_type}")
34
41
 
35
42
 
43
+ def _pad_time_series_batch(
44
+ batch_series: list[torch.Tensor],
45
+ max_length: int,
46
+ ) -> torch.Tensor:
47
+ if not batch_series:
48
+ return torch.empty((0, max_length))
49
+
50
+ first = batch_series[0]
51
+ dtype = first.dtype if first.is_floating_point() else torch.float32
52
+ device = first.device
53
+
54
+ padded = torch.full((len(batch_series), max_length), float("nan"), dtype=dtype, device=device)
55
+
56
+ for idx, series in enumerate(batch_series):
57
+ series = series.to(padded.dtype)
58
+ series_len = series.shape[0]
59
+ padded[idx, max_length - series_len :] = series
60
+
61
+ return padded
62
+
63
+
36
64
  def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
37
65
  for batch_ctx, batch_meta in batches:
38
66
  quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
@@ -45,7 +73,105 @@ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwar
45
73
  )
46
74
 
47
75
 
48
- def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
76
+ def _call_fc_with_padding(fc_func, batch_series: list[torch.Tensor], **predict_kwargs):
77
+ if not batch_series:
78
+ raise ValueError("Received empty batch for forecasting")
79
+
80
+ max_len = max(series.shape[0] for series in batch_series)
81
+ padded_ts = _pad_time_series_batch(batch_series, max_len)
82
+
83
+ return fc_func(padded_ts, **predict_kwargs)
84
+
85
+
86
+ def _resample_fc_func_wrapper(
87
+ fc_func,
88
+ batch,
89
+ resample_strategy: str,
90
+ max_context: int = 2016,
91
+ **predict_kwargs,
92
+ ):
93
+ # downsample the time series based on the dominant frequencies, if enabled
94
+ max_period = (max_context // 1000) * 500
95
+ prediction_length = predict_kwargs.get("prediction_length", 100)
96
+ batch_resampled_ts: list[torch.Tensor] = []
97
+ fc_resample_fns = []
98
+ scaling_factors = []
99
+
100
+ # select the function doing the resampling
101
+ ctx_resample_fn = lambda x: (x, 1.0, (lambda y: y))
102
+ match resample_strategy:
103
+ case "frequency":
104
+ ctx_resample_fn = frequency_resample
105
+ case _:
106
+ raise RuntimeError("This shouldn't happen.")
107
+
108
+ for series in batch:
109
+ resampled_ts, _sample_factor, fc_resample_fn = ctx_resample_fn(
110
+ series,
111
+ prediction_length=prediction_length,
112
+ max_period=max_period,
113
+ )
114
+
115
+ batch_resampled_ts.append(resampled_ts)
116
+ fc_resample_fns.append(fc_resample_fn)
117
+ scaling_factors.append(_sample_factor)
118
+
119
+ # Compute per-item required horizons (in downsampled domain)
120
+ per_item_pred_lens = [int(ceil(prediction_length * sf)) for sf in scaling_factors]
121
+ max_pred_len = max(per_item_pred_lens) if per_item_pred_lens else int(prediction_length)
122
+ predict_kwargs.update(prediction_length=max_pred_len)
123
+
124
+ max_ts_length = max(ts.shape[0] for ts in batch_resampled_ts)
125
+ padded_ts = _pad_time_series_batch(batch_resampled_ts, max_ts_length)
126
+ print(f"Average sample batch factor: {sum(scaling_factors) / len(scaling_factors)}")
127
+
128
+ # generate prediction
129
+ fc_quantiles, fc_mean = fc_func(padded_ts, **predict_kwargs)
130
+
131
+ batch_prediction_q = []
132
+ batch_prediction_m = []
133
+ for el_q, el_m, fc_resample_fn, item_pred_len in zip(fc_quantiles, fc_mean, fc_resample_fns, per_item_pred_lens):
134
+ # truncate the forecasts to their individual sample factor adjusted prediction lengths
135
+ el_q = el_q[:item_pred_len, ...] # [T, Q]
136
+ el_m = el_m[:item_pred_len] # [T]
137
+
138
+ # upsample prediction
139
+ quantiles = fc_resample_fn(el_q.squeeze(0).transpose(0, 1)).transpose(0, 1) # [T, Q]
140
+ mean = fc_resample_fn(el_m.squeeze(0))
141
+
142
+ quantiles = quantiles[:prediction_length, ...]
143
+ mean = mean[:prediction_length]
144
+
145
+ batch_prediction_q.append(quantiles)
146
+ batch_prediction_m.append(mean)
147
+
148
+ return torch.stack(batch_prediction_q, dim=0), torch.stack(batch_prediction_m, dim=0)
149
+
150
+
151
+ def _gen_forecast(
152
+ fc_func,
153
+ batches,
154
+ output_type,
155
+ quantile_levels,
156
+ yield_per_batch,
157
+ resample_strategy: str | None = None,
158
+ max_context: int = 2016,
159
+ **predict_kwargs,
160
+ ):
161
+ base_fc_func = fc_func
162
+
163
+ if resample_strategy is not None:
164
+ if resample_strategy not in RESAMPLE_STRATEGIES:
165
+ raise ValueError(f"Invalid resample strategy: {resample_strategy}. Allowed: {RESAMPLE_STRATEGIES}")
166
+ fc_func = partial(
167
+ _resample_fc_func_wrapper,
168
+ base_fc_func,
169
+ resample_strategy=resample_strategy,
170
+ max_context=max_context,
171
+ )
172
+ else:
173
+ fc_func = partial(_call_fc_with_padding, base_fc_func)
174
+
49
175
  if yield_per_batch:
50
176
  return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
51
177
 
@@ -92,6 +218,9 @@ def _common_forecast_doc():
92
218
  forecasts batch by batch as they are computed.
93
219
  Defaults to `False`.
94
220
 
221
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
222
+ If `None`, no resampling is applied. Currently only "frequency" is supported.
223
+
95
224
  **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
96
225
  prediction mechanism of the pre-trained model. Refer to the model's
97
226
  internal prediction method documentation for available options.
@@ -113,6 +242,11 @@ class ForecastModel(ABC):
113
242
  def _forecast_quantiles(self, batch, **predict_kwargs):
114
243
  pass
115
244
 
245
+ @property
246
+ def max_context_length(self) -> int:
247
+ # retrieve the max_context attribute of the model configuration if present
248
+ return getattr(getattr(self, "config", None), "max_context", 2016)
249
+
116
250
  def forecast(
117
251
  self,
118
252
  context: ContextType,
@@ -120,6 +254,7 @@ class ForecastModel(ABC):
120
254
  batch_size: int = 512,
121
255
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
122
256
  yield_per_batch: bool = False,
257
+ resample_strategy: Literal["frequency"] | None = None,
123
258
  **predict_kwargs,
124
259
  ):
125
260
  f"""
@@ -134,7 +269,14 @@ class ForecastModel(ABC):
134
269
  assert batch_size >= 1, "Batch size must be >= 1"
135
270
  batches = get_batches(context, batch_size)
136
271
  return _gen_forecast(
137
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
272
+ self._forecast_quantiles,
273
+ batches,
274
+ output_type,
275
+ quantile_levels,
276
+ yield_per_batch,
277
+ resample_strategy=resample_strategy,
278
+ max_context=self.max_context_length,
279
+ **predict_kwargs,
138
280
  )
139
281
 
140
282
  def forecast_gluon(
@@ -144,6 +286,7 @@ class ForecastModel(ABC):
144
286
  batch_size: int = 512,
145
287
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
146
288
  yield_per_batch: bool = False,
289
+ resample_strategy: Literal["frequency"] | None = None,
147
290
  data_kwargs: dict = {},
148
291
  **predict_kwargs,
149
292
  ):
@@ -165,7 +308,14 @@ class ForecastModel(ABC):
165
308
 
166
309
  batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
167
310
  return _gen_forecast(
168
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
311
+ self._forecast_quantiles,
312
+ batches,
313
+ output_type,
314
+ quantile_levels,
315
+ yield_per_batch,
316
+ resample_strategy=resample_strategy,
317
+ max_context=self.max_context_length,
318
+ **predict_kwargs,
169
319
  )
170
320
 
171
321
  def forecast_hfdata(
@@ -175,6 +325,7 @@ class ForecastModel(ABC):
175
325
  batch_size: int = 512,
176
326
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
177
327
  yield_per_batch: bool = False,
328
+ resample_strategy: Literal["frequency"] | None = None,
178
329
  data_kwargs: dict = {},
179
330
  **predict_kwargs,
180
331
  ):
@@ -198,5 +349,12 @@ class ForecastModel(ABC):
198
349
 
199
350
  batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
200
351
  return _gen_forecast(
201
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
352
+ self._forecast_quantiles,
353
+ batches,
354
+ output_type,
355
+ quantile_levels,
356
+ yield_per_batch,
357
+ resample_strategy=resample_strategy,
358
+ max_context=self.max_context_length,
359
+ **predict_kwargs,
202
360
  )
@@ -7,7 +7,7 @@ from gluonts.dataset.common import Dataset
7
7
  from gluonts.dataset.field_names import FieldName
8
8
  from gluonts.model.forecast import QuantileForecast
9
9
 
10
- from .standard_adapter import _batch_pad_iterable
10
+ from .standard_adapter import _batch_iterable
11
11
 
12
12
  DEF_TARGET_COLUMN = FieldName.TARGET # target
13
13
  DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
@@ -27,7 +27,7 @@ def _get_gluon_ts_map(**gluon_kwargs):
27
27
 
28
28
 
29
29
  def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
- return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
30
+ return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
31
 
32
32
 
33
33
  def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
@@ -4,7 +4,7 @@
4
4
  import datasets
5
5
  import torch
6
6
 
7
- from .standard_adapter import _batch_pad_iterable
7
+ from .standard_adapter import _batch_iterable
8
8
 
9
9
  DEF_TARGET_COLUMN = "target"
10
10
 
@@ -35,4 +35,4 @@ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
35
35
 
36
36
  def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
37
37
  dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
38
- return _batch_pad_iterable(map(map_func, dataset), batch_size)
38
+ return _batch_iterable(map(map_func, dataset), batch_size)
@@ -0,0 +1,90 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import itertools
5
+ from collections.abc import Iterable, Iterator, Sequence
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ ContextType = Union[
12
+ torch.Tensor,
13
+ np.ndarray,
14
+ list[torch.Tensor],
15
+ list[np.ndarray],
16
+ ]
17
+
18
+
19
+ def _ensure_1d_tensor(sample) -> torch.Tensor:
20
+ if isinstance(sample, torch.Tensor):
21
+ tensor = sample
22
+ else:
23
+ tensor = torch.as_tensor(sample)
24
+
25
+ if tensor.ndim > 1:
26
+ tensor = tensor.squeeze()
27
+
28
+ assert tensor.ndim == 1, "Each sample must be one-dimensional"
29
+ return tensor
30
+
31
+
32
+ def _batched_slice(
33
+ full_batch,
34
+ full_meta: list[dict] | None,
35
+ batch_size: int,
36
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
37
+ total = len(full_batch)
38
+ for start in range(0, total, batch_size):
39
+ batch = full_batch[start : start + batch_size]
40
+ meta = full_meta[start : start + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]
41
+
42
+ batch_series = []
43
+ for idx in range(len(batch)):
44
+ sample = batch[idx]
45
+ tensor = _ensure_1d_tensor(sample)
46
+ batch_series.append(tensor)
47
+
48
+ yield batch_series, meta
49
+
50
+
51
+ def _batched(iterable: Iterable, n: int):
52
+ it = iter(iterable)
53
+ while batch := tuple(itertools.islice(it, n)):
54
+ yield batch
55
+
56
+
57
+ def _batch_iterable(
58
+ iterable: Iterable[tuple[torch.Tensor, dict | None]],
59
+ batch_size: int,
60
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
61
+ for batch in _batched(iterable, batch_size):
62
+ series_list: list[torch.Tensor] = []
63
+ meta_list: list[dict] = []
64
+
65
+ for sample, meta in batch:
66
+ tensor = _ensure_1d_tensor(sample)
67
+ assert len(tensor) > 0, "Each sample needs to have a length > 0"
68
+ series_list.append(tensor)
69
+ meta_list.append(meta if meta is not None else {})
70
+
71
+ yield series_list, meta_list
72
+
73
+
74
+ def get_batches(context: ContextType, batch_size: int):
75
+ batches = None
76
+ if isinstance(context, torch.Tensor):
77
+ if context.ndim == 1:
78
+ context = context.unsqueeze(0)
79
+ assert context.ndim == 2
80
+ batches = _batched_slice(context, None, batch_size)
81
+ elif isinstance(context, np.ndarray):
82
+ if context.ndim == 1:
83
+ context = np.expand_dims(context, axis=0)
84
+ assert context.ndim == 2
85
+ batches = _batched_slice(context, None, batch_size)
86
+ elif isinstance(context, (list, Iterable)):
87
+ batches = _batch_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
88
+ if batches is None:
89
+ raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
90
+ return batches
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) NXAI GmbH.
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
+ import logging
4
5
  import os
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Literal, TypeVar
@@ -8,7 +9,10 @@ from typing import Literal, TypeVar
8
9
  import torch
9
10
  from huggingface_hub import hf_hub_download
10
11
 
12
+ from tirex.models.slstm.cell import sLSTMCellTorch
13
+
11
14
  T = TypeVar("T", bound="PretrainedModel")
15
+ VERSION_DELIMITER = "-"
12
16
 
13
17
 
14
18
  def skip_cuda():
@@ -29,6 +33,17 @@ def parse_hf_repo_id(path):
29
33
  return "/".join(parts[0:2])
30
34
 
31
35
 
36
+ def parse_model_string(model_string):
37
+ if VERSION_DELIMITER in model_string:
38
+ parts = model_string.split(VERSION_DELIMITER)
39
+ model_id, version = parts[0], parts[0]
40
+ else:
41
+ model_id = model_string
42
+ version = None
43
+
44
+ return model_id, version
45
+
46
+
32
47
  class PretrainedModel(ABC):
33
48
  REGISTRY: dict[str, "PretrainedModel"] = {}
34
49
 
@@ -38,7 +53,7 @@ class PretrainedModel(ABC):
38
53
 
39
54
  @classmethod
40
55
  def from_pretrained(
41
- cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
56
+ cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
42
57
  ) -> T:
43
58
  if hf_kwargs is None:
44
59
  hf_kwargs = {}
@@ -58,9 +73,10 @@ class PretrainedModel(ABC):
58
73
  model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
59
74
  model.on_load_checkpoint(checkpoint)
60
75
  model.load_state_dict(checkpoint["state_dict"])
76
+ model = model.to(device)
61
77
 
62
- if backend == "cuda":
63
- model = model.to(device)
78
+ if compile and backend == "torch":
79
+ sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
64
80
  return model
65
81
 
66
82
  @classmethod
@@ -76,6 +92,7 @@ def load_model(
76
92
  path: str,
77
93
  device: str | None = None,
78
94
  backend: Literal["torch", "cuda"] | None = None,
95
+ compile: bool = False,
79
96
  hf_kwargs=None,
80
97
  ckp_kwargs=None,
81
98
  ) -> PretrainedModel:
@@ -85,6 +102,7 @@ def load_model(
85
102
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
86
103
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
87
104
  backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
105
+ compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
88
106
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
89
107
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
90
108
 
@@ -99,11 +117,14 @@ def load_model(
99
117
  backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
100
118
 
101
119
  try:
102
- _, model_id = parse_hf_repo_id(path).split("/")
120
+ _, model_string = parse_hf_repo_id(path).split("/")
121
+ model_id, version = parse_model_string(model_string)
103
122
  except:
104
123
  raise ValueError(f"Invalid model path {path}")
105
124
  model_cls = PretrainedModel.REGISTRY.get(model_id, None)
106
125
  if model_cls is None:
107
126
  raise ValueError(f"Invalid model id {model_id}")
108
127
 
109
- return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
128
+ return model_cls.from_pretrained(
129
+ path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
130
+ )
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
43
43
  state = self._get_state(input, state)
44
44
 
45
45
  if self.backend == "torch":
46
- all_states = self._impl_torch(input, state)
46
+ output, state = self._impl_torch(input, state)
47
47
  elif self.backend == "cuda":
48
- all_states = self._impl_cuda(input, state)
48
+ output, state = self._impl_cuda(input, state)
49
49
 
50
- state = all_states[:, -1]
51
- output = self._permute_output(all_states[0][1:])
52
- return output.to(input.dtype), state.to(input.dtype)
50
+ return self._permute_output(output).to(input.dtype), state.to(input.dtype)
53
51
 
54
52
  def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
55
53
  input = input.to(dtype=torch.bfloat16)
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
64
62
  .reshape(-1)
65
63
  )
66
64
 
67
- return slstm_forward(input, state, recurrent_kernel, bias)[0]
65
+ return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
68
66
 
69
67
  def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
70
68
  if input.device.type != "cuda":
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
88
86
 
89
87
  input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
90
88
 
91
- return self.func.apply(
89
+ all_states = self.func.apply(
92
90
  False,
93
91
  input.contiguous(),
94
92
  state.contiguous(),
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
96
94
  self._bias_.contiguous(),
97
95
  )
98
96
 
97
+ state = all_states[:, -1]
98
+ output = all_states[0][1:]
99
+ return output, state
100
+
99
101
  def _get_input(self, x: torch.Tensor) -> torch.Tensor:
100
102
  assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
101
103
  f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
119
121
  return output.permute(1, 2, 0, 3)
120
122
 
121
123
 
122
- def slstm_forward(
123
- x: torch.Tensor, # [S, B, G*I]
124
- states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
125
- R: torch.Tensor, # [K, R*H, H] - K num_heads
126
- b: torch.Tensor, # [T*H]
127
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
- num_states = states.shape[0]
129
- sequence_dim = x.shape[0]
130
- # this only works for a fully-connected RNN, for a hin change this
131
- num_gates_r = R.shape[2] // R.shape[1]
132
- hidden_dim = R.shape[1] * R.shape[0]
133
- batch_dim = x.shape[1]
134
- num_heads = R.shape[0]
135
-
136
- assert batch_dim == states.shape[1]
137
- assert hidden_dim == states.shape[2]
138
-
139
- states_all = torch.zeros(
140
- [num_states, sequence_dim + 1, batch_dim, hidden_dim],
141
- device=x.device,
142
- dtype=x.dtype,
143
- )
144
- states_all[:, 0] = states
145
- for i, Wx_t in enumerate(x.unbind(dim=0)):
146
- Ry = (
147
- states[0]
148
- .reshape(batch_dim, num_heads, 1, -1)
149
- .matmul(R.unsqueeze(0))
150
- .reshape(batch_dim, num_heads, num_gates_r, -1)
151
- .transpose(1, 2)
152
- .reshape(batch_dim, -1)
153
- )
154
- sdtype = states.dtype
155
- Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
156
- states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
157
- states = states.to(dtype=sdtype)
158
- states_all[:, i + 1] = states
159
-
160
- # shapes ([S, B, H], ([B,H], [B,H], [B,H])
161
- return states_all, states
162
-
163
-
164
- def slstm_forward_pointwise(
165
- Wx: torch.Tensor, # dim [B, 4*H]
166
- Ry: torch.Tensor, # dim [B, 4*H]
167
- b: torch.Tensor, # dim [1, 4*H]
168
- states: torch.Tensor, # dim [4, B, H]
169
- ) -> tuple[torch.Tensor, torch.Tensor]:
170
- raw = Wx + Ry + b
171
-
172
- iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
173
- y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
174
-
175
- # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
176
- # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
177
- logfplusm = m + F.logsigmoid(fraw) # eq 15
178
- if torch.all(n == 0.0):
179
- mnew = iraw
180
- else:
181
- mnew = torch.max(iraw, logfplusm) # eq 15
182
- ogate = torch.sigmoid(oraw) # eq 14
183
- igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
184
- fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
185
- zgate = torch.tanh(zraw) # eq 11
186
- cnew = fgate * c + igate * zgate # eq 8
187
- nnew = fgate * n + igate # eq 9
188
- hnew = ogate * cnew / nnew # eq 10
189
-
190
- # y (4, B, H), state (4, B, H)
191
- return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
124
+ class sLSTMCellTorch:
125
+ @staticmethod
126
+ def slstm_forward(
127
+ x: torch.Tensor, # [S, B, G*I]
128
+ states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
129
+ R: torch.Tensor, # [K, R*H, H] - K num_heads
130
+ b: torch.Tensor, # [T*H]
131
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
132
+ num_gates = 4
133
+ num_heads = R.shape[0]
134
+ S, B, _ = x.shape
135
+ H = R.shape[1] * num_heads
136
+ assert states.shape == (num_gates, B, H)
137
+
138
+ states = states.to(R.dtype).unbind(dim=0)
139
+ output = []
140
+ for i in range(S):
141
+ Ry = (
142
+ states[0]
143
+ .reshape(B, num_heads, 1, -1)
144
+ .matmul(R.unsqueeze(0))
145
+ .reshape(B, num_heads, num_gates, -1)
146
+ .transpose(1, 2)
147
+ .reshape(B, -1)
148
+ )
149
+ states = sLSTMCellTorch.slstm_forward_pointwise(
150
+ x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
151
+ )
152
+ states = [s.to(dtype=R.dtype) for s in states]
153
+ output.append(states[0])
154
+
155
+ return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
156
+
157
+ @staticmethod
158
+ def slstm_forward_pointwise(
159
+ Wx: torch.Tensor, # dim [B, 4*H]
160
+ Ry: torch.Tensor, # dim [B, 4*H]
161
+ b: torch.Tensor, # dim [1, 4*H]
162
+ states: torch.Tensor, # dim 4 x [B, H]
163
+ ) -> list[torch.Tensor]:
164
+ y, c, n, m = states
165
+
166
+ raw = Wx + Ry + b
167
+ iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
168
+
169
+ # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
170
+ logfplusm = m + F.logsigmoid(fraw) # eq 15
171
+ mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
172
+ ogate = torch.sigmoid(oraw) # eq 14
173
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
174
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
175
+ zgate = torch.tanh(zraw) # eq 11
176
+ cnew = fgate * c + igate * zgate # eq 8
177
+ nnew = fgate * n + igate # eq 9
178
+ hnew = ogate * cnew / nnew # eq 10
179
+
180
+ return [hnew, cnew, nnew, mnew] # 4 x (B, H)