flixopt 3.0.1__py3-none-any.whl → 6.0.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flixopt/__init__.py +57 -49
  2. flixopt/carrier.py +159 -0
  3. flixopt/clustering/__init__.py +51 -0
  4. flixopt/clustering/base.py +1746 -0
  5. flixopt/clustering/intercluster_helpers.py +201 -0
  6. flixopt/color_processing.py +372 -0
  7. flixopt/comparison.py +819 -0
  8. flixopt/components.py +848 -270
  9. flixopt/config.py +853 -496
  10. flixopt/core.py +111 -98
  11. flixopt/effects.py +294 -284
  12. flixopt/elements.py +484 -223
  13. flixopt/features.py +220 -118
  14. flixopt/flow_system.py +2026 -389
  15. flixopt/interface.py +504 -286
  16. flixopt/io.py +1718 -55
  17. flixopt/linear_converters.py +291 -230
  18. flixopt/modeling.py +304 -181
  19. flixopt/network_app.py +2 -1
  20. flixopt/optimization.py +788 -0
  21. flixopt/optimize_accessor.py +373 -0
  22. flixopt/plot_result.py +143 -0
  23. flixopt/plotting.py +1177 -1034
  24. flixopt/results.py +1331 -372
  25. flixopt/solvers.py +12 -4
  26. flixopt/statistics_accessor.py +2412 -0
  27. flixopt/stats_accessor.py +75 -0
  28. flixopt/structure.py +954 -120
  29. flixopt/topology_accessor.py +676 -0
  30. flixopt/transform_accessor.py +2277 -0
  31. flixopt/types.py +120 -0
  32. flixopt-6.0.0rc7.dist-info/METADATA +290 -0
  33. flixopt-6.0.0rc7.dist-info/RECORD +36 -0
  34. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/WHEEL +1 -1
  35. flixopt/aggregation.py +0 -382
  36. flixopt/calculation.py +0 -672
  37. flixopt/commons.py +0 -51
  38. flixopt/utils.py +0 -86
  39. flixopt-3.0.1.dist-info/METADATA +0 -209
  40. flixopt-3.0.1.dist-info/RECORD +0 -26
  41. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/licenses/LICENSE +0 -0
  42. {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/top_level.txt +0 -0
flixopt/plotting.py CHANGED
@@ -25,9 +25,7 @@ accessible for standalone data visualization tasks.
25
25
 
26
26
  from __future__ import annotations
27
27
 
28
- import itertools
29
28
  import logging
30
- import os
31
29
  import pathlib
32
30
  from typing import TYPE_CHECKING, Any, Literal
33
31
 
@@ -39,14 +37,17 @@ import pandas as pd
39
37
  import plotly.express as px
40
38
  import plotly.graph_objects as go
41
39
  import plotly.offline
42
- from plotly.exceptions import PlotlyError
40
+ import xarray as xr
41
+
42
+ from .color_processing import ColorType, process_colors
43
+ from .config import CONFIG
43
44
 
44
45
  if TYPE_CHECKING:
45
46
  import pyvis
46
47
 
47
48
  logger = logging.getLogger('flixopt')
48
49
 
49
- # Define the colors for the 'portland' colormap in matplotlib
50
+ # Define the colors for the 'portland' colorscale in matplotlib
50
51
  _portland_colors = [
51
52
  [12 / 255, 51 / 255, 131 / 255], # Dark blue
52
53
  [10 / 255, 136 / 255, 186 / 255], # Light blue
@@ -55,7 +56,7 @@ _portland_colors = [
55
56
  [217 / 255, 30 / 255, 30 / 255], # Red
56
57
  ]
57
58
 
58
- # Check if the colormap already exists before registering it
59
+ # Check if the colorscale already exists before registering it
59
60
  if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7
60
61
  registry = plt.colormaps
61
62
  if 'portland' not in registry:
@@ -65,486 +66,524 @@ else: # Matplotlib < 3.7
65
66
  plt.register_cmap(name='portland', cmap=mcolors.LinearSegmentedColormap.from_list('portland', _portland_colors))
66
67
 
67
68
 
68
- ColorType = str | list[str] | dict[str, str]
69
- """Flexible color specification type supporting multiple input formats for visualization.
70
-
71
- Color specifications can take several forms to accommodate different use cases:
72
-
73
- **Named Colormaps** (str):
74
- - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1'
75
- - Energy-focused: 'portland' (custom flixopt colormap for energy systems)
76
- - Backend-specific maps available in Plotly and Matplotlib
77
-
78
- **Color Lists** (list[str]):
79
- - Explicit color sequences: ['red', 'blue', 'green', 'orange']
80
- - HEX codes: ['#FF0000', '#0000FF', '#00FF00', '#FFA500']
81
- - Mixed formats: ['red', '#0000FF', 'green', 'orange']
69
+ PlottingEngine = Literal['plotly', 'matplotlib']
70
+ """Identifier for the plotting engine to use."""
82
71
 
83
- **Label-to-Color Mapping** (dict[str, str]):
84
- - Explicit associations: {'Wind': 'skyblue', 'Solar': 'gold', 'Gas': 'brown'}
85
- - Ensures consistent colors across different plots and datasets
86
- - Ideal for energy system components with semantic meaning
87
72
 
88
- Examples:
89
- ```python
90
- # Named colormap
91
- colors = 'viridis' # Automatic color generation
73
+ def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset:
74
+ """Convert DataFrame or Series to Dataset if needed."""
75
+ if isinstance(data, xr.Dataset):
76
+ return data
77
+ elif isinstance(data, pd.DataFrame):
78
+ # Convert DataFrame to Dataset
79
+ return data.to_xarray()
80
+ elif isinstance(data, pd.Series):
81
+ # Convert Series to DataFrame first, then to Dataset
82
+ return data.to_frame().to_xarray()
83
+ else:
84
+ raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}')
92
85
 
93
- # Explicit color list
94
- colors = ['red', 'blue', 'green', '#FFD700']
95
86
 
96
- # Component-specific mapping
97
- colors = {
98
- 'Wind_Turbine': 'skyblue',
99
- 'Solar_Panel': 'gold',
100
- 'Natural_Gas': 'brown',
101
- 'Battery': 'green',
102
- 'Electric_Load': 'darkred'
103
- }
104
- ```
105
-
106
- Color Format Support:
107
- - **Named Colors**: 'red', 'blue', 'forestgreen', 'darkorange'
108
- - **HEX Codes**: '#FF0000', '#0000FF', '#228B22', '#FF8C00'
109
- - **RGB Tuples**: (255, 0, 0), (0, 0, 255) [Matplotlib only]
110
- - **RGBA**: 'rgba(255,0,0,0.8)' [Plotly only]
111
-
112
- References:
113
- - HTML Color Names: https://htmlcolorcodes.com/color-names/
114
- - Matplotlib Colormaps: https://matplotlib.org/stable/tutorials/colors/colormaps.html
115
- - Plotly Built-in Colorscales: https://plotly.com/python/builtin-colorscales/
116
- """
87
+ def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None:
88
+ """Validate dataset for plotting (checks for empty data, non-numeric types, etc.)."""
89
+ # Check for empty data
90
+ if not allow_empty and len(data.data_vars) == 0:
91
+ raise ValueError('Empty Dataset provided (no variables). Cannot create plot.')
92
+
93
+ # Check if dataset has any data (xarray uses nbytes for total size)
94
+ if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True:
95
+ if not allow_empty and len(data.data_vars) > 0:
96
+ raise ValueError('Dataset has zero size. Cannot create plot.')
97
+ if len(data.data_vars) == 0:
98
+ return # Empty dataset, nothing to validate
99
+ return
100
+
101
+ # Check for non-numeric data types
102
+ for var in data.data_vars:
103
+ dtype = data[var].dtype
104
+ if not np.issubdtype(dtype, np.number):
105
+ raise TypeError(
106
+ f"Variable '{var}' has non-numeric dtype '{dtype}'. "
107
+ f'Plotting requires numeric data types (int, float, etc.).'
108
+ )
117
109
 
118
- PlottingEngine = Literal['plotly', 'matplotlib']
119
- """Identifier for the plotting engine to use."""
110
+ # Warn about NaN/Inf values
111
+ for var in data.data_vars:
112
+ if np.isnan(data[var].values).any():
113
+ logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.")
114
+ if np.isinf(data[var].values).any():
115
+ logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.")
120
116
 
121
117
 
122
- class ColorProcessor:
123
- """Intelligent color management system for consistent multi-backend visualization.
118
+ def with_plotly(
119
+ data: xr.Dataset | pd.DataFrame | pd.Series,
120
+ mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
121
+ colors: ColorType | None = None,
122
+ title: str = '',
123
+ ylabel: str = '',
124
+ xlabel: str = '',
125
+ facet_by: str | list[str] | None = None,
126
+ animate_by: str | None = None,
127
+ facet_cols: int | None = None,
128
+ shared_yaxes: bool = True,
129
+ shared_xaxes: bool = True,
130
+ **px_kwargs: Any,
131
+ ) -> go.Figure:
132
+ """
133
+ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
124
134
 
125
- This class provides unified color processing across Plotly and Matplotlib backends,
126
- ensuring consistent visual appearance regardless of the plotting engine used.
127
- It handles color palette generation, named colormap translation, and intelligent
128
- color cycling for complex datasets with many categories.
135
+ Uses Plotly Express for convenient faceting and animation with automatic styling.
129
136
 
130
- Key Features:
131
- **Backend Agnostic**: Automatic color format conversion between engines
132
- **Palette Management**: Support for named colormaps, custom palettes, and color lists
133
- **Intelligent Cycling**: Smart color assignment for datasets with many categories
134
- **Fallback Handling**: Graceful degradation when requested colormaps are unavailable
135
- **Energy System Colors**: Built-in palettes optimized for energy system visualization
137
+ Args:
138
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot.
139
+ mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
140
+ 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
141
+ colors: Color specification (colorscale, list, or dict mapping labels to colors).
142
+ title: The main title of the plot.
143
+ ylabel: The label for the y-axis.
144
+ xlabel: The label for the x-axis.
145
+ facet_by: Dimension(s) to create facets for. Creates a subplot grid.
146
+ Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
147
+ If the dimension doesn't exist in the data, it will be silently ignored.
148
+ animate_by: Dimension to animate over. Creates animation frames.
149
+ If the dimension doesn't exist in the data, it will be silently ignored.
150
+ facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
151
+ shared_yaxes: Whether subplots share y-axes.
152
+ shared_xaxes: Whether subplots share x-axes.
153
+ **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function
154
+ (px.bar, px.line, px.area). These override default arguments if provided.
155
+ Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear'
136
156
 
137
- Color Input Types:
138
- - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc.
139
- - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00']
140
- - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'}
157
+ Returns:
158
+ A Plotly figure object containing the faceted/animated plot. You can further customize
159
+ the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()).
141
160
 
142
161
  Examples:
143
- Basic color processing:
162
+ Simple plot:
144
163
 
145
164
  ```python
146
- # Initialize for Plotly backend
147
- processor = ColorProcessor(engine='plotly', default_colormap='viridis')
148
-
149
- # Process different color specifications
150
- colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage'])
151
- colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C'])
152
- colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas'])
153
-
154
- # Switch to Matplotlib
155
- processor = ColorProcessor(engine='matplotlib')
156
- mpl_colors = processor.process_colors('tab10', component_labels)
165
+ fig = with_plotly(dataset, mode='area', title='Energy Mix')
157
166
  ```
158
167
 
159
- Energy system visualization:
168
+ Facet by scenario:
160
169
 
161
170
  ```python
162
- # Specialized energy system palette
163
- energy_colors = {
164
- 'Natural_Gas': '#8B4513', # Brown
165
- 'Electricity': '#FFD700', # Gold
166
- 'Heat': '#FF4500', # Red-orange
167
- 'Cooling': '#87CEEB', # Sky blue
168
- 'Hydrogen': '#E6E6FA', # Lavender
169
- 'Battery': '#32CD32', # Lime green
170
- }
171
-
172
- processor = ColorProcessor('plotly')
173
- flow_colors = processor.process_colors(energy_colors, flow_labels)
171
+ fig = with_plotly(dataset, facet_by='scenario', facet_cols=2)
174
172
  ```
175
173
 
176
- Args:
177
- engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format.
178
- default_colormap: Fallback colormap when requested palettes are unavailable.
179
- Common options: 'viridis', 'plasma', 'tab10', 'portland'.
174
+ Animate by period:
180
175
 
181
- """
176
+ ```python
177
+ fig = with_plotly(dataset, animate_by='period')
178
+ ```
182
179
 
183
- def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'):
184
- """Initialize the color processor with specified backend and defaults."""
185
- if engine not in ['plotly', 'matplotlib']:
186
- raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}')
187
- self.engine = engine
188
- self.default_colormap = default_colormap
189
-
190
- def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]:
191
- """
192
- Generate colors from a named colormap.
193
-
194
- Args:
195
- colormap_name: Name of the colormap
196
- num_colors: Number of colors to generate
197
-
198
- Returns:
199
- list of colors in the format appropriate for the engine
200
- """
201
- if self.engine == 'plotly':
202
- try:
203
- colorscale = px.colors.get_colorscale(colormap_name)
204
- except PlotlyError as e:
205
- logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}")
206
- colorscale = px.colors.get_colorscale(self.default_colormap)
207
-
208
- # Generate evenly spaced points
209
- color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0]
210
- return px.colors.sample_colorscale(colorscale, color_points)
211
-
212
- else: # matplotlib
213
- try:
214
- cmap = plt.get_cmap(colormap_name, num_colors)
215
- except ValueError as e:
216
- logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}")
217
- cmap = plt.get_cmap(self.default_colormap, num_colors)
218
-
219
- return [cmap(i) for i in range(num_colors)]
220
-
221
- def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]:
222
- """
223
- Handle a list of colors, cycling if necessary.
224
-
225
- Args:
226
- colors: list of color strings
227
- num_labels: Number of labels that need colors
228
-
229
- Returns:
230
- list of colors matching the number of labels
231
- """
232
- if len(colors) == 0:
233
- logger.error(f'Empty color list provided. Using {self.default_colormap} instead.')
234
- return self._generate_colors_from_colormap(self.default_colormap, num_labels)
235
-
236
- if len(colors) < num_labels:
237
- logger.warning(
238
- f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.'
239
- )
240
- # Cycle through the colors
241
- color_iter = itertools.cycle(colors)
242
- return [next(color_iter) for _ in range(num_labels)]
243
- else:
244
- # Trim if necessary
245
- if len(colors) > num_labels:
246
- logger.warning(
247
- f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.'
248
- )
249
- return colors[:num_labels]
250
-
251
- def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]:
252
- """
253
- Handle a dictionary mapping labels to colors.
254
-
255
- Args:
256
- colors: Dictionary mapping labels to colors
257
- labels: list of labels that need colors
258
-
259
- Returns:
260
- list of colors in the same order as labels
261
- """
262
- if len(colors) == 0:
263
- logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.')
264
- return self._generate_colors_from_colormap(self.default_colormap, len(labels))
265
-
266
- # Find missing labels
267
- missing_labels = sorted(set(labels) - set(colors.keys()))
268
- if missing_labels:
269
- logger.warning(
270
- f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.'
271
- )
180
+ Facet and animate:
272
181
 
273
- # Generate colors for missing labels
274
- missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels))
182
+ ```python
183
+ fig = with_plotly(dataset, facet_by='scenario', animate_by='period')
184
+ ```
275
185
 
276
- # Create a copy to avoid modifying the original
277
- colors_copy = colors.copy()
278
- for i, label in enumerate(missing_labels):
279
- colors_copy[label] = missing_colors[i]
280
- else:
281
- colors_copy = colors
282
-
283
- # Create color list in the same order as labels
284
- return [colors_copy[label] for label in labels]
285
-
286
- def process_colors(
287
- self,
288
- colors: ColorType,
289
- labels: list[str],
290
- return_mapping: bool = False,
291
- ) -> list[Any] | dict[str, Any]:
292
- """
293
- Process colors for the specified labels.
294
-
295
- Args:
296
- colors: Color specification (colormap name, list of colors, or label-to-color mapping)
297
- labels: list of data labels that need colors assigned
298
- return_mapping: If True, returns a dictionary mapping labels to colors;
299
- if False, returns a list of colors in the same order as labels
300
-
301
- Returns:
302
- Either a list of colors or a dictionary mapping labels to colors
303
- """
304
- if len(labels) == 0:
305
- logger.error('No labels provided for color assignment.')
306
- return {} if return_mapping else []
307
-
308
- # Process based on type of colors input
309
- if isinstance(colors, str):
310
- color_list = self._generate_colors_from_colormap(colors, len(labels))
311
- elif isinstance(colors, list):
312
- color_list = self._handle_color_list(colors, len(labels))
313
- elif isinstance(colors, dict):
314
- color_list = self._handle_color_dict(colors, labels)
315
- else:
316
- logger.error(
317
- f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.'
318
- )
319
- color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels))
186
+ Customize with Plotly Express kwargs:
320
187
 
321
- # Return either a list or a mapping
322
- if return_mapping:
323
- return {label: color_list[i] for i, label in enumerate(labels)}
324
- else:
325
- return color_list
188
+ ```python
189
+ fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear')
190
+ ```
326
191
 
192
+ Further customize the returned figure:
327
193
 
328
- def with_plotly(
329
- data: pd.DataFrame,
330
- style: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
331
- colors: ColorType = 'viridis',
332
- title: str = '',
333
- ylabel: str = '',
334
- xlabel: str = 'Time in h',
335
- fig: go.Figure | None = None,
336
- ) -> go.Figure:
194
+ ```python
195
+ fig = with_plotly(dataset, mode='line')
196
+ fig.update_traces(line={'width': 5, 'dash': 'dot'})
197
+ fig.update_layout(template='plotly_dark', width=1200, height=600)
198
+ ```
337
199
  """
338
- Plot a DataFrame with Plotly, using either stacked bars or stepped lines.
200
+ if colors is None:
201
+ colors = CONFIG.Plotting.default_qualitative_colorscale
339
202
 
340
- Args:
341
- data: A DataFrame containing the data to plot, where the index represents time (e.g., hours),
342
- and each column represents a separate data series.
343
- style: The plotting style. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines,
344
- or 'area' for stacked area charts.
345
- colors: Color specification, can be:
346
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
347
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
348
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
349
- title: The title of the plot.
350
- ylabel: The label for the y-axis.
351
- xlabel: The label for the x-axis.
352
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
203
+ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
204
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
353
205
 
354
- Returns:
355
- A Plotly figure object containing the generated plot.
356
- """
357
- if style not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
358
- raise ValueError(f"'style' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {style!r}")
359
- if data.empty:
360
- return go.Figure()
206
+ # Apply CONFIG defaults if not explicitly set
207
+ if facet_cols is None:
208
+ facet_cols = CONFIG.Plotting.default_facet_cols
361
209
 
362
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns))
210
+ # Ensure data is a Dataset and validate it
211
+ data = _ensure_dataset(data)
212
+ _validate_plotting_data(data, allow_empty=True)
363
213
 
364
- fig = fig if fig is not None else go.Figure()
214
+ # Handle empty data
215
+ if len(data.data_vars) == 0:
216
+ logger.error('with_plotly() got an empty Dataset.')
217
+ return go.Figure()
365
218
 
366
- if style == 'stacked_bar':
367
- for i, column in enumerate(data.columns):
368
- fig.add_trace(
369
- go.Bar(
370
- x=data.index,
371
- y=data[column],
372
- name=column,
373
- marker=dict(
374
- color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)')
375
- ), # Transparent line with 0 width
376
- )
377
- )
219
+ # Handle all-scalar datasets (where all variables have no dimensions)
220
+ # This occurs when all variables are scalar values with dims=()
221
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
222
+ # Create a simple DataFrame with variable names as x-axis
223
+ variables = list(data.data_vars.keys())
224
+ values = [float(data[var].values) for var in data.data_vars]
378
225
 
379
- fig.update_layout(
380
- barmode='relative',
381
- bargap=0, # No space between bars
382
- bargroupgap=0, # No space between grouped bars
383
- )
384
- if style == 'grouped_bar':
385
- for i, column in enumerate(data.columns):
386
- fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i])))
387
-
388
- fig.update_layout(
389
- barmode='group',
390
- bargap=0.2, # No space between bars
391
- bargroupgap=0, # space between grouped bars
226
+ # Resolve colors
227
+ color_discrete_map = process_colors(
228
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
392
229
  )
393
- elif style == 'line':
394
- for i, column in enumerate(data.columns):
395
- fig.add_trace(
396
- go.Scatter(
397
- x=data.index,
398
- y=data[column],
399
- mode='lines',
400
- name=column,
401
- line=dict(shape='hv', color=processed_colors[i]),
402
- )
230
+ marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables]
231
+
232
+ # Create simple plot based on mode using go (not px) for better color control
233
+ if mode in ('stacked_bar', 'grouped_bar'):
234
+ fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)])
235
+ elif mode == 'line':
236
+ fig = go.Figure(
237
+ data=[
238
+ go.Scatter(
239
+ x=variables,
240
+ y=values,
241
+ mode='lines+markers',
242
+ marker=dict(color=marker_colors, size=8),
243
+ line=dict(color='lightgray'),
244
+ )
245
+ ]
403
246
  )
404
- elif style == 'area':
405
- data = data.copy()
406
- data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting
407
- # Split columns into positive, negative, and mixed categories
408
- positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()])
409
- negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()])
410
- negative_columns = [column for column in negative_columns if column not in positive_columns]
411
- mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns))
412
-
413
- if mixed_columns:
414
- logger.error(
415
- f'Data for plotting stacked lines contains columns with both positive and negative values:'
416
- f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
247
+ elif mode == 'area':
248
+ fig = go.Figure(
249
+ data=[
250
+ go.Scatter(
251
+ x=variables,
252
+ y=values,
253
+ fill='tozeroy',
254
+ marker=dict(color=marker_colors, size=8),
255
+ line=dict(color='lightgray'),
256
+ )
257
+ ]
417
258
  )
418
-
419
- # Get color mapping for all columns
420
- colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)}
421
-
422
- for column in positive_columns + negative_columns:
423
- fig.add_trace(
424
- go.Scatter(
425
- x=data.index,
426
- y=data[column],
427
- mode='lines',
428
- name=column,
429
- line=dict(shape='hv', color=colors_stacked[column]),
430
- fill='tonexty',
431
- stackgroup='pos' if column in positive_columns else 'neg',
259
+ else:
260
+ raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"')
261
+
262
+ fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False)
263
+ return fig
264
+
265
+ # Convert Dataset to long-form DataFrame for Plotly Express
266
+ # Structure: time, variable, value, scenario, period, ... (all dims as columns)
267
+ dim_names = list(data.dims)
268
+ df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value')
269
+
270
+ # Validate facet_by and animate_by dimensions exist in the data
271
+ available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
272
+
273
+ # Check facet_by dimensions
274
+ if facet_by is not None:
275
+ if isinstance(facet_by, str):
276
+ if facet_by not in available_dims:
277
+ logger.debug(
278
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
279
+ f'Ignoring facet_by parameter.'
432
280
  )
433
- )
434
-
435
- for column in mixed_columns:
436
- fig.add_trace(
437
- go.Scatter(
438
- x=data.index,
439
- y=data[column],
440
- mode='lines',
441
- name=column,
442
- line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
281
+ facet_by = None
282
+ elif isinstance(facet_by, list):
283
+ # Filter out dimensions that don't exist
284
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
285
+ facet_by = [dim for dim in facet_by if dim in available_dims]
286
+ if missing_dims:
287
+ logger.debug(
288
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
289
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
443
290
  )
444
- )
291
+ if len(facet_by) == 0:
292
+ facet_by = None
293
+
294
+ # Check animate_by dimension
295
+ if animate_by is not None and animate_by not in available_dims:
296
+ logger.debug(
297
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
298
+ f'Ignoring animate_by parameter.'
299
+ )
300
+ animate_by = None
301
+
302
+ # Setup faceting parameters for Plotly Express
303
+ facet_row = None
304
+ facet_col = None
305
+ if facet_by:
306
+ if isinstance(facet_by, str):
307
+ # Single facet dimension - use facet_col with facet_col_wrap
308
+ facet_col = facet_by
309
+ elif len(facet_by) == 1:
310
+ facet_col = facet_by[0]
311
+ elif len(facet_by) == 2:
312
+ # Two facet dimensions - use facet_row and facet_col
313
+ facet_row = facet_by[0]
314
+ facet_col = facet_by[1]
315
+ else:
316
+ raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
445
317
 
446
- # Update layout for better aesthetics
447
- fig.update_layout(
448
- title=title,
449
- yaxis=dict(
450
- title=ylabel,
451
- showgrid=True, # Enable grid lines on the y-axis
452
- gridcolor='lightgrey', # Customize grid line color
453
- gridwidth=0.5, # Customize grid line width
454
- ),
455
- xaxis=dict(
456
- title=xlabel,
457
- showgrid=True, # Enable grid lines on the x-axis
458
- gridcolor='lightgrey', # Customize grid line color
459
- gridwidth=0.5, # Customize grid line width
460
- ),
461
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
462
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
463
- font=dict(size=14), # Increase font size for better readability
318
+ # Process colors
319
+ all_vars = df_long['variable'].unique().tolist()
320
+ color_discrete_map = process_colors(
321
+ colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
464
322
  )
465
323
 
324
+ # Determine which dimension to use for x-axis
325
+ # Collect dimensions used for faceting and animation
326
+ used_dims = set()
327
+ if facet_row:
328
+ used_dims.add(facet_row)
329
+ if facet_col:
330
+ used_dims.add(facet_col)
331
+ if animate_by:
332
+ used_dims.add(animate_by)
333
+
334
+ # Find available dimensions for x-axis (not used for faceting/animation)
335
+ x_candidates = [d for d in available_dims if d not in used_dims]
336
+
337
+ # Use 'time' if available, otherwise use the first available dimension
338
+ if 'time' in x_candidates:
339
+ x_dim = 'time'
340
+ elif len(x_candidates) > 0:
341
+ x_dim = x_candidates[0]
342
+ else:
343
+ # Fallback: use the first dimension (shouldn't happen in normal cases)
344
+ x_dim = available_dims[0] if available_dims else 'time'
345
+
346
+ # Create plot using Plotly Express based on mode
347
+ common_args = {
348
+ 'data_frame': df_long,
349
+ 'x': x_dim,
350
+ 'y': 'value',
351
+ 'color': 'variable',
352
+ 'facet_row': facet_row,
353
+ 'facet_col': facet_col,
354
+ 'animation_frame': animate_by,
355
+ 'color_discrete_map': color_discrete_map,
356
+ 'title': title,
357
+ 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''},
358
+ }
359
+
360
+ # Add facet_col_wrap for single facet dimension
361
+ if facet_col and not facet_row:
362
+ common_args['facet_col_wrap'] = facet_cols
363
+
364
+ # Add mode-specific defaults (before px_kwargs so they can be overridden)
365
+ if mode in ('line', 'area'):
366
+ common_args['line_shape'] = 'hv' # Stepped lines by default
367
+
368
+ # Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape)
369
+ # These will override the defaults set above
370
+ if px_kwargs:
371
+ common_args.update(px_kwargs)
372
+
373
+ if mode == 'stacked_bar':
374
+ fig = px.bar(**common_args)
375
+ fig.update_traces(marker_line_width=0)
376
+ fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
377
+ elif mode == 'grouped_bar':
378
+ fig = px.bar(**common_args)
379
+ fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
380
+ elif mode == 'line':
381
+ fig = px.line(**common_args)
382
+ elif mode == 'area':
383
+ # Use Plotly Express to create the area plot (preserves animation, legends, faceting)
384
+ fig = px.area(**common_args)
385
+
386
+ # Classify each variable based on its values
387
+ variable_classification = {}
388
+ for var in all_vars:
389
+ var_data = df_long[df_long['variable'] == var]['value']
390
+ var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
391
+
392
+ if len(var_data_clean) == 0:
393
+ variable_classification[var] = 'zero'
394
+ else:
395
+ has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
396
+ variable_classification[var] = (
397
+ 'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
398
+ )
399
+
400
+ # Log warning for mixed variables
401
+ mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
402
+ if mixed_vars:
403
+ logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
404
+
405
+ all_traces = list(fig.data)
406
+ for frame in fig.frames:
407
+ all_traces.extend(frame.data)
408
+
409
+ for trace in all_traces:
410
+ cls = variable_classification.get(trace.name, None)
411
+ # Only stack positive and negative, not mixed or zero
412
+ trace.stackgroup = cls if cls in ('positive', 'negative') else None
413
+
414
+ if cls in ('positive', 'negative'):
415
+ # Stacked area: add opacity to avoid hiding layers, remove line border
416
+ if hasattr(trace, 'line') and trace.line.color:
417
+ trace.fillcolor = trace.line.color
418
+ trace.line.width = 0
419
+ elif cls == 'mixed':
420
+ # Mixed variables: show as dashed line, not stacked
421
+ if hasattr(trace, 'line'):
422
+ trace.line.width = 2
423
+ trace.line.dash = 'dash'
424
+ if hasattr(trace, 'fill'):
425
+ trace.fill = None
426
+
427
+ # Update axes to share if requested (Plotly Express already handles this, but we can customize)
428
+ if not shared_yaxes:
429
+ fig.update_yaxes(matches=None)
430
+ if not shared_xaxes:
431
+ fig.update_xaxes(matches=None)
432
+
466
433
  return fig
467
434
 
468
435
 
469
436
  def with_matplotlib(
470
- data: pd.DataFrame,
471
- style: Literal['stacked_bar', 'line'] = 'stacked_bar',
472
- colors: ColorType = 'viridis',
437
+ data: xr.Dataset | pd.DataFrame | pd.Series,
438
+ mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
439
+ colors: ColorType | None = None,
473
440
  title: str = '',
474
441
  ylabel: str = '',
475
442
  xlabel: str = 'Time in h',
476
443
  figsize: tuple[int, int] = (12, 6),
477
- fig: plt.Figure | None = None,
478
- ax: plt.Axes | None = None,
444
+ plot_kwargs: dict[str, Any] | None = None,
479
445
  ) -> tuple[plt.Figure, plt.Axes]:
480
446
  """
481
- Plot a DataFrame with Matplotlib using stacked bars or stepped lines.
447
+ Plot data with Matplotlib using stacked bars or stepped lines.
482
448
 
483
449
  Args:
484
- data: A DataFrame containing the data to plot. The index should represent time (e.g., hours),
485
- and each column represents a separate data series.
486
- style: Plotting style. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
487
- colors: Color specification, can be:
488
- - A string with a colormap name (e.g., 'viridis', 'plasma')
450
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame,
451
+ the index represents time and each column represents a separate data series (variables).
452
+ mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
453
+ colors: Color specification. Can be:
454
+ - A colorscale name (e.g., 'turbo', 'plasma')
489
455
  - A list of color strings (e.g., ['#ff0000', '#00ff00'])
490
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
456
+ - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'})
491
457
  title: The title of the plot.
492
458
  ylabel: The ylabel of the plot.
493
459
  xlabel: The xlabel of the plot.
494
- figsize: Specify the size of the figure
495
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
496
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
460
+ figsize: Specify the size of the figure (width, height) in inches.
461
+ plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls.
462
+ Use this to customize plot properties (e.g., linewidth, alpha, edgecolor).
497
463
 
498
464
  Returns:
499
465
  A tuple containing the Matplotlib figure and axes objects used for the plot.
500
466
 
501
467
  Notes:
502
- - If `style` is 'stacked_bar', bars are stacked for both positive and negative values.
468
+ - If `mode` is 'stacked_bar', bars are stacked for both positive and negative values.
503
469
  Negative values are stacked separately without extra labels in the legend.
504
- - If `style` is 'line', stepped lines are drawn for each data series.
470
+ - If `mode` is 'line', stepped lines are drawn for each data series.
505
471
  """
506
- if style not in ('stacked_bar', 'line'):
507
- raise ValueError(f"'style' must be one of {{'stacked_bar','line'}} for matplotlib, got {style!r}")
472
+ if colors is None:
473
+ colors = CONFIG.Plotting.default_qualitative_colorscale
508
474
 
509
- if fig is None or ax is None:
510
- fig, ax = plt.subplots(figsize=figsize)
475
+ if mode not in ('stacked_bar', 'line'):
476
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
477
+
478
+ # Ensure data is a Dataset and validate it
479
+ data = _ensure_dataset(data)
480
+ _validate_plotting_data(data, allow_empty=True)
481
+
482
+ # Create new figure and axes
483
+ fig, ax = plt.subplots(figsize=figsize)
511
484
 
512
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns))
485
+ # Initialize plot_kwargs if not provided
486
+ if plot_kwargs is None:
487
+ plot_kwargs = {}
513
488
 
514
- if style == 'stacked_bar':
515
- cumulative_positive = np.zeros(len(data))
516
- cumulative_negative = np.zeros(len(data))
517
- width = data.index.to_series().diff().dropna().min() # Minimum time difference
489
+ # Handle all-scalar datasets (where all variables have no dimensions)
490
+ # This occurs when all variables are scalar values with dims=()
491
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
492
+ # Create simple bar/line plot with variable names as x-axis
493
+ variables = list(data.data_vars.keys())
494
+ values = [float(data[var].values) for var in data.data_vars]
518
495
 
519
- for i, column in enumerate(data.columns):
520
- positive_values = np.clip(data[column], 0, None) # Keep only positive values
521
- negative_values = np.clip(data[column], None, 0) # Keep only negative values
496
+ # Resolve colors
497
+ color_discrete_map = process_colors(
498
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
499
+ )
500
+ colors_list = [color_discrete_map.get(var, '#808080') for var in variables]
501
+
502
+ # Create plot based on mode
503
+ if mode == 'stacked_bar':
504
+ ax.bar(variables, values, color=colors_list, **plot_kwargs)
505
+ elif mode == 'line':
506
+ ax.plot(
507
+ variables,
508
+ values,
509
+ marker='o',
510
+ color=colors_list[0] if len(set(colors_list)) == 1 else None,
511
+ **plot_kwargs,
512
+ )
513
+ # If different colors, plot each point separately
514
+ if len(set(colors_list)) > 1:
515
+ ax.clear()
516
+ for i, (var, val) in enumerate(zip(variables, values, strict=False)):
517
+ ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs)
518
+ ax.set_xticks(range(len(variables)))
519
+ ax.set_xticklabels(variables)
520
+
521
+ ax.set_xlabel(xlabel, ha='center')
522
+ ax.set_ylabel(ylabel, va='center')
523
+ ax.set_title(title)
524
+ ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y')
525
+ fig.tight_layout()
526
+
527
+ return fig, ax
528
+
529
+ # Resolve colors first (includes validation)
530
+ color_discrete_map = process_colors(
531
+ colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
532
+ )
533
+
534
+ # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form)
535
+ df = data.to_dataframe()
536
+
537
+ # Get colors in column order
538
+ processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns]
539
+
540
+ if mode == 'stacked_bar':
541
+ cumulative_positive = np.zeros(len(df))
542
+ cumulative_negative = np.zeros(len(df))
543
+
544
+ # Robust bar width: handle datetime-like, numeric, and single-point indexes
545
+ if len(df.index) > 1:
546
+ delta = pd.Index(df.index).to_series().diff().dropna().min()
547
+ if hasattr(delta, 'total_seconds'): # datetime-like
548
+ width = delta.total_seconds() / 86400.0 # Matplotlib date units = days
549
+ else:
550
+ width = float(delta)
551
+ else:
552
+ width = 0.8 # reasonable default for a single bar
553
+
554
+ for i, column in enumerate(df.columns):
555
+ # Fill NaNs to avoid breaking stacking math
556
+ series = df[column].fillna(0)
557
+ positive_values = np.clip(series, 0, None) # Keep only positive values
558
+ negative_values = np.clip(series, None, 0) # Keep only negative values
522
559
  # Plot positive bars
523
560
  ax.bar(
524
- data.index,
561
+ df.index,
525
562
  positive_values,
526
563
  bottom=cumulative_positive,
527
564
  color=processed_colors[i],
528
565
  label=column,
529
566
  width=width,
530
567
  align='center',
568
+ **plot_kwargs,
531
569
  )
532
570
  cumulative_positive += positive_values.values
533
571
  # Plot negative bars
534
572
  ax.bar(
535
- data.index,
573
+ df.index,
536
574
  negative_values,
537
575
  bottom=cumulative_negative,
538
576
  color=processed_colors[i],
539
577
  label='', # No label for negative bars
540
578
  width=width,
541
579
  align='center',
580
+ **plot_kwargs,
542
581
  )
543
582
  cumulative_negative += negative_values.values
544
583
 
545
- elif style == 'line':
546
- for i, column in enumerate(data.columns):
547
- ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column)
584
+ elif mode == 'line':
585
+ for i, column in enumerate(df.columns):
586
+ ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs)
548
587
 
549
588
  # Aesthetics
550
589
  ax.set_xlabel(xlabel, ha='center')
@@ -562,213 +601,110 @@ def with_matplotlib(
562
601
  return fig, ax
563
602
 
564
603
 
565
- def heat_map_matplotlib(
566
- data: pd.DataFrame,
567
- color_map: str = 'viridis',
568
- title: str = '',
569
- xlabel: str = 'Period',
570
- ylabel: str = 'Step',
571
- figsize: tuple[float, float] = (12, 6),
572
- ) -> tuple[plt.Figure, plt.Axes]:
573
- """
574
- Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis,
575
- the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot.
576
-
577
- Args:
578
- data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
579
- The values in the DataFrame will be represented as colors in the heatmap.
580
- color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc.
581
- title: The title of the plot.
582
- xlabel: The label for the x-axis.
583
- ylabel: The label for the y-axis.
584
- figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches.
585
-
586
- Returns:
587
- A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area
588
- where the heatmap is drawn. These can be used for further customization or saving the plot to a file.
589
-
590
- Notes:
591
- - The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot.
592
- - The color scale is normalized based on the minimum and maximum values in the DataFrame.
593
- - The x-axis labels (periods) are placed at the top of the plot.
594
- - The colorbar is added horizontally at the bottom of the plot, with a label.
595
- """
596
-
597
- # Get the min and max values for color normalization
598
- color_bar_min, color_bar_max = data.min().min(), data.max().max()
599
-
600
- # Create the heatmap plot
601
- fig, ax = plt.subplots(figsize=figsize)
602
- ax.pcolormesh(data.values, cmap=color_map, shading='auto')
603
- ax.invert_yaxis() # Flip the y-axis to start at the top
604
-
605
- # Adjust ticks and labels for x and y axes
606
- ax.set_xticks(np.arange(len(data.columns)) + 0.5)
607
- ax.set_xticklabels(data.columns, ha='center')
608
- ax.set_yticks(np.arange(len(data.index)) + 0.5)
609
- ax.set_yticklabels(data.index, va='center')
610
-
611
- # Add labels to the axes
612
- ax.set_xlabel(xlabel, ha='center')
613
- ax.set_ylabel(ylabel, va='center')
614
- ax.set_title(title)
615
-
616
- # Position x-axis labels at the top
617
- ax.xaxis.set_label_position('top')
618
- ax.xaxis.set_ticks_position('top')
619
-
620
- # Add the colorbar
621
- sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max))
622
- sm1.set_array([])
623
- fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal')
624
-
625
- fig.tight_layout()
626
-
627
- return fig, ax
628
-
629
-
630
- def heat_map_plotly(
631
- data: pd.DataFrame,
632
- color_map: str = 'viridis',
633
- title: str = '',
634
- xlabel: str = 'Period',
635
- ylabel: str = 'Step',
636
- categorical_labels: bool = True,
637
- ) -> go.Figure:
604
+ def reshape_data_for_heatmap(
605
+ data: xr.DataArray,
606
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
607
+ | Literal['auto']
608
+ | None = 'auto',
609
+ facet_by: str | list[str] | None = None,
610
+ animate_by: str | None = None,
611
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
612
+ ) -> xr.DataArray:
638
613
  """
639
- Plots a DataFrame as a heatmap using Plotly. The columns of the DataFrame will be mapped to the x-axis,
640
- and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot.
614
+ Reshape data for heatmap visualization, handling time dimension intelligently.
641
615
 
642
- Args:
643
- data: A DataFrame with the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
644
- The values in the DataFrame will be represented as colors in the heatmap.
645
- color_map: The color scale to use for the heatmap. Default is 'viridis'. Plotly supports various color scales like 'Cividis', 'Inferno', etc.
646
- title: The title of the heatmap. Default is an empty string.
647
- xlabel: The label for the x-axis. Default is 'Period'.
648
- ylabel: The label for the y-axis. Default is 'Step'.
649
- categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data).
650
- Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data.
616
+ This function decides whether to reshape the 'time' dimension based on the reshape_time parameter:
617
+ - 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap
618
+ - Tuple: Explicitly reshapes time with specified parameters
619
+ - None: No reshaping (returns data as-is)
651
620
 
