tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Metrics calculation for time series forecasting.
|
|
2
|
+
|
|
3
|
+
Deprecated: use ``tsagentkit.eval.evaluate_forecasts`` (utilsforecast-backed)
|
|
4
|
+
for new evaluation paths. This module is retained for backward compatibility
|
|
5
|
+
and will be removed in a future phase.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
from collections.abc import Iterable
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from tsagentkit.eval import evaluate_forecasts
|
|
17
|
+
from tsagentkit.utils import quantile_col_name
|
|
18
|
+
|
|
19
|
+
_DEPRECATION_MESSAGE = (
|
|
20
|
+
"tsagentkit.backtest.metrics is deprecated; "
|
|
21
|
+
"use tsagentkit.eval.evaluate_forecasts instead."
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _warn_deprecated(name: str) -> None:
|
|
26
|
+
warnings.warn(
|
|
27
|
+
f"{name} is deprecated. {_DEPRECATION_MESSAGE}",
|
|
28
|
+
DeprecationWarning,
|
|
29
|
+
stacklevel=2,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _as_array(values: Iterable[float]) -> np.ndarray:
|
|
34
|
+
return np.asarray(values, dtype=float)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _build_eval_frame(
|
|
38
|
+
y_true: Iterable[float],
|
|
39
|
+
y_pred: Iterable[float],
|
|
40
|
+
y_quantiles: dict[float, Iterable[float]] | None = None,
|
|
41
|
+
) -> pd.DataFrame:
|
|
42
|
+
y_true_arr = _as_array(y_true)
|
|
43
|
+
y_pred_arr = _as_array(y_pred)
|
|
44
|
+
if y_true_arr.shape != y_pred_arr.shape:
|
|
45
|
+
raise ValueError("y_true and y_pred must have the same shape.")
|
|
46
|
+
|
|
47
|
+
n = y_true_arr.shape[0]
|
|
48
|
+
df = pd.DataFrame(
|
|
49
|
+
{
|
|
50
|
+
"unique_id": ["series"] * n,
|
|
51
|
+
"ds": pd.date_range("2000-01-01", periods=n, freq="D"),
|
|
52
|
+
"y": y_true_arr,
|
|
53
|
+
"yhat": y_pred_arr,
|
|
54
|
+
"model": "model",
|
|
55
|
+
}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if y_quantiles:
|
|
59
|
+
for q, values in y_quantiles.items():
|
|
60
|
+
q_values = _as_array(values)
|
|
61
|
+
if q_values.shape != y_true_arr.shape:
|
|
62
|
+
raise ValueError("Quantile predictions must match y_true shape.")
|
|
63
|
+
df[quantile_col_name(float(q))] = q_values
|
|
64
|
+
|
|
65
|
+
return df
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _build_train_frame(
|
|
69
|
+
y_train: Iterable[float],
|
|
70
|
+
) -> pd.DataFrame:
|
|
71
|
+
y_train_arr = _as_array(y_train)
|
|
72
|
+
n = y_train_arr.shape[0]
|
|
73
|
+
return pd.DataFrame(
|
|
74
|
+
{
|
|
75
|
+
"unique_id": ["series"] * n,
|
|
76
|
+
"ds": pd.date_range("1999-01-01", periods=n, freq="D"),
|
|
77
|
+
"y": y_train_arr,
|
|
78
|
+
}
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _summary_to_metrics(summary_df: pd.DataFrame, model_name: str = "model") -> dict[str, float]:
|
|
83
|
+
if summary_df.empty:
|
|
84
|
+
return {}
|
|
85
|
+
df = summary_df
|
|
86
|
+
if "model" in df.columns:
|
|
87
|
+
df = df[df["model"] == model_name]
|
|
88
|
+
if df.empty or "metric" not in df.columns or "value" not in df.columns:
|
|
89
|
+
return {}
|
|
90
|
+
return {row["metric"]: float(row["value"]) for _, row in df.iterrows()}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def wape(y_true: Iterable[float], y_pred: Iterable[float]) -> float:
|
|
94
|
+
"""Weighted Absolute Percentage Error (utilsforecast ND)."""
|
|
95
|
+
_warn_deprecated("wape")
|
|
96
|
+
df = _build_eval_frame(y_true, y_pred)
|
|
97
|
+
_, summary = evaluate_forecasts(df)
|
|
98
|
+
metrics = _summary_to_metrics(summary.df)
|
|
99
|
+
return float(metrics.get("wape", np.nan))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def smape(y_true: Iterable[float], y_pred: Iterable[float]) -> float:
|
|
103
|
+
"""Symmetric Mean Absolute Percentage Error."""
|
|
104
|
+
_warn_deprecated("smape")
|
|
105
|
+
df = _build_eval_frame(y_true, y_pred)
|
|
106
|
+
_, summary = evaluate_forecasts(df)
|
|
107
|
+
metrics = _summary_to_metrics(summary.df)
|
|
108
|
+
return float(metrics.get("smape", np.nan))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def mase(
|
|
112
|
+
y_true: Iterable[float],
|
|
113
|
+
y_pred: Iterable[float],
|
|
114
|
+
y_train: Iterable[float],
|
|
115
|
+
season_length: int = 1,
|
|
116
|
+
) -> float:
|
|
117
|
+
"""Mean Absolute Scaled Error."""
|
|
118
|
+
_warn_deprecated("mase")
|
|
119
|
+
df = _build_eval_frame(y_true, y_pred)
|
|
120
|
+
train_df = _build_train_frame(y_train)
|
|
121
|
+
_, summary = evaluate_forecasts(df, train_df=train_df, season_length=season_length)
|
|
122
|
+
metrics = _summary_to_metrics(summary.df)
|
|
123
|
+
value = float(metrics.get("mase", np.nan))
|
|
124
|
+
if np.isinf(value):
|
|
125
|
+
return float("nan")
|
|
126
|
+
return value
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def pinball_loss(
|
|
130
|
+
y_true: Iterable[float],
|
|
131
|
+
y_quantile: Iterable[float],
|
|
132
|
+
tau: float,
|
|
133
|
+
) -> float:
|
|
134
|
+
"""Pinball Loss for quantile forecasts."""
|
|
135
|
+
_warn_deprecated("pinball_loss")
|
|
136
|
+
df = _build_eval_frame(y_true, y_quantile, y_quantiles={tau: y_quantile})
|
|
137
|
+
_, summary = evaluate_forecasts(df)
|
|
138
|
+
metrics = _summary_to_metrics(summary.df)
|
|
139
|
+
return float(metrics.get(f"pinball_{tau:.3f}", np.nan))
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def wql(y_true: Iterable[float], y_quantiles: dict[float, Iterable[float]]) -> float:
|
|
143
|
+
"""Weighted Quantile Loss (average pinball loss across quantiles)."""
|
|
144
|
+
_warn_deprecated("wql")
|
|
145
|
+
df = _build_eval_frame(y_true, y_true, y_quantiles=y_quantiles)
|
|
146
|
+
_, summary = evaluate_forecasts(df)
|
|
147
|
+
metrics = _summary_to_metrics(summary.df)
|
|
148
|
+
return float(metrics.get("wql", np.nan))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def mae(y_true: Iterable[float], y_pred: Iterable[float]) -> float:
|
|
152
|
+
"""Mean Absolute Error."""
|
|
153
|
+
_warn_deprecated("mae")
|
|
154
|
+
df = _build_eval_frame(y_true, y_pred)
|
|
155
|
+
_, summary = evaluate_forecasts(df)
|
|
156
|
+
metrics = _summary_to_metrics(summary.df)
|
|
157
|
+
return float(metrics.get("mae", np.nan))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def rmse(y_true: Iterable[float], y_pred: Iterable[float]) -> float:
|
|
161
|
+
"""Root Mean Squared Error."""
|
|
162
|
+
_warn_deprecated("rmse")
|
|
163
|
+
df = _build_eval_frame(y_true, y_pred)
|
|
164
|
+
_, summary = evaluate_forecasts(df)
|
|
165
|
+
metrics = _summary_to_metrics(summary.df)
|
|
166
|
+
return float(metrics.get("rmse", np.nan))
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def compute_all_metrics(
|
|
170
|
+
y_true: Iterable[float],
|
|
171
|
+
y_pred: Iterable[float],
|
|
172
|
+
y_train: Iterable[float] | None = None,
|
|
173
|
+
season_length: int = 1,
|
|
174
|
+
y_quantiles: dict[float, Iterable[float]] | None = None,
|
|
175
|
+
) -> dict[str, float]:
|
|
176
|
+
"""Compute all standard metrics via utilsforecast.evaluate."""
|
|
177
|
+
_warn_deprecated("compute_all_metrics")
|
|
178
|
+
df = _build_eval_frame(y_true, y_pred, y_quantiles=y_quantiles)
|
|
179
|
+
train_df = _build_train_frame(y_train) if y_train is not None else None
|
|
180
|
+
_, summary = evaluate_forecasts(
|
|
181
|
+
df,
|
|
182
|
+
train_df=train_df,
|
|
183
|
+
season_length=season_length if train_df is not None else None,
|
|
184
|
+
)
|
|
185
|
+
metrics = _summary_to_metrics(summary.df)
|
|
186
|
+
|
|
187
|
+
if "mase" not in metrics:
|
|
188
|
+
metrics["mase"] = float("nan")
|
|
189
|
+
|
|
190
|
+
if y_quantiles:
|
|
191
|
+
for q in y_quantiles:
|
|
192
|
+
source_key = f"pinball_{q:.3f}"
|
|
193
|
+
target_key = f"pinball_{q:.2f}"
|
|
194
|
+
if source_key in metrics and target_key not in metrics:
|
|
195
|
+
metrics[target_key] = metrics[source_key]
|
|
196
|
+
|
|
197
|
+
return metrics
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def compute_metrics_by_series(
|
|
201
|
+
df: pd.DataFrame,
|
|
202
|
+
id_col: str = "unique_id",
|
|
203
|
+
actual_col: str = "y",
|
|
204
|
+
pred_col: str = "yhat",
|
|
205
|
+
) -> dict[str, dict[str, float]]:
|
|
206
|
+
"""Compute metrics for each series separately via utilsforecast.evaluate."""
|
|
207
|
+
_warn_deprecated("compute_metrics_by_series")
|
|
208
|
+
if df.empty:
|
|
209
|
+
return {}
|
|
210
|
+
|
|
211
|
+
working = df.copy()
|
|
212
|
+
if "model" not in working.columns:
|
|
213
|
+
working["model"] = "model"
|
|
214
|
+
if "ds" not in working.columns:
|
|
215
|
+
working["ds"] = working.groupby(id_col).cumcount()
|
|
216
|
+
|
|
217
|
+
metric_frame, _ = evaluate_forecasts(
|
|
218
|
+
working,
|
|
219
|
+
id_col=id_col,
|
|
220
|
+
ds_col="ds",
|
|
221
|
+
target_col=actual_col,
|
|
222
|
+
model_col="model",
|
|
223
|
+
pred_col=pred_col,
|
|
224
|
+
cutoff_col=None,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
metrics_df = metric_frame.df
|
|
228
|
+
if metrics_df is None or metrics_df.empty:
|
|
229
|
+
return {}
|
|
230
|
+
|
|
231
|
+
if "model" in metrics_df.columns:
|
|
232
|
+
metrics_df = metrics_df[metrics_df["model"] == "model"]
|
|
233
|
+
|
|
234
|
+
if id_col not in metrics_df.columns or "metric" not in metrics_df.columns:
|
|
235
|
+
return {}
|
|
236
|
+
|
|
237
|
+
grouped = metrics_df.groupby([id_col, "metric"])["value"].mean().reset_index()
|
|
238
|
+
result: dict[str, dict[str, float]] = {}
|
|
239
|
+
for uid, group in grouped.groupby(id_col):
|
|
240
|
+
result[str(uid)] = {
|
|
241
|
+
row["metric"]: float(row["value"]) for _, row in group.iterrows()
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
return result
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""Backtest report structures.
|
|
2
|
+
|
|
3
|
+
Defines data classes for backtest results and diagnostics.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from tsagentkit.contracts import CVFrame
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class WindowResult:
|
|
18
|
+
"""Results from a single backtest window.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
window_index: Index of the window (0-based)
|
|
22
|
+
train_start: Start date of training set
|
|
23
|
+
train_end: End date of training set
|
|
24
|
+
test_start: Start date of test set
|
|
25
|
+
test_end: End date of test set
|
|
26
|
+
metrics: Dictionary of metrics for this window
|
|
27
|
+
num_series: Number of series evaluated
|
|
28
|
+
num_observations: Number of observations in test set
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
window_index: int
|
|
32
|
+
train_start: str
|
|
33
|
+
train_end: str
|
|
34
|
+
test_start: str
|
|
35
|
+
test_end: str
|
|
36
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
37
|
+
num_series: int = 0
|
|
38
|
+
num_observations: int = 0
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict[str, Any]:
|
|
41
|
+
"""Convert to dictionary."""
|
|
42
|
+
return {
|
|
43
|
+
"window_index": self.window_index,
|
|
44
|
+
"train_start": self.train_start,
|
|
45
|
+
"train_end": self.train_end,
|
|
46
|
+
"test_start": self.test_start,
|
|
47
|
+
"test_end": self.test_end,
|
|
48
|
+
"metrics": self.metrics,
|
|
49
|
+
"num_series": self.num_series,
|
|
50
|
+
"num_observations": self.num_observations,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class SeriesMetrics:
|
|
56
|
+
"""Metrics aggregated by series.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
series_id: Unique identifier for the series
|
|
60
|
+
metrics: Dictionary of metric name to value
|
|
61
|
+
num_windows: Number of windows this series appeared in
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
series_id: str
|
|
65
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
66
|
+
num_windows: int = 0
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> dict[str, Any]:
|
|
69
|
+
"""Convert to dictionary."""
|
|
70
|
+
return {
|
|
71
|
+
"series_id": self.series_id,
|
|
72
|
+
"metrics": self.metrics,
|
|
73
|
+
"num_windows": self.num_windows,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class SegmentMetrics:
|
|
79
|
+
"""Metrics aggregated by segment (e.g., sparsity class).
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
segment_name: Name of the segment (e.g., "intermittent", "regular")
|
|
83
|
+
series_ids: List of series IDs in this segment
|
|
84
|
+
metrics: Dictionary of metric name to value
|
|
85
|
+
n_series: Number of series in this segment
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
segment_name: str
|
|
89
|
+
series_ids: list[str] = field(default_factory=list)
|
|
90
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
91
|
+
n_series: int = 0
|
|
92
|
+
|
|
93
|
+
def to_dict(self) -> dict[str, Any]:
|
|
94
|
+
"""Convert to dictionary."""
|
|
95
|
+
return {
|
|
96
|
+
"segment_name": self.segment_name,
|
|
97
|
+
"series_ids": self.series_ids,
|
|
98
|
+
"metrics": self.metrics,
|
|
99
|
+
"n_series": self.n_series,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass(frozen=True)
|
|
104
|
+
class TemporalMetrics:
|
|
105
|
+
"""Metrics aggregated by temporal dimension.
|
|
106
|
+
|
|
107
|
+
Attributes:
|
|
108
|
+
dimension: Temporal dimension (e.g., "hour", "dayofweek")
|
|
109
|
+
metrics_by_value: Dictionary mapping dimension value to metrics
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
dimension: str
|
|
113
|
+
metrics_by_value: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
114
|
+
|
|
115
|
+
def to_dict(self) -> dict[str, Any]:
|
|
116
|
+
"""Convert to dictionary."""
|
|
117
|
+
return {
|
|
118
|
+
"dimension": self.dimension,
|
|
119
|
+
"metrics_by_value": self.metrics_by_value,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass(frozen=True)
|
|
124
|
+
class BacktestReport:
|
|
125
|
+
"""Complete backtest results.
|
|
126
|
+
|
|
127
|
+
Contains window-level results, aggregate metrics, and series-level
|
|
128
|
+
diagnostics for a backtest run.
|
|
129
|
+
|
|
130
|
+
Attributes:
|
|
131
|
+
n_windows: Number of backtest windows
|
|
132
|
+
strategy: Window strategy ("expanding" or "sliding")
|
|
133
|
+
window_results: List of results per window
|
|
134
|
+
aggregate_metrics: Metrics aggregated across all windows
|
|
135
|
+
series_metrics: Metrics aggregated by series
|
|
136
|
+
segment_metrics: Metrics aggregated by segment (sparsity class)
|
|
137
|
+
temporal_metrics: Metrics aggregated by temporal dimension
|
|
138
|
+
errors: List of errors encountered during backtest
|
|
139
|
+
metadata: Additional backtest metadata
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
n_windows: int
|
|
143
|
+
strategy: str
|
|
144
|
+
window_results: list[WindowResult] = field(default_factory=list)
|
|
145
|
+
aggregate_metrics: dict[str, float] = field(default_factory=dict)
|
|
146
|
+
series_metrics: dict[str, SeriesMetrics] = field(default_factory=dict)
|
|
147
|
+
segment_metrics: dict[str, SegmentMetrics] = field(default_factory=dict)
|
|
148
|
+
temporal_metrics: dict[str, TemporalMetrics] = field(default_factory=dict)
|
|
149
|
+
errors: list[dict[str, Any]] = field(default_factory=list)
|
|
150
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
151
|
+
cv_frame: pd.DataFrame | CVFrame | None = None
|
|
152
|
+
|
|
153
|
+
def _cv_dataframe(self) -> pd.DataFrame | None:
|
|
154
|
+
if isinstance(self.cv_frame, CVFrame):
|
|
155
|
+
return self.cv_frame.df
|
|
156
|
+
return self.cv_frame
|
|
157
|
+
|
|
158
|
+
def get_metric(self, metric_name: str) -> float:
|
|
159
|
+
"""Get an aggregate metric by name.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
metric_name: Name of the metric (e.g., "wape", "smape")
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Metric value or NaN if not found
|
|
166
|
+
"""
|
|
167
|
+
return self.aggregate_metrics.get(metric_name, float("nan"))
|
|
168
|
+
|
|
169
|
+
def get_series_metric(self, series_id: str, metric_name: str) -> float:
|
|
170
|
+
"""Get a metric for a specific series.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
series_id: Series identifier
|
|
174
|
+
metric_name: Name of the metric
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Metric value or NaN if not found
|
|
178
|
+
"""
|
|
179
|
+
if series_id not in self.series_metrics:
|
|
180
|
+
return float("nan")
|
|
181
|
+
return self.series_metrics[series_id].metrics.get(metric_name, float("nan"))
|
|
182
|
+
|
|
183
|
+
def get_best_series(self, metric_name: str = "wape") -> str | None:
|
|
184
|
+
"""Get the series with the best performance for a metric.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
metric_name: Metric to optimize (lower is better)
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Series ID with lowest metric value, or None if no data
|
|
191
|
+
"""
|
|
192
|
+
if not self.series_metrics:
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
best_id = None
|
|
196
|
+
best_value = float("inf")
|
|
197
|
+
|
|
198
|
+
for sid, sm in self.series_metrics.items():
|
|
199
|
+
value = sm.metrics.get(metric_name, float("inf"))
|
|
200
|
+
if value < best_value:
|
|
201
|
+
best_value = value
|
|
202
|
+
best_id = sid
|
|
203
|
+
|
|
204
|
+
return best_id
|
|
205
|
+
|
|
206
|
+
def get_worst_series(self, metric_name: str = "wape") -> str | None:
|
|
207
|
+
"""Get the series with the worst performance for a metric.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
metric_name: Metric to check (higher is worse)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Series ID with highest metric value, or None if no data
|
|
214
|
+
"""
|
|
215
|
+
if not self.series_metrics:
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
worst_id = None
|
|
219
|
+
worst_value = float("-inf")
|
|
220
|
+
|
|
221
|
+
for sid, sm in self.series_metrics.items():
|
|
222
|
+
value = sm.metrics.get(metric_name, float("-inf"))
|
|
223
|
+
if value > worst_value and not pd.isna(value):
|
|
224
|
+
worst_value = value
|
|
225
|
+
worst_id = sid
|
|
226
|
+
|
|
227
|
+
return worst_id
|
|
228
|
+
|
|
229
|
+
def get_segment_metric(self, segment_name: str, metric_name: str) -> float:
|
|
230
|
+
"""Get a metric for a specific segment.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
segment_name: Segment name (e.g., "intermittent", "regular")
|
|
234
|
+
metric_name: Name of the metric
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Metric value or NaN if not found
|
|
238
|
+
"""
|
|
239
|
+
if segment_name not in self.segment_metrics:
|
|
240
|
+
return float("nan")
|
|
241
|
+
return self.segment_metrics[segment_name].metrics.get(metric_name, float("nan"))
|
|
242
|
+
|
|
243
|
+
def get_temporal_metric(self, dimension: str, value: str, metric_name: str) -> float:
|
|
244
|
+
"""Get a metric for a specific temporal dimension value.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
dimension: Temporal dimension ("hour", "dayofweek", etc.)
|
|
248
|
+
value: Dimension value (e.g., "0" for Sunday, "14" for 2pm)
|
|
249
|
+
metric_name: Name of the metric
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Metric value or NaN if not found
|
|
253
|
+
"""
|
|
254
|
+
if dimension not in self.temporal_metrics:
|
|
255
|
+
return float("nan")
|
|
256
|
+
dim_metrics = self.temporal_metrics[dimension]
|
|
257
|
+
if value not in dim_metrics.metrics_by_value:
|
|
258
|
+
return float("nan")
|
|
259
|
+
return dim_metrics.metrics_by_value[value].get(metric_name, float("nan"))
|
|
260
|
+
|
|
261
|
+
def compare_segments(self, metric_name: str = "wape") -> dict[str, float]:
|
|
262
|
+
"""Compare a metric across all segments.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
metric_name: Metric to compare
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Dictionary mapping segment name to metric value
|
|
269
|
+
"""
|
|
270
|
+
return {
|
|
271
|
+
name: sm.metrics.get(metric_name, float("nan"))
|
|
272
|
+
for name, sm in self.segment_metrics.items()
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
def to_dict(self) -> dict[str, Any]:
|
|
276
|
+
"""Convert to dictionary for serialization."""
|
|
277
|
+
return {
|
|
278
|
+
"n_windows": self.n_windows,
|
|
279
|
+
"strategy": self.strategy,
|
|
280
|
+
"aggregate_metrics": self.aggregate_metrics,
|
|
281
|
+
"series_metrics": {
|
|
282
|
+
k: v.to_dict() for k, v in self.series_metrics.items()
|
|
283
|
+
},
|
|
284
|
+
"segment_metrics": {
|
|
285
|
+
k: v.to_dict() for k, v in self.segment_metrics.items()
|
|
286
|
+
},
|
|
287
|
+
"temporal_metrics": {
|
|
288
|
+
k: v.to_dict() for k, v in self.temporal_metrics.items()
|
|
289
|
+
},
|
|
290
|
+
"window_results": [w.to_dict() for w in self.window_results],
|
|
291
|
+
"errors": self.errors,
|
|
292
|
+
"metadata": self.metadata,
|
|
293
|
+
"cv_frame": (
|
|
294
|
+
self._cv_dataframe().to_dict("records")
|
|
295
|
+
if self._cv_dataframe() is not None
|
|
296
|
+
else None
|
|
297
|
+
),
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
def summary(self) -> str:
|
|
301
|
+
"""Generate a human-readable summary.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Summary string with key metrics
|
|
305
|
+
"""
|
|
306
|
+
lines = [
|
|
307
|
+
f"Backtest Report: {self.n_windows} windows ({self.strategy})",
|
|
308
|
+
"=" * 50,
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
# Aggregate metrics
|
|
312
|
+
lines.append("\nAggregate Metrics:")
|
|
313
|
+
for name, value in sorted(self.aggregate_metrics.items()):
|
|
314
|
+
lines.append(f" {name}: {value:.4f}")
|
|
315
|
+
|
|
316
|
+
# Segment metrics
|
|
317
|
+
if self.segment_metrics:
|
|
318
|
+
lines.append("\nSegment Metrics:")
|
|
319
|
+
for name, sm in sorted(self.segment_metrics.items()):
|
|
320
|
+
metrics_str = ", ".join(
|
|
321
|
+
f"{k}={v:.4f}" for k, v in sorted(sm.metrics.items())[:2]
|
|
322
|
+
)
|
|
323
|
+
lines.append(f" {name} ({sm.n_series} series): {metrics_str}")
|
|
324
|
+
|
|
325
|
+
# Temporal metrics
|
|
326
|
+
if self.temporal_metrics:
|
|
327
|
+
lines.append("\nTemporal Metrics:")
|
|
328
|
+
for dim, tm in sorted(self.temporal_metrics.items()):
|
|
329
|
+
lines.append(f" {dim}: {len(tm.metrics_by_value)} unique values")
|
|
330
|
+
|
|
331
|
+
# Best/worst series
|
|
332
|
+
best = self.get_best_series("wape")
|
|
333
|
+
worst = self.get_worst_series("wape")
|
|
334
|
+
if best and worst:
|
|
335
|
+
lines.append(f"\nBest Series: {best}")
|
|
336
|
+
lines.append(f"Worst Series: {worst}")
|
|
337
|
+
|
|
338
|
+
# Errors
|
|
339
|
+
if self.errors:
|
|
340
|
+
lines.append(f"\nErrors: {len(self.errors)}")
|
|
341
|
+
|
|
342
|
+
return "\n".join(lines)
|