tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""Forecast result structures.
|
|
2
|
+
|
|
3
|
+
Defines the data structures for forecast outputs including provenance tracking.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
_QUANTILE_PATTERN = re.compile(r"^q[_]?([0-9]+(?:\.[0-9]+)?)$")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _parse_quantile_column(col: str) -> float | None:
|
|
18
|
+
match = _QUANTILE_PATTERN.match(col)
|
|
19
|
+
if not match:
|
|
20
|
+
return None
|
|
21
|
+
value = float(match.group(1))
|
|
22
|
+
if value > 1:
|
|
23
|
+
value = value / 100.0
|
|
24
|
+
if not 0 < value < 1:
|
|
25
|
+
return None
|
|
26
|
+
return value
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_datetime_like(series: Any) -> bool:
|
|
30
|
+
try:
|
|
31
|
+
import pandas as pd # Optional import for accurate dtype checks.
|
|
32
|
+
|
|
33
|
+
return bool(pd.api.types.is_datetime64_any_dtype(series))
|
|
34
|
+
except Exception:
|
|
35
|
+
dtype = getattr(series, "dtype", None)
|
|
36
|
+
kind = getattr(dtype, "kind", None)
|
|
37
|
+
return kind == "M"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class ForecastFrame:
|
|
42
|
+
"""Forecast frame in long format.
|
|
43
|
+
|
|
44
|
+
Expected columns: unique_id, ds, model, yhat (+ intervals/quantiles).
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
df: Any
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class CVFrame:
|
|
52
|
+
"""Cross-validation frame in long format.
|
|
53
|
+
|
|
54
|
+
Expected columns: unique_id, ds, cutoff, model, y, yhat (+ intervals/quantiles).
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
df: Any
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass(frozen=True)
|
|
61
|
+
class Provenance:
|
|
62
|
+
"""Provenance information for a forecast run.
|
|
63
|
+
|
|
64
|
+
Provides full traceability of the forecasting pipeline including
|
|
65
|
+
data signatures, model configurations, and execution metadata.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
run_id: Unique identifier for this run (UUID)
|
|
69
|
+
timestamp: ISO 8601 timestamp of execution
|
|
70
|
+
data_signature: Hash of input data
|
|
71
|
+
task_signature: Hash of task specification
|
|
72
|
+
plan_signature: Hash of execution plan
|
|
73
|
+
model_signature: Hash of model configuration
|
|
74
|
+
qa_repairs: List of data repairs applied
|
|
75
|
+
fallbacks_triggered: List of fallback events
|
|
76
|
+
metadata: Additional execution metadata
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
run_id: str
|
|
80
|
+
timestamp: str
|
|
81
|
+
data_signature: str
|
|
82
|
+
task_signature: str
|
|
83
|
+
plan_signature: str
|
|
84
|
+
model_signature: str
|
|
85
|
+
qa_repairs: list[dict[str, Any]] = field(default_factory=list)
|
|
86
|
+
fallbacks_triggered: list[dict[str, Any]] = field(default_factory=list)
|
|
87
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
88
|
+
|
|
89
|
+
def to_dict(self) -> dict[str, Any]:
|
|
90
|
+
"""Convert to dictionary for serialization."""
|
|
91
|
+
return {
|
|
92
|
+
"run_id": self.run_id,
|
|
93
|
+
"timestamp": self.timestamp,
|
|
94
|
+
"data_signature": self.data_signature,
|
|
95
|
+
"task_signature": self.task_signature,
|
|
96
|
+
"plan_signature": self.plan_signature,
|
|
97
|
+
"model_signature": self.model_signature,
|
|
98
|
+
"qa_repairs": self.qa_repairs,
|
|
99
|
+
"fallbacks_triggered": self.fallbacks_triggered,
|
|
100
|
+
"metadata": self.metadata,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def from_dict(cls, data: dict[str, Any]) -> Provenance:
|
|
105
|
+
"""Create from dictionary."""
|
|
106
|
+
return cls(**data)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass(frozen=True)
|
|
110
|
+
class ForecastResult:
|
|
111
|
+
"""Result of a forecast operation.
|
|
112
|
+
|
|
113
|
+
Contains the forecast values with optional quantiles and full
|
|
114
|
+
provenance information for reproducibility.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
df: DataFrame with columns [unique_id, ds, model, yhat] + quantile columns
|
|
118
|
+
provenance: Full provenance information
|
|
119
|
+
model_name: Name of the model that produced this forecast
|
|
120
|
+
horizon: Forecast horizon
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
df: Any
|
|
124
|
+
provenance: Provenance
|
|
125
|
+
model_name: str
|
|
126
|
+
horizon: int
|
|
127
|
+
|
|
128
|
+
def __post_init__(self) -> None:
|
|
129
|
+
"""Validate the dataframe structure."""
|
|
130
|
+
required_cols = {"unique_id", "ds", "model", "yhat"}
|
|
131
|
+
missing = required_cols - set(self.df.columns)
|
|
132
|
+
if missing:
|
|
133
|
+
raise ValueError(f"ForecastResult df missing columns: {missing}")
|
|
134
|
+
|
|
135
|
+
# Validate types
|
|
136
|
+
if not _is_datetime_like(self.df["ds"]):
|
|
137
|
+
raise ValueError("Column 'ds' must be datetime")
|
|
138
|
+
|
|
139
|
+
def get_quantile_columns(self) -> list[str]:
|
|
140
|
+
"""Get list of quantile column names.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
List of column names starting with 'q' (quantile columns)
|
|
144
|
+
"""
|
|
145
|
+
return [c for c in self.df.columns if _parse_quantile_column(c) is not None]
|
|
146
|
+
|
|
147
|
+
def get_series(self, unique_id: str) -> pd.DataFrame:
|
|
148
|
+
"""Get forecast for a specific series.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
unique_id: The series identifier
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
DataFrame with forecast for the specified series
|
|
155
|
+
"""
|
|
156
|
+
return self.df[self.df["unique_id"] == unique_id].copy()
|
|
157
|
+
|
|
158
|
+
def to_dict(self) -> dict[str, Any]:
|
|
159
|
+
"""Convert to dictionary for serialization.
|
|
160
|
+
|
|
161
|
+
Note: DataFrame is converted to records format.
|
|
162
|
+
"""
|
|
163
|
+
return {
|
|
164
|
+
"df": self.df.to_dict("records"),
|
|
165
|
+
"provenance": self.provenance.to_dict(),
|
|
166
|
+
"model_name": self.model_name,
|
|
167
|
+
"horizon": self.horizon,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@dataclass(frozen=True)
|
|
172
|
+
class ValidationReport:
|
|
173
|
+
"""Report from data validation.
|
|
174
|
+
|
|
175
|
+
Contains the results of validating input data against the
|
|
176
|
+
required schema and constraints.
|
|
177
|
+
|
|
178
|
+
Attributes:
|
|
179
|
+
valid: Whether validation passed
|
|
180
|
+
errors: List of validation errors (if any)
|
|
181
|
+
warnings: List of validation warnings
|
|
182
|
+
stats: Statistics about the data
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
valid: bool
|
|
186
|
+
errors: list[dict[str, Any]] = field(default_factory=list)
|
|
187
|
+
warnings: list[dict[str, Any]] = field(default_factory=list)
|
|
188
|
+
stats: dict[str, Any] = field(default_factory=dict)
|
|
189
|
+
|
|
190
|
+
def has_errors(self) -> bool:
|
|
191
|
+
"""Check if there are any errors."""
|
|
192
|
+
return len(self.errors) > 0
|
|
193
|
+
|
|
194
|
+
def has_warnings(self) -> bool:
|
|
195
|
+
"""Check if there are any warnings."""
|
|
196
|
+
return len(self.warnings) > 0
|
|
197
|
+
|
|
198
|
+
def raise_if_errors(self) -> None:
|
|
199
|
+
"""Raise the first error if any exist."""
|
|
200
|
+
from .errors import get_error_class
|
|
201
|
+
|
|
202
|
+
if self.errors:
|
|
203
|
+
err = self.errors[0]
|
|
204
|
+
error_code = err.get("code", "E_CONTRACT_MISSING_COLUMN")
|
|
205
|
+
message = err.get("message", "Validation failed")
|
|
206
|
+
context = err.get("context", {})
|
|
207
|
+
|
|
208
|
+
error_class = get_error_class(error_code)
|
|
209
|
+
raise error_class(message, context)
|
|
210
|
+
|
|
211
|
+
def to_dict(self) -> dict[str, Any]:
|
|
212
|
+
"""Convert to dictionary."""
|
|
213
|
+
return {
|
|
214
|
+
"valid": self.valid,
|
|
215
|
+
"errors": self.errors,
|
|
216
|
+
"warnings": self.warnings,
|
|
217
|
+
"stats": self.stats,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclass(frozen=True)
|
|
222
|
+
class ModelArtifact:
|
|
223
|
+
"""Container for a fitted model.
|
|
224
|
+
|
|
225
|
+
Stores the fitted model along with its configuration and metadata
|
|
226
|
+
for later prediction and provenance tracking.
|
|
227
|
+
|
|
228
|
+
Attributes:
|
|
229
|
+
model: The fitted model (type depends on implementation)
|
|
230
|
+
model_name: Name of the model
|
|
231
|
+
config: Model configuration dictionary
|
|
232
|
+
signature: Hash of model configuration
|
|
233
|
+
fit_timestamp: ISO 8601 timestamp of fitting
|
|
234
|
+
metadata: Additional model metadata
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
model: Any
|
|
238
|
+
model_name: str
|
|
239
|
+
config: dict[str, Any] = field(default_factory=dict)
|
|
240
|
+
signature: str = ""
|
|
241
|
+
fit_timestamp: str = ""
|
|
242
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
243
|
+
|
|
244
|
+
def __post_init__(self) -> None:
|
|
245
|
+
"""Compute signature if not provided."""
|
|
246
|
+
if not self.signature:
|
|
247
|
+
import hashlib
|
|
248
|
+
import json
|
|
249
|
+
|
|
250
|
+
# Create deterministic signature from config
|
|
251
|
+
config_str = json.dumps(self.config, sort_keys=True, separators=(",", ":"))
|
|
252
|
+
object.__setattr__( # Bypass frozen
|
|
253
|
+
self,
|
|
254
|
+
"signature",
|
|
255
|
+
hashlib.sha256(config_str.encode()).hexdigest()[:16],
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@dataclass(frozen=True)
|
|
260
|
+
class RepairReport:
|
|
261
|
+
"""Detailed repair report for audit trail.
|
|
262
|
+
|
|
263
|
+
Provides comprehensive information about data repairs applied
|
|
264
|
+
during QA, ensuring full traceability and PIT safety verification.
|
|
265
|
+
|
|
266
|
+
Attributes:
|
|
267
|
+
repair_type: Type of repair ("missing_values", "winsorize", "median_filter")
|
|
268
|
+
column: Column that was repaired
|
|
269
|
+
count: Number of values repaired
|
|
270
|
+
method: Method used for repair
|
|
271
|
+
scope: Scope of repair ("observed_history", "future")
|
|
272
|
+
before_sample: Sample statistics before repair (optional)
|
|
273
|
+
after_sample: Sample statistics after repair (optional)
|
|
274
|
+
time_range: Time range of repair as (start, end) ISO strings (optional)
|
|
275
|
+
pit_safe: Whether repair is PIT-safe
|
|
276
|
+
validation_passed: Whether validation passed
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
repair_type: str
|
|
280
|
+
column: str
|
|
281
|
+
count: int
|
|
282
|
+
method: str
|
|
283
|
+
scope: str = "observed_history"
|
|
284
|
+
|
|
285
|
+
# PIT safety information
|
|
286
|
+
before_sample: dict[str, Any] | None = None
|
|
287
|
+
after_sample: dict[str, Any] | None = None
|
|
288
|
+
time_range: tuple[str, str] | None = None
|
|
289
|
+
|
|
290
|
+
# Validation
|
|
291
|
+
pit_safe: bool = True
|
|
292
|
+
validation_passed: bool = True
|
|
293
|
+
|
|
294
|
+
def to_dict(self) -> dict[str, Any]:
|
|
295
|
+
"""Convert to dictionary for serialization."""
|
|
296
|
+
return {
|
|
297
|
+
"repair_type": self.repair_type,
|
|
298
|
+
"column": self.column,
|
|
299
|
+
"count": self.count,
|
|
300
|
+
"method": self.method,
|
|
301
|
+
"scope": self.scope,
|
|
302
|
+
"pit_safe": self.pit_safe,
|
|
303
|
+
"validation_passed": self.validation_passed,
|
|
304
|
+
"before_sample": self.before_sample,
|
|
305
|
+
"after_sample": self.after_sample,
|
|
306
|
+
"time_range": self.time_range,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@dataclass(frozen=True)
|
|
311
|
+
class RunArtifact:
|
|
312
|
+
"""Complete artifact from a forecasting run.
|
|
313
|
+
|
|
314
|
+
The comprehensive output of the forecasting pipeline containing
|
|
315
|
+
all results, reports, and provenance information.
|
|
316
|
+
|
|
317
|
+
Attributes:
|
|
318
|
+
forecast: The forecast result
|
|
319
|
+
plan: Execution plan that was used
|
|
320
|
+
backtest_report: Backtest results (if performed)
|
|
321
|
+
qa_report: QA report (if available)
|
|
322
|
+
model_artifact: The fitted model artifact
|
|
323
|
+
provenance: Full provenance information
|
|
324
|
+
metadata: Additional run metadata
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
forecast: ForecastResult
|
|
328
|
+
plan: dict[str, Any] | None = None
|
|
329
|
+
task_spec: dict[str, Any] | None = None
|
|
330
|
+
plan_spec: dict[str, Any] | None = None
|
|
331
|
+
validation_report: dict[str, Any] | None = None
|
|
332
|
+
backtest_report: dict[str, Any] | None = None
|
|
333
|
+
qa_report: dict[str, Any] | None = None
|
|
334
|
+
model_artifact: ModelArtifact | None = None
|
|
335
|
+
provenance: Provenance | None = None
|
|
336
|
+
calibration_artifact: dict[str, Any] | None = None
|
|
337
|
+
anomaly_report: dict[str, Any] | None = None
|
|
338
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
339
|
+
|
|
340
|
+
def to_dict(self) -> dict[str, Any]:
|
|
341
|
+
"""Convert to dictionary for serialization."""
|
|
342
|
+
return {
|
|
343
|
+
"forecast": self.forecast.to_dict() if self.forecast else None,
|
|
344
|
+
"plan": self.plan,
|
|
345
|
+
"task_spec": self.task_spec,
|
|
346
|
+
"plan_spec": self.plan_spec,
|
|
347
|
+
"validation_report": self.validation_report,
|
|
348
|
+
"backtest_report": self.backtest_report,
|
|
349
|
+
"qa_report": self.qa_report,
|
|
350
|
+
"model_artifact": {
|
|
351
|
+
"model_name": self.model_artifact.model_name,
|
|
352
|
+
"signature": self.model_artifact.signature,
|
|
353
|
+
"fit_timestamp": self.model_artifact.fit_timestamp,
|
|
354
|
+
} if self.model_artifact else None,
|
|
355
|
+
"provenance": self.provenance.to_dict() if self.provenance else None,
|
|
356
|
+
"calibration_artifact": self.calibration_artifact,
|
|
357
|
+
"anomaly_report": self.anomaly_report,
|
|
358
|
+
"metadata": self.metadata,
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
def summary(self) -> str:
|
|
362
|
+
"""Generate a human-readable summary."""
|
|
363
|
+
model_name = self.forecast.model_name if self.forecast else "N/A"
|
|
364
|
+
forecast_rows = len(self.forecast.df) if self.forecast else 0
|
|
365
|
+
|
|
366
|
+
plan_desc = "N/A"
|
|
367
|
+
if isinstance(self.plan, dict):
|
|
368
|
+
candidates = self.plan.get("candidate_models")
|
|
369
|
+
if candidates:
|
|
370
|
+
chain = "->".join(candidates)
|
|
371
|
+
plan_desc = f"Plan({chain})"
|
|
372
|
+
else:
|
|
373
|
+
primary = self.plan.get("primary_model")
|
|
374
|
+
fallback = self.plan.get("fallback_chain", [])
|
|
375
|
+
if primary:
|
|
376
|
+
chain = "->".join([primary] + list(fallback)) if fallback else primary
|
|
377
|
+
plan_desc = f"Plan({chain})"
|
|
378
|
+
else:
|
|
379
|
+
plan_desc = str(self.plan.get("signature") or self.plan)
|
|
380
|
+
else:
|
|
381
|
+
plan_desc = str(self.plan)
|
|
382
|
+
|
|
383
|
+
lines = [
|
|
384
|
+
"Run Artifact Summary",
|
|
385
|
+
"=" * 40,
|
|
386
|
+
f"Model: {model_name}",
|
|
387
|
+
f"Plan: {plan_desc}",
|
|
388
|
+
f"Forecast rows: {forecast_rows}",
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
if self.backtest_report:
|
|
392
|
+
n_windows = self.backtest_report.get("n_windows")
|
|
393
|
+
if n_windows is not None:
|
|
394
|
+
lines.append(f"Backtest windows: {n_windows}")
|
|
395
|
+
metrics = self.backtest_report.get("aggregate_metrics", {})
|
|
396
|
+
if metrics:
|
|
397
|
+
lines.append("Aggregate Metrics:")
|
|
398
|
+
for name, value in sorted(metrics.items()):
|
|
399
|
+
lines.append(f" {name}: {value:.4f}")
|
|
400
|
+
|
|
401
|
+
if self.provenance:
|
|
402
|
+
lines.append("\nProvenance:")
|
|
403
|
+
lines.append(f" Data signature: {self.provenance.data_signature}")
|
|
404
|
+
lines.append(f" Timestamp: {self.provenance.timestamp}")
|
|
405
|
+
|
|
406
|
+
return "\n".join(lines)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
__all__ = [
|
|
410
|
+
"CVFrame",
|
|
411
|
+
"ForecastFrame",
|
|
412
|
+
"ForecastResult",
|
|
413
|
+
"ModelArtifact",
|
|
414
|
+
"Provenance",
|
|
415
|
+
"RepairReport",
|
|
416
|
+
"RunArtifact",
|
|
417
|
+
"ValidationReport",
|
|
418
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Data validation schemas (compat wrapper).
|
|
2
|
+
|
|
3
|
+
This module preserves the stable API while keeping contracts free of
|
|
4
|
+
non-stdlib dependencies. The implementation lives in tsagentkit.series.validation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from importlib import import_module
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from tsagentkit.contracts.results import ValidationReport
|
|
14
|
+
from tsagentkit.contracts.task_spec import PanelContract
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _impl():
|
|
18
|
+
return import_module("tsagentkit.series.validation")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def normalize_panel_columns(
|
|
22
|
+
df: Any,
|
|
23
|
+
contract: PanelContract,
|
|
24
|
+
) -> tuple[Any, dict[str, str] | None]:
|
|
25
|
+
"""Normalize panel columns to the canonical contract names."""
|
|
26
|
+
return _impl().normalize_panel_columns(df, contract)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def validate_contract(
|
|
30
|
+
data: Any,
|
|
31
|
+
panel_contract: PanelContract | None = None,
|
|
32
|
+
apply_aggregation: bool = False,
|
|
33
|
+
return_data: bool = False,
|
|
34
|
+
) -> ValidationReport | tuple[ValidationReport, Any]:
|
|
35
|
+
"""Validate input data against the required schema."""
|
|
36
|
+
return _impl().validate_contract(
|
|
37
|
+
data,
|
|
38
|
+
panel_contract=panel_contract,
|
|
39
|
+
apply_aggregation=apply_aggregation,
|
|
40
|
+
return_data=return_data,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
__all__ = ["validate_contract", "normalize_panel_columns"]
|