652
- Returns:
653
- A Plotly figure object containing the heatmap. This can be further customized and saved
654
- or displayed using `fig.show()`.
655
-
656
- Notes:
657
- The color bar is automatically scaled to the minimum and maximum values in the data.
658
- The y-axis is reversed to display the first row at the top.
659
- """
660
-
661
- color_bar_min, color_bar_max = data.min().min(), data.max().max() # Min and max values for color scaling
662
- # Define the figure
663
- fig = go.Figure(
664
- data=go.Heatmap(
665
- z=data.values,
666
- x=data.columns,
667
- y=data.index,
668
- colorscale=color_map,
669
- zmin=color_bar_min,
670
- zmax=color_bar_max,
671
- colorbar=dict(
672
- title=dict(text='Color Bar Label', side='right'),
673
- orientation='h',
674
- xref='container',
675
- yref='container',
676
- len=0.8, # Color bar length relative to plot
677
- x=0.5,
678
- y=0.1,
679
- ),
680
- )
681
- )
682
-
683
- # Set axis labels and style
684
- fig.update_layout(
685
- title=title,
686
- xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None),
687
- yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None),
688
- )
689
-
690
- return fig
691
-
692
-
693
- def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray:
694
- """
695
- Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap.
696
-
697
- The reshaped array will have the number of rows corresponding to the steps per column
698
- (e.g., 24 hours per day) and columns representing time periods (e.g., days or months).
621
+ All non-time dimensions are preserved during reshaping.
699
622
 
