tirex-mirror 2025.10.2__py3-none-any.whl → 2025.10.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tirex/api_adapter/forecast.py +163 -5
- tirex/api_adapter/gluon.py +2 -2
- tirex/api_adapter/hf_data.py +2 -2
- tirex/api_adapter/standard_adapter.py +45 -22
- tirex/base.py +26 -5
- tirex/models/slstm/cell.py +66 -77
- tirex/models/tirex.py +17 -0
- tirex/util.py +604 -0
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/METADATA +1 -1
- tirex_mirror-2025.10.7.dist-info/RECORD +21 -0
- tirex_mirror-2025.10.2.dist-info/RECORD +0 -21
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/WHEEL +0 -0
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/licenses/LICENSE +0 -0
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/licenses/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/licenses/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/top_level.txt +0 -0
tirex/api_adapter/forecast.py
CHANGED
|
@@ -2,15 +2,22 @@
|
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from
|
|
5
|
+
from functools import partial
|
|
6
|
+
from math import ceil
|
|
7
|
+
from typing import Literal, Optional
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
10
|
|
|
11
|
+
from tirex.util import frequency_resample
|
|
12
|
+
|
|
9
13
|
from .standard_adapter import ContextType, get_batches
|
|
10
14
|
|
|
11
15
|
DEF_TARGET_COLUMN = "target"
|
|
12
16
|
DEF_META_COLUMNS = ("start", "item_id")
|
|
13
17
|
|
|
18
|
+
# Allowed resampling strategies (extend as new strategies are implemented)
|
|
19
|
+
RESAMPLE_STRATEGIES: list[str] = ["frequency"]
|
|
20
|
+
|
|
14
21
|
|
|
15
22
|
def _format_output(
|
|
16
23
|
quantiles: torch.Tensor,
|
|
@@ -33,6 +40,27 @@ def _format_output(
|
|
|
33
40
|
raise ValueError(f"Invalid output type: {output_type}")
|
|
34
41
|
|
|
35
42
|
|
|
43
|
+
def _pad_time_series_batch(
|
|
44
|
+
batch_series: list[torch.Tensor],
|
|
45
|
+
max_length: int,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
if not batch_series:
|
|
48
|
+
return torch.empty((0, max_length))
|
|
49
|
+
|
|
50
|
+
first = batch_series[0]
|
|
51
|
+
dtype = first.dtype if first.is_floating_point() else torch.float32
|
|
52
|
+
device = first.device
|
|
53
|
+
|
|
54
|
+
padded = torch.full((len(batch_series), max_length), float("nan"), dtype=dtype, device=device)
|
|
55
|
+
|
|
56
|
+
for idx, series in enumerate(batch_series):
|
|
57
|
+
series = series.to(padded.dtype)
|
|
58
|
+
series_len = series.shape[0]
|
|
59
|
+
padded[idx, max_length - series_len :] = series
|
|
60
|
+
|
|
61
|
+
return padded
|
|
62
|
+
|
|
63
|
+
|
|
36
64
|
def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
|
|
37
65
|
for batch_ctx, batch_meta in batches:
|
|
38
66
|
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
|
|
@@ -45,7 +73,105 @@ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwar
|
|
|
45
73
|
)
|
|
46
74
|
|
|
47
75
|
|
|
48
|
-
def
|
|
76
|
+
def _call_fc_with_padding(fc_func, batch_series: list[torch.Tensor], **predict_kwargs):
|
|
77
|
+
if not batch_series:
|
|
78
|
+
raise ValueError("Received empty batch for forecasting")
|
|
79
|
+
|
|
80
|
+
max_len = max(series.shape[0] for series in batch_series)
|
|
81
|
+
padded_ts = _pad_time_series_batch(batch_series, max_len)
|
|
82
|
+
|
|
83
|
+
return fc_func(padded_ts, **predict_kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _resample_fc_func_wrapper(
|
|
87
|
+
fc_func,
|
|
88
|
+
batch,
|
|
89
|
+
resample_strategy: str,
|
|
90
|
+
max_context: int = 2016,
|
|
91
|
+
**predict_kwargs,
|
|
92
|
+
):
|
|
93
|
+
# downsample the time series based on the dominant frequencies, if enabled
|
|
94
|
+
max_period = (max_context // 1000) * 500
|
|
95
|
+
prediction_length = predict_kwargs.get("prediction_length", 100)
|
|
96
|
+
batch_resampled_ts: list[torch.Tensor] = []
|
|
97
|
+
fc_resample_fns = []
|
|
98
|
+
scaling_factors = []
|
|
99
|
+
|
|
100
|
+
# select the function doing the resampling
|
|
101
|
+
ctx_resample_fn = lambda x: (x, 1.0, (lambda y: y))
|
|
102
|
+
match resample_strategy:
|
|
103
|
+
case "frequency":
|
|
104
|
+
ctx_resample_fn = frequency_resample
|
|
105
|
+
case _:
|
|
106
|
+
raise RuntimeError("This shouldn't happen.")
|
|
107
|
+
|
|
108
|
+
for series in batch:
|
|
109
|
+
resampled_ts, _sample_factor, fc_resample_fn = ctx_resample_fn(
|
|
110
|
+
series,
|
|
111
|
+
prediction_length=prediction_length,
|
|
112
|
+
max_period=max_period,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
batch_resampled_ts.append(resampled_ts)
|
|
116
|
+
fc_resample_fns.append(fc_resample_fn)
|
|
117
|
+
scaling_factors.append(_sample_factor)
|
|
118
|
+
|
|
119
|
+
# Compute per-item required horizons (in downsampled domain)
|
|
120
|
+
per_item_pred_lens = [int(ceil(prediction_length * sf)) for sf in scaling_factors]
|
|
121
|
+
max_pred_len = max(per_item_pred_lens) if per_item_pred_lens else int(prediction_length)
|
|
122
|
+
predict_kwargs.update(prediction_length=max_pred_len)
|
|
123
|
+
|
|
124
|
+
max_ts_length = max(ts.shape[0] for ts in batch_resampled_ts)
|
|
125
|
+
padded_ts = _pad_time_series_batch(batch_resampled_ts, max_ts_length)
|
|
126
|
+
print(f"Average sample batch factor: {sum(scaling_factors) / len(scaling_factors)}")
|
|
127
|
+
|
|
128
|
+
# generate prediction
|
|
129
|
+
fc_quantiles, fc_mean = fc_func(padded_ts, **predict_kwargs)
|
|
130
|
+
|
|
131
|
+
batch_prediction_q = []
|
|
132
|
+
batch_prediction_m = []
|
|
133
|
+
for el_q, el_m, fc_resample_fn, item_pred_len in zip(fc_quantiles, fc_mean, fc_resample_fns, per_item_pred_lens):
|
|
134
|
+
# truncate the forecasts to their individual sample factor adjusted prediction lengths
|
|
135
|
+
el_q = el_q[:item_pred_len, ...] # [T, Q]
|
|
136
|
+
el_m = el_m[:item_pred_len] # [T]
|
|
137
|
+
|
|
138
|
+
# upsample prediction
|
|
139
|
+
quantiles = fc_resample_fn(el_q.squeeze(0).transpose(0, 1)).transpose(0, 1) # [T, Q]
|
|
140
|
+
mean = fc_resample_fn(el_m.squeeze(0))
|
|
141
|
+
|
|
142
|
+
quantiles = quantiles[:prediction_length, ...]
|
|
143
|
+
mean = mean[:prediction_length]
|
|
144
|
+
|
|
145
|
+
batch_prediction_q.append(quantiles)
|
|
146
|
+
batch_prediction_m.append(mean)
|
|
147
|
+
|
|
148
|
+
return torch.stack(batch_prediction_q, dim=0), torch.stack(batch_prediction_m, dim=0)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _gen_forecast(
|
|
152
|
+
fc_func,
|
|
153
|
+
batches,
|
|
154
|
+
output_type,
|
|
155
|
+
quantile_levels,
|
|
156
|
+
yield_per_batch,
|
|
157
|
+
resample_strategy: str | None = None,
|
|
158
|
+
max_context: int = 2016,
|
|
159
|
+
**predict_kwargs,
|
|
160
|
+
):
|
|
161
|
+
base_fc_func = fc_func
|
|
162
|
+
|
|
163
|
+
if resample_strategy is not None:
|
|
164
|
+
if resample_strategy not in RESAMPLE_STRATEGIES:
|
|
165
|
+
raise ValueError(f"Invalid resample strategy: {resample_strategy}. Allowed: {RESAMPLE_STRATEGIES}")
|
|
166
|
+
fc_func = partial(
|
|
167
|
+
_resample_fc_func_wrapper,
|
|
168
|
+
base_fc_func,
|
|
169
|
+
resample_strategy=resample_strategy,
|
|
170
|
+
max_context=max_context,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
fc_func = partial(_call_fc_with_padding, base_fc_func)
|
|
174
|
+
|
|
49
175
|
if yield_per_batch:
|
|
50
176
|
return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
|
|
51
177
|
|
|
@@ -92,6 +218,9 @@ def _common_forecast_doc():
|
|
|
92
218
|
forecasts batch by batch as they are computed.
|
|
93
219
|
Defaults to `False`.
|
|
94
220
|
|
|
221
|
+
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
|
|
222
|
+
If `None`, no resampling is applied. Currently only "frequency" is supported.
|
|
223
|
+
|
|
95
224
|
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
|
96
225
|
prediction mechanism of the pre-trained model. Refer to the model's
|
|
97
226
|
internal prediction method documentation for available options.
|
|
@@ -113,6 +242,11 @@ class ForecastModel(ABC):
|
|
|
113
242
|
def _forecast_quantiles(self, batch, **predict_kwargs):
|
|
114
243
|
pass
|
|
115
244
|
|
|
245
|
+
@property
|
|
246
|
+
def max_context_length(self) -> int:
|
|
247
|
+
# retrieve the max_context attribute of the model configuration if present
|
|
248
|
+
return getattr(getattr(self, "config", None), "max_context", 2016)
|
|
249
|
+
|
|
116
250
|
def forecast(
|
|
117
251
|
self,
|
|
118
252
|
context: ContextType,
|
|
@@ -120,6 +254,7 @@ class ForecastModel(ABC):
|
|
|
120
254
|
batch_size: int = 512,
|
|
121
255
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
122
256
|
yield_per_batch: bool = False,
|
|
257
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
123
258
|
**predict_kwargs,
|
|
124
259
|
):
|
|
125
260
|
f"""
|
|
@@ -134,7 +269,14 @@ class ForecastModel(ABC):
|
|
|
134
269
|
assert batch_size >= 1, "Batch size must be >= 1"
|
|
135
270
|
batches = get_batches(context, batch_size)
|
|
136
271
|
return _gen_forecast(
|
|
137
|
-
self._forecast_quantiles,
|
|
272
|
+
self._forecast_quantiles,
|
|
273
|
+
batches,
|
|
274
|
+
output_type,
|
|
275
|
+
quantile_levels,
|
|
276
|
+
yield_per_batch,
|
|
277
|
+
resample_strategy=resample_strategy,
|
|
278
|
+
max_context=self.max_context_length,
|
|
279
|
+
**predict_kwargs,
|
|
138
280
|
)
|
|
139
281
|
|
|
140
282
|
def forecast_gluon(
|
|
@@ -144,6 +286,7 @@ class ForecastModel(ABC):
|
|
|
144
286
|
batch_size: int = 512,
|
|
145
287
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
146
288
|
yield_per_batch: bool = False,
|
|
289
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
147
290
|
data_kwargs: dict = {},
|
|
148
291
|
**predict_kwargs,
|
|
149
292
|
):
|
|
@@ -165,7 +308,14 @@ class ForecastModel(ABC):
|
|
|
165
308
|
|
|
166
309
|
batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
|
|
167
310
|
return _gen_forecast(
|
|
168
|
-
self._forecast_quantiles,
|
|
311
|
+
self._forecast_quantiles,
|
|
312
|
+
batches,
|
|
313
|
+
output_type,
|
|
314
|
+
quantile_levels,
|
|
315
|
+
yield_per_batch,
|
|
316
|
+
resample_strategy=resample_strategy,
|
|
317
|
+
max_context=self.max_context_length,
|
|
318
|
+
**predict_kwargs,
|
|
169
319
|
)
|
|
170
320
|
|
|
171
321
|
def forecast_hfdata(
|
|
@@ -175,6 +325,7 @@ class ForecastModel(ABC):
|
|
|
175
325
|
batch_size: int = 512,
|
|
176
326
|
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
177
327
|
yield_per_batch: bool = False,
|
|
328
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
178
329
|
data_kwargs: dict = {},
|
|
179
330
|
**predict_kwargs,
|
|
180
331
|
):
|
|
@@ -198,5 +349,12 @@ class ForecastModel(ABC):
|
|
|
198
349
|
|
|
199
350
|
batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
|
|
200
351
|
return _gen_forecast(
|
|
201
|
-
self._forecast_quantiles,
|
|
352
|
+
self._forecast_quantiles,
|
|
353
|
+
batches,
|
|
354
|
+
output_type,
|
|
355
|
+
quantile_levels,
|
|
356
|
+
yield_per_batch,
|
|
357
|
+
resample_strategy=resample_strategy,
|
|
358
|
+
max_context=self.max_context_length,
|
|
359
|
+
**predict_kwargs,
|
|
202
360
|
)
|
tirex/api_adapter/gluon.py
CHANGED
|
@@ -7,7 +7,7 @@ from gluonts.dataset.common import Dataset
|
|
|
7
7
|
from gluonts.dataset.field_names import FieldName
|
|
8
8
|
from gluonts.model.forecast import QuantileForecast
|
|
9
9
|
|
|
10
|
-
from .standard_adapter import
|
|
10
|
+
from .standard_adapter import _batch_iterable
|
|
11
11
|
|
|
12
12
|
DEF_TARGET_COLUMN = FieldName.TARGET # target
|
|
13
13
|
DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
|
|
@@ -27,7 +27,7 @@ def _get_gluon_ts_map(**gluon_kwargs):
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
|
|
30
|
-
return
|
|
30
|
+
return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
|
tirex/api_adapter/hf_data.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
import datasets
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .standard_adapter import
|
|
7
|
+
from .standard_adapter import _batch_iterable
|
|
8
8
|
|
|
9
9
|
DEF_TARGET_COLUMN = "target"
|
|
10
10
|
|
|
@@ -35,4 +35,4 @@ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
|
|
|
35
35
|
|
|
36
36
|
def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
|
|
37
37
|
dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
|
|
38
|
-
return
|
|
38
|
+
return _batch_iterable(map(map_func, dataset), batch_size)
|
|
@@ -16,13 +16,36 @@ ContextType = Union[
|
|
|
16
16
|
]
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
20
|
-
if
|
|
21
|
-
|
|
19
|
+
def _ensure_1d_tensor(sample) -> torch.Tensor:
|
|
20
|
+
if isinstance(sample, torch.Tensor):
|
|
21
|
+
tensor = sample
|
|
22
22
|
else:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
tensor = torch.as_tensor(sample)
|
|
24
|
+
|
|
25
|
+
if tensor.ndim > 1:
|
|
26
|
+
tensor = tensor.squeeze()
|
|
27
|
+
|
|
28
|
+
assert tensor.ndim == 1, "Each sample must be one-dimensional"
|
|
29
|
+
return tensor
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _batched_slice(
|
|
33
|
+
full_batch,
|
|
34
|
+
full_meta: list[dict] | None,
|
|
35
|
+
batch_size: int,
|
|
36
|
+
) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
|
|
37
|
+
total = len(full_batch)
|
|
38
|
+
for start in range(0, total, batch_size):
|
|
39
|
+
batch = full_batch[start : start + batch_size]
|
|
40
|
+
meta = full_meta[start : start + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]
|
|
41
|
+
|
|
42
|
+
batch_series = []
|
|
43
|
+
for idx in range(len(batch)):
|
|
44
|
+
sample = batch[idx]
|
|
45
|
+
tensor = _ensure_1d_tensor(sample)
|
|
46
|
+
batch_series.append(tensor)
|
|
47
|
+
|
|
48
|
+
yield batch_series, meta
|
|
26
49
|
|
|
27
50
|
|
|
28
51
|
def _batched(iterable: Iterable, n: int):
|
|
@@ -31,21 +54,21 @@ def _batched(iterable: Iterable, n: int):
|
|
|
31
54
|
yield batch
|
|
32
55
|
|
|
33
56
|
|
|
34
|
-
def
|
|
57
|
+
def _batch_iterable(
|
|
58
|
+
iterable: Iterable[tuple[torch.Tensor, dict | None]],
|
|
59
|
+
batch_size: int,
|
|
60
|
+
) -> Iterator[tuple[list[torch.Tensor], list[dict]]]:
|
|
35
61
|
for batch in _batched(iterable, batch_size):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
meta
|
|
40
|
-
|
|
41
|
-
sample
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
padded_batch.append(torch.cat((padding, sample)))
|
|
47
|
-
meta.append(el[1])
|
|
48
|
-
yield torch.stack(padded_batch), meta
|
|
62
|
+
series_list: list[torch.Tensor] = []
|
|
63
|
+
meta_list: list[dict] = []
|
|
64
|
+
|
|
65
|
+
for sample, meta in batch:
|
|
66
|
+
tensor = _ensure_1d_tensor(sample)
|
|
67
|
+
assert len(tensor) > 0, "Each sample needs to have a length > 0"
|
|
68
|
+
series_list.append(tensor)
|
|
69
|
+
meta_list.append(meta if meta is not None else {})
|
|
70
|
+
|
|
71
|
+
yield series_list, meta_list
|
|
49
72
|
|
|
50
73
|
|
|
51
74
|
def get_batches(context: ContextType, batch_size: int):
|
|
@@ -59,9 +82,9 @@ def get_batches(context: ContextType, batch_size: int):
|
|
|
59
82
|
if context.ndim == 1:
|
|
60
83
|
context = np.expand_dims(context, axis=0)
|
|
61
84
|
assert context.ndim == 2
|
|
62
|
-
batches =
|
|
85
|
+
batches = _batched_slice(context, None, batch_size)
|
|
63
86
|
elif isinstance(context, (list, Iterable)):
|
|
64
|
-
batches =
|
|
87
|
+
batches = _batch_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
|
|
65
88
|
if batches is None:
|
|
66
89
|
raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
|
|
67
90
|
return batches
|
tirex/base.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (c) NXAI GmbH.
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
import os
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from typing import Literal, TypeVar
|
|
@@ -8,7 +9,10 @@ from typing import Literal, TypeVar
|
|
|
8
9
|
import torch
|
|
9
10
|
from huggingface_hub import hf_hub_download
|
|
10
11
|
|
|
12
|
+
from tirex.models.slstm.cell import sLSTMCellTorch
|
|
13
|
+
|
|
11
14
|
T = TypeVar("T", bound="PretrainedModel")
|
|
15
|
+
VERSION_DELIMITER = "-"
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
def skip_cuda():
|
|
@@ -29,6 +33,17 @@ def parse_hf_repo_id(path):
|
|
|
29
33
|
return "/".join(parts[0:2])
|
|
30
34
|
|
|
31
35
|
|
|
36
|
+
def parse_model_string(model_string):
|
|
37
|
+
if VERSION_DELIMITER in model_string:
|
|
38
|
+
parts = model_string.split(VERSION_DELIMITER)
|
|
39
|
+
model_id, version = parts[0], parts[0]
|
|
40
|
+
else:
|
|
41
|
+
model_id = model_string
|
|
42
|
+
version = None
|
|
43
|
+
|
|
44
|
+
return model_id, version
|
|
45
|
+
|
|
46
|
+
|
|
32
47
|
class PretrainedModel(ABC):
|
|
33
48
|
REGISTRY: dict[str, "PretrainedModel"] = {}
|
|
34
49
|
|
|
@@ -38,7 +53,7 @@ class PretrainedModel(ABC):
|
|
|
38
53
|
|
|
39
54
|
@classmethod
|
|
40
55
|
def from_pretrained(
|
|
41
|
-
cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
|
|
56
|
+
cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
|
|
42
57
|
) -> T:
|
|
43
58
|
if hf_kwargs is None:
|
|
44
59
|
hf_kwargs = {}
|
|
@@ -58,9 +73,10 @@ class PretrainedModel(ABC):
|
|
|
58
73
|
model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
|
|
59
74
|
model.on_load_checkpoint(checkpoint)
|
|
60
75
|
model.load_state_dict(checkpoint["state_dict"])
|
|
76
|
+
model = model.to(device)
|
|
61
77
|
|
|
62
|
-
if backend == "
|
|
63
|
-
|
|
78
|
+
if compile and backend == "torch":
|
|
79
|
+
sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
|
|
64
80
|
return model
|
|
65
81
|
|
|
66
82
|
@classmethod
|
|
@@ -76,6 +92,7 @@ def load_model(
|
|
|
76
92
|
path: str,
|
|
77
93
|
device: str | None = None,
|
|
78
94
|
backend: Literal["torch", "cuda"] | None = None,
|
|
95
|
+
compile: bool = False,
|
|
79
96
|
hf_kwargs=None,
|
|
80
97
|
ckp_kwargs=None,
|
|
81
98
|
) -> PretrainedModel:
|
|
@@ -85,6 +102,7 @@ def load_model(
|
|
|
85
102
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
86
103
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
87
104
|
backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
|
|
105
|
+
compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
|
|
88
106
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
89
107
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
90
108
|
|
|
@@ -99,11 +117,14 @@ def load_model(
|
|
|
99
117
|
backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
|
|
100
118
|
|
|
101
119
|
try:
|
|
102
|
-
_,
|
|
120
|
+
_, model_string = parse_hf_repo_id(path).split("/")
|
|
121
|
+
model_id, version = parse_model_string(model_string)
|
|
103
122
|
except:
|
|
104
123
|
raise ValueError(f"Invalid model path {path}")
|
|
105
124
|
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
|
|
106
125
|
if model_cls is None:
|
|
107
126
|
raise ValueError(f"Invalid model id {model_id}")
|
|
108
127
|
|
|
109
|
-
return model_cls.from_pretrained(
|
|
128
|
+
return model_cls.from_pretrained(
|
|
129
|
+
path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
|
|
130
|
+
)
|
tirex/models/slstm/cell.py
CHANGED
|
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
|
|
|
43
43
|
state = self._get_state(input, state)
|
|
44
44
|
|
|
45
45
|
if self.backend == "torch":
|
|
46
|
-
|
|
46
|
+
output, state = self._impl_torch(input, state)
|
|
47
47
|
elif self.backend == "cuda":
|
|
48
|
-
|
|
48
|
+
output, state = self._impl_cuda(input, state)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
output = self._permute_output(all_states[0][1:])
|
|
52
|
-
return output.to(input.dtype), state.to(input.dtype)
|
|
50
|
+
return self._permute_output(output).to(input.dtype), state.to(input.dtype)
|
|
53
51
|
|
|
54
52
|
def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
55
53
|
input = input.to(dtype=torch.bfloat16)
|
|
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
|
|
|
64
62
|
.reshape(-1)
|
|
65
63
|
)
|
|
66
64
|
|
|
67
|
-
return slstm_forward(input, state, recurrent_kernel, bias)
|
|
65
|
+
return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
|
|
68
66
|
|
|
69
67
|
def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
70
68
|
if input.device.type != "cuda":
|
|
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
|
|
|
88
86
|
|
|
89
87
|
input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
|
|
90
88
|
|
|
91
|
-
|
|
89
|
+
all_states = self.func.apply(
|
|
92
90
|
False,
|
|
93
91
|
input.contiguous(),
|
|
94
92
|
state.contiguous(),
|
|
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
|
|
|
96
94
|
self._bias_.contiguous(),
|
|
97
95
|
)
|
|
98
96
|
|
|
97
|
+
state = all_states[:, -1]
|
|
98
|
+
output = all_states[0][1:]
|
|
99
|
+
return output, state
|
|
100
|
+
|
|
99
101
|
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
102
|
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
101
103
|
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
|
|
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
|
|
|
119
121
|
return output.permute(1, 2, 0, 3)
|
|
120
122
|
|
|
121
123
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
states
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
.
|
|
151
|
-
.
|
|
152
|
-
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
mnew = iraw
|
|
180
|
-
else:
|
|
181
|
-
mnew = torch.max(iraw, logfplusm) # eq 15
|
|
182
|
-
ogate = torch.sigmoid(oraw) # eq 14
|
|
183
|
-
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
184
|
-
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
185
|
-
zgate = torch.tanh(zraw) # eq 11
|
|
186
|
-
cnew = fgate * c + igate * zgate # eq 8
|
|
187
|
-
nnew = fgate * n + igate # eq 9
|
|
188
|
-
hnew = ogate * cnew / nnew # eq 10
|
|
189
|
-
|
|
190
|
-
# y (4, B, H), state (4, B, H)
|
|
191
|
-
return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
124
|
+
class sLSTMCellTorch:
|
|
125
|
+
@staticmethod
|
|
126
|
+
def slstm_forward(
|
|
127
|
+
x: torch.Tensor, # [S, B, G*I]
|
|
128
|
+
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
129
|
+
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
130
|
+
b: torch.Tensor, # [T*H]
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
132
|
+
num_gates = 4
|
|
133
|
+
num_heads = R.shape[0]
|
|
134
|
+
S, B, _ = x.shape
|
|
135
|
+
H = R.shape[1] * num_heads
|
|
136
|
+
assert states.shape == (num_gates, B, H)
|
|
137
|
+
|
|
138
|
+
states = states.to(R.dtype).unbind(dim=0)
|
|
139
|
+
output = []
|
|
140
|
+
for i in range(S):
|
|
141
|
+
Ry = (
|
|
142
|
+
states[0]
|
|
143
|
+
.reshape(B, num_heads, 1, -1)
|
|
144
|
+
.matmul(R.unsqueeze(0))
|
|
145
|
+
.reshape(B, num_heads, num_gates, -1)
|
|
146
|
+
.transpose(1, 2)
|
|
147
|
+
.reshape(B, -1)
|
|
148
|
+
)
|
|
149
|
+
states = sLSTMCellTorch.slstm_forward_pointwise(
|
|
150
|
+
x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
|
|
151
|
+
)
|
|
152
|
+
states = [s.to(dtype=R.dtype) for s in states]
|
|
153
|
+
output.append(states[0])
|
|
154
|
+
|
|
155
|
+
return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def slstm_forward_pointwise(
|
|
159
|
+
Wx: torch.Tensor, # dim [B, 4*H]
|
|
160
|
+
Ry: torch.Tensor, # dim [B, 4*H]
|
|
161
|
+
b: torch.Tensor, # dim [1, 4*H]
|
|
162
|
+
states: torch.Tensor, # dim 4 x [B, H]
|
|
163
|
+
) -> list[torch.Tensor]:
|
|
164
|
+
y, c, n, m = states
|
|
165
|
+
|
|
166
|
+
raw = Wx + Ry + b
|
|
167
|
+
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
168
|
+
|
|
169
|
+
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
|
|
170
|
+
logfplusm = m + F.logsigmoid(fraw) # eq 15
|
|
171
|
+
mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
|
|
172
|
+
ogate = torch.sigmoid(oraw) # eq 14
|
|
173
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
174
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
175
|
+
zgate = torch.tanh(zraw) # eq 11
|
|
176
|
+
cnew = fgate * c + igate * zgate # eq 8
|
|
177
|
+
nnew = fgate * n + igate # eq 9
|
|
178
|
+
hnew = ogate * cnew / nnew # eq 10
|
|
179
|
+
|
|
180
|
+
return [hnew, cnew, nnew, mnew] # 4 x (B, H)
|
tirex/models/tirex.py
CHANGED
|
@@ -179,8 +179,25 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
179
179
|
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
180
180
|
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
181
181
|
|
|
182
|
+
quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
|
|
183
|
+
|
|
184
|
+
quantile_preds = torch.unflatten(
|
|
185
|
+
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
186
|
+
)
|
|
187
|
+
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
188
|
+
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
182
189
|
return quantile_preds, hidden_states
|
|
183
190
|
|
|
191
|
+
def _forward_model(self, input: torch.Tensor):
|
|
192
|
+
hidden_states = self.input_patch_embedding(input)
|
|
193
|
+
|
|
194
|
+
for block in self.blocks:
|
|
195
|
+
hidden_states = block(hidden_states)
|
|
196
|
+
|
|
197
|
+
hidden_states = self.out_norm(hidden_states)
|
|
198
|
+
|
|
199
|
+
return self.output_patch_embedding(hidden_states)
|
|
200
|
+
|
|
184
201
|
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
|
|
185
202
|
training_quantile_levels = self.config.quantiles
|
|
186
203
|
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
|
tirex/util.py
CHANGED
|
@@ -1,7 +1,611 @@
|
|
|
1
1
|
# Copyright (c) NXAI GmbH.
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from dataclasses import fields
|
|
6
|
+
from functools import partial
|
|
7
|
+
from math import ceil
|
|
8
|
+
from typing import Literal, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def frequency_resample(
|
|
15
|
+
ts: torch.Tensor,
|
|
16
|
+
prediction_length: int,
|
|
17
|
+
patch_size: int = 64,
|
|
18
|
+
peak_prominence: float = 0.1,
|
|
19
|
+
selection_method: Literal["low_harmonic", "high_harmonic", "highest_amplitude"] = "low_harmonic",
|
|
20
|
+
min_period: int | None = None,
|
|
21
|
+
max_period: int = 1000,
|
|
22
|
+
bandpass_filter: bool = True,
|
|
23
|
+
nifr_enabled: bool = True,
|
|
24
|
+
nifr_start_integer: int = 2,
|
|
25
|
+
nifr_end_integer: int = 12,
|
|
26
|
+
nifr_clamp_large_factors: bool = False,
|
|
27
|
+
) -> tuple[torch.Tensor, float, Callable[[torch.Tensor], torch.Tensor]]:
|
|
28
|
+
"""
|
|
29
|
+
Downsample a time series according to a frequency-based strategy and return a helper to upsample
|
|
30
|
+
forecasts back to the original resolution.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
ts : torch.Tensor
|
|
35
|
+
1D time series of shape [T].
|
|
36
|
+
prediction_length : int
|
|
37
|
+
Requested forecast horizon; short horizons (<100) skip resampling.
|
|
38
|
+
patch_size : int, default 64
|
|
39
|
+
Nominal patch size used to align one dominant period to one patch.
|
|
40
|
+
peak_prominence : float, default 0.1
|
|
41
|
+
Threshold for FFT peak detection on the normalized spectrum.
|
|
42
|
+
selection_method : {"low_harmonic", "high_harmonic", "highest_amplitude"}, default "low_harmonic"
|
|
43
|
+
How to resolve two dominant peaks in ~2x harmonic relation.
|
|
44
|
+
min_period : int or None, optional
|
|
45
|
+
Minimum period to consider; if None, defaults to `patch_size`.
|
|
46
|
+
max_period : int, default 1000
|
|
47
|
+
Maximum period to consider for FFT peak search.
|
|
48
|
+
bandpass_filter : bool, default True
|
|
49
|
+
If True, suppresses very low frequencies before peak search.
|
|
50
|
+
nifr_enabled : bool, default True
|
|
51
|
+
Enable nearest-integer-fraction rounding of the factor.
|
|
52
|
+
nifr_start_integer : int, default 2
|
|
53
|
+
Smallest integer k used for 1/k grid when NIFR is enabled.
|
|
54
|
+
nifr_end_integer : int, default 12
|
|
55
|
+
Largest integer k used for 1/k grid when NIFR is enabled.
|
|
56
|
+
nifr_clamp_large_factors : bool, default False
|
|
57
|
+
If True, clamps large factors in [1, 1/nifr_start_integer] to 1.0.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
resampled_ts : torch.Tensor
|
|
62
|
+
The resampled input series.
|
|
63
|
+
sample_factor : float
|
|
64
|
+
Applied sampling factor (<= 1 means downsampling; 1.0 = identity).
|
|
65
|
+
fc_resample_fn : Callable[[torch.Tensor], torch.Tensor]
|
|
66
|
+
Function that upsamples a forecast back to the original resolution using the inverse factor.
|
|
67
|
+
|
|
68
|
+
Notes
|
|
69
|
+
-----
|
|
70
|
+
- For short horizons (prediction_length < 100), resampling is disabled and the factor is set to 1.0.
|
|
71
|
+
- The factor is clamped to at most 1.0 to avoid upsampling the context.
|
|
72
|
+
"""
|
|
73
|
+
sample_factor = frequency_factor(
|
|
74
|
+
ts,
|
|
75
|
+
max_period=max_period,
|
|
76
|
+
min_period=min_period,
|
|
77
|
+
bandpass_filter=bandpass_filter,
|
|
78
|
+
selection_method=selection_method,
|
|
79
|
+
peak_prominence=peak_prominence,
|
|
80
|
+
patch_size=patch_size,
|
|
81
|
+
nifr_enabled=nifr_enabled,
|
|
82
|
+
nifr_start_integer=nifr_start_integer,
|
|
83
|
+
nifr_end_integer=nifr_end_integer,
|
|
84
|
+
nifr_clamp_large_factors=nifr_clamp_large_factors,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
sample_factor = min(1, sample_factor)
|
|
88
|
+
|
|
89
|
+
if prediction_length < 100:
|
|
90
|
+
# do not resample for short forecasts
|
|
91
|
+
sample_factor = 1.0
|
|
92
|
+
|
|
93
|
+
fc_resample_factor = 1 / sample_factor
|
|
94
|
+
fc_resample_fn = partial(resample, sample_rate=fc_resample_factor)
|
|
95
|
+
resampled_ts = resample(ts, sample_rate=sample_factor)
|
|
96
|
+
|
|
97
|
+
return resampled_ts, sample_factor, fc_resample_fn
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def frequency_factor(
|
|
101
|
+
ts: torch.Tensor,
|
|
102
|
+
patch_size: int = 64, # This doesn't have to match model patch size, but rather the 'target frequency'
|
|
103
|
+
peak_prominence: float = 0.1,
|
|
104
|
+
selection_method: Literal["low_harmonic", "high_harmonic", "highest_amplitude"] = "low_harmonic",
|
|
105
|
+
min_period: int | None = None,
|
|
106
|
+
max_period: int = 1000,
|
|
107
|
+
bandpass_filter: bool = True,
|
|
108
|
+
nifr_enabled: bool = False,
|
|
109
|
+
nifr_start_integer: int = 2,
|
|
110
|
+
nifr_end_integer: int = 12,
|
|
111
|
+
nifr_clamp_large_factors: bool = False,
|
|
112
|
+
) -> float:
|
|
113
|
+
"""
|
|
114
|
+
Estimate a sampling factor from the dominant frequency of a 1D series so that one period
|
|
115
|
+
approximately fits into one patch of length `patch_size`.
|
|
116
|
+
|
|
117
|
+
The factor is computed as `patch_size / period`, where `period = 1 / f*` and `f*`
|
|
118
|
+
is the selected dominant frequency from the one-sided FFT of the series (NaNs are
|
|
119
|
+
linearly interpolated for analysis). If two prominent peaks are detected whose
|
|
120
|
+
frequencies are in roughly a 2x harmonic relation (ratio in [1.5, 2.5]),
|
|
121
|
+
`selection_method` determines whether to select the lower or higher harmonic. A set of
|
|
122
|
+
guards returns 1.0 (identity) for short series, invalid/non-finite results, or when no
|
|
123
|
+
prominent peak is found. Optional nearest-integer-fraction rounding (NIFR) can snap the
|
|
124
|
+
factor to the closest value in {1} ∪ {1/k | k ∈ [nifr_start_integer, nifr_end_integer]}.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
ts : torch.Tensor
|
|
129
|
+
Input 1D series (last dim is time). NaNs are linearly interpolated for FFT analysis only;
|
|
130
|
+
the original series is not modified.
|
|
131
|
+
patch_size : int, default 64
|
|
132
|
+
Target number of samples per period.
|
|
133
|
+
peak_prominence : float, default 0.1
|
|
134
|
+
Minimum normalized spectrum height to treat a bin as a peak.
|
|
135
|
+
selection_method : {"low_harmonic", "high_harmonic", "highest_amplitude"}, default "low_harmonic"
|
|
136
|
+
Rule for picking between two ~2x related peaks.
|
|
137
|
+
min_period : int or None, optional
|
|
138
|
+
Minimum period to consider. If None, defaults to `patch_size`.
|
|
139
|
+
max_period : int, default 1000
|
|
140
|
+
Series shorter than `2 * max_period` return 1.0.
|
|
141
|
+
bandpass_filter : bool, default True
|
|
142
|
+
If True, very low frequencies below 1 / max_period are suppressed.
|
|
143
|
+
nifr_enabled : bool, default False
|
|
144
|
+
Enable nearest-integer-fraction rounding of the factor.
|
|
145
|
+
nifr_start_integer : int, default 2
|
|
146
|
+
Smallest integer k used for the 1/k grid when NIFR is enabled.
|
|
147
|
+
nifr_end_integer : int, default 12
|
|
148
|
+
Largest integer k used for the 1/k grid when NIFR is enabled.
|
|
149
|
+
nifr_clamp_large_factors : bool, default False
|
|
150
|
+
If True, clamps factors in [1, 1/nifr_start_integer] to 1.0.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
float
|
|
155
|
+
The sampling factor. Values <= 0 or non-finite are mapped to 1.0. If no valid
|
|
156
|
+
dominant frequency is found or the series is too short, returns 1.0.
|
|
157
|
+
|
|
158
|
+
Notes
|
|
159
|
+
-----
|
|
160
|
+
- The factor is computed as `patch_size / period`, where `period = 1 / f*` and `f*` is the selected
|
|
161
|
+
dominant FFT frequency.
|
|
162
|
+
- If two prominent peaks are detected ~2x apart, `selection_method` determines whether to select the lower or higher harmonic.
|
|
163
|
+
- Optional nearest-integer-fraction rounding (NIFR) can snap the factor to the closest value in {1} ∪ {1/k | k ∈ [nifr_start_integer, nifr_end_integer]}.
|
|
164
|
+
"""
|
|
165
|
+
if min_period is None:
|
|
166
|
+
# NOTE: be careful when min_period is not matching patch_size, it can create unexpected scaling factors!
|
|
167
|
+
min_period = patch_size
|
|
168
|
+
|
|
169
|
+
# Ensure CPU numpy array for FFT analysis
|
|
170
|
+
ts_np = ts.detach().cpu().numpy() if isinstance(ts, torch.Tensor) else np.asarray(ts)
|
|
171
|
+
|
|
172
|
+
# NOTE: If the series is shorter than max_period *2, FFT may not be accurate, to avoid detecting these peaks, we don't scale
|
|
173
|
+
if ts_np.size < max_period * 2:
|
|
174
|
+
return 1.0
|
|
175
|
+
|
|
176
|
+
freqs, specs, peak_idc = run_fft_analysis(
|
|
177
|
+
ts_np,
|
|
178
|
+
scaling="amplitude",
|
|
179
|
+
peak_prominence=peak_prominence,
|
|
180
|
+
min_period=min_period,
|
|
181
|
+
max_period=max_period,
|
|
182
|
+
bandpass_filter=bandpass_filter,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# No detectable peaks -> keep original sampling
|
|
186
|
+
if peak_idc.size == 0:
|
|
187
|
+
return 1.0
|
|
188
|
+
|
|
189
|
+
# Choose initial candidate as the highest-amplitude peak
|
|
190
|
+
chosen_idx = int(peak_idc[0])
|
|
191
|
+
|
|
192
|
+
# If two peaks exist, check for ~2x harmonic relation and prefer the higher/lower one
|
|
193
|
+
if peak_idc.size >= 2:
|
|
194
|
+
idx_a = int(peak_idc[0]) # highest amplitude
|
|
195
|
+
idx_b = int(peak_idc[1]) # second highest amplitude
|
|
196
|
+
f_a = float(freqs[idx_a])
|
|
197
|
+
f_b = float(freqs[idx_b])
|
|
198
|
+
|
|
199
|
+
# Determine lower/higher frequency
|
|
200
|
+
low_f = min(f_a, f_b)
|
|
201
|
+
high_f = max(f_a, f_b)
|
|
202
|
+
|
|
203
|
+
if low_f > 0:
|
|
204
|
+
ratio = high_f / low_f
|
|
205
|
+
# Roughly half relation
|
|
206
|
+
if 1.5 <= ratio <= 2.5:
|
|
207
|
+
if selection_method == "low_harmonic":
|
|
208
|
+
chosen_idx = idx_a if f_a < f_b else idx_b
|
|
209
|
+
elif selection_method == "high_harmonic":
|
|
210
|
+
chosen_idx = idx_a if f_a > f_b else idx_b
|
|
211
|
+
|
|
212
|
+
chosen_freq = float(freqs[chosen_idx])
|
|
213
|
+
|
|
214
|
+
# Guard against zero or non-finite frequency
|
|
215
|
+
if not np.isfinite(chosen_freq) or chosen_freq <= 0:
|
|
216
|
+
return 1.0
|
|
217
|
+
|
|
218
|
+
# Convert to period and compute scaling factor so one period fits one patch
|
|
219
|
+
period = 1.0 / chosen_freq
|
|
220
|
+
factor = resampling_factor(period, patch_size)
|
|
221
|
+
factor = round(factor, 4)
|
|
222
|
+
|
|
223
|
+
# Guard against factor being negative
|
|
224
|
+
if not np.isfinite(factor) or factor <= 0:
|
|
225
|
+
return 1.0
|
|
226
|
+
|
|
227
|
+
# nearest interger fraction rounding (nifr)
|
|
228
|
+
if nifr_enabled:
|
|
229
|
+
int_fractions = np.concatenate([[1], 1 / np.arange(nifr_start_integer, nifr_end_integer + 1)])
|
|
230
|
+
diff = np.abs(factor - int_fractions)
|
|
231
|
+
min_diff_idc = np.argmin(diff)
|
|
232
|
+
factor = int_fractions[min_diff_idc]
|
|
233
|
+
|
|
234
|
+
if nifr_clamp_large_factors:
|
|
235
|
+
# Clamp everything between 1 and 1/nifr_start_integer to 1, that is no scaling
|
|
236
|
+
factor = factor if factor < int_fractions[1] else 1
|
|
237
|
+
|
|
238
|
+
return float(factor)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def resample(ts: torch.Tensor, sample_rate: float, window_position: str = "center") -> torch.Tensor:
|
|
242
|
+
"""
|
|
243
|
+
Resample the time series using NaN-tolerant window averaging with size 1/sample_rate.
|
|
244
|
+
|
|
245
|
+
- If sample_rate > 1 the series is upsampled; windows may collapse to a single index.
|
|
246
|
+
- If sample_rate < 1 the series is downsampled; windows span multiple indices.
|
|
247
|
+
- If sample_rate == 1 the series is returned unchanged (cast to float for NaN support).
|
|
248
|
+
|
|
249
|
+
Window alignment controlled by `window_position`:
|
|
250
|
+
- "center": average over [c - L/2, c + L/2]
|
|
251
|
+
- "left" : average over [c - L, c]
|
|
252
|
+
- "right" : average over [c, c + L]
|
|
253
|
+
|
|
254
|
+
The window is truncated at the boundaries. NaNs are ignored via nan-mean semantics;
|
|
255
|
+
if a window contains only NaNs, the output is NaN.
|
|
256
|
+
|
|
257
|
+
Arguments:
|
|
258
|
+
----------
|
|
259
|
+
ts: torch.Tensor of shape [..., T]
|
|
260
|
+
The time series to be rescaled (last dim is time).
|
|
261
|
+
sample_rate: float
|
|
262
|
+
The factor determining the final number of timesteps in the series, i.e., T' = ceil(T * sample_rate).
|
|
263
|
+
window_position: {"center", "left", "right"}
|
|
264
|
+
Placement of the averaging window relative to each target coordinate.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
--------
|
|
268
|
+
torch.Tensor of shape [..., ceil(T * sample_rate)] with dtype float.
|
|
269
|
+
"""
|
|
270
|
+
# Validate inputs
|
|
271
|
+
if sample_rate <= 0 or sample_rate == 1:
|
|
272
|
+
# Invalid or no scaling; return original as float
|
|
273
|
+
return ts.to(torch.float)
|
|
274
|
+
|
|
275
|
+
src_num_timesteps = ts.shape[-1]
|
|
276
|
+
tgt_num_timesteps = ceil(src_num_timesteps * sample_rate)
|
|
277
|
+
|
|
278
|
+
# Do not change coordinate creation logic
|
|
279
|
+
src_coords = torch.arange(src_num_timesteps, device=ts.device)
|
|
280
|
+
tgt_coords = torch.linspace(0, src_num_timesteps - 1, tgt_num_timesteps, device=ts.device)
|
|
281
|
+
|
|
282
|
+
if sample_rate == 1:
|
|
283
|
+
return ts.to(torch.float)
|
|
284
|
+
|
|
285
|
+
# Branch: upsampling -> linear interpolation between nearest neighbors (NaN-aware)
|
|
286
|
+
if sample_rate > 1:
|
|
287
|
+
# Neighbour indices for each target coordinate along the last dimension
|
|
288
|
+
tgt_in_src_idx_lo = tgt_coords.floor().to(torch.long)
|
|
289
|
+
tgt_in_src_idx_hi = tgt_coords.ceil().to(torch.long)
|
|
290
|
+
|
|
291
|
+
# Distances in index space and offsets from lower index
|
|
292
|
+
dist = src_coords[tgt_in_src_idx_hi] - src_coords[tgt_in_src_idx_lo]
|
|
293
|
+
|
|
294
|
+
# Work in float for NaN support; gather neighbour values
|
|
295
|
+
src_lo_vals = ts[..., tgt_in_src_idx_lo].to(torch.float)
|
|
296
|
+
src_hi_vals = ts[..., tgt_in_src_idx_hi].to(torch.float)
|
|
297
|
+
diff = src_hi_vals - src_lo_vals
|
|
298
|
+
offset = tgt_coords - src_coords[tgt_in_src_idx_lo]
|
|
299
|
+
|
|
300
|
+
# Allocate output
|
|
301
|
+
tgt_values = torch.empty(*ts.shape[:-1], tgt_num_timesteps, dtype=torch.float, device=ts.device)
|
|
302
|
+
|
|
303
|
+
# Masks
|
|
304
|
+
exact_mask = dist == 0
|
|
305
|
+
interp_mask = ~exact_mask
|
|
306
|
+
|
|
307
|
+
# Exact source index -> take the source value
|
|
308
|
+
if exact_mask.any():
|
|
309
|
+
tgt_values[..., exact_mask] = src_lo_vals[..., exact_mask]
|
|
310
|
+
|
|
311
|
+
# Linear interpolate where indices differ
|
|
312
|
+
if interp_mask.any():
|
|
313
|
+
tgt_values[..., interp_mask] = (
|
|
314
|
+
diff[..., interp_mask] / dist[interp_mask].to(torch.float) * offset[interp_mask]
|
|
315
|
+
+ src_lo_vals[..., interp_mask]
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Propagate NaNs from either neighbour
|
|
319
|
+
nan_mask = torch.isnan(src_lo_vals) | torch.isnan(src_hi_vals)
|
|
320
|
+
if nan_mask.any():
|
|
321
|
+
tgt_values[..., nan_mask] = torch.nan
|
|
322
|
+
|
|
323
|
+
return tgt_values
|
|
324
|
+
|
|
325
|
+
# Window length in source-index units
|
|
326
|
+
L = 1.0 / sample_rate
|
|
327
|
+
half_L = 0.5 * L
|
|
328
|
+
|
|
329
|
+
if window_position == "center":
|
|
330
|
+
left_f = tgt_coords - half_L
|
|
331
|
+
right_f = tgt_coords + half_L
|
|
332
|
+
elif window_position == "left":
|
|
333
|
+
left_f = tgt_coords - L
|
|
334
|
+
right_f = tgt_coords
|
|
335
|
+
elif window_position == "right":
|
|
336
|
+
left_f = tgt_coords
|
|
337
|
+
right_f = tgt_coords + L
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError("window_position must be one of {'center','left','right'}")
|
|
340
|
+
|
|
341
|
+
# Convert to integer indices, inclusive bounds
|
|
342
|
+
left_idx = torch.ceil(left_f).to(torch.long)
|
|
343
|
+
right_idx = torch.floor(right_f).to(torch.long)
|
|
344
|
+
|
|
345
|
+
# Clip to valid range and ensure non-empty windows (at least one index)
|
|
346
|
+
left_idx = torch.clamp(left_idx, 0, src_num_timesteps - 1)
|
|
347
|
+
right_idx = torch.clamp(right_idx, 0, src_num_timesteps - 1)
|
|
348
|
+
right_idx = torch.maximum(right_idx, left_idx)
|
|
349
|
+
|
|
350
|
+
# Prepare cumulative sums for fast [l, r] segment nan-mean along the last dim
|
|
351
|
+
ts_float = ts.to(torch.float)
|
|
352
|
+
valid_mask = ~torch.isnan(ts_float)
|
|
353
|
+
|
|
354
|
+
values_filled = torch.where(valid_mask, ts_float, torch.zeros_like(ts_float))
|
|
355
|
+
counts = valid_mask.to(torch.float)
|
|
356
|
+
|
|
357
|
+
cumsum_vals = values_filled.cumsum(dim=-1)
|
|
358
|
+
cumsum_cnts = counts.cumsum(dim=-1)
|
|
359
|
+
|
|
360
|
+
# Pad a leading zero to make inclusive range sums easy: sum[l:r] = cs[r] - cs[l-1]
|
|
361
|
+
pad_shape = (*ts.shape[:-1], 1)
|
|
362
|
+
zeros_vals = torch.zeros(pad_shape, dtype=cumsum_vals.dtype, device=ts.device)
|
|
363
|
+
zeros_cnts = torch.zeros(pad_shape, dtype=cumsum_cnts.dtype, device=ts.device)
|
|
364
|
+
cumsum_vals = torch.cat([zeros_vals, cumsum_vals], dim=-1)
|
|
365
|
+
cumsum_cnts = torch.cat([zeros_cnts, cumsum_cnts], dim=-1)
|
|
366
|
+
|
|
367
|
+
# Build broadcastable indices for gather along the last dim
|
|
368
|
+
prefix_shape = ts.shape[:-1]
|
|
369
|
+
target_len = tgt_num_timesteps
|
|
370
|
+
|
|
371
|
+
def _expand_index(idx: torch.Tensor) -> torch.Tensor:
|
|
372
|
+
# idx shape [target_len] -> [..., target_len]
|
|
373
|
+
view_shape = (1,) * len(prefix_shape) + (target_len,)
|
|
374
|
+
return idx.view(view_shape).expand(*prefix_shape, target_len)
|
|
375
|
+
|
|
376
|
+
# For inclusive [l, r], use cumsum at (r+1) and (l)
|
|
377
|
+
r_plus1 = torch.clamp(right_idx + 1, 0, src_num_timesteps)
|
|
378
|
+
l_idx = left_idx
|
|
379
|
+
|
|
380
|
+
r_plus1_exp = _expand_index(r_plus1)
|
|
381
|
+
l_exp = _expand_index(l_idx)
|
|
382
|
+
|
|
383
|
+
seg_sums = cumsum_vals.gather(dim=-1, index=r_plus1_exp) - cumsum_vals.gather(dim=-1, index=l_exp)
|
|
384
|
+
seg_cnts = cumsum_cnts.gather(dim=-1, index=r_plus1_exp) - cumsum_cnts.gather(dim=-1, index=l_exp)
|
|
385
|
+
|
|
386
|
+
# Compute nan-mean: where count==0 -> NaN
|
|
387
|
+
with torch.no_grad():
|
|
388
|
+
safe_cnts = torch.where(seg_cnts > 0, seg_cnts, torch.ones_like(seg_cnts))
|
|
389
|
+
averages = seg_sums / safe_cnts
|
|
390
|
+
averages = torch.where(seg_cnts > 0, averages, torch.full_like(averages, float("nan")))
|
|
391
|
+
|
|
392
|
+
return averages
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def run_fft_analysis(
|
|
396
|
+
y,
|
|
397
|
+
dt: float = 1.0,
|
|
398
|
+
window: str = "hann",
|
|
399
|
+
detrend: bool = True,
|
|
400
|
+
scaling: str = "amplitude",
|
|
401
|
+
peak_prominence: float = 0.1,
|
|
402
|
+
min_period: int = 64,
|
|
403
|
+
max_period: int = 1000,
|
|
404
|
+
bandpass_filter: bool = True,
|
|
405
|
+
):
|
|
406
|
+
"""
|
|
407
|
+
Compute one-sided FFT frequencies and spectrum magnitude for a real 1D signal.
|
|
408
|
+
|
|
409
|
+
Parameters
|
|
410
|
+
----------
|
|
411
|
+
y : array_like
|
|
412
|
+
1D time series (regularly sampled). NaNs will be linearly interpolated.
|
|
413
|
+
dt : float
|
|
414
|
+
Sampling period (time between samples). Frequencies are cycles per unit of dt.
|
|
415
|
+
window : {'hann', None}
|
|
416
|
+
Optional taper to reduce leakage.
|
|
417
|
+
detrend : bool
|
|
418
|
+
If True, remove the mean before FFT.
|
|
419
|
+
scaling : {'amplitude', 'power', 'raw'}
|
|
420
|
+
- 'amplitude': one-sided amplitude spectrum with window-power compensation.
|
|
421
|
+
- 'power' : one-sided power (not density) with window-power compensation.
|
|
422
|
+
- 'raw' : |rfft(yw)| (no normalization, mostly for debugging).
|
|
423
|
+
peak_prominence : float
|
|
424
|
+
Absolute threshold on the normalized spectrum for peak detection.
|
|
425
|
+
|
|
426
|
+
Returns
|
|
427
|
+
-------
|
|
428
|
+
f : ndarray
|
|
429
|
+
Frequencies (non-negative), length N//2 + 1, in cycles per unit (1/dt).
|
|
430
|
+
spec : ndarray
|
|
431
|
+
Spectrum corresponding to `f` under the chosen `scaling`.
|
|
432
|
+
peaks_idx : ndarray
|
|
433
|
+
Indices into f of detected peaks.
|
|
434
|
+
"""
|
|
435
|
+
y = np.asarray(y, dtype=float)
|
|
436
|
+
if y.ndim != 1:
|
|
437
|
+
y = y.reshape(-1)
|
|
438
|
+
n = y.size
|
|
439
|
+
if n < 2:
|
|
440
|
+
return np.array([]), np.array([]), np.array([])
|
|
441
|
+
|
|
442
|
+
# Fill NaNs linearly (handles edge NaNs as well)
|
|
443
|
+
y = _nan_linear_interpolate(y)
|
|
444
|
+
|
|
445
|
+
if detrend:
|
|
446
|
+
y = y - np.mean(y)
|
|
447
|
+
|
|
448
|
+
# Windowing
|
|
449
|
+
if window == "hann":
|
|
450
|
+
w = np.hanning(n)
|
|
451
|
+
yw = y * w
|
|
452
|
+
# average window power (for proper amplitude/power normalization)
|
|
453
|
+
w_power = np.sum(w**2) / n
|
|
454
|
+
elif window is None:
|
|
455
|
+
yw = y
|
|
456
|
+
w_power = 1.0
|
|
457
|
+
else:
|
|
458
|
+
raise ValueError("window must be either 'hann' or None")
|
|
459
|
+
|
|
460
|
+
# FFT (one-sided)
|
|
461
|
+
Y = np.fft.rfft(yw)
|
|
462
|
+
f = np.fft.rfftfreq(n, d=dt) # cycles per unit time
|
|
463
|
+
|
|
464
|
+
if scaling == "raw":
|
|
465
|
+
spec = np.abs(Y)
|
|
466
|
+
elif scaling == "amplitude":
|
|
467
|
+
# One-sided amplitude with window power compensation
|
|
468
|
+
spec = np.abs(Y) / (n * np.sqrt(w_power))
|
|
469
|
+
if n % 2 == 0:
|
|
470
|
+
spec[1:-1] *= 2.0
|
|
471
|
+
else:
|
|
472
|
+
spec[1:] *= 2.0
|
|
473
|
+
elif scaling == "power":
|
|
474
|
+
# One-sided power (not PSD)
|
|
475
|
+
spec = (np.abs(Y) ** 2) / (n**2 * w_power)
|
|
476
|
+
if n % 2 == 0:
|
|
477
|
+
spec[1:-1] *= 2.0
|
|
478
|
+
else:
|
|
479
|
+
spec[1:] *= 2.0
|
|
480
|
+
else:
|
|
481
|
+
raise ValueError("scaling must be 'amplitude', 'power', or 'raw'")
|
|
482
|
+
|
|
483
|
+
# Normalize the spectrum by its maximum value
|
|
484
|
+
if spec.max() > 0:
|
|
485
|
+
spec = spec / spec.max()
|
|
486
|
+
|
|
487
|
+
# Find peaks in the spectrum
|
|
488
|
+
peaks_idx = custom_find_peaks(
|
|
489
|
+
f,
|
|
490
|
+
spec,
|
|
491
|
+
max_peaks=2,
|
|
492
|
+
prominence_threshold=peak_prominence,
|
|
493
|
+
min_period=min_period,
|
|
494
|
+
max_period=max_period,
|
|
495
|
+
bandpass_filter=bandpass_filter,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
return f, spec, peaks_idx
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _nan_linear_interpolate(y: np.ndarray) -> np.ndarray:
|
|
502
|
+
y = y.astype(np.float32)
|
|
503
|
+
if y.ndim != 1:
|
|
504
|
+
y = y.reshape(-1)
|
|
505
|
+
n = y.size
|
|
506
|
+
mask = np.isfinite(y)
|
|
507
|
+
if mask.all():
|
|
508
|
+
return y
|
|
509
|
+
if (~mask).all():
|
|
510
|
+
return np.zeros(n, dtype=np.float32)
|
|
511
|
+
idx = np.arange(n)
|
|
512
|
+
y_interp = y.copy()
|
|
513
|
+
y_interp[~mask] = np.interp(idx[~mask], idx[mask], y[mask])
|
|
514
|
+
return y_interp
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def resampling_factor(inverted_freq, path_size):
|
|
518
|
+
"""
|
|
519
|
+
Compute the resampling factor based on the inverted frequency and path size.
|
|
520
|
+
"""
|
|
521
|
+
if inverted_freq <= 0:
|
|
522
|
+
return 1.0
|
|
523
|
+
factor = path_size / inverted_freq
|
|
524
|
+
return factor
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def custom_find_peaks(
|
|
528
|
+
f,
|
|
529
|
+
spec,
|
|
530
|
+
*,
|
|
531
|
+
max_peaks=5,
|
|
532
|
+
prominence_threshold=0.1,
|
|
533
|
+
min_period=64,
|
|
534
|
+
max_period=1000,
|
|
535
|
+
bandpass_filter=True,
|
|
536
|
+
):
|
|
537
|
+
"""
|
|
538
|
+
Finds prominent peaks in a spectrum using a simple custom logic.
|
|
539
|
+
|
|
540
|
+
A peak is a local maximum. A peak is considered prominent if its height
|
|
541
|
+
(on a normalized spectrum) is greater than a given threshold.
|
|
542
|
+
|
|
543
|
+
Parameters
|
|
544
|
+
----------
|
|
545
|
+
f : np.ndarray
|
|
546
|
+
Frequency array (currently unused but kept for API consistency).
|
|
547
|
+
spec : np.ndarray
|
|
548
|
+
The normalized spectrum.
|
|
549
|
+
max_peaks : int
|
|
550
|
+
The maximum number of peaks to return.
|
|
551
|
+
prominence_threshold : float
|
|
552
|
+
The minimum height for a peak to be considered prominent.
|
|
553
|
+
|
|
554
|
+
Returns
|
|
555
|
+
-------
|
|
556
|
+
np.ndarray
|
|
557
|
+
An array of indices of the detected peaks in the spectrum. Returns an
|
|
558
|
+
empty array if no prominent peaks are found.
|
|
559
|
+
"""
|
|
560
|
+
if len(spec) < 5: # Need at least 5 points to exclude last two bins
|
|
561
|
+
return np.array([], dtype=int)
|
|
562
|
+
|
|
563
|
+
if bandpass_filter: # only truly filter low frequencies, high frequencies are dealt with later
|
|
564
|
+
min_freq = 1 / max_period
|
|
565
|
+
freq_mask = f >= min_freq
|
|
566
|
+
spec = spec * freq_mask
|
|
567
|
+
|
|
568
|
+
# Find all local maxima, excluding the last two bins
|
|
569
|
+
local_maxima_indices = []
|
|
570
|
+
for i in range(1, len(spec) - 2):
|
|
571
|
+
if spec[i] > spec[i - 1] and spec[i] > spec[i + 1]:
|
|
572
|
+
local_maxima_indices.append(i)
|
|
573
|
+
|
|
574
|
+
if not local_maxima_indices:
|
|
575
|
+
return np.array([], dtype=int)
|
|
576
|
+
|
|
577
|
+
# Filter by prominence (height)
|
|
578
|
+
prominent_peaks = []
|
|
579
|
+
for idx in local_maxima_indices:
|
|
580
|
+
if spec[idx] > prominence_threshold:
|
|
581
|
+
prominent_peaks.append((idx, spec[idx]))
|
|
582
|
+
|
|
583
|
+
# If no peaks are above the threshold, return an empty list
|
|
584
|
+
if not prominent_peaks:
|
|
585
|
+
return np.array([], dtype=int)
|
|
586
|
+
|
|
587
|
+
# Check for clear peaks below min_period (do lowpass filter)
|
|
588
|
+
for idx, _ in prominent_peaks:
|
|
589
|
+
period = 1 / f[idx]
|
|
590
|
+
if period < min_period:
|
|
591
|
+
return np.array([], dtype=int)
|
|
592
|
+
|
|
593
|
+
# Filter by period
|
|
594
|
+
period_filtered_peaks = []
|
|
595
|
+
for idx, prominence in prominent_peaks:
|
|
596
|
+
period = 1 / f[idx]
|
|
597
|
+
|
|
598
|
+
if min_period <= period <= max_period:
|
|
599
|
+
period_filtered_peaks.append((idx, prominence))
|
|
600
|
+
|
|
601
|
+
if not period_filtered_peaks:
|
|
602
|
+
return np.array([], dtype=int)
|
|
603
|
+
|
|
604
|
+
# Sort by height and return the top `max_peaks`
|
|
605
|
+
period_filtered_peaks.sort(key=lambda x: x[1], reverse=True)
|
|
606
|
+
peak_indices = np.array([p[0] for p in period_filtered_peaks[:max_peaks]], dtype=int)
|
|
607
|
+
|
|
608
|
+
return peak_indices
|
|
5
609
|
|
|
6
610
|
|
|
7
611
|
def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
+
tirex/base.py,sha256=JZuWzezuDenrlDekMaaPLBWzB9mw6vJ_qF2GWjou8ws,4229
|
|
3
|
+
tirex/util.py,sha256=AQ8eM-7IpGeAqegVqflrjPQITzIEKltRcJOBbJmikcA,22758
|
|
4
|
+
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
+
tirex/api_adapter/forecast.py,sha256=CpX9YfiQ1X6nvzODklCPzuKPlYidJDha8FnEU-6zr1Q,13919
|
|
6
|
+
tirex/api_adapter/gluon.py,sha256=CbqL8ZgSRvxWAHK6TWqKdxUrMVWmGskRWOHNF84Lh1U,1819
|
|
7
|
+
tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
|
|
8
|
+
tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
|
|
9
|
+
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
+
tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
|
|
11
|
+
tirex/models/tirex.py,sha256=JKNuCzTI6B9_yCbcmTf2UFjAQXulLNEmloqtAhKJKjQ,9830
|
|
12
|
+
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
13
|
+
tirex/models/slstm/cell.py,sha256=JfCs1aUy9IHuz9RwExhUwiUtbg8WmbEg4upcO7hA5Rg,7229
|
|
14
|
+
tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
|
|
15
|
+
tirex_mirror-2025.10.7.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
16
|
+
tirex_mirror-2025.10.7.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
17
|
+
tirex_mirror-2025.10.7.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
18
|
+
tirex_mirror-2025.10.7.dist-info/METADATA,sha256=KdlklRtYjCWVG3GB12RFNi1fPxYO0qmVqlrq6r7mURs,11443
|
|
19
|
+
tirex_mirror-2025.10.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
tirex_mirror-2025.10.7.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
21
|
+
tirex_mirror-2025.10.7.dist-info/RECORD,,
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
-
tirex/base.py,sha256=fwyUTGL103kK5jgK5MoLSIHQcZb4lrox_D9fNbY1W1k,3507
|
|
3
|
-
tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
|
|
4
|
-
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
-
tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
|
|
6
|
-
tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
|
|
7
|
-
tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
|
|
8
|
-
tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
|
|
9
|
-
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
-
tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
|
|
11
|
-
tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
|
|
12
|
-
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
13
|
-
tirex/models/slstm/cell.py,sha256=ippaAPKI83j3_1l3pu9ks-iBGO641Elm1W4HsHgVu-c,7601
|
|
14
|
-
tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
|
|
15
|
-
tirex_mirror-2025.10.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
16
|
-
tirex_mirror-2025.10.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
17
|
-
tirex_mirror-2025.10.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
18
|
-
tirex_mirror-2025.10.2.dist-info/METADATA,sha256=Aq9VAU0pojVsrxwbfFLCmwmk1Gfl_Z_49G4yB-Z9eLY,11443
|
|
19
|
-
tirex_mirror-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
-
tirex_mirror-2025.10.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
21
|
-
tirex_mirror-2025.10.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.2.dist-info → tirex_mirror-2025.10.7.dist-info}/licenses/LICENSE_MIRROR.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|