xarray-plotly 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xarray_plotly/__init__.py CHANGED
@@ -53,13 +53,19 @@ from xarray import DataArray, Dataset, register_dataarray_accessor, register_dat
53
53
  from xarray_plotly import config
54
54
  from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor
55
55
  from xarray_plotly.common import SLOT_ORDERS, auto
56
+ from xarray_plotly.figures import (
57
+ add_secondary_y,
58
+ overlay,
59
+ update_traces,
60
+ )
56
61
 
57
62
  __all__ = [
58
63
  "SLOT_ORDERS",
59
- "DataArrayPlotlyAccessor",
60
- "DatasetPlotlyAccessor",
64
+ "add_secondary_y",
61
65
  "auto",
62
66
  "config",
67
+ "overlay",
68
+ "update_traces",
63
69
  "xpx",
64
70
  ]
65
71
 
xarray_plotly/accessor.py CHANGED
@@ -34,7 +34,7 @@ class DataArrayPlotlyAccessor:
34
34
  ```
35
35
  """
36
36
 
37
- __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "imshow", "pie"]
37
+ __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "imshow", "pie"]
38
38
 
39
39
  def __init__(self, darray: DataArray) -> None:
40
40
  self._da = darray
@@ -160,6 +160,41 @@ class DataArrayPlotlyAccessor:
160
160
  **px_kwargs,
161
161
  )
162
162
 
163
+ def fast_bar(
164
+ self,
165
+ *,
166
+ x: SlotValue = auto,
167
+ color: SlotValue = auto,
168
+ facet_col: SlotValue = auto,
169
+ facet_row: SlotValue = auto,
170
+ animation_frame: SlotValue = auto,
171
+ **px_kwargs: Any,
172
+ ) -> go.Figure:
173
+ """Create a bar-like chart using stacked areas for better performance.
174
+
175
+ Slot order: x -> color -> facet_col -> facet_row -> animation_frame
176
+
177
+ Args:
178
+ x: Dimension for x-axis. Default: first dimension.
179
+ color: Dimension for color/stacking. Default: second dimension.
180
+ facet_col: Dimension for subplot columns. Default: third dimension.
181
+ facet_row: Dimension for subplot rows. Default: fourth dimension.
182
+ animation_frame: Dimension for animation. Default: fifth dimension.
183
+ **px_kwargs: Additional arguments passed to `plotly.express.area()`.
184
+
185
+ Returns:
186
+ Interactive Plotly Figure.
187
+ """
188
+ return plotting.fast_bar(
189
+ self._da,
190
+ x=x,
191
+ color=color,
192
+ facet_col=facet_col,
193
+ facet_row=facet_row,
194
+ animation_frame=animation_frame,
195
+ **px_kwargs,
196
+ )
197
+
163
198
  def scatter(
164
199
  self,
165
200
  *,
@@ -257,6 +292,11 @@ class DataArrayPlotlyAccessor:
257
292
 
258
293
  Slot order: y (rows) -> x (columns) -> facet_col -> animation_frame
259
294
 
295
+ Note:
296
+ **Difference from px.imshow**: Color bounds are computed from the
297
+ entire dataset by default, ensuring consistent coloring across
298
+ animation frames. Use `zmin`/`zmax` to override.
299
+
260
300
  Args:
261
301
  x: Dimension for x-axis (columns). Default: second dimension.
262
302
  y: Dimension for y-axis (rows). Default: first dimension.
@@ -344,7 +384,7 @@ class DatasetPlotlyAccessor:
344
384
  ```