700
623
  Args:
701
- data_1d: A 1D numpy array with the data to reshape.
702
- nr_of_steps_per_column: The number of steps (rows) per column in the resulting 2D array. For example,
703
- this could be 24 (for hours) or 31 (for days in a month).
624
+ data: DataArray to reshape for heatmap visualization.
625
+ reshape_time: Reshaping configuration:
626
+ - 'auto' (default): Auto-reshape if needed based on facet_by/animate_by
627
+ - Tuple (timeframes, timesteps_per_frame): Explicit time reshaping
628
+ - None: No reshaping
629
+ facet_by: Dimension(s) used for faceting (used in 'auto' decision).
630
+ animate_by: Dimension used for animation (used in 'auto' decision).
631
+ fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'.
704
632
 
705
633
  Returns:
706
- The reshaped 2D array. Each internal array corresponds to one column, with the specified number of steps.
707
- Each column might represents a time period (e.g., day, month, etc.).
708
- """
709
-
710
- # Step 1: Ensure the input is a 1D array.
711
- if data_1d.ndim != 1:
712
- raise ValueError('Input must be a 1D array')
634
+ Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced
635
+ by 'timestep' and 'timeframe'. All other dimensions are preserved.
713
636
 
714
- # Step 2: Convert data to float type to allow NaN padding
715
- if data_1d.dtype != np.float64:
716
- data_1d = data_1d.astype(np.float64)
717
-
718
- # Step 3: Calculate the number of columns required
719
- total_steps = len(data_1d)
720
- cols = len(data_1d) // nr_of_steps_per_column # Base number of columns
721
-
722
- # If there's a remainder, add an extra column to hold the remaining values
723
- if total_steps % nr_of_steps_per_column != 0:
724
- cols += 1
637
+ Examples:
638
+ Auto-reshaping:
725
639
 
726
- # Step 4: Pad the 1D data to match the required number of rows and columns
727
- padded_data = np.pad(
728
- data_1d, (0, cols * nr_of_steps_per_column - total_steps), mode='constant', constant_values=np.nan
729
- )
640
+ ```python
641
+ # Will auto-reshape because only 'time' remains after faceting/animation
642
+ data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period')
643
+ ```
730
644
 
731
- # Step 5: Reshape the padded data into a 2D array
732
- data_2d = padded_data.reshape(cols, nr_of_steps_per_column)
645
+ Explicit reshaping:
733
646
 
734
- return data_2d.T
647
+ ```python
648
+ # Explicitly reshape to daily pattern
649
+ data = reshape_data_for_heatmap(data, reshape_time=('D', 'h'))
650
+ ```
735
651
 
652
+ No reshaping:
736
653
 
737
- def heat_map_data_from_df(
738
- df: pd.DataFrame,
739
- periods: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'],
740
- steps_per_period: Literal['W', 'D', 'h', '15min', 'min'],
741
- fill: Literal['ffill', 'bfill'] | None = None,
742
- ) -> pd.DataFrame:
654
+ ```python
655
+ # Keep data as-is
656
+ data = reshape_data_for_heatmap(data, reshape_time=None)
657
+ ```
743
658
  """
