flixopt 3.0.2__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flixopt might be problematic. Click here for more details.

flixopt/plotting.py CHANGED
@@ -39,6 +39,7 @@ import pandas as pd
39
39
  import plotly.express as px
40
40
  import plotly.graph_objects as go
41
41
  import plotly.offline
42
+ import xarray as xr
42
43
  from plotly.exceptions import PlotlyError
43
44
 
44
45
  if TYPE_CHECKING:
@@ -326,149 +327,275 @@ class ColorProcessor:
326
327
 
327
328
 
328
329
  def with_plotly(
329
- data: pd.DataFrame,
330
- style: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
330
+ data: pd.DataFrame | xr.DataArray | xr.Dataset,
331
+ mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
331
332
  colors: ColorType = 'viridis',
332
333
  title: str = '',
333
334
  ylabel: str = '',
334
335
  xlabel: str = 'Time in h',
335
336
  fig: go.Figure | None = None,
337
+ facet_by: str | list[str] | None = None,
338
+ animate_by: str | None = None,
339
+ facet_cols: int = 3,
340
+ shared_yaxes: bool = True,
341
+ shared_xaxes: bool = True,
336
342
  ) -> go.Figure:
337
343
  """
338
- Plot a DataFrame with Plotly, using either stacked bars or stepped lines.
344
+ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
345
+
346
+ Uses Plotly Express for convenient faceting and animation with automatic styling.
347
+ For simple plots without faceting, can optionally add to an existing figure.
339
348
 
340
349
  Args:
341
- data: A DataFrame containing the data to plot, where the index represents time (e.g., hours),
342
- and each column represents a separate data series.
343
- style: The plotting style. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines,
344
- or 'area' for stacked area charts.
345
- colors: Color specification, can be:
346
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
347
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
348
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
349
- title: The title of the plot.
350
+ data: A DataFrame or xarray DataArray/Dataset to plot.
351
+ mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
352
+ 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
353
+ colors: Color specification (colormap, list, or dict mapping labels to colors).
354
+ title: The main title of the plot.
350
355
  ylabel: The label for the y-axis.
351
356
  xlabel: The label for the x-axis.
352
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
357
+ fig: A Plotly figure object to plot on (only for simple plots without faceting).
358
+ If not provided, a new figure will be created.
359
+ facet_by: Dimension(s) to create facets for. Creates a subplot grid.
360
+ Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
361
+ If the dimension doesn't exist in the data, it will be silently ignored.
362
+ animate_by: Dimension to animate over. Creates animation frames.
363
+ If the dimension doesn't exist in the data, it will be silently ignored.
364
+ facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
365
+ shared_yaxes: Whether subplots share y-axes.
366
+ shared_xaxes: Whether subplots share x-axes.
353
367
 
354
368
  Returns:
355
- A Plotly figure object containing the generated plot.
356
- """
357
- if style not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
358
- raise ValueError(f"'style' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {style!r}")
359
- if data.empty:
360
- return go.Figure()
369
+ A Plotly figure object containing the faceted/animated plot.
361
370
 
362
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns))
371
+ Examples:
372
+ Simple plot:
363
373
 
364
- fig = fig if fig is not None else go.Figure()
374
+ ```python
375
+ fig = with_plotly(df, mode='area', title='Energy Mix')
376
+ ```
365
377
 
366
- if style == 'stacked_bar':
367
- for i, column in enumerate(data.columns):
368
- fig.add_trace(
369
- go.Bar(
370
- x=data.index,
371
- y=data[column],
372
- name=column,
373
- marker=dict(
374
- color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)')
375
- ), # Transparent line with 0 width
376
- )
377
- )
378
+ Facet by scenario:
378
379
 
379
- fig.update_layout(
380
- barmode='relative',
381
- bargap=0, # No space between bars
382
- bargroupgap=0, # No space between grouped bars
383
- )
384
- if style == 'grouped_bar':
385
- for i, column in enumerate(data.columns):
386
- fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i])))
380
+ ```python
381
+ fig = with_plotly(ds, facet_by='scenario', facet_cols=2)
382
+ ```
387
383
 
