flixopt 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flixopt might be problematic. Click here for more details.
- flixopt/aggregation.py +13 -4
- flixopt/calculation.py +2 -3
- flixopt/color_processing.py +261 -0
- flixopt/config.py +59 -4
- flixopt/flow_system.py +5 -3
- flixopt/interface.py +2 -1
- flixopt/io.py +239 -22
- flixopt/plotting.py +583 -789
- flixopt/results.py +445 -56
- flixopt/structure.py +1 -3
- {flixopt-3.1.1.dist-info → flixopt-3.2.0.dist-info}/METADATA +2 -2
- flixopt-3.2.0.dist-info/RECORD +26 -0
- flixopt/utils.py +0 -86
- flixopt-3.1.1.dist-info/RECORD +0 -26
- {flixopt-3.1.1.dist-info → flixopt-3.2.0.dist-info}/WHEEL +0 -0
- {flixopt-3.1.1.dist-info → flixopt-3.2.0.dist-info}/licenses/LICENSE +0 -0
- {flixopt-3.1.1.dist-info → flixopt-3.2.0.dist-info}/top_level.txt +0 -0
flixopt/plotting.py
CHANGED
|
@@ -40,14 +40,16 @@ import plotly.express as px
|
|
|
40
40
|
import plotly.graph_objects as go
|
|
41
41
|
import plotly.offline
|
|
42
42
|
import xarray as xr
|
|
43
|
-
|
|
43
|
+
|
|
44
|
+
from .color_processing import process_colors
|
|
45
|
+
from .config import CONFIG
|
|
44
46
|
|
|
45
47
|
if TYPE_CHECKING:
|
|
46
48
|
import pyvis
|
|
47
49
|
|
|
48
50
|
logger = logging.getLogger('flixopt')
|
|
49
51
|
|
|
50
|
-
# Define the colors for the 'portland'
|
|
52
|
+
# Define the colors for the 'portland' colorscale in matplotlib
|
|
51
53
|
_portland_colors = [
|
|
52
54
|
[12 / 255, 51 / 255, 131 / 255], # Dark blue
|
|
53
55
|
[10 / 255, 136 / 255, 186 / 255], # Light blue
|
|
@@ -56,7 +58,7 @@ _portland_colors = [
|
|
|
56
58
|
[217 / 255, 30 / 255, 30 / 255], # Red
|
|
57
59
|
]
|
|
58
60
|
|
|
59
|
-
# Check if the
|
|
61
|
+
# Check if the colorscale already exists before registering it
|
|
60
62
|
if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7
|
|
61
63
|
registry = plt.colormaps
|
|
62
64
|
if 'portland' not in registry:
|
|
@@ -71,9 +73,9 @@ ColorType = str | list[str] | dict[str, str]
|
|
|
71
73
|
|
|
72
74
|
Color specifications can take several forms to accommodate different use cases:
|
|
73
75
|
|
|
74
|
-
**Named
|
|
75
|
-
- Standard
|
|
76
|
-
- Energy-focused: 'portland' (custom flixopt
|
|
76
|
+
**Named colorscales** (str):
|
|
77
|
+
- Standard colorscales: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1'
|
|
78
|
+
- Energy-focused: 'portland' (custom flixopt colorscale for energy systems)
|
|
77
79
|
- Backend-specific maps available in Plotly and Matplotlib
|
|
78
80
|
|
|
79
81
|
**Color Lists** (list[str]):
|
|
@@ -88,8 +90,8 @@ Color specifications can take several forms to accommodate different use cases:
|
|
|
88
90
|
|
|
89
91
|
Examples:
|
|
90
92
|
```python
|
|
91
|
-
# Named
|
|
92
|
-
colors = '
|
|
93
|
+
# Named colorscale
|
|
94
|
+
colors = 'turbo' # Automatic color generation
|
|
93
95
|
|
|
94
96
|
# Explicit color list
|
|
95
97
|
colors = ['red', 'blue', 'green', '#FFD700']
|
|
@@ -112,7 +114,7 @@ Color Format Support:
|
|
|
112
114
|
|
|
113
115
|
References:
|
|
114
116
|
- HTML Color Names: https://htmlcolorcodes.com/color-names/
|
|
115
|
-
- Matplotlib
|
|
117
|
+
- Matplotlib colorscales: https://matplotlib.org/stable/tutorials/colors/colorscales.html
|
|
116
118
|
- Plotly Built-in Colorscales: https://plotly.com/python/builtin-colorscales/
|
|
117
119
|
"""
|
|
118
120
|
|
|
@@ -120,242 +122,78 @@ PlottingEngine = Literal['plotly', 'matplotlib']
|
|
|
120
122
|
"""Identifier for the plotting engine to use."""
|
|
121
123
|
|
|
122
124
|
|
|
123
|
-
|
|
124
|
-
"""
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
**Fallback Handling**: Graceful degradation when requested colormaps are unavailable
|
|
136
|
-
**Energy System Colors**: Built-in palettes optimized for energy system visualization
|
|
137
|
-
|
|
138
|
-
Color Input Types:
|
|
139
|
-
- **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc.
|
|
140
|
-
- **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00']
|
|
141
|
-
- **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'}
|
|
142
|
-
|
|
143
|
-
Examples:
|
|
144
|
-
Basic color processing:
|
|
145
|
-
|
|
146
|
-
```python
|
|
147
|
-
# Initialize for Plotly backend
|
|
148
|
-
processor = ColorProcessor(engine='plotly', default_colormap='viridis')
|
|
149
|
-
|
|
150
|
-
# Process different color specifications
|
|
151
|
-
colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage'])
|
|
152
|
-
colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C'])
|
|
153
|
-
colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas'])
|
|
154
|
-
|
|
155
|
-
# Switch to Matplotlib
|
|
156
|
-
processor = ColorProcessor(engine='matplotlib')
|
|
157
|
-
mpl_colors = processor.process_colors('tab10', component_labels)
|
|
158
|
-
```
|
|
159
|
-
|
|
160
|
-
Energy system visualization:
|
|
161
|
-
|
|
162
|
-
```python
|
|
163
|
-
# Specialized energy system palette
|
|
164
|
-
energy_colors = {
|
|
165
|
-
'Natural_Gas': '#8B4513', # Brown
|
|
166
|
-
'Electricity': '#FFD700', # Gold
|
|
167
|
-
'Heat': '#FF4500', # Red-orange
|
|
168
|
-
'Cooling': '#87CEEB', # Sky blue
|
|
169
|
-
'Hydrogen': '#E6E6FA', # Lavender
|
|
170
|
-
'Battery': '#32CD32', # Lime green
|
|
171
|
-
}
|
|
172
|
-
|
|
173
|
-
processor = ColorProcessor('plotly')
|
|
174
|
-
flow_colors = processor.process_colors(energy_colors, flow_labels)
|
|
175
|
-
```
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format.
|
|
179
|
-
default_colormap: Fallback colormap when requested palettes are unavailable.
|
|
180
|
-
Common options: 'viridis', 'plasma', 'tab10', 'portland'.
|
|
181
|
-
|
|
182
|
-
"""
|
|
183
|
-
|
|
184
|
-
def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'):
|
|
185
|
-
"""Initialize the color processor with specified backend and defaults."""
|
|
186
|
-
if engine not in ['plotly', 'matplotlib']:
|
|
187
|
-
raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}')
|
|
188
|
-
self.engine = engine
|
|
189
|
-
self.default_colormap = default_colormap
|
|
190
|
-
|
|
191
|
-
def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]:
|
|
192
|
-
"""
|
|
193
|
-
Generate colors from a named colormap.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
colormap_name: Name of the colormap
|
|
197
|
-
num_colors: Number of colors to generate
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
list of colors in the format appropriate for the engine
|
|
201
|
-
"""
|
|
202
|
-
if self.engine == 'plotly':
|
|
203
|
-
try:
|
|
204
|
-
colorscale = px.colors.get_colorscale(colormap_name)
|
|
205
|
-
except PlotlyError as e:
|
|
206
|
-
logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}")
|
|
207
|
-
colorscale = px.colors.get_colorscale(self.default_colormap)
|
|
208
|
-
|
|
209
|
-
# Generate evenly spaced points
|
|
210
|
-
color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0]
|
|
211
|
-
return px.colors.sample_colorscale(colorscale, color_points)
|
|
212
|
-
|
|
213
|
-
else: # matplotlib
|
|
214
|
-
try:
|
|
215
|
-
cmap = plt.get_cmap(colormap_name, num_colors)
|
|
216
|
-
except ValueError as e:
|
|
217
|
-
logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}")
|
|
218
|
-
cmap = plt.get_cmap(self.default_colormap, num_colors)
|
|
219
|
-
|
|
220
|
-
return [cmap(i) for i in range(num_colors)]
|
|
221
|
-
|
|
222
|
-
def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]:
|
|
223
|
-
"""
|
|
224
|
-
Handle a list of colors, cycling if necessary.
|
|
225
|
-
|
|
226
|
-
Args:
|
|
227
|
-
colors: list of color strings
|
|
228
|
-
num_labels: Number of labels that need colors
|
|
229
|
-
|
|
230
|
-
Returns:
|
|
231
|
-
list of colors matching the number of labels
|
|
232
|
-
"""
|
|
233
|
-
if len(colors) == 0:
|
|
234
|
-
logger.error(f'Empty color list provided. Using {self.default_colormap} instead.')
|
|
235
|
-
return self._generate_colors_from_colormap(self.default_colormap, num_labels)
|
|
236
|
-
|
|
237
|
-
if len(colors) < num_labels:
|
|
238
|
-
logger.warning(
|
|
239
|
-
f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.'
|
|
240
|
-
)
|
|
241
|
-
# Cycle through the colors
|
|
242
|
-
color_iter = itertools.cycle(colors)
|
|
243
|
-
return [next(color_iter) for _ in range(num_labels)]
|
|
244
|
-
else:
|
|
245
|
-
# Trim if necessary
|
|
246
|
-
if len(colors) > num_labels:
|
|
247
|
-
logger.warning(
|
|
248
|
-
f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.'
|
|
249
|
-
)
|
|
250
|
-
return colors[:num_labels]
|
|
251
|
-
|
|
252
|
-
def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]:
|
|
253
|
-
"""
|
|
254
|
-
Handle a dictionary mapping labels to colors.
|
|
255
|
-
|
|
256
|
-
Args:
|
|
257
|
-
colors: Dictionary mapping labels to colors
|
|
258
|
-
labels: list of labels that need colors
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
list of colors in the same order as labels
|
|
262
|
-
"""
|
|
263
|
-
if len(colors) == 0:
|
|
264
|
-
logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.')
|
|
265
|
-
return self._generate_colors_from_colormap(self.default_colormap, len(labels))
|
|
266
|
-
|
|
267
|
-
# Find missing labels
|
|
268
|
-
missing_labels = sorted(set(labels) - set(colors.keys()))
|
|
269
|
-
if missing_labels:
|
|
270
|
-
logger.warning(
|
|
271
|
-
f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.'
|
|
272
|
-
)
|
|
125
|
+
def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset:
|
|
126
|
+
"""Convert DataFrame or Series to Dataset if needed."""
|
|
127
|
+
if isinstance(data, xr.Dataset):
|
|
128
|
+
return data
|
|
129
|
+
elif isinstance(data, pd.DataFrame):
|
|
130
|
+
# Convert DataFrame to Dataset
|
|
131
|
+
return data.to_xarray()
|
|
132
|
+
elif isinstance(data, pd.Series):
|
|
133
|
+
# Convert Series to DataFrame first, then to Dataset
|
|
134
|
+
return data.to_frame().to_xarray()
|
|
135
|
+
else:
|
|
136
|
+
raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}')
|
|
273
137
|
|
|
274
|
-
# Generate colors for missing labels
|
|
275
|
-
missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels))
|
|
276
138
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
labels: list of data labels that need colors assigned
|
|
299
|
-
return_mapping: If True, returns a dictionary mapping labels to colors;
|
|
300
|
-
if False, returns a list of colors in the same order as labels
|
|
301
|
-
|
|
302
|
-
Returns:
|
|
303
|
-
Either a list of colors or a dictionary mapping labels to colors
|
|
304
|
-
"""
|
|
305
|
-
if len(labels) == 0:
|
|
306
|
-
logger.error('No labels provided for color assignment.')
|
|
307
|
-
return {} if return_mapping else []
|
|
308
|
-
|
|
309
|
-
# Process based on type of colors input
|
|
310
|
-
if isinstance(colors, str):
|
|
311
|
-
color_list = self._generate_colors_from_colormap(colors, len(labels))
|
|
312
|
-
elif isinstance(colors, list):
|
|
313
|
-
color_list = self._handle_color_list(colors, len(labels))
|
|
314
|
-
elif isinstance(colors, dict):
|
|
315
|
-
color_list = self._handle_color_dict(colors, labels)
|
|
316
|
-
else:
|
|
317
|
-
logger.error(
|
|
318
|
-
f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.'
|
|
139
|
+
def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None:
|
|
140
|
+
"""Validate dataset for plotting (checks for empty data, non-numeric types, etc.)."""
|
|
141
|
+
# Check for empty data
|
|
142
|
+
if not allow_empty and len(data.data_vars) == 0:
|
|
143
|
+
raise ValueError('Empty Dataset provided (no variables). Cannot create plot.')
|
|
144
|
+
|
|
145
|
+
# Check if dataset has any data (xarray uses nbytes for total size)
|
|
146
|
+
if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True:
|
|
147
|
+
if not allow_empty and len(data.data_vars) > 0:
|
|
148
|
+
raise ValueError('Dataset has zero size. Cannot create plot.')
|
|
149
|
+
if len(data.data_vars) == 0:
|
|
150
|
+
return # Empty dataset, nothing to validate
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
# Check for non-numeric data types
|
|
154
|
+
for var in data.data_vars:
|
|
155
|
+
dtype = data[var].dtype
|
|
156
|
+
if not np.issubdtype(dtype, np.number):
|
|
157
|
+
raise TypeError(
|
|
158
|
+
f"Variable '{var}' has non-numeric dtype '{dtype}'. "
|
|
159
|
+
f'Plotting requires numeric data types (int, float, etc.).'
|
|
319
160
|
)
|
|
320
|
-
color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels))
|
|
321
161
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
162
|
+
# Warn about NaN/Inf values
|
|
163
|
+
for var in data.data_vars:
|
|
164
|
+
if np.isnan(data[var].values).any():
|
|
165
|
+
logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.")
|
|
166
|
+
if np.isinf(data[var].values).any():
|
|
167
|
+
logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.")
|
|
327
168
|
|
|
328
169
|
|
|
329
170
|
def with_plotly(
|
|
330
|
-
data:
|
|
171
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
331
172
|
mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
|
|
332
|
-
colors: ColorType =
|
|
173
|
+
colors: ColorType | None = None,
|
|
333
174
|
title: str = '',
|
|
334
175
|
ylabel: str = '',
|
|
335
|
-
xlabel: str = '
|
|
336
|
-
fig: go.Figure | None = None,
|
|
176
|
+
xlabel: str = '',
|
|
337
177
|
facet_by: str | list[str] | None = None,
|
|
338
178
|
animate_by: str | None = None,
|
|
339
|
-
facet_cols: int =
|
|
179
|
+
facet_cols: int | None = None,
|
|
340
180
|
shared_yaxes: bool = True,
|
|
341
181
|
shared_xaxes: bool = True,
|
|
182
|
+
**px_kwargs: Any,
|
|
342
183
|
) -> go.Figure:
|
|
343
184
|
"""
|
|
344
185
|
Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
|
|
345
186
|
|
|
346
187
|
Uses Plotly Express for convenient faceting and animation with automatic styling.
|
|
347
|
-
For simple plots without faceting, can optionally add to an existing figure.
|
|
348
188
|
|
|
349
189
|
Args:
|
|
350
|
-
data:
|
|
190
|
+
data: An xarray Dataset, pandas DataFrame, or pandas Series to plot.
|
|
351
191
|
mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
|
|
352
192
|
'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
|
|
353
|
-
colors: Color specification (
|
|
193
|
+
colors: Color specification (colorscale, list, or dict mapping labels to colors).
|
|
354
194
|
title: The main title of the plot.
|
|
355
195
|
ylabel: The label for the y-axis.
|
|
356
196
|
xlabel: The label for the x-axis.
|
|
357
|
-
fig: A Plotly figure object to plot on (only for simple plots without faceting).
|
|
358
|
-
If not provided, a new figure will be created.
|
|
359
197
|
facet_by: Dimension(s) to create facets for. Creates a subplot grid.
|
|
360
198
|
Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
|
|
361
199
|
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
@@ -364,93 +202,122 @@ def with_plotly(
|
|
|
364
202
|
facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
|
|
365
203
|
shared_yaxes: Whether subplots share y-axes.
|
|
366
204
|
shared_xaxes: Whether subplots share x-axes.
|
|
205
|
+
**px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function
|
|
206
|
+
(px.bar, px.line, px.area). These override default arguments if provided.
|
|
207
|
+
Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear'
|
|
367
208
|
|
|
368
209
|
Returns:
|
|
369
|
-
A Plotly figure object containing the faceted/animated plot.
|
|
210
|
+
A Plotly figure object containing the faceted/animated plot. You can further customize
|
|
211
|
+
the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()).
|
|
370
212
|
|
|
371
213
|
Examples:
|
|
372
214
|
Simple plot:
|
|
373
215
|
|
|
374
216
|
```python
|
|
375
|
-
fig = with_plotly(
|
|
217
|
+
fig = with_plotly(dataset, mode='area', title='Energy Mix')
|
|
376
218
|
```
|
|
377
219
|
|
|
378
220
|
Facet by scenario:
|
|
379
221
|
|
|
380
222
|
```python
|
|
381
|
-
fig = with_plotly(
|
|
223
|
+
fig = with_plotly(dataset, facet_by='scenario', facet_cols=2)
|
|
382
224
|
```
|
|
383
225
|
|
|
384
226
|
Animate by period:
|
|
385
227
|
|
|
386
228
|
```python
|
|
387
|
-
fig = with_plotly(
|
|
229
|
+
fig = with_plotly(dataset, animate_by='period')
|
|
388
230
|
```
|
|
389
231
|
|
|
390
232
|
Facet and animate:
|
|
391
233
|
|
|
392
234
|
```python
|
|
393
|
-
fig = with_plotly(
|
|
235
|
+
fig = with_plotly(dataset, facet_by='scenario', animate_by='period')
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
Customize with Plotly Express kwargs:
|
|
239
|
+
|
|
240
|
+
```python
|
|
241
|
+
fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear')
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
Further customize the returned figure:
|
|
245
|
+
|
|
246
|
+
```python
|
|
247
|
+
fig = with_plotly(dataset, mode='line')
|
|
248
|
+
fig.update_traces(line={'width': 5, 'dash': 'dot'})
|
|
249
|
+
fig.update_layout(template='plotly_dark', width=1200, height=600)
|
|
394
250
|
```
|
|
395
251
|
"""
|
|
252
|
+
if colors is None:
|
|
253
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
254
|
+
|
|
396
255
|
if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
|
|
397
256
|
raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
|
|
398
257
|
|
|
258
|
+
# Apply CONFIG defaults if not explicitly set
|
|
259
|
+
if facet_cols is None:
|
|
260
|
+
facet_cols = CONFIG.Plotting.default_facet_cols
|
|
261
|
+
|
|
262
|
+
# Ensure data is a Dataset and validate it
|
|
263
|
+
data = _ensure_dataset(data)
|
|
264
|
+
_validate_plotting_data(data, allow_empty=True)
|
|
265
|
+
|
|
399
266
|
# Handle empty data
|
|
400
|
-
if
|
|
401
|
-
|
|
402
|
-
elif isinstance(data, xr.DataArray) and data.size == 0:
|
|
403
|
-
return go.Figure()
|
|
404
|
-
elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0:
|
|
267
|
+
if len(data.data_vars) == 0:
|
|
268
|
+
logger.error('with_plotly() got an empty Dataset.')
|
|
405
269
|
return go.Figure()
|
|
406
270
|
|
|
407
|
-
#
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
#
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
raise ValueError(
|
|
433
|
-
f'Expected exactly one non-dimension column for unnamed DataArray, '
|
|
434
|
-
f'but found {len(non_dim_cols)}: {non_dim_cols}'
|
|
271
|
+
# Handle all-scalar datasets (where all variables have no dimensions)
|
|
272
|
+
# This occurs when all variables are scalar values with dims=()
|
|
273
|
+
if all(len(data[var].dims) == 0 for var in data.data_vars):
|
|
274
|
+
# Create a simple DataFrame with variable names as x-axis
|
|
275
|
+
variables = list(data.data_vars.keys())
|
|
276
|
+
values = [float(data[var].values) for var in data.data_vars]
|
|
277
|
+
|
|
278
|
+
# Resolve colors
|
|
279
|
+
color_discrete_map = process_colors(
|
|
280
|
+
colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
281
|
+
)
|
|
282
|
+
marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables]
|
|
283
|
+
|
|
284
|
+
# Create simple plot based on mode using go (not px) for better color control
|
|
285
|
+
if mode in ('stacked_bar', 'grouped_bar'):
|
|
286
|
+
fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)])
|
|
287
|
+
elif mode == 'line':
|
|
288
|
+
fig = go.Figure(
|
|
289
|
+
data=[
|
|
290
|
+
go.Scatter(
|
|
291
|
+
x=variables,
|
|
292
|
+
y=values,
|
|
293
|
+
mode='lines+markers',
|
|
294
|
+
marker=dict(color=marker_colors, size=8),
|
|
295
|
+
line=dict(color='lightgray'),
|
|
435
296
|
)
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
297
|
+
]
|
|
298
|
+
)
|
|
299
|
+
elif mode == 'area':
|
|
300
|
+
fig = go.Figure(
|
|
301
|
+
data=[
|
|
302
|
+
go.Scatter(
|
|
303
|
+
x=variables,
|
|
304
|
+
y=values,
|
|
305
|
+
fill='tozeroy',
|
|
306
|
+
marker=dict(color=marker_colors, size=8),
|
|
307
|
+
line=dict(color='lightgray'),
|
|
308
|
+
)
|
|
309
|
+
]
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"')
|
|
313
|
+
|
|
314
|
+
fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False)
|
|
315
|
+
return fig
|
|
316
|
+
|
|
317
|
+
# Convert Dataset to long-form DataFrame for Plotly Express
|
|
318
|
+
# Structure: time, variable, value, scenario, period, ... (all dims as columns)
|
|
319
|
+
dim_names = list(data.dims)
|
|
320
|
+
df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value')
|
|
454
321
|
|
|
455
322
|
# Validate facet_by and animate_by dimensions exist in the data
|
|
456
323
|
available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
|
|
@@ -502,13 +369,36 @@ def with_plotly(
|
|
|
502
369
|
|
|
503
370
|
# Process colors
|
|
504
371
|
all_vars = df_long['variable'].unique().tolist()
|
|
505
|
-
|
|
506
|
-
|
|
372
|
+
color_discrete_map = process_colors(
|
|
373
|
+
colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Determine which dimension to use for x-axis
|
|
377
|
+
# Collect dimensions used for faceting and animation
|
|
378
|
+
used_dims = set()
|
|
379
|
+
if facet_row:
|
|
380
|
+
used_dims.add(facet_row)
|
|
381
|
+
if facet_col:
|
|
382
|
+
used_dims.add(facet_col)
|
|
383
|
+
if animate_by:
|
|
384
|
+
used_dims.add(animate_by)
|
|
385
|
+
|
|
386
|
+
# Find available dimensions for x-axis (not used for faceting/animation)
|
|
387
|
+
x_candidates = [d for d in available_dims if d not in used_dims]
|
|
388
|
+
|
|
389
|
+
# Use 'time' if available, otherwise use the first available dimension
|
|
390
|
+
if 'time' in x_candidates:
|
|
391
|
+
x_dim = 'time'
|
|
392
|
+
elif len(x_candidates) > 0:
|
|
393
|
+
x_dim = x_candidates[0]
|
|
394
|
+
else:
|
|
395
|
+
# Fallback: use the first dimension (shouldn't happen in normal cases)
|
|
396
|
+
x_dim = available_dims[0] if available_dims else 'time'
|
|
507
397
|
|
|
508
398
|
# Create plot using Plotly Express based on mode
|
|
509
399
|
common_args = {
|
|
510
400
|
'data_frame': df_long,
|
|
511
|
-
'x':
|
|
401
|
+
'x': x_dim,
|
|
512
402
|
'y': 'value',
|
|
513
403
|
'color': 'variable',
|
|
514
404
|
'facet_row': facet_row,
|
|
@@ -516,13 +406,22 @@ def with_plotly(
|
|
|
516
406
|
'animation_frame': animate_by,
|
|
517
407
|
'color_discrete_map': color_discrete_map,
|
|
518
408
|
'title': title,
|
|
519
|
-
'labels': {'value': ylabel,
|
|
409
|
+
'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''},
|
|
520
410
|
}
|
|
521
411
|
|
|
522
412
|
# Add facet_col_wrap for single facet dimension
|
|
523
413
|
if facet_col and not facet_row:
|
|
524
414
|
common_args['facet_col_wrap'] = facet_cols
|
|
525
415
|
|
|
416
|
+
# Add mode-specific defaults (before px_kwargs so they can be overridden)
|
|
417
|
+
if mode in ('line', 'area'):
|
|
418
|
+
common_args['line_shape'] = 'hv' # Stepped lines by default
|
|
419
|
+
|
|
420
|
+
# Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape)
|
|
421
|
+
# These will override the defaults set above
|
|
422
|
+
if px_kwargs:
|
|
423
|
+
common_args.update(px_kwargs)
|
|
424
|
+
|
|
526
425
|
if mode == 'stacked_bar':
|
|
527
426
|
fig = px.bar(**common_args)
|
|
528
427
|
fig.update_traces(marker_line_width=0)
|
|
@@ -531,10 +430,10 @@ def with_plotly(
|
|
|
531
430
|
fig = px.bar(**common_args)
|
|
532
431
|
fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
|
|
533
432
|
elif mode == 'line':
|
|
534
|
-
fig = px.line(**common_args
|
|
433
|
+
fig = px.line(**common_args)
|
|
535
434
|
elif mode == 'area':
|
|
536
435
|
# Use Plotly Express to create the area plot (preserves animation, legends, faceting)
|
|
537
|
-
fig = px.area(**common_args
|
|
436
|
+
fig = px.area(**common_args)
|
|
538
437
|
|
|
539
438
|
# Classify each variable based on its values
|
|
540
439
|
variable_classification = {}
|
|
@@ -577,13 +476,6 @@ def with_plotly(
|
|
|
577
476
|
if hasattr(trace, 'fill'):
|
|
578
477
|
trace.fill = None
|
|
579
478
|
|
|
580
|
-
# Update layout with basic styling (Plotly Express handles sizing automatically)
|
|
581
|
-
fig.update_layout(
|
|
582
|
-
plot_bgcolor='rgba(0,0,0,0)',
|
|
583
|
-
paper_bgcolor='rgba(0,0,0,0)',
|
|
584
|
-
font=dict(size=12),
|
|
585
|
-
)
|
|
586
|
-
|
|
587
479
|
# Update axes to share if requested (Plotly Express already handles this, but we can customize)
|
|
588
480
|
if not shared_yaxes:
|
|
589
481
|
fig.update_yaxes(matches=None)
|
|
@@ -594,33 +486,32 @@ def with_plotly(
|
|
|
594
486
|
|
|
595
487
|
|
|
596
488
|
def with_matplotlib(
|
|
597
|
-
data: pd.DataFrame,
|
|
489
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
598
490
|
mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
|
|
599
|
-
colors: ColorType =
|
|
491
|
+
colors: ColorType | None = None,
|
|
600
492
|
title: str = '',
|
|
601
493
|
ylabel: str = '',
|
|
602
494
|
xlabel: str = 'Time in h',
|
|
603
495
|
figsize: tuple[int, int] = (12, 6),
|
|
604
|
-
|
|
605
|
-
ax: plt.Axes | None = None,
|
|
496
|
+
plot_kwargs: dict[str, Any] | None = None,
|
|
606
497
|
) -> tuple[plt.Figure, plt.Axes]:
|
|
607
498
|
"""
|
|
608
|
-
Plot
|
|
499
|
+
Plot data with Matplotlib using stacked bars or stepped lines.
|
|
609
500
|
|
|
610
501
|
Args:
|
|
611
|
-
data:
|
|
612
|
-
and each column represents a separate data series.
|
|
502
|
+
data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame,
|
|
503
|
+
the index represents time and each column represents a separate data series (variables).
|
|
613
504
|
mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
|
|
614
|
-
colors: Color specification
|
|
615
|
-
- A
|
|
505
|
+
colors: Color specification. Can be:
|
|
506
|
+
- A colorscale name (e.g., 'turbo', 'plasma')
|
|
616
507
|
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
617
|
-
- A
|
|
508
|
+
- A dict mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
618
509
|
title: The title of the plot.
|
|
619
510
|
ylabel: The ylabel of the plot.
|
|
620
511
|
xlabel: The xlabel of the plot.
|
|
621
|
-
figsize: Specify the size of the figure
|
|
622
|
-
|
|
623
|
-
|
|
512
|
+
figsize: Specify the size of the figure (width, height) in inches.
|
|
513
|
+
plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls.
|
|
514
|
+
Use this to customize plot properties (e.g., linewidth, alpha, edgecolor).
|
|
624
515
|
|
|
625
516
|
Returns:
|
|
626
517
|
A tuple containing the Matplotlib figure and axes objects used for the plot.
|
|
@@ -630,48 +521,121 @@ def with_matplotlib(
|
|
|
630
521
|
Negative values are stacked separately without extra labels in the legend.
|
|
631
522
|
- If `mode` is 'line', stepped lines are drawn for each data series.
|
|
632
523
|
"""
|
|
524
|
+
if colors is None:
|
|
525
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
526
|
+
|
|
633
527
|
if mode not in ('stacked_bar', 'line'):
|
|
634
528
|
raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
|
|
635
529
|
|
|
636
|
-
|
|
637
|
-
|
|
530
|
+
# Ensure data is a Dataset and validate it
|
|
531
|
+
data = _ensure_dataset(data)
|
|
532
|
+
_validate_plotting_data(data, allow_empty=True)
|
|
533
|
+
|
|
534
|
+
# Create new figure and axes
|
|
535
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
536
|
+
|
|
537
|
+
# Initialize plot_kwargs if not provided
|
|
538
|
+
if plot_kwargs is None:
|
|
539
|
+
plot_kwargs = {}
|
|
540
|
+
|
|
541
|
+
# Handle all-scalar datasets (where all variables have no dimensions)
|
|
542
|
+
# This occurs when all variables are scalar values with dims=()
|
|
543
|
+
if all(len(data[var].dims) == 0 for var in data.data_vars):
|
|
544
|
+
# Create simple bar/line plot with variable names as x-axis
|
|
545
|
+
variables = list(data.data_vars.keys())
|
|
546
|
+
values = [float(data[var].values) for var in data.data_vars]
|
|
547
|
+
|
|
548
|
+
# Resolve colors
|
|
549
|
+
color_discrete_map = process_colors(
|
|
550
|
+
colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
551
|
+
)
|
|
552
|
+
colors_list = [color_discrete_map.get(var, '#808080') for var in variables]
|
|
553
|
+
|
|
554
|
+
# Create plot based on mode
|
|
555
|
+
if mode == 'stacked_bar':
|
|
556
|
+
ax.bar(variables, values, color=colors_list, **plot_kwargs)
|
|
557
|
+
elif mode == 'line':
|
|
558
|
+
ax.plot(
|
|
559
|
+
variables,
|
|
560
|
+
values,
|
|
561
|
+
marker='o',
|
|
562
|
+
color=colors_list[0] if len(set(colors_list)) == 1 else None,
|
|
563
|
+
**plot_kwargs,
|
|
564
|
+
)
|
|
565
|
+
# If different colors, plot each point separately
|
|
566
|
+
if len(set(colors_list)) > 1:
|
|
567
|
+
ax.clear()
|
|
568
|
+
for i, (var, val) in enumerate(zip(variables, values, strict=False)):
|
|
569
|
+
ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs)
|
|
570
|
+
ax.set_xticks(range(len(variables)))
|
|
571
|
+
ax.set_xticklabels(variables)
|
|
572
|
+
|
|
573
|
+
ax.set_xlabel(xlabel, ha='center')
|
|
574
|
+
ax.set_ylabel(ylabel, va='center')
|
|
575
|
+
ax.set_title(title)
|
|
576
|
+
ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y')
|
|
577
|
+
fig.tight_layout()
|
|
578
|
+
|
|
579
|
+
return fig, ax
|
|
580
|
+
|
|
581
|
+
# Resolve colors first (includes validation)
|
|
582
|
+
color_discrete_map = process_colors(
|
|
583
|
+
colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form)
|
|
587
|
+
df = data.to_dataframe()
|
|
638
588
|
|
|
639
|
-
|
|
589
|
+
# Get colors in column order
|
|
590
|
+
processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns]
|
|
640
591
|
|
|
641
592
|
if mode == 'stacked_bar':
|
|
642
|
-
cumulative_positive = np.zeros(len(
|
|
643
|
-
cumulative_negative = np.zeros(len(
|
|
644
|
-
|
|
593
|
+
cumulative_positive = np.zeros(len(df))
|
|
594
|
+
cumulative_negative = np.zeros(len(df))
|
|
595
|
+
|
|
596
|
+
# Robust bar width: handle datetime-like, numeric, and single-point indexes
|
|
597
|
+
if len(df.index) > 1:
|
|
598
|
+
delta = pd.Index(df.index).to_series().diff().dropna().min()
|
|
599
|
+
if hasattr(delta, 'total_seconds'): # datetime-like
|
|
600
|
+
width = delta.total_seconds() / 86400.0 # Matplotlib date units = days
|
|
601
|
+
else:
|
|
602
|
+
width = float(delta)
|
|
603
|
+
else:
|
|
604
|
+
width = 0.8 # reasonable default for a single bar
|
|
645
605
|
|
|
646
|
-
for i, column in enumerate(
|
|
647
|
-
|
|
648
|
-
|
|
606
|
+
for i, column in enumerate(df.columns):
|
|
607
|
+
# Fill NaNs to avoid breaking stacking math
|
|
608
|
+
series = df[column].fillna(0)
|
|
609
|
+
positive_values = np.clip(series, 0, None) # Keep only positive values
|
|
610
|
+
negative_values = np.clip(series, None, 0) # Keep only negative values
|
|
649
611
|
# Plot positive bars
|
|
650
612
|
ax.bar(
|
|
651
|
-
|
|
613
|
+
df.index,
|
|
652
614
|
positive_values,
|
|
653
615
|
bottom=cumulative_positive,
|
|
654
616
|
color=processed_colors[i],
|
|
655
617
|
label=column,
|
|
656
618
|
width=width,
|
|
657
619
|
align='center',
|
|
620
|
+
**plot_kwargs,
|
|
658
621
|
)
|
|
659
622
|
cumulative_positive += positive_values.values
|
|
660
623
|
# Plot negative bars
|
|
661
624
|
ax.bar(
|
|
662
|
-
|
|
625
|
+
df.index,
|
|
663
626
|
negative_values,
|
|
664
627
|
bottom=cumulative_negative,
|
|
665
628
|
color=processed_colors[i],
|
|
666
629
|
label='', # No label for negative bars
|
|
667
630
|
width=width,
|
|
668
631
|
align='center',
|
|
632
|
+
**plot_kwargs,
|
|
669
633
|
)
|
|
670
634
|
cumulative_negative += negative_values.values
|
|
671
635
|
|
|
672
636
|
elif mode == 'line':
|
|
673
|
-
for i, column in enumerate(
|
|
674
|
-
ax.step(
|
|
637
|
+
for i, column in enumerate(df.columns):
|
|
638
|
+
ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs)
|
|
675
639
|
|
|
676
640
|
# Aesthetics
|
|
677
641
|
ax.set_xlabel(xlabel, ha='center')
|
|
@@ -944,228 +908,104 @@ def plot_network(
|
|
|
944
908
|
)
|
|
945
909
|
|
|
946
910
|
|
|
947
|
-
def
|
|
948
|
-
data: pd.DataFrame,
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
legend_title: str = '',
|
|
952
|
-
hole: float = 0.0,
|
|
953
|
-
fig: go.Figure | None = None,
|
|
954
|
-
) -> go.Figure:
|
|
955
|
-
"""
|
|
956
|
-
Create a pie chart with Plotly to visualize the proportion of values in a DataFrame.
|
|
957
|
-
|
|
958
|
-
Args:
|
|
959
|
-
data: A DataFrame containing the data to plot. If multiple rows exist,
|
|
960
|
-
they will be summed unless a specific index value is passed.
|
|
961
|
-
colors: Color specification, can be:
|
|
962
|
-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
|
|
963
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
964
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
965
|
-
title: The title of the plot.
|
|
966
|
-
legend_title: The title for the legend.
|
|
967
|
-
hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
|
|
968
|
-
fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
|
|
969
|
-
|
|
970
|
-
Returns:
|
|
971
|
-
A Plotly figure object containing the generated pie chart.
|
|
972
|
-
|
|
973
|
-
Notes:
|
|
974
|
-
- Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
|
|
975
|
-
- If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
|
|
976
|
-
for better readability.
|
|
977
|
-
- By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
|
|
978
|
-
|
|
911
|
+
def preprocess_data_for_pie(
|
|
912
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
913
|
+
lower_percentage_threshold: float = 5.0,
|
|
914
|
+
) -> pd.Series:
|
|
979
915
|
"""
|
|
980
|
-
|
|
981
|
-
logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
|
|
982
|
-
return go.Figure()
|
|
983
|
-
|
|
984
|
-
# Create a copy to avoid modifying the original DataFrame
|
|
985
|
-
data_copy = data.copy()
|
|
986
|
-
|
|
987
|
-
# Check if any negative values and warn
|
|
988
|
-
if (data_copy < 0).any().any():
|
|
989
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
990
|
-
data_copy = data_copy.abs()
|
|
991
|
-
|
|
992
|
-
# If data has multiple rows, sum them to get total for each column
|
|
993
|
-
if len(data_copy) > 1:
|
|
994
|
-
data_sum = data_copy.sum()
|
|
995
|
-
else:
|
|
996
|
-
data_sum = data_copy.iloc[0]
|
|
916
|
+
Preprocess data for pie chart display.
|
|
997
917
|
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
values = data_sum.values.tolist()
|
|
1001
|
-
|
|
1002
|
-
# Apply color mapping using the unified color processor
|
|
1003
|
-
processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels)
|
|
1004
|
-
|
|
1005
|
-
# Create figure if not provided
|
|
1006
|
-
fig = fig if fig is not None else go.Figure()
|
|
1007
|
-
|
|
1008
|
-
# Add pie trace
|
|
1009
|
-
fig.add_trace(
|
|
1010
|
-
go.Pie(
|
|
1011
|
-
labels=labels,
|
|
1012
|
-
values=values,
|
|
1013
|
-
hole=hole,
|
|
1014
|
-
marker=dict(colors=processed_colors),
|
|
1015
|
-
textinfo='percent+label+value',
|
|
1016
|
-
textposition='inside',
|
|
1017
|
-
insidetextorientation='radial',
|
|
1018
|
-
)
|
|
1019
|
-
)
|
|
1020
|
-
|
|
1021
|
-
# Update layout for better aesthetics
|
|
1022
|
-
fig.update_layout(
|
|
1023
|
-
title=title,
|
|
1024
|
-
legend_title=legend_title,
|
|
1025
|
-
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
|
|
1026
|
-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
|
|
1027
|
-
font=dict(size=14), # Increase font size for better readability
|
|
1028
|
-
)
|
|
1029
|
-
|
|
1030
|
-
return fig
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
def pie_with_matplotlib(
|
|
1034
|
-
data: pd.DataFrame,
|
|
1035
|
-
colors: ColorType = 'viridis',
|
|
1036
|
-
title: str = '',
|
|
1037
|
-
legend_title: str = 'Categories',
|
|
1038
|
-
hole: float = 0.0,
|
|
1039
|
-
figsize: tuple[int, int] = (10, 8),
|
|
1040
|
-
fig: plt.Figure | None = None,
|
|
1041
|
-
ax: plt.Axes | None = None,
|
|
1042
|
-
) -> tuple[plt.Figure, plt.Axes]:
|
|
1043
|
-
"""
|
|
1044
|
-
Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame.
|
|
918
|
+
Groups items that are individually below the threshold percentage into an "Other" category.
|
|
919
|
+
Converts various input types to a pandas Series for uniform handling.
|
|
1045
920
|
|
|
1046
921
|
Args:
|
|
1047
|
-
data:
|
|
1048
|
-
|
|
1049
|
-
colors: Color specification, can be:
|
|
1050
|
-
- A string with a colormap name (e.g., 'viridis', 'plasma')
|
|
1051
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
1052
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
1053
|
-
title: The title of the plot.
|
|
1054
|
-
legend_title: The title for the legend.
|
|
1055
|
-
hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
|
|
1056
|
-
figsize: The size of the figure (width, height) in inches.
|
|
1057
|
-
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1058
|
-
ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
|
|
922
|
+
data: Input data (xarray Dataset, DataFrame, or Series)
|
|
923
|
+
lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other"
|
|
1059
924
|
|
|
1060
925
|
Returns:
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
Notes:
|
|
1064
|
-
- Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
|
|
1065
|
-
- If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
|
|
1066
|
-
for better readability.
|
|
1067
|
-
- By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
|
|
1068
|
-
|
|
926
|
+
Processed pandas Series with small items grouped into "Other"
|
|
1069
927
|
"""
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
928
|
+
# Convert to Series
|
|
929
|
+
if isinstance(data, xr.Dataset):
|
|
930
|
+
# Sum all dimensions for each variable to get total values
|
|
931
|
+
values = {}
|
|
932
|
+
for var in data.data_vars:
|
|
933
|
+
var_data = data[var]
|
|
934
|
+
if len(var_data.dims) > 0:
|
|
935
|
+
total_value = float(var_data.sum().item())
|
|
936
|
+
else:
|
|
937
|
+
total_value = float(var_data.item())
|
|
1075
938
|
|
|
1076
|
-
|
|
1077
|
-
|
|
939
|
+
# Handle negative values
|
|
940
|
+
if total_value < 0:
|
|
941
|
+
logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.')
|
|
942
|
+
total_value = abs(total_value)
|
|
1078
943
|
|
|
1079
|
-
|
|
1080
|
-
if (data_copy < 0).any().any():
|
|
1081
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
1082
|
-
data_copy = data_copy.abs()
|
|
944
|
+
values[var] = total_value
|
|
1083
945
|
|
|
1084
|
-
|
|
1085
|
-
if len(data_copy) > 1:
|
|
1086
|
-
data_sum = data_copy.sum()
|
|
1087
|
-
else:
|
|
1088
|
-
data_sum = data_copy.iloc[0]
|
|
946
|
+
series = pd.Series(values)
|
|
1089
947
|
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
948
|
+
elif isinstance(data, pd.DataFrame):
|
|
949
|
+
# Sum across all columns if DataFrame
|
|
950
|
+
series = data.sum(axis=0)
|
|
951
|
+
# Handle negative values
|
|
952
|
+
negative_mask = series < 0
|
|
953
|
+
if negative_mask.any():
|
|
954
|
+
logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
|
|
955
|
+
series = series.abs()
|
|
1093
956
|
|
|
1094
|
-
#
|
|
1095
|
-
|
|
957
|
+
else: # pd.Series
|
|
958
|
+
series = data.copy()
|
|
959
|
+
# Handle negative values
|
|
960
|
+
negative_mask = series < 0
|
|
961
|
+
if negative_mask.any():
|
|
962
|
+
logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
|
|
963
|
+
series = series.abs()
|
|
1096
964
|
|
|
1097
|
-
#
|
|
1098
|
-
|
|
1099
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
965
|
+
# Only keep positive values
|
|
966
|
+
series = series[series > 0]
|
|
1100
967
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
values,
|
|
1104
|
-
labels=labels,
|
|
1105
|
-
colors=processed_colors,
|
|
1106
|
-
autopct='%1.1f%%',
|
|
1107
|
-
startangle=90,
|
|
1108
|
-
shadow=False,
|
|
1109
|
-
wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut
|
|
1110
|
-
)
|
|
968
|
+
if series.empty or lower_percentage_threshold <= 0:
|
|
969
|
+
return series
|
|
1111
970
|
|
|
1112
|
-
#
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
if hole > 0:
|
|
1116
|
-
# Adjust hole size to match plotly's hole parameter
|
|
1117
|
-
# In matplotlib, wedge width is relative to the radius (which is 1)
|
|
1118
|
-
# For plotly, hole is a fraction of the radius
|
|
1119
|
-
wedge_width = 1 - hole
|
|
1120
|
-
for wedge in wedges:
|
|
1121
|
-
wedge.set_width(wedge_width)
|
|
1122
|
-
|
|
1123
|
-
# Customize the appearance
|
|
1124
|
-
# Make autopct text more visible
|
|
1125
|
-
for autotext in autotexts:
|
|
1126
|
-
autotext.set_fontsize(10)
|
|
1127
|
-
autotext.set_color('white')
|
|
1128
|
-
|
|
1129
|
-
# Set aspect ratio to be equal to ensure a circular pie
|
|
1130
|
-
ax.set_aspect('equal')
|
|
1131
|
-
|
|
1132
|
-
# Add title
|
|
1133
|
-
if title:
|
|
1134
|
-
ax.set_title(title, fontsize=16)
|
|
971
|
+
# Calculate percentages
|
|
972
|
+
total = series.sum()
|
|
973
|
+
percentages = (series / total) * 100
|
|
1135
974
|
|
|
1136
|
-
#
|
|
1137
|
-
|
|
1138
|
-
|
|
975
|
+
# Find items below and above threshold
|
|
976
|
+
below_threshold = series[percentages < lower_percentage_threshold]
|
|
977
|
+
above_threshold = series[percentages >= lower_percentage_threshold]
|
|
1139
978
|
|
|
1140
|
-
#
|
|
1141
|
-
|
|
979
|
+
# Only group if there are at least 2 items below threshold
|
|
980
|
+
if len(below_threshold) > 1:
|
|
981
|
+
# Create new series with items above threshold + "Other"
|
|
982
|
+
result = above_threshold.copy()
|
|
983
|
+
result['Other'] = below_threshold.sum()
|
|
984
|
+
return result
|
|
1142
985
|
|
|
1143
|
-
return
|
|
986
|
+
return series
|
|
1144
987
|
|
|
1145
988
|
|
|
1146
989
|
def dual_pie_with_plotly(
|
|
1147
|
-
data_left: pd.Series,
|
|
1148
|
-
data_right: pd.Series,
|
|
1149
|
-
colors: ColorType =
|
|
990
|
+
data_left: xr.Dataset | pd.DataFrame | pd.Series,
|
|
991
|
+
data_right: xr.Dataset | pd.DataFrame | pd.Series,
|
|
992
|
+
colors: ColorType | None = None,
|
|
1150
993
|
title: str = '',
|
|
1151
994
|
subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
|
|
1152
995
|
legend_title: str = '',
|
|
1153
996
|
hole: float = 0.2,
|
|
1154
997
|
lower_percentage_group: float = 5.0,
|
|
1155
|
-
hover_template: str = '%{label}: %{value} (%{percent})',
|
|
1156
998
|
text_info: str = 'percent+label',
|
|
1157
999
|
text_position: str = 'inside',
|
|
1000
|
+
hover_template: str = '%{label}: %{value} (%{percent})',
|
|
1158
1001
|
) -> go.Figure:
|
|
1159
1002
|
"""
|
|
1160
|
-
Create two pie charts side by side with Plotly
|
|
1003
|
+
Create two pie charts side by side with Plotly.
|
|
1161
1004
|
|
|
1162
1005
|
Args:
|
|
1163
|
-
data_left:
|
|
1164
|
-
data_right:
|
|
1165
|
-
colors: Color specification,
|
|
1166
|
-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
|
|
1167
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
1168
|
-
- A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
|
|
1006
|
+
data_left: Data for the left pie chart. Variables are summed across all dimensions.
|
|
1007
|
+
data_right: Data for the right pie chart. Variables are summed across all dimensions.
|
|
1008
|
+
colors: Color specification (colorscale name, list of colors, or dict mapping)
|
|
1169
1009
|
title: The main title of the plot.
|
|
1170
1010
|
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
1171
1011
|
legend_title: The title for the legend.
|
|
@@ -1177,119 +1017,67 @@ def dual_pie_with_plotly(
|
|
|
1177
1017
|
text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
|
|
1178
1018
|
|
|
1179
1019
|
Returns:
|
|
1180
|
-
|
|
1020
|
+
Plotly Figure object
|
|
1181
1021
|
"""
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
#
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
# Sort series by value (ascending)
|
|
1220
|
-
sorted_series = series.sort_values()
|
|
1221
|
-
|
|
1222
|
-
# Calculate cumulative percentage contribution
|
|
1223
|
-
cumulative_percent = (sorted_series.cumsum() / total) * 100
|
|
1224
|
-
|
|
1225
|
-
# Find entries that collectively make up less than lower_percentage_group
|
|
1226
|
-
to_group = cumulative_percent <= lower_percentage_group
|
|
1227
|
-
|
|
1228
|
-
if to_group.sum() > 1:
|
|
1229
|
-
# Create "Other" category for the smallest values that together are < threshold
|
|
1230
|
-
other_sum = sorted_series[to_group].sum()
|
|
1231
|
-
|
|
1232
|
-
# Keep only values that aren't in the "Other" group
|
|
1233
|
-
result_series = series[~series.index.isin(sorted_series[to_group].index)]
|
|
1234
|
-
|
|
1235
|
-
# Add the "Other" category if it has a value
|
|
1236
|
-
if other_sum > 0:
|
|
1237
|
-
result_series['Other'] = other_sum
|
|
1238
|
-
|
|
1239
|
-
return result_series
|
|
1240
|
-
|
|
1241
|
-
return series
|
|
1242
|
-
|
|
1243
|
-
data_left_processed = preprocess_series(data_left)
|
|
1244
|
-
data_right_processed = preprocess_series(data_right)
|
|
1245
|
-
|
|
1246
|
-
# Get unique set of all labels for consistent coloring
|
|
1247
|
-
all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
|
|
1248
|
-
|
|
1249
|
-
# Get consistent color mapping for both charts using our unified function
|
|
1250
|
-
color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True)
|
|
1251
|
-
|
|
1252
|
-
# Function to create a pie trace with consistently mapped colors
|
|
1253
|
-
def create_pie_trace(data_series, side):
|
|
1254
|
-
if data_series.empty:
|
|
1255
|
-
return None
|
|
1256
|
-
|
|
1257
|
-
labels = data_series.index.tolist()
|
|
1258
|
-
values = data_series.values.tolist()
|
|
1259
|
-
trace_colors = [color_map[label] for label in labels]
|
|
1260
|
-
|
|
1261
|
-
return go.Pie(
|
|
1262
|
-
labels=labels,
|
|
1263
|
-
values=values,
|
|
1264
|
-
name=side,
|
|
1265
|
-
marker=dict(colors=trace_colors),
|
|
1266
|
-
hole=hole,
|
|
1267
|
-
textinfo=text_info,
|
|
1268
|
-
textposition=text_position,
|
|
1269
|
-
insidetextorientation='radial',
|
|
1270
|
-
hovertemplate=hover_template,
|
|
1271
|
-
sort=True, # Sort values by default (largest first)
|
|
1022
|
+
if colors is None:
|
|
1023
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
1024
|
+
|
|
1025
|
+
# Preprocess data to Series
|
|
1026
|
+
left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
|
|
1027
|
+
right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
|
|
1028
|
+
|
|
1029
|
+
# Extract labels and values
|
|
1030
|
+
left_labels = left_series.index.tolist()
|
|
1031
|
+
left_values = left_series.values.tolist()
|
|
1032
|
+
|
|
1033
|
+
right_labels = right_series.index.tolist()
|
|
1034
|
+
right_values = right_series.values.tolist()
|
|
1035
|
+
|
|
1036
|
+
# Get all unique labels for consistent coloring
|
|
1037
|
+
all_labels = sorted(set(left_labels) | set(right_labels))
|
|
1038
|
+
|
|
1039
|
+
# Create color map
|
|
1040
|
+
color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
|
|
1041
|
+
|
|
1042
|
+
# Create figure
|
|
1043
|
+
fig = go.Figure()
|
|
1044
|
+
|
|
1045
|
+
# Add left pie
|
|
1046
|
+
if left_labels:
|
|
1047
|
+
fig.add_trace(
|
|
1048
|
+
go.Pie(
|
|
1049
|
+
labels=left_labels,
|
|
1050
|
+
values=left_values,
|
|
1051
|
+
name=subtitles[0],
|
|
1052
|
+
marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]),
|
|
1053
|
+
hole=hole,
|
|
1054
|
+
textinfo=text_info,
|
|
1055
|
+
textposition=text_position,
|
|
1056
|
+
hovertemplate=hover_template,
|
|
1057
|
+
domain=dict(x=[0, 0.48]),
|
|
1058
|
+
)
|
|
1272
1059
|
)
|
|
1273
1060
|
|
|
1274
|
-
# Add
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1061
|
+
# Add right pie
|
|
1062
|
+
if right_labels:
|
|
1063
|
+
fig.add_trace(
|
|
1064
|
+
go.Pie(
|
|
1065
|
+
labels=right_labels,
|
|
1066
|
+
values=right_values,
|
|
1067
|
+
name=subtitles[1],
|
|
1068
|
+
marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]),
|
|
1069
|
+
hole=hole,
|
|
1070
|
+
textinfo=text_info,
|
|
1071
|
+
textposition=text_position,
|
|
1072
|
+
hovertemplate=hover_template,
|
|
1073
|
+
domain=dict(x=[0.52, 1]),
|
|
1074
|
+
)
|
|
1075
|
+
)
|
|
1285
1076
|
|
|
1286
1077
|
# Update layout
|
|
1287
1078
|
fig.update_layout(
|
|
1288
1079
|
title=title,
|
|
1289
1080
|
legend_title=legend_title,
|
|
1290
|
-
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
|
|
1291
|
-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
|
|
1292
|
-
font=dict(size=14),
|
|
1293
1081
|
margin=dict(t=80, b=50, l=30, r=30),
|
|
1294
1082
|
)
|
|
1295
1083
|
|
|
@@ -1297,178 +1085,127 @@ def dual_pie_with_plotly(
|
|
|
1297
1085
|
|
|
1298
1086
|
|
|
1299
1087
|
def dual_pie_with_matplotlib(
|
|
1300
|
-
data_left: pd.Series,
|
|
1301
|
-
data_right: pd.Series,
|
|
1302
|
-
colors: ColorType =
|
|
1088
|
+
data_left: xr.Dataset | pd.DataFrame | pd.Series,
|
|
1089
|
+
data_right: xr.Dataset | pd.DataFrame | pd.Series,
|
|
1090
|
+
colors: ColorType | None = None,
|
|
1303
1091
|
title: str = '',
|
|
1304
1092
|
subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
|
|
1305
1093
|
legend_title: str = '',
|
|
1306
1094
|
hole: float = 0.2,
|
|
1307
1095
|
lower_percentage_group: float = 5.0,
|
|
1308
1096
|
figsize: tuple[int, int] = (14, 7),
|
|
1309
|
-
fig: plt.Figure | None = None,
|
|
1310
|
-
axes: list[plt.Axes] | None = None,
|
|
1311
1097
|
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
1312
1098
|
"""
|
|
1313
|
-
Create two pie charts side by side with Matplotlib
|
|
1314
|
-
Leverages the existing pie_with_matplotlib function.
|
|
1099
|
+
Create two pie charts side by side with Matplotlib.
|
|
1315
1100
|
|
|
1316
1101
|
Args:
|
|
1317
|
-
data_left:
|
|
1318
|
-
data_right:
|
|
1319
|
-
colors: Color specification,
|
|
1320
|
-
- A string with a colormap name (e.g., 'viridis', 'plasma')
|
|
1321
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
1322
|
-
- A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
|
|
1102
|
+
data_left: Data for the left pie chart.
|
|
1103
|
+
data_right: Data for the right pie chart.
|
|
1104
|
+
colors: Color specification (colorscale name, list of colors, or dict mapping)
|
|
1323
1105
|
title: The main title of the plot.
|
|
1324
1106
|
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
1325
1107
|
legend_title: The title for the legend.
|
|
1326
1108
|
hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
|
|
1327
1109
|
lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
|
|
1328
1110
|
figsize: The size of the figure (width, height) in inches.
|
|
1329
|
-
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1330
|
-
axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created.
|
|
1331
1111
|
|
|
1332
1112
|
Returns:
|
|
1333
|
-
|
|
1113
|
+
Tuple of (Figure, list of Axes)
|
|
1334
1114
|
"""
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
logger.error('Both datasets are empty. Returning empty figure.')
|
|
1338
|
-
if fig is None:
|
|
1339
|
-
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1340
|
-
return fig, axes
|
|
1341
|
-
|
|
1342
|
-
# Create figure and axes if not provided
|
|
1343
|
-
if fig is None or axes is None:
|
|
1344
|
-
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1345
|
-
|
|
1346
|
-
# Process series to handle negative values and apply minimum percentage threshold
|
|
1347
|
-
def preprocess_series(series: pd.Series):
|
|
1348
|
-
"""
|
|
1349
|
-
Preprocess a series for pie chart display by handling negative values
|
|
1350
|
-
and grouping the smallest parts together if they collectively represent
|
|
1351
|
-
less than the specified percentage threshold.
|
|
1352
|
-
"""
|
|
1353
|
-
# Handle negative values
|
|
1354
|
-
if (series < 0).any():
|
|
1355
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
1356
|
-
series = series.abs()
|
|
1115
|
+
if colors is None:
|
|
1116
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
1357
1117
|
|
|
1358
|
-
|
|
1359
|
-
|
|
1118
|
+
# Preprocess data to Series
|
|
1119
|
+
left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
|
|
1120
|
+
right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
|
|
1360
1121
|
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
if total > 0:
|
|
1365
|
-
# Sort series by value (ascending)
|
|
1366
|
-
sorted_series = series.sort_values()
|
|
1122
|
+
# Extract labels and values
|
|
1123
|
+
left_labels = left_series.index.tolist()
|
|
1124
|
+
left_values = left_series.values.tolist()
|
|
1367
1125
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1126
|
+
right_labels = right_series.index.tolist()
|
|
1127
|
+
right_values = right_series.values.tolist()
|
|
1370
1128
|
|
|
1371
|
-
|
|
1372
|
-
|
|
1129
|
+
# Get all unique labels for consistent coloring
|
|
1130
|
+
all_labels = sorted(set(left_labels) | set(right_labels))
|
|
1373
1131
|
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
other_sum = sorted_series[to_group].sum()
|
|
1132
|
+
# Create color map (process_colors always returns a dict)
|
|
1133
|
+
color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
|
|
1377
1134
|
|
|
1378
|
-
|
|
1379
|
-
|
|
1135
|
+
# Create figure
|
|
1136
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1380
1137
|
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1138
|
+
def draw_pie(ax, labels, values, subtitle):
|
|
1139
|
+
"""Draw a single pie chart."""
|
|
1140
|
+
if not labels:
|
|
1141
|
+
ax.set_title(subtitle)
|
|
1142
|
+
ax.axis('off')
|
|
1143
|
+
return
|
|
1384
1144
|
|
|
1385
|
-
|
|
1145
|
+
chart_colors = [color_map[label] for label in labels]
|
|
1386
1146
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
# Get unique set of all labels for consistent coloring
|
|
1398
|
-
all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
|
|
1399
|
-
|
|
1400
|
-
# Get consistent color mapping for both charts using our unified function
|
|
1401
|
-
color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True)
|
|
1147
|
+
# Draw pie
|
|
1148
|
+
wedges, texts, autotexts = ax.pie(
|
|
1149
|
+
values,
|
|
1150
|
+
labels=labels,
|
|
1151
|
+
colors=chart_colors,
|
|
1152
|
+
autopct='%1.1f%%',
|
|
1153
|
+
startangle=90,
|
|
1154
|
+
wedgeprops=dict(width=1 - hole) if hole > 0 else None,
|
|
1155
|
+
)
|
|
1402
1156
|
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1157
|
+
# Style text
|
|
1158
|
+
for autotext in autotexts:
|
|
1159
|
+
autotext.set_fontsize(10)
|
|
1160
|
+
autotext.set_color('white')
|
|
1161
|
+
autotext.set_weight('bold')
|
|
1406
1162
|
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0])
|
|
1410
|
-
else:
|
|
1411
|
-
axes[0].set_title(subtitles[0])
|
|
1412
|
-
axes[0].axis('off')
|
|
1163
|
+
ax.set_aspect('equal')
|
|
1164
|
+
ax.set_title(subtitle, fontsize=14, pad=20)
|
|
1413
1165
|
|
|
1414
|
-
#
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
else:
|
|
1418
|
-
axes[1].set_title(subtitles[1])
|
|
1419
|
-
axes[1].axis('off')
|
|
1166
|
+
# Draw both pies
|
|
1167
|
+
draw_pie(axes[0], left_labels, left_values, subtitles[0])
|
|
1168
|
+
draw_pie(axes[1], right_labels, right_values, subtitles[1])
|
|
1420
1169
|
|
|
1421
1170
|
# Add main title
|
|
1422
1171
|
if title:
|
|
1423
1172
|
fig.suptitle(title, fontsize=16, y=0.98)
|
|
1424
1173
|
|
|
1425
|
-
#
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
for ax in axes:
|
|
1432
|
-
if ax.get_legend():
|
|
1433
|
-
ax.get_legend().remove()
|
|
1434
|
-
|
|
1435
|
-
# Create handles for the unified legend
|
|
1436
|
-
handles = []
|
|
1437
|
-
labels_for_legend = []
|
|
1438
|
-
|
|
1439
|
-
for label in all_labels:
|
|
1440
|
-
color = color_map[label]
|
|
1441
|
-
patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label)
|
|
1442
|
-
handles.append(patch)
|
|
1443
|
-
labels_for_legend.append(label)
|
|
1174
|
+
# Create unified legend
|
|
1175
|
+
if left_labels or right_labels:
|
|
1176
|
+
handles = [
|
|
1177
|
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10)
|
|
1178
|
+
for label in all_labels
|
|
1179
|
+
]
|
|
1444
1180
|
|
|
1445
|
-
# Add unified legend
|
|
1446
1181
|
fig.legend(
|
|
1447
1182
|
handles=handles,
|
|
1448
|
-
labels=
|
|
1183
|
+
labels=all_labels,
|
|
1449
1184
|
title=legend_title,
|
|
1450
1185
|
loc='lower center',
|
|
1451
|
-
bbox_to_anchor=(0.5, 0),
|
|
1452
|
-
ncol=min(len(all_labels), 5),
|
|
1186
|
+
bbox_to_anchor=(0.5, -0.02),
|
|
1187
|
+
ncol=min(len(all_labels), 5),
|
|
1453
1188
|
)
|
|
1454
1189
|
|
|
1455
|
-
|
|
1456
|
-
|
|
1190
|
+
fig.subplots_adjust(bottom=0.15)
|
|
1191
|
+
|
|
1192
|
+
fig.tight_layout()
|
|
1457
1193
|
|
|
1458
1194
|
return fig, axes
|
|
1459
1195
|
|
|
1460
1196
|
|
|
1461
1197
|
def heatmap_with_plotly(
|
|
1462
1198
|
data: xr.DataArray,
|
|
1463
|
-
colors: ColorType =
|
|
1199
|
+
colors: ColorType | None = None,
|
|
1464
1200
|
title: str = '',
|
|
1465
1201
|
facet_by: str | list[str] | None = None,
|
|
1466
1202
|
animate_by: str | None = None,
|
|
1467
|
-
facet_cols: int =
|
|
1203
|
+
facet_cols: int | None = None,
|
|
1468
1204
|
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1469
1205
|
| Literal['auto']
|
|
1470
1206
|
| None = 'auto',
|
|
1471
1207
|
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1208
|
+
**imshow_kwargs: Any,
|
|
1472
1209
|
) -> go.Figure:
|
|
1473
1210
|
"""
|
|
1474
1211
|
Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
|
|
@@ -1486,8 +1223,8 @@ def heatmap_with_plotly(
|
|
|
1486
1223
|
Args:
|
|
1487
1224
|
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1488
1225
|
2 dimensions, or a 'time' dimension that can be reshaped into 2D.
|
|
1489
|
-
colors: Color specification (
|
|
1490
|
-
'
|
|
1226
|
+
colors: Color specification (colorscale name, list, or dict). Common options:
|
|
1227
|
+
'turbo', 'plasma', 'RdBu', 'portland'.
|
|
1491
1228
|
title: The main title of the heatmap.
|
|
1492
1229
|
facet_by: Dimension to create facets for. Creates a subplot grid.
|
|
1493
1230
|
Can be a single dimension name or list (only first dimension used).
|
|
@@ -1501,6 +1238,11 @@ def heatmap_with_plotly(
|
|
|
1501
1238
|
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1502
1239
|
- None: Disable time reshaping (will error if only 1D time data)
|
|
1503
1240
|
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1241
|
+
**imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow.
|
|
1242
|
+
Common options include:
|
|
1243
|
+
- aspect: 'auto', 'equal', or a number for aspect ratio
|
|
1244
|
+
- zmin, zmax: Minimum and maximum values for color scale
|
|
1245
|
+
- labels: Dict to customize axis labels
|
|
1504
1246
|
|
|
1505
1247
|
Returns:
|
|
1506
1248
|
A Plotly figure object containing the heatmap visualization.
|
|
@@ -1538,6 +1280,13 @@ def heatmap_with_plotly(
|
|
|
1538
1280
|
fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
|
|
1539
1281
|
```
|
|
1540
1282
|
"""
|
|
1283
|
+
if colors is None:
|
|
1284
|
+
colors = CONFIG.Plotting.default_sequential_colorscale
|
|
1285
|
+
|
|
1286
|
+
# Apply CONFIG defaults if not explicitly set
|
|
1287
|
+
if facet_cols is None:
|
|
1288
|
+
facet_cols = CONFIG.Plotting.default_facet_cols
|
|
1289
|
+
|
|
1541
1290
|
# Handle empty data
|
|
1542
1291
|
if data.size == 0:
|
|
1543
1292
|
return go.Figure()
|
|
@@ -1589,12 +1338,26 @@ def heatmap_with_plotly(
|
|
|
1589
1338
|
heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
|
|
1590
1339
|
|
|
1591
1340
|
if len(heatmap_dims) < 2:
|
|
1592
|
-
#
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1341
|
+
# Handle single-dimension case by adding variable name as a dimension
|
|
1342
|
+
if len(heatmap_dims) == 1:
|
|
1343
|
+
# Get the variable name, or use a default
|
|
1344
|
+
var_name = data.name if data.name else 'value'
|
|
1345
|
+
|
|
1346
|
+
# Expand the DataArray by adding a new dimension with the variable name
|
|
1347
|
+
data = data.expand_dims({'variable': [var_name]})
|
|
1348
|
+
|
|
1349
|
+
# Update available dimensions
|
|
1350
|
+
available_dims = list(data.dims)
|
|
1351
|
+
heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
|
|
1352
|
+
|
|
1353
|
+
logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}')
|
|
1354
|
+
else:
|
|
1355
|
+
# No dimensions at all - cannot create a heatmap
|
|
1356
|
+
logger.error(
|
|
1357
|
+
f'Heatmap requires at least 1 dimension. '
|
|
1358
|
+
f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
|
|
1359
|
+
)
|
|
1360
|
+
return go.Figure()
|
|
1598
1361
|
|
|
1599
1362
|
# Setup faceting parameters for Plotly Express
|
|
1600
1363
|
# Note: px.imshow only supports facet_col, not facet_row
|
|
@@ -1617,7 +1380,7 @@ def heatmap_with_plotly(
|
|
|
1617
1380
|
# Create the imshow plot - px.imshow can work directly with xarray DataArrays
|
|
1618
1381
|
common_args = {
|
|
1619
1382
|
'img': data,
|
|
1620
|
-
'color_continuous_scale': colors
|
|
1383
|
+
'color_continuous_scale': colors,
|
|
1621
1384
|
'title': title,
|
|
1622
1385
|
}
|
|
1623
1386
|
|
|
@@ -1631,38 +1394,39 @@ def heatmap_with_plotly(
|
|
|
1631
1394
|
if animate_by:
|
|
1632
1395
|
common_args['animation_frame'] = animate_by
|
|
1633
1396
|
|
|
1397
|
+
# Merge in additional imshow kwargs
|
|
1398
|
+
common_args.update(imshow_kwargs)
|
|
1399
|
+
|
|
1634
1400
|
try:
|
|
1635
1401
|
fig = px.imshow(**common_args)
|
|
1636
1402
|
except Exception as e:
|
|
1637
1403
|
logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
|
|
1638
1404
|
# Fallback: create a simple heatmap without faceting
|
|
1639
|
-
|
|
1640
|
-
data.values,
|
|
1641
|
-
color_continuous_scale
|
|
1642
|
-
title
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
fig.update_layout(
|
|
1647
|
-
plot_bgcolor='rgba(0,0,0,0)',
|
|
1648
|
-
paper_bgcolor='rgba(0,0,0,0)',
|
|
1649
|
-
font=dict(size=12),
|
|
1650
|
-
)
|
|
1405
|
+
fallback_args = {
|
|
1406
|
+
'img': data.values,
|
|
1407
|
+
'color_continuous_scale': colors,
|
|
1408
|
+
'title': title,
|
|
1409
|
+
}
|
|
1410
|
+
fallback_args.update(imshow_kwargs)
|
|
1411
|
+
fig = px.imshow(**fallback_args)
|
|
1651
1412
|
|
|
1652
1413
|
return fig
|
|
1653
1414
|
|
|
1654
1415
|
|
|
1655
1416
|
def heatmap_with_matplotlib(
|
|
1656
1417
|
data: xr.DataArray,
|
|
1657
|
-
colors: ColorType =
|
|
1418
|
+
colors: ColorType | None = None,
|
|
1658
1419
|
title: str = '',
|
|
1659
1420
|
figsize: tuple[float, float] = (12, 6),
|
|
1660
|
-
fig: plt.Figure | None = None,
|
|
1661
|
-
ax: plt.Axes | None = None,
|
|
1662
1421
|
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1663
1422
|
| Literal['auto']
|
|
1664
1423
|
| None = 'auto',
|
|
1665
1424
|
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1425
|
+
vmin: float | None = None,
|
|
1426
|
+
vmax: float | None = None,
|
|
1427
|
+
imshow_kwargs: dict[str, Any] | None = None,
|
|
1428
|
+
cbar_kwargs: dict[str, Any] | None = None,
|
|
1429
|
+
**kwargs: Any,
|
|
1666
1430
|
) -> tuple[plt.Figure, plt.Axes]:
|
|
1667
1431
|
"""
|
|
1668
1432
|
Plot a heatmap visualization using Matplotlib's imshow.
|
|
@@ -1674,16 +1438,25 @@ def heatmap_with_matplotlib(
|
|
|
1674
1438
|
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1675
1439
|
2 dimensions. If more than 2 dimensions exist, additional dimensions will
|
|
1676
1440
|
be reduced by taking the first slice.
|
|
1677
|
-
colors: Color specification. Should be a
|
|
1441
|
+
colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu').
|
|
1678
1442
|
title: The title of the heatmap.
|
|
1679
1443
|
figsize: The size of the figure (width, height) in inches.
|
|
1680
|
-
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1681
|
-
ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
|
|
1682
1444
|
reshape_time: Time reshaping configuration:
|
|
1683
1445
|
- 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
|
|
1684
1446
|
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1685
1447
|
- None: Disable time reshaping
|
|
1686
1448
|
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1449
|
+
vmin: Minimum value for color scale. If None, uses data minimum.
|
|
1450
|
+
vmax: Maximum value for color scale. If None, uses data maximum.
|
|
1451
|
+
imshow_kwargs: Optional dict of parameters to pass to ax.imshow().
|
|
1452
|
+
Use this to customize image properties (e.g., interpolation, aspect).
|
|
1453
|
+
cbar_kwargs: Optional dict of parameters to pass to plt.colorbar().
|
|
1454
|
+
Use this to customize colorbar properties (e.g., orientation, label).
|
|
1455
|
+
**kwargs: Additional keyword arguments passed to ax.imshow().
|
|
1456
|
+
Common options include:
|
|
1457
|
+
- interpolation: 'nearest', 'bilinear', 'bicubic', etc.
|
|
1458
|
+
- alpha: Transparency level (0-1)
|
|
1459
|
+
- extent: [left, right, bottom, top] for axis limits
|
|
1687
1460
|
|
|
1688
1461
|
Returns:
|
|
1689
1462
|
A tuple containing the Matplotlib figure and axes objects used for the plot.
|
|
@@ -1705,19 +1478,36 @@ def heatmap_with_matplotlib(
|
|
|
1705
1478
|
fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
|
|
1706
1479
|
```
|
|
1707
1480
|
"""
|
|
1481
|
+
if colors is None:
|
|
1482
|
+
colors = CONFIG.Plotting.default_sequential_colorscale
|
|
1483
|
+
|
|
1484
|
+
# Initialize kwargs if not provided
|
|
1485
|
+
if imshow_kwargs is None:
|
|
1486
|
+
imshow_kwargs = {}
|
|
1487
|
+
if cbar_kwargs is None:
|
|
1488
|
+
cbar_kwargs = {}
|
|
1489
|
+
|
|
1490
|
+
# Merge any additional kwargs into imshow_kwargs
|
|
1491
|
+
# This allows users to pass imshow options directly
|
|
1492
|
+
imshow_kwargs.update(kwargs)
|
|
1493
|
+
|
|
1708
1494
|
# Handle empty data
|
|
1709
1495
|
if data.size == 0:
|
|
1710
|
-
|
|
1711
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
1496
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1712
1497
|
return fig, ax
|
|
1713
1498
|
|
|
1714
1499
|
# Apply time reshaping using the new unified function
|
|
1715
1500
|
# Matplotlib doesn't support faceting/animation, so we pass None for those
|
|
1716
1501
|
data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
|
|
1717
1502
|
|
|
1718
|
-
#
|
|
1719
|
-
if
|
|
1720
|
-
|
|
1503
|
+
# Handle single-dimension case by adding variable name as a dimension
|
|
1504
|
+
if isinstance(data, xr.DataArray) and len(data.dims) == 1:
|
|
1505
|
+
var_name = data.name if data.name else 'value'
|
|
1506
|
+
data = data.expand_dims({'variable': [var_name]})
|
|
1507
|
+
logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}')
|
|
1508
|
+
|
|
1509
|
+
# Create figure and axes
|
|
1510
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1721
1511
|
|
|
1722
1512
|
# Extract data values
|
|
1723
1513
|
# If data has more than 2 dimensions, we need to reduce it
|
|
@@ -1742,15 +1532,19 @@ def heatmap_with_matplotlib(
|
|
|
1742
1532
|
x_labels = 'x'
|
|
1743
1533
|
y_labels = 'y'
|
|
1744
1534
|
|
|
1745
|
-
#
|
|
1746
|
-
|
|
1535
|
+
# Create the heatmap using imshow with user customizations
|
|
1536
|
+
imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax}
|
|
1537
|
+
imshow_defaults.update(imshow_kwargs) # User kwargs override defaults
|
|
1538
|
+
im = ax.imshow(values, **imshow_defaults)
|
|
1747
1539
|
|
|
1748
|
-
#
|
|
1749
|
-
|
|
1540
|
+
# Add colorbar with user customizations
|
|
1541
|
+
cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05}
|
|
1542
|
+
cbar_defaults.update(cbar_kwargs) # User kwargs override defaults
|
|
1543
|
+
cbar = plt.colorbar(im, **cbar_defaults)
|
|
1750
1544
|
|
|
1751
|
-
#
|
|
1752
|
-
|
|
1753
|
-
|
|
1545
|
+
# Set colorbar label if not overridden by user
|
|
1546
|
+
if 'label' not in cbar_kwargs:
|
|
1547
|
+
cbar.set_label('Value')
|
|
1754
1548
|
|
|
1755
1549
|
# Set labels and title
|
|
1756
1550
|
ax.set_xlabel(str(x_labels).capitalize())
|
|
@@ -1768,8 +1562,9 @@ def export_figure(
|
|
|
1768
1562
|
default_path: pathlib.Path,
|
|
1769
1563
|
default_filetype: str | None = None,
|
|
1770
1564
|
user_path: pathlib.Path | None = None,
|
|
1771
|
-
show: bool =
|
|
1565
|
+
show: bool | None = None,
|
|
1772
1566
|
save: bool = False,
|
|
1567
|
+
dpi: int | None = None,
|
|
1773
1568
|
) -> go.Figure | tuple[plt.Figure, plt.Axes]:
|
|
1774
1569
|
"""
|
|
1775
1570
|
Export a figure to a file and or show it.
|
|
@@ -1779,13 +1574,21 @@ def export_figure(
|
|
|
1779
1574
|
default_path: The default file path if no user filename is provided.
|
|
1780
1575
|
default_filetype: The default filetype if the path doesnt end with a filetype.
|
|
1781
1576
|
user_path: An optional user-specified file path.
|
|
1782
|
-
show: Whether to display the figure (default:
|
|
1577
|
+
show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None).
|
|
1783
1578
|
save: Whether to save the figure (default: False).
|
|
1579
|
+
dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi.
|
|
1784
1580
|
|
|
1785
1581
|
Raises:
|
|
1786
1582
|
ValueError: If no default filetype is provided and the path doesn't specify a filetype.
|
|
1787
1583
|
TypeError: If the figure type is not supported.
|
|
1788
1584
|
"""
|
|
1585
|
+
# Apply CONFIG defaults if not explicitly set
|
|
1586
|
+
if show is None:
|
|
1587
|
+
show = CONFIG.Plotting.default_show
|
|
1588
|
+
|
|
1589
|
+
if dpi is None:
|
|
1590
|
+
dpi = CONFIG.Plotting.default_dpi
|
|
1591
|
+
|
|
1789
1592
|
filename = user_path or default_path
|
|
1790
1593
|
filename = filename.with_name(filename.name.replace('|', '__'))
|
|
1791
1594
|
if filename.suffix == '':
|
|
@@ -1800,25 +1603,17 @@ def export_figure(
|
|
|
1800
1603
|
filename = filename.with_suffix('.html')
|
|
1801
1604
|
|
|
1802
1605
|
try:
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
#
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
plotly.offline.plot(fig, filename=str(filename))
|
|
1815
|
-
elif save and not show:
|
|
1816
|
-
# Save without opening
|
|
1817
|
-
fig.write_html(str(filename))
|
|
1818
|
-
elif show and not save:
|
|
1819
|
-
# Show interactively without saving
|
|
1820
|
-
fig.show()
|
|
1821
|
-
# If neither save nor show: do nothing
|
|
1606
|
+
# Respect show and save flags (tests should set CONFIG.Plotting.default_show=False)
|
|
1607
|
+
if save and show:
|
|
1608
|
+
# Save and auto-open in browser
|
|
1609
|
+
plotly.offline.plot(fig, filename=str(filename))
|
|
1610
|
+
elif save and not show:
|
|
1611
|
+
# Save without opening
|
|
1612
|
+
fig.write_html(str(filename))
|
|
1613
|
+
elif show and not save:
|
|
1614
|
+
# Show interactively without saving
|
|
1615
|
+
fig.show()
|
|
1616
|
+
# If neither save nor show: do nothing
|
|
1822
1617
|
finally:
|
|
1823
1618
|
# Cleanup to prevent socket warnings
|
|
1824
1619
|
if hasattr(fig, '_renderer'):
|
|
@@ -1829,16 +1624,15 @@ def export_figure(
|
|
|
1829
1624
|
elif isinstance(figure_like, tuple):
|
|
1830
1625
|
fig, ax = figure_like
|
|
1831
1626
|
if show:
|
|
1832
|
-
# Only show if using interactive backend
|
|
1627
|
+
# Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False)
|
|
1833
1628
|
backend = matplotlib.get_backend().lower()
|
|
1834
1629
|
is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'}
|
|
1835
|
-
is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
|
|
1836
1630
|
|
|
1837
|
-
if is_interactive
|
|
1631
|
+
if is_interactive:
|
|
1838
1632
|
plt.show()
|
|
1839
1633
|
|
|
1840
1634
|
if save:
|
|
1841
|
-
fig.savefig(str(filename), dpi=
|
|
1635
|
+
fig.savefig(str(filename), dpi=dpi)
|
|
1842
1636
|
plt.close(fig) # Close figure to free memory
|
|
1843
1637
|
|
|
1844
1638
|
return fig, ax
|