flixopt 3.0.3__py3-none-any.whl → 3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flixopt might be problematic. Click here for more details.
- flixopt/__init__.py +1 -1
- flixopt/components.py +12 -10
- flixopt/effects.py +11 -13
- flixopt/elements.py +4 -0
- flixopt/interface.py +5 -0
- flixopt/plotting.py +668 -318
- flixopt/results.py +681 -156
- flixopt/structure.py +3 -6
- {flixopt-3.0.3.dist-info → flixopt-3.1.1.dist-info}/METADATA +4 -1
- {flixopt-3.0.3.dist-info → flixopt-3.1.1.dist-info}/RECORD +13 -13
- {flixopt-3.0.3.dist-info → flixopt-3.1.1.dist-info}/WHEEL +0 -0
- {flixopt-3.0.3.dist-info → flixopt-3.1.1.dist-info}/licenses/LICENSE +0 -0
- {flixopt-3.0.3.dist-info → flixopt-3.1.1.dist-info}/top_level.txt +0 -0
flixopt/plotting.py
CHANGED
|
@@ -39,6 +39,7 @@ import pandas as pd
|
|
|
39
39
|
import plotly.express as px
|
|
40
40
|
import plotly.graph_objects as go
|
|
41
41
|
import plotly.offline
|
|
42
|
+
import xarray as xr
|
|
42
43
|
from plotly.exceptions import PlotlyError
|
|
43
44
|
|
|
44
45
|
if TYPE_CHECKING:
|
|
@@ -326,143 +327,269 @@ class ColorProcessor:
|
|
|
326
327
|
|
|
327
328
|
|
|
328
329
|
def with_plotly(
|
|
329
|
-
data: pd.DataFrame,
|
|
330
|
+
data: pd.DataFrame | xr.DataArray | xr.Dataset,
|
|
330
331
|
mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
|
|
331
332
|
colors: ColorType = 'viridis',
|
|
332
333
|
title: str = '',
|
|
333
334
|
ylabel: str = '',
|
|
334
335
|
xlabel: str = 'Time in h',
|
|
335
336
|
fig: go.Figure | None = None,
|
|
337
|
+
facet_by: str | list[str] | None = None,
|
|
338
|
+
animate_by: str | None = None,
|
|
339
|
+
facet_cols: int = 3,
|
|
340
|
+
shared_yaxes: bool = True,
|
|
341
|
+
shared_xaxes: bool = True,
|
|
336
342
|
) -> go.Figure:
|
|
337
343
|
"""
|
|
338
|
-
Plot
|
|
344
|
+
Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
|
|
345
|
+
|
|
346
|
+
Uses Plotly Express for convenient faceting and animation with automatic styling.
|
|
347
|
+
For simple plots without faceting, can optionally add to an existing figure.
|
|
339
348
|
|
|
340
349
|
Args:
|
|
341
|
-
data: A DataFrame
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
|
|
347
|
-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
|
|
348
|
-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
|
|
349
|
-
title: The title of the plot.
|
|
350
|
+
data: A DataFrame or xarray DataArray/Dataset to plot.
|
|
351
|
+
mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
|
|
352
|
+
'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
|
|
353
|
+
colors: Color specification (colormap, list, or dict mapping labels to colors).
|
|
354
|
+
title: The main title of the plot.
|
|
350
355
|
ylabel: The label for the y-axis.
|
|
351
356
|
xlabel: The label for the x-axis.
|
|
352
|
-
fig: A Plotly figure object to plot on
|
|
357
|
+
fig: A Plotly figure object to plot on (only for simple plots without faceting).
|
|
358
|
+
If not provided, a new figure will be created.
|
|
359
|
+
facet_by: Dimension(s) to create facets for. Creates a subplot grid.
|
|
360
|
+
Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
|
|
361
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
362
|
+
animate_by: Dimension to animate over. Creates animation frames.
|
|
363
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
364
|
+
facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
|
|
365
|
+
shared_yaxes: Whether subplots share y-axes.
|
|
366
|
+
shared_xaxes: Whether subplots share x-axes.
|
|
353
367
|
|
|
354
368
|
Returns:
|
|
355
|
-
A Plotly figure object containing the
|
|
369
|
+
A Plotly figure object containing the faceted/animated plot.
|
|
370
|
+
|
|
371
|
+
Examples:
|
|
372
|
+
Simple plot:
|
|
373
|
+
|
|
374
|
+
```python
|
|
375
|
+
fig = with_plotly(df, mode='area', title='Energy Mix')
|
|
376
|
+
```
|
|
377
|
+
|
|
378
|
+
Facet by scenario:
|
|
379
|
+
|
|
380
|
+
```python
|
|
381
|
+
fig = with_plotly(ds, facet_by='scenario', facet_cols=2)
|
|
382
|
+
```
|
|
383
|
+
|
|
384
|
+
Animate by period:
|
|
385
|
+
|
|
386
|
+
```python
|
|
387
|
+
fig = with_plotly(ds, animate_by='period')
|
|
388
|
+
```
|
|
389
|
+
|
|
390
|
+
Facet and animate:
|
|
391
|
+
|
|
392
|
+
```python
|
|
393
|
+
fig = with_plotly(ds, facet_by='scenario', animate_by='period')
|
|
394
|
+
```
|
|
356
395
|
"""
|
|
357
396
|
if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
|
|
358
397
|
raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
|
|
359
|
-
if data.empty:
|
|
360
|
-
return go.Figure()
|
|
361
398
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
399
|
+
# Handle empty data
|
|
400
|
+
if isinstance(data, pd.DataFrame) and data.empty:
|
|
401
|
+
return go.Figure()
|
|
402
|
+
elif isinstance(data, xr.DataArray) and data.size == 0:
|
|
403
|
+
return go.Figure()
|
|
404
|
+
elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0:
|
|
405
|
+
return go.Figure()
|
|
365
406
|
|
|
366
|
-
if
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
407
|
+
# Warn if fig parameter is used with faceting
|
|
408
|
+
if fig is not None and (facet_by is not None or animate_by is not None):
|
|
409
|
+
logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.')
|
|
410
|
+
fig = None
|
|
411
|
+
|
|
412
|
+
# Convert xarray to long-form DataFrame for Plotly Express
|
|
413
|
+
if isinstance(data, (xr.DataArray, xr.Dataset)):
|
|
414
|
+
# Convert to long-form (tidy) DataFrame
|
|
415
|
+
# Structure: time, variable, value, scenario, period, ... (all dims as columns)
|
|
416
|
+
if isinstance(data, xr.Dataset):
|
|
417
|
+
# Stack all data variables into long format
|
|
418
|
+
df_long = data.to_dataframe().reset_index()
|
|
419
|
+
# Melt to get: time, scenario, period, ..., variable, value
|
|
420
|
+
id_vars = [dim for dim in data.dims]
|
|
421
|
+
value_vars = list(data.data_vars)
|
|
422
|
+
df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
|
|
423
|
+
else:
|
|
424
|
+
# DataArray
|
|
425
|
+
df_long = data.to_dataframe().reset_index()
|
|
426
|
+
if data.name:
|
|
427
|
+
df_long = df_long.rename(columns={data.name: 'value'})
|
|
428
|
+
else:
|
|
429
|
+
# Unnamed DataArray, find the value column
|
|
430
|
+
non_dim_cols = [col for col in df_long.columns if col not in data.dims]
|
|
431
|
+
if len(non_dim_cols) != 1:
|
|
432
|
+
raise ValueError(
|
|
433
|
+
f'Expected exactly one non-dimension column for unnamed DataArray, '
|
|
434
|
+
f'but found {len(non_dim_cols)}: {non_dim_cols}'
|
|
435
|
+
)
|
|
436
|
+
value_col = non_dim_cols[0]
|
|
437
|
+
df_long = df_long.rename(columns={value_col: 'value'})
|
|
438
|
+
df_long['variable'] = data.name or 'data'
|
|
439
|
+
else:
|
|
440
|
+
# Already a DataFrame - convert to long format for Plotly Express
|
|
441
|
+
df_long = data.reset_index()
|
|
442
|
+
if 'time' not in df_long.columns:
|
|
443
|
+
# First column is probably time
|
|
444
|
+
df_long = df_long.rename(columns={df_long.columns[0]: 'time'})
|
|
445
|
+
# Melt to long format
|
|
446
|
+
id_vars = [
|
|
447
|
+
col
|
|
448
|
+
for col in df_long.columns
|
|
449
|
+
if col in ['time', 'scenario', 'period']
|
|
450
|
+
or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else [])
|
|
451
|
+
]
|
|
452
|
+
value_vars = [col for col in df_long.columns if col not in id_vars]
|
|
453
|
+
df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
|
|
454
|
+
|
|
455
|
+
# Validate facet_by and animate_by dimensions exist in the data
|
|
456
|
+
available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
|
|
457
|
+
|
|
458
|
+
# Check facet_by dimensions
|
|
459
|
+
if facet_by is not None:
|
|
460
|
+
if isinstance(facet_by, str):
|
|
461
|
+
if facet_by not in available_dims:
|
|
462
|
+
logger.debug(
|
|
463
|
+
f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
|
|
464
|
+
f'Ignoring facet_by parameter.'
|
|
376
465
|
)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
466
|
+
facet_by = None
|
|
467
|
+
elif isinstance(facet_by, list):
|
|
468
|
+
# Filter out dimensions that don't exist
|
|
469
|
+
missing_dims = [dim for dim in facet_by if dim not in available_dims]
|
|
470
|
+
facet_by = [dim for dim in facet_by if dim in available_dims]
|
|
471
|
+
if missing_dims:
|
|
472
|
+
logger.debug(
|
|
473
|
+
f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
|
|
474
|
+
f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
|
|
475
|
+
)
|
|
476
|
+
if len(facet_by) == 0:
|
|
477
|
+
facet_by = None
|
|
478
|
+
|
|
479
|
+
# Check animate_by dimension
|
|
480
|
+
if animate_by is not None and animate_by not in available_dims:
|
|
481
|
+
logger.debug(
|
|
482
|
+
f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
|
|
483
|
+
f'Ignoring animate_by parameter.'
|
|
383
484
|
)
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
485
|
+
animate_by = None
|
|
486
|
+
|
|
487
|
+
# Setup faceting parameters for Plotly Express
|
|
488
|
+
facet_row = None
|
|
489
|
+
facet_col = None
|
|
490
|
+
if facet_by:
|
|
491
|
+
if isinstance(facet_by, str):
|
|
492
|
+
# Single facet dimension - use facet_col with facet_col_wrap
|
|
493
|
+
facet_col = facet_by
|
|
494
|
+
elif len(facet_by) == 1:
|
|
495
|
+
facet_col = facet_by[0]
|
|
496
|
+
elif len(facet_by) == 2:
|
|
497
|
+
# Two facet dimensions - use facet_row and facet_col
|
|
498
|
+
facet_row = facet_by[0]
|
|
499
|
+
facet_col = facet_by[1]
|
|
500
|
+
else:
|
|
501
|
+
raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
|
|
502
|
+
|
|
503
|
+
# Process colors
|
|
504
|
+
all_vars = df_long['variable'].unique().tolist()
|
|
505
|
+
processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars)
|
|
506
|
+
color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)}
|
|
507
|
+
|
|
508
|
+
# Create plot using Plotly Express based on mode
|
|
509
|
+
common_args = {
|
|
510
|
+
'data_frame': df_long,
|
|
511
|
+
'x': 'time',
|
|
512
|
+
'y': 'value',
|
|
513
|
+
'color': 'variable',
|
|
514
|
+
'facet_row': facet_row,
|
|
515
|
+
'facet_col': facet_col,
|
|
516
|
+
'animation_frame': animate_by,
|
|
517
|
+
'color_discrete_map': color_discrete_map,
|
|
518
|
+
'title': title,
|
|
519
|
+
'labels': {'value': ylabel, 'time': xlabel, 'variable': ''},
|
|
520
|
+
}
|
|
387
521
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
522
|
+
# Add facet_col_wrap for single facet dimension
|
|
523
|
+
if facet_col and not facet_row:
|
|
524
|
+
common_args['facet_col_wrap'] = facet_cols
|
|
525
|
+
|
|
526
|
+
if mode == 'stacked_bar':
|
|
527
|
+
fig = px.bar(**common_args)
|
|
528
|
+
fig.update_traces(marker_line_width=0)
|
|
529
|
+
fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
|
|
530
|
+
elif mode == 'grouped_bar':
|
|
531
|
+
fig = px.bar(**common_args)
|
|
532
|
+
fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
|
|
393
533
|
elif mode == 'line':
|
|
394
|
-
|
|
395
|
-
fig.add_trace(
|
|
396
|
-
go.Scatter(
|
|
397
|
-
x=data.index,
|
|
398
|
-
y=data[column],
|
|
399
|
-
mode='lines',
|
|
400
|
-
name=column,
|
|
401
|
-
line=dict(shape='hv', color=processed_colors[i]),
|
|
402
|
-
)
|
|
403
|
-
)
|
|
534
|
+
fig = px.line(**common_args, line_shape='hv') # Stepped lines
|
|
404
535
|
elif mode == 'area':
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
# Split columns into positive, negative, and mixed categories
|
|
408
|
-
positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()])
|
|
409
|
-
negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()])
|
|
410
|
-
negative_columns = [column for column in negative_columns if column not in positive_columns]
|
|
411
|
-
mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns))
|
|
412
|
-
|
|
413
|
-
if mixed_columns:
|
|
414
|
-
logger.error(
|
|
415
|
-
f'Data for plotting stacked lines contains columns with both positive and negative values:'
|
|
416
|
-
f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
|
|
417
|
-
)
|
|
536
|
+
# Use Plotly Express to create the area plot (preserves animation, legends, faceting)
|
|
537
|
+
fig = px.area(**common_args, line_shape='hv')
|
|
418
538
|
|
|
419
|
-
#
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
go.Scatter(
|
|
425
|
-
x=data.index,
|
|
426
|
-
y=data[column],
|
|
427
|
-
mode='lines',
|
|
428
|
-
name=column,
|
|
429
|
-
line=dict(shape='hv', color=colors_stacked[column]),
|
|
430
|
-
fill='tonexty',
|
|
431
|
-
stackgroup='pos' if column in positive_columns else 'neg',
|
|
432
|
-
)
|
|
433
|
-
)
|
|
539
|
+
# Classify each variable based on its values
|
|
540
|
+
variable_classification = {}
|
|
541
|
+
for var in all_vars:
|
|
542
|
+
var_data = df_long[df_long['variable'] == var]['value']
|
|
543
|
+
var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
|
|
434
544
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
name=column,
|
|
442
|
-
line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
|
|
545
|
+
if len(var_data_clean) == 0:
|
|
546
|
+
variable_classification[var] = 'zero'
|
|
547
|
+
else:
|
|
548
|
+
has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
|
|
549
|
+
variable_classification[var] = (
|
|
550
|
+
'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
|
|
443
551
|
)
|
|
444
|
-
)
|
|
445
552
|
|
|
446
|
-
|
|
553
|
+
# Log warning for mixed variables
|
|
554
|
+
mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
|
|
555
|
+
if mixed_vars:
|
|
556
|
+
logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
|
|
557
|
+
|
|
558
|
+
all_traces = list(fig.data)
|
|
559
|
+
for frame in fig.frames:
|
|
560
|
+
all_traces.extend(frame.data)
|
|
561
|
+
|
|
562
|
+
for trace in all_traces:
|
|
563
|
+
cls = variable_classification.get(trace.name, None)
|
|
564
|
+
# Only stack positive and negative, not mixed or zero
|
|
565
|
+
trace.stackgroup = cls if cls in ('positive', 'negative') else None
|
|
566
|
+
|
|
567
|
+
if cls in ('positive', 'negative'):
|
|
568
|
+
# Stacked area: add opacity to avoid hiding layers, remove line border
|
|
569
|
+
if hasattr(trace, 'line') and trace.line.color:
|
|
570
|
+
trace.fillcolor = trace.line.color
|
|
571
|
+
trace.line.width = 0
|
|
572
|
+
elif cls == 'mixed':
|
|
573
|
+
# Mixed variables: show as dashed line, not stacked
|
|
574
|
+
if hasattr(trace, 'line'):
|
|
575
|
+
trace.line.width = 2
|
|
576
|
+
trace.line.dash = 'dash'
|
|
577
|
+
if hasattr(trace, 'fill'):
|
|
578
|
+
trace.fill = None
|
|
579
|
+
|
|
580
|
+
# Update layout with basic styling (Plotly Express handles sizing automatically)
|
|
447
581
|
fig.update_layout(
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
showgrid=True, # Enable grid lines on the y-axis
|
|
452
|
-
gridcolor='lightgrey', # Customize grid line color
|
|
453
|
-
gridwidth=0.5, # Customize grid line width
|
|
454
|
-
),
|
|
455
|
-
xaxis=dict(
|
|
456
|
-
title=xlabel,
|
|
457
|
-
showgrid=True, # Enable grid lines on the x-axis
|
|
458
|
-
gridcolor='lightgrey', # Customize grid line color
|
|
459
|
-
gridwidth=0.5, # Customize grid line width
|
|
460
|
-
),
|
|
461
|
-
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
|
|
462
|
-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
|
|
463
|
-
font=dict(size=14), # Increase font size for better readability
|
|
582
|
+
plot_bgcolor='rgba(0,0,0,0)',
|
|
583
|
+
paper_bgcolor='rgba(0,0,0,0)',
|
|
584
|
+
font=dict(size=12),
|
|
464
585
|
)
|
|
465
586
|
|
|
587
|
+
# Update axes to share if requested (Plotly Express already handles this, but we can customize)
|
|
588
|
+
if not shared_yaxes:
|
|
589
|
+
fig.update_yaxes(matches=None)
|
|
590
|
+
if not shared_xaxes:
|
|
591
|
+
fig.update_xaxes(matches=None)
|
|
592
|
+
|
|
466
593
|
return fig
|
|
467
594
|
|
|
468
595
|
|
|
@@ -562,213 +689,110 @@ def with_matplotlib(
|
|
|
562
689
|
return fig, ax
|
|
563
690
|
|
|
564
691
|
|
|
565
|
-
def
|
|
566
|
-
data:
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis,
|
|
575
|
-
the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot.
|
|
576
|
-
|
|
577
|
-
Args:
|
|
578
|
-
data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
|
|
579
|
-
The values in the DataFrame will be represented as colors in the heatmap.
|
|
580
|
-
color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc.
|
|
581
|
-
title: The title of the plot.
|
|
582
|
-
xlabel: The label for the x-axis.
|
|
583
|
-
ylabel: The label for the y-axis.
|
|
584
|
-
figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches.
|
|
585
|
-
|
|
586
|
-
Returns:
|
|
587
|
-
A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area
|
|
588
|
-
where the heatmap is drawn. These can be used for further customization or saving the plot to a file.
|
|
589
|
-
|
|
590
|
-
Notes:
|
|
591
|
-
- The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot.
|
|
592
|
-
- The color scale is normalized based on the minimum and maximum values in the DataFrame.
|
|
593
|
-
- The x-axis labels (periods) are placed at the top of the plot.
|
|
594
|
-
- The colorbar is added horizontally at the bottom of the plot, with a label.
|
|
595
|
-
"""
|
|
596
|
-
|
|
597
|
-
# Get the min and max values for color normalization
|
|
598
|
-
color_bar_min, color_bar_max = data.min().min(), data.max().max()
|
|
599
|
-
|
|
600
|
-
# Create the heatmap plot
|
|
601
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
602
|
-
ax.pcolormesh(data.values, cmap=color_map, shading='auto')
|
|
603
|
-
ax.invert_yaxis() # Flip the y-axis to start at the top
|
|
604
|
-
|
|
605
|
-
# Adjust ticks and labels for x and y axes
|
|
606
|
-
ax.set_xticks(np.arange(len(data.columns)) + 0.5)
|
|
607
|
-
ax.set_xticklabels(data.columns, ha='center')
|
|
608
|
-
ax.set_yticks(np.arange(len(data.index)) + 0.5)
|
|
609
|
-
ax.set_yticklabels(data.index, va='center')
|
|
610
|
-
|
|
611
|
-
# Add labels to the axes
|
|
612
|
-
ax.set_xlabel(xlabel, ha='center')
|
|
613
|
-
ax.set_ylabel(ylabel, va='center')
|
|
614
|
-
ax.set_title(title)
|
|
615
|
-
|
|
616
|
-
# Position x-axis labels at the top
|
|
617
|
-
ax.xaxis.set_label_position('top')
|
|
618
|
-
ax.xaxis.set_ticks_position('top')
|
|
619
|
-
|
|
620
|
-
# Add the colorbar
|
|
621
|
-
sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max))
|
|
622
|
-
sm1.set_array([])
|
|
623
|
-
fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal')
|
|
624
|
-
|
|
625
|
-
fig.tight_layout()
|
|
626
|
-
|
|
627
|
-
return fig, ax
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
def heat_map_plotly(
|
|
631
|
-
data: pd.DataFrame,
|
|
632
|
-
color_map: str = 'viridis',
|
|
633
|
-
title: str = '',
|
|
634
|
-
xlabel: str = 'Period',
|
|
635
|
-
ylabel: str = 'Step',
|
|
636
|
-
categorical_labels: bool = True,
|
|
637
|
-
) -> go.Figure:
|
|
638
|
-
"""
|
|
639
|
-
Plots a DataFrame as a heatmap using Plotly. The columns of the DataFrame will be mapped to the x-axis,
|
|
640
|
-
and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot.
|
|
641
|
-
|
|
642
|
-
Args:
|
|
643
|
-
data: A DataFrame with the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
|
|
644
|
-
The values in the DataFrame will be represented as colors in the heatmap.
|
|
645
|
-
color_map: The color scale to use for the heatmap. Default is 'viridis'. Plotly supports various color scales like 'Cividis', 'Inferno', etc.
|
|
646
|
-
title: The title of the heatmap. Default is an empty string.
|
|
647
|
-
xlabel: The label for the x-axis. Default is 'Period'.
|
|
648
|
-
ylabel: The label for the y-axis. Default is 'Step'.
|
|
649
|
-
categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data).
|
|
650
|
-
Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data.
|
|
651
|
-
|
|
652
|
-
Returns:
|
|
653
|
-
A Plotly figure object containing the heatmap. This can be further customized and saved
|
|
654
|
-
or displayed using `fig.show()`.
|
|
655
|
-
|
|
656
|
-
Notes:
|
|
657
|
-
The color bar is automatically scaled to the minimum and maximum values in the data.
|
|
658
|
-
The y-axis is reversed to display the first row at the top.
|
|
692
|
+
def reshape_data_for_heatmap(
|
|
693
|
+
data: xr.DataArray,
|
|
694
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
695
|
+
| Literal['auto']
|
|
696
|
+
| None = 'auto',
|
|
697
|
+
facet_by: str | list[str] | None = None,
|
|
698
|
+
animate_by: str | None = None,
|
|
699
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
700
|
+
) -> xr.DataArray:
|
|
659
701
|
"""
|
|
702
|
+
Reshape data for heatmap visualization, handling time dimension intelligently.
|
|
660
703
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
z=data.values,
|
|
666
|
-
x=data.columns,
|
|
667
|
-
y=data.index,
|
|
668
|
-
colorscale=color_map,
|
|
669
|
-
zmin=color_bar_min,
|
|
670
|
-
zmax=color_bar_max,
|
|
671
|
-
colorbar=dict(
|
|
672
|
-
title=dict(text='Color Bar Label', side='right'),
|
|
673
|
-
orientation='h',
|
|
674
|
-
xref='container',
|
|
675
|
-
yref='container',
|
|
676
|
-
len=0.8, # Color bar length relative to plot
|
|
677
|
-
x=0.5,
|
|
678
|
-
y=0.1,
|
|
679
|
-
),
|
|
680
|
-
)
|
|
681
|
-
)
|
|
682
|
-
|
|
683
|
-
# Set axis labels and style
|
|
684
|
-
fig.update_layout(
|
|
685
|
-
title=title,
|
|
686
|
-
xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None),
|
|
687
|
-
yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None),
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
return fig
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray:
|
|
694
|
-
"""
|
|
695
|
-
Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap.
|
|
704
|
+
This function decides whether to reshape the 'time' dimension based on the reshape_time parameter:
|
|
705
|
+
- 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap
|
|
706
|
+
- Tuple: Explicitly reshapes time with specified parameters
|
|
707
|
+
- None: No reshaping (returns data as-is)
|
|
696
708
|
|
|
697
|
-
|
|
698
|
-
(e.g., 24 hours per day) and columns representing time periods (e.g., days or months).
|
|
709
|
+
All non-time dimensions are preserved during reshaping.
|
|
699
710
|
|
|
700
711
|
Args:
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
712
|
+
data: DataArray to reshape for heatmap visualization.
|
|
713
|
+
reshape_time: Reshaping configuration:
|
|
714
|
+
- 'auto' (default): Auto-reshape if needed based on facet_by/animate_by
|
|
715
|
+
- Tuple (timeframes, timesteps_per_frame): Explicit time reshaping
|
|
716
|
+
- None: No reshaping
|
|
717
|
+
facet_by: Dimension(s) used for faceting (used in 'auto' decision).
|
|
718
|
+
animate_by: Dimension used for animation (used in 'auto' decision).
|
|
719
|
+
fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
704
720
|
|
|
705
721
|
Returns:
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
"""
|
|
709
|
-
|
|
710
|
-
# Step 1: Ensure the input is a 1D array.
|
|
711
|
-
if data_1d.ndim != 1:
|
|
712
|
-
raise ValueError('Input must be a 1D array')
|
|
713
|
-
|
|
714
|
-
# Step 2: Convert data to float type to allow NaN padding
|
|
715
|
-
if data_1d.dtype != np.float64:
|
|
716
|
-
data_1d = data_1d.astype(np.float64)
|
|
722
|
+
Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced
|
|
723
|
+
by 'timestep' and 'timeframe'. All other dimensions are preserved.
|
|
717
724
|
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
cols = len(data_1d) // nr_of_steps_per_column # Base number of columns
|
|
721
|
-
|
|
722
|
-
# If there's a remainder, add an extra column to hold the remaining values
|
|
723
|
-
if total_steps % nr_of_steps_per_column != 0:
|
|
724
|
-
cols += 1
|
|
725
|
+
Examples:
|
|
726
|
+
Auto-reshaping:
|
|
725
727
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
728
|
+
```python
|
|
729
|
+
# Will auto-reshape because only 'time' remains after faceting/animation
|
|
730
|
+
data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period')
|
|
731
|
+
```
|
|
730
732
|
|
|
731
|
-
|
|
732
|
-
data_2d = padded_data.reshape(cols, nr_of_steps_per_column)
|
|
733
|
+
Explicit reshaping:
|
|
733
734
|
|
|
734
|
-
|
|
735
|
+
```python
|
|
736
|
+
# Explicitly reshape to daily pattern
|
|
737
|
+
data = reshape_data_for_heatmap(data, reshape_time=('D', 'h'))
|
|
738
|
+
```
|
|
735
739
|
|
|
740
|
+
No reshaping:
|
|
736
741
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
fill: Literal['ffill', 'bfill'] | None = None,
|
|
742
|
-
) -> pd.DataFrame:
|
|
742
|
+
```python
|
|
743
|
+
# Keep data as-is
|
|
744
|
+
data = reshape_data_for_heatmap(data, reshape_time=None)
|
|
745
|
+
```
|
|
743
746
|
"""
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
747
|
+
# If no time dimension, return data as-is
|
|
748
|
+
if 'time' not in data.dims:
|
|
749
|
+
return data
|
|
750
|
+
|
|
751
|
+
# Handle None (disabled) - return data as-is
|
|
752
|
+
if reshape_time is None:
|
|
753
|
+
return data
|
|
754
|
+
|
|
755
|
+
# Determine timeframes and timesteps_per_frame based on reshape_time parameter
|
|
756
|
+
if reshape_time == 'auto':
|
|
757
|
+
# Check if we need automatic time reshaping
|
|
758
|
+
facet_dims_used = []
|
|
759
|
+
if facet_by:
|
|
760
|
+
facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by)
|
|
761
|
+
if animate_by:
|
|
762
|
+
facet_dims_used.append(animate_by)
|
|
763
|
+
|
|
764
|
+
# Get dimensions that would remain for heatmap
|
|
765
|
+
potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used]
|
|
766
|
+
|
|
767
|
+
# Auto-reshape if only 'time' dimension remains
|
|
768
|
+
if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time':
|
|
769
|
+
logger.debug(
|
|
770
|
+
"Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. "
|
|
771
|
+
"Using default timeframes='D' and timesteps_per_frame='h'. "
|
|
772
|
+
"To customize, use reshape_time=('D', 'h') or disable with reshape_time=None."
|
|
773
|
+
)
|
|
774
|
+
timeframes, timesteps_per_frame = 'D', 'h'
|
|
775
|
+
else:
|
|
776
|
+
# No reshaping needed
|
|
777
|
+
return data
|
|
778
|
+
elif isinstance(reshape_time, tuple):
|
|
779
|
+
# Explicit reshaping
|
|
780
|
+
timeframes, timesteps_per_frame = reshape_time
|
|
781
|
+
else:
|
|
782
|
+
raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}")
|
|
755
783
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
"""
|
|
760
|
-
assert pd.api.types.is_datetime64_any_dtype(df.index), (
|
|
761
|
-
'The index of the DataFrame must be datetime to transform it properly for a heatmap plot'
|
|
762
|
-
)
|
|
784
|
+
# Validate that time is datetime
|
|
785
|
+
if not np.issubdtype(data.coords['time'].dtype, np.datetime64):
|
|
786
|
+
raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}')
|
|
763
787
|
|
|
764
|
-
# Define formats for different combinations
|
|
788
|
+
# Define formats for different combinations
|
|
765
789
|
formats = {
|
|
766
790
|
('YS', 'W'): ('%Y', '%W'),
|
|
767
791
|
('YS', 'D'): ('%Y', '%j'), # day of year
|
|
768
792
|
('YS', 'h'): ('%Y', '%j %H:00'),
|
|
769
793
|
('MS', 'D'): ('%Y-%m', '%d'), # day of month
|
|
770
794
|
('MS', 'h'): ('%Y-%m', '%d %H:00'),
|
|
771
|
-
('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
|
|
795
|
+
('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
|
|
772
796
|
('W', 'h'): ('%Y-w%W', '%w_%A %H:00'),
|
|
773
797
|
('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour
|
|
774
798
|
('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute
|
|
@@ -776,43 +800,64 @@ def heat_map_data_from_df(
|
|
|
776
800
|
('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour
|
|
777
801
|
}
|
|
778
802
|
|
|
779
|
-
|
|
780
|
-
raise ValueError('DataFrame is empty.')
|
|
781
|
-
diffs = df.index.to_series().diff().dropna()
|
|
782
|
-
minimum_time_diff_in_min = diffs.min().total_seconds() / 60
|
|
783
|
-
time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
|
|
784
|
-
if time_intervals[steps_per_period] > minimum_time_diff_in_min:
|
|
785
|
-
logger.error(
|
|
786
|
-
f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to '
|
|
787
|
-
f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.'
|
|
788
|
-
)
|
|
789
|
-
|
|
790
|
-
# Select the format based on the `periods` and `steps_per_period` combination
|
|
791
|
-
format_pair = (periods, steps_per_period)
|
|
803
|
+
format_pair = (timeframes, timesteps_per_frame)
|
|
792
804
|
if format_pair not in formats:
|
|
793
805
|
raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}')
|
|
794
806
|
period_format, step_format = formats[format_pair]
|
|
795
807
|
|
|
796
|
-
|
|
808
|
+
# Check if resampling is needed
|
|
809
|
+
if data.sizes['time'] > 1:
|
|
810
|
+
# Use NumPy for more efficient timedelta computation
|
|
811
|
+
time_values = data.coords['time'].values # Already numpy datetime64[ns]
|
|
812
|
+
# Calculate differences and convert to minutes
|
|
813
|
+
time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0
|
|
814
|
+
if time_diffs.size > 0:
|
|
815
|
+
min_time_diff_min = np.nanmin(time_diffs)
|
|
816
|
+
time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
|
|
817
|
+
if time_intervals[timesteps_per_frame] > min_time_diff_min:
|
|
818
|
+
logger.warning(
|
|
819
|
+
f'Resampling data from {min_time_diff_min:.2f} min to '
|
|
820
|
+
f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.'
|
|
821
|
+
)
|
|
797
822
|
|
|
798
|
-
|
|
823
|
+
# Resample along time dimension
|
|
824
|
+
resampled = data.resample(time=timesteps_per_frame).mean()
|
|
799
825
|
|
|
800
|
-
|
|
801
|
-
|
|
826
|
+
# Apply fill if specified
|
|
827
|
+
if fill == 'ffill':
|
|
828
|
+
resampled = resampled.ffill(dim='time')
|
|
802
829
|
elif fill == 'bfill':
|
|
803
|
-
|
|
830
|
+
resampled = resampled.bfill(dim='time')
|
|
831
|
+
|
|
832
|
+
# Create period and step labels
|
|
833
|
+
time_values = pd.to_datetime(resampled.coords['time'].values)
|
|
834
|
+
period_labels = time_values.strftime(period_format)
|
|
835
|
+
step_labels = time_values.strftime(step_format)
|
|
836
|
+
|
|
837
|
+
# Handle special case for weekly day format
|
|
838
|
+
if '%w_%A' in step_format:
|
|
839
|
+
step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values
|
|
840
|
+
|
|
841
|
+
# Add period and step as coordinates
|
|
842
|
+
resampled = resampled.assign_coords(
|
|
843
|
+
{
|
|
844
|
+
'timeframe': ('time', period_labels),
|
|
845
|
+
'timestep': ('time', step_labels),
|
|
846
|
+
}
|
|
847
|
+
)
|
|
804
848
|
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
resampled_data['step'] = resampled_data['step'].apply(
|
|
809
|
-
lambda x: x.replace('0_Sunday', '7_Sunday') if '0_Sunday' in x else x
|
|
810
|
-
)
|
|
849
|
+
# Convert to multi-index and unstack
|
|
850
|
+
resampled = resampled.set_index(time=['timeframe', 'timestep'])
|
|
851
|
+
result = resampled.unstack('time')
|
|
811
852
|
|
|
812
|
-
#
|
|
813
|
-
|
|
853
|
+
# Ensure timestep and timeframe come first in dimension order
|
|
854
|
+
# Get other dimensions
|
|
855
|
+
other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']]
|
|
814
856
|
|
|
815
|
-
|
|
857
|
+
# Reorder: timestep, timeframe, then other dimensions
|
|
858
|
+
result = result.transpose('timestep', 'timeframe', *other_dims)
|
|
859
|
+
|
|
860
|
+
return result
|
|
816
861
|
|
|
817
862
|
|
|
818
863
|
def plot_network(
|
|
@@ -1413,6 +1458,311 @@ def dual_pie_with_matplotlib(
|
|
|
1413
1458
|
return fig, axes
|
|
1414
1459
|
|
|
1415
1460
|
|
|
1461
|
+
def heatmap_with_plotly(
|
|
1462
|
+
data: xr.DataArray,
|
|
1463
|
+
colors: ColorType = 'viridis',
|
|
1464
|
+
title: str = '',
|
|
1465
|
+
facet_by: str | list[str] | None = None,
|
|
1466
|
+
animate_by: str | None = None,
|
|
1467
|
+
facet_cols: int = 3,
|
|
1468
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1469
|
+
| Literal['auto']
|
|
1470
|
+
| None = 'auto',
|
|
1471
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1472
|
+
) -> go.Figure:
|
|
1473
|
+
"""
|
|
1474
|
+
Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
|
|
1475
|
+
|
|
1476
|
+
This function creates heatmap visualizations from xarray DataArrays, supporting
|
|
1477
|
+
multi-dimensional data through faceting (subplots) and animation. It automatically
|
|
1478
|
+
handles dimension reduction and data reshaping for optimal heatmap display.
|
|
1479
|
+
|
|
1480
|
+
Automatic Time Reshaping:
|
|
1481
|
+
If only the 'time' dimension remains after faceting/animation (making the data 1D),
|
|
1482
|
+
the function automatically reshapes time into a 2D format using default values
|
|
1483
|
+
(timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap
|
|
1484
|
+
showing hours vs days.
|
|
1485
|
+
|
|
1486
|
+
Args:
|
|
1487
|
+
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1488
|
+
2 dimensions, or a 'time' dimension that can be reshaped into 2D.
|
|
1489
|
+
colors: Color specification (colormap name, list, or dict). Common options:
|
|
1490
|
+
'viridis', 'plasma', 'RdBu', 'portland'.
|
|
1491
|
+
title: The main title of the heatmap.
|
|
1492
|
+
facet_by: Dimension to create facets for. Creates a subplot grid.
|
|
1493
|
+
Can be a single dimension name or list (only first dimension used).
|
|
1494
|
+
Note: px.imshow only supports single-dimension faceting.
|
|
1495
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
1496
|
+
animate_by: Dimension to animate over. Creates animation frames.
|
|
1497
|
+
If the dimension doesn't exist in the data, it will be silently ignored.
|
|
1498
|
+
facet_cols: Number of columns in the facet grid (used with facet_by).
|
|
1499
|
+
reshape_time: Time reshaping configuration:
|
|
1500
|
+
- 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains
|
|
1501
|
+
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1502
|
+
- None: Disable time reshaping (will error if only 1D time data)
|
|
1503
|
+
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1504
|
+
|
|
1505
|
+
Returns:
|
|
1506
|
+
A Plotly figure object containing the heatmap visualization.
|
|
1507
|
+
|
|
1508
|
+
Examples:
|
|
1509
|
+
Simple heatmap:
|
|
1510
|
+
|
|
1511
|
+
```python
|
|
1512
|
+
fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map')
|
|
1513
|
+
```
|
|
1514
|
+
|
|
1515
|
+
Facet by scenario:
|
|
1516
|
+
|
|
1517
|
+
```python
|
|
1518
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2)
|
|
1519
|
+
```
|
|
1520
|
+
|
|
1521
|
+
Animate by period:
|
|
1522
|
+
|
|
1523
|
+
```python
|
|
1524
|
+
fig = heatmap_with_plotly(data_array, animate_by='period')
|
|
1525
|
+
```
|
|
1526
|
+
|
|
1527
|
+
Automatic time reshaping (when only time dimension remains):
|
|
1528
|
+
|
|
1529
|
+
```python
|
|
1530
|
+
# Data with dims ['time', 'scenario', 'period']
|
|
1531
|
+
# After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe)
|
|
1532
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period')
|
|
1533
|
+
```
|
|
1534
|
+
|
|
1535
|
+
Explicit time reshaping:
|
|
1536
|
+
|
|
1537
|
+
```python
|
|
1538
|
+
fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
|
|
1539
|
+
```
|
|
1540
|
+
"""
|
|
1541
|
+
# Handle empty data
|
|
1542
|
+
if data.size == 0:
|
|
1543
|
+
return go.Figure()
|
|
1544
|
+
|
|
1545
|
+
# Apply time reshaping using the new unified function
|
|
1546
|
+
data = reshape_data_for_heatmap(
|
|
1547
|
+
data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
# Get available dimensions
|
|
1551
|
+
available_dims = list(data.dims)
|
|
1552
|
+
|
|
1553
|
+
# Validate and filter facet_by dimensions
|
|
1554
|
+
if facet_by is not None:
|
|
1555
|
+
if isinstance(facet_by, str):
|
|
1556
|
+
if facet_by not in available_dims:
|
|
1557
|
+
logger.debug(
|
|
1558
|
+
f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
|
|
1559
|
+
f'Ignoring facet_by parameter.'
|
|
1560
|
+
)
|
|
1561
|
+
facet_by = None
|
|
1562
|
+
elif isinstance(facet_by, list):
|
|
1563
|
+
missing_dims = [dim for dim in facet_by if dim not in available_dims]
|
|
1564
|
+
facet_by = [dim for dim in facet_by if dim in available_dims]
|
|
1565
|
+
if missing_dims:
|
|
1566
|
+
logger.debug(
|
|
1567
|
+
f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
|
|
1568
|
+
f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
|
|
1569
|
+
)
|
|
1570
|
+
if len(facet_by) == 0:
|
|
1571
|
+
facet_by = None
|
|
1572
|
+
|
|
1573
|
+
# Validate animate_by dimension
|
|
1574
|
+
if animate_by is not None and animate_by not in available_dims:
|
|
1575
|
+
logger.debug(
|
|
1576
|
+
f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
|
|
1577
|
+
f'Ignoring animate_by parameter.'
|
|
1578
|
+
)
|
|
1579
|
+
animate_by = None
|
|
1580
|
+
|
|
1581
|
+
# Determine which dimensions are used for faceting/animation
|
|
1582
|
+
facet_dims = []
|
|
1583
|
+
if facet_by:
|
|
1584
|
+
facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by
|
|
1585
|
+
if animate_by:
|
|
1586
|
+
facet_dims.append(animate_by)
|
|
1587
|
+
|
|
1588
|
+
# Get remaining dimensions for the heatmap itself
|
|
1589
|
+
heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
|
|
1590
|
+
|
|
1591
|
+
if len(heatmap_dims) < 2:
|
|
1592
|
+
# Need at least 2 dimensions for a heatmap
|
|
1593
|
+
logger.error(
|
|
1594
|
+
f'Heatmap requires at least 2 dimensions for rows and columns. '
|
|
1595
|
+
f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
|
|
1596
|
+
)
|
|
1597
|
+
return go.Figure()
|
|
1598
|
+
|
|
1599
|
+
# Setup faceting parameters for Plotly Express
|
|
1600
|
+
# Note: px.imshow only supports facet_col, not facet_row
|
|
1601
|
+
facet_col_param = None
|
|
1602
|
+
if facet_by:
|
|
1603
|
+
if isinstance(facet_by, str):
|
|
1604
|
+
facet_col_param = facet_by
|
|
1605
|
+
elif len(facet_by) == 1:
|
|
1606
|
+
facet_col_param = facet_by[0]
|
|
1607
|
+
elif len(facet_by) >= 2:
|
|
1608
|
+
# px.imshow doesn't support facet_row, so we can only facet by one dimension
|
|
1609
|
+
# Use the first dimension and warn about the rest
|
|
1610
|
+
facet_col_param = facet_by[0]
|
|
1611
|
+
logger.warning(
|
|
1612
|
+
f'px.imshow only supports faceting by a single dimension. '
|
|
1613
|
+
f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. '
|
|
1614
|
+
f'Consider using animate_by for additional dimensions.'
|
|
1615
|
+
)
|
|
1616
|
+
|
|
1617
|
+
# Create the imshow plot - px.imshow can work directly with xarray DataArrays
|
|
1618
|
+
common_args = {
|
|
1619
|
+
'img': data,
|
|
1620
|
+
'color_continuous_scale': colors if isinstance(colors, str) else 'viridis',
|
|
1621
|
+
'title': title,
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
# Add faceting if specified
|
|
1625
|
+
if facet_col_param:
|
|
1626
|
+
common_args['facet_col'] = facet_col_param
|
|
1627
|
+
if facet_cols:
|
|
1628
|
+
common_args['facet_col_wrap'] = facet_cols
|
|
1629
|
+
|
|
1630
|
+
# Add animation if specified
|
|
1631
|
+
if animate_by:
|
|
1632
|
+
common_args['animation_frame'] = animate_by
|
|
1633
|
+
|
|
1634
|
+
try:
|
|
1635
|
+
fig = px.imshow(**common_args)
|
|
1636
|
+
except Exception as e:
|
|
1637
|
+
logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
|
|
1638
|
+
# Fallback: create a simple heatmap without faceting
|
|
1639
|
+
fig = px.imshow(
|
|
1640
|
+
data.values,
|
|
1641
|
+
color_continuous_scale=colors if isinstance(colors, str) else 'viridis',
|
|
1642
|
+
title=title,
|
|
1643
|
+
)
|
|
1644
|
+
|
|
1645
|
+
# Update layout with basic styling
|
|
1646
|
+
fig.update_layout(
|
|
1647
|
+
plot_bgcolor='rgba(0,0,0,0)',
|
|
1648
|
+
paper_bgcolor='rgba(0,0,0,0)',
|
|
1649
|
+
font=dict(size=12),
|
|
1650
|
+
)
|
|
1651
|
+
|
|
1652
|
+
return fig
|
|
1653
|
+
|
|
1654
|
+
|
|
1655
|
+
def heatmap_with_matplotlib(
|
|
1656
|
+
data: xr.DataArray,
|
|
1657
|
+
colors: ColorType = 'viridis',
|
|
1658
|
+
title: str = '',
|
|
1659
|
+
figsize: tuple[float, float] = (12, 6),
|
|
1660
|
+
fig: plt.Figure | None = None,
|
|
1661
|
+
ax: plt.Axes | None = None,
|
|
1662
|
+
reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
|
|
1663
|
+
| Literal['auto']
|
|
1664
|
+
| None = 'auto',
|
|
1665
|
+
fill: Literal['ffill', 'bfill'] | None = 'ffill',
|
|
1666
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
1667
|
+
"""
|
|
1668
|
+
Plot a heatmap visualization using Matplotlib's imshow.
|
|
1669
|
+
|
|
1670
|
+
This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's
|
|
1671
|
+
imshow function. For multi-dimensional data, only the first two dimensions are used.
|
|
1672
|
+
|
|
1673
|
+
Args:
|
|
1674
|
+
data: An xarray DataArray containing the data to visualize. Should have at least
|
|
1675
|
+
2 dimensions. If more than 2 dimensions exist, additional dimensions will
|
|
1676
|
+
be reduced by taking the first slice.
|
|
1677
|
+
colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu').
|
|
1678
|
+
title: The title of the heatmap.
|
|
1679
|
+
figsize: The size of the figure (width, height) in inches.
|
|
1680
|
+
fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
|
|
1681
|
+
ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
|
|
1682
|
+
reshape_time: Time reshaping configuration:
|
|
1683
|
+
- 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
|
|
1684
|
+
- Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
|
|
1685
|
+
- None: Disable time reshaping
|
|
1686
|
+
fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
|
|
1687
|
+
|
|
1688
|
+
Returns:
|
|
1689
|
+
A tuple containing the Matplotlib figure and axes objects used for the plot.
|
|
1690
|
+
|
|
1691
|
+
Notes:
|
|
1692
|
+
- Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features.
|
|
1693
|
+
- The y-axis is automatically inverted to display data with origin at top-left.
|
|
1694
|
+
- A colorbar is added to show the value scale.
|
|
1695
|
+
|
|
1696
|
+
Examples:
|
|
1697
|
+
```python
|
|
1698
|
+
fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature')
|
|
1699
|
+
plt.savefig('heatmap.png')
|
|
1700
|
+
```
|
|
1701
|
+
|
|
1702
|
+
Time reshaping:
|
|
1703
|
+
|
|
1704
|
+
```python
|
|
1705
|
+
fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
|
|
1706
|
+
```
|
|
1707
|
+
"""
|
|
1708
|
+
# Handle empty data
|
|
1709
|
+
if data.size == 0:
|
|
1710
|
+
if fig is None or ax is None:
|
|
1711
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1712
|
+
return fig, ax
|
|
1713
|
+
|
|
1714
|
+
# Apply time reshaping using the new unified function
|
|
1715
|
+
# Matplotlib doesn't support faceting/animation, so we pass None for those
|
|
1716
|
+
data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
|
|
1717
|
+
|
|
1718
|
+
# Create figure and axes if not provided
|
|
1719
|
+
if fig is None or ax is None:
|
|
1720
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1721
|
+
|
|
1722
|
+
# Extract data values
|
|
1723
|
+
# If data has more than 2 dimensions, we need to reduce it
|
|
1724
|
+
if isinstance(data, xr.DataArray):
|
|
1725
|
+
# Get the first 2 dimensions
|
|
1726
|
+
dims = list(data.dims)
|
|
1727
|
+
if len(dims) > 2:
|
|
1728
|
+
logger.warning(
|
|
1729
|
+
f'Data has {len(dims)} dimensions: {dims}. '
|
|
1730
|
+
f'Only the first 2 will be used for the heatmap. '
|
|
1731
|
+
f'Use the plotly engine for faceting/animation support.'
|
|
1732
|
+
)
|
|
1733
|
+
# Select only the first 2 dimensions by taking first slice of others
|
|
1734
|
+
selection = {dim: 0 for dim in dims[2:]}
|
|
1735
|
+
data = data.isel(selection)
|
|
1736
|
+
|
|
1737
|
+
values = data.values
|
|
1738
|
+
x_labels = data.dims[1] if len(data.dims) > 1 else 'x'
|
|
1739
|
+
y_labels = data.dims[0] if len(data.dims) > 0 else 'y'
|
|
1740
|
+
else:
|
|
1741
|
+
values = data
|
|
1742
|
+
x_labels = 'x'
|
|
1743
|
+
y_labels = 'y'
|
|
1744
|
+
|
|
1745
|
+
# Process colormap
|
|
1746
|
+
cmap = colors if isinstance(colors, str) else 'viridis'
|
|
1747
|
+
|
|
1748
|
+
# Create the heatmap using imshow
|
|
1749
|
+
im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper')
|
|
1750
|
+
|
|
1751
|
+
# Add colorbar
|
|
1752
|
+
cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05)
|
|
1753
|
+
cbar.set_label('Value')
|
|
1754
|
+
|
|
1755
|
+
# Set labels and title
|
|
1756
|
+
ax.set_xlabel(str(x_labels).capitalize())
|
|
1757
|
+
ax.set_ylabel(str(y_labels).capitalize())
|
|
1758
|
+
ax.set_title(title)
|
|
1759
|
+
|
|
1760
|
+
# Apply tight layout
|
|
1761
|
+
fig.tight_layout()
|
|
1762
|
+
|
|
1763
|
+
return fig, ax
|
|
1764
|
+
|
|
1765
|
+
|
|
1416
1766
|
def export_figure(
|
|
1417
1767
|
figure_like: go.Figure | tuple[plt.Figure, plt.Axes],
|
|
1418
1768
|
default_path: pathlib.Path,
|