388
- fig.update_layout(
389
- barmode='group',
390
- bargap=0.2, # No space between bars
391
- bargroupgap=0, # space between grouped bars
392
- )
393
- elif style == 'line':
394
- for i, column in enumerate(data.columns):
395
- fig.add_trace(
396
- go.Scatter(
397
- x=data.index,
398
- y=data[column],
399
- mode='lines',
400
- name=column,
401
- line=dict(shape='hv', color=processed_colors[i]),
402
- )
403
- )
404
- elif style == 'area':
405
- data = data.copy()
406
- data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting
407
- # Split columns into positive, negative, and mixed categories
408
- positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()])
409
- negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()])
410
- negative_columns = [column for column in negative_columns if column not in positive_columns]
411
- mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns))
412
-
413
- if mixed_columns:
414
- logger.error(
415
- f'Data for plotting stacked lines contains columns with both positive and negative values:'
416
- f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
417
- )
384
+ Animate by period:
385
+
386
+ ```python
387
+ fig = with_plotly(ds, animate_by='period')
388
+ ```
389
+
390
+ Facet and animate:
391
+
392
+ ```python
393
+ fig = with_plotly(ds, facet_by='scenario', animate_by='period')
394
+ ```
395
+ """
396
+ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
397
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
398
+
399
+ # Handle empty data
400
+ if isinstance(data, pd.DataFrame) and data.empty:
401
+ return go.Figure()
402
+ elif isinstance(data, xr.DataArray) and data.size == 0:
403
+ return go.Figure()
404
+ elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0:
405
+ return go.Figure()
418
406
 
419
- # Get color mapping for all columns
420
- colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)}
421
-
422
- for column in positive_columns + negative_columns:
423
- fig.add_trace(
424
- go.Scatter(
425
- x=data.index,
426
- y=data[column],
427
- mode='lines',
428
- name=column,
429
- line=dict(shape='hv', color=colors_stacked[column]),
430
- fill='tonexty',
431
- stackgroup='pos' if column in positive_columns else 'neg',
407
+ # Warn if fig parameter is used with faceting
408
+ if fig is not None and (facet_by is not None or animate_by is not None):
409
+ logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.')
410
+ fig = None
411
+
412
+ # Convert xarray to long-form DataFrame for Plotly Express
413
+ if isinstance(data, (xr.DataArray, xr.Dataset)):
414
+ # Convert to long-form (tidy) DataFrame
415
+ # Structure: time, variable, value, scenario, period, ... (all dims as columns)
416
+ if isinstance(data, xr.Dataset):
417
+ # Stack all data variables into long format
418
+ df_long = data.to_dataframe().reset_index()
419
+ # Melt to get: time, scenario, period, ..., variable, value
420
+ id_vars = [dim for dim in data.dims]
421
+ value_vars = list(data.data_vars)
422
+ df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
423
+ else:
424
+ # DataArray
425
+ df_long = data.to_dataframe().reset_index()
426
+ if data.name:
427
+ df_long = df_long.rename(columns={data.name: 'value'})
428
+ else:
429
+ # Unnamed DataArray, find the value column
430
+ non_dim_cols = [col for col in df_long.columns if col not in data.dims]
431
+ if len(non_dim_cols) != 1:
432
+ raise ValueError(
433
+ f'Expected exactly one non-dimension column for unnamed DataArray, '
434
+ f'but found {len(non_dim_cols)}: {non_dim_cols}'
435
+ )
436
+ value_col = non_dim_cols[0]
437
+ df_long = df_long.rename(columns={value_col: 'value'})
438
+ df_long['variable'] = data.name or 'data'
439
+ else:
440
+ # Already a DataFrame - convert to long format for Plotly Express
441
+ df_long = data.reset_index()
442
+ if 'time' not in df_long.columns:
443
+ # First column is probably time
444
+ df_long = df_long.rename(columns={df_long.columns[0]: 'time'})
445
+ # Melt to long format
446
+ id_vars = [
447
+ col
448
+ for col in df_long.columns
449
+ if col in ['time', 'scenario', 'period']
450
+ or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else [])
451
+ ]
452
+ value_vars = [col for col in df_long.columns if col not in id_vars]
453
+ df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
454
+
455
+ # Validate facet_by and animate_by dimensions exist in the data
456
+ available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
457
+
458
+ # Check facet_by dimensions
459
+ if facet_by is not None:
460
+ if isinstance(facet_by, str):
461
+ if facet_by not in available_dims:
462
+ logger.debug(
463
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
464
+ f'Ignoring facet_by parameter.'
432
465
  )
433
- )
466
+ facet_by = None
467
+ elif isinstance(facet_by, list):
468
+ # Filter out dimensions that don't exist
469
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
470
+ facet_by = [dim for dim in facet_by if dim in available_dims]
471
+ if missing_dims:
472
+ logger.debug(
473
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
474
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
475
+ )
476
+ if len(facet_by) == 0:
477
+ facet_by = None
478
+
479
+ # Check animate_by dimension
480
+ if animate_by is not None and animate_by not in available_dims:
481
+ logger.debug(
482
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
483
+ f'Ignoring animate_by parameter.'
484
+ )
485
+ animate_by = None
486
+
487
+ # Setup faceting parameters for Plotly Express
488
+ facet_row = None
489
+ facet_col = None
490
+ if facet_by:
491
+ if isinstance(facet_by, str):
492
+ # Single facet dimension - use facet_col with facet_col_wrap
493
+ facet_col = facet_by
494
+ elif len(facet_by) == 1:
495
+ facet_col = facet_by[0]
496
+ elif len(facet_by) == 2:
497
+ # Two facet dimensions - use facet_row and facet_col
498
+ facet_row = facet_by[0]
499
+ facet_col = facet_by[1]
500
+ else:
501
+ raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
502
+
503
+ # Process colors
504
+ all_vars = df_long['variable'].unique().tolist()
505
+ processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars)
506
+ color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)}
507
+
508
+ # Create plot using Plotly Express based on mode
509
+ common_args = {
510
+ 'data_frame': df_long,
511
+ 'x': 'time',
512
+ 'y': 'value',
513
+ 'color': 'variable',
514
+ 'facet_row': facet_row,
515
+ 'facet_col': facet_col,
516
+ 'animation_frame': animate_by,
517
+ 'color_discrete_map': color_discrete_map,
518
+ 'title': title,
519
+ 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''},
520
+ }
434
521
 
435
- for column in mixed_columns:
436
- fig.add_trace(
437
- go.Scatter(
438
- x=data.index,
439
- y=data[column],
440
- mode='lines',
441
- name=column,
442
- line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
522
+ # Add facet_col_wrap for single facet dimension
523
+ if facet_col and not facet_row:
524
+ common_args['facet_col_wrap'] = facet_cols
525
+
526
+ if mode == 'stacked_bar':
527
+ fig = px.bar(**common_args)
528
+ fig.update_traces(marker_line_width=0)
529
+ fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
530
+ elif mode == 'grouped_bar':
531
+ fig = px.bar(**common_args)
532
+ fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
533
+ elif mode == 'line':
534
+ fig = px.line(**common_args, line_shape='hv') # Stepped lines
535
+ elif mode == 'area':
536
+ # Use Plotly Express to create the area plot (preserves animation, legends, faceting)
537
+ fig = px.area(**common_args, line_shape='hv')
538
+
539
+ # Classify each variable based on its values
540
+ variable_classification = {}
541
+ for var in all_vars:
542
+ var_data = df_long[df_long['variable'] == var]['value']
543
+ var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
544
+
545
+ if len(var_data_clean) == 0:
546
+ variable_classification[var] = 'zero'
547
+ else:
548
+ has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
549
+ variable_classification[var] = (
550
+ 'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
443
551
  )
444
- )
445
552
 
446
- # Update layout for better aesthetics
553
+ # Log warning for mixed variables
554
+ mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
555
+ if mixed_vars:
556
+ logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
557
+
558
+ all_traces = list(fig.data)
559
+ for frame in fig.frames:
560
+ all_traces.extend(frame.data)
561
+
562
+ for trace in all_traces:
563
+ cls = variable_classification.get(trace.name, None)
564
+ # Only stack positive and negative, not mixed or zero
565
+ trace.stackgroup = cls if cls in ('positive', 'negative') else None
566
+
567
+ if cls in ('positive', 'negative'):
568
+ # Stacked area: add opacity to avoid hiding layers, remove line border
569
+ if hasattr(trace, 'line') and trace.line.color:
570
+ trace.fillcolor = trace.line.color
571
+ trace.line.width = 0
572
+ elif cls == 'mixed':
573
+ # Mixed variables: show as dashed line, not stacked
574
+ if hasattr(trace, 'line'):
575
+ trace.line.width = 2
576
+ trace.line.dash = 'dash'
577
+ if hasattr(trace, 'fill'):
578
+ trace.fill = None
579
+
580
+ # Update layout with basic styling (Plotly Express handles sizing automatically)
447
581
  fig.update_layout(
448
- title=title,
449
- yaxis=dict(
450
- title=ylabel,
451
- showgrid=True, # Enable grid lines on the y-axis
452
- gridcolor='lightgrey', # Customize grid line color
453
- gridwidth=0.5, # Customize grid line width
454
- ),
455
- xaxis=dict(
456
- title=xlabel,
457
- showgrid=True, # Enable grid lines on the x-axis
458
- gridcolor='lightgrey', # Customize grid line color
459
- gridwidth=0.5, # Customize grid line width
460
- ),
461
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
462
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
463
- font=dict(size=14), # Increase font size for better readability
582
+ plot_bgcolor='rgba(0,0,0,0)',
583
+ paper_bgcolor='rgba(0,0,0,0)',
584
+ font=dict(size=12),
464
585
  )
465
586
 
587
+ # Update axes to share if requested (Plotly Express already handles this, but we can customize)
588
+ if not shared_yaxes:
589
+ fig.update_yaxes(matches=None)
590
+ if not shared_xaxes:
591
+ fig.update_xaxes(matches=None)
592
+
466
593
  return fig
467
594
 
468
595
 
469
596
  def with_matplotlib(
470
597
  data: pd.DataFrame,
471
- style: Literal['stacked_bar', 'line'] = 'stacked_bar',
598
+ mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
472
599
  colors: ColorType = 'viridis',
473
600
  title: str = '',
474
601
  ylabel: str = '',
@@ -483,7 +610,7 @@ def with_matplotlib(
483
610
  Args:
484
611
  data: A DataFrame containing the data to plot. The index should represent time (e.g., hours),
485
612
  and each column represents a separate data series.
486
- style: Plotting style. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
613
+ mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
487
614
  colors: Color specification, can be:
488
615
  - A string with a colormap name (e.g., 'viridis', 'plasma')
489
616
  - A list of color strings (e.g., ['#ff0000', '#00ff00'])
@@ -499,19 +626,19 @@ def with_matplotlib(
499
626
  A tuple containing the Matplotlib figure and axes objects used for the plot.
500
627
 
501
628
  Notes:
502
- - If `style` is 'stacked_bar', bars are stacked for both positive and negative values.
629
+ - If `mode` is 'stacked_bar', bars are stacked for both positive and negative values.
503
630
  Negative values are stacked separately without extra labels in the legend.
504
- - If `style` is 'line', stepped lines are drawn for each data series.
631
+ - If `mode` is 'line', stepped lines are drawn for each data series.
505
632
  """