744
- Reshapes a DataFrame with a DateTime index into a 2D array for heatmap plotting,
745
- based on a specified sample rate.
746
- Only specific combinations of `periods` and `steps_per_period` are supported; invalid combinations raise an assertion.
747
-
748
- Args:
749
- df: A DataFrame with a DateTime index containing the data to reshape.
750
- periods: The time interval of each period (columns of the heatmap),
751
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
752
- steps_per_period: The time interval within each period (rows in the heatmap),
753
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
754
- fill: Method to fill missing values: 'ffill' for forward fill or 'bfill' for backward fill.
659
+ # If no time dimension, return data as-is
660
+ if 'time' not in data.dims:
661
+ return data
662
+
663
+ # Handle None (disabled) - return data as-is
664
+ if reshape_time is None:
665
+ return data
666
+
667
+ # Determine timeframes and timesteps_per_frame based on reshape_time parameter
668
+ if reshape_time == 'auto':
669
+ # Check if we need automatic time reshaping
670
+ facet_dims_used = []
671
+ if facet_by:
672
+ facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by)
673
+ if animate_by:
674
+ facet_dims_used.append(animate_by)
675
+
676
+ # Get dimensions that would remain for heatmap
677
+ potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used]
678
+
679
+ # Auto-reshape if only 'time' dimension remains
680
+ if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time':
681
+ logger.debug(
682
+ "Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. "
683
+ "Using default timeframes='D' and timesteps_per_frame='h'. "
684
+ "To customize, use reshape_time=('D', 'h') or disable with reshape_time=None."
685
+ )
686
+ timeframes, timesteps_per_frame = 'D', 'h'
687
+ else:
688
+ # No reshaping needed
689
+ return data
690
+ elif isinstance(reshape_time, tuple):
691
+ # Explicit reshaping
692
+ timeframes, timesteps_per_frame = reshape_time
693
+ else:
694
+ raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}")
755
695
 
