flixopt 3.0.3__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flixopt might be problematic. Click here for more details.

flixopt/results.py CHANGED
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Literal
10
10
  import linopy
11
11
  import numpy as np
12
12
  import pandas as pd
13
- import plotly
14
13
  import xarray as xr
15
14
  import yaml
16
15
 
@@ -20,6 +19,7 @@ from .flow_system import FlowSystem
20
19
 
21
20
  if TYPE_CHECKING:
22
21
  import matplotlib.pyplot as plt
22
+ import plotly
23
23
  import pyvis
24
24
 
25
25
  from .calculation import Calculation, SegmentedCalculation
@@ -195,8 +195,8 @@ class CalculationResults:
195
195
  if 'flow_system' in kwargs and flow_system_data is None:
196
196
  flow_system_data = kwargs.pop('flow_system')
197
197
  warnings.warn(
198
- "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead."
199
- "Acess is now by '.flow_system_data', while '.flow_system' returns the restored FlowSystem.",
198
+ "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead. "
199
+ "Access is now via '.flow_system_data', while '.flow_system' returns the restored FlowSystem.",
200
200
  DeprecationWarning,
201
201
  stacklevel=2,
202
202
  )
@@ -230,6 +230,7 @@ class CalculationResults:
230
230
  self.timesteps_extra = self.solution.indexes['time']
231
231
  self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.timesteps_extra)
232
232
  self.scenarios = self.solution.indexes['scenario'] if 'scenario' in self.solution.indexes else None
233
+ self.periods = self.solution.indexes['period'] if 'period' in self.solution.indexes else None
233
234
 
234
235
  self._effect_share_factors = None
235
236
  self._flow_system = None
@@ -619,6 +620,30 @@ class CalculationResults:
619
620
  total = xr.DataArray(np.nan)
620
621
  return total.rename(f'{element}->{effect}({mode})')
621
622
 
623
+ def _create_template_for_mode(self, mode: Literal['temporal', 'periodic', 'total']) -> xr.DataArray:
624
+ """Create a template DataArray with the correct dimensions for a given mode.
625
+
626
+ Args:
627
+ mode: The calculation mode ('temporal', 'periodic', or 'total').
628
+
629
+ Returns:
630
+ A DataArray filled with NaN, with dimensions appropriate for the mode.
631
+ """
632
+ coords = {}
633
+ if mode == 'temporal':
634
+ coords['time'] = self.timesteps_extra
635
+ if self.periods is not None:
636
+ coords['period'] = self.periods
637
+ if self.scenarios is not None:
638
+ coords['scenario'] = self.scenarios
639
+
640
+ # Create template with appropriate shape
641
+ if coords:
642
+ shape = tuple(len(coords[dim]) for dim in coords)
643
+ return xr.DataArray(np.full(shape, np.nan, dtype=float), coords=coords, dims=list(coords.keys()))
644
+ else:
645
+ return xr.DataArray(np.nan)
646
+
622
647
  def _create_effects_dataset(self, mode: Literal['temporal', 'periodic', 'total']) -> xr.Dataset:
623
648
  """Creates a dataset containing effect totals for all components (including their flows).
624
649
  The dataset does contain the direct as well as the indirect effects of each component.
@@ -629,32 +654,23 @@ class CalculationResults:
629
654
  Returns:
630
655
  An xarray Dataset with components as dimension and effects as variables.
631
656
  """
657
+ # Create template with correct dimensions for this mode
658
+ template = self._create_template_for_mode(mode)
659
+
632
660
  ds = xr.Dataset()
633
661
  all_arrays = {}
634
- template = None # Template is needed to determine the dimensions of the arrays. This handles the case of no shares for an effect
635
-
636
662
  components_list = list(self.components)
637
663
 
638
- # First pass: collect arrays and find template
664
+ # Collect arrays for all effects and components
639
665
  for effect in self.effects:
640
666
  effect_arrays = []
641
667
  for component in components_list:
642
668
  da = self._compute_effect_total(element=component, effect=effect, mode=mode, include_flows=True)
643
669
  effect_arrays.append(da)
644
670
 
645
- if template is None and (da.dims or not da.isnull().all()):
646
- template = da
647
-
648
671
  all_arrays[effect] = effect_arrays
649
672
 
650
- # Ensure we have a template
651
- if template is None:
652
- raise ValueError(
653
- f"No template with proper dimensions found for mode '{mode}'. "
654
- f'All computed arrays are scalars, which indicates a data issue.'
655
- )
656
-
657
- # Second pass: process all effects (guaranteed to include all)
673
+ # Process all effects: expand scalar NaN arrays to match template dimensions
658
674
  for effect in self.effects:
659
675
  dataarrays = all_arrays[effect]
660
676
  component_arrays = []
@@ -687,68 +703,117 @@ class CalculationResults:
687
703
 
