flixopt 3.0.1__py3-none-any.whl → 6.0.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flixopt/__init__.py +57 -49
  2. flixopt/carrier.py +159 -0
  3. flixopt/clustering/__init__.py +51 -0
  4. flixopt/clustering/base.py +1746 -0
  5. flixopt/clustering/intercluster_helpers.py +201 -0
  6. flixopt/color_processing.py +372 -0
  7. flixopt/comparison.py +819 -0
  8. flixopt/components.py +848 -270
  9. flixopt/config.py +853 -496
  10. flixopt/core.py +111 -98
  11. flixopt/effects.py +294 -284
  12. flixopt/elements.py +484 -223
  13. flixopt/features.py +220 -118
  14. flixopt/flow_system.py +2026 -389
  15. flixopt/interface.py +504 -286
  16. flixopt/io.py +1718 -55
  17. flixopt/linear_converters.py +291 -230
  18. flixopt/modeling.py +304 -181
  19. flixopt/network_app.py +2 -1
  20. flixopt/optimization.py +788 -0
  21. flixopt/optimize_accessor.py +373 -0
  22. flixopt/plot_result.py +143 -0
  23. flixopt/plotting.py +1177 -1034
  24. flixopt/results.py +1331 -372
  25. flixopt/solvers.py +12 -4
  26. flixopt/statistics_accessor.py +2412 -0
  27. flixopt/stats_accessor.py +75 -0
  28. flixopt/structure.py +954 -120
  29. flixopt/topology_accessor.py +676 -0
  30. flixopt/transform_accessor.py +2277 -0
  31. flixopt/types.py +120 -0
  32. flixopt-6.0.0rc7.dist-info/METADATA +290 -0
  33. flixopt-6.0.0rc7.dist-info/RECORD +36 -0
  34. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/WHEEL +1 -1
  35. flixopt/aggregation.py +0 -382
  36. flixopt/calculation.py +0 -672
  37. flixopt/commons.py +0 -51
  38. flixopt/utils.py +0 -86
  39. flixopt-3.0.1.dist-info/METADATA +0 -209
  40. flixopt-3.0.1.dist-info/RECORD +0 -26
  41. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/licenses/LICENSE +0 -0
  42. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2277 @@