756
- Returns:
757
- A DataFrame suitable for heatmap plotting, with rows representing steps within each period
758
- and columns representing each period.
759
- """
760
- assert pd.api.types.is_datetime64_any_dtype(df.index), (
761
- 'The index of the DataFrame must be datetime to transform it properly for a heatmap plot'
762
- )
696
+ # Validate that time is datetime
697
+ if not np.issubdtype(data.coords['time'].dtype, np.datetime64):
698
+ raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}')
763
699
 
764
- # Define formats for different combinations of `periods` and `steps_per_period`
700
+ # Define formats for different combinations
765
701
  formats = {
766
702
  ('YS', 'W'): ('%Y', '%W'),
767
703
  ('YS', 'D'): ('%Y', '%j'), # day of year
768
704
  ('YS', 'h'): ('%Y', '%j %H:00'),
769
705
  ('MS', 'D'): ('%Y-%m', '%d'), # day of month
770
706
  ('MS', 'h'): ('%Y-%m', '%d %H:00'),
771
- ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week (with prefix for proper sorting)
707
+ ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
772
708
  ('W', 'h'): ('%Y-w%W', '%w_%A %H:00'),
773
709
  ('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour
774
710
  ('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute
@@ -776,43 +712,64 @@ def heat_map_data_from_df(
776
712
  ('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour
777
713
  }
778
714
 
779
- if df.empty:
780
- raise ValueError('DataFrame is empty.')
781
- diffs = df.index.to_series().diff().dropna()
782
- minimum_time_diff_in_min = diffs.min().total_seconds() / 60
783
- time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
784
- if time_intervals[steps_per_period] > minimum_time_diff_in_min:
785
- logger.error(
786
- f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to '
787
- f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.'
788
- )
789
-
790
- # Select the format based on the `periods` and `steps_per_period` combination
791
- format_pair = (periods, steps_per_period)
715
+ format_pair = (timeframes, timesteps_per_frame)
792
716
  if format_pair not in formats:
793
717
  raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}')
794
718
  period_format, step_format = formats[format_pair]
795
719
 
796
- df = df.sort_index() # Ensure DataFrame is sorted by time index
720
+ # Check if resampling is needed
721
+ if data.sizes['time'] > 1:
722
+ # Use NumPy for more efficient timedelta computation
723
+ time_values = data.coords['time'].values # Already numpy datetime64[ns]
724
+ # Calculate differences and convert to minutes
725
+ time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0
726
+ if time_diffs.size > 0:
727
+ min_time_diff_min = np.nanmin(time_diffs)
728
+ time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
729
+ if time_intervals[timesteps_per_frame] > min_time_diff_min:
730
+ logger.warning(
731
+ f'Resampling data from {min_time_diff_min:.2f} min to '
732
+ f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.'
733
+ )
797
734
 
798
- resampled_data = df.resample(steps_per_period).mean() # Resample and fill any gaps with NaN
735
+ # Resample along time dimension
736
+ resampled = data.resample(time=timesteps_per_frame).mean()
799
737
 
800
- if fill == 'ffill': # Apply fill method if specified
801
- resampled_data = resampled_data.ffill()
738
+ # Apply fill if specified
739
+ if fill == 'ffill':
740
+ resampled = resampled.ffill(dim='time')
802
741
  elif fill == 'bfill':
803
- resampled_data = resampled_data.bfill()
742
+ resampled = resampled.bfill(dim='time')
743
+
744
+ # Create period and step labels
745
+ time_values = pd.to_datetime(resampled.coords['time'].values)
746
+ period_labels = time_values.strftime(period_format)
747
+ step_labels = time_values.strftime(step_format)
748
+
749
+ # Handle special case for weekly day format
750
+ if '%w_%A' in step_format:
751
+ step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values
752
+
753
+ # Add period and step as coordinates
754
+ resampled = resampled.assign_coords(
755
+ {
756
+ 'timeframe': ('time', period_labels),
757
+ 'timestep': ('time', step_labels),
758
+ }
759
+ )
804
760
 
805
- resampled_data['period'] = resampled_data.index.strftime(period_format)
806
- resampled_data['step'] = resampled_data.index.strftime(step_format)
807
- if '%w_%A' in step_format: # Shift index of strings to ensure proper sorting
808
- resampled_data['step'] = resampled_data['step'].apply(
809
- lambda x: x.replace('0_Sunday', '7_Sunday') if '0_Sunday' in x else x
810
- )
761
+ # Convert to multi-index and unstack
762
+ resampled = resampled.set_index(time=['timeframe', 'timestep'])
763
+ result = resampled.unstack('time')
764
+
765
+ # Ensure timestep and timeframe come first in dimension order
766
+ # Get other dimensions
767
+ other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']]
811
768
 
812
- # Pivot the table so periods are columns and steps are indices
813
- df_pivoted = resampled_data.pivot(columns='period', index='step', values=df.columns[0])
769
+ # Reorder: timestep, timeframe, then other dimensions
770
+ result = result.transpose('timestep', 'timeframe', *other_dims)
814
771
 
815
- return df_pivoted
772
+ return result
816
773
 
817
774
 
818
775
  def plot_network(
@@ -899,518 +856,704 @@ def plot_network(
899
856
  )
900
857
 
901
858
 
902
- def pie_with_plotly(
903
- data: pd.DataFrame,
904
- colors: ColorType = 'viridis',
905
- title: str = '',
906
- legend_title: str = '',
907
- hole: float = 0.0,
908
- fig: go.Figure | None = None,
909
- ) -> go.Figure:
859
+ def preprocess_data_for_pie(
860
+ data: xr.Dataset | pd.DataFrame | pd.Series,
861
+ lower_percentage_threshold: float = 5.0,
862
+ ) -> pd.Series:
910
863
  """
911
- Create a pie chart with Plotly to visualize the proportion of values in a DataFrame.
864
+ Preprocess data for pie chart display.
865
+
866
+ Groups items that are individually below the threshold percentage into an "Other" category.
867
+ Converts various input types to a pandas Series for uniform handling.
912
868
 
913
869
  Args:
914
- data: A DataFrame containing the data to plot. If multiple rows exist,
915
- they will be summed unless a specific index value is passed.
916
- colors: Color specification, can be:
917
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
918
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
919
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
920
- title: The title of the plot.
921
- legend_title: The title for the legend.
922
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
923
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
870
+ data: Input data (xarray Dataset, DataFrame, or Series)
871
+ lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other"
924
872
 
925
873
  Returns:
926
- A Plotly figure object containing the generated pie chart.
874
+ Processed pandas Series with small items grouped into "Other"
875
+ """
876
+ # Convert to Series
877
+ if isinstance(data, xr.Dataset):
878
+ # Sum all dimensions for each variable to get total values
879
+ values = {}
880
+ for var in data.data_vars:
881
+ var_data = data[var]
882
+ if len(var_data.dims) > 0:
883
+ total_value = float(var_data.sum().item())
884
+ else:
885
+ total_value = float(var_data.item())
927
886
 
928
- Notes:
929
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
930
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
931
- for better readability.
932
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
887
+ # Handle negative values
888
+ if total_value < 0:
889
+ logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.')
890
+ total_value = abs(total_value)
933
891
 
934
- """
935
- if data.empty:
936
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
937
- return go.Figure()
892
+ values[var] = total_value
938
893
 
939
- # Create a copy to avoid modifying the original DataFrame
940
- data_copy = data.copy()
894
+ series = pd.Series(values)
941
895
 
942
- # Check if any negative values and warn
943
- if (data_copy < 0).any().any():
944
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
945
- data_copy = data_copy.abs()
896
+ elif isinstance(data, pd.DataFrame):
897
+ # Sum across all columns if DataFrame
898
+ series = data.sum(axis=0)
899
+ # Handle negative values
900
+ negative_mask = series < 0
901
+ if negative_mask.any():
902
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
903
+ series = series.abs()
946
904
 
947
- # If data has multiple rows, sum them to get total for each column
948
- if len(data_copy) > 1:
949
- data_sum = data_copy.sum()
950
- else:
951
- data_sum = data_copy.iloc[0]
905
+ else: # pd.Series
906
+ series = data.copy()
907
+ # Handle negative values
908
+ negative_mask = series < 0
909
+ if negative_mask.any():
910
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
911
+ series = series.abs()
952
912
 
953
- # Get labels (column names) and values
954
- labels = data_sum.index.tolist()
955
- values = data_sum.values.tolist()
913
+ # Only keep positive values
914
+ series = series[series > 0]
956
915
 
957
- # Apply color mapping using the unified color processor
958
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels)
916
+ if series.empty or lower_percentage_threshold <= 0:
917
+ return series
959
918
 
960
- # Create figure if not provided
961
- fig = fig if fig is not None else go.Figure()
919
+ # Calculate percentages
920
+ total = series.sum()
921
+ percentages = (series / total) * 100
962
922
 
963
- # Add pie trace
964
- fig.add_trace(
965
- go.Pie(
966
- labels=labels,
967
- values=values,
968
- hole=hole,
969
- marker=dict(colors=processed_colors),
970
- textinfo='percent+label+value',
971
- textposition='inside',
972
- insidetextorientation='radial',
923
+ # Find items below and above threshold
924
+ below_threshold = series[percentages < lower_percentage_threshold]
925
+ above_threshold = series[percentages >= lower_percentage_threshold]
926
+
927
+ # Only group if there are at least 2 items below threshold
928
+ if len(below_threshold) > 1:
929
+ # Create new series with items above threshold + "Other"
930
+ result = above_threshold.copy()
931
+ result['Other'] = below_threshold.sum()
932
+ return result
933
+
934
+ return series
935
+
936
+
937
+ def dual_pie_with_plotly(
938
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
939
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
940
+ colors: ColorType | None = None,
941
+ title: str = '',
942
+ subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
943
+ legend_title: str = '',
944
+ hole: float = 0.2,
945
+ lower_percentage_group: float = 5.0,
946
+ text_info: str = 'percent+label',
947
+ text_position: str = 'inside',
948
+ hover_template: str = '%{label}: %{value} (%{percent})',
949
+ ) -> go.Figure:
950
+ """
951
+ Create two pie charts side by side with Plotly.
952
+
953
+ Args:
954
+ data_left: Data for the left pie chart. Variables are summed across all dimensions.
955
+ data_right: Data for the right pie chart. Variables are summed across all dimensions.
956
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
957
+ title: The main title of the plot.
958
+ subtitles: Tuple containing the subtitles for (left, right) charts.
959
+ legend_title: The title for the legend.
960
+ hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
961
+ lower_percentage_group: Group segments whose cumulative share is below this percentage (0–100) into "Other".
962
+ hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
963
+ text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
964
+ 'label+value', 'percent+value', 'label+percent+value', or 'none'.
965
+ text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
966
+
967
+ Returns:
968
+ Plotly Figure object
969
+ """
970
+ if colors is None:
971
+ colors = CONFIG.Plotting.default_qualitative_colorscale
972
+
973
+ # Preprocess data to Series
974
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
975
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
976
+
977
+ # Extract labels and values
978
+ left_labels = left_series.index.tolist()
979
+ left_values = left_series.values.tolist()
980
+
981
+ right_labels = right_series.index.tolist()
982
+ right_values = right_series.values.tolist()
983
+
984
+ # Get all unique labels for consistent coloring
985
+ all_labels = sorted(set(left_labels) | set(right_labels))
986
+
987
+ # Create color map
988
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
989
+
990
+ # Create figure
991
+ fig = go.Figure()
992
+
993
+ # Add left pie
994
+ if left_labels:
995
+ fig.add_trace(
996
+ go.Pie(
997
+ labels=left_labels,
998
+ values=left_values,
999
+ name=subtitles[0],
1000
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]),
1001
+ hole=hole,
1002
+ textinfo=text_info,
1003
+ textposition=text_position,
1004
+ hovertemplate=hover_template,
1005
+ domain=dict(x=[0, 0.48]),
1006
+ )
973
1007
  )
974
- )
975
1008
 
976
- # Update layout for better aesthetics
1009
+ # Add right pie
1010
+ if right_labels:
1011
+ fig.add_trace(
1012
+ go.Pie(
1013
+ labels=right_labels,
1014
+ values=right_values,
1015
+ name=subtitles[1],
1016
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]),
1017
+ hole=hole,
1018
+ textinfo=text_info,
1019
+ textposition=text_position,
1020
+ hovertemplate=hover_template,
1021
+ domain=dict(x=[0.52, 1]),
1022
+ )
1023
+ )
1024
+
1025
+ # Update layout
977
1026
  fig.update_layout(
978
1027
  title=title,
979
1028
  legend_title=legend_title,
980
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
981
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
982
- font=dict(size=14), # Increase font size for better readability
1029
+ margin=dict(t=80, b=50, l=30, r=30),
983
1030
  )
984
1031
 
985
1032
  return fig
986
1033
 
987
1034
 