688
704
  def plot_heatmap(
689
705
  self,
690
- variable_name: str,
691
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
692
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
693
- color_map: str = 'portland',
706
+ variable_name: str | list[str],
694
707
  save: bool | pathlib.Path = False,
695
708
  show: bool = True,
709
+ colors: plotting.ColorType = 'viridis',
696
710
  engine: plotting.PlottingEngine = 'plotly',
711
+ select: dict[FlowSystemDimensions, Any] | None = None,
712
+ facet_by: str | list[str] | None = 'scenario',
713
+ animate_by: str | None = 'period',
714
+ facet_cols: int = 3,
715
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
716
+ | Literal['auto']
717
+ | None = 'auto',
718
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
719
+ # Deprecated parameters (kept for backwards compatibility)
697
720
  indexer: dict[FlowSystemDimensions, Any] | None = None,
721
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
722
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
723
+ color_map: str | None = None,
698
724
  ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
699
725
  """
700
- Plots a heatmap of the solution of a variable.
726
+ Plots a heatmap visualization of a variable using imshow or time-based reshaping.
727
+
728
+ Supports multiple visualization features that can be combined:
729
+ - **Multi-variable**: Plot multiple variables on a single heatmap (creates 'variable' dimension)
730
+ - **Time reshaping**: Converts 'time' dimension into 2D (e.g., hours vs days)
731
+ - **Faceting**: Creates subplots for different dimension values
732
+ - **Animation**: Animates through dimension values (Plotly only)
701
733
 
702
734
  Args:
703
- variable_name: The name of the variable to plot.
704
- heatmap_timeframes: The timeframes to use for the heatmap.
705
- heatmap_timesteps_per_frame: The timesteps per frame to use for the heatmap.
706
- color_map: The color map to use for the heatmap.
735
+ variable_name: The name of the variable to plot, or a list of variable names.
736
+ When a list is provided, variables are combined into a single DataArray
737
+ with a new 'variable' dimension.
707
738
  save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
708
739
  show: Whether to show the plot or not.
740
+ colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options.
709
741
  engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
710
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
711
- If None, uses first value for each dimension.
712
- If empty dict {}, uses all values.
742
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
743
+ Applied BEFORE faceting/animation/reshaping.
744
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
745
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
746
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
747
+ dimension values. Only one dimension can be animated. Ignored if not found.
748
+ facet_cols: Number of columns in the facet grid layout (default: 3).
749
+ reshape_time: Time reshaping configuration (default: 'auto'):
750
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
751
+ - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours,
752
+ ('MS', 'D') for months vs days, ('W', 'h') for weeks vs hours
753
+ - None: Disable auto-reshaping (will error if only 1D time data)
754
+ Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min'
755
+ fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill).
756
+ Default is 'ffill'.
713
757
 
714
758
  Examples:
715
- Basic usage (uses first scenario, first period, all time):
759
+ Direct imshow mode (default):
760
+
761
+ >>> results.plot_heatmap('Battery|charge_state', select={'scenario': 'base'})
762
+
763
+ Facet by scenario:
716
764
 
717
- >>> results.plot_heatmap('Battery|charge_state')
765
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', facet_by='scenario', facet_cols=2)
718
766
 
719
- Select specific scenario and period:
767
+ Animate by period:
720
768
 
721
- >>> results.plot_heatmap('Boiler(Qth)|flow_rate', indexer={'scenario': 'base', 'period': 2024})
769
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, animate_by='period')
722
770
 
723
- Time filtering (summer months only):
771
+ Time reshape mode - daily patterns:
772
+
773
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, reshape_time=('D', 'h'))
774
+
775
+ Combined: time reshaping with faceting and animation:
724
776
 
725
777
  >>> results.plot_heatmap(
726
- ... 'Boiler(Qth)|flow_rate',
727
- ... indexer={
728
- ... 'scenario': 'base',
729
- ... 'time': results.solution.time[results.solution.time.dt.month.isin([6, 7, 8])],
730
- ... },
778
+ ... 'Boiler(Qth)|flow_rate', facet_by='scenario', animate_by='period', reshape_time=('D', 'h')
731
779
  ... )
732
780
 
733
- Save to specific location:
781
+ Multi-variable heatmap (variables as one axis):
734
782
 
735
783
  >>> results.plot_heatmap(
736
- ... 'Boiler(Qth)|flow_rate', indexer={'scenario': 'base'}, save='path/to/my_heatmap.html'
784
+ ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate', 'HeatStorage|charge_state'],
785
+ ... select={'scenario': 'base', 'period': 1},
786
+ ... reshape_time=None,
737
787
  ... )
738
- """
739
- dataarray = self.solution[variable_name]
740
788
 
789
+ Multi-variable with time reshaping:
790
+
791
+ >>> results.plot_heatmap(
792
+ ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate'],
793
+ ... facet_by='scenario',
794
+ ... animate_by='period',
795
+ ... reshape_time=('D', 'h'),
796
+ ... )
797
+ """
798
+ # Delegate to module-level plot_heatmap function
741
799
  return plot_heatmap(
742
- dataarray=dataarray,
743
- name=variable_name,
800
+ data=self.solution[variable_name],
801
+ name=variable_name if isinstance(variable_name, str) else None,
744
802
  folder=self.folder,
745
- heatmap_timeframes=heatmap_timeframes,
746
- heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
747
- color_map=color_map,
803
+ colors=colors,
748
804
  save=save,
749
805
  show=show,
750
806
  engine=engine,
807
+ select=select,
808
+ facet_by=facet_by,
809
+ animate_by=animate_by,
810
+ facet_cols=facet_cols,
811
+ reshape_time=reshape_time,
812
+ fill=fill,
751
813
  indexer=indexer,
814
+ heatmap_timeframes=heatmap_timeframes,
815
+ heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
816
+ color_map=color_map,
752
817
  )
753
818
 
754
819
  def plot_network(
@@ -920,30 +985,107 @@ class _NodeResults(_ElementResults):
920
985
  show: bool = True,
921
986
  colors: plotting.ColorType = 'viridis',
922
987
  engine: plotting.PlottingEngine = 'plotly',
923
- indexer: dict[FlowSystemDimensions, Any] | None = None,
988
+ select: dict[FlowSystemDimensions, Any] | None = None,
924
989
  unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
925
990
  mode: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar',
926
991
  drop_suffix: bool = True,
992
+ facet_by: str | list[str] | None = 'scenario',
993
+ animate_by: str | None = 'period',
994
+ facet_cols: int = 3,
995
+ # Deprecated parameter (kept for backwards compatibility)
996
+ indexer: dict[FlowSystemDimensions, Any] | None = None,
927
997
  ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
928
998
  """
929
- Plots the node balance of the Component or Bus.
999
+ Plots the node balance of the Component or Bus with optional faceting and animation.
1000
+
930
1001
  Args:
931
1002
  save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
932
1003
  show: Whether to show the plot or not.
933
1004
  colors: The colors to use for the plot. See `flixopt.plotting.ColorType` for options.
934
1005
  engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
935
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
936
- If None, uses first value for each dimension (except time).
937
- If empty dict {}, uses all values.
1006
+ select: Optional data selection dict. Supports:
1007
+ - Single values: {'scenario': 'base', 'period': 2024}
1008
+ - Multiple values: {'scenario': ['base', 'high', 'renewable']}
1009
+ - Slices: {'time': slice('2024-01', '2024-06')}
1010
+ - Index arrays: {'time': time_array}
1011
+ Note: Applied BEFORE faceting/animation.
938
1012
  unit_type: The unit type to use for the dataset. Can be 'flow_rate' or 'flow_hours'.
939
1013
  - 'flow_rate': Returns the flow_rates of the Node.
940
1014
  - 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours.
941
1015
  mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts.
942
1016
  drop_suffix: Whether to drop the suffix from the variable names.
1017
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
1018
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
1019
+ Example: 'scenario' creates one subplot per scenario.
1020
+ Example: ['scenario', 'period'] creates a grid of subplots for each scenario-period combination.
1021
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
1022
+ dimension values. Only one dimension can be animated. Ignored if not found.
1023
+ facet_cols: Number of columns in the facet grid layout (default: 3).
1024
+
1025
+ Examples:
1026
+ Basic plot (current behavior):
1027
+
1028
+ >>> results['Boiler'].plot_node_balance()
1029
+
1030
+ Facet by scenario:
1031
+
1032
+ >>> results['Boiler'].plot_node_balance(facet_by='scenario', facet_cols=2)
1033
+
1034
+ Animate by period:
1035
+
1036
+ >>> results['Boiler'].plot_node_balance(animate_by='period')
1037
+
1038
+ Facet by scenario AND animate by period:
1039
+
1040
+ >>> results['Boiler'].plot_node_balance(facet_by='scenario', animate_by='period')
1041
+
1042
+ Select single scenario, then facet by period:
1043
+
1044
+ >>> results['Boiler'].plot_node_balance(select={'scenario': 'base'}, facet_by='period')
1045
+
1046
+ Select multiple scenarios and facet by them:
1047
+
1048
+ >>> results['Boiler'].plot_node_balance(
1049
+ ... select={'scenario': ['base', 'high', 'renewable']}, facet_by='scenario'
1050
+ ... )
1051
+
1052
+ Time range selection (summer months only):
1053
+
1054
+ >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario')
943
1055
  """
944
- ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix, indexer=indexer)
1056
+ # Handle deprecated indexer parameter
1057
+ if indexer is not None:
1058
+ # Check for conflict with new parameter
1059
+ if select is not None:
1060
+ raise ValueError(
1061
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
1062
+ )
1063
+
1064
+ import warnings
1065
+
1066
+ warnings.warn(
1067
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
1068
+ DeprecationWarning,
1069
+ stacklevel=2,
1070
+ )
1071
+ select = indexer
1072
+
1073
+ if engine not in {'plotly', 'matplotlib'}:
1074
+ raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]')
1075
+
1076
+ # Don't pass select/indexer to node_balance - we'll apply it afterwards
1077
+ ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix)
945
1078
 
946
- ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True)
1079
+ ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True)
1080
+
1081
+ # Matplotlib requires only 'time' dimension; check for extras after selection
1082
+ if engine == 'matplotlib':
1083
+ extra_dims = [d for d in ds.dims if d != 'time']
1084
+ if extra_dims:
1085
+ raise ValueError(
1086
+ f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. '
1087
+ f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.'
1088
+ )
947
1089
  suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
948
1090
 
949
1091
  title = (
@@ -952,13 +1094,16 @@ class _NodeResults(_ElementResults):
952
1094
 
953
1095
  if engine == 'plotly':
954
1096
  figure_like = plotting.with_plotly(
955
- ds.to_dataframe(),
1097
+ ds,
1098
+ facet_by=facet_by,
1099
+ animate_by=animate_by,
956
1100
  colors=colors,
957
1101
  mode=mode,
958
1102
  title=title,
1103
+ facet_cols=facet_cols,
959
1104
  )
960
1105
  default_filetype = '.html'
961
- elif engine == 'matplotlib':
1106
+ else:
962
1107
  figure_like = plotting.with_matplotlib(
963
1108
  ds.to_dataframe(),
964
1109
  colors=colors,
@@ -966,8 +1111,6 @@ class _NodeResults(_ElementResults):
966
1111
  title=title,
967
1112
  )
968
1113
  default_filetype = '.png'
969
- else:
970
- raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"')
971
1114
 
972
1115
  return plotting.export_figure(
973
1116
  figure_like=figure_like,
@@ -986,9 +1129,19 @@ class _NodeResults(_ElementResults):
986
1129
  save: bool | pathlib.Path = False,
987
1130
  show: bool = True,
988
1131
  engine: plotting.PlottingEngine = 'plotly',
1132
+ select: dict[FlowSystemDimensions, Any] | None = None,
1133
+ # Deprecated parameter (kept for backwards compatibility)
989
1134
  indexer: dict[FlowSystemDimensions, Any] | None = None,
990
1135
  ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]:
991
1136
  """Plot pie chart of flow hours distribution.
1137
+
1138
+ Note:
1139
+ Pie charts require scalar data (no extra dimensions beyond time).
1140
+ If your data has dimensions like 'scenario' or 'period', either:
1141
+
1142
+ - Use `select` to choose specific values: `select={'scenario': 'base', 'period': 2024}`
1143
+ - Let auto-selection choose the first value (a warning will be logged)
1144
+
992
1145
  Args:
993
1146
  lower_percentage_group: Percentage threshold for "Others" grouping.
994
1147
  colors: Color scheme. Also see plotly.
@@ -996,10 +1149,35 @@ class _NodeResults(_ElementResults):
996
1149
  save: Whether to save plot.
997
1150
  show: Whether to display plot.
998
1151
  engine: Plotting engine ('plotly' or 'matplotlib').
999
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
1000
- If None, uses first value for each dimension.
1001
- If empty dict {}, uses all values.
1152
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
1153
+ Use this to select specific scenario/period before creating the pie chart.
1154
+
1155
+ Examples:
1156
+ Basic usage (auto-selects first scenario/period if present):
1157
+
1158
+ >>> results['Bus'].plot_node_balance_pie()
1159
+
1160
+ Explicitly select a scenario and period:
1161
+
1162
+ >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030})
1002
1163
  """
1164
+ # Handle deprecated indexer parameter
1165
+ if indexer is not None:
1166
+ # Check for conflict with new parameter
1167
+ if select is not None:
1168
+ raise ValueError(
1169
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
1170
+ )
1171
+
1172
+ import warnings
1173
+
1174
+ warnings.warn(
1175
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
1176
+ DeprecationWarning,
1177
+ stacklevel=2,
1178
+ )
1179
+ select = indexer
1180
+
1003
1181
  inputs = sanitize_dataset(
1004
1182
  ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep,
1005
1183
  threshold=1e-5,
@@ -1015,15 +1193,46 @@ class _NodeResults(_ElementResults):
1015
1193
  drop_suffix='|',
1016
1194
  )
1017
1195
 
1018
- inputs, suffix_parts = _apply_indexer_to_data(inputs, indexer, drop=True)
1019
- outputs, suffix_parts = _apply_indexer_to_data(outputs, indexer, drop=True)
1020
- suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
1021
-
1022
- title = f'{self.label} (total flow hours){suffix}'
1196
+ inputs, suffix_parts = _apply_selection_to_data(inputs, select=select, drop=True)
1197
+ outputs, suffix_parts = _apply_selection_to_data(outputs, select=select, drop=True)
1023
1198
 
1199
+ # Sum over time dimension
1024
1200
  inputs = inputs.sum('time')
1025
1201
  outputs = outputs.sum('time')
1026
1202
 
1203
+ # Auto-select first value for any remaining dimensions (scenario, period, etc.)
1204
+ # Pie charts need scalar data, so we automatically reduce extra dimensions
1205
+ extra_dims_inputs = [dim for dim in inputs.dims if dim != 'time']
1206
+ extra_dims_outputs = [dim for dim in outputs.dims if dim != 'time']
1207
+ extra_dims = list(set(extra_dims_inputs + extra_dims_outputs))
1208
+
1209
+ if extra_dims:
1210
+ auto_select = {}
1211
+ for dim in extra_dims:
1212
+ # Get first value of this dimension
1213
+ if dim in inputs.coords:
1214
+ first_val = inputs.coords[dim].values[0]
1215
+ elif dim in outputs.coords:
1216
+ first_val = outputs.coords[dim].values[0]
1217
+ else:
1218
+ continue
1219
+ auto_select[dim] = first_val
1220
+ logger.info(
1221
+ f'Pie chart auto-selected {dim}={first_val} (first value). '
1222
+ f'Use select={{"{dim}": value}} to choose a different value.'
1223
+ )
1224
+
1225
+ # Apply auto-selection
1226
+ inputs = inputs.sel(auto_select)
1227
+ outputs = outputs.sel(auto_select)
1228
+
1229
+ # Update suffix with auto-selected values
1230
+ auto_suffix_parts = [f'{dim}={val}' for dim, val in auto_select.items()]
1231
+ suffix_parts.extend(auto_suffix_parts)
1232
+
1233
+ suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
1234
+ title = f'{self.label} (total flow hours){suffix}'
1235
+
1027
1236
  if engine == 'plotly':
1028
1237
  figure_like = plotting.dual_pie_with_plotly(
1029
1238
  data_left=inputs.to_pandas(),
@@ -1068,6 +1277,8 @@ class _NodeResults(_ElementResults):
1068
1277
  with_last_timestep: bool = False,
1069
1278
  unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
1070
1279
  drop_suffix: bool = False,
1280
+ select: dict[FlowSystemDimensions, Any] | None = None,
1281
+ # Deprecated parameter (kept for backwards compatibility)
1071
1282
  indexer: dict[FlowSystemDimensions, Any] | None = None,
1072
1283
  ) -> xr.Dataset:
1073
1284
  """