1
+ """
2
+ Transform accessor for FlowSystem.
3
+
4
+ This module provides the TransformAccessor class that enables
5
+ transformations on FlowSystem like clustering, selection, and resampling.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import warnings
12
+ from collections import defaultdict
13
+ from typing import TYPE_CHECKING, Any, Literal
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ import xarray as xr
18
+
19
+ from .modeling import _scalar_safe_reduce
20
+ from .structure import EXPAND_DIVIDE, EXPAND_INTERPOLATE, VariableCategory
21
+
22
+ if TYPE_CHECKING:
23
+ from tsam import ClusterConfig, ExtremeConfig, SegmentConfig
24
+
25
+ from .clustering import Clustering
26
+ from .flow_system import FlowSystem
27
+
28
+ logger = logging.getLogger('flixopt')
29
+
30
+
31
+ class TransformAccessor:
32
+ """
33
+ Accessor for transformation methods on FlowSystem.
34
+
35
+ This class provides transformations that create new FlowSystem instances
36
+ with modified structure or data, accessible via `flow_system.transform`.
37
+
38
+ Examples:
39
+ Time series aggregation (8 typical days):
40
+
41
+ >>> reduced_fs = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D')
42
+ >>> reduced_fs.optimize(solver)
43
+ >>> expanded_fs = reduced_fs.transform.expand()
44
+
45
+ Future MGA:
46
+
47
+ >>> mga_fs = flow_system.transform.mga(alternatives=5)
48
+ >>> mga_fs.optimize(solver)
49
+ """
50
+
51
+ def __init__(self, flow_system: FlowSystem) -> None:
52
+ """
53
+ Initialize the accessor with a reference to the FlowSystem.
54
+
55
+ Args:
56
+ flow_system: The FlowSystem to transform.
57
+ """
58
+ self._fs = flow_system
59
+
60
+ @staticmethod
61
+ def _calculate_clustering_weights(ds) -> dict[str, float]:
62
+ """Calculate weights for clustering based on dataset attributes."""
63
+ from collections import Counter
64
+
65
+ import numpy as np
66
+
67
+ groups = [da.attrs.get('clustering_group') for da in ds.data_vars.values() if 'clustering_group' in da.attrs]
68
+ group_counts = Counter(groups)
69
+
70
+ # Calculate weight for each group (1/count)
71
+ group_weights = {group: 1 / count for group, count in group_counts.items()}
72
+
73
+ weights = {}
74
+ variables = ds.variables
75
+ for name in ds.data_vars:
76
+ var_attrs = variables[name].attrs
77
+ clustering_group = var_attrs.get('clustering_group')
78
+ group_weight = group_weights.get(clustering_group)
79
+ if group_weight is not None:
80
+ weights[name] = group_weight
81
+ else:
82
+ weights[name] = var_attrs.get('clustering_weight', 1)
83
+
84
+ if np.all(np.isclose(list(weights.values()), 1, atol=1e-6)):
85
+ logger.debug('All Clustering weights were set to 1')
86
+
87
+ return weights
88
+
89
+ @staticmethod
90
+ def _build_cluster_config_with_weights(
91
+ cluster: ClusterConfig | None,
92
+ auto_weights: dict[str, float],
93
+ ) -> ClusterConfig:
94
+ """Merge auto-calculated weights into ClusterConfig.
95
+
96
+ Args:
97
+ cluster: Optional user-provided ClusterConfig.
98
+ auto_weights: Automatically calculated weights based on data variance.
99
+
100
+ Returns:
101
+ ClusterConfig with weights set (either user-provided or auto-calculated).
102
+ """
103
+ from tsam import ClusterConfig
104
+
105
+ # User provided ClusterConfig with weights - use as-is
106
+ if cluster is not None and cluster.weights is not None:
107
+ return cluster
108
+
109
+ # No ClusterConfig provided - use defaults with auto-calculated weights
110
+ if cluster is None:
111
+ return ClusterConfig(weights=auto_weights)
112
+
113
+ # ClusterConfig provided without weights - add auto-calculated weights
114
+ return ClusterConfig(
115
+ method=cluster.method,
116
+ representation=cluster.representation,
117
+ weights=auto_weights,
118
+ normalize_column_means=cluster.normalize_column_means,
119
+ use_duration_curves=cluster.use_duration_curves,
120
+ include_period_sums=cluster.include_period_sums,
121
+ solver=cluster.solver,
122
+ )
123
+
124
+ @staticmethod
125
+ def _accuracy_to_dataframe(accuracy) -> pd.DataFrame:
126
+ """Convert tsam AccuracyMetrics to DataFrame.
127
+
128
+ Args:
129
+ accuracy: tsam AccuracyMetrics object.
130
+
131
+ Returns:
132
+ DataFrame with RMSE, MAE, and RMSE_duration columns.
133
+ """
134
+ return pd.DataFrame(
135
+ {
136
+ 'RMSE': accuracy.rmse,
137
+ 'MAE': accuracy.mae,
138
+ 'RMSE_duration': accuracy.rmse_duration,
139
+ }
140
+ )
141
+
142
+ def _build_cluster_weight_da(
143
+ self,
144
+ cluster_occurrences_all: dict[tuple, dict],
145
+ n_clusters: int,
146
+ cluster_coords: np.ndarray,
147
+ periods: list,
148
+ scenarios: list,
149
+ ) -> xr.DataArray:
150
+ """Build cluster_weight DataArray from occurrence counts.
151
+
152
+ Args:
153
+ cluster_occurrences_all: Dict mapping (period, scenario) tuples to
154
+ dicts of {cluster_id: occurrence_count}.
155
+ n_clusters: Number of clusters.
156
+ cluster_coords: Cluster coordinate values.
157
+ periods: List of period labels ([None] if no periods dimension).
158
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
159
+
160
+ Returns:
161
+ DataArray with dims [cluster] or [cluster, period?, scenario?].
162
+ """
163
+
164
+ def _weight_for_key(key: tuple) -> xr.DataArray:
165
+ occurrences = cluster_occurrences_all[key]
166
+ # Missing clusters contribute zero weight (not 1)
167
+ weights = np.array([occurrences.get(c, 0) for c in range(n_clusters)])
168
+ return xr.DataArray(weights, dims=['cluster'], coords={'cluster': cluster_coords})
169
+
170
+ weight_slices = {key: _weight_for_key(key) for key in cluster_occurrences_all}
171
+ return self._combine_slices_to_dataarray_generic(
172
+ weight_slices, ['cluster'], periods, scenarios, 'cluster_weight'
173
+ )
174
+
175
+ def _build_typical_das(
176
+ self,
177
+ tsam_aggregation_results: dict[tuple, Any],
178
+ actual_n_clusters: int,
179
+ n_time_points: int,
180
+ cluster_coords: np.ndarray,
181
+ time_coords: pd.DatetimeIndex | pd.RangeIndex,
182
+ is_segmented: bool = False,
183
+ ) -> dict[str, dict[tuple, xr.DataArray]]:
184
+ """Build typical periods DataArrays with (cluster, time) shape.
185
+
186
+ Args:
187
+ tsam_aggregation_results: Dict mapping (period, scenario) to tsam results.
188
+ actual_n_clusters: Number of clusters.
189
+ n_time_points: Number of time points per cluster (timesteps or segments).
190
+ cluster_coords: Cluster coordinate values.
191
+ time_coords: Time coordinate values.
192
+ is_segmented: Whether segmentation was used.
193
+
194
+ Returns:
195
+ Nested dict: {column_name: {(period, scenario): DataArray}}.
196
+ """
197
+ typical_das: dict[str, dict[tuple, xr.DataArray]] = {}
198
+ for key, tsam_result in tsam_aggregation_results.items():
199
+ typical_df = tsam_result.cluster_representatives
200
+ if is_segmented:
201
+ # Segmented data: MultiIndex with cluster as first level
202
+ # Each cluster has exactly n_time_points rows (segments)
203
+ # Extract all data at once using numpy reshape, avoiding slow .loc calls
204
+ columns = typical_df.columns.tolist()
205
+
206
+ # Get all values as numpy array: (n_clusters * n_time_points, n_columns)
207
+ all_values = typical_df.values
208
+
209
+ # Reshape to (n_clusters, n_time_points, n_columns)
210
+ reshaped = all_values.reshape(actual_n_clusters, n_time_points, -1)
211
+
212
+ for col_idx, col in enumerate(columns):
213
+ # reshaped[:, :, col_idx] selects all clusters, all time points, single column
214
+ # Result shape: (n_clusters, n_time_points)
215
+ typical_das.setdefault(col, {})[key] = xr.DataArray(
216
+ reshaped[:, :, col_idx],
217
+ dims=['cluster', 'time'],
218
+ coords={'cluster': cluster_coords, 'time': time_coords},
219
+ )
220
+ else:
221
+ # Non-segmented: flat data that can be reshaped
222
+ for col in typical_df.columns:
223
+ flat_data = typical_df[col].values
224
+ reshaped = flat_data.reshape(actual_n_clusters, n_time_points)
225
+ typical_das.setdefault(col, {})[key] = xr.DataArray(
226
+ reshaped,
227
+ dims=['cluster', 'time'],
228
+ coords={'cluster': cluster_coords, 'time': time_coords},
229
+ )
230
+ return typical_das
231
+
232
+ def _build_segment_durations_da(
233
+ self,
234
+ tsam_aggregation_results: dict[tuple, Any],
235
+ actual_n_clusters: int,
236
+ n_segments: int,
237
+ cluster_coords: np.ndarray,
238
+ time_coords: pd.RangeIndex,
239
+ dt: float,
240
+ periods: list,
241
+ scenarios: list,
242
+ ) -> xr.DataArray:
243
+ """Build timestep_duration DataArray from segment durations.
244
+
245
+ For segmented systems, each segment represents multiple original timesteps.
246
+ The duration is segment_duration_in_original_timesteps * dt (hours per original timestep).
247
+
248
+ Args:
249
+ tsam_aggregation_results: Dict mapping (period, scenario) to tsam results.
250
+ actual_n_clusters: Number of clusters.
251
+ n_segments: Number of segments per cluster.
252
+ cluster_coords: Cluster coordinate values.
253
+ time_coords: Time coordinate values (RangeIndex for segments).
254
+ dt: Hours per original timestep.
255
+ periods: List of period labels ([None] if no periods dimension).
256
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
257
+
258
+ Returns:
259
+ DataArray with dims [cluster, time] or [cluster, time, period?, scenario?]
260
+ containing duration in hours for each segment.
261
+ """
262
+ segment_duration_slices: dict[tuple, xr.DataArray] = {}
263
+
264
+ for key, tsam_result in tsam_aggregation_results.items():
265
+ # segment_durations is tuple of tuples: ((dur1, dur2, ...), (dur1, dur2, ...), ...)
266
+ # Each inner tuple is durations for one cluster
267
+ seg_durs = tsam_result.segment_durations
268
+
269
+ # Build 2D array (cluster, segment) of durations in hours
270
+ data = np.zeros((actual_n_clusters, n_segments))
271
+ for cluster_id in range(actual_n_clusters):
272
+ cluster_seg_durs = seg_durs[cluster_id]
273
+ for seg_id in range(n_segments):
274
+ # Duration in hours = number of original timesteps * dt
275
+ data[cluster_id, seg_id] = cluster_seg_durs[seg_id] * dt
276
+
277
+ segment_duration_slices[key] = xr.DataArray(
278
+ data,
279
+ dims=['cluster', 'time'],
280
+ coords={'cluster': cluster_coords, 'time': time_coords},
281
+ )
282
+
283
+ return self._combine_slices_to_dataarray_generic(
284
+ segment_duration_slices, ['cluster', 'time'], periods, scenarios, 'timestep_duration'
285
+ )
286
+
287
+ def _build_clustering_metrics(
288
+ self,
289
+ clustering_metrics_all: dict[tuple, pd.DataFrame],
290
+ periods: list,
291
+ scenarios: list,
292
+ ) -> xr.Dataset:
293
+ """Build clustering metrics Dataset from per-slice DataFrames.
294
+
295
+ Args:
296
+ clustering_metrics_all: Dict mapping (period, scenario) to metric DataFrames.
297
+ periods: List of period labels ([None] if no periods dimension).
298
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
299
+
300
+ Returns:
301
+ Dataset with RMSE, MAE, RMSE_duration metrics.
302
+ """
303
+ non_empty_metrics = {k: v for k, v in clustering_metrics_all.items() if not v.empty}
304
+
305
+ if not non_empty_metrics:
306
+ return xr.Dataset()
307
+
308
+ first_key = (periods[0], scenarios[0])
309
+
310
+ if len(clustering_metrics_all) == 1 and len(non_empty_metrics) == 1:
311
+ metrics_df = non_empty_metrics.get(first_key)
312
+ if metrics_df is None:
313
+ metrics_df = next(iter(non_empty_metrics.values()))
314
+ return xr.Dataset(
315
+ {
316
+ col: xr.DataArray(
317
+ metrics_df[col].values,
318
+ dims=['time_series'],
319
+ coords={'time_series': metrics_df.index},
320
+ )
321
+ for col in metrics_df.columns
322
+ }
323
+ )
324
+
325
+ # Multi-dim case
326
+ sample_df = next(iter(non_empty_metrics.values()))
327
+ metric_names = list(sample_df.columns)
328
+ data_vars = {}
329
+
330
+ for metric in metric_names:
331
+ slices = {}
332
+ for (p, s), df in clustering_metrics_all.items():
333
+ if df.empty:
334
+ slices[(p, s)] = xr.DataArray(
335
+ np.full(len(sample_df.index), np.nan),
336
+ dims=['time_series'],
337
+ coords={'time_series': list(sample_df.index)},
338
+ )
339
+ else:
340
+ slices[(p, s)] = xr.DataArray(
341
+ df[metric].values,
342
+ dims=['time_series'],
343
+ coords={'time_series': list(df.index)},
344
+ )
345
+ data_vars[metric] = self._combine_slices_to_dataarray_generic(
346
+ slices, ['time_series'], periods, scenarios, metric
347
+ )
348
+
349
+ return xr.Dataset(data_vars)
350
+
351
+ def _build_reduced_flow_system(
352
+ self,
353
+ ds: xr.Dataset,
354
+ tsam_aggregation_results: dict[tuple, Any],
355
+ cluster_occurrences_all: dict[tuple, dict],
356
+ clustering_metrics_all: dict[tuple, pd.DataFrame],
357
+ timesteps_per_cluster: int,
358
+ dt: float,
359
+ periods: list,
360
+ scenarios: list,
361
+ n_clusters_requested: int | None = None,
362
+ ) -> FlowSystem:
363
+ """Build a reduced FlowSystem from tsam aggregation results.
364
+
365
+ This is the shared implementation used by both cluster() and apply_clustering().
366
+
367
+ Args:
368
+ ds: Original dataset.
369
+ tsam_aggregation_results: Dict mapping (period, scenario) to tsam AggregationResult.
370
+ cluster_occurrences_all: Dict mapping (period, scenario) to cluster occurrence counts.
371
+ clustering_metrics_all: Dict mapping (period, scenario) to accuracy metrics.
372
+ timesteps_per_cluster: Number of timesteps per cluster.
373
+ dt: Hours per timestep.
374
+ periods: List of period labels ([None] if no periods).
375
+ scenarios: List of scenario labels ([None] if no scenarios).
376
+ n_clusters_requested: Requested number of clusters (for logging). None to skip.
377
+
378
+ Returns:
379
+ Reduced FlowSystem with clustering metadata attached.
380
+ """
381
+ from .clustering import Clustering
382
+ from .core import drop_constant_arrays
383
+ from .flow_system import FlowSystem
384
+
385
+ has_periods = periods != [None]
386
+ has_scenarios = scenarios != [None]
387
+
388
+ # Build dim_names for Clustering
389
+ dim_names = []
390
+ if has_periods:
391
+ dim_names.append('period')
392
+ if has_scenarios:
393
+ dim_names.append('scenario')
394
+
395
+ # Build dict keyed by (period?, scenario?) tuples (without None)
396
+ aggregation_results: dict[tuple, Any] = {}
397
+ for (p, s), result in tsam_aggregation_results.items():
398
+ key_parts = []
399
+ if has_periods:
400
+ key_parts.append(p)
401
+ if has_scenarios:
402
+ key_parts.append(s)
403
+ key = tuple(key_parts)
404
+ aggregation_results[key] = result
405
+
406
+ # Use first result for structure
407
+ first_key = (periods[0], scenarios[0])
408
+ first_tsam = tsam_aggregation_results[first_key]
409
+
410
+ # Build metrics
411
+ clustering_metrics = self._build_clustering_metrics(clustering_metrics_all, periods, scenarios)
412
+
413
+ n_reduced_timesteps = len(first_tsam.cluster_representatives)
414
+ actual_n_clusters = len(first_tsam.cluster_weights)
415
+
416
+ # Create coordinates for the 2D cluster structure
417
+ cluster_coords = np.arange(actual_n_clusters)
418
+
419
+ # Detect if segmentation was used
420
+ is_segmented = first_tsam.n_segments is not None
421
+ n_segments = first_tsam.n_segments if is_segmented else None
422
+
423
+ # Determine time dimension based on segmentation
424
+ if is_segmented:
425
+ n_time_points = n_segments
426
+ time_coords = pd.RangeIndex(n_time_points, name='time')
427
+ else:
428
+ n_time_points = timesteps_per_cluster
429
+ time_coords = pd.date_range(
430
+ start='2000-01-01',
431
+ periods=timesteps_per_cluster,
432
+ freq=pd.Timedelta(hours=dt),
433
+ name='time',
434
+ )
435
+
436
+ # Build cluster_weight
437
+ cluster_weight = self._build_cluster_weight_da(
438
+ cluster_occurrences_all, actual_n_clusters, cluster_coords, periods, scenarios
439
+ )
440
+
441
+ # Logging
442
+ if is_segmented:
443
+ logger.info(
444
+ f'Reduced from {len(self._fs.timesteps)} to {actual_n_clusters} clusters × {n_segments} segments'
445
+ )
446
+ else:
447
+ logger.info(
448
+ f'Reduced from {len(self._fs.timesteps)} to {actual_n_clusters} clusters × {timesteps_per_cluster} timesteps'
449
+ )
450
+
451
+ # Build typical periods DataArrays with (cluster, time) shape
452
+ typical_das = self._build_typical_das(
453
+ tsam_aggregation_results, actual_n_clusters, n_time_points, cluster_coords, time_coords, is_segmented
454
+ )
455
+
456
+ # Build reduced dataset with (cluster, time) dimensions
457
+ ds_new = self._build_reduced_dataset(
458
+ ds,
459
+ typical_das,
460
+ actual_n_clusters,
461
+ n_reduced_timesteps,
462
+ n_time_points,
463
+ cluster_coords,
464
+ time_coords,
465
+ periods,
466
+ scenarios,
467
+ )
468
+
469
+ # For segmented systems, build timestep_duration from segment_durations
470
+ if is_segmented:
471
+ segment_durations = self._build_segment_durations_da(
472
+ tsam_aggregation_results,
473
+ actual_n_clusters,
474
+ n_segments,
475
+ cluster_coords,
476
+ time_coords,
477
+ dt,
478
+ periods,
479
+ scenarios,
480
+ )
481
+ ds_new['timestep_duration'] = segment_durations
482
+
483
+ reduced_fs = FlowSystem.from_dataset(ds_new)
484
+ reduced_fs.cluster_weight = cluster_weight
485
+
486
+ # Remove 'equals_final' from storages - doesn't make sense on reduced timesteps
487
+ for storage in reduced_fs.storages.values():
488
+ ics = storage.initial_charge_state
489
+ if isinstance(ics, str) and ics == 'equals_final':
490
+ storage.initial_charge_state = None
491
+
492
+ # Create Clustering object with full AggregationResult access
493
+ # Only store time-varying data (constant arrays are clutter for plotting)
494
+ reduced_fs.clustering = Clustering(
495
+ original_timesteps=self._fs.timesteps,
496
+ original_data=drop_constant_arrays(ds, dim='time'),
497
+ aggregated_data=drop_constant_arrays(ds_new, dim='time'),
498
+ _metrics=clustering_metrics if clustering_metrics.data_vars else None,
499
+ _aggregation_results=aggregation_results,
500
+ _dim_names=dim_names,
501
+ )
502
+
503
+ return reduced_fs
504
+
505
+ def _build_reduced_dataset(
506
+ self,
507
+ ds: xr.Dataset,
508
+ typical_das: dict[str, dict[tuple, xr.DataArray]],
509
+ actual_n_clusters: int,
510
+ n_reduced_timesteps: int,
511
+ n_time_points: int,
512
+ cluster_coords: np.ndarray,
513
+ time_coords: pd.DatetimeIndex | pd.RangeIndex,
514
+ periods: list,
515
+ scenarios: list,
516
+ ) -> xr.Dataset:
517
+ """Build the reduced dataset with (cluster, time) structure.
518
+
519
+ Args:
520
+ ds: Original dataset.
521
+ typical_das: Typical periods DataArrays from _build_typical_das().
522
+ actual_n_clusters: Number of clusters.
523
+ n_reduced_timesteps: Total reduced timesteps (n_clusters * n_time_points).
524
+ n_time_points: Number of time points per cluster (timesteps or segments).
525
+ cluster_coords: Cluster coordinate values.
526
+ time_coords: Time coordinate values.
527
+ periods: List of period labels.
528
+ scenarios: List of scenario labels.
529
+
530
+ Returns:
531
+ Dataset with reduced timesteps and (cluster, time) structure.
532
+ """
533
+ from .core import TimeSeriesData
534
+
535
+ all_keys = {(p, s) for p in periods for s in scenarios}
536
+ ds_new_vars = {}
537
+
538
+ # Use ds.variables to avoid _construct_dataarray overhead
539
+ variables = ds.variables
540
+ coord_cache = {k: ds.coords[k].values for k in ds.coords}
541
+
542
+ for name in ds.data_vars:
543
+ var = variables[name]
544
+ if 'time' not in var.dims:
545
+ # No time dimension - wrap Variable in DataArray
546
+ coords = {d: coord_cache[d] for d in var.dims if d in coord_cache}
547
+ ds_new_vars[name] = xr.DataArray(var.values, dims=var.dims, coords=coords, attrs=var.attrs, name=name)
548
+ elif name not in typical_das:
549
+ # Time-dependent but constant: reshape to (cluster, time, ...)
550
+ # Use numpy slicing instead of .isel()
551
+ time_idx = var.dims.index('time')
552
+ slices = [slice(None)] * len(var.dims)
553
+ slices[time_idx] = slice(0, n_reduced_timesteps)
554
+ sliced_values = var.values[tuple(slices)]
555
+
556
+ other_dims = [d for d in var.dims if d != 'time']
557
+ other_shape = [var.sizes[d] for d in other_dims]
558
+ new_shape = [actual_n_clusters, n_time_points] + other_shape
559
+ reshaped = sliced_values.reshape(new_shape)
560
+ new_coords = {'cluster': cluster_coords, 'time': time_coords}
561
+ for dim in other_dims:
562
+ if dim in coord_cache:
563
+ new_coords[dim] = coord_cache[dim]
564
+ ds_new_vars[name] = xr.DataArray(
565
+ reshaped,
566
+ dims=['cluster', 'time'] + other_dims,
567
+ coords=new_coords,
568
+ attrs=var.attrs,
569
+ )
570
+ elif set(typical_das[name].keys()) != all_keys:
571
+ # Partial typical slices: fill missing keys with constant values
572
+ time_idx = var.dims.index('time')
573
+ slices_list = [slice(None)] * len(var.dims)
574
+ slices_list[time_idx] = slice(0, n_reduced_timesteps)
575
+ sliced_values = var.values[tuple(slices_list)]
576
+
577
+ other_dims = [d for d in var.dims if d != 'time']
578
+ other_shape = [var.sizes[d] for d in other_dims]
579
+ new_shape = [actual_n_clusters, n_time_points] + other_shape
580
+ reshaped_constant = sliced_values.reshape(new_shape)
581
+
582
+ new_coords = {'cluster': cluster_coords, 'time': time_coords}
583
+ for dim in other_dims:
584
+ if dim in coord_cache:
585
+ new_coords[dim] = coord_cache[dim]
586
+
587
+ # Build filled slices dict: use typical where available, constant otherwise
588
+ filled_slices = {}
589
+ for key in all_keys:
590
+ if key in typical_das[name]:
591
+ filled_slices[key] = typical_das[name][key]
592
+ else:
593
+ filled_slices[key] = xr.DataArray(
594
+ reshaped_constant,
595
+ dims=['cluster', 'time'] + other_dims,
596
+ coords=new_coords,
597
+ )
598
+
599
+ da = self._combine_slices_to_dataarray_2d(
600
+ slices=filled_slices,
601
+ attrs=var.attrs,
602
+ periods=periods,
603
+ scenarios=scenarios,
604
+ )
605
+ if var.attrs.get('__timeseries_data__', False):
606
+ da = TimeSeriesData.from_dataarray(da.assign_attrs(var.attrs))
607
+ ds_new_vars[name] = da
608
+ else:
609
+ # Time-varying: combine per-(period, scenario) slices
610
+ da = self._combine_slices_to_dataarray_2d(
611
+ slices=typical_das[name],
612
+ attrs=var.attrs,
613
+ periods=periods,
614
+ scenarios=scenarios,
615
+ )
616
+ if var.attrs.get('__timeseries_data__', False):
617
+ da = TimeSeriesData.from_dataarray(da.assign_attrs(var.attrs))
618
+ ds_new_vars[name] = da
619
+
620
+ # Copy attrs but remove cluster_weight
621
+ new_attrs = dict(ds.attrs)
622
+ new_attrs.pop('cluster_weight', None)
623
+ return xr.Dataset(ds_new_vars, attrs=new_attrs)
624
+
625
+ def _build_cluster_assignments_da(
626
+ self,
627
+ cluster_assignmentss: dict[tuple, np.ndarray],
628
+ periods: list,
629
+ scenarios: list,
630
+ ) -> xr.DataArray:
631
+ """Build cluster_assignments DataArray from cluster assignments.
632
+
633
+ Args:
634
+ cluster_assignmentss: Dict mapping (period, scenario) to cluster assignment arrays.
635
+ periods: List of period labels ([None] if no periods dimension).
636
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
637
+
638
+ Returns:
639
+ DataArray with dims [original_cluster] or [original_cluster, period?, scenario?].
640
+ """
641
+ has_periods = periods != [None]
642
+ has_scenarios = scenarios != [None]
643
+
644
+ if has_periods or has_scenarios:
645
+ # Multi-dimensional case
646
+ cluster_assignments_slices = {}
647
+ for p in periods:
648
+ for s in scenarios:
649
+ key = (p, s)
650
+ cluster_assignments_slices[key] = xr.DataArray(
651
+ cluster_assignmentss[key], dims=['original_cluster'], name='cluster_assignments'
652
+ )
653
+ return self._combine_slices_to_dataarray_generic(
654
+ cluster_assignments_slices, ['original_cluster'], periods, scenarios, 'cluster_assignments'
655
+ )
656
+ else:
657
+ # Simple case
658
+ first_key = (periods[0], scenarios[0])
659
+ return xr.DataArray(cluster_assignmentss[first_key], dims=['original_cluster'], name='cluster_assignments')
660
+
661
+ def sel(
662
+ self,
663
+ time: str | slice | list[str] | pd.Timestamp | pd.DatetimeIndex | None = None,
664
+ period: int | slice | list[int] | pd.Index | None = None,
665
+ scenario: str | slice | list[str] | pd.Index | None = None,
666
+ ) -> FlowSystem:
667
+ """
668
+ Select a subset of the FlowSystem by label.
669
+
670
+ Creates a new FlowSystem with data selected along the specified dimensions.
671
+ The returned FlowSystem has no solution (it must be re-optimized).
672
+
673
+ Args:
674
+ time: Time selection (e.g., slice('2023-01-01', '2023-12-31'), '2023-06-15')
675
+ period: Period selection (e.g., slice(2023, 2024), or list of periods)
676
+ scenario: Scenario selection (e.g., 'scenario1', or list of scenarios)
677
+
678
+ Returns:
679
+ FlowSystem: New FlowSystem with selected data (no solution).
680
+
681
+ Examples:
682
+ >>> # Select specific time range
683
+ >>> fs_jan = flow_system.transform.sel(time=slice('2023-01-01', '2023-01-31'))
684
+ >>> fs_jan.optimize(solver)
685
+
686
+ >>> # Select single scenario
687
+ >>> fs_base = flow_system.transform.sel(scenario='Base Case')
688
+ """
689
+ from .flow_system import FlowSystem
690
+
691
+ if time is None and period is None and scenario is None:
692
+ result = self._fs.copy()
693
+ result.solution = None
694
+ return result
695
+
696
+ if not self._fs.connected_and_transformed:
697
+ self._fs.connect_and_transform()
698
+
699
+ ds = self._fs.to_dataset()
700
+ ds = self._dataset_sel(ds, time=time, period=period, scenario=scenario)
701
+ return FlowSystem.from_dataset(ds) # from_dataset doesn't include solution
702
+
703
+ def isel(
704
+ self,
705
+ time: int | slice | list[int] | None = None,
706
+ period: int | slice | list[int] | None = None,
707
+ scenario: int | slice | list[int] | None = None,
708
+ ) -> FlowSystem:
709
+ """
710
+ Select a subset of the FlowSystem by integer indices.
711
+
712
+ Creates a new FlowSystem with data selected along the specified dimensions.
713
+ The returned FlowSystem has no solution (it must be re-optimized).
714
+
715
+ Args:
716
+ time: Time selection by integer index (e.g., slice(0, 100), 50, or [0, 5, 10])
717
+ period: Period selection by integer index
718
+ scenario: Scenario selection by integer index
719
+
720
+ Returns:
721
+ FlowSystem: New FlowSystem with selected data (no solution).
722
+
723
+ Examples:
724
+ >>> # Select first 24 timesteps
725
+ >>> fs_day1 = flow_system.transform.isel(time=slice(0, 24))
726
+ >>> fs_day1.optimize(solver)
727
+
728
+ >>> # Select first scenario
729
+ >>> fs_first = flow_system.transform.isel(scenario=0)
730
+ """
731
+ from .flow_system import FlowSystem
732
+
733
+ if time is None and period is None and scenario is None:
734
+ result = self._fs.copy()
735
+ result.solution = None
736
+ return result
737
+
738
+ if not self._fs.connected_and_transformed:
739
+ self._fs.connect_and_transform()
740
+
741
+ ds = self._fs.to_dataset()
742
+ ds = self._dataset_isel(ds, time=time, period=period, scenario=scenario)
743
+ return FlowSystem.from_dataset(ds) # from_dataset doesn't include solution
744
+
745
+ def resample(
746
+ self,
747
+ time: str,
748
+ method: Literal['mean', 'sum', 'max', 'min', 'first', 'last', 'std', 'var', 'median', 'count'] = 'mean',
749
+ hours_of_last_timestep: int | float | None = None,
750
+ hours_of_previous_timesteps: int | float | np.ndarray | None = None,
751
+ fill_gaps: Literal['ffill', 'bfill', 'interpolate'] | None = None,
752
+ **kwargs: Any,
753
+ ) -> FlowSystem:
754
+ """
755
+ Create a resampled FlowSystem by resampling data along the time dimension.
756
+
757
+ Creates a new FlowSystem with resampled time series data.
758
+ The returned FlowSystem has no solution (it must be re-optimized).
759
+
760
+ Args:
761
+ time: Resampling frequency (e.g., '3h', '2D', '1M')
762
+ method: Resampling method. Recommended: 'mean', 'first', 'last', 'max', 'min'
763
+ hours_of_last_timestep: Duration of the last timestep after resampling.
764
+ If None, computed from the last time interval.
765
+ hours_of_previous_timesteps: Duration of previous timesteps after resampling.
766
+ If None, computed from the first time interval. Can be a scalar or array.
767
+ fill_gaps: Strategy for filling gaps (NaN values) that arise when resampling
768
+ irregular timesteps to regular intervals. Options: 'ffill' (forward fill),
769
+ 'bfill' (backward fill), 'interpolate' (linear interpolation).
770
+ If None (default), raises an error when gaps are detected.
771
+ **kwargs: Additional arguments passed to xarray.resample()
772
+
773
+ Returns:
774
+ FlowSystem: New resampled FlowSystem (no solution).
775
+
776
+ Raises:
777
+ ValueError: If resampling creates gaps and fill_gaps is not specified.
778
+
779
+ Examples:
780
+ >>> # Resample to 4-hour intervals
781
+ >>> fs_4h = flow_system.transform.resample(time='4h', method='mean')
782
+ >>> fs_4h.optimize(solver)
783
+
784
+ >>> # Resample to daily with max values
785
+ >>> fs_daily = flow_system.transform.resample(time='1D', method='max')
786
+ """
787
+ from .flow_system import FlowSystem
788
+
789
+ if not self._fs.connected_and_transformed:
790
+ self._fs.connect_and_transform()
791
+
792
+ ds = self._fs.to_dataset()
793
+ ds = self._dataset_resample(
794
+ ds,
795
+ freq=time,
796
+ method=method,
797
+ hours_of_last_timestep=hours_of_last_timestep,
798
+ hours_of_previous_timesteps=hours_of_previous_timesteps,
799
+ fill_gaps=fill_gaps,
800
+ **kwargs,
801
+ )
802
+ return FlowSystem.from_dataset(ds) # from_dataset doesn't include solution
803
+
804
+ # --- Class methods for dataset operations (can be called without instance) ---
805
+
806
+ @classmethod
807
+ def _dataset_sel(
808
+ cls,
809
+ dataset: xr.Dataset,
810
+ time: str | slice | list[str] | pd.Timestamp | pd.DatetimeIndex | None = None,
811
+ period: int | slice | list[int] | pd.Index | None = None,
812
+ scenario: str | slice | list[str] | pd.Index | None = None,
813
+ hours_of_last_timestep: int | float | None = None,
814
+ hours_of_previous_timesteps: int | float | np.ndarray | None = None,
815
+ ) -> xr.Dataset:
816
+ """
817
+ Select subset of dataset by label.
818
+
819
+ Args:
820
+ dataset: xarray Dataset from FlowSystem.to_dataset()
821
+ time: Time selection (e.g., '2020-01', slice('2020-01-01', '2020-06-30'))
822
+ period: Period selection (e.g., 2020, slice(2020, 2022))
823
+ scenario: Scenario selection (e.g., 'Base Case', ['Base Case', 'High Demand'])
824
+ hours_of_last_timestep: Duration of the last timestep.
825
+ hours_of_previous_timesteps: Duration of previous timesteps.
826
+
827
+ Returns:
828
+ xr.Dataset: Selected dataset
829
+ """
830
+ from .flow_system import FlowSystem
831
+
832
+ indexers = {}
833
+ if time is not None:
834
+ indexers['time'] = time
835
+ if period is not None:
836
+ indexers['period'] = period
837
+ if scenario is not None:
838
+ indexers['scenario'] = scenario
839
+
840
+ if not indexers:
841
+ return dataset
842
+
843
+ result = dataset.sel(**indexers)
844
+
845
+ if 'time' in indexers:
846
+ result = FlowSystem._update_time_metadata(result, hours_of_last_timestep, hours_of_previous_timesteps)
847
+
848
+ if 'period' in indexers:
849
+ result = FlowSystem._update_period_metadata(result)
850
+
851
+ if 'scenario' in indexers:
852
+ result = FlowSystem._update_scenario_metadata(result)
853
+
854
+ return result
855
+
856
+ @classmethod
857
+ def _dataset_isel(
858
+ cls,
859
+ dataset: xr.Dataset,
860
+ time: int | slice | list[int] | None = None,
861
+ period: int | slice | list[int] | None = None,
862
+ scenario: int | slice | list[int] | None = None,
863
+ hours_of_last_timestep: int | float | None = None,
864
+ hours_of_previous_timesteps: int | float | np.ndarray | None = None,
865
+ ) -> xr.Dataset:
866
+ """
867
+ Select subset of dataset by integer index.
868
+
869
+ Args:
870
+ dataset: xarray Dataset from FlowSystem.to_dataset()
871
+ time: Time selection by index
872
+ period: Period selection by index
873
+ scenario: Scenario selection by index
874
+ hours_of_last_timestep: Duration of the last timestep.
875
+ hours_of_previous_timesteps: Duration of previous timesteps.
876
+
877
+ Returns:
878
+ xr.Dataset: Selected dataset
879
+ """
880
+ from .flow_system import FlowSystem
881
+
882
+ indexers = {}
883
+ if time is not None:
884
+ indexers['time'] = time
885
+ if period is not None:
886
+ indexers['period'] = period
887
+ if scenario is not None:
888
+ indexers['scenario'] = scenario
889
+
890
+ if not indexers:
891
+ return dataset
892
+
893
+ result = dataset.isel(**indexers)
894
+
895
+ if 'time' in indexers:
896
+ result = FlowSystem._update_time_metadata(result, hours_of_last_timestep, hours_of_previous_timesteps)
897
+
898
+ if 'period' in indexers:
899
+ result = FlowSystem._update_period_metadata(result)
900
+
901
+ if 'scenario' in indexers:
902
+ result = FlowSystem._update_scenario_metadata(result)
903
+
904
+ return result
905
+
906
+ @classmethod
907
+ def _dataset_resample(
908
+ cls,
909
+ dataset: xr.Dataset,
910
+ freq: str,
911
+ method: Literal['mean', 'sum', 'max', 'min', 'first', 'last', 'std', 'var', 'median', 'count'] = 'mean',
912
+ hours_of_last_timestep: int | float | None = None,
913
+ hours_of_previous_timesteps: int | float | np.ndarray | None = None,
914
+ fill_gaps: Literal['ffill', 'bfill', 'interpolate'] | None = None,
915
+ **kwargs: Any,
916
+ ) -> xr.Dataset:
917
+ """
918
+ Resample dataset along time dimension.
919
+
920
+ Args:
921
+ dataset: xarray Dataset from FlowSystem.to_dataset()
922
+ freq: Resampling frequency (e.g., '2h', '1D', '1M')
923
+ method: Resampling method (e.g., 'mean', 'sum', 'first')
924
+ hours_of_last_timestep: Duration of the last timestep after resampling.
925
+ hours_of_previous_timesteps: Duration of previous timesteps after resampling.
926
+ fill_gaps: Strategy for filling gaps (NaN values) that arise when resampling
927
+ irregular timesteps to regular intervals. Options: 'ffill' (forward fill),
928
+ 'bfill' (backward fill), 'interpolate' (linear interpolation).
929
+ If None (default), raises an error when gaps are detected.
930
+ **kwargs: Additional arguments passed to xarray.resample()
931
+
932
+ Returns:
933
+ xr.Dataset: Resampled dataset
934
+
935
+ Raises:
936
+ ValueError: If resampling creates gaps and fill_gaps is not specified.
937
+ """
938
+ from .flow_system import FlowSystem
939
+
940
+ available_methods = ['mean', 'sum', 'max', 'min', 'first', 'last', 'std', 'var', 'median', 'count']
941
+ if method not in available_methods:
942
+ raise ValueError(f'Unsupported resampling method: {method}. Available: {available_methods}')
943
+
944
+ original_attrs = dict(dataset.attrs)
945
+
946
+ time_var_names = [v for v in dataset.data_vars if 'time' in dataset[v].dims]
947
+ non_time_var_names = [v for v in dataset.data_vars if v not in time_var_names]
948
+
949
+ # Handle case where no data variables have time dimension (all scalars)
950
+ # We still need to resample the time coordinate itself
951
+ if not time_var_names:
952
+ if 'time' not in dataset.coords:
953
+ raise ValueError('Dataset has no time dimension to resample')
954
+ # Create a dummy variable to resample the time coordinate
955
+ dummy = xr.DataArray(
956
+ np.zeros(len(dataset.coords['time'])), dims=['time'], coords={'time': dataset.coords['time']}
957
+ )
958
+ dummy_ds = xr.Dataset({'__dummy__': dummy})
959
+ resampled_dummy = getattr(dummy_ds.resample(time=freq, **kwargs), method)()
960
+ # Get the resampled time coordinate
961
+ resampled_time = resampled_dummy.coords['time']
962
+ # Create result with all original scalar data and resampled time coordinate
963
+ # Keep all existing coordinates (period, scenario, etc.) except time which gets resampled
964
+ result = dataset.copy()
965
+ result = result.assign_coords(time=resampled_time)
966
+ result.attrs.update(original_attrs)
967
+ return FlowSystem._update_time_metadata(result, hours_of_last_timestep, hours_of_previous_timesteps)
968
+
969
+ time_dataset = dataset[time_var_names]
970
+ resampled_time_dataset = cls._resample_by_dimension_groups(time_dataset, freq, method, **kwargs)
971
+
972
+ # Handle NaN values that may arise from resampling irregular timesteps to regular intervals.
973
+ # When irregular data (e.g., [00:00, 01:00, 03:00]) is resampled to regular intervals (e.g., '1h'),
974
+ # bins without data (e.g., 02:00) get NaN.
975
+ if resampled_time_dataset.isnull().any().to_array().any():
976
+ if fill_gaps is None:
977
+ # Find which variables have NaN values for a helpful error message
978
+ vars_with_nans = [
979
+ name for name in resampled_time_dataset.data_vars if resampled_time_dataset[name].isnull().any()
980
+ ]
981
+ raise ValueError(
982
+ f'Resampling created gaps (NaN values) in variables: {vars_with_nans}. '
983
+ f'This typically happens when resampling irregular timesteps to regular intervals. '
984
+ f"Specify fill_gaps='ffill', 'bfill', or 'interpolate' to handle gaps, "
985
+ f'or resample to a coarser frequency.'
986
+ )
987
+ elif fill_gaps == 'ffill':
988
+ resampled_time_dataset = resampled_time_dataset.ffill(dim='time').bfill(dim='time')
989
+ elif fill_gaps == 'bfill':
990
+ resampled_time_dataset = resampled_time_dataset.bfill(dim='time').ffill(dim='time')
991
+ elif fill_gaps == 'interpolate':
992
+ resampled_time_dataset = resampled_time_dataset.interpolate_na(dim='time', method='linear')
993
+ # Handle edges that can't be interpolated
994
+ resampled_time_dataset = resampled_time_dataset.ffill(dim='time').bfill(dim='time')
995
+
996
+ if non_time_var_names:
997
+ non_time_dataset = dataset[non_time_var_names]
998
+ result = xr.merge([resampled_time_dataset, non_time_dataset])
999
+ else:
1000
+ result = resampled_time_dataset
1001
+
1002
+ # Preserve all original coordinates that aren't 'time' (e.g., period, scenario, cluster)
1003
+ # These may be lost during merge if no data variable uses them
1004
+ for coord_name, coord_val in dataset.coords.items():
1005
+ if coord_name != 'time' and coord_name not in result.coords:
1006
+ result = result.assign_coords({coord_name: coord_val})
1007
+
1008
+ result.attrs.update(original_attrs)
1009
+ return FlowSystem._update_time_metadata(result, hours_of_last_timestep, hours_of_previous_timesteps)
1010
+
1011
+ @staticmethod
1012
+ def _resample_by_dimension_groups(
1013
+ time_dataset: xr.Dataset,
1014
+ time: str,
1015
+ method: str,
1016
+ **kwargs: Any,
1017
+ ) -> xr.Dataset:
1018
+ """
1019
+ Resample variables grouped by their dimension structure to avoid broadcasting.
1020
+
1021
+ Groups variables by their non-time dimensions before resampling for performance
1022
+ and to prevent xarray from broadcasting variables with different dimensions.
1023
+
1024
+ Args:
1025
+ time_dataset: Dataset containing only variables with time dimension
1026
+ time: Resampling frequency (e.g., '2h', '1D', '1M')
1027
+ method: Resampling method name (e.g., 'mean', 'sum', 'first')
1028
+ **kwargs: Additional arguments passed to xarray.resample()
1029
+
1030
+ Returns:
1031
+ Resampled dataset with original dimension structure preserved
1032
+ """
1033
+ dim_groups = defaultdict(list)
1034
+ variables = time_dataset.variables
1035
+ for var_name in time_dataset.data_vars:
1036
+ dims_key = tuple(sorted(d for d in variables[var_name].dims if d != 'time'))
1037
+ dim_groups[dims_key].append(var_name)
1038
+
1039
+ # Note: defaultdict is always truthy, so we check length explicitly
1040
+ if len(dim_groups) == 0:
1041
+ return getattr(time_dataset.resample(time=time, **kwargs), method)()
1042
+
1043
+ resampled_groups = []
1044
+ for var_names in dim_groups.values():
1045
+ if not var_names:
1046
+ continue
1047
+
1048
+ stacked = xr.concat(
1049
+ [time_dataset[name] for name in var_names],
1050
+ dim=pd.Index(var_names, name='variable'),
1051
+ combine_attrs='drop_conflicts',
1052
+ )
1053
+ resampled = getattr(stacked.resample(time=time, **kwargs), method)()
1054
+ resampled_dataset = resampled.to_dataset(dim='variable')
1055
+ resampled_groups.append(resampled_dataset)
1056
+
1057
+ if not resampled_groups:
1058
+ # No data variables to resample, but still resample coordinates
1059
+ return getattr(time_dataset.resample(time=time, **kwargs), method)()
1060
+
1061
+ if len(resampled_groups) == 1:
1062
+ return resampled_groups[0]
1063
+
1064
+ return xr.merge(resampled_groups, combine_attrs='drop_conflicts')
1065
+
1066
+ def fix_sizes(
1067
+ self,
1068
+ sizes: xr.Dataset | dict[str, float] | None = None,
1069
+ decimal_rounding: int | None = 5,
1070
+ ) -> FlowSystem:
1071
+ """
1072
+ Create a new FlowSystem with investment sizes fixed to specified values.
1073
+
1074
+ This is useful for two-stage optimization workflows:
1075
+ 1. Solve a sizing problem (possibly resampled for speed)
1076
+ 2. Fix sizes and solve dispatch at full resolution
1077
+
1078
+ The returned FlowSystem has InvestParameters with fixed_size set,
1079
+ making those sizes mandatory rather than decision variables.
1080
+
1081
+ Args:
1082
+ sizes: The sizes to fix. Can be:
1083
+ - None: Uses sizes from this FlowSystem's solution (must be solved)
1084
+ - xr.Dataset: Dataset with size variables (e.g., from statistics.sizes)
1085
+ - dict: Mapping of component names to sizes (e.g., {'Boiler(Q_fu)': 100})
1086
+ decimal_rounding: Number of decimal places to round sizes to.
1087
+ Rounding helps avoid numerical infeasibility. Set to None to disable.
1088
+
1089
+ Returns:
1090
+ FlowSystem: New FlowSystem with fixed sizes (no solution).
1091
+
1092
+ Raises:
1093
+ ValueError: If no sizes provided and FlowSystem has no solution.
1094
+ KeyError: If a specified size doesn't match any InvestParameters.
1095
+
1096
+ Examples:
1097
+ Two-stage optimization:
1098
+
1099
+ >>> # Stage 1: Size with resampled data
1100
+ >>> fs_sizing = flow_system.transform.resample('2h')
1101
+ >>> fs_sizing.optimize(solver)
1102
+ >>>
1103
+ >>> # Stage 2: Fix sizes and optimize at full resolution
1104
+ >>> fs_dispatch = flow_system.transform.fix_sizes(fs_sizing.statistics.sizes)
1105
+ >>> fs_dispatch.optimize(solver)
1106
+
1107
+ Using a dict:
1108
+
1109
+ >>> fs_fixed = flow_system.transform.fix_sizes(
1110
+ ... {
1111
+ ... 'Boiler(Q_fu)': 100,
1112
+ ... 'Storage': 500,
1113
+ ... }
1114
+ ... )
1115
+ >>> fs_fixed.optimize(solver)
1116
+ """
1117
+ from .flow_system import FlowSystem
1118
+ from .interface import InvestParameters
1119
+
1120
+ # Get sizes from solution if not provided
1121
+ if sizes is None:
1122
+ if self._fs.solution is None:
1123
+ raise ValueError(
1124
+ 'No sizes provided and FlowSystem has no solution. '
1125
+ 'Either provide sizes or optimize the FlowSystem first.'
1126
+ )
1127
+ sizes = self._fs.statistics.sizes
1128
+
1129
+ # Convert dict to Dataset format
1130
+ if isinstance(sizes, dict):
1131
+ sizes = xr.Dataset({k: xr.DataArray(v) for k, v in sizes.items()})
1132
+
1133
+ # Apply rounding
1134
+ if decimal_rounding is not None:
1135
+ sizes = sizes.round(decimal_rounding)
1136
+
1137
+ # Create copy of FlowSystem
1138
+ if not self._fs.connected_and_transformed:
1139
+ self._fs.connect_and_transform()
1140
+
1141
+ ds = self._fs.to_dataset()
1142
+ new_fs = FlowSystem.from_dataset(ds)
1143
+
1144
+ # Fix sizes in the new FlowSystem's InvestParameters
1145
+ # Note: statistics.sizes returns keys without '|size' suffix (e.g., 'Boiler(Q_fu)')
1146
+ # but dicts may have either format
1147
+ for size_var in sizes.data_vars:
1148
+ # Normalize: strip '|size' suffix if present
1149
+ base_name = size_var.replace('|size', '') if size_var.endswith('|size') else size_var
1150
+ fixed_value = float(sizes[size_var].item())
1151
+
1152
+ # Find matching element with InvestParameters
1153
+ found = False
1154
+
1155
+ # Check flows
1156
+ for flow in new_fs.flows.values():
1157
+ if flow.label_full == base_name and isinstance(flow.size, InvestParameters):
1158
+ flow.size.fixed_size = fixed_value
1159
+ flow.size.mandatory = True
1160
+ found = True
1161
+ logger.debug(f'Fixed size of {base_name} to {fixed_value}')
1162
+ break
1163
+
1164
+ # Check storage capacity
1165
+ if not found:
1166
+ for component in new_fs.components.values():
1167
+ if hasattr(component, 'capacity_in_flow_hours'):
1168
+ if component.label == base_name and isinstance(
1169
+ component.capacity_in_flow_hours, InvestParameters
1170
+ ):
1171
+ component.capacity_in_flow_hours.fixed_size = fixed_value
1172
+ component.capacity_in_flow_hours.mandatory = True
1173
+ found = True
1174
+ logger.debug(f'Fixed size of {base_name} to {fixed_value}')
1175
+ break
1176
+
1177
+ if not found:
1178
+ logger.warning(
1179
+ f'Size variable "{base_name}" not found as InvestParameters in FlowSystem. '
1180
+ f'It may be a fixed-size component or the name may not match.'
1181
+ )
1182
+
1183
+ return new_fs
1184
+
1185
+ def clustering_data(
1186
+ self,
1187
+ period: Any | None = None,
1188
+ scenario: Any | None = None,
1189
+ ) -> xr.Dataset:
1190
+ """
1191
+ Get the time-varying data that would be used for clustering.
1192
+
1193
+ This method extracts only the data arrays that vary over time, which is
1194
+ the data that clustering algorithms use to identify typical periods.
1195
+ Constant arrays (same value for all timesteps) are excluded since they
1196
+ don't contribute to pattern identification.
1197
+
1198
+ Use this to inspect or pre-process the data before clustering, or to
1199
+ understand which variables influence the clustering result.
1200
+
1201
+ Args:
1202
+ period: Optional period label to select. If None and the FlowSystem
1203
+ has multiple periods, returns data for all periods.
1204
+ scenario: Optional scenario label to select. If None and the FlowSystem
1205
+ has multiple scenarios, returns data for all scenarios.
1206
+
1207
+ Returns:
1208
+ xr.Dataset containing only time-varying data arrays. The dataset
1209
+ includes arrays like demand profiles, price profiles, and other
1210
+ time series that vary over the time dimension.
1211
+
1212
+ Examples:
1213
+ Inspect clustering input data:
1214
+
1215
+ >>> data = flow_system.transform.clustering_data()
1216
+ >>> print(f'Variables used for clustering: {list(data.data_vars)}')
1217
+ >>> data['HeatDemand(Q)|fixed_relative_profile'].plot()
1218
+
1219
+ Get data for a specific period/scenario:
1220
+
1221
+ >>> data_2024 = flow_system.transform.clustering_data(period=2024)
1222
+ >>> data_high = flow_system.transform.clustering_data(scenario='high')
1223
+
1224
+ Convert to DataFrame for external tools:
1225
+
1226
+ >>> df = flow_system.transform.clustering_data().to_dataframe()
1227
+ """
1228
+ from .core import drop_constant_arrays
1229
+
1230
+ if not self._fs.connected_and_transformed:
1231
+ self._fs.connect_and_transform()
1232
+
1233
+ ds = self._fs.to_dataset(include_solution=False)
1234
+
1235
+ # Build selector for period/scenario
1236
+ selector = {}
1237
+ if period is not None:
1238
+ selector['period'] = period
1239
+ if scenario is not None:
1240
+ selector['scenario'] = scenario
1241
+
1242
+ # Apply selection if specified
1243
+ if selector:
1244
+ ds = ds.sel(**selector, drop=True)
1245
+
1246
+ # Filter to only time-varying arrays
1247
+ result = drop_constant_arrays(ds, dim='time')
1248
+
1249
+ # Guard against empty dataset (all variables are constant)
1250
+ if not result.data_vars:
1251
+ selector_info = f' for {selector}' if selector else ''
1252
+ raise ValueError(
1253
+ f'No time-varying data found{selector_info}. '
1254
+ f'All variables are constant over time. Check your period/scenario filter or input data.'
1255
+ )
1256
+
1257
+ # Remove attrs for cleaner output
1258
+ result.attrs = {}
1259
+ for var in result.data_vars:
1260
+ result[var].attrs = {}
1261
+
1262
+ return result
1263
+
1264
+ def cluster(
1265
+ self,
1266
+ n_clusters: int,
1267
+ cluster_duration: str | float,
1268
+ data_vars: list[str] | None = None,
1269
+ cluster: ClusterConfig | None = None,
1270
+ extremes: ExtremeConfig | None = None,
1271
+ segments: SegmentConfig | None = None,
1272
+ preserve_column_means: bool = True,
1273
+ rescale_exclude_columns: list[str] | None = None,
1274
+ round_decimals: int | None = None,
1275
+ numerical_tolerance: float = 1e-13,
1276
+ **tsam_kwargs: Any,
1277
+ ) -> FlowSystem:
1278
+ """
1279
+ Create a FlowSystem with reduced timesteps using typical clusters.
1280
+
1281
+ This method creates a new FlowSystem optimized for sizing studies by reducing
1282
+ the number of timesteps to only the typical (representative) clusters identified
1283
+ through time series aggregation using the tsam package.
1284
+
1285
+ The method:
1286
+ 1. Performs time series clustering using tsam (hierarchical by default)
1287
+ 2. Extracts only the typical clusters (not all original timesteps)
1288
+ 3. Applies timestep weighting for accurate cost representation
1289
+ 4. Handles storage states between clusters based on each Storage's ``cluster_mode``
1290
+
1291
+ Use this for initial sizing optimization, then use ``fix_sizes()`` to re-optimize
1292
+ at full resolution for accurate dispatch results.
1293
+
1294
+ To reuse an existing clustering on different data, use ``apply_clustering()`` instead.
1295
+
1296
+ Args:
1297
+ n_clusters: Number of clusters (typical periods) to extract (e.g., 8 typical days).
1298
+ cluster_duration: Duration of each cluster. Can be a pandas-style string
1299
+ ('1D', '24h', '6h') or a numeric value in hours.
1300
+ data_vars: Optional list of variable names to use for clustering. If specified,
1301
+ only these variables are used to determine cluster assignments, but the
1302
+ clustering is then applied to ALL time-varying data in the FlowSystem.
1303
+ Use ``transform.clustering_data()`` to see available variables.
1304
+ Example: ``data_vars=['HeatDemand(Q)|fixed_relative_profile']`` to cluster
1305
+ based only on heat demand patterns.
1306
+ cluster: Optional tsam ``ClusterConfig`` object specifying clustering algorithm,
1307
+ representation method, and weights. If None, uses default settings (hierarchical
1308
+ clustering with medoid representation) and automatically calculated weights
1309
+ based on data variance.
1310
+ extremes: Optional tsam ``ExtremeConfig`` object specifying how to handle
1311
+ extreme periods (peaks). Use this to ensure peak demand days are captured.
1312
+ Example: ``ExtremeConfig(method='new_cluster', max_value=['demand'])``.
1313
+ segments: Optional tsam ``SegmentConfig`` object specifying intra-period
1314
+ segmentation. Segments divide each cluster period into variable-duration
1315
+ sub-segments. Example: ``SegmentConfig(n_segments=4)``.
1316
+ preserve_column_means: Rescale typical periods so each column's weighted mean
1317
+ matches the original data's mean. Ensures total energy/load is preserved
1318
+ when weights represent occurrence counts. Default is True.
1319
+ rescale_exclude_columns: Column names to exclude from rescaling when
1320
+ ``preserve_column_means=True``. Useful for binary/indicator columns (0/1 values)
1321
+ that should not be rescaled.
1322
+ round_decimals: Round output values to this many decimal places.
1323
+ If None (default), no rounding is applied.
1324
+ numerical_tolerance: Tolerance for numerical precision issues. Controls when
1325
+ warnings are raised for aggregated values exceeding original time series bounds.
1326
+ Default is 1e-13.
1327
+ **tsam_kwargs: Additional keyword arguments passed to ``tsam.aggregate()``
1328
+ for forward compatibility. See tsam documentation for all options.
1329
+
1330
+ Returns:
1331
+ A new FlowSystem with reduced timesteps (only typical clusters).
1332
+ The FlowSystem has metadata stored in ``clustering`` for expansion.
1333
+
1334
+ Raises:
1335
+ ValueError: If timestep sizes are inconsistent.
1336
+ ValueError: If cluster_duration is not a multiple of timestep size.
1337
+
1338
+ Examples:
1339
+ Basic clustering with peak preservation:
1340
+
1341
+ >>> from tsam import ExtremeConfig
1342
+ >>> fs_clustered = flow_system.transform.cluster(
1343
+ ... n_clusters=8,
1344
+ ... cluster_duration='1D',
1345
+ ... extremes=ExtremeConfig(
1346
+ ... method='new_cluster',
1347
+ ... max_value=['HeatDemand(Q_th)|fixed_relative_profile'],
1348
+ ... ),
1349
+ ... )
1350
+ >>> fs_clustered.optimize(solver)
1351
+
1352
+ Clustering based on specific variables only:
1353
+
1354
+ >>> # See available variables for clustering
1355
+ >>> print(flow_system.transform.clustering_data().data_vars)
1356
+ >>>
1357
+ >>> # Cluster based only on demand profile
1358
+ >>> fs_clustered = flow_system.transform.cluster(
1359
+ ... n_clusters=8,
1360
+ ... cluster_duration='1D',
1361
+ ... data_vars=['HeatDemand(Q)|fixed_relative_profile'],
1362
+ ... )
1363
+
1364
+ Note:
1365
+ - This is best suited for initial sizing, not final dispatch optimization
1366
+ - Use ``extremes`` to ensure peak demand clusters are captured
1367
+ - A 5-10% safety margin on sizes is recommended for the dispatch stage
1368
+ - For seasonal storage (e.g., hydrogen, thermal storage), set
1369
+ ``Storage.cluster_mode='intercluster'`` or ``'intercluster_cyclic'``
1370
+ """
1371
+ import tsam
1372
+
1373
+ from .clustering import ClusteringResults
1374
+ from .core import drop_constant_arrays
1375
+
1376
+ # Parse cluster_duration to hours
1377
+ hours_per_cluster = (
1378
+ pd.Timedelta(cluster_duration).total_seconds() / 3600
1379
+ if isinstance(cluster_duration, str)
1380
+ else float(cluster_duration)
1381
+ )
1382
+
1383
+ # Validation
1384
+ dt = float(self._fs.timestep_duration.min().item())
1385
+ if not np.isclose(dt, float(self._fs.timestep_duration.max().item())):
1386
+ raise ValueError(
1387
+ f'cluster() requires uniform timestep sizes, got min={dt}h, '
1388
+ f'max={float(self._fs.timestep_duration.max().item())}h.'
1389
+ )
1390
+ if not np.isclose(hours_per_cluster / dt, round(hours_per_cluster / dt), atol=1e-9):
1391
+ raise ValueError(f'cluster_duration={hours_per_cluster}h must be a multiple of timestep size ({dt}h).')
1392
+
1393
+ timesteps_per_cluster = int(round(hours_per_cluster / dt))
1394
+ has_periods = self._fs.periods is not None
1395
+ has_scenarios = self._fs.scenarios is not None
1396
+
1397
+ # Determine iteration dimensions
1398
+ periods = list(self._fs.periods) if has_periods else [None]
1399
+ scenarios = list(self._fs.scenarios) if has_scenarios else [None]
1400
+
1401
+ ds = self._fs.to_dataset(include_solution=False)
1402
+
1403
+ # Validate and prepare data_vars for clustering
1404
+ if data_vars is not None:
1405
+ missing = set(data_vars) - set(ds.data_vars)
1406
+ if missing:
1407
+ raise ValueError(
1408
+ f'data_vars not found in FlowSystem: {missing}. '
1409
+ f'Available time-varying variables can be found via transform.clustering_data().'
1410
+ )
1411
+ ds_for_clustering = ds[list(data_vars)]
1412
+ else:
1413
+ ds_for_clustering = ds
1414
+
1415
+ # Validate tsam_kwargs doesn't override explicit parameters
1416
+ reserved_tsam_keys = {
1417
+ 'n_clusters',
1418
+ 'period_duration', # exposed as cluster_duration
1419
+ 'timestep_duration', # computed automatically
1420
+ 'cluster',
1421
+ 'segments',
1422
+ 'extremes',
1423
+ 'preserve_column_means',
1424
+ 'rescale_exclude_columns',
1425
+ 'round_decimals',
1426
+ 'numerical_tolerance',
1427
+ }
1428
+ conflicts = reserved_tsam_keys & set(tsam_kwargs.keys())
1429
+ if conflicts:
1430
+ raise ValueError(
1431
+ f'Cannot override explicit parameters via tsam_kwargs: {conflicts}. '
1432
+ f'Use the corresponding cluster() parameters instead.'
1433
+ )
1434
+
1435
+ # Cluster each (period, scenario) combination using tsam directly
1436
+ tsam_aggregation_results: dict[tuple, Any] = {} # AggregationResult objects
1437
+ tsam_clustering_results: dict[tuple, Any] = {} # ClusteringResult objects for persistence
1438
+ cluster_assignmentss: dict[tuple, np.ndarray] = {}
1439
+ cluster_occurrences_all: dict[tuple, dict] = {}
1440
+
1441
+ # Collect metrics per (period, scenario) slice
1442
+ clustering_metrics_all: dict[tuple, pd.DataFrame] = {}
1443
+
1444
+ for period_label in periods:
1445
+ for scenario_label in scenarios:
1446
+ key = (period_label, scenario_label)
1447
+ selector = {k: v for k, v in [('period', period_label), ('scenario', scenario_label)] if v is not None}
1448
+
1449
+ # Select data for clustering (may be subset if data_vars specified)
1450
+ ds_slice_for_clustering = (
1451
+ ds_for_clustering.sel(**selector, drop=True) if selector else ds_for_clustering
1452
+ )
1453
+ temporaly_changing_ds_for_clustering = drop_constant_arrays(ds_slice_for_clustering, dim='time')
1454
+
1455
+ # Guard against empty dataset after removing constant arrays
1456
+ if not temporaly_changing_ds_for_clustering.data_vars:
1457
+ filter_info = f'data_vars={data_vars}' if data_vars else 'all variables'
1458
+ selector_info = f', selector={selector}' if selector else ''
1459
+ raise ValueError(
1460
+ f'No time-varying data found for clustering ({filter_info}{selector_info}). '
1461
+ f'All variables are constant over time. Check your data_vars filter or input data.'
1462
+ )
1463
+
1464
+ df_for_clustering = temporaly_changing_ds_for_clustering.to_dataframe()
1465
+
1466
+ if selector:
1467
+ logger.info(f'Clustering {", ".join(f"{k}={v}" for k, v in selector.items())}...')
1468
+
1469
+ # Suppress tsam warning about minimal value constraints (informational, not actionable)
1470
+ with warnings.catch_warnings():
1471
+ warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*')
1472
+
1473
+ # Build ClusterConfig with auto-calculated weights
1474
+ clustering_weights = self._calculate_clustering_weights(temporaly_changing_ds_for_clustering)
1475
+ filtered_weights = {
1476
+ name: w for name, w in clustering_weights.items() if name in df_for_clustering.columns
1477
+ }
1478
+ cluster_config = self._build_cluster_config_with_weights(cluster, filtered_weights)
1479
+
1480
+ # Perform clustering based on selected data_vars (or all if not specified)
1481
+ tsam_result = tsam.aggregate(
1482
+ df_for_clustering,
1483
+ n_clusters=n_clusters,
1484
+ period_duration=hours_per_cluster,
1485
+ timestep_duration=dt,
1486
+ cluster=cluster_config,
1487
+ extremes=extremes,
1488
+ segments=segments,
1489
+ preserve_column_means=preserve_column_means,
1490
+ rescale_exclude_columns=rescale_exclude_columns,
1491
+ round_decimals=round_decimals,
1492
+ numerical_tolerance=numerical_tolerance,
1493
+ **tsam_kwargs,
1494
+ )
1495
+
1496
+ tsam_aggregation_results[key] = tsam_result
1497
+ tsam_clustering_results[key] = tsam_result.clustering
1498
+ cluster_assignmentss[key] = tsam_result.cluster_assignments
1499
+ cluster_occurrences_all[key] = tsam_result.cluster_weights
1500
+ try:
1501
+ clustering_metrics_all[key] = self._accuracy_to_dataframe(tsam_result.accuracy)
1502
+ except Exception as e:
1503
+ logger.warning(f'Failed to compute clustering metrics for {key}: {e}')
1504
+ clustering_metrics_all[key] = pd.DataFrame()
1505
+
1506
+ # If data_vars was specified, apply clustering to FULL data
1507
+ if data_vars is not None:
1508
+ # Build dim_names for ClusteringResults
1509
+ dim_names = []
1510
+ if has_periods:
1511
+ dim_names.append('period')
1512
+ if has_scenarios:
1513
+ dim_names.append('scenario')
1514
+
1515
+ # Convert (period, scenario) keys to ClusteringResults format
1516
+ def to_cr_key(p, s):
1517
+ key_parts = []
1518
+ if has_periods:
1519
+ key_parts.append(p)
1520
+ if has_scenarios:
1521
+ key_parts.append(s)
1522
+ return tuple(key_parts)
1523
+
1524
+ # Build ClusteringResults from subset clustering
1525
+ clustering_results = ClusteringResults(
1526
+ {to_cr_key(p, s): cr for (p, s), cr in tsam_clustering_results.items()},
1527
+ dim_names,
1528
+ )
1529
+
1530
+ # Apply to full data - this returns AggregationResults
1531
+ agg_results = clustering_results.apply(ds)
1532
+
1533
+ # Update tsam_aggregation_results with full data results
1534
+ for cr_key, result in agg_results:
1535
+ # Convert back to (period, scenario) format
1536
+ if has_periods and has_scenarios:
1537
+ full_key = (cr_key[0], cr_key[1])
1538
+ elif has_periods:
1539
+ full_key = (cr_key[0], None)
1540
+ elif has_scenarios:
1541
+ full_key = (None, cr_key[0])
1542
+ else:
1543
+ full_key = (None, None)
1544
+ tsam_aggregation_results[full_key] = result
1545
+ cluster_occurrences_all[full_key] = result.cluster_weights
1546
+
1547
+ # Build and return the reduced FlowSystem
1548
+ return self._build_reduced_flow_system(
1549
+ ds=ds,
1550
+ tsam_aggregation_results=tsam_aggregation_results,
1551
+ cluster_occurrences_all=cluster_occurrences_all,
1552
+ clustering_metrics_all=clustering_metrics_all,
1553
+ timesteps_per_cluster=timesteps_per_cluster,
1554
+ dt=dt,
1555
+ periods=periods,
1556
+ scenarios=scenarios,
1557
+ n_clusters_requested=n_clusters,
1558
+ )
1559
+
1560
+ def apply_clustering(
1561
+ self,
1562
+ clustering: Clustering,
1563
+ ) -> FlowSystem:
1564
+ """
1565
+ Apply an existing clustering to this FlowSystem.
1566
+
1567
+ This method applies a previously computed clustering (from another FlowSystem)
1568
+ to the current FlowSystem's data. The clustering structure (cluster assignments,
1569
+ number of clusters, etc.) is preserved while the time series data is aggregated
1570
+ according to the existing cluster assignments.
1571
+
1572
+ Use this to:
1573
+ - Compare different scenarios with identical cluster assignments
1574
+ - Apply a reference clustering to new data
1575
+
1576
+ Args:
1577
+ clustering: A ``Clustering`` object from a previously clustered FlowSystem.
1578
+ Obtain this via ``fs.clustering`` from a clustered FlowSystem.
1579
+
1580
+ Returns:
1581
+ A new FlowSystem with reduced timesteps (only typical clusters).
1582
+ The FlowSystem has metadata stored in ``clustering`` for expansion.
1583
+
1584
+ Raises:
1585
+ ValueError: If the clustering dimensions don't match this FlowSystem's
1586
+ periods/scenarios.
1587
+
1588
+ Examples:
1589
+ Apply clustering from one FlowSystem to another:
1590
+
1591
+ >>> fs_reference = fs_base.transform.cluster(n_clusters=8, cluster_duration='1D')
1592
+ >>> fs_other = fs_high.transform.apply_clustering(fs_reference.clustering)
1593
+ """
1594
+ # Validation
1595
+ dt = float(self._fs.timestep_duration.min().item())
1596
+ if not np.isclose(dt, float(self._fs.timestep_duration.max().item())):
1597
+ raise ValueError(
1598
+ f'apply_clustering() requires uniform timestep sizes, got min={dt}h, '
1599
+ f'max={float(self._fs.timestep_duration.max().item())}h.'
1600
+ )
1601
+
1602
+ # Get timesteps_per_cluster from the clustering object (survives serialization)
1603
+ timesteps_per_cluster = clustering.timesteps_per_cluster
1604
+ has_periods = self._fs.periods is not None
1605
+ has_scenarios = self._fs.scenarios is not None
1606
+
1607
+ # Determine iteration dimensions
1608
+ periods = list(self._fs.periods) if has_periods else [None]
1609
+ scenarios = list(self._fs.scenarios) if has_scenarios else [None]
1610
+
1611
+ ds = self._fs.to_dataset(include_solution=False)
1612
+
1613
+ # Validate that timesteps match the clustering expectations
1614
+ current_timesteps = len(self._fs.timesteps)
1615
+ expected_timesteps = clustering.n_original_clusters * clustering.timesteps_per_cluster
1616
+ if current_timesteps != expected_timesteps:
1617
+ raise ValueError(
1618
+ f'Timestep count mismatch in apply_clustering(): '
1619
+ f'FlowSystem has {current_timesteps} timesteps, but clustering expects '
1620
+ f'{expected_timesteps} timesteps ({clustering.n_original_clusters} clusters × '
1621
+ f'{clustering.timesteps_per_cluster} timesteps/cluster). '
1622
+ f'Ensure self._fs.timesteps matches the original data used for clustering.results.apply(ds).'
1623
+ )
1624
+
1625
+ # Apply existing clustering to all (period, scenario) combinations at once
1626
+ logger.info('Applying clustering...')
1627
+ with warnings.catch_warnings():
1628
+ warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*')
1629
+ agg_results = clustering.results.apply(ds)
1630
+
1631
+ # Convert AggregationResults to the dict format expected by _build_reduced_flow_system
1632
+ tsam_aggregation_results: dict[tuple, Any] = {}
1633
+ cluster_occurrences_all: dict[tuple, dict] = {}
1634
+ clustering_metrics_all: dict[tuple, pd.DataFrame] = {}
1635
+
1636
+ for cr_key, result in agg_results:
1637
+ # Convert ClusteringResults key to (period, scenario) format
1638
+ if has_periods and has_scenarios:
1639
+ full_key = (cr_key[0], cr_key[1])
1640
+ elif has_periods:
1641
+ full_key = (cr_key[0], None)
1642
+ elif has_scenarios:
1643
+ full_key = (None, cr_key[0])
1644
+ else:
1645
+ full_key = (None, None)
1646
+
1647
+ tsam_aggregation_results[full_key] = result
1648
+ cluster_occurrences_all[full_key] = result.cluster_weights
1649
+ try:
1650
+ clustering_metrics_all[full_key] = self._accuracy_to_dataframe(result.accuracy)
1651
+ except Exception as e:
1652
+ logger.warning(f'Failed to compute clustering metrics for {full_key}: {e}')
1653
+ clustering_metrics_all[full_key] = pd.DataFrame()
1654
+
1655
+ # Build and return the reduced FlowSystem
1656
+ return self._build_reduced_flow_system(
1657
+ ds=ds,
1658
+ tsam_aggregation_results=tsam_aggregation_results,
1659
+ cluster_occurrences_all=cluster_occurrences_all,
1660
+ clustering_metrics_all=clustering_metrics_all,
1661
+ timesteps_per_cluster=timesteps_per_cluster,
1662
+ dt=dt,
1663
+ periods=periods,
1664
+ scenarios=scenarios,
1665
+ )
1666
+
1667
+ @staticmethod
1668
+ def _combine_slices_to_dataarray_generic(
1669
+ slices: dict[tuple, xr.DataArray],
1670
+ base_dims: list[str],
1671
+ periods: list,
1672
+ scenarios: list,
1673
+ name: str,
1674
+ ) -> xr.DataArray:
1675
+ """Combine per-(period, scenario) slices into a multi-dimensional DataArray.
1676
+
1677
+ Generic version that works with any base dimension (not just 'time').
1678
+
1679
+ Args:
1680
+ slices: Dict mapping (period, scenario) tuples to DataArrays.
1681
+ base_dims: Base dimensions of each slice (e.g., ['original_cluster'] or ['original_time']).
1682
+ periods: List of period labels ([None] if no periods dimension).
1683
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
1684
+ name: Name for the resulting DataArray.
1685
+
1686
+ Returns:
1687
+ DataArray with dimensions [base_dims..., period?, scenario?].
1688
+ """
1689
+ first_key = (periods[0], scenarios[0])
1690
+ has_periods = periods != [None]
1691
+ has_scenarios = scenarios != [None]
1692
+
1693
+ # Simple case: no period/scenario dimensions
1694
+ if not has_periods and not has_scenarios:
1695
+ return slices[first_key].rename(name)
1696
+
1697
+ # Multi-dimensional: use xr.concat to stack along period/scenario dims
1698
+ # Use join='outer' to handle cases where different periods/scenarios have different
1699
+ # coordinate values (e.g., different time_series after drop_constant_arrays)
1700
+ if has_periods and has_scenarios:
1701
+ # Stack scenarios first, then periods
1702
+ period_arrays = []
1703
+ for p in periods:
1704
+ scenario_arrays = [slices[(p, s)] for s in scenarios]
1705
+ period_arrays.append(
1706
+ xr.concat(
1707
+ scenario_arrays, dim=pd.Index(scenarios, name='scenario'), join='outer', fill_value=np.nan
1708
+ )
1709
+ )
1710
+ result = xr.concat(period_arrays, dim=pd.Index(periods, name='period'), join='outer', fill_value=np.nan)
1711
+ elif has_periods:
1712
+ result = xr.concat(
1713
+ [slices[(p, None)] for p in periods],
1714
+ dim=pd.Index(periods, name='period'),
1715
+ join='outer',
1716
+ fill_value=np.nan,
1717
+ )
1718
+ else:
1719
+ result = xr.concat(
1720
+ [slices[(None, s)] for s in scenarios],
1721
+ dim=pd.Index(scenarios, name='scenario'),
1722
+ join='outer',
1723
+ fill_value=np.nan,
1724
+ )
1725
+
1726
+ # Put base dimension first (standard order)
1727
+ result = result.transpose(base_dims[0], ...)
1728
+
1729
+ return result.rename(name)
1730
+
1731
+ @staticmethod
1732
+ def _combine_slices_to_dataarray_2d(
1733
+ slices: dict[tuple, xr.DataArray],
1734
+ attrs: dict,
1735
+ periods: list,
1736
+ scenarios: list,
1737
+ ) -> xr.DataArray:
1738
+ """Combine per-(period, scenario) slices into a multi-dimensional DataArray with (cluster, time) dims.
1739
+
1740
+ Args:
1741
+ slices: Dict mapping (period, scenario) tuples to DataArrays with (cluster, time) dims.
1742
+ attrs: Attributes to assign to the result.
1743
+ periods: List of period labels ([None] if no periods dimension).
1744
+ scenarios: List of scenario labels ([None] if no scenarios dimension).
1745
+
1746
+ Returns:
1747
+ DataArray with dimensions (cluster, time, period?, scenario?).
1748
+ """
1749
+ first_key = (periods[0], scenarios[0])
1750
+ has_periods = periods != [None]
1751
+ has_scenarios = scenarios != [None]
1752
+
1753
+ # Simple case: no period/scenario dimensions
1754
+ if not has_periods and not has_scenarios:
1755
+ return slices[first_key].assign_attrs(attrs)
1756
+
1757
+ # Multi-dimensional: use xr.concat to stack along period/scenario dims
1758
+ if has_periods and has_scenarios:
1759
+ # Stack scenarios first, then periods
1760
+ period_arrays = []
1761
+ for p in periods:
1762
+ scenario_arrays = [slices[(p, s)] for s in scenarios]
1763
+ period_arrays.append(xr.concat(scenario_arrays, dim=pd.Index(scenarios, name='scenario')))
1764
+ result = xr.concat(period_arrays, dim=pd.Index(periods, name='period'))
1765
+ elif has_periods:
1766
+ result = xr.concat([slices[(p, None)] for p in periods], dim=pd.Index(periods, name='period'))
1767
+ else:
1768
+ result = xr.concat([slices[(None, s)] for s in scenarios], dim=pd.Index(scenarios, name='scenario'))
1769
+
1770
+ # Put cluster and time first (standard order for clustered data)
1771
+ result = result.transpose('cluster', 'time', ...)
1772
+
1773
+ return result.assign_attrs(attrs)
1774
+
1775
+ def _validate_for_expansion(self) -> Clustering:
1776
+ """Validate FlowSystem can be expanded and return clustering info.
1777
+
1778
+ Returns:
1779
+ The Clustering object.
1780
+
1781
+ Raises:
1782
+ ValueError: If FlowSystem wasn't created with cluster() or has no solution.
1783
+ """
1784
+
1785
+ if self._fs.clustering is None:
1786
+ raise ValueError(
1787
+ 'expand() requires a FlowSystem created with cluster(). This FlowSystem has no aggregation info.'
1788
+ )
1789
+ if self._fs.solution is None:
1790
+ raise ValueError('FlowSystem has no solution. Run optimize() or solve() first.')
1791
+
1792
+ return self._fs.clustering
1793
+
1794
+ def _combine_intercluster_charge_states(
1795
+ self,
1796
+ expanded_fs: FlowSystem,
1797
+ reduced_solution: xr.Dataset,
1798
+ clustering: Clustering,
1799
+ original_timesteps_extra: pd.DatetimeIndex,
1800
+ timesteps_per_cluster: int,
1801
+ n_original_clusters: int,
1802
+ ) -> None:
1803
+ """Combine charge_state with SOC_boundary for intercluster storages (in-place).
1804
+
1805
+ For intercluster storages, charge_state is relative (delta-E) and can be negative.
1806
+ Per Blanke et al. (2022) Eq. 9, actual SOC at time t in period d is:
1807
+ SOC(t) = SOC_boundary[d] * (1 - loss)^t_within_period + charge_state(t)
1808
+ where t_within_period is hours from period start (accounts for self-discharge decay).
1809
+
1810
+ Args:
1811
+ expanded_fs: The expanded FlowSystem (modified in-place).
1812
+ reduced_solution: The original reduced solution dataset.
1813
+ clustering: Clustering with cluster order info.
1814
+ original_timesteps_extra: Original timesteps including the extra final timestep.
1815
+ timesteps_per_cluster: Number of timesteps per cluster.
1816
+ n_original_clusters: Number of original clusters before aggregation.
1817
+ """
1818
+ n_original_timesteps_extra = len(original_timesteps_extra)
1819
+ soc_boundary_vars = self._fs.get_variables_by_category(VariableCategory.SOC_BOUNDARY)
1820
+
1821
+ for soc_boundary_name in soc_boundary_vars:
1822
+ storage_name = soc_boundary_name.rsplit('|', 1)[0]
1823
+ charge_state_name = f'{storage_name}|charge_state'
1824
+ if charge_state_name not in expanded_fs._solution:
1825
+ continue
1826
+
1827
+ soc_boundary = reduced_solution[soc_boundary_name]
1828
+ expanded_charge_state = expanded_fs._solution[charge_state_name]
1829
+
1830
+ # Map each original timestep to its original period index
1831
+ original_cluster_indices = np.minimum(
1832
+ np.arange(n_original_timesteps_extra) // timesteps_per_cluster,
1833
+ n_original_clusters - 1,
1834
+ )
1835
+
1836
+ # Select SOC_boundary for each timestep
1837
+ soc_boundary_per_timestep = soc_boundary.isel(
1838
+ cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time'])
1839
+ ).assign_coords(time=original_timesteps_extra)
1840
+
1841
+ # Apply self-discharge decay
1842
+ soc_boundary_per_timestep = self._apply_soc_decay(
1843
+ soc_boundary_per_timestep,
1844
+ storage_name,
1845
+ clustering,
1846
+ original_timesteps_extra,
1847
+ original_cluster_indices,
1848
+ timesteps_per_cluster,
1849
+ )
1850
+
1851
+ # Combine and clip to non-negative
1852
+ combined = (expanded_charge_state + soc_boundary_per_timestep).clip(min=0)
1853
+ expanded_fs._solution[charge_state_name] = combined.assign_attrs(expanded_charge_state.attrs)
1854
+
1855
+ # Clean up SOC_boundary variables and orphaned coordinates
1856
+ for soc_boundary_name in soc_boundary_vars:
1857
+ if soc_boundary_name in expanded_fs._solution:
1858
+ del expanded_fs._solution[soc_boundary_name]
1859
+ if 'cluster_boundary' in expanded_fs._solution.coords:
1860
+ expanded_fs._solution = expanded_fs._solution.drop_vars('cluster_boundary')
1861
+
1862
+ def _apply_soc_decay(
1863
+ self,
1864
+ soc_boundary_per_timestep: xr.DataArray,
1865
+ storage_name: str,
1866
+ clustering: Clustering,
1867
+ original_timesteps_extra: pd.DatetimeIndex,
1868
+ original_cluster_indices: np.ndarray,
1869
+ timesteps_per_cluster: int,
1870
+ ) -> xr.DataArray:
1871
+ """Apply self-discharge decay to SOC_boundary values.
1872
+
1873
+ Args:
1874
+ soc_boundary_per_timestep: SOC boundary values mapped to each timestep.
1875
+ storage_name: Name of the storage component.
1876
+ clustering: Clustering with cluster order info.
1877
+ original_timesteps_extra: Original timesteps including final extra timestep.
1878
+ original_cluster_indices: Mapping of timesteps to original cluster indices.
1879
+ timesteps_per_cluster: Number of timesteps per cluster.
1880
+
1881
+ Returns:
1882
+ SOC boundary values with decay applied.
1883
+ """
1884
+ storage = self._fs.storages.get(storage_name)
1885
+ if storage is None:
1886
+ return soc_boundary_per_timestep
1887
+
1888
+ n_timesteps = len(original_timesteps_extra)
1889
+
1890
+ # Time within period for each timestep (0, 1, 2, ..., T-1, 0, 1, ...)
1891
+ time_within_period = np.arange(n_timesteps) % timesteps_per_cluster
1892
+ time_within_period[-1] = timesteps_per_cluster # Extra timestep gets full decay
1893
+ time_within_period_da = xr.DataArray(
1894
+ time_within_period, dims=['time'], coords={'time': original_timesteps_extra}
1895
+ )
1896
+
1897
+ # Decay factor: (1 - loss)^t
1898
+ loss_value = _scalar_safe_reduce(storage.relative_loss_per_hour, 'time', 'mean')
1899
+ if not np.any(loss_value.values > 0):
1900
+ return soc_boundary_per_timestep
1901
+
1902
+ decay_da = (1 - loss_value) ** time_within_period_da
1903
+
1904
+ # Handle cluster dimension if present
1905
+ if 'cluster' in decay_da.dims:
1906
+ cluster_assignments = clustering.cluster_assignments
1907
+ if cluster_assignments.ndim == 1:
1908
+ cluster_per_timestep = xr.DataArray(
1909
+ cluster_assignments.values[original_cluster_indices],
1910
+ dims=['time'],
1911
+ coords={'time': original_timesteps_extra},
1912
+ )
1913
+ else:
1914
+ cluster_per_timestep = cluster_assignments.isel(
1915
+ original_cluster=xr.DataArray(original_cluster_indices, dims=['time'])
1916
+ ).assign_coords(time=original_timesteps_extra)
1917
+ decay_da = decay_da.isel(cluster=cluster_per_timestep).drop_vars('cluster', errors='ignore')
1918
+
1919
+ return soc_boundary_per_timestep * decay_da
1920
+
1921
+ def _build_segment_total_varnames(self) -> set[str]:
1922
+ """Build segment total variable names - BACKWARDS COMPATIBILITY FALLBACK.
1923
+
1924
+ This method is only used when variable_categories is empty (old FlowSystems
1925
+ saved before category registration was implemented). New FlowSystems use
1926
+ the VariableCategory registry with EXPAND_DIVIDE categories (PER_TIMESTEP, SHARE).
1927
+
1928
+ For segmented systems, these variables contain values that are summed over
1929
+ segments. When expanded to hourly resolution, they need to be divided by
1930
+ segment duration to get correct hourly rates.
1931
+
1932
+ Returns:
1933
+ Set of variable names that should be divided by expansion divisor.
1934
+ """
1935
+ segment_total_vars: set[str] = set()
1936
+
1937
+ # Get all effect names
1938
+ effect_names = list(self._fs.effects.keys())
1939
+
1940
+ # 1. Per-timestep totals for each effect: {effect}(temporal)|per_timestep
1941
+ for effect in effect_names:
1942
+ segment_total_vars.add(f'{effect}(temporal)|per_timestep')
1943
+
1944
+ # 2. Flow contributions to effects: {flow}->{effect}(temporal)
1945
+ # (from effects_per_flow_hour on Flow elements)
1946
+ for flow_label in self._fs.flows:
1947
+ for effect in effect_names:
1948
+ segment_total_vars.add(f'{flow_label}->{effect}(temporal)')
1949
+
1950
+ # 3. Component contributions to effects: {component}->{effect}(temporal)
1951
+ # (from effects_per_startup, effects_per_active_hour on OnOffParameters)
1952
+ for component_label in self._fs.components:
1953
+ for effect in effect_names:
1954
+ segment_total_vars.add(f'{component_label}->{effect}(temporal)')
1955
+
1956
+ # 4. Effect-to-effect contributions (from share_from_temporal)
1957
+ # {source_effect}(temporal)->{target_effect}(temporal)
1958
+ for target_effect_name, target_effect in self._fs.effects.items():
1959
+ if target_effect.share_from_temporal:
1960
+ for source_effect_name in target_effect.share_from_temporal:
1961
+ segment_total_vars.add(f'{source_effect_name}(temporal)->{target_effect_name}(temporal)')
1962
+
1963
+ return segment_total_vars
1964
+
1965
+ def _interpolate_charge_state_segmented(
1966
+ self,
1967
+ da: xr.DataArray,
1968
+ clustering: Clustering,
1969
+ original_timesteps: pd.DatetimeIndex,
1970
+ ) -> xr.DataArray:
1971
+ """Interpolate charge_state values within segments for segmented systems.
1972
+
1973
+ For segmented systems, charge_state has values at segment boundaries (n_segments+1).
1974
+ Instead of repeating the start boundary value for all timesteps in a segment,
1975
+ this method interpolates between start and end boundary values to show the
1976
+ actual charge trajectory as the storage charges/discharges.
1977
+
1978
+ Uses vectorized xarray operations via Clustering class properties.
1979
+
1980
+ Args:
1981
+ da: charge_state DataArray with dims (cluster, time) where time has n_segments+1 entries.
1982
+ clustering: Clustering object with segment info.
1983
+ original_timesteps: Original timesteps to expand to.
1984
+
1985
+ Returns:
1986
+ Interpolated charge_state with dims (time, ...) for original timesteps.
1987
+ """
1988
+ # Get multi-dimensional properties from Clustering
1989
+ timestep_mapping = clustering.timestep_mapping
1990
+ segment_assignments = clustering.results.segment_assignments
1991
+ segment_durations = clustering.results.segment_durations
1992
+ position_within_segment = clustering.results.position_within_segment
1993
+
1994
+ # Decode timestep_mapping into cluster and time indices
1995
+ # timestep_mapping encodes original timestep -> (cluster, position_within_cluster)
1996
+ # where position_within_cluster indexes into segment_assignments/position_within_segment
1997
+ # which have shape (cluster, timesteps_per_cluster)
1998
+ timesteps_per_cluster = clustering.timesteps_per_cluster
1999
+ cluster_indices = timestep_mapping // timesteps_per_cluster
2000
+ time_indices = timestep_mapping % timesteps_per_cluster
2001
+
2002
+ # Get segment index and position for each original timestep
2003
+ seg_indices = segment_assignments.isel(cluster=cluster_indices, time=time_indices)
2004
+ positions = position_within_segment.isel(cluster=cluster_indices, time=time_indices)
2005
+ durations = segment_durations.isel(cluster=cluster_indices, segment=seg_indices)
2006
+
2007
+ # Calculate interpolation factor: position within segment (0 to 1)
2008
+ # At position=0, factor=0.5/duration (start of segment)
2009
+ # At position=duration-1, factor approaches 1 (end of segment)
2010
+ factor = xr.where(durations > 1, (positions + 0.5) / durations, 0.5)
2011
+
2012
+ # Get start and end boundary values from charge_state
2013
+ # charge_state has dims (cluster, time) where time = segment boundaries (n_segments+1)
2014
+ start_vals = da.isel(cluster=cluster_indices, time=seg_indices)
2015
+ end_vals = da.isel(cluster=cluster_indices, time=seg_indices + 1)
2016
+
2017
+ # Linear interpolation
2018
+ interpolated = start_vals + (end_vals - start_vals) * factor
2019
+
2020
+ # Clean up coordinate artifacts and rename
2021
+ interpolated = interpolated.drop_vars(['cluster', 'time', 'segment'], errors='ignore')
2022
+ interpolated = interpolated.rename({'original_time': 'time'}).assign_coords(time=original_timesteps)
2023
+
2024
+ return interpolated.transpose('time', ...).assign_attrs(da.attrs)
2025
+
2026
+ def expand(self) -> FlowSystem:
2027
+ """Expand a clustered FlowSystem back to full original timesteps.
2028
+
2029
+ After solving a FlowSystem created with ``cluster()``, this method
2030
+ disaggregates the FlowSystem by:
2031
+ 1. Expanding all time series data from typical clusters to full timesteps
2032
+ 2. Expanding the solution by mapping each typical cluster back to all
2033
+ original clusters it represents
2034
+
2035
+ For FlowSystems with periods and/or scenarios, each (period, scenario)
2036
+ combination is expanded using its own cluster assignment.
2037
+
2038
+ This enables using all existing solution accessors (``statistics``, ``plot``, etc.)
2039
+ with full time resolution, where both the data and solution are consistently
2040
+ expanded from the typical clusters.
2041
+
2042
+ Returns:
2043
+ FlowSystem: A new FlowSystem with full timesteps and expanded solution.
2044
+
2045
+ Raises:
2046
+ ValueError: If the FlowSystem was not created with ``cluster()``.
2047
+ ValueError: If the FlowSystem has no solution.
2048
+
2049
+ Examples:
2050
+ Two-stage optimization with expansion:
2051
+
2052
+ >>> # Stage 1: Size with reduced timesteps
2053
+ >>> fs_reduced = flow_system.transform.cluster(
2054
+ ... n_clusters=8,
2055
+ ... cluster_duration='1D',
2056
+ ... )
2057
+ >>> fs_reduced.optimize(solver)
2058
+ >>>
2059
+ >>> # Expand to full resolution FlowSystem
2060
+ >>> fs_expanded = fs_reduced.transform.expand()
2061
+ >>>
2062
+ >>> # Use all existing accessors with full timesteps
2063
+ >>> fs_expanded.statistics.flow_rates # Full 8760 timesteps
2064
+ >>> fs_expanded.statistics.plot.balance('HeatBus') # Full resolution plots
2065
+ >>> fs_expanded.statistics.plot.heatmap('Boiler(Q_th)|flow_rate')
2066
+
2067
+ Note:
2068
+ The expanded FlowSystem repeats the typical cluster values for all
2069
+ original clusters belonging to the same cluster. Both input data and solution
2070
+ are consistently expanded, so they match. This is an approximation -
2071
+ the actual dispatch at full resolution would differ due to
2072
+ intra-cluster variations in time series data.
2073
+
2074
+ For accurate dispatch results, use ``fix_sizes()`` to fix the sizes
2075
+ from the reduced optimization and re-optimize at full resolution.
2076
+
2077
+ **Segmented Systems Variable Handling:**
2078
+
2079
+ For systems clustered with ``SegmentConfig``, special handling is applied
2080
+ to time-varying solution variables. Variables without a ``time`` dimension
2081
+ are unaffected by segment expansion. This includes:
2082
+
2083
+ - Investment: ``{component}|size``, ``{component}|exists``
2084
+ - Storage boundaries: ``{storage}|SOC_boundary``
2085
+ - Aggregated totals: ``{flow}|total_flow_hours``, ``{flow}|active_hours``
2086
+ - Effect totals: ``{effect}``, ``{effect}(temporal)``, ``{effect}(periodic)``
2087
+
2088
+ Time-varying variables are categorized and handled as follows:
2089
+
2090
+ 1. **State variables** - Interpolated within segments:
2091
+
2092
+ - ``{storage}|charge_state``: Linear interpolation between segment
2093
+ boundary values to show the charge trajectory during charge/discharge.
2094
+
2095
+ 2. **Segment totals** - Divided by segment duration:
2096
+
2097
+ These variables represent values summed over the segment. Division
2098
+ converts them back to hourly rates for correct plotting and analysis.
2099
+
2100
+ - ``{effect}(temporal)|per_timestep``: Per-timestep effect contributions
2101
+ - ``{flow}->{effect}(temporal)``: Flow contributions (includes both
2102
+ ``effects_per_flow_hour`` and ``effects_per_startup``)
2103
+ - ``{component}->{effect}(temporal)``: Component-level contributions
2104
+ - ``{source}(temporal)->{target}(temporal)``: Effect-to-effect shares
2105
+
2106
+ 3. **Rate/average variables** - Expanded as-is:
2107
+
2108
+ These variables represent average values within the segment. tsam
2109
+ already provides properly averaged values, so no correction needed.
2110
+
2111
+ - ``{flow}|flow_rate``: Average flow rate during segment
2112
+ - ``{storage}|netto_discharge``: Net discharge rate (discharge - charge)
2113
+
2114
+ 4. **Binary status variables** - Constant within segment:
2115
+
2116
+ These variables cannot be meaningfully interpolated. They indicate
2117
+ the dominant state or whether an event occurred during the segment.
2118
+
2119
+ - ``{flow}|status``: On/off status (0 or 1)
2120
+ - ``{flow}|startup``: Startup event occurred in segment
2121
+ - ``{flow}|shutdown``: Shutdown event occurred in segment
2122
+ """
2123
+ from .flow_system import FlowSystem
2124
+
2125
+ # Validate and extract clustering info
2126
+ clustering = self._validate_for_expansion()
2127
+
2128
+ timesteps_per_cluster = clustering.timesteps_per_cluster
2129
+ # For segmented systems, the time dimension has n_segments entries
2130
+ n_segments = clustering.n_segments
2131
+ time_dim_size = n_segments if n_segments is not None else timesteps_per_cluster
2132
+ n_clusters = clustering.n_clusters
2133
+ n_original_clusters = clustering.n_original_clusters
2134
+
2135
+ # Get original timesteps and dimensions
2136
+ original_timesteps = clustering.original_timesteps
2137
+ n_original_timesteps = len(original_timesteps)
2138
+ original_timesteps_extra = FlowSystem._create_timesteps_with_extra(original_timesteps, None)
2139
+
2140
+ # For charge_state expansion: index of last valid original cluster
2141
+ last_original_cluster_idx = min(
2142
+ (n_original_timesteps - 1) // timesteps_per_cluster,
2143
+ n_original_clusters - 1,
2144
+ )
2145
+
2146
+ # For segmented systems: build expansion divisor and identify segment total variables
2147
+ expansion_divisor = None
2148
+ segment_total_vars: set[str] = set()
2149
+ variable_categories = getattr(self._fs, '_variable_categories', {})
2150
+ if clustering.is_segmented:
2151
+ expansion_divisor = clustering.build_expansion_divisor(original_time=original_timesteps)
2152
+ # Build segment total vars using registry first, fall back to pattern matching
2153
+ segment_total_vars = {name for name, cat in variable_categories.items() if cat in EXPAND_DIVIDE}
2154
+ # Fall back to pattern matching for backwards compatibility (old FlowSystems without categories)
2155
+ if not segment_total_vars:
2156
+ segment_total_vars = self._build_segment_total_varnames()
2157
+
2158
+ def _is_state_variable(var_name: str) -> bool:
2159
+ """Check if a variable is a state variable (should be interpolated)."""
2160
+ if var_name in variable_categories:
2161
+ return variable_categories[var_name] in EXPAND_INTERPOLATE
2162
+ # Fall back to pattern matching for backwards compatibility
2163
+ return var_name.endswith('|charge_state')
2164
+
2165
+ def _append_final_state(expanded: xr.DataArray, da: xr.DataArray) -> xr.DataArray:
2166
+ """Append final state value from original data to expanded data."""
2167
+ cluster_assignments = clustering.cluster_assignments
2168
+ if cluster_assignments.ndim == 1:
2169
+ last_cluster = int(cluster_assignments.values[last_original_cluster_idx])
2170
+ extra_val = da.isel(cluster=last_cluster, time=-1)
2171
+ else:
2172
+ last_clusters = cluster_assignments.isel(original_cluster=last_original_cluster_idx)
2173
+ extra_val = da.isel(cluster=last_clusters, time=-1)
2174
+ extra_val = extra_val.drop_vars(['cluster', 'time'], errors='ignore')
2175
+ extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]])
2176
+ return xr.concat([expanded, extra_val], dim='time')
2177
+
2178
+ def expand_da(da: xr.DataArray, var_name: str = '', is_solution: bool = False) -> xr.DataArray:
2179
+ """Expand a DataArray from clustered to original timesteps."""
2180
+ if 'time' not in da.dims:
2181
+ return da.copy()
2182
+
2183
+ is_state = _is_state_variable(var_name) and 'cluster' in da.dims
2184
+
2185
+ # State variables in segmented systems: interpolate within segments
2186
+ if is_state and clustering.is_segmented:
2187
+ expanded = self._interpolate_charge_state_segmented(da, clustering, original_timesteps)
2188
+ return _append_final_state(expanded, da)
2189
+
2190
+ expanded = clustering.expand_data(da, original_time=original_timesteps)
2191
+
2192
+ # Segment totals: divide by expansion divisor
2193
+ if is_solution and expansion_divisor is not None and var_name in segment_total_vars:
2194
+ expanded = expanded / expansion_divisor
2195
+
2196
+ # State variables: append final state
2197
+ if is_state:
2198
+ expanded = _append_final_state(expanded, da)
2199
+
2200
+ return expanded
2201
+
2202
+ # Helper to construct DataArray without slow _construct_dataarray
2203
+ def _fast_get_da(ds: xr.Dataset, name: str, coord_cache: dict) -> xr.DataArray:
2204
+ variable = ds.variables[name]
2205
+ var_dims = set(variable.dims)
2206
+ coords = {k: v for k, v in coord_cache.items() if set(v.dims).issubset(var_dims)}
2207
+ return xr.DataArray(variable, coords=coords, name=name)
2208
+
2209
+ # 1. Expand FlowSystem data
2210
+ reduced_ds = self._fs.to_dataset(include_solution=False)
2211
+ clustering_attrs = {'is_clustered', 'n_clusters', 'timesteps_per_cluster', 'clustering', 'cluster_weight'}
2212
+ skip_vars = {'cluster_weight', 'timestep_duration'} # These have special handling
2213
+ data_vars = {}
2214
+ # Use ds.variables pattern to avoid slow _construct_dataarray calls
2215
+ coord_cache = {k: v for k, v in reduced_ds.coords.items()}
2216
+ coord_names = set(coord_cache)
2217
+ for name in reduced_ds.variables:
2218
+ if name in coord_names:
2219
+ continue
2220
+ if name in skip_vars or name.startswith('clustering|'):
2221
+ continue
2222
+ da = _fast_get_da(reduced_ds, name, coord_cache)
2223
+ # Skip vars with cluster dim but no time dim - they don't make sense after expansion
2224
+ # (e.g., representative_weights with dims ('cluster',) or ('cluster', 'period'))
2225
+ if 'cluster' in da.dims and 'time' not in da.dims:
2226
+ continue
2227
+ data_vars[name] = expand_da(da, name)
2228
+ # Remove timestep_duration reference from attrs - let FlowSystem compute it from timesteps_extra
2229
+ # This ensures proper time coordinates for xarray alignment with N+1 solution timesteps
2230
+ attrs = {k: v for k, v in reduced_ds.attrs.items() if k not in clustering_attrs and k != 'timestep_duration'}
2231
+ expanded_ds = xr.Dataset(data_vars, attrs=attrs)
2232
+
2233
+ expanded_fs = FlowSystem.from_dataset(expanded_ds)
2234
+
2235
+ # 2. Expand solution (with segment total correction for segmented systems)
2236
+ reduced_solution = self._fs.solution
2237
+ # Use ds.variables pattern to avoid slow _construct_dataarray calls
2238
+ sol_coord_cache = {k: v for k, v in reduced_solution.coords.items()}
2239
+ sol_coord_names = set(sol_coord_cache)
2240
+ expanded_sol_vars = {}
2241
+ for name in reduced_solution.variables:
2242
+ if name in sol_coord_names:
2243
+ continue
2244
+ da = _fast_get_da(reduced_solution, name, sol_coord_cache)
2245
+ expanded_sol_vars[name] = expand_da(da, name, is_solution=True)
2246
+ expanded_fs._solution = xr.Dataset(expanded_sol_vars, attrs=reduced_solution.attrs)
2247
+ expanded_fs._solution = expanded_fs._solution.reindex(time=original_timesteps_extra)
2248
+
2249
+ # 3. Combine charge_state with SOC_boundary for intercluster storages
2250
+ self._combine_intercluster_charge_states(
2251
+ expanded_fs,
2252
+ reduced_solution,
2253
+ clustering,
2254
+ original_timesteps_extra,
2255
+ timesteps_per_cluster,
2256
+ n_original_clusters,
2257
+ )
2258
+
2259
+ # Log expansion info
2260
+ has_periods = self._fs.periods is not None
2261
+ has_scenarios = self._fs.scenarios is not None
2262
+ n_combinations = (len(self._fs.periods) if has_periods else 1) * (
2263
+ len(self._fs.scenarios) if has_scenarios else 1
2264
+ )
2265
+ n_reduced_timesteps = n_clusters * time_dim_size
2266
+ segmented_info = f' ({n_segments} segments)' if n_segments else ''
2267
+ logger.info(
2268
+ f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps '
2269
+ f'({n_clusters} clusters{segmented_info}'
2270
+ + (
2271
+ f', {n_combinations} period/scenario combinations)'
2272
+ if n_combinations > 1
2273
+ else f' → {n_original_clusters} original clusters)'
2274
+ )
2275
+ )
2276
+
2277
+ return expanded_fs