tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""Alert conditions and triggering for monitoring.
|
|
2
|
+
|
|
3
|
+
Provides configurable alert conditions for coverage, drift,
|
|
4
|
+
and model performance monitoring.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import UTC
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class AlertCondition:
|
|
16
|
+
"""Alert condition configuration.
|
|
17
|
+
|
|
18
|
+
Defines when an alert should be triggered based on a metric
|
|
19
|
+
and threshold.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
name: Alert name/identifier
|
|
23
|
+
metric: Metric to monitor (e.g., "coverage", "drift_score")
|
|
24
|
+
operator: Comparison operator ("lt", "gt", "eq", "ne")
|
|
25
|
+
threshold: Threshold value for triggering
|
|
26
|
+
severity: Alert severity ("info", "warning", "critical")
|
|
27
|
+
message: Optional custom alert message
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
name: str
|
|
31
|
+
metric: str
|
|
32
|
+
operator: str # "lt", "gt", "eq", "ne"
|
|
33
|
+
threshold: float
|
|
34
|
+
severity: str = "warning"
|
|
35
|
+
message: str | None = None
|
|
36
|
+
|
|
37
|
+
def evaluate(self, value: float) -> bool:
|
|
38
|
+
"""Evaluate if condition is met.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
value: Current metric value
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
True if alert should trigger
|
|
45
|
+
"""
|
|
46
|
+
if self.operator == "lt":
|
|
47
|
+
return value < self.threshold
|
|
48
|
+
if self.operator == "gt":
|
|
49
|
+
return value > self.threshold
|
|
50
|
+
if self.operator == "eq":
|
|
51
|
+
return value == self.threshold
|
|
52
|
+
if self.operator == "ne":
|
|
53
|
+
return value != self.threshold
|
|
54
|
+
if self.operator == "lte":
|
|
55
|
+
return value <= self.threshold
|
|
56
|
+
if self.operator == "gte":
|
|
57
|
+
return value >= self.threshold
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
def format_message(self, value: float, context: dict[str, Any] | None = None) -> str:
|
|
61
|
+
"""Format alert message with current value.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
value: Current metric value
|
|
65
|
+
context: Additional context for message formatting
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Formatted alert message
|
|
69
|
+
"""
|
|
70
|
+
if self.message:
|
|
71
|
+
ctx = context or {}
|
|
72
|
+
return self.message.format(
|
|
73
|
+
name=self.name,
|
|
74
|
+
metric=self.metric,
|
|
75
|
+
value=value,
|
|
76
|
+
threshold=self.threshold,
|
|
77
|
+
**ctx,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
op_str = {
|
|
81
|
+
"lt": "below",
|
|
82
|
+
"gt": "above",
|
|
83
|
+
"eq": "equal to",
|
|
84
|
+
"ne": "not equal to",
|
|
85
|
+
"lte": "at or below",
|
|
86
|
+
"gte": "at or above",
|
|
87
|
+
}.get(self.operator, self.operator)
|
|
88
|
+
|
|
89
|
+
return f"Alert '{self.name}': {self.metric} ({value:.4f}) is {op_str} threshold ({self.threshold:.4f})"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class Alert:
|
|
94
|
+
"""Triggered alert instance.
|
|
95
|
+
|
|
96
|
+
Represents an alert that has been triggered with full context.
|
|
97
|
+
|
|
98
|
+
Attributes:
|
|
99
|
+
condition: The alert condition that triggered
|
|
100
|
+
value: The metric value that triggered the alert
|
|
101
|
+
timestamp: ISO 8601 timestamp of when alert triggered
|
|
102
|
+
context: Additional context about the alert
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
condition: AlertCondition
|
|
106
|
+
value: float
|
|
107
|
+
timestamp: str
|
|
108
|
+
context: dict[str, Any] = field(default_factory=dict)
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict[str, Any]:
|
|
111
|
+
"""Convert to dictionary for serialization."""
|
|
112
|
+
return {
|
|
113
|
+
"name": self.condition.name,
|
|
114
|
+
"metric": self.condition.metric,
|
|
115
|
+
"severity": self.condition.severity,
|
|
116
|
+
"value": self.value,
|
|
117
|
+
"threshold": self.condition.threshold,
|
|
118
|
+
"operator": self.condition.operator,
|
|
119
|
+
"timestamp": self.timestamp,
|
|
120
|
+
"message": self.condition.format_message(self.value, self.context),
|
|
121
|
+
"context": self.context,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class AlertManager:
|
|
126
|
+
"""Manager for alert conditions and triggering.
|
|
127
|
+
|
|
128
|
+
Provides a centralized way to define alert conditions and
|
|
129
|
+
check them against current metrics.
|
|
130
|
+
|
|
131
|
+
Example:
|
|
132
|
+
>>> manager = AlertManager()
|
|
133
|
+
>>> manager.add_condition(AlertCondition(
|
|
134
|
+
... name="low_coverage",
|
|
135
|
+
... metric="coverage_80",
|
|
136
|
+
... operator="lt",
|
|
137
|
+
... threshold=0.75,
|
|
138
|
+
... severity="critical",
|
|
139
|
+
... ))
|
|
140
|
+
>>> alerts = manager.check_metrics({"coverage_80": 0.70})
|
|
141
|
+
>>> print(len(alerts))
|
|
142
|
+
1
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(self) -> None:
|
|
146
|
+
"""Initialize alert manager."""
|
|
147
|
+
self.conditions: list[AlertCondition] = []
|
|
148
|
+
self._alert_history: list[Alert] = []
|
|
149
|
+
|
|
150
|
+
def add_condition(self, condition: AlertCondition) -> None:
|
|
151
|
+
"""Add an alert condition.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
condition: Alert condition to add
|
|
155
|
+
"""
|
|
156
|
+
self.conditions.append(condition)
|
|
157
|
+
|
|
158
|
+
def remove_condition(self, name: str) -> bool:
|
|
159
|
+
"""Remove an alert condition by name.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
name: Name of condition to remove
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
True if condition was found and removed
|
|
166
|
+
"""
|
|
167
|
+
for i, cond in enumerate(self.conditions):
|
|
168
|
+
if cond.name == name:
|
|
169
|
+
self.conditions.pop(i)
|
|
170
|
+
return True
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
def check_metrics(
|
|
174
|
+
self,
|
|
175
|
+
metrics: dict[str, float],
|
|
176
|
+
context: dict[str, Any] | None = None,
|
|
177
|
+
) -> list[Alert]:
|
|
178
|
+
"""Check all conditions against current metrics.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
metrics: Dictionary of metric names to values
|
|
182
|
+
context: Additional context for alert messages
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
List of triggered alerts
|
|
186
|
+
"""
|
|
187
|
+
from datetime import datetime
|
|
188
|
+
|
|
189
|
+
triggered: list[Alert] = []
|
|
190
|
+
|
|
191
|
+
for condition in self.conditions:
|
|
192
|
+
if condition.metric not in metrics:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
value = metrics[condition.metric]
|
|
196
|
+
if condition.evaluate(value):
|
|
197
|
+
alert = Alert(
|
|
198
|
+
condition=condition,
|
|
199
|
+
value=value,
|
|
200
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
201
|
+
context=context or {},
|
|
202
|
+
)
|
|
203
|
+
triggered.append(alert)
|
|
204
|
+
self._alert_history.append(alert)
|
|
205
|
+
|
|
206
|
+
return triggered
|
|
207
|
+
|
|
208
|
+
def check_coverage(
|
|
209
|
+
self,
|
|
210
|
+
coverage_checks: list[Any],
|
|
211
|
+
context: dict[str, Any] | None = None,
|
|
212
|
+
) -> list[Alert]:
|
|
213
|
+
"""Check coverage results against conditions.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
coverage_checks: List of CoverageCheck objects
|
|
217
|
+
context: Additional context
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
List of triggered alerts
|
|
221
|
+
"""
|
|
222
|
+
metrics: dict[str, float] = {}
|
|
223
|
+
for check in coverage_checks:
|
|
224
|
+
if hasattr(check, "is_acceptable") and hasattr(check, "actual_coverage"):
|
|
225
|
+
metric_name = f"coverage_{check.expected_coverage:.0%}"
|
|
226
|
+
metrics[metric_name] = check.actual_coverage
|
|
227
|
+
|
|
228
|
+
return self.check_metrics(metrics, context)
|
|
229
|
+
|
|
230
|
+
def get_alert_history(
|
|
231
|
+
self,
|
|
232
|
+
severity: str | None = None,
|
|
233
|
+
) -> list[Alert]:
|
|
234
|
+
"""Get history of triggered alerts.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
severity: Filter by severity (optional)
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
List of historical alerts
|
|
241
|
+
"""
|
|
242
|
+
if severity is None:
|
|
243
|
+
return self._alert_history.copy()
|
|
244
|
+
return [a for a in self._alert_history if a.condition.severity == severity]
|
|
245
|
+
|
|
246
|
+
def clear_history(self) -> None:
|
|
247
|
+
"""Clear alert history."""
|
|
248
|
+
self._alert_history.clear()
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def create_default_coverage_alerts(
|
|
252
|
+
coverage_levels: list[float] | None = None,
|
|
253
|
+
tolerance: float = 0.05,
|
|
254
|
+
) -> list[AlertCondition]:
|
|
255
|
+
"""Create default alert conditions for coverage monitoring.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
coverage_levels: Coverage levels to monitor (default: [0.5, 0.8, 0.95])
|
|
259
|
+
tolerance: Tolerance for coverage deviation
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
List of alert conditions
|
|
263
|
+
"""
|
|
264
|
+
levels = coverage_levels or [0.5, 0.8, 0.95]
|
|
265
|
+
conditions: list[AlertCondition] = []
|
|
266
|
+
|
|
267
|
+
for level in levels:
|
|
268
|
+
conditions.append(
|
|
269
|
+
AlertCondition(
|
|
270
|
+
name=f"low_coverage_{level:.0%}",
|
|
271
|
+
metric=f"coverage_{level:.0%}",
|
|
272
|
+
operator="lt",
|
|
273
|
+
threshold=level - tolerance,
|
|
274
|
+
severity="warning" if level < 0.9 else "critical",
|
|
275
|
+
message=f"Coverage for {level:.0%} interval is below acceptable threshold",
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return conditions
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def create_default_drift_alerts(
|
|
283
|
+
drift_threshold: float = 0.05,
|
|
284
|
+
) -> list[AlertCondition]:
|
|
285
|
+
"""Create default alert conditions for drift monitoring.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
drift_threshold: Drift score threshold
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
List of alert conditions
|
|
292
|
+
"""
|
|
293
|
+
return [
|
|
294
|
+
AlertCondition(
|
|
295
|
+
name="drift_detected",
|
|
296
|
+
metric="drift_score",
|
|
297
|
+
operator="gt",
|
|
298
|
+
threshold=drift_threshold,
|
|
299
|
+
severity="warning",
|
|
300
|
+
message="Data drift detected above threshold",
|
|
301
|
+
),
|
|
302
|
+
]
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Coverage monitoring for quantile forecasts.
|
|
2
|
+
|
|
3
|
+
Provides interval coverage checks to verify that prediction intervals
|
|
4
|
+
are well-calibrated over time.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class CoverageCheck:
|
|
17
|
+
"""Interval coverage check for a specific quantile.
|
|
18
|
+
|
|
19
|
+
Tracks the actual coverage rate compared to expected coverage
|
|
20
|
+
for a given quantile level, with per-horizon breakdown.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
quantile: The quantile level (e.g., 0.1 for lower bound)
|
|
24
|
+
expected_coverage: Expected coverage rate
|
|
25
|
+
actual_coverage: Actual observed coverage rate
|
|
26
|
+
hit_rate_by_horizon: Coverage rate for each forecast horizon
|
|
27
|
+
tolerance: Acceptable deviation from expected coverage
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
quantile: float
|
|
31
|
+
expected_coverage: float
|
|
32
|
+
actual_coverage: float
|
|
33
|
+
hit_rate_by_horizon: dict[int, float]
|
|
34
|
+
tolerance: float = 0.05
|
|
35
|
+
|
|
36
|
+
def is_acceptable(self) -> bool:
|
|
37
|
+
"""Check if coverage is within tolerance of expected.
|
|
38
|
+
|
|
39
|
+
For interval coverage, we want the actual coverage to be at least
|
|
40
|
+
the expected coverage minus tolerance. Being slightly over is fine,
|
|
41
|
+
but being under indicates the intervals are too narrow.
|
|
42
|
+
"""
|
|
43
|
+
return self.actual_coverage >= (self.expected_coverage - self.tolerance)
|
|
44
|
+
|
|
45
|
+
def to_dict(self) -> dict[str, Any]:
|
|
46
|
+
"""Convert to dictionary for serialization."""
|
|
47
|
+
return {
|
|
48
|
+
"quantile": self.quantile,
|
|
49
|
+
"expected_coverage": self.expected_coverage,
|
|
50
|
+
"actual_coverage": self.actual_coverage,
|
|
51
|
+
"hit_rate_by_horizon": self.hit_rate_by_horizon,
|
|
52
|
+
"tolerance": self.tolerance,
|
|
53
|
+
"is_acceptable": self.is_acceptable(),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CoverageMonitor:
|
|
58
|
+
"""Monitor quantile coverage over time.
|
|
59
|
+
|
|
60
|
+
Provides functionality to check if prediction intervals are
|
|
61
|
+
well-calibrated by comparing actual vs expected coverage rates.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
>>> monitor = CoverageMonitor()
|
|
65
|
+
>>> checks = monitor.check(
|
|
66
|
+
... forecasts=forecast_df,
|
|
67
|
+
... actuals=actual_df,
|
|
68
|
+
... quantiles=[0.1, 0.5, 0.9],
|
|
69
|
+
... )
|
|
70
|
+
>>> for check in checks:
|
|
71
|
+
... print(f"Q{check.quantile}: {check.actual_coverage:.2%}")
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def check(
|
|
75
|
+
self,
|
|
76
|
+
forecasts: pd.DataFrame,
|
|
77
|
+
actuals: pd.DataFrame,
|
|
78
|
+
quantiles: list[float],
|
|
79
|
+
tolerance: float = 0.05,
|
|
80
|
+
) -> list[CoverageCheck]:
|
|
81
|
+
"""Check coverage for given quantiles.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
forecasts: Forecast dataframe with quantile columns (q_0.1, q_0.9, etc.)
|
|
85
|
+
actuals: Actual values dataframe
|
|
86
|
+
quantiles: List of quantile levels to check
|
|
87
|
+
tolerance: Acceptable deviation from expected coverage
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
List of CoverageCheck objects, one per quantile pair
|
|
91
|
+
"""
|
|
92
|
+
results: list[CoverageCheck] = []
|
|
93
|
+
|
|
94
|
+
# Ensure required columns exist
|
|
95
|
+
if "unique_id" not in forecasts.columns or "ds" not in forecasts.columns:
|
|
96
|
+
return results
|
|
97
|
+
|
|
98
|
+
# Merge forecasts with actuals
|
|
99
|
+
merged = forecasts.merge(
|
|
100
|
+
actuals[["unique_id", "ds", "y"]],
|
|
101
|
+
on=["unique_id", "ds"],
|
|
102
|
+
how="inner",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if merged.empty:
|
|
106
|
+
return results
|
|
107
|
+
|
|
108
|
+
# Calculate coverage for each quantile pair
|
|
109
|
+
for i, lower_q in enumerate(quantiles):
|
|
110
|
+
for upper_q in quantiles[i + 1 :]:
|
|
111
|
+
lower_col = f"q_{lower_q}"
|
|
112
|
+
upper_col = f"q_{upper_q}"
|
|
113
|
+
|
|
114
|
+
if lower_col not in merged.columns or upper_col not in merged.columns:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Calculate overall coverage
|
|
118
|
+
in_interval = (merged["y"] >= merged[lower_col]) & (
|
|
119
|
+
merged["y"] <= merged[upper_col]
|
|
120
|
+
)
|
|
121
|
+
actual_coverage = in_interval.mean()
|
|
122
|
+
|
|
123
|
+
# Calculate coverage by horizon if available
|
|
124
|
+
hit_rate_by_horizon: dict[int, float] = {}
|
|
125
|
+
if "h" in merged.columns or "horizon" in merged.columns:
|
|
126
|
+
h_col = "h" if "h" in merged.columns else "horizon"
|
|
127
|
+
for h in sorted(merged[h_col].unique()):
|
|
128
|
+
h_data = merged[merged[h_col] == h]
|
|
129
|
+
if len(h_data) > 0:
|
|
130
|
+
h_in_interval = (h_data["y"] >= h_data[lower_col]) & (
|
|
131
|
+
h_data["y"] <= h_data[upper_col]
|
|
132
|
+
)
|
|
133
|
+
hit_rate_by_horizon[int(h)] = float(h_in_interval.mean())
|
|
134
|
+
|
|
135
|
+
# Expected coverage is the difference between quantiles
|
|
136
|
+
expected_coverage = upper_q - lower_q
|
|
137
|
+
|
|
138
|
+
results.append(
|
|
139
|
+
CoverageCheck(
|
|
140
|
+
quantile=lower_q,
|
|
141
|
+
expected_coverage=expected_coverage,
|
|
142
|
+
actual_coverage=float(actual_coverage),
|
|
143
|
+
hit_rate_by_horizon=hit_rate_by_horizon,
|
|
144
|
+
tolerance=tolerance,
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return results
|
|
149
|
+
|
|
150
|
+
def check_single_quantile(
|
|
151
|
+
self,
|
|
152
|
+
forecasts: pd.DataFrame,
|
|
153
|
+
actuals: pd.DataFrame,
|
|
154
|
+
quantile: float,
|
|
155
|
+
tolerance: float = 0.05,
|
|
156
|
+
) -> CoverageCheck | None:
|
|
157
|
+
"""Check coverage for a single quantile (e.g., median).
|
|
158
|
+
|
|
159
|
+
For a single quantile, checks if actuals fall below the quantile
|
|
160
|
+
at the expected rate.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
forecasts: Forecast dataframe
|
|
164
|
+
actuals: Actual values dataframe
|
|
165
|
+
quantile: Quantile level to check
|
|
166
|
+
tolerance: Acceptable deviation
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
CoverageCheck or None if data not available
|
|
170
|
+
"""
|
|
171
|
+
q_col = f"q_{quantile}"
|
|
172
|
+
if q_col not in forecasts.columns:
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
merged = forecasts.merge(
|
|
176
|
+
actuals[["unique_id", "ds", "y"]],
|
|
177
|
+
on=["unique_id", "ds"],
|
|
178
|
+
how="inner",
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if merged.empty:
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
# For single quantile: check if actual <= quantile at quantile rate
|
|
185
|
+
below_quantile = merged["y"] <= merged[q_col]
|
|
186
|
+
actual_coverage = float(below_quantile.mean())
|
|
187
|
+
|
|
188
|
+
hit_rate_by_horizon: dict[int, float] = {}
|
|
189
|
+
if "h" in merged.columns or "horizon" in merged.columns:
|
|
190
|
+
h_col = "h" if "h" in merged.columns else "horizon"
|
|
191
|
+
for h in sorted(merged[h_col].unique()):
|
|
192
|
+
h_data = merged[merged[h_col] == h]
|
|
193
|
+
if len(h_data) > 0:
|
|
194
|
+
h_below = h_data["y"] <= h_data[q_col]
|
|
195
|
+
hit_rate_by_horizon[int(h)] = float(h_below.mean())
|
|
196
|
+
|
|
197
|
+
return CoverageCheck(
|
|
198
|
+
quantile=quantile,
|
|
199
|
+
expected_coverage=quantile,
|
|
200
|
+
actual_coverage=actual_coverage,
|
|
201
|
+
hit_rate_by_horizon=hit_rate_by_horizon,
|
|
202
|
+
tolerance=tolerance,
|
|
203
|
+
)
|