506
- if style not in ('stacked_bar', 'line'):
507
- raise ValueError(f"'style' must be one of {{'stacked_bar','line'}} for matplotlib, got {style!r}")
633
+ if mode not in ('stacked_bar', 'line'):
634
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
508
635
 
509
636
  if fig is None or ax is None:
510
637
  fig, ax = plt.subplots(figsize=figsize)
511
638
 
512
639
  processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns))
513
640
 
514
- if style == 'stacked_bar':
641
+ if mode == 'stacked_bar':
515
642
  cumulative_positive = np.zeros(len(data))
516
643
  cumulative_negative = np.zeros(len(data))
517
644
  width = data.index.to_series().diff().dropna().min() # Minimum time difference
@@ -542,7 +669,7 @@ def with_matplotlib(
542
669
  )
543
670
  cumulative_negative += negative_values.values
544
671
 
545
- elif style == 'line':
672
+ elif mode == 'line':
546
673
  for i, column in enumerate(data.columns):
547
674
  ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column)
548
675
 
@@ -562,213 +689,110 @@ def with_matplotlib(
562
689
  return fig, ax
563
690
 
564
691
 
565
- def heat_map_matplotlib(
566
- data: pd.DataFrame,
567
- color_map: str = 'viridis',
568
- title: str = '',
569
- xlabel: str = 'Period',
570
- ylabel: str = 'Step',
571
- figsize: tuple[float, float] = (12, 6),
572
- ) -> tuple[plt.Figure, plt.Axes]:
692
+ def reshape_data_for_heatmap(
693
+ data: xr.DataArray,
694
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
695
+ | Literal['auto']
696
+ | None = 'auto',
697
+ facet_by: str | list[str] | None = None,
698
+ animate_by: str | None = None,
699
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
700
+ ) -> xr.DataArray:
573
701
  """
574
- Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis,
575
- the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot.
702
+ Reshape data for heatmap visualization, handling time dimension intelligently.
576
703
 
577
- Args:
578
- data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
579
- The values in the DataFrame will be represented as colors in the heatmap.
580
- color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc.
581
- title: The title of the plot.
582
- xlabel: The label for the x-axis.
583
- ylabel: The label for the y-axis.
584
- figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches.
585
-
586
- Returns:
587
- A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area
588
- where the heatmap is drawn. These can be used for further customization or saving the plot to a file.
589
-
590
- Notes:
591
- - The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot.
592
- - The color scale is normalized based on the minimum and maximum values in the DataFrame.
593
- - The x-axis labels (periods) are placed at the top of the plot.
594
- - The colorbar is added horizontally at the bottom of the plot, with a label.
595
- """
596
-
597
- # Get the min and max values for color normalization
598
- color_bar_min, color_bar_max = data.min().min(), data.max().max()
599
-
600
- # Create the heatmap plot
601
- fig, ax = plt.subplots(figsize=figsize)
602
- ax.pcolormesh(data.values, cmap=color_map, shading='auto')
603
- ax.invert_yaxis() # Flip the y-axis to start at the top
604
-
605
- # Adjust ticks and labels for x and y axes
606
- ax.set_xticks(np.arange(len(data.columns)) + 0.5)
607
- ax.set_xticklabels(data.columns, ha='center')
608
- ax.set_yticks(np.arange(len(data.index)) + 0.5)
609
- ax.set_yticklabels(data.index, va='center')
610
-
611
- # Add labels to the axes
612
- ax.set_xlabel(xlabel, ha='center')
613
- ax.set_ylabel(ylabel, va='center')
614
- ax.set_title(title)
615
-
616
- # Position x-axis labels at the top
617
- ax.xaxis.set_label_position('top')
618
- ax.xaxis.set_ticks_position('top')
619
-
620
- # Add the colorbar
621
- sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max))
622
- sm1.set_array([])
623
- fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal')
704
+ This function decides whether to reshape the 'time' dimension based on the reshape_time parameter:
705
+ - 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap
706
+ - Tuple: Explicitly reshapes time with specified parameters
707
+ - None: No reshaping (returns data as-is)
624
708
 