345
385
  """
346
386
 
347
- __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "pie"]
387
+ __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "pie"]
348
388
 
349
389
  def __init__(self, dataset: Dataset) -> None:
350
390
  self._ds = dataset
@@ -496,6 +536,42 @@ class DatasetPlotlyAccessor:
496
536
  **px_kwargs,
497
537
  )
498
538
 
539
+ def fast_bar(
540
+ self,
541
+ var: str | None = None,
542
+ *,
543
+ x: SlotValue = auto,
544
+ color: SlotValue = auto,
545
+ facet_col: SlotValue = auto,
546
+ facet_row: SlotValue = auto,
547
+ animation_frame: SlotValue = auto,
548
+ **px_kwargs: Any,
549
+ ) -> go.Figure:
550
+ """Create a bar-like chart using stacked areas for better performance.
551
+
552
+ Args:
553
+ var: Variable to plot. If None, plots all variables with "variable" dimension.
554
+ x: Dimension for x-axis.
555
+ color: Dimension for color/stacking.
556
+ facet_col: Dimension for subplot columns.
557
+ facet_row: Dimension for subplot rows.
558
+ animation_frame: Dimension for animation.
559
+ **px_kwargs: Additional arguments passed to `plotly.express.area()`.
560
+
561
+ Returns:
562
+ Interactive Plotly Figure.
563
+ """
564
+ da = self._get_dataarray(var)
565
+ return plotting.fast_bar(
566
+ da,
567
+ x=x,
568
+ color=color,
569
+ facet_col=facet_col,
570
+ facet_row=facet_row,
571
+ animation_frame=animation_frame,
572
+ **px_kwargs,
573
+ )
574
+
499
575
  def scatter(
500
576
  self,
501
577
  var: str | None = None,
xarray_plotly/config.py CHANGED
@@ -26,6 +26,7 @@ DEFAULT_SLOT_ORDERS: dict[str, tuple[str, ...]] = {
26
26
  "animation_frame",
27
27
  ),
28
28
  "bar": ("x", "color", "pattern_shape", "facet_col", "facet_row", "animation_frame"),
29
+ "fast_bar": ("x", "color", "facet_col", "facet_row", "animation_frame"),
29
30
  "area": (
30
31
  "x",
31
32
  "color",
@@ -0,0 +1,448 @@
1
+ """
2
+ Helper functions for combining and manipulating Plotly figures.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import copy
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import Iterator
12
+
13
+ import plotly.graph_objects as go
14
+
15
+
16
+ def _iter_all_traces(fig: go.Figure) -> Iterator:
17
+ """Iterate over all traces in a figure, including animation frames.
18
+
19
+ Yields traces from fig.data first, then from each frame in fig.frames.
20
+ Useful for applying styling to all traces including those in animations.
21
+
22
+ Args:
23
+ fig: Plotly Figure.
24
+
25
+ Yields:
26
+ Each trace object from the figure.
27
+ """
28
+ yield from fig.data
29
+ for frame in fig.frames or []:
30
+ yield from frame.data
31
+
32
+
33
+ def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]:
34
+ """Extract (xaxis, yaxis) pairs from figure traces.
35
+
36
+ Args:
37
+ fig: A Plotly figure.
38
+
39
+ Returns:
40
+ Set of (xaxis, yaxis) tuples, e.g., {('x', 'y'), ('x2', 'y2')}.
41
+ """
42
+ axes_pairs = set()
43
+ for trace in fig.data:
44
+ xaxis = getattr(trace, "xaxis", None) or "x"
45
+ yaxis = getattr(trace, "yaxis", None) or "y"
46
+ axes_pairs.add((xaxis, yaxis))
47
+ return axes_pairs
48
+
49
+
50
+ def _validate_compatible_structure(base: go.Figure, overlay: go.Figure) -> None:
51
+ """Validate that overlay's subplot structure is compatible with base.
52
+
53
+ Args:
54
+ base: The base figure.
55
+ overlay: The overlay figure to check.
56
+
57
+ Raises:
58
+ ValueError: If overlay has subplots not present in base.
59
+ """
60
+ base_axes = _get_subplot_axes(base)
61
+ overlay_axes = _get_subplot_axes(overlay)
62
+
63
+ extra_axes = overlay_axes - base_axes
64
+ if extra_axes:
65
+ raise ValueError(
66
+ f"Overlay figure has subplots not present in base figure: {extra_axes}. "
67
+ "Ensure both figures have the same facet structure."
68
+ )
69
+
70
+
71
+ def _validate_animation_compatibility(base: go.Figure, overlay: go.Figure) -> None:
72
+ """Validate animation frame compatibility between base and overlay.
73
+
74
+ Args:
75
+ base: The base figure.
76
+ overlay: The overlay figure to check.
77
+
78
+ Raises:
79
+ ValueError: If overlay has animation but base doesn't, or frame names don't match.
80
+ """
81
+ base_has_frames = bool(base.frames)
82
+ overlay_has_frames = bool(overlay.frames)
83
+
84
+ if overlay_has_frames and not base_has_frames:
85
+ raise ValueError(
86
+ "Overlay figure has animation frames but base figure does not. "
87
+ "Cannot add animated overlay to static base figure."
88
+ )
89
+
90
+ if base_has_frames and overlay_has_frames:
91
+ base_frame_names = {frame.name for frame in base.frames}
92
+ overlay_frame_names = {frame.name for frame in overlay.frames}
93
+
94
+ if base_frame_names != overlay_frame_names:
95
+ missing_in_overlay = base_frame_names - overlay_frame_names
96
+ extra_in_overlay = overlay_frame_names - base_frame_names
97
+ msg = "Animation frame names don't match between base and overlay."
98
+ if missing_in_overlay:
99
+ msg += f" Missing in overlay: {missing_in_overlay}."
100
+ if extra_in_overlay:
101
+ msg += f" Extra in overlay: {extra_in_overlay}."
102
+ raise ValueError(msg)
103
+
104
+
105
+ def _merge_frames(
106
+ base: go.Figure,
107
+ overlays: list[go.Figure],
108
+ base_trace_count: int,
109
+ overlay_trace_counts: list[int],
110
+ ) -> list:
111
+ """Merge animation frames from base and overlay figures.
112
+
113
+ Args:
114
+ base: The base figure with animation frames.
115
+ overlays: List of overlay figures (may or may not have frames).
116
+ base_trace_count: Number of traces in the base figure.
117
+ overlay_trace_counts: Number of traces in each overlay figure.
118
+
119
+ Returns:
120
+ List of merged frames.
121
+ """
122
+ import plotly.graph_objects as go
123
+
124
+ merged_frames = []
125
+
126
+ for base_frame in base.frames:
127
+ frame_name = base_frame.name
128
+ merged_data = list(base_frame.data)
129
+
130
+ for overlay, _overlay_trace_count in zip(overlays, overlay_trace_counts, strict=False):
131
+ if overlay.frames:
132
+ # Find matching frame in overlay
133
+ overlay_frame = next((f for f in overlay.frames if f.name == frame_name), None)
134
+ if overlay_frame:
135
+ merged_data.extend(overlay_frame.data)
136
+ else:
137
+ # Static overlay: replicate traces to this frame
138
+ merged_data.extend(overlay.data)
139
+
140
+ merged_frames.append(
141
+ go.Frame(
142
+ data=merged_data,
143
+ name=frame_name,
144
+ traces=list(range(base_trace_count + sum(overlay_trace_counts))),
145
+ )
146
+ )
147
+
148
+ return merged_frames
149
+
150
+
151
+ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
152
+ """Overlay multiple Plotly figures on the same axes.
153
+
154
+ Creates a new figure with the base figure's layout, sliders, and buttons,
155
+ with all overlay traces added on top. Correctly handles faceted figures
156
+ and animation frames.
157
+
158
+ Args:
159
+ base: The base figure whose layout is preserved.
160
+ *overlays: One or more figures to overlay on the base.
161
+
162
+ Returns:
163
+ A new combined figure.
164
+
165
+ Raises:
166
+ ValueError: If overlay has subplots not in base, animation frames don't match,
167
+ or overlay has animation but base doesn't.
168
+
169
+ Example:
170
+ >>> import numpy as np
171
+ >>> import xarray as xr
172
+ >>> from xarray_plotly import xpx, overlay
173
+ >>>
174
+ >>> da = xr.DataArray(np.random.rand(10, 3), dims=["time", "cat"])
175
+ >>> area_fig = xpx(da).area()
176
+ >>> line_fig = xpx(da).line()
177
+ >>> combined = overlay(area_fig, line_fig)
178
+ >>>
179
+ >>> # With animation
180
+ >>> da3d = xr.DataArray(np.random.rand(10, 3, 4), dims=["x", "cat", "time"])
181
+ >>> area = xpx(da3d).area(animation_frame="time")
182
+ >>> line = xpx(da3d).line(animation_frame="time")
183
+ >>> combined = overlay(area, line)
184
+ """
185
+ import plotly.graph_objects as go
186
+
187
+ if not overlays:
188
+ # No overlays: return a deep copy of base
189
+ return copy.deepcopy(base)
190
+
191
+ # Validate all overlays
192
+ for overlay in overlays:
193
+ _validate_compatible_structure(base, overlay)
194
+ _validate_animation_compatibility(base, overlay)
195
+
196
+ # Create new figure with base's layout
197
+ combined = go.Figure(layout=copy.deepcopy(base.layout))
198
+
199
+ # Add all traces from base
200
+ for trace in base.data:
201
+ combined.add_trace(copy.deepcopy(trace))
202
+
203
+ # Add all traces from overlays
204
+ for overlay in overlays:
205
+ for trace in overlay.data:
206
+ combined.add_trace(copy.deepcopy(trace))
207
+
208
+ # Handle animation frames
209
+ if base.frames:
210
+ base_trace_count = len(base.data)
211
+ overlay_trace_counts = [len(overlay.data) for overlay in overlays]
212
+ merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts)
213
+ combined.frames = merged_frames
214
+
215
+ return combined
216
+
217
+
218
+ def _build_secondary_y_mapping(base_axes: set[tuple[str, str]]) -> dict[str, str]:
219
+ """Build mapping from primary y-axes to secondary y-axes.
220
+
221
+ Args:
222
+ base_axes: Set of (xaxis, yaxis) pairs from base figure.
223
+
224
+ Returns:
225
+ Dict mapping primary yaxis names to secondary yaxis names.
226
+ E.g., {'y': 'y4', 'y2': 'y5', 'y3': 'y6'}
227
+ """
228
+ primary_y_axes = sorted({yaxis for _, yaxis in base_axes})
229
+
230
+ # Find the highest existing yaxis number
231
+ max_y_num = 1 # 'y' is 1
232
+ for yaxis in primary_y_axes:
233
+ num = 1 if yaxis == "y" else int(yaxis[1:])
234
+ max_y_num = max(max_y_num, num)
235
+
236
+ # Create mapping: primary_yaxis -> secondary_yaxis
237
+ y_mapping = {}
238
+ next_y_num = max_y_num + 1
239
+ for yaxis in primary_y_axes:
240
+ y_mapping[yaxis] = f"y{next_y_num}"
241
+ next_y_num += 1
242
+
243
+ return y_mapping
244
+
245
+
246
+ def add_secondary_y(
247
+ base: go.Figure,
248
+ secondary: go.Figure,
249
+ *,
250
+ secondary_y_title: str | None = None,
251
+ ) -> go.Figure:
252
+ """Add a secondary y-axis with traces from another figure.
253
+
254
+ Creates a new figure with the base figure's layout and secondary y-axes
255
+ on the right side. All traces from the secondary figure are plotted against
256
+ the secondary y-axes. Supports faceted figures when both have matching
257
+ facet structure.
258
+
259
+ Args:
260
+ base: The base figure (left y-axis).
261
+ secondary: The figure whose traces use the secondary y-axis (right).
262
+ secondary_y_title: Optional title for the secondary y-axis.
263
+ If not provided, uses the secondary figure's y-axis title.
264
+
265
+ Returns:
266
+ A new figure with both primary and secondary y-axes.
267
+
268
+ Raises:
269
+ ValueError: If facet structures don't match, or if animation
270
+ frames don't match.
271
+
272
+ Example:
273
+ >>> import numpy as np
274
+ >>> import xarray as xr
275
+ >>> from xarray_plotly import xpx, add_secondary_y
276
+ >>>
277
+ >>> # Two variables with different scales
278
+ >>> temp = xr.DataArray([20, 22, 25, 23], dims=["time"], name="Temperature (°C)")
279
+ >>> precip = xr.DataArray([0, 5, 12, 2], dims=["time"], name="Precipitation (mm)")
280
+ >>>
281
+ >>> temp_fig = xpx(temp).line()
282
+ >>> precip_fig = xpx(precip).bar()
283
+ >>> combined = add_secondary_y(temp_fig, precip_fig)
284
+ >>>
285
+ >>> # With facets
286
+ >>> data = xr.DataArray(np.random.rand(10, 3), dims=["x", "facet"])
287
+ >>> fig1 = xpx(data).line(facet_col="facet")
288
+ >>> fig2 = xpx(data * 100).bar(facet_col="facet") # Different scale
289
+ >>> combined = add_secondary_y(fig1, fig2)
290
+ """
291
+ import plotly.graph_objects as go
292
+
293
+ # Get axis pairs from both figures
294
+ base_axes = _get_subplot_axes(base)
295
+ secondary_axes = _get_subplot_axes(secondary)
296
+
297
+ # Validate same facet structure
298
+ if base_axes != secondary_axes:
299
+ raise ValueError(
300
+ f"Base and secondary figures must have the same facet structure. "
301
+ f"Base has {base_axes}, secondary has {secondary_axes}."
302
+ )
303
+
304
+ # Validate animation compatibility
305
+ _validate_animation_compatibility(base, secondary)
306
+
307
+ # Build mapping from primary y-axes to secondary y-axes
308
+ y_mapping = _build_secondary_y_mapping(base_axes)
309
+
310
+ # Create new figure with base's layout
311
+ combined = go.Figure(layout=copy.deepcopy(base.layout))
312
+
313
+ # Add all traces from base (primary y-axis)
314
+ for trace in base.data:
315
+ combined.add_trace(copy.deepcopy(trace))
316
+
317
+ # Add all traces from secondary, remapped to secondary y-axes
318
+ for trace in secondary.data:
319
+ trace_copy = copy.deepcopy(trace)
320
+ original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
321
+ trace_copy.yaxis = y_mapping[original_yaxis]
322
+ combined.add_trace(trace_copy)
323
+
324
+ # Configure secondary y-axes
325
+ for primary_yaxis, secondary_yaxis in y_mapping.items():
326
+ # Get title - only set on first secondary axis or use provided title
327
+ title = None
328
+ if secondary_y_title is not None:
329
+ # Only set title on the first secondary axis to avoid repetition
330
+ if primary_yaxis == "y":
331
+ title = secondary_y_title
332
+ elif primary_yaxis == "y" and secondary.layout.yaxis and secondary.layout.yaxis.title:
333
+ # Try to get from secondary's layout
334
+ title = secondary.layout.yaxis.title.text
335
+
336
+ # Configure the secondary axis
337
+ axis_config = {
338
+ "title": title,
339
+ "overlaying": primary_yaxis,
340
+ "side": "right",
341
+ "anchor": "free" if primary_yaxis != "y" else None,
342
+ }
343
+ # Remove None values
344
+ axis_config = {k: v for k, v in axis_config.items() if v is not None}
345
+
346
+ # Convert y2 -> yaxis2, y3 -> yaxis3, etc. for layout property name
347
+ layout_prop = "yaxis" if secondary_yaxis == "y" else f"yaxis{secondary_yaxis[1:]}"
348
+ combined.update_layout(**{layout_prop: axis_config})
349
+
350
+ # Handle animation frames
351
+ if base.frames:
352
+ merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping)
353
+ combined.frames = merged_frames
354
+
355
+ return combined
356
+
357
+
358
+ def _merge_secondary_y_frames(
359
+ base: go.Figure,
360
+ secondary: go.Figure,
361
+ y_mapping: dict[str, str],
362
+ ) -> list:
363
+ """Merge animation frames for secondary y-axis combination.
364
+
365
+ Args:
366
+ base: The base figure with animation frames.
367
+ secondary: The secondary figure (may or may not have frames).
368
+ y_mapping: Mapping from primary y-axis names to secondary y-axis names.
369
+
370
+ Returns:
371
+ List of merged frames with secondary traces assigned to secondary y-axes.
372
+ """
373
+ import plotly.graph_objects as go
374
+
375
+ merged_frames = []
376
+ base_trace_count = len(base.data)
377
+ secondary_trace_count = len(secondary.data)
378
+
379
+ for base_frame in base.frames:
380
+ frame_name = base_frame.name
381
+ merged_data = list(base_frame.data)
382
+
383
+ if secondary.frames:
384
+ # Find matching frame in secondary
385
+ secondary_frame = next((f for f in secondary.frames if f.name == frame_name), None)
386
+ if secondary_frame:
387
+ # Add secondary frame data with remapped y-axis
388
+ for trace_data in secondary_frame.data:
389
+ trace_copy = copy.deepcopy(trace_data)
390
+ original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
391
+ trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis)
392
+ merged_data.append(trace_copy)
393
+ else:
394
+ # Static secondary: replicate traces to this frame
395
+ for trace in secondary.data:
396
+ trace_copy = copy.deepcopy(trace)
397
+ original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
398
+ trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis)
399
+ merged_data.append(trace_copy)
400
+
401
+ merged_frames.append(
402
+ go.Frame(
403
+ data=merged_data,
404
+ name=frame_name,
405
+ traces=list(range(base_trace_count + secondary_trace_count)),
406
+ )
407
+ )
408
+
409
+ return merged_frames
410
+
411
+
412
+ def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure:
413
+ """Update traces in both base figure and all animation frames.
414
+
415
+ Plotly's `update_traces()` only updates the base figure, not animation frames.
416
+ This function updates both, ensuring trace styles persist during animation.
417
+
418
+ Args:
419
+ fig: A Plotly figure, optionally with animation frames.
420
+ selector: Dict to match specific traces, e.g. ``{"name": "Germany"}``.
421
+ If None, updates all traces.
422
+ **kwargs: Trace properties to update, e.g. ``line_width=4``, ``line_dash="dot"``.
423
+
424
+ Returns:
425
+ The modified figure (same object, mutated in place).
426
+
427
+ Example:
428
+ >>> import plotly.express as px
429
+ >>> from xarray_plotly import update_traces
430
+ >>>
431
+ >>> df = px.data.gapminder()
432
+ >>> fig = px.line(df, x="year", y="gdpPercap", color="country", animation_frame="continent")
433
+ >>>
434
+ >>> # Update all traces
435
+ >>> update_traces(fig, line_width=3)
436
+ >>>
437
+ >>> # Update specific trace by name
438
+ >>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot")
439
+ """
440
+ for trace in _iter_all_traces(fig):
441
+ if selector is None:
442
+ trace.update(**kwargs)
443
+ else:
444
+ # Check if trace matches all selector criteria
445
+ if all(getattr(trace, k, None) == v for k, v in selector.items()):
446
+ trace.update(**kwargs)
447
+
448
+ return fig
xarray_plotly/plotting.py CHANGED
@@ -4,6 +4,7 @@ Plotly Express plotting functions for DataArray objects.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import warnings
7
8
  from typing import TYPE_CHECKING, Any
