tirex-mirror 2025.10.2__py3-none-any.whl → 2025.10.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,15 +2,22 @@
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Literal
5
+ from functools import partial
6
+ from math import ceil
7
+ from typing import Literal, Optional
6
8
 
7
9
  import torch
8
10
 
11
+ from tirex.util import frequency_resample
12
+
9
13
  from .standard_adapter import ContextType, get_batches
10
14
 
11
15
  DEF_TARGET_COLUMN = "target"
12
16
  DEF_META_COLUMNS = ("start", "item_id")
13
17
 
18
+ # Allowed resampling strategies (extend as new strategies are implemented)
19
+ RESAMPLE_STRATEGIES: list[str] = ["frequency"]
20
+
14
21
 
15
22
  def _format_output(
16
23
  quantiles: torch.Tensor,
@@ -33,6 +40,27 @@ def _format_output(
33
40
  raise ValueError(f"Invalid output type: {output_type}")
34
41
 
35
42
 
43
+ def _pad_time_series_batch(
44
+ batch_series: list[torch.Tensor],
45
+ max_length: int,
46
+ ) -> torch.Tensor:
47
+ if not batch_series:
48
+ return torch.empty((0, max_length))
49
+
50
+ first = batch_series[0]
51
+ dtype = first.dtype if first.is_floating_point() else torch.float32
52
+ device = first.device
53
+
54
+ padded = torch.full((len(batch_series), max_length), float("nan"), dtype=dtype, device=device)
55
+
56
+ for idx, series in enumerate(batch_series):
57
+ series = series.to(padded.dtype)
58
+ series_len = series.shape[0]
59
+ padded[idx, max_length - series_len :] = series
60
+
61
+ return padded
62
+
63
+
36
64
  def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
37
65
  for batch_ctx, batch_meta in batches:
38
66
  quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
@@ -45,7 +73,105 @@ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwar
45
73
  )
46
74
 
47
75
 
48
- def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
76
+ def _call_fc_with_padding(fc_func, batch_series: list[torch.Tensor], **predict_kwargs):
77
+ if not batch_series:
78
+ raise ValueError("Received empty batch for forecasting")
79
+
80
+ max_len = max(series.shape[0] for series in batch_series)
81
+ padded_ts = _pad_time_series_batch(batch_series, max_len)
82
+
83
+ return fc_func(padded_ts, **predict_kwargs)
84
+
85
+
86
+ def _resample_fc_func_wrapper(
87
+ fc_func,
88
+ batch,
89
+ resample_strategy: str,
90
+ max_context: int = 2016,
91
+ **predict_kwargs,
92
+ ):
93
+ # downsample the time series based on the dominant frequencies, if enabled
94
+ max_period = (max_context // 1000) * 500
95
+ prediction_length = predict_kwargs.get("prediction_length", 100)
96
+ batch_resampled_ts: list[torch.Tensor] = []
97
+ fc_resample_fns = []
98
+ scaling_factors = []
99
+
100
+ # select the function doing the resampling
101
+ ctx_resample_fn = lambda x: (x, 1.0, (lambda y: y))
102
+ match resample_strategy:
103
+ case "frequency":
104
+ ctx_resample_fn = frequency_resample
105
+ case _:
106
+ raise RuntimeError("This shouldn't happen.")
107
+
108
+ for series in batch:
109
+ resampled_ts, _sample_factor, fc_resample_fn = ctx_resample_fn(
110
+ series,
111
+ prediction_length=prediction_length,
112
+ max_period=max_period,
113
+ )
114
+
115
+ batch_resampled_ts.append(resampled_ts)
116
+ fc_resample_fns.append(fc_resample_fn)
117
+ scaling_factors.append(_sample_factor)
118
+
119
+ # Compute per-item required horizons (in downsampled domain)
120
+ per_item_pred_lens = [int(ceil(prediction_length * sf)) for sf in scaling_factors]
121
+ max_pred_len = max(per_item_pred_lens) if per_item_pred_lens else int(prediction_length)
122
+ predict_kwargs.update(prediction_length=max_pred_len)
123
+
124
+ max_ts_length = max(ts.shape[0] for ts in batch_resampled_ts)
125
+ padded_ts = _pad_time_series_batch(batch_resampled_ts, max_ts_length)
126
+ print(f"Average sample batch factor: {sum(scaling_factors) / len(scaling_factors)}")
127
+
128
+ # generate prediction
129
+ fc_quantiles, fc_mean = fc_func(padded_ts, **predict_kwargs)
130
+
131
+ batch_prediction_q = []
132
+ batch_prediction_m = []
133
+ for el_q, el_m, fc_resample_fn, item_pred_len in zip(fc_quantiles, fc_mean, fc_resample_fns, per_item_pred_lens):
134
+ # truncate the forecasts to their individual sample factor adjusted prediction lengths
135
+ el_q = el_q[:item_pred_len, ...] # [T, Q]
136
+ el_m = el_m[:item_pred_len] # [T]
137
+
138
+ # upsample prediction
139
+ quantiles = fc_resample_fn(el_q.squeeze(0).transpose(0, 1)).transpose(0, 1) # [T, Q]
140
+ mean = fc_resample_fn(el_m.squeeze(0))
141
+
142
+ quantiles = quantiles[:prediction_length, ...]
143
+ mean = mean[:prediction_length]
144
+
145
+ batch_prediction_q.append(quantiles)
146
+ batch_prediction_m.append(mean)
147
+
148
+ return torch.stack(batch_prediction_q, dim=0), torch.stack(batch_prediction_m, dim=0)
149
+
150
+
151
+ def _gen_forecast(
152
+ fc_func,
153
+ batches,
154
+ output_type,
155
+ quantile_levels,
156
+ yield_per_batch,
157
+ resample_strategy: str | None = None,
158
+ max_context: int = 2016,
159
+ **predict_kwargs,
160
+ ):
161
+ base_fc_func = fc_func
162
+
163
+ if resample_strategy is not None:
164
+ if resample_strategy not in RESAMPLE_STRATEGIES:
165
+ raise ValueError(f"Invalid resample strategy: {resample_strategy}. Allowed: {RESAMPLE_STRATEGIES}")
166
+ fc_func = partial(
167
+ _resample_fc_func_wrapper,
168
+ base_fc_func,
169
+ resample_strategy=resample_strategy,
170
+ max_context=max_context,
171
+ )
172
+ else:
173
+ fc_func = partial(_call_fc_with_padding, base_fc_func)
174
+
49
175
  if yield_per_batch:
50
176
  return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
51
177
 
@@ -92,6 +218,9 @@ def _common_forecast_doc():
92
218
  forecasts batch by batch as they are computed.
93
219
  Defaults to `False`.
94
220
 
221
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
222
+ If `None`, no resampling is applied. Currently only "frequency" is supported.
223
+
95
224
  **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
96
225
  prediction mechanism of the pre-trained model. Refer to the model's
97
226
  internal prediction method documentation for available options.
@@ -113,6 +242,11 @@ class ForecastModel(ABC):
113
242
  def _forecast_quantiles(self, batch, **predict_kwargs):
114
243
  pass
115
244
 
245
+ @property
246
+ def max_context_length(self) -> int:
247
+ # retrieve the max_context attribute of the model configuration if present
248
+ return getattr(getattr(self, "config", None), "max_context", 2016)
249
+
116
250
  def forecast(
117
251
  self,
118
252
  context: ContextType,
@@ -120,6 +254,7 @@ class ForecastModel(ABC):
120
254
  batch_size: int = 512,
121
255
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
122
256
  yield_per_batch: bool = False,
257
+ resample_strategy: Literal["frequency"] | None = None,
123
258
  **predict_kwargs,
124
259
  ):
125
260
  f"""
@@ -134,7 +269,14 @@ class ForecastModel(ABC):
134
269
  assert batch_size >= 1, "Batch size must be >= 1"
135
270
  batches = get_batches(context, batch_size)
136
271
  return _gen_forecast(
137
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
272
+ self._forecast_quantiles,
273
+ batches,
274
+ output_type,
275
+ quantile_levels,
276
+ yield_per_batch,
277
+ resample_strategy=resample_strategy,
278
+ max_context=self.max_context_length,
279
+ **predict_kwargs,
138
280
  )
139
281
 
140
282
  def forecast_gluon(
@@ -144,6 +286,7 @@ class ForecastModel(ABC):
144
286
  batch_size: int = 512,
145
287
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
146
288
  yield_per_batch: bool = False,
289
+ resample_strategy: Literal["frequency"] | None = None,
147
290
  data_kwargs: dict = {},
148
291
  **predict_kwargs,
149
292
  ):
@@ -165,7 +308,14 @@ class ForecastModel(ABC):
165
308
 
166
309
  batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
167
310
  return _gen_forecast(
168
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
311
+ self._forecast_quantiles,
312
+ batches,
313
+ output_type,
314
+ quantile_levels,
315
+ yield_per_batch,
316
+ resample_strategy=resample_strategy,
317
+ max_context=self.max_context_length,
318
+ **predict_kwargs,
169
319
  )
170
320
 
171
321
  def forecast_hfdata(
@@ -175,6 +325,7 @@ class ForecastModel(ABC):
175
325
  batch_size: int = 512,
176
326
  quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
177
327
  yield_per_batch: bool = False,
328
+ resample_strategy: Literal["frequency"] | None = None,
178
329
  data_kwargs: dict = {},
179
330
  **predict_kwargs,
180
331
  ):
@@ -198,5 +349,12 @@ class ForecastModel(ABC):
198
349
 
199
350
  batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
200
351
  return _gen_forecast(
201
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
352
+ self._forecast_quantiles,
353
+ batches,
354
+ output_type,
355
+ quantile_levels,
356
+ yield_per_batch,
357
+ resample_strategy=resample_strategy,
358
+ max_context=self.max_context_length,
359
+ **predict_kwargs,
202
360
  )
@@ -7,7 +7,7 @@ from gluonts.dataset.common import Dataset
7
7
  from gluonts.dataset.field_names import FieldName
8
8
  from gluonts.model.forecast import QuantileForecast
9
9
 
10
- from .standard_adapter import _batch_pad_iterable
10
+ from .standard_adapter import _batch_iterable
11
11
 
12
12
  DEF_TARGET_COLUMN = FieldName.TARGET # target
13
13
  DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
@@ -27,7 +27,7 @@ def _get_gluon_ts_map(**gluon_kwargs):
27
27
 
28
28
 
29
29
  def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
- return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
30
+ return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
31
 
32
32
 
33
33
  def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
@@ -4,7 +4,7 @@
4
4
  import datasets
5
5
  import torch
6
6
 
7
- from .standard_adapter import _batch_pad_iterable
7
+ from .standard_adapter import _batch_iterable
8
8
 
9
9
  DEF_TARGET_COLUMN = "target"
10
10
 
@@ -35,4 +35,4 @@ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
35
35
 
36
36
  def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
37
37
  dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
38
- return _batch_pad_iterable(map(map_func, dataset), batch_size)
38
+ return _batch_iterable(map(map_func, dataset), batch_size)
@@ -16,13 +16,36 @@ ContextType = Union[
16
16
  ]
17
17
 
18
18
 
19
- def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]:
20
- if len(full_batch) <= batch_size:
21
- yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))]
19
+ def _ensure_1d_tensor(sample) -> torch.Tensor:
20
+ if isinstance(sample, torch.Tensor):
21
+ tensor = sample
22
22
  else:
23
- for i in range(0, len(full_batch), batch_size):
24
- batch = full_batch[i : i + batch_size]
25
- yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))])
23
+ tensor = torch.as_tensor(sample)
24
+
25
+ if tensor.ndim > 1:
26
+ tensor = tensor.squeeze()
27
+
28
+ assert tensor.ndim == 1, "Each sample must be one-dimensional"
29
+ return tensor
30
+
31
+
32
+ def _batched_slice(
33
+ full_batch,
34
+ full_meta: list[dict] | None,
35
+ batch_size: int,
36
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
37
+ total = len(full_batch)
38
+ for start in range(0, total, batch_size):
39
+ batch = full_batch[start : start + batch_size]
40
+ meta = full_meta[start : start + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]
41
+
42
+ batch_series = []
43
+ for idx in range(len(batch)):
44
+ sample = batch[idx]
45
+ tensor = _ensure_1d_tensor(sample)
46
+ batch_series.append(tensor)
47
+
48
+ yield batch_series, meta
26
49
 
27
50
 
28
51
  def _batched(iterable: Iterable, n: int):
@@ -31,21 +54,21 @@ def _batched(iterable: Iterable, n: int):
31
54
  yield batch
32
55
 
33
56
 
34
- def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int):
57
+ def _batch_iterable(
58
+ iterable: Iterable[tuple[torch.Tensor, dict | None]],
59
+ batch_size: int,
60
+ ) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
35
61
  for batch in _batched(iterable, batch_size):
36
- # ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3)
37
- max_len = max(len(el[0]) for el in batch)
38
- padded_batch = []
39
- meta = []
40
- for el in batch:
41
- sample = el[0]
42
- assert isinstance(sample, torch.Tensor)
43
- assert sample.ndim == 1
44
- assert len(sample) > 0, "Each sample needs to have a length > 0"
45
- padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device)
46
- padded_batch.append(torch.cat((padding, sample)))
47
- meta.append(el[1])
48
- yield torch.stack(padded_batch), meta
62
+ series_list: list[torch.Tensor] = []
63
+ meta_list: list[dict] = []
64
+
65
+ for sample, meta in batch:
66
+ tensor = _ensure_1d_tensor(sample)
67
+ assert len(tensor) > 0, "Each sample needs to have a length > 0"
68
+ series_list.append(tensor)
69
+ meta_list.append(meta if meta is not None else {})
70
+
71
+ yield series_list, meta_list
49
72
 
50
73
 
51
74
  def get_batches(context: ContextType, batch_size: int):
@@ -59,9 +82,9 @@ def get_batches(context: ContextType, batch_size: int):
59
82
  if context.ndim == 1:
60
83
  context = np.expand_dims(context, axis=0)
61
84
  assert context.ndim == 2
62
- batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size))
85
+ batches = _batched_slice(context, None, batch_size)
63
86
  elif isinstance(context, (list, Iterable)):
64
- batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
87
+ batches = _batch_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
65
88
  if batches is None:
66
89
  raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
67
90
  return batches
tirex/base.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) NXAI GmbH.
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
+ import logging
4
5
  import os
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Literal, TypeVar
@@ -8,7 +9,10 @@ from typing import Literal, TypeVar
8
9
  import torch
9
10
  from huggingface_hub import hf_hub_download
10
11
 
12
+ from tirex.models.slstm.cell import sLSTMCellTorch
13
+
11
14
  T = TypeVar("T", bound="PretrainedModel")
15
+ VERSION_DELIMITER = "-"
12
16
 
13
17
 
14
18
  def skip_cuda():
@@ -29,6 +33,17 @@ def parse_hf_repo_id(path):
29
33
  return "/".join(parts[0:2])
30
34
 
31
35
 
36
+ def parse_model_string(model_string):
37
+ if VERSION_DELIMITER in model_string:
38
+ parts = model_string.split(VERSION_DELIMITER)
39
+ model_id, version = parts[0], parts[0]
40
+ else:
41
+ model_id = model_string
42
+ version = None
43
+
44
+ return model_id, version
45
+
46
+
32
47
  class PretrainedModel(ABC):
33
48
  REGISTRY: dict[str, "PretrainedModel"] = {}
34
49
 
@@ -38,7 +53,7 @@ class PretrainedModel(ABC):
38
53
 
39
54
  @classmethod
40
55
  def from_pretrained(
41
- cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
56
+ cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
42
57
  ) -> T:
43
58
  if hf_kwargs is None:
44
59
  hf_kwargs = {}
@@ -58,9 +73,10 @@ class PretrainedModel(ABC):
58
73
  model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
59
74
  model.on_load_checkpoint(checkpoint)
60
75
  model.load_state_dict(checkpoint["state_dict"])
76
+ model = model.to(device)
61
77
 
62
- if backend == "cuda":
63
- model = model.to(device)
78
+ if compile and backend == "torch":
79
+ sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
64
80
  return model
65
81
 
66
82
  @classmethod
@@ -76,6 +92,7 @@ def load_model(
76
92
  path: str,
77
93
  device: str | None = None,
78
94
  backend: Literal["torch", "cuda"] | None = None,
95
+ compile: bool = False,
79
96
  hf_kwargs=None,
80
97
  ckp_kwargs=None,
81
98
  ) -> PretrainedModel:
@@ -85,6 +102,7 @@ def load_model(
85
102
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
86
103
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
87
104
  backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
105
+ compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
88
106
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
89
107
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
90
108
 
@@ -99,11 +117,14 @@ def load_model(
99
117
  backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
100
118
 
101
119
  try:
102
- _, model_id = parse_hf_repo_id(path).split("/")
120
+ _, model_string = parse_hf_repo_id(path).split("/")
121
+ model_id, version = parse_model_string(model_string)
103
122
  except:
104
123
  raise ValueError(f"Invalid model path {path}")
105
124
  model_cls = PretrainedModel.REGISTRY.get(model_id, None)
106
125
  if model_cls is None:
107
126
  raise ValueError(f"Invalid model id {model_id}")
108
127
 
109
- return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
128
+ return model_cls.from_pretrained(
129
+ path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
130
+ )
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
43
43
  state = self._get_state(input, state)
44
44
 
45
45
  if self.backend == "torch":
46
- all_states = self._impl_torch(input, state)
46
+ output, state = self._impl_torch(input, state)
47
47
  elif self.backend == "cuda":
48
- all_states = self._impl_cuda(input, state)
48
+ output, state = self._impl_cuda(input, state)
49
49
 
50
- state = all_states[:, -1]
51
- output = self._permute_output(all_states[0][1:])
52
- return output.to(input.dtype), state.to(input.dtype)
50
+ return self._permute_output(output).to(input.dtype), state.to(input.dtype)
53
51
 
54
52
  def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
55
53
  input = input.to(dtype=torch.bfloat16)
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
64
62
  .reshape(-1)
65
63
  )
66
64
 
67
- return slstm_forward(input, state, recurrent_kernel, bias)[0]
65
+ return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
68
66
 
69
67
  def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
70
68
  if input.device.type != "cuda":
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
88
86
 
89
87
  input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
90
88
 
91
- return self.func.apply(
89
+ all_states = self.func.apply(
92
90
  False,
93
91
  input.contiguous(),
94
92
  state.contiguous(),
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
96
94
  self._bias_.contiguous(),
97
95
  )
98
96
 
97
+ state = all_states[:, -1]
98
+ output = all_states[0][1:]
99
+ return output, state
100
+
99
101
  def _get_input(self, x: torch.Tensor) -> torch.Tensor:
100
102
  assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
101
103
  f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
119
121
  return output.permute(1, 2, 0, 3)
120
122
 
121
123
 
122
- def slstm_forward(
123
- x: torch.Tensor, # [S, B, G*I]
124
- states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
125
- R: torch.Tensor, # [K, R*H, H] - K num_heads
126
- b: torch.Tensor, # [T*H]
127
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
- num_states = states.shape[0]
129
- sequence_dim = x.shape[0]
130
- # this only works for a fully-connected RNN, for a hin change this
131
- num_gates_r = R.shape[2] // R.shape[1]
132
- hidden_dim = R.shape[1] * R.shape[0]
133
- batch_dim = x.shape[1]
134
- num_heads = R.shape[0]
135
-
136
- assert batch_dim == states.shape[1]
137
- assert hidden_dim == states.shape[2]
138
-
139
- states_all = torch.zeros(
140
- [num_states, sequence_dim + 1, batch_dim, hidden_dim],
141
- device=x.device,
142
- dtype=x.dtype,
143
- )
144
- states_all[:, 0] = states
145
- for i, Wx_t in enumerate(x.unbind(dim=0)):
146
- Ry = (
147
- states[0]
148
- .reshape(batch_dim, num_heads, 1, -1)
149
- .matmul(R.unsqueeze(0))
150
- .reshape(batch_dim, num_heads, num_gates_r, -1)
151
- .transpose(1, 2)
152
- .reshape(batch_dim, -1)
153
- )
154
- sdtype = states.dtype
155
- Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
156
- states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
157
- states = states.to(dtype=sdtype)
158
- states_all[:, i + 1] = states
159
-
160
- # shapes ([S, B, H], ([B,H], [B,H], [B,H])
161
- return states_all, states
162
-
163
-
164
- def slstm_forward_pointwise(
165
- Wx: torch.Tensor, # dim [B, 4*H]
166
- Ry: torch.Tensor, # dim [B, 4*H]
167
- b: torch.Tensor, # dim [1, 4*H]
168
- states: torch.Tensor, # dim [4, B, H]
169
- ) -> tuple[torch.Tensor, torch.Tensor]:
170
- raw = Wx + Ry + b
171
-
172
- iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
173
- y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
174
-
175
- # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
176
- # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
177
- logfplusm = m + F.logsigmoid(fraw) # eq 15
178
- if torch.all(n == 0.0):
179
- mnew = iraw
180
- else:
181
- mnew = torch.max(iraw, logfplusm) # eq 15
182
- ogate = torch.sigmoid(oraw) # eq 14
183
- igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
184
- fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
185
- zgate = torch.tanh(zraw) # eq 11
186
- cnew = fgate * c + igate * zgate # eq 8
187
- nnew = fgate * n + igate # eq 9
188
- hnew = ogate * cnew / nnew # eq 10
189
-
190
- # y (4, B, H), state (4, B, H)
191
- return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
124
+ class sLSTMCellTorch:
125
+ @staticmethod
126
+ def slstm_forward(
127
+ x: torch.Tensor, # [S, B, G*I]
128
+ states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
129
+ R: torch.Tensor, # [K, R*H, H] - K num_heads
130
+ b: torch.Tensor, # [T*H]
131
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
132
+ num_gates = 4
133
+ num_heads = R.shape[0]
134
+ S, B, _ = x.shape
135
+ H = R.shape[1] * num_heads
136
+ assert states.shape == (num_gates, B, H)
137
+
138
+ states = states.to(R.dtype).unbind(dim=0)
139
+ output = []
140
+ for i in range(S):
141
+ Ry = (
142
+ states[0]
143
+ .reshape(B, num_heads, 1, -1)
144
+ .matmul(R.unsqueeze(0))
145
+ .reshape(B, num_heads, num_gates, -1)
146
+ .transpose(1, 2)
147
+ .reshape(B, -1)
148
+ )
149
+ states = sLSTMCellTorch.slstm_forward_pointwise(
150
+ x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
151
+ )
152
+ states = [s.to(dtype=R.dtype) for s in states]
153
+ output.append(states[0])
154
+
155
+ return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
156
+
157
+ @staticmethod
158
+ def slstm_forward_pointwise(
159
+ Wx: torch.Tensor, # dim [B, 4*H]
160
+ Ry: torch.Tensor, # dim [B, 4*H]
161
+ b: torch.Tensor, # dim [1, 4*H]
162
+ states: torch.Tensor, # dim 4 x [B, H]
163
+ ) -> list[torch.Tensor]:
164
+ y, c, n, m = states
165
+
166
+ raw = Wx + Ry + b
167
+ iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
168
+
169
+ # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
170
+ logfplusm = m + F.logsigmoid(fraw) # eq 15
171
+ mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
172
+ ogate = torch.sigmoid(oraw) # eq 14
173
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
174
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
175
+ zgate = torch.tanh(zraw) # eq 11
176
+ cnew = fgate * c + igate * zgate # eq 8
177
+ nnew = fgate * n + igate # eq 9
178
+ hnew = ogate * cnew / nnew # eq 10
179
+
180
+ return [hnew, cnew, nnew, mnew] # 4 x (B, H)
tirex/models/tirex.py CHANGED
@@ -179,8 +179,25 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
179
179
  quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
180
180
  # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
181
181
 
182
+ quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
183
+
184
+ quantile_preds = torch.unflatten(
185
+ quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
186
+ )
187
+ quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
188
+ # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
182
189
  return quantile_preds, hidden_states
183
190
 
191
+ def _forward_model(self, input: torch.Tensor):
192
+ hidden_states = self.input_patch_embedding(input)
193
+
194
+ for block in self.blocks:
195
+ hidden_states = block(hidden_states)
196
+
197
+ hidden_states = self.out_norm(hidden_states)
198
+
199
+ return self.output_patch_embedding(hidden_states)
200
+
184
201
  def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
185
202
  training_quantile_levels = self.config.quantiles
186
203
  if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
tirex/util.py CHANGED
@@ -1,7 +1,611 @@
1
1
  # Copyright (c) NXAI GmbH.
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
+ from collections.abc import Callable
4
5
  from dataclasses import fields
6
+ from functools import partial
7
+ from math import ceil
8
+ from typing import Literal, Optional
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ def frequency_resample(
15
+ ts: torch.Tensor,
16
+ prediction_length: int,
17
+ patch_size: int = 64,
18
+ peak_prominence: float = 0.1,
19
+ selection_method: Literal["low_harmonic", "high_harmonic", "highest_amplitude"] = "low_harmonic",
20
+ min_period: int | None = None,
21
+ max_period: int = 1000,
22
+ bandpass_filter: bool = True,
23
+ nifr_enabled: bool = True,
24
+ nifr_start_integer: int = 2,
25
+ nifr_end_integer: int = 12,
26
+ nifr_clamp_large_factors: bool = False,
27
+ ) -> tuple[torch.Tensor, float, Callable[[torch.Tensor], torch.Tensor]]:
28
+ """
29
+ Downsample a time series according to a frequency-based strategy and return a helper to upsample
30
+ forecasts back to the original resolution.
31
+
32
+ Parameters
33
+ ----------
34
+ ts : torch.Tensor
35
+ 1D time series of shape [T].
36
+ prediction_length : int
37
+ Requested forecast horizon; short horizons (<100) skip resampling.
38
+ patch_size : int, default 64
39
+ Nominal patch size used to align one dominant period to one patch.
40
+ peak_prominence : float, default 0.1
41
+ Threshold for FFT peak detection on the normalized spectrum.
42
+ selection_method : {"low_harmonic", "high_harmonic", "highest_amplitude"}, default "low_harmonic"
43
+ How to resolve two dominant peaks in ~2x harmonic relation.
44
+ min_period : int or None, optional
45
+ Minimum period to consider; if None, defaults to `patch_size`.
46
+ max_period : int, default 1000
47
+ Maximum period to consider for FFT peak search.
48
+ bandpass_filter : bool, default True
49
+ If True, suppresses very low frequencies before peak search.
50
+ nifr_enabled : bool, default True
51
+ Enable nearest-integer-fraction rounding of the factor.
52
+ nifr_start_integer : int, default 2
53
+ Smallest integer k used for 1/k grid when NIFR is enabled.
54
+ nifr_end_integer : int, default 12
55
+ Largest integer k used for 1/k grid when NIFR is enabled.
56
+ nifr_clamp_large_factors : bool, default False
57
+ If True, clamps large factors in [1, 1/nifr_start_integer] to 1.0.
58
+
59
+ Returns
60
+ -------
61
+ resampled_ts : torch.Tensor
62
+ The resampled input series.
63
+ sample_factor : float
64
+ Applied sampling factor (<= 1 means downsampling; 1.0 = identity).
65
+ fc_resample_fn : Callable[[torch.Tensor], torch.Tensor]
66
+ Function that upsamples a forecast back to the original resolution using the inverse factor.
67
+
68
+ Notes
69
+ -----
70
+ - For short horizons (prediction_length < 100), resampling is disabled and the factor is set to 1.0.
71
+ - The factor is clamped to at most 1.0 to avoid upsampling the context.
72
+ """
73
+ sample_factor = frequency_factor(
74
+ ts,
75
+ max_period=max_period,
76
+ min_period=min_period,
77
+ bandpass_filter=bandpass_filter,
78
+ selection_method=selection_method,
79
+ peak_prominence=peak_prominence,
80
+ patch_size=patch_size,
81
+ nifr_enabled=nifr_enabled,
82
+ nifr_start_integer=nifr_start_integer,
83
+ nifr_end_integer=nifr_end_integer,
84
+ nifr_clamp_large_factors=nifr_clamp_large_factors,
85
+ )
86
+
87
+ sample_factor = min(1, sample_factor)
88
+
89
+ if prediction_length < 100:
90
+ # do not resample for short forecasts
91
+ sample_factor = 1.0
92
+
93
+ fc_resample_factor = 1 / sample_factor
94
+ fc_resample_fn = partial(resample, sample_rate=fc_resample_factor)
95
+ resampled_ts = resample(ts, sample_rate=sample_factor)
96
+
97
+ return resampled_ts, sample_factor, fc_resample_fn
98
+
99
+
100
+ def frequency_factor(
101
+ ts: torch.Tensor,
102
+ patch_size: int = 64, # This doesn't have to match model patch size, but rather the 'target frequency'
103
+ peak_prominence: float = 0.1,
104
+ selection_method: Literal["low_harmonic", "high_harmonic", "highest_amplitude"] = "low_harmonic",
105
+ min_period: int | None = None,
106
+ max_period: int = 1000,
107
+ bandpass_filter: bool = True,
108
+ nifr_enabled: bool = False,
109
+ nifr_start_integer: int = 2,
110
+ nifr_end_integer: int = 12,
111
+ nifr_clamp_large_factors: bool = False,
112
+ ) -> float:
113
+ """
114
+ Estimate a sampling factor from the dominant frequency of a 1D series so that one period
115
+ approximately fits into one patch of length `patch_size`.
116
+
117
+ The factor is computed as `patch_size / period`, where `period = 1 / f*` and `f*`
118
+ is the selected dominant frequency from the one-sided FFT of the series (NaNs are
119
+ linearly interpolated for analysis). If two prominent peaks are detected whose
120
+ frequencies are in roughly a 2x harmonic relation (ratio in [1.5, 2.5]),
121
+ `selection_method` determines whether to select the lower or higher harmonic. A set of
122
+ guards returns 1.0 (identity) for short series, invalid/non-finite results, or when no
123
+ prominent peak is found. Optional nearest-integer-fraction rounding (NIFR) can snap the
124
+ factor to the closest value in {1} ∪ {1/k | k ∈ [nifr_start_integer, nifr_end_integer]}.
125
+
126
+ Parameters
127
+ ----------
128
+ ts : torch.Tensor
129
+ Input 1D series (last dim is time). NaNs are linearly interpolated for FFT analysis only;
130
+ the original series is not modified.
131
+ patch_size : int, default 64
132
+ Target number of samples per period.
133
+ peak_prominence : float, default 0.1
134
+ Minimum normalized spectrum height to treat a bin as a peak.
135
+ selection_method : {"low_harmonic", "high_harmonic", "highest_amplitude"}, default "low_harmonic"
136
+ Rule for picking between two ~2x related peaks.
137
+ min_period : int or None, optional
138
+ Minimum period to consider. If None, defaults to `patch_size`.
139
+ max_period : int, default 1000
140
+ Series shorter than `2 * max_period` return 1.0.
141
+ bandpass_filter : bool, default True
142
+ If True, very low frequencies below 1 / max_period are suppressed.
143
+ nifr_enabled : bool, default False
144
+ Enable nearest-integer-fraction rounding of the factor.
145
+ nifr_start_integer : int, default 2
146
+ Smallest integer k used for the 1/k grid when NIFR is enabled.
147
+ nifr_end_integer : int, default 12
148
+ Largest integer k used for the 1/k grid when NIFR is enabled.
149
+ nifr_clamp_large_factors : bool, default False
150
+ If True, clamps factors in [1, 1/nifr_start_integer] to 1.0.
151
+
152
+ Returns
153
+ -------
154
+ float
155
+ The sampling factor. Values <= 0 or non-finite are mapped to 1.0. If no valid
156
+ dominant frequency is found or the series is too short, returns 1.0.
157
+
158
+ Notes
159
+ -----
160
+ - The factor is computed as `patch_size / period`, where `period = 1 / f*` and `f*` is the selected
161
+ dominant FFT frequency.
162
+ - If two prominent peaks are detected ~2x apart, `selection_method` determines whether to select the lower or higher harmonic.
163
+ - Optional nearest-integer-fraction rounding (NIFR) can snap the factor to the closest value in {1} ∪ {1/k | k ∈ [nifr_start_integer, nifr_end_integer]}.
164
+ """
165
+ if min_period is None:
166
+ # NOTE: be careful when min_period is not matching patch_size, it can create unexpected scaling factors!
167
+ min_period = patch_size
168
+
169
+ # Ensure CPU numpy array for FFT analysis
170
+ ts_np = ts.detach().cpu().numpy() if isinstance(ts, torch.Tensor) else np.asarray(ts)
171
+
172
+ # NOTE: If the series is shorter than max_period *2, FFT may not be accurate, to avoid detecting these peaks, we don't scale
173
+ if ts_np.size < max_period * 2:
174
+ return 1.0
175
+
176
+ freqs, specs, peak_idc = run_fft_analysis(
177
+ ts_np,
178
+ scaling="amplitude",
179
+ peak_prominence=peak_prominence,
180
+ min_period=min_period,
181
+ max_period=max_period,
182
+ bandpass_filter=bandpass_filter,
183
+ )
184
+
185
+ # No detectable peaks -> keep original sampling
186
+ if peak_idc.size == 0:
187
+ return 1.0
188
+
189
+ # Choose initial candidate as the highest-amplitude peak
190
+ chosen_idx = int(peak_idc[0])
191
+
192
+ # If two peaks exist, check for ~2x harmonic relation and prefer the higher/lower one
193
+ if peak_idc.size >= 2:
194
+ idx_a = int(peak_idc[0]) # highest amplitude
195
+ idx_b = int(peak_idc[1]) # second highest amplitude
196
+ f_a = float(freqs[idx_a])
197
+ f_b = float(freqs[idx_b])
198
+
199
+ # Determine lower/higher frequency
200
+ low_f = min(f_a, f_b)
201
+ high_f = max(f_a, f_b)
202
+
203
+ if low_f > 0:
204
+ ratio = high_f / low_f
205
+ # Roughly half relation
206
+ if 1.5 <= ratio <= 2.5:
207
+ if selection_method == "low_harmonic":
208
+ chosen_idx = idx_a if f_a < f_b else idx_b
209
+ elif selection_method == "high_harmonic":
210
+ chosen_idx = idx_a if f_a > f_b else idx_b
211
+
212
+ chosen_freq = float(freqs[chosen_idx])
213
+
214
+ # Guard against zero or non-finite frequency
215
+ if not np.isfinite(chosen_freq) or chosen_freq <= 0:
216
+ return 1.0
217
+
218
+ # Convert to period and compute scaling factor so one period fits one patch
219
+ period = 1.0 / chosen_freq
220
+ factor = resampling_factor(period, patch_size)
221
+ factor = round(factor, 4)
222
+
223
+ # Guard against factor being negative
224
+ if not np.isfinite(factor) or factor <= 0:
225
+ return 1.0
226
+
227
+ # nearest interger fraction rounding (nifr)
228
+ if nifr_enabled:
229
+ int_fractions = np.concatenate([[1], 1 / np.arange(nifr_start_integer, nifr_end_integer + 1)])
230
+ diff = np.abs(factor - int_fractions)
231
+ min_diff_idc = np.argmin(diff)
232
+ factor = int_fractions[min_diff_idc]
233
+
234
+ if nifr_clamp_large_factors:
235
+ # Clamp everything between 1 and 1/nifr_start_integer to 1, that is no scaling
236
+ factor = factor if factor < int_fractions[1] else 1
237
+
238
+ return float(factor)
239
+
240
+
241
+ def resample(ts: torch.Tensor, sample_rate: float, window_position: str = "center") -> torch.Tensor:
242
+ """
243
+ Resample the time series using NaN-tolerant window averaging with size 1/sample_rate.
244
+
245
+ - If sample_rate > 1 the series is upsampled; windows may collapse to a single index.
246
+ - If sample_rate < 1 the series is downsampled; windows span multiple indices.
247
+ - If sample_rate == 1 the series is returned unchanged (cast to float for NaN support).
248
+
249
+ Window alignment controlled by `window_position`:
250
+ - "center": average over [c - L/2, c + L/2]
251
+ - "left" : average over [c - L, c]
252
+ - "right" : average over [c, c + L]
253
+
254
+ The window is truncated at the boundaries. NaNs are ignored via nan-mean semantics;
255
+ if a window contains only NaNs, the output is NaN.
256
+
257
+ Arguments:
258
+ ----------
259
+ ts: torch.Tensor of shape [..., T]
260
+ The time series to be rescaled (last dim is time).
261
+ sample_rate: float
262
+ The factor determining the final number of timesteps in the series, i.e., T' = ceil(T * sample_rate).
263
+ window_position: {"center", "left", "right"}
264
+ Placement of the averaging window relative to each target coordinate.
265
+
266
+ Returns:
267
+ --------
268
+ torch.Tensor of shape [..., ceil(T * sample_rate)] with dtype float.
269
+ """
270
+ # Validate inputs
271
+ if sample_rate <= 0 or sample_rate == 1:
272
+ # Invalid or no scaling; return original as float
273
+ return ts.to(torch.float)
274
+
275
+ src_num_timesteps = ts.shape[-1]
276
+ tgt_num_timesteps = ceil(src_num_timesteps * sample_rate)
277
+
278
+ # Do not change coordinate creation logic
279
+ src_coords = torch.arange(src_num_timesteps, device=ts.device)
280
+ tgt_coords = torch.linspace(0, src_num_timesteps - 1, tgt_num_timesteps, device=ts.device)
281
+
282
+ if sample_rate == 1:
283
+ return ts.to(torch.float)
284
+
285
+ # Branch: upsampling -> linear interpolation between nearest neighbors (NaN-aware)
286
+ if sample_rate > 1:
287
+ # Neighbour indices for each target coordinate along the last dimension
288
+ tgt_in_src_idx_lo = tgt_coords.floor().to(torch.long)
289
+ tgt_in_src_idx_hi = tgt_coords.ceil().to(torch.long)
290
+
291
+ # Distances in index space and offsets from lower index
292
+ dist = src_coords[tgt_in_src_idx_hi] - src_coords[tgt_in_src_idx_lo]
293
+
294
+ # Work in float for NaN support; gather neighbour values
295
+ src_lo_vals = ts[..., tgt_in_src_idx_lo].to(torch.float)
296
+ src_hi_vals = ts[..., tgt_in_src_idx_hi].to(torch.float)
297
+ diff = src_hi_vals - src_lo_vals
298
+ offset = tgt_coords - src_coords[tgt_in_src_idx_lo]
299
+
300
+ # Allocate output
301
+ tgt_values = torch.empty(*ts.shape[:-1], tgt_num_timesteps, dtype=torch.float, device=ts.device)
302
+
303
+ # Masks
304
+ exact_mask = dist == 0
305
+ interp_mask = ~exact_mask
306
+
307
+ # Exact source index -> take the source value
308
+ if exact_mask.any():
309
+ tgt_values[..., exact_mask] = src_lo_vals[..., exact_mask]
310
+
311
+ # Linear interpolate where indices differ
312
+ if interp_mask.any():
313
+ tgt_values[..., interp_mask] = (
314
+ diff[..., interp_mask] / dist[interp_mask].to(torch.float) * offset[interp_mask]
315
+ + src_lo_vals[..., interp_mask]
316
+ )
317
+
318
+ # Propagate NaNs from either neighbour
319
+ nan_mask = torch.isnan(src_lo_vals) | torch.isnan(src_hi_vals)
320
+ if nan_mask.any():
321
+ tgt_values[..., nan_mask] = torch.nan
322
+
323
+ return tgt_values
324
+
325
+ # Window length in source-index units
326
+ L = 1.0 / sample_rate
327
+ half_L = 0.5 * L
328
+
329
+ if window_position == "center":
330
+ left_f = tgt_coords - half_L
331
+ right_f = tgt_coords + half_L
332
+ elif window_position == "left":
333
+ left_f = tgt_coords - L
334
+ right_f = tgt_coords
335
+ elif window_position == "right":
336
+ left_f = tgt_coords
337
+ right_f = tgt_coords + L
338
+ else:
339
+ raise ValueError("window_position must be one of {'center','left','right'}")
340
+
341
+ # Convert to integer indices, inclusive bounds
342
+ left_idx = torch.ceil(left_f).to(torch.long)
343
+ right_idx = torch.floor(right_f).to(torch.long)
344
+
345
+ # Clip to valid range and ensure non-empty windows (at least one index)
346
+ left_idx = torch.clamp(left_idx, 0, src_num_timesteps - 1)
347
+ right_idx = torch.clamp(right_idx, 0, src_num_timesteps - 1)
348
+ right_idx = torch.maximum(right_idx, left_idx)
349
+
350
+ # Prepare cumulative sums for fast [l, r] segment nan-mean along the last dim
351
+ ts_float = ts.to(torch.float)
352
+ valid_mask = ~torch.isnan(ts_float)
353
+
354
+ values_filled = torch.where(valid_mask, ts_float, torch.zeros_like(ts_float))
355
+ counts = valid_mask.to(torch.float)
356
+
357
+ cumsum_vals = values_filled.cumsum(dim=-1)
358
+ cumsum_cnts = counts.cumsum(dim=-1)
359
+
360
+ # Pad a leading zero to make inclusive range sums easy: sum[l:r] = cs[r] - cs[l-1]
361
+ pad_shape = (*ts.shape[:-1], 1)
362
+ zeros_vals = torch.zeros(pad_shape, dtype=cumsum_vals.dtype, device=ts.device)
363
+ zeros_cnts = torch.zeros(pad_shape, dtype=cumsum_cnts.dtype, device=ts.device)
364
+ cumsum_vals = torch.cat([zeros_vals, cumsum_vals], dim=-1)
365
+ cumsum_cnts = torch.cat([zeros_cnts, cumsum_cnts], dim=-1)
366
+
367
+ # Build broadcastable indices for gather along the last dim
368
+ prefix_shape = ts.shape[:-1]
369
+ target_len = tgt_num_timesteps
370
+
371
+ def _expand_index(idx: torch.Tensor) -> torch.Tensor:
372
+ # idx shape [target_len] -> [..., target_len]
373
+ view_shape = (1,) * len(prefix_shape) + (target_len,)
374
+ return idx.view(view_shape).expand(*prefix_shape, target_len)
375
+
376
+ # For inclusive [l, r], use cumsum at (r+1) and (l)
377
+ r_plus1 = torch.clamp(right_idx + 1, 0, src_num_timesteps)
378
+ l_idx = left_idx
379
+
380
+ r_plus1_exp = _expand_index(r_plus1)
381
+ l_exp = _expand_index(l_idx)
382
+
383
+ seg_sums = cumsum_vals.gather(dim=-1, index=r_plus1_exp) - cumsum_vals.gather(dim=-1, index=l_exp)
384
+ seg_cnts = cumsum_cnts.gather(dim=-1, index=r_plus1_exp) - cumsum_cnts.gather(dim=-1, index=l_exp)
385
+
386
+ # Compute nan-mean: where count==0 -> NaN
387
+ with torch.no_grad():
388
+ safe_cnts = torch.where(seg_cnts > 0, seg_cnts, torch.ones_like(seg_cnts))
389
+ averages = seg_sums / safe_cnts
390
+ averages = torch.where(seg_cnts > 0, averages, torch.full_like(averages, float("nan")))
391
+
392
+ return averages
393
+
394
+
395
+ def run_fft_analysis(
396
+ y,
397
+ dt: float = 1.0,
398
+ window: str = "hann",
399
+ detrend: bool = True,
400
+ scaling: str = "amplitude",
401
+ peak_prominence: float = 0.1,
402
+ min_period: int = 64,
403
+ max_period: int = 1000,
404
+ bandpass_filter: bool = True,
405
+ ):
406
+ """
407
+ Compute one-sided FFT frequencies and spectrum magnitude for a real 1D signal.
408
+
409
+ Parameters
410
+ ----------
411
+ y : array_like
412
+ 1D time series (regularly sampled). NaNs will be linearly interpolated.
413
+ dt : float
414
+ Sampling period (time between samples). Frequencies are cycles per unit of dt.
415
+ window : {'hann', None}
416
+ Optional taper to reduce leakage.
417
+ detrend : bool
418
+ If True, remove the mean before FFT.
419
+ scaling : {'amplitude', 'power', 'raw'}
420
+ - 'amplitude': one-sided amplitude spectrum with window-power compensation.
421
+ - 'power' : one-sided power (not density) with window-power compensation.
422
+ - 'raw' : |rfft(yw)| (no normalization, mostly for debugging).
423
+ peak_prominence : float
424
+ Absolute threshold on the normalized spectrum for peak detection.
425
+
426
+ Returns
427
+ -------
428
+ f : ndarray
429
+ Frequencies (non-negative), length N//2 + 1, in cycles per unit (1/dt).
430
+ spec : ndarray
431
+ Spectrum corresponding to `f` under the chosen `scaling`.
432
+ peaks_idx : ndarray
433
+ Indices into f of detected peaks.
434
+ """
435
+ y = np.asarray(y, dtype=float)
436
+ if y.ndim != 1:
437
+ y = y.reshape(-1)
438
+ n = y.size
439
+ if n < 2:
440
+ return np.array([]), np.array([]), np.array([])
441
+
442
+ # Fill NaNs linearly (handles edge NaNs as well)
443
+ y = _nan_linear_interpolate(y)
444
+
445
+ if detrend:
446
+ y = y - np.mean(y)
447
+
448
+ # Windowing
449
+ if window == "hann":
450
+ w = np.hanning(n)
451
+ yw = y * w
452
+ # average window power (for proper amplitude/power normalization)
453
+ w_power = np.sum(w**2) / n
454
+ elif window is None:
455
+ yw = y
456
+ w_power = 1.0
457
+ else:
458
+ raise ValueError("window must be either 'hann' or None")
459
+
460
+ # FFT (one-sided)
461
+ Y = np.fft.rfft(yw)
462
+ f = np.fft.rfftfreq(n, d=dt) # cycles per unit time
463
+
464
+ if scaling == "raw":
465
+ spec = np.abs(Y)
466
+ elif scaling == "amplitude":
467
+ # One-sided amplitude with window power compensation
468
+ spec = np.abs(Y) / (n * np.sqrt(w_power))
469
+ if n % 2 == 0:
470
+ spec[1:-1] *= 2.0
471
+ else:
472
+ spec[1:] *= 2.0
473
+ elif scaling == "power":
474
+ # One-sided power (not PSD)
475
+ spec = (np.abs(Y) ** 2) / (n**2 * w_power)
476
+ if n % 2 == 0:
477
+ spec[1:-1] *= 2.0
478
+ else:
479
+ spec[1:] *= 2.0
480
+ else:
481
+ raise ValueError("scaling must be 'amplitude', 'power', or 'raw'")
482
+
483
+ # Normalize the spectrum by its maximum value
484
+ if spec.max() > 0:
485
+ spec = spec / spec.max()
486
+
487
+ # Find peaks in the spectrum
488
+ peaks_idx = custom_find_peaks(
489
+ f,
490
+ spec,
491
+ max_peaks=2,
492
+ prominence_threshold=peak_prominence,
493
+ min_period=min_period,
494
+ max_period=max_period,
495
+ bandpass_filter=bandpass_filter,
496
+ )
497
+
498
+ return f, spec, peaks_idx
499
+
500
+
501
+ def _nan_linear_interpolate(y: np.ndarray) -> np.ndarray:
502
+ y = y.astype(np.float32)
503
+ if y.ndim != 1:
504
+ y = y.reshape(-1)
505
+ n = y.size
506
+ mask = np.isfinite(y)
507
+ if mask.all():
508
+ return y
509
+ if (~mask).all():
510
+ return np.zeros(n, dtype=np.float32)
511
+ idx = np.arange(n)
512
+ y_interp = y.copy()
513
+ y_interp[~mask] = np.interp(idx[~mask], idx[mask], y[mask])
514
+ return y_interp
515
+
516
+
517
+ def resampling_factor(inverted_freq, path_size):
518
+ """
519
+ Compute the resampling factor based on the inverted frequency and path size.
520
+ """
521
+ if inverted_freq <= 0:
522
+ return 1.0
523
+ factor = path_size / inverted_freq
524
+ return factor
525
+
526
+
527
+ def custom_find_peaks(
528
+ f,
529
+ spec,
530
+ *,
531
+ max_peaks=5,
532
+ prominence_threshold=0.1,
533
+ min_period=64,
534
+ max_period=1000,
535
+ bandpass_filter=True,
536
+ ):
537
+ """
538
+ Finds prominent peaks in a spectrum using a simple custom logic.
539
+
540
+ A peak is a local maximum. A peak is considered prominent if its height
541
+ (on a normalized spectrum) is greater than a given threshold.
542
+
543
+ Parameters
544
+ ----------
545
+ f : np.ndarray
546
+ Frequency array (currently unused but kept for API consistency).
547
+ spec : np.ndarray
548
+ The normalized spectrum.
549
+ max_peaks : int
550
+ The maximum number of peaks to return.
551
+ prominence_threshold : float
552
+ The minimum height for a peak to be considered prominent.
553
+
554
+ Returns
555
+ -------
556
+ np.ndarray
557
+ An array of indices of the detected peaks in the spectrum. Returns an
558
+ empty array if no prominent peaks are found.
559
+ """
560
+ if len(spec) < 5: # Need at least 5 points to exclude last two bins
561
+ return np.array([], dtype=int)
562
+
563
+ if bandpass_filter: # only truly filter low frequencies, high frequencies are dealt with later
564
+ min_freq = 1 / max_period
565
+ freq_mask = f >= min_freq
566
+ spec = spec * freq_mask
567
+
568
+ # Find all local maxima, excluding the last two bins
569
+ local_maxima_indices = []
570
+ for i in range(1, len(spec) - 2):
571
+ if spec[i] > spec[i - 1] and spec[i] > spec[i + 1]:
572
+ local_maxima_indices.append(i)
573
+
574
+ if not local_maxima_indices:
575
+ return np.array([], dtype=int)
576
+
577
+ # Filter by prominence (height)
578
+ prominent_peaks = []
579
+ for idx in local_maxima_indices:
580
+ if spec[idx] > prominence_threshold:
581
+ prominent_peaks.append((idx, spec[idx]))
582
+
583
+ # If no peaks are above the threshold, return an empty list
584
+ if not prominent_peaks:
585
+ return np.array([], dtype=int)
586
+
587
+ # Check for clear peaks below min_period (do lowpass filter)
588
+ for idx, _ in prominent_peaks:
589
+ period = 1 / f[idx]
590
+ if period < min_period:
591
+ return np.array([], dtype=int)
592
+
593
+ # Filter by period
594
+ period_filtered_peaks = []
595
+ for idx, prominence in prominent_peaks:
596
+ period = 1 / f[idx]
597
+
598
+ if min_period <= period <= max_period:
599
+ period_filtered_peaks.append((idx, prominence))
600
+
601
+ if not period_filtered_peaks:
602
+ return np.array([], dtype=int)
603
+
604
+ # Sort by height and return the top `max_peaks`
605
+ period_filtered_peaks.sort(key=lambda x: x[1], reverse=True)
606
+ peak_indices = np.array([p[0] for p in period_filtered_peaks[:max_peaks]], dtype=int)
607
+
608
+ return peak_indices
5
609
 
6
610
 
7
611
  def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.2
3
+ Version: 2025.10.7
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -0,0 +1,21 @@
1
+ tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
+ tirex/base.py,sha256=JZuWzezuDenrlDekMaaPLBWzB9mw6vJ_qF2GWjou8ws,4229
3
+ tirex/util.py,sha256=AQ8eM-7IpGeAqegVqflrjPQITzIEKltRcJOBbJmikcA,22758
4
+ tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
+ tirex/api_adapter/forecast.py,sha256=CpX9YfiQ1X6nvzODklCPzuKPlYidJDha8FnEU-6zr1Q,13919
6
+ tirex/api_adapter/gluon.py,sha256=CbqL8ZgSRvxWAHK6TWqKdxUrMVWmGskRWOHNF84Lh1U,1819
7
+ tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
8
+ tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
9
+ tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
+ tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
+ tirex/models/tirex.py,sha256=JKNuCzTI6B9_yCbcmTf2UFjAQXulLNEmloqtAhKJKjQ,9830
12
+ tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
13
+ tirex/models/slstm/cell.py,sha256=JfCs1aUy9IHuz9RwExhUwiUtbg8WmbEg4upcO7hA5Rg,7229
14
+ tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
+ tirex_mirror-2025.10.7.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
+ tirex_mirror-2025.10.7.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
+ tirex_mirror-2025.10.7.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
+ tirex_mirror-2025.10.7.dist-info/METADATA,sha256=KdlklRtYjCWVG3GB12RFNi1fPxYO0qmVqlrq6r7mURs,11443
19
+ tirex_mirror-2025.10.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ tirex_mirror-2025.10.7.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
+ tirex_mirror-2025.10.7.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
- tirex/base.py,sha256=fwyUTGL103kK5jgK5MoLSIHQcZb4lrox_D9fNbY1W1k,3507
3
- tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
4
- tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
- tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
6
- tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
7
- tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
8
- tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
9
- tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
- tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
- tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
12
- tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
13
- tirex/models/slstm/cell.py,sha256=ippaAPKI83j3_1l3pu9ks-iBGO641Elm1W4HsHgVu-c,7601
14
- tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
- tirex_mirror-2025.10.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
- tirex_mirror-2025.10.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
- tirex_mirror-2025.10.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
- tirex_mirror-2025.10.2.dist-info/METADATA,sha256=Aq9VAU0pojVsrxwbfFLCmwmk1Gfl_Z_49G4yB-Z9eLY,11443
19
- tirex_mirror-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- tirex_mirror-2025.10.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
- tirex_mirror-2025.10.2.dist-info/RECORD,,