tscli-darts 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tscli/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """tscli package."""
2
+
3
+ __all__ = ["__version__"]
4
+
5
+ __version__ = "0.1.0"
tscli/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ from tscli.main import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
tscli/analysis.py ADDED
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import pandas as pd
6
+
7
+ from tscli.data import LoadedSeries
8
+
9
+
10
+ @dataclass
11
+ class SeriesSummary:
12
+ row_count: int
13
+ start: str
14
+ end: str
15
+ missing_target: int
16
+ mean: float
17
+ median: float
18
+ minimum: float
19
+ maximum: float
20
+ std_dev: float
21
+ inferred_frequency: str
22
+ trend_direction: str
23
+
24
+
25
+ def summarize_series(dataset: LoadedSeries) -> SeriesSummary:
26
+ frame = dataset.frame.copy()
27
+ target = frame[dataset.target_col]
28
+
29
+ inferred_frequency = "not available"
30
+ if dataset.time_col != "__index__":
31
+ inferred = pd.infer_freq(frame[dataset.time_col])
32
+ if inferred:
33
+ inferred_frequency = inferred
34
+
35
+ clean_target = target.dropna()
36
+ if clean_target.empty:
37
+ raise ValueError("The target series is empty after dropping missing values.")
38
+
39
+ trend_delta = clean_target.iloc[-1] - clean_target.iloc[0]
40
+ if trend_delta > 0:
41
+ trend_direction = "upward"
42
+ elif trend_delta < 0:
43
+ trend_direction = "downward"
44
+ else:
45
+ trend_direction = "flat"
46
+
47
+ return SeriesSummary(
48
+ row_count=len(frame),
49
+ start=str(frame[dataset.time_col].iloc[0]),
50
+ end=str(frame[dataset.time_col].iloc[-1]),
51
+ missing_target=int(target.isna().sum()),
52
+ mean=float(clean_target.mean()),
53
+ median=float(clean_target.median()),
54
+ minimum=float(clean_target.min()),
55
+ maximum=float(clean_target.max()),
56
+ std_dev=float(clean_target.std(ddof=0)),
57
+ inferred_frequency=inferred_frequency,
58
+ trend_direction=trend_direction,
59
+ )
60
+
61
+
62
+ def recent_observations(dataset: LoadedSeries, rows: int = 5) -> pd.DataFrame:
63
+ return dataset.frame[[dataset.time_col, dataset.target_col]].tail(rows).copy()
tscli/data.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import pandas as pd
7
+
8
+ from tscli.preprocessing import (
9
+ PreprocessingReport,
10
+ clean_numeric_column,
11
+ finalize_time_series,
12
+ normalize_columns,
13
+ parse_time_column,
14
+ )
15
+
16
+
17
+ @dataclass
18
+ class LoadedSeries:
19
+ source: Path
20
+ frame: pd.DataFrame
21
+ time_col: str
22
+ target_col: str
23
+ report: PreprocessingReport
24
+
25
+
26
+ def load_csv(csv_path: Path, time_col: str | None, target_col: str) -> LoadedSeries:
27
+ frame = pd.read_csv(csv_path)
28
+ report = PreprocessingReport()
29
+ frame = normalize_columns(frame, report)
30
+ if target_col not in frame.columns:
31
+ raise ValueError(f"Target column '{target_col}' was not found in the CSV.")
32
+
33
+ resolved_time_col = time_col
34
+ if resolved_time_col is None:
35
+ for candidate in ("date", "datetime", "timestamp", "ds", "time"):
36
+ if candidate in frame.columns:
37
+ resolved_time_col = candidate
38
+ break
39
+
40
+ if resolved_time_col is not None:
41
+ if resolved_time_col not in frame.columns:
42
+ raise ValueError(f"Time column '{resolved_time_col}' was not found in the CSV.")
43
+ frame = parse_time_column(frame, resolved_time_col, report)
44
+ if frame[resolved_time_col].isna().any():
45
+ raise ValueError(
46
+ f"Time column '{resolved_time_col}' contains values that could not be parsed as datetime."
47
+ )
48
+ else:
49
+ resolved_time_col = "__index__"
50
+ frame[resolved_time_col] = pd.RangeIndex(start=0, stop=len(frame), step=1)
51
+ report.add_fix("Created a synthetic integer time index because no time column was provided.")
52
+
53
+ frame = clean_numeric_column(frame, target_col, report)
54
+ if frame[target_col].isna().all():
55
+ raise ValueError(f"Target column '{target_col}' does not contain numeric values.")
56
+ frame = finalize_time_series(frame, resolved_time_col, target_col, report)
57
+
58
+ return LoadedSeries(
59
+ source=csv_path,
60
+ frame=frame,
61
+ time_col=resolved_time_col,
62
+ target_col=target_col,
63
+ report=report,
64
+ )
tscli/forecasting.py ADDED
@@ -0,0 +1,531 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from importlib import import_module
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ from darts import TimeSeries
12
+ from statsmodels.tsa.arima.model import ARIMA as StatsmodelsARIMA
13
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
14
+ from tqdm import tqdm
15
+
16
+ from tscli.data import LoadedSeries
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ModelSpec:
21
+ description: str
22
+ family: str
23
+
24
+
25
+ MODEL_SPECS = {
26
+ "naive-last": ModelSpec("Repeats the last observed value.", "built-in"),
27
+ "naive-drift": ModelSpec("Extends the line from the first to the last observation.", "built-in"),
28
+ "naive-seasonal": ModelSpec("Repeats the last seasonal pattern.", "built-in"),
29
+ "moving-average": ModelSpec("Forecasts with the mean of the latest seasonal window.", "built-in"),
30
+ "weighted-moving-average": ModelSpec(
31
+ "Forecasts with a linearly weighted average of the latest seasonal window.", "built-in"
32
+ ),
33
+ "exp-smoothing": ModelSpec("Forecasts with an exponentially weighted moving average level.", "built-in"),
34
+ "seasonal-average": ModelSpec(
35
+ "Forecasts each seasonal position with the average of past matching positions.", "built-in"
36
+ ),
37
+ "seasonal-median": ModelSpec(
38
+ "Forecasts each seasonal position with the median of past matching positions.", "built-in"
39
+ ),
40
+ "linear-trend": ModelSpec("Fits a straight trend line across the series.", "built-in"),
41
+ "quadratic-trend": ModelSpec("Fits a quadratic trend curve across the series.", "built-in"),
42
+ "arima": ModelSpec("DARTS ARIMA model for classical forecasting.", "darts-classical"),
43
+ "theta": ModelSpec("DARTS Theta model for classical univariate forecasting.", "darts-classical"),
44
+ "exponential-smoothing": ModelSpec(
45
+ "DARTS ExponentialSmoothing model for level, trend, and seasonality.", "darts-classical"
46
+ ),
47
+ "auto-arima": ModelSpec("DARTS AutoARIMA model with automatic order selection.", "darts-classical"),
48
+ "sarima": ModelSpec("DARTS ARIMA model configured with seasonal ARIMA defaults.", "darts-classical"),
49
+ }
50
+
51
+ SUPPORTED_MODELS = {name: spec.description for name, spec in MODEL_SPECS.items()}
52
+
53
+
54
+ @dataclass
55
+ class ForecastResult:
56
+ model_name: str
57
+ forecast_frame: pd.DataFrame
58
+
59
+
60
+ @dataclass
61
+ class EvaluationResult:
62
+ model_name: str
63
+ mae: float
64
+ rmse: float
65
+ mape: float
66
+
67
+
68
+ @dataclass
69
+ class BenchmarkResult:
70
+ scores: list[EvaluationResult]
71
+ actual_frame: pd.DataFrame
72
+ forecasts: dict[str, pd.DataFrame]
73
+ best_model: str
74
+ skipped_models: dict[str, str]
75
+
76
+
77
+ def build_series(dataset: LoadedSeries) -> TimeSeries:
78
+ frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
79
+
80
+ if dataset.time_col == "__index__":
81
+ frame[dataset.time_col] = pd.RangeIndex(start=0, stop=len(frame), step=1)
82
+ return TimeSeries.from_dataframe(
83
+ frame,
84
+ time_col=dataset.time_col,
85
+ value_cols=dataset.target_col,
86
+ )
87
+
88
+ return TimeSeries.from_dataframe(
89
+ frame,
90
+ time_col=dataset.time_col,
91
+ value_cols=dataset.target_col,
92
+ fill_missing_dates=False,
93
+ )
94
+
95
+
96
+ def _future_index(frame: pd.DataFrame, time_col: str, horizon: int) -> pd.Index:
97
+ if time_col == "__index__":
98
+ start = int(frame[time_col].iloc[-1]) + 1
99
+ return pd.RangeIndex(start=start, stop=start + horizon, step=1)
100
+
101
+ inferred = pd.infer_freq(frame[time_col])
102
+ if inferred:
103
+ offset = pd.tseries.frequencies.to_offset(inferred)
104
+ start = frame[time_col].iloc[-1] + offset
105
+ return pd.date_range(start=start, periods=horizon, freq=offset)
106
+
107
+ if len(frame) >= 2:
108
+ delta = frame[time_col].iloc[-1] - frame[time_col].iloc[-2]
109
+ else:
110
+ delta = pd.Timedelta(days=1)
111
+ start = frame[time_col].iloc[-1] + delta
112
+ return pd.Index([start + delta * step for step in range(horizon)])
113
+
114
+
115
+ def _coerce_output_frame(frame: pd.DataFrame, time_col: str) -> pd.DataFrame:
116
+ export = frame.copy()
117
+ if time_col != "__index__" and pd.api.types.is_datetime64_any_dtype(export[time_col]):
118
+ export[time_col] = export[time_col].dt.strftime("%Y-%m-%d")
119
+ return export
120
+
121
+
122
+ def _dataset_from_frame(dataset: LoadedSeries, frame: pd.DataFrame) -> LoadedSeries:
123
+ return LoadedSeries(
124
+ source=dataset.source,
125
+ frame=frame.reset_index(drop=True),
126
+ time_col=dataset.time_col,
127
+ target_col=dataset.target_col,
128
+ report=dataset.report,
129
+ )
130
+
131
+
132
+ def _load_optional_model(model_name: str) -> Any:
133
+ loaders = {
134
+ "arima": ("darts.models.forecasting.arima", "ARIMA"),
135
+ "theta": ("darts.models.forecasting.theta", "Theta"),
136
+ "exponential-smoothing": ("darts.models.forecasting.exponential_smoothing", "ExponentialSmoothing"),
137
+ "auto-arima": ("darts.models.forecasting.sf_auto_arima", "AutoARIMA"),
138
+ "sarima": ("darts.models.forecasting.arima", "ARIMA"),
139
+ }
140
+ module_name, class_name = loaders[model_name]
141
+ try:
142
+ module = import_module(module_name)
143
+ return getattr(module, class_name)
144
+ except Exception as exc: # pragma: no cover - environment dependent
145
+ extra_hint = "Install tscli with the right optional extra."
146
+ if model_name in {"arima", "theta", "exponential-smoothing", "sarima"}:
147
+ extra_hint = "Install tscli with the 'classical' extra: pip install -e .[classical]"
148
+ elif model_name == "auto-arima":
149
+ extra_hint = (
150
+ "Install tscli with the 'autoarima' extra, or 'full' for everything: "
151
+ "pip install -e .[autoarima]"
152
+ )
153
+ raise ValueError(
154
+ f"Model '{model_name}' is unavailable in this environment. {extra_hint}. "
155
+ f"Original error: {exc}"
156
+ ) from exc
157
+
158
+
159
+ def _heuristic_forecast_values(
160
+ history: np.ndarray,
161
+ model_name: str,
162
+ horizon: int,
163
+ seasonal_period: int,
164
+ ) -> np.ndarray:
165
+ if len(history) == 0:
166
+ raise ValueError("The target series is empty after dropping missing values.")
167
+
168
+ if model_name == "naive-last":
169
+ return np.repeat(history[-1], horizon)
170
+
171
+ if model_name == "naive-drift":
172
+ if len(history) == 1:
173
+ return np.repeat(history[-1], horizon)
174
+ slope = (history[-1] - history[0]) / (len(history) - 1)
175
+ return np.array([history[-1] + slope * step for step in range(1, horizon + 1)])
176
+
177
+ if model_name == "naive-seasonal":
178
+ if len(history) < seasonal_period:
179
+ raise ValueError(
180
+ f"Naive seasonal needs at least {seasonal_period} observations, but found {len(history)}."
181
+ )
182
+ pattern = history[-seasonal_period:]
183
+ return np.array([pattern[step % seasonal_period] for step in range(horizon)])
184
+
185
+ if model_name == "moving-average":
186
+ window = min(seasonal_period, len(history))
187
+ average = float(np.mean(history[-window:]))
188
+ return np.repeat(average, horizon)
189
+
190
+ if model_name == "weighted-moving-average":
191
+ window = min(seasonal_period, len(history))
192
+ weights = np.arange(1, window + 1, dtype=float)
193
+ average = float(np.average(history[-window:], weights=weights))
194
+ return np.repeat(average, horizon)
195
+
196
+ if model_name == "exp-smoothing":
197
+ span = max(2, min(seasonal_period, len(history)))
198
+ level = float(pd.Series(history).ewm(span=span, adjust=False).mean().iloc[-1])
199
+ return np.repeat(level, horizon)
200
+
201
+ if model_name == "seasonal-average":
202
+ if len(history) < seasonal_period:
203
+ raise ValueError(
204
+ f"Seasonal average needs at least {seasonal_period} observations, but found {len(history)}."
205
+ )
206
+ values = []
207
+ for step in range(horizon):
208
+ position = step % seasonal_period
209
+ seasonal_slice = history[position::seasonal_period]
210
+ values.append(float(np.mean(seasonal_slice)))
211
+ return np.array(values)
212
+
213
+ if model_name == "seasonal-median":
214
+ if len(history) < seasonal_period:
215
+ raise ValueError(
216
+ f"Seasonal median needs at least {seasonal_period} observations, but found {len(history)}."
217
+ )
218
+ values = []
219
+ for step in range(horizon):
220
+ position = step % seasonal_period
221
+ seasonal_slice = history[position::seasonal_period]
222
+ values.append(float(np.median(seasonal_slice)))
223
+ return np.array(values)
224
+
225
+ if model_name == "quadratic-trend":
226
+ if len(history) < 3:
227
+ raise ValueError("Quadratic trend needs at least 3 observations.")
228
+ x = np.arange(len(history), dtype=float)
229
+ a, b, c = np.polyfit(x, history.astype(float), 2)
230
+ future_x = np.arange(len(history), len(history) + horizon, dtype=float)
231
+ return a * np.square(future_x) + b * future_x + c
232
+
233
+ x = np.arange(len(history), dtype=float)
234
+ slope, intercept = np.polyfit(x, history.astype(float), 1)
235
+ future_x = np.arange(len(history), len(history) + horizon, dtype=float)
236
+ return intercept + slope * future_x
237
+
238
+
239
+ def _darts_classical_forecast(
240
+ dataset: LoadedSeries,
241
+ model_name: str,
242
+ horizon: int,
243
+ seasonal_period: int,
244
+ ) -> pd.DataFrame:
245
+ series = build_series(dataset)
246
+ model_class = _load_optional_model(model_name)
247
+
248
+ if model_name == "theta":
249
+ model = model_class()
250
+ elif model_name == "exponential-smoothing":
251
+ kwargs: dict[str, Any] = {}
252
+ if len(series) >= seasonal_period * 2:
253
+ kwargs["seasonal_periods"] = seasonal_period
254
+ model = model_class(**kwargs)
255
+ elif model_name == "auto-arima":
256
+ kwargs = {}
257
+ if len(series) >= seasonal_period * 2:
258
+ kwargs["season_length"] = seasonal_period
259
+ model = model_class(**kwargs)
260
+ elif model_name == "sarima":
261
+ kwargs = {"p": 1, "d": 1, "q": 1}
262
+ if len(series) >= seasonal_period * 2:
263
+ kwargs["seasonal_order"] = (1, 1, 1, seasonal_period)
264
+ model = model_class(**kwargs)
265
+ else:
266
+ model = model_class(p=1, d=1, q=1)
267
+
268
+ try:
269
+ model.fit(series)
270
+ forecast = model.predict(horizon)
271
+ except Exception as exc: # pragma: no cover - environment dependent
272
+ raise ValueError(f"Model '{model_name}' failed to fit or predict. Original error: {exc}") from exc
273
+
274
+ try:
275
+ if hasattr(forecast, "pd_dataframe"):
276
+ forecast_frame = forecast.pd_dataframe().reset_index()
277
+ elif hasattr(forecast, "to_dataframe"):
278
+ forecast_frame = forecast.to_dataframe().reset_index()
279
+ elif hasattr(forecast, "pd_series"):
280
+ forecast_frame = forecast.pd_series().to_frame(name=dataset.target_col).reset_index()
281
+ else:
282
+ raise AttributeError("No supported dataframe conversion method was found on the DARTS TimeSeries result.")
283
+ except Exception as exc: # pragma: no cover - environment dependent
284
+ raise ValueError(
285
+ f"Model '{model_name}' produced a forecast, but tscli could not convert it to a table. "
286
+ f"Original error: {exc}"
287
+ ) from exc
288
+
289
+ forecast_frame.columns = [dataset.time_col, dataset.target_col]
290
+ return forecast_frame
291
+
292
+
293
+ def _statsmodels_fallback_forecast(
294
+ dataset: LoadedSeries,
295
+ model_name: str,
296
+ horizon: int,
297
+ seasonal_period: int,
298
+ ) -> pd.DataFrame:
299
+ frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
300
+ history = frame[dataset.target_col].astype(float).to_numpy()
301
+
302
+ try:
303
+ if model_name == "arima":
304
+ fitted = StatsmodelsARIMA(history, order=(1, 1, 1)).fit()
305
+ else:
306
+ seasonal_order = (1, 1, 1, seasonal_period)
307
+ if len(history) < max(24, seasonal_period * 2 + 6):
308
+ seasonal_order = (0, 1, 1, seasonal_period)
309
+ fitted = SARIMAX(
310
+ history,
311
+ order=(1, 1, 1),
312
+ seasonal_order=seasonal_order,
313
+ enforce_stationarity=False,
314
+ enforce_invertibility=False,
315
+ ).fit(disp=False)
316
+ forecast_values = np.asarray(fitted.forecast(steps=horizon), dtype=float)
317
+ except Exception as exc:
318
+ raise ValueError(
319
+ f"Model '{model_name}' is unavailable through DARTS and the fallback implementation also failed. "
320
+ f"Original error: {exc}"
321
+ ) from exc
322
+
323
+ return pd.DataFrame(
324
+ {
325
+ dataset.time_col: _future_index(frame, dataset.time_col, horizon),
326
+ dataset.target_col: forecast_values,
327
+ }
328
+ )
329
+
330
+
331
+ def _statsmodels_classical_forecast(
332
+ dataset: LoadedSeries,
333
+ model_name: str,
334
+ horizon: int,
335
+ seasonal_period: int,
336
+ ) -> pd.DataFrame:
337
+ frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
338
+ history = frame[dataset.target_col].astype(float).to_numpy()
339
+
340
+ try:
341
+ if model_name == "arima":
342
+ fitted = StatsmodelsARIMA(history, order=(1, 1, 1)).fit()
343
+ else:
344
+ seasonal_order = (1, 1, 1, seasonal_period)
345
+ # Small datasets often cannot support a full seasonal parameterization.
346
+ if len(history) < max(24, seasonal_period * 2 + 6):
347
+ seasonal_order = (0, 1, 1, seasonal_period)
348
+ fitted = SARIMAX(
349
+ history,
350
+ order=(1, 1, 1),
351
+ seasonal_order=seasonal_order,
352
+ enforce_stationarity=False,
353
+ enforce_invertibility=False,
354
+ ).fit(disp=False)
355
+ forecast_values = np.asarray(fitted.forecast(steps=horizon), dtype=float)
356
+ except Exception as exc:
357
+ raise ValueError(f"Model '{model_name}' failed to fit or predict. Original error: {exc}") from exc
358
+
359
+ forecast_frame = pd.DataFrame(
360
+ {
361
+ dataset.time_col: _future_index(frame, dataset.time_col, horizon),
362
+ dataset.target_col: forecast_values,
363
+ }
364
+ )
365
+ return forecast_frame
366
+
367
+
368
+ def generate_forecast(
369
+ dataset: LoadedSeries,
370
+ model_name: str,
371
+ horizon: int,
372
+ seasonal_period: int = 12,
373
+ ) -> ForecastResult:
374
+ if model_name not in SUPPORTED_MODELS:
375
+ allowed = ", ".join(sorted(SUPPORTED_MODELS))
376
+ raise ValueError(f"Unsupported model '{model_name}'. Choose from: {allowed}.")
377
+
378
+ frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
379
+ family = MODEL_SPECS[model_name].family
380
+
381
+ if family == "built-in":
382
+ series = build_series(dataset)
383
+ history = series.values(copy=False).reshape(-1)
384
+ forecast_values = _heuristic_forecast_values(history, model_name, horizon, seasonal_period)
385
+ forecast_frame = pd.DataFrame(
386
+ {
387
+ dataset.time_col: _future_index(frame, dataset.time_col, horizon),
388
+ dataset.target_col: forecast_values,
389
+ }
390
+ )
391
+ else:
392
+ try:
393
+ forecast_frame = _darts_classical_forecast(dataset, model_name, horizon, seasonal_period)
394
+ except ValueError as exc:
395
+ if model_name in {"arima", "sarima"} and "unavailable in this environment" in str(exc):
396
+ forecast_frame = _statsmodels_fallback_forecast(dataset, model_name, horizon, seasonal_period)
397
+ else:
398
+ raise
399
+
400
+ return ForecastResult(model_name=model_name, forecast_frame=forecast_frame)
401
+
402
+
403
+ def evaluate_forecast(actual: np.ndarray, predicted: np.ndarray, model_name: str) -> EvaluationResult:
404
+ actual = actual.astype(float)
405
+ predicted = predicted.astype(float)
406
+ errors = actual - predicted
407
+ mae = float(np.mean(np.abs(errors)))
408
+ rmse = float(np.sqrt(np.mean(np.square(errors))))
409
+
410
+ non_zero_mask = actual != 0
411
+ if non_zero_mask.any():
412
+ mape = float(np.mean(np.abs(errors[non_zero_mask] / actual[non_zero_mask])) * 100)
413
+ else:
414
+ mape = float("nan")
415
+
416
+ return EvaluationResult(model_name=model_name, mae=mae, rmse=rmse, mape=mape)
417
+
418
+
419
+ def benchmark_models(
420
+ dataset: LoadedSeries,
421
+ model_names: list[str],
422
+ horizon: int,
423
+ seasonal_period: int = 12,
424
+ ) -> BenchmarkResult:
425
+ clean_frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().reset_index(drop=True)
426
+ if len(clean_frame) <= horizon:
427
+ raise ValueError(
428
+ f"Need more than {horizon} observations to benchmark models, but found {len(clean_frame)}."
429
+ )
430
+
431
+ train_frame = clean_frame.iloc[:-horizon].copy()
432
+ actual_frame = clean_frame.iloc[-horizon:].copy().reset_index(drop=True)
433
+ train_dataset = _dataset_from_frame(dataset, train_frame)
434
+
435
+ scores: list[EvaluationResult] = []
436
+ forecasts: dict[str, pd.DataFrame] = {}
437
+ skipped_models: dict[str, str] = {}
438
+ for model_name in tqdm(model_names, desc="Benchmarking models", unit="model"):
439
+ try:
440
+ result = generate_forecast(
441
+ train_dataset,
442
+ model_name=model_name,
443
+ horizon=horizon,
444
+ seasonal_period=seasonal_period,
445
+ )
446
+ predicted = result.forecast_frame[dataset.target_col].to_numpy(dtype=float)
447
+ actual = actual_frame[dataset.target_col].to_numpy(dtype=float)
448
+ scores.append(evaluate_forecast(actual, predicted, model_name))
449
+ forecasts[model_name] = result.forecast_frame
450
+ except Exception as exc:
451
+ skipped_models[model_name] = str(exc)
452
+
453
+ if not scores:
454
+ raise ValueError("No models could be benchmarked successfully with the current data and settings.")
455
+
456
+ scores.sort(key=lambda item: (item.rmse, item.mae, item.model_name))
457
+ return BenchmarkResult(
458
+ scores=scores,
459
+ actual_frame=actual_frame,
460
+ forecasts=forecasts,
461
+ best_model=scores[0].model_name,
462
+ skipped_models=skipped_models,
463
+ )
464
+
465
+
466
+ def export_frame(frame: pd.DataFrame, output_path: Path, time_col: str) -> None:
467
+ export = _coerce_output_frame(frame, time_col)
468
+ output_path.parent.mkdir(parents=True, exist_ok=True)
469
+ export.to_csv(output_path, index=False)
470
+
471
+
472
+ def export_scores(scores: list[EvaluationResult], output_path: Path) -> None:
473
+ output_path.parent.mkdir(parents=True, exist_ok=True)
474
+ score_frame = pd.DataFrame(
475
+ [
476
+ {
477
+ "model": score.model_name,
478
+ "mae": score.mae,
479
+ "rmse": score.rmse,
480
+ "mape": score.mape,
481
+ }
482
+ for score in scores
483
+ ]
484
+ )
485
+ score_frame.to_csv(output_path, index=False)
486
+
487
+
488
+ def export_forecast_plot(
489
+ history_frame: pd.DataFrame,
490
+ forecast_frame: pd.DataFrame,
491
+ time_col: str,
492
+ target_col: str,
493
+ output_path: Path,
494
+ model_name: str,
495
+ ) -> None:
496
+ output_path.parent.mkdir(parents=True, exist_ok=True)
497
+ fig, ax = plt.subplots(figsize=(10, 5))
498
+ ax.plot(history_frame[time_col], history_frame[target_col], label="history", linewidth=2)
499
+ ax.plot(forecast_frame[time_col], forecast_frame[target_col], label=f"forecast ({model_name})", linewidth=2)
500
+ ax.set_title(f"Forecast - {model_name}")
501
+ ax.set_xlabel(time_col)
502
+ ax.set_ylabel(target_col)
503
+ ax.legend()
504
+ ax.grid(alpha=0.3)
505
+ fig.autofmt_xdate()
506
+ fig.tight_layout()
507
+ fig.savefig(output_path, dpi=160)
508
+ plt.close(fig)
509
+
510
+
511
+ def export_benchmark_plot(
512
+ actual_frame: pd.DataFrame,
513
+ forecast_frame: pd.DataFrame,
514
+ time_col: str,
515
+ target_col: str,
516
+ output_path: Path,
517
+ model_name: str,
518
+ ) -> None:
519
+ output_path.parent.mkdir(parents=True, exist_ok=True)
520
+ fig, ax = plt.subplots(figsize=(10, 5))
521
+ ax.plot(actual_frame[time_col], actual_frame[target_col], label="actual", linewidth=2)
522
+ ax.plot(forecast_frame[time_col], forecast_frame[target_col], label=f"predicted ({model_name})", linewidth=2)
523
+ ax.set_title(f"Benchmark Holdout - {model_name}")
524
+ ax.set_xlabel(time_col)
525
+ ax.set_ylabel(target_col)
526
+ ax.legend()
527
+ ax.grid(alpha=0.3)
528
+ fig.autofmt_xdate()
529
+ fig.tight_layout()
530
+ fig.savefig(output_path, dpi=160)
531
+ plt.close(fig)