@@ -1081,10 +1292,25 @@ class _NodeResults(_ElementResults):
1081
1292
  - 'flow_rate': Returns the flow_rates of the Node.
1082
1293
  - 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours.
1083
1294
  drop_suffix: Whether to drop the suffix from the variable names.
1084
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
1085
- If None, uses first value for each dimension.
1086
- If empty dict {}, uses all values.
1295
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
1087
1296
  """
1297
+ # Handle deprecated indexer parameter
1298
+ if indexer is not None:
1299
+ # Check for conflict with new parameter
1300
+ if select is not None:
1301
+ raise ValueError(
1302
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
1303
+ )
1304
+
1305
+ import warnings
1306
+
1307
+ warnings.warn(
1308
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
1309
+ DeprecationWarning,
1310
+ stacklevel=2,
1311
+ )
1312
+ select = indexer
1313
+
1088
1314
  ds = self.solution[self.inputs + self.outputs]
1089
1315
 
1090
1316
  ds = sanitize_dataset(
@@ -1103,7 +1329,7 @@ class _NodeResults(_ElementResults):
1103
1329
  drop_suffix='|' if drop_suffix else None,
1104
1330
  )
1105
1331
 
1106
- ds, _ = _apply_indexer_to_data(ds, indexer, drop=True)
1332
+ ds, _ = _apply_selection_to_data(ds, select=select, drop=True)
1107
1333
 
1108
1334
  if unit_type == 'flow_hours':
1109
1335
  ds = ds * self._calculation_results.hours_per_timestep
@@ -1140,10 +1366,15 @@ class ComponentResults(_NodeResults):
1140
1366
  show: bool = True,
1141
1367
  colors: plotting.ColorType = 'viridis',
1142
1368
  engine: plotting.PlottingEngine = 'plotly',
1143
- mode: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar',
1369
+ mode: Literal['area', 'stacked_bar', 'line'] = 'area',
1370
+ select: dict[FlowSystemDimensions, Any] | None = None,
1371
+ facet_by: str | list[str] | None = 'scenario',
1372
+ animate_by: str | None = 'period',
1373
+ facet_cols: int = 3,
1374
+ # Deprecated parameter (kept for backwards compatibility)
1144
1375
  indexer: dict[FlowSystemDimensions, Any] | None = None,
1145
1376
  ) -> plotly.graph_objs.Figure:
1146
- """Plot storage charge state over time, combined with the node balance.
1377
+ """Plot storage charge state over time, combined with the node balance with optional faceting and animation.
1147
1378
 
1148
1379
  Args:
1149
1380
  save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
@@ -1151,42 +1382,120 @@ class ComponentResults(_NodeResults):
1151
1382
  colors: Color scheme. Also see plotly.
1152
1383
  engine: Plotting engine to use. Only 'plotly' is implemented atm.
1153
1384
  mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts.
1154
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
1155
- If None, uses first value for each dimension.
1156
- If empty dict {}, uses all values.
1385
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
1386
+ Applied BEFORE faceting/animation.
1387
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
1388
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
1389
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
1390
+ dimension values. Only one dimension can be animated. Ignored if not found.
1391
+ facet_cols: Number of columns in the facet grid layout (default: 3).
1157
1392
 
1158
1393
  Raises:
1159
1394
  ValueError: If component is not a storage.
1395
+
1396
+ Examples:
1397
+ Basic plot:
1398
+
1399
+ >>> results['Storage'].plot_charge_state()
1400
+
1401
+ Facet by scenario:
1402
+
1403
+ >>> results['Storage'].plot_charge_state(facet_by='scenario', facet_cols=2)
1404
+
1405
+ Animate by period:
1406
+
1407
+ >>> results['Storage'].plot_charge_state(animate_by='period')
1408
+
1409
+ Facet by scenario AND animate by period:
1410
+
1411
+ >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period')
1160
1412
  """
