tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,300 @@
1
+ """Pydantic specs for tsagentkit contracts and configuration.
2
+
3
+ These models are the JSON-serializable configuration and artifact contracts
4
+ used by agents and orchestration layers. They mirror docs/PRD.md Appendix B.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from datetime import datetime
10
+ from typing import Any, Literal
11
+
12
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
13
+
14
+ # ---------------------------
15
+ # Common
16
+ # ---------------------------
17
+
18
+ class BaseSpec(BaseModel):
19
+ model_config = ConfigDict(extra="forbid", frozen=True)
20
+
21
+
22
+ CovariateRole = Literal["static", "past", "future_known"]
23
+ AggregationMode = Literal["reject", "sum", "mean", "median", "last"]
24
+ MissingPolicy = Literal["error", "ffill", "bfill", "zero", "mean"]
25
+ IntervalMode = Literal["level", "quantiles"]
26
+ AnomalyMethod = Literal["interval_breach", "conformal", "mad_residual"]
27
+ SeasonalityMethod = Literal["acf", "stl", "periodogram"]
28
+ CovariatePolicy = Literal["ignore", "known", "observed", "auto", "spec"]
29
+
30
+
31
+ # ---------------------------
32
+ # Data contracts (column-level)
33
+ # ---------------------------
34
+
35
+ class PanelContract(BaseSpec):
36
+ unique_id_col: str = "unique_id"
37
+ ds_col: str = "ds"
38
+ y_col: str = "y"
39
+ aggregation: AggregationMode = "reject"
40
+
41
+
42
+ class ForecastContract(BaseSpec):
43
+ long_format: bool = True
44
+ model_col: str = "model"
45
+ yhat_col: str = "yhat"
46
+ cutoff_col: str = "cutoff" # required for CV output
47
+ interval_mode: IntervalMode = "level"
48
+ levels: list[int] = Field(default_factory=lambda: [80, 95])
49
+ quantiles: list[float] = Field(default_factory=lambda: [0.1, 0.5, 0.9])
50
+
51
+
52
+ class CovariateSpec(BaseSpec):
53
+ # Explicit typing strongly preferred for agent safety.
54
+ roles: dict[str, CovariateRole] = Field(default_factory=dict)
55
+ missing_policy: MissingPolicy = "error"
56
+
57
+
58
+ # ---------------------------
59
+ # Task / execution specs
60
+ # ---------------------------
61
+
62
+ class BacktestSpec(BaseSpec):
63
+ h: int | None = Field(None, gt=0)
64
+ n_windows: int = Field(5, gt=0)
65
+ step: int = Field(1, gt=0)
66
+ min_train_size: int = Field(56, gt=1)
67
+ regularize_grid: bool = True
68
+
69
+
70
+ class TaskSpec(BaseSpec):
71
+ # Forecast horizon
72
+ h: int = Field(..., gt=0)
73
+
74
+ # Frequency handling
75
+ freq: str | None = None
76
+ infer_freq: bool = True
77
+
78
+ # Contracts
79
+ panel_contract: PanelContract = Field(default_factory=PanelContract)
80
+ forecast_contract: ForecastContract = Field(default_factory=ForecastContract)
81
+
82
+ # Covariates
83
+ covariates: CovariateSpec | None = None
84
+ covariate_policy: CovariatePolicy = "auto"
85
+
86
+ # Backtest defaults (can be overridden by the caller)
87
+ backtest: BacktestSpec = Field(default_factory=BacktestSpec)
88
+
89
+ @model_validator(mode="before")
90
+ @classmethod
91
+ def _normalize_inputs(cls, data: Any) -> Any:
92
+ if not isinstance(data, dict):
93
+ return data
94
+
95
+ payload = dict(data)
96
+
97
+ # Backward-compat aliases
98
+ if "horizon" in payload and "h" not in payload:
99
+ payload["h"] = payload.pop("horizon")
100
+ if "rolling_step" in payload:
101
+ backtest = payload.get("backtest", {})
102
+ if isinstance(backtest, BacktestSpec):
103
+ backtest = backtest.model_dump()
104
+ if isinstance(backtest, dict):
105
+ backtest = dict(backtest)
106
+ if "step" not in backtest:
107
+ backtest["step"] = payload.pop("rolling_step")
108
+ payload["backtest"] = backtest
109
+
110
+ # Legacy quantiles/levels mapping to forecast_contract
111
+ if "quantiles" in payload or "levels" in payload:
112
+ fc = payload.get("forecast_contract", {})
113
+ if isinstance(fc, ForecastContract):
114
+ fc = fc.model_dump()
115
+ if isinstance(fc, dict):
116
+ fc = dict(fc)
117
+ if "quantiles" in payload:
118
+ fc["quantiles"] = payload.pop("quantiles")
119
+ if "levels" in payload:
120
+ fc["levels"] = payload.pop("levels")
121
+ payload["forecast_contract"] = fc
122
+
123
+ return payload
124
+
125
+ @model_validator(mode="after")
126
+ def _apply_backtest_defaults(self) -> TaskSpec:
127
+ if self.backtest.h is None:
128
+ object.__setattr__(self, "backtest", self.backtest.model_copy(update={"h": self.h}))
129
+ return self
130
+
131
+ @property
132
+ def horizon(self) -> int:
133
+ return self.h
134
+
135
+ @property
136
+ def quantiles(self) -> list[float]:
137
+ return self.forecast_contract.quantiles
138
+
139
+ @property
140
+ def levels(self) -> list[int]:
141
+ return self.forecast_contract.levels
142
+
143
+ @property
144
+ def season_length(self) -> int | None:
145
+ return self._infer_season_length(self.freq)
146
+
147
+ @staticmethod
148
+ def _infer_season_length(freq: str | None) -> int | None:
149
+ if not freq:
150
+ return None
151
+ freq_map: dict[str, int] = {
152
+ "D": 7,
153
+ "B": 5,
154
+ "H": 24,
155
+ "T": 60,
156
+ "min": 60,
157
+ "M": 12,
158
+ "MS": 12,
159
+ "Q": 4,
160
+ "QS": 4,
161
+ "W": 52,
162
+ }
163
+ base_freq = freq.lstrip("0123456789")
164
+ return freq_map.get(base_freq)
165
+
166
+ def model_hash(self) -> str:
167
+ import hashlib
168
+ import json
169
+
170
+ data = self.model_dump(exclude_none=True)
171
+ json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
172
+ return hashlib.sha256(json_str.encode()).hexdigest()[:16]
173
+
174
+
175
+ # ---------------------------
176
+ # Router / planning
177
+ # ---------------------------
178
+
179
+ class RouterThresholds(BaseSpec):
180
+ min_train_size: int = Field(56, gt=1)
181
+ max_missing_ratio: float = Field(0.15, ge=0.0, le=1.0)
182
+
183
+ # Intermittency classification (heuristic, deterministic)
184
+ max_intermittency_adi: float = Field(1.32, gt=0.0)
185
+ max_intermittency_cv2: float = Field(0.49, ge=0.0)
186
+
187
+ # Seasonality
188
+ seasonality_method: SeasonalityMethod = "acf"
189
+ min_seasonality_conf: float = Field(0.70, ge=0.0, le=1.0)
190
+
191
+ # Practical routing guardrails
192
+ max_series_count_for_tsfm: int = Field(20000, gt=0)
193
+ max_points_per_series_for_tsfm: int = Field(5000, gt=0)
194
+
195
+
196
+ class PlanSpec(BaseSpec):
197
+ plan_name: str
198
+ candidate_models: list[str] = Field(..., min_length=1)
199
+
200
+ # Covariate usage rules
201
+ use_static: bool = True
202
+ use_past: bool = True
203
+ use_future_known: bool = True
204
+
205
+ # Training policy
206
+ min_train_size: int = Field(56, gt=1)
207
+ max_train_size: int | None = None # if set, truncate oldest points deterministically
208
+
209
+ # Output policy
210
+ interval_mode: IntervalMode = "level"
211
+ levels: list[int] = Field(default_factory=lambda: [80, 95])
212
+ quantiles: list[float] = Field(default_factory=lambda: [0.1, 0.5, 0.9])
213
+
214
+ # Fallback policy
215
+ allow_drop_covariates: bool = True
216
+ allow_baseline: bool = True
217
+
218
+
219
+ class RouteDecision(BaseSpec):
220
+ # Series statistics used in routing (computed deterministically)
221
+ stats: dict[str, Any] = Field(default_factory=dict)
222
+
223
+ # Bucket tags
224
+ buckets: list[str] = Field(default_factory=list)
225
+
226
+ # Which plan template was selected
227
+ selected_plan: PlanSpec
228
+
229
+ # Human-readable deterministic reasons (safe for logs)
230
+ reasons: list[str] = Field(default_factory=list)
231
+
232
+
233
+ class RouterConfig(BaseSpec):
234
+ thresholds: RouterThresholds = Field(default_factory=RouterThresholds)
235
+
236
+ # Mapping bucket -> plan template name, resolved by registry
237
+ bucket_to_plan: dict[str, str] = Field(default_factory=dict)
238
+
239
+ # Default plan when no bucket matches
240
+ default_plan: str = "default"
241
+
242
+
243
+ # ---------------------------
244
+ # Calibration + anomaly
245
+ # ---------------------------
246
+
247
+ class CalibratorSpec(BaseSpec):
248
+ method: Literal["none", "conformal"] = "conformal"
249
+ level: int = Field(99, ge=50, le=99)
250
+ by: Literal["unique_id", "global"] = "unique_id"
251
+
252
+
253
+ class AnomalySpec(BaseSpec):
254
+ method: AnomalyMethod = "conformal"
255
+ level: int = Field(99, ge=50, le=99)
256
+ score: Literal["margin", "normalized_margin", "zscore"] = "normalized_margin"
257
+
258
+
259
+ # ---------------------------
260
+ # Provenance artifacts (config-level, serializable)
261
+ # ---------------------------
262
+
263
+ class RunArtifactSpec(BaseSpec):
264
+ run_id: str
265
+ created_at: datetime
266
+
267
+ task_spec: TaskSpec
268
+ router_config: RouterConfig | None = None
269
+ route_decision: RouteDecision | None = None
270
+
271
+ # Identifiers / hashes for reproducibility (implementation-defined)
272
+ data_signature: str | None = None
273
+ code_signature: str | None = None
274
+
275
+ # Output references (implementation-defined; typically file paths or object-store keys)
276
+ outputs: dict[str, str] = Field(default_factory=dict)
277
+
278
+
279
+ __all__ = [
280
+ "AggregationMode",
281
+ "AnomalyMethod",
282
+ "AnomalySpec",
283
+ "BacktestSpec",
284
+ "BaseSpec",
285
+ "CalibratorSpec",
286
+ "CovariatePolicy",
287
+ "CovariateRole",
288
+ "CovariateSpec",
289
+ "ForecastContract",
290
+ "IntervalMode",
291
+ "MissingPolicy",
292
+ "PanelContract",
293
+ "PlanSpec",
294
+ "RouteDecision",
295
+ "RouterConfig",
296
+ "RouterThresholds",
297
+ "RunArtifactSpec",
298
+ "SeasonalityMethod",
299
+ "TaskSpec",
300
+ ]
@@ -0,0 +1,340 @@
1
+ """Covariate typing, alignment, and guardrails."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import pandas as pd
8
+
9
+ from tsagentkit.contracts import (
10
+ CovariateSpec,
11
+ ECovariateIncompleteKnown,
12
+ ECovariateLeakage,
13
+ ECovariateStaticInvalid,
14
+ ETaskSpecInvalid,
15
+ TaskSpec,
16
+ )
17
+ from tsagentkit.time import make_future_index
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class CovariateBundle:
22
+ static_x: pd.DataFrame | None = None
23
+ past_x: pd.DataFrame | None = None
24
+ future_x: pd.DataFrame | None = None
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class AlignedDataset:
29
+ panel: pd.DataFrame
30
+ static_x: pd.DataFrame | None
31
+ past_x: pd.DataFrame | None
32
+ future_x: pd.DataFrame | None
33
+ covariate_spec: CovariateSpec | None
34
+ future_index: pd.DataFrame | None
35
+
36
+
37
+ def align_covariates(
38
+ panel: pd.DataFrame,
39
+ task_spec: TaskSpec,
40
+ covariates: CovariateBundle | None = None,
41
+ ) -> AlignedDataset:
42
+ """Align covariates and enforce coverage/leakage guardrails."""
43
+ contract = task_spec.panel_contract
44
+ uid_col = contract.unique_id_col
45
+ ds_col = contract.ds_col
46
+ y_col = contract.y_col
47
+
48
+ panel = panel.copy()
49
+ future_index = make_future_index(panel, task_spec.h, task_spec.freq, uid_col, ds_col, y_col)
50
+
51
+ if covariates is not None:
52
+ static_x = _validate_static_covariates(covariates.static_x, uid_col)
53
+ past_x = _validate_past_covariates(covariates.past_x, future_index, uid_col, ds_col)
54
+ future_x = _validate_future_covariates(
55
+ covariates.future_x, future_index, uid_col, ds_col
56
+ )
57
+ return AlignedDataset(
58
+ panel=_panel_base(panel, uid_col, ds_col, y_col),
59
+ static_x=static_x,
60
+ past_x=past_x,
61
+ future_x=future_x,
62
+ covariate_spec=task_spec.covariates,
63
+ future_index=future_index,
64
+ )
65
+
66
+ covariate_cols = [c for c in panel.columns if c not in {uid_col, ds_col, y_col}]
67
+ if not covariate_cols or task_spec.covariate_policy == "ignore":
68
+ return AlignedDataset(
69
+ panel=_panel_base(panel, uid_col, ds_col, y_col),
70
+ static_x=None,
71
+ past_x=None,
72
+ future_x=None,
73
+ covariate_spec=task_spec.covariates,
74
+ future_index=future_index,
75
+ )
76
+
77
+ static_cols: list[str] = []
78
+ past_cols: list[str] = []
79
+ future_cols: list[str] = []
80
+
81
+ policy = task_spec.covariate_policy
82
+ spec = task_spec.covariates
83
+
84
+ if policy == "spec" and spec is None and covariate_cols:
85
+ raise ETaskSpecInvalid(
86
+ "covariate_policy='spec' requires task_spec.covariates.",
87
+ context={"missing": "covariates"},
88
+ )
89
+
90
+ if policy == "spec" and spec is not None:
91
+ _validate_spec_roles(spec, covariate_cols)
92
+ for col, role in spec.roles.items():
93
+ if role == "static":
94
+ static_cols.append(col)
95
+ elif role == "past":
96
+ past_cols.append(col)
97
+ elif role == "future_known":
98
+ future_cols.append(col)
99
+ elif policy == "known":
100
+ future_cols = covariate_cols
101
+ elif policy == "observed":
102
+ past_cols = covariate_cols
103
+ elif policy == "auto":
104
+ for col in covariate_cols:
105
+ if _has_future_values(panel, future_index, col, uid_col, ds_col):
106
+ # Candidate future-known: require full coverage
107
+ future_cols.append(col)
108
+ _enforce_future_coverage(panel, future_index, col, uid_col, ds_col)
109
+ else:
110
+ past_cols.append(col)
111
+ else:
112
+ past_cols = covariate_cols
113
+
114
+ for col in future_cols:
115
+ _enforce_future_coverage(panel, future_index, col, uid_col, ds_col)
116
+
117
+ static_x = _extract_static(panel, uid_col, static_cols)
118
+ past_x = _extract_time_covariates(panel, uid_col, ds_col, past_cols)
119
+ future_x = _extract_future_covariates(panel, future_index, uid_col, ds_col, future_cols)
120
+
121
+ _enforce_past_leakage(panel, future_index, uid_col, ds_col, past_cols)
122
+
123
+ return AlignedDataset(
124
+ panel=_panel_base(panel, uid_col, ds_col, y_col),
125
+ static_x=static_x,
126
+ past_x=past_x,
127
+ future_x=future_x,
128
+ covariate_spec=spec,
129
+ future_index=future_index,
130
+ )
131
+
132
+
133
+ def _panel_base(panel: pd.DataFrame, uid_col: str, ds_col: str, y_col: str) -> pd.DataFrame:
134
+ cols = [c for c in [uid_col, ds_col, y_col] if c in panel.columns]
135
+ return panel[cols].copy()
136
+
137
+
138
+ def _validate_spec_roles(spec: CovariateSpec, covariate_cols: list[str]) -> None:
139
+ if not spec.roles:
140
+ if covariate_cols:
141
+ raise ETaskSpecInvalid(
142
+ "covariate_policy='spec' requires explicit roles for all covariate columns.",
143
+ context={"missing_roles_for": sorted(covariate_cols)},
144
+ )
145
+ return
146
+
147
+ missing_in_panel = [col for col in spec.roles if col not in covariate_cols]
148
+ if missing_in_panel:
149
+ raise ETaskSpecInvalid(
150
+ "CovariateSpec roles include columns not present in panel data.",
151
+ context={"missing_in_panel": sorted(missing_in_panel)},
152
+ )
153
+
154
+ extra_in_panel = [col for col in covariate_cols if col not in spec.roles]
155
+ if extra_in_panel:
156
+ raise ETaskSpecInvalid(
157
+ "covariate_policy='spec' requires roles for all panel covariates.",
158
+ context={"missing_roles_for": sorted(extra_in_panel)},
159
+ )
160
+
161
+
162
+ def _has_future_values(
163
+ panel: pd.DataFrame,
164
+ future_index: pd.DataFrame,
165
+ col: str,
166
+ uid_col: str,
167
+ ds_col: str,
168
+ ) -> bool:
169
+ merged = future_index.merge(
170
+ panel[[uid_col, ds_col, col]],
171
+ on=[uid_col, ds_col],
172
+ how="left",
173
+ )
174
+ return merged[col].notna().any()
175
+
176
+
177
+ def _enforce_future_coverage(
178
+ panel: pd.DataFrame,
179
+ future_index: pd.DataFrame,
180
+ col: str,
181
+ uid_col: str,
182
+ ds_col: str,
183
+ ) -> None:
184
+ merged = future_index.merge(
185
+ panel[[uid_col, ds_col, col]],
186
+ on=[uid_col, ds_col],
187
+ how="left",
188
+ )
189
+ if merged[col].isna().any():
190
+ missing = int(merged[col].isna().sum())
191
+ raise ECovariateIncompleteKnown(
192
+ f"Future-known covariate '{col}' missing {missing} values in horizon.",
193
+ context={"covariate": col, "missing": missing},
194
+ )
195
+
196
+
197
+ def _enforce_past_leakage(
198
+ panel: pd.DataFrame,
199
+ future_index: pd.DataFrame,
200
+ uid_col: str,
201
+ ds_col: str,
202
+ past_cols: list[str],
203
+ ) -> None:
204
+ if not past_cols:
205
+ return
206
+
207
+ merged = future_index.merge(
208
+ panel[[uid_col, ds_col] + past_cols],
209
+ on=[uid_col, ds_col],
210
+ how="left",
211
+ )
212
+ for col in past_cols:
213
+ if merged[col].notna().any():
214
+ count = int(merged[col].notna().sum())
215
+ raise ECovariateLeakage(
216
+ f"Past covariate '{col}' has {count} values in forecast horizon.",
217
+ context={"covariate": col, "future_values_count": count},
218
+ )
219
+
220
+
221
+ def _extract_static(panel: pd.DataFrame, uid_col: str, cols: list[str]) -> pd.DataFrame | None:
222
+ if not cols:
223
+ return None
224
+ df = panel[[uid_col] + cols].dropna(subset=[uid_col]).copy()
225
+ # Validate constant per unique_id
226
+ for col in cols:
227
+ counts = df.groupby(uid_col)[col].nunique(dropna=True)
228
+ if (counts > 1).any():
229
+ bad = counts[counts > 1].index.tolist()[:5]
230
+ raise ECovariateStaticInvalid(
231
+ f"Static covariate '{col}' varies within series.",
232
+ context={"covariate": col, "unique_id_examples": bad},
233
+ )
234
+ return df.groupby(uid_col, as_index=False).first()
235
+
236
+
237
+ def _extract_time_covariates(
238
+ panel: pd.DataFrame,
239
+ uid_col: str,
240
+ ds_col: str,
241
+ cols: list[str],
242
+ ) -> pd.DataFrame | None:
243
+ if not cols:
244
+ return None
245
+ return panel[[uid_col, ds_col] + cols].copy()
246
+
247
+
248
+ def _extract_future_covariates(
249
+ panel: pd.DataFrame,
250
+ future_index: pd.DataFrame,
251
+ uid_col: str,
252
+ ds_col: str,
253
+ cols: list[str],
254
+ ) -> pd.DataFrame | None:
255
+ if not cols:
256
+ return None
257
+ merged = future_index.merge(
258
+ panel[[uid_col, ds_col] + cols],
259
+ on=[uid_col, ds_col],
260
+ how="left",
261
+ )
262
+ for col in cols:
263
+ if merged[col].isna().any():
264
+ missing = int(merged[col].isna().sum())
265
+ raise ECovariateIncompleteKnown(
266
+ f"Future-known covariate '{col}' missing {missing} values in horizon.",
267
+ context={"covariate": col, "missing": missing},
268
+ )
269
+ return merged
270
+
271
+
272
+ def _validate_static_covariates(
273
+ static_x: pd.DataFrame | None,
274
+ uid_col: str,
275
+ ) -> pd.DataFrame | None:
276
+ if static_x is None or static_x.empty:
277
+ return None
278
+ counts = static_x.groupby(uid_col).size()
279
+ if (counts != 1).any():
280
+ bad = counts[counts != 1].index.tolist()[:5]
281
+ raise ECovariateStaticInvalid(
282
+ "Static covariates must have exactly one row per unique_id.",
283
+ context={"unique_id_examples": bad},
284
+ )
285
+ return static_x.copy()
286
+
287
+
288
+ def _validate_past_covariates(
289
+ past_x: pd.DataFrame | None,
290
+ future_index: pd.DataFrame,
291
+ uid_col: str,
292
+ ds_col: str,
293
+ ) -> pd.DataFrame | None:
294
+ if past_x is None or past_x.empty:
295
+ return None
296
+ merged = future_index.merge(
297
+ past_x,
298
+ on=[uid_col, ds_col],
299
+ how="left",
300
+ )
301
+ covariate_cols = [c for c in past_x.columns if c not in {uid_col, ds_col}]
302
+ for col in covariate_cols:
303
+ if merged[col].notna().any():
304
+ count = int(merged[col].notna().sum())
305
+ raise ECovariateLeakage(
306
+ f"Past covariate '{col}' has {count} values in forecast horizon.",
307
+ context={"covariate": col, "future_values_count": count},
308
+ )
309
+ return past_x.copy()
310
+
311
+
312
+ def _validate_future_covariates(
313
+ future_x: pd.DataFrame | None,
314
+ future_index: pd.DataFrame,
315
+ uid_col: str,
316
+ ds_col: str,
317
+ ) -> pd.DataFrame | None:
318
+ if future_x is None or future_x.empty:
319
+ return None
320
+ merged = future_index.merge(
321
+ future_x,
322
+ on=[uid_col, ds_col],
323
+ how="left",
324
+ )
325
+ covariate_cols = [c for c in future_x.columns if c not in {uid_col, ds_col}]
326
+ for col in covariate_cols:
327
+ if merged[col].isna().any():
328
+ missing = int(merged[col].isna().sum())
329
+ raise ECovariateIncompleteKnown(
330
+ f"Future-known covariate '{col}' missing {missing} values in horizon.",
331
+ context={"covariate": col, "missing": missing},
332
+ )
333
+ return future_x.copy()
334
+
335
+
336
+ __all__ = [
337
+ "AlignedDataset",
338
+ "CovariateBundle",
339
+ "align_covariates",
340
+ ]