spotforecast2-safe 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,815 +0,0 @@
1
- """Time series visualization."""
2
-
3
- from typing import Dict, List, Optional, Any, Union
4
-
5
- import pandas as pd
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
-
9
- try:
10
- import plotly.graph_objects as go
11
- except ImportError:
12
- go = None
13
-
14
-
15
- def visualize_ts_plotly(
16
- dataframes: Dict[str, pd.DataFrame],
17
- columns: Optional[List[str]] = None,
18
- title_suffix: str = "",
19
- figsize: tuple[int, int] = (1000, 500),
20
- template: str = "plotly_white",
21
- colors: Optional[Dict[str, str]] = None,
22
- **kwargs: Any,
23
- ) -> None:
24
- """Visualize multiple time series datasets interactively with Plotly.
25
-
26
- Creates interactive Plotly scatter plots for specified columns across multiple
27
- datasets (e.g., train, validation, test splits). Each dataset is displayed as
28
- a separate line with a unique color and name in the legend.
29
-
30
- Args:
31
- dataframes: Dictionary mapping dataset names to pandas DataFrames with datetime
32
- index. Example: {'Train': df_train, 'Validation': df_val, 'Test': df_test}
33
- columns: List of column names to visualize. If None, all columns are used.
34
- Default: None.
35
- title_suffix: Suffix to append to the column name in the title. Useful for
36
- adding units or descriptions. Default: "".
37
- figsize: Figure size as (width, height) in pixels. Default: (1000, 500).
38
- template: Plotly template name for styling. Options include 'plotly_white',
39
- 'plotly_dark', 'plotly', 'ggplot2', etc. Default: 'plotly_white'.
40
- colors: Dictionary mapping dataset names to colors. If None, uses Plotly
41
- default colors. Example: {'Train': 'blue', 'Validation': 'orange'}.
42
- Default: None.
43
- **kwargs: Additional keyword arguments passed to go.Scatter() (e.g.,
44
- mode='lines+markers', line=dict(dash='dash')).
45
-
46
- Returns:
47
- None. Displays Plotly figures.
48
-
49
- Raises:
50
- ValueError: If dataframes dict is empty, contains no columns, or if
51
- specified columns don't exist in all dataframes.
52
- ImportError: If plotly is not installed.
53
- TypeError: If dataframes parameter is not a dictionary.
54
-
55
- Examples:
56
- >>> import pandas as pd
57
- >>> import numpy as np
58
- >>> from spotforecast2.preprocessing.time_series_visualization import visualize_ts_plotly
59
- >>>
60
- >>> # Create sample time series data
61
- >>> np.random.seed(42)
62
- >>> dates_train = pd.date_range('2024-01-01', periods=100, freq='h')
63
- >>> dates_val = pd.date_range('2024-05-11', periods=50, freq='h')
64
- >>> dates_test = pd.date_range('2024-07-01', periods=30, freq='h')
65
- >>>
66
- >>> data_train = pd.DataFrame({
67
- ... 'temperature': np.random.normal(20, 5, 100),
68
- ... 'humidity': np.random.normal(60, 10, 100)
69
- ... }, index=dates_train)
70
- >>>
71
- >>> data_val = pd.DataFrame({
72
- ... 'temperature': np.random.normal(22, 5, 50),
73
- ... 'humidity': np.random.normal(55, 10, 50)
74
- ... }, index=dates_val)
75
- >>>
76
- >>> data_test = pd.DataFrame({
77
- ... 'temperature': np.random.normal(25, 5, 30),
78
- ... 'humidity': np.random.normal(50, 10, 30)
79
- ... }, index=dates_test)
80
- >>>
81
- >>> # Visualize all datasets
82
- >>> dataframes = {
83
- ... 'Train': data_train,
84
- ... 'Validation': data_val,
85
- ... 'Test': data_test
86
- ... }
87
- >>> visualize_ts_plotly(dataframes)
88
-
89
- Single dataset example:
90
-
91
- >>> # Visualize single dataset
92
- >>> dataframes = {'Data': data_train}
93
- >>> visualize_ts_plotly(dataframes, columns=['temperature'])
94
-
95
- Custom styling:
96
-
97
- >>> visualize_ts_plotly(
98
- ... dataframes,
99
- ... columns=['temperature'],
100
- ... template='plotly_dark',
101
- ... colors={'Train': 'blue', 'Validation': 'green', 'Test': 'red'},
102
- ... mode='lines+markers'
103
- ... )
104
- """
105
- if go is None:
106
- raise ImportError(
107
- "plotly is required for this function. " "Install with: pip install plotly"
108
- )
109
-
110
- if not isinstance(dataframes, dict):
111
- raise TypeError("dataframes parameter must be a dictionary")
112
-
113
- if not dataframes:
114
- raise ValueError("dataframes dictionary is empty")
115
-
116
- # Validate all dataframes have data
117
- for name, df in dataframes.items():
118
- if df.empty:
119
- raise ValueError(f"DataFrame '{name}' is empty")
120
- if len(df.columns) == 0:
121
- raise ValueError(f"DataFrame '{name}' contains no columns")
122
-
123
- # Determine columns to plot
124
- all_columns = set()
125
- for df in dataframes.values():
126
- all_columns.update(df.columns)
127
-
128
- if not all_columns:
129
- raise ValueError("No columns found in any dataframe")
130
-
131
- columns_to_plot = columns if columns is not None else sorted(list(all_columns))
132
-
133
- # Validate columns exist in all dataframes
134
- for col in columns_to_plot:
135
- for name, df in dataframes.items():
136
- if col not in df.columns:
137
- raise ValueError(f"Column '{col}' not found in dataframe '{name}'")
138
-
139
- # Default colors if not provided
140
- if colors is None:
141
- # Use a set of distinct colors
142
- default_colors = [
143
- "#1f77b4", # blue
144
- "#ff7f0e", # orange
145
- "#2ca02c", # green
146
- "#d62728", # red
147
- "#9467bd", # purple
148
- "#8c564b", # brown
149
- "#e377c2", # pink
150
- "#7f7f7f", # gray
151
- "#bcbd22", # olive
152
- "#17becf", # cyan
153
- ]
154
- colors = {
155
- name: default_colors[i % len(default_colors)]
156
- for i, name in enumerate(dataframes.keys())
157
- }
158
-
159
- # Create figures for each column
160
- for col in columns_to_plot:
161
- fig = go.Figure()
162
-
163
- # Add trace for each dataset
164
- for dataset_name, df in dataframes.items():
165
- fig.add_trace(
166
- go.Scatter(
167
- x=df.index,
168
- y=df[col],
169
- mode="lines",
170
- name=dataset_name,
171
- line=dict(color=colors[dataset_name]),
172
- **kwargs,
173
- )
174
- )
175
-
176
- # Create title
177
- title = col
178
- if title_suffix:
179
- title = f"{col} {title_suffix}"
180
-
181
- # Update layout
182
- fig.update_layout(
183
- title=title,
184
- xaxis_title="Time",
185
- yaxis_title=col,
186
- width=figsize[0],
187
- height=figsize[1],
188
- template=template,
189
- legend=dict(
190
- orientation="h",
191
- yanchor="bottom",
192
- y=1.02,
193
- xanchor="right",
194
- x=1,
195
- ),
196
- hovermode="x unified",
197
- )
198
-
199
- fig.show()
200
-
201
-
202
- def visualize_ts_comparison(
203
- dataframes: Dict[str, pd.DataFrame],
204
- columns: Optional[List[str]] = None,
205
- title_suffix: str = "",
206
- figsize: tuple[int, int] = (1000, 500),
207
- template: str = "plotly_white",
208
- colors: Optional[Dict[str, str]] = None,
209
- show_mean: bool = False,
210
- **kwargs: Any,
211
- ) -> None:
212
- """Visualize time series with optional statistical overlays.
213
-
214
- Similar to visualize_ts_plotly but adds options for statistical overlays
215
- like mean values across all datasets.
216
-
217
- Args:
218
- dataframes: Dictionary mapping dataset names to pandas DataFrames.
219
- columns: List of column names to visualize. If None, all columns are used.
220
- Default: None.
221
- title_suffix: Suffix to append to column names. Default: "".
222
- figsize: Figure size as (width, height) in pixels. Default: (1000, 500).
223
- template: Plotly template. Default: 'plotly_white'.
224
- colors: Dictionary mapping dataset names to colors. Default: None.
225
- show_mean: If True, overlay the mean of all datasets. Default: False.
226
- **kwargs: Additional keyword arguments for go.Scatter().
227
-
228
- Returns:
229
- None. Displays Plotly figures.
230
-
231
- Raises:
232
- ValueError: If dataframes is empty.
233
- ImportError: If plotly is not installed.
234
-
235
- Examples:
236
- >>> import pandas as pd
237
- >>> import numpy as np
238
- >>> from spotforecast2.preprocessing.time_series_visualization import visualize_ts_comparison
239
- >>>
240
- >>> # Create sample data
241
- >>> np.random.seed(42)
242
- >>> dates1 = pd.date_range('2024-01-01', periods=100, freq='h')
243
- >>> dates2 = pd.date_range('2024-05-11', periods=100, freq='h')
244
- >>>
245
- >>> df1 = pd.DataFrame({
246
- ... 'temperature': np.random.normal(20, 5, 100)
247
- ... }, index=dates1)
248
- >>>
249
- >>> df2 = pd.DataFrame({
250
- ... 'temperature': np.random.normal(22, 5, 100)
251
- ... }, index=dates2)
252
- >>>
253
- >>> # Compare with mean overlay
254
- >>> visualize_ts_comparison(
255
- ... {'Dataset1': df1, 'Dataset2': df2},
256
- ... show_mean=True
257
- ... )
258
- """
259
- if go is None:
260
- raise ImportError(
261
- "plotly is required for this function. " "Install with: pip install plotly"
262
- )
263
-
264
- if not dataframes:
265
- raise ValueError("dataframes dictionary is empty")
266
-
267
- # First visualize normally
268
- visualize_ts_plotly(
269
- dataframes,
270
- columns=columns,
271
- title_suffix=title_suffix,
272
- figsize=figsize,
273
- template=template,
274
- colors=colors,
275
- **kwargs,
276
- )
277
-
278
- # If show_mean, create additional mean plot
279
- if show_mean:
280
- # Determine columns to plot
281
- all_columns = set()
282
- for df in dataframes.values():
283
- all_columns.update(df.columns)
284
-
285
- columns_to_plot = columns if columns is not None else sorted(list(all_columns))
286
-
287
- for col in columns_to_plot:
288
- fig = go.Figure()
289
-
290
- # Add individual traces
291
- if colors is None:
292
- default_colors = [
293
- "#1f77b4",
294
- "#ff7f0e",
295
- "#2ca02c",
296
- "#d62728",
297
- "#9467bd",
298
- ]
299
- colors_dict = {
300
- name: default_colors[i % len(default_colors)]
301
- for i, name in enumerate(dataframes.keys())
302
- }
303
- else:
304
- colors_dict = colors
305
-
306
- for dataset_name, df in dataframes.items():
307
- fig.add_trace(
308
- go.Scatter(
309
- x=df.index,
310
- y=df[col],
311
- mode="lines",
312
- name=dataset_name,
313
- line=dict(color=colors_dict[dataset_name], width=1),
314
- opacity=0.5,
315
- **kwargs,
316
- )
317
- )
318
-
319
- # Calculate and add mean
320
- # Align all dataframes by index and compute mean
321
- aligned_dfs = [
322
- dataframes[name][[col]].rename(columns={col: name})
323
- for name in dataframes.keys()
324
- ]
325
- combined = pd.concat(aligned_dfs, axis=1)
326
- mean_values = combined.mean(axis=1)
327
-
328
- fig.add_trace(
329
- go.Scatter(
330
- x=mean_values.index,
331
- y=mean_values,
332
- mode="lines",
333
- name="Mean",
334
- line=dict(color="black", width=3, dash="dash"),
335
- )
336
- )
337
-
338
- title = f"{col} (with mean){title_suffix}"
339
-
340
- fig.update_layout(
341
- title=title,
342
- xaxis_title="Time",
343
- yaxis_title=col,
344
- width=figsize[0],
345
- height=figsize[1],
346
- template=template,
347
- legend=dict(
348
- orientation="h",
349
- yanchor="bottom",
350
- y=1.02,
351
- xanchor="right",
352
- x=1,
353
- ),
354
- hovermode="x unified",
355
- )
356
-
357
- fig.show()
358
-
359
-
360
- def plot_zoomed_timeseries(
361
- data: pd.DataFrame,
362
- target: str,
363
- zoom: tuple[str, str],
364
- title: Optional[str] = None,
365
- figsize: tuple[int, int] = (8, 4),
366
- show: bool = True,
367
- ) -> plt.Figure:
368
- """Plot a time series with a zoomed-in focus area.
369
-
370
- Creates a two-panel plot:
371
- 1. Top panel: Full time series with the zoom area highlighted.
372
- 2. Bottom panel: Zoomed-in view of the specified time range.
373
-
374
- Args:
375
- data: DataFrame containing the time series data. Must have a DatetimeIndex
376
- or an index convertible to datetime.
377
- target: Name of the column to plot.
378
- zoom: Tuple of (start_date, end_date) strings defining the zoom range.
379
- title: Optional title for the plot. If None, defaults to target name.
380
- figsize: Figure dimensions (width, height). Defaults to (8, 4).
381
- show: Whether to display the plot immediately. Defaults to True.
382
-
383
- Returns:
384
- plt.Figure: The matplotlib Figure object.
385
-
386
- Examples:
387
- >>> import pandas as pd
388
- >>> import matplotlib.pyplot as plt
389
- >>> from spotforecast2.preprocessing.time_series_visualization import plot_zoomed_timeseries
390
- >>> # Create sample data
391
- >>> dates = pd.date_range("2023-01-01", periods=100, freq="h")
392
- >>> df = pd.DataFrame({"value": range(100)}, index=dates)
393
- >>> # Plot with zoom
394
- >>> fig = plot_zoomed_timeseries(
395
- ... data=df,
396
- ... target="value",
397
- ... zoom=("2023-01-02 00:00", "2023-01-03 00:00"),
398
- ... show=False
399
- ... )
400
- >>> plt.close(fig)
401
- """
402
- if title is None:
403
- title = target
404
-
405
- fig, axs = plt.subplots(
406
- 2, 1, figsize=figsize, gridspec_kw={"height_ratios": [1, 2]}
407
- )
408
-
409
- # Top plot: Full series with highlighted zoom area
410
- data[target].plot(ax=axs[0], color="black", alpha=0.5)
411
- axs[0].axvspan(zoom[0], zoom[1], color="blue", alpha=0.7)
412
- axs[0].set_title(f"{title}")
413
- axs[0].set_xlabel("")
414
- axs[0].grid(True)
415
-
416
- # Bottom plot: Zoomed view
417
- data.loc[zoom[0] : zoom[1], target].plot(ax=axs[1], color="blue")
418
- axs[1].set_title(f"Zoom: {zoom[0]} to {zoom[1]}", fontsize=10)
419
- axs[1].grid(True)
420
-
421
- plt.tight_layout()
422
-
423
- if show:
424
- plt.show()
425
-
426
- return fig
427
-
428
-
429
- def plot_seasonality(
430
- data: pd.DataFrame,
431
- target: str,
432
- figsize: tuple[int, int] = (8, 5),
433
- show: bool = True,
434
- logscale: Union[bool, list[bool]] = False,
435
- ) -> plt.Figure:
436
- """Plot seasonal patterns (annual, weekly, daily) for a given target.
437
-
438
- Creates a 2x2 grid of plots:
439
- 1. Distribution by month (boxplot + median).
440
- 2. Distribution by week day (boxplot + median).
441
- 3. Distribution by hour of day (boxplot + median).
442
- 4. Mean target value by day of week and hour.
443
-
444
- Args:
445
- data: DataFrame containing the time series data. Must have a DatetimeIndex
446
- or an index convertible to datetime.
447
- target: Name of the column to plot.
448
- figsize: Figure dimensions (width, height). Defaults to (8, 5).
449
- show: Whether to display the plot immediately. Defaults to True.
450
- logscale: Whether to use a log scale for the y-axis.
451
- Can be a single boolean (applies to all 4 plots) or a list of 4
452
- booleans (applies to each plot individually). Defaults to False.
453
-
454
- Returns:
455
- plt.Figure: The matplotlib Figure object.
456
-
457
- Examples:
458
- >>> import pandas as pd
459
- >>> import matplotlib.pyplot as plt
460
- >>> from spotforecast2.preprocessing.time_series_visualization import plot_seasonality
461
- >>> # Create sample data
462
- >>> dates = pd.date_range("2023-01-01", periods=1000, freq="h")
463
- >>> df = pd.DataFrame({"value": range(1, 1001)}, index=dates)
464
- >>> # Plot seasonality with log scale for all plots
465
- >>> fig = plot_seasonality(data=df, target="value", logscale=True, show=False)
466
- >>> plt.close(fig)
467
- >>> # Plot seasonality with log scale for the first plot only
468
- >>> fig = plot_seasonality(
469
- ... data=df,
470
- ... target="value",
471
- ... logscale=[True, False, False, False],
472
- ... show=False
473
- ... )
474
- >>> plt.close(fig)
475
- """
476
- # Work on a copy to avoid modifying the original dataframe with localized features
477
- df = data.copy()
478
-
479
- # Create temporal features
480
- df["month"] = df.index.month
481
- df["week_day"] = df.index.day_of_week + 1
482
- df["hour_day"] = df.index.hour + 1
483
-
484
- # Handle logscale
485
- if isinstance(logscale, bool):
486
- logscales = [logscale] * 4
487
- sharey = True
488
- else:
489
- if len(logscale) != 4:
490
- raise ValueError("logscale list must have length 4.")
491
- logscales = logscale
492
- # If different scales are used, we should not share y-axis
493
- sharey = len(set(logscales)) == 1
494
-
495
- fig, axs = plt.subplots(2, 2, figsize=figsize, sharex=False, sharey=sharey)
496
- axs = axs.ravel()
497
-
498
- # 1. Distribution by month
499
- df.boxplot(
500
- column=target, by="month", ax=axs[0], flierprops={"markersize": 3, "alpha": 0.3}
501
- )
502
- df.groupby("month")[target].median().plot(style="o-", linewidth=0.8, ax=axs[0])
503
- axs[0].set_ylabel(target)
504
- axs[0].set_title(f"{target} distribution by month", fontsize=9)
505
-
506
- # 2. Distribution by week day
507
- df.boxplot(
508
- column=target,
509
- by="week_day",
510
- ax=axs[1],
511
- flierprops={"markersize": 3, "alpha": 0.3},
512
- )
513
- df.groupby("week_day")[target].median().plot(style="o-", linewidth=0.8, ax=axs[1])
514
- axs[1].set_ylabel(target)
515
- axs[1].set_title(f"{target} distribution by week day", fontsize=9)
516
-
517
- # 3. Distribution by the hour of the day
518
- df.boxplot(
519
- column=target,
520
- by="hour_day",
521
- ax=axs[2],
522
- flierprops={"markersize": 3, "alpha": 0.3},
523
- )
524
- df.groupby("hour_day")[target].median().plot(style="o-", linewidth=0.8, ax=axs[2])
525
- axs[2].set_ylabel(target)
526
- axs[2].set_title(f"{target} distribution by the hour of the day", fontsize=9)
527
-
528
- # 4. Distribution by week day and hour of the day
529
- mean_day_hour = df.groupby(["week_day", "hour_day"])[target].mean()
530
- mean_day_hour.plot(ax=axs[3])
531
- axs[3].set(
532
- title=f"Mean {target} during week",
533
- xticks=[i * 24 for i in range(7)],
534
- xticklabels=["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"],
535
- xlabel="Day and hour",
536
- ylabel=f"Number of {target}",
537
- )
538
- axs[3].grid(True)
539
- axs[3].title.set_size(10)
540
-
541
- # Apply logscale
542
- for i, ax in enumerate(axs):
543
- if logscales[i]:
544
- ax.set_yscale("log")
545
-
546
- fig.suptitle(f"Seasonality plots: {target}", fontsize=12)
547
- fig.tight_layout()
548
-
549
- if show:
550
- plt.show()
551
-
552
- return fig
553
-
554
-
555
- def plot_predictions(
556
- y_true: Union[pd.Series, pd.DataFrame],
557
- predictions: Dict[str, Union[pd.Series, pd.DataFrame, np.ndarray]],
558
- slice_seq: Optional[slice] = None,
559
- title: str = "Predictions vs Actuals",
560
- figsize: Optional[tuple] = None,
561
- show: bool = True,
562
- nrows: Optional[int] = None,
563
- ncols: int = 1,
564
- sharex: bool = True,
565
- ) -> plt.Figure:
566
- """Plot actual values against one or more prediction series.
567
-
568
- Allows visualizing model performance by overlaying predictions on top of
569
- actual data. Supports slicing to focus on a specific time range (e.g.,
570
- the recent test set). Handles both univariate and multivariate targets
571
- by creating subplots for multiple targets.
572
-
573
- Args:
574
- y_true: Series or DataFrame containing the actual target values.
575
- predictions: Dictionary where keys are labels (e.g., model names) and
576
- values are the corresponding predictions.
577
- If arrays are provided, they must have the same length as the
578
- sliced `y_true`.
579
- slice_seq: Optional slice object to select a subset of the data.
580
- If None, the entire series is plotted.
581
- Example: `slice(-96, None)` to select the last 96 points.
582
- title: Title of the plot. Defaults to "Predictions vs Actuals".
583
- figsize: Tuple defining figure width and height. If None, automatically
584
- calculated based on number of subplots.
585
- show: Whether to display the plot. Defaults to True.
586
- nrows: Number of rows for subplots (multivariate). Defaults to n_targets.
587
- ncols: Number of columns for subplots (multivariate). Defaults to 1.
588
- sharex: Whether to share x-axis for subplots. Defaults to True.
589
-
590
- Returns:
591
- plt.Figure: The matplotlib Figure object containing the plot.
592
-
593
- Examples:
594
- >>> import matplotlib.pyplot as plt
595
- >>> import pandas as pd
596
- >>> import numpy as np
597
- >>> from spotforecast2.preprocessing.time_series_visualization import plot_predictions
598
- >>> # Create sample data
599
- >>> dates = pd.date_range("2023-01-01", periods=10, freq="D")
600
- >>> y_true = pd.Series(np.arange(10), index=dates, name="Target")
601
- >>> predictions = {"Model A": y_true + 0.5}
602
- >>> # Plot predictions
603
- >>> fig = plot_predictions(y_true, predictions, show=False)
604
- >>> plt.close(fig)
605
- """
606
- if slice_seq is None:
607
- slice_seq = slice(None)
608
-
609
- # Handle y_true slicing
610
- y_plot = y_true.iloc[slice_seq]
611
-
612
- # Determine dimensions
613
- if isinstance(y_plot, pd.Series):
614
- targets = [y_plot.name] if y_plot.name else ["Target"]
615
- # Convert to DataFrame for consistent interface
616
- y_plot = y_plot.to_frame(name=targets[0])
617
- else:
618
- targets = y_plot.columns.tolist()
619
-
620
- n_targets = len(targets)
621
-
622
- # Setup layout
623
- if nrows is None:
624
- nrows = n_targets
625
-
626
- # Check if nrows * ncols covers all targets
627
- if nrows * ncols < n_targets:
628
- # Auto-adjust if invalid
629
- nrows = (n_targets + ncols - 1) // ncols
630
-
631
- if figsize is None:
632
- figsize = (12, 4 * nrows)
633
-
634
- fig, axes = plt.subplots(
635
- nrows=nrows,
636
- ncols=ncols,
637
- figsize=figsize,
638
- sharex=sharex,
639
- squeeze=False, # Ensure axes is always 2D array
640
- )
641
- fig.suptitle(title)
642
-
643
- # Flatten axes for iteration
644
- axes_flat = axes.flatten()
645
-
646
- for i, target in enumerate(targets):
647
- if i >= len(axes_flat):
648
- break
649
-
650
- ax = axes_flat[i]
651
-
652
- # Plot Actuals
653
- target_actuals = y_plot[target]
654
- ax.plot(
655
- target_actuals.index,
656
- target_actuals.values,
657
- "x-",
658
- alpha=0.5,
659
- label="Actual",
660
- color="black",
661
- linewidth=2,
662
- )
663
-
664
- # Plot Predictions
665
- for label, y_pred in predictions.items():
666
- if isinstance(y_pred, pd.DataFrame):
667
- # Try specific column logic
668
- if target in y_pred.columns:
669
- pred_part = y_pred[target]
670
- elif len(y_pred.columns) == n_targets:
671
- # Assume aligned order? Risky but fallback
672
- pred_part = y_pred.iloc[:, i]
673
- else:
674
- continue # Warning?
675
-
676
- elif isinstance(y_pred, np.ndarray):
677
- # If array, check dimensions
678
- if y_pred.ndim > 1 and y_pred.shape[1] == n_targets:
679
- pred_part = y_pred[:, i]
680
- elif y_pred.ndim == 1 and n_targets == 1:
681
- pred_part = y_pred
682
- else:
683
- continue # Mismatch
684
-
685
- elif isinstance(y_pred, pd.Series):
686
- if n_targets == 1:
687
- pred_part = y_pred
688
- else:
689
- continue # Mismatch?
690
-
691
- else:
692
- continue
693
-
694
- # Process slice/alignment for pred_part
695
- # Logic borrowed from previous:
696
- # If length matches full y_true, slice it.
697
- # If length matches y_plot (sliced), use as is.
698
-
699
- full_len = len(y_true)
700
- sliced_len = len(y_plot)
701
-
702
- vals_to_plot = None
703
-
704
- # Simple heuristic
705
- if isinstance(pred_part, (pd.Series, pd.DataFrame)):
706
- vals_to_plot = pred_part.values
707
- else:
708
- vals_to_plot = pred_part
709
-
710
- if len(vals_to_plot) == full_len:
711
- vals_to_plot = vals_to_plot[slice_seq]
712
- elif len(vals_to_plot) != sliced_len:
713
- # Length mismatch warning?
714
- pass
715
-
716
- ax.plot(target_actuals.index, vals_to_plot, "x-", label=label, alpha=0.8)
717
-
718
- ax.set_title(target)
719
- ax.legend()
720
- ax.grid(True, alpha=0.3)
721
- if not sharex:
722
- plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
723
-
724
- # Hide unused subplots
725
- for j in range(i + 1, len(axes_flat)):
726
- axes_flat[j].axis("off")
727
-
728
- if sharex:
729
- # Rotate labels for bottom row axes
730
- for ax in axes[-1, :]:
731
- plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
732
-
733
- plt.tight_layout()
734
-
735
- if show:
736
- plt.show()
737
-
738
- return fig
739
-
740
-
741
- def plot_forecast(
742
- model: Any,
743
- X: pd.DataFrame,
744
- y: Union[pd.Series, pd.DataFrame],
745
- cv_results: Optional[Dict[str, Any]] = None,
746
- title: str = "Forecast",
747
- figsize: Optional[tuple] = None,
748
- show: bool = True,
749
- nrows: Optional[int] = None,
750
- ncols: int = 1,
751
- sharex: bool = True,
752
- ) -> plt.Figure:
753
- """Plot model forecast against actuals and display CV metrics.
754
-
755
- Args:
756
- model: Fitted scikit-learn model.
757
- X: Feature matrix (e.g., test set).
758
- y: Target series or DataFrame (e.g., test set).
759
- cv_results: Optional dictionary of cross-validation results from
760
- `evaluate()` or `sklearn.model_selection.cross_validate()`.
761
- title: Title of the plot. Defaults to "Forecast".
762
- figsize: Figure dimensions.
763
- show: Whether to display the plot. Defaults to True.
764
- nrows: Number of rows for subplots (multivariate).
765
- ncols: Number of columns for subplots (multivariate).
766
- sharex: Whether to share x-axis for subplots. Defaults to True.
767
-
768
- Returns:
769
- plt.Figure: The matplotlib Figure object.
770
-
771
- Examples:
772
- >>> import matplotlib.pyplot as plt
773
- >>> import pandas as pd
774
- >>> import numpy as np
775
- >>> from sklearn.linear_model import LinearRegression
776
- >>> from spotforecast2.preprocessing.time_series_visualization import plot_forecast
777
- >>> # Create sample data
778
- >>> dates = pd.date_range("2023-01-01", periods=10, freq="D")
779
- >>> X = pd.DataFrame({"feat": np.arange(10)}, index=dates)
780
- >>> y = pd.Series(np.arange(10), index=dates)
781
- >>> model = LinearRegression().fit(X, y)
782
- >>> # Plot forecast
783
- >>> fig = plot_forecast(model, X, y, show=False)
784
- >>> plt.close(fig)
785
- """
786
- # 1. Generate predictions/forecast
787
- # Assume model is already fitted
788
- y_pred = model.predict(X)
789
-
790
- # 2. Format title with metrics if available
791
- if cv_results:
792
- metrics_str = []
793
- if "test_neg_mean_absolute_error" in cv_results:
794
- mae = -cv_results["test_neg_mean_absolute_error"]
795
- metrics_str.append(f"MAE: {np.mean(mae):.3f} (±{np.std(mae):.3f})")
796
- if "test_neg_root_mean_squared_error" in cv_results:
797
- rmse = -cv_results["test_neg_root_mean_squared_error"]
798
- metrics_str.append(f"RMSE: {np.mean(rmse):.3f} (±{np.std(rmse):.3f})")
799
-
800
- if metrics_str:
801
- title += "\n" + " | ".join(metrics_str)
802
-
803
- # 3. Plot
804
- predictions = {"Forecast": y_pred}
805
- return plot_predictions(
806
- y,
807
- predictions,
808
- slice_seq=None,
809
- title=title,
810
- figsize=figsize,
811
- show=show,
812
- nrows=nrows,
813
- ncols=ncols,
814
- sharex=sharex,
815
- )