1413
+ # Handle deprecated indexer parameter
1414
+ if indexer is not None:
1415
+ # Check for conflict with new parameter
1416
+ if select is not None:
1417
+ raise ValueError(
1418
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
1419
+ )
1420
+
1421
+ import warnings
1422
+
1423
+ warnings.warn(
1424
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
1425
+ DeprecationWarning,
1426
+ stacklevel=2,
1427
+ )
1428
+ select = indexer
1429
+
1161
1430
  if not self.is_storage:
1162
1431
  raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage')
1163
1432
 
1164
- ds = self.node_balance(with_last_timestep=True, indexer=indexer)
1165
- charge_state = self.charge_state
1433
+ # Get node balance and charge state
1434
+ ds = self.node_balance(with_last_timestep=True)
1435
+ charge_state_da = self.charge_state
1166
1436
 
1167
- ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True)
1168
- charge_state, suffix_parts = _apply_indexer_to_data(charge_state, indexer, drop=True)
1437
+ # Apply select filtering
1438
+ ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True)
1439
+ charge_state_da, _ = _apply_selection_to_data(charge_state_da, select=select, drop=True)
1169
1440
  suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
1170
1441
 
1171
1442
  title = f'Operation Balance of {self.label}{suffix}'
1172
1443
 
1173
1444
  if engine == 'plotly':
