spotforecast2-safe 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2_safe/preprocessing/__init__.py +1 -6
- {spotforecast2_safe-1.0.0.dist-info → spotforecast2_safe-1.0.2.dist-info}/METADATA +1 -1
- {spotforecast2_safe-1.0.0.dist-info → spotforecast2_safe-1.0.2.dist-info}/RECORD +4 -18
- spotforecast2_safe/forecaster/metrics.py +0 -527
- spotforecast2_safe/model_selection/__init__.py +0 -5
- spotforecast2_safe/model_selection/bayesian_search.py +0 -453
- spotforecast2_safe/model_selection/grid_search.py +0 -314
- spotforecast2_safe/model_selection/random_search.py +0 -151
- spotforecast2_safe/model_selection/split_base.py +0 -357
- spotforecast2_safe/model_selection/split_one_step.py +0 -248
- spotforecast2_safe/model_selection/split_ts_cv.py +0 -687
- spotforecast2_safe/model_selection/utils_common.py +0 -718
- spotforecast2_safe/model_selection/utils_metrics.py +0 -103
- spotforecast2_safe/model_selection/validation.py +0 -685
- spotforecast2_safe/preprocessing/time_series_visualization.py +0 -815
- spotforecast2_safe/stats/__init__.py +0 -7
- spotforecast2_safe/stats/autocorrelation.py +0 -173
- {spotforecast2_safe-1.0.0.dist-info → spotforecast2_safe-1.0.2.dist-info}/WHEEL +0 -0
|
@@ -1,815 +0,0 @@
|
|
|
1
|
-
"""Time series visualization."""
|
|
2
|
-
|
|
3
|
-
from typing import Dict, List, Optional, Any, Union
|
|
4
|
-
|
|
5
|
-
import pandas as pd
|
|
6
|
-
import matplotlib.pyplot as plt
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
try:
|
|
10
|
-
import plotly.graph_objects as go
|
|
11
|
-
except ImportError:
|
|
12
|
-
go = None
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def visualize_ts_plotly(
|
|
16
|
-
dataframes: Dict[str, pd.DataFrame],
|
|
17
|
-
columns: Optional[List[str]] = None,
|
|
18
|
-
title_suffix: str = "",
|
|
19
|
-
figsize: tuple[int, int] = (1000, 500),
|
|
20
|
-
template: str = "plotly_white",
|
|
21
|
-
colors: Optional[Dict[str, str]] = None,
|
|
22
|
-
**kwargs: Any,
|
|
23
|
-
) -> None:
|
|
24
|
-
"""Visualize multiple time series datasets interactively with Plotly.
|
|
25
|
-
|
|
26
|
-
Creates interactive Plotly scatter plots for specified columns across multiple
|
|
27
|
-
datasets (e.g., train, validation, test splits). Each dataset is displayed as
|
|
28
|
-
a separate line with a unique color and name in the legend.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
dataframes: Dictionary mapping dataset names to pandas DataFrames with datetime
|
|
32
|
-
index. Example: {'Train': df_train, 'Validation': df_val, 'Test': df_test}
|
|
33
|
-
columns: List of column names to visualize. If None, all columns are used.
|
|
34
|
-
Default: None.
|
|
35
|
-
title_suffix: Suffix to append to the column name in the title. Useful for
|
|
36
|
-
adding units or descriptions. Default: "".
|
|
37
|
-
figsize: Figure size as (width, height) in pixels. Default: (1000, 500).
|
|
38
|
-
template: Plotly template name for styling. Options include 'plotly_white',
|
|
39
|
-
'plotly_dark', 'plotly', 'ggplot2', etc. Default: 'plotly_white'.
|
|
40
|
-
colors: Dictionary mapping dataset names to colors. If None, uses Plotly
|
|
41
|
-
default colors. Example: {'Train': 'blue', 'Validation': 'orange'}.
|
|
42
|
-
Default: None.
|
|
43
|
-
**kwargs: Additional keyword arguments passed to go.Scatter() (e.g.,
|
|
44
|
-
mode='lines+markers', line=dict(dash='dash')).
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
None. Displays Plotly figures.
|
|
48
|
-
|
|
49
|
-
Raises:
|
|
50
|
-
ValueError: If dataframes dict is empty, contains no columns, or if
|
|
51
|
-
specified columns don't exist in all dataframes.
|
|
52
|
-
ImportError: If plotly is not installed.
|
|
53
|
-
TypeError: If dataframes parameter is not a dictionary.
|
|
54
|
-
|
|
55
|
-
Examples:
|
|
56
|
-
>>> import pandas as pd
|
|
57
|
-
>>> import numpy as np
|
|
58
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import visualize_ts_plotly
|
|
59
|
-
>>>
|
|
60
|
-
>>> # Create sample time series data
|
|
61
|
-
>>> np.random.seed(42)
|
|
62
|
-
>>> dates_train = pd.date_range('2024-01-01', periods=100, freq='h')
|
|
63
|
-
>>> dates_val = pd.date_range('2024-05-11', periods=50, freq='h')
|
|
64
|
-
>>> dates_test = pd.date_range('2024-07-01', periods=30, freq='h')
|
|
65
|
-
>>>
|
|
66
|
-
>>> data_train = pd.DataFrame({
|
|
67
|
-
... 'temperature': np.random.normal(20, 5, 100),
|
|
68
|
-
... 'humidity': np.random.normal(60, 10, 100)
|
|
69
|
-
... }, index=dates_train)
|
|
70
|
-
>>>
|
|
71
|
-
>>> data_val = pd.DataFrame({
|
|
72
|
-
... 'temperature': np.random.normal(22, 5, 50),
|
|
73
|
-
... 'humidity': np.random.normal(55, 10, 50)
|
|
74
|
-
... }, index=dates_val)
|
|
75
|
-
>>>
|
|
76
|
-
>>> data_test = pd.DataFrame({
|
|
77
|
-
... 'temperature': np.random.normal(25, 5, 30),
|
|
78
|
-
... 'humidity': np.random.normal(50, 10, 30)
|
|
79
|
-
... }, index=dates_test)
|
|
80
|
-
>>>
|
|
81
|
-
>>> # Visualize all datasets
|
|
82
|
-
>>> dataframes = {
|
|
83
|
-
... 'Train': data_train,
|
|
84
|
-
... 'Validation': data_val,
|
|
85
|
-
... 'Test': data_test
|
|
86
|
-
... }
|
|
87
|
-
>>> visualize_ts_plotly(dataframes)
|
|
88
|
-
|
|
89
|
-
Single dataset example:
|
|
90
|
-
|
|
91
|
-
>>> # Visualize single dataset
|
|
92
|
-
>>> dataframes = {'Data': data_train}
|
|
93
|
-
>>> visualize_ts_plotly(dataframes, columns=['temperature'])
|
|
94
|
-
|
|
95
|
-
Custom styling:
|
|
96
|
-
|
|
97
|
-
>>> visualize_ts_plotly(
|
|
98
|
-
... dataframes,
|
|
99
|
-
... columns=['temperature'],
|
|
100
|
-
... template='plotly_dark',
|
|
101
|
-
... colors={'Train': 'blue', 'Validation': 'green', 'Test': 'red'},
|
|
102
|
-
... mode='lines+markers'
|
|
103
|
-
... )
|
|
104
|
-
"""
|
|
105
|
-
if go is None:
|
|
106
|
-
raise ImportError(
|
|
107
|
-
"plotly is required for this function. " "Install with: pip install plotly"
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
if not isinstance(dataframes, dict):
|
|
111
|
-
raise TypeError("dataframes parameter must be a dictionary")
|
|
112
|
-
|
|
113
|
-
if not dataframes:
|
|
114
|
-
raise ValueError("dataframes dictionary is empty")
|
|
115
|
-
|
|
116
|
-
# Validate all dataframes have data
|
|
117
|
-
for name, df in dataframes.items():
|
|
118
|
-
if df.empty:
|
|
119
|
-
raise ValueError(f"DataFrame '{name}' is empty")
|
|
120
|
-
if len(df.columns) == 0:
|
|
121
|
-
raise ValueError(f"DataFrame '{name}' contains no columns")
|
|
122
|
-
|
|
123
|
-
# Determine columns to plot
|
|
124
|
-
all_columns = set()
|
|
125
|
-
for df in dataframes.values():
|
|
126
|
-
all_columns.update(df.columns)
|
|
127
|
-
|
|
128
|
-
if not all_columns:
|
|
129
|
-
raise ValueError("No columns found in any dataframe")
|
|
130
|
-
|
|
131
|
-
columns_to_plot = columns if columns is not None else sorted(list(all_columns))
|
|
132
|
-
|
|
133
|
-
# Validate columns exist in all dataframes
|
|
134
|
-
for col in columns_to_plot:
|
|
135
|
-
for name, df in dataframes.items():
|
|
136
|
-
if col not in df.columns:
|
|
137
|
-
raise ValueError(f"Column '{col}' not found in dataframe '{name}'")
|
|
138
|
-
|
|
139
|
-
# Default colors if not provided
|
|
140
|
-
if colors is None:
|
|
141
|
-
# Use a set of distinct colors
|
|
142
|
-
default_colors = [
|
|
143
|
-
"#1f77b4", # blue
|
|
144
|
-
"#ff7f0e", # orange
|
|
145
|
-
"#2ca02c", # green
|
|
146
|
-
"#d62728", # red
|
|
147
|
-
"#9467bd", # purple
|
|
148
|
-
"#8c564b", # brown
|
|
149
|
-
"#e377c2", # pink
|
|
150
|
-
"#7f7f7f", # gray
|
|
151
|
-
"#bcbd22", # olive
|
|
152
|
-
"#17becf", # cyan
|
|
153
|
-
]
|
|
154
|
-
colors = {
|
|
155
|
-
name: default_colors[i % len(default_colors)]
|
|
156
|
-
for i, name in enumerate(dataframes.keys())
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
# Create figures for each column
|
|
160
|
-
for col in columns_to_plot:
|
|
161
|
-
fig = go.Figure()
|
|
162
|
-
|
|
163
|
-
# Add trace for each dataset
|
|
164
|
-
for dataset_name, df in dataframes.items():
|
|
165
|
-
fig.add_trace(
|
|
166
|
-
go.Scatter(
|
|
167
|
-
x=df.index,
|
|
168
|
-
y=df[col],
|
|
169
|
-
mode="lines",
|
|
170
|
-
name=dataset_name,
|
|
171
|
-
line=dict(color=colors[dataset_name]),
|
|
172
|
-
**kwargs,
|
|
173
|
-
)
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
# Create title
|
|
177
|
-
title = col
|
|
178
|
-
if title_suffix:
|
|
179
|
-
title = f"{col} {title_suffix}"
|
|
180
|
-
|
|
181
|
-
# Update layout
|
|
182
|
-
fig.update_layout(
|
|
183
|
-
title=title,
|
|
184
|
-
xaxis_title="Time",
|
|
185
|
-
yaxis_title=col,
|
|
186
|
-
width=figsize[0],
|
|
187
|
-
height=figsize[1],
|
|
188
|
-
template=template,
|
|
189
|
-
legend=dict(
|
|
190
|
-
orientation="h",
|
|
191
|
-
yanchor="bottom",
|
|
192
|
-
y=1.02,
|
|
193
|
-
xanchor="right",
|
|
194
|
-
x=1,
|
|
195
|
-
),
|
|
196
|
-
hovermode="x unified",
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
fig.show()
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
def visualize_ts_comparison(
|
|
203
|
-
dataframes: Dict[str, pd.DataFrame],
|
|
204
|
-
columns: Optional[List[str]] = None,
|
|
205
|
-
title_suffix: str = "",
|
|
206
|
-
figsize: tuple[int, int] = (1000, 500),
|
|
207
|
-
template: str = "plotly_white",
|
|
208
|
-
colors: Optional[Dict[str, str]] = None,
|
|
209
|
-
show_mean: bool = False,
|
|
210
|
-
**kwargs: Any,
|
|
211
|
-
) -> None:
|
|
212
|
-
"""Visualize time series with optional statistical overlays.
|
|
213
|
-
|
|
214
|
-
Similar to visualize_ts_plotly but adds options for statistical overlays
|
|
215
|
-
like mean values across all datasets.
|
|
216
|
-
|
|
217
|
-
Args:
|
|
218
|
-
dataframes: Dictionary mapping dataset names to pandas DataFrames.
|
|
219
|
-
columns: List of column names to visualize. If None, all columns are used.
|
|
220
|
-
Default: None.
|
|
221
|
-
title_suffix: Suffix to append to column names. Default: "".
|
|
222
|
-
figsize: Figure size as (width, height) in pixels. Default: (1000, 500).
|
|
223
|
-
template: Plotly template. Default: 'plotly_white'.
|
|
224
|
-
colors: Dictionary mapping dataset names to colors. Default: None.
|
|
225
|
-
show_mean: If True, overlay the mean of all datasets. Default: False.
|
|
226
|
-
**kwargs: Additional keyword arguments for go.Scatter().
|
|
227
|
-
|
|
228
|
-
Returns:
|
|
229
|
-
None. Displays Plotly figures.
|
|
230
|
-
|
|
231
|
-
Raises:
|
|
232
|
-
ValueError: If dataframes is empty.
|
|
233
|
-
ImportError: If plotly is not installed.
|
|
234
|
-
|
|
235
|
-
Examples:
|
|
236
|
-
>>> import pandas as pd
|
|
237
|
-
>>> import numpy as np
|
|
238
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import visualize_ts_comparison
|
|
239
|
-
>>>
|
|
240
|
-
>>> # Create sample data
|
|
241
|
-
>>> np.random.seed(42)
|
|
242
|
-
>>> dates1 = pd.date_range('2024-01-01', periods=100, freq='h')
|
|
243
|
-
>>> dates2 = pd.date_range('2024-05-11', periods=100, freq='h')
|
|
244
|
-
>>>
|
|
245
|
-
>>> df1 = pd.DataFrame({
|
|
246
|
-
... 'temperature': np.random.normal(20, 5, 100)
|
|
247
|
-
... }, index=dates1)
|
|
248
|
-
>>>
|
|
249
|
-
>>> df2 = pd.DataFrame({
|
|
250
|
-
... 'temperature': np.random.normal(22, 5, 100)
|
|
251
|
-
... }, index=dates2)
|
|
252
|
-
>>>
|
|
253
|
-
>>> # Compare with mean overlay
|
|
254
|
-
>>> visualize_ts_comparison(
|
|
255
|
-
... {'Dataset1': df1, 'Dataset2': df2},
|
|
256
|
-
... show_mean=True
|
|
257
|
-
... )
|
|
258
|
-
"""
|
|
259
|
-
if go is None:
|
|
260
|
-
raise ImportError(
|
|
261
|
-
"plotly is required for this function. " "Install with: pip install plotly"
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
if not dataframes:
|
|
265
|
-
raise ValueError("dataframes dictionary is empty")
|
|
266
|
-
|
|
267
|
-
# First visualize normally
|
|
268
|
-
visualize_ts_plotly(
|
|
269
|
-
dataframes,
|
|
270
|
-
columns=columns,
|
|
271
|
-
title_suffix=title_suffix,
|
|
272
|
-
figsize=figsize,
|
|
273
|
-
template=template,
|
|
274
|
-
colors=colors,
|
|
275
|
-
**kwargs,
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
# If show_mean, create additional mean plot
|
|
279
|
-
if show_mean:
|
|
280
|
-
# Determine columns to plot
|
|
281
|
-
all_columns = set()
|
|
282
|
-
for df in dataframes.values():
|
|
283
|
-
all_columns.update(df.columns)
|
|
284
|
-
|
|
285
|
-
columns_to_plot = columns if columns is not None else sorted(list(all_columns))
|
|
286
|
-
|
|
287
|
-
for col in columns_to_plot:
|
|
288
|
-
fig = go.Figure()
|
|
289
|
-
|
|
290
|
-
# Add individual traces
|
|
291
|
-
if colors is None:
|
|
292
|
-
default_colors = [
|
|
293
|
-
"#1f77b4",
|
|
294
|
-
"#ff7f0e",
|
|
295
|
-
"#2ca02c",
|
|
296
|
-
"#d62728",
|
|
297
|
-
"#9467bd",
|
|
298
|
-
]
|
|
299
|
-
colors_dict = {
|
|
300
|
-
name: default_colors[i % len(default_colors)]
|
|
301
|
-
for i, name in enumerate(dataframes.keys())
|
|
302
|
-
}
|
|
303
|
-
else:
|
|
304
|
-
colors_dict = colors
|
|
305
|
-
|
|
306
|
-
for dataset_name, df in dataframes.items():
|
|
307
|
-
fig.add_trace(
|
|
308
|
-
go.Scatter(
|
|
309
|
-
x=df.index,
|
|
310
|
-
y=df[col],
|
|
311
|
-
mode="lines",
|
|
312
|
-
name=dataset_name,
|
|
313
|
-
line=dict(color=colors_dict[dataset_name], width=1),
|
|
314
|
-
opacity=0.5,
|
|
315
|
-
**kwargs,
|
|
316
|
-
)
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
# Calculate and add mean
|
|
320
|
-
# Align all dataframes by index and compute mean
|
|
321
|
-
aligned_dfs = [
|
|
322
|
-
dataframes[name][[col]].rename(columns={col: name})
|
|
323
|
-
for name in dataframes.keys()
|
|
324
|
-
]
|
|
325
|
-
combined = pd.concat(aligned_dfs, axis=1)
|
|
326
|
-
mean_values = combined.mean(axis=1)
|
|
327
|
-
|
|
328
|
-
fig.add_trace(
|
|
329
|
-
go.Scatter(
|
|
330
|
-
x=mean_values.index,
|
|
331
|
-
y=mean_values,
|
|
332
|
-
mode="lines",
|
|
333
|
-
name="Mean",
|
|
334
|
-
line=dict(color="black", width=3, dash="dash"),
|
|
335
|
-
)
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
title = f"{col} (with mean){title_suffix}"
|
|
339
|
-
|
|
340
|
-
fig.update_layout(
|
|
341
|
-
title=title,
|
|
342
|
-
xaxis_title="Time",
|
|
343
|
-
yaxis_title=col,
|
|
344
|
-
width=figsize[0],
|
|
345
|
-
height=figsize[1],
|
|
346
|
-
template=template,
|
|
347
|
-
legend=dict(
|
|
348
|
-
orientation="h",
|
|
349
|
-
yanchor="bottom",
|
|
350
|
-
y=1.02,
|
|
351
|
-
xanchor="right",
|
|
352
|
-
x=1,
|
|
353
|
-
),
|
|
354
|
-
hovermode="x unified",
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
fig.show()
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def plot_zoomed_timeseries(
|
|
361
|
-
data: pd.DataFrame,
|
|
362
|
-
target: str,
|
|
363
|
-
zoom: tuple[str, str],
|
|
364
|
-
title: Optional[str] = None,
|
|
365
|
-
figsize: tuple[int, int] = (8, 4),
|
|
366
|
-
show: bool = True,
|
|
367
|
-
) -> plt.Figure:
|
|
368
|
-
"""Plot a time series with a zoomed-in focus area.
|
|
369
|
-
|
|
370
|
-
Creates a two-panel plot:
|
|
371
|
-
1. Top panel: Full time series with the zoom area highlighted.
|
|
372
|
-
2. Bottom panel: Zoomed-in view of the specified time range.
|
|
373
|
-
|
|
374
|
-
Args:
|
|
375
|
-
data: DataFrame containing the time series data. Must have a DatetimeIndex
|
|
376
|
-
or an index convertible to datetime.
|
|
377
|
-
target: Name of the column to plot.
|
|
378
|
-
zoom: Tuple of (start_date, end_date) strings defining the zoom range.
|
|
379
|
-
title: Optional title for the plot. If None, defaults to target name.
|
|
380
|
-
figsize: Figure dimensions (width, height). Defaults to (8, 4).
|
|
381
|
-
show: Whether to display the plot immediately. Defaults to True.
|
|
382
|
-
|
|
383
|
-
Returns:
|
|
384
|
-
plt.Figure: The matplotlib Figure object.
|
|
385
|
-
|
|
386
|
-
Examples:
|
|
387
|
-
>>> import pandas as pd
|
|
388
|
-
>>> import matplotlib.pyplot as plt
|
|
389
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import plot_zoomed_timeseries
|
|
390
|
-
>>> # Create sample data
|
|
391
|
-
>>> dates = pd.date_range("2023-01-01", periods=100, freq="h")
|
|
392
|
-
>>> df = pd.DataFrame({"value": range(100)}, index=dates)
|
|
393
|
-
>>> # Plot with zoom
|
|
394
|
-
>>> fig = plot_zoomed_timeseries(
|
|
395
|
-
... data=df,
|
|
396
|
-
... target="value",
|
|
397
|
-
... zoom=("2023-01-02 00:00", "2023-01-03 00:00"),
|
|
398
|
-
... show=False
|
|
399
|
-
... )
|
|
400
|
-
>>> plt.close(fig)
|
|
401
|
-
"""
|
|
402
|
-
if title is None:
|
|
403
|
-
title = target
|
|
404
|
-
|
|
405
|
-
fig, axs = plt.subplots(
|
|
406
|
-
2, 1, figsize=figsize, gridspec_kw={"height_ratios": [1, 2]}
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
# Top plot: Full series with highlighted zoom area
|
|
410
|
-
data[target].plot(ax=axs[0], color="black", alpha=0.5)
|
|
411
|
-
axs[0].axvspan(zoom[0], zoom[1], color="blue", alpha=0.7)
|
|
412
|
-
axs[0].set_title(f"{title}")
|
|
413
|
-
axs[0].set_xlabel("")
|
|
414
|
-
axs[0].grid(True)
|
|
415
|
-
|
|
416
|
-
# Bottom plot: Zoomed view
|
|
417
|
-
data.loc[zoom[0] : zoom[1], target].plot(ax=axs[1], color="blue")
|
|
418
|
-
axs[1].set_title(f"Zoom: {zoom[0]} to {zoom[1]}", fontsize=10)
|
|
419
|
-
axs[1].grid(True)
|
|
420
|
-
|
|
421
|
-
plt.tight_layout()
|
|
422
|
-
|
|
423
|
-
if show:
|
|
424
|
-
plt.show()
|
|
425
|
-
|
|
426
|
-
return fig
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
def plot_seasonality(
|
|
430
|
-
data: pd.DataFrame,
|
|
431
|
-
target: str,
|
|
432
|
-
figsize: tuple[int, int] = (8, 5),
|
|
433
|
-
show: bool = True,
|
|
434
|
-
logscale: Union[bool, list[bool]] = False,
|
|
435
|
-
) -> plt.Figure:
|
|
436
|
-
"""Plot seasonal patterns (annual, weekly, daily) for a given target.
|
|
437
|
-
|
|
438
|
-
Creates a 2x2 grid of plots:
|
|
439
|
-
1. Distribution by month (boxplot + median).
|
|
440
|
-
2. Distribution by week day (boxplot + median).
|
|
441
|
-
3. Distribution by hour of day (boxplot + median).
|
|
442
|
-
4. Mean target value by day of week and hour.
|
|
443
|
-
|
|
444
|
-
Args:
|
|
445
|
-
data: DataFrame containing the time series data. Must have a DatetimeIndex
|
|
446
|
-
or an index convertible to datetime.
|
|
447
|
-
target: Name of the column to plot.
|
|
448
|
-
figsize: Figure dimensions (width, height). Defaults to (8, 5).
|
|
449
|
-
show: Whether to display the plot immediately. Defaults to True.
|
|
450
|
-
logscale: Whether to use a log scale for the y-axis.
|
|
451
|
-
Can be a single boolean (applies to all 4 plots) or a list of 4
|
|
452
|
-
booleans (applies to each plot individually). Defaults to False.
|
|
453
|
-
|
|
454
|
-
Returns:
|
|
455
|
-
plt.Figure: The matplotlib Figure object.
|
|
456
|
-
|
|
457
|
-
Examples:
|
|
458
|
-
>>> import pandas as pd
|
|
459
|
-
>>> import matplotlib.pyplot as plt
|
|
460
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import plot_seasonality
|
|
461
|
-
>>> # Create sample data
|
|
462
|
-
>>> dates = pd.date_range("2023-01-01", periods=1000, freq="h")
|
|
463
|
-
>>> df = pd.DataFrame({"value": range(1, 1001)}, index=dates)
|
|
464
|
-
>>> # Plot seasonality with log scale for all plots
|
|
465
|
-
>>> fig = plot_seasonality(data=df, target="value", logscale=True, show=False)
|
|
466
|
-
>>> plt.close(fig)
|
|
467
|
-
>>> # Plot seasonality with log scale for the first plot only
|
|
468
|
-
>>> fig = plot_seasonality(
|
|
469
|
-
... data=df,
|
|
470
|
-
... target="value",
|
|
471
|
-
... logscale=[True, False, False, False],
|
|
472
|
-
... show=False
|
|
473
|
-
... )
|
|
474
|
-
>>> plt.close(fig)
|
|
475
|
-
"""
|
|
476
|
-
# Work on a copy to avoid modifying the original dataframe with localized features
|
|
477
|
-
df = data.copy()
|
|
478
|
-
|
|
479
|
-
# Create temporal features
|
|
480
|
-
df["month"] = df.index.month
|
|
481
|
-
df["week_day"] = df.index.day_of_week + 1
|
|
482
|
-
df["hour_day"] = df.index.hour + 1
|
|
483
|
-
|
|
484
|
-
# Handle logscale
|
|
485
|
-
if isinstance(logscale, bool):
|
|
486
|
-
logscales = [logscale] * 4
|
|
487
|
-
sharey = True
|
|
488
|
-
else:
|
|
489
|
-
if len(logscale) != 4:
|
|
490
|
-
raise ValueError("logscale list must have length 4.")
|
|
491
|
-
logscales = logscale
|
|
492
|
-
# If different scales are used, we should not share y-axis
|
|
493
|
-
sharey = len(set(logscales)) == 1
|
|
494
|
-
|
|
495
|
-
fig, axs = plt.subplots(2, 2, figsize=figsize, sharex=False, sharey=sharey)
|
|
496
|
-
axs = axs.ravel()
|
|
497
|
-
|
|
498
|
-
# 1. Distribution by month
|
|
499
|
-
df.boxplot(
|
|
500
|
-
column=target, by="month", ax=axs[0], flierprops={"markersize": 3, "alpha": 0.3}
|
|
501
|
-
)
|
|
502
|
-
df.groupby("month")[target].median().plot(style="o-", linewidth=0.8, ax=axs[0])
|
|
503
|
-
axs[0].set_ylabel(target)
|
|
504
|
-
axs[0].set_title(f"{target} distribution by month", fontsize=9)
|
|
505
|
-
|
|
506
|
-
# 2. Distribution by week day
|
|
507
|
-
df.boxplot(
|
|
508
|
-
column=target,
|
|
509
|
-
by="week_day",
|
|
510
|
-
ax=axs[1],
|
|
511
|
-
flierprops={"markersize": 3, "alpha": 0.3},
|
|
512
|
-
)
|
|
513
|
-
df.groupby("week_day")[target].median().plot(style="o-", linewidth=0.8, ax=axs[1])
|
|
514
|
-
axs[1].set_ylabel(target)
|
|
515
|
-
axs[1].set_title(f"{target} distribution by week day", fontsize=9)
|
|
516
|
-
|
|
517
|
-
# 3. Distribution by the hour of the day
|
|
518
|
-
df.boxplot(
|
|
519
|
-
column=target,
|
|
520
|
-
by="hour_day",
|
|
521
|
-
ax=axs[2],
|
|
522
|
-
flierprops={"markersize": 3, "alpha": 0.3},
|
|
523
|
-
)
|
|
524
|
-
df.groupby("hour_day")[target].median().plot(style="o-", linewidth=0.8, ax=axs[2])
|
|
525
|
-
axs[2].set_ylabel(target)
|
|
526
|
-
axs[2].set_title(f"{target} distribution by the hour of the day", fontsize=9)
|
|
527
|
-
|
|
528
|
-
# 4. Distribution by week day and hour of the day
|
|
529
|
-
mean_day_hour = df.groupby(["week_day", "hour_day"])[target].mean()
|
|
530
|
-
mean_day_hour.plot(ax=axs[3])
|
|
531
|
-
axs[3].set(
|
|
532
|
-
title=f"Mean {target} during week",
|
|
533
|
-
xticks=[i * 24 for i in range(7)],
|
|
534
|
-
xticklabels=["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"],
|
|
535
|
-
xlabel="Day and hour",
|
|
536
|
-
ylabel=f"Number of {target}",
|
|
537
|
-
)
|
|
538
|
-
axs[3].grid(True)
|
|
539
|
-
axs[3].title.set_size(10)
|
|
540
|
-
|
|
541
|
-
# Apply logscale
|
|
542
|
-
for i, ax in enumerate(axs):
|
|
543
|
-
if logscales[i]:
|
|
544
|
-
ax.set_yscale("log")
|
|
545
|
-
|
|
546
|
-
fig.suptitle(f"Seasonality plots: {target}", fontsize=12)
|
|
547
|
-
fig.tight_layout()
|
|
548
|
-
|
|
549
|
-
if show:
|
|
550
|
-
plt.show()
|
|
551
|
-
|
|
552
|
-
return fig
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
def plot_predictions(
|
|
556
|
-
y_true: Union[pd.Series, pd.DataFrame],
|
|
557
|
-
predictions: Dict[str, Union[pd.Series, pd.DataFrame, np.ndarray]],
|
|
558
|
-
slice_seq: Optional[slice] = None,
|
|
559
|
-
title: str = "Predictions vs Actuals",
|
|
560
|
-
figsize: Optional[tuple] = None,
|
|
561
|
-
show: bool = True,
|
|
562
|
-
nrows: Optional[int] = None,
|
|
563
|
-
ncols: int = 1,
|
|
564
|
-
sharex: bool = True,
|
|
565
|
-
) -> plt.Figure:
|
|
566
|
-
"""Plot actual values against one or more prediction series.
|
|
567
|
-
|
|
568
|
-
Allows visualizing model performance by overlaying predictions on top of
|
|
569
|
-
actual data. Supports slicing to focus on a specific time range (e.g.,
|
|
570
|
-
the recent test set). Handles both univariate and multivariate targets
|
|
571
|
-
by creating subplots for multiple targets.
|
|
572
|
-
|
|
573
|
-
Args:
|
|
574
|
-
y_true: Series or DataFrame containing the actual target values.
|
|
575
|
-
predictions: Dictionary where keys are labels (e.g., model names) and
|
|
576
|
-
values are the corresponding predictions.
|
|
577
|
-
If arrays are provided, they must have the same length as the
|
|
578
|
-
sliced `y_true`.
|
|
579
|
-
slice_seq: Optional slice object to select a subset of the data.
|
|
580
|
-
If None, the entire series is plotted.
|
|
581
|
-
Example: `slice(-96, None)` to select the last 96 points.
|
|
582
|
-
title: Title of the plot. Defaults to "Predictions vs Actuals".
|
|
583
|
-
figsize: Tuple defining figure width and height. If None, automatically
|
|
584
|
-
calculated based on number of subplots.
|
|
585
|
-
show: Whether to display the plot. Defaults to True.
|
|
586
|
-
nrows: Number of rows for subplots (multivariate). Defaults to n_targets.
|
|
587
|
-
ncols: Number of columns for subplots (multivariate). Defaults to 1.
|
|
588
|
-
sharex: Whether to share x-axis for subplots. Defaults to True.
|
|
589
|
-
|
|
590
|
-
Returns:
|
|
591
|
-
plt.Figure: The matplotlib Figure object containing the plot.
|
|
592
|
-
|
|
593
|
-
Examples:
|
|
594
|
-
>>> import matplotlib.pyplot as plt
|
|
595
|
-
>>> import pandas as pd
|
|
596
|
-
>>> import numpy as np
|
|
597
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import plot_predictions
|
|
598
|
-
>>> # Create sample data
|
|
599
|
-
>>> dates = pd.date_range("2023-01-01", periods=10, freq="D")
|
|
600
|
-
>>> y_true = pd.Series(np.arange(10), index=dates, name="Target")
|
|
601
|
-
>>> predictions = {"Model A": y_true + 0.5}
|
|
602
|
-
>>> # Plot predictions
|
|
603
|
-
>>> fig = plot_predictions(y_true, predictions, show=False)
|
|
604
|
-
>>> plt.close(fig)
|
|
605
|
-
"""
|
|
606
|
-
if slice_seq is None:
|
|
607
|
-
slice_seq = slice(None)
|
|
608
|
-
|
|
609
|
-
# Handle y_true slicing
|
|
610
|
-
y_plot = y_true.iloc[slice_seq]
|
|
611
|
-
|
|
612
|
-
# Determine dimensions
|
|
613
|
-
if isinstance(y_plot, pd.Series):
|
|
614
|
-
targets = [y_plot.name] if y_plot.name else ["Target"]
|
|
615
|
-
# Convert to DataFrame for consistent interface
|
|
616
|
-
y_plot = y_plot.to_frame(name=targets[0])
|
|
617
|
-
else:
|
|
618
|
-
targets = y_plot.columns.tolist()
|
|
619
|
-
|
|
620
|
-
n_targets = len(targets)
|
|
621
|
-
|
|
622
|
-
# Setup layout
|
|
623
|
-
if nrows is None:
|
|
624
|
-
nrows = n_targets
|
|
625
|
-
|
|
626
|
-
# Check if nrows * ncols covers all targets
|
|
627
|
-
if nrows * ncols < n_targets:
|
|
628
|
-
# Auto-adjust if invalid
|
|
629
|
-
nrows = (n_targets + ncols - 1) // ncols
|
|
630
|
-
|
|
631
|
-
if figsize is None:
|
|
632
|
-
figsize = (12, 4 * nrows)
|
|
633
|
-
|
|
634
|
-
fig, axes = plt.subplots(
|
|
635
|
-
nrows=nrows,
|
|
636
|
-
ncols=ncols,
|
|
637
|
-
figsize=figsize,
|
|
638
|
-
sharex=sharex,
|
|
639
|
-
squeeze=False, # Ensure axes is always 2D array
|
|
640
|
-
)
|
|
641
|
-
fig.suptitle(title)
|
|
642
|
-
|
|
643
|
-
# Flatten axes for iteration
|
|
644
|
-
axes_flat = axes.flatten()
|
|
645
|
-
|
|
646
|
-
for i, target in enumerate(targets):
|
|
647
|
-
if i >= len(axes_flat):
|
|
648
|
-
break
|
|
649
|
-
|
|
650
|
-
ax = axes_flat[i]
|
|
651
|
-
|
|
652
|
-
# Plot Actuals
|
|
653
|
-
target_actuals = y_plot[target]
|
|
654
|
-
ax.plot(
|
|
655
|
-
target_actuals.index,
|
|
656
|
-
target_actuals.values,
|
|
657
|
-
"x-",
|
|
658
|
-
alpha=0.5,
|
|
659
|
-
label="Actual",
|
|
660
|
-
color="black",
|
|
661
|
-
linewidth=2,
|
|
662
|
-
)
|
|
663
|
-
|
|
664
|
-
# Plot Predictions
|
|
665
|
-
for label, y_pred in predictions.items():
|
|
666
|
-
if isinstance(y_pred, pd.DataFrame):
|
|
667
|
-
# Try specific column logic
|
|
668
|
-
if target in y_pred.columns:
|
|
669
|
-
pred_part = y_pred[target]
|
|
670
|
-
elif len(y_pred.columns) == n_targets:
|
|
671
|
-
# Assume aligned order? Risky but fallback
|
|
672
|
-
pred_part = y_pred.iloc[:, i]
|
|
673
|
-
else:
|
|
674
|
-
continue # Warning?
|
|
675
|
-
|
|
676
|
-
elif isinstance(y_pred, np.ndarray):
|
|
677
|
-
# If array, check dimensions
|
|
678
|
-
if y_pred.ndim > 1 and y_pred.shape[1] == n_targets:
|
|
679
|
-
pred_part = y_pred[:, i]
|
|
680
|
-
elif y_pred.ndim == 1 and n_targets == 1:
|
|
681
|
-
pred_part = y_pred
|
|
682
|
-
else:
|
|
683
|
-
continue # Mismatch
|
|
684
|
-
|
|
685
|
-
elif isinstance(y_pred, pd.Series):
|
|
686
|
-
if n_targets == 1:
|
|
687
|
-
pred_part = y_pred
|
|
688
|
-
else:
|
|
689
|
-
continue # Mismatch?
|
|
690
|
-
|
|
691
|
-
else:
|
|
692
|
-
continue
|
|
693
|
-
|
|
694
|
-
# Process slice/alignment for pred_part
|
|
695
|
-
# Logic borrowed from previous:
|
|
696
|
-
# If length matches full y_true, slice it.
|
|
697
|
-
# If length matches y_plot (sliced), use as is.
|
|
698
|
-
|
|
699
|
-
full_len = len(y_true)
|
|
700
|
-
sliced_len = len(y_plot)
|
|
701
|
-
|
|
702
|
-
vals_to_plot = None
|
|
703
|
-
|
|
704
|
-
# Simple heuristic
|
|
705
|
-
if isinstance(pred_part, (pd.Series, pd.DataFrame)):
|
|
706
|
-
vals_to_plot = pred_part.values
|
|
707
|
-
else:
|
|
708
|
-
vals_to_plot = pred_part
|
|
709
|
-
|
|
710
|
-
if len(vals_to_plot) == full_len:
|
|
711
|
-
vals_to_plot = vals_to_plot[slice_seq]
|
|
712
|
-
elif len(vals_to_plot) != sliced_len:
|
|
713
|
-
# Length mismatch warning?
|
|
714
|
-
pass
|
|
715
|
-
|
|
716
|
-
ax.plot(target_actuals.index, vals_to_plot, "x-", label=label, alpha=0.8)
|
|
717
|
-
|
|
718
|
-
ax.set_title(target)
|
|
719
|
-
ax.legend()
|
|
720
|
-
ax.grid(True, alpha=0.3)
|
|
721
|
-
if not sharex:
|
|
722
|
-
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
|
723
|
-
|
|
724
|
-
# Hide unused subplots
|
|
725
|
-
for j in range(i + 1, len(axes_flat)):
|
|
726
|
-
axes_flat[j].axis("off")
|
|
727
|
-
|
|
728
|
-
if sharex:
|
|
729
|
-
# Rotate labels for bottom row axes
|
|
730
|
-
for ax in axes[-1, :]:
|
|
731
|
-
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
|
732
|
-
|
|
733
|
-
plt.tight_layout()
|
|
734
|
-
|
|
735
|
-
if show:
|
|
736
|
-
plt.show()
|
|
737
|
-
|
|
738
|
-
return fig
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
def plot_forecast(
|
|
742
|
-
model: Any,
|
|
743
|
-
X: pd.DataFrame,
|
|
744
|
-
y: Union[pd.Series, pd.DataFrame],
|
|
745
|
-
cv_results: Optional[Dict[str, Any]] = None,
|
|
746
|
-
title: str = "Forecast",
|
|
747
|
-
figsize: Optional[tuple] = None,
|
|
748
|
-
show: bool = True,
|
|
749
|
-
nrows: Optional[int] = None,
|
|
750
|
-
ncols: int = 1,
|
|
751
|
-
sharex: bool = True,
|
|
752
|
-
) -> plt.Figure:
|
|
753
|
-
"""Plot model forecast against actuals and display CV metrics.
|
|
754
|
-
|
|
755
|
-
Args:
|
|
756
|
-
model: Fitted scikit-learn model.
|
|
757
|
-
X: Feature matrix (e.g., test set).
|
|
758
|
-
y: Target series or DataFrame (e.g., test set).
|
|
759
|
-
cv_results: Optional dictionary of cross-validation results from
|
|
760
|
-
`evaluate()` or `sklearn.model_selection.cross_validate()`.
|
|
761
|
-
title: Title of the plot. Defaults to "Forecast".
|
|
762
|
-
figsize: Figure dimensions.
|
|
763
|
-
show: Whether to display the plot. Defaults to True.
|
|
764
|
-
nrows: Number of rows for subplots (multivariate).
|
|
765
|
-
ncols: Number of columns for subplots (multivariate).
|
|
766
|
-
sharex: Whether to share x-axis for subplots. Defaults to True.
|
|
767
|
-
|
|
768
|
-
Returns:
|
|
769
|
-
plt.Figure: The matplotlib Figure object.
|
|
770
|
-
|
|
771
|
-
Examples:
|
|
772
|
-
>>> import matplotlib.pyplot as plt
|
|
773
|
-
>>> import pandas as pd
|
|
774
|
-
>>> import numpy as np
|
|
775
|
-
>>> from sklearn.linear_model import LinearRegression
|
|
776
|
-
>>> from spotforecast2.preprocessing.time_series_visualization import plot_forecast
|
|
777
|
-
>>> # Create sample data
|
|
778
|
-
>>> dates = pd.date_range("2023-01-01", periods=10, freq="D")
|
|
779
|
-
>>> X = pd.DataFrame({"feat": np.arange(10)}, index=dates)
|
|
780
|
-
>>> y = pd.Series(np.arange(10), index=dates)
|
|
781
|
-
>>> model = LinearRegression().fit(X, y)
|
|
782
|
-
>>> # Plot forecast
|
|
783
|
-
>>> fig = plot_forecast(model, X, y, show=False)
|
|
784
|
-
>>> plt.close(fig)
|
|
785
|
-
"""
|
|
786
|
-
# 1. Generate predictions/forecast
|
|
787
|
-
# Assume model is already fitted
|
|
788
|
-
y_pred = model.predict(X)
|
|
789
|
-
|
|
790
|
-
# 2. Format title with metrics if available
|
|
791
|
-
if cv_results:
|
|
792
|
-
metrics_str = []
|
|
793
|
-
if "test_neg_mean_absolute_error" in cv_results:
|
|
794
|
-
mae = -cv_results["test_neg_mean_absolute_error"]
|
|
795
|
-
metrics_str.append(f"MAE: {np.mean(mae):.3f} (±{np.std(mae):.3f})")
|
|
796
|
-
if "test_neg_root_mean_squared_error" in cv_results:
|
|
797
|
-
rmse = -cv_results["test_neg_root_mean_squared_error"]
|
|
798
|
-
metrics_str.append(f"RMSE: {np.mean(rmse):.3f} (±{np.std(rmse):.3f})")
|
|
799
|
-
|
|
800
|
-
if metrics_str:
|
|
801
|
-
title += "\n" + " | ".join(metrics_str)
|
|
802
|
-
|
|
803
|
-
# 3. Plot
|
|
804
|
-
predictions = {"Forecast": y_pred}
|
|
805
|
-
return plot_predictions(
|
|
806
|
-
y,
|
|
807
|
-
predictions,
|
|
808
|
-
slice_seq=None,
|
|
809
|
-
title=title,
|
|
810
|
-
figsize=figsize,
|
|
811
|
-
show=show,
|
|
812
|
-
nrows=nrows,
|
|
813
|
-
ncols=ncols,
|
|
814
|
-
sharex=sharex,
|
|
815
|
-
)
|