tscli-darts 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tscli/__init__.py +5 -0
- tscli/__main__.py +5 -0
- tscli/analysis.py +63 -0
- tscli/data.py +64 -0
- tscli/forecasting.py +531 -0
- tscli/main.py +422 -0
- tscli/preprocessing.py +113 -0
- tscli_darts-0.1.0.dist-info/METADATA +204 -0
- tscli_darts-0.1.0.dist-info/RECORD +12 -0
- tscli_darts-0.1.0.dist-info/WHEEL +5 -0
- tscli_darts-0.1.0.dist-info/entry_points.txt +2 -0
- tscli_darts-0.1.0.dist-info/top_level.txt +1 -0
tscli/__init__.py
ADDED
tscli/__main__.py
ADDED
tscli/analysis.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
from tscli.data import LoadedSeries
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class SeriesSummary:
|
|
12
|
+
row_count: int
|
|
13
|
+
start: str
|
|
14
|
+
end: str
|
|
15
|
+
missing_target: int
|
|
16
|
+
mean: float
|
|
17
|
+
median: float
|
|
18
|
+
minimum: float
|
|
19
|
+
maximum: float
|
|
20
|
+
std_dev: float
|
|
21
|
+
inferred_frequency: str
|
|
22
|
+
trend_direction: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def summarize_series(dataset: LoadedSeries) -> SeriesSummary:
|
|
26
|
+
frame = dataset.frame.copy()
|
|
27
|
+
target = frame[dataset.target_col]
|
|
28
|
+
|
|
29
|
+
inferred_frequency = "not available"
|
|
30
|
+
if dataset.time_col != "__index__":
|
|
31
|
+
inferred = pd.infer_freq(frame[dataset.time_col])
|
|
32
|
+
if inferred:
|
|
33
|
+
inferred_frequency = inferred
|
|
34
|
+
|
|
35
|
+
clean_target = target.dropna()
|
|
36
|
+
if clean_target.empty:
|
|
37
|
+
raise ValueError("The target series is empty after dropping missing values.")
|
|
38
|
+
|
|
39
|
+
trend_delta = clean_target.iloc[-1] - clean_target.iloc[0]
|
|
40
|
+
if trend_delta > 0:
|
|
41
|
+
trend_direction = "upward"
|
|
42
|
+
elif trend_delta < 0:
|
|
43
|
+
trend_direction = "downward"
|
|
44
|
+
else:
|
|
45
|
+
trend_direction = "flat"
|
|
46
|
+
|
|
47
|
+
return SeriesSummary(
|
|
48
|
+
row_count=len(frame),
|
|
49
|
+
start=str(frame[dataset.time_col].iloc[0]),
|
|
50
|
+
end=str(frame[dataset.time_col].iloc[-1]),
|
|
51
|
+
missing_target=int(target.isna().sum()),
|
|
52
|
+
mean=float(clean_target.mean()),
|
|
53
|
+
median=float(clean_target.median()),
|
|
54
|
+
minimum=float(clean_target.min()),
|
|
55
|
+
maximum=float(clean_target.max()),
|
|
56
|
+
std_dev=float(clean_target.std(ddof=0)),
|
|
57
|
+
inferred_frequency=inferred_frequency,
|
|
58
|
+
trend_direction=trend_direction,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def recent_observations(dataset: LoadedSeries, rows: int = 5) -> pd.DataFrame:
|
|
63
|
+
return dataset.frame[[dataset.time_col, dataset.target_col]].tail(rows).copy()
|
tscli/data.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from tscli.preprocessing import (
|
|
9
|
+
PreprocessingReport,
|
|
10
|
+
clean_numeric_column,
|
|
11
|
+
finalize_time_series,
|
|
12
|
+
normalize_columns,
|
|
13
|
+
parse_time_column,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class LoadedSeries:
|
|
19
|
+
source: Path
|
|
20
|
+
frame: pd.DataFrame
|
|
21
|
+
time_col: str
|
|
22
|
+
target_col: str
|
|
23
|
+
report: PreprocessingReport
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_csv(csv_path: Path, time_col: str | None, target_col: str) -> LoadedSeries:
|
|
27
|
+
frame = pd.read_csv(csv_path)
|
|
28
|
+
report = PreprocessingReport()
|
|
29
|
+
frame = normalize_columns(frame, report)
|
|
30
|
+
if target_col not in frame.columns:
|
|
31
|
+
raise ValueError(f"Target column '{target_col}' was not found in the CSV.")
|
|
32
|
+
|
|
33
|
+
resolved_time_col = time_col
|
|
34
|
+
if resolved_time_col is None:
|
|
35
|
+
for candidate in ("date", "datetime", "timestamp", "ds", "time"):
|
|
36
|
+
if candidate in frame.columns:
|
|
37
|
+
resolved_time_col = candidate
|
|
38
|
+
break
|
|
39
|
+
|
|
40
|
+
if resolved_time_col is not None:
|
|
41
|
+
if resolved_time_col not in frame.columns:
|
|
42
|
+
raise ValueError(f"Time column '{resolved_time_col}' was not found in the CSV.")
|
|
43
|
+
frame = parse_time_column(frame, resolved_time_col, report)
|
|
44
|
+
if frame[resolved_time_col].isna().any():
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Time column '{resolved_time_col}' contains values that could not be parsed as datetime."
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
resolved_time_col = "__index__"
|
|
50
|
+
frame[resolved_time_col] = pd.RangeIndex(start=0, stop=len(frame), step=1)
|
|
51
|
+
report.add_fix("Created a synthetic integer time index because no time column was provided.")
|
|
52
|
+
|
|
53
|
+
frame = clean_numeric_column(frame, target_col, report)
|
|
54
|
+
if frame[target_col].isna().all():
|
|
55
|
+
raise ValueError(f"Target column '{target_col}' does not contain numeric values.")
|
|
56
|
+
frame = finalize_time_series(frame, resolved_time_col, target_col, report)
|
|
57
|
+
|
|
58
|
+
return LoadedSeries(
|
|
59
|
+
source=csv_path,
|
|
60
|
+
frame=frame,
|
|
61
|
+
time_col=resolved_time_col,
|
|
62
|
+
target_col=target_col,
|
|
63
|
+
report=report,
|
|
64
|
+
)
|
tscli/forecasting.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from importlib import import_module
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from darts import TimeSeries
|
|
12
|
+
from statsmodels.tsa.arima.model import ARIMA as StatsmodelsARIMA
|
|
13
|
+
from statsmodels.tsa.statespace.sarimax import SARIMAX
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
from tscli.data import LoadedSeries
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ModelSpec:
|
|
21
|
+
description: str
|
|
22
|
+
family: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
MODEL_SPECS = {
|
|
26
|
+
"naive-last": ModelSpec("Repeats the last observed value.", "built-in"),
|
|
27
|
+
"naive-drift": ModelSpec("Extends the line from the first to the last observation.", "built-in"),
|
|
28
|
+
"naive-seasonal": ModelSpec("Repeats the last seasonal pattern.", "built-in"),
|
|
29
|
+
"moving-average": ModelSpec("Forecasts with the mean of the latest seasonal window.", "built-in"),
|
|
30
|
+
"weighted-moving-average": ModelSpec(
|
|
31
|
+
"Forecasts with a linearly weighted average of the latest seasonal window.", "built-in"
|
|
32
|
+
),
|
|
33
|
+
"exp-smoothing": ModelSpec("Forecasts with an exponentially weighted moving average level.", "built-in"),
|
|
34
|
+
"seasonal-average": ModelSpec(
|
|
35
|
+
"Forecasts each seasonal position with the average of past matching positions.", "built-in"
|
|
36
|
+
),
|
|
37
|
+
"seasonal-median": ModelSpec(
|
|
38
|
+
"Forecasts each seasonal position with the median of past matching positions.", "built-in"
|
|
39
|
+
),
|
|
40
|
+
"linear-trend": ModelSpec("Fits a straight trend line across the series.", "built-in"),
|
|
41
|
+
"quadratic-trend": ModelSpec("Fits a quadratic trend curve across the series.", "built-in"),
|
|
42
|
+
"arima": ModelSpec("DARTS ARIMA model for classical forecasting.", "darts-classical"),
|
|
43
|
+
"theta": ModelSpec("DARTS Theta model for classical univariate forecasting.", "darts-classical"),
|
|
44
|
+
"exponential-smoothing": ModelSpec(
|
|
45
|
+
"DARTS ExponentialSmoothing model for level, trend, and seasonality.", "darts-classical"
|
|
46
|
+
),
|
|
47
|
+
"auto-arima": ModelSpec("DARTS AutoARIMA model with automatic order selection.", "darts-classical"),
|
|
48
|
+
"sarima": ModelSpec("DARTS ARIMA model configured with seasonal ARIMA defaults.", "darts-classical"),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
SUPPORTED_MODELS = {name: spec.description for name, spec in MODEL_SPECS.items()}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ForecastResult:
|
|
56
|
+
model_name: str
|
|
57
|
+
forecast_frame: pd.DataFrame
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class EvaluationResult:
|
|
62
|
+
model_name: str
|
|
63
|
+
mae: float
|
|
64
|
+
rmse: float
|
|
65
|
+
mape: float
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class BenchmarkResult:
|
|
70
|
+
scores: list[EvaluationResult]
|
|
71
|
+
actual_frame: pd.DataFrame
|
|
72
|
+
forecasts: dict[str, pd.DataFrame]
|
|
73
|
+
best_model: str
|
|
74
|
+
skipped_models: dict[str, str]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def build_series(dataset: LoadedSeries) -> TimeSeries:
|
|
78
|
+
frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
|
|
79
|
+
|
|
80
|
+
if dataset.time_col == "__index__":
|
|
81
|
+
frame[dataset.time_col] = pd.RangeIndex(start=0, stop=len(frame), step=1)
|
|
82
|
+
return TimeSeries.from_dataframe(
|
|
83
|
+
frame,
|
|
84
|
+
time_col=dataset.time_col,
|
|
85
|
+
value_cols=dataset.target_col,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return TimeSeries.from_dataframe(
|
|
89
|
+
frame,
|
|
90
|
+
time_col=dataset.time_col,
|
|
91
|
+
value_cols=dataset.target_col,
|
|
92
|
+
fill_missing_dates=False,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _future_index(frame: pd.DataFrame, time_col: str, horizon: int) -> pd.Index:
|
|
97
|
+
if time_col == "__index__":
|
|
98
|
+
start = int(frame[time_col].iloc[-1]) + 1
|
|
99
|
+
return pd.RangeIndex(start=start, stop=start + horizon, step=1)
|
|
100
|
+
|
|
101
|
+
inferred = pd.infer_freq(frame[time_col])
|
|
102
|
+
if inferred:
|
|
103
|
+
offset = pd.tseries.frequencies.to_offset(inferred)
|
|
104
|
+
start = frame[time_col].iloc[-1] + offset
|
|
105
|
+
return pd.date_range(start=start, periods=horizon, freq=offset)
|
|
106
|
+
|
|
107
|
+
if len(frame) >= 2:
|
|
108
|
+
delta = frame[time_col].iloc[-1] - frame[time_col].iloc[-2]
|
|
109
|
+
else:
|
|
110
|
+
delta = pd.Timedelta(days=1)
|
|
111
|
+
start = frame[time_col].iloc[-1] + delta
|
|
112
|
+
return pd.Index([start + delta * step for step in range(horizon)])
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _coerce_output_frame(frame: pd.DataFrame, time_col: str) -> pd.DataFrame:
|
|
116
|
+
export = frame.copy()
|
|
117
|
+
if time_col != "__index__" and pd.api.types.is_datetime64_any_dtype(export[time_col]):
|
|
118
|
+
export[time_col] = export[time_col].dt.strftime("%Y-%m-%d")
|
|
119
|
+
return export
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _dataset_from_frame(dataset: LoadedSeries, frame: pd.DataFrame) -> LoadedSeries:
|
|
123
|
+
return LoadedSeries(
|
|
124
|
+
source=dataset.source,
|
|
125
|
+
frame=frame.reset_index(drop=True),
|
|
126
|
+
time_col=dataset.time_col,
|
|
127
|
+
target_col=dataset.target_col,
|
|
128
|
+
report=dataset.report,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _load_optional_model(model_name: str) -> Any:
|
|
133
|
+
loaders = {
|
|
134
|
+
"arima": ("darts.models.forecasting.arima", "ARIMA"),
|
|
135
|
+
"theta": ("darts.models.forecasting.theta", "Theta"),
|
|
136
|
+
"exponential-smoothing": ("darts.models.forecasting.exponential_smoothing", "ExponentialSmoothing"),
|
|
137
|
+
"auto-arima": ("darts.models.forecasting.sf_auto_arima", "AutoARIMA"),
|
|
138
|
+
"sarima": ("darts.models.forecasting.arima", "ARIMA"),
|
|
139
|
+
}
|
|
140
|
+
module_name, class_name = loaders[model_name]
|
|
141
|
+
try:
|
|
142
|
+
module = import_module(module_name)
|
|
143
|
+
return getattr(module, class_name)
|
|
144
|
+
except Exception as exc: # pragma: no cover - environment dependent
|
|
145
|
+
extra_hint = "Install tscli with the right optional extra."
|
|
146
|
+
if model_name in {"arima", "theta", "exponential-smoothing", "sarima"}:
|
|
147
|
+
extra_hint = "Install tscli with the 'classical' extra: pip install -e .[classical]"
|
|
148
|
+
elif model_name == "auto-arima":
|
|
149
|
+
extra_hint = (
|
|
150
|
+
"Install tscli with the 'autoarima' extra, or 'full' for everything: "
|
|
151
|
+
"pip install -e .[autoarima]"
|
|
152
|
+
)
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Model '{model_name}' is unavailable in this environment. {extra_hint}. "
|
|
155
|
+
f"Original error: {exc}"
|
|
156
|
+
) from exc
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _heuristic_forecast_values(
|
|
160
|
+
history: np.ndarray,
|
|
161
|
+
model_name: str,
|
|
162
|
+
horizon: int,
|
|
163
|
+
seasonal_period: int,
|
|
164
|
+
) -> np.ndarray:
|
|
165
|
+
if len(history) == 0:
|
|
166
|
+
raise ValueError("The target series is empty after dropping missing values.")
|
|
167
|
+
|
|
168
|
+
if model_name == "naive-last":
|
|
169
|
+
return np.repeat(history[-1], horizon)
|
|
170
|
+
|
|
171
|
+
if model_name == "naive-drift":
|
|
172
|
+
if len(history) == 1:
|
|
173
|
+
return np.repeat(history[-1], horizon)
|
|
174
|
+
slope = (history[-1] - history[0]) / (len(history) - 1)
|
|
175
|
+
return np.array([history[-1] + slope * step for step in range(1, horizon + 1)])
|
|
176
|
+
|
|
177
|
+
if model_name == "naive-seasonal":
|
|
178
|
+
if len(history) < seasonal_period:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Naive seasonal needs at least {seasonal_period} observations, but found {len(history)}."
|
|
181
|
+
)
|
|
182
|
+
pattern = history[-seasonal_period:]
|
|
183
|
+
return np.array([pattern[step % seasonal_period] for step in range(horizon)])
|
|
184
|
+
|
|
185
|
+
if model_name == "moving-average":
|
|
186
|
+
window = min(seasonal_period, len(history))
|
|
187
|
+
average = float(np.mean(history[-window:]))
|
|
188
|
+
return np.repeat(average, horizon)
|
|
189
|
+
|
|
190
|
+
if model_name == "weighted-moving-average":
|
|
191
|
+
window = min(seasonal_period, len(history))
|
|
192
|
+
weights = np.arange(1, window + 1, dtype=float)
|
|
193
|
+
average = float(np.average(history[-window:], weights=weights))
|
|
194
|
+
return np.repeat(average, horizon)
|
|
195
|
+
|
|
196
|
+
if model_name == "exp-smoothing":
|
|
197
|
+
span = max(2, min(seasonal_period, len(history)))
|
|
198
|
+
level = float(pd.Series(history).ewm(span=span, adjust=False).mean().iloc[-1])
|
|
199
|
+
return np.repeat(level, horizon)
|
|
200
|
+
|
|
201
|
+
if model_name == "seasonal-average":
|
|
202
|
+
if len(history) < seasonal_period:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Seasonal average needs at least {seasonal_period} observations, but found {len(history)}."
|
|
205
|
+
)
|
|
206
|
+
values = []
|
|
207
|
+
for step in range(horizon):
|
|
208
|
+
position = step % seasonal_period
|
|
209
|
+
seasonal_slice = history[position::seasonal_period]
|
|
210
|
+
values.append(float(np.mean(seasonal_slice)))
|
|
211
|
+
return np.array(values)
|
|
212
|
+
|
|
213
|
+
if model_name == "seasonal-median":
|
|
214
|
+
if len(history) < seasonal_period:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
f"Seasonal median needs at least {seasonal_period} observations, but found {len(history)}."
|
|
217
|
+
)
|
|
218
|
+
values = []
|
|
219
|
+
for step in range(horizon):
|
|
220
|
+
position = step % seasonal_period
|
|
221
|
+
seasonal_slice = history[position::seasonal_period]
|
|
222
|
+
values.append(float(np.median(seasonal_slice)))
|
|
223
|
+
return np.array(values)
|
|
224
|
+
|
|
225
|
+
if model_name == "quadratic-trend":
|
|
226
|
+
if len(history) < 3:
|
|
227
|
+
raise ValueError("Quadratic trend needs at least 3 observations.")
|
|
228
|
+
x = np.arange(len(history), dtype=float)
|
|
229
|
+
a, b, c = np.polyfit(x, history.astype(float), 2)
|
|
230
|
+
future_x = np.arange(len(history), len(history) + horizon, dtype=float)
|
|
231
|
+
return a * np.square(future_x) + b * future_x + c
|
|
232
|
+
|
|
233
|
+
x = np.arange(len(history), dtype=float)
|
|
234
|
+
slope, intercept = np.polyfit(x, history.astype(float), 1)
|
|
235
|
+
future_x = np.arange(len(history), len(history) + horizon, dtype=float)
|
|
236
|
+
return intercept + slope * future_x
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _darts_classical_forecast(
|
|
240
|
+
dataset: LoadedSeries,
|
|
241
|
+
model_name: str,
|
|
242
|
+
horizon: int,
|
|
243
|
+
seasonal_period: int,
|
|
244
|
+
) -> pd.DataFrame:
|
|
245
|
+
series = build_series(dataset)
|
|
246
|
+
model_class = _load_optional_model(model_name)
|
|
247
|
+
|
|
248
|
+
if model_name == "theta":
|
|
249
|
+
model = model_class()
|
|
250
|
+
elif model_name == "exponential-smoothing":
|
|
251
|
+
kwargs: dict[str, Any] = {}
|
|
252
|
+
if len(series) >= seasonal_period * 2:
|
|
253
|
+
kwargs["seasonal_periods"] = seasonal_period
|
|
254
|
+
model = model_class(**kwargs)
|
|
255
|
+
elif model_name == "auto-arima":
|
|
256
|
+
kwargs = {}
|
|
257
|
+
if len(series) >= seasonal_period * 2:
|
|
258
|
+
kwargs["season_length"] = seasonal_period
|
|
259
|
+
model = model_class(**kwargs)
|
|
260
|
+
elif model_name == "sarima":
|
|
261
|
+
kwargs = {"p": 1, "d": 1, "q": 1}
|
|
262
|
+
if len(series) >= seasonal_period * 2:
|
|
263
|
+
kwargs["seasonal_order"] = (1, 1, 1, seasonal_period)
|
|
264
|
+
model = model_class(**kwargs)
|
|
265
|
+
else:
|
|
266
|
+
model = model_class(p=1, d=1, q=1)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
model.fit(series)
|
|
270
|
+
forecast = model.predict(horizon)
|
|
271
|
+
except Exception as exc: # pragma: no cover - environment dependent
|
|
272
|
+
raise ValueError(f"Model '{model_name}' failed to fit or predict. Original error: {exc}") from exc
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
if hasattr(forecast, "pd_dataframe"):
|
|
276
|
+
forecast_frame = forecast.pd_dataframe().reset_index()
|
|
277
|
+
elif hasattr(forecast, "to_dataframe"):
|
|
278
|
+
forecast_frame = forecast.to_dataframe().reset_index()
|
|
279
|
+
elif hasattr(forecast, "pd_series"):
|
|
280
|
+
forecast_frame = forecast.pd_series().to_frame(name=dataset.target_col).reset_index()
|
|
281
|
+
else:
|
|
282
|
+
raise AttributeError("No supported dataframe conversion method was found on the DARTS TimeSeries result.")
|
|
283
|
+
except Exception as exc: # pragma: no cover - environment dependent
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Model '{model_name}' produced a forecast, but tscli could not convert it to a table. "
|
|
286
|
+
f"Original error: {exc}"
|
|
287
|
+
) from exc
|
|
288
|
+
|
|
289
|
+
forecast_frame.columns = [dataset.time_col, dataset.target_col]
|
|
290
|
+
return forecast_frame
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _statsmodels_fallback_forecast(
|
|
294
|
+
dataset: LoadedSeries,
|
|
295
|
+
model_name: str,
|
|
296
|
+
horizon: int,
|
|
297
|
+
seasonal_period: int,
|
|
298
|
+
) -> pd.DataFrame:
|
|
299
|
+
frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
|
|
300
|
+
history = frame[dataset.target_col].astype(float).to_numpy()
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
if model_name == "arima":
|
|
304
|
+
fitted = StatsmodelsARIMA(history, order=(1, 1, 1)).fit()
|
|
305
|
+
else:
|
|
306
|
+
seasonal_order = (1, 1, 1, seasonal_period)
|
|
307
|
+
if len(history) < max(24, seasonal_period * 2 + 6):
|
|
308
|
+
seasonal_order = (0, 1, 1, seasonal_period)
|
|
309
|
+
fitted = SARIMAX(
|
|
310
|
+
history,
|
|
311
|
+
order=(1, 1, 1),
|
|
312
|
+
seasonal_order=seasonal_order,
|
|
313
|
+
enforce_stationarity=False,
|
|
314
|
+
enforce_invertibility=False,
|
|
315
|
+
).fit(disp=False)
|
|
316
|
+
forecast_values = np.asarray(fitted.forecast(steps=horizon), dtype=float)
|
|
317
|
+
except Exception as exc:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Model '{model_name}' is unavailable through DARTS and the fallback implementation also failed. "
|
|
320
|
+
f"Original error: {exc}"
|
|
321
|
+
) from exc
|
|
322
|
+
|
|
323
|
+
return pd.DataFrame(
|
|
324
|
+
{
|
|
325
|
+
dataset.time_col: _future_index(frame, dataset.time_col, horizon),
|
|
326
|
+
dataset.target_col: forecast_values,
|
|
327
|
+
}
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _statsmodels_classical_forecast(
|
|
332
|
+
dataset: LoadedSeries,
|
|
333
|
+
model_name: str,
|
|
334
|
+
horizon: int,
|
|
335
|
+
seasonal_period: int,
|
|
336
|
+
) -> pd.DataFrame:
|
|
337
|
+
frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
|
|
338
|
+
history = frame[dataset.target_col].astype(float).to_numpy()
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
if model_name == "arima":
|
|
342
|
+
fitted = StatsmodelsARIMA(history, order=(1, 1, 1)).fit()
|
|
343
|
+
else:
|
|
344
|
+
seasonal_order = (1, 1, 1, seasonal_period)
|
|
345
|
+
# Small datasets often cannot support a full seasonal parameterization.
|
|
346
|
+
if len(history) < max(24, seasonal_period * 2 + 6):
|
|
347
|
+
seasonal_order = (0, 1, 1, seasonal_period)
|
|
348
|
+
fitted = SARIMAX(
|
|
349
|
+
history,
|
|
350
|
+
order=(1, 1, 1),
|
|
351
|
+
seasonal_order=seasonal_order,
|
|
352
|
+
enforce_stationarity=False,
|
|
353
|
+
enforce_invertibility=False,
|
|
354
|
+
).fit(disp=False)
|
|
355
|
+
forecast_values = np.asarray(fitted.forecast(steps=horizon), dtype=float)
|
|
356
|
+
except Exception as exc:
|
|
357
|
+
raise ValueError(f"Model '{model_name}' failed to fit or predict. Original error: {exc}") from exc
|
|
358
|
+
|
|
359
|
+
forecast_frame = pd.DataFrame(
|
|
360
|
+
{
|
|
361
|
+
dataset.time_col: _future_index(frame, dataset.time_col, horizon),
|
|
362
|
+
dataset.target_col: forecast_values,
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
return forecast_frame
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def generate_forecast(
|
|
369
|
+
dataset: LoadedSeries,
|
|
370
|
+
model_name: str,
|
|
371
|
+
horizon: int,
|
|
372
|
+
seasonal_period: int = 12,
|
|
373
|
+
) -> ForecastResult:
|
|
374
|
+
if model_name not in SUPPORTED_MODELS:
|
|
375
|
+
allowed = ", ".join(sorted(SUPPORTED_MODELS))
|
|
376
|
+
raise ValueError(f"Unsupported model '{model_name}'. Choose from: {allowed}.")
|
|
377
|
+
|
|
378
|
+
frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().copy()
|
|
379
|
+
family = MODEL_SPECS[model_name].family
|
|
380
|
+
|
|
381
|
+
if family == "built-in":
|
|
382
|
+
series = build_series(dataset)
|
|
383
|
+
history = series.values(copy=False).reshape(-1)
|
|
384
|
+
forecast_values = _heuristic_forecast_values(history, model_name, horizon, seasonal_period)
|
|
385
|
+
forecast_frame = pd.DataFrame(
|
|
386
|
+
{
|
|
387
|
+
dataset.time_col: _future_index(frame, dataset.time_col, horizon),
|
|
388
|
+
dataset.target_col: forecast_values,
|
|
389
|
+
}
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
try:
|
|
393
|
+
forecast_frame = _darts_classical_forecast(dataset, model_name, horizon, seasonal_period)
|
|
394
|
+
except ValueError as exc:
|
|
395
|
+
if model_name in {"arima", "sarima"} and "unavailable in this environment" in str(exc):
|
|
396
|
+
forecast_frame = _statsmodels_fallback_forecast(dataset, model_name, horizon, seasonal_period)
|
|
397
|
+
else:
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
return ForecastResult(model_name=model_name, forecast_frame=forecast_frame)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def evaluate_forecast(actual: np.ndarray, predicted: np.ndarray, model_name: str) -> EvaluationResult:
|
|
404
|
+
actual = actual.astype(float)
|
|
405
|
+
predicted = predicted.astype(float)
|
|
406
|
+
errors = actual - predicted
|
|
407
|
+
mae = float(np.mean(np.abs(errors)))
|
|
408
|
+
rmse = float(np.sqrt(np.mean(np.square(errors))))
|
|
409
|
+
|
|
410
|
+
non_zero_mask = actual != 0
|
|
411
|
+
if non_zero_mask.any():
|
|
412
|
+
mape = float(np.mean(np.abs(errors[non_zero_mask] / actual[non_zero_mask])) * 100)
|
|
413
|
+
else:
|
|
414
|
+
mape = float("nan")
|
|
415
|
+
|
|
416
|
+
return EvaluationResult(model_name=model_name, mae=mae, rmse=rmse, mape=mape)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def benchmark_models(
|
|
420
|
+
dataset: LoadedSeries,
|
|
421
|
+
model_names: list[str],
|
|
422
|
+
horizon: int,
|
|
423
|
+
seasonal_period: int = 12,
|
|
424
|
+
) -> BenchmarkResult:
|
|
425
|
+
clean_frame = dataset.frame[[dataset.time_col, dataset.target_col]].dropna().reset_index(drop=True)
|
|
426
|
+
if len(clean_frame) <= horizon:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
f"Need more than {horizon} observations to benchmark models, but found {len(clean_frame)}."
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
train_frame = clean_frame.iloc[:-horizon].copy()
|
|
432
|
+
actual_frame = clean_frame.iloc[-horizon:].copy().reset_index(drop=True)
|
|
433
|
+
train_dataset = _dataset_from_frame(dataset, train_frame)
|
|
434
|
+
|
|
435
|
+
scores: list[EvaluationResult] = []
|
|
436
|
+
forecasts: dict[str, pd.DataFrame] = {}
|
|
437
|
+
skipped_models: dict[str, str] = {}
|
|
438
|
+
for model_name in tqdm(model_names, desc="Benchmarking models", unit="model"):
|
|
439
|
+
try:
|
|
440
|
+
result = generate_forecast(
|
|
441
|
+
train_dataset,
|
|
442
|
+
model_name=model_name,
|
|
443
|
+
horizon=horizon,
|
|
444
|
+
seasonal_period=seasonal_period,
|
|
445
|
+
)
|
|
446
|
+
predicted = result.forecast_frame[dataset.target_col].to_numpy(dtype=float)
|
|
447
|
+
actual = actual_frame[dataset.target_col].to_numpy(dtype=float)
|
|
448
|
+
scores.append(evaluate_forecast(actual, predicted, model_name))
|
|
449
|
+
forecasts[model_name] = result.forecast_frame
|
|
450
|
+
except Exception as exc:
|
|
451
|
+
skipped_models[model_name] = str(exc)
|
|
452
|
+
|
|
453
|
+
if not scores:
|
|
454
|
+
raise ValueError("No models could be benchmarked successfully with the current data and settings.")
|
|
455
|
+
|
|
456
|
+
scores.sort(key=lambda item: (item.rmse, item.mae, item.model_name))
|
|
457
|
+
return BenchmarkResult(
|
|
458
|
+
scores=scores,
|
|
459
|
+
actual_frame=actual_frame,
|
|
460
|
+
forecasts=forecasts,
|
|
461
|
+
best_model=scores[0].model_name,
|
|
462
|
+
skipped_models=skipped_models,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def export_frame(frame: pd.DataFrame, output_path: Path, time_col: str) -> None:
|
|
467
|
+
export = _coerce_output_frame(frame, time_col)
|
|
468
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
469
|
+
export.to_csv(output_path, index=False)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def export_scores(scores: list[EvaluationResult], output_path: Path) -> None:
|
|
473
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
474
|
+
score_frame = pd.DataFrame(
|
|
475
|
+
[
|
|
476
|
+
{
|
|
477
|
+
"model": score.model_name,
|
|
478
|
+
"mae": score.mae,
|
|
479
|
+
"rmse": score.rmse,
|
|
480
|
+
"mape": score.mape,
|
|
481
|
+
}
|
|
482
|
+
for score in scores
|
|
483
|
+
]
|
|
484
|
+
)
|
|
485
|
+
score_frame.to_csv(output_path, index=False)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def export_forecast_plot(
|
|
489
|
+
history_frame: pd.DataFrame,
|
|
490
|
+
forecast_frame: pd.DataFrame,
|
|
491
|
+
time_col: str,
|
|
492
|
+
target_col: str,
|
|
493
|
+
output_path: Path,
|
|
494
|
+
model_name: str,
|
|
495
|
+
) -> None:
|
|
496
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
497
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
498
|
+
ax.plot(history_frame[time_col], history_frame[target_col], label="history", linewidth=2)
|
|
499
|
+
ax.plot(forecast_frame[time_col], forecast_frame[target_col], label=f"forecast ({model_name})", linewidth=2)
|
|
500
|
+
ax.set_title(f"Forecast - {model_name}")
|
|
501
|
+
ax.set_xlabel(time_col)
|
|
502
|
+
ax.set_ylabel(target_col)
|
|
503
|
+
ax.legend()
|
|
504
|
+
ax.grid(alpha=0.3)
|
|
505
|
+
fig.autofmt_xdate()
|
|
506
|
+
fig.tight_layout()
|
|
507
|
+
fig.savefig(output_path, dpi=160)
|
|
508
|
+
plt.close(fig)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def export_benchmark_plot(
|
|
512
|
+
actual_frame: pd.DataFrame,
|
|
513
|
+
forecast_frame: pd.DataFrame,
|
|
514
|
+
time_col: str,
|
|
515
|
+
target_col: str,
|
|
516
|
+
output_path: Path,
|
|
517
|
+
model_name: str,
|
|
518
|
+
) -> None:
|
|
519
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
520
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
521
|
+
ax.plot(actual_frame[time_col], actual_frame[target_col], label="actual", linewidth=2)
|
|
522
|
+
ax.plot(forecast_frame[time_col], forecast_frame[target_col], label=f"predicted ({model_name})", linewidth=2)
|
|
523
|
+
ax.set_title(f"Benchmark Holdout - {model_name}")
|
|
524
|
+
ax.set_xlabel(time_col)
|
|
525
|
+
ax.set_ylabel(target_col)
|
|
526
|
+
ax.legend()
|
|
527
|
+
ax.grid(alpha=0.3)
|
|
528
|
+
fig.autofmt_xdate()
|
|
529
|
+
fig.tight_layout()
|
|
530
|
+
fig.savefig(output_path, dpi=160)
|
|
531
|
+
plt.close(fig)
|