tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,330 @@
1
+ """Drift detection using PSI and KS tests."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from tsagentkit.monitoring.report import DriftReport, FeatureDriftResult
11
+
12
+ if TYPE_CHECKING:
13
+ pass
14
+
15
+
16
+ class DriftDetector:
17
+ """Detect data drift between reference and current distributions.
18
+
19
+ Supports two methods:
20
+ - PSI (Population Stability Index): Industry standard for distribution drift
21
+ - KS (Kolmogorov-Smirnov): Statistical test for distribution differences
22
+
23
+ PSI interpretation:
24
+ - < 0.1: No significant change
25
+ - 0.1 - 0.2: Moderate change
26
+ - > 0.2: Significant change (drift detected)
27
+
28
+ KS interpretation:
29
+ - p-value < 0.05: Statistically significant difference (drift detected)
30
+
31
+ Example:
32
+ >>> detector = DriftDetector(method="psi", threshold=0.2)
33
+ >>> report = detector.detect(reference_data, current_data, features=["sales"])
34
+ >>> if report.drift_detected:
35
+ ... print(f"Drift detected in features: {report.get_drifting_features()}")
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ method: Literal["psi", "ks"] = "psi",
41
+ threshold: float | None = None,
42
+ n_bins: int = 10,
43
+ ):
44
+ """Initialize drift detector.
45
+
46
+ Args:
47
+ method: Drift detection method ("psi" or "ks")
48
+ threshold: Threshold for drift detection.
49
+ Default: 0.2 for PSI, 0.05 for KS (p-value)
50
+ n_bins: Number of bins for PSI calculation
51
+
52
+ Raises:
53
+ ValueError: If invalid method specified
54
+ """
55
+ if method not in ("psi", "ks"):
56
+ raise ValueError(f"Method must be 'psi' or 'ks', got {method}")
57
+
58
+ self.method = method
59
+ self.n_bins = n_bins
60
+
61
+ # Set default thresholds
62
+ if threshold is None:
63
+ self.threshold = 0.2 if method == "psi" else 0.05
64
+ else:
65
+ self.threshold = threshold
66
+
67
+ def detect(
68
+ self,
69
+ reference_data: pd.DataFrame,
70
+ current_data: pd.DataFrame,
71
+ features: list[str] | None = None,
72
+ ) -> DriftReport:
73
+ """Detect drift between reference and current datasets.
74
+
75
+ Args:
76
+ reference_data: Baseline/reference distribution (training data)
77
+ current_data: Current data to compare (recent observations)
78
+ features: List of features to check (defaults to numeric columns)
79
+
80
+ Returns:
81
+ DriftReport with per-feature and overall results
82
+
83
+ Example:
84
+ >>> detector = DriftDetector(method="psi")
85
+ >>> report = detector.detect(
86
+ ... reference_data=train_df,
87
+ ... current_data=recent_df,
88
+ ... features=["sales", "price"]
89
+ ... )
90
+ >>> print(report.overall_drift_score)
91
+ 0.15
92
+ """
93
+ # Auto-select numeric features if not specified
94
+ if features is None:
95
+ features = reference_data.select_dtypes(
96
+ include=[np.number]
97
+ ).columns.tolist()
98
+ # Exclude common non-feature columns
99
+ exclude = {"unique_id", "ds", "y", "timestamp"}
100
+ features = [f for f in features if f not in exclude]
101
+
102
+ feature_drifts: dict[str, FeatureDriftResult] = {}
103
+ drift_scores = []
104
+
105
+ for feature in features:
106
+ if feature not in reference_data.columns:
107
+ continue
108
+ if feature not in current_data.columns:
109
+ continue
110
+
111
+ result = self._analyze_feature(
112
+ reference_data[feature],
113
+ current_data[feature],
114
+ feature,
115
+ )
116
+ feature_drifts[feature] = result
117
+ drift_scores.append(result.statistic)
118
+
119
+ # Calculate overall drift score (mean of individual statistics)
120
+ overall_drift = np.mean(drift_scores) if drift_scores else 0.0
121
+
122
+ # Determine if drift detected based on method
123
+ if self.method == "psi":
124
+ drift_detected = overall_drift > self.threshold
125
+ else: # ks
126
+ # For KS, drift detected if any feature has p-value < threshold
127
+ # Overall score is the max KS statistic across features
128
+ p_values = [r.p_value for r in feature_drifts.values() if r.p_value is not None]
129
+ drift_detected = any(p < self.threshold for p in p_values) if p_values else False
130
+
131
+ return DriftReport(
132
+ drift_detected=drift_detected,
133
+ feature_drifts=feature_drifts,
134
+ overall_drift_score=float(overall_drift),
135
+ threshold_used=self.threshold,
136
+ )
137
+
138
+ def _analyze_feature(
139
+ self,
140
+ reference: pd.Series,
141
+ current: pd.Series,
142
+ feature_name: str,
143
+ ) -> FeatureDriftResult:
144
+ """Analyze drift for a single feature.
145
+
146
+ Args:
147
+ reference: Reference distribution
148
+ current: Current distribution
149
+ feature_name: Name of the feature
150
+
151
+ Returns:
152
+ FeatureDriftResult with drift statistics
153
+ """
154
+ # Remove NaN values
155
+ ref_values = reference.dropna().values
156
+ cur_values = current.dropna().values
157
+
158
+ if self.method == "psi":
159
+ statistic = self._compute_psi(ref_values, cur_values, self.n_bins)
160
+ p_value = None
161
+ drift_detected = statistic > self.threshold
162
+ else: # ks
163
+ statistic, p_value = self._compute_ks_test(ref_values, cur_values)
164
+ # For KS, drift detected if p-value < threshold
165
+ drift_detected = p_value < self.threshold
166
+
167
+ # Compute distribution summaries
168
+ ref_dist = self._summarize_distribution(ref_values)
169
+ cur_dist = self._summarize_distribution(cur_values)
170
+
171
+ return FeatureDriftResult(
172
+ feature_name=feature_name,
173
+ metric=self.method,
174
+ statistic=float(statistic),
175
+ p_value=float(p_value) if p_value is not None else None,
176
+ drift_detected=drift_detected,
177
+ reference_distribution=ref_dist,
178
+ current_distribution=cur_dist,
179
+ )
180
+
181
+ def _compute_psi(
182
+ self,
183
+ reference: np.ndarray,
184
+ current: np.ndarray,
185
+ n_bins: int = 10,
186
+ ) -> float:
187
+ """Compute Population Stability Index.
188
+
189
+ PSI = sum((Actual% - Expected%) * ln(Actual% / Expected%))
190
+
191
+ PSI interpretation:
192
+ - < 0.1: No significant change
193
+ - 0.1 - 0.2: Moderate change
194
+ - > 0.2: Significant change
195
+
196
+ Args:
197
+ reference: Reference distribution values
198
+ current: Current distribution values
199
+ n_bins: Number of bins for discretization
200
+
201
+ Returns:
202
+ PSI value (float)
203
+ """
204
+ if len(reference) == 0 or len(current) == 0:
205
+ return 0.0
206
+
207
+ # Create bins based on reference distribution
208
+ min_val, max_val = reference.min(), reference.max()
209
+
210
+ # Handle constant reference
211
+ if min_val == max_val:
212
+ return 0.0 if current.min() == current.max() == min_val else 1.0
213
+
214
+ # Create bins
215
+ bins = np.linspace(min_val, max_val, n_bins + 1)
216
+ bins[-1] += 1e-10 # Ensure max value is included
217
+
218
+ # Compute histograms
219
+ ref_hist, _ = np.histogram(reference, bins=bins)
220
+ cur_hist, _ = np.histogram(current, bins=bins)
221
+
222
+ # Convert to probabilities
223
+ ref_pct = ref_hist / len(reference) + 1e-10 # Add epsilon to avoid log(0)
224
+ cur_pct = cur_hist / len(current) + 1e-10
225
+
226
+ # Compute PSI
227
+ psi_values = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
228
+ psi = np.sum(psi_values)
229
+
230
+ return float(psi)
231
+
232
+ def _compute_ks_test(
233
+ self,
234
+ reference: np.ndarray,
235
+ current: np.ndarray,
236
+ ) -> tuple[float, float]:
237
+ """Compute Kolmogorov-Smirnov test.
238
+
239
+ Args:
240
+ reference: Reference distribution values
241
+ current: Current distribution values
242
+
243
+ Returns:
244
+ Tuple of (statistic, p_value)
245
+ """
246
+ from scipy import stats
247
+
248
+ if len(reference) == 0 or len(current) == 0:
249
+ return 0.0, 1.0
250
+
251
+ statistic, p_value = stats.ks_2samp(reference, current)
252
+ return float(statistic), float(p_value)
253
+
254
+ def _summarize_distribution(self, values: np.ndarray) -> dict:
255
+ """Create a summary of a distribution.
256
+
257
+ Args:
258
+ values: Array of values
259
+
260
+ Returns:
261
+ Dictionary with distribution statistics
262
+ """
263
+ if len(values) == 0:
264
+ return {"count": 0}
265
+
266
+ return {
267
+ "count": int(len(values)),
268
+ "mean": float(np.mean(values)),
269
+ "std": float(np.std(values)),
270
+ "min": float(np.min(values)),
271
+ "max": float(np.max(values)),
272
+ "median": float(np.median(values)),
273
+ }
274
+
275
+
276
+ def compute_psi_summary(
277
+ reference: pd.Series | np.ndarray,
278
+ current: pd.Series | np.ndarray,
279
+ n_bins: int = 10,
280
+ ) -> dict:
281
+ """Compute detailed PSI breakdown by bin.
282
+
283
+ Args:
284
+ reference: Reference distribution
285
+ current: Current distribution
286
+ n_bins: Number of bins
287
+
288
+ Returns:
289
+ Dictionary with PSI breakdown
290
+ """
291
+ ref_values = np.asarray(reference.dropna() if hasattr(reference, "dropna") else reference)
292
+ cur_values = np.asarray(current.dropna() if hasattr(current, "dropna") else current)
293
+
294
+ if len(ref_values) == 0 or len(cur_values) == 0:
295
+ return {"psi": 0.0, "bins": []}
296
+
297
+ # Create bins
298
+ min_val, max_val = ref_values.min(), ref_values.max()
299
+ if min_val == max_val:
300
+ return {"psi": 0.0, "bins": [], "message": "Constant reference distribution"}
301
+
302
+ bins = np.linspace(min_val, max_val, n_bins + 1)
303
+ bins[-1] += 1e-10
304
+
305
+ # Compute histograms
306
+ ref_hist, bin_edges = np.histogram(ref_values, bins=bins)
307
+ cur_hist, _ = np.histogram(cur_values, bins=bins)
308
+
309
+ ref_pct = ref_hist / len(ref_values) + 1e-10
310
+ cur_pct = cur_hist / len(cur_values) + 1e-10
311
+
312
+ # Compute per-bin PSI
313
+ bin_psi = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
314
+
315
+ bins_data = []
316
+ for i in range(n_bins):
317
+ bins_data.append({
318
+ "bin_start": float(bin_edges[i]),
319
+ "bin_end": float(bin_edges[i + 1]),
320
+ "reference_count": int(ref_hist[i]),
321
+ "reference_pct": float(ref_pct[i] - 1e-10),
322
+ "current_count": int(cur_hist[i]),
323
+ "current_pct": float(cur_pct[i] - 1e-10),
324
+ "psi": float(bin_psi[i]),
325
+ })
326
+
327
+ return {
328
+ "psi": float(np.sum(bin_psi)),
329
+ "bins": bins_data,
330
+ }
@@ -0,0 +1,214 @@
1
+ """Report dataclasses for monitoring results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from datetime import UTC, datetime
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class FeatureDriftResult:
11
+ """Drift result for a single feature.
12
+
13
+ Attributes:
14
+ feature_name: Name of the feature analyzed
15
+ metric: Drift metric used ("psi" or "ks")
16
+ statistic: Drift statistic value
17
+ p_value: P-value for statistical test (KS only)
18
+ drift_detected: Whether drift was detected for this feature
19
+ reference_distribution: Summary of reference distribution
20
+ current_distribution: Summary of current distribution
21
+
22
+ Example:
23
+ >>> result = FeatureDriftResult(
24
+ ... feature_name="sales",
25
+ ... metric="psi",
26
+ ... statistic=0.25,
27
+ ... p_value=None,
28
+ ... drift_detected=True,
29
+ ... reference_distribution={"mean": 100.0, "std": 15.0},
30
+ ... current_distribution={"mean": 120.0, "std": 20.0},
31
+ ... )
32
+ >>> print(result)
33
+ FeatureDriftResult(sales, psi=0.250, drift=True)
34
+ """
35
+
36
+ feature_name: str
37
+ metric: str
38
+ statistic: float
39
+ p_value: float | None
40
+ drift_detected: bool
41
+ reference_distribution: dict
42
+ current_distribution: dict
43
+
44
+ def __repr__(self) -> str:
45
+ return (
46
+ f"FeatureDriftResult({self.feature_name}, "
47
+ f"{self.metric}={self.statistic:.3f}, "
48
+ f"drift={self.drift_detected})"
49
+ )
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class DriftReport:
54
+ """Report from drift detection analysis.
55
+
56
+ Attributes:
57
+ drift_detected: Whether any drift was detected
58
+ feature_drifts: Dict mapping feature names to drift results
59
+ overall_drift_score: Aggregated drift score across all features
60
+ threshold_used: Threshold used for drift detection
61
+ reference_timestamp: Timestamp of reference data
62
+ current_timestamp: Timestamp of current data
63
+
64
+ Example:
65
+ >>> report = DriftReport(
66
+ ... drift_detected=True,
67
+ ... feature_drifts={"sales": feature_result},
68
+ ... overall_drift_score=0.25,
69
+ ... threshold_used=0.2,
70
+ ... )
71
+ >>> print(report.summary())
72
+ Drift detected in 1/1 features. Overall score: 0.250
73
+ """
74
+
75
+ drift_detected: bool
76
+ feature_drifts: dict[str, FeatureDriftResult]
77
+ overall_drift_score: float
78
+ threshold_used: float
79
+ reference_timestamp: str = field(
80
+ default_factory=lambda: datetime.now(UTC).isoformat()
81
+ )
82
+ current_timestamp: str = field(
83
+ default_factory=lambda: datetime.now(UTC).isoformat()
84
+ )
85
+
86
+ def summary(self) -> str:
87
+ """Generate a human-readable summary of the drift report."""
88
+ n_drifting = sum(1 for r in self.feature_drifts.values() if r.drift_detected)
89
+ n_total = len(self.feature_drifts)
90
+ return (
91
+ f"Drift detected in {n_drifting}/{n_total} features. "
92
+ f"Overall score: {self.overall_drift_score:.3f} "
93
+ f"(threshold: {self.threshold_used})"
94
+ )
95
+
96
+ def get_drifting_features(self) -> list[str]:
97
+ """Return list of feature names with detected drift."""
98
+ return [
99
+ name for name, result in self.feature_drifts.items()
100
+ if result.drift_detected
101
+ ]
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class CalibrationReport:
106
+ """Report on quantile calibration.
107
+
108
+ Attributes:
109
+ target_quantiles: List of target quantile levels
110
+ empirical_coverage: Dict mapping quantile to empirical coverage
111
+ calibration_errors: Dict mapping quantile to calibration error
112
+ well_calibrated: Whether all quantiles are well-calibrated
113
+ tolerance: Tolerance used for calibration check
114
+
115
+ Example:
116
+ >>> report = CalibrationReport(
117
+ ... target_quantiles=[0.1, 0.5, 0.9],
118
+ ... empirical_coverage={0.1: 0.08, 0.5: 0.52, 0.9: 0.91},
119
+ ... calibration_errors={0.1: 0.02, 0.5: 0.02, 0.9: 0.01},
120
+ ... well_calibrated=True,
121
+ ... tolerance=0.05,
122
+ ... )
123
+ """
124
+
125
+ target_quantiles: list[float]
126
+ empirical_coverage: dict[float, float]
127
+ calibration_errors: dict[float, float]
128
+ well_calibrated: bool
129
+ tolerance: float
130
+
131
+ def summary(self) -> str:
132
+ """Generate a human-readable summary."""
133
+ errors_str = ", ".join(
134
+ f"q={q:.2f}: err={e:.3f}"
135
+ for q, e in self.calibration_errors.items()
136
+ )
137
+ status = "well-calibrated" if self.well_calibrated else "poorly-calibrated"
138
+ return f"Calibration ({status}): {errors_str}"
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class StabilityReport:
143
+ """Report on prediction stability.
144
+
145
+ Attributes:
146
+ jitter_metrics: Dict mapping series_id to jitter metric
147
+ overall_jitter: Aggregate jitter across all series
148
+ jitter_threshold: Threshold used for jitter evaluation
149
+ high_jitter_series: List of series with high jitter
150
+ coverage_report: Optional calibration report for quantiles
151
+
152
+ Example:
153
+ >>> report = StabilityReport(
154
+ ... jitter_metrics={"A": 0.05, "B": 0.15},
155
+ ... overall_jitter=0.10,
156
+ ... jitter_threshold=0.10,
157
+ ... high_jitter_series=["B"],
158
+ ... )
159
+ >>> print(report.is_stable)
160
+ False
161
+ """
162
+
163
+ jitter_metrics: dict[str, float]
164
+ overall_jitter: float
165
+ jitter_threshold: float
166
+ high_jitter_series: list[str]
167
+ coverage_report: CalibrationReport | None = None
168
+
169
+ @property
170
+ def is_stable(self) -> bool:
171
+ """Whether predictions are considered stable."""
172
+ return len(self.high_jitter_series) == 0
173
+
174
+ def summary(self) -> str:
175
+ """Generate a human-readable summary."""
176
+ status = "stable" if self.is_stable else "unstable"
177
+ n_high = len(self.high_jitter_series)
178
+ return (
179
+ f"Stability ({status}): overall_jitter={self.overall_jitter:.3f}, "
180
+ f"{n_high} series exceed threshold"
181
+ )
182
+
183
+
184
+ @dataclass(frozen=True)
185
+ class TriggerResult:
186
+ """Result of a retrain trigger evaluation.
187
+
188
+ Attributes:
189
+ trigger_type: Type of trigger that was evaluated
190
+ fired: Whether the trigger fired
191
+ reason: Human-readable reason for trigger firing or not
192
+ timestamp: When the trigger was evaluated
193
+ metadata: Additional trigger-specific metadata
194
+
195
+ Example:
196
+ >>> result = TriggerResult(
197
+ ... trigger_type=TriggerType.DRIFT,
198
+ ... fired=True,
199
+ ... reason="PSI drift score 0.25 exceeded threshold 0.2",
200
+ ... metadata={"psi_score": 0.25, "threshold": 0.2},
201
+ ... )
202
+ """
203
+
204
+ trigger_type: str
205
+ fired: bool
206
+ reason: str
207
+ timestamp: str = field(
208
+ default_factory=lambda: datetime.now(UTC).isoformat()
209
+ )
210
+ metadata: dict = field(default_factory=dict)
211
+
212
+ def __repr__(self) -> str:
213
+ status = "FIRED" if self.fired else "no-op"
214
+ return f"TriggerResult({self.trigger_type}, {status})"