tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Baseline model implementations backed by statsforecast."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from tsagentkit.contracts import ModelArtifact
|
|
10
|
+
from tsagentkit.utils import normalize_quantile_columns, quantile_col_name
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalize_name(name: str) -> str:
|
|
14
|
+
return name.lower().replace("_", "").replace("-", "")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_statsforecast_models() -> dict[str, type]:
|
|
18
|
+
try:
|
|
19
|
+
from statsforecast import models as sf_models
|
|
20
|
+
except ImportError as exc:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"statsforecast is required for baseline models. "
|
|
23
|
+
"Install with: uv sync (or pip install statsforecast)"
|
|
24
|
+
) from exc
|
|
25
|
+
|
|
26
|
+
model_map: dict[str, type] = {}
|
|
27
|
+
for key, attr in {
|
|
28
|
+
"naive": "Naive",
|
|
29
|
+
"seasonalnaive": "SeasonalNaive",
|
|
30
|
+
"historicaverage": "HistoricAverage",
|
|
31
|
+
"theta": "Theta",
|
|
32
|
+
"windowaverage": "WindowAverage",
|
|
33
|
+
"movingaverage": "WindowAverage",
|
|
34
|
+
"seasonalwindowaverage": "SeasonalWindowAverage",
|
|
35
|
+
"autoets": "AutoETS",
|
|
36
|
+
"ets": "AutoETS",
|
|
37
|
+
}.items():
|
|
38
|
+
model_cls = getattr(sf_models, attr, None)
|
|
39
|
+
if model_cls is not None:
|
|
40
|
+
model_map[key] = model_cls
|
|
41
|
+
|
|
42
|
+
croston_cls = None
|
|
43
|
+
for name in ("CrostonClassic", "CrostonOptimized", "CrostonSBA", "CrostonTSB"):
|
|
44
|
+
croston_cls = getattr(sf_models, name, None)
|
|
45
|
+
if croston_cls is not None:
|
|
46
|
+
break
|
|
47
|
+
if croston_cls is not None:
|
|
48
|
+
model_map["croston"] = croston_cls
|
|
49
|
+
|
|
50
|
+
return model_map
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def is_baseline_model(model_name: str) -> bool:
|
|
54
|
+
"""Return True if model_name is a supported baseline."""
|
|
55
|
+
normalized = _normalize_name(model_name)
|
|
56
|
+
return normalized in _get_statsforecast_models()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _build_model(
|
|
60
|
+
model_name: str,
|
|
61
|
+
config: dict[str, Any],
|
|
62
|
+
) -> tuple[Any, str]:
|
|
63
|
+
model_map = _get_statsforecast_models()
|
|
64
|
+
normalized = _normalize_name(model_name)
|
|
65
|
+
|
|
66
|
+
if normalized not in model_map:
|
|
67
|
+
raise ValueError(f"Unknown baseline model: {model_name}")
|
|
68
|
+
|
|
69
|
+
model_cls = model_map[normalized]
|
|
70
|
+
model_key = model_cls.__name__
|
|
71
|
+
|
|
72
|
+
kwargs: dict[str, Any] = {}
|
|
73
|
+
if normalized == "seasonalnaive":
|
|
74
|
+
kwargs["season_length"] = int(config.get("season_length", 1))
|
|
75
|
+
elif normalized in {"windowaverage", "movingaverage"}:
|
|
76
|
+
kwargs["window_size"] = int(config.get("window_size", 3))
|
|
77
|
+
elif normalized == "seasonalwindowaverage":
|
|
78
|
+
kwargs["season_length"] = int(config.get("season_length", 1))
|
|
79
|
+
kwargs["window_size"] = int(config.get("window_size", 3))
|
|
80
|
+
elif normalized in {"autoets", "ets"}:
|
|
81
|
+
kwargs["season_length"] = int(config.get("season_length", 1))
|
|
82
|
+
return model_cls(**kwargs), model_key
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def fit_baseline(
|
|
86
|
+
model_name: str,
|
|
87
|
+
dataset: Any,
|
|
88
|
+
config: dict[str, Any],
|
|
89
|
+
) -> ModelArtifact:
|
|
90
|
+
"""Fit a baseline model using statsforecast."""
|
|
91
|
+
from statsforecast import StatsForecast
|
|
92
|
+
|
|
93
|
+
model, model_key = _build_model(model_name, config)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
sf = StatsForecast(
|
|
97
|
+
models=[model],
|
|
98
|
+
freq=dataset.task_spec.freq,
|
|
99
|
+
n_jobs=1,
|
|
100
|
+
)
|
|
101
|
+
sf.fit(dataset.df)
|
|
102
|
+
except TypeError:
|
|
103
|
+
sf = StatsForecast(
|
|
104
|
+
df=dataset.df,
|
|
105
|
+
models=[model],
|
|
106
|
+
freq=dataset.task_spec.freq,
|
|
107
|
+
n_jobs=1,
|
|
108
|
+
)
|
|
109
|
+
sf.fit()
|
|
110
|
+
|
|
111
|
+
return ModelArtifact(
|
|
112
|
+
model=sf,
|
|
113
|
+
model_name=model_name,
|
|
114
|
+
config=config,
|
|
115
|
+
metadata={
|
|
116
|
+
"baseline_model": model_key,
|
|
117
|
+
"freq": dataset.task_spec.freq,
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _extract_point_column(forecast_df: pd.DataFrame) -> str:
|
|
123
|
+
value_cols = [c for c in forecast_df.columns if c not in {"unique_id", "ds"}]
|
|
124
|
+
point_cols = [c for c in value_cols if "lo-" not in c and "hi-" not in c]
|
|
125
|
+
if not point_cols:
|
|
126
|
+
raise ValueError("No point forecast column found in statsforecast output.")
|
|
127
|
+
if len(point_cols) == 1:
|
|
128
|
+
return point_cols[0]
|
|
129
|
+
# Prefer plain yhat if present
|
|
130
|
+
if "yhat" in point_cols:
|
|
131
|
+
return "yhat"
|
|
132
|
+
return point_cols[0]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _level_for_quantile(q: float) -> int:
|
|
136
|
+
return int(round((1 - 2 * abs(q - 0.5)) * 100))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _find_interval_column(
|
|
140
|
+
forecast_df: pd.DataFrame,
|
|
141
|
+
level: int,
|
|
142
|
+
kind: str,
|
|
143
|
+
) -> str | None:
|
|
144
|
+
suffix = f"{kind}-{level}"
|
|
145
|
+
matches = [c for c in forecast_df.columns if c.endswith(suffix)]
|
|
146
|
+
if matches:
|
|
147
|
+
return matches[0]
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def predict_baseline(
|
|
152
|
+
model_artifact: ModelArtifact,
|
|
153
|
+
dataset: Any | None,
|
|
154
|
+
horizon: int,
|
|
155
|
+
quantiles: list[float] | None = None,
|
|
156
|
+
) -> pd.DataFrame:
|
|
157
|
+
"""Generate baseline forecasts using statsforecast."""
|
|
158
|
+
sf = model_artifact.model
|
|
159
|
+
levels: list[int] = []
|
|
160
|
+
|
|
161
|
+
if quantiles:
|
|
162
|
+
for q in quantiles:
|
|
163
|
+
if q == 0.5:
|
|
164
|
+
continue
|
|
165
|
+
levels.append(_level_for_quantile(q))
|
|
166
|
+
levels = sorted({lvl for lvl in levels if lvl > 0})
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
forecast_df = sf.forecast(h=horizon, level=levels or None)
|
|
170
|
+
except TypeError:
|
|
171
|
+
if dataset is None:
|
|
172
|
+
raise
|
|
173
|
+
forecast_df = sf.forecast(df=dataset.df, h=horizon, level=levels or None)
|
|
174
|
+
if "unique_id" not in forecast_df.columns or "ds" not in forecast_df.columns:
|
|
175
|
+
forecast_df = forecast_df.reset_index()
|
|
176
|
+
|
|
177
|
+
point_col = _extract_point_column(forecast_df)
|
|
178
|
+
if point_col != "yhat":
|
|
179
|
+
forecast_df = forecast_df.rename(columns={point_col: "yhat"})
|
|
180
|
+
|
|
181
|
+
# Normalize quantile columns
|
|
182
|
+
if quantiles:
|
|
183
|
+
for q in quantiles:
|
|
184
|
+
col_name = quantile_col_name(q)
|
|
185
|
+
if q == 0.5:
|
|
186
|
+
forecast_df[col_name] = forecast_df["yhat"]
|
|
187
|
+
continue
|
|
188
|
+
level = _level_for_quantile(q)
|
|
189
|
+
kind = "lo" if q < 0.5 else "hi"
|
|
190
|
+
interval_col = _find_interval_column(forecast_df, level, kind)
|
|
191
|
+
if interval_col is None:
|
|
192
|
+
forecast_df[col_name] = forecast_df["yhat"]
|
|
193
|
+
else:
|
|
194
|
+
forecast_df[col_name] = forecast_df[interval_col]
|
|
195
|
+
|
|
196
|
+
forecast_df = normalize_quantile_columns(forecast_df)
|
|
197
|
+
|
|
198
|
+
# Keep standard column order
|
|
199
|
+
if "model" not in forecast_df.columns:
|
|
200
|
+
forecast_df["model"] = model_artifact.model_name
|
|
201
|
+
cols = [
|
|
202
|
+
"unique_id",
|
|
203
|
+
"ds",
|
|
204
|
+
"model",
|
|
205
|
+
"yhat",
|
|
206
|
+
] + [c for c in forecast_df.columns if c.startswith("q")]
|
|
207
|
+
return forecast_df[cols]
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""Sktime forecaster adapter with covariate support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import UTC
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from tsagentkit.contracts import ETaskSpecIncompatible, ForecastResult, ModelArtifact
|
|
12
|
+
from tsagentkit.time import make_future_index
|
|
13
|
+
from tsagentkit.utils import normalize_quantile_columns
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class SktimeModelBundle:
|
|
18
|
+
"""Container for per-series sktime forecasters."""
|
|
19
|
+
|
|
20
|
+
model_name: str
|
|
21
|
+
forecasters: dict[str, Any]
|
|
22
|
+
exog_columns: list[str]
|
|
23
|
+
static_columns: list[str]
|
|
24
|
+
future_columns: list[str]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _require_sktime() -> None:
|
|
28
|
+
try:
|
|
29
|
+
import sktime # noqa: F401
|
|
30
|
+
except ImportError as exc:
|
|
31
|
+
raise ImportError(
|
|
32
|
+
"sktime is required for sktime adapters. Install with: uv sync --extra features"
|
|
33
|
+
) from exc
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _make_forecaster(model_key: str, season_length: int | None) -> Any:
|
|
37
|
+
_require_sktime()
|
|
38
|
+
|
|
39
|
+
from sktime.forecasting.naive import NaiveForecaster
|
|
40
|
+
|
|
41
|
+
if model_key in {"naive", "last"}:
|
|
42
|
+
return NaiveForecaster(strategy="last")
|
|
43
|
+
if model_key in {"seasonal_naive", "seasonal"}:
|
|
44
|
+
sp = season_length or 1
|
|
45
|
+
return NaiveForecaster(strategy="last", sp=sp)
|
|
46
|
+
|
|
47
|
+
raise ValueError(f"Unsupported sktime model key: {model_key}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_exog_columns(
|
|
51
|
+
covariates: Any | None,
|
|
52
|
+
plan: Any,
|
|
53
|
+
) -> tuple[list[str], list[str]]:
|
|
54
|
+
static_cols: list[str] = []
|
|
55
|
+
future_cols: list[str] = []
|
|
56
|
+
|
|
57
|
+
if covariates is None:
|
|
58
|
+
return static_cols, future_cols
|
|
59
|
+
|
|
60
|
+
if getattr(plan, "use_static", True) and covariates.static_x is not None:
|
|
61
|
+
static_cols = [c for c in covariates.static_x.columns if c != "unique_id"]
|
|
62
|
+
|
|
63
|
+
if getattr(plan, "use_future_known", True) and covariates.future_x is not None:
|
|
64
|
+
future_cols = [
|
|
65
|
+
c for c in covariates.future_x.columns if c not in {"unique_id", "ds"}
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
return static_cols, future_cols
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _build_train_exog(
|
|
72
|
+
dataset: Any,
|
|
73
|
+
covariates: Any,
|
|
74
|
+
uid: str,
|
|
75
|
+
ds_index: pd.Index,
|
|
76
|
+
static_cols: list[str],
|
|
77
|
+
future_cols: list[str],
|
|
78
|
+
) -> pd.DataFrame | None:
|
|
79
|
+
if not static_cols and not future_cols:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
exog = pd.DataFrame(index=ds_index)
|
|
83
|
+
|
|
84
|
+
if future_cols:
|
|
85
|
+
panel = dataset.panel_with_covariates
|
|
86
|
+
if panel is None:
|
|
87
|
+
raise ETaskSpecIncompatible(
|
|
88
|
+
"Future-known covariates require panel_with_covariates for history.",
|
|
89
|
+
context={"unique_id": uid},
|
|
90
|
+
)
|
|
91
|
+
hist = panel[panel["unique_id"] == uid][["ds"] + future_cols].copy()
|
|
92
|
+
hist = hist.set_index("ds")
|
|
93
|
+
exog = exog.join(hist[future_cols], how="left")
|
|
94
|
+
|
|
95
|
+
if static_cols:
|
|
96
|
+
static_row = covariates.static_x
|
|
97
|
+
if static_row is None:
|
|
98
|
+
raise ETaskSpecIncompatible(
|
|
99
|
+
"Static covariates requested but none provided.",
|
|
100
|
+
context={"unique_id": uid},
|
|
101
|
+
)
|
|
102
|
+
static_row = static_row[static_row["unique_id"] == uid]
|
|
103
|
+
if static_row.empty:
|
|
104
|
+
raise ETaskSpecIncompatible(
|
|
105
|
+
"Static covariate row missing for series.",
|
|
106
|
+
context={"unique_id": uid},
|
|
107
|
+
)
|
|
108
|
+
for col in static_cols:
|
|
109
|
+
value = static_row.iloc[0][col]
|
|
110
|
+
exog[col] = value
|
|
111
|
+
|
|
112
|
+
if exog.isna().any().any():
|
|
113
|
+
raise ETaskSpecIncompatible(
|
|
114
|
+
"Missing exogenous covariate values in training data.",
|
|
115
|
+
context={"unique_id": uid},
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return exog
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _build_future_exog(
|
|
122
|
+
covariates: Any,
|
|
123
|
+
uid: str,
|
|
124
|
+
static_cols: list[str],
|
|
125
|
+
future_cols: list[str],
|
|
126
|
+
) -> pd.DataFrame | None:
|
|
127
|
+
if not static_cols and not future_cols:
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
if covariates.future_index is None:
|
|
131
|
+
raise ETaskSpecIncompatible(
|
|
132
|
+
"Future index is required for exogenous forecasting.",
|
|
133
|
+
context={"unique_id": uid},
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
future_index = covariates.future_index
|
|
137
|
+
future_index = future_index[future_index["unique_id"] == uid].copy()
|
|
138
|
+
future_index = future_index.set_index("ds")
|
|
139
|
+
|
|
140
|
+
exog = pd.DataFrame(index=future_index.index)
|
|
141
|
+
|
|
142
|
+
if future_cols:
|
|
143
|
+
future = covariates.future_x
|
|
144
|
+
if future is None:
|
|
145
|
+
raise ETaskSpecIncompatible(
|
|
146
|
+
"Future-known covariates requested but none provided.",
|
|
147
|
+
context={"unique_id": uid},
|
|
148
|
+
)
|
|
149
|
+
future = future[future["unique_id"] == uid][["ds"] + future_cols].copy()
|
|
150
|
+
future = future.set_index("ds")
|
|
151
|
+
exog = exog.join(future[future_cols], how="left")
|
|
152
|
+
|
|
153
|
+
if static_cols:
|
|
154
|
+
static_row = covariates.static_x
|
|
155
|
+
if static_row is None:
|
|
156
|
+
raise ETaskSpecIncompatible(
|
|
157
|
+
"Static covariates requested but none provided.",
|
|
158
|
+
context={"unique_id": uid},
|
|
159
|
+
)
|
|
160
|
+
static_row = static_row[static_row["unique_id"] == uid]
|
|
161
|
+
if static_row.empty:
|
|
162
|
+
raise ETaskSpecIncompatible(
|
|
163
|
+
"Static covariate row missing for series.",
|
|
164
|
+
context={"unique_id": uid},
|
|
165
|
+
)
|
|
166
|
+
for col in static_cols:
|
|
167
|
+
value = static_row.iloc[0][col]
|
|
168
|
+
exog[col] = value
|
|
169
|
+
|
|
170
|
+
if exog.isna().any().any():
|
|
171
|
+
raise ETaskSpecIncompatible(
|
|
172
|
+
"Missing exogenous covariate values for forecast horizon.",
|
|
173
|
+
context={"unique_id": uid},
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return exog
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def fit_sktime(
|
|
180
|
+
model_name: str,
|
|
181
|
+
dataset: Any,
|
|
182
|
+
plan: Any,
|
|
183
|
+
covariates: Any | None = None,
|
|
184
|
+
) -> ModelArtifact:
|
|
185
|
+
"""Fit per-series sktime forecasters with exogenous covariates."""
|
|
186
|
+
model_key = model_name.split("sktime-", 1)[-1]
|
|
187
|
+
static_cols, future_cols = _get_exog_columns(covariates, plan)
|
|
188
|
+
forecasters: dict[str, Any] = {}
|
|
189
|
+
|
|
190
|
+
for uid in dataset.series_ids:
|
|
191
|
+
series_df = dataset.get_series(uid).copy()
|
|
192
|
+
series_df = series_df.sort_values("ds")
|
|
193
|
+
y = series_df.set_index("ds")["y"].astype(float)
|
|
194
|
+
|
|
195
|
+
x_train = None
|
|
196
|
+
if covariates is not None:
|
|
197
|
+
x_train = _build_train_exog(
|
|
198
|
+
dataset=dataset,
|
|
199
|
+
covariates=covariates,
|
|
200
|
+
uid=uid,
|
|
201
|
+
ds_index=y.index,
|
|
202
|
+
static_cols=static_cols,
|
|
203
|
+
future_cols=future_cols,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
forecaster = _make_forecaster(model_key, dataset.task_spec.season_length)
|
|
207
|
+
forecaster.fit(y, X=x_train)
|
|
208
|
+
forecasters[uid] = forecaster
|
|
209
|
+
|
|
210
|
+
bundle = SktimeModelBundle(
|
|
211
|
+
model_name=model_name,
|
|
212
|
+
forecasters=forecasters,
|
|
213
|
+
exog_columns=sorted(static_cols + future_cols),
|
|
214
|
+
static_columns=static_cols,
|
|
215
|
+
future_columns=future_cols,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return ModelArtifact(
|
|
219
|
+
model=bundle,
|
|
220
|
+
model_name=model_name,
|
|
221
|
+
config={
|
|
222
|
+
"model_key": model_key,
|
|
223
|
+
"static_columns": static_cols,
|
|
224
|
+
"future_columns": future_cols,
|
|
225
|
+
},
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def predict_sktime(
|
|
230
|
+
dataset: Any,
|
|
231
|
+
artifact: ModelArtifact,
|
|
232
|
+
spec: Any,
|
|
233
|
+
covariates: Any | None = None,
|
|
234
|
+
) -> ForecastResult:
|
|
235
|
+
"""Predict with per-series sktime forecasters."""
|
|
236
|
+
_require_sktime()
|
|
237
|
+
|
|
238
|
+
from sktime.forecasting.base import ForecastingHorizon
|
|
239
|
+
|
|
240
|
+
bundle: SktimeModelBundle = artifact.model
|
|
241
|
+
rows: list[dict[str, Any]] = []
|
|
242
|
+
|
|
243
|
+
if covariates is None:
|
|
244
|
+
future_index = make_future_index(dataset.df, spec.horizon, spec.freq)
|
|
245
|
+
else:
|
|
246
|
+
future_index = covariates.future_index
|
|
247
|
+
if future_index is None:
|
|
248
|
+
future_index = make_future_index(dataset.df, spec.horizon, spec.freq)
|
|
249
|
+
|
|
250
|
+
for uid, forecaster in bundle.forecasters.items():
|
|
251
|
+
future_dates = future_index[future_index["unique_id"] == uid]["ds"]
|
|
252
|
+
future_dates = pd.to_datetime(future_dates).sort_values()
|
|
253
|
+
fh = ForecastingHorizon(pd.DatetimeIndex(future_dates), is_relative=False)
|
|
254
|
+
|
|
255
|
+
x_future = None
|
|
256
|
+
if covariates is not None and (bundle.static_columns or bundle.future_columns):
|
|
257
|
+
x_future = _build_future_exog(
|
|
258
|
+
covariates=covariates,
|
|
259
|
+
uid=uid,
|
|
260
|
+
static_cols=bundle.static_columns,
|
|
261
|
+
future_cols=bundle.future_columns,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
y_pred = forecaster.predict(fh, X=x_future)
|
|
265
|
+
if isinstance(y_pred, pd.DataFrame):
|
|
266
|
+
y_pred = y_pred.iloc[:, 0]
|
|
267
|
+
|
|
268
|
+
for ds, value in y_pred.items():
|
|
269
|
+
rows.append(
|
|
270
|
+
{
|
|
271
|
+
"unique_id": uid,
|
|
272
|
+
"ds": pd.Timestamp(ds),
|
|
273
|
+
"model": artifact.model_name,
|
|
274
|
+
"yhat": float(value),
|
|
275
|
+
}
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
result_df = pd.DataFrame(rows)
|
|
279
|
+
result_df = normalize_quantile_columns(result_df)
|
|
280
|
+
|
|
281
|
+
provenance = _basic_provenance(dataset, spec, artifact)
|
|
282
|
+
return ForecastResult(
|
|
283
|
+
df=result_df,
|
|
284
|
+
provenance=provenance,
|
|
285
|
+
model_name=artifact.model_name,
|
|
286
|
+
horizon=spec.horizon,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _basic_provenance(dataset: Any, spec: Any, artifact: ModelArtifact) -> Any:
|
|
291
|
+
from datetime import datetime
|
|
292
|
+
|
|
293
|
+
from tsagentkit.contracts import Provenance
|
|
294
|
+
from tsagentkit.utils import compute_data_signature
|
|
295
|
+
|
|
296
|
+
return Provenance(
|
|
297
|
+
run_id=f"sktime_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}",
|
|
298
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
299
|
+
data_signature=compute_data_signature(dataset.df),
|
|
300
|
+
task_signature=spec.model_hash(),
|
|
301
|
+
plan_signature=artifact.signature,
|
|
302
|
+
model_signature=artifact.signature,
|
|
303
|
+
metadata={"adapter": "sktime"},
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
__all__ = ["SktimeModelBundle", "fit_sktime", "predict_sktime"]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Monitoring module for drift detection, coverage, and model stability.
|
|
2
|
+
|
|
3
|
+
Provides utilities for detecting data drift, monitoring prediction stability,
|
|
4
|
+
checking quantile coverage, and triggering alerts when necessary.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from tsagentkit.monitoring.alerts import (
|
|
10
|
+
Alert,
|
|
11
|
+
AlertCondition,
|
|
12
|
+
AlertManager,
|
|
13
|
+
create_default_coverage_alerts,
|
|
14
|
+
create_default_drift_alerts,
|
|
15
|
+
)
|
|
16
|
+
from tsagentkit.monitoring.coverage import CoverageCheck, CoverageMonitor
|
|
17
|
+
from tsagentkit.monitoring.drift import DriftDetector
|
|
18
|
+
from tsagentkit.monitoring.report import (
|
|
19
|
+
CalibrationReport,
|
|
20
|
+
DriftReport,
|
|
21
|
+
FeatureDriftResult,
|
|
22
|
+
StabilityReport,
|
|
23
|
+
TriggerResult,
|
|
24
|
+
)
|
|
25
|
+
from tsagentkit.monitoring.stability import StabilityMonitor
|
|
26
|
+
from tsagentkit.monitoring.triggers import RetrainTrigger, TriggerEvaluator, TriggerType
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Coverage monitoring
|
|
30
|
+
"CoverageMonitor",
|
|
31
|
+
"CoverageCheck",
|
|
32
|
+
# Alerts
|
|
33
|
+
"AlertManager",
|
|
34
|
+
"AlertCondition",
|
|
35
|
+
"Alert",
|
|
36
|
+
"create_default_coverage_alerts",
|
|
37
|
+
"create_default_drift_alerts",
|
|
38
|
+
# Drift detection
|
|
39
|
+
"DriftDetector",
|
|
40
|
+
"DriftReport",
|
|
41
|
+
"FeatureDriftResult",
|
|
42
|
+
# Stability monitoring
|
|
43
|
+
"StabilityMonitor",
|
|
44
|
+
"StabilityReport",
|
|
45
|
+
"CalibrationReport",
|
|
46
|
+
# Retrain triggers
|
|
47
|
+
"TriggerEvaluator",
|
|
48
|
+
"RetrainTrigger",
|
|
49
|
+
"TriggerType",
|
|
50
|
+
"TriggerResult",
|
|
51
|
+
]
|