flixopt 3.0.1__py3-none-any.whl → 6.0.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flixopt/__init__.py +57 -49
  2. flixopt/carrier.py +159 -0
  3. flixopt/clustering/__init__.py +51 -0
  4. flixopt/clustering/base.py +1746 -0
  5. flixopt/clustering/intercluster_helpers.py +201 -0
  6. flixopt/color_processing.py +372 -0
  7. flixopt/comparison.py +819 -0
  8. flixopt/components.py +848 -270
  9. flixopt/config.py +853 -496
  10. flixopt/core.py +111 -98
  11. flixopt/effects.py +294 -284
  12. flixopt/elements.py +484 -223
  13. flixopt/features.py +220 -118
  14. flixopt/flow_system.py +2026 -389
  15. flixopt/interface.py +504 -286
  16. flixopt/io.py +1718 -55
  17. flixopt/linear_converters.py +291 -230
  18. flixopt/modeling.py +304 -181
  19. flixopt/network_app.py +2 -1
  20. flixopt/optimization.py +788 -0
  21. flixopt/optimize_accessor.py +373 -0
  22. flixopt/plot_result.py +143 -0
  23. flixopt/plotting.py +1177 -1034
  24. flixopt/results.py +1331 -372
  25. flixopt/solvers.py +12 -4
  26. flixopt/statistics_accessor.py +2412 -0
  27. flixopt/stats_accessor.py +75 -0
  28. flixopt/structure.py +954 -120
  29. flixopt/topology_accessor.py +676 -0
  30. flixopt/transform_accessor.py +2277 -0
  31. flixopt/types.py +120 -0
  32. flixopt-6.0.0rc7.dist-info/METADATA +290 -0
  33. flixopt-6.0.0rc7.dist-info/RECORD +36 -0
  34. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/WHEEL +1 -1
  35. flixopt/aggregation.py +0 -382
  36. flixopt/calculation.py +0 -672
  37. flixopt/commons.py +0 -51
  38. flixopt/utils.py +0 -86
  39. flixopt-3.0.1.dist-info/METADATA +0 -209
  40. flixopt-3.0.1.dist-info/RECORD +0 -26
  41. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/licenses/LICENSE +0 -0
  42. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1746 @@