1174
- fig = plotting.with_plotly(
1175
- ds.to_dataframe(),
1445
+ # Plot flows (node balance) with the specified mode
1446
+ figure_like = plotting.with_plotly(
1447
+ ds,
1448
+ facet_by=facet_by,
1449
+ animate_by=animate_by,
1176
1450
  colors=colors,
1177
1451
  mode=mode,
1178
1452
  title=title,
1453
+ facet_cols=facet_cols,
1179
1454
  )
1180
1455
 
1181
- # TODO: Use colors for charge state?
1456
+ # Create a dataset with just charge_state and plot it as lines
1457
+ # This ensures proper handling of facets and animation
1458
+ charge_state_ds = charge_state_da.to_dataset(name=self._charge_state)
1182
1459
 
1183
- charge_state = charge_state.to_dataframe()
1184
- fig.add_trace(
1185
- plotly.graph_objs.Scatter(
1186
- x=charge_state.index, y=charge_state.values.flatten(), mode='lines', name=self._charge_state
1187
- )
1460
+ # Plot charge_state with mode='line' to get Scatter traces
1461
+ charge_state_fig = plotting.with_plotly(
1462
+ charge_state_ds,
1463
+ facet_by=facet_by,
1464
+ animate_by=animate_by,
1465
+ colors=colors,
1466
+ mode='line', # Always line for charge_state
1467
+ title='', # No title needed for this temp figure
1468
+ facet_cols=facet_cols,
1188
1469
  )
1470
+
1471
+ # Add charge_state traces to the main figure
1472
+ # This preserves subplot assignments and animation frames
1473
+ for trace in charge_state_fig.data:
1474
+ trace.line.width = 2 # Make charge_state line more prominent
1475
+ trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows)
1476
+ figure_like.add_trace(trace)
1477
+
1478
+ # Also add traces from animation frames if they exist
1479
+ # Both figures use the same animate_by parameter, so they should have matching frames
1480
+ if hasattr(charge_state_fig, 'frames') and charge_state_fig.frames:
1481
+ # Add charge_state traces to each frame
1482
+ for i, frame in enumerate(charge_state_fig.frames):
1483
+ if i < len(figure_like.frames):
1484
+ for trace in frame.data:
1485
+ trace.line.width = 2
1486
+ trace.line.shape = 'linear' # Smooth line for charge state
1487
+ figure_like.frames[i].data = figure_like.frames[i].data + (trace,)
1488
+
1489
+ default_filetype = '.html'
1189
1490
  elif engine == 'matplotlib':
1491
+ # Matplotlib requires only 'time' dimension; check for extras after selection
1492
+ extra_dims = [d for d in ds.dims if d != 'time']
1493
+ if extra_dims:
1494
+ raise ValueError(
1495
+ f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. '
1496
+ f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.'
1497
+ )
1498
+ # For matplotlib, plot flows (node balance), then add charge_state as line
1190
1499
  fig, ax = plotting.with_matplotlib(
1191
1500
  ds.to_dataframe(),
1192
1501
  colors=colors,
@@ -1194,15 +1503,25 @@ class ComponentResults(_NodeResults):
1194
1503
  title=title,
1195
1504
  )
1196
1505
 
1197
- charge_state = charge_state.to_dataframe()
1198
- ax.plot(charge_state.index, charge_state.values.flatten(), label=self._charge_state)
1506
+ # Add charge_state as a line overlay
1507
+ charge_state_df = charge_state_da.to_dataframe()
1508
+ ax.plot(
1509
+ charge_state_df.index,
1510
+ charge_state_df.values.flatten(),
1511
+ label=self._charge_state,
1512
+ linewidth=2,
1513
+ color='black',
1514
+ )
1515
+ ax.legend()
1199
1516
  fig.tight_layout()
1200
- fig = fig, ax
1517
+
1518
+ figure_like = fig, ax
1519
+ default_filetype = '.png'
1201
1520
 
1202
1521
  return plotting.export_figure(
1203
- fig,
1522
+ figure_like=figure_like,
1204
1523
  default_path=self._calculation_results.folder / title,
1205
- default_filetype='.html',
1524
+ default_filetype=default_filetype,
1206
1525
  user_path=None if isinstance(save, bool) else pathlib.Path(save),
1207
1526
  show=show,
1208
1527
  save=True if save else False,
@@ -1476,37 +1795,95 @@ class SegmentedCalculationResults:
1476
1795
  def plot_heatmap(
1477
1796
  self,
1478
1797
  variable_name: str,
1479
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
1480
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
1481
- color_map: str = 'portland',
1798
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1799
+ | Literal['auto']
1800
+ | None = 'auto',
1801
+ colors: str = 'portland',
1482
1802
  save: bool | pathlib.Path = False,
1483
1803
  show: bool = True,
1484
1804
  engine: plotting.PlottingEngine = 'plotly',
1805
+ facet_by: str | list[str] | None = None,
1806
+ animate_by: str | None = None,
1807
+ facet_cols: int = 3,
1808
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1809
+ # Deprecated parameters (kept for backwards compatibility)
1810
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
1811
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
1812
+ color_map: str | None = None,
1485
1813
  ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
1486
1814
  """Plot heatmap of variable solution across segments.
1487
1815
 
1488
1816
  Args:
1489
1817
  variable_name: Variable to plot.
1490
- heatmap_timeframes: Time aggregation level.
1491
- heatmap_timesteps_per_frame: Timesteps per frame.
1492
- color_map: Color scheme. Also see plotly.
1818
+ reshape_time: Time reshaping configuration (default: 'auto'):
1819
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
1820
+ - Tuple like ('D', 'h'): Explicit reshaping (days vs hours)
1821
+ - None: Disable time reshaping
1822
+ colors: Color scheme. See plotting.ColorType for options.
1493
1823
  save: Whether to save plot.
1494
1824
  show: Whether to display plot.
1495
1825
  engine: Plotting engine.
1826
+ facet_by: Dimension(s) to create facets (subplots) for.
1827
+ animate_by: Dimension to animate over (Plotly only).
1828
+ facet_cols: Number of columns in the facet grid layout.
1829
+ fill: Method to fill missing values: 'ffill' or 'bfill'.
1830
+ heatmap_timeframes: (Deprecated) Use reshape_time instead.
1831
+ heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead.
1832
+ color_map: (Deprecated) Use colors instead.
1496
1833
 
1497
1834
  Returns:
1498
1835
  Figure object.
1499
1836
  """
1837
+ # Handle deprecated parameters
1838
+ if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None:
1839
+ # Check for conflict with new parameter
1840
+ if reshape_time != 'auto': # Check if user explicitly set reshape_time
1841
+ raise ValueError(
1842
+ "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' "
1843
+ "and new parameter 'reshape_time'. Use only 'reshape_time'."
1844
+ )
1845
+
1846
+ import warnings
1847
+
1848
+ warnings.warn(
1849
+ "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. "
1850
+ "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.",
1851
+ DeprecationWarning,
1852
+ stacklevel=2,
1853
+ )
1854
+ # Override reshape_time if old parameters provided
1855
+ if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None:
1856
+ reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame)
1857
+
1858
+ if color_map is not None:
1859
+ # Check for conflict with new parameter
1860
+ if colors != 'portland': # Check if user explicitly set colors
1861
+ raise ValueError(
1862
+ "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'."
1863
+ )
1864
+
1865
+ import warnings
1866
+
1867
+ warnings.warn(
1868
+ "The 'color_map' parameter is deprecated. Use 'colors' instead.",
1869
+ DeprecationWarning,
1870
+ stacklevel=2,
1871
+ )
1872
+ colors = color_map
1873
+
1500
1874
  return plot_heatmap(
1501
- dataarray=self.solution_without_overlap(variable_name),
1875
+ data=self.solution_without_overlap(variable_name),
1502
1876
  name=variable_name,
1503
1877
  folder=self.folder,
1504
- heatmap_timeframes=heatmap_timeframes,
1505
- heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
1506
- color_map=color_map,
1878
+ reshape_time=reshape_time,
1879
+ colors=colors,
1507
1880
  save=save,
1508
1881
  show=show,
1509
1882
  engine=engine,
1883
+ facet_by=facet_by,
1884
+ animate_by=animate_by,
1885
+ facet_cols=facet_cols,
1886
+ fill=fill,
1510
1887
  )
1511
1888
 
1512
1889
  def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5):
@@ -1536,59 +1913,212 @@ class SegmentedCalculationResults:
1536
1913
 
1537
1914
 
1538
1915
  def plot_heatmap(
1539
- dataarray: xr.DataArray,
1540
- name: str,
1541
- folder: pathlib.Path,
1542
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
1543
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
1544
- color_map: str = 'portland',
1916
+ data: xr.DataArray | xr.Dataset,
1917
+ name: str | None = None,
1918
+ folder: pathlib.Path | None = None,
1919
+ colors: plotting.ColorType = 'viridis',
1545
1920
  save: bool | pathlib.Path = False,
1546
1921
  show: bool = True,
1547
1922
  engine: plotting.PlottingEngine = 'plotly',
1923
+ select: dict[str, Any] | None = None,
1924
+ facet_by: str | list[str] | None = None,
1925
+ animate_by: str | None = None,
1926
+ facet_cols: int = 3,
1927
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
1928
+ | Literal['auto']
1929
+ | None = 'auto',
1930
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
1931
+ # Deprecated parameters (kept for backwards compatibility)
1548
1932
  indexer: dict[str, Any] | None = None,
1933
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
1934
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
1935
+ color_map: str | None = None,
1549
1936
  ):
1550
- """Plot heatmap of time series data.
1937
+ """Plot heatmap visualization with support for multi-variable, faceting, and animation.
1938
+
1939
+ This function provides a standalone interface to the heatmap plotting capabilities,
1940
+ supporting the same modern features as CalculationResults.plot_heatmap().
1551
1941
 
1552
1942
  Args:
1553
- dataarray: Data to plot.
1554
- name: Variable name for title.
1555
- folder: Save folder.
1556
- heatmap_timeframes: Time aggregation level.
1557
- heatmap_timesteps_per_frame: Timesteps per frame.
1558
- color_map: Color scheme. Also see plotly.
1559
- save: Whether to save plot.
1560
- show: Whether to display plot.
1561
- engine: Plotting engine.
1562
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
1563
- If None, uses first value for each dimension.
1564
- If empty dict {}, uses all values.
1943
+ data: Data to plot. Can be a single DataArray or an xarray Dataset.
1944
+ When a Dataset is provided, all data variables are combined along a new 'variable' dimension.
1945
+ name: Optional name for the title. If not provided, uses the DataArray name or
1946
+ generates a default title for Datasets.
1947
+ folder: Save folder for the plot. Defaults to current directory if not provided.
1948
+ colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options.
1949
+ save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
1950
+ show: Whether to show the plot or not.
1951
+ engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
1952
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
1953
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
1954
+ or list of dimensions. Each unique value combination creates a subplot.
1955
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames.
1956
+ facet_cols: Number of columns in the facet grid layout (default: 3).
1957
+ reshape_time: Time reshaping configuration (default: 'auto'):
1958
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
1959
+ - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours
1960
+ - None: Disable auto-reshaping
1961
+ fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill).
1962
+ Default is 'ffill'.
1963
+
1964
+ Examples:
1965
+ Single DataArray with time reshaping:
1966
+
1967
+ >>> plot_heatmap(data, name='Temperature', folder=Path('.'), reshape_time=('D', 'h'))
1968
+
1969
+ Dataset with multiple variables (facet by variable):
1970
+
1971
+ >>> dataset = xr.Dataset({'Boiler': data1, 'CHP': data2, 'Storage': data3})
1972
+ >>> plot_heatmap(
1973
+ ... dataset,
1974
+ ... folder=Path('.'),
1975
+ ... facet_by='variable',
1976
+ ... reshape_time=('D', 'h'),
1977
+ ... )
1978
+
1979
+ Dataset with animation by variable:
1980
+
1981
+ >>> plot_heatmap(dataset, animate_by='variable', reshape_time=('D', 'h'))
1565
1982
  """
1566
- dataarray, suffix_parts = _apply_indexer_to_data(dataarray, indexer, drop=True)
1983
+ # Handle deprecated heatmap time parameters
1984
+ if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None:
1985
+ # Check for conflict with new parameter
1986
+ if reshape_time != 'auto': # User explicitly set reshape_time
1987
+ raise ValueError(
1988
+ "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' "
1989
+ "and new parameter 'reshape_time'. Use only 'reshape_time'."
1990
+ )
1991
+
1992
+ import warnings
1993
+
1994
+ warnings.warn(
1995
+ "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. "
1996
+ "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.",
1997
+ DeprecationWarning,
1998
+ stacklevel=2,
1999
+ )
2000
+ # Override reshape_time if both old parameters provided
2001
+ if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None:
2002
+ reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame)
2003
+
2004
+ # Handle deprecated color_map parameter
2005
+ if color_map is not None:
2006
+ # Check for conflict with new parameter
2007
+ if colors != 'viridis': # User explicitly set colors
2008
+ raise ValueError(
2009
+ "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'."
2010
+ )
2011
+
2012
+ import warnings
2013
+
2014
+ warnings.warn(
2015
+ "The 'color_map' parameter is deprecated. Use 'colors' instead.",
2016
+ DeprecationWarning,
2017
+ stacklevel=2,
2018
+ )
2019
+ colors = color_map
2020
+
2021
+ # Handle deprecated indexer parameter
2022
+ if indexer is not None:
2023
+ # Check for conflict with new parameter
2024
+ if select is not None: # User explicitly set select
2025
+ raise ValueError(
2026
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
2027
+ )
2028
+
2029
+ import warnings
2030
+
2031
+ warnings.warn(
2032
+ "The 'indexer' parameter is deprecated. Use 'select' instead.",
2033
+ DeprecationWarning,
2034
+ stacklevel=2,
2035
+ )
2036
+ select = indexer
2037
+
2038
+ # Convert Dataset to DataArray with 'variable' dimension
2039
+ if isinstance(data, xr.Dataset):
2040
+ # Extract all data variables from the Dataset
2041
+ variable_names = list(data.data_vars)
2042
+ dataarrays = [data[var] for var in variable_names]
2043
+
2044
+ # Combine into single DataArray with 'variable' dimension
2045
+ data = xr.concat(dataarrays, dim='variable')
2046
+ data = data.assign_coords(variable=variable_names)
2047
+
2048
+ # Use Dataset variable names for title if name not provided
2049
+ if name is None:
2050
+ title_name = f'Heatmap of {len(variable_names)} variables'
2051
+ else:
2052
+ title_name = name
2053
+ else:
2054
+ # Single DataArray
2055
+ if name is None:
2056
+ title_name = data.name if data.name else 'Heatmap'
2057
+ else:
2058
+ title_name = name
2059
+
2060
+ # Apply select filtering
2061
+ data, suffix_parts = _apply_selection_to_data(data, select=select, drop=True)
1567
2062
  suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
1568
- name = name if not suffix_parts else name + suffix
1569
2063
 
1570
- heatmap_data = plotting.heat_map_data_from_df(
1571
- dataarray.to_dataframe(name), heatmap_timeframes, heatmap_timesteps_per_frame, 'ffill'
1572
- )
2064
+ # Matplotlib heatmaps require at most 2D data
2065
+ # Time dimension will be reshaped to 2D (timeframe × timestep), so can't have other dims alongside it
2066
+ if engine == 'matplotlib':
2067
+ dims = list(data.dims)
2068
+
2069
+ # If 'time' dimension exists and will be reshaped, we can't have any other dimensions
2070
+ if 'time' in dims and len(dims) > 1 and reshape_time is not None:
2071
+ extra_dims = [d for d in dims if d != 'time']
2072
+ raise ValueError(
2073
+ f'Matplotlib heatmaps with time reshaping cannot have additional dimensions. '
2074
+ f'Found extra dimensions: {extra_dims}. '
2075
+ f'Use select={{...}} to reduce to time only, use "reshape_time=None" or switch to engine="plotly" or use for multi-dimensional support.'
2076
+ )
2077
+ # If no 'time' dimension (already reshaped or different data), allow at most 2 dimensions
2078
+ elif 'time' not in dims and len(dims) > 2:
2079
+ raise ValueError(
2080
+ f'Matplotlib heatmaps support at most 2 dimensions, but data has {len(dims)}: {dims}. '
2081
+ f'Use select={{...}} to reduce dimensions or switch to engine="plotly".'
2082
+ )
1573
2083
 
1574
- xlabel, ylabel = f'timeframe [{heatmap_timeframes}]', f'timesteps [{heatmap_timesteps_per_frame}]'
2084
+ # Build title
2085
+ title = f'{title_name}{suffix}'
2086
+ if isinstance(reshape_time, tuple):
2087
+ timeframes, timesteps_per_frame = reshape_time
2088
+ title += f' ({timeframes} vs {timesteps_per_frame})'
1575
2089
 
2090
+ # Plot with appropriate engine
1576
2091
  if engine == 'plotly':
1577
- figure_like = plotting.heat_map_plotly(
1578
- heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel
2092
+ figure_like = plotting.heatmap_with_plotly(
2093
+ data=data,
2094
+ facet_by=facet_by,
2095
+ animate_by=animate_by,
2096
+ colors=colors,
2097
+ title=title,
2098
+ facet_cols=facet_cols,
2099
+ reshape_time=reshape_time,
2100
+ fill=fill,
1579
2101
  )
1580
2102
  default_filetype = '.html'
1581
2103
  elif engine == 'matplotlib':
1582
- figure_like = plotting.heat_map_matplotlib(
1583
- heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel
2104
+ figure_like = plotting.heatmap_with_matplotlib(
2105
+ data=data,
2106
+ colors=colors,
2107
+ title=title,
2108
+ reshape_time=reshape_time,
2109
+ fill=fill,
1584
2110
  )
1585
2111
  default_filetype = '.png'
1586
2112
  else:
1587
2113
  raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"')
1588
2114
 
2115
+ # Set default folder if not provided
2116
+ if folder is None:
2117
+ folder = pathlib.Path('.')
2118
+
1589
2119
  return plotting.export_figure(
1590
2120
  figure_like=figure_like,
1591
- default_path=folder / f'{name} ({heatmap_timeframes}-{heatmap_timesteps_per_frame})',
2121
+ default_path=folder / title,
1592
2122
  default_filetype=default_filetype,
1593
2123
  user_path=None if isinstance(save, bool) else pathlib.Path(save),
1594
2124
  show=show,
@@ -1790,8 +2320,13 @@ def filter_dataarray_by_coord(da: xr.DataArray, **kwargs: str | list[str] | None
1790
2320
  if coord_name not in array.coords:
1791
2321
  raise AttributeError(f"Missing required coordinate '{coord_name}'")
1792
2322
 
1793
- # Convert single value to list
1794
- val_list = [coord_values] if isinstance(coord_values, str) else coord_values
2323
+ # Normalize to list for sequence-like inputs (excluding strings)
2324
+ if isinstance(coord_values, str):
2325
+ val_list = [coord_values]
2326
+ elif isinstance(coord_values, (list, tuple, np.ndarray, pd.Index)):
2327
+ val_list = list(coord_values)
2328
+ else:
2329
+ val_list = [coord_values]
1795
2330
 
1796
2331
  # Verify coord_values exist
1797
2332
  available = set(array[coord_name].values)
@@ -1801,7 +2336,7 @@ def filter_dataarray_by_coord(da: xr.DataArray, **kwargs: str | list[str] | None
1801
2336
 
1802
2337
  # Apply filter
1803
2338
  return array.where(
1804
- array[coord_name].isin(val_list) if isinstance(coord_values, list) else array[coord_name] == coord_values,
2339
+ array[coord_name].isin(val_list) if len(val_list) > 1 else array[coord_name] == val_list[0],
1805
2340
  drop=True,
1806
2341
  )
1807
2342
 
@@ -1820,36 +2355,26 @@ def filter_dataarray_by_coord(da: xr.DataArray, **kwargs: str | list[str] | None
1820
2355
  return da
1821
2356
 
1822
2357
 
1823
- def _apply_indexer_to_data(
1824
- data: xr.DataArray | xr.Dataset, indexer: dict[str, Any] | None = None, drop=False
2358
+ def _apply_selection_to_data(
2359
+ data: xr.DataArray | xr.Dataset,
2360
+ select: dict[str, Any] | None = None,
2361
+ drop=False,
1825
2362
  ) -> tuple[xr.DataArray | xr.Dataset, list[str]]:
1826
2363
  """
