tirex-mirror 2025.10.3__tar.gz → 2025.10.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {tirex_mirror-2025.10.3/src/tirex_mirror.egg-info → tirex_mirror-2025.10.7}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/forecast.py +163 -5
  4. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/gluon.py +2 -2
  5. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/hf_data.py +2 -2
  6. tirex_mirror-2025.10.7/src/tirex/api_adapter/standard_adapter.py +90 -0
  7. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/base.py +14 -1
  8. tirex_mirror-2025.10.7/src/tirex/util.py +617 -0
  9. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
  10. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/SOURCES.txt +2 -1
  11. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/tests/test_forecast.py +17 -6
  12. tirex_mirror-2025.10.7/tests/test_standard_adapter.py +166 -0
  13. tirex_mirror-2025.10.7/tests/test_util_freq.py +112 -0
  14. tirex_mirror-2025.10.3/src/tirex/api_adapter/standard_adapter.py +0 -67
  15. tirex_mirror-2025.10.3/src/tirex/util.py +0 -13
  16. tirex_mirror-2025.10.3/tests/test_standard_adapter.py +0 -183
  17. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/LICENSE +0 -0
  18. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/LICENSE_MIRROR.txt +0 -0
  19. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/MANIFEST.in +0 -0
  20. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/NOTICE.txt +0 -0
  21. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/README.md +0 -0
  22. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/setup.cfg +0 -0
  23. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/__init__.py +0 -0
  24. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/api_adapter/__init__.py +0 -0
  25. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/__init__.py +0 -0
  26. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/patcher.py +0 -0
  27. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/block.py +0 -0
  28. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/cell.py +0 -0
  29. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/slstm/layer.py +0 -0
  30. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex/models/tirex.py +0 -0
  31. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  32. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/requires.txt +0 -0
  33. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  34. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/tests/test_chronos_zs.py +0 -0
  35. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/tests/test_forecast_adapter.py +0 -0
  36. {tirex_mirror-2025.10.3 → tirex_mirror-2025.10.7}/tests/test_slstm_torch_vs_cuda.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.3
3
+ Version: 2025.10.7
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.03"
3
+ version = "2025.10.07"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -2,15 +2,22 @@
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Literal
5
+ from functools import partial
6
+ from math import ceil
7
+ from typing import Literal, Optional
6
8
 
7
9
  import torch
8
10
 
11
+ from tirex.util import frequency_resample
12
+
9
13
  from .standard_adapter import ContextType, get_batches
10
14
 
11
15
  DEF_TARGET_COLUMN = "target"
12
16
  DEF_META_COLUMNS = ("start", "item_id")
13
17
 
18
+ # Allowed resampling strategies (extend as new strategies are implemented)
19
+ RESAMPLE_STRATEGIES: list[str] = ["frequency"]
20
+
14
21
 
