tirex-mirror 2025.10.28__tar.gz → 2025.11.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/forecast.py +92 -49
  4. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/gluon.py +2 -1
  5. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/util.py +14 -2
  6. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/PKG-INFO +1 -1
  7. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_chronos_zs.py +4 -4
  8. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_forecast_adapter.py +6 -11
  9. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/LICENSE +0 -0
  10. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/LICENSE_MIRROR.txt +0 -0
  11. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/MANIFEST.in +0 -0
  12. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/NOTICE.txt +0 -0
  13. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/README.md +0 -0
  14. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/setup.cfg +0 -0
  15. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/__init__.py +0 -0
  16. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/__init__.py +0 -0
  17. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/hf_data.py +0 -0
  18. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/api_adapter/standard_adapter.py +0 -0
  19. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/base.py +0 -0
  20. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/__init__.py +0 -0
  21. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/patcher.py +0 -0
  22. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/block.py +0 -0
  23. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/cell.py +0 -0
  24. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/slstm/layer.py +0 -0
  25. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex/models/tirex.py +0 -0
  26. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
  27. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  28. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/requires.txt +0 -0
  29. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  30. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_compile.py +0 -0
  31. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_forecast.py +0 -0
  32. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_patcher.py +0 -0
  33. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_slstm_torch_vs_cuda.py +0 -0
  34. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_standard_adapter.py +0 -0
  35. {tirex_mirror-2025.10.28 → tirex_mirror-2025.11.4}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.28
3
+ Version: 2025.11.4
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.28"
3
+ version = "2025.11.04"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -23,7 +23,6 @@ def _format_output(
23
23
  quantiles: torch.Tensor,
24
24
  means: torch.Tensor,
25
25
  sample_meta: list[dict],
26
- quantile_levels: list[float],
27
26
  output_type: Literal["torch", "numpy", "gluonts"],
28
27
  ):
29
28
  if output_type == "torch":