988
- def pie_with_matplotlib(
989
- data: pd.DataFrame,
990
- colors: ColorType = 'viridis',
1035
+ def dual_pie_with_matplotlib(
1036
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
1037
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
1038
+ colors: ColorType | None = None,
991
1039
  title: str = '',
992
- legend_title: str = 'Categories',
993
- hole: float = 0.0,
994
- figsize: tuple[int, int] = (10, 8),
995
- fig: plt.Figure | None = None,
996
- ax: plt.Axes | None = None,
997
- ) -> tuple[plt.Figure, plt.Axes]:
1040
+ subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
1041
+ legend_title: str = '',
1042
+ hole: float = 0.2,
1043
+ lower_percentage_group: float = 5.0,
1044
+ figsize: tuple[int, int] = (14, 7),
1045
+ ) -> tuple[plt.Figure, list[plt.Axes]]:
998
1046
  """
999
- Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame.
1047
+ Create two pie charts side by side with Matplotlib.
1000
1048
 
1001
1049
  Args:
1002
- data: A DataFrame containing the data to plot. If multiple rows exist,
1003
- they will be summed unless a specific index value is passed.
1004
- colors: Color specification, can be:
1005
- - A string with a colormap name (e.g., 'viridis', 'plasma')
1006
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1007
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
1008
- title: The title of the plot.
1050
+ data_left: Data for the left pie chart.
1051
+ data_right: Data for the right pie chart.
1052
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
1053
+ title: The main title of the plot.
1054
+ subtitles: Tuple containing the subtitles for (left, right) charts.
1009
1055
  legend_title: The title for the legend.
1010
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
1056
+ hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
1057
+ lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
1011
1058
  figsize: The size of the figure (width, height) in inches.
1012
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1013
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
1014
1059
 
1015
1060
  Returns:
1016
- A tuple containing the Matplotlib figure and axes objects used for the plot.
1061
+ Tuple of (Figure, list of Axes)
1062
+ """
1063
+ if colors is None:
1064
+ colors = CONFIG.Plotting.default_qualitative_colorscale
1017
1065
 
1018
- Notes:
1019
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
1020
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
1021
- for better readability.
1022
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
1066
+ # Preprocess data to Series
1067
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
1068
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
1023
1069
 
1024
- """
1025
- if data.empty:
1026
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
1027
- if fig is None or ax is None:
1028
- fig, ax = plt.subplots(figsize=figsize)
1029
- return fig, ax
1070
+ # Extract labels and values
1071
+ left_labels = left_series.index.tolist()
1072
+ left_values = left_series.values.tolist()
1030
1073
 
1031
- # Create a copy to avoid modifying the original DataFrame
1032
- data_copy = data.copy()
1074
+ right_labels = right_series.index.tolist()
1075
+ right_values = right_series.values.tolist()
1033
1076
 
1034
- # Check if any negative values and warn
1035
- if (data_copy < 0).any().any():
1036
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1037
- data_copy = data_copy.abs()
1077
+ # Get all unique labels for consistent coloring
1078
+ all_labels = sorted(set(left_labels) | set(right_labels))
1038
1079
 
1039
- # If data has multiple rows, sum them to get total for each column
1040
- if len(data_copy) > 1:
1041
- data_sum = data_copy.sum()
1042
- else:
1043
- data_sum = data_copy.iloc[0]
1080
+ # Create color map (process_colors always returns a dict)
1081
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
1044
1082
 
1045
- # Get labels (column names) and values
1046
- labels = data_sum.index.tolist()
1047
- values = data_sum.values.tolist()
1083
+ # Create figure
1084
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1048
1085
 
1049
- # Apply color mapping using the unified color processor
1050
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels)
1086
+ def draw_pie(ax, labels, values, subtitle):
1087
+ """Draw a single pie chart."""
1088
+ if not labels:
1089
+ ax.set_title(subtitle)
1090
+ ax.axis('off')
1091
+ return
1051
1092
 
1052
- # Create figure and axis if not provided
1053
- if fig is None or ax is None:
1054
- fig, ax = plt.subplots(figsize=figsize)
1093
+ chart_colors = [color_map[label] for label in labels]
1055
1094
 
1056
- # Draw the pie chart
1057
- wedges, texts, autotexts = ax.pie(
1058
- values,
1059
- labels=labels,
1060
- colors=processed_colors,
1061
- autopct='%1.1f%%',
1062
- startangle=90,
1063
- shadow=False,
1064
- wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut
1065
- )
1095
+ # Draw pie
1096
+ wedges, texts, autotexts = ax.pie(
1097
+ values,
1098
+ labels=labels,
1099
+ colors=chart_colors,
1100
+ autopct='%1.1f%%',
1101
+ startangle=90,
1102
+ wedgeprops=dict(width=1 - hole) if hole > 0 else None,
1103
+ )
1104
+
1105
+ # Style text
1106
+ for autotext in autotexts:
1107
+ autotext.set_fontsize(10)
1108
+ autotext.set_color('white')
1109
+ autotext.set_weight('bold')
1066
1110
 
1067
- # Adjust the wedgeprops to make donut hole size consistent with plotly
1068
- # For matplotlib, the hole size is determined by the wedge width
1069
- # Convert hole parameter to wedge width
1070
- if hole > 0:
1071
- # Adjust hole size to match plotly's hole parameter
1072
- # In matplotlib, wedge width is relative to the radius (which is 1)
1073
- # For plotly, hole is a fraction of the radius
1074
- wedge_width = 1 - hole
1075
- for wedge in wedges:
1076
- wedge.set_width(wedge_width)
1077
-
1078
- # Customize the appearance
1079
- # Make autopct text more visible
1080
- for autotext in autotexts:
1081
- autotext.set_fontsize(10)
1082
- autotext.set_color('white')
1083
-
1084
- # Set aspect ratio to be equal to ensure a circular pie
1085
- ax.set_aspect('equal')
1086
-
1087
- # Add title
1111
+ ax.set_aspect('equal')
1112
+ ax.set_title(subtitle, fontsize=14, pad=20)
1113
+
1114
+ # Draw both pies
1115
+ draw_pie(axes[0], left_labels, left_values, subtitles[0])
1116
+ draw_pie(axes[1], right_labels, right_values, subtitles[1])
1117
+
1118
+ # Add main title
1088
1119
  if title:
1089
- ax.set_title(title, fontsize=16)
1120
+ fig.suptitle(title, fontsize=16, y=0.98)
1090
1121
 
1091
- # Create a legend if there are many segments
1092
- if len(labels) > 6:
1093
- ax.legend(wedges, labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0, 0.5, 1))
1122
+ # Create unified legend
1123
+ if left_labels or right_labels:
1124
+ handles = [
1125
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10)
1126
+ for label in all_labels
1127
+ ]
1128
+
1129
+ fig.legend(
1130
+ handles=handles,
1131
+ labels=all_labels,
1132
+ title=legend_title,
1133
+ loc='lower center',
1134
+ bbox_to_anchor=(0.5, -0.02),
1135
+ ncol=min(len(all_labels), 5),
1136
+ )
1137
+
1138
+ fig.subplots_adjust(bottom=0.15)
1094
1139
 
1095
- # Apply tight layout
1096
1140
  fig.tight_layout()
1097
1141
 
1098
- return fig, ax
1142
+ return fig, axes
1099
1143
 
1100
1144
 
1101
- def dual_pie_with_plotly(
1102
- data_left: pd.Series,
1103
- data_right: pd.Series,
1104
- colors: ColorType = 'viridis',
1145
+ def heatmap_with_plotly_v2(
1146
+ data: xr.DataArray,
1147
+ colors: ColorType | None = None,
1105
1148
  title: str = '',
1106
- subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
1107
- legend_title: str = '',
1108
- hole: float = 0.2,
1109
- lower_percentage_group: float = 5.0,
1110
- hover_template: str = '%{label}: %{value} (%{percent})',
1111
- text_info: str = 'percent+label',
1112
- text_position: str = 'inside',
1149
+ facet_col: str | None = None,
1150
+ animation_frame: str | None = None,
1151
+ facet_col_wrap: int | None = None,
1152
+ **imshow_kwargs: Any,
1113
1153
  ) -> go.Figure:
1114
1154
  """
1115
- Create two pie charts side by side with Plotly, with consistent coloring across both charts.
1155
+ Plot a heatmap using Plotly's imshow.
1156
+
1157
+ Data should be prepared with dims in order: (y_axis, x_axis, [facet_col], [animation_frame]).
1158
+ Use reshape_data_for_heatmap() to prepare time-series data before calling this.
1116
1159
 
1117
1160
  Args:
1118
- data_left: Series for the left pie chart.
1119
- data_right: Series for the right pie chart.
1120
- colors: Color specification, can be:
1121
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
1122
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1123
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
1124
- title: The main title of the plot.
1125
- subtitles: Tuple containing the subtitles for (left, right) charts.
1126
- legend_title: The title for the legend.
1127
- hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
1128
- lower_percentage_group: Group segments whose cumulative share is below this percentage (0–100) into "Other".
1129
- hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
1130
- text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
1131
- 'label+value', 'percent+value', 'label+percent+value', or 'none'.
1132
- text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
1161
+ data: DataArray with 2-4 dimensions. First two are heatmap axes.
1162
+ colors: Colorscale name ('viridis', 'plasma', etc.).
1163
+ title: Plot title.
1164
+ facet_col: Dimension name for subplot columns (3rd dim).
1165
+ animation_frame: Dimension name for animation (4th dim).
1166
+ facet_col_wrap: Max columns before wrapping (only if < n_facets).
1167
+ **imshow_kwargs: Additional args for px.imshow.
1133
1168
 
1134
1169
  Returns:
1135
- A Plotly figure object containing the generated dual pie chart.
1170
+ Plotly Figure object.
1136
1171
  """
1137
- from plotly.subplots import make_subplots
1138
-
1139
- # Check for empty data
1140
- if data_left.empty and data_right.empty:
1141
- logger.error('Both datasets are empty. Returning empty figure.')
1172
+ if data.size == 0:
1142
1173
  return go.Figure()
1143
1174
 
1144
- # Create a subplot figure
1145
- fig = make_subplots(
1146
- rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05
1147
- )
1175
+ colors = colors or CONFIG.Plotting.default_sequential_colorscale
1176
+ facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols
1148
1177
 
1149
- # Process series to handle negative values and apply minimum percentage threshold
1150
- def preprocess_series(series: pd.Series):
1151
- """
1152
- Preprocess a series for pie chart display by handling negative values
1153
- and grouping the smallest parts together if they collectively represent
1154
- less than the specified percentage threshold.
1178
+ imshow_args: dict[str, Any] = {
1179
+ 'img': data,
1180
+ 'color_continuous_scale': colors,
1181
+ 'title': title,
1182
+ **imshow_kwargs,
1183
+ }
1155
1184
 
1156
- Args:
1157
- series: The series to preprocess
1185
+ if facet_col and facet_col in data.dims:
1186
+ imshow_args['facet_col'] = facet_col
1187
+ if facet_col_wrap < data.sizes[facet_col]:
1188
+ imshow_args['facet_col_wrap'] = facet_col_wrap
1158
1189
 
1159
- Returns:
1160
- A preprocessed pandas Series
1161
- """
1162
- # Handle negative values
1163
- if (series < 0).any():
1164
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1165
- series = series.abs()
1190
+ if animation_frame and animation_frame in data.dims:
1191
+ imshow_args['animation_frame'] = animation_frame
1166
1192
 
1167
- # Remove zeros
1168
- series = series[series > 0]
1193
+ return px.imshow(**imshow_args)
1169
1194
 
1170
- # Apply minimum percentage threshold if needed
1171
- if lower_percentage_group and not series.empty:
1172
- total = series.sum()
1173
- if total > 0:
1174
- # Sort series by value (ascending)
1175
- sorted_series = series.sort_values()
1176
1195
 
1177
- # Calculate cumulative percentage contribution
1178
- cumulative_percent = (sorted_series.cumsum() / total) * 100
1196
+ def heatmap_with_plotly(
1197
+ data: xr.DataArray,
1198
+ colors: ColorType | None = None,
1199
+ title: str = '',
1200
+ facet_by: str | list[str] | None = None,
1201
+ animate_by: str | None = None,
1202
+ facet_cols: int | None = None,
1203
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1204
+ | Literal['auto']
1205
+ | None = 'auto',
1206
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1207
+ **imshow_kwargs: Any,
1208
+ ) -> go.Figure:
1209
+ """
1210
+ Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
1179
1211
 
1180
- # Find entries that collectively make up less than lower_percentage_group
1181
- to_group = cumulative_percent <= lower_percentage_group
1212
+ This function creates heatmap visualizations from xarray DataArrays, supporting
1213
+ multi-dimensional data through faceting (subplots) and animation. It automatically
1214
+ handles dimension reduction and data reshaping for optimal heatmap display.
1182
1215
 
