tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,203 @@
1
+ """Feature versioning and hashing for provenance tracking."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Literal
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class FeatureConfig:
13
+ """Configuration for feature engineering.
14
+
15
+ Attributes:
16
+ engine: Feature backend selection ("auto", "native", "tsfeatures")
17
+ lags: List of lag periods to create (e.g., [1, 7, 14])
18
+ calendar_features: List of calendar features to create
19
+ (e.g., ["dayofweek", "month", "quarter", "year"])
20
+ rolling_windows: Dict mapping window sizes to aggregation functions
21
+ (e.g., {7: ["mean", "std"], 30: ["mean"]})
22
+ known_covariates: List of column names for known covariates
23
+ observed_covariates: List of column names for observed covariates
24
+ include_intercept: Whether to include an intercept column (all 1s)
25
+ tsfeatures_features: Optional list of tsfeatures function names
26
+ tsfeatures_freq: Optional season length to pass to tsfeatures
27
+ tsfeatures_dict_freqs: Optional dict mapping pandas freq -> season length
28
+ allow_fallback: Allow fallback to native when tsfeatures is unavailable
29
+
30
+ Example:
31
+ >>> config = FeatureConfig(
32
+ ... lags=[1, 7, 14],
33
+ ... calendar_features=["dayofweek", "month"],
34
+ ... rolling_windows={7: ["mean", "std"], 30: ["mean"]},
35
+ ... known_covariates=["holiday"],
36
+ ... observed_covariates=["promotion"],
37
+ ... )
38
+ >>> print(compute_feature_hash(config))
39
+ abc123def456...
40
+ """
41
+
42
+ engine: Literal["auto", "native", "tsfeatures"] = "auto"
43
+ lags: list[int] = field(default_factory=list)
44
+ calendar_features: list[str] = field(default_factory=list)
45
+ rolling_windows: dict[int, list[str]] = field(default_factory=dict)
46
+ known_covariates: list[str] = field(default_factory=list)
47
+ observed_covariates: list[str] = field(default_factory=list)
48
+ include_intercept: bool = False
49
+ tsfeatures_features: list[str] = field(default_factory=list)
50
+ tsfeatures_freq: int | None = None
51
+ tsfeatures_dict_freqs: dict[str, int] = field(default_factory=dict)
52
+ allow_fallback: bool = False
53
+
54
+ def __post_init__(self) -> None:
55
+ """Validate configuration after creation."""
56
+ if self.engine not in {"auto", "native", "tsfeatures"}:
57
+ raise ValueError(f"Invalid feature engine: {self.engine}")
58
+
59
+ # Validate lags are positive integers
60
+ for lag in self.lags:
61
+ if not isinstance(lag, int) or lag < 1:
62
+ raise ValueError(f"Lags must be positive integers, got {lag}")
63
+
64
+ # Validate calendar features
65
+ valid_calendar = {
66
+ "dayofweek", "month", "quarter", "year", "dayofmonth",
67
+ "dayofyear", "weekofyear", "hour", "minute", "is_month_start",
68
+ "is_month_end", "is_quarter_start", "is_quarter_end",
69
+ }
70
+ invalid = set(self.calendar_features) - valid_calendar
71
+ if invalid:
72
+ raise ValueError(f"Invalid calendar features: {invalid}. Valid: {valid_calendar}")
73
+
74
+ # Validate rolling aggregations
75
+ valid_aggs = {"mean", "std", "min", "max", "sum", "median"}
76
+ for window, aggs in self.rolling_windows.items():
77
+ if not isinstance(window, int) or window < 1:
78
+ raise ValueError(f"Window sizes must be positive integers, got {window}")
79
+ invalid_aggs = set(aggs) - valid_aggs
80
+ if invalid_aggs:
81
+ raise ValueError(f"Invalid aggregations: {invalid_aggs}. Valid: {valid_aggs}")
82
+
83
+ # Check for overlap in covariates
84
+ overlap = set(self.known_covariates) & set(self.observed_covariates)
85
+ if overlap:
86
+ raise ValueError(f"Covariates cannot be both known and observed: {overlap}")
87
+
88
+ if self.tsfeatures_freq is not None and self.tsfeatures_freq < 1:
89
+ raise ValueError("tsfeatures_freq must be a positive integer when provided")
90
+
91
+ for key, value in self.tsfeatures_dict_freqs.items():
92
+ if not isinstance(value, int) or value < 1:
93
+ raise ValueError(
94
+ f"tsfeatures_dict_freqs must map to positive integers, got {key}={value}"
95
+ )
96
+
97
+
98
+ def compute_feature_hash(config: FeatureConfig) -> str:
99
+ """Compute deterministic hash of feature configuration.
100
+
101
+ The hash includes all feature configuration parameters to ensure
102
+ that any change to the feature engineering setup results in a
103
+ different hash for provenance tracking.
104
+
105
+ Args:
106
+ config: Feature configuration to hash
107
+
108
+ Returns:
109
+ SHA-256 hash string (truncated to 16 characters)
110
+
111
+ Example:
112
+ >>> config = FeatureConfig(lags=[1, 7], calendar_features=["dayofweek"])
113
+ >>> h = compute_feature_hash(config)
114
+ >>> len(h)
115
+ 16
116
+ """
117
+ # Build normalized configuration dict
118
+ config_dict = {
119
+ "engine": config.engine,
120
+ "lags": sorted(config.lags) if config.lags else [],
121
+ "calendar": sorted(config.calendar_features),
122
+ "rolling": [
123
+ {"window": w, "aggs": sorted(a)}
124
+ for w, a in sorted(config.rolling_windows.items())
125
+ ],
126
+ "known_covariates": sorted(config.known_covariates),
127
+ "observed_covariates": sorted(config.observed_covariates),
128
+ "include_intercept": config.include_intercept,
129
+ "tsfeatures_features": sorted(config.tsfeatures_features),
130
+ "tsfeatures_freq": config.tsfeatures_freq,
131
+ "tsfeatures_dict_freqs": {
132
+ k: config.tsfeatures_dict_freqs[k]
133
+ for k in sorted(config.tsfeatures_dict_freqs)
134
+ },
135
+ "allow_fallback": config.allow_fallback,
136
+ }
137
+
138
+ # Create deterministic JSON representation
139
+ json_str = json.dumps(config_dict, sort_keys=True, separators=(",", ":"))
140
+
141
+ # Compute hash
142
+ return hashlib.sha256(json_str.encode()).hexdigest()[:16]
143
+
144
+
145
+ def configs_equal(config1: FeatureConfig, config2: FeatureConfig) -> bool:
146
+ """Check if two feature configurations are equivalent.
147
+
148
+ Args:
149
+ config1: First configuration
150
+ config2: Second configuration
151
+
152
+ Returns:
153
+ True if configurations produce identical features
154
+ """
155
+ return compute_feature_hash(config1) == compute_feature_hash(config2)
156
+
157
+
158
+ def config_to_dict(config: FeatureConfig) -> dict[str, Any]:
159
+ """Convert feature config to dictionary for serialization.
160
+
161
+ Args:
162
+ config: Feature configuration
163
+
164
+ Returns:
165
+ Dictionary representation
166
+ """
167
+ return {
168
+ "engine": config.engine,
169
+ "lags": config.lags,
170
+ "calendar_features": config.calendar_features,
171
+ "rolling_windows": config.rolling_windows,
172
+ "known_covariates": config.known_covariates,
173
+ "observed_covariates": config.observed_covariates,
174
+ "include_intercept": config.include_intercept,
175
+ "tsfeatures_features": config.tsfeatures_features,
176
+ "tsfeatures_freq": config.tsfeatures_freq,
177
+ "tsfeatures_dict_freqs": config.tsfeatures_dict_freqs,
178
+ "allow_fallback": config.allow_fallback,
179
+ }
180
+
181
+
182
+ def config_from_dict(data: dict[str, Any]) -> FeatureConfig:
183
+ """Create feature config from dictionary.
184
+
185
+ Args:
186
+ data: Dictionary with configuration values
187
+
188
+ Returns:
189
+ FeatureConfig instance
190
+ """
191
+ return FeatureConfig(
192
+ engine=data.get("engine", "auto"),
193
+ lags=data.get("lags", []),
194
+ calendar_features=data.get("calendar_features", []),
195
+ rolling_windows=data.get("rolling_windows", {}),
196
+ known_covariates=data.get("known_covariates", []),
197
+ observed_covariates=data.get("observed_covariates", []),
198
+ include_intercept=data.get("include_intercept", False),
199
+ tsfeatures_features=data.get("tsfeatures_features", []),
200
+ tsfeatures_freq=data.get("tsfeatures_freq"),
201
+ tsfeatures_dict_freqs=data.get("tsfeatures_dict_freqs", {}),
202
+ allow_fallback=data.get("allow_fallback", True),
203
+ )
@@ -0,0 +1,39 @@
1
+ """Hierarchical time series forecasting and reconciliation.
2
+
3
+ This module provides tools for working with hierarchical time series data,
4
+ including structure definition, aggregation matrix operations, and multiple
5
+ reconciliation methods (bottom-up, top-down, middle-out, OLS, MinT).
6
+
7
+ Example:
8
+ >>> from tsagentkit.hierarchy import HierarchyStructure, Reconciler
9
+ >>> from tsagentkit.hierarchy import ReconciliationMethod
10
+ >>>
11
+ >>> # Define hierarchy
12
+ >>> structure = HierarchyStructure.from_dataframe(
13
+ ... df=sales_data,
14
+ ... hierarchy_columns=["region", "state", "store"]
15
+ ... )
16
+ >>>
17
+ >>> # Reconcile forecasts
18
+ >>> reconciler = Reconciler(ReconciliationMethod.MIN_TRACE, structure)
19
+ >>> reconciled = reconciler.reconcile(base_forecasts, residuals=residuals)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from .evaluator import CoherenceViolation, HierarchyEvaluationReport, HierarchyEvaluator
25
+ from .reconciliation import Reconciler, ReconciliationMethod, reconcile_forecasts
26
+ from .structure import HierarchyStructure
27
+
28
+ __all__ = [
29
+ # Structure
30
+ "HierarchyStructure",
31
+ # Reconciliation
32
+ "Reconciler",
33
+ "ReconciliationMethod",
34
+ "reconcile_forecasts",
35
+ # Evaluation
36
+ "HierarchyEvaluator",
37
+ "HierarchyEvaluationReport",
38
+ "CoherenceViolation",
39
+ ]
@@ -0,0 +1,62 @@
1
+ """Deprecated aggregation helpers.
2
+
3
+ This module previously contained projection-matrix implementations for
4
+ hierarchical reconciliation. Those algorithms are now delegated to
5
+ `hierarchicalforecast` to keep this package as a thin adapter.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+
12
+ import numpy as np
13
+
14
+ from .structure import HierarchyStructure
15
+
16
+ _DEPRECATION_MESSAGE = (
17
+ "tsagentkit.hierarchy.aggregation is deprecated. "
18
+ "Use hierarchicalforecast methods directly instead."
19
+ )
20
+
21
+
22
+ def _deprecated(name: str) -> None:
23
+ warnings.warn(
24
+ f"{name} is deprecated. {_DEPRECATION_MESSAGE}",
25
+ DeprecationWarning,
26
+ stacklevel=2,
27
+ )
28
+
29
+
30
+ def create_bottom_up_matrix(structure: HierarchyStructure) -> np.ndarray: # pragma: no cover
31
+ _deprecated("create_bottom_up_matrix")
32
+ raise NotImplementedError(_DEPRECATION_MESSAGE)
33
+
34
+
35
+ def create_top_down_matrix( # pragma: no cover
36
+ structure: HierarchyStructure,
37
+ proportions: dict[str, float] | None = None,
38
+ historical_data: np.ndarray | None = None,
39
+ ) -> np.ndarray:
40
+ _deprecated("create_top_down_matrix")
41
+ raise NotImplementedError(_DEPRECATION_MESSAGE)
42
+
43
+
44
+ def create_middle_out_matrix( # pragma: no cover
45
+ structure: HierarchyStructure,
46
+ middle_level: int,
47
+ ) -> np.ndarray:
48
+ _deprecated("create_middle_out_matrix")
49
+ raise NotImplementedError(_DEPRECATION_MESSAGE)
50
+
51
+
52
+ def create_ols_matrix(structure: HierarchyStructure) -> np.ndarray: # pragma: no cover
53
+ _deprecated("create_ols_matrix")
54
+ raise NotImplementedError(_DEPRECATION_MESSAGE)
55
+
56
+
57
+ def create_wls_matrix( # pragma: no cover
58
+ structure: HierarchyStructure,
59
+ weights: np.ndarray,
60
+ ) -> np.ndarray:
61
+ _deprecated("create_wls_matrix")
62
+ raise NotImplementedError(_DEPRECATION_MESSAGE)
@@ -0,0 +1,400 @@
1
+ """Hierarchy-aware evaluation metrics and coherence checking.
2
+
3
+ Provides tools to evaluate hierarchical forecast quality and detect
4
+ coherence violations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ if TYPE_CHECKING:
16
+ from .structure import HierarchyStructure
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class CoherenceViolation:
21
+ """Single coherence violation record.
22
+
23
+ Records when forecasts violate the hierarchical aggregation constraint
24
+ (i.e., children don't sum to parent).
25
+
26
+ Attributes:
27
+ parent_node: Name of the parent node
28
+ child_nodes: List of child node names
29
+ expected_value: Expected value (sum of children)
30
+ actual_value: Actual value (parent forecast)
31
+ difference: Absolute difference between expected and actual
32
+ timestamp: Timestamp of the violation
33
+ """
34
+
35
+ parent_node: str
36
+ child_nodes: list[str]
37
+ expected_value: float
38
+ actual_value: float
39
+ difference: float
40
+ timestamp: str
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class HierarchyEvaluationReport:
45
+ """Evaluation report for hierarchical forecasts.
46
+
47
+ Contains metrics and diagnostics for hierarchical forecast quality.
48
+
49
+ Attributes:
50
+ level_metrics: Metrics aggregated by hierarchy level
51
+ coherence_violations: List of coherence violations found
52
+ coherence_score: Overall coherence score (0-1, higher is better)
53
+ reconciliation_improvement: Improvement vs base forecasts
54
+ total_violations: Total number of violations
55
+ violation_rate: Proportion of forecasts with violations
56
+ """
57
+
58
+ level_metrics: dict[int, dict[str, float]] = field(default_factory=dict)
59
+ coherence_violations: list[CoherenceViolation] = field(default_factory=list)
60
+ coherence_score: float = 0.0
61
+ reconciliation_improvement: dict[str, float] = field(default_factory=dict)
62
+ total_violations: int = 0
63
+ violation_rate: float = 0.0
64
+
65
+ def to_dict(self) -> dict:
66
+ """Convert report to dictionary."""
67
+ return {
68
+ "level_metrics": self.level_metrics,
69
+ "coherence_violations": [
70
+ {
71
+ "parent_node": v.parent_node,
72
+ "child_nodes": v.child_nodes,
73
+ "expected_value": v.expected_value,
74
+ "actual_value": v.actual_value,
75
+ "difference": v.difference,
76
+ "timestamp": v.timestamp,
77
+ }
78
+ for v in self.coherence_violations
79
+ ],
80
+ "coherence_score": self.coherence_score,
81
+ "reconciliation_improvement": self.reconciliation_improvement,
82
+ "total_violations": self.total_violations,
83
+ "violation_rate": self.violation_rate,
84
+ }
85
+
86
+
87
+ class HierarchyEvaluator:
88
+ """Evaluate hierarchical forecast quality and coherence.
89
+
90
+ Provides methods to compute hierarchical metrics and detect
91
+ coherence violations.
92
+
93
+ Example:
94
+ >>> evaluator = HierarchyEvaluator(hierarchy_structure)
95
+ >>> report = evaluator.evaluate(forecasts, actuals)
96
+ >>> print(f"Coherence score: {report.coherence_score:.3f}")
97
+ Coherence score: 0.998
98
+ """
99
+
100
+ def __init__(self, structure: HierarchyStructure):
101
+ """Initialize evaluator.
102
+
103
+ Args:
104
+ structure: Hierarchy structure
105
+ """
106
+ self.structure = structure
107
+
108
+ def evaluate(
109
+ self,
110
+ forecasts: pd.DataFrame,
111
+ actuals: pd.DataFrame | None = None,
112
+ tolerance: float = 1e-6,
113
+ ) -> HierarchyEvaluationReport:
114
+ """Evaluate hierarchical forecasts.
115
+
116
+ Computes standard forecast metrics per level and checks coherence.
117
+
118
+ Args:
119
+ forecasts: Forecast DataFrame with columns [unique_id, ds, yhat]
120
+ actuals: Optional actual values for accuracy metrics
121
+ tolerance: Tolerance for coherence violations
122
+
123
+ Returns:
124
+ HierarchyEvaluationReport with metrics and violations
125
+ """
126
+ # Compute per-level metrics if actuals provided
127
+ level_metrics = {}
128
+ if actuals is not None:
129
+ level_metrics = self._compute_level_metrics(forecasts, actuals)
130
+
131
+ # Detect coherence violations
132
+ violations = self._detect_violations(forecasts, tolerance)
133
+
134
+ # Compute coherence score
135
+ coherence_score = self._compute_coherence_score(forecasts, tolerance)
136
+
137
+ # Compute violation rate
138
+ total_checks = self._count_total_checks(forecasts)
139
+ violation_rate = len(violations) / max(total_checks, 1)
140
+
141
+ return HierarchyEvaluationReport(
142
+ level_metrics=level_metrics,
143
+ coherence_violations=violations,
144
+ coherence_score=coherence_score,
145
+ total_violations=len(violations),
146
+ violation_rate=violation_rate,
147
+ )
148
+
149
+ def _compute_level_metrics(
150
+ self,
151
+ forecasts: pd.DataFrame,
152
+ actuals: pd.DataFrame,
153
+ ) -> dict[int, dict[str, float]]:
154
+ """Compute metrics aggregated by hierarchy level.
155
+
156
+ Args:
157
+ forecasts: Forecast DataFrame
158
+ actuals: Actual values DataFrame
159
+
160
+ Returns:
161
+ Dictionary mapping level to metrics dict
162
+ """
163
+ metrics_by_level: dict[int, dict[str, list[float]]] = {}
164
+
165
+ # Merge forecasts with actuals
166
+ merged = forecasts.merge(
167
+ actuals,
168
+ on=["unique_id", "ds"],
169
+ suffixes=("_forecast", "_actual"),
170
+ )
171
+
172
+ # Compute metrics per series
173
+ for _, row in merged.iterrows():
174
+ uid = row["unique_id"]
175
+ if uid not in self.structure.all_nodes:
176
+ continue
177
+
178
+ level = self.structure.get_level(uid)
179
+ if level not in metrics_by_level:
180
+ metrics_by_level[level] = {"mae": [], "mape": [], "rmse": []}
181
+
182
+ actual = row.get("y_actual", row.get("y", 0))
183
+ forecast = row["yhat"]
184
+ error = forecast - actual
185
+
186
+ metrics_by_level[level]["mae"].append(abs(error))
187
+ metrics_by_level[level]["rmse"].append(error ** 2)
188
+ if actual != 0:
189
+ metrics_by_level[level]["mape"].append(abs(error / actual) * 100)
190
+
191
+ # Aggregate
192
+ result = {}
193
+ for level, metrics in metrics_by_level.items():
194
+ result[level] = {
195
+ "mae": np.mean(metrics["mae"]) if metrics["mae"] else 0,
196
+ "rmse": np.sqrt(np.mean(metrics["rmse"])) if metrics["rmse"] else 0,
197
+ "mape": np.mean(metrics["mape"]) if metrics["mape"] else 0,
198
+ "count": len(metrics["mae"]),
199
+ }
200
+
201
+ return result
202
+
203
+ def _detect_violations(
204
+ self,
205
+ forecasts: pd.DataFrame,
206
+ tolerance: float = 1e-6,
207
+ ) -> list[CoherenceViolation]:
208
+ """Detect where forecasts violate hierarchical coherence.
209
+
210
+ A coherence violation occurs when the sum of child forecasts
211
+ doesn't equal the parent forecast (within tolerance).
212
+
213
+ Args:
214
+ forecasts: Forecast DataFrame
215
+ tolerance: Numerical tolerance for violations
216
+
217
+ Returns:
218
+ List of coherence violations
219
+ """
220
+ violations = []
221
+
222
+ # Pivot to wide format for easier computation
223
+ pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
224
+
225
+ for parent, children in self.structure.aggregation_graph.items():
226
+ if parent not in pivot.index:
227
+ continue
228
+
229
+ parent_forecast = pivot.loc[parent]
230
+
231
+ # Sum children forecasts
232
+ available_children = [c for c in children if c in pivot.index]
233
+ if not available_children:
234
+ continue
235
+
236
+ children_sum = pivot.loc[available_children].sum()
237
+
238
+ # Check for violations at each time point
239
+ for ds in parent_forecast.index:
240
+ parent_val = parent_forecast[ds]
241
+ children_val = children_sum[ds]
242
+ diff = abs(parent_val - children_val)
243
+
244
+ if diff > tolerance:
245
+ violations.append(
246
+ CoherenceViolation(
247
+ parent_node=parent,
248
+ child_nodes=available_children,
249
+ expected_value=float(children_val),
250
+ actual_value=float(parent_val),
251
+ difference=float(diff),
252
+ timestamp=str(ds),
253
+ )
254
+ )
255
+
256
+ return violations
257
+
258
+ def _compute_coherence_score(
259
+ self,
260
+ forecasts: pd.DataFrame,
261
+ tolerance: float = 1e-6,
262
+ ) -> float:
263
+ """Compute overall coherence score.
264
+
265
+ Score is 1.0 if perfectly coherent, decreases with violations.
266
+
267
+ Args:
268
+ forecasts: Forecast DataFrame
269
+ tolerance: Numerical tolerance
270
+
271
+ Returns:
272
+ Coherence score between 0 and 1
273
+ """
274
+ pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
275
+
276
+ total_abs_sum = 0.0
277
+ total_violation = 0.0
278
+
279
+ for parent, children in self.structure.aggregation_graph.items():
280
+ if parent not in pivot.index:
281
+ continue
282
+
283
+ parent_forecast = pivot.loc[parent]
284
+ available_children = [c for c in children if c in pivot.index]
285
+
286
+ if not available_children:
287
+ continue
288
+
289
+ children_sum = pivot.loc[available_children].sum()
290
+
291
+ for ds in parent_forecast.index:
292
+ parent_val = parent_forecast[ds]
293
+ children_val = children_sum[ds]
294
+
295
+ total_abs_sum += abs(parent_val)
296
+ violation = abs(parent_val - children_val)
297
+
298
+ if violation > tolerance:
299
+ total_violation += violation
300
+
301
+ if total_abs_sum == 0:
302
+ return 1.0
303
+
304
+ # Score decreases as violation magnitude increases
305
+ return max(0.0, 1.0 - (total_violation / total_abs_sum))
306
+
307
+ def _count_total_checks(self, forecasts: pd.DataFrame) -> int:
308
+ """Count total number of coherence checks performed.
309
+
310
+ Args:
311
+ forecasts: Forecast DataFrame
312
+
313
+ Returns:
314
+ Total number of parent-timestamp pairs checked
315
+ """
316
+ pivot = forecasts.pivot(index="unique_id", columns="ds", values="yhat")
317
+
318
+ count = 0
319
+ for parent, children in self.structure.aggregation_graph.items():
320
+ if parent not in pivot.index:
321
+ continue
322
+
323
+ available_children = [c for c in children if c in pivot.index]
324
+ if available_children:
325
+ count += len(pivot.columns)
326
+
327
+ return count
328
+
329
+ def compute_improvement(
330
+ self,
331
+ base_forecasts: pd.DataFrame,
332
+ reconciled_forecasts: pd.DataFrame,
333
+ actuals: pd.DataFrame,
334
+ ) -> dict[str, float]:
335
+ """Compute improvement of reconciled vs base forecasts.
336
+
337
+ Args:
338
+ base_forecasts: Base (unreconciled) forecasts
339
+ reconciled_forecasts: Reconciled forecasts
340
+ actuals: Actual values
341
+
342
+ Returns:
343
+ Dictionary with improvement metrics
344
+ """
345
+ base_metrics = self._compute_overall_metrics(base_forecasts, actuals)
346
+ reconciled_metrics = self._compute_overall_metrics(
347
+ reconciled_forecasts, actuals
348
+ )
349
+
350
+ improvement = {}
351
+ for metric in ["mae", "rmse", "mape"]:
352
+ if metric in base_metrics and base_metrics[metric] > 0:
353
+ improvement[metric] = (
354
+ (base_metrics[metric] - reconciled_metrics[metric])
355
+ / base_metrics[metric]
356
+ ) * 100
357
+ else:
358
+ improvement[metric] = 0.0
359
+
360
+ return improvement
361
+
362
+ def _compute_overall_metrics(
363
+ self,
364
+ forecasts: pd.DataFrame,
365
+ actuals: pd.DataFrame,
366
+ ) -> dict[str, float]:
367
+ """Compute overall metrics for forecasts.
368
+
369
+ Args:
370
+ forecasts: Forecast DataFrame
371
+ actuals: Actual values DataFrame
372
+
373
+ Returns:
374
+ Dictionary of metrics
375
+ """
376
+ merged = forecasts.merge(
377
+ actuals,
378
+ on=["unique_id", "ds"],
379
+ suffixes=("_forecast", "_actual"),
380
+ )
381
+
382
+ if len(merged) == 0:
383
+ return {"mae": 0, "rmse": 0, "mape": 0}
384
+
385
+ actual_col = "y_actual" if "y_actual" in merged.columns else "y"
386
+ errors = merged["yhat"] - merged[actual_col]
387
+
388
+ mae = np.mean(np.abs(errors))
389
+ rmse = np.sqrt(np.mean(errors ** 2))
390
+
391
+ # Compute MAPE only for non-zero actuals
392
+ non_zero_mask = merged[actual_col] != 0
393
+ if non_zero_mask.any():
394
+ mape = np.mean(
395
+ np.abs(errors[non_zero_mask] / merged.loc[non_zero_mask, actual_col])
396
+ ) * 100
397
+ else:
398
+ mape = 0.0
399
+
400
+ return {"mae": mae, "rmse": rmse, "mape": mape}