flixopt 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flixopt might be problematic. Click here for more details.

flixopt/plotting.py CHANGED
@@ -40,14 +40,16 @@ import plotly.express as px
40
40
  import plotly.graph_objects as go
41
41
  import plotly.offline
42
42
  import xarray as xr
43
- from plotly.exceptions import PlotlyError
43
+
44
+ from .color_processing import process_colors
45
+ from .config import CONFIG
44
46
 
45
47
  if TYPE_CHECKING:
46
48
  import pyvis
47
49
 
48
50
  logger = logging.getLogger('flixopt')
49
51
 
50
- # Define the colors for the 'portland' colormap in matplotlib
52
+ # Define the colors for the 'portland' colorscale in matplotlib
51
53
  _portland_colors = [
52
54
  [12 / 255, 51 / 255, 131 / 255], # Dark blue
53
55
  [10 / 255, 136 / 255, 186 / 255], # Light blue
@@ -56,7 +58,7 @@ _portland_colors = [
56
58
  [217 / 255, 30 / 255, 30 / 255], # Red
57
59
  ]
58
60
 
59
- # Check if the colormap already exists before registering it
61
+ # Check if the colorscale already exists before registering it
60
62
  if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7
61
63
  registry = plt.colormaps
62
64
  if 'portland' not in registry:
@@ -71,9 +73,9 @@ ColorType = str | list[str] | dict[str, str]
71
73
 
72
74
  Color specifications can take several forms to accommodate different use cases:
73
75
 
74
- **Named Colormaps** (str):
75
- - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1'
76
- - Energy-focused: 'portland' (custom flixopt colormap for energy systems)
76
+ **Named colorscales** (str):
77
+ - Standard colorscales: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1'
78
+ - Energy-focused: 'portland' (custom flixopt colorscale for energy systems)
77
79
  - Backend-specific maps available in Plotly and Matplotlib
78
80
 
79
81
  **Color Lists** (list[str]):
@@ -88,8 +90,8 @@ Color specifications can take several forms to accommodate different use cases:
88
90
 
89
91
  Examples:
90
92
  ```python
91
- # Named colormap
92
- colors = 'viridis' # Automatic color generation
93
+ # Named colorscale
94
+ colors = 'turbo' # Automatic color generation
93
95
 
94
96
  # Explicit color list
95
97
  colors = ['red', 'blue', 'green', '#FFD700']
@@ -112,7 +114,7 @@ Color Format Support:
112
114
 
113
115
  References:
114
116
  - HTML Color Names: https://htmlcolorcodes.com/color-names/
115
- - Matplotlib Colormaps: https://matplotlib.org/stable/tutorials/colors/colormaps.html
117
+ - Matplotlib colorscales: https://matplotlib.org/stable/tutorials/colors/colorscales.html
116
118
  - Plotly Built-in Colorscales: https://plotly.com/python/builtin-colorscales/
117
119
  """
118
120
 
@@ -120,242 +122,78 @@ PlottingEngine = Literal['plotly', 'matplotlib']
120
122
  """Identifier for the plotting engine to use."""
121
123
 
122
124
 
123
- class ColorProcessor:
124
- """Intelligent color management system for consistent multi-backend visualization.
125
-
126
- This class provides unified color processing across Plotly and Matplotlib backends,
127
- ensuring consistent visual appearance regardless of the plotting engine used.
128
- It handles color palette generation, named colormap translation, and intelligent
129
- color cycling for complex datasets with many categories.
130
-
131
- Key Features:
132
- **Backend Agnostic**: Automatic color format conversion between engines
133
- **Palette Management**: Support for named colormaps, custom palettes, and color lists
134
- **Intelligent Cycling**: Smart color assignment for datasets with many categories
135
- **Fallback Handling**: Graceful degradation when requested colormaps are unavailable
136
- **Energy System Colors**: Built-in palettes optimized for energy system visualization
137
-
138
- Color Input Types:
139
- - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc.
140
- - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00']
141
- - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'}
142
-
143
- Examples:
144
- Basic color processing:
145
-
146
- ```python
147
- # Initialize for Plotly backend
148
- processor = ColorProcessor(engine='plotly', default_colormap='viridis')
149
-
150
- # Process different color specifications
151
- colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage'])
152
- colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C'])
153
- colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas'])
154
-
155
- # Switch to Matplotlib
156
- processor = ColorProcessor(engine='matplotlib')
157
- mpl_colors = processor.process_colors('tab10', component_labels)
158
- ```
159
-
160
- Energy system visualization:
161
-
162
- ```python
163
- # Specialized energy system palette
164
- energy_colors = {
165
- 'Natural_Gas': '#8B4513', # Brown
166
- 'Electricity': '#FFD700', # Gold
167
- 'Heat': '#FF4500', # Red-orange
168
- 'Cooling': '#87CEEB', # Sky blue
169
- 'Hydrogen': '#E6E6FA', # Lavender
170
- 'Battery': '#32CD32', # Lime green
171
- }
172
-
173
- processor = ColorProcessor('plotly')
174
- flow_colors = processor.process_colors(energy_colors, flow_labels)
175
- ```
176
-
177
- Args:
178
- engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format.
179
- default_colormap: Fallback colormap when requested palettes are unavailable.
180
- Common options: 'viridis', 'plasma', 'tab10', 'portland'.
181
-
182
- """
183
-
184
- def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'):
185
- """Initialize the color processor with specified backend and defaults."""
186
- if engine not in ['plotly', 'matplotlib']:
187
- raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}')
188
- self.engine = engine
189
- self.default_colormap = default_colormap
190
-
191
- def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]:
192
- """
193
- Generate colors from a named colormap.
194
-
195
- Args:
196
- colormap_name: Name of the colormap
197
- num_colors: Number of colors to generate
198
-
199
- Returns:
200
- list of colors in the format appropriate for the engine
201
- """
202
- if self.engine == 'plotly':
203
- try:
204
- colorscale = px.colors.get_colorscale(colormap_name)
205
- except PlotlyError as e:
206
- logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}")
207
- colorscale = px.colors.get_colorscale(self.default_colormap)
208
-
209
- # Generate evenly spaced points
210
- color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0]
211
- return px.colors.sample_colorscale(colorscale, color_points)
212
-
213
- else: # matplotlib
214
- try:
215
- cmap = plt.get_cmap(colormap_name, num_colors)
216
- except ValueError as e:
217
- logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}")
218
- cmap = plt.get_cmap(self.default_colormap, num_colors)
219
-
220
- return [cmap(i) for i in range(num_colors)]
221
-
222
- def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]:
223
- """
224
- Handle a list of colors, cycling if necessary.
225
-
226
- Args:
227
- colors: list of color strings
228
- num_labels: Number of labels that need colors
229
-
230
- Returns:
231
- list of colors matching the number of labels
232
- """
233
- if len(colors) == 0:
234
- logger.error(f'Empty color list provided. Using {self.default_colormap} instead.')
235
- return self._generate_colors_from_colormap(self.default_colormap, num_labels)
236
-
237
- if len(colors) < num_labels:
238
- logger.warning(
239
- f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.'
240
- )
241
- # Cycle through the colors
242
- color_iter = itertools.cycle(colors)
243
- return [next(color_iter) for _ in range(num_labels)]
244
- else:
245
- # Trim if necessary
246
- if len(colors) > num_labels:
247
- logger.warning(
248
- f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.'
249
- )
250
- return colors[:num_labels]
251
-
252
- def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]:
253
- """
254
- Handle a dictionary mapping labels to colors.
255
-
256
- Args:
257
- colors: Dictionary mapping labels to colors
258
- labels: list of labels that need colors
259
-
260
- Returns:
261
- list of colors in the same order as labels
262
- """
263
- if len(colors) == 0:
264
- logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.')
265
- return self._generate_colors_from_colormap(self.default_colormap, len(labels))
266
-
267
- # Find missing labels
268
- missing_labels = sorted(set(labels) - set(colors.keys()))
269
- if missing_labels:
270
- logger.warning(
271
- f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.'
272
- )
125
+ def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset:
126
+ """Convert DataFrame or Series to Dataset if needed."""
127
+ if isinstance(data, xr.Dataset):
128
+ return data
129
+ elif isinstance(data, pd.DataFrame):
130
+ # Convert DataFrame to Dataset
131
+ return data.to_xarray()
132
+ elif isinstance(data, pd.Series):
133
+ # Convert Series to DataFrame first, then to Dataset
134
+ return data.to_frame().to_xarray()
135
+ else:
136
+ raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}')
273
137
 
274
- # Generate colors for missing labels
275
- missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels))
276
138
 
277
- # Create a copy to avoid modifying the original
278
- colors_copy = colors.copy()
279
- for i, label in enumerate(missing_labels):
280
- colors_copy[label] = missing_colors[i]
281
- else:
282
- colors_copy = colors
283
-
284
- # Create color list in the same order as labels
285
- return [colors_copy[label] for label in labels]
286
-
287
- def process_colors(
288
- self,
289
- colors: ColorType,
290
- labels: list[str],
291
- return_mapping: bool = False,
292
- ) -> list[Any] | dict[str, Any]:
293
- """
294
- Process colors for the specified labels.
295
-
296
- Args:
297
- colors: Color specification (colormap name, list of colors, or label-to-color mapping)
298
- labels: list of data labels that need colors assigned
299
- return_mapping: If True, returns a dictionary mapping labels to colors;
300
- if False, returns a list of colors in the same order as labels
301
-
302
- Returns:
303
- Either a list of colors or a dictionary mapping labels to colors
304
- """
305
- if len(labels) == 0:
306
- logger.error('No labels provided for color assignment.')
307
- return {} if return_mapping else []
308
-
309
- # Process based on type of colors input
310
- if isinstance(colors, str):
311
- color_list = self._generate_colors_from_colormap(colors, len(labels))
312
- elif isinstance(colors, list):
313
- color_list = self._handle_color_list(colors, len(labels))
314
- elif isinstance(colors, dict):
315
- color_list = self._handle_color_dict(colors, labels)
316
- else:
317
- logger.error(
318
- f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.'
139
+ def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None:
140
+ """Validate dataset for plotting (checks for empty data, non-numeric types, etc.)."""
141
+ # Check for empty data
142
+ if not allow_empty and len(data.data_vars) == 0:
143
+ raise ValueError('Empty Dataset provided (no variables). Cannot create plot.')
144
+
145
+ # Check if dataset has any data (xarray uses nbytes for total size)
146
+ if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True:
147
+ if not allow_empty and len(data.data_vars) > 0:
148
+ raise ValueError('Dataset has zero size. Cannot create plot.')
149
+ if len(data.data_vars) == 0:
150
+ return # Empty dataset, nothing to validate
151
+ return
152
+
153
+ # Check for non-numeric data types
154
+ for var in data.data_vars:
155
+ dtype = data[var].dtype
156
+ if not np.issubdtype(dtype, np.number):
157
+ raise TypeError(
158
+ f"Variable '{var}' has non-numeric dtype '{dtype}'. "
159
+ f'Plotting requires numeric data types (int, float, etc.).'
319
160
  )
320
- color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels))
321
161
 
322
- # Return either a list or a mapping
323
- if return_mapping:
324
- return {label: color_list[i] for i, label in enumerate(labels)}
325
- else:
326
- return color_list
162
+ # Warn about NaN/Inf values
163
+ for var in data.data_vars:
164
+ if np.isnan(data[var].values).any():
165
+ logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.")
166
+ if np.isinf(data[var].values).any():
167
+ logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.")
327
168
 
328
169
 
329
170
  def with_plotly(
330
- data: pd.DataFrame | xr.DataArray | xr.Dataset,
171
+ data: xr.Dataset | pd.DataFrame | pd.Series,
331
172
  mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
332
- colors: ColorType = 'viridis',
173
+ colors: ColorType | None = None,
333
174
  title: str = '',
334
175
  ylabel: str = '',
335
- xlabel: str = 'Time in h',
336
- fig: go.Figure | None = None,
176
+ xlabel: str = '',
337
177
  facet_by: str | list[str] | None = None,
338
178
  animate_by: str | None = None,
339
- facet_cols: int = 3,
179
+ facet_cols: int | None = None,
340
180
  shared_yaxes: bool = True,
341
181
  shared_xaxes: bool = True,
182
+ **px_kwargs: Any,
342
183
  ) -> go.Figure:
343
184
  """
344
185
  Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
345
186
 
346
187
  Uses Plotly Express for convenient faceting and animation with automatic styling.
347
- For simple plots without faceting, can optionally add to an existing figure.
348
188
 
349
189
  Args:
350
- data: A DataFrame or xarray DataArray/Dataset to plot.
190
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot.
351
191
  mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
352
192
  'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
353
- colors: Color specification (colormap, list, or dict mapping labels to colors).
193
+ colors: Color specification (colorscale, list, or dict mapping labels to colors).
354
194
  title: The main title of the plot.
355
195
  ylabel: The label for the y-axis.
356
196
  xlabel: The label for the x-axis.
357
- fig: A Plotly figure object to plot on (only for simple plots without faceting).
358
- If not provided, a new figure will be created.
359
197
  facet_by: Dimension(s) to create facets for. Creates a subplot grid.
360
198
  Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
361
199
  If the dimension doesn't exist in the data, it will be silently ignored.
@@ -364,93 +202,122 @@ def with_plotly(
364
202
  facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
365
203
  shared_yaxes: Whether subplots share y-axes.
366
204
  shared_xaxes: Whether subplots share x-axes.
205
+ **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function
206
+ (px.bar, px.line, px.area). These override default arguments if provided.
207
+ Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear'
367
208
 
368
209
  Returns:
369
- A Plotly figure object containing the faceted/animated plot.
210
+ A Plotly figure object containing the faceted/animated plot. You can further customize
211
+ the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()).
370
212
 
371
213
  Examples:
372
214
  Simple plot:
373
215
 
374
216
  ```python
375
- fig = with_plotly(df, mode='area', title='Energy Mix')
217
+ fig = with_plotly(dataset, mode='area', title='Energy Mix')
376
218
  ```
377
219
 
378
220
  Facet by scenario:
379
221
 
380
222
  ```python
381
- fig = with_plotly(ds, facet_by='scenario', facet_cols=2)
223
+ fig = with_plotly(dataset, facet_by='scenario', facet_cols=2)
382
224
  ```
383
225
 
384
226
  Animate by period:
385
227
 
386
228
  ```python
387
- fig = with_plotly(ds, animate_by='period')
229
+ fig = with_plotly(dataset, animate_by='period')
388
230
  ```
389
231
 
390
232
  Facet and animate:
391
233
 
392
234
  ```python
393
- fig = with_plotly(ds, facet_by='scenario', animate_by='period')
235
+ fig = with_plotly(dataset, facet_by='scenario', animate_by='period')
236
+ ```
237
+
238
+ Customize with Plotly Express kwargs:
239
+
240
+ ```python
241
+ fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear')
242
+ ```
243
+
244
+ Further customize the returned figure:
245
+
246
+ ```python
247
+ fig = with_plotly(dataset, mode='line')
248
+ fig.update_traces(line={'width': 5, 'dash': 'dot'})
249
+ fig.update_layout(template='plotly_dark', width=1200, height=600)
394
250
  ```
395
251
  """
252
+ if colors is None:
253
+ colors = CONFIG.Plotting.default_qualitative_colorscale
254
+
396
255
  if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
397
256
  raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
398
257
 
258
+ # Apply CONFIG defaults if not explicitly set
259
+ if facet_cols is None:
260
+ facet_cols = CONFIG.Plotting.default_facet_cols
261
+
262
+ # Ensure data is a Dataset and validate it
263
+ data = _ensure_dataset(data)
264
+ _validate_plotting_data(data, allow_empty=True)
265
+
399
266
  # Handle empty data
400
- if isinstance(data, pd.DataFrame) and data.empty:
401
- return go.Figure()
402
- elif isinstance(data, xr.DataArray) and data.size == 0:
403
- return go.Figure()
404
- elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0:
267
+ if len(data.data_vars) == 0:
268
+ logger.error('with_plotly() got an empty Dataset.')
405
269
  return go.Figure()
406
270
 
407
- # Warn if fig parameter is used with faceting
408
- if fig is not None and (facet_by is not None or animate_by is not None):
409
- logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.')
410
- fig = None
411
-
412
- # Convert xarray to long-form DataFrame for Plotly Express
413
- if isinstance(data, (xr.DataArray, xr.Dataset)):
414
- # Convert to long-form (tidy) DataFrame
415
- # Structure: time, variable, value, scenario, period, ... (all dims as columns)
416
- if isinstance(data, xr.Dataset):
417
- # Stack all data variables into long format
418
- df_long = data.to_dataframe().reset_index()
419
- # Melt to get: time, scenario, period, ..., variable, value
420
- id_vars = [dim for dim in data.dims]
421
- value_vars = list(data.data_vars)
422
- df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
423
- else:
424
- # DataArray
425
- df_long = data.to_dataframe().reset_index()
426
- if data.name:
427
- df_long = df_long.rename(columns={data.name: 'value'})
428
- else:
429
- # Unnamed DataArray, find the value column
430
- non_dim_cols = [col for col in df_long.columns if col not in data.dims]
431
- if len(non_dim_cols) != 1:
432
- raise ValueError(
433
- f'Expected exactly one non-dimension column for unnamed DataArray, '
434
- f'but found {len(non_dim_cols)}: {non_dim_cols}'
271
+ # Handle all-scalar datasets (where all variables have no dimensions)
272
+ # This occurs when all variables are scalar values with dims=()
273
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
274
+ # Create a simple DataFrame with variable names as x-axis
275
+ variables = list(data.data_vars.keys())
276
+ values = [float(data[var].values) for var in data.data_vars]
277
+
278
+ # Resolve colors
279
+ color_discrete_map = process_colors(
280
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
281
+ )
282
+ marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables]
283
+
284
+ # Create simple plot based on mode using go (not px) for better color control
285
+ if mode in ('stacked_bar', 'grouped_bar'):
286
+ fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)])
287
+ elif mode == 'line':
288
+ fig = go.Figure(
289
+ data=[
290
+ go.Scatter(
291
+ x=variables,
292
+ y=values,
293
+ mode='lines+markers',
294
+ marker=dict(color=marker_colors, size=8),
295
+ line=dict(color='lightgray'),
435
296
  )
436
- value_col = non_dim_cols[0]
437
- df_long = df_long.rename(columns={value_col: 'value'})
438
- df_long['variable'] = data.name or 'data'
439
- else:
440
- # Already a DataFrame - convert to long format for Plotly Express
441
- df_long = data.reset_index()
442
- if 'time' not in df_long.columns:
443
- # First column is probably time
444
- df_long = df_long.rename(columns={df_long.columns[0]: 'time'})
445
- # Melt to long format
446
- id_vars = [
447
- col
448
- for col in df_long.columns
449
- if col in ['time', 'scenario', 'period']
450
- or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else [])
451
- ]
452
- value_vars = [col for col in df_long.columns if col not in id_vars]
453
- df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
297
+ ]
298
+ )
299
+ elif mode == 'area':
300
+ fig = go.Figure(
301
+ data=[
302
+ go.Scatter(
303
+ x=variables,
304
+ y=values,
305
+ fill='tozeroy',
306
+ marker=dict(color=marker_colors, size=8),
307
+ line=dict(color='lightgray'),
308
+ )
309
+ ]
310
+ )
311
+ else:
312
+ raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"')
313
+
314
+ fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False)
315
+ return fig
316
+
317
+ # Convert Dataset to long-form DataFrame for Plotly Express
318
+ # Structure: time, variable, value, scenario, period, ... (all dims as columns)
319
+ dim_names = list(data.dims)
320
+ df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value')
454
321
 
455
322
  # Validate facet_by and animate_by dimensions exist in the data
456
323
  available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
@@ -502,13 +369,36 @@ def with_plotly(
502
369
 
503
370
  # Process colors
504
371
  all_vars = df_long['variable'].unique().tolist()
505
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars)
506
- color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)}
372
+ color_discrete_map = process_colors(
373
+ colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
374
+ )
375
+
376
+ # Determine which dimension to use for x-axis
377
+ # Collect dimensions used for faceting and animation
378
+ used_dims = set()
379
+ if facet_row:
380
+ used_dims.add(facet_row)
381
+ if facet_col:
382
+ used_dims.add(facet_col)
383
+ if animate_by:
384
+ used_dims.add(animate_by)
385
+
386
+ # Find available dimensions for x-axis (not used for faceting/animation)
387
+ x_candidates = [d for d in available_dims if d not in used_dims]
388
+
389
+ # Use 'time' if available, otherwise use the first available dimension
390
+ if 'time' in x_candidates:
391
+ x_dim = 'time'
392
+ elif len(x_candidates) > 0:
393
+ x_dim = x_candidates[0]
394
+ else:
395
+ # Fallback: use the first dimension (shouldn't happen in normal cases)
396
+ x_dim = available_dims[0] if available_dims else 'time'
507
397
 
508
398
  # Create plot using Plotly Express based on mode
509
399
  common_args = {
510
400
  'data_frame': df_long,
511
- 'x': 'time',
401
+ 'x': x_dim,
512
402
  'y': 'value',
513
403
  'color': 'variable',
514
404
  'facet_row': facet_row,
@@ -516,13 +406,22 @@ def with_plotly(
516
406
  'animation_frame': animate_by,
517
407
  'color_discrete_map': color_discrete_map,
518
408
  'title': title,
519
- 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''},
409
+ 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''},
520
410
  }
521
411
 
522
412
  # Add facet_col_wrap for single facet dimension
523
413
  if facet_col and not facet_row:
524
414
  common_args['facet_col_wrap'] = facet_cols
525
415
 
416
+ # Add mode-specific defaults (before px_kwargs so they can be overridden)
417
+ if mode in ('line', 'area'):
418
+ common_args['line_shape'] = 'hv' # Stepped lines by default
419
+
420
+ # Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape)
421
+ # These will override the defaults set above
422
+ if px_kwargs:
423
+ common_args.update(px_kwargs)
424
+
526
425
  if mode == 'stacked_bar':
527
426
  fig = px.bar(**common_args)
528
427
  fig.update_traces(marker_line_width=0)
@@ -531,10 +430,10 @@ def with_plotly(
531
430
  fig = px.bar(**common_args)
532
431
  fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
533
432
  elif mode == 'line':
534
- fig = px.line(**common_args, line_shape='hv') # Stepped lines
433
+ fig = px.line(**common_args)
535
434
  elif mode == 'area':
536
435
  # Use Plotly Express to create the area plot (preserves animation, legends, faceting)
537
- fig = px.area(**common_args, line_shape='hv')
436
+ fig = px.area(**common_args)
538
437
 
539
438
  # Classify each variable based on its values
540
439
  variable_classification = {}
@@ -577,13 +476,6 @@ def with_plotly(
577
476
  if hasattr(trace, 'fill'):
578
477
  trace.fill = None
579
478
 
580
- # Update layout with basic styling (Plotly Express handles sizing automatically)
581
- fig.update_layout(
582
- plot_bgcolor='rgba(0,0,0,0)',
583
- paper_bgcolor='rgba(0,0,0,0)',
584
- font=dict(size=12),
585
- )
586
-
587
479
  # Update axes to share if requested (Plotly Express already handles this, but we can customize)
588
480
  if not shared_yaxes:
589
481
  fig.update_yaxes(matches=None)
@@ -594,33 +486,32 @@ def with_plotly(
594
486
 
595
487
 
596
488
  def with_matplotlib(
597
- data: pd.DataFrame,
489
+ data: xr.Dataset | pd.DataFrame | pd.Series,
598
490
  mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
599
- colors: ColorType = 'viridis',
491
+ colors: ColorType | None = None,
600
492
  title: str = '',
601
493
  ylabel: str = '',
602
494
  xlabel: str = 'Time in h',
603
495
  figsize: tuple[int, int] = (12, 6),
604
- fig: plt.Figure | None = None,
605
- ax: plt.Axes | None = None,
496
+ plot_kwargs: dict[str, Any] | None = None,
606
497
  ) -> tuple[plt.Figure, plt.Axes]:
607
498
  """
608
- Plot a DataFrame with Matplotlib using stacked bars or stepped lines.
499
+ Plot data with Matplotlib using stacked bars or stepped lines.
609
500
 
610
501
  Args:
611
- data: A DataFrame containing the data to plot. The index should represent time (e.g., hours),
612
- and each column represents a separate data series.
502
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame,
503
+ the index represents time and each column represents a separate data series (variables).
613
504
  mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
614
- colors: Color specification, can be:
615
- - A string with a colormap name (e.g., 'viridis', 'plasma')
505
+ colors: Color specification. Can be:
506
+ - A colorscale name (e.g., 'turbo', 'plasma')
616
507
  - A list of color strings (e.g., ['#ff0000', '#00ff00'])
617
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
508
+ - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'})
618
509
  title: The title of the plot.
619
510
  ylabel: The ylabel of the plot.
620
511
  xlabel: The xlabel of the plot.
621
- figsize: Specify the size of the figure
622
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
623
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
512
+ figsize: Specify the size of the figure (width, height) in inches.
513
+ plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls.
514
+ Use this to customize plot properties (e.g., linewidth, alpha, edgecolor).
624
515
 
625
516
  Returns:
626
517
  A tuple containing the Matplotlib figure and axes objects used for the plot.
@@ -630,48 +521,121 @@ def with_matplotlib(
630
521
  Negative values are stacked separately without extra labels in the legend.
631
522
  - If `mode` is 'line', stepped lines are drawn for each data series.
632
523
  """
524
+ if colors is None:
525
+ colors = CONFIG.Plotting.default_qualitative_colorscale
526
+
633
527
  if mode not in ('stacked_bar', 'line'):
634
528
  raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
635
529
 
636
- if fig is None or ax is None:
637
- fig, ax = plt.subplots(figsize=figsize)
530
+ # Ensure data is a Dataset and validate it
531
+ data = _ensure_dataset(data)
532
+ _validate_plotting_data(data, allow_empty=True)
533
+
534
+ # Create new figure and axes
535
+ fig, ax = plt.subplots(figsize=figsize)
536
+
537
+ # Initialize plot_kwargs if not provided
538
+ if plot_kwargs is None:
539
+ plot_kwargs = {}
540
+
541
+ # Handle all-scalar datasets (where all variables have no dimensions)
542
+ # This occurs when all variables are scalar values with dims=()
543
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
544
+ # Create simple bar/line plot with variable names as x-axis
545
+ variables = list(data.data_vars.keys())
546
+ values = [float(data[var].values) for var in data.data_vars]
547
+
548
+ # Resolve colors
549
+ color_discrete_map = process_colors(
550
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
551
+ )
552
+ colors_list = [color_discrete_map.get(var, '#808080') for var in variables]
553
+
554
+ # Create plot based on mode
555
+ if mode == 'stacked_bar':
556
+ ax.bar(variables, values, color=colors_list, **plot_kwargs)
557
+ elif mode == 'line':
558
+ ax.plot(
559
+ variables,
560
+ values,
561
+ marker='o',
562
+ color=colors_list[0] if len(set(colors_list)) == 1 else None,
563
+ **plot_kwargs,
564
+ )
565
+ # If different colors, plot each point separately
566
+ if len(set(colors_list)) > 1:
567
+ ax.clear()
568
+ for i, (var, val) in enumerate(zip(variables, values, strict=False)):
569
+ ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs)
570
+ ax.set_xticks(range(len(variables)))
571
+ ax.set_xticklabels(variables)
572
+
573
+ ax.set_xlabel(xlabel, ha='center')
574
+ ax.set_ylabel(ylabel, va='center')
575
+ ax.set_title(title)
576
+ ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y')
577
+ fig.tight_layout()
578
+
579
+ return fig, ax
580
+
581
+ # Resolve colors first (includes validation)
582
+ color_discrete_map = process_colors(
583
+ colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
584
+ )
585
+
586
+ # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form)
587
+ df = data.to_dataframe()
638
588
 
639
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns))
589
+ # Get colors in column order
590
+ processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns]
640
591
 
641
592
  if mode == 'stacked_bar':
642
- cumulative_positive = np.zeros(len(data))
643
- cumulative_negative = np.zeros(len(data))
644
- width = data.index.to_series().diff().dropna().min() # Minimum time difference
593
+ cumulative_positive = np.zeros(len(df))
594
+ cumulative_negative = np.zeros(len(df))
595
+
596
+ # Robust bar width: handle datetime-like, numeric, and single-point indexes
597
+ if len(df.index) > 1:
598
+ delta = pd.Index(df.index).to_series().diff().dropna().min()
599
+ if hasattr(delta, 'total_seconds'): # datetime-like
600
+ width = delta.total_seconds() / 86400.0 # Matplotlib date units = days
601
+ else:
602
+ width = float(delta)
603
+ else:
604
+ width = 0.8 # reasonable default for a single bar
645
605
 
646
- for i, column in enumerate(data.columns):
647
- positive_values = np.clip(data[column], 0, None) # Keep only positive values
648
- negative_values = np.clip(data[column], None, 0) # Keep only negative values
606
+ for i, column in enumerate(df.columns):
607
+ # Fill NaNs to avoid breaking stacking math
608
+ series = df[column].fillna(0)
609
+ positive_values = np.clip(series, 0, None) # Keep only positive values
610
+ negative_values = np.clip(series, None, 0) # Keep only negative values
649
611
  # Plot positive bars
650
612
  ax.bar(
651
- data.index,
613
+ df.index,
652
614
  positive_values,
653
615
  bottom=cumulative_positive,
654
616
  color=processed_colors[i],
655
617
  label=column,
656
618
  width=width,
657
619
  align='center',
620
+ **plot_kwargs,
658
621
  )
659
622
  cumulative_positive += positive_values.values
660
623
  # Plot negative bars
661
624
  ax.bar(
662
- data.index,
625
+ df.index,
663
626
  negative_values,
664
627
  bottom=cumulative_negative,
665
628
  color=processed_colors[i],
666
629
  label='', # No label for negative bars
667
630
  width=width,
668
631
  align='center',
632
+ **plot_kwargs,
669
633
  )
670
634
  cumulative_negative += negative_values.values
671
635
 
672
636
  elif mode == 'line':
673
- for i, column in enumerate(data.columns):
674
- ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column)
637
+ for i, column in enumerate(df.columns):
638
+ ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs)
675
639
 
676
640
  # Aesthetics
677
641
  ax.set_xlabel(xlabel, ha='center')
@@ -944,228 +908,104 @@ def plot_network(
944
908
  )
945
909
 
946
910
 
947
- def pie_with_plotly(
948
- data: pd.DataFrame,
949
- colors: ColorType = 'viridis',
950
- title: str = '',
951
- legend_title: str = '',
952
- hole: float = 0.0,
953
- fig: go.Figure | None = None,
954
- ) -> go.Figure:
955
- """
956
- Create a pie chart with Plotly to visualize the proportion of values in a DataFrame.
957
-
958
- Args:
959
- data: A DataFrame containing the data to plot. If multiple rows exist,
960
- they will be summed unless a specific index value is passed.
961
- colors: Color specification, can be:
962
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
963
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
964
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
965
- title: The title of the plot.
966
- legend_title: The title for the legend.
967
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
968
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
969
-
970
- Returns:
971
- A Plotly figure object containing the generated pie chart.
972
-
973
- Notes:
974
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
975
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
976
- for better readability.
977
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
978
-
911
+ def preprocess_data_for_pie(
912
+ data: xr.Dataset | pd.DataFrame | pd.Series,
913
+ lower_percentage_threshold: float = 5.0,
914
+ ) -> pd.Series:
979
915
  """
980
- if data.empty:
981
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
982
- return go.Figure()
983
-
984
- # Create a copy to avoid modifying the original DataFrame
985
- data_copy = data.copy()
986
-
987
- # Check if any negative values and warn
988
- if (data_copy < 0).any().any():
989
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
990
- data_copy = data_copy.abs()
991
-
992
- # If data has multiple rows, sum them to get total for each column
993
- if len(data_copy) > 1:
994
- data_sum = data_copy.sum()
995
- else:
996
- data_sum = data_copy.iloc[0]
916
+ Preprocess data for pie chart display.
997
917
 
998
- # Get labels (column names) and values
999
- labels = data_sum.index.tolist()
1000
- values = data_sum.values.tolist()
1001
-
1002
- # Apply color mapping using the unified color processor
1003
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels)
1004
-
1005
- # Create figure if not provided
1006
- fig = fig if fig is not None else go.Figure()
1007
-
1008
- # Add pie trace
1009
- fig.add_trace(
1010
- go.Pie(
1011
- labels=labels,
1012
- values=values,
1013
- hole=hole,
1014
- marker=dict(colors=processed_colors),
1015
- textinfo='percent+label+value',
1016
- textposition='inside',
1017
- insidetextorientation='radial',
1018
- )
1019
- )
1020
-
1021
- # Update layout for better aesthetics
1022
- fig.update_layout(
1023
- title=title,
1024
- legend_title=legend_title,
1025
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
1026
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
1027
- font=dict(size=14), # Increase font size for better readability
1028
- )
1029
-
1030
- return fig
1031
-
1032
-
1033
- def pie_with_matplotlib(
1034
- data: pd.DataFrame,
1035
- colors: ColorType = 'viridis',
1036
- title: str = '',
1037
- legend_title: str = 'Categories',
1038
- hole: float = 0.0,
1039
- figsize: tuple[int, int] = (10, 8),
1040
- fig: plt.Figure | None = None,
1041
- ax: plt.Axes | None = None,
1042
- ) -> tuple[plt.Figure, plt.Axes]:
1043
- """
1044
- Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame.
918
+ Groups items that are individually below the threshold percentage into an "Other" category.
919
+ Converts various input types to a pandas Series for uniform handling.
1045
920
 
1046
921
  Args:
1047
- data: A DataFrame containing the data to plot. If multiple rows exist,
1048
- they will be summed unless a specific index value is passed.
1049
- colors: Color specification, can be:
1050
- - A string with a colormap name (e.g., 'viridis', 'plasma')
1051
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1052
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
1053
- title: The title of the plot.
1054
- legend_title: The title for the legend.
1055
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
1056
- figsize: The size of the figure (width, height) in inches.
1057
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1058
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
922
+ data: Input data (xarray Dataset, DataFrame, or Series)
923
+ lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other"
1059
924
 
1060
925
  Returns:
1061
- A tuple containing the Matplotlib figure and axes objects used for the plot.
1062
-
1063
- Notes:
1064
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
1065
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
1066
- for better readability.
1067
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
1068
-
926
+ Processed pandas Series with small items grouped into "Other"
1069
927
  """
1070
- if data.empty:
1071
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
1072
- if fig is None or ax is None:
1073
- fig, ax = plt.subplots(figsize=figsize)
1074
- return fig, ax
928
+ # Convert to Series
929
+ if isinstance(data, xr.Dataset):
930
+ # Sum all dimensions for each variable to get total values
931
+ values = {}
932
+ for var in data.data_vars:
933
+ var_data = data[var]
934
+ if len(var_data.dims) > 0:
935
+ total_value = float(var_data.sum().item())
936
+ else:
937
+ total_value = float(var_data.item())
1075
938
 
1076
- # Create a copy to avoid modifying the original DataFrame
1077
- data_copy = data.copy()
939
+ # Handle negative values
940
+ if total_value < 0:
941
+ logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.')
942
+ total_value = abs(total_value)
1078
943
 
1079
- # Check if any negative values and warn
1080
- if (data_copy < 0).any().any():
1081
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1082
- data_copy = data_copy.abs()
944
+ values[var] = total_value
1083
945
 
1084
- # If data has multiple rows, sum them to get total for each column
1085
- if len(data_copy) > 1:
1086
- data_sum = data_copy.sum()
1087
- else:
1088
- data_sum = data_copy.iloc[0]
946
+ series = pd.Series(values)
1089
947
 
1090
- # Get labels (column names) and values
1091
- labels = data_sum.index.tolist()
1092
- values = data_sum.values.tolist()
948
+ elif isinstance(data, pd.DataFrame):
949
+ # Sum across all columns if DataFrame
950
+ series = data.sum(axis=0)
951
+ # Handle negative values
952
+ negative_mask = series < 0
953
+ if negative_mask.any():
954
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
955
+ series = series.abs()
1093
956
 
1094
- # Apply color mapping using the unified color processor
1095
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels)
957
+ else: # pd.Series
958
+ series = data.copy()
959
+ # Handle negative values
960
+ negative_mask = series < 0
961
+ if negative_mask.any():
962
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
963
+ series = series.abs()
1096
964
 
1097
- # Create figure and axis if not provided
1098
- if fig is None or ax is None:
1099
- fig, ax = plt.subplots(figsize=figsize)
965
+ # Only keep positive values
966
+ series = series[series > 0]
1100
967
 
1101
- # Draw the pie chart
1102
- wedges, texts, autotexts = ax.pie(
1103
- values,
1104
- labels=labels,
1105
- colors=processed_colors,
1106
- autopct='%1.1f%%',
1107
- startangle=90,
1108
- shadow=False,
1109
- wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut
1110
- )
968
+ if series.empty or lower_percentage_threshold <= 0:
969
+ return series
1111
970
 
1112
- # Adjust the wedgeprops to make donut hole size consistent with plotly
1113
- # For matplotlib, the hole size is determined by the wedge width
1114
- # Convert hole parameter to wedge width
1115
- if hole > 0:
1116
- # Adjust hole size to match plotly's hole parameter
1117
- # In matplotlib, wedge width is relative to the radius (which is 1)
1118
- # For plotly, hole is a fraction of the radius
1119
- wedge_width = 1 - hole
1120
- for wedge in wedges:
1121
- wedge.set_width(wedge_width)
1122
-
1123
- # Customize the appearance
1124
- # Make autopct text more visible
1125
- for autotext in autotexts:
1126
- autotext.set_fontsize(10)
1127
- autotext.set_color('white')
1128
-
1129
- # Set aspect ratio to be equal to ensure a circular pie
1130
- ax.set_aspect('equal')
1131
-
1132
- # Add title
1133
- if title:
1134
- ax.set_title(title, fontsize=16)
971
+ # Calculate percentages
972
+ total = series.sum()
973
+ percentages = (series / total) * 100
1135
974
 
1136
- # Create a legend if there are many segments
1137
- if len(labels) > 6:
1138
- ax.legend(wedges, labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0, 0.5, 1))
975
+ # Find items below and above threshold
976
+ below_threshold = series[percentages < lower_percentage_threshold]
977
+ above_threshold = series[percentages >= lower_percentage_threshold]
1139
978
 
1140
- # Apply tight layout
1141
- fig.tight_layout()
979
+ # Only group if there are at least 2 items below threshold
980
+ if len(below_threshold) > 1:
981
+ # Create new series with items above threshold + "Other"
982
+ result = above_threshold.copy()
983
+ result['Other'] = below_threshold.sum()
984
+ return result
1142
985
 
1143
- return fig, ax
986
+ return series
1144
987
 
1145
988
 
1146
989
  def dual_pie_with_plotly(
1147
- data_left: pd.Series,
1148
- data_right: pd.Series,
1149
- colors: ColorType = 'viridis',
990
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
991
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
992
+ colors: ColorType | None = None,
1150
993
  title: str = '',
1151
994
  subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
1152
995
  legend_title: str = '',
1153
996
  hole: float = 0.2,
1154
997
  lower_percentage_group: float = 5.0,
1155
- hover_template: str = '%{label}: %{value} (%{percent})',
1156
998
  text_info: str = 'percent+label',
1157
999
  text_position: str = 'inside',
1000
+ hover_template: str = '%{label}: %{value} (%{percent})',
1158
1001
  ) -> go.Figure:
1159
1002
  """
1160
- Create two pie charts side by side with Plotly, with consistent coloring across both charts.
1003
+ Create two pie charts side by side with Plotly.
1161
1004
 
1162
1005
  Args:
1163
- data_left: Series for the left pie chart.
1164
- data_right: Series for the right pie chart.
1165
- colors: Color specification, can be:
1166
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
1167
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1168
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
1006
+ data_left: Data for the left pie chart. Variables are summed across all dimensions.
1007
+ data_right: Data for the right pie chart. Variables are summed across all dimensions.
1008
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
1169
1009
  title: The main title of the plot.
1170
1010
  subtitles: Tuple containing the subtitles for (left, right) charts.
1171
1011
  legend_title: The title for the legend.
@@ -1177,119 +1017,67 @@ def dual_pie_with_plotly(
1177
1017
  text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
1178
1018
 
1179
1019
  Returns:
1180
- A Plotly figure object containing the generated dual pie chart.
1020
+ Plotly Figure object
1181
1021
  """
1182
- from plotly.subplots import make_subplots
1183
-
1184
- # Check for empty data
1185
- if data_left.empty and data_right.empty:
1186
- logger.error('Both datasets are empty. Returning empty figure.')
1187
- return go.Figure()
1188
-
1189
- # Create a subplot figure
1190
- fig = make_subplots(
1191
- rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05
1192
- )
1193
-
1194
- # Process series to handle negative values and apply minimum percentage threshold
1195
- def preprocess_series(series: pd.Series):
1196
- """
1197
- Preprocess a series for pie chart display by handling negative values
1198
- and grouping the smallest parts together if they collectively represent
1199
- less than the specified percentage threshold.
1200
-
1201
- Args:
1202
- series: The series to preprocess
1203
-
1204
- Returns:
1205
- A preprocessed pandas Series
1206
- """
1207
- # Handle negative values
1208
- if (series < 0).any():
1209
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1210
- series = series.abs()
1211
-
1212
- # Remove zeros
1213
- series = series[series > 0]
1214
-
1215
- # Apply minimum percentage threshold if needed
1216
- if lower_percentage_group and not series.empty:
1217
- total = series.sum()
1218
- if total > 0:
1219
- # Sort series by value (ascending)
1220
- sorted_series = series.sort_values()
1221
-
1222
- # Calculate cumulative percentage contribution
1223
- cumulative_percent = (sorted_series.cumsum() / total) * 100
1224
-
1225
- # Find entries that collectively make up less than lower_percentage_group
1226
- to_group = cumulative_percent <= lower_percentage_group
1227
-
1228
- if to_group.sum() > 1:
1229
- # Create "Other" category for the smallest values that together are < threshold
1230
- other_sum = sorted_series[to_group].sum()
1231
-
1232
- # Keep only values that aren't in the "Other" group
1233
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
1234
-
1235
- # Add the "Other" category if it has a value
1236
- if other_sum > 0:
1237
- result_series['Other'] = other_sum
1238
-
1239
- return result_series
1240
-
1241
- return series
1242
-
1243
- data_left_processed = preprocess_series(data_left)
1244
- data_right_processed = preprocess_series(data_right)
1245
-
1246
- # Get unique set of all labels for consistent coloring
1247
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
1248
-
1249
- # Get consistent color mapping for both charts using our unified function
1250
- color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True)
1251
-
1252
- # Function to create a pie trace with consistently mapped colors
1253
- def create_pie_trace(data_series, side):
1254
- if data_series.empty:
1255
- return None
1256
-
1257
- labels = data_series.index.tolist()
1258
- values = data_series.values.tolist()
1259
- trace_colors = [color_map[label] for label in labels]
1260
-
1261
- return go.Pie(
1262
- labels=labels,
1263
- values=values,
1264
- name=side,
1265
- marker=dict(colors=trace_colors),
1266
- hole=hole,
1267
- textinfo=text_info,
1268
- textposition=text_position,
1269
- insidetextorientation='radial',
1270
- hovertemplate=hover_template,
1271
- sort=True, # Sort values by default (largest first)
1022
+ if colors is None:
1023
+ colors = CONFIG.Plotting.default_qualitative_colorscale
1024
+
1025
+ # Preprocess data to Series
1026
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
1027
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
1028
+
1029
+ # Extract labels and values
1030
+ left_labels = left_series.index.tolist()
1031
+ left_values = left_series.values.tolist()
1032
+
1033
+ right_labels = right_series.index.tolist()
1034
+ right_values = right_series.values.tolist()
1035
+
1036
+ # Get all unique labels for consistent coloring
1037
+ all_labels = sorted(set(left_labels) | set(right_labels))
1038
+
1039
+ # Create color map
1040
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
1041
+
1042
+ # Create figure
1043
+ fig = go.Figure()
1044
+
1045
+ # Add left pie
1046
+ if left_labels:
1047
+ fig.add_trace(
1048
+ go.Pie(
1049
+ labels=left_labels,
1050
+ values=left_values,
1051
+ name=subtitles[0],
1052
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]),
1053
+ hole=hole,
1054
+ textinfo=text_info,
1055
+ textposition=text_position,
1056
+ hovertemplate=hover_template,
1057
+ domain=dict(x=[0, 0.48]),
1058
+ )
1272
1059
  )
1273
1060
 
1274
- # Add left pie if data exists
1275
- left_trace = create_pie_trace(data_left_processed, subtitles[0])
1276
- if left_trace:
1277
- left_trace.domain = dict(x=[0, 0.48])
1278
- fig.add_trace(left_trace, row=1, col=1)
1279
-
1280
- # Add right pie if data exists
1281
- right_trace = create_pie_trace(data_right_processed, subtitles[1])
1282
- if right_trace:
1283
- right_trace.domain = dict(x=[0.52, 1])
1284
- fig.add_trace(right_trace, row=1, col=2)
1061
+ # Add right pie
1062
+ if right_labels:
1063
+ fig.add_trace(
1064
+ go.Pie(
1065
+ labels=right_labels,
1066
+ values=right_values,
1067
+ name=subtitles[1],
1068
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]),
1069
+ hole=hole,
1070
+ textinfo=text_info,
1071
+ textposition=text_position,
1072
+ hovertemplate=hover_template,
1073
+ domain=dict(x=[0.52, 1]),
1074
+ )
1075
+ )
1285
1076
 
1286
1077
  # Update layout
1287
1078
  fig.update_layout(
1288
1079
  title=title,
1289
1080
  legend_title=legend_title,
1290
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
1291
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
1292
- font=dict(size=14),
1293
1081
  margin=dict(t=80, b=50, l=30, r=30),
1294
1082
  )
1295
1083
 
@@ -1297,178 +1085,127 @@ def dual_pie_with_plotly(
1297
1085
 
1298
1086
 
1299
1087
  def dual_pie_with_matplotlib(
1300
- data_left: pd.Series,
1301
- data_right: pd.Series,
1302
- colors: ColorType = 'viridis',
1088
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
1089
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
1090
+ colors: ColorType | None = None,
1303
1091
  title: str = '',
1304
1092
  subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
1305
1093
  legend_title: str = '',
1306
1094
  hole: float = 0.2,
1307
1095
  lower_percentage_group: float = 5.0,
1308
1096
  figsize: tuple[int, int] = (14, 7),
1309
- fig: plt.Figure | None = None,
1310
- axes: list[plt.Axes] | None = None,
1311
1097
  ) -> tuple[plt.Figure, list[plt.Axes]]:
1312
1098
  """
1313
- Create two pie charts side by side with Matplotlib, with consistent coloring across both charts.
1314
- Leverages the existing pie_with_matplotlib function.
1099
+ Create two pie charts side by side with Matplotlib.
1315
1100
 
1316
1101
  Args:
1317
- data_left: Series for the left pie chart.
1318
- data_right: Series for the right pie chart.
1319
- colors: Color specification, can be:
1320
- - A string with a colormap name (e.g., 'viridis', 'plasma')
1321
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
1322
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
1102
+ data_left: Data for the left pie chart.
1103
+ data_right: Data for the right pie chart.
1104
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
1323
1105
  title: The main title of the plot.
1324
1106
  subtitles: Tuple containing the subtitles for (left, right) charts.
1325
1107
  legend_title: The title for the legend.
1326
1108
  hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
1327
1109
  lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
1328
1110
  figsize: The size of the figure (width, height) in inches.
1329
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1330
- axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created.
1331
1111
 
1332
1112
  Returns:
1333
- A tuple containing the Matplotlib figure and list of axes objects used for the plot.
1113
+ Tuple of (Figure, list of Axes)
1334
1114
  """
1335
- # Check for empty data
1336
- if data_left.empty and data_right.empty:
1337
- logger.error('Both datasets are empty. Returning empty figure.')
1338
- if fig is None:
1339
- fig, axes = plt.subplots(1, 2, figsize=figsize)
1340
- return fig, axes
1341
-
1342
- # Create figure and axes if not provided
1343
- if fig is None or axes is None:
1344
- fig, axes = plt.subplots(1, 2, figsize=figsize)
1345
-
1346
- # Process series to handle negative values and apply minimum percentage threshold
1347
- def preprocess_series(series: pd.Series):
1348
- """
1349
- Preprocess a series for pie chart display by handling negative values
1350
- and grouping the smallest parts together if they collectively represent
1351
- less than the specified percentage threshold.
1352
- """
1353
- # Handle negative values
1354
- if (series < 0).any():
1355
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
1356
- series = series.abs()
1115
+ if colors is None:
1116
+ colors = CONFIG.Plotting.default_qualitative_colorscale
1357
1117
 
1358
- # Remove zeros
1359
- series = series[series > 0]
1118
+ # Preprocess data to Series
1119
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
1120
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
1360
1121
 
1361
- # Apply minimum percentage threshold if needed
1362
- if lower_percentage_group and not series.empty:
1363
- total = series.sum()
1364
- if total > 0:
1365
- # Sort series by value (ascending)
1366
- sorted_series = series.sort_values()
1122
+ # Extract labels and values
1123
+ left_labels = left_series.index.tolist()
1124
+ left_values = left_series.values.tolist()
1367
1125
 
1368
- # Calculate cumulative percentage contribution
1369
- cumulative_percent = (sorted_series.cumsum() / total) * 100
1126
+ right_labels = right_series.index.tolist()
1127
+ right_values = right_series.values.tolist()
1370
1128
 
1371
- # Find entries that collectively make up less than lower_percentage_group
1372
- to_group = cumulative_percent <= lower_percentage_group
1129
+ # Get all unique labels for consistent coloring
1130
+ all_labels = sorted(set(left_labels) | set(right_labels))
1373
1131
 
1374
- if to_group.sum() > 1:
1375
- # Create "Other" category for the smallest values that together are < threshold
1376
- other_sum = sorted_series[to_group].sum()
1132
+ # Create color map (process_colors always returns a dict)
1133
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
1377
1134
 
1378
- # Keep only values that aren't in the "Other" group
1379
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
1135
+ # Create figure
1136
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1380
1137
 
1381
- # Add the "Other" category if it has a value
1382
- if other_sum > 0:
1383
- result_series['Other'] = other_sum
1138
+ def draw_pie(ax, labels, values, subtitle):
1139
+ """Draw a single pie chart."""
1140
+ if not labels:
1141
+ ax.set_title(subtitle)
1142
+ ax.axis('off')
1143
+ return
1384
1144
 
1385
- return result_series
1145
+ chart_colors = [color_map[label] for label in labels]
1386
1146
 
1387
- return series
1388
-
1389
- # Preprocess data
1390
- data_left_processed = preprocess_series(data_left)
1391
- data_right_processed = preprocess_series(data_right)
1392
-
1393
- # Convert Series to DataFrames for pie_with_matplotlib
1394
- df_left = pd.DataFrame(data_left_processed).T if not data_left_processed.empty else pd.DataFrame()
1395
- df_right = pd.DataFrame(data_right_processed).T if not data_right_processed.empty else pd.DataFrame()
1396
-
1397
- # Get unique set of all labels for consistent coloring
1398
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
1399
-
1400
- # Get consistent color mapping for both charts using our unified function
1401
- color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True)
1147
+ # Draw pie
1148
+ wedges, texts, autotexts = ax.pie(
1149
+ values,
1150
+ labels=labels,
1151
+ colors=chart_colors,
1152
+ autopct='%1.1f%%',
1153
+ startangle=90,
1154
+ wedgeprops=dict(width=1 - hole) if hole > 0 else None,
1155
+ )
1402
1156
 
1403
- # Configure colors for each DataFrame based on the consistent mapping
1404
- left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else []
1405
- right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else []
1157
+ # Style text
1158
+ for autotext in autotexts:
1159
+ autotext.set_fontsize(10)
1160
+ autotext.set_color('white')
1161
+ autotext.set_weight('bold')
1406
1162
 
1407
- # Create left pie chart
1408
- if not df_left.empty:
1409
- pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0])
1410
- else:
1411
- axes[0].set_title(subtitles[0])
1412
- axes[0].axis('off')
1163
+ ax.set_aspect('equal')
1164
+ ax.set_title(subtitle, fontsize=14, pad=20)
1413
1165
 
1414
- # Create right pie chart
1415
- if not df_right.empty:
1416
- pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1])
1417
- else:
1418
- axes[1].set_title(subtitles[1])
1419
- axes[1].axis('off')
1166
+ # Draw both pies
1167
+ draw_pie(axes[0], left_labels, left_values, subtitles[0])
1168
+ draw_pie(axes[1], right_labels, right_values, subtitles[1])
1420
1169
 
1421
1170
  # Add main title
1422
1171
  if title:
1423
1172
  fig.suptitle(title, fontsize=16, y=0.98)
1424
1173
 
1425
- # Adjust layout
1426
- fig.tight_layout()
1427
-
1428
- # Create a unified legend if both charts have data
1429
- if not df_left.empty and not df_right.empty:
1430
- # Remove individual legends
1431
- for ax in axes:
1432
- if ax.get_legend():
1433
- ax.get_legend().remove()
1434
-
1435
- # Create handles for the unified legend
1436
- handles = []
1437
- labels_for_legend = []
1438
-
1439
- for label in all_labels:
1440
- color = color_map[label]
1441
- patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label)
1442
- handles.append(patch)
1443
- labels_for_legend.append(label)
1174
+ # Create unified legend
1175
+ if left_labels or right_labels:
1176
+ handles = [
1177
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10)
1178
+ for label in all_labels
1179
+ ]
1444
1180
 
1445
- # Add unified legend
1446
1181
  fig.legend(
1447
1182
  handles=handles,
1448
- labels=labels_for_legend,
1183
+ labels=all_labels,
1449
1184
  title=legend_title,
1450
1185
  loc='lower center',
1451
- bbox_to_anchor=(0.5, 0),
1452
- ncol=min(len(all_labels), 5), # Limit columns to 5 for readability
1186
+ bbox_to_anchor=(0.5, -0.02),
1187
+ ncol=min(len(all_labels), 5),
1453
1188
  )
1454
1189
 
1455
- # Add padding at the bottom for the legend
1456
- fig.subplots_adjust(bottom=0.2)
1190
+ fig.subplots_adjust(bottom=0.15)
1191
+
1192
+ fig.tight_layout()
1457
1193
 
1458
1194
  return fig, axes
1459
1195
 
1460
1196
 
1461
1197
  def heatmap_with_plotly(
1462
1198
  data: xr.DataArray,
1463
- colors: ColorType = 'viridis',
1199
+ colors: ColorType | None = None,
1464
1200
  title: str = '',
1465
1201
  facet_by: str | list[str] | None = None,
1466
1202
  animate_by: str | None = None,
1467
- facet_cols: int = 3,
1203
+ facet_cols: int | None = None,
1468
1204
  reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1469
1205
  | Literal['auto']
1470
1206
  | None = 'auto',
1471
1207
  fill: Literal['ffill', 'bfill'] | None = 'ffill',
1208
+ **imshow_kwargs: Any,
1472
1209
  ) -> go.Figure:
1473
1210
  """
1474
1211
  Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
@@ -1486,8 +1223,8 @@ def heatmap_with_plotly(
1486
1223
  Args:
1487
1224
  data: An xarray DataArray containing the data to visualize. Should have at least
1488
1225
  2 dimensions, or a 'time' dimension that can be reshaped into 2D.
1489
- colors: Color specification (colormap name, list, or dict). Common options:
1490
- 'viridis', 'plasma', 'RdBu', 'portland'.
1226
+ colors: Color specification (colorscale name, list, or dict). Common options:
1227
+ 'turbo', 'plasma', 'RdBu', 'portland'.
1491
1228
  title: The main title of the heatmap.
1492
1229
  facet_by: Dimension to create facets for. Creates a subplot grid.
1493
1230
  Can be a single dimension name or list (only first dimension used).
@@ -1501,6 +1238,11 @@ def heatmap_with_plotly(
1501
1238
  - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1502
1239
  - None: Disable time reshaping (will error if only 1D time data)
1503
1240
  fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1241
+ **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow.
1242
+ Common options include:
1243
+ - aspect: 'auto', 'equal', or a number for aspect ratio
1244
+ - zmin, zmax: Minimum and maximum values for color scale
1245
+ - labels: Dict to customize axis labels
1504
1246
 
1505
1247
  Returns:
1506
1248
  A Plotly figure object containing the heatmap visualization.
@@ -1538,6 +1280,13 @@ def heatmap_with_plotly(
1538
1280
  fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
1539
1281
  ```
1540
1282
  """
1283
+ if colors is None:
1284
+ colors = CONFIG.Plotting.default_sequential_colorscale
1285
+
1286
+ # Apply CONFIG defaults if not explicitly set
1287
+ if facet_cols is None:
1288
+ facet_cols = CONFIG.Plotting.default_facet_cols
1289
+
1541
1290
  # Handle empty data
1542
1291
  if data.size == 0:
1543
1292
  return go.Figure()
@@ -1589,12 +1338,26 @@ def heatmap_with_plotly(
1589
1338
  heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
1590
1339
 
1591
1340
  if len(heatmap_dims) < 2:
1592
- # Need at least 2 dimensions for a heatmap
1593
- logger.error(
1594
- f'Heatmap requires at least 2 dimensions for rows and columns. '
1595
- f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
1596
- )
1597
- return go.Figure()
1341
+ # Handle single-dimension case by adding variable name as a dimension
1342
+ if len(heatmap_dims) == 1:
1343
+ # Get the variable name, or use a default
1344
+ var_name = data.name if data.name else 'value'
1345
+
1346
+ # Expand the DataArray by adding a new dimension with the variable name
1347
+ data = data.expand_dims({'variable': [var_name]})
1348
+
1349
+ # Update available dimensions
1350
+ available_dims = list(data.dims)
1351
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
1352
+
1353
+ logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}')
1354
+ else:
1355
+ # No dimensions at all - cannot create a heatmap
1356
+ logger.error(
1357
+ f'Heatmap requires at least 1 dimension. '
1358
+ f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
1359
+ )
1360
+ return go.Figure()
1598
1361
 
1599
1362
  # Setup faceting parameters for Plotly Express
1600
1363
  # Note: px.imshow only supports facet_col, not facet_row
@@ -1617,7 +1380,7 @@ def heatmap_with_plotly(
1617
1380
  # Create the imshow plot - px.imshow can work directly with xarray DataArrays
1618
1381
  common_args = {
1619
1382
  'img': data,
1620
- 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis',
1383
+ 'color_continuous_scale': colors,
1621
1384
  'title': title,
1622
1385
  }
1623
1386
 
@@ -1631,38 +1394,39 @@ def heatmap_with_plotly(
1631
1394
  if animate_by:
1632
1395
  common_args['animation_frame'] = animate_by
1633
1396
 
1397
+ # Merge in additional imshow kwargs
1398
+ common_args.update(imshow_kwargs)
1399
+
1634
1400
  try:
1635
1401
  fig = px.imshow(**common_args)
1636
1402
  except Exception as e:
1637
1403
  logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
1638
1404
  # Fallback: create a simple heatmap without faceting
1639
- fig = px.imshow(
1640
- data.values,
1641
- color_continuous_scale=colors if isinstance(colors, str) else 'viridis',
1642
- title=title,
1643
- )
1644
-
1645
- # Update layout with basic styling
1646
- fig.update_layout(
1647
- plot_bgcolor='rgba(0,0,0,0)',
1648
- paper_bgcolor='rgba(0,0,0,0)',
1649
- font=dict(size=12),
1650
- )
1405
+ fallback_args = {
1406
+ 'img': data.values,
1407
+ 'color_continuous_scale': colors,
1408
+ 'title': title,
1409
+ }
1410
+ fallback_args.update(imshow_kwargs)
1411
+ fig = px.imshow(**fallback_args)
1651
1412
 
1652
1413
  return fig
1653
1414
 
1654
1415
 
1655
1416
  def heatmap_with_matplotlib(
1656
1417
  data: xr.DataArray,
1657
- colors: ColorType = 'viridis',
1418
+ colors: ColorType | None = None,
1658
1419
  title: str = '',
1659
1420
  figsize: tuple[float, float] = (12, 6),
1660
- fig: plt.Figure | None = None,
1661
- ax: plt.Axes | None = None,
1662
1421
  reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1663
1422
  | Literal['auto']
1664
1423
  | None = 'auto',
1665
1424
  fill: Literal['ffill', 'bfill'] | None = 'ffill',
1425
+ vmin: float | None = None,
1426
+ vmax: float | None = None,
1427
+ imshow_kwargs: dict[str, Any] | None = None,
1428
+ cbar_kwargs: dict[str, Any] | None = None,
1429
+ **kwargs: Any,
1666
1430
  ) -> tuple[plt.Figure, plt.Axes]:
1667
1431
  """
1668
1432
  Plot a heatmap visualization using Matplotlib's imshow.
@@ -1674,16 +1438,25 @@ def heatmap_with_matplotlib(
1674
1438
  data: An xarray DataArray containing the data to visualize. Should have at least
1675
1439
  2 dimensions. If more than 2 dimensions exist, additional dimensions will
1676
1440
  be reduced by taking the first slice.
1677
- colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu').
1441
+ colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu').
1678
1442
  title: The title of the heatmap.
1679
1443
  figsize: The size of the figure (width, height) in inches.
1680
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1681
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
1682
1444
  reshape_time: Time reshaping configuration:
1683
1445
  - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
1684
1446
  - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1685
1447
  - None: Disable time reshaping
1686
1448
  fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1449
+ vmin: Minimum value for color scale. If None, uses data minimum.
1450
+ vmax: Maximum value for color scale. If None, uses data maximum.
1451
+ imshow_kwargs: Optional dict of parameters to pass to ax.imshow().
1452
+ Use this to customize image properties (e.g., interpolation, aspect).
1453
+ cbar_kwargs: Optional dict of parameters to pass to plt.colorbar().
1454
+ Use this to customize colorbar properties (e.g., orientation, label).
1455
+ **kwargs: Additional keyword arguments passed to ax.imshow().
1456
+ Common options include:
1457
+ - interpolation: 'nearest', 'bilinear', 'bicubic', etc.
1458
+ - alpha: Transparency level (0-1)
1459
+ - extent: [left, right, bottom, top] for axis limits
1687
1460
 
1688
1461
  Returns:
1689
1462
  A tuple containing the Matplotlib figure and axes objects used for the plot.
@@ -1705,19 +1478,36 @@ def heatmap_with_matplotlib(
1705
1478
  fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
1706
1479
  ```
1707
1480
  """
1481
+ if colors is None:
1482
+ colors = CONFIG.Plotting.default_sequential_colorscale
1483
+
1484
+ # Initialize kwargs if not provided
1485
+ if imshow_kwargs is None:
1486
+ imshow_kwargs = {}
1487
+ if cbar_kwargs is None:
1488
+ cbar_kwargs = {}
1489
+
1490
+ # Merge any additional kwargs into imshow_kwargs
1491
+ # This allows users to pass imshow options directly
1492
+ imshow_kwargs.update(kwargs)
1493
+
1708
1494
  # Handle empty data
1709
1495
  if data.size == 0:
1710
- if fig is None or ax is None:
1711
- fig, ax = plt.subplots(figsize=figsize)
1496
+ fig, ax = plt.subplots(figsize=figsize)
1712
1497
  return fig, ax
1713
1498
 
1714
1499
  # Apply time reshaping using the new unified function
1715
1500
  # Matplotlib doesn't support faceting/animation, so we pass None for those
1716
1501
  data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
1717
1502
 
1718
- # Create figure and axes if not provided
1719
- if fig is None or ax is None:
1720
- fig, ax = plt.subplots(figsize=figsize)
1503
+ # Handle single-dimension case by adding variable name as a dimension
1504
+ if isinstance(data, xr.DataArray) and len(data.dims) == 1:
1505
+ var_name = data.name if data.name else 'value'
1506
+ data = data.expand_dims({'variable': [var_name]})
1507
+ logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}')
1508
+
1509
+ # Create figure and axes
1510
+ fig, ax = plt.subplots(figsize=figsize)
1721
1511
 
1722
1512
  # Extract data values
1723
1513
  # If data has more than 2 dimensions, we need to reduce it
@@ -1742,15 +1532,19 @@ def heatmap_with_matplotlib(
1742
1532
  x_labels = 'x'
1743
1533
  y_labels = 'y'
1744
1534
 
1745
- # Process colormap
1746
- cmap = colors if isinstance(colors, str) else 'viridis'
1535
+ # Create the heatmap using imshow with user customizations
1536
+ imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax}
1537
+ imshow_defaults.update(imshow_kwargs) # User kwargs override defaults
1538
+ im = ax.imshow(values, **imshow_defaults)
1747
1539
 
1748
- # Create the heatmap using imshow
1749
- im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper')
1540
+ # Add colorbar with user customizations
1541
+ cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05}
1542
+ cbar_defaults.update(cbar_kwargs) # User kwargs override defaults
1543
+ cbar = plt.colorbar(im, **cbar_defaults)
1750
1544
 
1751
- # Add colorbar
1752
- cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05)
1753
- cbar.set_label('Value')
1545
+ # Set colorbar label if not overridden by user
1546
+ if 'label' not in cbar_kwargs:
1547
+ cbar.set_label('Value')
1754
1548
 
1755
1549
  # Set labels and title
1756
1550
  ax.set_xlabel(str(x_labels).capitalize())
@@ -1768,8 +1562,9 @@ def export_figure(
1768
1562
  default_path: pathlib.Path,
1769
1563
  default_filetype: str | None = None,
1770
1564
  user_path: pathlib.Path | None = None,
1771
- show: bool = True,
1565
+ show: bool | None = None,
1772
1566
  save: bool = False,
1567
+ dpi: int | None = None,
1773
1568
  ) -> go.Figure | tuple[plt.Figure, plt.Axes]:
1774
1569
  """
1775
1570
  Export a figure to a file and or show it.
@@ -1779,13 +1574,21 @@ def export_figure(
1779
1574
  default_path: The default file path if no user filename is provided.
1780
1575
  default_filetype: The default filetype if the path doesnt end with a filetype.
1781
1576
  user_path: An optional user-specified file path.
1782
- show: Whether to display the figure (default: True).
1577
+ show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None).
1783
1578
  save: Whether to save the figure (default: False).
1579
+ dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi.
1784
1580
 
1785
1581
  Raises:
1786
1582
  ValueError: If no default filetype is provided and the path doesn't specify a filetype.
1787
1583
  TypeError: If the figure type is not supported.
1788
1584
  """
1585
+ # Apply CONFIG defaults if not explicitly set
1586
+ if show is None:
1587
+ show = CONFIG.Plotting.default_show
1588
+
1589
+ if dpi is None:
1590
+ dpi = CONFIG.Plotting.default_dpi
1591
+
1789
1592
  filename = user_path or default_path
1790
1593
  filename = filename.with_name(filename.name.replace('|', '__'))
1791
1594
  if filename.suffix == '':
@@ -1800,25 +1603,17 @@ def export_figure(
1800
1603
  filename = filename.with_suffix('.html')
1801
1604
 
1802
1605
  try:
1803
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
1804
-
1805
- if is_test_env:
1806
- # Test environment: never open browser, only save if requested
1807
- if save:
1808
- fig.write_html(str(filename))
1809
- # Ignore show flag in tests
1810
- else:
1811
- # Production environment: respect show and save flags
1812
- if save and show:
1813
- # Save and auto-open in browser
1814
- plotly.offline.plot(fig, filename=str(filename))
1815
- elif save and not show:
1816
- # Save without opening
1817
- fig.write_html(str(filename))
1818
- elif show and not save:
1819
- # Show interactively without saving
1820
- fig.show()
1821
- # If neither save nor show: do nothing
1606
+ # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False)
1607
+ if save and show:
1608
+ # Save and auto-open in browser
1609
+ plotly.offline.plot(fig, filename=str(filename))
1610
+ elif save and not show:
1611
+ # Save without opening
1612
+ fig.write_html(str(filename))
1613
+ elif show and not save:
1614
+ # Show interactively without saving
1615
+ fig.show()
1616
+ # If neither save nor show: do nothing
1822
1617
  finally:
1823
1618
  # Cleanup to prevent socket warnings
1824
1619
  if hasattr(fig, '_renderer'):
@@ -1829,16 +1624,15 @@ def export_figure(
1829
1624
  elif isinstance(figure_like, tuple):
1830
1625
  fig, ax = figure_like
1831
1626
  if show:
1832
- # Only show if using interactive backend and not in test environment
1627
+ # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False)
1833
1628
  backend = matplotlib.get_backend().lower()
1834
1629
  is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'}
1835
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
1836
1630
 
1837
- if is_interactive and not is_test_env:
1631
+ if is_interactive:
1838
1632
  plt.show()
1839
1633
 
1840
1634
  if save:
1841
- fig.savefig(str(filename), dpi=300)
1635
+ fig.savefig(str(filename), dpi=dpi)
1842
1636
  plt.close(fig) # Close figure to free memory
1843
1637
 
1844
1638
  return fig, ax