tirex-mirror 2025.10.28__tar.gz → 2025.11.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/pyproject.toml +1 -1
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/forecast.py +92 -49
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/gluon.py +2 -1
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/util.py +14 -2
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/PKG-INFO +1 -1
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_chronos_zs.py +4 -4
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_forecast_adapter.py +6 -11
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/LICENSE +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/MANIFEST.in +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/README.md +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/setup.cfg +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/base.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/patcher.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/cell.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/layer.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/tirex.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_compile.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_forecast.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_patcher.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_slstm_torch_vs_cuda.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_standard_adapter.py +0 -0
- {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_util_freq.py +0 -0
|
@@ -23,7 +23,6 @@ def _format_output(
|
|
|
23
23
|
quantiles: torch.Tensor,
|
|
24
24
|
means: torch.Tensor,
|
|
25
25
|
sample_meta: list[dict],
|
|
26
|
-
quantile_levels: list[float],
|
|
27
26
|
output_type: Literal["torch", "numpy", "gluonts"],
|
|
28
27
|
):
|
|
29
28
|
if output_type == "torch":
|
|
@@ -35,7 +34,7 @@ def _format_output(
|
|
|
35
34
|
from .gluon import format_gluonts_output
|
|
36
35
|
except ImportError:
|
|
37
36
|
raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
|
|
38
|
-
return format_gluonts_output(quantiles, means, sample_meta
|
|
37
|
+
return format_gluonts_output(quantiles, means, sample_meta)
|
|
39
38
|
else:
|
|
40
39
|
raise ValueError(f"Invalid output type: {output_type}")
|
|
41
40
|
|
|
@@ -61,14 +60,13 @@ def _pad_time_series_batch(
|
|
|
61
60
|
return padded
|
|
62
61
|
|
|
63
62
|
|
|
64
|
-
def _as_generator(batches, fc_func,
|
|
63
|
+
def _as_generator(batches, fc_func, output_type, **predict_kwargs):
|
|
65
64
|
for batch_ctx, batch_meta in batches:
|
|
66
65
|
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
|
|
67
66
|
yield _format_output(
|
|
68
67
|
quantiles=quantiles,
|
|
69
68
|
means=mean,
|
|
70
69
|
sample_meta=batch_meta,
|
|
71
|
-
quantile_levels=quantile_levels,
|
|
72
70
|
output_type=output_type,
|
|
73
71
|
)
|
|
74
72
|
|
|
@@ -152,7 +150,6 @@ def _gen_forecast(
|
|
|
152
150
|
fc_func,
|
|
153
151
|
batches,
|
|
154
152
|
output_type,
|
|
155
|
-
quantile_levels,
|
|
156
153
|
yield_per_batch,
|
|
157
154
|
resample_strategy: str | None = None,
|
|
158
155
|
max_context: int = 2016,
|
|
@@ -173,7 +170,7 @@ def _gen_forecast(
|
|
|
173
170
|
fc_func = partial(_call_fc_with_padding, base_fc_func)
|
|
174
171
|
|
|
175
172
|
if yield_per_batch:
|
|
176
|
-
return _as_generator(batches, fc_func,
|
|
173
|
+
return _as_generator(batches, fc_func, output_type, **predict_kwargs)
|
|
177
174
|
|
|
178
175
|
prediction_q = []
|
|
179
176
|
prediction_m = []
|
|
@@ -191,34 +188,48 @@ def _gen_forecast(
|
|
|
191
188
|
quantiles=prediction_q,
|
|
192
189
|
means=prediction_m,
|
|
193
190
|
sample_meta=sample_meta,
|
|
194
|
-
quantile_levels=quantile_levels,
|
|
195
191
|
output_type=output_type,
|
|
196
192
|
)
|
|
197
193
|
|
|
198
194
|
|
|
199
|
-
|
|
200
|
-
|
|
195
|
+
class ForecastModel(ABC):
|
|
196
|
+
@abstractmethod
|
|
197
|
+
def _forecast_quantiles(self, batch, **predict_kwargs):
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def max_context_length(self) -> int:
|
|
202
|
+
# retrieve the max_context attribute of the model configuration if present
|
|
203
|
+
return getattr(getattr(self, "config", None), "max_context", 2016)
|
|
204
|
+
|
|
205
|
+
def forecast(
|
|
206
|
+
self,
|
|
207
|
+
context: ContextType,
|
|
208
|
+
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
|
209
|
+
batch_size: int = 512,
|
|
210
|
+
yield_per_batch: bool = False,
|
|
211
|
+
resample_strategy: Literal["frequency"] | None = None,
|
|
212
|
+
**predict_kwargs,
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
201
215
|
This method takes historical context data as input and outputs probabilistic forecasts.
|
|
202
216
|
|
|
203
217
|
Args:
|
|
204
218
|
output_type (Literal["torch", "numpy", "gluonts"], optional):
|
|
205
219
|
Specifies the desired format of the returned forecasts:
|
|
206
|
-
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len,
|
|
207
|
-
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len,
|
|
220
|
+
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
|
|
221
|
+
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
|
|
208
222
|
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
|
|
209
223
|
Defaults to "torch".
|
|
210
224
|
|
|
211
225
|
batch_size (int, optional): The number of time series instances to process concurrently by the model.
|
|
212
226
|
Defaults to 512. Must be $>= 1$.
|
|
213
227
|
|
|
214
|
-
quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
|
|
215
|
-
Defaults to (0.1, 0.2, ..., 0.9).
|
|
216
|
-
|
|
217
228
|
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
|
|
218
229
|
forecasts batch by batch as they are computed.
|
|
219
230
|
Defaults to `False`.
|
|
220
231
|
|
|
221
|
-
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values:
|
|
232
|
+
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
|
|
222
233
|
If `None`, no resampling is applied. Currently only "frequency" is supported.
|
|
223
234
|
|
|
224
235
|
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
|
@@ -233,32 +244,7 @@ def _common_forecast_doc():
|
|
|
233
244
|
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
|
|
234
245
|
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
|
|
235
246
|
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
|
|
236
|
-
"""
|
|
237
|
-
return common_doc
|
|
238
|
-
|
|
239
247
|
|
|
240
|
-
class ForecastModel(ABC):
|
|
241
|
-
@abstractmethod
|
|
242
|
-
def _forecast_quantiles(self, batch, **predict_kwargs):
|
|
243
|
-
pass
|
|
244
|
-
|
|
245
|
-
@property
|
|
246
|
-
def max_context_length(self) -> int:
|
|
247
|
-
# retrieve the max_context attribute of the model configuration if present
|
|
248
|
-
return getattr(getattr(self, "config", None), "max_context", 2016)
|
|
249
|
-
|
|
250
|
-
def forecast(
|
|
251
|
-
self,
|
|
252
|
-
context: ContextType,
|
|
253
|
-
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
|
254
|
-
batch_size: int = 512,
|
|
255
|
-
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
256
|
-
yield_per_batch: bool = False,
|
|
257
|
-
resample_strategy: Literal["frequency"] | None = None,
|
|
258
|
-
**predict_kwargs,
|
|
259
|
-
):
|
|
260
|
-
f"""
|
|
261
|
-
{_common_forecast_doc}
|
|
262
248
|
Args:
|
|
263
249
|
context (ContextType): The historical "context" data of the time series:
|
|
264
250
|
- `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
|
|
@@ -272,7 +258,6 @@ class ForecastModel(ABC):
|
|
|
272
258
|
self._forecast_quantiles,
|
|
273
259
|
batches,
|
|
274
260
|
output_type,
|
|
275
|
-
quantile_levels,
|
|
276
261
|
yield_per_batch,
|
|
277
262
|
resample_strategy=resample_strategy,
|
|
278
263
|
max_context=self.max_context_length,
|
|
@@ -284,14 +269,44 @@ class ForecastModel(ABC):
|
|
|
284
269
|
gluonDataset,
|
|
285
270
|
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
|
286
271
|
batch_size: int = 512,
|
|
287
|
-
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
288
272
|
yield_per_batch: bool = False,
|
|
289
273
|
resample_strategy: Literal["frequency"] | None = None,
|
|
290
274
|
data_kwargs: dict = {},
|
|
291
275
|
**predict_kwargs,
|
|
292
276
|
):
|
|
293
|
-
|
|
294
|
-
|
|
277
|
+
"""
|
|
278
|
+
This method takes historical context data as input and outputs probabilistic forecasts.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
output_type (Literal["torch", "numpy", "gluonts"], optional):
|
|
282
|
+
Specifies the desired format of the returned forecasts:
|
|
283
|
+
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
|
|
284
|
+
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
|
|
285
|
+
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
|
|
286
|
+
Defaults to "torch".
|
|
287
|
+
|
|
288
|
+
batch_size (int, optional): The number of time series instances to process concurrently by the model.
|
|
289
|
+
Defaults to 512. Must be $>= 1$.
|
|
290
|
+
|
|
291
|
+
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
|
|
292
|
+
forecasts batch by batch as they are computed.
|
|
293
|
+
Defaults to `False`.
|
|
294
|
+
|
|
295
|
+
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
|
|
296
|
+
If `None`, no resampling is applied. Currently only "frequency" is supported.
|
|
297
|
+
|
|
298
|
+
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
|
299
|
+
prediction mechanism of the pre-trained model. Refer to the model's
|
|
300
|
+
internal prediction method documentation for available options.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
The return type depends on `output_type` and `yield_per_batch`:
|
|
304
|
+
- If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
|
|
305
|
+
will correspond to a batch of forecasts in the format specified by `output_type`.
|
|
306
|
+
- If `yield_per_batch` is `False`: A single object containing all forecasts.
|
|
307
|
+
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
|
|
308
|
+
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
|
|
309
|
+
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
|
|
295
310
|
|
|
296
311
|
Args:
|
|
297
312
|
gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
|
|
@@ -311,7 +326,6 @@ class ForecastModel(ABC):
|
|
|
311
326
|
self._forecast_quantiles,
|
|
312
327
|
batches,
|
|
313
328
|
output_type,
|
|
314
|
-
quantile_levels,
|
|
315
329
|
yield_per_batch,
|
|
316
330
|
resample_strategy=resample_strategy,
|
|
317
331
|
max_context=self.max_context_length,
|
|
@@ -323,14 +337,44 @@ class ForecastModel(ABC):
|
|
|
323
337
|
hf_dataset,
|
|
324
338
|
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
|
325
339
|
batch_size: int = 512,
|
|
326
|
-
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
|
327
340
|
yield_per_batch: bool = False,
|
|
328
341
|
resample_strategy: Literal["frequency"] | None = None,
|
|
329
342
|
data_kwargs: dict = {},
|
|
330
343
|
**predict_kwargs,
|
|
331
344
|
):
|
|
332
|
-
|
|
333
|
-
|
|
345
|
+
"""
|
|
346
|
+
This method takes historical context data as input and outputs probabilistic forecasts.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
output_type (Literal["torch", "numpy", "gluonts"], optional):
|
|
350
|
+
Specifies the desired format of the returned forecasts:
|
|
351
|
+
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
|
|
352
|
+
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
|
|
353
|
+
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
|
|
354
|
+
Defaults to "torch".
|
|
355
|
+
|
|
356
|
+
batch_size (int, optional): The number of time series instances to process concurrently by the model.
|
|
357
|
+
Defaults to 512. Must be $>= 1$.
|
|
358
|
+
|
|
359
|
+
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
|
|
360
|
+
forecasts batch by batch as they are computed.
|
|
361
|
+
Defaults to `False`.
|
|
362
|
+
|
|
363
|
+
resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
|
|
364
|
+
If `None`, no resampling is applied. Currently only "frequency" is supported.
|
|
365
|
+
|
|
366
|
+
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
|
367
|
+
prediction mechanism of the pre-trained model. Refer to the model's
|
|
368
|
+
internal prediction method documentation for available options.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
The return type depends on `output_type` and `yield_per_batch`:
|
|
372
|
+
- If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
|
|
373
|
+
will correspond to a batch of forecasts in the format specified by `output_type`.
|
|
374
|
+
- If `yield_per_batch` is `False`: A single object containing all forecasts.
|
|
375
|
+
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
|
|
376
|
+
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
|
|
377
|
+
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
|
|
334
378
|
|
|
335
379
|
Args:
|
|
336
380
|
hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
|
|
@@ -352,7 +396,6 @@ class ForecastModel(ABC):
|
|
|
352
396
|
self._forecast_quantiles,
|
|
353
397
|
batches,
|
|
354
398
|
output_type,
|
|
355
|
-
quantile_levels,
|
|
356
399
|
yield_per_batch,
|
|
357
400
|
resample_strategy=resample_strategy,
|
|
358
401
|
max_context=self.max_context_length,
|
|
@@ -30,7 +30,8 @@ def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
|
|
|
30
30
|
return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict]
|
|
33
|
+
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict]):
|
|
34
|
+
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
|
|
34
35
|
forecasts = []
|
|
35
36
|
for i in range(quantile_forecasts.shape[0]):
|
|
36
37
|
start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h")))
|
|
@@ -419,7 +419,7 @@ def run_fft_analysis(
|
|
|
419
419
|
scaling : {'amplitude', 'power', 'raw'}
|
|
420
420
|
- 'amplitude': one-sided amplitude spectrum with window-power compensation.
|
|
421
421
|
- 'power' : one-sided power (not density) with window-power compensation.
|
|
422
|
-
- 'raw' :
|
|
422
|
+
- 'raw' : rfft(yw) (no normalization, mostly for debugging).
|
|
423
423
|
peak_prominence : float
|
|
424
424
|
Absolute threshold on the normalized spectrum for peak detection.
|
|
425
425
|
|
|
@@ -527,7 +527,6 @@ def resampling_factor(inverted_freq, path_size):
|
|
|
527
527
|
def custom_find_peaks(
|
|
528
528
|
f,
|
|
529
529
|
spec,
|
|
530
|
-
*,
|
|
531
530
|
max_peaks=5,
|
|
532
531
|
prominence_threshold=0.1,
|
|
533
532
|
min_period=64,
|
|
@@ -615,3 +614,16 @@ def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
|
|
|
615
614
|
def dataclass_from_dict(cls, dict: dict):
|
|
616
615
|
class_fields = {f.name for f in fields(cls)}
|
|
617
616
|
return cls(**{k: v for k, v in dict.items() if k in class_fields})
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def select_quantile_subset(quantiles: torch.Tensor, quantile_levels: list[float]):
|
|
620
|
+
"""
|
|
621
|
+
Select specified quantile levels from the quantiles.
|
|
622
|
+
"""
|
|
623
|
+
trained_quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
|
624
|
+
assert set(quantile_levels).issubset(trained_quantiles), (
|
|
625
|
+
f"Only the following quantile_levels are supported: {quantile_levels}"
|
|
626
|
+
)
|
|
627
|
+
quantile_levels_idx = [trained_quantiles.index(q) for q in quantile_levels]
|
|
628
|
+
quantiles_idx = torch.tensor(quantile_levels_idx, dtype=torch.long, device=quantiles.device)
|
|
629
|
+
return torch.index_select(quantiles, dim=-1, index=quantiles_idx).squeeze(-1)
|
|
@@ -10,6 +10,7 @@ import fev
|
|
|
10
10
|
import pytest
|
|
11
11
|
|
|
12
12
|
from tirex import ForecastModel, load_model
|
|
13
|
+
from tirex.util import select_quantile_subset
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def geometric_mean(s):
|
|
@@ -25,14 +26,13 @@ def eval_task(model, task):
|
|
|
25
26
|
loaded_targets = [t for t in past_data["target"]]
|
|
26
27
|
|
|
27
28
|
start_time = time.monotonic()
|
|
28
|
-
quantiles, means = model.forecast(
|
|
29
|
-
loaded_targets, quantile_levels=task.quantile_levels, prediction_length=task.horizon
|
|
30
|
-
)
|
|
29
|
+
quantiles, means = model.forecast(loaded_targets, prediction_length=task.horizon)
|
|
31
30
|
inference_time += time.monotonic() - start_time
|
|
32
31
|
|
|
33
32
|
predictions_dict = {"predictions": means}
|
|
33
|
+
quantiles_subset = select_quantile_subset(quantiles, task.quantile_levels)
|
|
34
34
|
for idx, level in enumerate(task.quantile_levels):
|
|
35
|
-
predictions_dict[str(level)] =
|
|
35
|
+
predictions_dict[str(level)] = quantiles_subset[:, :, idx]
|
|
36
36
|
|
|
37
37
|
predictions_per_window.append(
|
|
38
38
|
fev.combine_univariate_predictions_to_multivariate(
|
|
@@ -25,18 +25,15 @@ class DummyForecaster(ForecastModel):
|
|
|
25
25
|
return fc_random_from_tensor(batch, **kwargs)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
quantile_levels = list(np.linspace(0.1, 0.9, 9))
|
|
29
|
-
|
|
30
|
-
|
|
31
28
|
# ----- Tests: Output formatting -----
|
|
32
29
|
def test_format_output_shapes(dummy_fc_func):
|
|
33
30
|
B, L = 2, 5
|
|
34
31
|
PL = 10
|
|
35
32
|
q, m = dummy_fc_func(torch.rand(B, L), prediction_length=PL)
|
|
36
|
-
out_q, out_m = _format_output(q, m, [{}] * B,
|
|
33
|
+
out_q, out_m = _format_output(q, m, [{}] * B, "torch")
|
|
37
34
|
assert isinstance(out_q, torch.Tensor)
|
|
38
35
|
assert isinstance(out_m, torch.Tensor)
|
|
39
|
-
assert out_q.shape == (B, PL,
|
|
36
|
+
assert out_q.shape == (B, PL, 9)
|
|
40
37
|
assert out_m.shape == (B, PL)
|
|
41
38
|
|
|
42
39
|
|
|
@@ -44,10 +41,10 @@ def test_format_output_shapes(dummy_fc_func):
|
|
|
44
41
|
B, L = 2, 5
|
|
45
42
|
PL = 10
|
|
46
43
|
q, m = dummy_fc_func(torch.rand(B, L), prediction_length=PL)
|
|
47
|
-
out_q, out_m = _format_output(q, m, [{}] * B,
|
|
44
|
+
out_q, out_m = _format_output(q, m, [{}] * B, "numpy")
|
|
48
45
|
assert isinstance(out_q, np.ndarray)
|
|
49
46
|
assert isinstance(out_m, np.ndarray)
|
|
50
|
-
assert out_q.shape == (B, PL,
|
|
47
|
+
assert out_q.shape == (B, PL, 9)
|
|
51
48
|
assert out_m.shape == (B, PL)
|
|
52
49
|
|
|
53
50
|
|
|
@@ -55,7 +52,7 @@ def test_format_output_shapes(dummy_fc_func):
|
|
|
55
52
|
def test_gen_forecast_single_batch(dummy_fc_func):
|
|
56
53
|
context = torch.rand((5, 20))
|
|
57
54
|
batches = get_batches(context, batch_size=2)
|
|
58
|
-
q, m = _gen_forecast(dummy_fc_func, batches, "torch",
|
|
55
|
+
q, m = _gen_forecast(dummy_fc_func, batches, "torch", prediction_length=10, yield_per_batch=False)
|
|
59
56
|
assert q.shape == (5, 10, 9)
|
|
60
57
|
assert m.shape == (5, 10)
|
|
61
58
|
|
|
@@ -63,9 +60,7 @@ def test_gen_forecast_single_batch(dummy_fc_func):
|
|
|
63
60
|
def test_gen_forecast_iterator(dummy_fc_func):
|
|
64
61
|
context = torch.rand((5, 20))
|
|
65
62
|
batches = get_batches(context, batch_size=2)
|
|
66
|
-
iterator = _gen_forecast(
|
|
67
|
-
dummy_fc_func, batches, "torch", quantile_levels, prediction_length=10, yield_per_batch=True
|
|
68
|
-
)
|
|
63
|
+
iterator = _gen_forecast(dummy_fc_func, batches, "torch", prediction_length=10, yield_per_batch=True)
|
|
69
64
|
outputs = list(iterator)
|
|
70
65
|
assert len(outputs) == 3
|
|
71
66
|
for i, (q, m) in enumerate(outputs):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/standard_adapter.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|