tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,275 @@
1
+ """Stability monitoring for prediction jitter and quantile coverage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from tsagentkit.monitoring.report import CalibrationReport, StabilityReport
11
+ from tsagentkit.utils import normalize_quantile_columns, quantile_col_name
12
+
13
+ if TYPE_CHECKING:
14
+ pass
15
+
16
+
17
+ class StabilityMonitor:
18
+ """Monitor prediction stability and calibration.
19
+
20
+ This class tracks:
21
+ - Prediction jitter: Variance in point predictions across runs
22
+ - Quantile coverage: Whether empirical coverage matches target quantiles
23
+
24
+ Example:
25
+ >>> monitor = StabilityMonitor(jitter_threshold=0.1)
26
+ >>>
27
+ >>> # Check jitter across multiple forecast runs
28
+ >>> jitter = monitor.compute_jitter([forecast1, forecast2, forecast3])
29
+ >>>
30
+ >>> # Check quantile calibration
31
+ >>> coverage = monitor.compute_coverage(actuals, forecasts, [0.1, 0.5, 0.9])
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ jitter_threshold: float = 0.1,
37
+ coverage_tolerance: float = 0.05,
38
+ ):
39
+ """Initialize stability monitor.
40
+
41
+ Args:
42
+ jitter_threshold: Coefficient of variation threshold for jitter warnings
43
+ coverage_tolerance: Allowed deviation from target coverage
44
+ """
45
+ self.jitter_threshold = jitter_threshold
46
+ self.coverage_tolerance = coverage_tolerance
47
+
48
+ def compute_jitter(
49
+ self,
50
+ predictions: list[pd.DataFrame],
51
+ method: Literal["cv", "mad"] = "cv",
52
+ ) -> dict[str, float]:
53
+ """Compute prediction jitter across multiple forecast runs.
54
+
55
+ Jitter measures how much point predictions vary across different
56
+ runs or model versions. High jitter indicates unstable predictions.
57
+
58
+ Args:
59
+ predictions: List of forecast DataFrames from different runs.
60
+ Each should have columns [unique_id, ds, yhat]
61
+ method: "cv" for coefficient of variation, "mad" for median absolute deviation
62
+
63
+ Returns:
64
+ Dict mapping unique_id to jitter metric
65
+
66
+ Example:
67
+ >>> forecasts = [model.predict(data) for model in model_versions]
68
+ >>> jitter = monitor.compute_jitter(forecasts, method="cv")
69
+ >>> print(jitter["series_A"])
70
+ 0.05
71
+ """
72
+ if not predictions or len(predictions) < 2:
73
+ return {}
74
+
75
+ # Combine all predictions
76
+ combined = predictions[0][["unique_id", "ds", "yhat"]].copy()
77
+ combined.rename(columns={"yhat": "yhat_0"}, inplace=True)
78
+
79
+ for i, pred in enumerate(predictions[1:], 1):
80
+ combined = combined.merge(
81
+ pred[["unique_id", "ds", "yhat"]].rename(columns={"yhat": f"yhat_{i}"}),
82
+ on=["unique_id", "ds"],
83
+ how="outer",
84
+ )
85
+
86
+ jitter_metrics = {}
87
+
88
+ for uid in combined["unique_id"].unique():
89
+ series_data = combined[combined["unique_id"] == uid]
90
+ yhat_cols = [c for c in series_data.columns if c.startswith("yhat_")]
91
+
92
+ if len(yhat_cols) < 2:
93
+ continue
94
+
95
+ values = series_data[yhat_cols].values
96
+
97
+ if method == "cv":
98
+ # Coefficient of variation (std / mean)
99
+ means = np.nanmean(values, axis=1)
100
+ stds = np.nanstd(values, axis=1)
101
+ # Avoid division by zero
102
+ cvs = np.where(
103
+ np.abs(means) > 1e-10,
104
+ stds / np.abs(means),
105
+ 0.0
106
+ )
107
+ jitter = float(np.nanmean(cvs))
108
+ else: # mad
109
+ # Median absolute deviation
110
+ medians = np.nanmedian(values, axis=1, keepdims=True)
111
+ mads = np.nanmedian(np.abs(values - medians), axis=1)
112
+ jitter = float(np.nanmean(mads))
113
+
114
+ jitter_metrics[uid] = jitter
115
+
116
+ return jitter_metrics
117
+
118
+ def compute_coverage(
119
+ self,
120
+ actuals: pd.DataFrame,
121
+ forecasts: pd.DataFrame,
122
+ quantiles: list[float],
123
+ ) -> dict[float, float]:
124
+ """Compute empirical coverage for each quantile.
125
+
126
+ Coverage is the proportion of actual values that fall below
127
+ the predicted quantile. For a well-calibrated model:
128
+ - q=0.1 should have ~10% coverage
129
+ - q=0.5 should have ~50% coverage
130
+ - q=0.9 should have ~90% coverage
131
+
132
+ Args:
133
+ actuals: DataFrame with actual values [unique_id, ds, y]
134
+ forecasts: DataFrame with quantile forecasts [unique_id, ds, q_0.1, ...]
135
+ quantiles: List of quantile levels (e.g., [0.1, 0.5, 0.9])
136
+
137
+ Returns:
138
+ Dict mapping quantile to empirical coverage (0-1)
139
+
140
+ Example:
141
+ >>> coverage = monitor.compute_coverage(
142
+ ... actuals=df[["unique_id", "ds", "y"]],
143
+ ... forecasts=pred_df,
144
+ ... quantiles=[0.1, 0.5, 0.9]
145
+ ... )
146
+ >>> print(coverage[0.5]) # Should be ~0.5 for well-calibrated model
147
+ 0.52
148
+ """
149
+ forecasts = normalize_quantile_columns(forecasts)
150
+ # Merge actuals with forecasts
151
+ merged = actuals.merge(forecasts, on=["unique_id", "ds"], how="inner")
152
+
153
+ coverage = {}
154
+ for q in quantiles:
155
+ col_name = quantile_col_name(q)
156
+ if col_name not in merged.columns:
157
+ continue
158
+
159
+ # Compute coverage: proportion of actuals <= quantile prediction
160
+ below_quantile = merged["y"] <= merged[col_name]
161
+ coverage[q] = float(below_quantile.mean())
162
+
163
+ return coverage
164
+
165
+ def check_calibration(
166
+ self,
167
+ actuals: pd.DataFrame,
168
+ forecasts: pd.DataFrame,
169
+ quantiles: list[float],
170
+ ) -> CalibrationReport:
171
+ """Check if quantiles are well-calibrated.
172
+
173
+ Args:
174
+ actuals: DataFrame with actual values [unique_id, ds, y]
175
+ forecasts: DataFrame with quantile forecasts
176
+ quantiles: List of target quantile levels
177
+
178
+ Returns:
179
+ CalibrationReport with coverage metrics and warnings
180
+ """
181
+ coverage = self.compute_coverage(actuals, forecasts, quantiles)
182
+
183
+ # Compute calibration errors
184
+ errors = {}
185
+ for q in quantiles:
186
+ if q in coverage:
187
+ errors[q] = abs(coverage[q] - q)
188
+ else:
189
+ errors[q] = float("inf")
190
+
191
+ # Check if all quantiles are well-calibrated
192
+ well_calibrated = all(
193
+ err <= self.coverage_tolerance
194
+ for err in errors.values()
195
+ )
196
+
197
+ return CalibrationReport(
198
+ target_quantiles=quantiles,
199
+ empirical_coverage=coverage,
200
+ calibration_errors=errors,
201
+ well_calibrated=well_calibrated,
202
+ tolerance=self.coverage_tolerance,
203
+ )
204
+
205
+ def generate_stability_report(
206
+ self,
207
+ predictions: list[pd.DataFrame] | None = None,
208
+ actuals: pd.DataFrame | None = None,
209
+ forecasts: pd.DataFrame | None = None,
210
+ quantiles: list[float] | None = None,
211
+ ) -> StabilityReport:
212
+ """Generate a comprehensive stability report.
213
+
214
+ Args:
215
+ predictions: List of forecast DataFrames for jitter calculation
216
+ actuals: DataFrame with actual values for coverage analysis
217
+ forecasts: DataFrame with quantile forecasts for coverage analysis
218
+ quantiles: List of quantile levels for coverage analysis
219
+
220
+ Returns:
221
+ StabilityReport with jitter and coverage metrics
222
+ """
223
+ # Compute jitter if predictions provided
224
+ jitter_metrics = {}
225
+ if predictions and len(predictions) >= 2:
226
+ jitter_metrics = self.compute_jitter(predictions)
227
+
228
+ overall_jitter = np.mean(list(jitter_metrics.values())) if jitter_metrics else 0.0
229
+
230
+ # Identify high jitter series
231
+ high_jitter_series = [
232
+ uid for uid, jit in jitter_metrics.items()
233
+ if jit > self.jitter_threshold
234
+ ]
235
+
236
+ # Compute coverage report if data provided
237
+ coverage_report = None
238
+ if actuals is not None and forecasts is not None and quantiles:
239
+ coverage_report = self.check_calibration(actuals, forecasts, quantiles)
240
+
241
+ return StabilityReport(
242
+ jitter_metrics=jitter_metrics,
243
+ overall_jitter=float(overall_jitter),
244
+ jitter_threshold=self.jitter_threshold,
245
+ high_jitter_series=high_jitter_series,
246
+ coverage_report=coverage_report,
247
+ )
248
+
249
+
250
+ def compute_prediction_interval_coverage(
251
+ actuals: pd.Series,
252
+ lower_bound: pd.Series,
253
+ upper_bound: pd.Series,
254
+ ) -> float:
255
+ """Compute coverage for a prediction interval.
256
+
257
+ Args:
258
+ actuals: Actual values
259
+ lower_bound: Lower bound of prediction interval
260
+ upper_bound: Upper bound of prediction interval
261
+
262
+ Returns:
263
+ Proportion of actuals within the interval (0-1)
264
+
265
+ Example:
266
+ >>> coverage = compute_prediction_interval_coverage(
267
+ ... actuals=df["y"],
268
+ ... lower_bound=forecasts["q_0.1"],
269
+ ... upper_bound=forecasts["q_0.9"],
270
+ ... )
271
+ >>> print(coverage) # Should be ~0.8 for 80% PI
272
+ 0.82
273
+ """
274
+ within_interval = (actuals >= lower_bound) & (actuals <= upper_bound)
275
+ return float(within_interval.mean())
@@ -0,0 +1,423 @@
1
+ """Retrain triggers for model monitoring."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from datetime import UTC, datetime, timedelta
7
+ from enum import Enum
8
+ from typing import TYPE_CHECKING
9
+
10
+ from tsagentkit.monitoring.report import TriggerResult
11
+
12
+ if TYPE_CHECKING:
13
+ from tsagentkit.monitoring.report import DriftReport, StabilityReport
14
+
15
+
16
+ class TriggerType(Enum):
17
+ """Types of retrain triggers.
18
+
19
+ - DRIFT: Data drift detected (PSI/KS above threshold)
20
+ - SCHEDULE: Time-based trigger (e.g., daily, weekly)
21
+ - PERFORMANCE: Metric degradation (accuracy drop)
22
+ - MANUAL: Explicit manual trigger
23
+ """
24
+
25
+ DRIFT = "drift"
26
+ SCHEDULE = "schedule"
27
+ PERFORMANCE = "performance"
28
+ MANUAL = "manual"
29
+
30
+
31
+ @dataclass
32
+ class RetrainTrigger:
33
+ """Configuration for a retrain trigger.
34
+
35
+ Attributes:
36
+ trigger_type: Type of trigger
37
+ threshold: Threshold value for trigger activation
38
+ metric_name: For PERFORMANCE triggers, the metric to monitor
39
+ schedule_interval: For SCHEDULE triggers, the interval (e.g., "1d", "7d")
40
+ enabled: Whether this trigger is active
41
+
42
+ Example:
43
+ >>> drift_trigger = RetrainTrigger(
44
+ ... trigger_type=TriggerType.DRIFT,
45
+ ... threshold=0.2,
46
+ ... )
47
+ >>> schedule_trigger = RetrainTrigger(
48
+ ... trigger_type=TriggerType.SCHEDULE,
49
+ ... schedule_interval="7d",
50
+ ... )
51
+ >>> perf_trigger = RetrainTrigger(
52
+ ... trigger_type=TriggerType.PERFORMANCE,
53
+ ... metric_name="wape",
54
+ ... threshold=0.15,
55
+ ... )
56
+ """
57
+
58
+ trigger_type: TriggerType
59
+ threshold: float | None = None
60
+ metric_name: str | None = None
61
+ schedule_interval: str | None = None
62
+ enabled: bool = True
63
+
64
+ def __post_init__(self) -> None:
65
+ """Validate trigger configuration."""
66
+ if self.trigger_type == TriggerType.DRIFT and self.threshold is None:
67
+ self.threshold = 0.2 # Default PSI threshold
68
+ elif self.trigger_type == TriggerType.PERFORMANCE and self.threshold is None:
69
+ self.threshold = 0.1 # Default 10% degradation
70
+
71
+
72
+ class TriggerEvaluator:
73
+ """Evaluate retrain triggers based on monitoring data.
74
+
75
+ Example:
76
+ >>> triggers = [
77
+ ... RetrainTrigger(TriggerType.DRIFT, threshold=0.2),
78
+ ... RetrainTrigger(TriggerType.SCHEDULE, schedule_interval="7d"),
79
+ ... ]
80
+ >>> evaluator = TriggerEvaluator(triggers)
81
+ >>>
82
+ >>> # Evaluate with monitoring data
83
+ >>> results = evaluator.evaluate(
84
+ ... drift_report=drift_report,
85
+ ... last_train_time=datetime(2024, 1, 1),
86
+ ... )
87
+ >>>
88
+ >>> if evaluator.should_retrain(...):
89
+ ... print("Retraining needed!")
90
+ """
91
+
92
+ def __init__(self, triggers: list[RetrainTrigger]):
93
+ """Initialize with trigger configurations.
94
+
95
+ Args:
96
+ triggers: List of trigger configurations to evaluate
97
+ """
98
+ self.triggers = [t for t in triggers if t.enabled]
99
+
100
+ def evaluate(
101
+ self,
102
+ drift_report: DriftReport | None = None,
103
+ stability_report: StabilityReport | None = None,
104
+ current_metrics: dict[str, float] | None = None,
105
+ last_train_time: datetime | None = None,
106
+ ) -> list[TriggerResult]:
107
+ """Evaluate all triggers and return those that fired.
108
+
109
+ Args:
110
+ drift_report: Latest drift detection results
111
+ stability_report: Latest stability metrics
112
+ current_metrics: Current model performance metrics
113
+ last_train_time: When model was last trained
114
+
115
+ Returns:
116
+ List of TriggerResult for all evaluated triggers
117
+ """
118
+ results = []
119
+
120
+ for trigger in self.triggers:
121
+ result = self._evaluate_single_trigger(
122
+ trigger,
123
+ drift_report,
124
+ stability_report,
125
+ current_metrics,
126
+ last_train_time,
127
+ )
128
+ results.append(result)
129
+
130
+ return results
131
+
132
+ def _evaluate_single_trigger(
133
+ self,
134
+ trigger: RetrainTrigger,
135
+ drift_report: DriftReport | None,
136
+ stability_report: StabilityReport | None,
137
+ current_metrics: dict[str, float] | None,
138
+ last_train_time: datetime | None,
139
+ ) -> TriggerResult:
140
+ """Evaluate a single trigger.
141
+
142
+ Args:
143
+ trigger: Trigger configuration
144
+ drift_report: Drift detection results
145
+ stability_report: Stability metrics
146
+ current_metrics: Current performance metrics
147
+ last_train_time: Last training timestamp
148
+
149
+ Returns:
150
+ TriggerResult with evaluation outcome
151
+ """
152
+ if trigger.trigger_type == TriggerType.DRIFT:
153
+ return self._evaluate_drift_trigger(trigger, drift_report)
154
+ elif trigger.trigger_type == TriggerType.SCHEDULE:
155
+ return self._evaluate_schedule_trigger(trigger, last_train_time)
156
+ elif trigger.trigger_type == TriggerType.PERFORMANCE:
157
+ return self._evaluate_performance_trigger(trigger, current_metrics)
158
+ elif trigger.trigger_type == TriggerType.MANUAL:
159
+ return self._evaluate_manual_trigger(trigger)
160
+ else:
161
+ return TriggerResult(
162
+ trigger_type=trigger.trigger_type.value,
163
+ fired=False,
164
+ reason=f"Unknown trigger type: {trigger.trigger_type}",
165
+ )
166
+
167
+ def _evaluate_drift_trigger(
168
+ self,
169
+ trigger: RetrainTrigger,
170
+ drift_report: DriftReport | None,
171
+ ) -> TriggerResult:
172
+ """Evaluate drift-based trigger."""
173
+ if drift_report is None:
174
+ return TriggerResult(
175
+ trigger_type=TriggerType.DRIFT.value,
176
+ fired=False,
177
+ reason="No drift report provided",
178
+ )
179
+
180
+ threshold = trigger.threshold or 0.2
181
+
182
+ if drift_report.overall_drift_score > threshold:
183
+ return TriggerResult(
184
+ trigger_type=TriggerType.DRIFT.value,
185
+ fired=True,
186
+ reason=(
187
+ f"PSI drift score {drift_report.overall_drift_score:.3f} "
188
+ f"exceeded threshold {threshold}"
189
+ ),
190
+ metadata={
191
+ "psi_score": drift_report.overall_drift_score,
192
+ "threshold": threshold,
193
+ "drifting_features": drift_report.get_drifting_features(),
194
+ },
195
+ )
196
+ else:
197
+ return TriggerResult(
198
+ trigger_type=TriggerType.DRIFT.value,
199
+ fired=False,
200
+ reason=(
201
+ f"PSI drift score {drift_report.overall_drift_score:.3f} "
202
+ f"below threshold {threshold}"
203
+ ),
204
+ metadata={"psi_score": drift_report.overall_drift_score},
205
+ )
206
+
207
+ def _evaluate_schedule_trigger(
208
+ self,
209
+ trigger: RetrainTrigger,
210
+ last_train_time: datetime | None,
211
+ ) -> TriggerResult:
212
+ """Evaluate schedule-based trigger."""
213
+ if last_train_time is None:
214
+ return TriggerResult(
215
+ trigger_type=TriggerType.SCHEDULE.value,
216
+ fired=False,
217
+ reason="No last_train_time provided",
218
+ )
219
+
220
+ if trigger.schedule_interval is None:
221
+ return TriggerResult(
222
+ trigger_type=TriggerType.SCHEDULE.value,
223
+ fired=False,
224
+ reason="No schedule_interval configured",
225
+ )
226
+
227
+ # Parse interval (e.g., "7d", "1h")
228
+ interval = self._parse_interval(trigger.schedule_interval)
229
+ if interval is None:
230
+ return TriggerResult(
231
+ trigger_type=TriggerType.SCHEDULE.value,
232
+ fired=False,
233
+ reason=f"Invalid schedule_interval: {trigger.schedule_interval}",
234
+ )
235
+
236
+ # Ensure last_train_time is timezone-aware
237
+ if last_train_time.tzinfo is None:
238
+ last_train_time = last_train_time.replace(tzinfo=UTC)
239
+
240
+ next_train_time = last_train_time + interval
241
+ now = datetime.now(UTC)
242
+
243
+ if now >= next_train_time:
244
+ return TriggerResult(
245
+ trigger_type=TriggerType.SCHEDULE.value,
246
+ fired=True,
247
+ reason=(
248
+ f"Schedule interval {trigger.schedule_interval} elapsed "
249
+ f"since last training at {last_train_time.isoformat()}"
250
+ ),
251
+ metadata={
252
+ "last_train_time": last_train_time.isoformat(),
253
+ "next_train_time": next_train_time.isoformat(),
254
+ "interval": trigger.schedule_interval,
255
+ },
256
+ )
257
+ else:
258
+ return TriggerResult(
259
+ trigger_type=TriggerType.SCHEDULE.value,
260
+ fired=False,
261
+ reason=(
262
+ f"Schedule interval {trigger.schedule_interval} not yet elapsed"
263
+ ),
264
+ metadata={
265
+ "next_train_time": next_train_time.isoformat(),
266
+ },
267
+ )
268
+
269
+ def _evaluate_performance_trigger(
270
+ self,
271
+ trigger: RetrainTrigger,
272
+ current_metrics: dict[str, float] | None,
273
+ ) -> TriggerResult:
274
+ """Evaluate performance-based trigger."""
275
+ if current_metrics is None:
276
+ return TriggerResult(
277
+ trigger_type=TriggerType.PERFORMANCE.value,
278
+ fired=False,
279
+ reason="No current metrics provided",
280
+ )
281
+
282
+ if trigger.metric_name is None:
283
+ return TriggerResult(
284
+ trigger_type=TriggerType.PERFORMANCE.value,
285
+ fired=False,
286
+ reason="No metric_name configured",
287
+ )
288
+
289
+ if trigger.metric_name not in current_metrics:
290
+ return TriggerResult(
291
+ trigger_type=TriggerType.PERFORMANCE.value,
292
+ fired=False,
293
+ reason=f"Metric '{trigger.metric_name}' not found in current metrics",
294
+ )
295
+
296
+ metric_value = current_metrics[trigger.metric_name]
297
+ threshold = trigger.threshold or 0.1
298
+
299
+ # For error metrics like MAPE, higher is worse
300
+ if metric_value > threshold:
301
+ return TriggerResult(
302
+ trigger_type=TriggerType.PERFORMANCE.value,
303
+ fired=True,
304
+ reason=(
305
+ f"Metric '{trigger.metric_name}' value {metric_value:.4f} "
306
+ f"exceeded threshold {threshold}"
307
+ ),
308
+ metadata={
309
+ "metric_name": trigger.metric_name,
310
+ "metric_value": metric_value,
311
+ "threshold": threshold,
312
+ },
313
+ )
314
+ else:
315
+ return TriggerResult(
316
+ trigger_type=TriggerType.PERFORMANCE.value,
317
+ fired=False,
318
+ reason=(
319
+ f"Metric '{trigger.metric_name}' value {metric_value:.4f} "
320
+ f"within threshold {threshold}"
321
+ ),
322
+ metadata={
323
+ "metric_name": trigger.metric_name,
324
+ "metric_value": metric_value,
325
+ },
326
+ )
327
+
328
+ def _evaluate_manual_trigger(
329
+ self,
330
+ trigger: RetrainTrigger,
331
+ ) -> TriggerResult:
332
+ """Evaluate manual trigger.
333
+
334
+ Manual triggers are always "no-op" unless explicitly set to fire.
335
+ In practice, a manual trigger would be checked via an external signal.
336
+ """
337
+ return TriggerResult(
338
+ trigger_type=TriggerType.MANUAL.value,
339
+ fired=False,
340
+ reason="Manual trigger not activated (use manual retrain API)",
341
+ )
342
+
343
+ def _parse_interval(self, interval_str: str) -> timedelta | None:
344
+ """Parse interval string into timedelta.
345
+
346
+ Args:
347
+ interval_str: Interval string like "7d", "1h", "30m"
348
+
349
+ Returns:
350
+ timedelta or None if invalid
351
+ """
352
+ if not interval_str:
353
+ return None
354
+
355
+ try:
356
+ # Extract number and unit
357
+ num = int("".join(filter(str.isdigit, interval_str)))
358
+ unit = "".join(filter(str.isalpha, interval_str)).lower()
359
+
360
+ if unit in ("d", "day", "days"):
361
+ return timedelta(days=num)
362
+ elif unit in ("h", "hour", "hours"):
363
+ return timedelta(hours=num)
364
+ elif unit in ("m", "min", "minute", "minutes"):
365
+ return timedelta(minutes=num)
366
+ elif unit in ("w", "week", "weeks"):
367
+ return timedelta(weeks=num)
368
+ else:
369
+ return None
370
+ except (ValueError, TypeError):
371
+ return None
372
+
373
+ def should_retrain(
374
+ self,
375
+ drift_report: DriftReport | None = None,
376
+ stability_report: StabilityReport | None = None,
377
+ current_metrics: dict[str, float] | None = None,
378
+ last_train_time: datetime | None = None,
379
+ ) -> bool:
380
+ """Check if any trigger indicates retraining is needed.
381
+
382
+ Args:
383
+ drift_report: Drift detection results
384
+ stability_report: Stability metrics
385
+ current_metrics: Current performance metrics
386
+ last_train_time: Last training timestamp
387
+
388
+ Returns:
389
+ True if any trigger fired, False otherwise
390
+ """
391
+ results = self.evaluate(
392
+ drift_report,
393
+ stability_report,
394
+ current_metrics,
395
+ last_train_time,
396
+ )
397
+ return any(r.fired for r in results)
398
+
399
+ def get_fired_triggers(
400
+ self,
401
+ drift_report: DriftReport | None = None,
402
+ stability_report: StabilityReport | None = None,
403
+ current_metrics: dict[str, float] | None = None,
404
+ last_train_time: datetime | None = None,
405
+ ) -> list[TriggerResult]:
406
+ """Get only the triggers that fired.
407
+
408
+ Args:
409
+ drift_report: Drift detection results
410
+ stability_report: Stability metrics
411
+ current_metrics: Current performance metrics
412
+ last_train_time: Last training timestamp
413
+
414
+ Returns:
415
+ List of TriggerResult for triggers that fired
416
+ """
417
+ results = self.evaluate(
418
+ drift_report,
419
+ stability_report,
420
+ current_metrics,
421
+ last_train_time,
422
+ )
423
+ return [r for r in results if r.fired]