625
- fig.tight_layout()
626
-
627
- return fig, ax
628
-
629
-
630
- def heat_map_plotly(
631
- data: pd.DataFrame,
632
- color_map: str = 'viridis',
633
- title: str = '',
634
- xlabel: str = 'Period',
635
- ylabel: str = 'Step',
636
- categorical_labels: bool = True,
637
- ) -> go.Figure:
638
- """
639
- Plots a DataFrame as a heatmap using Plotly. The columns of the DataFrame will be mapped to the x-axis,
640
- and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot.
709
+ All non-time dimensions are preserved during reshaping.
641
710
 
642
711
  Args:
643
- data: A DataFrame with the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
644
- The values in the DataFrame will be represented as colors in the heatmap.
645
- color_map: The color scale to use for the heatmap. Default is 'viridis'. Plotly supports various color scales like 'Cividis', 'Inferno', etc.
646
- title: The title of the heatmap. Default is an empty string.
647
- xlabel: The label for the x-axis. Default is 'Period'.
648
- ylabel: The label for the y-axis. Default is 'Step'.
649
- categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data).
650
- Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data.
712
+ data: DataArray to reshape for heatmap visualization.
713
+ reshape_time: Reshaping configuration:
714
+ - 'auto' (default): Auto-reshape if needed based on facet_by/animate_by
715
+ - Tuple (timeframes, timesteps_per_frame): Explicit time reshaping
716
+ - None: No reshaping
717
+ facet_by: Dimension(s) used for faceting (used in 'auto' decision).
718
+ animate_by: Dimension used for animation (used in 'auto' decision).
719
+ fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'.
651
720
 
652
721
  Returns:
653
- A Plotly figure object containing the heatmap. This can be further customized and saved
654
- or displayed using `fig.show()`.
655
-
656
- Notes:
657
- The color bar is automatically scaled to the minimum and maximum values in the data.
658
- The y-axis is reversed to display the first row at the top.
659
- """
722
+ Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced
723
+ by 'timestep' and 'timeframe'. All other dimensions are preserved.
660
724
 
