tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,418 @@
1
+ """Forecast result structures.
2
+
3
+ Defines the data structures for forecast outputs including provenance tracking.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+ import pandas as pd
13
+
14
+ _QUANTILE_PATTERN = re.compile(r"^q[_]?([0-9]+(?:\.[0-9]+)?)$")
15
+
16
+
17
+ def _parse_quantile_column(col: str) -> float | None:
18
+ match = _QUANTILE_PATTERN.match(col)
19
+ if not match:
20
+ return None
21
+ value = float(match.group(1))
22
+ if value > 1:
23
+ value = value / 100.0
24
+ if not 0 < value < 1:
25
+ return None
26
+ return value
27
+
28
+
29
+ def _is_datetime_like(series: Any) -> bool:
30
+ try:
31
+ import pandas as pd # Optional import for accurate dtype checks.
32
+
33
+ return bool(pd.api.types.is_datetime64_any_dtype(series))
34
+ except Exception:
35
+ dtype = getattr(series, "dtype", None)
36
+ kind = getattr(dtype, "kind", None)
37
+ return kind == "M"
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class ForecastFrame:
42
+ """Forecast frame in long format.
43
+
44
+ Expected columns: unique_id, ds, model, yhat (+ intervals/quantiles).
45
+ """
46
+
47
+ df: Any
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class CVFrame:
52
+ """Cross-validation frame in long format.
53
+
54
+ Expected columns: unique_id, ds, cutoff, model, y, yhat (+ intervals/quantiles).
55
+ """
56
+
57
+ df: Any
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class Provenance:
62
+ """Provenance information for a forecast run.
63
+
64
+ Provides full traceability of the forecasting pipeline including
65
+ data signatures, model configurations, and execution metadata.
66
+
67
+ Attributes:
68
+ run_id: Unique identifier for this run (UUID)
69
+ timestamp: ISO 8601 timestamp of execution
70
+ data_signature: Hash of input data
71
+ task_signature: Hash of task specification
72
+ plan_signature: Hash of execution plan
73
+ model_signature: Hash of model configuration
74
+ qa_repairs: List of data repairs applied
75
+ fallbacks_triggered: List of fallback events
76
+ metadata: Additional execution metadata
77
+ """
78
+
79
+ run_id: str
80
+ timestamp: str
81
+ data_signature: str
82
+ task_signature: str
83
+ plan_signature: str
84
+ model_signature: str
85
+ qa_repairs: list[dict[str, Any]] = field(default_factory=list)
86
+ fallbacks_triggered: list[dict[str, Any]] = field(default_factory=list)
87
+ metadata: dict[str, Any] = field(default_factory=dict)
88
+
89
+ def to_dict(self) -> dict[str, Any]:
90
+ """Convert to dictionary for serialization."""
91
+ return {
92
+ "run_id": self.run_id,
93
+ "timestamp": self.timestamp,
94
+ "data_signature": self.data_signature,
95
+ "task_signature": self.task_signature,
96
+ "plan_signature": self.plan_signature,
97
+ "model_signature": self.model_signature,
98
+ "qa_repairs": self.qa_repairs,
99
+ "fallbacks_triggered": self.fallbacks_triggered,
100
+ "metadata": self.metadata,
101
+ }
102
+
103
+ @classmethod
104
+ def from_dict(cls, data: dict[str, Any]) -> Provenance:
105
+ """Create from dictionary."""
106
+ return cls(**data)
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class ForecastResult:
111
+ """Result of a forecast operation.
112
+
113
+ Contains the forecast values with optional quantiles and full
114
+ provenance information for reproducibility.
115
+
116
+ Attributes:
117
+ df: DataFrame with columns [unique_id, ds, model, yhat] + quantile columns
118
+ provenance: Full provenance information
119
+ model_name: Name of the model that produced this forecast
120
+ horizon: Forecast horizon
121
+ """
122
+
123
+ df: Any
124
+ provenance: Provenance
125
+ model_name: str
126
+ horizon: int
127
+
128
+ def __post_init__(self) -> None:
129
+ """Validate the dataframe structure."""
130
+ required_cols = {"unique_id", "ds", "model", "yhat"}
131
+ missing = required_cols - set(self.df.columns)
132
+ if missing:
133
+ raise ValueError(f"ForecastResult df missing columns: {missing}")
134
+
135
+ # Validate types
136
+ if not _is_datetime_like(self.df["ds"]):
137
+ raise ValueError("Column 'ds' must be datetime")
138
+
139
+ def get_quantile_columns(self) -> list[str]:
140
+ """Get list of quantile column names.
141
+
142
+ Returns:
143
+ List of column names starting with 'q' (quantile columns)
144
+ """
145
+ return [c for c in self.df.columns if _parse_quantile_column(c) is not None]
146
+
147
+ def get_series(self, unique_id: str) -> pd.DataFrame:
148
+ """Get forecast for a specific series.
149
+
150
+ Args:
151
+ unique_id: The series identifier
152
+
153
+ Returns:
154
+ DataFrame with forecast for the specified series
155
+ """
156
+ return self.df[self.df["unique_id"] == unique_id].copy()
157
+
158
+ def to_dict(self) -> dict[str, Any]:
159
+ """Convert to dictionary for serialization.
160
+
161
+ Note: DataFrame is converted to records format.
162
+ """
163
+ return {
164
+ "df": self.df.to_dict("records"),
165
+ "provenance": self.provenance.to_dict(),
166
+ "model_name": self.model_name,
167
+ "horizon": self.horizon,
168
+ }
169
+
170
+
171
+ @dataclass(frozen=True)
172
+ class ValidationReport:
173
+ """Report from data validation.
174
+
175
+ Contains the results of validating input data against the
176
+ required schema and constraints.
177
+
178
+ Attributes:
179
+ valid: Whether validation passed
180
+ errors: List of validation errors (if any)
181
+ warnings: List of validation warnings
182
+ stats: Statistics about the data
183
+ """
184
+
185
+ valid: bool
186
+ errors: list[dict[str, Any]] = field(default_factory=list)
187
+ warnings: list[dict[str, Any]] = field(default_factory=list)
188
+ stats: dict[str, Any] = field(default_factory=dict)
189
+
190
+ def has_errors(self) -> bool:
191
+ """Check if there are any errors."""
192
+ return len(self.errors) > 0
193
+
194
+ def has_warnings(self) -> bool:
195
+ """Check if there are any warnings."""
196
+ return len(self.warnings) > 0
197
+
198
+ def raise_if_errors(self) -> None:
199
+ """Raise the first error if any exist."""
200
+ from .errors import get_error_class
201
+
202
+ if self.errors:
203
+ err = self.errors[0]
204
+ error_code = err.get("code", "E_CONTRACT_MISSING_COLUMN")
205
+ message = err.get("message", "Validation failed")
206
+ context = err.get("context", {})
207
+
208
+ error_class = get_error_class(error_code)
209
+ raise error_class(message, context)
210
+
211
+ def to_dict(self) -> dict[str, Any]:
212
+ """Convert to dictionary."""
213
+ return {
214
+ "valid": self.valid,
215
+ "errors": self.errors,
216
+ "warnings": self.warnings,
217
+ "stats": self.stats,
218
+ }
219
+
220
+
221
+ @dataclass(frozen=True)
222
+ class ModelArtifact:
223
+ """Container for a fitted model.
224
+
225
+ Stores the fitted model along with its configuration and metadata
226
+ for later prediction and provenance tracking.
227
+
228
+ Attributes:
229
+ model: The fitted model (type depends on implementation)
230
+ model_name: Name of the model
231
+ config: Model configuration dictionary
232
+ signature: Hash of model configuration
233
+ fit_timestamp: ISO 8601 timestamp of fitting
234
+ metadata: Additional model metadata
235
+ """
236
+
237
+ model: Any
238
+ model_name: str
239
+ config: dict[str, Any] = field(default_factory=dict)
240
+ signature: str = ""
241
+ fit_timestamp: str = ""
242
+ metadata: dict[str, Any] = field(default_factory=dict)
243
+
244
+ def __post_init__(self) -> None:
245
+ """Compute signature if not provided."""
246
+ if not self.signature:
247
+ import hashlib
248
+ import json
249
+
250
+ # Create deterministic signature from config
251
+ config_str = json.dumps(self.config, sort_keys=True, separators=(",", ":"))
252
+ object.__setattr__( # Bypass frozen
253
+ self,
254
+ "signature",
255
+ hashlib.sha256(config_str.encode()).hexdigest()[:16],
256
+ )
257
+
258
+
259
+ @dataclass(frozen=True)
260
+ class RepairReport:
261
+ """Detailed repair report for audit trail.
262
+
263
+ Provides comprehensive information about data repairs applied
264
+ during QA, ensuring full traceability and PIT safety verification.
265
+
266
+ Attributes:
267
+ repair_type: Type of repair ("missing_values", "winsorize", "median_filter")
268
+ column: Column that was repaired
269
+ count: Number of values repaired
270
+ method: Method used for repair
271
+ scope: Scope of repair ("observed_history", "future")
272
+ before_sample: Sample statistics before repair (optional)
273
+ after_sample: Sample statistics after repair (optional)
274
+ time_range: Time range of repair as (start, end) ISO strings (optional)
275
+ pit_safe: Whether repair is PIT-safe
276
+ validation_passed: Whether validation passed
277
+ """
278
+
279
+ repair_type: str
280
+ column: str
281
+ count: int
282
+ method: str
283
+ scope: str = "observed_history"
284
+
285
+ # PIT safety information
286
+ before_sample: dict[str, Any] | None = None
287
+ after_sample: dict[str, Any] | None = None
288
+ time_range: tuple[str, str] | None = None
289
+
290
+ # Validation
291
+ pit_safe: bool = True
292
+ validation_passed: bool = True
293
+
294
+ def to_dict(self) -> dict[str, Any]:
295
+ """Convert to dictionary for serialization."""
296
+ return {
297
+ "repair_type": self.repair_type,
298
+ "column": self.column,
299
+ "count": self.count,
300
+ "method": self.method,
301
+ "scope": self.scope,
302
+ "pit_safe": self.pit_safe,
303
+ "validation_passed": self.validation_passed,
304
+ "before_sample": self.before_sample,
305
+ "after_sample": self.after_sample,
306
+ "time_range": self.time_range,
307
+ }
308
+
309
+
310
+ @dataclass(frozen=True)
311
+ class RunArtifact:
312
+ """Complete artifact from a forecasting run.
313
+
314
+ The comprehensive output of the forecasting pipeline containing
315
+ all results, reports, and provenance information.
316
+
317
+ Attributes:
318
+ forecast: The forecast result
319
+ plan: Execution plan that was used
320
+ backtest_report: Backtest results (if performed)
321
+ qa_report: QA report (if available)
322
+ model_artifact: The fitted model artifact
323
+ provenance: Full provenance information
324
+ metadata: Additional run metadata
325
+ """
326
+
327
+ forecast: ForecastResult
328
+ plan: dict[str, Any] | None = None
329
+ task_spec: dict[str, Any] | None = None
330
+ plan_spec: dict[str, Any] | None = None
331
+ validation_report: dict[str, Any] | None = None
332
+ backtest_report: dict[str, Any] | None = None
333
+ qa_report: dict[str, Any] | None = None
334
+ model_artifact: ModelArtifact | None = None
335
+ provenance: Provenance | None = None
336
+ calibration_artifact: dict[str, Any] | None = None
337
+ anomaly_report: dict[str, Any] | None = None
338
+ metadata: dict[str, Any] = field(default_factory=dict)
339
+
340
+ def to_dict(self) -> dict[str, Any]:
341
+ """Convert to dictionary for serialization."""
342
+ return {
343
+ "forecast": self.forecast.to_dict() if self.forecast else None,
344
+ "plan": self.plan,
345
+ "task_spec": self.task_spec,
346
+ "plan_spec": self.plan_spec,
347
+ "validation_report": self.validation_report,
348
+ "backtest_report": self.backtest_report,
349
+ "qa_report": self.qa_report,
350
+ "model_artifact": {
351
+ "model_name": self.model_artifact.model_name,
352
+ "signature": self.model_artifact.signature,
353
+ "fit_timestamp": self.model_artifact.fit_timestamp,
354
+ } if self.model_artifact else None,
355
+ "provenance": self.provenance.to_dict() if self.provenance else None,
356
+ "calibration_artifact": self.calibration_artifact,
357
+ "anomaly_report": self.anomaly_report,
358
+ "metadata": self.metadata,
359
+ }
360
+
361
+ def summary(self) -> str:
362
+ """Generate a human-readable summary."""
363
+ model_name = self.forecast.model_name if self.forecast else "N/A"
364
+ forecast_rows = len(self.forecast.df) if self.forecast else 0
365
+
366
+ plan_desc = "N/A"
367
+ if isinstance(self.plan, dict):
368
+ candidates = self.plan.get("candidate_models")
369
+ if candidates:
370
+ chain = "->".join(candidates)
371
+ plan_desc = f"Plan({chain})"
372
+ else:
373
+ primary = self.plan.get("primary_model")
374
+ fallback = self.plan.get("fallback_chain", [])
375
+ if primary:
376
+ chain = "->".join([primary] + list(fallback)) if fallback else primary
377
+ plan_desc = f"Plan({chain})"
378
+ else:
379
+ plan_desc = str(self.plan.get("signature") or self.plan)
380
+ else:
381
+ plan_desc = str(self.plan)
382
+
383
+ lines = [
384
+ "Run Artifact Summary",
385
+ "=" * 40,
386
+ f"Model: {model_name}",
387
+ f"Plan: {plan_desc}",
388
+ f"Forecast rows: {forecast_rows}",
389
+ ]
390
+
391
+ if self.backtest_report:
392
+ n_windows = self.backtest_report.get("n_windows")
393
+ if n_windows is not None:
394
+ lines.append(f"Backtest windows: {n_windows}")
395
+ metrics = self.backtest_report.get("aggregate_metrics", {})
396
+ if metrics:
397
+ lines.append("Aggregate Metrics:")
398
+ for name, value in sorted(metrics.items()):
399
+ lines.append(f" {name}: {value:.4f}")
400
+
401
+ if self.provenance:
402
+ lines.append("\nProvenance:")
403
+ lines.append(f" Data signature: {self.provenance.data_signature}")
404
+ lines.append(f" Timestamp: {self.provenance.timestamp}")
405
+
406
+ return "\n".join(lines)
407
+
408
+
409
+ __all__ = [
410
+ "CVFrame",
411
+ "ForecastFrame",
412
+ "ForecastResult",
413
+ "ModelArtifact",
414
+ "Provenance",
415
+ "RepairReport",
416
+ "RunArtifact",
417
+ "ValidationReport",
418
+ ]
@@ -0,0 +1,44 @@
1
+ """Data validation schemas (compat wrapper).
2
+
3
+ This module preserves the stable API while keeping contracts free of
4
+ non-stdlib dependencies. The implementation lives in tsagentkit.series.validation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from importlib import import_module
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ if TYPE_CHECKING:
13
+ from tsagentkit.contracts.results import ValidationReport
14
+ from tsagentkit.contracts.task_spec import PanelContract
15
+
16
+
17
+ def _impl():
18
+ return import_module("tsagentkit.series.validation")
19
+
20
+
21
+ def normalize_panel_columns(
22
+ df: Any,
23
+ contract: PanelContract,
24
+ ) -> tuple[Any, dict[str, str] | None]:
25
+ """Normalize panel columns to the canonical contract names."""
26
+ return _impl().normalize_panel_columns(df, contract)
27
+
28
+
29
+ def validate_contract(
30
+ data: Any,
31
+ panel_contract: PanelContract | None = None,
32
+ apply_aggregation: bool = False,
33
+ return_data: bool = False,
34
+ ) -> ValidationReport | tuple[ValidationReport, Any]:
35
+ """Validate input data against the required schema."""
36
+ return _impl().validate_contract(
37
+ data,
38
+ panel_contract=panel_contract,
39
+ apply_aggregation=apply_aggregation,
40
+ return_data=return_data,
41
+ )
42
+
43
+
44
+ __all__ = ["validate_contract", "normalize_panel_columns"]