tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Drift detection using PSI and KS tests."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from tsagentkit.monitoring.report import DriftReport, FeatureDriftResult
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DriftDetector:
|
|
17
|
+
"""Detect data drift between reference and current distributions.
|
|
18
|
+
|
|
19
|
+
Supports two methods:
|
|
20
|
+
- PSI (Population Stability Index): Industry standard for distribution drift
|
|
21
|
+
- KS (Kolmogorov-Smirnov): Statistical test for distribution differences
|
|
22
|
+
|
|
23
|
+
PSI interpretation:
|
|
24
|
+
- < 0.1: No significant change
|
|
25
|
+
- 0.1 - 0.2: Moderate change
|
|
26
|
+
- > 0.2: Significant change (drift detected)
|
|
27
|
+
|
|
28
|
+
KS interpretation:
|
|
29
|
+
- p-value < 0.05: Statistically significant difference (drift detected)
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> detector = DriftDetector(method="psi", threshold=0.2)
|
|
33
|
+
>>> report = detector.detect(reference_data, current_data, features=["sales"])
|
|
34
|
+
>>> if report.drift_detected:
|
|
35
|
+
... print(f"Drift detected in features: {report.get_drifting_features()}")
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
method: Literal["psi", "ks"] = "psi",
|
|
41
|
+
threshold: float | None = None,
|
|
42
|
+
n_bins: int = 10,
|
|
43
|
+
):
|
|
44
|
+
"""Initialize drift detector.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
method: Drift detection method ("psi" or "ks")
|
|
48
|
+
threshold: Threshold for drift detection.
|
|
49
|
+
Default: 0.2 for PSI, 0.05 for KS (p-value)
|
|
50
|
+
n_bins: Number of bins for PSI calculation
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If invalid method specified
|
|
54
|
+
"""
|
|
55
|
+
if method not in ("psi", "ks"):
|
|
56
|
+
raise ValueError(f"Method must be 'psi' or 'ks', got {method}")
|
|
57
|
+
|
|
58
|
+
self.method = method
|
|
59
|
+
self.n_bins = n_bins
|
|
60
|
+
|
|
61
|
+
# Set default thresholds
|
|
62
|
+
if threshold is None:
|
|
63
|
+
self.threshold = 0.2 if method == "psi" else 0.05
|
|
64
|
+
else:
|
|
65
|
+
self.threshold = threshold
|
|
66
|
+
|
|
67
|
+
def detect(
|
|
68
|
+
self,
|
|
69
|
+
reference_data: pd.DataFrame,
|
|
70
|
+
current_data: pd.DataFrame,
|
|
71
|
+
features: list[str] | None = None,
|
|
72
|
+
) -> DriftReport:
|
|
73
|
+
"""Detect drift between reference and current datasets.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
reference_data: Baseline/reference distribution (training data)
|
|
77
|
+
current_data: Current data to compare (recent observations)
|
|
78
|
+
features: List of features to check (defaults to numeric columns)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
DriftReport with per-feature and overall results
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> detector = DriftDetector(method="psi")
|
|
85
|
+
>>> report = detector.detect(
|
|
86
|
+
... reference_data=train_df,
|
|
87
|
+
... current_data=recent_df,
|
|
88
|
+
... features=["sales", "price"]
|
|
89
|
+
... )
|
|
90
|
+
>>> print(report.overall_drift_score)
|
|
91
|
+
0.15
|
|
92
|
+
"""
|
|
93
|
+
# Auto-select numeric features if not specified
|
|
94
|
+
if features is None:
|
|
95
|
+
features = reference_data.select_dtypes(
|
|
96
|
+
include=[np.number]
|
|
97
|
+
).columns.tolist()
|
|
98
|
+
# Exclude common non-feature columns
|
|
99
|
+
exclude = {"unique_id", "ds", "y", "timestamp"}
|
|
100
|
+
features = [f for f in features if f not in exclude]
|
|
101
|
+
|
|
102
|
+
feature_drifts: dict[str, FeatureDriftResult] = {}
|
|
103
|
+
drift_scores = []
|
|
104
|
+
|
|
105
|
+
for feature in features:
|
|
106
|
+
if feature not in reference_data.columns:
|
|
107
|
+
continue
|
|
108
|
+
if feature not in current_data.columns:
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
result = self._analyze_feature(
|
|
112
|
+
reference_data[feature],
|
|
113
|
+
current_data[feature],
|
|
114
|
+
feature,
|
|
115
|
+
)
|
|
116
|
+
feature_drifts[feature] = result
|
|
117
|
+
drift_scores.append(result.statistic)
|
|
118
|
+
|
|
119
|
+
# Calculate overall drift score (mean of individual statistics)
|
|
120
|
+
overall_drift = np.mean(drift_scores) if drift_scores else 0.0
|
|
121
|
+
|
|
122
|
+
# Determine if drift detected based on method
|
|
123
|
+
if self.method == "psi":
|
|
124
|
+
drift_detected = overall_drift > self.threshold
|
|
125
|
+
else: # ks
|
|
126
|
+
# For KS, drift detected if any feature has p-value < threshold
|
|
127
|
+
# Overall score is the max KS statistic across features
|
|
128
|
+
p_values = [r.p_value for r in feature_drifts.values() if r.p_value is not None]
|
|
129
|
+
drift_detected = any(p < self.threshold for p in p_values) if p_values else False
|
|
130
|
+
|
|
131
|
+
return DriftReport(
|
|
132
|
+
drift_detected=drift_detected,
|
|
133
|
+
feature_drifts=feature_drifts,
|
|
134
|
+
overall_drift_score=float(overall_drift),
|
|
135
|
+
threshold_used=self.threshold,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _analyze_feature(
|
|
139
|
+
self,
|
|
140
|
+
reference: pd.Series,
|
|
141
|
+
current: pd.Series,
|
|
142
|
+
feature_name: str,
|
|
143
|
+
) -> FeatureDriftResult:
|
|
144
|
+
"""Analyze drift for a single feature.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
reference: Reference distribution
|
|
148
|
+
current: Current distribution
|
|
149
|
+
feature_name: Name of the feature
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
FeatureDriftResult with drift statistics
|
|
153
|
+
"""
|
|
154
|
+
# Remove NaN values
|
|
155
|
+
ref_values = reference.dropna().values
|
|
156
|
+
cur_values = current.dropna().values
|
|
157
|
+
|
|
158
|
+
if self.method == "psi":
|
|
159
|
+
statistic = self._compute_psi(ref_values, cur_values, self.n_bins)
|
|
160
|
+
p_value = None
|
|
161
|
+
drift_detected = statistic > self.threshold
|
|
162
|
+
else: # ks
|
|
163
|
+
statistic, p_value = self._compute_ks_test(ref_values, cur_values)
|
|
164
|
+
# For KS, drift detected if p-value < threshold
|
|
165
|
+
drift_detected = p_value < self.threshold
|
|
166
|
+
|
|
167
|
+
# Compute distribution summaries
|
|
168
|
+
ref_dist = self._summarize_distribution(ref_values)
|
|
169
|
+
cur_dist = self._summarize_distribution(cur_values)
|
|
170
|
+
|
|
171
|
+
return FeatureDriftResult(
|
|
172
|
+
feature_name=feature_name,
|
|
173
|
+
metric=self.method,
|
|
174
|
+
statistic=float(statistic),
|
|
175
|
+
p_value=float(p_value) if p_value is not None else None,
|
|
176
|
+
drift_detected=drift_detected,
|
|
177
|
+
reference_distribution=ref_dist,
|
|
178
|
+
current_distribution=cur_dist,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def _compute_psi(
|
|
182
|
+
self,
|
|
183
|
+
reference: np.ndarray,
|
|
184
|
+
current: np.ndarray,
|
|
185
|
+
n_bins: int = 10,
|
|
186
|
+
) -> float:
|
|
187
|
+
"""Compute Population Stability Index.
|
|
188
|
+
|
|
189
|
+
PSI = sum((Actual% - Expected%) * ln(Actual% / Expected%))
|
|
190
|
+
|
|
191
|
+
PSI interpretation:
|
|
192
|
+
- < 0.1: No significant change
|
|
193
|
+
- 0.1 - 0.2: Moderate change
|
|
194
|
+
- > 0.2: Significant change
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
reference: Reference distribution values
|
|
198
|
+
current: Current distribution values
|
|
199
|
+
n_bins: Number of bins for discretization
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
PSI value (float)
|
|
203
|
+
"""
|
|
204
|
+
if len(reference) == 0 or len(current) == 0:
|
|
205
|
+
return 0.0
|
|
206
|
+
|
|
207
|
+
# Create bins based on reference distribution
|
|
208
|
+
min_val, max_val = reference.min(), reference.max()
|
|
209
|
+
|
|
210
|
+
# Handle constant reference
|
|
211
|
+
if min_val == max_val:
|
|
212
|
+
return 0.0 if current.min() == current.max() == min_val else 1.0
|
|
213
|
+
|
|
214
|
+
# Create bins
|
|
215
|
+
bins = np.linspace(min_val, max_val, n_bins + 1)
|
|
216
|
+
bins[-1] += 1e-10 # Ensure max value is included
|
|
217
|
+
|
|
218
|
+
# Compute histograms
|
|
219
|
+
ref_hist, _ = np.histogram(reference, bins=bins)
|
|
220
|
+
cur_hist, _ = np.histogram(current, bins=bins)
|
|
221
|
+
|
|
222
|
+
# Convert to probabilities
|
|
223
|
+
ref_pct = ref_hist / len(reference) + 1e-10 # Add epsilon to avoid log(0)
|
|
224
|
+
cur_pct = cur_hist / len(current) + 1e-10
|
|
225
|
+
|
|
226
|
+
# Compute PSI
|
|
227
|
+
psi_values = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
|
|
228
|
+
psi = np.sum(psi_values)
|
|
229
|
+
|
|
230
|
+
return float(psi)
|
|
231
|
+
|
|
232
|
+
def _compute_ks_test(
|
|
233
|
+
self,
|
|
234
|
+
reference: np.ndarray,
|
|
235
|
+
current: np.ndarray,
|
|
236
|
+
) -> tuple[float, float]:
|
|
237
|
+
"""Compute Kolmogorov-Smirnov test.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
reference: Reference distribution values
|
|
241
|
+
current: Current distribution values
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Tuple of (statistic, p_value)
|
|
245
|
+
"""
|
|
246
|
+
from scipy import stats
|
|
247
|
+
|
|
248
|
+
if len(reference) == 0 or len(current) == 0:
|
|
249
|
+
return 0.0, 1.0
|
|
250
|
+
|
|
251
|
+
statistic, p_value = stats.ks_2samp(reference, current)
|
|
252
|
+
return float(statistic), float(p_value)
|
|
253
|
+
|
|
254
|
+
def _summarize_distribution(self, values: np.ndarray) -> dict:
|
|
255
|
+
"""Create a summary of a distribution.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
values: Array of values
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Dictionary with distribution statistics
|
|
262
|
+
"""
|
|
263
|
+
if len(values) == 0:
|
|
264
|
+
return {"count": 0}
|
|
265
|
+
|
|
266
|
+
return {
|
|
267
|
+
"count": int(len(values)),
|
|
268
|
+
"mean": float(np.mean(values)),
|
|
269
|
+
"std": float(np.std(values)),
|
|
270
|
+
"min": float(np.min(values)),
|
|
271
|
+
"max": float(np.max(values)),
|
|
272
|
+
"median": float(np.median(values)),
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def compute_psi_summary(
|
|
277
|
+
reference: pd.Series | np.ndarray,
|
|
278
|
+
current: pd.Series | np.ndarray,
|
|
279
|
+
n_bins: int = 10,
|
|
280
|
+
) -> dict:
|
|
281
|
+
"""Compute detailed PSI breakdown by bin.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
reference: Reference distribution
|
|
285
|
+
current: Current distribution
|
|
286
|
+
n_bins: Number of bins
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Dictionary with PSI breakdown
|
|
290
|
+
"""
|
|
291
|
+
ref_values = np.asarray(reference.dropna() if hasattr(reference, "dropna") else reference)
|
|
292
|
+
cur_values = np.asarray(current.dropna() if hasattr(current, "dropna") else current)
|
|
293
|
+
|
|
294
|
+
if len(ref_values) == 0 or len(cur_values) == 0:
|
|
295
|
+
return {"psi": 0.0, "bins": []}
|
|
296
|
+
|
|
297
|
+
# Create bins
|
|
298
|
+
min_val, max_val = ref_values.min(), ref_values.max()
|
|
299
|
+
if min_val == max_val:
|
|
300
|
+
return {"psi": 0.0, "bins": [], "message": "Constant reference distribution"}
|
|
301
|
+
|
|
302
|
+
bins = np.linspace(min_val, max_val, n_bins + 1)
|
|
303
|
+
bins[-1] += 1e-10
|
|
304
|
+
|
|
305
|
+
# Compute histograms
|
|
306
|
+
ref_hist, bin_edges = np.histogram(ref_values, bins=bins)
|
|
307
|
+
cur_hist, _ = np.histogram(cur_values, bins=bins)
|
|
308
|
+
|
|
309
|
+
ref_pct = ref_hist / len(ref_values) + 1e-10
|
|
310
|
+
cur_pct = cur_hist / len(cur_values) + 1e-10
|
|
311
|
+
|
|
312
|
+
# Compute per-bin PSI
|
|
313
|
+
bin_psi = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
|
|
314
|
+
|
|
315
|
+
bins_data = []
|
|
316
|
+
for i in range(n_bins):
|
|
317
|
+
bins_data.append({
|
|
318
|
+
"bin_start": float(bin_edges[i]),
|
|
319
|
+
"bin_end": float(bin_edges[i + 1]),
|
|
320
|
+
"reference_count": int(ref_hist[i]),
|
|
321
|
+
"reference_pct": float(ref_pct[i] - 1e-10),
|
|
322
|
+
"current_count": int(cur_hist[i]),
|
|
323
|
+
"current_pct": float(cur_pct[i] - 1e-10),
|
|
324
|
+
"psi": float(bin_psi[i]),
|
|
325
|
+
})
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
"psi": float(np.sum(bin_psi)),
|
|
329
|
+
"bins": bins_data,
|
|
330
|
+
}
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""Report dataclasses for monitoring results."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class FeatureDriftResult:
|
|
11
|
+
"""Drift result for a single feature.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
feature_name: Name of the feature analyzed
|
|
15
|
+
metric: Drift metric used ("psi" or "ks")
|
|
16
|
+
statistic: Drift statistic value
|
|
17
|
+
p_value: P-value for statistical test (KS only)
|
|
18
|
+
drift_detected: Whether drift was detected for this feature
|
|
19
|
+
reference_distribution: Summary of reference distribution
|
|
20
|
+
current_distribution: Summary of current distribution
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> result = FeatureDriftResult(
|
|
24
|
+
... feature_name="sales",
|
|
25
|
+
... metric="psi",
|
|
26
|
+
... statistic=0.25,
|
|
27
|
+
... p_value=None,
|
|
28
|
+
... drift_detected=True,
|
|
29
|
+
... reference_distribution={"mean": 100.0, "std": 15.0},
|
|
30
|
+
... current_distribution={"mean": 120.0, "std": 20.0},
|
|
31
|
+
... )
|
|
32
|
+
>>> print(result)
|
|
33
|
+
FeatureDriftResult(sales, psi=0.250, drift=True)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
feature_name: str
|
|
37
|
+
metric: str
|
|
38
|
+
statistic: float
|
|
39
|
+
p_value: float | None
|
|
40
|
+
drift_detected: bool
|
|
41
|
+
reference_distribution: dict
|
|
42
|
+
current_distribution: dict
|
|
43
|
+
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
return (
|
|
46
|
+
f"FeatureDriftResult({self.feature_name}, "
|
|
47
|
+
f"{self.metric}={self.statistic:.3f}, "
|
|
48
|
+
f"drift={self.drift_detected})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class DriftReport:
|
|
54
|
+
"""Report from drift detection analysis.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
drift_detected: Whether any drift was detected
|
|
58
|
+
feature_drifts: Dict mapping feature names to drift results
|
|
59
|
+
overall_drift_score: Aggregated drift score across all features
|
|
60
|
+
threshold_used: Threshold used for drift detection
|
|
61
|
+
reference_timestamp: Timestamp of reference data
|
|
62
|
+
current_timestamp: Timestamp of current data
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> report = DriftReport(
|
|
66
|
+
... drift_detected=True,
|
|
67
|
+
... feature_drifts={"sales": feature_result},
|
|
68
|
+
... overall_drift_score=0.25,
|
|
69
|
+
... threshold_used=0.2,
|
|
70
|
+
... )
|
|
71
|
+
>>> print(report.summary())
|
|
72
|
+
Drift detected in 1/1 features. Overall score: 0.250
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
drift_detected: bool
|
|
76
|
+
feature_drifts: dict[str, FeatureDriftResult]
|
|
77
|
+
overall_drift_score: float
|
|
78
|
+
threshold_used: float
|
|
79
|
+
reference_timestamp: str = field(
|
|
80
|
+
default_factory=lambda: datetime.now(UTC).isoformat()
|
|
81
|
+
)
|
|
82
|
+
current_timestamp: str = field(
|
|
83
|
+
default_factory=lambda: datetime.now(UTC).isoformat()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def summary(self) -> str:
|
|
87
|
+
"""Generate a human-readable summary of the drift report."""
|
|
88
|
+
n_drifting = sum(1 for r in self.feature_drifts.values() if r.drift_detected)
|
|
89
|
+
n_total = len(self.feature_drifts)
|
|
90
|
+
return (
|
|
91
|
+
f"Drift detected in {n_drifting}/{n_total} features. "
|
|
92
|
+
f"Overall score: {self.overall_drift_score:.3f} "
|
|
93
|
+
f"(threshold: {self.threshold_used})"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def get_drifting_features(self) -> list[str]:
|
|
97
|
+
"""Return list of feature names with detected drift."""
|
|
98
|
+
return [
|
|
99
|
+
name for name, result in self.feature_drifts.items()
|
|
100
|
+
if result.drift_detected
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass(frozen=True)
|
|
105
|
+
class CalibrationReport:
|
|
106
|
+
"""Report on quantile calibration.
|
|
107
|
+
|
|
108
|
+
Attributes:
|
|
109
|
+
target_quantiles: List of target quantile levels
|
|
110
|
+
empirical_coverage: Dict mapping quantile to empirical coverage
|
|
111
|
+
calibration_errors: Dict mapping quantile to calibration error
|
|
112
|
+
well_calibrated: Whether all quantiles are well-calibrated
|
|
113
|
+
tolerance: Tolerance used for calibration check
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
>>> report = CalibrationReport(
|
|
117
|
+
... target_quantiles=[0.1, 0.5, 0.9],
|
|
118
|
+
... empirical_coverage={0.1: 0.08, 0.5: 0.52, 0.9: 0.91},
|
|
119
|
+
... calibration_errors={0.1: 0.02, 0.5: 0.02, 0.9: 0.01},
|
|
120
|
+
... well_calibrated=True,
|
|
121
|
+
... tolerance=0.05,
|
|
122
|
+
... )
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
target_quantiles: list[float]
|
|
126
|
+
empirical_coverage: dict[float, float]
|
|
127
|
+
calibration_errors: dict[float, float]
|
|
128
|
+
well_calibrated: bool
|
|
129
|
+
tolerance: float
|
|
130
|
+
|
|
131
|
+
def summary(self) -> str:
|
|
132
|
+
"""Generate a human-readable summary."""
|
|
133
|
+
errors_str = ", ".join(
|
|
134
|
+
f"q={q:.2f}: err={e:.3f}"
|
|
135
|
+
for q, e in self.calibration_errors.items()
|
|
136
|
+
)
|
|
137
|
+
status = "well-calibrated" if self.well_calibrated else "poorly-calibrated"
|
|
138
|
+
return f"Calibration ({status}): {errors_str}"
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass(frozen=True)
|
|
142
|
+
class StabilityReport:
|
|
143
|
+
"""Report on prediction stability.
|
|
144
|
+
|
|
145
|
+
Attributes:
|
|
146
|
+
jitter_metrics: Dict mapping series_id to jitter metric
|
|
147
|
+
overall_jitter: Aggregate jitter across all series
|
|
148
|
+
jitter_threshold: Threshold used for jitter evaluation
|
|
149
|
+
high_jitter_series: List of series with high jitter
|
|
150
|
+
coverage_report: Optional calibration report for quantiles
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> report = StabilityReport(
|
|
154
|
+
... jitter_metrics={"A": 0.05, "B": 0.15},
|
|
155
|
+
... overall_jitter=0.10,
|
|
156
|
+
... jitter_threshold=0.10,
|
|
157
|
+
... high_jitter_series=["B"],
|
|
158
|
+
... )
|
|
159
|
+
>>> print(report.is_stable)
|
|
160
|
+
False
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
jitter_metrics: dict[str, float]
|
|
164
|
+
overall_jitter: float
|
|
165
|
+
jitter_threshold: float
|
|
166
|
+
high_jitter_series: list[str]
|
|
167
|
+
coverage_report: CalibrationReport | None = None
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def is_stable(self) -> bool:
|
|
171
|
+
"""Whether predictions are considered stable."""
|
|
172
|
+
return len(self.high_jitter_series) == 0
|
|
173
|
+
|
|
174
|
+
def summary(self) -> str:
|
|
175
|
+
"""Generate a human-readable summary."""
|
|
176
|
+
status = "stable" if self.is_stable else "unstable"
|
|
177
|
+
n_high = len(self.high_jitter_series)
|
|
178
|
+
return (
|
|
179
|
+
f"Stability ({status}): overall_jitter={self.overall_jitter:.3f}, "
|
|
180
|
+
f"{n_high} series exceed threshold"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclass(frozen=True)
|
|
185
|
+
class TriggerResult:
|
|
186
|
+
"""Result of a retrain trigger evaluation.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
trigger_type: Type of trigger that was evaluated
|
|
190
|
+
fired: Whether the trigger fired
|
|
191
|
+
reason: Human-readable reason for trigger firing or not
|
|
192
|
+
timestamp: When the trigger was evaluated
|
|
193
|
+
metadata: Additional trigger-specific metadata
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> result = TriggerResult(
|
|
197
|
+
... trigger_type=TriggerType.DRIFT,
|
|
198
|
+
... fired=True,
|
|
199
|
+
... reason="PSI drift score 0.25 exceeded threshold 0.2",
|
|
200
|
+
... metadata={"psi_score": 0.25, "threshold": 0.2},
|
|
201
|
+
... )
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
trigger_type: str
|
|
205
|
+
fired: bool
|
|
206
|
+
reason: str
|
|
207
|
+
timestamp: str = field(
|
|
208
|
+
default_factory=lambda: datetime.now(UTC).isoformat()
|
|
209
|
+
)
|
|
210
|
+
metadata: dict = field(default_factory=dict)
|
|
211
|
+
|
|
212
|
+
def __repr__(self) -> str:
|
|
213
|
+
status = "FIRED" if self.fired else "no-op"
|
|
214
|
+
return f"TriggerResult({self.trigger_type}, {status})"
|