tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,943 @@
|
|
|
1
|
+
"""Main forecasting orchestration.
|
|
2
|
+
|
|
3
|
+
Provides the unified entry point run_forecast() for executing
|
|
4
|
+
the complete forecasting pipeline.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from tsagentkit.backtest import rolling_backtest
|
|
16
|
+
from tsagentkit.contracts import (
|
|
17
|
+
AnomalySpec,
|
|
18
|
+
CalibratorSpec,
|
|
19
|
+
EAnomalyFail,
|
|
20
|
+
ECalibrationFail,
|
|
21
|
+
ECovariateIncompleteKnown,
|
|
22
|
+
ECovariateLeakage,
|
|
23
|
+
ECovariateStaticInvalid,
|
|
24
|
+
EFallbackExhausted,
|
|
25
|
+
EQACriticalIssue,
|
|
26
|
+
ETaskSpecInvalid,
|
|
27
|
+
ForecastResult,
|
|
28
|
+
PanelContract,
|
|
29
|
+
TaskSpec,
|
|
30
|
+
ValidationReport,
|
|
31
|
+
validate_contract,
|
|
32
|
+
)
|
|
33
|
+
from tsagentkit.covariates import AlignedDataset, CovariateBundle, align_covariates
|
|
34
|
+
from tsagentkit.qa import QAReport, run_qa
|
|
35
|
+
from tsagentkit.router import make_plan
|
|
36
|
+
from tsagentkit.series import TSDataset
|
|
37
|
+
from tsagentkit.time import infer_freq
|
|
38
|
+
from tsagentkit.utils import drop_future_rows, normalize_quantile_columns
|
|
39
|
+
|
|
40
|
+
from .packaging import package_run
|
|
41
|
+
from .provenance import create_provenance, log_event
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from tsagentkit.contracts import RunArtifact
|
|
45
|
+
from tsagentkit.features import FeatureConfig, FeatureMatrix
|
|
46
|
+
from tsagentkit.hierarchy import HierarchyStructure
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class MonitoringConfig:
|
|
51
|
+
"""Configuration for monitoring during forecasting.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
enabled: Whether monitoring is enabled
|
|
55
|
+
drift_method: Drift detection method ("psi" or "ks")
|
|
56
|
+
drift_threshold: Threshold for drift detection
|
|
57
|
+
check_stability: Whether to compute stability metrics
|
|
58
|
+
jitter_threshold: Threshold for jitter warnings
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
>>> config = MonitoringConfig(
|
|
62
|
+
... enabled=True,
|
|
63
|
+
... drift_method="psi",
|
|
64
|
+
... drift_threshold=0.2,
|
|
65
|
+
... )
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
enabled: bool = False
|
|
69
|
+
drift_method: Literal["psi", "ks"] = "psi"
|
|
70
|
+
drift_threshold: float | None = None
|
|
71
|
+
check_stability: bool = False
|
|
72
|
+
jitter_threshold: float = 0.1
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def run_forecast(
|
|
76
|
+
data: pd.DataFrame,
|
|
77
|
+
task_spec: TaskSpec,
|
|
78
|
+
covariates: CovariateBundle | None = None,
|
|
79
|
+
mode: Literal["quick", "standard", "strict"] = "standard",
|
|
80
|
+
fit_func: Any | None = None,
|
|
81
|
+
predict_func: Any | None = None,
|
|
82
|
+
monitoring_config: MonitoringConfig | None = None,
|
|
83
|
+
reference_data: pd.DataFrame | None = None,
|
|
84
|
+
repair_strategy: dict[str, Any] | None = None,
|
|
85
|
+
hierarchy: HierarchyStructure | None = None,
|
|
86
|
+
feature_config: FeatureConfig | None = None,
|
|
87
|
+
calibrator_spec: CalibratorSpec | None = None,
|
|
88
|
+
anomaly_spec: AnomalySpec | None = None,
|
|
89
|
+
) -> RunArtifact:
|
|
90
|
+
"""Execute the complete forecasting pipeline.
|
|
91
|
+
|
|
92
|
+
This is the main entry point for tsagentkit. It orchestrates the
|
|
93
|
+
entire workflow: validation -> QA -> dataset creation -> planning ->
|
|
94
|
+
(backtest for standard/strict) -> fit -> predict -> package.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
data: Input DataFrame with columns [unique_id, ds, y]
|
|
98
|
+
task_spec: Task specification with horizon, freq, etc.
|
|
99
|
+
covariates: Optional covariate bundle (bundle mode)
|
|
100
|
+
mode: Execution mode:
|
|
101
|
+
- "quick": Skip backtest, fit on all data
|
|
102
|
+
- "standard": Full pipeline with backtest (default)
|
|
103
|
+
- "strict": Fail on any QA issue (no auto-repair)
|
|
104
|
+
fit_func: Optional custom model fit function (fit(dataset, plan))
|
|
105
|
+
predict_func: Optional custom model predict function (predict(dataset, artifact, spec))
|
|
106
|
+
monitoring_config: Optional monitoring configuration (v0.2)
|
|
107
|
+
reference_data: Optional reference data for drift detection (v0.2)
|
|
108
|
+
repair_strategy: Optional QA repair configuration (overrides TaskSpec)
|
|
109
|
+
hierarchy: Optional hierarchy structure for reconciliation
|
|
110
|
+
feature_config: Optional feature configuration for feature engineering (v1.0)
|
|
111
|
+
calibrator_spec: Optional calibration specification
|
|
112
|
+
anomaly_spec: Optional anomaly detection specification
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
RunArtifact with forecast, metrics, and provenance
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
EContractMissingColumn: If required columns missing
|
|
119
|
+
EContractInvalidType: If columns have wrong types
|
|
120
|
+
ESplitRandomForbidden: If data is not temporally ordered
|
|
121
|
+
EFallbackExhausted: If all models fail
|
|
122
|
+
"""
|
|
123
|
+
events: list[dict[str, Any]] = []
|
|
124
|
+
qa_repairs: list[dict[str, Any]] = []
|
|
125
|
+
fallbacks_triggered: list[dict[str, Any]] = []
|
|
126
|
+
start_time = time.time()
|
|
127
|
+
column_map: dict[str, str] | None = None
|
|
128
|
+
original_panel_contract = task_spec.panel_contract
|
|
129
|
+
|
|
130
|
+
# Step 1: Validate
|
|
131
|
+
data = data.copy()
|
|
132
|
+
step_start = time.time()
|
|
133
|
+
validation, data = _step_validate(data, task_spec)
|
|
134
|
+
events.append(
|
|
135
|
+
log_event(
|
|
136
|
+
step_name="validate",
|
|
137
|
+
status="success" if validation.valid else "failed",
|
|
138
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
139
|
+
error_code=_get_error_code(validation) if not validation.valid else None,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if not validation.valid:
|
|
144
|
+
validation.raise_if_errors()
|
|
145
|
+
|
|
146
|
+
# Normalize panel columns to canonical names if needed
|
|
147
|
+
data, column_map = TSDataset._normalize_panel_columns(
|
|
148
|
+
data,
|
|
149
|
+
task_spec.panel_contract,
|
|
150
|
+
)
|
|
151
|
+
if column_map:
|
|
152
|
+
task_spec = task_spec.model_copy(update={"panel_contract": PanelContract()})
|
|
153
|
+
|
|
154
|
+
# Infer frequency if missing
|
|
155
|
+
if not task_spec.freq:
|
|
156
|
+
if not task_spec.infer_freq:
|
|
157
|
+
raise ETaskSpecInvalid(
|
|
158
|
+
"TaskSpec.freq is required when infer_freq=False.",
|
|
159
|
+
context={"freq": task_spec.freq},
|
|
160
|
+
)
|
|
161
|
+
step_start = time.time()
|
|
162
|
+
try:
|
|
163
|
+
inferred = infer_freq(
|
|
164
|
+
data,
|
|
165
|
+
id_col=task_spec.panel_contract.unique_id_col,
|
|
166
|
+
ds_col=task_spec.panel_contract.ds_col,
|
|
167
|
+
)
|
|
168
|
+
task_spec = task_spec.model_copy(update={"freq": inferred})
|
|
169
|
+
events.append(
|
|
170
|
+
log_event(
|
|
171
|
+
step_name="infer_freq",
|
|
172
|
+
status="success",
|
|
173
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
174
|
+
context={"freq": inferred},
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
events.append(
|
|
179
|
+
log_event(
|
|
180
|
+
step_name="infer_freq",
|
|
181
|
+
status="failed",
|
|
182
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
183
|
+
error_code=type(e).__name__,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
raise
|
|
187
|
+
|
|
188
|
+
# Step 2: QA
|
|
189
|
+
step_start = time.time()
|
|
190
|
+
effective_repair_strategy = repair_strategy
|
|
191
|
+
covariate_error: Exception | None = None
|
|
192
|
+
try:
|
|
193
|
+
qa_report = _step_qa(
|
|
194
|
+
data,
|
|
195
|
+
task_spec,
|
|
196
|
+
mode,
|
|
197
|
+
apply_repairs=mode != "strict",
|
|
198
|
+
repair_strategy=effective_repair_strategy,
|
|
199
|
+
)
|
|
200
|
+
qa_repairs = qa_report.repairs
|
|
201
|
+
events.append(
|
|
202
|
+
log_event(
|
|
203
|
+
step_name="qa",
|
|
204
|
+
status="success",
|
|
205
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
except (ECovariateLeakage, ECovariateIncompleteKnown, ECovariateStaticInvalid) as e:
|
|
209
|
+
covariate_error = e
|
|
210
|
+
if mode == "strict":
|
|
211
|
+
events.append(
|
|
212
|
+
log_event(
|
|
213
|
+
step_name="qa",
|
|
214
|
+
status="failed",
|
|
215
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
216
|
+
error_code=type(e).__name__,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
raise
|
|
220
|
+
qa_report = _step_qa(
|
|
221
|
+
data,
|
|
222
|
+
task_spec,
|
|
223
|
+
mode,
|
|
224
|
+
apply_repairs=mode != "strict",
|
|
225
|
+
repair_strategy=effective_repair_strategy,
|
|
226
|
+
skip_covariate_checks=True,
|
|
227
|
+
)
|
|
228
|
+
qa_repairs = qa_report.repairs
|
|
229
|
+
issues = list(qa_report.issues)
|
|
230
|
+
issues.append(
|
|
231
|
+
{
|
|
232
|
+
"type": "covariate_guardrail",
|
|
233
|
+
"error": str(e),
|
|
234
|
+
"severity": "critical",
|
|
235
|
+
"action": "dropped_covariates",
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
qa_report = QAReport(
|
|
239
|
+
issues=issues,
|
|
240
|
+
repairs=qa_report.repairs,
|
|
241
|
+
leakage_detected=isinstance(e, ECovariateLeakage),
|
|
242
|
+
)
|
|
243
|
+
events.append(
|
|
244
|
+
log_event(
|
|
245
|
+
step_name="qa",
|
|
246
|
+
status="success",
|
|
247
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
248
|
+
context={"covariates_dropped": True},
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
except Exception as e:
|
|
252
|
+
events.append(
|
|
253
|
+
log_event(
|
|
254
|
+
step_name="qa",
|
|
255
|
+
status="failed",
|
|
256
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
257
|
+
error_code=type(e).__name__,
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
raise
|
|
261
|
+
|
|
262
|
+
# Step 2b: Align covariates before dropping future rows (preserve single-table future-known)
|
|
263
|
+
step_start = time.time()
|
|
264
|
+
aligned_dataset: AlignedDataset | None = None
|
|
265
|
+
panel_with_covariates = data.copy()
|
|
266
|
+
if covariate_error is None:
|
|
267
|
+
try:
|
|
268
|
+
aligned_dataset = align_covariates(
|
|
269
|
+
panel_with_covariates,
|
|
270
|
+
task_spec,
|
|
271
|
+
covariates=covariates,
|
|
272
|
+
)
|
|
273
|
+
events.append(
|
|
274
|
+
log_event(
|
|
275
|
+
step_name="align_covariates",
|
|
276
|
+
status="success",
|
|
277
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
278
|
+
artifacts_generated=["aligned_covariates"],
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
except (ECovariateLeakage, ECovariateIncompleteKnown, ECovariateStaticInvalid) as e:
|
|
282
|
+
covariate_error = e
|
|
283
|
+
events.append(
|
|
284
|
+
log_event(
|
|
285
|
+
step_name="align_covariates",
|
|
286
|
+
status="failed",
|
|
287
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
288
|
+
error_code=type(e).__name__,
|
|
289
|
+
)
|
|
290
|
+
)
|
|
291
|
+
if mode == "strict":
|
|
292
|
+
raise
|
|
293
|
+
except Exception as e:
|
|
294
|
+
events.append(
|
|
295
|
+
log_event(
|
|
296
|
+
step_name="align_covariates",
|
|
297
|
+
status="failed",
|
|
298
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
299
|
+
error_code=type(e).__name__,
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
if mode == "strict":
|
|
303
|
+
raise
|
|
304
|
+
else:
|
|
305
|
+
events.append(
|
|
306
|
+
log_event(
|
|
307
|
+
step_name="align_covariates",
|
|
308
|
+
status="skipped",
|
|
309
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
310
|
+
context={"covariates_dropped": True},
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Step 2c: Drop future rows (y is null beyond last observed per series)
|
|
315
|
+
step_start = time.time()
|
|
316
|
+
data, drop_info = drop_future_rows(
|
|
317
|
+
data,
|
|
318
|
+
id_col=task_spec.panel_contract.unique_id_col,
|
|
319
|
+
ds_col=task_spec.panel_contract.ds_col,
|
|
320
|
+
y_col=task_spec.panel_contract.y_col,
|
|
321
|
+
)
|
|
322
|
+
if drop_info:
|
|
323
|
+
qa_repairs.append(drop_info)
|
|
324
|
+
events.append(
|
|
325
|
+
log_event(
|
|
326
|
+
step_name="drop_future_rows",
|
|
327
|
+
status="success",
|
|
328
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
329
|
+
artifacts_generated=["clean_data"],
|
|
330
|
+
)
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Step 3: Build Dataset
|
|
334
|
+
step_start = time.time()
|
|
335
|
+
dataset = TSDataset.from_dataframe(data, task_spec, validate=False)
|
|
336
|
+
if hierarchy is not None:
|
|
337
|
+
dataset = dataset.with_hierarchy(hierarchy)
|
|
338
|
+
if aligned_dataset is not None:
|
|
339
|
+
uid_col = task_spec.panel_contract.unique_id_col
|
|
340
|
+
ds_col = task_spec.panel_contract.ds_col
|
|
341
|
+
y_col = task_spec.panel_contract.y_col
|
|
342
|
+
aligned_dataset = AlignedDataset(
|
|
343
|
+
panel=data[[uid_col, ds_col, y_col]].copy(),
|
|
344
|
+
static_x=aligned_dataset.static_x,
|
|
345
|
+
past_x=aligned_dataset.past_x,
|
|
346
|
+
future_x=aligned_dataset.future_x,
|
|
347
|
+
covariate_spec=aligned_dataset.covariate_spec,
|
|
348
|
+
future_index=aligned_dataset.future_index,
|
|
349
|
+
)
|
|
350
|
+
dataset = dataset.with_covariates(
|
|
351
|
+
aligned_dataset,
|
|
352
|
+
panel_with_covariates=panel_with_covariates,
|
|
353
|
+
covariate_bundle=covariates,
|
|
354
|
+
)
|
|
355
|
+
events.append(
|
|
356
|
+
log_event(
|
|
357
|
+
step_name="build_dataset",
|
|
358
|
+
status="success",
|
|
359
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Step 3b: Feature Engineering (v1.0)
|
|
364
|
+
feature_matrix: FeatureMatrix | None = None
|
|
365
|
+
if feature_config is not None:
|
|
366
|
+
step_start = time.time()
|
|
367
|
+
try:
|
|
368
|
+
from tsagentkit.features import FeatureFactory
|
|
369
|
+
|
|
370
|
+
factory = FeatureFactory(feature_config)
|
|
371
|
+
feature_matrix = factory.create_features(dataset)
|
|
372
|
+
events.append(
|
|
373
|
+
log_event(
|
|
374
|
+
step_name="feature_engineering",
|
|
375
|
+
status="success",
|
|
376
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
377
|
+
artifacts_generated=["feature_matrix"],
|
|
378
|
+
context={
|
|
379
|
+
"n_features": len(feature_matrix.feature_cols),
|
|
380
|
+
"feature_hash": feature_matrix.config_hash,
|
|
381
|
+
},
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
except Exception as e:
|
|
385
|
+
events.append(
|
|
386
|
+
log_event(
|
|
387
|
+
step_name="feature_engineering",
|
|
388
|
+
status="failed",
|
|
389
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
390
|
+
error_code=type(e).__name__,
|
|
391
|
+
)
|
|
392
|
+
)
|
|
393
|
+
if mode == "strict":
|
|
394
|
+
raise
|
|
395
|
+
|
|
396
|
+
# Step 4: Make Plan
|
|
397
|
+
step_start = time.time()
|
|
398
|
+
plan, route_decision = make_plan(dataset, task_spec, qa_report)
|
|
399
|
+
events.append(
|
|
400
|
+
log_event(
|
|
401
|
+
step_name="make_plan",
|
|
402
|
+
status="success",
|
|
403
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
404
|
+
artifacts_generated=["plan", "route_decision"],
|
|
405
|
+
context={
|
|
406
|
+
"buckets": route_decision.buckets,
|
|
407
|
+
"reasons": route_decision.reasons,
|
|
408
|
+
},
|
|
409
|
+
)
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
if covariate_error is not None:
|
|
413
|
+
if not plan.allow_drop_covariates:
|
|
414
|
+
raise covariate_error
|
|
415
|
+
fallbacks_triggered.append(
|
|
416
|
+
{
|
|
417
|
+
"type": "covariates_dropped",
|
|
418
|
+
"error": str(covariate_error),
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Step 5: Backtest (if standard or strict mode)
|
|
423
|
+
backtest_report = None
|
|
424
|
+
if mode in ("standard", "strict"):
|
|
425
|
+
step_start = time.time()
|
|
426
|
+
try:
|
|
427
|
+
from tsagentkit.models import fit as default_fit
|
|
428
|
+
from tsagentkit.models import predict as default_predict
|
|
429
|
+
|
|
430
|
+
if fit_func is None:
|
|
431
|
+
fit_func = default_fit
|
|
432
|
+
if predict_func is None:
|
|
433
|
+
predict_func = default_predict
|
|
434
|
+
|
|
435
|
+
backtest_cfg = task_spec.backtest
|
|
436
|
+
n_windows = backtest_cfg.n_windows
|
|
437
|
+
min_train_size = backtest_cfg.min_train_size
|
|
438
|
+
|
|
439
|
+
backtest_report = rolling_backtest(
|
|
440
|
+
dataset=dataset,
|
|
441
|
+
spec=task_spec,
|
|
442
|
+
plan=plan,
|
|
443
|
+
fit_func=fit_func,
|
|
444
|
+
predict_func=predict_func,
|
|
445
|
+
n_windows=n_windows,
|
|
446
|
+
step_size=task_spec.horizon,
|
|
447
|
+
min_train_size=min_train_size,
|
|
448
|
+
route_decision=route_decision,
|
|
449
|
+
)
|
|
450
|
+
events.append(
|
|
451
|
+
log_event(
|
|
452
|
+
step_name="rolling_backtest",
|
|
453
|
+
status="success",
|
|
454
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
455
|
+
artifacts_generated=["backtest_report"],
|
|
456
|
+
)
|
|
457
|
+
)
|
|
458
|
+
except Exception as e:
|
|
459
|
+
events.append(
|
|
460
|
+
log_event(
|
|
461
|
+
step_name="rolling_backtest",
|
|
462
|
+
status="failed",
|
|
463
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
464
|
+
error_code=type(e).__name__,
|
|
465
|
+
)
|
|
466
|
+
)
|
|
467
|
+
if mode == "strict":
|
|
468
|
+
raise
|
|
469
|
+
|
|
470
|
+
# Step 6: Fit Model
|
|
471
|
+
step_start = time.time()
|
|
472
|
+
|
|
473
|
+
def on_fallback(from_model: str, to_model: str, error: Exception) -> None:
|
|
474
|
+
fallbacks_triggered.append(
|
|
475
|
+
{
|
|
476
|
+
"from": from_model,
|
|
477
|
+
"to": to_model,
|
|
478
|
+
"error": str(error),
|
|
479
|
+
}
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
model_artifact = _step_fit(
|
|
483
|
+
dataset=dataset,
|
|
484
|
+
plan=plan,
|
|
485
|
+
fit_func=fit_func,
|
|
486
|
+
on_fallback=on_fallback,
|
|
487
|
+
covariates=aligned_dataset,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
events.append(
|
|
491
|
+
log_event(
|
|
492
|
+
step_name="fit",
|
|
493
|
+
status="success",
|
|
494
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
495
|
+
artifacts_generated=["model_artifact"],
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Step 7: Predict
|
|
500
|
+
step_start = time.time()
|
|
501
|
+
try:
|
|
502
|
+
forecast_df = _step_predict(
|
|
503
|
+
artifact=model_artifact,
|
|
504
|
+
dataset=dataset,
|
|
505
|
+
task_spec=task_spec,
|
|
506
|
+
predict_func=predict_func,
|
|
507
|
+
plan=plan,
|
|
508
|
+
covariates=aligned_dataset,
|
|
509
|
+
)
|
|
510
|
+
events.append(
|
|
511
|
+
log_event(
|
|
512
|
+
step_name="predict",
|
|
513
|
+
status="success",
|
|
514
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
515
|
+
artifacts_generated=["forecast"],
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
except Exception as e:
|
|
519
|
+
events.append(
|
|
520
|
+
log_event(
|
|
521
|
+
step_name="predict",
|
|
522
|
+
status="failed",
|
|
523
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
524
|
+
error_code=type(e).__name__,
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
try:
|
|
528
|
+
model_artifact, forecast_df = _fit_predict_with_fallback(
|
|
529
|
+
dataset=dataset,
|
|
530
|
+
plan=plan,
|
|
531
|
+
task_spec=task_spec,
|
|
532
|
+
fit_func=fit_func,
|
|
533
|
+
predict_func=predict_func,
|
|
534
|
+
covariates=aligned_dataset,
|
|
535
|
+
start_after=model_artifact.model_name,
|
|
536
|
+
initial_error=e,
|
|
537
|
+
on_fallback=on_fallback,
|
|
538
|
+
)
|
|
539
|
+
events.append(
|
|
540
|
+
log_event(
|
|
541
|
+
step_name="predict_fallback",
|
|
542
|
+
status="success",
|
|
543
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
544
|
+
artifacts_generated=["forecast"],
|
|
545
|
+
)
|
|
546
|
+
)
|
|
547
|
+
except Exception as fallback_error:
|
|
548
|
+
events.append(
|
|
549
|
+
log_event(
|
|
550
|
+
step_name="predict_fallback",
|
|
551
|
+
status="failed",
|
|
552
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
553
|
+
error_code=type(fallback_error).__name__,
|
|
554
|
+
)
|
|
555
|
+
)
|
|
556
|
+
raise
|
|
557
|
+
|
|
558
|
+
# Step 8: Calibration (optional)
|
|
559
|
+
calibration_artifact = None
|
|
560
|
+
if calibrator_spec is not None:
|
|
561
|
+
step_start = time.time()
|
|
562
|
+
try:
|
|
563
|
+
from tsagentkit.calibration import apply_calibrator, fit_calibrator
|
|
564
|
+
|
|
565
|
+
if backtest_report is None or backtest_report.cv_frame is None:
|
|
566
|
+
raise ECalibrationFail(
|
|
567
|
+
"Calibration requires CV residuals from backtest.",
|
|
568
|
+
context={"mode": mode},
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
cv_frame = backtest_report.cv_frame
|
|
572
|
+
if hasattr(cv_frame, "df"):
|
|
573
|
+
cv_frame = cv_frame.df
|
|
574
|
+
|
|
575
|
+
calibration_artifact = fit_calibrator(
|
|
576
|
+
cv_frame,
|
|
577
|
+
method=calibrator_spec.method,
|
|
578
|
+
level=calibrator_spec.level,
|
|
579
|
+
by=calibrator_spec.by,
|
|
580
|
+
)
|
|
581
|
+
forecast_df = apply_calibrator(forecast_df, calibration_artifact)
|
|
582
|
+
events.append(
|
|
583
|
+
log_event(
|
|
584
|
+
step_name="calibration",
|
|
585
|
+
status="success",
|
|
586
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
587
|
+
artifacts_generated=["calibration_artifact"],
|
|
588
|
+
)
|
|
589
|
+
)
|
|
590
|
+
except Exception as e:
|
|
591
|
+
events.append(
|
|
592
|
+
log_event(
|
|
593
|
+
step_name="calibration",
|
|
594
|
+
status="failed",
|
|
595
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
596
|
+
error_code=type(e).__name__,
|
|
597
|
+
)
|
|
598
|
+
)
|
|
599
|
+
if mode == "strict":
|
|
600
|
+
raise
|
|
601
|
+
|
|
602
|
+
# Step 9: Anomaly Detection (optional)
|
|
603
|
+
anomaly_report = None
|
|
604
|
+
if anomaly_spec is not None:
|
|
605
|
+
step_start = time.time()
|
|
606
|
+
try:
|
|
607
|
+
from tsagentkit.anomaly import detect_anomalies
|
|
608
|
+
|
|
609
|
+
uid_col = task_spec.panel_contract.unique_id_col
|
|
610
|
+
ds_col = task_spec.panel_contract.ds_col
|
|
611
|
+
y_col = task_spec.panel_contract.y_col
|
|
612
|
+
|
|
613
|
+
actuals = data[[uid_col, ds_col, y_col]].copy()
|
|
614
|
+
merged = forecast_df.merge(
|
|
615
|
+
actuals,
|
|
616
|
+
on=[uid_col, ds_col],
|
|
617
|
+
how="left",
|
|
618
|
+
)
|
|
619
|
+
if merged[y_col].notna().any():
|
|
620
|
+
anomaly_report = detect_anomalies(
|
|
621
|
+
merged,
|
|
622
|
+
method=anomaly_spec.method,
|
|
623
|
+
level=anomaly_spec.level,
|
|
624
|
+
score=anomaly_spec.score,
|
|
625
|
+
calibrator=calibration_artifact,
|
|
626
|
+
strict=(mode == "strict"),
|
|
627
|
+
)
|
|
628
|
+
events.append(
|
|
629
|
+
log_event(
|
|
630
|
+
step_name="anomaly_detection",
|
|
631
|
+
status="success",
|
|
632
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
633
|
+
artifacts_generated=["anomaly_report"],
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
elif mode == "strict":
|
|
637
|
+
raise EAnomalyFail(
|
|
638
|
+
"No actuals available for anomaly detection.",
|
|
639
|
+
context={"mode": mode},
|
|
640
|
+
)
|
|
641
|
+
except Exception as e:
|
|
642
|
+
events.append(
|
|
643
|
+
log_event(
|
|
644
|
+
step_name="anomaly_detection",
|
|
645
|
+
status="failed",
|
|
646
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
647
|
+
error_code=type(e).__name__,
|
|
648
|
+
)
|
|
649
|
+
)
|
|
650
|
+
if mode == "strict":
|
|
651
|
+
raise
|
|
652
|
+
|
|
653
|
+
# Step 10: Drift Detection (v0.2)
|
|
654
|
+
drift_report = None
|
|
655
|
+
if monitoring_config and monitoring_config.enabled and reference_data is not None:
|
|
656
|
+
step_start = time.time()
|
|
657
|
+
try:
|
|
658
|
+
from tsagentkit.monitoring import DriftDetector
|
|
659
|
+
|
|
660
|
+
detector = DriftDetector(
|
|
661
|
+
method=monitoring_config.drift_method,
|
|
662
|
+
threshold=monitoring_config.drift_threshold,
|
|
663
|
+
)
|
|
664
|
+
drift_report = detector.detect(
|
|
665
|
+
reference_data=reference_data,
|
|
666
|
+
current_data=data,
|
|
667
|
+
)
|
|
668
|
+
events.append(
|
|
669
|
+
log_event(
|
|
670
|
+
step_name="drift_detection",
|
|
671
|
+
status="success",
|
|
672
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
673
|
+
artifacts_generated=["drift_report"],
|
|
674
|
+
)
|
|
675
|
+
)
|
|
676
|
+
except Exception as e:
|
|
677
|
+
events.append(
|
|
678
|
+
log_event(
|
|
679
|
+
step_name="drift_detection",
|
|
680
|
+
status="failed",
|
|
681
|
+
duration_ms=(time.time() - step_start) * 1000,
|
|
682
|
+
error_code=type(e).__name__,
|
|
683
|
+
)
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# Step 11: Create Provenance
|
|
687
|
+
provenance = create_provenance(
|
|
688
|
+
data=data,
|
|
689
|
+
task_spec=task_spec,
|
|
690
|
+
plan=plan,
|
|
691
|
+
model_config=plan.model_dump() if hasattr(plan, "model_dump") else None,
|
|
692
|
+
qa_repairs=qa_repairs,
|
|
693
|
+
fallbacks_triggered=fallbacks_triggered,
|
|
694
|
+
feature_matrix=feature_matrix,
|
|
695
|
+
drift_report=drift_report,
|
|
696
|
+
column_map=column_map,
|
|
697
|
+
original_panel_contract=(
|
|
698
|
+
original_panel_contract.model_dump()
|
|
699
|
+
if hasattr(original_panel_contract, "model_dump")
|
|
700
|
+
else None
|
|
701
|
+
),
|
|
702
|
+
route_decision=route_decision,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# Step 12: Package
|
|
706
|
+
forecast_result = ForecastResult(
|
|
707
|
+
df=forecast_df,
|
|
708
|
+
provenance=provenance,
|
|
709
|
+
model_name=model_artifact.model_name,
|
|
710
|
+
horizon=task_spec.horizon,
|
|
711
|
+
)
|
|
712
|
+
artifact = package_run(
|
|
713
|
+
forecast=forecast_result,
|
|
714
|
+
plan=plan,
|
|
715
|
+
task_spec=task_spec.model_dump() if hasattr(task_spec, "model_dump") else None,
|
|
716
|
+
validation_report=validation.to_dict() if validation else None,
|
|
717
|
+
backtest_report=backtest_report,
|
|
718
|
+
qa_report=qa_report,
|
|
719
|
+
model_artifact=model_artifact,
|
|
720
|
+
provenance=provenance,
|
|
721
|
+
calibration_artifact=calibration_artifact,
|
|
722
|
+
anomaly_report=anomaly_report,
|
|
723
|
+
metadata={
|
|
724
|
+
"mode": mode,
|
|
725
|
+
"total_duration_ms": (time.time() - start_time) * 1000,
|
|
726
|
+
"events": events,
|
|
727
|
+
},
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
return artifact
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _step_validate(
|
|
734
|
+
data: pd.DataFrame,
|
|
735
|
+
task_spec: TaskSpec,
|
|
736
|
+
) -> tuple[ValidationReport, pd.DataFrame]:
|
|
737
|
+
"""Execute validation step."""
|
|
738
|
+
report, normalized = validate_contract(
|
|
739
|
+
data,
|
|
740
|
+
panel_contract=task_spec.panel_contract,
|
|
741
|
+
apply_aggregation=True,
|
|
742
|
+
return_data=True,
|
|
743
|
+
)
|
|
744
|
+
return report, normalized
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _step_qa(
|
|
748
|
+
data: pd.DataFrame,
|
|
749
|
+
task_spec: TaskSpec,
|
|
750
|
+
mode: Literal["quick", "standard", "strict"],
|
|
751
|
+
apply_repairs: bool = False,
|
|
752
|
+
repair_strategy: dict[str, Any] | None = None,
|
|
753
|
+
skip_covariate_checks: bool = False,
|
|
754
|
+
) -> QAReport:
|
|
755
|
+
"""Execute QA step.
|
|
756
|
+
|
|
757
|
+
For v0.1, this is a minimal implementation.
|
|
758
|
+
"""
|
|
759
|
+
report = run_qa(
|
|
760
|
+
data,
|
|
761
|
+
task_spec,
|
|
762
|
+
mode,
|
|
763
|
+
apply_repairs=apply_repairs,
|
|
764
|
+
repair_strategy=repair_strategy,
|
|
765
|
+
skip_covariate_checks=skip_covariate_checks,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
if mode == "strict":
|
|
769
|
+
if report.leakage_detected:
|
|
770
|
+
raise ECovariateLeakage("Covariate leakage detected")
|
|
771
|
+
if report.has_critical_issues():
|
|
772
|
+
raise EQACriticalIssue("Critical QA issues detected")
|
|
773
|
+
|
|
774
|
+
return report
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _step_fit(
|
|
778
|
+
dataset: TSDataset,
|
|
779
|
+
plan: Any,
|
|
780
|
+
fit_func: Any | None,
|
|
781
|
+
on_fallback: Any | None = None,
|
|
782
|
+
covariates: AlignedDataset | None = None,
|
|
783
|
+
) -> Any:
|
|
784
|
+
"""Execute fit step with fallback."""
|
|
785
|
+
from tsagentkit.models import fit as default_fit
|
|
786
|
+
|
|
787
|
+
if fit_func is None:
|
|
788
|
+
# Use default fit function
|
|
789
|
+
fit_func = default_fit
|
|
790
|
+
kwargs = {"covariates": covariates} if covariates is not None else {}
|
|
791
|
+
if fit_func is default_fit:
|
|
792
|
+
return fit_func(dataset, plan, on_fallback=on_fallback, **kwargs)
|
|
793
|
+
return _call_with_optional_kwargs(fit_func, dataset, plan, **kwargs)
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _step_predict(
|
|
797
|
+
artifact: Any,
|
|
798
|
+
dataset: TSDataset,
|
|
799
|
+
task_spec: TaskSpec,
|
|
800
|
+
predict_func: Any | None,
|
|
801
|
+
plan: Any | None = None,
|
|
802
|
+
covariates: AlignedDataset | None = None,
|
|
803
|
+
) -> pd.DataFrame:
|
|
804
|
+
"""Execute predict step."""
|
|
805
|
+
if predict_func is None:
|
|
806
|
+
# Use default predict function
|
|
807
|
+
from tsagentkit.models import predict as default_predict
|
|
808
|
+
|
|
809
|
+
predict_func = default_predict
|
|
810
|
+
|
|
811
|
+
kwargs = {"covariates": covariates} if covariates is not None else {}
|
|
812
|
+
forecast = _call_with_optional_kwargs(predict_func, dataset, artifact, task_spec, **kwargs)
|
|
813
|
+
if isinstance(forecast, ForecastResult):
|
|
814
|
+
forecast = forecast.df
|
|
815
|
+
|
|
816
|
+
if "model" not in forecast.columns:
|
|
817
|
+
model_name = getattr(artifact, "model_name", None)
|
|
818
|
+
if model_name is None and hasattr(artifact, "metadata"):
|
|
819
|
+
model_name = artifact.metadata.get("model_name") if artifact.metadata else None
|
|
820
|
+
forecast = forecast.copy()
|
|
821
|
+
forecast["model"] = model_name or "model"
|
|
822
|
+
|
|
823
|
+
# Apply reconciliation if hierarchical
|
|
824
|
+
if plan and dataset.is_hierarchical() and dataset.hierarchy:
|
|
825
|
+
from tsagentkit.hierarchy import ReconciliationMethod, reconcile_forecasts
|
|
826
|
+
|
|
827
|
+
method_str = "bottom_up"
|
|
828
|
+
method_map = {
|
|
829
|
+
"bottom_up": ReconciliationMethod.BOTTOM_UP,
|
|
830
|
+
"top_down": ReconciliationMethod.TOP_DOWN,
|
|
831
|
+
"middle_out": ReconciliationMethod.MIDDLE_OUT,
|
|
832
|
+
"ols": ReconciliationMethod.OLS,
|
|
833
|
+
"wls": ReconciliationMethod.WLS,
|
|
834
|
+
"min_trace": ReconciliationMethod.MIN_TRACE,
|
|
835
|
+
}
|
|
836
|
+
method = method_map.get(method_str, ReconciliationMethod.BOTTOM_UP)
|
|
837
|
+
|
|
838
|
+
forecast = reconcile_forecasts(
|
|
839
|
+
base_forecasts=forecast,
|
|
840
|
+
structure=dataset.hierarchy,
|
|
841
|
+
method=method,
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
forecast = normalize_quantile_columns(forecast)
|
|
845
|
+
if {"unique_id", "ds"}.issubset(forecast.columns):
|
|
846
|
+
forecast = forecast.sort_values(["unique_id", "ds"]).reset_index(drop=True)
|
|
847
|
+
|
|
848
|
+
return forecast
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _fit_predict_with_fallback(
|
|
852
|
+
dataset: TSDataset,
|
|
853
|
+
plan: Any,
|
|
854
|
+
task_spec: TaskSpec,
|
|
855
|
+
fit_func: Any | None,
|
|
856
|
+
predict_func: Any | None,
|
|
857
|
+
covariates: AlignedDataset | None = None,
|
|
858
|
+
start_after: str | None = None,
|
|
859
|
+
initial_error: Exception | None = None,
|
|
860
|
+
on_fallback: Any | None = None,
|
|
861
|
+
) -> tuple[Any, pd.DataFrame]:
|
|
862
|
+
"""Fit and predict with fallback across remaining candidates."""
|
|
863
|
+
from tsagentkit.models import fit as default_fit
|
|
864
|
+
from tsagentkit.models import predict as default_predict
|
|
865
|
+
|
|
866
|
+
fit_callable = fit_func or default_fit
|
|
867
|
+
predict_callable = predict_func or default_predict
|
|
868
|
+
|
|
869
|
+
candidates = list(getattr(plan, "candidate_models", []) or [])
|
|
870
|
+
start_idx = 0
|
|
871
|
+
if start_after in candidates:
|
|
872
|
+
start_idx = candidates.index(start_after) + 1
|
|
873
|
+
remaining = candidates[start_idx:]
|
|
874
|
+
|
|
875
|
+
last_error: Exception | None = None
|
|
876
|
+
|
|
877
|
+
if start_after and remaining and on_fallback and initial_error is not None:
|
|
878
|
+
on_fallback(start_after, remaining[0], initial_error)
|
|
879
|
+
|
|
880
|
+
for i, model_name in enumerate(remaining):
|
|
881
|
+
plan_for_model = plan
|
|
882
|
+
if hasattr(plan, "model_copy"):
|
|
883
|
+
plan_for_model = plan.model_copy(update={"candidate_models": [model_name]})
|
|
884
|
+
|
|
885
|
+
try:
|
|
886
|
+
artifact = _call_with_optional_kwargs(
|
|
887
|
+
fit_callable,
|
|
888
|
+
dataset,
|
|
889
|
+
plan_for_model,
|
|
890
|
+
covariates=covariates,
|
|
891
|
+
)
|
|
892
|
+
except Exception as e:
|
|
893
|
+
last_error = e
|
|
894
|
+
if on_fallback and i < len(remaining) - 1:
|
|
895
|
+
on_fallback(model_name, remaining[i + 1], e)
|
|
896
|
+
continue
|
|
897
|
+
|
|
898
|
+
try:
|
|
899
|
+
forecast = _step_predict(
|
|
900
|
+
artifact=artifact,
|
|
901
|
+
dataset=dataset,
|
|
902
|
+
task_spec=task_spec,
|
|
903
|
+
predict_func=predict_callable,
|
|
904
|
+
plan=plan,
|
|
905
|
+
covariates=covariates,
|
|
906
|
+
)
|
|
907
|
+
return artifact, forecast
|
|
908
|
+
except Exception as e:
|
|
909
|
+
last_error = e
|
|
910
|
+
if on_fallback and i < len(remaining) - 1:
|
|
911
|
+
on_fallback(model_name, remaining[i + 1], e)
|
|
912
|
+
continue
|
|
913
|
+
|
|
914
|
+
raise EFallbackExhausted(
|
|
915
|
+
f"All models failed during predict fallback. Last error: {last_error}",
|
|
916
|
+
context={
|
|
917
|
+
"models_attempted": remaining,
|
|
918
|
+
"last_error": str(last_error),
|
|
919
|
+
},
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
def _get_error_code(validation: ValidationReport) -> str | None:
|
|
924
|
+
"""Extract error code from validation report."""
|
|
925
|
+
if validation.errors:
|
|
926
|
+
return validation.errors[0].get("code")
|
|
927
|
+
return None
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def _call_with_optional_kwargs(func: Any, *args: Any, **kwargs: Any) -> Any:
|
|
931
|
+
"""Call a function with only supported keyword arguments."""
|
|
932
|
+
if not kwargs:
|
|
933
|
+
return func(*args)
|
|
934
|
+
|
|
935
|
+
try:
|
|
936
|
+
import inspect
|
|
937
|
+
|
|
938
|
+
params = inspect.signature(func).parameters
|
|
939
|
+
accepted = {k: v for k, v in kwargs.items() if k in params}
|
|
940
|
+
return func(*args, **accepted)
|
|
941
|
+
except Exception:
|
|
942
|
+
# Fall back to direct call if signature inspection fails
|
|
943
|
+
return func(*args, **kwargs)
|