tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Pydantic specs for tsagentkit contracts and configuration.
|
|
2
|
+
|
|
3
|
+
These models are the JSON-serializable configuration and artifact contracts
|
|
4
|
+
used by agents and orchestration layers. They mirror docs/PRD.md Appendix B.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
13
|
+
|
|
14
|
+
# ---------------------------
|
|
15
|
+
# Common
|
|
16
|
+
# ---------------------------
|
|
17
|
+
|
|
18
|
+
class BaseSpec(BaseModel):
|
|
19
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
CovariateRole = Literal["static", "past", "future_known"]
|
|
23
|
+
AggregationMode = Literal["reject", "sum", "mean", "median", "last"]
|
|
24
|
+
MissingPolicy = Literal["error", "ffill", "bfill", "zero", "mean"]
|
|
25
|
+
IntervalMode = Literal["level", "quantiles"]
|
|
26
|
+
AnomalyMethod = Literal["interval_breach", "conformal", "mad_residual"]
|
|
27
|
+
SeasonalityMethod = Literal["acf", "stl", "periodogram"]
|
|
28
|
+
CovariatePolicy = Literal["ignore", "known", "observed", "auto", "spec"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ---------------------------
|
|
32
|
+
# Data contracts (column-level)
|
|
33
|
+
# ---------------------------
|
|
34
|
+
|
|
35
|
+
class PanelContract(BaseSpec):
|
|
36
|
+
unique_id_col: str = "unique_id"
|
|
37
|
+
ds_col: str = "ds"
|
|
38
|
+
y_col: str = "y"
|
|
39
|
+
aggregation: AggregationMode = "reject"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ForecastContract(BaseSpec):
|
|
43
|
+
long_format: bool = True
|
|
44
|
+
model_col: str = "model"
|
|
45
|
+
yhat_col: str = "yhat"
|
|
46
|
+
cutoff_col: str = "cutoff" # required for CV output
|
|
47
|
+
interval_mode: IntervalMode = "level"
|
|
48
|
+
levels: list[int] = Field(default_factory=lambda: [80, 95])
|
|
49
|
+
quantiles: list[float] = Field(default_factory=lambda: [0.1, 0.5, 0.9])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class CovariateSpec(BaseSpec):
|
|
53
|
+
# Explicit typing strongly preferred for agent safety.
|
|
54
|
+
roles: dict[str, CovariateRole] = Field(default_factory=dict)
|
|
55
|
+
missing_policy: MissingPolicy = "error"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ---------------------------
|
|
59
|
+
# Task / execution specs
|
|
60
|
+
# ---------------------------
|
|
61
|
+
|
|
62
|
+
class BacktestSpec(BaseSpec):
|
|
63
|
+
h: int | None = Field(None, gt=0)
|
|
64
|
+
n_windows: int = Field(5, gt=0)
|
|
65
|
+
step: int = Field(1, gt=0)
|
|
66
|
+
min_train_size: int = Field(56, gt=1)
|
|
67
|
+
regularize_grid: bool = True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TaskSpec(BaseSpec):
|
|
71
|
+
# Forecast horizon
|
|
72
|
+
h: int = Field(..., gt=0)
|
|
73
|
+
|
|
74
|
+
# Frequency handling
|
|
75
|
+
freq: str | None = None
|
|
76
|
+
infer_freq: bool = True
|
|
77
|
+
|
|
78
|
+
# Contracts
|
|
79
|
+
panel_contract: PanelContract = Field(default_factory=PanelContract)
|
|
80
|
+
forecast_contract: ForecastContract = Field(default_factory=ForecastContract)
|
|
81
|
+
|
|
82
|
+
# Covariates
|
|
83
|
+
covariates: CovariateSpec | None = None
|
|
84
|
+
covariate_policy: CovariatePolicy = "auto"
|
|
85
|
+
|
|
86
|
+
# Backtest defaults (can be overridden by the caller)
|
|
87
|
+
backtest: BacktestSpec = Field(default_factory=BacktestSpec)
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="before")
|
|
90
|
+
@classmethod
|
|
91
|
+
def _normalize_inputs(cls, data: Any) -> Any:
|
|
92
|
+
if not isinstance(data, dict):
|
|
93
|
+
return data
|
|
94
|
+
|
|
95
|
+
payload = dict(data)
|
|
96
|
+
|
|
97
|
+
# Backward-compat aliases
|
|
98
|
+
if "horizon" in payload and "h" not in payload:
|
|
99
|
+
payload["h"] = payload.pop("horizon")
|
|
100
|
+
if "rolling_step" in payload:
|
|
101
|
+
backtest = payload.get("backtest", {})
|
|
102
|
+
if isinstance(backtest, BacktestSpec):
|
|
103
|
+
backtest = backtest.model_dump()
|
|
104
|
+
if isinstance(backtest, dict):
|
|
105
|
+
backtest = dict(backtest)
|
|
106
|
+
if "step" not in backtest:
|
|
107
|
+
backtest["step"] = payload.pop("rolling_step")
|
|
108
|
+
payload["backtest"] = backtest
|
|
109
|
+
|
|
110
|
+
# Legacy quantiles/levels mapping to forecast_contract
|
|
111
|
+
if "quantiles" in payload or "levels" in payload:
|
|
112
|
+
fc = payload.get("forecast_contract", {})
|
|
113
|
+
if isinstance(fc, ForecastContract):
|
|
114
|
+
fc = fc.model_dump()
|
|
115
|
+
if isinstance(fc, dict):
|
|
116
|
+
fc = dict(fc)
|
|
117
|
+
if "quantiles" in payload:
|
|
118
|
+
fc["quantiles"] = payload.pop("quantiles")
|
|
119
|
+
if "levels" in payload:
|
|
120
|
+
fc["levels"] = payload.pop("levels")
|
|
121
|
+
payload["forecast_contract"] = fc
|
|
122
|
+
|
|
123
|
+
return payload
|
|
124
|
+
|
|
125
|
+
@model_validator(mode="after")
|
|
126
|
+
def _apply_backtest_defaults(self) -> TaskSpec:
|
|
127
|
+
if self.backtest.h is None:
|
|
128
|
+
object.__setattr__(self, "backtest", self.backtest.model_copy(update={"h": self.h}))
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def horizon(self) -> int:
|
|
133
|
+
return self.h
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def quantiles(self) -> list[float]:
|
|
137
|
+
return self.forecast_contract.quantiles
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def levels(self) -> list[int]:
|
|
141
|
+
return self.forecast_contract.levels
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def season_length(self) -> int | None:
|
|
145
|
+
return self._infer_season_length(self.freq)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _infer_season_length(freq: str | None) -> int | None:
|
|
149
|
+
if not freq:
|
|
150
|
+
return None
|
|
151
|
+
freq_map: dict[str, int] = {
|
|
152
|
+
"D": 7,
|
|
153
|
+
"B": 5,
|
|
154
|
+
"H": 24,
|
|
155
|
+
"T": 60,
|
|
156
|
+
"min": 60,
|
|
157
|
+
"M": 12,
|
|
158
|
+
"MS": 12,
|
|
159
|
+
"Q": 4,
|
|
160
|
+
"QS": 4,
|
|
161
|
+
"W": 52,
|
|
162
|
+
}
|
|
163
|
+
base_freq = freq.lstrip("0123456789")
|
|
164
|
+
return freq_map.get(base_freq)
|
|
165
|
+
|
|
166
|
+
def model_hash(self) -> str:
|
|
167
|
+
import hashlib
|
|
168
|
+
import json
|
|
169
|
+
|
|
170
|
+
data = self.model_dump(exclude_none=True)
|
|
171
|
+
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
|
172
|
+
return hashlib.sha256(json_str.encode()).hexdigest()[:16]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ---------------------------
|
|
176
|
+
# Router / planning
|
|
177
|
+
# ---------------------------
|
|
178
|
+
|
|
179
|
+
class RouterThresholds(BaseSpec):
|
|
180
|
+
min_train_size: int = Field(56, gt=1)
|
|
181
|
+
max_missing_ratio: float = Field(0.15, ge=0.0, le=1.0)
|
|
182
|
+
|
|
183
|
+
# Intermittency classification (heuristic, deterministic)
|
|
184
|
+
max_intermittency_adi: float = Field(1.32, gt=0.0)
|
|
185
|
+
max_intermittency_cv2: float = Field(0.49, ge=0.0)
|
|
186
|
+
|
|
187
|
+
# Seasonality
|
|
188
|
+
seasonality_method: SeasonalityMethod = "acf"
|
|
189
|
+
min_seasonality_conf: float = Field(0.70, ge=0.0, le=1.0)
|
|
190
|
+
|
|
191
|
+
# Practical routing guardrails
|
|
192
|
+
max_series_count_for_tsfm: int = Field(20000, gt=0)
|
|
193
|
+
max_points_per_series_for_tsfm: int = Field(5000, gt=0)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class PlanSpec(BaseSpec):
|
|
197
|
+
plan_name: str
|
|
198
|
+
candidate_models: list[str] = Field(..., min_length=1)
|
|
199
|
+
|
|
200
|
+
# Covariate usage rules
|
|
201
|
+
use_static: bool = True
|
|
202
|
+
use_past: bool = True
|
|
203
|
+
use_future_known: bool = True
|
|
204
|
+
|
|
205
|
+
# Training policy
|
|
206
|
+
min_train_size: int = Field(56, gt=1)
|
|
207
|
+
max_train_size: int | None = None # if set, truncate oldest points deterministically
|
|
208
|
+
|
|
209
|
+
# Output policy
|
|
210
|
+
interval_mode: IntervalMode = "level"
|
|
211
|
+
levels: list[int] = Field(default_factory=lambda: [80, 95])
|
|
212
|
+
quantiles: list[float] = Field(default_factory=lambda: [0.1, 0.5, 0.9])
|
|
213
|
+
|
|
214
|
+
# Fallback policy
|
|
215
|
+
allow_drop_covariates: bool = True
|
|
216
|
+
allow_baseline: bool = True
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class RouteDecision(BaseSpec):
|
|
220
|
+
# Series statistics used in routing (computed deterministically)
|
|
221
|
+
stats: dict[str, Any] = Field(default_factory=dict)
|
|
222
|
+
|
|
223
|
+
# Bucket tags
|
|
224
|
+
buckets: list[str] = Field(default_factory=list)
|
|
225
|
+
|
|
226
|
+
# Which plan template was selected
|
|
227
|
+
selected_plan: PlanSpec
|
|
228
|
+
|
|
229
|
+
# Human-readable deterministic reasons (safe for logs)
|
|
230
|
+
reasons: list[str] = Field(default_factory=list)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class RouterConfig(BaseSpec):
|
|
234
|
+
thresholds: RouterThresholds = Field(default_factory=RouterThresholds)
|
|
235
|
+
|
|
236
|
+
# Mapping bucket -> plan template name, resolved by registry
|
|
237
|
+
bucket_to_plan: dict[str, str] = Field(default_factory=dict)
|
|
238
|
+
|
|
239
|
+
# Default plan when no bucket matches
|
|
240
|
+
default_plan: str = "default"
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# ---------------------------
|
|
244
|
+
# Calibration + anomaly
|
|
245
|
+
# ---------------------------
|
|
246
|
+
|
|
247
|
+
class CalibratorSpec(BaseSpec):
|
|
248
|
+
method: Literal["none", "conformal"] = "conformal"
|
|
249
|
+
level: int = Field(99, ge=50, le=99)
|
|
250
|
+
by: Literal["unique_id", "global"] = "unique_id"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class AnomalySpec(BaseSpec):
|
|
254
|
+
method: AnomalyMethod = "conformal"
|
|
255
|
+
level: int = Field(99, ge=50, le=99)
|
|
256
|
+
score: Literal["margin", "normalized_margin", "zscore"] = "normalized_margin"
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# ---------------------------
|
|
260
|
+
# Provenance artifacts (config-level, serializable)
|
|
261
|
+
# ---------------------------
|
|
262
|
+
|
|
263
|
+
class RunArtifactSpec(BaseSpec):
|
|
264
|
+
run_id: str
|
|
265
|
+
created_at: datetime
|
|
266
|
+
|
|
267
|
+
task_spec: TaskSpec
|
|
268
|
+
router_config: RouterConfig | None = None
|
|
269
|
+
route_decision: RouteDecision | None = None
|
|
270
|
+
|
|
271
|
+
# Identifiers / hashes for reproducibility (implementation-defined)
|
|
272
|
+
data_signature: str | None = None
|
|
273
|
+
code_signature: str | None = None
|
|
274
|
+
|
|
275
|
+
# Output references (implementation-defined; typically file paths or object-store keys)
|
|
276
|
+
outputs: dict[str, str] = Field(default_factory=dict)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
__all__ = [
|
|
280
|
+
"AggregationMode",
|
|
281
|
+
"AnomalyMethod",
|
|
282
|
+
"AnomalySpec",
|
|
283
|
+
"BacktestSpec",
|
|
284
|
+
"BaseSpec",
|
|
285
|
+
"CalibratorSpec",
|
|
286
|
+
"CovariatePolicy",
|
|
287
|
+
"CovariateRole",
|
|
288
|
+
"CovariateSpec",
|
|
289
|
+
"ForecastContract",
|
|
290
|
+
"IntervalMode",
|
|
291
|
+
"MissingPolicy",
|
|
292
|
+
"PanelContract",
|
|
293
|
+
"PlanSpec",
|
|
294
|
+
"RouteDecision",
|
|
295
|
+
"RouterConfig",
|
|
296
|
+
"RouterThresholds",
|
|
297
|
+
"RunArtifactSpec",
|
|
298
|
+
"SeasonalityMethod",
|
|
299
|
+
"TaskSpec",
|
|
300
|
+
]
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""Covariate typing, alignment, and guardrails."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from tsagentkit.contracts import (
|
|
10
|
+
CovariateSpec,
|
|
11
|
+
ECovariateIncompleteKnown,
|
|
12
|
+
ECovariateLeakage,
|
|
13
|
+
ECovariateStaticInvalid,
|
|
14
|
+
ETaskSpecInvalid,
|
|
15
|
+
TaskSpec,
|
|
16
|
+
)
|
|
17
|
+
from tsagentkit.time import make_future_index
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class CovariateBundle:
|
|
22
|
+
static_x: pd.DataFrame | None = None
|
|
23
|
+
past_x: pd.DataFrame | None = None
|
|
24
|
+
future_x: pd.DataFrame | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class AlignedDataset:
|
|
29
|
+
panel: pd.DataFrame
|
|
30
|
+
static_x: pd.DataFrame | None
|
|
31
|
+
past_x: pd.DataFrame | None
|
|
32
|
+
future_x: pd.DataFrame | None
|
|
33
|
+
covariate_spec: CovariateSpec | None
|
|
34
|
+
future_index: pd.DataFrame | None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def align_covariates(
|
|
38
|
+
panel: pd.DataFrame,
|
|
39
|
+
task_spec: TaskSpec,
|
|
40
|
+
covariates: CovariateBundle | None = None,
|
|
41
|
+
) -> AlignedDataset:
|
|
42
|
+
"""Align covariates and enforce coverage/leakage guardrails."""
|
|
43
|
+
contract = task_spec.panel_contract
|
|
44
|
+
uid_col = contract.unique_id_col
|
|
45
|
+
ds_col = contract.ds_col
|
|
46
|
+
y_col = contract.y_col
|
|
47
|
+
|
|
48
|
+
panel = panel.copy()
|
|
49
|
+
future_index = make_future_index(panel, task_spec.h, task_spec.freq, uid_col, ds_col, y_col)
|
|
50
|
+
|
|
51
|
+
if covariates is not None:
|
|
52
|
+
static_x = _validate_static_covariates(covariates.static_x, uid_col)
|
|
53
|
+
past_x = _validate_past_covariates(covariates.past_x, future_index, uid_col, ds_col)
|
|
54
|
+
future_x = _validate_future_covariates(
|
|
55
|
+
covariates.future_x, future_index, uid_col, ds_col
|
|
56
|
+
)
|
|
57
|
+
return AlignedDataset(
|
|
58
|
+
panel=_panel_base(panel, uid_col, ds_col, y_col),
|
|
59
|
+
static_x=static_x,
|
|
60
|
+
past_x=past_x,
|
|
61
|
+
future_x=future_x,
|
|
62
|
+
covariate_spec=task_spec.covariates,
|
|
63
|
+
future_index=future_index,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
covariate_cols = [c for c in panel.columns if c not in {uid_col, ds_col, y_col}]
|
|
67
|
+
if not covariate_cols or task_spec.covariate_policy == "ignore":
|
|
68
|
+
return AlignedDataset(
|
|
69
|
+
panel=_panel_base(panel, uid_col, ds_col, y_col),
|
|
70
|
+
static_x=None,
|
|
71
|
+
past_x=None,
|
|
72
|
+
future_x=None,
|
|
73
|
+
covariate_spec=task_spec.covariates,
|
|
74
|
+
future_index=future_index,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
static_cols: list[str] = []
|
|
78
|
+
past_cols: list[str] = []
|
|
79
|
+
future_cols: list[str] = []
|
|
80
|
+
|
|
81
|
+
policy = task_spec.covariate_policy
|
|
82
|
+
spec = task_spec.covariates
|
|
83
|
+
|
|
84
|
+
if policy == "spec" and spec is None and covariate_cols:
|
|
85
|
+
raise ETaskSpecInvalid(
|
|
86
|
+
"covariate_policy='spec' requires task_spec.covariates.",
|
|
87
|
+
context={"missing": "covariates"},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if policy == "spec" and spec is not None:
|
|
91
|
+
_validate_spec_roles(spec, covariate_cols)
|
|
92
|
+
for col, role in spec.roles.items():
|
|
93
|
+
if role == "static":
|
|
94
|
+
static_cols.append(col)
|
|
95
|
+
elif role == "past":
|
|
96
|
+
past_cols.append(col)
|
|
97
|
+
elif role == "future_known":
|
|
98
|
+
future_cols.append(col)
|
|
99
|
+
elif policy == "known":
|
|
100
|
+
future_cols = covariate_cols
|
|
101
|
+
elif policy == "observed":
|
|
102
|
+
past_cols = covariate_cols
|
|
103
|
+
elif policy == "auto":
|
|
104
|
+
for col in covariate_cols:
|
|
105
|
+
if _has_future_values(panel, future_index, col, uid_col, ds_col):
|
|
106
|
+
# Candidate future-known: require full coverage
|
|
107
|
+
future_cols.append(col)
|
|
108
|
+
_enforce_future_coverage(panel, future_index, col, uid_col, ds_col)
|
|
109
|
+
else:
|
|
110
|
+
past_cols.append(col)
|
|
111
|
+
else:
|
|
112
|
+
past_cols = covariate_cols
|
|
113
|
+
|
|
114
|
+
for col in future_cols:
|
|
115
|
+
_enforce_future_coverage(panel, future_index, col, uid_col, ds_col)
|
|
116
|
+
|
|
117
|
+
static_x = _extract_static(panel, uid_col, static_cols)
|
|
118
|
+
past_x = _extract_time_covariates(panel, uid_col, ds_col, past_cols)
|
|
119
|
+
future_x = _extract_future_covariates(panel, future_index, uid_col, ds_col, future_cols)
|
|
120
|
+
|
|
121
|
+
_enforce_past_leakage(panel, future_index, uid_col, ds_col, past_cols)
|
|
122
|
+
|
|
123
|
+
return AlignedDataset(
|
|
124
|
+
panel=_panel_base(panel, uid_col, ds_col, y_col),
|
|
125
|
+
static_x=static_x,
|
|
126
|
+
past_x=past_x,
|
|
127
|
+
future_x=future_x,
|
|
128
|
+
covariate_spec=spec,
|
|
129
|
+
future_index=future_index,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _panel_base(panel: pd.DataFrame, uid_col: str, ds_col: str, y_col: str) -> pd.DataFrame:
|
|
134
|
+
cols = [c for c in [uid_col, ds_col, y_col] if c in panel.columns]
|
|
135
|
+
return panel[cols].copy()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _validate_spec_roles(spec: CovariateSpec, covariate_cols: list[str]) -> None:
|
|
139
|
+
if not spec.roles:
|
|
140
|
+
if covariate_cols:
|
|
141
|
+
raise ETaskSpecInvalid(
|
|
142
|
+
"covariate_policy='spec' requires explicit roles for all covariate columns.",
|
|
143
|
+
context={"missing_roles_for": sorted(covariate_cols)},
|
|
144
|
+
)
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
missing_in_panel = [col for col in spec.roles if col not in covariate_cols]
|
|
148
|
+
if missing_in_panel:
|
|
149
|
+
raise ETaskSpecInvalid(
|
|
150
|
+
"CovariateSpec roles include columns not present in panel data.",
|
|
151
|
+
context={"missing_in_panel": sorted(missing_in_panel)},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
extra_in_panel = [col for col in covariate_cols if col not in spec.roles]
|
|
155
|
+
if extra_in_panel:
|
|
156
|
+
raise ETaskSpecInvalid(
|
|
157
|
+
"covariate_policy='spec' requires roles for all panel covariates.",
|
|
158
|
+
context={"missing_roles_for": sorted(extra_in_panel)},
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _has_future_values(
|
|
163
|
+
panel: pd.DataFrame,
|
|
164
|
+
future_index: pd.DataFrame,
|
|
165
|
+
col: str,
|
|
166
|
+
uid_col: str,
|
|
167
|
+
ds_col: str,
|
|
168
|
+
) -> bool:
|
|
169
|
+
merged = future_index.merge(
|
|
170
|
+
panel[[uid_col, ds_col, col]],
|
|
171
|
+
on=[uid_col, ds_col],
|
|
172
|
+
how="left",
|
|
173
|
+
)
|
|
174
|
+
return merged[col].notna().any()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _enforce_future_coverage(
|
|
178
|
+
panel: pd.DataFrame,
|
|
179
|
+
future_index: pd.DataFrame,
|
|
180
|
+
col: str,
|
|
181
|
+
uid_col: str,
|
|
182
|
+
ds_col: str,
|
|
183
|
+
) -> None:
|
|
184
|
+
merged = future_index.merge(
|
|
185
|
+
panel[[uid_col, ds_col, col]],
|
|
186
|
+
on=[uid_col, ds_col],
|
|
187
|
+
how="left",
|
|
188
|
+
)
|
|
189
|
+
if merged[col].isna().any():
|
|
190
|
+
missing = int(merged[col].isna().sum())
|
|
191
|
+
raise ECovariateIncompleteKnown(
|
|
192
|
+
f"Future-known covariate '{col}' missing {missing} values in horizon.",
|
|
193
|
+
context={"covariate": col, "missing": missing},
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _enforce_past_leakage(
|
|
198
|
+
panel: pd.DataFrame,
|
|
199
|
+
future_index: pd.DataFrame,
|
|
200
|
+
uid_col: str,
|
|
201
|
+
ds_col: str,
|
|
202
|
+
past_cols: list[str],
|
|
203
|
+
) -> None:
|
|
204
|
+
if not past_cols:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
merged = future_index.merge(
|
|
208
|
+
panel[[uid_col, ds_col] + past_cols],
|
|
209
|
+
on=[uid_col, ds_col],
|
|
210
|
+
how="left",
|
|
211
|
+
)
|
|
212
|
+
for col in past_cols:
|
|
213
|
+
if merged[col].notna().any():
|
|
214
|
+
count = int(merged[col].notna().sum())
|
|
215
|
+
raise ECovariateLeakage(
|
|
216
|
+
f"Past covariate '{col}' has {count} values in forecast horizon.",
|
|
217
|
+
context={"covariate": col, "future_values_count": count},
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _extract_static(panel: pd.DataFrame, uid_col: str, cols: list[str]) -> pd.DataFrame | None:
|
|
222
|
+
if not cols:
|
|
223
|
+
return None
|
|
224
|
+
df = panel[[uid_col] + cols].dropna(subset=[uid_col]).copy()
|
|
225
|
+
# Validate constant per unique_id
|
|
226
|
+
for col in cols:
|
|
227
|
+
counts = df.groupby(uid_col)[col].nunique(dropna=True)
|
|
228
|
+
if (counts > 1).any():
|
|
229
|
+
bad = counts[counts > 1].index.tolist()[:5]
|
|
230
|
+
raise ECovariateStaticInvalid(
|
|
231
|
+
f"Static covariate '{col}' varies within series.",
|
|
232
|
+
context={"covariate": col, "unique_id_examples": bad},
|
|
233
|
+
)
|
|
234
|
+
return df.groupby(uid_col, as_index=False).first()
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _extract_time_covariates(
|
|
238
|
+
panel: pd.DataFrame,
|
|
239
|
+
uid_col: str,
|
|
240
|
+
ds_col: str,
|
|
241
|
+
cols: list[str],
|
|
242
|
+
) -> pd.DataFrame | None:
|
|
243
|
+
if not cols:
|
|
244
|
+
return None
|
|
245
|
+
return panel[[uid_col, ds_col] + cols].copy()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _extract_future_covariates(
|
|
249
|
+
panel: pd.DataFrame,
|
|
250
|
+
future_index: pd.DataFrame,
|
|
251
|
+
uid_col: str,
|
|
252
|
+
ds_col: str,
|
|
253
|
+
cols: list[str],
|
|
254
|
+
) -> pd.DataFrame | None:
|
|
255
|
+
if not cols:
|
|
256
|
+
return None
|
|
257
|
+
merged = future_index.merge(
|
|
258
|
+
panel[[uid_col, ds_col] + cols],
|
|
259
|
+
on=[uid_col, ds_col],
|
|
260
|
+
how="left",
|
|
261
|
+
)
|
|
262
|
+
for col in cols:
|
|
263
|
+
if merged[col].isna().any():
|
|
264
|
+
missing = int(merged[col].isna().sum())
|
|
265
|
+
raise ECovariateIncompleteKnown(
|
|
266
|
+
f"Future-known covariate '{col}' missing {missing} values in horizon.",
|
|
267
|
+
context={"covariate": col, "missing": missing},
|
|
268
|
+
)
|
|
269
|
+
return merged
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _validate_static_covariates(
|
|
273
|
+
static_x: pd.DataFrame | None,
|
|
274
|
+
uid_col: str,
|
|
275
|
+
) -> pd.DataFrame | None:
|
|
276
|
+
if static_x is None or static_x.empty:
|
|
277
|
+
return None
|
|
278
|
+
counts = static_x.groupby(uid_col).size()
|
|
279
|
+
if (counts != 1).any():
|
|
280
|
+
bad = counts[counts != 1].index.tolist()[:5]
|
|
281
|
+
raise ECovariateStaticInvalid(
|
|
282
|
+
"Static covariates must have exactly one row per unique_id.",
|
|
283
|
+
context={"unique_id_examples": bad},
|
|
284
|
+
)
|
|
285
|
+
return static_x.copy()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _validate_past_covariates(
|
|
289
|
+
past_x: pd.DataFrame | None,
|
|
290
|
+
future_index: pd.DataFrame,
|
|
291
|
+
uid_col: str,
|
|
292
|
+
ds_col: str,
|
|
293
|
+
) -> pd.DataFrame | None:
|
|
294
|
+
if past_x is None or past_x.empty:
|
|
295
|
+
return None
|
|
296
|
+
merged = future_index.merge(
|
|
297
|
+
past_x,
|
|
298
|
+
on=[uid_col, ds_col],
|
|
299
|
+
how="left",
|
|
300
|
+
)
|
|
301
|
+
covariate_cols = [c for c in past_x.columns if c not in {uid_col, ds_col}]
|
|
302
|
+
for col in covariate_cols:
|
|
303
|
+
if merged[col].notna().any():
|
|
304
|
+
count = int(merged[col].notna().sum())
|
|
305
|
+
raise ECovariateLeakage(
|
|
306
|
+
f"Past covariate '{col}' has {count} values in forecast horizon.",
|
|
307
|
+
context={"covariate": col, "future_values_count": count},
|
|
308
|
+
)
|
|
309
|
+
return past_x.copy()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _validate_future_covariates(
|
|
313
|
+
future_x: pd.DataFrame | None,
|
|
314
|
+
future_index: pd.DataFrame,
|
|
315
|
+
uid_col: str,
|
|
316
|
+
ds_col: str,
|
|
317
|
+
) -> pd.DataFrame | None:
|
|
318
|
+
if future_x is None or future_x.empty:
|
|
319
|
+
return None
|
|
320
|
+
merged = future_index.merge(
|
|
321
|
+
future_x,
|
|
322
|
+
on=[uid_col, ds_col],
|
|
323
|
+
how="left",
|
|
324
|
+
)
|
|
325
|
+
covariate_cols = [c for c in future_x.columns if c not in {uid_col, ds_col}]
|
|
326
|
+
for col in covariate_cols:
|
|
327
|
+
if merged[col].isna().any():
|
|
328
|
+
missing = int(merged[col].isna().sum())
|
|
329
|
+
raise ECovariateIncompleteKnown(
|
|
330
|
+
f"Future-known covariate '{col}' missing {missing} values in horizon.",
|
|
331
|
+
context={"covariate": col, "missing": missing},
|
|
332
|
+
)
|
|
333
|
+
return future_x.copy()
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
__all__ = [
|
|
337
|
+
"AlignedDataset",
|
|
338
|
+
"CovariateBundle",
|
|
339
|
+
"align_covariates",
|
|
340
|
+
]
|