661
- color_bar_min, color_bar_max = data.min().min(), data.max().max() # Min and max values for color scaling
662
- # Define the figure
663
- fig = go.Figure(
664
- data=go.Heatmap(
665
- z=data.values,
666
- x=data.columns,
667
- y=data.index,
668
- colorscale=color_map,
669
- zmin=color_bar_min,
670
- zmax=color_bar_max,
671
- colorbar=dict(
672
- title=dict(text='Color Bar Label', side='right'),
673
- orientation='h',
674
- xref='container',
675
- yref='container',
676
- len=0.8, # Color bar length relative to plot
677
- x=0.5,
678
- y=0.1,
679
- ),
680
- )
681
- )
682
-
683
- # Set axis labels and style
684
- fig.update_layout(
685
- title=title,
686
- xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None),
687
- yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None),
688
- )
689
-
690
- return fig
691
-
692
-
693
- def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray:
694
- """
695
- Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap.
696
-
697
- The reshaped array will have the number of rows corresponding to the steps per column
698
- (e.g., 24 hours per day) and columns representing time periods (e.g., days or months).
699
-
700
- Args:
701
- data_1d: A 1D numpy array with the data to reshape.
702
- nr_of_steps_per_column: The number of steps (rows) per column in the resulting 2D array. For example,
703
- this could be 24 (for hours) or 31 (for days in a month).
704
-
705
- Returns:
706
- The reshaped 2D array. Each internal array corresponds to one column, with the specified number of steps.
707
- Each column might represents a time period (e.g., day, month, etc.).
708
- """
709
-
710
- # Step 1: Ensure the input is a 1D array.
711
- if data_1d.ndim != 1:
712
- raise ValueError('Input must be a 1D array')
713
-
714
- # Step 2: Convert data to float type to allow NaN padding
715
- if data_1d.dtype != np.float64:
716
- data_1d = data_1d.astype(np.float64)
717
-
718
- # Step 3: Calculate the number of columns required
719
- total_steps = len(data_1d)
720
- cols = len(data_1d) // nr_of_steps_per_column # Base number of columns
721
-
722
- # If there's a remainder, add an extra column to hold the remaining values
723
- if total_steps % nr_of_steps_per_column != 0:
724
- cols += 1
725
+ Examples:
726
+ Auto-reshaping:
725
727
 
726
- # Step 4: Pad the 1D data to match the required number of rows and columns
727
- padded_data = np.pad(
728
- data_1d, (0, cols * nr_of_steps_per_column - total_steps), mode='constant', constant_values=np.nan
729
- )
728
+ ```python
729
+ # Will auto-reshape because only 'time' remains after faceting/animation
730
+ data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period')
731
+ ```
730
732
 
731
- # Step 5: Reshape the padded data into a 2D array
732
- data_2d = padded_data.reshape(cols, nr_of_steps_per_column)
733
+ Explicit reshaping:
733
734
 
734
- return data_2d.T
735
+ ```python
736
+ # Explicitly reshape to daily pattern
737
+ data = reshape_data_for_heatmap(data, reshape_time=('D', 'h'))
738
+ ```
735
739
 
740
+ No reshaping:
736
741
 