8
9
 
9
10
  import numpy as np
@@ -18,6 +19,9 @@ from xarray_plotly.common import (
18
19
  get_value_col,
19
20
  to_dataframe,
20
21
  )
22
+ from xarray_plotly.figures import (
23
+ _iter_all_traces,
24
+ )
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  import plotly.graph_objects as go
@@ -167,6 +171,167 @@ def bar(
167
171
  )
168
172
 
169
173
 
174
+ def _classify_trace_sign(y_values: np.ndarray) -> str:
175
+ """Classify a trace as 'positive', 'negative', or 'mixed' based on its values."""
176
+ y_arr = np.asarray(y_values)
177
+ y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]
178
+ if len(y_clean) == 0:
179
+ return "zero"
180
+ has_pos = bool(np.any(y_clean > 0))
181
+ has_neg = bool(np.any(y_clean < 0))
182
+ if has_pos and has_neg:
183
+ return "mixed"
184
+ elif has_neg:
185
+ return "negative"
186
+ elif has_pos:
187
+ return "positive"
188
+ return "zero"
189
+
190
+
191
+ def _style_traces_as_bars(fig: go.Figure) -> None:
192
+ """Style area chart traces to look like bar charts with proper pos/neg stacking.
193
+
194
+ Classifies each trace (by name) across all data and animation frames,
195
+ then assigns stackgroups: positive traces stack upward, negative stack downward.
196
+ """
197
+ # Collect all traces (main + animation frames)
198
+ all_traces = list(_iter_all_traces(fig))
199
+
200
+ # Classify each trace name by aggregating sign info across all occurrences
201
+ sign_flags: dict[str, dict[str, bool]] = {}
202
+ for trace in all_traces:
203
+ if trace.name not in sign_flags:
204
+ sign_flags[trace.name] = {"has_pos": False, "has_neg": False}
205
+ if trace.y is not None and len(trace.y) > 0:
206
+ y_arr = np.asarray(trace.y)
207
+ y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]
208
+ if len(y_clean) > 0:
209
+ if np.any(y_clean > 0):
210
+ sign_flags[trace.name]["has_pos"] = True
211
+ if np.any(y_clean < 0):
212
+ sign_flags[trace.name]["has_neg"] = True
213
+
214
+ # Build classification map
215
+ class_map: dict[str, str] = {}
216
+ mixed_traces: list[str] = []
217
+ for name, flags in sign_flags.items():
218
+ if flags["has_pos"] and flags["has_neg"]:
219
+ class_map[name] = "mixed"
220
+ mixed_traces.append(name)
221
+ elif flags["has_neg"]:
222
+ class_map[name] = "negative"
223
+ elif flags["has_pos"]:
224
+ class_map[name] = "positive"
225
+ else:
226
+ class_map[name] = "zero"
227
+
228
+ # Warn about mixed traces
229
+ if mixed_traces:
230
+ warnings.warn(
231
+ f"fast_bar: traces {mixed_traces} have mixed positive/negative values "
232
+ "and cannot be stacked. They are shown as dashed lines. "
233
+ "Consider using bar() for proper stacking of mixed data.",
234
+ UserWarning,
235
+ stacklevel=3,
236
+ )
237
+
238
+ # Apply styling to all traces
239
+ for trace in all_traces:
240
+ color = trace.line.color
241
+ cls = class_map.get(trace.name, "positive")
242
+
243
+ if cls in ("positive", "negative"):
244
+ trace.stackgroup = cls
245
+ trace.fillcolor = color
246
+ trace.line = {"width": 0, "color": color, "shape": "hv"}
247
+ elif cls == "mixed":
248
+ # Mixed: no stacking, show as dashed line
249
+ trace.stackgroup = None
250
+ trace.fill = None
251
+ trace.line = {"width": 2, "color": color, "shape": "hv", "dash": "dash"}
252
+ else: # zero
253
+ trace.stackgroup = None
254
+ trace.fill = None
255
+ trace.line = {"width": 0, "color": color, "shape": "hv"}
256
+
257
+
258
+ def fast_bar(
259
+ darray: DataArray,
260
+ *,
261
+ x: SlotValue = auto,
262
+ color: SlotValue = auto,
263
+ facet_col: SlotValue = auto,
264
+ facet_row: SlotValue = auto,
265
+ animation_frame: SlotValue = auto,
266
+ **px_kwargs: Any,
267
+ ) -> go.Figure:
268
+ """
269
+ Create a bar-like chart using stacked areas for better performance.
270
+
271
+ Uses `px.area` with stepped lines and no outline to create a bar-like
272
+ appearance. Renders faster than `bar()` for large datasets because it
273
+ uses a single polygon per trace instead of individual rectangles.
274
+
275
+ The y-axis shows DataArray values. Dimensions fill slots in order:
276
+ x -> color -> facet_col -> facet_row -> animation_frame
277
+
278
+ Traces are classified by their values: purely positive traces stack upward,
279
+ purely negative traces stack downward. Traces with mixed signs are shown
280
+ as dashed lines without stacking.
281
+
282
+ Parameters
283
+ ----------
284
+ darray
285
+ The DataArray to plot.
286
+ x
287
+ Dimension for x-axis. Default: first dimension.
288
+ color
289
+ Dimension for color/stacking. Default: second dimension.
290
+ facet_col
291
+ Dimension for subplot columns. Default: third dimension.
292
+ facet_row
293
+ Dimension for subplot rows. Default: fourth dimension.
294
+ animation_frame
295
+ Dimension for animation. Default: fifth dimension.
296
+ **px_kwargs
297
+ Additional arguments passed to `plotly.express.area()`.
298
+
299
+ Returns
300
+ -------
301
+ plotly.graph_objects.Figure
302
+ """
303
+ slots = assign_slots(
304
+ list(darray.dims),
305
+ "fast_bar",
306
+ x=x,
307
+ color=color,
308
+ facet_col=facet_col,
309
+ facet_row=facet_row,
310
+ animation_frame=animation_frame,
311
+ )
312
+
313
+ df = to_dataframe(darray)
314
+ value_col = get_value_col(darray)
315
+ labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})}
316
+
317
+ fig = px.area(
318
+ df,
319
+ x=slots.get("x"),
320
+ y=value_col,
321
+ color=slots.get("color"),
322
+ facet_col=slots.get("facet_col"),
323
+ facet_row=slots.get("facet_row"),
324
+ animation_frame=slots.get("animation_frame"),
325
+ line_shape="hv",
326
+ labels=labels,
327
+ **px_kwargs,
328
+ )
329
+
330
+ _style_traces_as_bars(fig)
331
+
332
+ return fig
333
+
334
+
170
335
  def area(
171
336
  darray: DataArray,
172
337
  *,
@@ -408,6 +573,14 @@ def imshow(
408
573
  Both x and y are dimensions. Dimensions fill slots in order:
409
574
  y (rows) -> x (columns) -> facet_col -> animation_frame
410
575
 
576
+ .. note::
577
+ **Difference from plotly.express.imshow**: By default, color bounds
578
+ (zmin/zmax) are computed from the **entire dataset**, ensuring
579
+ consistent coloring across animation frames and facets. In contrast,
580
+ ``px.imshow`` auto-scales each frame independently, which can make
581
+ animations visually confusing. Set ``zmin`` and ``zmax`` explicitly
582
+ to override this behavior.
583
+
411
584
  Parameters
412
585
  ----------
413
586
  darray
@@ -422,7 +595,7 @@ def imshow(
422
595
  Dimension for animation. Default: fourth dimension.
423
596
  robust
424
597
  If True, compute color bounds using 2nd and 98th percentiles
425
- for robustness against outliers. Default: False.
598
+ for robustness against outliers. Default: False (uses min/max).
426
599
  **px_kwargs
427
600
  Additional arguments passed to `plotly.express.imshow()`.
428
601
  Use `zmin` and `zmax` to manually set color scale bounds.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xarray_plotly
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Interactive Plotly Express plotting accessor for xarray
5
5
  Author: Felix
6
6
  License: MIT
@@ -28,7 +28,7 @@ Provides-Extra: dev
28
28
  Requires-Dist: pytest==9.0.2; extra == "dev"
29
29
  Requires-Dist: pytest-cov==7.0.0; extra == "dev"
30
30
  Requires-Dist: mypy==1.19.1; extra == "dev"
31
- Requires-Dist: ruff==0.14.11; extra == "dev"
31
+ Requires-Dist: ruff==0.14.13; extra == "dev"
32
32
  Requires-Dist: pre-commit==4.5.1; extra == "dev"
33
33
  Requires-Dist: nbstripout==0.8.2; extra == "dev"
34
34
  Provides-Extra: docs
@@ -0,0 +1,12 @@
1
+ xarray_plotly/__init__.py,sha256=vAM2TCeLnSkSjwCDETlBI_1SukhVeth7HFMaFIWl_ps,3384
2
+ xarray_plotly/accessor.py,sha256=eP4_GPBCN-c32M8s1LJv4Zn4PvHHO9HsP5oPQMDDd6I,23150
3
+ xarray_plotly/common.py,sha256=YTiaPLJ0Gh20mHV8-72J8DjWk-XaSYaT_ZiqXp6cecU,7192
4
+ xarray_plotly/config.py,sha256=gS6IqWdx82PQs1yZzl_diGHlnmVTladuWZzk1lYgIsQ,6569
5
+ xarray_plotly/figures.py,sha256=5SxTTPCgaEMRC5Bz7Ew1leJXDNMA1ZchbY_dO82SHY4,15574
6
+ xarray_plotly/plotting.py,sha256=Fl_UklCN8AYtPj0t3m8zXwcz6MsFLdmQrKZC68IH1OM,20668
7
+ xarray_plotly/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ xarray_plotly-0.0.8.dist-info/licenses/LICENSE,sha256=AvVEfNqbhIm9jHvt0acJNjW1JUKa2a70Zb5rJdEXCJI,1064
9
+ xarray_plotly-0.0.8.dist-info/METADATA,sha256=ezKDOXZ7CklqOB5KDBh32Lm1nri63OF_OoeCKhuIt9c,3415
10
+ xarray_plotly-0.0.8.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
11
+ xarray_plotly-0.0.8.dist-info/top_level.txt,sha256=GtMkvuZvLAYTjYXtwoNUa0ag42CJARZJK1CZemYD7pg,14
12
+ xarray_plotly-0.0.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,11 +0,0 @@
1
- xarray_plotly/__init__.py,sha256=o8w-prP-J-XWr6Bx5vqqpQEz1maAc0Sycq2q7ll5nuI,3294
2
- xarray_plotly/accessor.py,sha256=2vVvAv6K92wUyhQjAge3v85f2gk5mkrRC3VjBWJTGeo,20529
3
- xarray_plotly/common.py,sha256=YTiaPLJ0Gh20mHV8-72J8DjWk-XaSYaT_ZiqXp6cecU,7192
4
- xarray_plotly/config.py,sha256=-Sp6TNx-8Zk6x0WmJ5j_KGHPRY639ju3_RSemrlmuXY,6492
5
- xarray_plotly/plotting.py,sha256=upUt6620pUKUdIYAvDrxFuxBtQm79ZM11IleSajPScU,14799
6
- xarray_plotly/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- xarray_plotly-0.0.6.dist-info/licenses/LICENSE,sha256=AvVEfNqbhIm9jHvt0acJNjW1JUKa2a70Zb5rJdEXCJI,1064
8
- xarray_plotly-0.0.6.dist-info/METADATA,sha256=YiTD_OutpKfk-VCxi_eZKEAR8vK2sWoRgaa2DA_Xvr0,3415
9
- xarray_plotly-0.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- xarray_plotly-0.0.6.dist-info/top_level.txt,sha256=GtMkvuZvLAYTjYXtwoNUa0ag42CJARZJK1CZemYD7pg,14
11
- xarray_plotly-0.0.6.dist-info/RECORD,,