tsam-xarray 0.0.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ """tsam_xarray: Lightweight xarray wrapper for tsam time series aggregation."""
2
+
3
+ from tsam_xarray._core import aggregate
4
+ from tsam_xarray._result import AccuracyMetrics, AggregationResult
5
+
6
+ __all__ = ["AccuracyMetrics", "AggregationResult", "aggregate"]
tsam_xarray/_core.py ADDED
@@ -0,0 +1,457 @@
1
+ """Core aggregation logic for tsam_xarray."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools
6
+ from collections.abc import Hashable, Sequence
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import tsam
12
+ import xarray as xr
13
+
14
+ from tsam_xarray._result import AccuracyMetrics, AggregationResult
15
+
16
+ Weights = dict[str, float] | dict[str, dict[str, float]] | None
17
+
18
+
19
+ def aggregate(
20
+ da: xr.DataArray,
21
+ *,
22
+ time_dim: str,
23
+ cluster_dim: Sequence[str] | str,
24
+ n_clusters: int,
25
+ weights: Weights = None,
26
+ **tsam_kwargs: Any,
27
+ ) -> AggregationResult:
28
+ """Aggregate an xarray DataArray using tsam.
29
+
30
+ Parameters
31
+ ----------
32
+ da : xr.DataArray
33
+ Input data with a time dimension and optional extra dimensions.
34
+ time_dim : str
35
+ Name of the time dimension.
36
+ cluster_dim : Sequence[str] | str
37
+ Dimension(s) to cluster together. Multiple dims are stacked
38
+ internally into a MultiIndex and unstacked in results.
39
+ All remaining dims are sliced independently.
40
+ n_clusters : int
41
+ Number of typical periods.
42
+ weights : dict[str, float] | dict[str, dict[str, float]] | None
43
+ Per-coordinate weights for clustering. Missing entries default
44
+ to 1.0. Two formats:
45
+
46
+ - **Simple dict** (single ``cluster_dim``)::
47
+
48
+ weights={"solar": 2.0, "wind": 1.0}
49
+
50
+ - **Dict-of-dicts** (multiple ``cluster_dim``)::
51
+
52
+ weights={"variable": {"solar": 2.0}, "region": {"north": 1.5}}
53
+
54
+ Weights are multiplied across dimensions, e.g. ``("solar", "north")``
55
+ gets weight ``2.0 * 1.5 = 3.0``.
56
+
57
+ **tsam_kwargs
58
+ Additional keyword arguments passed to ``tsam.aggregate()``.
59
+ """
60
+ _validate_time_dim(da, time_dim)
61
+ col_dims = _resolve_cluster_dim(cluster_dim)
62
+ slice_dims = _infer_slice_dims(da, time_dim, col_dims)
63
+ _validate(da, time_dim, col_dims, slice_dims)
64
+ _validate_no_cluster_config_weights(tsam_kwargs)
65
+ per_dim_weights = _normalize_weights(weights, da, col_dims)
66
+
67
+ if not slice_dims:
68
+ return _aggregate_single(
69
+ da, n_clusters, time_dim, col_dims, per_dim_weights, tsam_kwargs
70
+ )
71
+
72
+ slice_coords = {d: da.coords[d].values for d in slice_dims}
73
+ slice_keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
74
+
75
+ results: list[AggregationResult] = []
76
+ raw_map: dict[tuple[Hashable, ...], Any] = {}
77
+
78
+ for key in slice_keys:
79
+ sel = dict(zip(slice_dims, key, strict=True))
80
+ da_slice = da.sel(sel)
81
+ r = _aggregate_single(
82
+ da_slice, n_clusters, time_dim, col_dims, per_dim_weights, tsam_kwargs
83
+ )
84
+ results.append(r)
85
+ raw_map[key] = r.raw
86
+
87
+ return _concat_results(results, slice_dims, slice_coords, raw_map)
88
+
89
+
90
+ def _resolve_cluster_dim(
91
+ cluster_dim: Sequence[str] | str,
92
+ ) -> list[str]:
93
+ """Resolve cluster_dim to a list of dimension names."""
94
+ if isinstance(cluster_dim, str):
95
+ return [cluster_dim]
96
+ return list(cluster_dim)
97
+
98
+
99
+ def _infer_slice_dims(
100
+ da: xr.DataArray,
101
+ time_dim: str,
102
+ col_dims: list[str],
103
+ ) -> list[str]:
104
+ """Infer slice dims: everything not time_dim or column dims."""
105
+ exclude = {time_dim, *col_dims}
106
+ return [str(d) for d in da.dims if d not in exclude]
107
+
108
+
109
+ def _validate_time_dim(da: xr.DataArray, time_dim: str) -> None:
110
+ if time_dim not in da.dims:
111
+ msg = f"time_dim {time_dim!r} not in DataArray dims {set(da.dims)}"
112
+ raise ValueError(msg)
113
+
114
+
115
+ def _validate_no_cluster_config_weights(
116
+ tsam_kwargs: dict[str, Any],
117
+ ) -> None:
118
+ """Reject deprecated weights in ClusterConfig."""
119
+ cluster_config = tsam_kwargs.get("cluster")
120
+ if cluster_config is not None and cluster_config.weights is not None:
121
+ msg = (
122
+ "ClusterConfig.weights is deprecated in tsam and not "
123
+ "supported by tsam_xarray. Use the top-level 'weights' "
124
+ "parameter of aggregate() instead."
125
+ )
126
+ raise ValueError(msg)
127
+
128
+
129
+ def _validate(
130
+ da: xr.DataArray,
131
+ time_dim: str,
132
+ col_dims: list[str],
133
+ slice_dims: list[str],
134
+ ) -> None:
135
+ dims = set(da.dims)
136
+ for d in col_dims:
137
+ if d not in dims:
138
+ msg = f"cluster_dim entry {d!r} not in DataArray dims {dims}"
139
+ raise ValueError(msg)
140
+ if d == time_dim:
141
+ msg = "cluster_dim and time_dim must not overlap"
142
+ raise ValueError(msg)
143
+
144
+
145
+ def _to_dataframe(
146
+ da: xr.DataArray,
147
+ time_dim: str,
148
+ col_dims: list[str],
149
+ ) -> pd.DataFrame:
150
+ """Convert DataArray to DataFrame for tsam."""
151
+ if not col_dims:
152
+ s = da.to_pandas()
153
+ if isinstance(s, pd.Series):
154
+ name = da.name or "value"
155
+ return s.to_frame(name=str(name))
156
+ return pd.DataFrame(s)
157
+
158
+ if len(col_dims) > 1:
159
+ da = da.stack(_column=col_dims)
160
+ col_dim = "_column"
161
+ else:
162
+ col_dim = col_dims[0]
163
+
164
+ da_t = da.transpose(time_dim, col_dim)
165
+ return pd.DataFrame(da_t.to_pandas())
166
+
167
+
168
+ def _representatives_to_da(
169
+ df: pd.DataFrame,
170
+ col_dims: list[str],
171
+ ) -> xr.DataArray:
172
+ """Convert cluster_representatives DataFrame to DataArray."""
173
+ df = df.copy()
174
+ # With segmentation, index has 3 levels: (cluster, segment_step, segment_duration)
175
+ # Without: 2 levels: (cluster, timestep)
176
+ if df.index.nlevels == 3:
177
+ df.index = df.index.droplevel(2) # drop segment_duration
178
+ df.index.names = ["cluster", "timestep"]
179
+
180
+ if not col_dims:
181
+ clusters = df.index.get_level_values(0).unique()
182
+ timesteps = df.index.get_level_values(1).unique()
183
+ values = df.values.squeeze(axis=1).reshape(len(clusters), len(timesteps))
184
+ return xr.DataArray(
185
+ values,
186
+ dims=["cluster", "timestep"],
187
+ coords={"cluster": clusters, "timestep": timesteps},
188
+ )
189
+
190
+ stacked = df.stack(df.columns.names, future_stack=True)
191
+ da: xr.DataArray = stacked.to_xarray() # type: ignore[assignment]
192
+ return da
193
+
194
+
195
+ def _segment_durations_to_da(
196
+ raw_durations: tuple[tuple[int, ...], ...] | None,
197
+ ) -> xr.DataArray | None:
198
+ """Convert tsam segment_durations to DataArray."""
199
+ if raw_durations is None:
200
+ return None
201
+ data = np.array(raw_durations) # (n_clusters, n_segments)
202
+ return xr.DataArray(
203
+ data,
204
+ dims=["cluster", "timestep"],
205
+ coords={
206
+ "cluster": np.arange(data.shape[0]),
207
+ "timestep": np.arange(data.shape[1]),
208
+ },
209
+ )
210
+
211
+
212
+ def _reconstructed_to_da(
213
+ df: pd.DataFrame,
214
+ time_dim: str,
215
+ col_dims: list[str],
216
+ ) -> xr.DataArray:
217
+ """Convert reconstructed DataFrame to DataArray."""
218
+ df = df.copy()
219
+ df.index.name = time_dim
220
+
221
+ if not col_dims:
222
+ return xr.DataArray(
223
+ df.values.squeeze(axis=1),
224
+ dims=[time_dim],
225
+ coords={time_dim: df.index},
226
+ )
227
+
228
+ stacked = df.stack(df.columns.names, future_stack=True)
229
+ da: xr.DataArray = stacked.to_xarray() # type: ignore[assignment]
230
+ return da
231
+
232
+
233
+ def _metric_to_da(
234
+ series: pd.Series[float],
235
+ col_dims: list[str],
236
+ column_names: list[str] | None = None,
237
+ ) -> xr.DataArray:
238
+ """Convert an accuracy metric Series to DataArray."""
239
+ if not col_dims:
240
+ return xr.DataArray(float(series.iloc[0]))
241
+ series = series.copy()
242
+ if isinstance(series.index, pd.MultiIndex):
243
+ if column_names is not None:
244
+ series.index = series.index.set_names(column_names)
245
+ elif series.index.name is None:
246
+ series.index.name = col_dims[0]
247
+ return xr.DataArray(series.to_xarray())
248
+
249
+
250
+ def _normalize_weights(
251
+ weights: dict[str, float] | dict[str, dict[str, float]] | None,
252
+ da: xr.DataArray,
253
+ col_dims: list[str],
254
+ ) -> dict[str, dict[str, float]] | None:
255
+ """Normalize weights to dict-of-dicts and validate dims/coords."""
256
+ if weights is None or not weights:
257
+ return None
258
+
259
+ first_val = next(iter(weights.values()))
260
+ if isinstance(first_val, dict):
261
+ # Dict-of-dicts — validate all values are dicts
262
+ for _key, val in weights.items():
263
+ if not isinstance(val, dict):
264
+ msg = (
265
+ "Mixed weights format: all values must be dicts. "
266
+ 'Use {"dim": {"coord": weight}} for all entries.'
267
+ )
268
+ raise ValueError(msg)
269
+ per_dim_weights: dict[str, dict[str, float]] = weights # type: ignore[assignment]
270
+ else:
271
+ # Simple dict — requires single cluster_dim
272
+ if len(col_dims) != 1:
273
+ msg = (
274
+ "Simple dict weights require a single cluster_dim. "
275
+ "For multiple cluster_dim, use dict-of-dicts: "
276
+ '{"dim_name": {"coord": weight}}.'
277
+ )
278
+ raise ValueError(msg)
279
+ per_dim_weights = {col_dims[0]: weights} # type: ignore[dict-item]
280
+
281
+ # Validate dim names exist in cluster_dim
282
+ extra_dims = set(per_dim_weights.keys()) - set(col_dims)
283
+ if extra_dims:
284
+ msg = (
285
+ f"weights has unknown dims {extra_dims}, "
286
+ f"must be subset of cluster_dim {col_dims}"
287
+ )
288
+ raise ValueError(msg)
289
+
290
+ # Validate coord values exist in the DataArray
291
+ for dim_name, coord_weights in per_dim_weights.items():
292
+ valid_coords = set(str(c) for c in da.coords[dim_name].values)
293
+ unknown = set(coord_weights.keys()) - valid_coords
294
+ if unknown:
295
+ msg = (
296
+ f"weights has unknown coords {unknown} for dim {dim_name!r}, "
297
+ f"valid coords: {sorted(valid_coords)}"
298
+ )
299
+ raise ValueError(msg)
300
+
301
+ return per_dim_weights
302
+
303
+
304
+ def _translate_weights(
305
+ weights: dict[str, dict[str, float]],
306
+ df: pd.DataFrame,
307
+ col_dims: list[str],
308
+ ) -> dict[Hashable, float]:
309
+ """Translate per-dim weights to flat column weights for tsam."""
310
+ flat: dict[Hashable, float] = {}
311
+ for col in df.columns:
312
+ w = 1.0
313
+ if isinstance(col, tuple):
314
+ for dim_name, coord_val in zip(col_dims, col, strict=True):
315
+ if dim_name in weights:
316
+ w *= weights[dim_name].get(str(coord_val), 1.0)
317
+ else:
318
+ dim_name = col_dims[0]
319
+ if dim_name in weights:
320
+ w *= weights[dim_name].get(str(col), 1.0)
321
+ flat[col] = w
322
+ return flat
323
+
324
+
325
+ def _aggregate_single(
326
+ da: xr.DataArray,
327
+ n_clusters: int,
328
+ time_dim: str,
329
+ col_dims: list[str],
330
+ weights: dict[str, dict[str, float]] | None,
331
+ tsam_kwargs: dict[str, Any],
332
+ ) -> AggregationResult:
333
+ """Run a single tsam aggregation on a DataArray."""
334
+ df = _to_dataframe(da, time_dim, col_dims)
335
+
336
+ tsam_weights: dict[Hashable, float] | None = None
337
+ if weights is not None:
338
+ tsam_weights = _translate_weights(weights, df, col_dims)
339
+
340
+ tsam_result = tsam.aggregate(
341
+ df,
342
+ n_clusters,
343
+ weights=tsam_weights, # type: ignore[arg-type]
344
+ **tsam_kwargs,
345
+ )
346
+
347
+ typical = _representatives_to_da(tsam_result.cluster_representatives, col_dims)
348
+ reconstructed = _reconstructed_to_da(tsam_result.reconstructed, time_dim, col_dims)
349
+
350
+ cw = tsam_result.cluster_weights
351
+ cluster_ids = np.array(sorted(cw.keys()))
352
+ cluster_weights_da = xr.DataArray(
353
+ np.array([cw[k] for k in cluster_ids]),
354
+ dims=["cluster"],
355
+ coords={"cluster": cluster_ids},
356
+ )
357
+
358
+ assignments_da = xr.DataArray(tsam_result.cluster_assignments, dims=["period"])
359
+
360
+ col_names: list[str] | None = None
361
+ if isinstance(df.columns, pd.MultiIndex):
362
+ col_names = [str(n) for n in df.columns.names]
363
+
364
+ accuracy = AccuracyMetrics(
365
+ rmse=_metric_to_da(tsam_result.accuracy.rmse, col_dims, col_names),
366
+ mae=_metric_to_da(tsam_result.accuracy.mae, col_dims, col_names),
367
+ rmse_duration=_metric_to_da(
368
+ tsam_result.accuracy.rmse_duration, col_dims, col_names
369
+ ),
370
+ )
371
+
372
+ seg_durations = _segment_durations_to_da(tsam_result.segment_durations)
373
+
374
+ return AggregationResult(
375
+ typical_periods=typical,
376
+ cluster_assignments=assignments_da,
377
+ cluster_weights=cluster_weights_da,
378
+ segment_durations=seg_durations,
379
+ accuracy=accuracy,
380
+ reconstructed=reconstructed,
381
+ original=da,
382
+ raw=tsam_result,
383
+ )
384
+
385
+
386
+ def _make_dim_index(
387
+ slice_coords: dict[str, Any],
388
+ dim: str,
389
+ ) -> pd.Index:
390
+ """Create a pd.Index for a slice dimension."""
391
+ return pd.Index(slice_coords[dim], name=dim) # type: ignore[no-any-return]
392
+
393
+
394
+ def _concat_along_dims(
395
+ arrays: list[xr.DataArray],
396
+ slice_dims: list[str],
397
+ slice_coords: dict[str, Any],
398
+ ) -> xr.DataArray:
399
+ """Concat arrays along one or more slice dims."""
400
+ if len(slice_dims) == 1:
401
+ return xr.concat(arrays, dim=_make_dim_index(slice_coords, slice_dims[0]))
402
+ it = iter(arrays)
403
+
404
+ def _nest(dims: list[str]) -> list[Any]:
405
+ if len(dims) == 1:
406
+ return [next(it) for _ in slice_coords[dims[0]]]
407
+ return [_nest(dims[1:]) for _ in slice_coords[dims[0]]]
408
+
409
+ nested: Any = _nest(slice_dims)
410
+
411
+ def _recursive_concat(node: Any, dims: list[str]) -> xr.DataArray:
412
+ dim = dims[0]
413
+ idx = _make_dim_index(slice_coords, dim)
414
+ if len(dims) == 1:
415
+ return xr.concat(node, dim=idx) # type: ignore[no-any-return]
416
+ children = [_recursive_concat(child, dims[1:]) for child in node]
417
+ return xr.concat(children, dim=idx)
418
+
419
+ return _recursive_concat(nested, slice_dims)
420
+
421
+
422
+ def _concat_results(
423
+ results: list[AggregationResult],
424
+ slice_dims: list[str],
425
+ slice_coords: dict[str, Any],
426
+ raw_map: dict[tuple[Hashable, ...], Any],
427
+ ) -> AggregationResult:
428
+ """Concatenate per-slice results along slice dims."""
429
+
430
+ def _field(field_name: str) -> xr.DataArray:
431
+ arrays = [getattr(r, field_name) for r in results]
432
+ return _concat_along_dims(arrays, slice_dims, slice_coords)
433
+
434
+ def _optional_field(field_name: str) -> xr.DataArray | None:
435
+ arrays = [getattr(r, field_name) for r in results]
436
+ if arrays[0] is None:
437
+ return None
438
+ return _concat_along_dims(arrays, slice_dims, slice_coords)
439
+
440
+ def _acc_field(field_name: str) -> xr.DataArray:
441
+ arrays = [getattr(r.accuracy, field_name) for r in results]
442
+ return _concat_along_dims(arrays, slice_dims, slice_coords)
443
+
444
+ return AggregationResult(
445
+ typical_periods=_field("typical_periods"),
446
+ cluster_assignments=_field("cluster_assignments"),
447
+ cluster_weights=_field("cluster_weights"),
448
+ segment_durations=_optional_field("segment_durations"),
449
+ accuracy=AccuracyMetrics(
450
+ rmse=_acc_field("rmse"),
451
+ mae=_acc_field("mae"),
452
+ rmse_duration=_acc_field("rmse_duration"),
453
+ ),
454
+ reconstructed=_field("reconstructed"),
455
+ original=_field("original"),
456
+ raw=raw_map,
457
+ )
tsam_xarray/_result.py ADDED
@@ -0,0 +1,226 @@
1
+ """Result dataclasses for tsam_xarray."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import xarray as xr
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class AccuracyMetrics:
14
+ """Accuracy metrics from time series aggregation."""
15
+
16
+ rmse: xr.DataArray
17
+ mae: xr.DataArray
18
+ rmse_duration: xr.DataArray
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class AggregationResult:
23
+ """Result of tsam_xarray.aggregate()."""
24
+
25
+ typical_periods: xr.DataArray
26
+ cluster_assignments: xr.DataArray
27
+ cluster_weights: xr.DataArray
28
+ segment_durations: xr.DataArray | None
29
+ accuracy: AccuracyMetrics
30
+ reconstructed: xr.DataArray
31
+ original: xr.DataArray
32
+ raw: Any # tsam.AggregationResult or dict of them
33
+
34
+ @property
35
+ def n_clusters(self) -> int:
36
+ """Number of typical period clusters."""
37
+ return int(self.cluster_weights.sizes["cluster"])
38
+
39
+ @property
40
+ def n_timesteps_per_period(self) -> int:
41
+ """Number of timesteps per typical period."""
42
+ return int(self.typical_periods.sizes["timestep"])
43
+
44
+ @property
45
+ def n_segments(self) -> int | None:
46
+ """Number of segments per period, if segmentation was used."""
47
+ if isinstance(self.raw, dict):
48
+ first = next(iter(self.raw.values()))
49
+ result: int | None = first.n_segments
50
+ else:
51
+ result = self.raw.n_segments
52
+ return result
53
+
54
+ @property
55
+ def clustering_duration(self) -> float:
56
+ """Time spent on clustering in seconds."""
57
+ if isinstance(self.raw, dict):
58
+ total: float = sum(r.clustering_duration for r in self.raw.values())
59
+ return total
60
+ duration: float = self.raw.clustering_duration
61
+ return duration
62
+
63
+ @property
64
+ def is_transferred(self) -> bool:
65
+ """Whether result was created via ClusteringResult.apply()."""
66
+ if isinstance(self.raw, dict):
67
+ return all(r.is_transferred for r in self.raw.values())
68
+ is_transferred: bool = self.raw.is_transferred
69
+ return is_transferred
70
+
71
+ @property
72
+ def residuals(self) -> xr.DataArray:
73
+ """Difference between original and reconstructed data."""
74
+ return self.original - self.reconstructed
75
+
76
+ def disaggregate(self, data: xr.DataArray) -> xr.DataArray:
77
+ """Map data on ``(cluster, timestep)`` back to original time.
78
+
79
+ This is the inverse of ``aggregate()``. Use it to expand
80
+ external data computed on the compact typical-period grid
81
+ (e.g., optimization results) back to the full time axis.
82
+
83
+ Without segmentation, values are repeated for each timestep
84
+ in the period. With segmentation, values are placed at segment
85
+ boundaries and remaining timesteps are NaN — use
86
+ ``.ffill(dim="time")``, ``.interpolate_na(dim="time")``, etc.
87
+
88
+ Parameters
89
+ ----------
90
+ data : xr.DataArray
91
+ Data with ``cluster`` and ``timestep`` dims, matching the
92
+ shape of ``result.typical_periods``. Additional dims
93
+ (including auto-sliced dims like scenario) are supported.
94
+
95
+ Returns
96
+ -------
97
+ xr.DataArray
98
+ Data with ``cluster`` and ``timestep`` replaced by the
99
+ original ``time`` dimension.
100
+ """
101
+ # Identify slice dims (dims on data that aren't cluster/timestep
102
+ # and aren't cluster_dim coords)
103
+ slice_dims = [
104
+ str(d)
105
+ for d in data.dims
106
+ if d not in ("cluster", "timestep") and d in self.cluster_assignments.dims
107
+ ]
108
+
109
+ if not slice_dims:
110
+ return self._disaggregate_single(data)
111
+
112
+ # Loop over slice dims and concat
113
+ import itertools
114
+
115
+ slice_coords = {d: data.coords[d].values for d in slice_dims}
116
+ keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
117
+ results = []
118
+ for key in keys:
119
+ sel = dict(zip(slice_dims, key, strict=True))
120
+ data_slice = data.sel(sel)
121
+ # Use per-slice raw result for assignments/durations
122
+ result_slice = self._make_slice_view(sel)
123
+ results.append(result_slice._disaggregate_single(data_slice))
124
+
125
+ # Concat along slice dims
126
+ out = results[0]
127
+ if len(slice_dims) == 1:
128
+ import pandas as pd
129
+
130
+ out = xr.concat(
131
+ results,
132
+ dim=pd.Index(
133
+ slice_coords[slice_dims[0]],
134
+ name=slice_dims[0],
135
+ ),
136
+ )
137
+ else:
138
+ import pandas as pd
139
+
140
+ # Multi-dim concat
141
+ it = iter(results)
142
+
143
+ def _nest(dims: list[str]) -> list: # type: ignore[type-arg]
144
+ if len(dims) == 1:
145
+ return [next(it) for _ in slice_coords[dims[0]]]
146
+ return [_nest(dims[1:]) for _ in slice_coords[dims[0]]]
147
+
148
+ nested = _nest(slice_dims)
149
+ for dim in reversed(slice_dims):
150
+ idx = pd.Index(slice_coords[dim], name=dim)
151
+ if isinstance(nested[0], list):
152
+ nested = [xr.concat(group, dim=idx) for group in nested]
153
+ else:
154
+ out = xr.concat(nested, dim=idx)
155
+ return out
156
+
157
+ def _make_slice_view(self, sel: dict[str, object]) -> AggregationResult:
158
+ """Create a view of this result for a single slice."""
159
+ return AggregationResult(
160
+ typical_periods=self.typical_periods.sel(sel),
161
+ cluster_assignments=self.cluster_assignments.sel(sel),
162
+ cluster_weights=self.cluster_weights.sel(sel),
163
+ segment_durations=(
164
+ self.segment_durations.sel(sel)
165
+ if self.segment_durations is not None
166
+ else None
167
+ ),
168
+ accuracy=AccuracyMetrics(
169
+ rmse=self.accuracy.rmse.sel(sel),
170
+ mae=self.accuracy.mae.sel(sel),
171
+ rmse_duration=self.accuracy.rmse_duration.sel(sel),
172
+ ),
173
+ reconstructed=self.reconstructed.sel(sel),
174
+ original=self.original.sel(sel),
175
+ raw=(
176
+ self.raw[tuple(sel.values())]
177
+ if isinstance(self.raw, dict)
178
+ else self.raw
179
+ ),
180
+ )
181
+
182
+ def _disaggregate_single(self, data: xr.DataArray) -> xr.DataArray:
183
+ """Disaggregate without slice dims."""
184
+ time_coords = self.original.coords["time"]
185
+ assignments = self.cluster_assignments.values
186
+ n_original_timesteps = len(time_coords)
187
+ n_periods = len(assignments)
188
+ n_per_period = n_original_timesteps // n_periods
189
+
190
+ other_dims = [str(d) for d in data.dims if d not in ("cluster", "timestep")]
191
+
192
+ if self.segment_durations is None:
193
+ expanded = data.sel(cluster=xr.DataArray(assignments, dims=["period"]))
194
+ flat = expanded.values.reshape(-1, *expanded.shape[2:])
195
+ result = xr.DataArray(
196
+ flat[:n_original_timesteps],
197
+ dims=["time", *other_dims],
198
+ coords={"time": time_coords},
199
+ )
200
+ for d in other_dims:
201
+ if d in data.coords:
202
+ result = result.assign_coords({d: data.coords[d]})
203
+ return result
204
+
205
+ other_shape = [data.sizes[d] for d in other_dims]
206
+ total_timesteps = n_periods * n_per_period
207
+ out = np.full([total_timesteps, *other_shape], np.nan)
208
+
209
+ for p_idx, cluster in enumerate(assignments):
210
+ offset = 0
211
+ durations = self.segment_durations.sel(cluster=int(cluster)).values
212
+ for seg_idx, dur in enumerate(durations):
213
+ t_start = p_idx * n_per_period + offset
214
+ vals = data.sel(cluster=int(cluster), timestep=seg_idx).values
215
+ out[t_start] = vals
216
+ offset += int(dur)
217
+
218
+ result = xr.DataArray(
219
+ out[:n_original_timesteps],
220
+ dims=["time", *other_dims],
221
+ coords={"time": time_coords},
222
+ )
223
+ for d in other_dims:
224
+ if d in data.coords:
225
+ result = result.assign_coords({d: data.coords[d]})
226
+ return result
@@ -0,0 +1,107 @@
1
+ """Synthetic sample data for documentation and testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import xarray as xr
8
+
9
+
10
+ def sample_energy_data(
11
+ n_days: int = 30,
12
+ seed: int = 42,
13
+ ) -> xr.DataArray:
14
+ """Create a synthetic energy DataArray with realistic profiles.
15
+
16
+ Returns an hourly DataArray with dimensions:
17
+
18
+ - **time** — hourly timestamps
19
+ - **variable** — ``solar``, ``wind``, ``demand``
20
+ - **region** — ``north``, ``south``, ``east``
21
+ - **scenario** — ``low``, ``high``
22
+
23
+ Solar follows a daily bell curve, wind has seasonal variation
24
+ with autocorrelation, and demand combines a daily commute pattern
25
+ with weather-driven noise. Scenarios scale the base profiles.
26
+
27
+ Parameters
28
+ ----------
29
+ n_days : int
30
+ Number of days of hourly data (default: 30).
31
+ seed : int
32
+ Random seed for reproducibility (default: 42).
33
+
34
+ Returns
35
+ -------
36
+ xr.DataArray
37
+ Shape ``(n_days * 24, 3, 3, 2)`` with coords on every dim.
38
+ """
39
+ rng = np.random.default_rng(seed)
40
+ hours = n_days * 24
41
+ time = pd.date_range("2020-01-01", periods=hours, freq="h")
42
+ hour_of_day = np.arange(hours) % 24
43
+ day_of_year = time.dayofyear.values
44
+
45
+ variables = ["solar", "wind", "demand"]
46
+ regions = ["north", "south", "east"]
47
+ scenarios = ["low", "high"]
48
+
49
+ # --- base profiles (hours,) ---
50
+ # Solar: bell curve peaking at noon, zero at night
51
+ solar_base = np.maximum(0, np.sin(np.pi * (hour_of_day - 6) / 12)) ** 1.5
52
+ # Seasonal envelope: weaker in winter
53
+ solar_season = 0.6 + 0.4 * np.sin(2 * np.pi * (day_of_year - 80) / 365)
54
+ solar = solar_base * solar_season
55
+
56
+ # Wind: autocorrelated noise with seasonal mean
57
+ wind = np.empty(hours)
58
+ wind[0] = 0.5
59
+ for t in range(1, hours):
60
+ wind[t] = 0.9 * wind[t - 1] + 0.1 * rng.standard_normal()
61
+ wind = (wind - wind.min()) / (wind.max() - wind.min())
62
+ wind_season = 0.7 + 0.3 * np.cos(2 * np.pi * (day_of_year - 1) / 365)
63
+ wind = wind * wind_season
64
+
65
+ # Demand: daily pattern + seasonal + noise
66
+ demand_daily = 0.5 + 0.3 * np.sin(np.pi * (hour_of_day - 5) / 12)
67
+ demand_season = 1.0 + 0.2 * np.cos(2 * np.pi * (day_of_year - 1) / 365)
68
+ demand = demand_daily * demand_season + 0.05 * rng.standard_normal(hours)
69
+ demand = np.clip(demand, 0, None)
70
+
71
+ bases = np.stack([solar, wind, demand], axis=-1) # (hours, 3)
72
+
73
+ # --- region modifiers ---
74
+ region_scales = np.array(
75
+ [
76
+ [0.7, 1.3, 1.1], # north: less solar, more wind, slightly more demand
77
+ [1.3, 0.7, 0.9], # south: more solar, less wind, less demand
78
+ [1.0, 1.0, 1.0], # east: baseline
79
+ ]
80
+ ) # (3 regions, 3 variables)
81
+
82
+ # (hours, variables, regions)
83
+ data_3d = bases[:, :, np.newaxis] * region_scales.T[np.newaxis, :, :]
84
+
85
+ # --- scenario scaling ---
86
+ scenario_scales = np.array([0.8, 1.2]) # low, high
87
+ # (hours, variables, regions, scenarios)
88
+ data_4d = (
89
+ data_3d[:, :, :, np.newaxis]
90
+ * scenario_scales[np.newaxis, np.newaxis, np.newaxis, :]
91
+ )
92
+
93
+ # Add a small amount of noise per cell
94
+ data_4d += 0.02 * rng.standard_normal(data_4d.shape)
95
+ data_4d = np.clip(data_4d, 0, None)
96
+
97
+ return xr.DataArray(
98
+ data_4d,
99
+ dims=["time", "variable", "region", "scenario"],
100
+ coords={
101
+ "time": time,
102
+ "variable": variables,
103
+ "region": regions,
104
+ "scenario": scenarios,
105
+ },
106
+ name="energy",
107
+ )
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.0.1a0'
32
+ __version_tuple__ = version_tuple = (0, 0, 1, 'a0')
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,79 @@
1
+ Metadata-Version: 2.4
2
+ Name: tsam_xarray
3
+ Version: 0.0.1a0
4
+ Summary: Lightweight xarray wrapper for tsam time series aggregation
5
+ License-Expression: MIT
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: bottleneck>=1.4
8
+ Requires-Dist: tsam>=3.2.0
9
+ Requires-Dist: xarray>=2024.1
10
+ Description-Content-Type: text/markdown
11
+
12
+ # tsam_xarray
13
+
14
+ Lightweight [xarray](https://xarray.dev/) wrapper for [tsam](https://github.com/FZJ-IEK3-VSA/tsam) time series aggregation.
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install tsam_xarray
20
+ ```
21
+
22
+ ## Quick start
23
+
24
+ ```python
25
+ import numpy as np
26
+ import pandas as pd
27
+ import xarray as xr
28
+ import tsam_xarray
29
+
30
+ # Create sample data: 30 days of hourly solar and wind data
31
+ time = pd.date_range("2020-01-01", periods=30 * 24, freq="h")
32
+ da = xr.DataArray(
33
+ np.random.default_rng(42).random((len(time), 2)),
34
+ dims=["time", "variable"],
35
+ coords={"time": time, "variable": ["solar", "wind"]},
36
+ )
37
+
38
+ # Aggregate to 4 typical days
39
+ result = tsam_xarray.aggregate(
40
+ da, time_dim="time", cluster_dim="variable", n_clusters=4,
41
+ )
42
+
43
+ result.typical_periods # (cluster, timestep, variable)
44
+ result.cluster_weights # (cluster,) — days each represents
45
+ result.accuracy.rmse # (variable,) — per-variable RMSE
46
+ result.reconstructed # same shape as input
47
+ ```
48
+
49
+ ## Multi-dimensional data
50
+
51
+ ```python
52
+ # 4D data: (time, variable, region, scenario)
53
+ da = xr.DataArray(...)
54
+
55
+ # Cluster variable × region together; scenario is sliced independently
56
+ result = tsam_xarray.aggregate(
57
+ da,
58
+ time_dim="time",
59
+ cluster_dim=["variable", "region"],
60
+ n_clusters=8,
61
+ )
62
+
63
+ result.typical_periods # (scenario, cluster, timestep, variable, region)
64
+ ```
65
+
66
+ All [tsam.aggregate()](https://github.com/FZJ-IEK3-VSA/tsam) keyword arguments pass through:
67
+
68
+ ```python
69
+ from tsam import ClusterConfig, SegmentConfig
70
+
71
+ result = tsam_xarray.aggregate(
72
+ da,
73
+ time_dim="time",
74
+ cluster_dim="variable",
75
+ n_clusters=8,
76
+ cluster=ClusterConfig(method="kmeans"),
77
+ segments=SegmentConfig(n_segments=6),
78
+ )
79
+ ```
@@ -0,0 +1,8 @@
1
+ tsam_xarray/__init__.py,sha256=99SJnPmXonkPw_ZZSNSca0x1EVNZCs4465KOG4oQX64,253
2
+ tsam_xarray/_core.py,sha256=ejIrJlJaVQDk1JgVT5SX_rmbOCqP2wZXxLVTEJVLQQ8,14855
3
+ tsam_xarray/_result.py,sha256=44O5AomGDnvDROMqmZpVRlDBfrVi_jNdwDThJPBhJZc,8187
4
+ tsam_xarray/_sample_data.py,sha256=yi2f5hPOUV3uUCX2kL4EfS3GV-obgQGnmidKFO5jSG4,3534
5
+ tsam_xarray/_version.py,sha256=N6jqqryygxntTpQZELt2H0LAGZ4wKgVPTWGAhPzx98U,712
6
+ tsam_xarray-0.0.1a0.dist-info/METADATA,sha256=PSXRFH12FrJSQfaSymU7PMzUFGMtOW1GepVKSPK9cDc,1980
7
+ tsam_xarray-0.0.1a0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
8
+ tsam_xarray-0.0.1a0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any