737
- def heat_map_data_from_df(
738
- df: pd.DataFrame,
739
- periods: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'],
740
- steps_per_period: Literal['W', 'D', 'h', '15min', 'min'],
741
- fill: Literal['ffill', 'bfill'] | None = None,
742
- ) -> pd.DataFrame:
742
+ ```python
743
+ # Keep data as-is
744
+ data = reshape_data_for_heatmap(data, reshape_time=None)
745
+ ```
743
746
  """
744
- Reshapes a DataFrame with a DateTime index into a 2D array for heatmap plotting,
745
- based on a specified sample rate.
746
- Only specific combinations of `periods` and `steps_per_period` are supported; invalid combinations raise an assertion.
747
+ # If no time dimension, return data as-is
748
+ if 'time' not in data.dims:
749
+ return data
750
+
751
+ # Handle None (disabled) - return data as-is
752
+ if reshape_time is None:
753
+ return data
754
+
755
+ # Determine timeframes and timesteps_per_frame based on reshape_time parameter
756
+ if reshape_time == 'auto':
757
+ # Check if we need automatic time reshaping
758
+ facet_dims_used = []
759
+ if facet_by:
760
+ facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by)
761
+ if animate_by:
762
+ facet_dims_used.append(animate_by)
763
+
764
+ # Get dimensions that would remain for heatmap
765
+ potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used]
766
+
767
+ # Auto-reshape if only 'time' dimension remains
768
+ if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time':
769
+ logger.debug(
770
+ "Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. "
771
+ "Using default timeframes='D' and timesteps_per_frame='h'. "
772
+ "To customize, use reshape_time=('D', 'h') or disable with reshape_time=None."
773
+ )
774
+ timeframes, timesteps_per_frame = 'D', 'h'
775
+ else:
776
+ # No reshaping needed
777
+ return data
778
+ elif isinstance(reshape_time, tuple):
779
+ # Explicit reshaping
780
+ timeframes, timesteps_per_frame = reshape_time
781
+ else:
782
+ raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}")
747
783
 
748
- Args:
749
- df: A DataFrame with a DateTime index containing the data to reshape.
750
- periods: The time interval of each period (columns of the heatmap),
751
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
752
- steps_per_period: The time interval within each period (rows in the heatmap),
753
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
754
- fill: Method to fill missing values: 'ffill' for forward fill or 'bfill' for backward fill.
784
+ # Validate that time is datetime
785
+ if not np.issubdtype(data.coords['time'].dtype, np.datetime64):
786
+ raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}')
755
787
 
756
- Returns:
757
- A DataFrame suitable for heatmap plotting, with rows representing steps within each period
758
- and columns representing each period.
759
- """
760
- assert pd.api.types.is_datetime64_any_dtype(df.index), (
761
- 'The index of the DataFrame must be datetime to transform it properly for a heatmap plot'
762
- )
763
-
764
- # Define formats for different combinations of `periods` and `steps_per_period`
788
+ # Define formats for different combinations
765
789
  formats = {
766
790
  ('YS', 'W'): ('%Y', '%W'),
767
791
  ('YS', 'D'): ('%Y', '%j'), # day of year
768
792
  ('YS', 'h'): ('%Y', '%j %H:00'),
769
793
  ('MS', 'D'): ('%Y-%m', '%d'), # day of month
770
794
  ('MS', 'h'): ('%Y-%m', '%d %H:00'),
771
- ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week (with prefix for proper sorting)
795
+ ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
772
796
  ('W', 'h'): ('%Y-w%W', '%w_%A %H:00'),
773
797
  ('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour
774
798
  ('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute
@@ -776,43 +800,64 @@ def heat_map_data_from_df(
776
800
  ('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour
777
801
  }
778
802
 
779
- if df.empty:
780
- raise ValueError('DataFrame is empty.')
781
- diffs = df.index.to_series().diff().dropna()
782
- minimum_time_diff_in_min = diffs.min().total_seconds() / 60
783
- time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
784
- if time_intervals[steps_per_period] > minimum_time_diff_in_min:
785
- logger.error(
786
- f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to '
787
- f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.'
788
- )
789
-
790
- # Select the format based on the `periods` and `steps_per_period` combination
791
- format_pair = (periods, steps_per_period)
803
+ format_pair = (timeframes, timesteps_per_frame)
792
804
  if format_pair not in formats:
793
805
  raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}')
794
806
  period_format, step_format = formats[format_pair]
795
807
 
796
- df = df.sort_index() # Ensure DataFrame is sorted by time index
808
+ # Check if resampling is needed
809
+ if data.sizes['time'] > 1:
810
+ # Use NumPy for more efficient timedelta computation
811
+ time_values = data.coords['time'].values # Already numpy datetime64[ns]
812
+ # Calculate differences and convert to minutes
813
+ time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0
814
+ if time_diffs.size > 0:
815
+ min_time_diff_min = np.nanmin(time_diffs)
816
+ time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
817
+ if time_intervals[timesteps_per_frame] > min_time_diff_min:
818
+ logger.warning(
819
+ f'Resampling data from {min_time_diff_min:.2f} min to '
820
+ f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.'
821
+ )
797
822
 
798
- resampled_data = df.resample(steps_per_period).mean() # Resample and fill any gaps with NaN
823
+ # Resample along time dimension
824
+ resampled = data.resample(time=timesteps_per_frame).mean()
799
825
 
800
- if fill == 'ffill': # Apply fill method if specified
801
- resampled_data = resampled_data.ffill()
826
+ # Apply fill if specified
827
+ if fill == 'ffill':
828
+ resampled = resampled.ffill(dim='time')
802
829
  elif fill == 'bfill':
803
- resampled_data = resampled_data.bfill()
830
+ resampled = resampled.bfill(dim='time')
831
+
832
+ # Create period and step labels
833
+ time_values = pd.to_datetime(resampled.coords['time'].values)
834
+ period_labels = time_values.strftime(period_format)
835
+ step_labels = time_values.strftime(step_format)
836
+
837
+ # Handle special case for weekly day format
838
+ if '%w_%A' in step_format:
839
+ step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values
840
+
841
+ # Add period and step as coordinates
842
+ resampled = resampled.assign_coords(
843
+ {
844
+ 'timeframe': ('time', period_labels),
845
+ 'timestep': ('time', step_labels),
846
+ }
847
+ )
804
848
 
805
- resampled_data['period'] = resampled_data.index.strftime(period_format)
806
- resampled_data['step'] = resampled_data.index.strftime(step_format)
807
- if '%w_%A' in step_format: # Shift index of strings to ensure proper sorting
808
- resampled_data['step'] = resampled_data['step'].apply(
809
- lambda x: x.replace('0_Sunday', '7_Sunday') if '0_Sunday' in x else x
810
- )
849
+ # Convert to multi-index and unstack
850
+ resampled = resampled.set_index(time=['timeframe', 'timestep'])
851
+ result = resampled.unstack('time')
811
852
 
812
- # Pivot the table so periods are columns and steps are indices
813
- df_pivoted = resampled_data.pivot(columns='period', index='step', values=df.columns[0])
853
+ # Ensure timestep and timeframe come first in dimension order
854
+ # Get other dimensions
855
+ other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']]
814
856
 
815
- return df_pivoted
857
+ # Reorder: timestep, timeframe, then other dimensions
858
+ result = result.transpose('timestep', 'timeframe', *other_dims)
859
+
860
+ return result
816
861
 
817
862
 
818
863
  def plot_network(
@@ -1413,6 +1458,311 @@ def dual_pie_with_matplotlib(
1413
1458
  return fig, axes
1414
1459
 
1415
1460
 
1461
+ def heatmap_with_plotly(
1462
+ data: xr.DataArray,
1463
+ colors: ColorType = 'viridis',
1464
+ title: str = '',
1465
+ facet_by: str | list[str] | None = None,
1466
+ animate_by: str | None = None,
1467
+ facet_cols: int = 3,
1468
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1469
+ | Literal['auto']
1470
+ | None = 'auto',
1471
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1472
+ ) -> go.Figure:
1473
+ """
1474
+ Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
1475
+
1476
+ This function creates heatmap visualizations from xarray DataArrays, supporting
1477
+ multi-dimensional data through faceting (subplots) and animation. It automatically
1478
+ handles dimension reduction and data reshaping for optimal heatmap display.
1479
+
1480
+ Automatic Time Reshaping:
1481
+ If only the 'time' dimension remains after faceting/animation (making the data 1D),
1482
+ the function automatically reshapes time into a 2D format using default values
1483
+ (timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap
1484
+ showing hours vs days.
1485
+
1486
+ Args:
1487
+ data: An xarray DataArray containing the data to visualize. Should have at least
1488
+ 2 dimensions, or a 'time' dimension that can be reshaped into 2D.
1489
+ colors: Color specification (colormap name, list, or dict). Common options:
1490
+ 'viridis', 'plasma', 'RdBu', 'portland'.
1491
+ title: The main title of the heatmap.
1492
+ facet_by: Dimension to create facets for. Creates a subplot grid.
1493
+ Can be a single dimension name or list (only first dimension used).
1494
+ Note: px.imshow only supports single-dimension faceting.
1495
+ If the dimension doesn't exist in the data, it will be silently ignored.
1496
+ animate_by: Dimension to animate over. Creates animation frames.
1497
+ If the dimension doesn't exist in the data, it will be silently ignored.
1498
+ facet_cols: Number of columns in the facet grid (used with facet_by).
1499
+ reshape_time: Time reshaping configuration:
1500
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains
1501
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1502
+ - None: Disable time reshaping (will error if only 1D time data)
1503
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1504
+
1505
+ Returns:
1506
+ A Plotly figure object containing the heatmap visualization.
1507
+
1508
+ Examples:
1509
+ Simple heatmap:
1510
+
1511
+ ```python
1512
+ fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map')
1513
+ ```
1514
+
1515
+ Facet by scenario:
1516
+
1517
+ ```python
1518
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2)
1519
+ ```
1520
+
1521
+ Animate by period:
1522
+
1523
+ ```python
1524
+ fig = heatmap_with_plotly(data_array, animate_by='period')
1525
+ ```
1526
+
1527
+ Automatic time reshaping (when only time dimension remains):
1528
+
1529
+ ```python
1530
+ # Data with dims ['time', 'scenario', 'period']
1531
+ # After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe)
1532
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period')
1533
+ ```
1534
+
1535
+ Explicit time reshaping:
1536
+
1537
+ ```python
1538
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
1539
+ ```
1540
+ """
1541
+ # Handle empty data
1542
+ if data.size == 0:
1543
+ return go.Figure()
1544
+
1545
+ # Apply time reshaping using the new unified function
1546
+ data = reshape_data_for_heatmap(
1547
+ data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill
1548
+ )
1549
+
1550
+ # Get available dimensions
1551
+ available_dims = list(data.dims)
1552
+
1553
+ # Validate and filter facet_by dimensions
1554
+ if facet_by is not None:
1555
+ if isinstance(facet_by, str):
1556
+ if facet_by not in available_dims:
1557
+ logger.debug(
1558
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
1559
+ f'Ignoring facet_by parameter.'
1560
+ )
1561
+ facet_by = None
1562
+ elif isinstance(facet_by, list):
1563
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
1564
+ facet_by = [dim for dim in facet_by if dim in available_dims]
1565
+ if missing_dims:
1566
+ logger.debug(
1567
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
1568
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
1569
+ )
1570
+ if len(facet_by) == 0:
1571
+ facet_by = None
1572
+
1573
+ # Validate animate_by dimension
1574
+ if animate_by is not None and animate_by not in available_dims:
1575
+ logger.debug(
1576
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
1577
+ f'Ignoring animate_by parameter.'
1578
+ )
1579
+ animate_by = None
1580
+
1581
+ # Determine which dimensions are used for faceting/animation
1582
+ facet_dims = []
1583
+ if facet_by:
1584
+ facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by
1585
+ if animate_by:
1586
+ facet_dims.append(animate_by)
1587
+
1588
+ # Get remaining dimensions for the heatmap itself
1589
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
1590
+
1591
+ if len(heatmap_dims) < 2:
1592
+ # Need at least 2 dimensions for a heatmap
1593
+ logger.error(
1594
+ f'Heatmap requires at least 2 dimensions for rows and columns. '
1595
+ f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
1596
+ )
1597
+ return go.Figure()
1598
+
1599
+ # Setup faceting parameters for Plotly Express
1600
+ # Note: px.imshow only supports facet_col, not facet_row
1601
+ facet_col_param = None
1602
+ if facet_by:
1603
+ if isinstance(facet_by, str):
1604
+ facet_col_param = facet_by
1605
+ elif len(facet_by) == 1:
1606
+ facet_col_param = facet_by[0]
1607
+ elif len(facet_by) >= 2:
1608
+ # px.imshow doesn't support facet_row, so we can only facet by one dimension
1609
+ # Use the first dimension and warn about the rest
1610
+ facet_col_param = facet_by[0]
1611
+ logger.warning(
1612
+ f'px.imshow only supports faceting by a single dimension. '
1613
+ f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. '
1614
+ f'Consider using animate_by for additional dimensions.'
1615
+ )
1616
+
1617
+ # Create the imshow plot - px.imshow can work directly with xarray DataArrays
1618
+ common_args = {
1619
+ 'img': data,
1620
+ 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis',
1621
+ 'title': title,
1622
+ }
1623
+
1624
+ # Add faceting if specified
1625
+ if facet_col_param:
1626
+ common_args['facet_col'] = facet_col_param
1627
+ if facet_cols:
1628
+ common_args['facet_col_wrap'] = facet_cols
1629
+
1630
+ # Add animation if specified
1631
+ if animate_by:
1632
+ common_args['animation_frame'] = animate_by
1633
+
1634
+ try:
1635
+ fig = px.imshow(**common_args)
1636
+ except Exception as e:
1637
+ logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
1638
+ # Fallback: create a simple heatmap without faceting
1639
+ fig = px.imshow(
1640
+ data.values,
1641
+ color_continuous_scale=colors if isinstance(colors, str) else 'viridis',
1642
+ title=title,
1643
+ )
1644
+
1645
+ # Update layout with basic styling
1646
+ fig.update_layout(
1647
+ plot_bgcolor='rgba(0,0,0,0)',
1648
+ paper_bgcolor='rgba(0,0,0,0)',
1649
+ font=dict(size=12),
1650
+ )
1651
+
1652
+ return fig
1653
+
1654
+
1655
+ def heatmap_with_matplotlib(
1656
+ data: xr.DataArray,
1657
+ colors: ColorType = 'viridis',
1658
+ title: str = '',
1659
+ figsize: tuple[float, float] = (12, 6),
1660
+ fig: plt.Figure | None = None,
1661
+ ax: plt.Axes | None = None,
1662
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1663
+ | Literal['auto']
1664
+ | None = 'auto',
1665
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1666
+ ) -> tuple[plt.Figure, plt.Axes]:
1667
+ """
1668
+ Plot a heatmap visualization using Matplotlib's imshow.
1669
+
1670
+ This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's
1671
+ imshow function. For multi-dimensional data, only the first two dimensions are used.
1672
+
1673
+ Args:
1674
+ data: An xarray DataArray containing the data to visualize. Should have at least
1675
+ 2 dimensions. If more than 2 dimensions exist, additional dimensions will
1676
+ be reduced by taking the first slice.
1677
+ colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu').
1678
+ title: The title of the heatmap.
1679
+ figsize: The size of the figure (width, height) in inches.
1680
+ fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
1681
+ ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
1682
+ reshape_time: Time reshaping configuration:
1683
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
1684
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
1685
+ - None: Disable time reshaping
1686
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
1687
+
1688
+ Returns:
1689
+ A tuple containing the Matplotlib figure and axes objects used for the plot.
1690
+
1691
+ Notes:
1692
+ - Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features.
1693
+ - The y-axis is automatically inverted to display data with origin at top-left.
1694
+ - A colorbar is added to show the value scale.
1695
+
1696
+ Examples:
1697
+ ```python
1698
+ fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature')
1699
+ plt.savefig('heatmap.png')
1700
+ ```
1701
+
1702
+ Time reshaping:
1703
+
1704
+ ```python
1705
+ fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
1706
+ ```
1707
+ """
1708
+ # Handle empty data
1709
+ if data.size == 0:
1710
+ if fig is None or ax is None:
1711
+ fig, ax = plt.subplots(figsize=figsize)
1712
+ return fig, ax
1713
+
1714
+ # Apply time reshaping using the new unified function
1715
+ # Matplotlib doesn't support faceting/animation, so we pass None for those
1716
+ data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
1717
+
1718
+ # Create figure and axes if not provided
1719
+ if fig is None or ax is None:
1720
+ fig, ax = plt.subplots(figsize=figsize)
1721
+
1722
+ # Extract data values
1723
+ # If data has more than 2 dimensions, we need to reduce it
1724
+ if isinstance(data, xr.DataArray):
1725
+ # Get the first 2 dimensions
1726
+ dims = list(data.dims)
1727
+ if len(dims) > 2:
1728
+ logger.warning(
1729
+ f'Data has {len(dims)} dimensions: {dims}. '
1730
+ f'Only the first 2 will be used for the heatmap. '
1731
+ f'Use the plotly engine for faceting/animation support.'
1732
+ )
1733
+ # Select only the first 2 dimensions by taking first slice of others
1734
+ selection = {dim: 0 for dim in dims[2:]}
1735
+ data = data.isel(selection)
1736
+
1737
+ values = data.values
1738
+ x_labels = data.dims[1] if len(data.dims) > 1 else 'x'
1739
+ y_labels = data.dims[0] if len(data.dims) > 0 else 'y'
1740
+ else:
1741
+ values = data
1742
+ x_labels = 'x'
1743
+ y_labels = 'y'
1744
+
1745
+ # Process colormap
1746
+ cmap = colors if isinstance(colors, str) else 'viridis'
1747
+
1748
+ # Create the heatmap using imshow
1749
+ im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper')
1750
+
1751
+ # Add colorbar
1752
+ cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05)
1753
+ cbar.set_label('Value')
1754
+
1755
+ # Set labels and title
1756
+ ax.set_xlabel(str(x_labels).capitalize())
1757
+ ax.set_ylabel(str(y_labels).capitalize())
1758
+ ax.set_title(title)
1759
+
1760
+ # Apply tight layout
1761
+ fig.tight_layout()
1762
+
1763
+ return fig, ax
1764
+
1765
+
1416
1766
  def export_figure(
1417
1767
  figure_like: go.Figure | tuple[plt.Figure, plt.Axes],
1418
1768
  default_path: pathlib.Path,