flixopt 3.0.1__py3-none-any.whl → 6.0.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flixopt/__init__.py +57 -49
- flixopt/carrier.py +159 -0
- flixopt/clustering/__init__.py +51 -0
- flixopt/clustering/base.py +1746 -0
- flixopt/clustering/intercluster_helpers.py +201 -0
- flixopt/color_processing.py +372 -0
- flixopt/comparison.py +819 -0
- flixopt/components.py +848 -270
- flixopt/config.py +853 -496
- flixopt/core.py +111 -98
- flixopt/effects.py +294 -284
- flixopt/elements.py +484 -223
- flixopt/features.py +220 -118
- flixopt/flow_system.py +2026 -389
- flixopt/interface.py +504 -286
- flixopt/io.py +1718 -55
- flixopt/linear_converters.py +291 -230
- flixopt/modeling.py +304 -181
- flixopt/network_app.py +2 -1
- flixopt/optimization.py +788 -0
- flixopt/optimize_accessor.py +373 -0
- flixopt/plot_result.py +143 -0
- flixopt/plotting.py +1177 -1034
- flixopt/results.py +1331 -372
- flixopt/solvers.py +12 -4
- flixopt/statistics_accessor.py +2412 -0
- flixopt/stats_accessor.py +75 -0
- flixopt/structure.py +954 -120
- flixopt/topology_accessor.py +676 -0
- flixopt/transform_accessor.py +2277 -0
- flixopt/types.py +120 -0
- flixopt-6.0.0rc7.dist-info/METADATA +290 -0
- flixopt-6.0.0rc7.dist-info/RECORD +36 -0
- {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/WHEEL +1 -1
- flixopt/aggregation.py +0 -382
- flixopt/calculation.py +0 -672
- flixopt/commons.py +0 -51
- flixopt/utils.py +0 -86
- flixopt-3.0.1.dist-info/METADATA +0 -209
- flixopt-3.0.1.dist-info/RECORD +0 -26
- {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/licenses/LICENSE +0 -0
- {flixopt-3.0.1.dist-info → flixopt-6.0.0rc7.dist-info}/top_level.txt +0 -0
flixopt/plotting.py
CHANGED
|
@@ -25,9 +25,7 @@ accessible for standalone data visualization tasks.
|
|
|
25
25
|
|
|
26
26
|
from __future__ import annotations
|
|
27
27
|
|
|
28
|
-
import itertools
|
|
29
28
|
import logging
|
|
30
|
-
import os
|
|
31
29
|
import pathlib
|
|
32
30
|
from typing import TYPE_CHECKING, Any, Literal
|
|
33
31
|
|
|
@@ -39,14 +37,17 @@ import pandas as pd
|
|
|
39
37
|
import plotly.express as px
|
|
40
38
|
import plotly.graph_objects as go
|
|
41
39
|
import plotly.offline
|
|
42
|
-
|
|
40
|
+
import xarray as xr
|
|
41
|
+
|
|
42
|
+
from .color_processing import ColorType, process_colors
|
|
43
|
+
from .config import CONFIG
|
|
43
44
|
|
|
44
45
|
if TYPE_CHECKING:
|
|
45
46
|
import pyvis
|
|
46
47
|
|
|
47
48
|
logger = logging.getLogger('flixopt')
|
|
48
49
|
|
|
49
|
-
# Define the colors for the 'portland'
|
|
50
|
+
# Define the colors for the 'portland' colorscale in matplotlib
|
|
50
51
|
_portland_colors = [
|
|
51
52
|
[12 / 255, 51 / 255, 131 / 255], # Dark blue
|
|
52
53
|
[10 / 255, 136 / 255, 186 / 255], # Light blue
|
|
@@ -55,7 +56,7 @@ _portland_colors = [
|
|
|
55
56
|
[217 / 255, 30 / 255, 30 / 255], # Red
|
|
56
57
|
]
|
|
57
58
|
|
|
58
|
-
# Check if the
|
|
59
|
+
# Check if the colorscale already exists before registering it
|
|
59
60
|
if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7
|
|
60
61
|
registry = plt.colormaps
|
|
61
62
|
if 'portland' not in registry:
|
|
@@ -65,486 +66,524 @@ else: # Matplotlib < 3.7
|
|
|
65
66
|
plt.register_cmap(name='portland', cmap=mcolors.LinearSegmentedColormap.from_list('portland', _portland_colors))
|
|
66
67
|
|
|
67
68
|
|
|
68
|
-
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
Color specifications can take several forms to accommodate different use cases:
|
|
72
|
-
|
|
73
|
-
**Named Colormaps** (str):
|
|
74
|
-
- Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1'
|
|
75
|
-
- Energy-focused: 'portland' (custom flixopt colormap for energy systems)
|
|
76
|
-
- Backend-specific maps available in Plotly and Matplotlib
|
|
77
|
-
|
|
78
|
-
**Color Lists** (list[str]):
|
|
79
|
-
- Explicit color sequences: ['red', 'blue', 'green', 'orange']
|
|
80
|
-
- HEX codes: ['#FF0000', '#0000FF', '#00FF00', '#FFA500']
|
|
81
|
-
- Mixed formats: ['red', '#0000FF', 'green', 'orange']
|
|
69
|
+
PlottingEngine = Literal['plotly', 'matplotlib']
|
|
70
|
+
"""Identifier for the plotting engine to use."""
|
|
82
71
|
|
|
83
|
-
**Label-to-Color Mapping** (dict[str, str]):
|
|
84
|
-
- Explicit associations: {'Wind': 'skyblue', 'Solar': 'gold', 'Gas': 'brown'}
|
|
85
|
-
- Ensures consistent colors across different plots and datasets
|
|
86
|
-
- Ideal for energy system components with semantic meaning
|
|
87
72
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
73
|
+
def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset:
|
|
74
|
+
"""Convert DataFrame or Series to Dataset if needed."""
|
|
75
|
+
if isinstance(data, xr.Dataset):
|
|
76
|
+
return data
|
|
77
|
+
elif isinstance(data, pd.DataFrame):
|
|
78
|
+
# Convert DataFrame to Dataset
|
|
79
|
+
return data.to_xarray()
|
|
80
|
+
elif isinstance(data, pd.Series):
|
|
81
|
+
# Convert Series to DataFrame first, then to Dataset
|
|
82
|
+
return data.to_frame().to_xarray()
|
|
83
|
+
else:
|
|
84
|
+
raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}')
|
|
92
85
|
|
|
93
|
-
# Explicit color list
|
|
94
|
-
colors = ['red', 'blue', 'green', '#FFD700']
|
|
95
86
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
'
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
87
|
+
def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None:
|
|
88
|
+
"""Validate dataset for plotting (checks for empty data, non-numeric types, etc.)."""
|
|
89
|
+
# Check for empty data
|
|
90
|
+
if not allow_empty and len(data.data_vars) == 0:
|
|
91
|
+
raise ValueError('Empty Dataset provided (no variables). Cannot create plot.')
|
|
92
|
+
|
|
93
|
+
# Check if dataset has any data (xarray uses nbytes for total size)
|
|
94
|
+
if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True:
|
|
95
|
+
if not allow_empty and len(data.data_vars) > 0:
|
|
96
|
+
raise ValueError('Dataset has zero size. Cannot create plot.')
|
|
97
|
+
if len(data.data_vars) == 0:
|
|
98
|
+
return # Empty dataset, nothing to validate
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
# Check for non-numeric data types
|
|
102
|
+
for var in data.data_vars:
|
|
103
|
+
dtype = data[var].dtype
|
|
104
|
+
if not np.issubdtype(dtype, np.number):
|
|
105
|
+
raise TypeError(
|
|
106
|
+
f"Variable '{var}' has non-numeric dtype '{dtype}'. "
|
|
107
|
+
f'Plotting requires numeric data types (int, float, etc.).'
|
|
108
|
+
)
|
|
117
109
|
|
|
118
|
-
|
|
119
|
-
|
|
110
|
+
# Warn about NaN/Inf values
|
|
111
|
+
for var in data.data_vars:
|
|
112
|
+
if np.isnan(data[var].values).any():
|
|
113
|
+
logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.")
|
|
114
|
+
if np.isinf(data[var].values).any():
|
|
115
|
+
logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.")
|
|
120
116
|
|
|
121
117
|
|
|
122
|
-
|
|
123
|
-
|
|
118
|
+
def with_plotly(
|
|
119
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
120
|
+
mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
|
|
121
|
+
colors: ColorType | None = None,
|
|
122
|
+
title: str = '',
|
|
123
|
+
ylabel: str = '',
|
|
124
|
+
xlabel: str = '',
|
|
125
|
+
facet_by: str | list[str] | None = None,
|
|
126
|
+
animate_by: str | None = None,
|
|
127
|
+
facet_cols: int | None = None,
|
|
128
|
+
shared_yaxes: bool = True,
|
|
129
|
+
shared_xaxes: bool = True,
|
|
130
|
+
**px_kwargs: Any,
|
|
131
|
+
) -> go.Figure:
|
|
132
|
+
"""
|
|
133
|
+
Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
|
|
124
134
|
|
|
125
|
-
|
|
126
|
-
ensuring consistent visual appearance regardless of the plotting engine used.
|
|
127
|
-
It handles color palette generation, named colormap translation, and intelligent
|
|
128
|
-
color cycling for complex datasets with many categories.
|
|
135
|
+
Uses Plotly Express for convenient faceting and animation with automatic styling.
|
|
129
136
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
137
|
+
Args:
|
|
138
|
+
data: An xarray Dataset, pandas DataFrame, or pandas Series to plot.
|
|
139
|
+
mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
|
|
140
|
+
'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
|
|
141
|
+
colors: Color specification (colorscale, list, or dict mapping labels to colors).
|
|
142
|
+
title: The main title of the plot.
|
|
143
|
+
ylabel: The label for the y-axis.
|
|
144
|
+
xlabel: The label for the x-axis.
|
|
145
|
+
facet_by: Dimension(s) to create facets for. Creates a subplot grid.
|
|
146
|
+
Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
|
|
147
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
148
|
+
animate_by: Dimension to animate over. Creates animation frames.
|
|
149
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
150
|
+
facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
|
|
151
|
+
shared_yaxes: Whether subplots share y-axes.
|
|
152
|
+
shared_xaxes: Whether subplots share x-axes.
|
|
153
|
+
**px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function
|
|
154
|
+
(px.bar, px.line, px.area). These override default arguments if provided.
|
|
155
|
+
Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear'
|
|
136
156
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
- **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'}
|
|
157
|
+
Returns:
|
|
158
|
+
A Plotly figure object containing the faceted/animated plot. You can further customize
|
|
159
|
+
the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()).
|
|
141
160
|
|
|
142
161
|
Examples:
|
|
143
|
-
|
|
162
|
+
Simple plot:
|
|
144
163
|
|
|
145
164
|
```python
|
|
146
|
-
|
|
147
|
-
processor = ColorProcessor(engine='plotly', default_colormap='viridis')
|
|
148
|
-
|
|
149
|
-
# Process different color specifications
|
|
150
|
-
colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage'])
|
|
151
|
-
colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C'])
|
|
152
|
-
colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas'])
|
|
153
|
-
|
|
154
|
-
# Switch to Matplotlib
|
|
155
|
-
processor = ColorProcessor(engine='matplotlib')
|
|
156
|
-
mpl_colors = processor.process_colors('tab10', component_labels)
|
|
165
|
+
fig = with_plotly(dataset, mode='area', title='Energy Mix')
|
|
157
166
|
```
|
|
158
167
|
|
|
159
|
-
|
|
168
|
+
Facet by scenario:
|
|
160
169
|
|
|
161
170
|
```python
|
|
162
|
-
|
|
163
|
-
energy_colors = {
|
|
164
|
-
'Natural_Gas': '#8B4513', # Brown
|
|
165
|
-
'Electricity': '#FFD700', # Gold
|
|
166
|
-
'Heat': '#FF4500', # Red-orange
|
|
167
|
-
'Cooling': '#87CEEB', # Sky blue
|
|
168
|
-
'Hydrogen': '#E6E6FA', # Lavender
|
|
169
|
-
'Battery': '#32CD32', # Lime green
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
processor = ColorProcessor('plotly')
|
|
173
|
-
flow_colors = processor.process_colors(energy_colors, flow_labels)
|
|
171
|
+
fig = with_plotly(dataset, facet_by='scenario', facet_cols=2)
|
|
174
172
|
```
|
|
175
173
|
|
|
176
|
-
|
|
177
|
-
engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format.
|
|
178
|
-
default_colormap: Fallback colormap when requested palettes are unavailable.
|
|
179
|
-
Common options: 'viridis', 'plasma', 'tab10', 'portland'.
|
|
174
|
+
Animate by period:
|
|
180
175
|
|
|
181
|
-
|
|
176
|
+
```python
|
|
177
|
+
fig = with_plotly(dataset, animate_by='period')
|
|
178
|
+
```
|
|
182
179
|
|
|
183
|
-
|
|
184
|
-
"""Initialize the color processor with specified backend and defaults."""
|
|
185
|
-
if engine not in ['plotly', 'matplotlib']:
|
|
186
|
-
raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}')
|
|
187
|
-
self.engine = engine
|
|
188
|
-
self.default_colormap = default_colormap
|
|
189
|
-
|
|
190
|
-
def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]:
|
|
191
|
-
"""
|
|
192
|
-
Generate colors from a named colormap.
|
|
193
|
-
|
|
194
|
-
Args:
|
|
195
|
-
colormap_name: Name of the colormap
|
|
196
|
-
num_colors: Number of colors to generate
|
|
197
|
-
|
|
198
|
-
Returns:
|
|
199
|
-
list of colors in the format appropriate for the engine
|
|
200
|
-
"""
|
|
201
|
-
if self.engine == 'plotly':
|
|
202
|
-
try:
|
|
203
|
-
colorscale = px.colors.get_colorscale(colormap_name)
|
|
204
|
-
except PlotlyError as e:
|
|
205
|
-
logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}")
|
|
206
|
-
colorscale = px.colors.get_colorscale(self.default_colormap)
|
|
207
|
-
|
|
208
|
-
# Generate evenly spaced points
|
|
209
|
-
color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0]
|
|
210
|
-
return px.colors.sample_colorscale(colorscale, color_points)
|
|
211
|
-
|
|
212
|
-
else: # matplotlib
|
|
213
|
-
try:
|
|
214
|
-
cmap = plt.get_cmap(colormap_name, num_colors)
|
|
215
|
-
except ValueError as e:
|
|
216
|
-
logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}")
|
|
217
|
-
cmap = plt.get_cmap(self.default_colormap, num_colors)
|
|
218
|
-
|
|
219
|
-
return [cmap(i) for i in range(num_colors)]
|
|
220
|
-
|
|
221
|
-
def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]:
|
|
222
|
-
"""
|
|
223
|
-
Handle a list of colors, cycling if necessary.
|
|
224
|
-
|
|
225
|
-
Args:
|
|
226
|
-
colors: list of color strings
|
|
227
|
-
num_labels: Number of labels that need colors
|
|
228
|
-
|
|
229
|
-
Returns:
|
|
230
|
-
list of colors matching the number of labels
|
|
231
|
-
"""
|
|
232
|
-
if len(colors) == 0:
|
|
233
|
-
logger.error(f'Empty color list provided. Using {self.default_colormap} instead.')
|
|
234
|
-
return self._generate_colors_from_colormap(self.default_colormap, num_labels)
|
|
235
|
-
|
|
236
|
-
if len(colors) < num_labels:
|
|
237
|
-
logger.warning(
|
|
238
|
-
f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.'
|
|
239
|
-
)
|
|
240
|
-
# Cycle through the colors
|
|
241
|
-
color_iter = itertools.cycle(colors)
|
|
242
|
-
return [next(color_iter) for _ in range(num_labels)]
|
|
243
|
-
else:
|
|
244
|
-
# Trim if necessary
|
|
245
|
-
if len(colors) > num_labels:
|
|
246
|
-
logger.warning(
|
|
247
|
-
f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.'
|
|
248
|
-
)
|
|
249
|
-
return colors[:num_labels]
|
|
250
|
-
|
|
251
|
-
def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]:
|
|
252
|
-
"""
|
|
253
|
-
Handle a dictionary mapping labels to colors.
|
|
254
|
-
|
|
255
|
-
Args:
|
|
256
|
-
colors: Dictionary mapping labels to colors
|
|
257
|
-
labels: list of labels that need colors
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
list of colors in the same order as labels
|
|
261
|
-
"""
|
|
262
|
-
if len(colors) == 0:
|
|
263
|
-
logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.')
|
|
264
|
-
return self._generate_colors_from_colormap(self.default_colormap, len(labels))
|
|
265
|
-
|
|
266
|
-
# Find missing labels
|
|
267
|
-
missing_labels = sorted(set(labels) - set(colors.keys()))
|
|
268
|
-
if missing_labels:
|
|
269
|
-
logger.warning(
|
|
270
|
-
f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.'
|
|
271
|
-
)
|
|
180
|
+
Facet and animate:
|
|
272
181
|
|
|
273
|
-
|
|
274
|
-
|
|
182
|
+
```python
|
|
183
|
+
fig = with_plotly(dataset, facet_by='scenario', animate_by='period')
|
|
184
|
+
```
|
|
275
185
|
|
|
276
|
-
|
|
277
|
-
colors_copy = colors.copy()
|
|
278
|
-
for i, label in enumerate(missing_labels):
|
|
279
|
-
colors_copy[label] = missing_colors[i]
|
|
280
|
-
else:
|
|
281
|
-
colors_copy = colors
|
|
282
|
-
|
|
283
|
-
# Create color list in the same order as labels
|
|
284
|
-
return [colors_copy[label] for label in labels]
|
|
285
|
-
|
|
286
|
-
def process_colors(
|
|
287
|
-
self,
|
|
288
|
-
colors: ColorType,
|
|
289
|
-
labels: list[str],
|
|
290
|
-
return_mapping: bool = False,
|
|
291
|
-
) -> list[Any] | dict[str, Any]:
|
|
292
|
-
"""
|
|
293
|
-
Process colors for the specified labels.
|
|
294
|
-
|
|
295
|
-
Args:
|
|
296
|
-
colors: Color specification (colormap name, list of colors, or label-to-color mapping)
|
|
297
|
-
labels: list of data labels that need colors assigned
|
|
298
|
-
return_mapping: If True, returns a dictionary mapping labels to colors;
|
|
299
|
-
if False, returns a list of colors in the same order as labels
|
|
300
|
-
|
|
301
|
-
Returns:
|
|
302
|
-
Either a list of colors or a dictionary mapping labels to colors
|
|
303
|
-
"""
|
|
304
|
-
if len(labels) == 0:
|
|
305
|
-
logger.error('No labels provided for color assignment.')
|
|
306
|
-
return {} if return_mapping else []
|
|
307
|
-
|
|
308
|
-
# Process based on type of colors input
|
|
309
|
-
if isinstance(colors, str):
|
|
310
|
-
color_list = self._generate_colors_from_colormap(colors, len(labels))
|
|
311
|
-
elif isinstance(colors, list):
|
|
312
|
-
color_list = self._handle_color_list(colors, len(labels))
|
|
313
|
-
elif isinstance(colors, dict):
|
|
314
|
-
color_list = self._handle_color_dict(colors, labels)
|
|
315
|
-
else:
|
|
316
|
-
logger.error(
|
|
317
|
-
f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.'
|
|
318
|
-
)
|
|
319
|
-
color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels))
|
|
186
|
+
Customize with Plotly Express kwargs:
|
|
320
187
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
else:
|
|
325
|
-
return color_list
|
|
188
|
+
```python
|
|
189
|
+
fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear')
|
|
190
|
+
```
|
|
326
191
|
|
|
192
|
+
Further customize the returned figure:
|
|
327
193
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
ylabel: str = '',
|
|
334
|
-
xlabel: str = 'Time in h',
|
|
335
|
-
fig: go.Figure | None = None,
|
|
336
|
-
) -> go.Figure:
|
|
194
|
+
```python
|
|
195
|
+
fig = with_plotly(dataset, mode='line')
|
|
196
|
+
fig.update_traces(line={'width': 5, 'dash': 'dot'})
|
|
197
|
+
fig.update_layout(template='plotly_dark', width=1200, height=600)
|
|
198
|
+
```
|
|
337
199
|
"""
|
|
338
|
-
|
|
200
|
+
if colors is None:
|
|
201
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
339
202
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
and each column represents a separate data series.
|
|
343
|
-
style: The plotting style. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines,
|
|
344
|
-
or 'area' for stacked area charts.
|
|
345
|
-
colors: Color specification, can be:
|
|
346
|
-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
|
|
347
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
348
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
349
|
-
title: The title of the plot.
|
|
350
|
-
ylabel: The label for the y-axis.
|
|
351
|
-
xlabel: The label for the x-axis.
|
|
352
|
-
fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
|
|
203
|
+
if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
|
|
204
|
+
raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
|
|
353
205
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
if style not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
|
|
358
|
-
raise ValueError(f"'style' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {style!r}")
|
|
359
|
-
if data.empty:
|
|
360
|
-
return go.Figure()
|
|
206
|
+
# Apply CONFIG defaults if not explicitly set
|
|
207
|
+
if facet_cols is None:
|
|
208
|
+
facet_cols = CONFIG.Plotting.default_facet_cols
|
|
361
209
|
|
|
362
|
-
|
|
210
|
+
# Ensure data is a Dataset and validate it
|
|
211
|
+
data = _ensure_dataset(data)
|
|
212
|
+
_validate_plotting_data(data, allow_empty=True)
|
|
363
213
|
|
|
364
|
-
|
|
214
|
+
# Handle empty data
|
|
215
|
+
if len(data.data_vars) == 0:
|
|
216
|
+
logger.error('with_plotly() got an empty Dataset.')
|
|
217
|
+
return go.Figure()
|
|
365
218
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
name=column,
|
|
373
|
-
marker=dict(
|
|
374
|
-
color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)')
|
|
375
|
-
), # Transparent line with 0 width
|
|
376
|
-
)
|
|
377
|
-
)
|
|
219
|
+
# Handle all-scalar datasets (where all variables have no dimensions)
|
|
220
|
+
# This occurs when all variables are scalar values with dims=()
|
|
221
|
+
if all(len(data[var].dims) == 0 for var in data.data_vars):
|
|
222
|
+
# Create a simple DataFrame with variable names as x-axis
|
|
223
|
+
variables = list(data.data_vars.keys())
|
|
224
|
+
values = [float(data[var].values) for var in data.data_vars]
|
|
378
225
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
bargroupgap=0, # No space between grouped bars
|
|
383
|
-
)
|
|
384
|
-
if style == 'grouped_bar':
|
|
385
|
-
for i, column in enumerate(data.columns):
|
|
386
|
-
fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i])))
|
|
387
|
-
|
|
388
|
-
fig.update_layout(
|
|
389
|
-
barmode='group',
|
|
390
|
-
bargap=0.2, # No space between bars
|
|
391
|
-
bargroupgap=0, # space between grouped bars
|
|
226
|
+
# Resolve colors
|
|
227
|
+
color_discrete_map = process_colors(
|
|
228
|
+
colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
392
229
|
)
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
230
|
+
marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables]
|
|
231
|
+
|
|
232
|
+
# Create simple plot based on mode using go (not px) for better color control
|
|
233
|
+
if mode in ('stacked_bar', 'grouped_bar'):
|
|
234
|
+
fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)])
|
|
235
|
+
elif mode == 'line':
|
|
236
|
+
fig = go.Figure(
|
|
237
|
+
data=[
|
|
238
|
+
go.Scatter(
|
|
239
|
+
x=variables,
|
|
240
|
+
y=values,
|
|
241
|
+
mode='lines+markers',
|
|
242
|
+
marker=dict(color=marker_colors, size=8),
|
|
243
|
+
line=dict(color='lightgray'),
|
|
244
|
+
)
|
|
245
|
+
]
|
|
403
246
|
)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
f'Data for plotting stacked lines contains columns with both positive and negative values:'
|
|
416
|
-
f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
|
|
247
|
+
elif mode == 'area':
|
|
248
|
+
fig = go.Figure(
|
|
249
|
+
data=[
|
|
250
|
+
go.Scatter(
|
|
251
|
+
x=variables,
|
|
252
|
+
y=values,
|
|
253
|
+
fill='tozeroy',
|
|
254
|
+
marker=dict(color=marker_colors, size=8),
|
|
255
|
+
line=dict(color='lightgray'),
|
|
256
|
+
)
|
|
257
|
+
]
|
|
417
258
|
)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"')
|
|
261
|
+
|
|
262
|
+
fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False)
|
|
263
|
+
return fig
|
|
264
|
+
|
|
265
|
+
# Convert Dataset to long-form DataFrame for Plotly Express
|
|
266
|
+
# Structure: time, variable, value, scenario, period, ... (all dims as columns)
|
|
267
|
+
dim_names = list(data.dims)
|
|
268
|
+
df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value')
|
|
269
|
+
|
|
270
|
+
# Validate facet_by and animate_by dimensions exist in the data
|
|
271
|
+
available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
|
|
272
|
+
|
|
273
|
+
# Check facet_by dimensions
|
|
274
|
+
if facet_by is not None:
|
|
275
|
+
if isinstance(facet_by, str):
|
|
276
|
+
if facet_by not in available_dims:
|
|
277
|
+
logger.debug(
|
|
278
|
+
f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
|
|
279
|
+
f'Ignoring facet_by parameter.'
|
|
432
280
|
)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
|
|
281
|
+
facet_by = None
|
|
282
|
+
elif isinstance(facet_by, list):
|
|
283
|
+
# Filter out dimensions that don't exist
|
|
284
|
+
missing_dims = [dim for dim in facet_by if dim not in available_dims]
|
|
285
|
+
facet_by = [dim for dim in facet_by if dim in available_dims]
|
|
286
|
+
if missing_dims:
|
|
287
|
+
logger.debug(
|
|
288
|
+
f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
|
|
289
|
+
f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
|
|
443
290
|
)
|
|
444
|
-
)
|
|
291
|
+
if len(facet_by) == 0:
|
|
292
|
+
facet_by = None
|
|
293
|
+
|
|
294
|
+
# Check animate_by dimension
|
|
295
|
+
if animate_by is not None and animate_by not in available_dims:
|
|
296
|
+
logger.debug(
|
|
297
|
+
f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
|
|
298
|
+
f'Ignoring animate_by parameter.'
|
|
299
|
+
)
|
|
300
|
+
animate_by = None
|
|
301
|
+
|
|
302
|
+
# Setup faceting parameters for Plotly Express
|
|
303
|
+
facet_row = None
|
|
304
|
+
facet_col = None
|
|
305
|
+
if facet_by:
|
|
306
|
+
if isinstance(facet_by, str):
|
|
307
|
+
# Single facet dimension - use facet_col with facet_col_wrap
|
|
308
|
+
facet_col = facet_by
|
|
309
|
+
elif len(facet_by) == 1:
|
|
310
|
+
facet_col = facet_by[0]
|
|
311
|
+
elif len(facet_by) == 2:
|
|
312
|
+
# Two facet dimensions - use facet_row and facet_col
|
|
313
|
+
facet_row = facet_by[0]
|
|
314
|
+
facet_col = facet_by[1]
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
|
|
445
317
|
|
|
446
|
-
#
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
title=ylabel,
|
|
451
|
-
showgrid=True, # Enable grid lines on the y-axis
|
|
452
|
-
gridcolor='lightgrey', # Customize grid line color
|
|
453
|
-
gridwidth=0.5, # Customize grid line width
|
|
454
|
-
),
|
|
455
|
-
xaxis=dict(
|
|
456
|
-
title=xlabel,
|
|
457
|
-
showgrid=True, # Enable grid lines on the x-axis
|
|
458
|
-
gridcolor='lightgrey', # Customize grid line color
|
|
459
|
-
gridwidth=0.5, # Customize grid line width
|
|
460
|
-
),
|
|
461
|
-
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
|
|
462
|
-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
|
|
463
|
-
font=dict(size=14), # Increase font size for better readability
|
|
318
|
+
# Process colors
|
|
319
|
+
all_vars = df_long['variable'].unique().tolist()
|
|
320
|
+
color_discrete_map = process_colors(
|
|
321
|
+
colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
464
322
|
)
|
|
465
323
|
|
|
324
|
+
# Determine which dimension to use for x-axis
|
|
325
|
+
# Collect dimensions used for faceting and animation
|
|
326
|
+
used_dims = set()
|
|
327
|
+
if facet_row:
|
|
328
|
+
used_dims.add(facet_row)
|
|
329
|
+
if facet_col:
|
|
330
|
+
used_dims.add(facet_col)
|
|
331
|
+
if animate_by:
|
|
332
|
+
used_dims.add(animate_by)
|
|
333
|
+
|
|
334
|
+
# Find available dimensions for x-axis (not used for faceting/animation)
|
|
335
|
+
x_candidates = [d for d in available_dims if d not in used_dims]
|
|
336
|
+
|
|
337
|
+
# Use 'time' if available, otherwise use the first available dimension
|
|
338
|
+
if 'time' in x_candidates:
|
|
339
|
+
x_dim = 'time'
|
|
340
|
+
elif len(x_candidates) > 0:
|
|
341
|
+
x_dim = x_candidates[0]
|
|
342
|
+
else:
|
|
343
|
+
# Fallback: use the first dimension (shouldn't happen in normal cases)
|
|
344
|
+
x_dim = available_dims[0] if available_dims else 'time'
|
|
345
|
+
|
|
346
|
+
# Create plot using Plotly Express based on mode
|
|
347
|
+
common_args = {
|
|
348
|
+
'data_frame': df_long,
|
|
349
|
+
'x': x_dim,
|
|
350
|
+
'y': 'value',
|
|
351
|
+
'color': 'variable',
|
|
352
|
+
'facet_row': facet_row,
|
|
353
|
+
'facet_col': facet_col,
|
|
354
|
+
'animation_frame': animate_by,
|
|
355
|
+
'color_discrete_map': color_discrete_map,
|
|
356
|
+
'title': title,
|
|
357
|
+
'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''},
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
# Add facet_col_wrap for single facet dimension
|
|
361
|
+
if facet_col and not facet_row:
|
|
362
|
+
common_args['facet_col_wrap'] = facet_cols
|
|
363
|
+
|
|
364
|
+
# Add mode-specific defaults (before px_kwargs so they can be overridden)
|
|
365
|
+
if mode in ('line', 'area'):
|
|
366
|
+
common_args['line_shape'] = 'hv' # Stepped lines by default
|
|
367
|
+
|
|
368
|
+
# Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape)
|
|
369
|
+
# These will override the defaults set above
|
|
370
|
+
if px_kwargs:
|
|
371
|
+
common_args.update(px_kwargs)
|
|
372
|
+
|
|
373
|
+
if mode == 'stacked_bar':
|
|
374
|
+
fig = px.bar(**common_args)
|
|
375
|
+
fig.update_traces(marker_line_width=0)
|
|
376
|
+
fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
|
|
377
|
+
elif mode == 'grouped_bar':
|
|
378
|
+
fig = px.bar(**common_args)
|
|
379
|
+
fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
|
|
380
|
+
elif mode == 'line':
|
|
381
|
+
fig = px.line(**common_args)
|
|
382
|
+
elif mode == 'area':
|
|
383
|
+
# Use Plotly Express to create the area plot (preserves animation, legends, faceting)
|
|
384
|
+
fig = px.area(**common_args)
|
|
385
|
+
|
|
386
|
+
# Classify each variable based on its values
|
|
387
|
+
variable_classification = {}
|
|
388
|
+
for var in all_vars:
|
|
389
|
+
var_data = df_long[df_long['variable'] == var]['value']
|
|
390
|
+
var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
|
|
391
|
+
|
|
392
|
+
if len(var_data_clean) == 0:
|
|
393
|
+
variable_classification[var] = 'zero'
|
|
394
|
+
else:
|
|
395
|
+
has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
|
|
396
|
+
variable_classification[var] = (
|
|
397
|
+
'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Log warning for mixed variables
|
|
401
|
+
mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
|
|
402
|
+
if mixed_vars:
|
|
403
|
+
logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
|
|
404
|
+
|
|
405
|
+
all_traces = list(fig.data)
|
|
406
|
+
for frame in fig.frames:
|
|
407
|
+
all_traces.extend(frame.data)
|
|
408
|
+
|
|
409
|
+
for trace in all_traces:
|
|
410
|
+
cls = variable_classification.get(trace.name, None)
|
|
411
|
+
# Only stack positive and negative, not mixed or zero
|
|
412
|
+
trace.stackgroup = cls if cls in ('positive', 'negative') else None
|
|
413
|
+
|
|
414
|
+
if cls in ('positive', 'negative'):
|
|
415
|
+
# Stacked area: add opacity to avoid hiding layers, remove line border
|
|
416
|
+
if hasattr(trace, 'line') and trace.line.color:
|
|
417
|
+
trace.fillcolor = trace.line.color
|
|
418
|
+
trace.line.width = 0
|
|
419
|
+
elif cls == 'mixed':
|
|
420
|
+
# Mixed variables: show as dashed line, not stacked
|
|
421
|
+
if hasattr(trace, 'line'):
|
|
422
|
+
trace.line.width = 2
|
|
423
|
+
trace.line.dash = 'dash'
|
|
424
|
+
if hasattr(trace, 'fill'):
|
|
425
|
+
trace.fill = None
|
|
426
|
+
|
|
427
|
+
# Update axes to share if requested (Plotly Express already handles this, but we can customize)
|
|
428
|
+
if not shared_yaxes:
|
|
429
|
+
fig.update_yaxes(matches=None)
|
|
430
|
+
if not shared_xaxes:
|
|
431
|
+
fig.update_xaxes(matches=None)
|
|
432
|
+
|
|
466
433
|
return fig
|
|
467
434
|
|
|
468
435
|
|
|
469
436
|
def with_matplotlib(
|
|
470
|
-
data: pd.DataFrame,
|
|
471
|
-
|
|
472
|
-
colors: ColorType =
|
|
437
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
438
|
+
mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
|
|
439
|
+
colors: ColorType | None = None,
|
|
473
440
|
title: str = '',
|
|
474
441
|
ylabel: str = '',
|
|
475
442
|
xlabel: str = 'Time in h',
|
|
476
443
|
figsize: tuple[int, int] = (12, 6),
|
|
477
|
-
|
|
478
|
-
ax: plt.Axes | None = None,
|
|
444
|
+
plot_kwargs: dict[str, Any] | None = None,
|
|
479
445
|
) -> tuple[plt.Figure, plt.Axes]:
|
|
480
446
|
"""
|
|
481
|
-
Plot
|
|
447
|
+
Plot data with Matplotlib using stacked bars or stepped lines.
|
|
482
448
|
|
|
483
449
|
Args:
|
|
484
|
-
data:
|
|
485
|
-
and each column represents a separate data series.
|
|
486
|
-
|
|
487
|
-
colors: Color specification
|
|
488
|
-
- A
|
|
450
|
+
data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame,
|
|
451
|
+
the index represents time and each column represents a separate data series (variables).
|
|
452
|
+
mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
|
|
453
|
+
colors: Color specification. Can be:
|
|
454
|
+
- A colorscale name (e.g., 'turbo', 'plasma')
|
|
489
455
|
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
490
|
-
- A
|
|
456
|
+
- A dict mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
491
457
|
title: The title of the plot.
|
|
492
458
|
ylabel: The ylabel of the plot.
|
|
493
459
|
xlabel: The xlabel of the plot.
|
|
494
|
-
figsize: Specify the size of the figure
|
|
495
|
-
|
|
496
|
-
|
|
460
|
+
figsize: Specify the size of the figure (width, height) in inches.
|
|
461
|
+
plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls.
|
|
462
|
+
Use this to customize plot properties (e.g., linewidth, alpha, edgecolor).
|
|
497
463
|
|
|
498
464
|
Returns:
|
|
499
465
|
A tuple containing the Matplotlib figure and axes objects used for the plot.
|
|
500
466
|
|
|
501
467
|
Notes:
|
|
502
|
-
- If `
|
|
468
|
+
- If `mode` is 'stacked_bar', bars are stacked for both positive and negative values.
|
|
503
469
|
Negative values are stacked separately without extra labels in the legend.
|
|
504
|
-
- If `
|
|
470
|
+
- If `mode` is 'line', stepped lines are drawn for each data series.
|
|
505
471
|
"""
|
|
506
|
-
if
|
|
507
|
-
|
|
472
|
+
if colors is None:
|
|
473
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
508
474
|
|
|
509
|
-
if
|
|
510
|
-
|
|
475
|
+
if mode not in ('stacked_bar', 'line'):
|
|
476
|
+
raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
|
|
477
|
+
|
|
478
|
+
# Ensure data is a Dataset and validate it
|
|
479
|
+
data = _ensure_dataset(data)
|
|
480
|
+
_validate_plotting_data(data, allow_empty=True)
|
|
481
|
+
|
|
482
|
+
# Create new figure and axes
|
|
483
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
511
484
|
|
|
512
|
-
|
|
485
|
+
# Initialize plot_kwargs if not provided
|
|
486
|
+
if plot_kwargs is None:
|
|
487
|
+
plot_kwargs = {}
|
|
513
488
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
489
|
+
# Handle all-scalar datasets (where all variables have no dimensions)
|
|
490
|
+
# This occurs when all variables are scalar values with dims=()
|
|
491
|
+
if all(len(data[var].dims) == 0 for var in data.data_vars):
|
|
492
|
+
# Create simple bar/line plot with variable names as x-axis
|
|
493
|
+
variables = list(data.data_vars.keys())
|
|
494
|
+
values = [float(data[var].values) for var in data.data_vars]
|
|
518
495
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
496
|
+
# Resolve colors
|
|
497
|
+
color_discrete_map = process_colors(
|
|
498
|
+
colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
499
|
+
)
|
|
500
|
+
colors_list = [color_discrete_map.get(var, '#808080') for var in variables]
|
|
501
|
+
|
|
502
|
+
# Create plot based on mode
|
|
503
|
+
if mode == 'stacked_bar':
|
|
504
|
+
ax.bar(variables, values, color=colors_list, **plot_kwargs)
|
|
505
|
+
elif mode == 'line':
|
|
506
|
+
ax.plot(
|
|
507
|
+
variables,
|
|
508
|
+
values,
|
|
509
|
+
marker='o',
|
|
510
|
+
color=colors_list[0] if len(set(colors_list)) == 1 else None,
|
|
511
|
+
**plot_kwargs,
|
|
512
|
+
)
|
|
513
|
+
# If different colors, plot each point separately
|
|
514
|
+
if len(set(colors_list)) > 1:
|
|
515
|
+
ax.clear()
|
|
516
|
+
for i, (var, val) in enumerate(zip(variables, values, strict=False)):
|
|
517
|
+
ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs)
|
|
518
|
+
ax.set_xticks(range(len(variables)))
|
|
519
|
+
ax.set_xticklabels(variables)
|
|
520
|
+
|
|
521
|
+
ax.set_xlabel(xlabel, ha='center')
|
|
522
|
+
ax.set_ylabel(ylabel, va='center')
|
|
523
|
+
ax.set_title(title)
|
|
524
|
+
ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y')
|
|
525
|
+
fig.tight_layout()
|
|
526
|
+
|
|
527
|
+
return fig, ax
|
|
528
|
+
|
|
529
|
+
# Resolve colors first (includes validation)
|
|
530
|
+
color_discrete_map = process_colors(
|
|
531
|
+
colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form)
|
|
535
|
+
df = data.to_dataframe()
|
|
536
|
+
|
|
537
|
+
# Get colors in column order
|
|
538
|
+
processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns]
|
|
539
|
+
|
|
540
|
+
if mode == 'stacked_bar':
|
|
541
|
+
cumulative_positive = np.zeros(len(df))
|
|
542
|
+
cumulative_negative = np.zeros(len(df))
|
|
543
|
+
|
|
544
|
+
# Robust bar width: handle datetime-like, numeric, and single-point indexes
|
|
545
|
+
if len(df.index) > 1:
|
|
546
|
+
delta = pd.Index(df.index).to_series().diff().dropna().min()
|
|
547
|
+
if hasattr(delta, 'total_seconds'): # datetime-like
|
|
548
|
+
width = delta.total_seconds() / 86400.0 # Matplotlib date units = days
|
|
549
|
+
else:
|
|
550
|
+
width = float(delta)
|
|
551
|
+
else:
|
|
552
|
+
width = 0.8 # reasonable default for a single bar
|
|
553
|
+
|
|
554
|
+
for i, column in enumerate(df.columns):
|
|
555
|
+
# Fill NaNs to avoid breaking stacking math
|
|
556
|
+
series = df[column].fillna(0)
|
|
557
|
+
positive_values = np.clip(series, 0, None) # Keep only positive values
|
|
558
|
+
negative_values = np.clip(series, None, 0) # Keep only negative values
|
|
522
559
|
# Plot positive bars
|
|
523
560
|
ax.bar(
|
|
524
|
-
|
|
561
|
+
df.index,
|
|
525
562
|
positive_values,
|
|
526
563
|
bottom=cumulative_positive,
|
|
527
564
|
color=processed_colors[i],
|
|
528
565
|
label=column,
|
|
529
566
|
width=width,
|
|
530
567
|
align='center',
|
|
568
|
+
**plot_kwargs,
|
|
531
569
|
)
|
|
532
570
|
cumulative_positive += positive_values.values
|
|
533
571
|
# Plot negative bars
|
|
534
572
|
ax.bar(
|
|
535
|
-
|
|
573
|
+
df.index,
|
|
536
574
|
negative_values,
|
|
537
575
|
bottom=cumulative_negative,
|
|
538
576
|
color=processed_colors[i],
|
|
539
577
|
label='', # No label for negative bars
|
|
540
578
|
width=width,
|
|
541
579
|
align='center',
|
|
580
|
+
**plot_kwargs,
|
|
542
581
|
)
|
|
543
582
|
cumulative_negative += negative_values.values
|
|
544
583
|
|
|
545
|
-
elif
|
|
546
|
-
for i, column in enumerate(
|
|
547
|
-
ax.step(
|
|
584
|
+
elif mode == 'line':
|
|
585
|
+
for i, column in enumerate(df.columns):
|
|
586
|
+
ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs)
|
|
548
587
|
|
|
549
588
|
# Aesthetics
|
|
550
589
|
ax.set_xlabel(xlabel, ha='center')
|
|
@@ -562,213 +601,110 @@ def with_matplotlib(
|
|
|
562
601
|
return fig, ax
|
|
563
602
|
|
|
564
603
|
|
|
565
|
-
def
|
|
566
|
-
data:
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis,
|
|
575
|
-
the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot.
|
|
576
|
-
|
|
577
|
-
Args:
|
|
578
|
-
data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
|
|
579
|
-
The values in the DataFrame will be represented as colors in the heatmap.
|
|
580
|
-
color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc.
|
|
581
|
-
title: The title of the plot.
|
|
582
|
-
xlabel: The label for the x-axis.
|
|
583
|
-
ylabel: The label for the y-axis.
|
|
584
|
-
figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches.
|
|
585
|
-
|
|
586
|
-
Returns:
|
|
587
|
-
A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area
|
|
588
|
-
where the heatmap is drawn. These can be used for further customization or saving the plot to a file.
|
|
589
|
-
|
|
590
|
-
Notes:
|
|
591
|
-
- The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot.
|
|
592
|
-
- The color scale is normalized based on the minimum and maximum values in the DataFrame.
|
|
593
|
-
- The x-axis labels (periods) are placed at the top of the plot.
|
|
594
|
-
- The colorbar is added horizontally at the bottom of the plot, with a label.
|
|
595
|
-
"""
|
|
596
|
-
|
|
597
|
-
# Get the min and max values for color normalization
|
|
598
|
-
color_bar_min, color_bar_max = data.min().min(), data.max().max()
|
|
599
|
-
|
|
600
|
-
# Create the heatmap plot
|
|
601
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
602
|
-
ax.pcolormesh(data.values, cmap=color_map, shading='auto')
|
|
603
|
-
ax.invert_yaxis() # Flip the y-axis to start at the top
|
|
604
|
-
|
|
605
|
-
# Adjust ticks and labels for x and y axes
|
|
606
|
-
ax.set_xticks(np.arange(len(data.columns)) + 0.5)
|
|
607
|
-
ax.set_xticklabels(data.columns, ha='center')
|
|
608
|
-
ax.set_yticks(np.arange(len(data.index)) + 0.5)
|
|
609
|
-
ax.set_yticklabels(data.index, va='center')
|
|
610
|
-
|
|
611
|
-
# Add labels to the axes
|
|
612
|
-
ax.set_xlabel(xlabel, ha='center')
|
|
613
|
-
ax.set_ylabel(ylabel, va='center')
|
|
614
|
-
ax.set_title(title)
|
|
615
|
-
|
|
616
|
-
# Position x-axis labels at the top
|
|
617
|
-
ax.xaxis.set_label_position('top')
|
|
618
|
-
ax.xaxis.set_ticks_position('top')
|
|
619
|
-
|
|
620
|
-
# Add the colorbar
|
|
621
|
-
sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max))
|
|
622
|
-
sm1.set_array([])
|
|
623
|
-
fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal')
|
|
624
|
-
|
|
625
|
-
fig.tight_layout()
|
|
626
|
-
|
|
627
|
-
return fig, ax
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
def heat_map_plotly(
|
|
631
|
-
data: pd.DataFrame,
|
|
632
|
-
color_map: str = 'viridis',
|
|
633
|
-
title: str = '',
|
|
634
|
-
xlabel: str = 'Period',
|
|
635
|
-
ylabel: str = 'Step',
|
|
636
|
-
categorical_labels: bool = True,
|
|
637
|
-
) -> go.Figure:
|
|
604
|
+
def reshape_data_for_heatmap(
|
|
605
|
+
data: xr.DataArray,
|
|
606
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
607
|
+
| Literal['auto']
|
|
608
|
+
| None = 'auto',
|
|
609
|
+
facet_by: str | list[str] | None = None,
|
|
610
|
+
animate_by: str | None = None,
|
|
611
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
612
|
+
) -> xr.DataArray:
|
|
638
613
|
"""
|
|
639
|
-
|
|
640
|
-
and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot.
|
|
614
|
+
Reshape data for heatmap visualization, handling time dimension intelligently.
|
|
641
615
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
title: The title of the heatmap. Default is an empty string.
|
|
647
|
-
xlabel: The label for the x-axis. Default is 'Period'.
|
|
648
|
-
ylabel: The label for the y-axis. Default is 'Step'.
|
|
649
|
-
categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data).
|
|
650
|
-
Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data.
|
|
616
|
+
This function decides whether to reshape the 'time' dimension based on the reshape_time parameter:
|
|
617
|
+
- 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap
|
|
618
|
+
- Tuple: Explicitly reshapes time with specified parameters
|
|
619
|
+
- None: No reshaping (returns data as-is)
|
|
651
620
|
|
|
652
|
-
|
|
653
|
-
A Plotly figure object containing the heatmap. This can be further customized and saved
|
|
654
|
-
or displayed using `fig.show()`.
|
|
655
|
-
|
|
656
|
-
Notes:
|
|
657
|
-
The color bar is automatically scaled to the minimum and maximum values in the data.
|
|
658
|
-
The y-axis is reversed to display the first row at the top.
|
|
659
|
-
"""
|
|
660
|
-
|
|
661
|
-
color_bar_min, color_bar_max = data.min().min(), data.max().max() # Min and max values for color scaling
|
|
662
|
-
# Define the figure
|
|
663
|
-
fig = go.Figure(
|
|
664
|
-
data=go.Heatmap(
|
|
665
|
-
z=data.values,
|
|
666
|
-
x=data.columns,
|
|
667
|
-
y=data.index,
|
|
668
|
-
colorscale=color_map,
|
|
669
|
-
zmin=color_bar_min,
|
|
670
|
-
zmax=color_bar_max,
|
|
671
|
-
colorbar=dict(
|
|
672
|
-
title=dict(text='Color Bar Label', side='right'),
|
|
673
|
-
orientation='h',
|
|
674
|
-
xref='container',
|
|
675
|
-
yref='container',
|
|
676
|
-
len=0.8, # Color bar length relative to plot
|
|
677
|
-
x=0.5,
|
|
678
|
-
y=0.1,
|
|
679
|
-
),
|
|
680
|
-
)
|
|
681
|
-
)
|
|
682
|
-
|
|
683
|
-
# Set axis labels and style
|
|
684
|
-
fig.update_layout(
|
|
685
|
-
title=title,
|
|
686
|
-
xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None),
|
|
687
|
-
yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None),
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
return fig
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray:
|
|
694
|
-
"""
|
|
695
|
-
Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap.
|
|
696
|
-
|
|
697
|
-
The reshaped array will have the number of rows corresponding to the steps per column
|
|
698
|
-
(e.g., 24 hours per day) and columns representing time periods (e.g., days or months).
|
|
621
|
+
All non-time dimensions are preserved during reshaping.
|
|
699
622
|
|
|
700
623
|
Args:
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
624
|
+
data: DataArray to reshape for heatmap visualization.
|
|
625
|
+
reshape_time: Reshaping configuration:
|
|
626
|
+
- 'auto' (default): Auto-reshape if needed based on facet_by/animate_by
|
|
627
|
+
- Tuple (timeframes, timesteps_per_frame): Explicit time reshaping
|
|
628
|
+
- None: No reshaping
|
|
629
|
+
facet_by: Dimension(s) used for faceting (used in 'auto' decision).
|
|
630
|
+
animate_by: Dimension used for animation (used in 'auto' decision).
|
|
631
|
+
fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
704
632
|
|
|
705
633
|
Returns:
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
"""
|
|
709
|
-
|
|
710
|
-
# Step 1: Ensure the input is a 1D array.
|
|
711
|
-
if data_1d.ndim != 1:
|
|
712
|
-
raise ValueError('Input must be a 1D array')
|
|
634
|
+
Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced
|
|
635
|
+
by 'timestep' and 'timeframe'. All other dimensions are preserved.
|
|
713
636
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
data_1d = data_1d.astype(np.float64)
|
|
717
|
-
|
|
718
|
-
# Step 3: Calculate the number of columns required
|
|
719
|
-
total_steps = len(data_1d)
|
|
720
|
-
cols = len(data_1d) // nr_of_steps_per_column # Base number of columns
|
|
721
|
-
|
|
722
|
-
# If there's a remainder, add an extra column to hold the remaining values
|
|
723
|
-
if total_steps % nr_of_steps_per_column != 0:
|
|
724
|
-
cols += 1
|
|
637
|
+
Examples:
|
|
638
|
+
Auto-reshaping:
|
|
725
639
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
640
|
+
```python
|
|
641
|
+
# Will auto-reshape because only 'time' remains after faceting/animation
|
|
642
|
+
data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period')
|
|
643
|
+
```
|
|
730
644
|
|
|
731
|
-
|
|
732
|
-
data_2d = padded_data.reshape(cols, nr_of_steps_per_column)
|
|
645
|
+
Explicit reshaping:
|
|
733
646
|
|
|
734
|
-
|
|
647
|
+
```python
|
|
648
|
+
# Explicitly reshape to daily pattern
|
|
649
|
+
data = reshape_data_for_heatmap(data, reshape_time=('D', 'h'))
|
|
650
|
+
```
|
|
735
651
|
|
|
652
|
+
No reshaping:
|
|
736
653
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
fill: Literal['ffill', 'bfill'] | None = None,
|
|
742
|
-
) -> pd.DataFrame:
|
|
654
|
+
```python
|
|
655
|
+
# Keep data as-is
|
|
656
|
+
data = reshape_data_for_heatmap(data, reshape_time=None)
|
|
657
|
+
```
|
|
743
658
|
"""
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
659
|
+
# If no time dimension, return data as-is
|
|
660
|
+
if 'time' not in data.dims:
|
|
661
|
+
return data
|
|
662
|
+
|
|
663
|
+
# Handle None (disabled) - return data as-is
|
|
664
|
+
if reshape_time is None:
|
|
665
|
+
return data
|
|
666
|
+
|
|
667
|
+
# Determine timeframes and timesteps_per_frame based on reshape_time parameter
|
|
668
|
+
if reshape_time == 'auto':
|
|
669
|
+
# Check if we need automatic time reshaping
|
|
670
|
+
facet_dims_used = []
|
|
671
|
+
if facet_by:
|
|
672
|
+
facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by)
|
|
673
|
+
if animate_by:
|
|
674
|
+
facet_dims_used.append(animate_by)
|
|
675
|
+
|
|
676
|
+
# Get dimensions that would remain for heatmap
|
|
677
|
+
potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used]
|
|
678
|
+
|
|
679
|
+
# Auto-reshape if only 'time' dimension remains
|
|
680
|
+
if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time':
|
|
681
|
+
logger.debug(
|
|
682
|
+
"Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. "
|
|
683
|
+
"Using default timeframes='D' and timesteps_per_frame='h'. "
|
|
684
|
+
"To customize, use reshape_time=('D', 'h') or disable with reshape_time=None."
|
|
685
|
+
)
|
|
686
|
+
timeframes, timesteps_per_frame = 'D', 'h'
|
|
687
|
+
else:
|
|
688
|
+
# No reshaping needed
|
|
689
|
+
return data
|
|
690
|
+
elif isinstance(reshape_time, tuple):
|
|
691
|
+
# Explicit reshaping
|
|
692
|
+
timeframes, timesteps_per_frame = reshape_time
|
|
693
|
+
else:
|
|
694
|
+
raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}")
|
|
755
695
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
"""
|
|
760
|
-
assert pd.api.types.is_datetime64_any_dtype(df.index), (
|
|
761
|
-
'The index of the DataFrame must be datetime to transform it properly for a heatmap plot'
|
|
762
|
-
)
|
|
696
|
+
# Validate that time is datetime
|
|
697
|
+
if not np.issubdtype(data.coords['time'].dtype, np.datetime64):
|
|
698
|
+
raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}')
|
|
763
699
|
|
|
764
|
-
# Define formats for different combinations
|
|
700
|
+
# Define formats for different combinations
|
|
765
701
|
formats = {
|
|
766
702
|
('YS', 'W'): ('%Y', '%W'),
|
|
767
703
|
('YS', 'D'): ('%Y', '%j'), # day of year
|
|
768
704
|
('YS', 'h'): ('%Y', '%j %H:00'),
|
|
769
705
|
('MS', 'D'): ('%Y-%m', '%d'), # day of month
|
|
770
706
|
('MS', 'h'): ('%Y-%m', '%d %H:00'),
|
|
771
|
-
('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
|
|
707
|
+
('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
|
|
772
708
|
('W', 'h'): ('%Y-w%W', '%w_%A %H:00'),
|
|
773
709
|
('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour
|
|
774
710
|
('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute
|
|
@@ -776,43 +712,64 @@ def heat_map_data_from_df(
|
|
|
776
712
|
('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour
|
|
777
713
|
}
|
|
778
714
|
|
|
779
|
-
|
|
780
|
-
raise ValueError('DataFrame is empty.')
|
|
781
|
-
diffs = df.index.to_series().diff().dropna()
|
|
782
|
-
minimum_time_diff_in_min = diffs.min().total_seconds() / 60
|
|
783
|
-
time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
|
|
784
|
-
if time_intervals[steps_per_period] > minimum_time_diff_in_min:
|
|
785
|
-
logger.error(
|
|
786
|
-
f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to '
|
|
787
|
-
f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.'
|
|
788
|
-
)
|
|
789
|
-
|
|
790
|
-
# Select the format based on the `periods` and `steps_per_period` combination
|
|
791
|
-
format_pair = (periods, steps_per_period)
|
|
715
|
+
format_pair = (timeframes, timesteps_per_frame)
|
|
792
716
|
if format_pair not in formats:
|
|
793
717
|
raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}')
|
|
794
718
|
period_format, step_format = formats[format_pair]
|
|
795
719
|
|
|
796
|
-
|
|
720
|
+
# Check if resampling is needed
|
|
721
|
+
if data.sizes['time'] > 1:
|
|
722
|
+
# Use NumPy for more efficient timedelta computation
|
|
723
|
+
time_values = data.coords['time'].values # Already numpy datetime64[ns]
|
|
724
|
+
# Calculate differences and convert to minutes
|
|
725
|
+
time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0
|
|
726
|
+
if time_diffs.size > 0:
|
|
727
|
+
min_time_diff_min = np.nanmin(time_diffs)
|
|
728
|
+
time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
|
|
729
|
+
if time_intervals[timesteps_per_frame] > min_time_diff_min:
|
|
730
|
+
logger.warning(
|
|
731
|
+
f'Resampling data from {min_time_diff_min:.2f} min to '
|
|
732
|
+
f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.'
|
|
733
|
+
)
|
|
797
734
|
|
|
798
|
-
|
|
735
|
+
# Resample along time dimension
|
|
736
|
+
resampled = data.resample(time=timesteps_per_frame).mean()
|
|
799
737
|
|
|
800
|
-
|
|
801
|
-
|
|
738
|
+
# Apply fill if specified
|
|
739
|
+
if fill == 'ffill':
|
|
740
|
+
resampled = resampled.ffill(dim='time')
|
|
802
741
|
elif fill == 'bfill':
|
|
803
|
-
|
|
742
|
+
resampled = resampled.bfill(dim='time')
|
|
743
|
+
|
|
744
|
+
# Create period and step labels
|
|
745
|
+
time_values = pd.to_datetime(resampled.coords['time'].values)
|
|
746
|
+
period_labels = time_values.strftime(period_format)
|
|
747
|
+
step_labels = time_values.strftime(step_format)
|
|
748
|
+
|
|
749
|
+
# Handle special case for weekly day format
|
|
750
|
+
if '%w_%A' in step_format:
|
|
751
|
+
step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values
|
|
752
|
+
|
|
753
|
+
# Add period and step as coordinates
|
|
754
|
+
resampled = resampled.assign_coords(
|
|
755
|
+
{
|
|
756
|
+
'timeframe': ('time', period_labels),
|
|
757
|
+
'timestep': ('time', step_labels),
|
|
758
|
+
}
|
|
759
|
+
)
|
|
804
760
|
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
761
|
+
# Convert to multi-index and unstack
|
|
762
|
+
resampled = resampled.set_index(time=['timeframe', 'timestep'])
|
|
763
|
+
result = resampled.unstack('time')
|
|
764
|
+
|
|
765
|
+
# Ensure timestep and timeframe come first in dimension order
|
|
766
|
+
# Get other dimensions
|
|
767
|
+
other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']]
|
|
811
768
|
|
|
812
|
-
#
|
|
813
|
-
|
|
769
|
+
# Reorder: timestep, timeframe, then other dimensions
|
|
770
|
+
result = result.transpose('timestep', 'timeframe', *other_dims)
|
|
814
771
|
|
|
815
|
-
return
|
|
772
|
+
return result
|
|
816
773
|
|
|
817
774
|
|
|
818
775
|
def plot_network(
|
|
@@ -899,518 +856,704 @@ def plot_network(
|
|
|
899
856
|
)
|
|
900
857
|
|
|
901
858
|
|
|
902
|
-
def
|
|
903
|
-
data: pd.DataFrame,
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
legend_title: str = '',
|
|
907
|
-
hole: float = 0.0,
|
|
908
|
-
fig: go.Figure | None = None,
|
|
909
|
-
) -> go.Figure:
|
|
859
|
+
def preprocess_data_for_pie(
|
|
860
|
+
data: xr.Dataset | pd.DataFrame | pd.Series,
|
|
861
|
+
lower_percentage_threshold: float = 5.0,
|
|
862
|
+
) -> pd.Series:
|
|
910
863
|
"""
|
|
911
|
-
|
|
864
|
+
Preprocess data for pie chart display.
|
|
865
|
+
|
|
866
|
+
Groups items that are individually below the threshold percentage into an "Other" category.
|
|
867
|
+
Converts various input types to a pandas Series for uniform handling.
|
|
912
868
|
|
|
913
869
|
Args:
|
|
914
|
-
data:
|
|
915
|
-
|
|
916
|
-
colors: Color specification, can be:
|
|
917
|
-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
|
|
918
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
919
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
920
|
-
title: The title of the plot.
|
|
921
|
-
legend_title: The title for the legend.
|
|
922
|
-
hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
|
|
923
|
-
fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
|
|
870
|
+
data: Input data (xarray Dataset, DataFrame, or Series)
|
|
871
|
+
lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other"
|
|
924
872
|
|
|
925
873
|
Returns:
|
|
926
|
-
|
|
874
|
+
Processed pandas Series with small items grouped into "Other"
|
|
875
|
+
"""
|
|
876
|
+
# Convert to Series
|
|
877
|
+
if isinstance(data, xr.Dataset):
|
|
878
|
+
# Sum all dimensions for each variable to get total values
|
|
879
|
+
values = {}
|
|
880
|
+
for var in data.data_vars:
|
|
881
|
+
var_data = data[var]
|
|
882
|
+
if len(var_data.dims) > 0:
|
|
883
|
+
total_value = float(var_data.sum().item())
|
|
884
|
+
else:
|
|
885
|
+
total_value = float(var_data.item())
|
|
927
886
|
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
- By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
|
|
887
|
+
# Handle negative values
|
|
888
|
+
if total_value < 0:
|
|
889
|
+
logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.')
|
|
890
|
+
total_value = abs(total_value)
|
|
933
891
|
|
|
934
|
-
|
|
935
|
-
if data.empty:
|
|
936
|
-
logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
|
|
937
|
-
return go.Figure()
|
|
892
|
+
values[var] = total_value
|
|
938
893
|
|
|
939
|
-
|
|
940
|
-
data_copy = data.copy()
|
|
894
|
+
series = pd.Series(values)
|
|
941
895
|
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
896
|
+
elif isinstance(data, pd.DataFrame):
|
|
897
|
+
# Sum across all columns if DataFrame
|
|
898
|
+
series = data.sum(axis=0)
|
|
899
|
+
# Handle negative values
|
|
900
|
+
negative_mask = series < 0
|
|
901
|
+
if negative_mask.any():
|
|
902
|
+
logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
|
|
903
|
+
series = series.abs()
|
|
946
904
|
|
|
947
|
-
#
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
905
|
+
else: # pd.Series
|
|
906
|
+
series = data.copy()
|
|
907
|
+
# Handle negative values
|
|
908
|
+
negative_mask = series < 0
|
|
909
|
+
if negative_mask.any():
|
|
910
|
+
logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
|
|
911
|
+
series = series.abs()
|
|
952
912
|
|
|
953
|
-
#
|
|
954
|
-
|
|
955
|
-
values = data_sum.values.tolist()
|
|
913
|
+
# Only keep positive values
|
|
914
|
+
series = series[series > 0]
|
|
956
915
|
|
|
957
|
-
|
|
958
|
-
|
|
916
|
+
if series.empty or lower_percentage_threshold <= 0:
|
|
917
|
+
return series
|
|
959
918
|
|
|
960
|
-
#
|
|
961
|
-
|
|
919
|
+
# Calculate percentages
|
|
920
|
+
total = series.sum()
|
|
921
|
+
percentages = (series / total) * 100
|
|
962
922
|
|
|
963
|
-
#
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
923
|
+
# Find items below and above threshold
|
|
924
|
+
below_threshold = series[percentages < lower_percentage_threshold]
|
|
925
|
+
above_threshold = series[percentages >= lower_percentage_threshold]
|
|
926
|
+
|
|
927
|
+
# Only group if there are at least 2 items below threshold
|
|
928
|
+
if len(below_threshold) > 1:
|
|
929
|
+
# Create new series with items above threshold + "Other"
|
|
930
|
+
result = above_threshold.copy()
|
|
931
|
+
result['Other'] = below_threshold.sum()
|
|
932
|
+
return result
|
|
933
|
+
|
|
934
|
+
return series
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
def dual_pie_with_plotly(
|
|
938
|
+
data_left: xr.Dataset | pd.DataFrame | pd.Series,
|
|
939
|
+
data_right: xr.Dataset | pd.DataFrame | pd.Series,
|
|
940
|
+
colors: ColorType | None = None,
|
|
941
|
+
title: str = '',
|
|
942
|
+
subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
|
|
943
|
+
legend_title: str = '',
|
|
944
|
+
hole: float = 0.2,
|
|
945
|
+
lower_percentage_group: float = 5.0,
|
|
946
|
+
text_info: str = 'percent+label',
|
|
947
|
+
text_position: str = 'inside',
|
|
948
|
+
hover_template: str = '%{label}: %{value} (%{percent})',
|
|
949
|
+
) -> go.Figure:
|
|
950
|
+
"""
|
|
951
|
+
Create two pie charts side by side with Plotly.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
data_left: Data for the left pie chart. Variables are summed across all dimensions.
|
|
955
|
+
data_right: Data for the right pie chart. Variables are summed across all dimensions.
|
|
956
|
+
colors: Color specification (colorscale name, list of colors, or dict mapping)
|
|
957
|
+
title: The main title of the plot.
|
|
958
|
+
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
959
|
+
legend_title: The title for the legend.
|
|
960
|
+
hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
|
|
961
|
+
lower_percentage_group: Group segments whose cumulative share is below this percentage (0–100) into "Other".
|
|
962
|
+
hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
|
|
963
|
+
text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
|
|
964
|
+
'label+value', 'percent+value', 'label+percent+value', or 'none'.
|
|
965
|
+
text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
|
|
966
|
+
|
|
967
|
+
Returns:
|
|
968
|
+
Plotly Figure object
|
|
969
|
+
"""
|
|
970
|
+
if colors is None:
|
|
971
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
972
|
+
|
|
973
|
+
# Preprocess data to Series
|
|
974
|
+
left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
|
|
975
|
+
right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
|
|
976
|
+
|
|
977
|
+
# Extract labels and values
|
|
978
|
+
left_labels = left_series.index.tolist()
|
|
979
|
+
left_values = left_series.values.tolist()
|
|
980
|
+
|
|
981
|
+
right_labels = right_series.index.tolist()
|
|
982
|
+
right_values = right_series.values.tolist()
|
|
983
|
+
|
|
984
|
+
# Get all unique labels for consistent coloring
|
|
985
|
+
all_labels = sorted(set(left_labels) | set(right_labels))
|
|
986
|
+
|
|
987
|
+
# Create color map
|
|
988
|
+
color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
|
|
989
|
+
|
|
990
|
+
# Create figure
|
|
991
|
+
fig = go.Figure()
|
|
992
|
+
|
|
993
|
+
# Add left pie
|
|
994
|
+
if left_labels:
|
|
995
|
+
fig.add_trace(
|
|
996
|
+
go.Pie(
|
|
997
|
+
labels=left_labels,
|
|
998
|
+
values=left_values,
|
|
999
|
+
name=subtitles[0],
|
|
1000
|
+
marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]),
|
|
1001
|
+
hole=hole,
|
|
1002
|
+
textinfo=text_info,
|
|
1003
|
+
textposition=text_position,
|
|
1004
|
+
hovertemplate=hover_template,
|
|
1005
|
+
domain=dict(x=[0, 0.48]),
|
|
1006
|
+
)
|
|
973
1007
|
)
|
|
974
|
-
)
|
|
975
1008
|
|
|
976
|
-
#
|
|
1009
|
+
# Add right pie
|
|
1010
|
+
if right_labels:
|
|
1011
|
+
fig.add_trace(
|
|
1012
|
+
go.Pie(
|
|
1013
|
+
labels=right_labels,
|
|
1014
|
+
values=right_values,
|
|
1015
|
+
name=subtitles[1],
|
|
1016
|
+
marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]),
|
|
1017
|
+
hole=hole,
|
|
1018
|
+
textinfo=text_info,
|
|
1019
|
+
textposition=text_position,
|
|
1020
|
+
hovertemplate=hover_template,
|
|
1021
|
+
domain=dict(x=[0.52, 1]),
|
|
1022
|
+
)
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# Update layout
|
|
977
1026
|
fig.update_layout(
|
|
978
1027
|
title=title,
|
|
979
1028
|
legend_title=legend_title,
|
|
980
|
-
|
|
981
|
-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
|
|
982
|
-
font=dict(size=14), # Increase font size for better readability
|
|
1029
|
+
margin=dict(t=80, b=50, l=30, r=30),
|
|
983
1030
|
)
|
|
984
1031
|
|
|
985
1032
|
return fig
|
|
986
1033
|
|
|
987
1034
|
|
|
988
|
-
def
|
|
989
|
-
|
|
990
|
-
|
|
1035
|
+
def dual_pie_with_matplotlib(
|
|
1036
|
+
data_left: xr.Dataset | pd.DataFrame | pd.Series,
|
|
1037
|
+
data_right: xr.Dataset | pd.DataFrame | pd.Series,
|
|
1038
|
+
colors: ColorType | None = None,
|
|
991
1039
|
title: str = '',
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
) -> tuple[plt.Figure, plt.Axes]:
|
|
1040
|
+
subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
|
|
1041
|
+
legend_title: str = '',
|
|
1042
|
+
hole: float = 0.2,
|
|
1043
|
+
lower_percentage_group: float = 5.0,
|
|
1044
|
+
figsize: tuple[int, int] = (14, 7),
|
|
1045
|
+
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
998
1046
|
"""
|
|
999
|
-
Create
|
|
1047
|
+
Create two pie charts side by side with Matplotlib.
|
|
1000
1048
|
|
|
1001
1049
|
Args:
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
colors: Color specification,
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
1008
|
-
title: The title of the plot.
|
|
1050
|
+
data_left: Data for the left pie chart.
|
|
1051
|
+
data_right: Data for the right pie chart.
|
|
1052
|
+
colors: Color specification (colorscale name, list of colors, or dict mapping)
|
|
1053
|
+
title: The main title of the plot.
|
|
1054
|
+
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
1009
1055
|
legend_title: The title for the legend.
|
|
1010
|
-
hole: Size of the hole in the center for creating
|
|
1056
|
+
hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
|
|
1057
|
+
lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
|
|
1011
1058
|
figsize: The size of the figure (width, height) in inches.
|
|
1012
|
-
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1013
|
-
ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
|
|
1014
1059
|
|
|
1015
1060
|
Returns:
|
|
1016
|
-
|
|
1061
|
+
Tuple of (Figure, list of Axes)
|
|
1062
|
+
"""
|
|
1063
|
+
if colors is None:
|
|
1064
|
+
colors = CONFIG.Plotting.default_qualitative_colorscale
|
|
1017
1065
|
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
for better readability.
|
|
1022
|
-
- By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
|
|
1066
|
+
# Preprocess data to Series
|
|
1067
|
+
left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
|
|
1068
|
+
right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
|
|
1023
1069
|
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
if fig is None or ax is None:
|
|
1028
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
1029
|
-
return fig, ax
|
|
1070
|
+
# Extract labels and values
|
|
1071
|
+
left_labels = left_series.index.tolist()
|
|
1072
|
+
left_values = left_series.values.tolist()
|
|
1030
1073
|
|
|
1031
|
-
|
|
1032
|
-
|
|
1074
|
+
right_labels = right_series.index.tolist()
|
|
1075
|
+
right_values = right_series.values.tolist()
|
|
1033
1076
|
|
|
1034
|
-
#
|
|
1035
|
-
|
|
1036
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
1037
|
-
data_copy = data_copy.abs()
|
|
1077
|
+
# Get all unique labels for consistent coloring
|
|
1078
|
+
all_labels = sorted(set(left_labels) | set(right_labels))
|
|
1038
1079
|
|
|
1039
|
-
#
|
|
1040
|
-
|
|
1041
|
-
data_sum = data_copy.sum()
|
|
1042
|
-
else:
|
|
1043
|
-
data_sum = data_copy.iloc[0]
|
|
1080
|
+
# Create color map (process_colors always returns a dict)
|
|
1081
|
+
color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
|
|
1044
1082
|
|
|
1045
|
-
#
|
|
1046
|
-
|
|
1047
|
-
values = data_sum.values.tolist()
|
|
1083
|
+
# Create figure
|
|
1084
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1048
1085
|
|
|
1049
|
-
|
|
1050
|
-
|
|
1086
|
+
def draw_pie(ax, labels, values, subtitle):
|
|
1087
|
+
"""Draw a single pie chart."""
|
|
1088
|
+
if not labels:
|
|
1089
|
+
ax.set_title(subtitle)
|
|
1090
|
+
ax.axis('off')
|
|
1091
|
+
return
|
|
1051
1092
|
|
|
1052
|
-
|
|
1053
|
-
if fig is None or ax is None:
|
|
1054
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
1093
|
+
chart_colors = [color_map[label] for label in labels]
|
|
1055
1094
|
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1095
|
+
# Draw pie
|
|
1096
|
+
wedges, texts, autotexts = ax.pie(
|
|
1097
|
+
values,
|
|
1098
|
+
labels=labels,
|
|
1099
|
+
colors=chart_colors,
|
|
1100
|
+
autopct='%1.1f%%',
|
|
1101
|
+
startangle=90,
|
|
1102
|
+
wedgeprops=dict(width=1 - hole) if hole > 0 else None,
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
# Style text
|
|
1106
|
+
for autotext in autotexts:
|
|
1107
|
+
autotext.set_fontsize(10)
|
|
1108
|
+
autotext.set_color('white')
|
|
1109
|
+
autotext.set_weight('bold')
|
|
1066
1110
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
for wedge in wedges:
|
|
1076
|
-
wedge.set_width(wedge_width)
|
|
1077
|
-
|
|
1078
|
-
# Customize the appearance
|
|
1079
|
-
# Make autopct text more visible
|
|
1080
|
-
for autotext in autotexts:
|
|
1081
|
-
autotext.set_fontsize(10)
|
|
1082
|
-
autotext.set_color('white')
|
|
1083
|
-
|
|
1084
|
-
# Set aspect ratio to be equal to ensure a circular pie
|
|
1085
|
-
ax.set_aspect('equal')
|
|
1086
|
-
|
|
1087
|
-
# Add title
|
|
1111
|
+
ax.set_aspect('equal')
|
|
1112
|
+
ax.set_title(subtitle, fontsize=14, pad=20)
|
|
1113
|
+
|
|
1114
|
+
# Draw both pies
|
|
1115
|
+
draw_pie(axes[0], left_labels, left_values, subtitles[0])
|
|
1116
|
+
draw_pie(axes[1], right_labels, right_values, subtitles[1])
|
|
1117
|
+
|
|
1118
|
+
# Add main title
|
|
1088
1119
|
if title:
|
|
1089
|
-
|
|
1120
|
+
fig.suptitle(title, fontsize=16, y=0.98)
|
|
1090
1121
|
|
|
1091
|
-
# Create
|
|
1092
|
-
if
|
|
1093
|
-
|
|
1122
|
+
# Create unified legend
|
|
1123
|
+
if left_labels or right_labels:
|
|
1124
|
+
handles = [
|
|
1125
|
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10)
|
|
1126
|
+
for label in all_labels
|
|
1127
|
+
]
|
|
1128
|
+
|
|
1129
|
+
fig.legend(
|
|
1130
|
+
handles=handles,
|
|
1131
|
+
labels=all_labels,
|
|
1132
|
+
title=legend_title,
|
|
1133
|
+
loc='lower center',
|
|
1134
|
+
bbox_to_anchor=(0.5, -0.02),
|
|
1135
|
+
ncol=min(len(all_labels), 5),
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
fig.subplots_adjust(bottom=0.15)
|
|
1094
1139
|
|
|
1095
|
-
# Apply tight layout
|
|
1096
1140
|
fig.tight_layout()
|
|
1097
1141
|
|
|
1098
|
-
return fig,
|
|
1142
|
+
return fig, axes
|
|
1099
1143
|
|
|
1100
1144
|
|
|
1101
|
-
def
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
colors: ColorType = 'viridis',
|
|
1145
|
+
def heatmap_with_plotly_v2(
|
|
1146
|
+
data: xr.DataArray,
|
|
1147
|
+
colors: ColorType | None = None,
|
|
1105
1148
|
title: str = '',
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
hover_template: str = '%{label}: %{value} (%{percent})',
|
|
1111
|
-
text_info: str = 'percent+label',
|
|
1112
|
-
text_position: str = 'inside',
|
|
1149
|
+
facet_col: str | None = None,
|
|
1150
|
+
animation_frame: str | None = None,
|
|
1151
|
+
facet_col_wrap: int | None = None,
|
|
1152
|
+
**imshow_kwargs: Any,
|
|
1113
1153
|
) -> go.Figure:
|
|
1114
1154
|
"""
|
|
1115
|
-
|
|
1155
|
+
Plot a heatmap using Plotly's imshow.
|
|
1156
|
+
|
|
1157
|
+
Data should be prepared with dims in order: (y_axis, x_axis, [facet_col], [animation_frame]).
|
|
1158
|
+
Use reshape_data_for_heatmap() to prepare time-series data before calling this.
|
|
1116
1159
|
|
|
1117
1160
|
Args:
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
1126
|
-
legend_title: The title for the legend.
|
|
1127
|
-
hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
|
|
1128
|
-
lower_percentage_group: Group segments whose cumulative share is below this percentage (0–100) into "Other".
|
|
1129
|
-
hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
|
|
1130
|
-
text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
|
|
1131
|
-
'label+value', 'percent+value', 'label+percent+value', or 'none'.
|
|
1132
|
-
text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
|
|
1161
|
+
data: DataArray with 2-4 dimensions. First two are heatmap axes.
|
|
1162
|
+
colors: Colorscale name ('viridis', 'plasma', etc.).
|
|
1163
|
+
title: Plot title.
|
|
1164
|
+
facet_col: Dimension name for subplot columns (3rd dim).
|
|
1165
|
+
animation_frame: Dimension name for animation (4th dim).
|
|
1166
|
+
facet_col_wrap: Max columns before wrapping (only if < n_facets).
|
|
1167
|
+
**imshow_kwargs: Additional args for px.imshow.
|
|
1133
1168
|
|
|
1134
1169
|
Returns:
|
|
1135
|
-
|
|
1170
|
+
Plotly Figure object.
|
|
1136
1171
|
"""
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
# Check for empty data
|
|
1140
|
-
if data_left.empty and data_right.empty:
|
|
1141
|
-
logger.error('Both datasets are empty. Returning empty figure.')
|
|
1172
|
+
if data.size == 0:
|
|
1142
1173
|
return go.Figure()
|
|
1143
1174
|
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05
|
|
1147
|
-
)
|
|
1175
|
+
colors = colors or CONFIG.Plotting.default_sequential_colorscale
|
|
1176
|
+
facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols
|
|
1148
1177
|
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1178
|
+
imshow_args: dict[str, Any] = {
|
|
1179
|
+
'img': data,
|
|
1180
|
+
'color_continuous_scale': colors,
|
|
1181
|
+
'title': title,
|
|
1182
|
+
**imshow_kwargs,
|
|
1183
|
+
}
|
|
1155
1184
|
|
|
1156
|
-
|
|
1157
|
-
|
|
1185
|
+
if facet_col and facet_col in data.dims:
|
|
1186
|
+
imshow_args['facet_col'] = facet_col
|
|
1187
|
+
if facet_col_wrap < data.sizes[facet_col]:
|
|
1188
|
+
imshow_args['facet_col_wrap'] = facet_col_wrap
|
|
1158
1189
|
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
"""
|
|
1162
|
-
# Handle negative values
|
|
1163
|
-
if (series < 0).any():
|
|
1164
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
1165
|
-
series = series.abs()
|
|
1190
|
+
if animation_frame and animation_frame in data.dims:
|
|
1191
|
+
imshow_args['animation_frame'] = animation_frame
|
|
1166
1192
|
|
|
1167
|
-
|
|
1168
|
-
series = series[series > 0]
|
|
1193
|
+
return px.imshow(**imshow_args)
|
|
1169
1194
|
|
|
1170
|
-
# Apply minimum percentage threshold if needed
|
|
1171
|
-
if lower_percentage_group and not series.empty:
|
|
1172
|
-
total = series.sum()
|
|
1173
|
-
if total > 0:
|
|
1174
|
-
# Sort series by value (ascending)
|
|
1175
|
-
sorted_series = series.sort_values()
|
|
1176
1195
|
|
|
1177
|
-
|
|
1178
|
-
|
|
1196
|
+
def heatmap_with_plotly(
|
|
1197
|
+
data: xr.DataArray,
|
|
1198
|
+
colors: ColorType | None = None,
|
|
1199
|
+
title: str = '',
|
|
1200
|
+
facet_by: str | list[str] | None = None,
|
|
1201
|
+
animate_by: str | None = None,
|
|
1202
|
+
facet_cols: int | None = None,
|
|
1203
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1204
|
+
| Literal['auto']
|
|
1205
|
+
| None = 'auto',
|
|
1206
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1207
|
+
**imshow_kwargs: Any,
|
|
1208
|
+
) -> go.Figure:
|
|
1209
|
+
"""
|
|
1210
|
+
Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
|
|
1179
1211
|
|
|
1180
|
-
|
|
1181
|
-
|
|
1212
|
+
This function creates heatmap visualizations from xarray DataArrays, supporting
|
|
1213
|
+
multi-dimensional data through faceting (subplots) and animation. It automatically
|
|
1214
|
+
handles dimension reduction and data reshaping for optimal heatmap display.
|
|
1182
1215
|
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1216
|
+
Automatic Time Reshaping:
|
|
1217
|
+
If only the 'time' dimension remains after faceting/animation (making the data 1D),
|
|
1218
|
+
the function automatically reshapes time into a 2D format using default values
|
|
1219
|
+
(timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap
|
|
1220
|
+
showing hours vs days.
|
|
1186
1221
|
|
|
1187
|
-
|
|
1188
|
-
|
|
1222
|
+
Args:
|
|
1223
|
+
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1224
|
+
2 dimensions, or a 'time' dimension that can be reshaped into 2D.
|
|
1225
|
+
colors: Color specification (colorscale name, list, or dict). Common options:
|
|
1226
|
+
'turbo', 'plasma', 'RdBu', 'portland'.
|
|
1227
|
+
title: The main title of the heatmap.
|
|
1228
|
+
facet_by: Dimension to create facets for. Creates a subplot grid.
|
|
1229
|
+
Can be a single dimension name or list (only first dimension used).
|
|
1230
|
+
Note: px.imshow only supports single-dimension faceting.
|
|
1231
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
1232
|
+
animate_by: Dimension to animate over. Creates animation frames.
|
|
1233
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
1234
|
+
facet_cols: Number of columns in the facet grid (used with facet_by).
|
|
1235
|
+
reshape_time: Time reshaping configuration:
|
|
1236
|
+
- 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains
|
|
1237
|
+
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1238
|
+
- None: Disable time reshaping (will error if only 1D time data)
|
|
1239
|
+
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1240
|
+
**imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow.
|
|
1241
|
+
Common options include:
|
|
1242
|
+
- aspect: 'auto', 'equal', or a number for aspect ratio
|
|
1243
|
+
- zmin, zmax: Minimum and maximum values for color scale
|
|
1244
|
+
- labels: Dict to customize axis labels
|
|
1189
1245
|
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
result_series['Other'] = other_sum
|
|
1246
|
+
Returns:
|
|
1247
|
+
A Plotly figure object containing the heatmap visualization.
|
|
1193
1248
|
|
|
1194
|
-
|
|
1249
|
+
Examples:
|
|
1250
|
+
Simple heatmap:
|
|
1195
1251
|
|
|
1196
|
-
|
|
1252
|
+
```python
|
|
1253
|
+
fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map')
|
|
1254
|
+
```
|
|
1197
1255
|
|
|
1198
|
-
|
|
1199
|
-
data_right_processed = preprocess_series(data_right)
|
|
1256
|
+
Facet by scenario:
|
|
1200
1257
|
|
|
1201
|
-
|
|
1202
|
-
|
|
1258
|
+
```python
|
|
1259
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2)
|
|
1260
|
+
```
|
|
1203
1261
|
|
|
1204
|
-
|
|
1205
|
-
color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True)
|
|
1262
|
+
Animate by period:
|
|
1206
1263
|
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
return None
|
|
1264
|
+
```python
|
|
1265
|
+
fig = heatmap_with_plotly(data_array, animate_by='period')
|
|
1266
|
+
```
|
|
1211
1267
|
|
|
1212
|
-
|
|
1213
|
-
values = data_series.values.tolist()
|
|
1214
|
-
trace_colors = [color_map[label] for label in labels]
|
|
1268
|
+
Automatic time reshaping (when only time dimension remains):
|
|
1215
1269
|
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
hole=hole,
|
|
1222
|
-
textinfo=text_info,
|
|
1223
|
-
textposition=text_position,
|
|
1224
|
-
insidetextorientation='radial',
|
|
1225
|
-
hovertemplate=hover_template,
|
|
1226
|
-
sort=True, # Sort values by default (largest first)
|
|
1227
|
-
)
|
|
1270
|
+
```python
|
|
1271
|
+
# Data with dims ['time', 'period','scenario']
|
|
1272
|
+
# After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe)
|
|
1273
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period')
|
|
1274
|
+
```
|
|
1228
1275
|
|
|
1229
|
-
|
|
1230
|
-
left_trace = create_pie_trace(data_left_processed, subtitles[0])
|
|
1231
|
-
if left_trace:
|
|
1232
|
-
left_trace.domain = dict(x=[0, 0.48])
|
|
1233
|
-
fig.add_trace(left_trace, row=1, col=1)
|
|
1276
|
+
Explicit time reshaping:
|
|
1234
1277
|
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1278
|
+
```python
|
|
1279
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
|
|
1280
|
+
```
|
|
1281
|
+
"""
|
|
1282
|
+
if colors is None:
|
|
1283
|
+
colors = CONFIG.Plotting.default_sequential_colorscale
|
|
1240
1284
|
|
|
1241
|
-
#
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1285
|
+
# Apply CONFIG defaults if not explicitly set
|
|
1286
|
+
if facet_cols is None:
|
|
1287
|
+
facet_cols = CONFIG.Plotting.default_facet_cols
|
|
1288
|
+
|
|
1289
|
+
# Handle empty data
|
|
1290
|
+
if data.size == 0:
|
|
1291
|
+
return go.Figure()
|
|
1292
|
+
|
|
1293
|
+
# Apply time reshaping using the new unified function
|
|
1294
|
+
data = reshape_data_for_heatmap(
|
|
1295
|
+
data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill
|
|
1249
1296
|
)
|
|
1250
1297
|
|
|
1251
|
-
|
|
1298
|
+
# Get available dimensions
|
|
1299
|
+
available_dims = list(data.dims)
|
|
1252
1300
|
|
|
1301
|
+
# Validate and filter facet_by dimensions
|
|
1302
|
+
if facet_by is not None:
|
|
1303
|
+
if isinstance(facet_by, str):
|
|
1304
|
+
if facet_by not in available_dims:
|
|
1305
|
+
logger.debug(
|
|
1306
|
+
f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
|
|
1307
|
+
f'Ignoring facet_by parameter.'
|
|
1308
|
+
)
|
|
1309
|
+
facet_by = None
|
|
1310
|
+
elif isinstance(facet_by, list):
|
|
1311
|
+
missing_dims = [dim for dim in facet_by if dim not in available_dims]
|
|
1312
|
+
facet_by = [dim for dim in facet_by if dim in available_dims]
|
|
1313
|
+
if missing_dims:
|
|
1314
|
+
logger.debug(
|
|
1315
|
+
f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
|
|
1316
|
+
f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
|
|
1317
|
+
)
|
|
1318
|
+
if len(facet_by) == 0:
|
|
1319
|
+
facet_by = None
|
|
1320
|
+
|
|
1321
|
+
# Validate animate_by dimension
|
|
1322
|
+
if animate_by is not None and animate_by not in available_dims:
|
|
1323
|
+
logger.debug(
|
|
1324
|
+
f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
|
|
1325
|
+
f'Ignoring animate_by parameter.'
|
|
1326
|
+
)
|
|
1327
|
+
animate_by = None
|
|
1253
1328
|
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
legend_title: str = '',
|
|
1261
|
-
hole: float = 0.2,
|
|
1262
|
-
lower_percentage_group: float = 5.0,
|
|
1263
|
-
figsize: tuple[int, int] = (14, 7),
|
|
1264
|
-
fig: plt.Figure | None = None,
|
|
1265
|
-
axes: list[plt.Axes] | None = None,
|
|
1266
|
-
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
1267
|
-
"""
|
|
1268
|
-
Create two pie charts side by side with Matplotlib, with consistent coloring across both charts.
|
|
1269
|
-
Leverages the existing pie_with_matplotlib function.
|
|
1329
|
+
# Determine which dimensions are used for faceting/animation
|
|
1330
|
+
facet_dims = []
|
|
1331
|
+
if facet_by:
|
|
1332
|
+
facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by
|
|
1333
|
+
if animate_by:
|
|
1334
|
+
facet_dims.append(animate_by)
|
|
1270
1335
|
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
data_right: Series for the right pie chart.
|
|
1274
|
-
colors: Color specification, can be:
|
|
1275
|
-
- A string with a colormap name (e.g., 'viridis', 'plasma')
|
|
1276
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
1277
|
-
- A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
|
|
1278
|
-
title: The main title of the plot.
|
|
1279
|
-
subtitles: Tuple containing the subtitles for (left, right) charts.
|
|
1280
|
-
legend_title: The title for the legend.
|
|
1281
|
-
hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
|
|
1282
|
-
lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
|
|
1283
|
-
figsize: The size of the figure (width, height) in inches.
|
|
1284
|
-
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1285
|
-
axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created.
|
|
1336
|
+
# Get remaining dimensions for the heatmap itself
|
|
1337
|
+
heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
|
|
1286
1338
|
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
logger.error('Both datasets are empty. Returning empty figure.')
|
|
1293
|
-
if fig is None:
|
|
1294
|
-
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1295
|
-
return fig, axes
|
|
1296
|
-
|
|
1297
|
-
# Create figure and axes if not provided
|
|
1298
|
-
if fig is None or axes is None:
|
|
1299
|
-
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1300
|
-
|
|
1301
|
-
# Process series to handle negative values and apply minimum percentage threshold
|
|
1302
|
-
def preprocess_series(series: pd.Series):
|
|
1303
|
-
"""
|
|
1304
|
-
Preprocess a series for pie chart display by handling negative values
|
|
1305
|
-
and grouping the smallest parts together if they collectively represent
|
|
1306
|
-
less than the specified percentage threshold.
|
|
1307
|
-
"""
|
|
1308
|
-
# Handle negative values
|
|
1309
|
-
if (series < 0).any():
|
|
1310
|
-
logger.error('Negative values detected in data. Using absolute values for pie chart.')
|
|
1311
|
-
series = series.abs()
|
|
1339
|
+
if len(heatmap_dims) < 2:
|
|
1340
|
+
# Handle single-dimension case by adding variable name as a dimension
|
|
1341
|
+
if len(heatmap_dims) == 1:
|
|
1342
|
+
# Get the variable name, or use a default
|
|
1343
|
+
var_name = data.name if data.name else 'value'
|
|
1312
1344
|
|
|
1313
|
-
|
|
1314
|
-
|
|
1345
|
+
# Expand the DataArray by adding a new dimension with the variable name
|
|
1346
|
+
data = data.expand_dims({'variable': [var_name]})
|
|
1315
1347
|
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
if total > 0:
|
|
1320
|
-
# Sort series by value (ascending)
|
|
1321
|
-
sorted_series = series.sort_values()
|
|
1348
|
+
# Update available dimensions
|
|
1349
|
+
available_dims = list(data.dims)
|
|
1350
|
+
heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
|
|
1322
1351
|
|
|
1323
|
-
|
|
1324
|
-
|
|
1352
|
+
logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}')
|
|
1353
|
+
else:
|
|
1354
|
+
# No dimensions at all - cannot create a heatmap
|
|
1355
|
+
logger.error(
|
|
1356
|
+
f'Heatmap requires at least 1 dimension. '
|
|
1357
|
+
f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
|
|
1358
|
+
)
|
|
1359
|
+
return go.Figure()
|
|
1360
|
+
|
|
1361
|
+
# Setup faceting parameters for Plotly Express
|
|
1362
|
+
# Note: px.imshow only supports facet_col, not facet_row
|
|
1363
|
+
facet_col_param = None
|
|
1364
|
+
if facet_by:
|
|
1365
|
+
if isinstance(facet_by, str):
|
|
1366
|
+
facet_col_param = facet_by
|
|
1367
|
+
elif len(facet_by) == 1:
|
|
1368
|
+
facet_col_param = facet_by[0]
|
|
1369
|
+
elif len(facet_by) >= 2:
|
|
1370
|
+
# px.imshow doesn't support facet_row, so we can only facet by one dimension
|
|
1371
|
+
# Use the first dimension and warn about the rest
|
|
1372
|
+
facet_col_param = facet_by[0]
|
|
1373
|
+
logger.warning(
|
|
1374
|
+
f'px.imshow only supports faceting by a single dimension. '
|
|
1375
|
+
f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. '
|
|
1376
|
+
f'Consider using animate_by for additional dimensions.'
|
|
1377
|
+
)
|
|
1325
1378
|
|
|
1326
|
-
|
|
1327
|
-
|
|
1379
|
+
# Create the imshow plot - px.imshow can work directly with xarray DataArrays
|
|
1380
|
+
common_args = {
|
|
1381
|
+
'img': data,
|
|
1382
|
+
'color_continuous_scale': colors,
|
|
1383
|
+
'title': title,
|
|
1384
|
+
}
|
|
1328
1385
|
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1386
|
+
# Add faceting if specified
|
|
1387
|
+
if facet_col_param:
|
|
1388
|
+
common_args['facet_col'] = facet_col_param
|
|
1389
|
+
if facet_cols:
|
|
1390
|
+
common_args['facet_col_wrap'] = facet_cols
|
|
1332
1391
|
|
|
1333
|
-
|
|
1334
|
-
|
|
1392
|
+
# Add animation if specified
|
|
1393
|
+
if animate_by:
|
|
1394
|
+
common_args['animation_frame'] = animate_by
|
|
1335
1395
|
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
result_series['Other'] = other_sum
|
|
1396
|
+
# Merge in additional imshow kwargs
|
|
1397
|
+
common_args.update(imshow_kwargs)
|
|
1339
1398
|
|
|
1340
|
-
|
|
1399
|
+
try:
|
|
1400
|
+
fig = px.imshow(**common_args)
|
|
1401
|
+
except Exception as e:
|
|
1402
|
+
logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
|
|
1403
|
+
# Fallback: create a simple heatmap without faceting
|
|
1404
|
+
fallback_args = {
|
|
1405
|
+
'img': data.values,
|
|
1406
|
+
'color_continuous_scale': colors,
|
|
1407
|
+
'title': title,
|
|
1408
|
+
}
|
|
1409
|
+
fallback_args.update(imshow_kwargs)
|
|
1410
|
+
fig = px.imshow(**fallback_args)
|
|
1341
1411
|
|
|
1342
|
-
|
|
1412
|
+
return fig
|
|
1343
1413
|
|
|
1344
|
-
# Preprocess data
|
|
1345
|
-
data_left_processed = preprocess_series(data_left)
|
|
1346
|
-
data_right_processed = preprocess_series(data_right)
|
|
1347
1414
|
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1415
|
+
def heatmap_with_matplotlib(
|
|
1416
|
+
data: xr.DataArray,
|
|
1417
|
+
colors: ColorType | None = None,
|
|
1418
|
+
title: str = '',
|
|
1419
|
+
figsize: tuple[float, float] = (12, 6),
|
|
1420
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1421
|
+
| Literal['auto']
|
|
1422
|
+
| None = 'auto',
|
|
1423
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1424
|
+
vmin: float | None = None,
|
|
1425
|
+
vmax: float | None = None,
|
|
1426
|
+
imshow_kwargs: dict[str, Any] | None = None,
|
|
1427
|
+
cbar_kwargs: dict[str, Any] | None = None,
|
|
1428
|
+
**kwargs: Any,
|
|
1429
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
1430
|
+
"""
|
|
1431
|
+
Plot a heatmap visualization using Matplotlib's imshow.
|
|
1351
1432
|
|
|
1352
|
-
|
|
1353
|
-
|
|
1433
|
+
This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's
|
|
1434
|
+
imshow function. For multi-dimensional data, only the first two dimensions are used.
|
|
1354
1435
|
|
|
1355
|
-
|
|
1356
|
-
|
|
1436
|
+
Args:
|
|
1437
|
+
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1438
|
+
2 dimensions. If more than 2 dimensions exist, additional dimensions will
|
|
1439
|
+
be reduced by taking the first slice.
|
|
1440
|
+
colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu').
|
|
1441
|
+
title: The title of the heatmap.
|
|
1442
|
+
figsize: The size of the figure (width, height) in inches.
|
|
1443
|
+
reshape_time: Time reshaping configuration:
|
|
1444
|
+
- 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
|
|
1445
|
+
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1446
|
+
- None: Disable time reshaping
|
|
1447
|
+
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1448
|
+
vmin: Minimum value for color scale. If None, uses data minimum.
|
|
1449
|
+
vmax: Maximum value for color scale. If None, uses data maximum.
|
|
1450
|
+
imshow_kwargs: Optional dict of parameters to pass to ax.imshow().
|
|
1451
|
+
Use this to customize image properties (e.g., interpolation, aspect).
|
|
1452
|
+
cbar_kwargs: Optional dict of parameters to pass to plt.colorbar().
|
|
1453
|
+
Use this to customize colorbar properties (e.g., orientation, label).
|
|
1454
|
+
**kwargs: Additional keyword arguments passed to ax.imshow().
|
|
1455
|
+
Common options include:
|
|
1456
|
+
- interpolation: 'nearest', 'bilinear', 'bicubic', etc.
|
|
1457
|
+
- alpha: Transparency level (0-1)
|
|
1458
|
+
- extent: [left, right, bottom, top] for axis limits
|
|
1357
1459
|
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else []
|
|
1460
|
+
Returns:
|
|
1461
|
+
A tuple containing the Matplotlib figure and axes objects used for the plot.
|
|
1361
1462
|
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
axes[0].set_title(subtitles[0])
|
|
1367
|
-
axes[0].axis('off')
|
|
1463
|
+
Notes:
|
|
1464
|
+
- Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features.
|
|
1465
|
+
- The y-axis is automatically inverted to display data with origin at top-left.
|
|
1466
|
+
- A colorbar is added to show the value scale.
|
|
1368
1467
|
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
axes[1].axis('off')
|
|
1468
|
+
Examples:
|
|
1469
|
+
```python
|
|
1470
|
+
fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature')
|
|
1471
|
+
plt.savefig('heatmap.png')
|
|
1472
|
+
```
|
|
1375
1473
|
|
|
1376
|
-
|
|
1377
|
-
if title:
|
|
1378
|
-
fig.suptitle(title, fontsize=16, y=0.98)
|
|
1474
|
+
Time reshaping:
|
|
1379
1475
|
|
|
1380
|
-
|
|
1381
|
-
|
|
1476
|
+
```python
|
|
1477
|
+
fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
|
|
1478
|
+
```
|
|
1479
|
+
"""
|
|
1480
|
+
if colors is None:
|
|
1481
|
+
colors = CONFIG.Plotting.default_sequential_colorscale
|
|
1382
1482
|
|
|
1383
|
-
#
|
|
1384
|
-
if
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
ax.get_legend().remove()
|
|
1483
|
+
# Initialize kwargs if not provided
|
|
1484
|
+
if imshow_kwargs is None:
|
|
1485
|
+
imshow_kwargs = {}
|
|
1486
|
+
if cbar_kwargs is None:
|
|
1487
|
+
cbar_kwargs = {}
|
|
1389
1488
|
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1489
|
+
# Merge any additional kwargs into imshow_kwargs
|
|
1490
|
+
# This allows users to pass imshow options directly
|
|
1491
|
+
imshow_kwargs.update(kwargs)
|
|
1393
1492
|
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
labels_for_legend.append(label)
|
|
1493
|
+
# Handle empty data
|
|
1494
|
+
if data.size == 0:
|
|
1495
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1496
|
+
return fig, ax
|
|
1399
1497
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
labels=labels_for_legend,
|
|
1404
|
-
title=legend_title,
|
|
1405
|
-
loc='lower center',
|
|
1406
|
-
bbox_to_anchor=(0.5, 0),
|
|
1407
|
-
ncol=min(len(all_labels), 5), # Limit columns to 5 for readability
|
|
1408
|
-
)
|
|
1498
|
+
# Apply time reshaping using the new unified function
|
|
1499
|
+
# Matplotlib doesn't support faceting/animation, so we pass None for those
|
|
1500
|
+
data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
|
|
1409
1501
|
|
|
1410
|
-
|
|
1411
|
-
|
|
1502
|
+
# Handle single-dimension case by adding variable name as a dimension
|
|
1503
|
+
if isinstance(data, xr.DataArray) and len(data.dims) == 1:
|
|
1504
|
+
var_name = data.name if data.name else 'value'
|
|
1505
|
+
data = data.expand_dims({'variable': [var_name]})
|
|
1506
|
+
logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}')
|
|
1412
1507
|
|
|
1413
|
-
|
|
1508
|
+
# Create figure and axes
|
|
1509
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1510
|
+
|
|
1511
|
+
# Extract data values
|
|
1512
|
+
# If data has more than 2 dimensions, we need to reduce it
|
|
1513
|
+
if isinstance(data, xr.DataArray):
|
|
1514
|
+
# Get the first 2 dimensions
|
|
1515
|
+
dims = list(data.dims)
|
|
1516
|
+
if len(dims) > 2:
|
|
1517
|
+
logger.warning(
|
|
1518
|
+
f'Data has {len(dims)} dimensions: {dims}. '
|
|
1519
|
+
f'Only the first 2 will be used for the heatmap. '
|
|
1520
|
+
f'Use the plotly engine for faceting/animation support.'
|
|
1521
|
+
)
|
|
1522
|
+
# Select only the first 2 dimensions by taking first slice of others
|
|
1523
|
+
selection = {dim: 0 for dim in dims[2:]}
|
|
1524
|
+
data = data.isel(selection)
|
|
1525
|
+
|
|
1526
|
+
values = data.values
|
|
1527
|
+
x_labels = data.dims[1] if len(data.dims) > 1 else 'x'
|
|
1528
|
+
y_labels = data.dims[0] if len(data.dims) > 0 else 'y'
|
|
1529
|
+
else:
|
|
1530
|
+
values = data
|
|
1531
|
+
x_labels = 'x'
|
|
1532
|
+
y_labels = 'y'
|
|
1533
|
+
|
|
1534
|
+
# Create the heatmap using imshow with user customizations
|
|
1535
|
+
imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax}
|
|
1536
|
+
imshow_defaults.update(imshow_kwargs) # User kwargs override defaults
|
|
1537
|
+
im = ax.imshow(values, **imshow_defaults)
|
|
1538
|
+
|
|
1539
|
+
# Add colorbar with user customizations
|
|
1540
|
+
cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05}
|
|
1541
|
+
cbar_defaults.update(cbar_kwargs) # User kwargs override defaults
|
|
1542
|
+
cbar = plt.colorbar(im, **cbar_defaults)
|
|
1543
|
+
|
|
1544
|
+
# Set colorbar label if not overridden by user
|
|
1545
|
+
if 'label' not in cbar_kwargs:
|
|
1546
|
+
cbar.set_label('Value')
|
|
1547
|
+
|
|
1548
|
+
# Set labels and title
|
|
1549
|
+
ax.set_xlabel(str(x_labels).capitalize())
|
|
1550
|
+
ax.set_ylabel(str(y_labels).capitalize())
|
|
1551
|
+
ax.set_title(title)
|
|
1552
|
+
|
|
1553
|
+
# Apply tight layout
|
|
1554
|
+
fig.tight_layout()
|
|
1555
|
+
|
|
1556
|
+
return fig, ax
|
|
1414
1557
|
|
|
1415
1558
|
|
|
1416
1559
|
def export_figure(
|
|
@@ -1418,8 +1561,9 @@ def export_figure(
|
|
|
1418
1561
|
default_path: pathlib.Path,
|
|
1419
1562
|
default_filetype: str | None = None,
|
|
1420
1563
|
user_path: pathlib.Path | None = None,
|
|
1421
|
-
show: bool =
|
|
1564
|
+
show: bool | None = None,
|
|
1422
1565
|
save: bool = False,
|
|
1566
|
+
dpi: int | None = None,
|
|
1423
1567
|
) -> go.Figure | tuple[plt.Figure, plt.Axes]:
|
|
1424
1568
|
"""
|
|
1425
1569
|
Export a figure to a file and or show it.
|
|
@@ -1429,13 +1573,21 @@ def export_figure(
|
|
|
1429
1573
|
default_path: The default file path if no user filename is provided.
|
|
1430
1574
|
default_filetype: The default filetype if the path doesnt end with a filetype.
|
|
1431
1575
|
user_path: An optional user-specified file path.
|
|
1432
|
-
show: Whether to display the figure (default:
|
|
1576
|
+
show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None).
|
|
1433
1577
|
save: Whether to save the figure (default: False).
|
|
1578
|
+
dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi.
|
|
1434
1579
|
|
|
1435
1580
|
Raises:
|
|
1436
1581
|
ValueError: If no default filetype is provided and the path doesn't specify a filetype.
|
|
1437
1582
|
TypeError: If the figure type is not supported.
|
|
1438
1583
|
"""
|
|
1584
|
+
# Apply CONFIG defaults if not explicitly set
|
|
1585
|
+
if show is None:
|
|
1586
|
+
show = CONFIG.Plotting.default_show
|
|
1587
|
+
|
|
1588
|
+
if dpi is None:
|
|
1589
|
+
dpi = CONFIG.Plotting.default_dpi
|
|
1590
|
+
|
|
1439
1591
|
filename = user_path or default_path
|
|
1440
1592
|
filename = filename.with_name(filename.name.replace('|', '__'))
|
|
1441
1593
|
if filename.suffix == '':
|
|
@@ -1450,25 +1602,17 @@ def export_figure(
|
|
|
1450
1602
|
filename = filename.with_suffix('.html')
|
|
1451
1603
|
|
|
1452
1604
|
try:
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
#
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
plotly.offline.plot(fig, filename=str(filename))
|
|
1465
|
-
elif save and not show:
|
|
1466
|
-
# Save without opening
|
|
1467
|
-
fig.write_html(str(filename))
|
|
1468
|
-
elif show and not save:
|
|
1469
|
-
# Show interactively without saving
|
|
1470
|
-
fig.show()
|
|
1471
|
-
# If neither save nor show: do nothing
|
|
1605
|
+
# Respect show and save flags (tests should set CONFIG.Plotting.default_show=False)
|
|
1606
|
+
if save and show:
|
|
1607
|
+
# Save and auto-open in browser
|
|
1608
|
+
plotly.offline.plot(fig, filename=str(filename))
|
|
1609
|
+
elif save and not show:
|
|
1610
|
+
# Save without opening
|
|
1611
|
+
fig.write_html(str(filename))
|
|
1612
|
+
elif show and not save:
|
|
1613
|
+
# Show interactively without saving
|
|
1614
|
+
fig.show()
|
|
1615
|
+
# If neither save nor show: do nothing
|
|
1472
1616
|
finally:
|
|
1473
1617
|
# Cleanup to prevent socket warnings
|
|
1474
1618
|
if hasattr(fig, '_renderer'):
|
|
@@ -1479,16 +1623,15 @@ def export_figure(
|
|
|
1479
1623
|
elif isinstance(figure_like, tuple):
|
|
1480
1624
|
fig, ax = figure_like
|
|
1481
1625
|
if show:
|
|
1482
|
-
# Only show if using interactive backend
|
|
1626
|
+
# Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False)
|
|
1483
1627
|
backend = matplotlib.get_backend().lower()
|
|
1484
1628
|
is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'}
|
|
1485
|
-
is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
|
|
1486
1629
|
|
|
1487
|
-
if is_interactive
|
|
1630
|
+
if is_interactive:
|
|
1488
1631
|
plt.show()
|
|
1489
1632
|
|
|
1490
1633
|
if save:
|
|
1491
|
-
fig.savefig(str(filename), dpi=
|
|
1634
|
+
fig.savefig(str(filename), dpi=dpi)
|
|
1492
1635
|
plt.close(fig) # Close figure to free memory
|
|
1493
1636
|
|
|
1494
1637
|
return fig, ax
|