15
22
  def _format_output(
16
23
  quantiles: torch.Tensor,
@@ -33,6 +40,27 @@ def _format_output(
33
40
  raise ValueError(f"Invalid output type: {output_type}")
34
41
 
35
42
 
43
+ def _pad_time_series_batch(
44
+ batch_series: list[torch.Tensor],
45
+ max_length: int,
46
+ ) -> torch.Tensor:
47
+ if not batch_series:
48
+ return torch.empty((0, max_length))
49
+
50
+ first = batch_series[0]
51
+ dtype = first.dtype if first.is_floating_point() else torch.float32
52
+ device = first.device
53
+
54
+ padded = torch.full((len(batch_series), max_length), float("nan"), dtype=dtype, device=device)
55
+
56
+ for idx, series in enumerate(batch_series):
57
+ series = series.to(padded.dtype)
58
+ series_len = series.shape[0]
59
+ padded[idx, max_length - series_len :] = series
60
+
61
+ return padded
62
+
63
+
36
64
  def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
37
65
  for batch_ctx, batch_meta in batches:
38
66
  quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
@@ -45,7 +73,105 @@ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwar
45
73
  )
46
74
 
47
75
 
48
- def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
76
+ def _call_fc_with_padding(fc_func, batch_series: list[torch.Tensor], **predict_kwargs):
77
+ if not batch_series:
78
+ raise ValueError("Received empty batch for forecasting")
79
+
80
+ max_len = max(series.shape[0] for series in batch_series)
81
+ padded_ts = _pad_time_series_batch(batch_series, max_len)
82
+
83
+ return fc_func(padded_ts, **predict_kwargs)
84
+
85
+
86
+ def _resample_fc_func_wrapper(
87
+ fc_func,
88
+ batch,
89
+ resample_strategy: str,
90
+ max_context: int = 2016,
91
+ **predict_kwargs,
92
+ ):
93
+ # downsample the time series based on the dominant frequencies, if enabled
94
+ max_period = (max_context // 1000) * 500
95
+ prediction_length = predict_kwargs.get("prediction_length", 100)
96
+ batch_resampled_ts: list[torch.Tensor] = []
97
+ fc_resample_fns = []
98
+ scaling_factors = []
99
+
100
+ # select the function doing the resampling
101
+ ctx_resample_fn = lambda x: (x, 1.0, (lambda y: y))
102
+ match resample_strategy:
103
+ case "frequency":
104
+ ctx_resample_fn = frequency_resample
105
+ case _:
106
+ raise RuntimeError("This shouldn't happen.")
107
+
108
+ for series in batch:
109
+ resampled_ts, _sample_factor, fc_resample_fn = ctx_resample_fn(
110
+ series,
111
+ prediction_length=prediction_length,
112
+ max_period=max_period,
113
+ )
114
+
115
+ batch_resampled_ts.append(resampled_ts)
116
+ fc_resample_fns.append(fc_resample_fn)
117
+ scaling_factors.append(_sample_factor)
118
+
119
+ # Compute per-item required horizons (in downsampled domain)
120
+ per_item_pred_lens = [int(ceil(prediction_length * sf)) for sf in scaling_factors]
121
+ max_pred_len = max(per_item_pred_lens) if per_item_pred_lens else int(prediction_length)
122
+ predict_kwargs.update(prediction_length=max_pred_len)
123
+
124
+ max_ts_length = max(ts.shape[0] for ts in batch_resampled_ts)
125
+ padded_ts = _pad_time_series_batch(batch_resampled_ts, max_ts_length)
126
+ print(f"Average sample batch factor: {sum(scaling_factors) / len(scaling_factors)}")
127
+
128
+ # generate prediction
129
+ fc_quantiles, fc_mean = fc_func(padded_ts, **predict_kwargs)
130
+
131
+ batch_prediction_q = []
132
+ batch_prediction_m = []
133
+ for el_q, el_m, fc_resample_fn, item_pred_len in zip(fc_quantiles, fc_mean, fc_resample_fns, per_item_pred_lens):
134
+ # truncate the forecasts to their individual sample factor adjusted prediction lengths
135
+ el_q = el_q[:item_pred_len, ...] # [T, Q]
136
+ el_m = el_m[:item_pred_len] # [T]
137
+
138
+ # upsample prediction
139
+ quantiles = fc_resample_fn(el_q.squeeze(0).transpose(0, 1)).transpose(0, 1) # [T, Q]
140
+ mean = fc_resample_fn(el_m.squeeze(0))
141
+
142
+ quantiles = quantiles[:prediction_length, ...]
143
+ mean = mean[:prediction_length]
144
+
145
+ batch_prediction_q.append(quantiles)
146
+ batch_prediction_m.append(mean)
147
+
148
+ return torch.stack(batch_prediction_q, dim=0), torch.stack(batch_prediction_m, dim=0)
149
+
150
+
151
+ def _gen_forecast(
152
+ fc_func,
153
+ batches,
154
+ output_type,
155
+ quantile_levels,
156
+ yield_per_batch,
157
+ resample_strategy: str | None = None,
158
+ max_context: int = 2016,
159
+ **predict_kwargs,
160
+ ):
161
+ base_fc_func = fc_func
162
+
163
+ if resample_strategy is not None:
164
+ if resample_strategy not in RESAMPLE_STRATEGIES:
165
+ raise ValueError(f"Invalid resample strategy: {resample_strategy}. Allowed: {RESAMPLE_STRATEGIES}")
166
+ fc_func = partial(
167
+ _resample_fc_func_wrapper,
168
+ base_fc_func,
169
+ resample_strategy=resample_strategy,
170
+ max_context=max_context,
171
+ )
172
+ else:
173
+ fc_func = partial(_call_fc_with_padding, base_fc_func)
174
+
49
175
  if yield_per_batch:
50
176
  return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
51
177
 
@@ -92,6 +218,9 @@ def _common_forecast_doc():
92
218
  forecasts batch by batch as they are computed.
93
219
  Defaults to `False`.
94
220
 
221
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
222
+ If `None`, no resampling is applied. Currently only "frequency" is supported.
223
+
95
224
  **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
96
225
  prediction mechanism of the pre-trained model. Refer to the model's
97
226
  internal prediction method documentation for available options.
@@ -113,6 +242,11 @@ class ForecastModel(ABC):
113
242
  def _forecast_quantiles(self, batch, **predict_kwargs):
114
243
  pass
115
244
 
245
+ @property
246
+ def max_context_length(self) -> int:
247
+ # retrieve the max_context attribute of the model configuration if present
248
+ return getattr(getattr(self, "config", None), "max_context", 2016)
249
+
116
250
  def forecast(
117
251
  self,
118
252
  context: ContextType,
@@ -120,6 +254,7 @@ class ForecastModel(ABC):
120
254
  batch_size: int = 512,
121
255
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
122
256
  yield_per_batch: bool = False,
257
+ resample_strategy: Literal["frequency"] | None = None,
123
258
  **predict_kwargs,
124
259
  ):
125
260
  f"""
@@ -134,7 +269,14 @@ class ForecastModel(ABC):
134
269
  assert batch_size >= 1, "Batch size must be >= 1"
135
270
  batches = get_batches(context, batch_size)
136
271
  return _gen_forecast(
137
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
272
+ self._forecast_quantiles,
273
+ batches,
274
+ output_type,
275
+ quantile_levels,
276
+ yield_per_batch,
277
+ resample_strategy=resample_strategy,
278
+ max_context=self.max_context_length,
279
+ **predict_kwargs,
138
280
  )
139
281
 
140
282
  def forecast_gluon(
@@ -144,6 +286,7 @@ class ForecastModel(ABC):
144
286
  batch_size: int = 512,
145
287
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
146
288
  yield_per_batch: bool = False,
289
+ resample_strategy: Literal["frequency"] | None = None,
147
290
  data_kwargs: dict = {},
148
291
  **predict_kwargs,
149
292
  ):
@@ -165,7 +308,14 @@ class ForecastModel(ABC):
165
308
 
166
309
  batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
167
310
  return _gen_forecast(
168
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
311
+ self._forecast_quantiles,
312
+ batches,
313
+ output_type,
314
+ quantile_levels,
315
+ yield_per_batch,
316
+ resample_strategy=resample_strategy,
317
+ max_context=self.max_context_length,
318
+ **predict_kwargs,
169
319
  )
170
320
 
171
321
  def forecast_hfdata(
@@ -175,6 +325,7 @@ class ForecastModel(ABC):
175
325
  batch_size: int = 512,
176
326
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
177
327
  yield_per_batch: bool = False,
328
+ resample_strategy: Literal["frequency"] | None = None,
178
329
  data_kwargs: dict = {},
179
330
  **predict_kwargs,
180
331
  ):
@@ -198,5 +349,12 @@ class ForecastModel(ABC):
198
349
 
199
350
  batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
200
351
  return _gen_forecast(
201
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
352
+ self._forecast_quantiles,
353
+ batches,
354
+ output_type,
355
+ quantile_levels,
356
+ yield_per_batch,
357
+ resample_strategy=resample_strategy,
358
+ max_context=self.max_context_length,
359
+ **predict_kwargs,
202
360
  )
@@ -7,7 +7,7 @@ from gluonts.dataset.common import Dataset
7
7
  from gluonts.dataset.field_names import FieldName
8
8
  from gluonts.model.forecast import QuantileForecast
9
9
 
10
- from .standard_adapter import _batch_pad_iterable
10
+ from .standard_adapter import _batch_iterable
11
11
 
12
12
  DEF_TARGET_COLUMN = FieldName.TARGET # target
13
13
  DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
@@ -27,7 +27,7 @@ def _get_gluon_ts_map(**gluon_kwargs):
27
27
 
28
28
 
29
29
  def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
- return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
30
+ return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
31
 
32
32
 
33
33
  def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
@@ -4,7 +4,7 @@
4
4
  import datasets
5
5
  import torch
6
6
 
7
- from .standard_adapter import _batch_pad_iterable
7
+ from .standard_adapter import _batch_iterable
8
8
 
9
9
  DEF_TARGET_COLUMN = "target"
10
10
 
@@ -35,4 +35,4 @@ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
35
35
 
36
36
  def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
37
37
  dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
38
- return _batch_pad_iterable(map(map_func, dataset), batch_size)
38
+ return _batch_iterable(map(map_func, dataset), batch_size)
@@ -0,0 +1,90 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import itertools
5
+ from collections.abc import Iterable, Iterator, Sequence
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ ContextType = Union[
12
+ torch.Tensor,
13
+ np.ndarray,
14
+ list[torch.Tensor],
15
+ list[np.ndarray],
16
+ ]
17
+
18
+
19
+ def _ensure_1d_tensor(sample) -> torch.Tensor:
20
+ if isinstance(sample, torch.Tensor):
21
+ tensor = sample
22
+ else:
23
+ tensor = torch.as_tensor(sample)
24
+
25
+ if tensor.ndim > 1:
26
+ tensor = tensor.squeeze()
27
+
28
+ assert tensor.ndim == 1, "Each sample must be one-dimensional"
29
+ return tensor
30
+
31
+
32
+ def _batched_slice(
33
+ full_batch,
34
+ full_meta: list[dict] | None,
35
+ batch_size: int,
36
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
37
+ total = len(full_batch)
38
+ for start in range(0, total, batch_size):
39
+ batch = full_batch[start : start + batch_size]
40
+ meta = full_meta[start : start + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]
41
+
42
+ batch_series = []
43
+ for idx in range(len(batch)):
44
+ sample = batch[idx]
45
+ tensor = _ensure_1d_tensor(sample)
46
+ batch_series.append(tensor)
47
+
48
+ yield batch_series, meta
49
+
50
+
51
+ def _batched(iterable: Iterable, n: int):
52
+ it = iter(iterable)
53
+ while batch := tuple(itertools.islice(it, n)):
54
+ yield batch
55
+
56
+
57
+ def _batch_iterable(
58
+ iterable: Iterable[tuple[torch.Tensor, dict | None]],
59
+ batch_size: int,
60
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
61
+ for batch in _batched(iterable, batch_size):
62
+ series_list: list[torch.Tensor] = []
63
+ meta_list: list[dict] = []
64
+
65
+ for sample, meta in batch:
66
+ tensor = _ensure_1d_tensor(sample)
67
+ assert len(tensor) > 0, "Each sample needs to have a length > 0"
68
+ series_list.append(tensor)
69
+ meta_list.append(meta if meta is not None else {})
70
+
71
+ yield series_list, meta_list
72
+
73
+
74
+ def get_batches(context: ContextType, batch_size: int):
75
+ batches = None
76
+ if isinstance(context, torch.Tensor):
77
+ if context.ndim == 1:
78
+ context = context.unsqueeze(0)
79
+ assert context.ndim == 2
80
+ batches = _batched_slice(context, None, batch_size)
81
+ elif isinstance(context, np.ndarray):
82
+ if context.ndim == 1:
83
+ context = np.expand_dims(context, axis=0)
84
+ assert context.ndim == 2
85
+ batches = _batched_slice(context, None, batch_size)
86
+ elif isinstance(context, (list, Iterable)):
87
+ batches = _batch_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
88
+ if batches is None:
89
+ raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
90
+ return batches
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download
12
12
  from tirex.models.slstm.cell import sLSTMCellTorch
13
13
 
14
14
  T = TypeVar("T", bound="PretrainedModel")
15
+ VERSION_DELIMITER = "-"
15
16
 
16
17
 
17
18
  def skip_cuda():
@@ -32,6 +33,17 @@ def parse_hf_repo_id(path):
32
33
  return "/".join(parts[0:2])
33
34
 
34
35
 
36
+ def parse_model_string(model_string):
37
+ if VERSION_DELIMITER in model_string:
38
+ parts = model_string.split(VERSION_DELIMITER)
39
+ model_id, version = parts[0], parts[0]
40
+ else:
41
+ model_id = model_string
42
+ version = None
43
+
44
+ return model_id, version
45
+
46
+
35
47
  class PretrainedModel(ABC):
36
48
  REGISTRY: dict[str, "PretrainedModel"] = {}
37
49
 
@@ -105,7 +117,8 @@ def load_model(
105
117
  backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
106
118
 
107
119
  try:
108
- _, model_id = parse_hf_repo_id(path).split("/")
120
+ _, model_string = parse_hf_repo_id(path).split("/")
121
+ model_id, version = parse_model_string(model_string)
109
122
  except:
110
123
  raise ValueError(f"Invalid model path {path}")
111
124
  model_cls = PretrainedModel.REGISTRY.get(model_id, None)