tsam 2.3.9__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsam/__init__.py +79 -0
- tsam/api.py +602 -0
- tsam/config.py +852 -0
- tsam/exceptions.py +17 -0
- tsam/hyperparametertuning.py +289 -245
- tsam/periodAggregation.py +140 -141
- tsam/plot.py +513 -0
- tsam/py.typed +0 -0
- tsam/representations.py +177 -167
- tsam/result.py +397 -0
- tsam/timeseriesaggregation.py +1446 -1361
- tsam/tuning.py +1038 -0
- tsam/utils/durationRepresentation.py +229 -223
- tsam/utils/k_maxoids.py +138 -145
- tsam/utils/k_medoids_contiguity.py +139 -140
- tsam/utils/k_medoids_exact.py +232 -239
- tsam/utils/segmentation.py +232 -118
- {tsam-2.3.9.dist-info → tsam-3.0.0.dist-info}/METADATA +124 -81
- tsam-3.0.0.dist-info/RECORD +23 -0
- {tsam-2.3.9.dist-info → tsam-3.0.0.dist-info}/WHEEL +1 -1
- {tsam-2.3.9.dist-info → tsam-3.0.0.dist-info}/licenses/LICENSE.txt +21 -21
- tsam-2.3.9.dist-info/RECORD +0 -16
- {tsam-2.3.9.dist-info → tsam-3.0.0.dist-info}/top_level.txt +0 -0
tsam/plot.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""Plotting accessor for tsam aggregation results.
|
|
2
|
+
|
|
3
|
+
Provides convenient plotting methods directly on the result object for
|
|
4
|
+
validation and visualization of aggregation quality.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
>>> result = tsam.aggregate(df, n_clusters=8)
|
|
8
|
+
>>> result.plot.compare() # Compare original vs reconstructed
|
|
9
|
+
>>> result.plot.residuals() # View reconstruction errors
|
|
10
|
+
>>> result.plot.cluster_representatives()
|
|
11
|
+
>>> result.plot.cluster_weights()
|
|
12
|
+
>>> result.plot.accuracy()
|
|
13
|
+
|
|
14
|
+
For exploring raw data before aggregation, use plotly directly with
|
|
15
|
+
``tsam.unstack_to_periods()`` to reshape data for heatmaps:
|
|
16
|
+
>>> import plotly.express as px
|
|
17
|
+
>>> unstacked = tsam.unstack_to_periods(df, period_duration=24)
|
|
18
|
+
>>> px.imshow(unstacked["Load"].values.T)
|
|
19
|
+
|
|
20
|
+
Note: This module requires the 'plotly' optional dependency.
|
|
21
|
+
Install with: pip install tsam[plot]
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import warnings
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import pandas as pd
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import plotly.express as px
|
|
34
|
+
import plotly.graph_objects as go
|
|
35
|
+
except ImportError as e:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"The tsam.plot module requires plotly. Install it with: pip install tsam[plot]"
|
|
38
|
+
) from e
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from tsam.result import AggregationResult
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _validate_columns(
|
|
45
|
+
requested: list[str] | None,
|
|
46
|
+
available: list[str],
|
|
47
|
+
context: str = "data",
|
|
48
|
+
) -> list[str]:
|
|
49
|
+
"""Validate and filter column names, warning about invalid ones.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
requested : list[str] | None
|
|
54
|
+
Columns requested by user. If None, returns all available.
|
|
55
|
+
available : list[str]
|
|
56
|
+
Columns available in the data.
|
|
57
|
+
context : str
|
|
58
|
+
Description for error messages (e.g., "original data").
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
list[str]
|
|
63
|
+
Valid columns to use.
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If no valid columns remain after filtering.
|
|
69
|
+
"""
|
|
70
|
+
if requested is None:
|
|
71
|
+
return available
|
|
72
|
+
|
|
73
|
+
valid = [c for c in requested if c in available]
|
|
74
|
+
invalid = [c for c in requested if c not in available]
|
|
75
|
+
|
|
76
|
+
if invalid:
|
|
77
|
+
warnings.warn(
|
|
78
|
+
f"Columns not found in {context} and will be ignored: {invalid}. "
|
|
79
|
+
f"Available columns: {available}",
|
|
80
|
+
UserWarning,
|
|
81
|
+
stacklevel=3,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if not valid:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"None of the requested columns {requested} exist in {context}. "
|
|
87
|
+
f"Available columns: {available}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return valid
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _duration_curve_figure(
|
|
94
|
+
results: dict[str, pd.DataFrame],
|
|
95
|
+
columns: list[str],
|
|
96
|
+
title: str | None = None,
|
|
97
|
+
) -> go.Figure:
|
|
98
|
+
"""Create duration curve comparison figure (internal helper)."""
|
|
99
|
+
frames = []
|
|
100
|
+
for name, data in results.items():
|
|
101
|
+
for col in columns:
|
|
102
|
+
sorted_vals = data[col].sort_values(ascending=False).reset_index(drop=True)
|
|
103
|
+
frames.append(
|
|
104
|
+
pd.DataFrame(
|
|
105
|
+
{
|
|
106
|
+
"Hour": range(len(sorted_vals)),
|
|
107
|
+
"Value": sorted_vals.values,
|
|
108
|
+
"Method": name,
|
|
109
|
+
"Column": col,
|
|
110
|
+
}
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
long_df = pd.concat(frames, ignore_index=True)
|
|
114
|
+
return px.line(
|
|
115
|
+
long_df,
|
|
116
|
+
x="Hour",
|
|
117
|
+
y="Value",
|
|
118
|
+
color="Column",
|
|
119
|
+
line_dash="Method",
|
|
120
|
+
title=title or "Duration Curve Comparison",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ResultPlotAccessor:
|
|
125
|
+
"""Plotting accessor for AggregationResult.
|
|
126
|
+
|
|
127
|
+
Provides convenient plotting methods directly on the result object.
|
|
128
|
+
|
|
129
|
+
Examples
|
|
130
|
+
--------
|
|
131
|
+
>>> result = tsam.aggregate(df, n_clusters=8)
|
|
132
|
+
>>> result.plot.compare() # Compare original vs reconstructed
|
|
133
|
+
>>> result.plot.residuals() # View reconstruction errors
|
|
134
|
+
>>> result.plot.cluster_representatives()
|
|
135
|
+
>>> result.plot.cluster_weights()
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self, result: AggregationResult):
|
|
139
|
+
self._result = result
|
|
140
|
+
|
|
141
|
+
def cluster_representatives(
|
|
142
|
+
self,
|
|
143
|
+
columns: list[str] | None = None,
|
|
144
|
+
title: str = "Cluster Representatives",
|
|
145
|
+
) -> go.Figure:
|
|
146
|
+
"""Plot all cluster representatives (typical periods).
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
columns : list[str], optional
|
|
151
|
+
Columns to plot.
|
|
152
|
+
title : str, default "Cluster Representatives"
|
|
153
|
+
Plot title.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
go.Figure
|
|
158
|
+
"""
|
|
159
|
+
typ = self._result.cluster_representatives
|
|
160
|
+
weights = self._result.cluster_weights
|
|
161
|
+
|
|
162
|
+
available_columns = [c for c in typ.columns if c not in ["cluster", "timestep"]]
|
|
163
|
+
columns = _validate_columns(
|
|
164
|
+
columns, available_columns, "cluster_representatives"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Reset index to get period/timestep as columns
|
|
168
|
+
df = typ[columns].reset_index()
|
|
169
|
+
df.columns = pd.Index(["Period", "Timestep", *columns])
|
|
170
|
+
|
|
171
|
+
# Map period IDs to labels with weights
|
|
172
|
+
df["Period"] = df["Period"].map(lambda p: f"Period {p} (n={weights.get(p, 1)})")
|
|
173
|
+
|
|
174
|
+
long_df = df.melt(
|
|
175
|
+
id_vars=["Period", "Timestep"],
|
|
176
|
+
var_name="Column",
|
|
177
|
+
value_name="Value",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
fig = px.line(
|
|
181
|
+
long_df,
|
|
182
|
+
x="Timestep",
|
|
183
|
+
y="Value",
|
|
184
|
+
color="Period",
|
|
185
|
+
facet_col="Column" if len(columns) > 1 else None,
|
|
186
|
+
title=title,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return fig
|
|
190
|
+
|
|
191
|
+
def cluster_weights(self, title: str = "Cluster Weights") -> go.Figure:
|
|
192
|
+
"""Plot cluster weight distribution.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
title : str, default "Cluster Weights"
|
|
197
|
+
Plot title.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
go.Figure
|
|
202
|
+
"""
|
|
203
|
+
weights = self._result.cluster_weights
|
|
204
|
+
df = pd.DataFrame(
|
|
205
|
+
{
|
|
206
|
+
"Period": [f"Period {p}" for p in weights],
|
|
207
|
+
"Count": list(weights.values()),
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
fig = px.bar(
|
|
212
|
+
df,
|
|
213
|
+
x="Period",
|
|
214
|
+
y="Count",
|
|
215
|
+
title=title,
|
|
216
|
+
text="Count",
|
|
217
|
+
color="Count",
|
|
218
|
+
color_continuous_scale="Viridis",
|
|
219
|
+
)
|
|
220
|
+
fig.update_traces(textposition="auto")
|
|
221
|
+
fig.update_layout(showlegend=False)
|
|
222
|
+
|
|
223
|
+
return fig
|
|
224
|
+
|
|
225
|
+
def accuracy(self, title: str = "Accuracy Metrics") -> go.Figure:
|
|
226
|
+
"""Plot accuracy metrics by column.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
title : str, default "Accuracy Metrics"
|
|
231
|
+
Plot title.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
go.Figure
|
|
236
|
+
"""
|
|
237
|
+
acc = self._result.accuracy
|
|
238
|
+
columns = list(acc.rmse.index)
|
|
239
|
+
|
|
240
|
+
records = []
|
|
241
|
+
for col in columns:
|
|
242
|
+
records.append({"Column": col, "Metric": "RMSE", "Value": acc.rmse[col]})
|
|
243
|
+
records.append({"Column": col, "Metric": "MAE", "Value": acc.mae[col]})
|
|
244
|
+
records.append(
|
|
245
|
+
{
|
|
246
|
+
"Column": col,
|
|
247
|
+
"Metric": "RMSE (Duration)",
|
|
248
|
+
"Value": acc.rmse_duration[col],
|
|
249
|
+
}
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
df = pd.DataFrame(records)
|
|
253
|
+
|
|
254
|
+
fig = px.bar(
|
|
255
|
+
df,
|
|
256
|
+
x="Column",
|
|
257
|
+
y="Value",
|
|
258
|
+
color="Metric",
|
|
259
|
+
barmode="group",
|
|
260
|
+
title=title,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return fig
|
|
264
|
+
|
|
265
|
+
def segment_durations(self, title: str = "Segment Durations") -> go.Figure:
|
|
266
|
+
"""Plot segment durations (if segmentation was used).
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
title : str, default "Segment Durations"
|
|
271
|
+
Plot title.
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
go.Figure
|
|
276
|
+
|
|
277
|
+
Raises
|
|
278
|
+
------
|
|
279
|
+
ValueError
|
|
280
|
+
If no segmentation was used.
|
|
281
|
+
"""
|
|
282
|
+
if self._result.segment_durations is None:
|
|
283
|
+
raise ValueError("No segmentation was used in this aggregation")
|
|
284
|
+
|
|
285
|
+
# segment_durations is tuple[tuple[int, ...], ...] - one tuple per period
|
|
286
|
+
# Average durations across all typical periods for the bar chart
|
|
287
|
+
durations = self._result.segment_durations
|
|
288
|
+
|
|
289
|
+
# Validate uniform structure across periods
|
|
290
|
+
segment_counts = {len(period) for period in durations}
|
|
291
|
+
if len(segment_counts) != 1:
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"Inconsistent segment counts across periods: {segment_counts}. "
|
|
294
|
+
"Cannot compute average durations."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
n_segments = len(durations[0])
|
|
298
|
+
avg_durations = [
|
|
299
|
+
sum(period[s] for period in durations) / len(durations)
|
|
300
|
+
for s in range(n_segments)
|
|
301
|
+
]
|
|
302
|
+
|
|
303
|
+
df = pd.DataFrame(
|
|
304
|
+
{
|
|
305
|
+
"Segment": [f"Segment {s}" for s in range(n_segments)],
|
|
306
|
+
"Duration": avg_durations,
|
|
307
|
+
}
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
fig = px.bar(
|
|
311
|
+
df,
|
|
312
|
+
x="Segment",
|
|
313
|
+
y="Duration",
|
|
314
|
+
title=title,
|
|
315
|
+
text="Duration",
|
|
316
|
+
color="Duration",
|
|
317
|
+
color_continuous_scale="Viridis",
|
|
318
|
+
)
|
|
319
|
+
fig.update_traces(texttemplate="%{text:.1f}", textposition="auto")
|
|
320
|
+
fig.update_layout(showlegend=False, yaxis_title="Duration (timesteps)")
|
|
321
|
+
|
|
322
|
+
return fig
|
|
323
|
+
|
|
324
|
+
def compare(
|
|
325
|
+
self,
|
|
326
|
+
columns: list[str] | None = None,
|
|
327
|
+
mode: str = "overlay",
|
|
328
|
+
title: str | None = None,
|
|
329
|
+
) -> go.Figure:
|
|
330
|
+
"""Compare original vs reconstructed time series.
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
columns : list[str], optional
|
|
335
|
+
Columns to compare. If None, compares all columns.
|
|
336
|
+
mode : str, default "overlay"
|
|
337
|
+
Comparison mode:
|
|
338
|
+
- "overlay": Both series on same axes
|
|
339
|
+
- "side_by_side": Separate subplots
|
|
340
|
+
- "duration_curve": Compare sorted values
|
|
341
|
+
title : str, optional
|
|
342
|
+
Plot title.
|
|
343
|
+
|
|
344
|
+
Returns
|
|
345
|
+
-------
|
|
346
|
+
go.Figure
|
|
347
|
+
|
|
348
|
+
Examples
|
|
349
|
+
--------
|
|
350
|
+
>>> result.plot.compare() # Compare all columns
|
|
351
|
+
>>> result.plot.compare(columns=["Load"]) # Compare specific column
|
|
352
|
+
>>> result.plot.compare(mode="duration_curve")
|
|
353
|
+
"""
|
|
354
|
+
orig = self._result.original
|
|
355
|
+
recon = self._result.reconstructed
|
|
356
|
+
|
|
357
|
+
columns = _validate_columns(columns, list(orig.columns), "original data")
|
|
358
|
+
|
|
359
|
+
if mode == "duration_curve":
|
|
360
|
+
return _duration_curve_figure(
|
|
361
|
+
{"Original": orig, "Reconstructed": recon},
|
|
362
|
+
columns=columns,
|
|
363
|
+
title=title,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
elif mode in ("overlay", "side_by_side"):
|
|
367
|
+
# Build long-form data with Source (Original/Reconstructed) and Column
|
|
368
|
+
orig_df = orig[columns].copy()
|
|
369
|
+
orig_df["Source"] = "Original"
|
|
370
|
+
recon_df = recon[columns].copy()
|
|
371
|
+
recon_df["Source"] = "Reconstructed"
|
|
372
|
+
|
|
373
|
+
combined = pd.concat([orig_df, recon_df])
|
|
374
|
+
combined.index.name = "Time"
|
|
375
|
+
long_df = combined.reset_index().melt(
|
|
376
|
+
id_vars=["Time", "Source"],
|
|
377
|
+
var_name="Column",
|
|
378
|
+
value_name="Value",
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
if mode == "overlay":
|
|
382
|
+
# Color by Column, dash by Source (Original/Reconstructed)
|
|
383
|
+
fig = px.line(
|
|
384
|
+
long_df,
|
|
385
|
+
x="Time",
|
|
386
|
+
y="Value",
|
|
387
|
+
color="Column",
|
|
388
|
+
line_dash="Source",
|
|
389
|
+
title=title or "Original vs Reconstructed",
|
|
390
|
+
)
|
|
391
|
+
else: # side_by_side
|
|
392
|
+
fig = px.line(
|
|
393
|
+
long_df,
|
|
394
|
+
x="Time",
|
|
395
|
+
y="Value",
|
|
396
|
+
color="Column",
|
|
397
|
+
facet_row="Source",
|
|
398
|
+
title=title or "Original vs Reconstructed",
|
|
399
|
+
)
|
|
400
|
+
fig.update_layout(height=600)
|
|
401
|
+
|
|
402
|
+
return fig
|
|
403
|
+
|
|
404
|
+
else:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"Unknown mode: {mode}. Use 'overlay', 'side_by_side', or 'duration_curve'."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
def residuals(
|
|
410
|
+
self,
|
|
411
|
+
columns: list[str] | None = None,
|
|
412
|
+
mode: str = "time_series",
|
|
413
|
+
title: str | None = None,
|
|
414
|
+
) -> go.Figure:
|
|
415
|
+
"""Plot residuals (original - reconstructed).
|
|
416
|
+
|
|
417
|
+
Parameters
|
|
418
|
+
----------
|
|
419
|
+
columns : list[str], optional
|
|
420
|
+
Columns to plot. If None, plots all.
|
|
421
|
+
mode : str, default "time_series"
|
|
422
|
+
Display mode:
|
|
423
|
+
- "time_series": Residuals over time
|
|
424
|
+
- "histogram": Distribution of residuals
|
|
425
|
+
- "by_period": Mean absolute error per period (bar chart)
|
|
426
|
+
- "by_timestep": Mean absolute error by timestep within period
|
|
427
|
+
title : str, optional
|
|
428
|
+
Plot title.
|
|
429
|
+
|
|
430
|
+
Returns
|
|
431
|
+
-------
|
|
432
|
+
go.Figure
|
|
433
|
+
|
|
434
|
+
Examples
|
|
435
|
+
--------
|
|
436
|
+
>>> result.plot.residuals() # Time series of residuals
|
|
437
|
+
>>> result.plot.residuals(mode="histogram") # Error distribution
|
|
438
|
+
>>> result.plot.residuals(mode="by_period") # Which periods have highest error
|
|
439
|
+
>>> result.plot.residuals(mode="by_timestep") # Error pattern within day
|
|
440
|
+
"""
|
|
441
|
+
resid = self._result.residuals
|
|
442
|
+
columns = _validate_columns(columns, list(resid.columns), "residuals")
|
|
443
|
+
|
|
444
|
+
if mode == "time_series":
|
|
445
|
+
df_plot = resid[columns].copy()
|
|
446
|
+
df_plot.index.name = "Time"
|
|
447
|
+
long_df = df_plot.reset_index().melt(
|
|
448
|
+
id_vars=["Time"],
|
|
449
|
+
var_name="Column",
|
|
450
|
+
value_name="Residual",
|
|
451
|
+
)
|
|
452
|
+
fig = px.line(
|
|
453
|
+
long_df,
|
|
454
|
+
x="Time",
|
|
455
|
+
y="Residual",
|
|
456
|
+
color="Column",
|
|
457
|
+
title=title or "Residuals Over Time",
|
|
458
|
+
)
|
|
459
|
+
fig.add_hline(y=0, line_dash="dash", line_color="gray")
|
|
460
|
+
return fig
|
|
461
|
+
|
|
462
|
+
elif mode == "histogram":
|
|
463
|
+
long_df = resid[columns].melt(var_name="Column", value_name="Residual")
|
|
464
|
+
fig = px.histogram(
|
|
465
|
+
long_df,
|
|
466
|
+
x="Residual",
|
|
467
|
+
color="Column",
|
|
468
|
+
barmode="overlay",
|
|
469
|
+
opacity=0.7,
|
|
470
|
+
title=title or "Residual Distribution",
|
|
471
|
+
)
|
|
472
|
+
fig.add_vline(x=0, line_dash="dash", line_color="red")
|
|
473
|
+
return fig
|
|
474
|
+
|
|
475
|
+
elif mode == "by_period":
|
|
476
|
+
n_timesteps = self._result.n_timesteps_per_period
|
|
477
|
+
abs_resid = resid[columns].abs().copy()
|
|
478
|
+
abs_resid["Period"] = np.arange(len(abs_resid)) // n_timesteps
|
|
479
|
+
|
|
480
|
+
df = abs_resid.groupby("Period")[columns].mean().reset_index()
|
|
481
|
+
long_df = df.melt(id_vars="Period", var_name="Column", value_name="MAE")
|
|
482
|
+
|
|
483
|
+
fig = px.bar(
|
|
484
|
+
long_df,
|
|
485
|
+
x="Period",
|
|
486
|
+
y="MAE",
|
|
487
|
+
color="Column",
|
|
488
|
+
barmode="group",
|
|
489
|
+
title=title or "Mean Absolute Error by Period",
|
|
490
|
+
)
|
|
491
|
+
return fig
|
|
492
|
+
|
|
493
|
+
elif mode == "by_timestep":
|
|
494
|
+
n_timesteps = self._result.n_timesteps_per_period
|
|
495
|
+
abs_resid = resid[columns].abs().copy()
|
|
496
|
+
abs_resid["Timestep"] = np.arange(len(abs_resid)) % n_timesteps
|
|
497
|
+
|
|
498
|
+
df = abs_resid.groupby("Timestep")[columns].mean().reset_index()
|
|
499
|
+
long_df = df.melt(id_vars="Timestep", var_name="Column", value_name="MAE")
|
|
500
|
+
|
|
501
|
+
fig = px.line(
|
|
502
|
+
long_df,
|
|
503
|
+
x="Timestep",
|
|
504
|
+
y="MAE",
|
|
505
|
+
color="Column",
|
|
506
|
+
title=title or "Mean Absolute Error by Timestep",
|
|
507
|
+
)
|
|
508
|
+
return fig
|
|
509
|
+
|
|
510
|
+
else:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
f"Unknown mode: {mode}. Use 'time_series', 'histogram', 'by_period', or 'by_timestep'."
|
|
513
|
+
)
|
tsam/py.typed
ADDED
|
File without changes
|