1183
- if to_group.sum() > 1:
1184
- # Create "Other" category for the smallest values that together are < threshold
1185
- other_sum = sorted_series[to_group].sum()
1216
+ Automatic Time Reshaping:
1217
+ If only the 'time' dimension remains after faceting/animation (making the data 1D),
1218
+ the function automatically reshapes time into a 2D format using default values
1219
+ (timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap
1220
+ showing hours vs days.
1186
1221
 
1187
- # Keep only values that aren't in the "Other" group
1188
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
1222
+ Args:
1223
+ data: An xarray DataArray containing the data to visualize. Should have at least
1224
+ 2 dimensions, or a 'time' dimension that can be reshaped into 2D.
1225
+ colors: Color specification (colorscale name, list, or dict). Common options:
1226
+ 'turbo', 'plasma', 'RdBu', 'portland'.
1227
+ title: The main title of the heatmap.
1228
+ facet_by: Dimension to create facets for. Creates a subplot grid.
1229
+ Can be a single dimension name or list (only first dimension used).
1230
+ Note: px.imshow only supports single-dimension faceting.
1231
+ If the dimension doesn't exist in the data, it will be silently ignored.
1232
+ animate_by: Dimension to animate over. Creates animation frames.
1233
+ If the dimension doesn't exist in the data, it will be silently ignored.
1234
+ facet_cols: Number of columns in the facet grid (used with facet_by).
1235
+ reshape_time: Time reshaping configuration:
1236
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains
1237
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1238
+ - None: Disable time reshaping (will error if only 1D time data)
1239
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1240
+ **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow.
1241
+ Common options include:
1242
+ - aspect: 'auto', 'equal', or a number for aspect ratio
1243
+ - zmin, zmax: Minimum and maximum values for color scale
1244
+ - labels: Dict to customize axis labels
1189
1245
 
1190
- # Add the "Other" category if it has a value
1191
- if other_sum > 0:
1192
- result_series['Other'] = other_sum
1246
+ Returns:
1247
+ A Plotly figure object containing the heatmap visualization.
1193
1248
 
1194
- return result_series
1249
+ Examples:
1250
+ Simple heatmap:
1195
1251
 
1196
- return series
1252
+ ```python
1253
+ fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map')
1254
+ ```
1197
1255
 
1198
- data_left_processed = preprocess_series(data_left)
1199
- data_right_processed = preprocess_series(data_right)
1256
+ Facet by scenario:
1200
1257
 
1201
- # Get unique set of all labels for consistent coloring
1202
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
1258
+ ```python
1259
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2)
1260
+ ```
1203
1261
 
1204
- # Get consistent color mapping for both charts using our unified function
1205
- color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True)
1262
+ Animate by period:
1206
1263
 
1207
- # Function to create a pie trace with consistently mapped colors
1208
- def create_pie_trace(data_series, side):
1209
- if data_series.empty:
1210
- return None
1264
+ ```python
1265
+ fig = heatmap_with_plotly(data_array, animate_by='period')
1266
+ ```
1211
1267
 
1212
- labels = data_series.index.tolist()
1213
- values = data_series.values.tolist()
1214
- trace_colors = [color_map[label] for label in labels]
1268
+ Automatic time reshaping (when only time dimension remains):
1215
1269
 
1216
- return go.Pie(
1217
- labels=labels,
1218
- values=values,
1219
- name=side,
1220
- marker=dict(colors=trace_colors),
1221
- hole=hole,
1222
- textinfo=text_info,
1223
- textposition=text_position,
1224
- insidetextorientation='radial',
1225
- hovertemplate=hover_template,
1226
- sort=True, # Sort values by default (largest first)
1227
- )
1270
+ ```python
1271
+ # Data with dims ['time', 'period','scenario']
1272
+ # After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe)
1273
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period')
1274
+ ```
1228
1275
 
1229
- # Add left pie if data exists
1230
- left_trace = create_pie_trace(data_left_processed, subtitles[0])
1231
- if left_trace:
1232
- left_trace.domain = dict(x=[0, 0.48])
1233
- fig.add_trace(left_trace, row=1, col=1)
1276
+ Explicit time reshaping:
1234
1277
 
1235
- # Add right pie if data exists
1236
- right_trace = create_pie_trace(data_right_processed, subtitles[1])
1237
- if right_trace:
1238
- right_trace.domain = dict(x=[0.52, 1])
1239
- fig.add_trace(right_trace, row=1, col=2)
1278
+ ```python
1279
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
1280
+ ```
1281
+ """
1282
+ if colors is None:
1283
+ colors = CONFIG.Plotting.default_sequential_colorscale
1240
1284
 
1241
- # Update layout
1242
- fig.update_layout(
1243
- title=title,
1244
- legend_title=legend_title,
1245
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
1246
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
1247
- font=dict(size=14),
1248
- margin=dict(t=80, b=50, l=30, r=30),
1285
+ # Apply CONFIG defaults if not explicitly set
1286
+ if facet_cols is None:
1287
+ facet_cols = CONFIG.Plotting.default_facet_cols
1288
+
1289
+ # Handle empty data
1290
+ if data.size == 0:
1291
+ return go.Figure()
1292
+
1293
+ # Apply time reshaping using the new unified function
1294
+ data = reshape_data_for_heatmap(
1295
+ data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill
1249
1296
  )
1250
1297
 
1251
- return fig
1298
+ # Get available dimensions
1299
+ available_dims = list(data.dims)
1252
1300
 
1301
+ # Validate and filter facet_by dimensions
1302
+ if facet_by is not None:
1303
+ if isinstance(facet_by, str):
1304
+ if facet_by not in available_dims:
1305
+ logger.debug(
1306
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
1307
+ f'Ignoring facet_by parameter.'
1308
+ )
1309
+ facet_by = None
1310
+ elif isinstance(facet_by, list):
1311
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
1312
+ facet_by = [dim for dim in facet_by if dim in available_dims]
1313
+ if missing_dims:
1314
+ logger.debug(
1315
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
1316
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
1317
+ )
1318
+ if len(facet_by) == 0:
1319
+ facet_by = None
1320
+
1321
+ # Validate animate_by dimension
1322
+ if animate_by is not None and animate_by not in available_dims:
1323
+ logger.debug(
1324
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
1325
+ f'Ignoring animate_by parameter.'
1326
+ )
1327
+ animate_by = None
1253
1328
 
1254
- def dual_pie_with_matplotlib(
1255
- data_left: pd.Series,
1256
- data_right: pd.Series,
1257
- colors: ColorType = 'viridis',
1258
- title: str = '',
1259
- subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
1260
- legend_title: str = '',
1261
- hole: float = 0.2,
1262
- lower_percentage_group: float = 5.0,
1263
- figsize: tuple[int, int] = (14, 7),
1264
- fig: plt.Figure | None = None,
1265
- axes: list[plt.Axes] | None = None,
1266
- ) -> tuple[plt.Figure, list[plt.Axes]]:
1267
- """
1268
- Create two pie charts side by side with Matplotlib, with consistent coloring across both charts.
1269
- Leverages the existing pie_with_matplotlib function.
1329
+ # Determine which dimensions are used for faceting/animation
1330
+ facet_dims = []
1331
+ if facet_by:
1332
+ facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by
1333
+ if animate_by:
1334
+ facet_dims.append(animate_by)
1270
1335
 
1271
- Args:
1272
- data_left: Series for the left pie chart.
1273
- data_right: Series for the right pie chart.
1274
- colors: Color specification, can be:
1275
- - A string with a colormap name (e.g., 'viridis', 'plasma')
1276
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1277
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
1278
- title: The main title of the plot.
1279
- subtitles: Tuple containing the subtitles for (left, right) charts.
1280
- legend_title: The title for the legend.
1281
- hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
1282
- lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
1283
- figsize: The size of the figure (width, height) in inches.
1284
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1285
- axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created.
1336
+ # Get remaining dimensions for the heatmap itself
1337
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
1286
1338
 
1287
- Returns:
1288
- A tuple containing the Matplotlib figure and list of axes objects used for the plot.
1289
- """
1290
- # Check for empty data
1291
- if data_left.empty and data_right.empty:
1292
- logger.error('Both datasets are empty. Returning empty figure.')
1293
- if fig is None:
1294
- fig, axes = plt.subplots(1, 2, figsize=figsize)
1295
- return fig, axes
1296
-
1297
- # Create figure and axes if not provided
1298
- if fig is None or axes is None:
1299
- fig, axes = plt.subplots(1, 2, figsize=figsize)
1300
-
1301
- # Process series to handle negative values and apply minimum percentage threshold
1302
- def preprocess_series(series: pd.Series):
1303
- """
1304
- Preprocess a series for pie chart display by handling negative values
1305
- and grouping the smallest parts together if they collectively represent
1306
- less than the specified percentage threshold.
1307
- """
1308
- # Handle negative values
1309
- if (series < 0).any():
1310
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1311
- series = series.abs()
1339
+ if len(heatmap_dims) < 2:
1340
+ # Handle single-dimension case by adding variable name as a dimension
1341
+ if len(heatmap_dims) == 1:
1342
+ # Get the variable name, or use a default
1343
+ var_name = data.name if data.name else 'value'
1312
1344
 
1313
- # Remove zeros
1314
- series = series[series > 0]
1345
+ # Expand the DataArray by adding a new dimension with the variable name
1346
+ data = data.expand_dims({'variable': [var_name]})
1315
1347
 
1316
- # Apply minimum percentage threshold if needed
1317
- if lower_percentage_group and not series.empty:
1318
- total = series.sum()
1319
- if total > 0:
1320
- # Sort series by value (ascending)
1321
- sorted_series = series.sort_values()
1348
+ # Update available dimensions
1349
+ available_dims = list(data.dims)
1350
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
1322
1351
 
1323
- # Calculate cumulative percentage contribution
1324
- cumulative_percent = (sorted_series.cumsum() / total) * 100
1352
+ logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}')
1353
+ else:
1354
+ # No dimensions at all - cannot create a heatmap
1355
+ logger.error(
1356
+ f'Heatmap requires at least 1 dimension. '
1357
+ f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
1358
+ )
1359
+ return go.Figure()
1360
+
1361
+ # Setup faceting parameters for Plotly Express
1362
+ # Note: px.imshow only supports facet_col, not facet_row
1363
+ facet_col_param = None
1364
+ if facet_by:
1365
+ if isinstance(facet_by, str):
1366
+ facet_col_param = facet_by
1367
+ elif len(facet_by) == 1:
1368
+ facet_col_param = facet_by[0]
1369
+ elif len(facet_by) >= 2:
1370
+ # px.imshow doesn't support facet_row, so we can only facet by one dimension
1371
+ # Use the first dimension and warn about the rest
1372
+ facet_col_param = facet_by[0]
1373
+ logger.warning(
1374
+ f'px.imshow only supports faceting by a single dimension. '
1375
+ f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. '
1376
+ f'Consider using animate_by for additional dimensions.'
1377
+ )
1325
1378
 
1326
- # Find entries that collectively make up less than lower_percentage_group
1327
- to_group = cumulative_percent <= lower_percentage_group
1379
+ # Create the imshow plot - px.imshow can work directly with xarray DataArrays
1380
+ common_args = {
1381
+ 'img': data,
1382
+ 'color_continuous_scale': colors,
1383
+ 'title': title,
1384
+ }
1328
1385
 
1329
- if to_group.sum() > 1:
1330
- # Create "Other" category for the smallest values that together are < threshold
1331
- other_sum = sorted_series[to_group].sum()
1386
+ # Add faceting if specified
1387
+ if facet_col_param:
1388
+ common_args['facet_col'] = facet_col_param
1389
+ if facet_cols:
1390
+ common_args['facet_col_wrap'] = facet_cols
1332
1391
 
1333
- # Keep only values that aren't in the "Other" group
1334
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
1392
+ # Add animation if specified
1393
+ if animate_by:
1394
+ common_args['animation_frame'] = animate_by
1335
1395
 
1336
- # Add the "Other" category if it has a value
1337
- if other_sum > 0:
1338
- result_series['Other'] = other_sum
1396
+ # Merge in additional imshow kwargs
1397
+ common_args.update(imshow_kwargs)
1339
1398
 
1340
- return result_series
1399
+ try:
1400
+ fig = px.imshow(**common_args)
1401
+ except Exception as e:
1402
+ logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
1403
+ # Fallback: create a simple heatmap without faceting
1404
+ fallback_args = {
1405
+ 'img': data.values,
1406
+ 'color_continuous_scale': colors,
1407
+ 'title': title,
1408
+ }
1409
+ fallback_args.update(imshow_kwargs)
1410
+ fig = px.imshow(**fallback_args)
1341
1411
 
1342
- return series
1412
+ return fig
1343
1413
 
1344
- # Preprocess data
1345
- data_left_processed = preprocess_series(data_left)
1346
- data_right_processed = preprocess_series(data_right)
1347
1414
 
1348
- # Convert Series to DataFrames for pie_with_matplotlib
1349
- df_left = pd.DataFrame(data_left_processed).T if not data_left_processed.empty else pd.DataFrame()
1350
- df_right = pd.DataFrame(data_right_processed).T if not data_right_processed.empty else pd.DataFrame()
1415
+ def heatmap_with_matplotlib(
1416
+ data: xr.DataArray,
1417
+ colors: ColorType | None = None,
1418
+ title: str = '',
1419
+ figsize: tuple[float, float] = (12, 6),
1420
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1421
+ | Literal['auto']
1422
+ | None = 'auto',
1423
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1424
+ vmin: float | None = None,
1425
+ vmax: float | None = None,
1426
+ imshow_kwargs: dict[str, Any] | None = None,
1427
+ cbar_kwargs: dict[str, Any] | None = None,
1428
+ **kwargs: Any,
1429
+ ) -> tuple[plt.Figure, plt.Axes]:
1430
+ """
1431
+ Plot a heatmap visualization using Matplotlib's imshow.
1351
1432
 
