tirex-mirror 2025.6.9.dev3__tar.gz → 2025.8.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {tirex_mirror-2025.6.9.dev3/src/tirex_mirror.egg-info → tirex_mirror-2025.8.28}/PKG-INFO +4 -1
  2. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/pyproject.toml +10 -2
  3. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/api_adapter/forecast.py +11 -17
  4. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28/src/tirex_mirror.egg-info}/PKG-INFO +4 -1
  5. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex_mirror.egg-info/SOURCES.txt +2 -0
  6. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex_mirror.egg-info/requires.txt +4 -0
  7. tirex_mirror-2025.8.28/tests/test_chronos_zs.py +85 -0
  8. tirex_mirror-2025.8.28/tests/test_forecast.py +35 -0
  9. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/LICENSE +0 -0
  10. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/LICENSE_MIRROR.txt +0 -0
  11. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/MANIFEST.in +0 -0
  12. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/NOTICE.txt +0 -0
  13. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/README.md +0 -0
  14. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/setup.cfg +0 -0
  15. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/__init__.py +0 -0
  16. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/api_adapter/__init__.py +0 -0
  17. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/api_adapter/gluon.py +0 -0
  18. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/api_adapter/hf_data.py +0 -0
  19. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/api_adapter/standard_adapter.py +0 -0
  20. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/base.py +0 -0
  21. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/models/__init__.py +0 -0
  22. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/models/components.py +0 -0
  23. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/models/mixed_stack.py +0 -0
  24. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/models/predict_utils.py +0 -0
  25. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex/models/tirex.py +0 -0
  26. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  27. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  28. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/tests/test_forecast_adapter.py +0 -0
  29. {tirex_mirror-2025.6.9.dev3 → tirex_mirror-2025.8.28}/tests/test_standard_adapter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.6.9.dev3
3
+ Version: 2025.8.28
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -80,10 +80,13 @@ Provides-Extra: gluonts
80
80
  Requires-Dist: gluonts; extra == "gluonts"
81
81
  Provides-Extra: hfdataset
82
82
  Requires-Dist: datasets; extra == "hfdataset"
83
+ Provides-Extra: test
84
+ Requires-Dist: fev; extra == "test"
83
85
  Provides-Extra: all
84
86
  Requires-Dist: ipykernel; extra == "all"
85
87
  Requires-Dist: gluonts; extra == "all"
86
88
  Requires-Dist: datasets; extra == "all"
89
+ Requires-Dist: fev; extra == "all"
87
90
  Dynamic: license-file
88
91
 
89
92
  # tirex-mirror
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.06.09dev3"
3
+ version = "2025.08.28"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -26,7 +26,8 @@ Issues = "https://github.com/rozsasarpi/tirex-mirror/issues"
26
26
  notebooks = [ "ipykernel",]
27
27
  gluonts = [ "gluonts",]
28
28
  hfdataset = [ "datasets",]
29
- all = [ "ipykernel", "gluonts", "datasets",]
29
+ test = [ "fev",]
30
+ all = [ "ipykernel", "gluonts", "datasets", "fev",]
30
31
 
31
32
  [tool.docformatter]
32
33
  diff = false