@@ -35,7 +34,7 @@ def _format_output(
35
34
  from .gluon import format_gluonts_output
36
35
  except ImportError:
37
36
  raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
38
- return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
37
+ return format_gluonts_output(quantiles, means, sample_meta)
39
38
  else:
40
39
  raise ValueError(f"Invalid output type: {output_type}")
41
40
 
@@ -61,14 +60,13 @@ def _pad_time_series_batch(
61
60
  return padded
62
61
 
63
62
 
64
- def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
63
+ def _as_generator(batches, fc_func, output_type, **predict_kwargs):
65
64
  for batch_ctx, batch_meta in batches:
66
65
  quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
67
66
  yield _format_output(
68
67
  quantiles=quantiles,
69
68
  means=mean,
70
69
  sample_meta=batch_meta,
71
- quantile_levels=quantile_levels,
72
70
  output_type=output_type,
73
71
  )
74
72
 
@@ -152,7 +150,6 @@ def _gen_forecast(
152
150
  fc_func,
153
151
  batches,
154
152
  output_type,
155
- quantile_levels,
156
153
  yield_per_batch,
157
154
  resample_strategy: str | None = None,
158
155
  max_context: int = 2016,
@@ -173,7 +170,7 @@ def _gen_forecast(
173
170
  fc_func = partial(_call_fc_with_padding, base_fc_func)
174
171
 
175
172
  if yield_per_batch:
176
- return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
173
+ return _as_generator(batches, fc_func, output_type, **predict_kwargs)
177
174
 
178
175
  prediction_q = []
179
176
  prediction_m = []
@@ -191,34 +188,48 @@ def _gen_forecast(
191
188
  quantiles=prediction_q,
192
189
  means=prediction_m,
193
190
  sample_meta=sample_meta,
194
- quantile_levels=quantile_levels,
195
191
  output_type=output_type,
196
192
  )
197
193
 
198
194
 
199
- def _common_forecast_doc():
200
- common_doc = f"""
195
+ class ForecastModel(ABC):
196
+ @abstractmethod
197
+ def _forecast_quantiles(self, batch, **predict_kwargs):
198
+ pass
199
+
200
+ @property
201
+ def max_context_length(self) -> int:
202
+ # retrieve the max_context attribute of the model configuration if present
203
+ return getattr(getattr(self, "config", None), "max_context", 2016)
204
+
205
+ def forecast(
206
+ self,
207
+ context: ContextType,
208
+ output_type: Literal["torch", "numpy", "gluonts"] = "torch",
209
+ batch_size: int = 512,
210
+ yield_per_batch: bool = False,
211
+ resample_strategy: Literal["frequency"] | None = None,
212
+ **predict_kwargs,
213
+ ):
214
+ """
201
215
  This method takes historical context data as input and outputs probabilistic forecasts.
202
216
 
203
217
  Args:
204
218
  output_type (Literal["torch", "numpy", "gluonts"], optional):
205
219
  Specifies the desired format of the returned forecasts:
206
- - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
207
- - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
220
+ - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
221
+ - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
208
222
  - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
209
223
  Defaults to "torch".
210
224
 
211
225
  batch_size (int, optional): The number of time series instances to process concurrently by the model.
212
226
  Defaults to 512. Must be $>= 1$.
213
227
 
214
- quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
215
- Defaults to (0.1, 0.2, ..., 0.9).
216
-
217
228
  yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
218
229
  forecasts batch by batch as they are computed.
219
230
  Defaults to `False`.
220
231
 
221
- resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: {RESAMPLE_STRATEGIES}.
232
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
222
233
  If `None`, no resampling is applied. Currently only "frequency" is supported.
223
234
 
224
235
  **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
@@ -233,32 +244,7 @@ def _common_forecast_doc():
233
244
  - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
234
245
  - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
235
246
  - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
236
- """
237
- return common_doc
238
-
239
247
 
240
- class ForecastModel(ABC):
241
- @abstractmethod
242
- def _forecast_quantiles(self, batch, **predict_kwargs):
243
- pass
244
-
245
- @property
246
- def max_context_length(self) -> int:
247
- # retrieve the max_context attribute of the model configuration if present
248
- return getattr(getattr(self, "config", None), "max_context", 2016)
249
-
250
- def forecast(
251
- self,
252
- context: ContextType,
253
- output_type: Literal["torch", "numpy", "gluonts"] = "torch",
254
- batch_size: int = 512,
255
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
256
- yield_per_batch: bool = False,
257
- resample_strategy: Literal["frequency"] | None = None,
258
- **predict_kwargs,
259
- ):
260
- f"""
261
- {_common_forecast_doc}
262
248
  Args:
263
249
  context (ContextType): The historical "context" data of the time series:
264
250
  - `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
@@ -272,7 +258,6 @@ class ForecastModel(ABC):
272
258
  self._forecast_quantiles,
273
259
  batches,
274
260
  output_type,
275
- quantile_levels,
276
261
  yield_per_batch,
277
262
  resample_strategy=resample_strategy,
278
263
  max_context=self.max_context_length,
@@ -284,14 +269,44 @@ class ForecastModel(ABC):
284
269
  gluonDataset,
285
270
  output_type: Literal["torch", "numpy", "gluonts"] = "torch",
286
271
  batch_size: int = 512,
287
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
288
272
  yield_per_batch: bool = False,
289
273
  resample_strategy: Literal["frequency"] | None = None,
290
274
  data_kwargs: dict = {},
291
275
  **predict_kwargs,
292
276
  ):
293
- f"""
294
- {_common_forecast_doc()}
277
+ """
278
+ This method takes historical context data as input and outputs probabilistic forecasts.
279
+
280
+ Args:
281
+ output_type (Literal["torch", "numpy", "gluonts"], optional):
282
+ Specifies the desired format of the returned forecasts:
283
+ - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
284
+ - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
285
+ - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
286
+ Defaults to "torch".
287
+
288
+ batch_size (int, optional): The number of time series instances to process concurrently by the model.
289
+ Defaults to 512. Must be $>= 1$.
290
+
291
+ yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
292
+ forecasts batch by batch as they are computed.
293
+ Defaults to `False`.
294
+
295
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
296
+ If `None`, no resampling is applied. Currently only "frequency" is supported.
297
+
298
+ **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
299
+ prediction mechanism of the pre-trained model. Refer to the model's
300
+ internal prediction method documentation for available options.
301
+
302
+ Returns:
303
+ The return type depends on `output_type` and `yield_per_batch`:
304
+ - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
305
+ will correspond to a batch of forecasts in the format specified by `output_type`.
306
+ - If `yield_per_batch` is `False`: A single object containing all forecasts.
307
+ - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
308
+ - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
309
+ - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
295
310
 
296
311
  Args:
297
312
  gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
@@ -311,7 +326,6 @@ class ForecastModel(ABC):
311
326
  self._forecast_quantiles,
312
327
  batches,
313
328
  output_type,
314
- quantile_levels,
315
329
  yield_per_batch,
316
330
  resample_strategy=resample_strategy,
317
331
  max_context=self.max_context_length,
@@ -323,14 +337,44 @@ class ForecastModel(ABC):
323
337
  hf_dataset,
324
338
  output_type: Literal["torch", "numpy", "gluonts"] = "torch",
325
339
  batch_size: int = 512,
326
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
327
340
  yield_per_batch: bool = False,
328
341
  resample_strategy: Literal["frequency"] | None = None,
329
342
  data_kwargs: dict = {},
330
343
  **predict_kwargs,
331
344
  ):
332
- f"""
333
- {_common_forecast_doc()}
345
+ """
346
+ This method takes historical context data as input and outputs probabilistic forecasts.
347
+
348
+ Args:
349
+ output_type (Literal["torch", "numpy", "gluonts"], optional):
350
+ Specifies the desired format of the returned forecasts:
351
+ - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, quantile_count]
352
+ - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, quantile_count]
353
+ - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
354
+ Defaults to "torch".
355
+
356
+ batch_size (int, optional): The number of time series instances to process concurrently by the model.
357
+ Defaults to 512. Must be $>= 1$.
358
+
359
+ yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
360
+ forecasts batch by batch as they are computed.
361
+ Defaults to `False`.
362
+
363
+ resample_strategy (Optional[str], optional): Choose a resampling strategy. Allowed values: "frequency".
364
+ If `None`, no resampling is applied. Currently only "frequency" is supported.
365
+
366
+ **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
367
+ prediction mechanism of the pre-trained model. Refer to the model's
368
+ internal prediction method documentation for available options.
369
+
370
+ Returns:
371
+ The return type depends on `output_type` and `yield_per_batch`:
372
+ - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
373
+ will correspond to a batch of forecasts in the format specified by `output_type`.
374
+ - If `yield_per_batch` is `False`: A single object containing all forecasts.
375
+ - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
376
+ - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
377
+ - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
334
378
 
335
379
  Args:
336
380
  hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
@@ -352,7 +396,6 @@ class ForecastModel(ABC):
352
396
  self._forecast_quantiles,
353
397
  batches,
354
398
  output_type,
355
- quantile_levels,
356
399
  yield_per_batch,
357
400
  resample_strategy=resample_strategy,
358
401
  max_context=self.max_context_length,
@@ -30,7 +30,8 @@ def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
30
  return _batch_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
31
 
32
32
 
33
- def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
33
+ def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict]):
34
+ quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
34
35
  forecasts = []
35
36
  for i in range(quantile_forecasts.shape[0]):
36
37
  start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h")))
@@ -419,7 +419,7 @@ def run_fft_analysis(
419
419
  scaling : {'amplitude', 'power', 'raw'}
420
420
  - 'amplitude': one-sided amplitude spectrum with window-power compensation.
421
421
  - 'power' : one-sided power (not density) with window-power compensation.
422
- - 'raw' : |rfft(yw)| (no normalization, mostly for debugging).
422
+ - 'raw' : rfft(yw) (no normalization, mostly for debugging).
423
423
  peak_prominence : float
424
424
  Absolute threshold on the normalized spectrum for peak detection.
425
425
 
@@ -527,7 +527,6 @@ def resampling_factor(inverted_freq, path_size):
527
527
  def custom_find_peaks(
528
528
  f,
529
529
  spec,
530
- *,
531
530
  max_peaks=5,
532
531
  prominence_threshold=0.1,
533
532
  min_period=64,
@@ -615,3 +614,16 @@ def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
615
614
  def dataclass_from_dict(cls, dict: dict):
616
615
  class_fields = {f.name for f in fields(cls)}
617
616
  return cls(**{k: v for k, v in dict.items() if k in class_fields})
617
+
618
+
619
+ def select_quantile_subset(quantiles: torch.Tensor, quantile_levels: list[float]):
620
+ """
621
+ Select specified quantile levels from the quantiles.
622
+ """
623
+ trained_quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
624
+ assert set(quantile_levels).issubset(trained_quantiles), (
625
+ f"Only the following quantile_levels are supported: {quantile_levels}"
626
+ )
627
+ quantile_levels_idx = [trained_quantiles.index(q) for q in quantile_levels]
628
+ quantiles_idx = torch.tensor(quantile_levels_idx, dtype=torch.long, device=quantiles.device)
629
+ return torch.index_select(quantiles, dim=-1, index=quantiles_idx).squeeze(-1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.28
3
+ Version: 2025.11.4
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -10,6 +10,7 @@ import fev
10
10
  import pytest
11
11
 
12
12
  from tirex import ForecastModel, load_model
13
+ from tirex.util import select_quantile_subset
13
14
 
14
15
 
15
16
  def geometric_mean(s):
@@ -25,14 +26,13 @@ def eval_task(model, task):
25
26
  loaded_targets = [t for t in past_data["target"]]
26
27
 
27
28
  start_time = time.monotonic()
28
- quantiles, means = model.forecast(
29
- loaded_targets, quantile_levels=task.quantile_levels, prediction_length=task.horizon
30
- )
29
+ quantiles, means = model.forecast(loaded_targets, prediction_length=task.horizon)
31
30
  inference_time += time.monotonic() - start_time
32
31
 
33
32
  predictions_dict = {"predictions": means}
33
+ quantiles_subset = select_quantile_subset(quantiles, task.quantile_levels)
34
34
  for idx, level in enumerate(task.quantile_levels):
35
- predictions_dict[str(level)] = quantiles[:, :, idx]
35
+ predictions_dict[str(level)] = quantiles_subset[:, :, idx]
36
36
 
37
37
  predictions_per_window.append(
38
38
  fev.combine_univariate_predictions_to_multivariate(
@@ -25,18 +25,15 @@ class DummyForecaster(ForecastModel):
25
25
  return fc_random_from_tensor(batch, **kwargs)
26
26
 
27
27
 
28
- quantile_levels = list(np.linspace(0.1, 0.9, 9))
29
-
30
-
31
28
  # ----- Tests: Output formatting -----
32
29
  def test_format_output_shapes(dummy_fc_func):
33
30
  B, L = 2, 5
34
31
  PL = 10
35
32
  q, m = dummy_fc_func(torch.rand(B, L), prediction_length=PL)
36
- out_q, out_m = _format_output(q, m, [{}] * B, quantile_levels, "torch")
33
+ out_q, out_m = _format_output(q, m, [{}] * B, "torch")
37
34
  assert isinstance(out_q, torch.Tensor)
38
35
  assert isinstance(out_m, torch.Tensor)
39
- assert out_q.shape == (B, PL, len(quantile_levels))
36
+ assert out_q.shape == (B, PL, 9)
40
37
  assert out_m.shape == (B, PL)
41
38
 
42
39
 
@@ -44,10 +41,10 @@ def test_format_output_shapes(dummy_fc_func):
44
41
  B, L = 2, 5
45
42
  PL = 10
46
43
  q, m = dummy_fc_func(torch.rand(B, L), prediction_length=PL)
47
- out_q, out_m = _format_output(q, m, [{}] * B, quantile_levels, "numpy")
44
+ out_q, out_m = _format_output(q, m, [{}] * B, "numpy")
48
45
  assert isinstance(out_q, np.ndarray)
49
46
  assert isinstance(out_m, np.ndarray)
50
- assert out_q.shape == (B, PL, len(quantile_levels))
47
+ assert out_q.shape == (B, PL, 9)
51
48
  assert out_m.shape == (B, PL)
52
49
 
53
50
 
@@ -55,7 +52,7 @@ def test_format_output_shapes(dummy_fc_func):
55
52
  def test_gen_forecast_single_batch(dummy_fc_func):
56
53
  context = torch.rand((5, 20))
57
54
  batches = get_batches(context, batch_size=2)
58
- q, m = _gen_forecast(dummy_fc_func, batches, "torch", quantile_levels, prediction_length=10, yield_per_batch=False)
55
+ q, m = _gen_forecast(dummy_fc_func, batches, "torch", prediction_length=10, yield_per_batch=False)
59
56
  assert q.shape == (5, 10, 9)
60
57
  assert m.shape == (5, 10)
61
58
 
@@ -63,9 +60,7 @@ def test_gen_forecast_single_batch(dummy_fc_func):
63
60
  def test_gen_forecast_iterator(dummy_fc_func):
64
61
  context = torch.rand((5, 20))
65
62
  batches = get_batches(context, batch_size=2)
66
- iterator = _gen_forecast(
67
- dummy_fc_func, batches, "torch", quantile_levels, prediction_length=10, yield_per_batch=True
68
- )
63
+ iterator = _gen_forecast(dummy_fc_func, batches, "torch", prediction_length=10, yield_per_batch=True)
69
64
  outputs = list(iterator)
70
65
  assert len(outputs) == 3
71
66
  for i, (q, m) in enumerate(outputs):