1827
- Apply indexer selection or auto-select first values for non-time dimensions.
2364
+ Apply selection to data.
1828
2365
 
1829
2366
  Args:
1830
2367
  data: xarray Dataset or DataArray
1831
- indexer: Optional selection dict
1832
- If None, uses first value for each dimension (except time).
1833
- If empty dict {}, uses all values.
2368
+ select: Optional selection dict
2369
+ drop: Whether to drop dimensions after selection
1834
2370
 
1835
2371
  Returns:
1836
2372
  Tuple of (selected_data, selection_string)
1837
2373
  """
1838
2374
  selection_string = []
1839
2375
 
1840
- if indexer is not None:
1841
- # User provided indexer
1842
- data = data.sel(indexer, drop=drop)
1843
- selection_string.extend(f'{v}[{k}]' for k, v in indexer.items())
1844
- else:
1845
- # Auto-select first value for each dimension except 'time'
1846
- selection = {}
1847
- for dim in data.dims:
1848
- if dim != 'time' and dim in data.coords:
1849
- first_value = data.coords[dim].values[0]
1850
- selection[dim] = first_value
1851
- selection_string.append(f'{first_value}[{dim}]')
1852
- if selection:
1853
- data = data.sel(selection, drop=drop)
2376
+ if select:
2377
+ data = data.sel(select, drop=drop)
2378
+ selection_string.extend(f'{dim}={val}' for dim, val in select.items())
1854
2379
 
1855
2380
  return data, selection_string