1352
- # Get unique set of all labels for consistent coloring
1353
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
1433
+ This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's
1434
+ imshow function. For multi-dimensional data, only the first two dimensions are used.
1354
1435
 
1355
- # Get consistent color mapping for both charts using our unified function
1356
- color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True)
1436
+ Args:
1437
+ data: An xarray DataArray containing the data to visualize. Should have at least
1438
+ 2 dimensions. If more than 2 dimensions exist, additional dimensions will
1439
+ be reduced by taking the first slice.
1440
+ colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu').
1441
+ title: The title of the heatmap.
1442
+ figsize: The size of the figure (width, height) in inches.
1443
+ reshape_time: Time reshaping configuration:
1444
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
1445
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1446
+ - None: Disable time reshaping
1447
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1448
+ vmin: Minimum value for color scale. If None, uses data minimum.
1449
+ vmax: Maximum value for color scale. If None, uses data maximum.
1450
+ imshow_kwargs: Optional dict of parameters to pass to ax.imshow().
1451
+ Use this to customize image properties (e.g., interpolation, aspect).
1452
+ cbar_kwargs: Optional dict of parameters to pass to plt.colorbar().
1453
+ Use this to customize colorbar properties (e.g., orientation, label).
1454
+ **kwargs: Additional keyword arguments passed to ax.imshow().
1455
+ Common options include:
1456
+ - interpolation: 'nearest', 'bilinear', 'bicubic', etc.
1457
+ - alpha: Transparency level (0-1)
1458
+ - extent: [left, right, bottom, top] for axis limits
1357
1459
 
1358
- # Configure colors for each DataFrame based on the consistent mapping
1359
- left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else []
1360
- right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else []
1460
+ Returns:
1461
+ A tuple containing the Matplotlib figure and axes objects used for the plot.
1361
1462
 
1362
- # Create left pie chart
1363
- if not df_left.empty:
1364
- pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0])
1365
- else:
1366
- axes[0].set_title(subtitles[0])
1367
- axes[0].axis('off')
1463
+ Notes:
1464
+ - Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features.
1465
+ - The y-axis is automatically inverted to display data with origin at top-left.
1466
+ - A colorbar is added to show the value scale.
1368
1467
 
1369
- # Create right pie chart
1370
- if not df_right.empty:
1371
- pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1])
1372
- else:
1373
- axes[1].set_title(subtitles[1])
1374
- axes[1].axis('off')
1468
+ Examples:
1469
+ ```python
1470
+ fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature')
1471
+ plt.savefig('heatmap.png')
1472
+ ```
1375
1473
 
1376
- # Add main title
1377
- if title:
1378
- fig.suptitle(title, fontsize=16, y=0.98)
1474
+ Time reshaping:
1379
1475
 
1380
- # Adjust layout
1381
- fig.tight_layout()
1476
+ ```python
1477
+ fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
1478
+ ```
1479
+ """
1480
+ if colors is None:
1481
+ colors = CONFIG.Plotting.default_sequential_colorscale
1382
1482
 
1383
- # Create a unified legend if both charts have data
1384
- if not df_left.empty and not df_right.empty:
1385
- # Remove individual legends
1386
- for ax in axes:
1387
- if ax.get_legend():
1388
- ax.get_legend().remove()
1483
+ # Initialize kwargs if not provided
1484
+ if imshow_kwargs is None:
1485
+ imshow_kwargs = {}
1486
+ if cbar_kwargs is None:
1487
+ cbar_kwargs = {}
1389
1488
 
1390
- # Create handles for the unified legend
1391
- handles = []
1392
- labels_for_legend = []
1489
+ # Merge any additional kwargs into imshow_kwargs
1490
+ # This allows users to pass imshow options directly
1491
+ imshow_kwargs.update(kwargs)
1393
1492
 
1394
- for label in all_labels:
1395
- color = color_map[label]
1396
- patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label)
1397
- handles.append(patch)
1398
- labels_for_legend.append(label)
1493
+ # Handle empty data
1494
+ if data.size == 0:
1495
+ fig, ax = plt.subplots(figsize=figsize)
1496
+ return fig, ax
1399
1497
 
1400
- # Add unified legend
1401
- fig.legend(
1402
- handles=handles,
1403
- labels=labels_for_legend,
1404
- title=legend_title,
1405
- loc='lower center',
1406
- bbox_to_anchor=(0.5, 0),
1407
- ncol=min(len(all_labels), 5), # Limit columns to 5 for readability
1408
- )
1498
+ # Apply time reshaping using the new unified function
1499
+ # Matplotlib doesn't support faceting/animation, so we pass None for those
1500
+ data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
1409
1501
 
1410
- # Add padding at the bottom for the legend
1411
- fig.subplots_adjust(bottom=0.2)
1502
+ # Handle single-dimension case by adding variable name as a dimension
1503
+ if isinstance(data, xr.DataArray) and len(data.dims) == 1:
1504
+ var_name = data.name if data.name else 'value'
1505
+ data = data.expand_dims({'variable': [var_name]})
1506
+ logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}')
1412
1507
 
1413
- return fig, axes
1508
+ # Create figure and axes
1509
+ fig, ax = plt.subplots(figsize=figsize)
1510
+
1511
+ # Extract data values
1512
+ # If data has more than 2 dimensions, we need to reduce it
1513
+ if isinstance(data, xr.DataArray):
1514
+ # Get the first 2 dimensions
1515
+ dims = list(data.dims)
1516
+ if len(dims) > 2:
1517
+ logger.warning(
1518
+ f'Data has {len(dims)} dimensions: {dims}. '
1519
+ f'Only the first 2 will be used for the heatmap. '
1520
+ f'Use the plotly engine for faceting/animation support.'
1521
+ )
1522
+ # Select only the first 2 dimensions by taking first slice of others
1523
+ selection = {dim: 0 for dim in dims[2:]}
1524
+ data = data.isel(selection)
1525
+
1526
+ values = data.values
1527
+ x_labels = data.dims[1] if len(data.dims) > 1 else 'x'
1528
+ y_labels = data.dims[0] if len(data.dims) > 0 else 'y'
1529
+ else:
1530
+ values = data
1531
+ x_labels = 'x'
1532
+ y_labels = 'y'
1533
+
1534
+ # Create the heatmap using imshow with user customizations
1535
+ imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax}
1536
+ imshow_defaults.update(imshow_kwargs) # User kwargs override defaults
1537
+ im = ax.imshow(values, **imshow_defaults)
1538
+
1539
+ # Add colorbar with user customizations
1540
+ cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05}
1541
+ cbar_defaults.update(cbar_kwargs) # User kwargs override defaults
1542
+ cbar = plt.colorbar(im, **cbar_defaults)
1543
+
1544
+ # Set colorbar label if not overridden by user
1545
+ if 'label' not in cbar_kwargs:
1546
+ cbar.set_label('Value')
1547
+
1548
+ # Set labels and title
1549
+ ax.set_xlabel(str(x_labels).capitalize())
1550
+ ax.set_ylabel(str(y_labels).capitalize())
1551
+ ax.set_title(title)
1552
+
1553
+ # Apply tight layout
1554
+ fig.tight_layout()
1555
+
1556
+ return fig, ax
1414
1557
 
1415
1558
 
1416
1559
  def export_figure(
@@ -1418,8 +1561,9 @@ def export_figure(
1418
1561
  default_path: pathlib.Path,
1419
1562
  default_filetype: str | None = None,
1420
1563
  user_path: pathlib.Path | None = None,
1421
- show: bool = True,
1564
+ show: bool | None = None,
1422
1565
  save: bool = False,
1566
+ dpi: int | None = None,
1423
1567
  ) -> go.Figure | tuple[plt.Figure, plt.Axes]:
1424
1568
  """
1425
1569
  Export a figure to a file and or show it.
@@ -1429,13 +1573,21 @@ def export_figure(
1429
1573
  default_path: The default file path if no user filename is provided.
1430
1574
  default_filetype: The default filetype if the path doesnt end with a filetype.
1431
1575
  user_path: An optional user-specified file path.
1432
- show: Whether to display the figure (default: True).
1576
+ show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None).
1433
1577
  save: Whether to save the figure (default: False).
1578
+ dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi.
1434
1579
 
1435
1580
  Raises:
1436
1581
  ValueError: If no default filetype is provided and the path doesn't specify a filetype.
1437
1582
  TypeError: If the figure type is not supported.
1438
1583
  """
1584
+ # Apply CONFIG defaults if not explicitly set
1585
+ if show is None:
1586
+ show = CONFIG.Plotting.default_show
1587
+
1588
+ if dpi is None:
1589
+ dpi = CONFIG.Plotting.default_dpi
1590
+
1439
1591
  filename = user_path or default_path
1440
1592
  filename = filename.with_name(filename.name.replace('|', '__'))
1441
1593
  if filename.suffix == '':
@@ -1450,25 +1602,17 @@ def export_figure(
1450
1602
  filename = filename.with_suffix('.html')
1451
1603
 
1452
1604
  try:
1453
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
1454
-
1455
- if is_test_env:
1456
- # Test environment: never open browser, only save if requested
1457
- if save:
1458
- fig.write_html(str(filename))
1459
- # Ignore show flag in tests
1460
- else:
1461
- # Production environment: respect show and save flags
1462
- if save and show:
1463
- # Save and auto-open in browser
1464
- plotly.offline.plot(fig, filename=str(filename))
1465
- elif save and not show:
1466
- # Save without opening
1467
- fig.write_html(str(filename))
1468
- elif show and not save:
1469
- # Show interactively without saving
1470
- fig.show()
1471
- # If neither save nor show: do nothing
1605
+ # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False)
1606
+ if save and show:
1607
+ # Save and auto-open in browser
1608
+ plotly.offline.plot(fig, filename=str(filename))
1609
+ elif save and not show:
1610
+ # Save without opening
1611
+ fig.write_html(str(filename))
1612
+ elif show and not save:
1613
+ # Show interactively without saving
1614
+ fig.show()
1615
+ # If neither save nor show: do nothing
1472
1616
  finally:
1473
1617
  # Cleanup to prevent socket warnings
1474
1618
  if hasattr(fig, '_renderer'):
@@ -1479,16 +1623,15 @@ def export_figure(
1479
1623
  elif isinstance(figure_like, tuple):
1480
1624
  fig, ax = figure_like
1481
1625
  if show:
1482
- # Only show if using interactive backend and not in test environment
1626
+ # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False)
1483
1627
  backend = matplotlib.get_backend().lower()
1484
1628
  is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'}
1485
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
1486
1629
 
1487
- if is_interactive and not is_test_env:
1630
+ if is_interactive:
1488
1631
  plt.show()
1489
1632
 
1490
1633
  if save:
1491
- fig.savefig(str(filename), dpi=300)
1634
+ fig.savefig(str(filename), dpi=dpi)
1492
1635
  plt.close(fig) # Close figure to free memory
1493
1636
 
1494
1637
  return fig, ax