tsam 2.3.9__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tsam/plot.py ADDED
@@ -0,0 +1,513 @@
1
+ """Plotting accessor for tsam aggregation results.
2
+
3
+ Provides convenient plotting methods directly on the result object for
4
+ validation and visualization of aggregation quality.
5
+
6
+ Usage:
7
+ >>> result = tsam.aggregate(df, n_clusters=8)
8
+ >>> result.plot.compare() # Compare original vs reconstructed
9
+ >>> result.plot.residuals() # View reconstruction errors
10
+ >>> result.plot.cluster_representatives()
11
+ >>> result.plot.cluster_weights()
12
+ >>> result.plot.accuracy()
13
+
14
+ For exploring raw data before aggregation, use plotly directly with
15
+ ``tsam.unstack_to_periods()`` to reshape data for heatmaps:
16
+ >>> import plotly.express as px
17
+ >>> unstacked = tsam.unstack_to_periods(df, period_duration=24)
18
+ >>> px.imshow(unstacked["Load"].values.T)
19
+
20
+ Note: This module requires the 'plotly' optional dependency.
21
+ Install with: pip install tsam[plot]
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import warnings
27
+ from typing import TYPE_CHECKING
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+
32
+ try:
33
+ import plotly.express as px
34
+ import plotly.graph_objects as go
35
+ except ImportError as e:
36
+ raise ImportError(
37
+ "The tsam.plot module requires plotly. Install it with: pip install tsam[plot]"
38
+ ) from e
39
+
40
+ if TYPE_CHECKING:
41
+ from tsam.result import AggregationResult
42
+
43
+
44
+ def _validate_columns(
45
+ requested: list[str] | None,
46
+ available: list[str],
47
+ context: str = "data",
48
+ ) -> list[str]:
49
+ """Validate and filter column names, warning about invalid ones.
50
+
51
+ Parameters
52
+ ----------
53
+ requested : list[str] | None
54
+ Columns requested by user. If None, returns all available.
55
+ available : list[str]
56
+ Columns available in the data.
57
+ context : str
58
+ Description for error messages (e.g., "original data").
59
+
60
+ Returns
61
+ -------
62
+ list[str]
63
+ Valid columns to use.
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If no valid columns remain after filtering.
69
+ """
70
+ if requested is None:
71
+ return available
72
+
73
+ valid = [c for c in requested if c in available]
74
+ invalid = [c for c in requested if c not in available]
75
+
76
+ if invalid:
77
+ warnings.warn(
78
+ f"Columns not found in {context} and will be ignored: {invalid}. "
79
+ f"Available columns: {available}",
80
+ UserWarning,
81
+ stacklevel=3,
82
+ )
83
+
84
+ if not valid:
85
+ raise ValueError(
86
+ f"None of the requested columns {requested} exist in {context}. "
87
+ f"Available columns: {available}"
88
+ )
89
+
90
+ return valid
91
+
92
+
93
+ def _duration_curve_figure(
94
+ results: dict[str, pd.DataFrame],
95
+ columns: list[str],
96
+ title: str | None = None,
97
+ ) -> go.Figure:
98
+ """Create duration curve comparison figure (internal helper)."""
99
+ frames = []
100
+ for name, data in results.items():
101
+ for col in columns:
102
+ sorted_vals = data[col].sort_values(ascending=False).reset_index(drop=True)
103
+ frames.append(
104
+ pd.DataFrame(
105
+ {
106
+ "Hour": range(len(sorted_vals)),
107
+ "Value": sorted_vals.values,
108
+ "Method": name,
109
+ "Column": col,
110
+ }
111
+ )
112
+ )
113
+ long_df = pd.concat(frames, ignore_index=True)
114
+ return px.line(
115
+ long_df,
116
+ x="Hour",
117
+ y="Value",
118
+ color="Column",
119
+ line_dash="Method",
120
+ title=title or "Duration Curve Comparison",
121
+ )
122
+
123
+
124
+ class ResultPlotAccessor:
125
+ """Plotting accessor for AggregationResult.
126
+
127
+ Provides convenient plotting methods directly on the result object.
128
+
129
+ Examples
130
+ --------
131
+ >>> result = tsam.aggregate(df, n_clusters=8)
132
+ >>> result.plot.compare() # Compare original vs reconstructed
133
+ >>> result.plot.residuals() # View reconstruction errors
134
+ >>> result.plot.cluster_representatives()
135
+ >>> result.plot.cluster_weights()
136
+ """
137
+
138
+ def __init__(self, result: AggregationResult):
139
+ self._result = result
140
+
141
+ def cluster_representatives(
142
+ self,
143
+ columns: list[str] | None = None,
144
+ title: str = "Cluster Representatives",
145
+ ) -> go.Figure:
146
+ """Plot all cluster representatives (typical periods).
147
+
148
+ Parameters
149
+ ----------
150
+ columns : list[str], optional
151
+ Columns to plot.
152
+ title : str, default "Cluster Representatives"
153
+ Plot title.
154
+
155
+ Returns
156
+ -------
157
+ go.Figure
158
+ """
159
+ typ = self._result.cluster_representatives
160
+ weights = self._result.cluster_weights
161
+
162
+ available_columns = [c for c in typ.columns if c not in ["cluster", "timestep"]]
163
+ columns = _validate_columns(
164
+ columns, available_columns, "cluster_representatives"
165
+ )
166
+
167
+ # Reset index to get period/timestep as columns
168
+ df = typ[columns].reset_index()
169
+ df.columns = pd.Index(["Period", "Timestep", *columns])
170
+
171
+ # Map period IDs to labels with weights
172
+ df["Period"] = df["Period"].map(lambda p: f"Period {p} (n={weights.get(p, 1)})")
173
+
174
+ long_df = df.melt(
175
+ id_vars=["Period", "Timestep"],
176
+ var_name="Column",
177
+ value_name="Value",
178
+ )
179
+
180
+ fig = px.line(
181
+ long_df,
182
+ x="Timestep",
183
+ y="Value",
184
+ color="Period",
185
+ facet_col="Column" if len(columns) > 1 else None,
186
+ title=title,
187
+ )
188
+
189
+ return fig
190
+
191
+ def cluster_weights(self, title: str = "Cluster Weights") -> go.Figure:
192
+ """Plot cluster weight distribution.
193
+
194
+ Parameters
195
+ ----------
196
+ title : str, default "Cluster Weights"
197
+ Plot title.
198
+
199
+ Returns
200
+ -------
201
+ go.Figure
202
+ """
203
+ weights = self._result.cluster_weights
204
+ df = pd.DataFrame(
205
+ {
206
+ "Period": [f"Period {p}" for p in weights],
207
+ "Count": list(weights.values()),
208
+ }
209
+ )
210
+
211
+ fig = px.bar(
212
+ df,
213
+ x="Period",
214
+ y="Count",
215
+ title=title,
216
+ text="Count",
217
+ color="Count",
218
+ color_continuous_scale="Viridis",
219
+ )
220
+ fig.update_traces(textposition="auto")
221
+ fig.update_layout(showlegend=False)
222
+
223
+ return fig
224
+
225
+ def accuracy(self, title: str = "Accuracy Metrics") -> go.Figure:
226
+ """Plot accuracy metrics by column.
227
+
228
+ Parameters
229
+ ----------
230
+ title : str, default "Accuracy Metrics"
231
+ Plot title.
232
+
233
+ Returns
234
+ -------
235
+ go.Figure
236
+ """
237
+ acc = self._result.accuracy
238
+ columns = list(acc.rmse.index)
239
+
240
+ records = []
241
+ for col in columns:
242
+ records.append({"Column": col, "Metric": "RMSE", "Value": acc.rmse[col]})
243
+ records.append({"Column": col, "Metric": "MAE", "Value": acc.mae[col]})
244
+ records.append(
245
+ {
246
+ "Column": col,
247
+ "Metric": "RMSE (Duration)",
248
+ "Value": acc.rmse_duration[col],
249
+ }
250
+ )
251
+
252
+ df = pd.DataFrame(records)
253
+
254
+ fig = px.bar(
255
+ df,
256
+ x="Column",
257
+ y="Value",
258
+ color="Metric",
259
+ barmode="group",
260
+ title=title,
261
+ )
262
+
263
+ return fig
264
+
265
+ def segment_durations(self, title: str = "Segment Durations") -> go.Figure:
266
+ """Plot segment durations (if segmentation was used).
267
+
268
+ Parameters
269
+ ----------
270
+ title : str, default "Segment Durations"
271
+ Plot title.
272
+
273
+ Returns
274
+ -------
275
+ go.Figure
276
+
277
+ Raises
278
+ ------
279
+ ValueError
280
+ If no segmentation was used.
281
+ """
282
+ if self._result.segment_durations is None:
283
+ raise ValueError("No segmentation was used in this aggregation")
284
+
285
+ # segment_durations is tuple[tuple[int, ...], ...] - one tuple per period
286
+ # Average durations across all typical periods for the bar chart
287
+ durations = self._result.segment_durations
288
+
289
+ # Validate uniform structure across periods
290
+ segment_counts = {len(period) for period in durations}
291
+ if len(segment_counts) != 1:
292
+ raise ValueError(
293
+ f"Inconsistent segment counts across periods: {segment_counts}. "
294
+ "Cannot compute average durations."
295
+ )
296
+
297
+ n_segments = len(durations[0])
298
+ avg_durations = [
299
+ sum(period[s] for period in durations) / len(durations)
300
+ for s in range(n_segments)
301
+ ]
302
+
303
+ df = pd.DataFrame(
304
+ {
305
+ "Segment": [f"Segment {s}" for s in range(n_segments)],
306
+ "Duration": avg_durations,
307
+ }
308
+ )
309
+
310
+ fig = px.bar(
311
+ df,
312
+ x="Segment",
313
+ y="Duration",
314
+ title=title,
315
+ text="Duration",
316
+ color="Duration",
317
+ color_continuous_scale="Viridis",
318
+ )
319
+ fig.update_traces(texttemplate="%{text:.1f}", textposition="auto")
320
+ fig.update_layout(showlegend=False, yaxis_title="Duration (timesteps)")
321
+
322
+ return fig
323
+
324
+ def compare(
325
+ self,
326
+ columns: list[str] | None = None,
327
+ mode: str = "overlay",
328
+ title: str | None = None,
329
+ ) -> go.Figure:
330
+ """Compare original vs reconstructed time series.
331
+
332
+ Parameters
333
+ ----------
334
+ columns : list[str], optional
335
+ Columns to compare. If None, compares all columns.
336
+ mode : str, default "overlay"
337
+ Comparison mode:
338
+ - "overlay": Both series on same axes
339
+ - "side_by_side": Separate subplots
340
+ - "duration_curve": Compare sorted values
341
+ title : str, optional
342
+ Plot title.
343
+
344
+ Returns
345
+ -------
346
+ go.Figure
347
+
348
+ Examples
349
+ --------
350
+ >>> result.plot.compare() # Compare all columns
351
+ >>> result.plot.compare(columns=["Load"]) # Compare specific column
352
+ >>> result.plot.compare(mode="duration_curve")
353
+ """
354
+ orig = self._result.original
355
+ recon = self._result.reconstructed
356
+
357
+ columns = _validate_columns(columns, list(orig.columns), "original data")
358
+
359
+ if mode == "duration_curve":
360
+ return _duration_curve_figure(
361
+ {"Original": orig, "Reconstructed": recon},
362
+ columns=columns,
363
+ title=title,
364
+ )
365
+
366
+ elif mode in ("overlay", "side_by_side"):
367
+ # Build long-form data with Source (Original/Reconstructed) and Column
368
+ orig_df = orig[columns].copy()
369
+ orig_df["Source"] = "Original"
370
+ recon_df = recon[columns].copy()
371
+ recon_df["Source"] = "Reconstructed"
372
+
373
+ combined = pd.concat([orig_df, recon_df])
374
+ combined.index.name = "Time"
375
+ long_df = combined.reset_index().melt(
376
+ id_vars=["Time", "Source"],
377
+ var_name="Column",
378
+ value_name="Value",
379
+ )
380
+
381
+ if mode == "overlay":
382
+ # Color by Column, dash by Source (Original/Reconstructed)
383
+ fig = px.line(
384
+ long_df,
385
+ x="Time",
386
+ y="Value",
387
+ color="Column",
388
+ line_dash="Source",
389
+ title=title or "Original vs Reconstructed",
390
+ )
391
+ else: # side_by_side
392
+ fig = px.line(
393
+ long_df,
394
+ x="Time",
395
+ y="Value",
396
+ color="Column",
397
+ facet_row="Source",
398
+ title=title or "Original vs Reconstructed",
399
+ )
400
+ fig.update_layout(height=600)
401
+
402
+ return fig
403
+
404
+ else:
405
+ raise ValueError(
406
+ f"Unknown mode: {mode}. Use 'overlay', 'side_by_side', or 'duration_curve'."
407
+ )
408
+
409
+ def residuals(
410
+ self,
411
+ columns: list[str] | None = None,
412
+ mode: str = "time_series",
413
+ title: str | None = None,
414
+ ) -> go.Figure:
415
+ """Plot residuals (original - reconstructed).
416
+
417
+ Parameters
418
+ ----------
419
+ columns : list[str], optional
420
+ Columns to plot. If None, plots all.
421
+ mode : str, default "time_series"
422
+ Display mode:
423
+ - "time_series": Residuals over time
424
+ - "histogram": Distribution of residuals
425
+ - "by_period": Mean absolute error per period (bar chart)
426
+ - "by_timestep": Mean absolute error by timestep within period
427
+ title : str, optional
428
+ Plot title.
429
+
430
+ Returns
431
+ -------
432
+ go.Figure
433
+
434
+ Examples
435
+ --------
436
+ >>> result.plot.residuals() # Time series of residuals
437
+ >>> result.plot.residuals(mode="histogram") # Error distribution
438
+ >>> result.plot.residuals(mode="by_period") # Which periods have highest error
439
+ >>> result.plot.residuals(mode="by_timestep") # Error pattern within day
440
+ """
441
+ resid = self._result.residuals
442
+ columns = _validate_columns(columns, list(resid.columns), "residuals")
443
+
444
+ if mode == "time_series":
445
+ df_plot = resid[columns].copy()
446
+ df_plot.index.name = "Time"
447
+ long_df = df_plot.reset_index().melt(
448
+ id_vars=["Time"],
449
+ var_name="Column",
450
+ value_name="Residual",
451
+ )
452
+ fig = px.line(
453
+ long_df,
454
+ x="Time",
455
+ y="Residual",
456
+ color="Column",
457
+ title=title or "Residuals Over Time",
458
+ )
459
+ fig.add_hline(y=0, line_dash="dash", line_color="gray")
460
+ return fig
461
+
462
+ elif mode == "histogram":
463
+ long_df = resid[columns].melt(var_name="Column", value_name="Residual")
464
+ fig = px.histogram(
465
+ long_df,
466
+ x="Residual",
467
+ color="Column",
468
+ barmode="overlay",
469
+ opacity=0.7,
470
+ title=title or "Residual Distribution",
471
+ )
472
+ fig.add_vline(x=0, line_dash="dash", line_color="red")
473
+ return fig
474
+
475
+ elif mode == "by_period":
476
+ n_timesteps = self._result.n_timesteps_per_period
477
+ abs_resid = resid[columns].abs().copy()
478
+ abs_resid["Period"] = np.arange(len(abs_resid)) // n_timesteps
479
+
480
+ df = abs_resid.groupby("Period")[columns].mean().reset_index()
481
+ long_df = df.melt(id_vars="Period", var_name="Column", value_name="MAE")
482
+
483
+ fig = px.bar(
484
+ long_df,
485
+ x="Period",
486
+ y="MAE",
487
+ color="Column",
488
+ barmode="group",
489
+ title=title or "Mean Absolute Error by Period",
490
+ )
491
+ return fig
492
+
493
+ elif mode == "by_timestep":
494
+ n_timesteps = self._result.n_timesteps_per_period
495
+ abs_resid = resid[columns].abs().copy()
496
+ abs_resid["Timestep"] = np.arange(len(abs_resid)) % n_timesteps
497
+
498
+ df = abs_resid.groupby("Timestep")[columns].mean().reset_index()
499
+ long_df = df.melt(id_vars="Timestep", var_name="Column", value_name="MAE")
500
+
501
+ fig = px.line(
502
+ long_df,
503
+ x="Timestep",
504
+ y="MAE",
505
+ color="Column",
506
+ title=title or "Mean Absolute Error by Timestep",
507
+ )
508
+ return fig
509
+
510
+ else:
511
+ raise ValueError(
512
+ f"Unknown mode: {mode}. Use 'time_series', 'histogram', 'by_period', or 'by_timestep'."
513
+ )
tsam/py.typed ADDED
File without changes