1
+ """
2
+ Clustering classes for time series aggregation.
3
+
4
+ This module provides wrapper classes around tsam's clustering functionality:
5
+ - `ClusteringResults`: Collection of tsam ClusteringResult objects for multi-dim (period, scenario) data
6
+ - `Clustering`: Top-level class stored on FlowSystem after clustering
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import functools
12
+ import json
13
+ from collections import Counter
14
+ from typing import TYPE_CHECKING, Any
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import xarray as xr
19
+
20
+ if TYPE_CHECKING:
21
+ from pathlib import Path
22
+
23
+ from tsam import AggregationResult
24
+ from tsam import ClusteringResult as TsamClusteringResult
25
+
26
+ from ..color_processing import ColorType
27
+ from ..plot_result import PlotResult
28
+ from ..statistics_accessor import SelectType
29
+
30
+ from ..statistics_accessor import _build_color_kwargs
31
+
32
+
33
+ def _apply_slot_defaults(plotly_kwargs: dict, defaults: dict[str, str | None]) -> None:
34
+ """Apply default slot assignments to plotly kwargs.
35
+
36
+ Args:
37
+ plotly_kwargs: The kwargs dict to update (modified in place).
38
+ defaults: Default slot assignments. None values block slots.
39
+ """
40
+ for slot, value in defaults.items():
41
+ plotly_kwargs.setdefault(slot, value)
42
+
43
+
44
+ def _select_dims(da: xr.DataArray, period: Any = None, scenario: Any = None) -> xr.DataArray:
45
+ """Select from DataArray by period/scenario if those dimensions exist."""
46
+ if 'period' in da.dims and period is not None:
47
+ da = da.sel(period=period)
48
+ if 'scenario' in da.dims and scenario is not None:
49
+ da = da.sel(scenario=scenario)
50
+ return da
51
+
52
+
53
+ def combine_slices(
54
+ slices: dict[tuple, np.ndarray],
55
+ extra_dims: list[str],
56
+ dim_coords: dict[str, list],
57
+ output_dim: str,
58
+ output_coord: Any,
59
+ attrs: dict | None = None,
60
+ ) -> xr.DataArray:
61
+ """Combine {(dim_values): 1D_array} dict into a DataArray.
62
+
63
+ This utility simplifies the common pattern of iterating over extra dimensions
64
+ (like period, scenario), processing each slice, and combining results.
65
+
66
+ Args:
67
+ slices: Dict mapping dimension value tuples to 1D numpy arrays.
68
+ Keys are tuples like ('period1', 'scenario1') matching extra_dims order.
69
+ extra_dims: Dimension names in order (e.g., ['period', 'scenario']).
70
+ dim_coords: Dict mapping dimension names to coordinate values.
71
+ output_dim: Name of the output dimension (typically 'time').
72
+ output_coord: Coordinate values for output dimension.
73
+ attrs: Optional DataArray attributes.
74
+
75
+ Returns:
76
+ DataArray with dims [output_dim, *extra_dims].
77
+
78
+ Raises:
79
+ ValueError: If slices is empty.
80
+ KeyError: If a required key is missing from slices.
81
+
82
+ Example:
83
+ >>> slices = {
84
+ ... ('P1', 'base'): np.array([1, 2, 3]),
85
+ ... ('P1', 'high'): np.array([4, 5, 6]),
86
+ ... ('P2', 'base'): np.array([7, 8, 9]),
87
+ ... ('P2', 'high'): np.array([10, 11, 12]),
88
+ ... }
89
+ >>> result = combine_slices(
90
+ ... slices,
91
+ ... extra_dims=['period', 'scenario'],
92
+ ... dim_coords={'period': ['P1', 'P2'], 'scenario': ['base', 'high']},
93
+ ... output_dim='time',
94
+ ... output_coord=[0, 1, 2],
95
+ ... )
96
+ >>> result.dims
97
+ ('time', 'period', 'scenario')
98
+ """
99
+ if not slices:
100
+ raise ValueError('slices cannot be empty')
101
+
102
+ first = next(iter(slices.values()))
103
+ n_output = len(first)
104
+ shape = [n_output] + [len(dim_coords[d]) for d in extra_dims]
105
+ data = np.empty(shape, dtype=first.dtype)
106
+
107
+ for combo in np.ndindex(*shape[1:]):
108
+ key = tuple(dim_coords[d][i] for d, i in zip(extra_dims, combo, strict=True))
109
+ try:
110
+ data[(slice(None),) + combo] = slices[key]
111
+ except KeyError:
112
+ raise KeyError(f'Missing slice for key {key} (extra_dims={extra_dims})') from None
113
+
114
+ return xr.DataArray(
115
+ data,
116
+ dims=[output_dim] + extra_dims,
117
+ coords={output_dim: output_coord, **dim_coords},
118
+ attrs=attrs or {},
119
+ )
120
+
121
+
122
+ def _cluster_occurrences(cr: TsamClusteringResult) -> np.ndarray:
123
+ """Compute cluster occurrences from ClusteringResult."""
124
+ counts = Counter(cr.cluster_assignments)
125
+ return np.array([counts.get(i, 0) for i in range(cr.n_clusters)])
126
+
127
+
128
+ def _build_timestep_mapping(cr: TsamClusteringResult, n_timesteps: int) -> np.ndarray:
129
+ """Build mapping from original timesteps to representative timestep indices.
130
+
131
+ For segmented systems, the mapping uses segment_assignments from tsam to map
132
+ each original timestep position to its corresponding segment index.
133
+ """
134
+ timesteps_per_cluster = cr.n_timesteps_per_period
135
+ # For segmented systems, representative time dimension has n_segments entries
136
+ # For non-segmented, it has timesteps_per_cluster entries
137
+ n_segments = cr.n_segments
138
+ is_segmented = n_segments is not None
139
+ time_dim_size = n_segments if is_segmented else timesteps_per_cluster
140
+
141
+ # For segmented systems, tsam provides segment_assignments which maps
142
+ # each position within a period to its segment index
143
+ segment_assignments = cr.segment_assignments if is_segmented else None
144
+
145
+ mapping = np.zeros(n_timesteps, dtype=np.int32)
146
+ for period_idx, cluster_id in enumerate(cr.cluster_assignments):
147
+ for pos in range(timesteps_per_cluster):
148
+ orig_idx = period_idx * timesteps_per_cluster + pos
149
+ if orig_idx < n_timesteps:
150
+ if is_segmented and segment_assignments is not None:
151
+ # For segmented: use tsam's segment_assignments to get segment index
152
+ # segment_assignments[cluster_id][pos] gives the segment index
153
+ segment_idx = segment_assignments[cluster_id][pos]
154
+ mapping[orig_idx] = int(cluster_id) * time_dim_size + segment_idx
155
+ else:
156
+ # Non-segmented: direct position mapping
157
+ mapping[orig_idx] = int(cluster_id) * time_dim_size + pos
158
+ return mapping
159
+
160
+
161
+ class ClusteringResults:
162
+ """Collection of tsam ClusteringResult objects for multi-dimensional data.
163
+
164
+ Manages multiple ClusteringResult objects keyed by (period, scenario) tuples
165
+ and provides convenient access and multi-dimensional DataArray building.
166
+
167
+ Follows xarray-like patterns with `.dims`, `.coords`, `.sel()`, and `.isel()`.
168
+
169
+ Attributes:
170
+ dims: Tuple of dimension names, e.g., ('period', 'scenario').
171
+ coords: Dict mapping dimension names to their coordinate values.
172
+
173
+ Example:
174
+ >>> results = ClusteringResults({(): cr}, dim_names=[])
175
+ >>> results.n_clusters
176
+ 2
177
+ >>> results.cluster_assignments # Returns DataArray
178
+ <xarray.DataArray (original_cluster: 3)>
179
+
180
+ >>> # Multi-dimensional case
181
+ >>> results = ClusteringResults(
182
+ ... {(2024, 'high'): cr1, (2024, 'low'): cr2},
183
+ ... dim_names=['period', 'scenario'],
184
+ ... )
185
+ >>> results.dims
186
+ ('period', 'scenario')
187
+ >>> results.coords
188
+ {'period': [2024], 'scenario': ['high', 'low']}
189
+ >>> results.sel(period=2024, scenario='high') # Label-based
190
+ <tsam ClusteringResult>
191
+ >>> results.isel(period=0, scenario=1) # Index-based
192
+ <tsam ClusteringResult>
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ results: dict[tuple, TsamClusteringResult],
198
+ dim_names: list[str],
199
+ ):
200
+ """Initialize ClusteringResults.
201
+
202
+ Args:
203
+ results: Dict mapping (period, scenario) tuples to tsam ClusteringResult objects.
204
+ For simple cases without periods/scenarios, use {(): result}.
205
+ dim_names: Names of extra dimensions, e.g., ['period', 'scenario'].
206
+ """
207
+ if not results:
208
+ raise ValueError('results cannot be empty')
209
+ self._results = results
210
+ self._dim_names = dim_names
211
+
212
+ # ==========================================================================
213
+ # xarray-like interface
214
+ # ==========================================================================
215
+
216
+ @property
217
+ def dims(self) -> tuple[str, ...]:
218
+ """Dimension names as tuple (xarray-like)."""
219
+ return tuple(self._dim_names)
220
+
221
+ @property
222
+ def dim_names(self) -> list[str]:
223
+ """Dimension names as list (backwards compatibility)."""
224
+ return list(self._dim_names)
225
+
226
+ @property
227
+ def coords(self) -> dict[str, list]:
228
+ """Coordinate values for each dimension (xarray-like).
229
+
230
+ Returns:
231
+ Dict mapping dimension names to lists of coordinate values.
232
+ """
233
+ return {dim: self._get_dim_values(dim) for dim in self._dim_names}
234
+
235
+ def sel(self, **kwargs: Any) -> TsamClusteringResult:
236
+ """Select result by dimension labels (xarray-like).
237
+
238
+ Args:
239
+ **kwargs: Dimension name=value pairs, e.g., period=2024, scenario='high'.
240
+
241
+ Returns:
242
+ The tsam ClusteringResult for the specified combination.
243
+
244
+ Raises:
245
+ KeyError: If no result found for the specified combination.
246
+
247
+ Example:
248
+ >>> results.sel(period=2024, scenario='high')
249
+ <tsam ClusteringResult>
250
+ """
251
+ key = self._make_key(**kwargs)
252
+ if key not in self._results:
253
+ raise KeyError(f'No result found for {kwargs}')
254
+ return self._results[key]
255
+
256
+ def isel(self, **kwargs: int) -> TsamClusteringResult:
257
+ """Select result by dimension indices (xarray-like).
258
+
259
+ Args:
260
+ **kwargs: Dimension name=index pairs, e.g., period=0, scenario=1.
261
+
262
+ Returns:
263
+ The tsam ClusteringResult for the specified combination.
264
+
265
+ Raises:
266
+ IndexError: If index is out of range for a dimension.
267
+
268
+ Example:
269
+ >>> results.isel(period=0, scenario=1)
270
+ <tsam ClusteringResult>
271
+ """
272
+ label_kwargs = {}
273
+ for dim, idx in kwargs.items():
274
+ coord_values = self._get_dim_values(dim)
275
+ if coord_values is None:
276
+ raise KeyError(f"Dimension '{dim}' not found in dims {self.dims}")
277
+ if idx < 0 or idx >= len(coord_values):
278
+ raise IndexError(f"Index {idx} out of range for dimension '{dim}' with {len(coord_values)} values")
279
+ label_kwargs[dim] = coord_values[idx]
280
+ return self.sel(**label_kwargs)
281
+
282
+ def __getitem__(self, key: tuple) -> TsamClusteringResult:
283
+ """Get result by key tuple."""
284
+ return self._results[key]
285
+
286
+ # === Iteration ===
287
+
288
+ def __iter__(self):
289
+ """Iterate over ClusteringResult objects."""
290
+ return iter(self._results.values())
291
+
292
+ def __len__(self) -> int:
293
+ """Number of ClusteringResult objects."""
294
+ return len(self._results)
295
+
296
+ def items(self):
297
+ """Iterate over (key, ClusteringResult) pairs."""
298
+ return self._results.items()
299
+
300
+ def keys(self):
301
+ """Iterate over keys."""
302
+ return self._results.keys()
303
+
304
+ def values(self):
305
+ """Iterate over ClusteringResult objects."""
306
+ return self._results.values()
307
+
308
+ # === Properties from first result ===
309
+
310
+ @property
311
+ def _first_result(self) -> TsamClusteringResult:
312
+ """Get the first ClusteringResult (for structure info)."""
313
+ return next(iter(self._results.values()))
314
+
315
+ @property
316
+ def n_clusters(self) -> int:
317
+ """Number of clusters (same for all results)."""
318
+ return self._first_result.n_clusters
319
+
320
+ @property
321
+ def timesteps_per_cluster(self) -> int:
322
+ """Number of timesteps per cluster (same for all results)."""
323
+ return self._first_result.n_timesteps_per_period
324
+
325
+ @property
326
+ def n_original_periods(self) -> int:
327
+ """Number of original periods (same for all results)."""
328
+ return self._first_result.n_original_periods
329
+
330
+ @property
331
+ def n_segments(self) -> int | None:
332
+ """Number of segments per cluster, or None if not segmented."""
333
+ return self._first_result.n_segments
334
+
335
+ # === Multi-dim DataArrays ===
336
+
337
+ @property
338
+ def cluster_assignments(self) -> xr.DataArray:
339
+ """Maps each original cluster to its typical cluster index.
340
+
341
+ Returns:
342
+ DataArray with dims [original_cluster, period?, scenario?].
343
+ """
344
+ # Note: No coords on original_cluster - they cause issues when used as isel() indexer
345
+ return self._build_property_array(
346
+ lambda cr: np.array(cr.cluster_assignments),
347
+ base_dims=['original_cluster'],
348
+ name='cluster_assignments',
349
+ )
350
+
351
+ @property
352
+ def cluster_occurrences(self) -> xr.DataArray:
353
+ """How many original clusters map to each typical cluster.
354
+
355
+ Returns:
356
+ DataArray with dims [cluster, period?, scenario?].
357
+ """
358
+ return self._build_property_array(
359
+ _cluster_occurrences,
360
+ base_dims=['cluster'],
361
+ base_coords={'cluster': range(self.n_clusters)},
362
+ name='cluster_occurrences',
363
+ )
364
+
365
+ @property
366
+ def cluster_centers(self) -> xr.DataArray:
367
+ """Which original cluster is the representative (center) for each typical cluster.
368
+
369
+ Returns:
370
+ DataArray with dims [cluster, period?, scenario?].
371
+ """
372
+ return self._build_property_array(
373
+ lambda cr: np.array(cr.cluster_centers),
374
+ base_dims=['cluster'],
375
+ base_coords={'cluster': range(self.n_clusters)},
376
+ name='cluster_centers',
377
+ )
378
+
379
+ @property
380
+ def segment_assignments(self) -> xr.DataArray | None:
381
+ """For each timestep within a cluster, which segment it belongs to.
382
+
383
+ Returns:
384
+ DataArray with dims [cluster, time, period?, scenario?], or None if not segmented.
385
+ """
386
+ if self._first_result.segment_assignments is None:
387
+ return None
388
+ timesteps = self._first_result.n_timesteps_per_period
389
+ return self._build_property_array(
390
+ lambda cr: np.array(cr.segment_assignments),
391
+ base_dims=['cluster', 'time'],
392
+ base_coords={'cluster': range(self.n_clusters), 'time': range(timesteps)},
393
+ name='segment_assignments',
394
+ )
395
+
396
+ @property
397
+ def segment_durations(self) -> xr.DataArray | None:
398
+ """Duration of each segment in timesteps.
399
+
400
+ Returns:
401
+ DataArray with dims [cluster, segment, period?, scenario?], or None if not segmented.
402
+ """
403
+ if self._first_result.segment_durations is None:
404
+ return None
405
+ n_segments = self._first_result.n_segments
406
+
407
+ def _get_padded_durations(cr: TsamClusteringResult) -> np.ndarray:
408
+ """Pad ragged segment durations to uniform shape."""
409
+ return np.array([list(d) + [np.nan] * (n_segments - len(d)) for d in cr.segment_durations])
410
+
411
+ return self._build_property_array(
412
+ _get_padded_durations,
413
+ base_dims=['cluster', 'segment'],
414
+ base_coords={'cluster': range(self.n_clusters), 'segment': range(n_segments)},
415
+ name='segment_durations',
416
+ )
417
+
418
+ @property
419
+ def segment_centers(self) -> xr.DataArray | None:
420
+ """Center of each intra-period segment.
421
+
422
+ Only available if segmentation was configured during clustering.
423
+
424
+ Returns:
425
+ DataArray or None if no segmentation.
426
+ """
427
+ first = self._first_result
428
+ if first.segment_centers is None:
429
+ return None
430
+
431
+ n_segments = first.n_segments
432
+ return self._build_property_array(
433
+ lambda cr: np.array(cr.segment_centers),
434
+ base_dims=['cluster', 'segment'],
435
+ base_coords={'cluster': range(self.n_clusters), 'segment': range(n_segments)},
436
+ name='segment_centers',
437
+ )
438
+
439
+ @property
440
+ def position_within_segment(self) -> xr.DataArray | None:
441
+ """Position of each timestep within its segment (0-indexed).
442
+
443
+ For each (cluster, time) position, returns how many timesteps into the
444
+ segment that position is. Used for interpolation within segments.
445
+
446
+ Returns:
447
+ DataArray with dims [cluster, time] or [cluster, time, period?, scenario?].
448
+ Returns None if no segmentation.
449
+ """
450
+ segment_assignments = self.segment_assignments
451
+ if segment_assignments is None:
452
+ return None
453
+
454
+ def _compute_positions(seg_assigns: np.ndarray) -> np.ndarray:
455
+ """Compute position within segment for each (cluster, time)."""
456
+ n_clusters, n_times = seg_assigns.shape
457
+ positions = np.zeros_like(seg_assigns)
458
+ for c in range(n_clusters):
459
+ pos = 0
460
+ prev_seg = -1
461
+ for t in range(n_times):
462
+ seg = seg_assigns[c, t]
463
+ if seg != prev_seg:
464
+ pos = 0
465
+ prev_seg = seg
466
+ positions[c, t] = pos
467
+ pos += 1
468
+ return positions
469
+
470
+ # Handle extra dimensions by applying _compute_positions to each slice
471
+ extra_dims = [d for d in segment_assignments.dims if d not in ('cluster', 'time')]
472
+
473
+ if not extra_dims:
474
+ positions = _compute_positions(segment_assignments.values)
475
+ return xr.DataArray(
476
+ positions,
477
+ dims=['cluster', 'time'],
478
+ coords=segment_assignments.coords,
479
+ name='position_within_segment',
480
+ )
481
+
482
+ # Multi-dimensional case: compute for each period/scenario slice
483
+ result = xr.apply_ufunc(
484
+ _compute_positions,
485
+ segment_assignments,
486
+ input_core_dims=[['cluster', 'time']],
487
+ output_core_dims=[['cluster', 'time']],
488
+ vectorize=True,
489
+ )
490
+ return result.rename('position_within_segment')
491
+
492
+ # === Serialization ===
493
+
494
+ def to_dict(self) -> dict:
495
+ """Serialize to dict.
496
+
497
+ The dict can be used to reconstruct via from_dict().
498
+ """
499
+ return {
500
+ 'dim_names': list(self._dim_names),
501
+ 'results': {self._key_to_str(key): result.to_dict() for key, result in self._results.items()},
502
+ }
503
+
504
+ @classmethod
505
+ def from_dict(cls, d: dict) -> ClusteringResults:
506
+ """Reconstruct from dict.
507
+
508
+ Args:
509
+ d: Dict from to_dict().
510
+
511
+ Returns:
512
+ Reconstructed ClusteringResults.
513
+ """
514
+ from tsam import ClusteringResult
515
+
516
+ dim_names = d['dim_names']
517
+ results = {}
518
+ for key_str, result_dict in d['results'].items():
519
+ key = cls._str_to_key(key_str, dim_names)
520
+ results[key] = ClusteringResult.from_dict(result_dict)
521
+ return cls(results, dim_names)
522
+
523
+ # === Private helpers ===
524
+
525
+ def _make_key(self, **kwargs: Any) -> tuple:
526
+ """Create a key tuple from dimension keyword arguments."""
527
+ key_parts = []
528
+ for dim in self._dim_names:
529
+ if dim in kwargs:
530
+ key_parts.append(kwargs[dim])
531
+ return tuple(key_parts)
532
+
533
+ def _get_dim_values(self, dim: str) -> list | None:
534
+ """Get unique values for a dimension, or None if dimension not present.
535
+
536
+ Preserves insertion order to ensure .isel() positional indexing matches
537
+ the original FlowSystem dimension order.
538
+ """
539
+ if dim not in self._dim_names:
540
+ return None
541
+ idx = self._dim_names.index(dim)
542
+ # Use dict.fromkeys to preserve insertion order while removing duplicates
543
+ values = [k[idx] for k in self._results.keys()]
544
+ return list(dict.fromkeys(values))
545
+
546
+ def _build_property_array(
547
+ self,
548
+ get_data: callable,
549
+ base_dims: list[str],
550
+ base_coords: dict | None = None,
551
+ name: str | None = None,
552
+ ) -> xr.DataArray:
553
+ """Build a DataArray property, handling both single and multi-dimensional cases."""
554
+ base_coords = base_coords or {}
555
+ periods = self._get_dim_values('period')
556
+ scenarios = self._get_dim_values('scenario')
557
+
558
+ # Build list of (dim_name, values) for dimensions that exist
559
+ extra_dims = []
560
+ if periods is not None:
561
+ extra_dims.append(('period', periods))
562
+ if scenarios is not None:
563
+ extra_dims.append(('scenario', scenarios))
564
+
565
+ # Simple case: no extra dimensions
566
+ if not extra_dims:
567
+ return xr.DataArray(get_data(self._results[()]), dims=base_dims, coords=base_coords, name=name)
568
+
569
+ # Multi-dimensional: stack data for each combination
570
+ first_data = get_data(next(iter(self._results.values())))
571
+ shape = list(first_data.shape) + [len(vals) for _, vals in extra_dims]
572
+ data = np.empty(shape, dtype=first_data.dtype) # Preserve dtype
573
+
574
+ for combo in np.ndindex(*[len(vals) for _, vals in extra_dims]):
575
+ key = tuple(extra_dims[i][1][idx] for i, idx in enumerate(combo))
576
+ data[(...,) + combo] = get_data(self._results[key])
577
+
578
+ dims = base_dims + [dim_name for dim_name, _ in extra_dims]
579
+ coords = {**base_coords, **{dim_name: vals for dim_name, vals in extra_dims}}
580
+ return xr.DataArray(data, dims=dims, coords=coords, name=name)
581
+
582
+ @staticmethod
583
+ def _key_to_str(key: tuple) -> str:
584
+ """Convert key tuple to string for serialization."""
585
+ if not key:
586
+ return '__single__'
587
+ return '|'.join(str(k) for k in key)
588
+
589
+ @staticmethod
590
+ def _str_to_key(key_str: str, dim_names: list[str]) -> tuple:
591
+ """Convert string back to key tuple."""
592
+ if key_str == '__single__':
593
+ return ()
594
+ parts = key_str.split('|')
595
+ # Try to convert to int if possible (for period years)
596
+ result = []
597
+ for part in parts:
598
+ try:
599
+ result.append(int(part))
600
+ except ValueError:
601
+ result.append(part)
602
+ return tuple(result)
603
+
604
+ def __repr__(self) -> str:
605
+ if not self.dims:
606
+ return f'ClusteringResults(n_clusters={self.n_clusters})'
607
+ coords_str = ', '.join(f'{k}: {len(v)}' for k, v in self.coords.items())
608
+ return f'ClusteringResults(dims={self.dims}, coords=({coords_str}), n_clusters={self.n_clusters})'
609
+
610
+ def apply(self, data: xr.Dataset) -> AggregationResults:
611
+ """Apply clustering to dataset for all (period, scenario) combinations.
612
+
613
+ Args:
614
+ data: Dataset with time-varying data. Must have 'time' dimension.
615
+ May have 'period' and/or 'scenario' dimensions matching this object.
616
+
617
+ Returns:
618
+ AggregationResults with full access to aggregated data.
619
+ Use `.clustering` on the result to get ClusteringResults for IO.
620
+
621
+ Example:
622
+ >>> agg_results = clustering_results.apply(dataset)
623
+ >>> agg_results.clustering # Get ClusteringResults for IO
624
+ >>> for key, result in agg_results:
625
+ ... print(result.cluster_representatives)
626
+ """
627
+ from ..core import drop_constant_arrays
628
+
629
+ results = {}
630
+ for key, cr in self._results.items():
631
+ # Build selector for this key
632
+ selector = dict(zip(self._dim_names, key, strict=False))
633
+
634
+ # Select the slice for this (period, scenario)
635
+ data_slice = data.sel(**selector, drop=True) if selector else data
636
+
637
+ # Drop constant arrays and convert to DataFrame
638
+ time_varying = drop_constant_arrays(data_slice, dim='time')
639
+ df = time_varying.to_dataframe()
640
+
641
+ # Apply clustering
642
+ results[key] = cr.apply(df)
643
+
644
+ return Clustering._from_aggregation_results(results, self._dim_names)
645
+
646
+
647
+ class Clustering:
648
+ """Clustering information for a FlowSystem.
649
+
650
+ Thin wrapper around tsam 3.0's AggregationResult objects, providing:
651
+ 1. Multi-dimensional access for (period, scenario) combinations
652
+ 2. Structure properties (n_clusters, dims, coords, cluster_assignments)
653
+ 3. JSON persistence via ClusteringResults
654
+
655
+ Use ``sel()`` to access individual tsam AggregationResult objects for
656
+ detailed analysis (cluster_representatives, accuracy, plotting).
657
+
658
+ Attributes:
659
+ results: ClusteringResults for structure access (works after JSON load).
660
+ original_timesteps: Original timesteps before clustering.
661
+ dims: Dimension names, e.g., ('period', 'scenario').
662
+ coords: Coordinate values, e.g., {'period': [2024, 2025]}.
663
+
664
+ Example:
665
+ >>> clustering = fs_clustered.clustering
666
+ >>> clustering.n_clusters
667
+ 8
668
+ >>> clustering.dims
669
+ ('period',)
670
+
671
+ # Access tsam AggregationResult for detailed analysis
672
+ >>> result = clustering.sel(period=2024)
673
+ >>> result.cluster_representatives # DataFrame
674
+ >>> result.accuracy # AccuracyMetrics
675
+ >>> result.plot.compare() # tsam's built-in plotting
676
+ """
677
+
678
+ # ==========================================================================
679
+ # Core properties (delegated to ClusteringResults)
680
+ # ==========================================================================
681
+
682
+ @property
683
+ def n_clusters(self) -> int:
684
+ """Number of clusters (typical periods)."""
685
+ return self.results.n_clusters
686
+
687
+ @property
688
+ def timesteps_per_cluster(self) -> int:
689
+ """Number of timesteps in each cluster."""
690
+ return self.results.timesteps_per_cluster
691
+
692
+ @property
693
+ def timesteps_per_period(self) -> int:
694
+ """Alias for timesteps_per_cluster."""
695
+ return self.timesteps_per_cluster
696
+
697
+ @property
698
+ def n_original_clusters(self) -> int:
699
+ """Number of original periods (before clustering)."""
700
+ return self.results.n_original_periods
701
+
702
+ @property
703
+ def dim_names(self) -> list[str]:
704
+ """Names of extra dimensions, e.g., ['period', 'scenario']."""
705
+ return self.results.dim_names
706
+
707
+ @property
708
+ def dims(self) -> tuple[str, ...]:
709
+ """Dimension names as tuple (xarray-like)."""
710
+ return self.results.dims
711
+
712
+ @property
713
+ def coords(self) -> dict[str, list]:
714
+ """Coordinate values for each dimension (xarray-like).
715
+
716
+ Returns:
717
+ Dict mapping dimension names to lists of coordinate values.
718
+
719
+ Example:
720
+ >>> clustering.coords
721
+ {'period': [2024, 2025], 'scenario': ['low', 'high']}
722
+ """
723
+ return self.results.coords
724
+
725
+ def sel(
726
+ self,
727
+ period: int | str | None = None,
728
+ scenario: str | None = None,
729
+ ) -> AggregationResult:
730
+ """Select AggregationResult by period and/or scenario.
731
+
732
+ Access individual tsam AggregationResult objects for detailed analysis.
733
+
734
+ Note:
735
+ This method is only available before saving/loading the FlowSystem.
736
+ After IO (to_dataset/from_dataset or to_json), the full AggregationResult
737
+ data is not preserved. Use `results.sel()` for structure-only access
738
+ after loading.
739
+
740
+ Args:
741
+ period: Period value (e.g., 2024). Required if clustering has periods.
742
+ scenario: Scenario name (e.g., 'high'). Required if clustering has scenarios.
743
+
744
+ Returns:
745
+ The tsam AggregationResult for the specified combination.
746
+ Access its properties like `cluster_representatives`, `accuracy`, etc.
747
+
748
+ Raises:
749
+ KeyError: If no result found for the specified combination.
750
+ ValueError: If accessed on a Clustering loaded from JSON/NetCDF.
751
+
752
+ Example:
753
+ >>> result = clustering.sel(period=2024, scenario='high')
754
+ >>> result.cluster_representatives # DataFrame with aggregated data
755
+ >>> result.accuracy # AccuracyMetrics
756
+ >>> result.plot.compare() # tsam's built-in comparison plot
757
+ """
758
+ self._require_full_data('sel()')
759
+ # Build key from provided args in dim order
760
+ key_parts = []
761
+ if 'period' in self._dim_names:
762
+ if period is None:
763
+ raise KeyError(f"'period' is required. Available: {self.coords.get('period', [])}")
764
+ key_parts.append(period)
765
+ if 'scenario' in self._dim_names:
766
+ if scenario is None:
767
+ raise KeyError(f"'scenario' is required. Available: {self.coords.get('scenario', [])}")
768
+ key_parts.append(scenario)
769
+ key = tuple(key_parts)
770
+ if key not in self._aggregation_results:
771
+ raise KeyError(f'No result found for period={period}, scenario={scenario}')
772
+ return self._aggregation_results[key]
773
+
774
+ @property
775
+ def is_segmented(self) -> bool:
776
+ """Whether intra-period segmentation was used.
777
+
778
+ Segmented systems have variable timestep durations within each cluster,
779
+ where each segment represents a different number of original timesteps.
780
+ """
781
+ return self.results.n_segments is not None
782
+
783
+ @property
784
+ def n_segments(self) -> int | None:
785
+ """Number of segments per cluster, or None if not segmented."""
786
+ return self.results.n_segments
787
+
788
+ @property
789
+ def cluster_assignments(self) -> xr.DataArray:
790
+ """Mapping from original periods to cluster IDs.
791
+
792
+ Returns:
793
+ DataArray with dims [original_cluster] or [original_cluster, period?, scenario?].
794
+ """
795
+ return self.results.cluster_assignments
796
+
797
+ @property
798
+ def n_representatives(self) -> int:
799
+ """Number of representative timesteps after clustering."""
800
+ if self.is_segmented:
801
+ return self.n_clusters * self.n_segments
802
+ return self.n_clusters * self.timesteps_per_cluster
803
+
804
+ # ==========================================================================
805
+ # Derived properties
806
+ # ==========================================================================
807
+
808
+ @property
809
+ def cluster_occurrences(self) -> xr.DataArray:
810
+ """Count of how many original periods each cluster represents.
811
+
812
+ Returns:
813
+ DataArray with dims [cluster] or [cluster, period?, scenario?].
814
+ """
815
+ return self.results.cluster_occurrences
816
+
817
+ @property
818
+ def representative_weights(self) -> xr.DataArray:
819
+ """Weight for each cluster (number of original periods it represents).
820
+
821
+ This is the same as cluster_occurrences but named for API consistency.
822
+ Used as cluster_weight in FlowSystem.
823
+ """
824
+ return self.cluster_occurrences.rename('representative_weights')
825
+
826
+ @functools.cached_property
827
+ def timestep_mapping(self) -> xr.DataArray:
828
+ """Mapping from original timesteps to representative timestep indices.
829
+
830
+ Each value indicates which representative timestep index (0 to n_representatives-1)
831
+ corresponds to each original timestep.
832
+
833
+ Note: This property is cached for performance since it's accessed frequently
834
+ during expand() operations.
835
+ """
836
+ return self._build_timestep_mapping()
837
+
838
+ @property
839
+ def metrics(self) -> xr.Dataset:
840
+ """Clustering quality metrics (RMSE, MAE, etc.).
841
+
842
+ Returns:
843
+ Dataset with dims [time_series, period?, scenario?], or empty Dataset if no metrics.
844
+ """
845
+ if self._metrics is None:
846
+ return xr.Dataset()
847
+ return self._metrics
848
+
849
+ @property
850
+ def cluster_start_positions(self) -> np.ndarray:
851
+ """Integer positions where clusters start in reduced timesteps.
852
+
853
+ Returns:
854
+ 1D array: [0, T, 2T, ...] where T = timesteps_per_cluster (or n_segments if segmented).
855
+ """
856
+ if self.is_segmented:
857
+ n_timesteps = self.n_clusters * self.n_segments
858
+ return np.arange(0, n_timesteps, self.n_segments)
859
+ n_timesteps = self.n_clusters * self.timesteps_per_cluster
860
+ return np.arange(0, n_timesteps, self.timesteps_per_cluster)
861
+
862
+ @property
863
+ def cluster_centers(self) -> xr.DataArray:
864
+ """Which original period is the representative (center) for each cluster.
865
+
866
+ Returns:
867
+ DataArray with dims [cluster] containing original period indices.
868
+ """
869
+ return self.results.cluster_centers
870
+
871
+ @property
872
+ def segment_assignments(self) -> xr.DataArray | None:
873
+ """For each timestep within a cluster, which intra-period segment it belongs to.
874
+
875
+ Only available if segmentation was configured during clustering.
876
+
877
+ Returns:
878
+ DataArray with dims [cluster, time] or None if no segmentation.
879
+ """
880
+ return self.results.segment_assignments
881
+
882
+ @property
883
+ def segment_durations(self) -> xr.DataArray | None:
884
+ """Duration of each intra-period segment in hours.
885
+
886
+ Only available if segmentation was configured during clustering.
887
+
888
+ Returns:
889
+ DataArray with dims [cluster, segment] or None if no segmentation.
890
+ """
891
+ return self.results.segment_durations
892
+
893
+ @property
894
+ def segment_centers(self) -> xr.DataArray | None:
895
+ """Center of each intra-period segment.
896
+
897
+ Only available if segmentation was configured during clustering.
898
+
899
+ Returns:
900
+ DataArray with dims [cluster, segment] or None if no segmentation.
901
+ """
902
+ return self.results.segment_centers
903
+
904
+ # ==========================================================================
905
+ # Methods
906
+ # ==========================================================================
907
+
908
+ def expand_data(
909
+ self,
910
+ aggregated: xr.DataArray,
911
+ original_time: pd.DatetimeIndex | None = None,
912
+ ) -> xr.DataArray:
913
+ """Expand aggregated data back to original timesteps.
914
+
915
+ Uses the timestep_mapping to map each original timestep to its
916
+ representative value from the aggregated data. Fully vectorized using
917
+ xarray's advanced indexing - no loops over period/scenario dimensions.
918
+
919
+ Args:
920
+ aggregated: DataArray with aggregated (cluster, time) or (time,) dimension.
921
+ original_time: Original time coordinates. Defaults to self.original_timesteps.
922
+
923
+ Returns:
924
+ DataArray expanded to original timesteps.
925
+ """
926
+ if original_time is None:
927
+ original_time = self.original_timesteps
928
+
929
+ timestep_mapping = self.timestep_mapping # Already multi-dimensional DataArray
930
+
931
+ if 'cluster' not in aggregated.dims:
932
+ # No cluster dimension: use mapping directly as time index
933
+ expanded = aggregated.isel(time=timestep_mapping)
934
+ else:
935
+ # Has cluster dimension: compute cluster and time indices from mapping
936
+ # For segmented systems, time dimension is n_segments, not timesteps_per_cluster
937
+ if self.is_segmented and self.n_segments is not None:
938
+ time_dim_size = self.n_segments
939
+ else:
940
+ time_dim_size = self.timesteps_per_cluster
941
+
942
+ cluster_indices = timestep_mapping // time_dim_size
943
+ time_indices = timestep_mapping % time_dim_size
944
+
945
+ # xarray's advanced indexing handles broadcasting across period/scenario dims
946
+ expanded = aggregated.isel(cluster=cluster_indices, time=time_indices)
947
+
948
+ # Clean up: drop coordinate artifacts from isel, then rename original_time -> time
949
+ # The isel operation may leave 'cluster' and 'time' as non-dimension coordinates
950
+ expanded = expanded.drop_vars(['cluster', 'time'], errors='ignore')
951
+ expanded = expanded.rename({'original_time': 'time'}).assign_coords(time=original_time)
952
+
953
+ return expanded.transpose('time', ...).assign_attrs(aggregated.attrs)
954
+
955
+ def build_expansion_divisor(
956
+ self,
957
+ original_time: pd.DatetimeIndex | None = None,
958
+ ) -> xr.DataArray:
959
+ """Build divisor for correcting segment totals when expanding to hourly.
960
+
961
+ For segmented systems, each segment value is a total that gets repeated N times
962
+ when expanded to hourly resolution (where N = segment duration in timesteps).
963
+ This divisor allows converting those totals back to hourly rates during expansion.
964
+
965
+ For each original timestep, returns the number of original timesteps that map
966
+ to the same (cluster, segment) - i.e., the segment duration in timesteps.
967
+
968
+ Fully vectorized using xarray's advanced indexing - no loops over period/scenario.
969
+
970
+ Args:
971
+ original_time: Original time coordinates. Defaults to self.original_timesteps.
972
+
973
+ Returns:
974
+ DataArray with dims ['time'] or ['time', 'period'?, 'scenario'?] containing
975
+ the number of timesteps in each segment, aligned to original timesteps.
976
+ """
977
+ if not self.is_segmented or self.n_segments is None:
978
+ raise ValueError('build_expansion_divisor requires a segmented clustering')
979
+
980
+ if original_time is None:
981
+ original_time = self.original_timesteps
982
+
983
+ timestep_mapping = self.timestep_mapping # Already multi-dimensional
984
+ segment_durations = self.results.segment_durations # [cluster, segment, period?, scenario?]
985
+
986
+ # Decode cluster and segment indices from timestep_mapping
987
+ # For segmented systems, encoding is: cluster_id * n_segments + segment_idx
988
+ time_dim_size = self.n_segments
989
+ cluster_indices = timestep_mapping // time_dim_size
990
+ segment_indices = timestep_mapping % time_dim_size # This IS the segment index
991
+
992
+ # Get duration for each segment directly
993
+ # segment_durations[cluster, segment] -> duration
994
+ divisor = segment_durations.isel(cluster=cluster_indices, segment=segment_indices)
995
+
996
+ # Clean up coordinates and rename
997
+ divisor = divisor.drop_vars(['cluster', 'time', 'segment'], errors='ignore')
998
+ divisor = divisor.rename({'original_time': 'time'}).assign_coords(time=original_time)
999
+
1000
+ return divisor.transpose('time', ...).rename('expansion_divisor')
1001
+
1002
+ def get_result(
1003
+ self,
1004
+ period: Any = None,
1005
+ scenario: Any = None,
1006
+ ) -> TsamClusteringResult:
1007
+ """Get the tsam ClusteringResult for a specific (period, scenario).
1008
+
1009
+ Args:
1010
+ period: Period label (if applicable).
1011
+ scenario: Scenario label (if applicable).
1012
+
1013
+ Returns:
1014
+ The tsam ClusteringResult for the specified combination.
1015
+ """
1016
+ return self.results.sel(period=period, scenario=scenario)
1017
+
1018
+ def apply(
1019
+ self,
1020
+ data: pd.DataFrame,
1021
+ period: Any = None,
1022
+ scenario: Any = None,
1023
+ ) -> AggregationResult:
1024
+ """Apply the saved clustering to new data.
1025
+
1026
+ Args:
1027
+ data: DataFrame with time series data to cluster.
1028
+ period: Period label (if applicable).
1029
+ scenario: Scenario label (if applicable).
1030
+
1031
+ Returns:
1032
+ tsam AggregationResult with the clustering applied.
1033
+ """
1034
+ return self.results.sel(period=period, scenario=scenario).apply(data)
1035
+
1036
+ def to_json(self, path: str | Path) -> None:
1037
+ """Save the clustering for reuse.
1038
+
1039
+ Uses ClusteringResults.to_dict() which preserves full tsam ClusteringResult.
1040
+ Can be loaded later with Clustering.from_json() and used with
1041
+ flow_system.transform.apply_clustering().
1042
+
1043
+ Args:
1044
+ path: Path to save the JSON file.
1045
+ """
1046
+ data = {
1047
+ 'results': self.results.to_dict(),
1048
+ 'original_timesteps': [ts.isoformat() for ts in self.original_timesteps],
1049
+ }
1050
+
1051
+ with open(path, 'w') as f:
1052
+ json.dump(data, f, indent=2)
1053
+
1054
+ @classmethod
1055
+ def from_json(
1056
+ cls,
1057
+ path: str | Path,
1058
+ original_timesteps: pd.DatetimeIndex | None = None,
1059
+ ) -> Clustering:
1060
+ """Load a clustering from JSON.
1061
+
1062
+ The loaded Clustering has full apply() support because ClusteringResult
1063
+ is fully preserved via tsam's serialization.
1064
+
1065
+ Args:
1066
+ path: Path to the JSON file.
1067
+ original_timesteps: Original timesteps for the new FlowSystem.
1068
+ If None, uses the timesteps stored in the JSON.
1069
+
1070
+ Returns:
1071
+ A Clustering that can be used with apply_clustering().
1072
+ """
1073
+ with open(path) as f:
1074
+ data = json.load(f)
1075
+
1076
+ results = ClusteringResults.from_dict(data['results'])
1077
+
1078
+ if original_timesteps is None:
1079
+ original_timesteps = pd.DatetimeIndex([pd.Timestamp(ts) for ts in data['original_timesteps']])
1080
+
1081
+ return cls(
1082
+ results=results,
1083
+ original_timesteps=original_timesteps,
1084
+ )
1085
+
1086
+ # ==========================================================================
1087
+ # Visualization
1088
+ # ==========================================================================
1089
+
1090
+ @property
1091
+ def plot(self) -> ClusteringPlotAccessor:
1092
+ """Access plotting methods for clustering visualization.
1093
+
1094
+ Returns:
1095
+ ClusteringPlotAccessor with compare(), heatmap(), and clusters() methods.
1096
+ """
1097
+ return ClusteringPlotAccessor(self)
1098
+
1099
+ # ==========================================================================
1100
+ # Private helpers
1101
+ # ==========================================================================
1102
+
1103
+ def _build_timestep_mapping(self) -> xr.DataArray:
1104
+ """Build timestep_mapping DataArray."""
1105
+ n_original = len(self.original_timesteps)
1106
+ original_time_coord = self.original_timesteps.rename('original_time')
1107
+ return self.results._build_property_array(
1108
+ lambda cr: _build_timestep_mapping(cr, n_original),
1109
+ base_dims=['original_time'],
1110
+ base_coords={'original_time': original_time_coord},
1111
+ name='timestep_mapping',
1112
+ )
1113
+
1114
+ def _create_reference_structure(self, include_original_data: bool = True) -> tuple[dict, dict[str, xr.DataArray]]:
1115
+ """Create serialization structure for to_dataset().
1116
+
1117
+ Args:
1118
+ include_original_data: Whether to include original_data in serialization.
1119
+ Set to False for smaller files when plot.compare() isn't needed after IO.
1120
+ Defaults to True.
1121
+
1122
+ Returns:
1123
+ Tuple of (reference_dict, arrays_dict).
1124
+ """
1125
+ arrays = {}
1126
+
1127
+ # Collect original_data arrays
1128
+ # Rename 'time' to 'original_time' to avoid conflict with clustered FlowSystem's time coord
1129
+ original_data_refs = None
1130
+ if include_original_data and self.original_data is not None:
1131
+ original_data_refs = []
1132
+ # Use variables for faster access (avoids _construct_dataarray overhead)
1133
+ variables = self.original_data.variables
1134
+ for name in self.original_data.data_vars:
1135
+ var = variables[name]
1136
+ ref_name = f'original_data|{name}'
1137
+ # Rename time dim to avoid xarray alignment issues
1138
+ if 'time' in var.dims:
1139
+ new_dims = tuple('original_time' if d == 'time' else d for d in var.dims)
1140
+ arrays[ref_name] = xr.Variable(new_dims, var.values, attrs=var.attrs)
1141
+ else:
1142
+ arrays[ref_name] = var
1143
+ original_data_refs.append(f':::{ref_name}')
1144
+
1145
+ # NOTE: aggregated_data is NOT serialized - it's identical to the FlowSystem's
1146
+ # main data arrays and would be redundant. After loading, aggregated_data is
1147
+ # reconstructed from the FlowSystem's dataset.
1148
+
1149
+ # Collect metrics arrays
1150
+ metrics_refs = None
1151
+ if self._metrics is not None:
1152
+ metrics_refs = []
1153
+ # Use variables for faster access (avoids _construct_dataarray overhead)
1154
+ metrics_vars = self._metrics.variables
1155
+ for name in self._metrics.data_vars:
1156
+ ref_name = f'metrics|{name}'
1157
+ arrays[ref_name] = metrics_vars[name]
1158
+ metrics_refs.append(f':::{ref_name}')
1159
+
1160
+ reference = {
1161
+ '__class__': 'Clustering',
1162
+ 'results': self.results.to_dict(), # Full ClusteringResults serialization
1163
+ 'original_timesteps': [ts.isoformat() for ts in self.original_timesteps],
1164
+ '_original_data_refs': original_data_refs,
1165
+ '_metrics_refs': metrics_refs,
1166
+ }
1167
+
1168
+ return reference, arrays
1169
+
1170
+ def __init__(
1171
+ self,
1172
+ results: ClusteringResults | dict | None = None,
1173
+ original_timesteps: pd.DatetimeIndex | list[str] | None = None,
1174
+ original_data: xr.Dataset | None = None,
1175
+ aggregated_data: xr.Dataset | None = None,
1176
+ _metrics: xr.Dataset | None = None,
1177
+ # These are for reconstruction from serialization
1178
+ _original_data_refs: list[str] | None = None,
1179
+ _metrics_refs: list[str] | None = None,
1180
+ # Internal: AggregationResult dict for full data access
1181
+ _aggregation_results: dict[tuple, AggregationResult] | None = None,
1182
+ _dim_names: list[str] | None = None,
1183
+ ):
1184
+ """Initialize Clustering object.
1185
+
1186
+ Args:
1187
+ results: ClusteringResults instance, or dict from to_dict() (for deserialization).
1188
+ Not needed if _aggregation_results is provided.
1189
+ original_timesteps: Original timesteps before clustering.
1190
+ original_data: Original dataset before clustering (for expand/plotting).
1191
+ aggregated_data: Aggregated dataset after clustering (for plotting).
1192
+ After loading from file, this is reconstructed from FlowSystem data.
1193
+ _metrics: Pre-computed metrics dataset.
1194
+ _original_data_refs: Internal: resolved DataArrays from serialization.
1195
+ _metrics_refs: Internal: resolved DataArrays from serialization.
1196
+ _aggregation_results: Internal: dict of AggregationResult for full data access.
1197
+ _dim_names: Internal: dimension names when using _aggregation_results.
1198
+ """
1199
+ # Handle ISO timestamp strings from serialization
1200
+ if (
1201
+ isinstance(original_timesteps, list)
1202
+ and len(original_timesteps) > 0
1203
+ and isinstance(original_timesteps[0], str)
1204
+ ):
1205
+ original_timesteps = pd.DatetimeIndex([pd.Timestamp(ts) for ts in original_timesteps])
1206
+
1207
+ # Store AggregationResults if provided (full data access)
1208
+ self._aggregation_results = _aggregation_results
1209
+ self._dim_names = _dim_names or []
1210
+
1211
+ # Handle results - only needed for serialization path
1212
+ if results is not None:
1213
+ if isinstance(results, dict):
1214
+ results = ClusteringResults.from_dict(results)
1215
+ self._results_cache = results
1216
+ else:
1217
+ self._results_cache = None
1218
+
1219
+ # Flag indicating this was loaded from serialization (missing full AggregationResult data)
1220
+ self._from_serialization = _aggregation_results is None and results is not None
1221
+
1222
+ self.original_timesteps = original_timesteps if original_timesteps is not None else pd.DatetimeIndex([])
1223
+ self._metrics = _metrics
1224
+
1225
+ # Handle reconstructed data from refs (list of DataArrays)
1226
+ if _original_data_refs is not None and isinstance(_original_data_refs, list):
1227
+ # These are resolved DataArrays from the structure resolver
1228
+ if all(isinstance(da, xr.DataArray) for da in _original_data_refs):
1229
+ # Rename 'original_time' back to 'time' and strip 'original_data|' prefix
1230
+ data_vars = {}
1231
+ for da in _original_data_refs:
1232
+ if 'original_time' in da.dims:
1233
+ da = da.rename({'original_time': 'time'})
1234
+ # Strip 'original_data|' prefix from name (added during serialization)
1235
+ name = da.name
1236
+ if name.startswith('original_data|'):
1237
+ name = name[14:] # len('original_data|') = 14
1238
+ data_vars[name] = da.rename(name)
1239
+ self.original_data = xr.Dataset(data_vars)
1240
+ else:
1241
+ self.original_data = original_data
1242
+ else:
1243
+ self.original_data = original_data
1244
+
1245
+ self.aggregated_data = aggregated_data
1246
+
1247
+ if _metrics_refs is not None and isinstance(_metrics_refs, list):
1248
+ if all(isinstance(da, xr.DataArray) for da in _metrics_refs):
1249
+ # Strip 'metrics|' prefix from name (added during serialization)
1250
+ data_vars = {}
1251
+ for da in _metrics_refs:
1252
+ name = da.name
1253
+ if name.startswith('metrics|'):
1254
+ name = name[8:] # len('metrics|') = 8
1255
+ data_vars[name] = da.rename(name)
1256
+ self._metrics = xr.Dataset(data_vars)
1257
+
1258
+ @property
1259
+ def results(self) -> ClusteringResults:
1260
+ """ClusteringResults for structure access (derived from AggregationResults or cached)."""
1261
+ if self._results_cache is not None:
1262
+ return self._results_cache
1263
+ if self._aggregation_results is not None:
1264
+ # Derive from AggregationResults (cached on first access)
1265
+ self._results_cache = ClusteringResults(
1266
+ {k: r.clustering for k, r in self._aggregation_results.items()},
1267
+ self._dim_names,
1268
+ )
1269
+ return self._results_cache
1270
+ raise ValueError('No results available - neither AggregationResults nor ClusteringResults set')
1271
+
1272
+ @classmethod
1273
+ def _from_aggregation_results(
1274
+ cls,
1275
+ aggregation_results: dict[tuple, AggregationResult],
1276
+ dim_names: list[str],
1277
+ original_timesteps: pd.DatetimeIndex | None = None,
1278
+ original_data: xr.Dataset | None = None,
1279
+ ) -> Clustering:
1280
+ """Create Clustering from AggregationResult dict.
1281
+
1282
+ This is the primary way to create a Clustering with full data access.
1283
+ Called by ClusteringResults.apply() and TransformAccessor.
1284
+
1285
+ Args:
1286
+ aggregation_results: Dict mapping (period, scenario) tuples to AggregationResult.
1287
+ dim_names: Dimension names, e.g., ['period', 'scenario'].
1288
+ original_timesteps: Original timesteps (optional, for expand).
1289
+ original_data: Original dataset (optional, for plotting).
1290
+
1291
+ Returns:
1292
+ Clustering with full AggregationResult access.
1293
+ """
1294
+ return cls(
1295
+ original_timesteps=original_timesteps,
1296
+ original_data=original_data,
1297
+ _aggregation_results=aggregation_results,
1298
+ _dim_names=dim_names,
1299
+ )
1300
+
1301
+ # ==========================================================================
1302
+ # Iteration over AggregationResults (for direct access to tsam results)
1303
+ # ==========================================================================
1304
+
1305
+ def __iter__(self):
1306
+ """Iterate over (key, AggregationResult) pairs.
1307
+
1308
+ Raises:
1309
+ ValueError: If accessed on a Clustering loaded from JSON.
1310
+ """
1311
+ self._require_full_data('iteration')
1312
+ return iter(self._aggregation_results.items())
1313
+
1314
+ def __len__(self) -> int:
1315
+ """Number of (period, scenario) combinations."""
1316
+ if self._aggregation_results is not None:
1317
+ return len(self._aggregation_results)
1318
+ return len(list(self.results.keys()))
1319
+
1320
+ def __getitem__(self, key: tuple) -> AggregationResult:
1321
+ """Get AggregationResult by (period, scenario) key.
1322
+
1323
+ Raises:
1324
+ ValueError: If accessed on a Clustering loaded from JSON.
1325
+ """
1326
+ self._require_full_data('item access')
1327
+ return self._aggregation_results[key]
1328
+
1329
+ def items(self):
1330
+ """Iterate over (key, AggregationResult) pairs.
1331
+
1332
+ Raises:
1333
+ ValueError: If accessed on a Clustering loaded from JSON.
1334
+ """
1335
+ self._require_full_data('items()')
1336
+ return self._aggregation_results.items()
1337
+
1338
+ def keys(self):
1339
+ """Iterate over (period, scenario) keys."""
1340
+ if self._aggregation_results is not None:
1341
+ return self._aggregation_results.keys()
1342
+ return self.results.keys()
1343
+
1344
+ def values(self):
1345
+ """Iterate over AggregationResult objects.
1346
+
1347
+ Raises:
1348
+ ValueError: If accessed on a Clustering loaded from JSON.
1349
+ """
1350
+ self._require_full_data('values()')
1351
+ return self._aggregation_results.values()
1352
+
1353
+ def _require_full_data(self, operation: str) -> None:
1354
+ """Raise error if full AggregationResult data is not available."""
1355
+ if self._from_serialization:
1356
+ raise ValueError(
1357
+ f'{operation} requires full AggregationResult data, '
1358
+ f'but this Clustering was loaded from JSON. '
1359
+ f'Use apply_clustering() to get full results.'
1360
+ )
1361
+
1362
+ def __repr__(self) -> str:
1363
+ return (
1364
+ f'Clustering(\n'
1365
+ f' {self.n_original_clusters} periods → {self.n_clusters} clusters\n'
1366
+ f' timesteps_per_cluster={self.timesteps_per_cluster}\n'
1367
+ f' dims={self.dim_names}\n'
1368
+ f')'
1369
+ )
1370
+
1371
+
1372
+ class ClusteringPlotAccessor:
1373
+ """Plot accessor for Clustering objects.
1374
+
1375
+ Provides visualization methods for comparing original vs aggregated data
1376
+ and understanding the clustering structure.
1377
+ """
1378
+
1379
+ def __init__(self, clustering: Clustering):
1380
+ self._clustering = clustering
1381
+
1382
+ def compare(
1383
+ self,
1384
+ kind: str = 'timeseries',
1385
+ variables: str | list[str] | None = None,
1386
+ *,
1387
+ select: SelectType | None = None,
1388
+ colors: ColorType | None = None,
1389
+ show: bool | None = None,
1390
+ data_only: bool = False,
1391
+ **plotly_kwargs: Any,
1392
+ ) -> PlotResult:
1393
+ """Compare original vs aggregated data.
1394
+
1395
+ Args:
1396
+ kind: Type of comparison plot.
1397
+ - 'timeseries': Time series comparison (default)
1398
+ - 'duration_curve': Sorted duration curve comparison
1399
+ variables: Variable(s) to plot. Can be a string, list of strings,
1400
+ or None to plot all time-varying variables.
1401
+ select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}.
1402
+ colors: Color specification (colorscale name, color list, or label-to-color dict).
1403
+ show: Whether to display the figure.
1404
+ Defaults to CONFIG.Plotting.default_show.
1405
+ data_only: If True, skip figure creation and return only data.
1406
+ **plotly_kwargs: Additional arguments passed to plotly (e.g., color, line_dash,
1407
+ facet_col, facet_row). Defaults: x='time'/'duration', color='variable',
1408
+ line_dash='representation', symbol=None.
1409
+
1410
+ Returns:
1411
+ PlotResult containing the comparison figure and underlying data.
1412
+ """
1413
+ import plotly.graph_objects as go
1414
+
1415
+ from ..config import CONFIG
1416
+ from ..plot_result import PlotResult
1417
+ from ..statistics_accessor import _apply_selection
1418
+
1419
+ if kind not in ('timeseries', 'duration_curve'):
1420
+ raise ValueError(f"Unknown kind '{kind}'. Use 'timeseries' or 'duration_curve'.")
1421
+
1422
+ clustering = self._clustering
1423
+ if clustering.original_data is None or clustering.aggregated_data is None:
1424
+ raise ValueError('No original/aggregated data available for comparison')
1425
+
1426
+ resolved_variables = self._resolve_variables(variables)
1427
+
1428
+ # Build Dataset with variables as data_vars
1429
+ data_vars = {}
1430
+ for var in resolved_variables:
1431
+ original = clustering.original_data[var]
1432
+ clustered = clustering.expand_data(clustering.aggregated_data[var])
1433
+ combined = xr.concat([original, clustered], dim=pd.Index(['Original', 'Clustered'], name='representation'))
1434
+ data_vars[var] = combined
1435
+ ds = xr.Dataset(data_vars)
1436
+
1437
+ ds = _apply_selection(ds, select)
1438
+
1439
+ if kind == 'duration_curve':
1440
+ sorted_vars = {}
1441
+ # Use variables for faster access (avoids _construct_dataarray overhead)
1442
+ variables = ds.variables
1443
+ rep_values = ds.coords['representation'].values
1444
+ rep_idx = {rep: i for i, rep in enumerate(rep_values)}
1445
+ for var in ds.data_vars:
1446
+ data = variables[var].values
1447
+ for rep in rep_values:
1448
+ # Direct numpy indexing instead of .sel()
1449
+ values = np.sort(data[rep_idx[rep]].flatten())[::-1]
1450
+ sorted_vars[(var, rep)] = values
1451
+ # Get length from first sorted array
1452
+ n = len(next(iter(sorted_vars.values())))
1453
+ ds = xr.Dataset(
1454
+ {
1455
+ var: xr.DataArray(
1456
+ [sorted_vars[(var, r)] for r in ['Original', 'Clustered']],
1457
+ dims=['representation', 'duration'],
1458
+ coords={'representation': ['Original', 'Clustered'], 'duration': range(n)},
1459
+ )
1460
+ for var in resolved_variables
1461
+ }
1462
+ )
1463
+
1464
+ title = (
1465
+ (
1466
+ 'Original vs Clustered'
1467
+ if len(resolved_variables) > 1
1468
+ else f'Original vs Clustered: {resolved_variables[0]}'
1469
+ )
1470
+ if kind == 'timeseries'
1471
+ else ('Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}')
1472
+ )
1473
+
1474
+ # Early return for data_only mode
1475
+ if data_only:
1476
+ return PlotResult(data=ds, figure=go.Figure())
1477
+
1478
+ # Apply slot defaults
1479
+ defaults = {
1480
+ 'x': 'duration' if kind == 'duration_curve' else 'time',
1481
+ 'color': 'variable',
1482
+ 'line_dash': 'representation',
1483
+ 'line_dash_map': {'Original': 'dot', 'Clustered': 'solid'},
1484
+ 'symbol': None, # Block symbol slot
1485
+ }
1486
+ _apply_slot_defaults(plotly_kwargs, defaults)
1487
+
1488
+ color_kwargs = _build_color_kwargs(colors, list(ds.data_vars))
1489
+ fig = ds.plotly.line(
1490
+ title=title,
1491
+ **color_kwargs,
1492
+ **plotly_kwargs,
1493
+ )
1494
+ fig.update_yaxes(matches=None)
1495
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
1496
+
1497
+ plot_result = PlotResult(data=ds, figure=fig)
1498
+
1499
+ if show is None:
1500
+ show = CONFIG.Plotting.default_show
1501
+ if show:
1502
+ plot_result.show()
1503
+
1504
+ return plot_result
1505
+
1506
+ def _get_time_varying_variables(self) -> list[str]:
1507
+ """Get list of time-varying variables from original data that also exist in aggregated data."""
1508
+ if self._clustering.original_data is None:
1509
+ return []
1510
+ # Get variables that exist in both original and aggregated data
1511
+ aggregated_vars = (
1512
+ set(self._clustering.aggregated_data.data_vars)
1513
+ if self._clustering.aggregated_data is not None
1514
+ else set(self._clustering.original_data.data_vars)
1515
+ )
1516
+ return [
1517
+ name
1518
+ for name in self._clustering.original_data.data_vars
1519
+ if name in aggregated_vars
1520
+ and 'time' in self._clustering.original_data[name].dims
1521
+ and not np.isclose(
1522
+ self._clustering.original_data[name].min(),
1523
+ self._clustering.original_data[name].max(),
1524
+ )
1525
+ ]
1526
+
1527
+ def _resolve_variables(self, variables: str | list[str] | None) -> list[str]:
1528
+ """Resolve variables parameter to a list of valid variable names."""
1529
+ time_vars = self._get_time_varying_variables()
1530
+ if not time_vars:
1531
+ raise ValueError('No time-varying variables found')
1532
+
1533
+ if variables is None:
1534
+ return time_vars
1535
+ elif isinstance(variables, str):
1536
+ if variables not in time_vars:
1537
+ raise ValueError(f"Variable '{variables}' not found. Available: {time_vars}")
1538
+ return [variables]
1539
+ else:
1540
+ invalid = [v for v in variables if v not in time_vars]
1541
+ if invalid:
1542
+ raise ValueError(f'Variables {invalid} not found. Available: {time_vars}')
1543
+ return list(variables)
1544
+
1545
+ def heatmap(
1546
+ self,
1547
+ *,
1548
+ select: SelectType | None = None,
1549
+ colors: str | list[str] | None = None,
1550
+ show: bool | None = None,
1551
+ data_only: bool = False,
1552
+ **plotly_kwargs: Any,
1553
+ ) -> PlotResult:
1554
+ """Plot cluster assignments over time as a heatmap timeline.
1555
+
1556
+ Shows which cluster each timestep belongs to as a horizontal color bar.
1557
+ The x-axis is time, color indicates cluster assignment. This visualization
1558
+ aligns with time series data, making it easy to correlate cluster
1559
+ assignments with other plots.
1560
+
1561
+ For multi-period/scenario data, uses faceting and/or animation.
1562
+
1563
+ Args:
1564
+ select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}.
1565
+ colors: Colorscale name (str) or list of colors for heatmap coloring.
1566
+ Dicts are not supported for heatmaps.
1567
+ Defaults to plotly template's sequential colorscale.
1568
+ show: Whether to display the figure.
1569
+ Defaults to CONFIG.Plotting.default_show.
1570
+ data_only: If True, skip figure creation and return only data.
1571
+ **plotly_kwargs: Additional arguments passed to plotly (e.g., facet_col, animation_frame).
1572
+
1573
+ Returns:
1574
+ PlotResult containing the heatmap figure and cluster assignment data.
1575
+ The data has 'cluster' variable with time dimension, matching original timesteps.
1576
+ """
1577
+ import plotly.graph_objects as go
1578
+
1579
+ from ..config import CONFIG
1580
+ from ..plot_result import PlotResult
1581
+ from ..statistics_accessor import _apply_selection
1582
+
1583
+ clustering = self._clustering
1584
+ cluster_assignments = clustering.cluster_assignments
1585
+ timesteps_per_cluster = clustering.timesteps_per_cluster
1586
+ original_time = clustering.original_timesteps
1587
+
1588
+ if select:
1589
+ cluster_assignments = _apply_selection(cluster_assignments.to_dataset(name='cluster'), select)['cluster']
1590
+
1591
+ # Expand cluster_assignments to per-timestep
1592
+ extra_dims = [d for d in cluster_assignments.dims if d != 'original_cluster']
1593
+ expanded_values = np.repeat(cluster_assignments.values, timesteps_per_cluster, axis=0)
1594
+
1595
+ coords = {'time': original_time}
1596
+ coords.update({d: cluster_assignments.coords[d].values for d in extra_dims})
1597
+ cluster_da = xr.DataArray(expanded_values, dims=['time'] + extra_dims, coords=coords)
1598
+ cluster_da.name = 'cluster'
1599
+
1600
+ # Early return for data_only mode
1601
+ if data_only:
1602
+ return PlotResult(data=xr.Dataset({'cluster': cluster_da}), figure=go.Figure())
1603
+
1604
+ heatmap_da = cluster_da.expand_dims('y', axis=-1).assign_coords(y=['Cluster'])
1605
+ heatmap_da.name = 'cluster_assignment'
1606
+ heatmap_da = heatmap_da.transpose('time', 'y', ...)
1607
+
1608
+ # Use plotly.imshow for heatmap
1609
+ # Only pass color_continuous_scale if explicitly provided (template handles default)
1610
+ if colors is not None:
1611
+ plotly_kwargs.setdefault('color_continuous_scale', colors)
1612
+ fig = heatmap_da.plotly.imshow(
1613
+ title='Cluster Assignments',
1614
+ aspect='auto',
1615
+ **plotly_kwargs,
1616
+ )
1617
+
1618
+ fig.update_yaxes(showticklabels=False)
1619
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
1620
+
1621
+ # Data is exactly what we plotted (without dummy y dimension)
1622
+ data = xr.Dataset({'cluster': cluster_da})
1623
+ plot_result = PlotResult(data=data, figure=fig)
1624
+
1625
+ if show is None:
1626
+ show = CONFIG.Plotting.default_show
1627
+ if show:
1628
+ plot_result.show()
1629
+
1630
+ return plot_result
1631
+
1632
+ def clusters(
1633
+ self,
1634
+ variables: str | list[str] | None = None,
1635
+ *,
1636
+ select: SelectType | None = None,
1637
+ colors: ColorType | None = None,
1638
+ show: bool | None = None,
1639
+ data_only: bool = False,
1640
+ **plotly_kwargs: Any,
1641
+ ) -> PlotResult:
1642
+ """Plot each cluster's typical period profile.
1643
+
1644
+ Shows each cluster as a separate faceted subplot with all variables
1645
+ colored differently. Useful for understanding what each cluster represents.
1646
+
1647
+ Args:
1648
+ variables: Variable(s) to plot. Can be a string, list of strings,
1649
+ or None to plot all time-varying variables.
1650
+ select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}.
1651
+ colors: Color specification (colorscale name, color list, or label-to-color dict).
1652
+ show: Whether to display the figure.
1653
+ Defaults to CONFIG.Plotting.default_show.
1654
+ data_only: If True, skip figure creation and return only data.
1655
+ **plotly_kwargs: Additional arguments passed to plotly (e.g., color, facet_col,
1656
+ facet_col_wrap). Defaults: x='time', color='variable', symbol=None.
1657
+
1658
+ Returns:
1659
+ PlotResult containing the figure and underlying data.
1660
+ """
1661
+ import plotly.graph_objects as go
1662
+
1663
+ from ..config import CONFIG
1664
+ from ..plot_result import PlotResult
1665
+ from ..statistics_accessor import _apply_selection
1666
+
1667
+ clustering = self._clustering
1668
+ if clustering.aggregated_data is None:
1669
+ raise ValueError('No aggregated data available')
1670
+
1671
+ aggregated_data = _apply_selection(clustering.aggregated_data, select)
1672
+ resolved_variables = self._resolve_variables(variables)
1673
+
1674
+ n_clusters = clustering.n_clusters
1675
+ timesteps_per_cluster = clustering.timesteps_per_cluster
1676
+ cluster_occurrences = clustering.cluster_occurrences
1677
+
1678
+ # Build cluster labels
1679
+ occ_extra_dims = [d for d in cluster_occurrences.dims if d != 'cluster']
1680
+ if occ_extra_dims:
1681
+ cluster_labels = [f'Cluster {c}' for c in range(n_clusters)]
1682
+ else:
1683
+ cluster_labels = [
1684
+ f'Cluster {c} (×{int(cluster_occurrences.sel(cluster=c).values)})' for c in range(n_clusters)
1685
+ ]
1686
+
1687
+ data_vars = {}
1688
+ for var in resolved_variables:
1689
+ da = aggregated_data[var]
1690
+ if 'cluster' in da.dims:
1691
+ data_by_cluster = da.values
1692
+ else:
1693
+ data_by_cluster = da.values.reshape(n_clusters, timesteps_per_cluster)
1694
+ data_vars[var] = xr.DataArray(
1695
+ data_by_cluster,
1696
+ dims=['cluster', 'time'],
1697
+ coords={'cluster': cluster_labels, 'time': range(timesteps_per_cluster)},
1698
+ )
1699
+
1700
+ ds = xr.Dataset(data_vars)
1701
+
1702
+ # Early return for data_only mode (include occurrences in result)
1703
+ if data_only:
1704
+ data_vars['occurrences'] = cluster_occurrences
1705
+ return PlotResult(data=xr.Dataset(data_vars), figure=go.Figure())
1706
+
1707
+ title = 'Clusters' if len(resolved_variables) > 1 else f'Clusters: {resolved_variables[0]}'
1708
+
1709
+ # Apply slot defaults
1710
+ defaults = {
1711
+ 'x': 'time',
1712
+ 'color': 'variable',
1713
+ 'symbol': None, # Block symbol slot
1714
+ }
1715
+ _apply_slot_defaults(plotly_kwargs, defaults)
1716
+
1717
+ color_kwargs = _build_color_kwargs(colors, list(ds.data_vars))
1718
+ fig = ds.plotly.line(
1719
+ title=title,
1720
+ **color_kwargs,
1721
+ **plotly_kwargs,
1722
+ )
1723
+ fig.update_yaxes(matches=None)
1724
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
1725
+
1726
+ data_vars['occurrences'] = cluster_occurrences
1727
+ result_data = xr.Dataset(data_vars)
1728
+ plot_result = PlotResult(data=result_data, figure=fig)
1729
+
1730
+ if show is None:
1731
+ show = CONFIG.Plotting.default_show
1732
+ if show:
1733
+ plot_result.show()
1734
+
1735
+ return plot_result
1736
+
1737
+
1738
+ # Backwards compatibility alias
1739
+ AggregationResults = Clustering
1740
+
1741
+
1742
+ def _register_clustering_classes():
1743
+ """Register clustering classes for IO."""
1744
+ from ..structure import CLASS_REGISTRY
1745
+
1746
+ CLASS_REGISTRY['Clustering'] = Clustering