tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""Stability monitoring for prediction jitter and quantile coverage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from tsagentkit.monitoring.report import CalibrationReport, StabilityReport
|
|
11
|
+
from tsagentkit.utils import normalize_quantile_columns, quantile_col_name
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class StabilityMonitor:
|
|
18
|
+
"""Monitor prediction stability and calibration.
|
|
19
|
+
|
|
20
|
+
This class tracks:
|
|
21
|
+
- Prediction jitter: Variance in point predictions across runs
|
|
22
|
+
- Quantile coverage: Whether empirical coverage matches target quantiles
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> monitor = StabilityMonitor(jitter_threshold=0.1)
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Check jitter across multiple forecast runs
|
|
28
|
+
>>> jitter = monitor.compute_jitter([forecast1, forecast2, forecast3])
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Check quantile calibration
|
|
31
|
+
>>> coverage = monitor.compute_coverage(actuals, forecasts, [0.1, 0.5, 0.9])
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
jitter_threshold: float = 0.1,
|
|
37
|
+
coverage_tolerance: float = 0.05,
|
|
38
|
+
):
|
|
39
|
+
"""Initialize stability monitor.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
jitter_threshold: Coefficient of variation threshold for jitter warnings
|
|
43
|
+
coverage_tolerance: Allowed deviation from target coverage
|
|
44
|
+
"""
|
|
45
|
+
self.jitter_threshold = jitter_threshold
|
|
46
|
+
self.coverage_tolerance = coverage_tolerance
|
|
47
|
+
|
|
48
|
+
def compute_jitter(
|
|
49
|
+
self,
|
|
50
|
+
predictions: list[pd.DataFrame],
|
|
51
|
+
method: Literal["cv", "mad"] = "cv",
|
|
52
|
+
) -> dict[str, float]:
|
|
53
|
+
"""Compute prediction jitter across multiple forecast runs.
|
|
54
|
+
|
|
55
|
+
Jitter measures how much point predictions vary across different
|
|
56
|
+
runs or model versions. High jitter indicates unstable predictions.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
predictions: List of forecast DataFrames from different runs.
|
|
60
|
+
Each should have columns [unique_id, ds, yhat]
|
|
61
|
+
method: "cv" for coefficient of variation, "mad" for median absolute deviation
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Dict mapping unique_id to jitter metric
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> forecasts = [model.predict(data) for model in model_versions]
|
|
68
|
+
>>> jitter = monitor.compute_jitter(forecasts, method="cv")
|
|
69
|
+
>>> print(jitter["series_A"])
|
|
70
|
+
0.05
|
|
71
|
+
"""
|
|
72
|
+
if not predictions or len(predictions) < 2:
|
|
73
|
+
return {}
|
|
74
|
+
|
|
75
|
+
# Combine all predictions
|
|
76
|
+
combined = predictions[0][["unique_id", "ds", "yhat"]].copy()
|
|
77
|
+
combined.rename(columns={"yhat": "yhat_0"}, inplace=True)
|
|
78
|
+
|
|
79
|
+
for i, pred in enumerate(predictions[1:], 1):
|
|
80
|
+
combined = combined.merge(
|
|
81
|
+
pred[["unique_id", "ds", "yhat"]].rename(columns={"yhat": f"yhat_{i}"}),
|
|
82
|
+
on=["unique_id", "ds"],
|
|
83
|
+
how="outer",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
jitter_metrics = {}
|
|
87
|
+
|
|
88
|
+
for uid in combined["unique_id"].unique():
|
|
89
|
+
series_data = combined[combined["unique_id"] == uid]
|
|
90
|
+
yhat_cols = [c for c in series_data.columns if c.startswith("yhat_")]
|
|
91
|
+
|
|
92
|
+
if len(yhat_cols) < 2:
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
values = series_data[yhat_cols].values
|
|
96
|
+
|
|
97
|
+
if method == "cv":
|
|
98
|
+
# Coefficient of variation (std / mean)
|
|
99
|
+
means = np.nanmean(values, axis=1)
|
|
100
|
+
stds = np.nanstd(values, axis=1)
|
|
101
|
+
# Avoid division by zero
|
|
102
|
+
cvs = np.where(
|
|
103
|
+
np.abs(means) > 1e-10,
|
|
104
|
+
stds / np.abs(means),
|
|
105
|
+
0.0
|
|
106
|
+
)
|
|
107
|
+
jitter = float(np.nanmean(cvs))
|
|
108
|
+
else: # mad
|
|
109
|
+
# Median absolute deviation
|
|
110
|
+
medians = np.nanmedian(values, axis=1, keepdims=True)
|
|
111
|
+
mads = np.nanmedian(np.abs(values - medians), axis=1)
|
|
112
|
+
jitter = float(np.nanmean(mads))
|
|
113
|
+
|
|
114
|
+
jitter_metrics[uid] = jitter
|
|
115
|
+
|
|
116
|
+
return jitter_metrics
|
|
117
|
+
|
|
118
|
+
def compute_coverage(
|
|
119
|
+
self,
|
|
120
|
+
actuals: pd.DataFrame,
|
|
121
|
+
forecasts: pd.DataFrame,
|
|
122
|
+
quantiles: list[float],
|
|
123
|
+
) -> dict[float, float]:
|
|
124
|
+
"""Compute empirical coverage for each quantile.
|
|
125
|
+
|
|
126
|
+
Coverage is the proportion of actual values that fall below
|
|
127
|
+
the predicted quantile. For a well-calibrated model:
|
|
128
|
+
- q=0.1 should have ~10% coverage
|
|
129
|
+
- q=0.5 should have ~50% coverage
|
|
130
|
+
- q=0.9 should have ~90% coverage
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
actuals: DataFrame with actual values [unique_id, ds, y]
|
|
134
|
+
forecasts: DataFrame with quantile forecasts [unique_id, ds, q_0.1, ...]
|
|
135
|
+
quantiles: List of quantile levels (e.g., [0.1, 0.5, 0.9])
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Dict mapping quantile to empirical coverage (0-1)
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
>>> coverage = monitor.compute_coverage(
|
|
142
|
+
... actuals=df[["unique_id", "ds", "y"]],
|
|
143
|
+
... forecasts=pred_df,
|
|
144
|
+
... quantiles=[0.1, 0.5, 0.9]
|
|
145
|
+
... )
|
|
146
|
+
>>> print(coverage[0.5]) # Should be ~0.5 for well-calibrated model
|
|
147
|
+
0.52
|
|
148
|
+
"""
|
|
149
|
+
forecasts = normalize_quantile_columns(forecasts)
|
|
150
|
+
# Merge actuals with forecasts
|
|
151
|
+
merged = actuals.merge(forecasts, on=["unique_id", "ds"], how="inner")
|
|
152
|
+
|
|
153
|
+
coverage = {}
|
|
154
|
+
for q in quantiles:
|
|
155
|
+
col_name = quantile_col_name(q)
|
|
156
|
+
if col_name not in merged.columns:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Compute coverage: proportion of actuals <= quantile prediction
|
|
160
|
+
below_quantile = merged["y"] <= merged[col_name]
|
|
161
|
+
coverage[q] = float(below_quantile.mean())
|
|
162
|
+
|
|
163
|
+
return coverage
|
|
164
|
+
|
|
165
|
+
def check_calibration(
|
|
166
|
+
self,
|
|
167
|
+
actuals: pd.DataFrame,
|
|
168
|
+
forecasts: pd.DataFrame,
|
|
169
|
+
quantiles: list[float],
|
|
170
|
+
) -> CalibrationReport:
|
|
171
|
+
"""Check if quantiles are well-calibrated.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
actuals: DataFrame with actual values [unique_id, ds, y]
|
|
175
|
+
forecasts: DataFrame with quantile forecasts
|
|
176
|
+
quantiles: List of target quantile levels
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
CalibrationReport with coverage metrics and warnings
|
|
180
|
+
"""
|
|
181
|
+
coverage = self.compute_coverage(actuals, forecasts, quantiles)
|
|
182
|
+
|
|
183
|
+
# Compute calibration errors
|
|
184
|
+
errors = {}
|
|
185
|
+
for q in quantiles:
|
|
186
|
+
if q in coverage:
|
|
187
|
+
errors[q] = abs(coverage[q] - q)
|
|
188
|
+
else:
|
|
189
|
+
errors[q] = float("inf")
|
|
190
|
+
|
|
191
|
+
# Check if all quantiles are well-calibrated
|
|
192
|
+
well_calibrated = all(
|
|
193
|
+
err <= self.coverage_tolerance
|
|
194
|
+
for err in errors.values()
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return CalibrationReport(
|
|
198
|
+
target_quantiles=quantiles,
|
|
199
|
+
empirical_coverage=coverage,
|
|
200
|
+
calibration_errors=errors,
|
|
201
|
+
well_calibrated=well_calibrated,
|
|
202
|
+
tolerance=self.coverage_tolerance,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def generate_stability_report(
|
|
206
|
+
self,
|
|
207
|
+
predictions: list[pd.DataFrame] | None = None,
|
|
208
|
+
actuals: pd.DataFrame | None = None,
|
|
209
|
+
forecasts: pd.DataFrame | None = None,
|
|
210
|
+
quantiles: list[float] | None = None,
|
|
211
|
+
) -> StabilityReport:
|
|
212
|
+
"""Generate a comprehensive stability report.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
predictions: List of forecast DataFrames for jitter calculation
|
|
216
|
+
actuals: DataFrame with actual values for coverage analysis
|
|
217
|
+
forecasts: DataFrame with quantile forecasts for coverage analysis
|
|
218
|
+
quantiles: List of quantile levels for coverage analysis
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
StabilityReport with jitter and coverage metrics
|
|
222
|
+
"""
|
|
223
|
+
# Compute jitter if predictions provided
|
|
224
|
+
jitter_metrics = {}
|
|
225
|
+
if predictions and len(predictions) >= 2:
|
|
226
|
+
jitter_metrics = self.compute_jitter(predictions)
|
|
227
|
+
|
|
228
|
+
overall_jitter = np.mean(list(jitter_metrics.values())) if jitter_metrics else 0.0
|
|
229
|
+
|
|
230
|
+
# Identify high jitter series
|
|
231
|
+
high_jitter_series = [
|
|
232
|
+
uid for uid, jit in jitter_metrics.items()
|
|
233
|
+
if jit > self.jitter_threshold
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
# Compute coverage report if data provided
|
|
237
|
+
coverage_report = None
|
|
238
|
+
if actuals is not None and forecasts is not None and quantiles:
|
|
239
|
+
coverage_report = self.check_calibration(actuals, forecasts, quantiles)
|
|
240
|
+
|
|
241
|
+
return StabilityReport(
|
|
242
|
+
jitter_metrics=jitter_metrics,
|
|
243
|
+
overall_jitter=float(overall_jitter),
|
|
244
|
+
jitter_threshold=self.jitter_threshold,
|
|
245
|
+
high_jitter_series=high_jitter_series,
|
|
246
|
+
coverage_report=coverage_report,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def compute_prediction_interval_coverage(
|
|
251
|
+
actuals: pd.Series,
|
|
252
|
+
lower_bound: pd.Series,
|
|
253
|
+
upper_bound: pd.Series,
|
|
254
|
+
) -> float:
|
|
255
|
+
"""Compute coverage for a prediction interval.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
actuals: Actual values
|
|
259
|
+
lower_bound: Lower bound of prediction interval
|
|
260
|
+
upper_bound: Upper bound of prediction interval
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Proportion of actuals within the interval (0-1)
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
>>> coverage = compute_prediction_interval_coverage(
|
|
267
|
+
... actuals=df["y"],
|
|
268
|
+
... lower_bound=forecasts["q_0.1"],
|
|
269
|
+
... upper_bound=forecasts["q_0.9"],
|
|
270
|
+
... )
|
|
271
|
+
>>> print(coverage) # Should be ~0.8 for 80% PI
|
|
272
|
+
0.82
|
|
273
|
+
"""
|
|
274
|
+
within_interval = (actuals >= lower_bound) & (actuals <= upper_bound)
|
|
275
|
+
return float(within_interval.mean())
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""Retrain triggers for model monitoring."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import UTC, datetime, timedelta
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from tsagentkit.monitoring.report import TriggerResult
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from tsagentkit.monitoring.report import DriftReport, StabilityReport
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TriggerType(Enum):
|
|
17
|
+
"""Types of retrain triggers.
|
|
18
|
+
|
|
19
|
+
- DRIFT: Data drift detected (PSI/KS above threshold)
|
|
20
|
+
- SCHEDULE: Time-based trigger (e.g., daily, weekly)
|
|
21
|
+
- PERFORMANCE: Metric degradation (accuracy drop)
|
|
22
|
+
- MANUAL: Explicit manual trigger
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
DRIFT = "drift"
|
|
26
|
+
SCHEDULE = "schedule"
|
|
27
|
+
PERFORMANCE = "performance"
|
|
28
|
+
MANUAL = "manual"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class RetrainTrigger:
|
|
33
|
+
"""Configuration for a retrain trigger.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
trigger_type: Type of trigger
|
|
37
|
+
threshold: Threshold value for trigger activation
|
|
38
|
+
metric_name: For PERFORMANCE triggers, the metric to monitor
|
|
39
|
+
schedule_interval: For SCHEDULE triggers, the interval (e.g., "1d", "7d")
|
|
40
|
+
enabled: Whether this trigger is active
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> drift_trigger = RetrainTrigger(
|
|
44
|
+
... trigger_type=TriggerType.DRIFT,
|
|
45
|
+
... threshold=0.2,
|
|
46
|
+
... )
|
|
47
|
+
>>> schedule_trigger = RetrainTrigger(
|
|
48
|
+
... trigger_type=TriggerType.SCHEDULE,
|
|
49
|
+
... schedule_interval="7d",
|
|
50
|
+
... )
|
|
51
|
+
>>> perf_trigger = RetrainTrigger(
|
|
52
|
+
... trigger_type=TriggerType.PERFORMANCE,
|
|
53
|
+
... metric_name="wape",
|
|
54
|
+
... threshold=0.15,
|
|
55
|
+
... )
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
trigger_type: TriggerType
|
|
59
|
+
threshold: float | None = None
|
|
60
|
+
metric_name: str | None = None
|
|
61
|
+
schedule_interval: str | None = None
|
|
62
|
+
enabled: bool = True
|
|
63
|
+
|
|
64
|
+
def __post_init__(self) -> None:
|
|
65
|
+
"""Validate trigger configuration."""
|
|
66
|
+
if self.trigger_type == TriggerType.DRIFT and self.threshold is None:
|
|
67
|
+
self.threshold = 0.2 # Default PSI threshold
|
|
68
|
+
elif self.trigger_type == TriggerType.PERFORMANCE and self.threshold is None:
|
|
69
|
+
self.threshold = 0.1 # Default 10% degradation
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TriggerEvaluator:
|
|
73
|
+
"""Evaluate retrain triggers based on monitoring data.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> triggers = [
|
|
77
|
+
... RetrainTrigger(TriggerType.DRIFT, threshold=0.2),
|
|
78
|
+
... RetrainTrigger(TriggerType.SCHEDULE, schedule_interval="7d"),
|
|
79
|
+
... ]
|
|
80
|
+
>>> evaluator = TriggerEvaluator(triggers)
|
|
81
|
+
>>>
|
|
82
|
+
>>> # Evaluate with monitoring data
|
|
83
|
+
>>> results = evaluator.evaluate(
|
|
84
|
+
... drift_report=drift_report,
|
|
85
|
+
... last_train_time=datetime(2024, 1, 1),
|
|
86
|
+
... )
|
|
87
|
+
>>>
|
|
88
|
+
>>> if evaluator.should_retrain(...):
|
|
89
|
+
... print("Retraining needed!")
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, triggers: list[RetrainTrigger]):
|
|
93
|
+
"""Initialize with trigger configurations.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
triggers: List of trigger configurations to evaluate
|
|
97
|
+
"""
|
|
98
|
+
self.triggers = [t for t in triggers if t.enabled]
|
|
99
|
+
|
|
100
|
+
def evaluate(
|
|
101
|
+
self,
|
|
102
|
+
drift_report: DriftReport | None = None,
|
|
103
|
+
stability_report: StabilityReport | None = None,
|
|
104
|
+
current_metrics: dict[str, float] | None = None,
|
|
105
|
+
last_train_time: datetime | None = None,
|
|
106
|
+
) -> list[TriggerResult]:
|
|
107
|
+
"""Evaluate all triggers and return those that fired.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
drift_report: Latest drift detection results
|
|
111
|
+
stability_report: Latest stability metrics
|
|
112
|
+
current_metrics: Current model performance metrics
|
|
113
|
+
last_train_time: When model was last trained
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of TriggerResult for all evaluated triggers
|
|
117
|
+
"""
|
|
118
|
+
results = []
|
|
119
|
+
|
|
120
|
+
for trigger in self.triggers:
|
|
121
|
+
result = self._evaluate_single_trigger(
|
|
122
|
+
trigger,
|
|
123
|
+
drift_report,
|
|
124
|
+
stability_report,
|
|
125
|
+
current_metrics,
|
|
126
|
+
last_train_time,
|
|
127
|
+
)
|
|
128
|
+
results.append(result)
|
|
129
|
+
|
|
130
|
+
return results
|
|
131
|
+
|
|
132
|
+
def _evaluate_single_trigger(
|
|
133
|
+
self,
|
|
134
|
+
trigger: RetrainTrigger,
|
|
135
|
+
drift_report: DriftReport | None,
|
|
136
|
+
stability_report: StabilityReport | None,
|
|
137
|
+
current_metrics: dict[str, float] | None,
|
|
138
|
+
last_train_time: datetime | None,
|
|
139
|
+
) -> TriggerResult:
|
|
140
|
+
"""Evaluate a single trigger.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
trigger: Trigger configuration
|
|
144
|
+
drift_report: Drift detection results
|
|
145
|
+
stability_report: Stability metrics
|
|
146
|
+
current_metrics: Current performance metrics
|
|
147
|
+
last_train_time: Last training timestamp
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
TriggerResult with evaluation outcome
|
|
151
|
+
"""
|
|
152
|
+
if trigger.trigger_type == TriggerType.DRIFT:
|
|
153
|
+
return self._evaluate_drift_trigger(trigger, drift_report)
|
|
154
|
+
elif trigger.trigger_type == TriggerType.SCHEDULE:
|
|
155
|
+
return self._evaluate_schedule_trigger(trigger, last_train_time)
|
|
156
|
+
elif trigger.trigger_type == TriggerType.PERFORMANCE:
|
|
157
|
+
return self._evaluate_performance_trigger(trigger, current_metrics)
|
|
158
|
+
elif trigger.trigger_type == TriggerType.MANUAL:
|
|
159
|
+
return self._evaluate_manual_trigger(trigger)
|
|
160
|
+
else:
|
|
161
|
+
return TriggerResult(
|
|
162
|
+
trigger_type=trigger.trigger_type.value,
|
|
163
|
+
fired=False,
|
|
164
|
+
reason=f"Unknown trigger type: {trigger.trigger_type}",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _evaluate_drift_trigger(
|
|
168
|
+
self,
|
|
169
|
+
trigger: RetrainTrigger,
|
|
170
|
+
drift_report: DriftReport | None,
|
|
171
|
+
) -> TriggerResult:
|
|
172
|
+
"""Evaluate drift-based trigger."""
|
|
173
|
+
if drift_report is None:
|
|
174
|
+
return TriggerResult(
|
|
175
|
+
trigger_type=TriggerType.DRIFT.value,
|
|
176
|
+
fired=False,
|
|
177
|
+
reason="No drift report provided",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
threshold = trigger.threshold or 0.2
|
|
181
|
+
|
|
182
|
+
if drift_report.overall_drift_score > threshold:
|
|
183
|
+
return TriggerResult(
|
|
184
|
+
trigger_type=TriggerType.DRIFT.value,
|
|
185
|
+
fired=True,
|
|
186
|
+
reason=(
|
|
187
|
+
f"PSI drift score {drift_report.overall_drift_score:.3f} "
|
|
188
|
+
f"exceeded threshold {threshold}"
|
|
189
|
+
),
|
|
190
|
+
metadata={
|
|
191
|
+
"psi_score": drift_report.overall_drift_score,
|
|
192
|
+
"threshold": threshold,
|
|
193
|
+
"drifting_features": drift_report.get_drifting_features(),
|
|
194
|
+
},
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
return TriggerResult(
|
|
198
|
+
trigger_type=TriggerType.DRIFT.value,
|
|
199
|
+
fired=False,
|
|
200
|
+
reason=(
|
|
201
|
+
f"PSI drift score {drift_report.overall_drift_score:.3f} "
|
|
202
|
+
f"below threshold {threshold}"
|
|
203
|
+
),
|
|
204
|
+
metadata={"psi_score": drift_report.overall_drift_score},
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def _evaluate_schedule_trigger(
|
|
208
|
+
self,
|
|
209
|
+
trigger: RetrainTrigger,
|
|
210
|
+
last_train_time: datetime | None,
|
|
211
|
+
) -> TriggerResult:
|
|
212
|
+
"""Evaluate schedule-based trigger."""
|
|
213
|
+
if last_train_time is None:
|
|
214
|
+
return TriggerResult(
|
|
215
|
+
trigger_type=TriggerType.SCHEDULE.value,
|
|
216
|
+
fired=False,
|
|
217
|
+
reason="No last_train_time provided",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if trigger.schedule_interval is None:
|
|
221
|
+
return TriggerResult(
|
|
222
|
+
trigger_type=TriggerType.SCHEDULE.value,
|
|
223
|
+
fired=False,
|
|
224
|
+
reason="No schedule_interval configured",
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Parse interval (e.g., "7d", "1h")
|
|
228
|
+
interval = self._parse_interval(trigger.schedule_interval)
|
|
229
|
+
if interval is None:
|
|
230
|
+
return TriggerResult(
|
|
231
|
+
trigger_type=TriggerType.SCHEDULE.value,
|
|
232
|
+
fired=False,
|
|
233
|
+
reason=f"Invalid schedule_interval: {trigger.schedule_interval}",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Ensure last_train_time is timezone-aware
|
|
237
|
+
if last_train_time.tzinfo is None:
|
|
238
|
+
last_train_time = last_train_time.replace(tzinfo=UTC)
|
|
239
|
+
|
|
240
|
+
next_train_time = last_train_time + interval
|
|
241
|
+
now = datetime.now(UTC)
|
|
242
|
+
|
|
243
|
+
if now >= next_train_time:
|
|
244
|
+
return TriggerResult(
|
|
245
|
+
trigger_type=TriggerType.SCHEDULE.value,
|
|
246
|
+
fired=True,
|
|
247
|
+
reason=(
|
|
248
|
+
f"Schedule interval {trigger.schedule_interval} elapsed "
|
|
249
|
+
f"since last training at {last_train_time.isoformat()}"
|
|
250
|
+
),
|
|
251
|
+
metadata={
|
|
252
|
+
"last_train_time": last_train_time.isoformat(),
|
|
253
|
+
"next_train_time": next_train_time.isoformat(),
|
|
254
|
+
"interval": trigger.schedule_interval,
|
|
255
|
+
},
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
return TriggerResult(
|
|
259
|
+
trigger_type=TriggerType.SCHEDULE.value,
|
|
260
|
+
fired=False,
|
|
261
|
+
reason=(
|
|
262
|
+
f"Schedule interval {trigger.schedule_interval} not yet elapsed"
|
|
263
|
+
),
|
|
264
|
+
metadata={
|
|
265
|
+
"next_train_time": next_train_time.isoformat(),
|
|
266
|
+
},
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def _evaluate_performance_trigger(
|
|
270
|
+
self,
|
|
271
|
+
trigger: RetrainTrigger,
|
|
272
|
+
current_metrics: dict[str, float] | None,
|
|
273
|
+
) -> TriggerResult:
|
|
274
|
+
"""Evaluate performance-based trigger."""
|
|
275
|
+
if current_metrics is None:
|
|
276
|
+
return TriggerResult(
|
|
277
|
+
trigger_type=TriggerType.PERFORMANCE.value,
|
|
278
|
+
fired=False,
|
|
279
|
+
reason="No current metrics provided",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if trigger.metric_name is None:
|
|
283
|
+
return TriggerResult(
|
|
284
|
+
trigger_type=TriggerType.PERFORMANCE.value,
|
|
285
|
+
fired=False,
|
|
286
|
+
reason="No metric_name configured",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if trigger.metric_name not in current_metrics:
|
|
290
|
+
return TriggerResult(
|
|
291
|
+
trigger_type=TriggerType.PERFORMANCE.value,
|
|
292
|
+
fired=False,
|
|
293
|
+
reason=f"Metric '{trigger.metric_name}' not found in current metrics",
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
metric_value = current_metrics[trigger.metric_name]
|
|
297
|
+
threshold = trigger.threshold or 0.1
|
|
298
|
+
|
|
299
|
+
# For error metrics like MAPE, higher is worse
|
|
300
|
+
if metric_value > threshold:
|
|
301
|
+
return TriggerResult(
|
|
302
|
+
trigger_type=TriggerType.PERFORMANCE.value,
|
|
303
|
+
fired=True,
|
|
304
|
+
reason=(
|
|
305
|
+
f"Metric '{trigger.metric_name}' value {metric_value:.4f} "
|
|
306
|
+
f"exceeded threshold {threshold}"
|
|
307
|
+
),
|
|
308
|
+
metadata={
|
|
309
|
+
"metric_name": trigger.metric_name,
|
|
310
|
+
"metric_value": metric_value,
|
|
311
|
+
"threshold": threshold,
|
|
312
|
+
},
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
return TriggerResult(
|
|
316
|
+
trigger_type=TriggerType.PERFORMANCE.value,
|
|
317
|
+
fired=False,
|
|
318
|
+
reason=(
|
|
319
|
+
f"Metric '{trigger.metric_name}' value {metric_value:.4f} "
|
|
320
|
+
f"within threshold {threshold}"
|
|
321
|
+
),
|
|
322
|
+
metadata={
|
|
323
|
+
"metric_name": trigger.metric_name,
|
|
324
|
+
"metric_value": metric_value,
|
|
325
|
+
},
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def _evaluate_manual_trigger(
|
|
329
|
+
self,
|
|
330
|
+
trigger: RetrainTrigger,
|
|
331
|
+
) -> TriggerResult:
|
|
332
|
+
"""Evaluate manual trigger.
|
|
333
|
+
|
|
334
|
+
Manual triggers are always "no-op" unless explicitly set to fire.
|
|
335
|
+
In practice, a manual trigger would be checked via an external signal.
|
|
336
|
+
"""
|
|
337
|
+
return TriggerResult(
|
|
338
|
+
trigger_type=TriggerType.MANUAL.value,
|
|
339
|
+
fired=False,
|
|
340
|
+
reason="Manual trigger not activated (use manual retrain API)",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def _parse_interval(self, interval_str: str) -> timedelta | None:
|
|
344
|
+
"""Parse interval string into timedelta.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
interval_str: Interval string like "7d", "1h", "30m"
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
timedelta or None if invalid
|
|
351
|
+
"""
|
|
352
|
+
if not interval_str:
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
# Extract number and unit
|
|
357
|
+
num = int("".join(filter(str.isdigit, interval_str)))
|
|
358
|
+
unit = "".join(filter(str.isalpha, interval_str)).lower()
|
|
359
|
+
|
|
360
|
+
if unit in ("d", "day", "days"):
|
|
361
|
+
return timedelta(days=num)
|
|
362
|
+
elif unit in ("h", "hour", "hours"):
|
|
363
|
+
return timedelta(hours=num)
|
|
364
|
+
elif unit in ("m", "min", "minute", "minutes"):
|
|
365
|
+
return timedelta(minutes=num)
|
|
366
|
+
elif unit in ("w", "week", "weeks"):
|
|
367
|
+
return timedelta(weeks=num)
|
|
368
|
+
else:
|
|
369
|
+
return None
|
|
370
|
+
except (ValueError, TypeError):
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
def should_retrain(
|
|
374
|
+
self,
|
|
375
|
+
drift_report: DriftReport | None = None,
|
|
376
|
+
stability_report: StabilityReport | None = None,
|
|
377
|
+
current_metrics: dict[str, float] | None = None,
|
|
378
|
+
last_train_time: datetime | None = None,
|
|
379
|
+
) -> bool:
|
|
380
|
+
"""Check if any trigger indicates retraining is needed.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
drift_report: Drift detection results
|
|
384
|
+
stability_report: Stability metrics
|
|
385
|
+
current_metrics: Current performance metrics
|
|
386
|
+
last_train_time: Last training timestamp
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
True if any trigger fired, False otherwise
|
|
390
|
+
"""
|
|
391
|
+
results = self.evaluate(
|
|
392
|
+
drift_report,
|
|
393
|
+
stability_report,
|
|
394
|
+
current_metrics,
|
|
395
|
+
last_train_time,
|
|
396
|
+
)
|
|
397
|
+
return any(r.fired for r in results)
|
|
398
|
+
|
|
399
|
+
def get_fired_triggers(
|
|
400
|
+
self,
|
|
401
|
+
drift_report: DriftReport | None = None,
|
|
402
|
+
stability_report: StabilityReport | None = None,
|
|
403
|
+
current_metrics: dict[str, float] | None = None,
|
|
404
|
+
last_train_time: datetime | None = None,
|
|
405
|
+
) -> list[TriggerResult]:
|
|
406
|
+
"""Get only the triggers that fired.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
drift_report: Drift detection results
|
|
410
|
+
stability_report: Stability metrics
|
|
411
|
+
current_metrics: Current performance metrics
|
|
412
|
+
last_train_time: Last training timestamp
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
List of TriggerResult for triggers that fired
|
|
416
|
+
"""
|
|
417
|
+
results = self.evaluate(
|
|
418
|
+
drift_report,
|
|
419
|
+
stability_report,
|
|
420
|
+
current_metrics,
|
|
421
|
+
last_train_time,
|
|
422
|
+
)
|
|
423
|
+
return [r for r in results if r.fired]
|