@@ -43,6 +44,13 @@ exclude = [ ".eggs", ".git", ".ipynb_checkpoints", ".pytest_cache", ".ruff_cache
43
44
  line-length = 120
44
45
  target-version = "py311"
45
46
 
47
+ [tool.pytest.ini_options]
48
+ pythonpath = "."
49
+ log_cli = true
50
+ log_cli_level = "INFO"
51
+ log_format = "%(asctime)s %(levelname)s %(message)s"
52
+ log_date_format = "%Y-%m-%d %H:%M:%S"
53
+
46
54
  [tool.ruff.format]
47
55
  quote-style = "double"
48
56
 
@@ -8,20 +8,6 @@ import torch
8
8
 
9
9
  from .standard_adapter import ContextType, get_batches
10
10
 
11
- try:
12
- from .gluon import format_gluonts_output, get_gluon_batches
13
-
14
- _GLUONTS_AVAILABLE = True
15
- except ImportError:
16
- _GLUONTS_AVAILABLE = False
17
-
18
- try:
19
- from .hf_data import get_hfdata_batches
20
-
21
- _HF_DATASETS_AVAILABLE = True
22
- except ImportError:
23
- _HF_DATASETS_AVAILABLE = False
24
-
25
11
 
26
12
  DEF_TARGET_COLUMN = "target"
27
13
  DEF_META_COLUMNS = ("start", "item_id")
@@ -39,7 +25,9 @@ def _format_output(
39
25
  elif output_type == "numpy":
40
26
  return quantiles.cpu().numpy(), means.cpu().numpy()
41
27
  elif output_type == "gluonts":
42
- if not _GLUONTS_AVAILABLE:
28
+ try:
29
+ from .gluon import format_gluonts_output
30
+ except ImportError:
43
31
  raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
44
32
  return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
45
33
  else:
@@ -171,8 +159,11 @@ class ForecastModel(ABC):
171
159
  autogluon data processing function.
172
160
  """
173
161
  assert batch_size >= 1, "Batch size must be >= 1"
174
- if not _GLUONTS_AVAILABLE:
162
+ try:
163
+ from .gluon import get_gluon_batches
164
+ except ImportError:
175
165
  raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
166
+
176
167
  batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
177
168
  return _gen_forecast(
178
169
  self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
@@ -199,10 +190,13 @@ class ForecastModel(ABC):
199
190
  datasets data processing function.
200
191
  """
201
192
  assert batch_size >= 1, "Batch size must be >= 1"
202
- if not _HF_DATASETS_AVAILABLE:
193
+ try:
194
+ from .hf_data import get_hfdata_batches
195
+ except ImportError:
203
196
  raise ValueError(
204
197
  "forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
205
198
  )
199
+
206
200
  batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
207
201
  return _gen_forecast(
208
202
  self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.6.9.dev3
3
+ Version: 2025.8.28
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -80,10 +80,13 @@ Provides-Extra: gluonts
80
80
  Requires-Dist: gluonts; extra == "gluonts"
81
81
  Provides-Extra: hfdataset
82
82
  Requires-Dist: datasets; extra == "hfdataset"
83
+ Provides-Extra: test
84
+ Requires-Dist: fev; extra == "test"
83
85
  Provides-Extra: all
84
86
  Requires-Dist: ipykernel; extra == "all"
85
87
  Requires-Dist: gluonts; extra == "all"
86
88
  Requires-Dist: datasets; extra == "all"
89
+ Requires-Dist: fev; extra == "all"
87
90
  Dynamic: license-file
88
91
 
89
92
  # tirex-mirror
@@ -21,5 +21,7 @@ src/tirex_mirror.egg-info/SOURCES.txt
21
21
  src/tirex_mirror.egg-info/dependency_links.txt
22
22
  src/tirex_mirror.egg-info/requires.txt
23
23
  src/tirex_mirror.egg-info/top_level.txt
24
+ tests/test_chronos_zs.py
25
+ tests/test_forecast.py
24
26
  tests/test_forecast_adapter.py
25
27
  tests/test_standard_adapter.py
@@ -14,6 +14,7 @@ tqdm
14
14
  ipykernel
15
15
  gluonts
16
16
  datasets
17
+ fev
17
18
 
18
19
  [gluonts]
19
20
  gluonts
@@ -23,3 +24,6 @@ datasets
23
24
 
24
25
  [notebooks]
25
26
  ipykernel
27
+
28
+ [test]
29
+ fev
@@ -0,0 +1,85 @@
1
+ from tirex import ForecastModel, load_model
2
+ import time
3
+ import datasets
4
+ import fev
5
+ import pytest
6
+ import os
7
+ import math
8
+
9
+ def geometric_mean(s):
10
+ return math.prod(s) ** (1/len(s))
11
+
12
+ def eval_task(model, task):
13
+ past_data, _ = task.get_input_data(trust_remote_code=True)
14
+ quantile_levels = task.quantile_levels
15
+ past_data = past_data.with_format("torch").cast_column(
16
+ task.target_column, datasets.Sequence(datasets.Value("float32"))
17
+ )[task.target_column]
18
+ loaded_data = [t for t in past_data]
19
+
20
+ start_time = time.monotonic()
21
+ quantiles, means = model.forecast(loaded_data, quantile_levels=quantile_levels, prediction_length=task.horizon)
22
+ inference_time = time.monotonic() - start_time
23
+ predictions_dict = {"predictions": means}
24
+ for idx, level in enumerate(quantile_levels):
25
+ predictions_dict[str(level)] = quantiles[:, :, idx] # [num_items, horizon]
26
+
27
+ predictions = datasets.Dataset.from_dict(predictions_dict)
28
+ return predictions, inference_time
29
+
30
+
31
+ @pytest.fixture
32
+ def tirex_model() -> ForecastModel:
33
+ return load_model("NX-AI/TiRex")
34
+
35
+
36
+ @pytest.fixture
37
+ def benchmark():
38
+ url = "https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/chronos_zeroshot/tasks.yaml"
39
+ return fev.Benchmark.from_yaml(url)
40
+
41
+
42
+ def test_chronos_single(tirex_model, benchmark):
43
+ task_name = "monash_australian_electricity"
44
+ task = [task for task in benchmark.tasks if task.dataset_config == task_name][0]
45
+ predictions, inference_time = eval_task(tirex_model, task)
46
+ evaluation_summary = task.evaluation_summary(
47
+ predictions,
48
+ model_name="TiRex",
49
+ inference_time_s=inference_time,
50
+ )
51
+
52
+ assert evaluation_summary["WQL"] < 0.055, "WQL on the electricity task needs to be less than 0.055"
53
+ assert evaluation_summary["MASE"] < 0.99, "MASE on the electricity task needs to be less than 0.99"
54
+
55
+
56
+ @pytest.mark.skipif(os.getenv("CI"), reason="Skip full chromos benchmarking in the CI")
57
+ def test_chronos_all(tirex_model, benchmark):
58
+ tasks_wql = []
59
+ tasks_mase = []
60
+ for task in benchmark.tasks:
61
+ predictions, inference_time = eval_task(tirex_model, task)
62
+ evaluation_summary = task.evaluation_summary(
63
+ predictions,
64
+ model_name="TiRex",
65
+ inference_time_s=inference_time,
66
+ )
67
+ tasks_wql.append(evaluation_summary["WQL"])
68
+ tasks_mase.append(evaluation_summary["MASE"])
69
+
70
+ # Calculated from the geometric mean of the WQL and MASE data of the seasonal_naive model
71
+ # https://github.com/autogluon/fev/blob/main/benchmarks/chronos_zeroshot/results/seasonal_naive.csv
72
+ agg_wql_baseline = 0.1460642781226389
73
+ agg_mase_baseline = 1.6708210897174531
74
+
75
+ agg_wql = geometric_mean(tasks_wql)
76
+ agg_mase = geometric_mean(tasks_mase)
77
+
78
+ print(f"WQL: {agg_wql / agg_wql_baseline:.3f}")
79
+ print(f"MASE: {agg_mase / agg_mase_baseline:.3f}")
80
+
81
+ tolerance = 0.01
82
+
83
+ # Values from Tirex paper: https://arxiv.org/pdf/2505.23719
84
+ assert agg_wql / agg_wql_baseline < 0.59 + tolerance, "WQL on chromos needs to be less than 0.60"
85
+ assert agg_mase / agg_mase_baseline < 0.78 + tolerance, "MASE on chromos needs to be less than 0.79"
@@ -0,0 +1,35 @@
1
+ from tirex import ForecastModel, load_model
2
+ import pytest
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def load_tensor_from_file(path):
9
+ base_path = Path(__file__).parent.resolve() / "data"
10
+ return torch.from_numpy(np.genfromtxt(base_path / path, dtype=np.float32))
11
+
12
+
13
+ @pytest.fixture
14
+ def tirex_model() -> ForecastModel:
15
+ return load_model("NX-AI/TiRex")
16
+
17
+
18
+ def test_forecast_air_traffic(tirex_model):
19
+ context = load_tensor_from_file("air_passengers.csv")[:-12]
20
+
21
+ quantiles, mean = tirex_model.forecast(context, prediction_length=24)
22
+
23
+ ref_data = load_tensor_from_file("air_passengers_forecast_ref.csv")
24
+
25
+ assert torch.allclose(mean, ref_data), "Forecasted tensor has to match reference data."
26
+
27
+
28
+ def test_forecast_seattle_5T(tirex_model):
29
+ context = load_tensor_from_file("loop_seattle_5T.csv")[:-512]
30
+
31
+ quantiles, mean = tirex_model.forecast(context, prediction_length=768)
32
+
33
+ ref_data = load_tensor_from_file("loop_seattle_5T_forecast_ref.csv")
34
+
35
+ assert torch.allclose(mean, ref_data), "Forecasted tensor has to match reference data."