tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Feature versioning and hashing for provenance tracking."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class FeatureConfig:
|
|
13
|
+
"""Configuration for feature engineering.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
engine: Feature backend selection ("auto", "native", "tsfeatures")
|
|
17
|
+
lags: List of lag periods to create (e.g., [1, 7, 14])
|
|
18
|
+
calendar_features: List of calendar features to create
|
|
19
|
+
(e.g., ["dayofweek", "month", "quarter", "year"])
|
|
20
|
+
rolling_windows: Dict mapping window sizes to aggregation functions
|
|
21
|
+
(e.g., {7: ["mean", "std"], 30: ["mean"]})
|
|
22
|
+
known_covariates: List of column names for known covariates
|
|
23
|
+
observed_covariates: List of column names for observed covariates
|
|
24
|
+
include_intercept: Whether to include an intercept column (all 1s)
|
|
25
|
+
tsfeatures_features: Optional list of tsfeatures function names
|
|
26
|
+
tsfeatures_freq: Optional season length to pass to tsfeatures
|
|
27
|
+
tsfeatures_dict_freqs: Optional dict mapping pandas freq -> season length
|
|
28
|
+
allow_fallback: Allow fallback to native when tsfeatures is unavailable
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> config = FeatureConfig(
|
|
32
|
+
... lags=[1, 7, 14],
|
|
33
|
+
... calendar_features=["dayofweek", "month"],
|
|
34
|
+
... rolling_windows={7: ["mean", "std"], 30: ["mean"]},
|
|
35
|
+
... known_covariates=["holiday"],
|
|
36
|
+
... observed_covariates=["promotion"],
|
|
37
|
+
... )
|
|
38
|
+
>>> print(compute_feature_hash(config))
|
|
39
|
+
abc123def456...
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
engine: Literal["auto", "native", "tsfeatures"] = "auto"
|
|
43
|
+
lags: list[int] = field(default_factory=list)
|
|
44
|
+
calendar_features: list[str] = field(default_factory=list)
|
|
45
|
+
rolling_windows: dict[int, list[str]] = field(default_factory=dict)
|
|
46
|
+
known_covariates: list[str] = field(default_factory=list)
|
|
47
|
+
observed_covariates: list[str] = field(default_factory=list)
|
|
48
|
+
include_intercept: bool = False
|
|
49
|
+
tsfeatures_features: list[str] = field(default_factory=list)
|
|
50
|
+
tsfeatures_freq: int | None = None
|
|
51
|
+
tsfeatures_dict_freqs: dict[str, int] = field(default_factory=dict)
|
|
52
|
+
allow_fallback: bool = False
|
|
53
|
+
|
|
54
|
+
def __post_init__(self) -> None:
|
|
55
|
+
"""Validate configuration after creation."""
|
|
56
|
+
if self.engine not in {"auto", "native", "tsfeatures"}:
|
|
57
|
+
raise ValueError(f"Invalid feature engine: {self.engine}")
|
|
58
|
+
|
|
59
|
+
# Validate lags are positive integers
|
|
60
|
+
for lag in self.lags:
|
|
61
|
+
if not isinstance(lag, int) or lag < 1:
|
|
62
|
+
raise ValueError(f"Lags must be positive integers, got {lag}")
|
|
63
|
+
|
|
64
|
+
# Validate calendar features
|
|
65
|
+
valid_calendar = {
|
|
66
|
+
"dayofweek", "month", "quarter", "year", "dayofmonth",
|
|
67
|
+
"dayofyear", "weekofyear", "hour", "minute", "is_month_start",
|
|
68
|
+
"is_month_end", "is_quarter_start", "is_quarter_end",
|
|
69
|
+
}
|
|
70
|
+
invalid = set(self.calendar_features) - valid_calendar
|
|
71
|
+
if invalid:
|
|
72
|
+
raise ValueError(f"Invalid calendar features: {invalid}. Valid: {valid_calendar}")
|
|
73
|
+
|
|
74
|
+
# Validate rolling aggregations
|
|
75
|
+
valid_aggs = {"mean", "std", "min", "max", "sum", "median"}
|
|
76
|
+
for window, aggs in self.rolling_windows.items():
|
|
77
|
+
if not isinstance(window, int) or window < 1:
|
|
78
|
+
raise ValueError(f"Window sizes must be positive integers, got {window}")
|
|
79
|
+
invalid_aggs = set(aggs) - valid_aggs
|
|
80
|
+
if invalid_aggs:
|
|
81
|
+
raise ValueError(f"Invalid aggregations: {invalid_aggs}. Valid: {valid_aggs}")
|
|
82
|
+
|
|
83
|
+
# Check for overlap in covariates
|
|
84
|
+
overlap = set(self.known_covariates) & set(self.observed_covariates)
|
|
85
|
+
if overlap:
|
|
86
|
+
raise ValueError(f"Covariates cannot be both known and observed: {overlap}")
|
|
87
|
+
|
|
88
|
+
if self.tsfeatures_freq is not None and self.tsfeatures_freq < 1:
|
|
89
|
+
raise ValueError("tsfeatures_freq must be a positive integer when provided")
|
|
90
|
+
|
|
91
|
+
for key, value in self.tsfeatures_dict_freqs.items():
|
|
92
|
+
if not isinstance(value, int) or value < 1:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"tsfeatures_dict_freqs must map to positive integers, got {key}={value}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def compute_feature_hash(config: FeatureConfig) -> str:
|
|
99
|
+
"""Compute deterministic hash of feature configuration.
|
|
100
|
+
|
|
101
|
+
The hash includes all feature configuration parameters to ensure
|
|
102
|
+
that any change to the feature engineering setup results in a
|
|
103
|
+
different hash for provenance tracking.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
config: Feature configuration to hash
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
SHA-256 hash string (truncated to 16 characters)
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
>>> config = FeatureConfig(lags=[1, 7], calendar_features=["dayofweek"])
|
|
113
|
+
>>> h = compute_feature_hash(config)
|
|
114
|
+
>>> len(h)
|
|
115
|
+
16
|
|
116
|
+
"""
|
|
117
|
+
# Build normalized configuration dict
|
|
118
|
+
config_dict = {
|
|
119
|
+
"engine": config.engine,
|
|
120
|
+
"lags": sorted(config.lags) if config.lags else [],
|
|
121
|
+
"calendar": sorted(config.calendar_features),
|
|
122
|
+
"rolling": [
|
|
123
|
+
{"window": w, "aggs": sorted(a)}
|
|
124
|
+
for w, a in sorted(config.rolling_windows.items())
|
|
125
|
+
],
|
|
126
|
+
"known_covariates": sorted(config.known_covariates),
|
|
127
|
+
"observed_covariates": sorted(config.observed_covariates),
|
|
128
|
+
"include_intercept": config.include_intercept,
|
|
129
|
+
"tsfeatures_features": sorted(config.tsfeatures_features),
|
|
130
|
+
"tsfeatures_freq": config.tsfeatures_freq,
|
|
131
|
+
"tsfeatures_dict_freqs": {
|
|
132
|
+
k: config.tsfeatures_dict_freqs[k]
|
|
133
|
+
for k in sorted(config.tsfeatures_dict_freqs)
|
|
134
|
+
},
|
|
135
|
+
"allow_fallback": config.allow_fallback,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# Create deterministic JSON representation
|
|
139
|
+
json_str = json.dumps(config_dict, sort_keys=True, separators=(",", ":"))
|
|
140
|
+
|
|
141
|
+
# Compute hash
|
|
142
|
+
return hashlib.sha256(json_str.encode()).hexdigest()[:16]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def configs_equal(config1: FeatureConfig, config2: FeatureConfig) -> bool:
|
|
146
|
+
"""Check if two feature configurations are equivalent.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
config1: First configuration
|
|
150
|
+
config2: Second configuration
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
True if configurations produce identical features
|
|
154
|
+
"""
|
|
155
|
+
return compute_feature_hash(config1) == compute_feature_hash(config2)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def config_to_dict(config: FeatureConfig) -> dict[str, Any]:
|
|
159
|
+
"""Convert feature config to dictionary for serialization.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
config: Feature configuration
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Dictionary representation
|
|
166
|
+
"""
|
|
167
|
+
return {
|
|
168
|
+
"engine": config.engine,
|
|
169
|
+
"lags": config.lags,
|
|
170
|
+
"calendar_features": config.calendar_features,
|
|
171
|
+
"rolling_windows": config.rolling_windows,
|
|
172
|
+
"known_covariates": config.known_covariates,
|
|
173
|
+
"observed_covariates": config.observed_covariates,
|
|
174
|
+
"include_intercept": config.include_intercept,
|
|
175
|
+
"tsfeatures_features": config.tsfeatures_features,
|
|
176
|
+
"tsfeatures_freq": config.tsfeatures_freq,
|
|
177
|
+
"tsfeatures_dict_freqs": config.tsfeatures_dict_freqs,
|
|
178
|
+
"allow_fallback": config.allow_fallback,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def config_from_dict(data: dict[str, Any]) -> FeatureConfig:
|
|
183
|
+
"""Create feature config from dictionary.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
data: Dictionary with configuration values
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
FeatureConfig instance
|
|
190
|
+
"""
|
|
191
|
+
return FeatureConfig(
|
|
192
|
+
engine=data.get("engine", "auto"),
|
|
193
|
+
lags=data.get("lags", []),
|
|
194
|
+
calendar_features=data.get("calendar_features", []),
|
|
195
|
+
rolling_windows=data.get("rolling_windows", {}),
|
|
196
|
+
known_covariates=data.get("known_covariates", []),
|
|
197
|
+
observed_covariates=data.get("observed_covariates", []),
|
|
198
|
+
include_intercept=data.get("include_intercept", False),
|
|
199
|
+
tsfeatures_features=data.get("tsfeatures_features", []),
|
|
200
|
+
tsfeatures_freq=data.get("tsfeatures_freq"),
|
|
201
|
+
tsfeatures_dict_freqs=data.get("tsfeatures_dict_freqs", {}),
|
|
202
|
+
allow_fallback=data.get("allow_fallback", True),
|
|
203
|
+
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Hierarchical time series forecasting and reconciliation.
|
|
2
|
+
|
|
3
|
+
This module provides tools for working with hierarchical time series data,
|
|
4
|
+
including structure definition, aggregation matrix operations, and multiple
|
|
5
|
+
reconciliation methods (bottom-up, top-down, middle-out, OLS, MinT).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from tsagentkit.hierarchy import HierarchyStructure, Reconciler
|
|
9
|
+
>>> from tsagentkit.hierarchy import ReconciliationMethod
|
|
10
|
+
>>>
|
|
11
|
+
>>> # Define hierarchy
|
|
12
|
+
>>> structure = HierarchyStructure.from_dataframe(
|
|
13
|
+
... df=sales_data,
|
|
14
|
+
... hierarchy_columns=["region", "state", "store"]
|
|
15
|
+
... )
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Reconcile forecasts
|
|
18
|
+
>>> reconciler = Reconciler(ReconciliationMethod.MIN_TRACE, structure)
|
|
19
|
+
>>> reconciled = reconciler.reconcile(base_forecasts, residuals=residuals)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from .evaluator import CoherenceViolation, HierarchyEvaluationReport, HierarchyEvaluator
|
|
25
|
+
from .reconciliation import Reconciler, ReconciliationMethod, reconcile_forecasts
|
|
26
|
+
from .structure import HierarchyStructure
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Structure
|
|
30
|
+
"HierarchyStructure",
|
|
31
|
+
# Reconciliation
|
|
32
|
+
"Reconciler",
|
|
33
|
+
"ReconciliationMethod",
|
|
34
|
+
"reconcile_forecasts",
|
|
35
|
+
# Evaluation
|
|
36
|
+
"HierarchyEvaluator",
|
|
37
|
+
"HierarchyEvaluationReport",
|
|
38
|
+
"CoherenceViolation",
|
|
39
|
+
]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Deprecated aggregation helpers.
|
|
2
|
+
|
|
3
|
+
This module previously contained projection-matrix implementations for
|
|
4
|
+
hierarchical reconciliation. Those algorithms are now delegated to
|
|
5
|
+
`hierarchicalforecast` to keep this package as a thin adapter.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from .structure import HierarchyStructure
|
|
15
|
+
|
|
16
|
+
_DEPRECATION_MESSAGE = (
|
|
17
|
+
"tsagentkit.hierarchy.aggregation is deprecated. "
|
|
18
|
+
"Use hierarchicalforecast methods directly instead."
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _deprecated(name: str) -> None:
|
|
23
|
+
warnings.warn(
|
|
24
|
+
f"{name} is deprecated. {_DEPRECATION_MESSAGE}",
|
|
25
|
+
DeprecationWarning,
|
|
26
|
+
stacklevel=2,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_bottom_up_matrix(structure: HierarchyStructure) -> np.ndarray: # pragma: no cover
|
|
31
|
+
_deprecated("create_bottom_up_matrix")
|
|
32
|
+
raise NotImplementedError(_DEPRECATION_MESSAGE)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_top_down_matrix( # pragma: no cover
|
|
36
|
+
structure: HierarchyStructure,
|
|
37
|
+
proportions: dict[str, float] | None = None,
|
|
38
|
+
historical_data: np.ndarray | None = None,
|
|
39
|
+
) -> np.ndarray:
|
|
40
|
+
_deprecated("create_top_down_matrix")
|
|
41
|
+
raise NotImplementedError(_DEPRECATION_MESSAGE)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def create_middle_out_matrix( # pragma: no cover
|
|
45
|
+
structure: HierarchyStructure,
|
|
46
|
+
middle_level: int,
|
|
47
|
+
) -> np.ndarray:
|
|
48
|
+
_deprecated("create_middle_out_matrix")
|
|
49
|
+
raise NotImplementedError(_DEPRECATION_MESSAGE)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def create_ols_matrix(structure: HierarchyStructure) -> np.ndarray: # pragma: no cover
|
|
53
|
+
_deprecated("create_ols_matrix")
|
|
54
|
+
raise NotImplementedError(_DEPRECATION_MESSAGE)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_wls_matrix( # pragma: no cover
|
|
58
|
+
structure: HierarchyStructure,
|
|
59
|
+
weights: np.ndarray,
|
|
60
|
+
) -> np.ndarray:
|
|
61
|
+
_deprecated("create_wls_matrix")
|
|
62
|
+
raise NotImplementedError(_DEPRECATION_MESSAGE)
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
"""Hierarchy-aware evaluation metrics and coherence checking.
|
|
2
|
+
|
|
3
|
+
Provides tools to evaluate hierarchical forecast quality and detect
|
|
4
|
+
coherence violations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .structure import HierarchyStructure
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class CoherenceViolation:
|
|
21
|
+
"""Single coherence violation record.
|
|
22
|
+
|
|
23
|
+
Records when forecasts violate the hierarchical aggregation constraint
|
|
24
|
+
(i.e., children don't sum to parent).
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
parent_node: Name of the parent node
|
|
28
|
+
child_nodes: List of child node names
|
|
29
|
+
expected_value: Expected value (sum of children)
|
|
30
|
+
actual_value: Actual value (parent forecast)
|
|
31
|
+
difference: Absolute difference between expected and actual
|
|
32
|
+
timestamp: Timestamp of the violation
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
parent_node: str
|
|
36
|
+
child_nodes: list[str]
|
|
37
|
+
expected_value: float
|
|
38
|
+
actual_value: float
|
|
39
|
+
difference: float
|
|
40
|
+
timestamp: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class HierarchyEvaluationReport:
|
|
45
|
+
"""Evaluation report for hierarchical forecasts.
|
|
46
|
+
|
|
47
|
+
Contains metrics and diagnostics for hierarchical forecast quality.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
level_metrics: Metrics aggregated by hierarchy level
|
|
51
|
+
coherence_violations: List of coherence violations found
|
|
52
|
+
coherence_score: Overall coherence score (0-1, higher is better)
|
|
53
|
+
reconciliation_improvement: Improvement vs base forecasts
|
|
54
|
+
total_violations: Total number of violations
|
|
55
|
+
violation_rate: Proportion of forecasts with violations
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
level_metrics: dict[int, dict[str, float]] = field(default_factory=dict)
|
|
59
|
+
coherence_violations: list[CoherenceViolation] = field(default_factory=list)
|
|
60
|
+
coherence_score: float = 0.0
|
|
61
|
+
reconciliation_improvement: dict[str, float] = field(default_factory=dict)
|
|
62
|
+
total_violations: int = 0
|
|
63
|
+
violation_rate: float = 0.0
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> dict:
|
|
66
|
+
"""Convert report to dictionary."""
|
|
67
|
+
return {
|
|
68
|
+
"level_metrics": self.level_metrics,
|
|
69
|
+
"coherence_violations": [
|
|
70
|
+
{
|
|
71
|
+
"parent_node": v.parent_node,
|
|
72
|
+
"child_nodes": v.child_nodes,
|
|
73
|
+
"expected_value": v.expected_value,
|
|
74
|
+
"actual_value": v.actual_value,
|
|
75
|
+
"difference": v.difference,
|
|
76
|
+
"timestamp": v.timestamp,
|
|
77
|
+
}
|
|
78
|
+
for v in self.coherence_violations
|
|
79
|
+
],
|
|
80
|
+
"coherence_score": self.coherence_score,
|
|
81
|
+
"reconciliation_improvement": self.reconciliation_improvement,
|
|
82
|
+
"total_violations": self.total_violations,
|
|
83
|
+
"violation_rate": self.violation_rate,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class HierarchyEvaluator:
|
|
88
|
+
"""Evaluate hierarchical forecast quality and coherence.
|
|
89
|
+
|
|
90
|
+
Provides methods to compute hierarchical metrics and detect
|
|
91
|
+
coherence violations.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
>>> evaluator = HierarchyEvaluator(hierarchy_structure)
|
|
95
|
+
>>> report = evaluator.evaluate(forecasts, actuals)
|
|
96
|
+
>>> print(f"Coherence score: {report.coherence_score:.3f}")
|
|
97
|
+
Coherence score: 0.998
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(self, structure: HierarchyStructure):
|
|
101
|
+
"""Initialize evaluator.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
structure: Hierarchy structure
|
|
105
|
+
"""
|
|
106
|
+
self.structure = structure
|
|
107
|
+
|
|
108
|
+
def evaluate(
|
|
109
|
+
self,
|
|
110
|
+
forecasts: pd.DataFrame,
|
|
111
|
+
actuals: pd.DataFrame | None = None,
|
|
112
|
+
tolerance: float = 1e-6,
|
|
113
|
+
) -> HierarchyEvaluationReport:
|
|
114
|
+
"""Evaluate hierarchical forecasts.
|
|
115
|
+
|
|
116
|
+
Computes standard forecast metrics per level and checks coherence.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
forecasts: Forecast DataFrame with columns [unique_id, ds, yhat]
|
|
120
|
+
actuals: Optional actual values for accuracy metrics
|
|
121
|
+
tolerance: Tolerance for coherence violations
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
HierarchyEvaluationReport with metrics and violations
|
|
125
|
+
"""
|
|
126
|
+
# Compute per-level metrics if actuals provided
|
|
127
|
+
level_metrics = {}
|
|
128
|
+
if actuals is not None:
|
|
129
|
+
level_metrics = self._compute_level_metrics(forecasts, actuals)
|
|
130
|
+
|
|
131
|
+
# Detect coherence violations
|
|
132
|
+
violations = self._detect_violations(forecasts, tolerance)
|
|
133
|
+
|
|
134
|
+
# Compute coherence score
|
|
135
|
+
coherence_score = self._compute_coherence_score(forecasts, tolerance)
|
|
136
|
+
|
|
137
|
+
# Compute violation rate
|
|
138
|
+
total_checks = self._count_total_checks(forecasts)
|
|
139
|
+
violation_rate = len(violations) / max(total_checks, 1)
|
|
140
|
+
|
|
141
|
+
return HierarchyEvaluationReport(
|
|
142
|
+
level_metrics=level_metrics,
|
|
143
|
+
coherence_violations=violations,
|
|
144
|
+
coherence_score=coherence_score,
|
|
145
|
+
total_violations=len(violations),
|
|
146
|
+
violation_rate=violation_rate,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _compute_level_metrics(
|
|
150
|
+
self,
|
|
151
|
+
forecasts: pd.DataFrame,
|
|
152
|
+
actuals: pd.DataFrame,
|
|
153
|
+
) -> dict[int, dict[str, float]]:
|
|
154
|
+
"""Compute metrics aggregated by hierarchy level.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
forecasts: Forecast DataFrame
|
|
158
|
+
actuals: Actual values DataFrame
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Dictionary mapping level to metrics dict
|
|
162
|
+
"""
|
|
163
|
+
metrics_by_level: dict[int, dict[str, list[float]]] = {}
|
|
164
|
+
|
|
165
|
+
# Merge forecasts with actuals
|
|
166
|
+
merged = forecasts.merge(
|
|
167
|
+
actuals,
|
|
168
|
+
on=["unique_id", "ds"],
|
|
169
|
+
suffixes=("_forecast", "_actual"),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Compute metrics per series
|
|
173
|
+
for _, row in merged.iterrows():
|
|
174
|
+
uid = row["unique_id"]
|
|
175
|
+
if uid not in self.structure.all_nodes:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
level = self.structure.get_level(uid)
|
|
179
|
+
if level not in metrics_by_level:
|
|
180
|
+
metrics_by_level[level] = {"mae": [], "mape": [], "rmse": []}
|
|
181
|
+
|
|
182
|
+
actual = row.get("y_actual", row.get("y", 0))
|
|
183
|
+
forecast = row["yhat"]
|
|
184
|
+
error = forecast - actual
|
|
185
|
+
|
|
186
|
+
metrics_by_level[level]["mae"].append(abs(error))
|
|
187
|
+
metrics_by_level[level]["rmse"].append(error ** 2)
|
|
188
|
+
if actual != 0:
|
|
189
|
+
metrics_by_level[level]["mape"].append(abs(error / actual) * 100)
|
|
190
|
+
|
|
191
|
+
# Aggregate
|
|
192
|
+
result = {}
|
|
193
|
+
for level, metrics in metrics_by_level.items():
|
|
194
|
+
result[level] = {
|
|
195
|
+
"mae": np.mean(metrics["mae"]) if metrics["mae"] else 0,
|
|
196
|
+
"rmse": np.sqrt(np.mean(metrics["rmse"])) if metrics["rmse"] else 0,
|
|
197
|
+
"mape": np.mean(metrics["mape"]) if metrics["mape"] else 0,
|
|
198
|
+
"count": len(metrics["mae"]),
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
return result
|
|
202
|
+
|
|
203
|
+
def _detect_violations(
|
|
204
|
+
self,
|
|
205
|
+
forecasts: pd.DataFrame,
|
|
206
|
+
tolerance: float = 1e-6,
|
|
207
|
+
) -> list[CoherenceViolation]:
|
|
208
|
+
"""Detect where forecasts violate hierarchical coherence.
|
|
209
|
+
|
|
210
|
+
A coherence violation occurs when the sum of child forecasts
|
|
211
|
+
doesn't equal the parent forecast (within tolerance).
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
forecasts: Forecast DataFrame
|
|
215
|
+
tolerance: Numerical tolerance for violations
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
List of coherence violations
|
|
219
|
+
"""
|
|
220
|
+
violations = []
|
|
221
|
+
|
|
222
|
+
# Pivot to wide format for easier computation
|
|
223
|
+
pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
|
|
224
|
+
|
|
225
|
+
for parent, children in self.structure.aggregation_graph.items():
|
|
226
|
+
if parent not in pivot.index:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
parent_forecast = pivot.loc[parent]
|
|
230
|
+
|
|
231
|
+
# Sum children forecasts
|
|
232
|
+
available_children = [c for c in children if c in pivot.index]
|
|
233
|
+
if not available_children:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
children_sum = pivot.loc[available_children].sum()
|
|
237
|
+
|
|
238
|
+
# Check for violations at each time point
|
|
239
|
+
for ds in parent_forecast.index:
|
|
240
|
+
parent_val = parent_forecast[ds]
|
|
241
|
+
children_val = children_sum[ds]
|
|
242
|
+
diff = abs(parent_val - children_val)
|
|
243
|
+
|
|
244
|
+
if diff > tolerance:
|
|
245
|
+
violations.append(
|
|
246
|
+
CoherenceViolation(
|
|
247
|
+
parent_node=parent,
|
|
248
|
+
child_nodes=available_children,
|
|
249
|
+
expected_value=float(children_val),
|
|
250
|
+
actual_value=float(parent_val),
|
|
251
|
+
difference=float(diff),
|
|
252
|
+
timestamp=str(ds),
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return violations
|
|
257
|
+
|
|
258
|
+
def _compute_coherence_score(
|
|
259
|
+
self,
|
|
260
|
+
forecasts: pd.DataFrame,
|
|
261
|
+
tolerance: float = 1e-6,
|
|
262
|
+
) -> float:
|
|
263
|
+
"""Compute overall coherence score.
|
|
264
|
+
|
|
265
|
+
Score is 1.0 if perfectly coherent, decreases with violations.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
forecasts: Forecast DataFrame
|
|
269
|
+
tolerance: Numerical tolerance
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Coherence score between 0 and 1
|
|
273
|
+
"""
|
|
274
|
+
pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
|
|
275
|
+
|
|
276
|
+
total_abs_sum = 0.0
|
|
277
|
+
total_violation = 0.0
|
|
278
|
+
|
|
279
|
+
for parent, children in self.structure.aggregation_graph.items():
|
|
280
|
+
if parent not in pivot.index:
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
parent_forecast = pivot.loc[parent]
|
|
284
|
+
available_children = [c for c in children if c in pivot.index]
|
|
285
|
+
|
|
286
|
+
if not available_children:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
children_sum = pivot.loc[available_children].sum()
|
|
290
|
+
|
|
291
|
+
for ds in parent_forecast.index:
|
|
292
|
+
parent_val = parent_forecast[ds]
|
|
293
|
+
children_val = children_sum[ds]
|
|
294
|
+
|
|
295
|
+
total_abs_sum += abs(parent_val)
|
|
296
|
+
violation = abs(parent_val - children_val)
|
|
297
|
+
|
|
298
|
+
if violation > tolerance:
|
|
299
|
+
total_violation += violation
|
|
300
|
+
|
|
301
|
+
if total_abs_sum == 0:
|
|
302
|
+
return 1.0
|
|
303
|
+
|
|
304
|
+
# Score decreases as violation magnitude increases
|
|
305
|
+
return max(0.0, 1.0 - (total_violation / total_abs_sum))
|
|
306
|
+
|
|
307
|
+
def _count_total_checks(self, forecasts: pd.DataFrame) -> int:
|
|
308
|
+
"""Count total number of coherence checks performed.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
forecasts: Forecast DataFrame
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Total number of parent-timestamp pairs checked
|
|
315
|
+
"""
|
|
316
|
+
pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
|
|
317
|
+
|
|
318
|
+
count = 0
|
|
319
|
+
for parent, children in self.structure.aggregation_graph.items():
|
|
320
|
+
if parent not in pivot.index:
|
|
321
|
+
continue
|
|
322
|
+
|
|
323
|
+
available_children = [c for c in children if c in pivot.index]
|
|
324
|
+
if available_children:
|
|
325
|
+
count += len(pivot.columns)
|
|
326
|
+
|
|
327
|
+
return count
|
|
328
|
+
|
|
329
|
+
def compute_improvement(
|
|
330
|
+
self,
|
|
331
|
+
base_forecasts: pd.DataFrame,
|
|
332
|
+
reconciled_forecasts: pd.DataFrame,
|
|
333
|
+
actuals: pd.DataFrame,
|
|
334
|
+
) -> dict[str, float]:
|
|
335
|
+
"""Compute improvement of reconciled vs base forecasts.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
base_forecasts: Base (unreconciled) forecasts
|
|
339
|
+
reconciled_forecasts: Reconciled forecasts
|
|
340
|
+
actuals: Actual values
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Dictionary with improvement metrics
|
|
344
|
+
"""
|
|
345
|
+
base_metrics = self._compute_overall_metrics(base_forecasts, actuals)
|
|
346
|
+
reconciled_metrics = self._compute_overall_metrics(
|
|
347
|
+
reconciled_forecasts, actuals
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
improvement = {}
|
|
351
|
+
for metric in ["mae", "rmse", "mape"]:
|
|
352
|
+
if metric in base_metrics and base_metrics[metric] > 0:
|
|
353
|
+
improvement[metric] = (
|
|
354
|
+
(base_metrics[metric] - reconciled_metrics[metric])
|
|
355
|
+
/ base_metrics[metric]
|
|
356
|
+
) * 100
|
|
357
|
+
else:
|
|
358
|
+
improvement[metric] = 0.0
|
|
359
|
+
|
|
360
|
+
return improvement
|
|
361
|
+
|
|
362
|
+
def _compute_overall_metrics(
|
|
363
|
+
self,
|
|
364
|
+
forecasts: pd.DataFrame,
|
|
365
|
+
actuals: pd.DataFrame,
|
|
366
|
+
) -> dict[str, float]:
|
|
367
|
+
"""Compute overall metrics for forecasts.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
forecasts: Forecast DataFrame
|
|
371
|
+
actuals: Actual values DataFrame
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
Dictionary of metrics
|
|
375
|
+
"""
|
|
376
|
+
merged = forecasts.merge(
|
|
377
|
+
actuals,
|
|
378
|
+
on=["unique_id", "ds"],
|
|
379
|
+
suffixes=("_forecast", "_actual"),
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if len(merged) == 0:
|
|
383
|
+
return {"mae": 0, "rmse": 0, "mape": 0}
|
|
384
|
+
|
|
385
|
+
actual_col = "y_actual" if "y_actual" in merged.columns else "y"
|
|
386
|
+
errors = merged["yhat"] - merged[actual_col]
|
|
387
|
+
|
|
388
|
+
mae = np.mean(np.abs(errors))
|
|
389
|
+
rmse = np.sqrt(np.mean(errors ** 2))
|
|
390
|
+
|
|
391
|
+
# Compute MAPE only for non-zero actuals
|
|
392
|
+
non_zero_mask = merged[actual_col] != 0
|
|
393
|
+
if non_zero_mask.any():
|
|
394
|
+
mape = np.mean(
|
|
395
|
+
np.abs(errors[non_zero_mask] / merged.loc[non_zero_mask, actual_col])
|
|
396
|
+
) * 100
|
|
397
|
+
else:
|
|
398
|
+
mape = 0.0
|
|
399
|
+
|
|
400
|
+
return {"mae": mae, "rmse": rmse, "mape": mape}
|