xarray-plotly 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xarray_plotly/accessor.py CHANGED
@@ -34,7 +34,7 @@ class DataArrayPlotlyAccessor:
34
34
  ```
35
35
  """
36
36
 
37
- __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "imshow", "pie"]
37
+ __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "imshow", "pie"]
38
38
 
39
39
  def __init__(self, darray: DataArray) -> None:
40
40
  self._da = darray
@@ -160,6 +160,41 @@ class DataArrayPlotlyAccessor:
160
160
  **px_kwargs,
161
161
  )
162
162
 
163
+ def fast_bar(
164
+ self,
165
+ *,
166
+ x: SlotValue = auto,
167
+ color: SlotValue = auto,
168
+ facet_col: SlotValue = auto,
169
+ facet_row: SlotValue = auto,
170
+ animation_frame: SlotValue = auto,
171
+ **px_kwargs: Any,
172
+ ) -> go.Figure:
173
+ """Create a bar-like chart using stacked areas for better performance.
174
+
175
+ Slot order: x -> color -> facet_col -> facet_row -> animation_frame
176
+
177
+ Args:
178
+ x: Dimension for x-axis. Default: first dimension.
179
+ color: Dimension for color/stacking. Default: second dimension.
180
+ facet_col: Dimension for subplot columns. Default: third dimension.
181
+ facet_row: Dimension for subplot rows. Default: fourth dimension.
182
+ animation_frame: Dimension for animation. Default: fifth dimension.
183
+ **px_kwargs: Additional arguments passed to `plotly.express.area()`.
184
+
185
+ Returns:
186
+ Interactive Plotly Figure.
187
+ """
188
+ return plotting.fast_bar(
189
+ self._da,
190
+ x=x,
191
+ color=color,
192
+ facet_col=facet_col,
193
+ facet_row=facet_row,
194
+ animation_frame=animation_frame,
195
+ **px_kwargs,
196
+ )
197
+
163
198
  def scatter(
164
199
  self,
165
200
  *,
@@ -349,7 +384,7 @@ class DatasetPlotlyAccessor:
349
384
  ```
350
385
  """
351
386
 
352
- __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "pie"]
387
+ __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "pie"]
353
388
 
354
389
  def __init__(self, dataset: Dataset) -> None:
355
390
  self._ds = dataset
@@ -501,6 +536,42 @@ class DatasetPlotlyAccessor:
501
536
  **px_kwargs,
502
537
  )
503
538
 
539
+ def fast_bar(
540
+ self,
541
+ var: str | None = None,
542
+ *,
543
+ x: SlotValue = auto,
544
+ color: SlotValue = auto,
545
+ facet_col: SlotValue = auto,
546
+ facet_row: SlotValue = auto,
547
+ animation_frame: SlotValue = auto,
548
+ **px_kwargs: Any,
549
+ ) -> go.Figure:
550
+ """Create a bar-like chart using stacked areas for better performance.
551
+
552
+ Args:
553
+ var: Variable to plot. If None, plots all variables with "variable" dimension.
554
+ x: Dimension for x-axis.
555
+ color: Dimension for color/stacking.
556
+ facet_col: Dimension for subplot columns.
557
+ facet_row: Dimension for subplot rows.
558
+ animation_frame: Dimension for animation.
559
+ **px_kwargs: Additional arguments passed to `plotly.express.area()`.
560
+
561
+ Returns:
562
+ Interactive Plotly Figure.
563
+ """
564
+ da = self._get_dataarray(var)
565
+ return plotting.fast_bar(
566
+ da,
567
+ x=x,
568
+ color=color,
569
+ facet_col=facet_col,
570
+ facet_row=facet_row,
571
+ animation_frame=animation_frame,
572
+ **px_kwargs,
573
+ )
574
+
504
575
  def scatter(
505
576
  self,
506
577
  var: str | None = None,
xarray_plotly/config.py CHANGED
@@ -26,6 +26,7 @@ DEFAULT_SLOT_ORDERS: dict[str, tuple[str, ...]] = {
26
26
  "animation_frame",
27
27
  ),
28
28
  "bar": ("x", "color", "pattern_shape", "facet_col", "facet_row", "animation_frame"),
29
+ "fast_bar": ("x", "color", "facet_col", "facet_row", "animation_frame"),
29
30
  "area": (
30
31
  "x",
31
32
  "color",
xarray_plotly/figures.py CHANGED
@@ -8,9 +8,28 @@ import copy
8
8
  from typing import TYPE_CHECKING
9
9
 
10
10
  if TYPE_CHECKING:
11
+ from collections.abc import Iterator
12
+
11
13
  import plotly.graph_objects as go
12
14
 
13
15
 
16
+ def _iter_all_traces(fig: go.Figure) -> Iterator:
17
+ """Iterate over all traces in a figure, including animation frames.
18
+
19
+ Yields traces from fig.data first, then from each frame in fig.frames.
20
+ Useful for applying styling to all traces including those in animations.
21
+
22
+ Args:
23
+ fig: Plotly Figure.
24
+
25
+ Yields:
26
+ Each trace object from the figure.
27
+ """
28
+ yield from fig.data
29
+ for frame in fig.frames or []:
30
+ yield from frame.data
31
+
32
+
14
33
  def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]:
15
34
  """Extract (xaxis, yaxis) pairs from figure traces.
16
35
 
@@ -418,15 +437,12 @@ def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.
418
437
  >>> # Update specific trace by name
419
438
  >>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot")
420
439
  """
421
- fig.update_traces(selector=selector, **kwargs)
422
-
423
- for frame in fig.frames:
424
- for trace in frame.data:
425
- if selector is None:
440
+ for trace in _iter_all_traces(fig):
441
+ if selector is None:
442
+ trace.update(**kwargs)
443
+ else:
444
+ # Check if trace matches all selector criteria
445
+ if all(getattr(trace, k, None) == v for k, v in selector.items()):
426
446
  trace.update(**kwargs)
427
- else:
428
- # Check if trace matches all selector criteria
429
- if all(getattr(trace, k, None) == v for k, v in selector.items()):
430
- trace.update(**kwargs)
431
447
 
432
448
  return fig
xarray_plotly/plotting.py CHANGED
@@ -4,6 +4,7 @@ Plotly Express plotting functions for DataArray objects.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import warnings
7
8
  from typing import TYPE_CHECKING, Any
8
9
 
9
10
  import numpy as np
@@ -18,6 +19,9 @@ from xarray_plotly.common import (
18
19
  get_value_col,
19
20
  to_dataframe,
20
21
  )
22
+ from xarray_plotly.figures import (
23
+ _iter_all_traces,
24
+ )
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  import plotly.graph_objects as go
@@ -167,6 +171,167 @@ def bar(
167
171
  )
168
172
 
169
173
 
174
+ def _classify_trace_sign(y_values: np.ndarray) -> str:
175
+ """Classify a trace as 'positive', 'negative', or 'mixed' based on its values."""
176
+ y_arr = np.asarray(y_values)
177
+ y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]
178
+ if len(y_clean) == 0:
179
+ return "zero"
180
+ has_pos = bool(np.any(y_clean > 0))
181
+ has_neg = bool(np.any(y_clean < 0))
182
+ if has_pos and has_neg:
183
+ return "mixed"
184
+ elif has_neg:
185
+ return "negative"
186
+ elif has_pos:
187
+ return "positive"
188
+ return "zero"
189
+
190
+
191
+ def _style_traces_as_bars(fig: go.Figure) -> None:
192
+ """Style area chart traces to look like bar charts with proper pos/neg stacking.
193
+
194
+ Classifies each trace (by name) across all data and animation frames,
195
+ then assigns stackgroups: positive traces stack upward, negative stack downward.
196
+ """
197
+ # Collect all traces (main + animation frames)
198
+ all_traces = list(_iter_all_traces(fig))
199
+
200
+ # Classify each trace name by aggregating sign info across all occurrences
201
+ sign_flags: dict[str, dict[str, bool]] = {}
202
+ for trace in all_traces:
203
+ if trace.name not in sign_flags:
204
+ sign_flags[trace.name] = {"has_pos": False, "has_neg": False}
205
+ if trace.y is not None and len(trace.y) > 0:
206
+ y_arr = np.asarray(trace.y)
207
+ y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]
208
+ if len(y_clean) > 0:
209
+ if np.any(y_clean > 0):
210
+ sign_flags[trace.name]["has_pos"] = True
211
+ if np.any(y_clean < 0):
212
+ sign_flags[trace.name]["has_neg"] = True
213
+
214
+ # Build classification map
215
+ class_map: dict[str, str] = {}
216
+ mixed_traces: list[str] = []
217
+ for name, flags in sign_flags.items():
218
+ if flags["has_pos"] and flags["has_neg"]:
219
+ class_map[name] = "mixed"
220
+ mixed_traces.append(name)
221
+ elif flags["has_neg"]:
222
+ class_map[name] = "negative"
223
+ elif flags["has_pos"]:
224
+ class_map[name] = "positive"
225
+ else:
226
+ class_map[name] = "zero"
227
+
228
+ # Warn about mixed traces
229
+ if mixed_traces:
230
+ warnings.warn(
231
+ f"fast_bar: traces {mixed_traces} have mixed positive/negative values "
232
+ "and cannot be stacked. They are shown as dashed lines. "
233
+ "Consider using bar() for proper stacking of mixed data.",
234
+ UserWarning,
235
+ stacklevel=3,
236
+ )
237
+
238
+ # Apply styling to all traces
239
+ for trace in all_traces:
240
+ color = trace.line.color
241
+ cls = class_map.get(trace.name, "positive")
242
+
243
+ if cls in ("positive", "negative"):
244
+ trace.stackgroup = cls
245
+ trace.fillcolor = color
246
+ trace.line = {"width": 0, "color": color, "shape": "hv"}
247
+ elif cls == "mixed":
248
+ # Mixed: no stacking, show as dashed line
249
+ trace.stackgroup = None
250
+ trace.fill = None
251
+ trace.line = {"width": 2, "color": color, "shape": "hv", "dash": "dash"}
252
+ else: # zero
253
+ trace.stackgroup = None
254
+ trace.fill = None
255
+ trace.line = {"width": 0, "color": color, "shape": "hv"}
256
+
257
+
258
+ def fast_bar(
259
+ darray: DataArray,
260
+ *,
261
+ x: SlotValue = auto,
262
+ color: SlotValue = auto,
263
+ facet_col: SlotValue = auto,
264
+ facet_row: SlotValue = auto,
265
+ animation_frame: SlotValue = auto,
266
+ **px_kwargs: Any,
267
+ ) -> go.Figure:
268
+ """
269
+ Create a bar-like chart using stacked areas for better performance.
270
+
271
+ Uses `px.area` with stepped lines and no outline to create a bar-like
272
+ appearance. Renders faster than `bar()` for large datasets because it
273
+ uses a single polygon per trace instead of individual rectangles.
274
+
275
+ The y-axis shows DataArray values. Dimensions fill slots in order:
276
+ x -> color -> facet_col -> facet_row -> animation_frame
277
+
278
+ Traces are classified by their values: purely positive traces stack upward,
279
+ purely negative traces stack downward. Traces with mixed signs are shown
280
+ as dashed lines without stacking.
281
+
282
+ Parameters
283
+ ----------
284
+ darray
285
+ The DataArray to plot.
286
+ x
287
+ Dimension for x-axis. Default: first dimension.
288
+ color
289
+ Dimension for color/stacking. Default: second dimension.
290
+ facet_col
291
+ Dimension for subplot columns. Default: third dimension.
292
+ facet_row
293
+ Dimension for subplot rows. Default: fourth dimension.
294
+ animation_frame
295
+ Dimension for animation. Default: fifth dimension.
296
+ **px_kwargs
297
+ Additional arguments passed to `plotly.express.area()`.
298
+
299
+ Returns
300
+ -------
301
+ plotly.graph_objects.Figure
302
+ """
303
+ slots = assign_slots(
304
+ list(darray.dims),
305
+ "fast_bar",
306
+ x=x,
307
+ color=color,
308
+ facet_col=facet_col,
309
+ facet_row=facet_row,
310
+ animation_frame=animation_frame,
311
+ )
312
+
313
+ df = to_dataframe(darray)
314
+ value_col = get_value_col(darray)
315
+ labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})}
316
+
317
+ fig = px.area(
318
+ df,
319
+ x=slots.get("x"),
320
+ y=value_col,
321
+ color=slots.get("color"),
322
+ facet_col=slots.get("facet_col"),
323
+ facet_row=slots.get("facet_row"),
324
+ animation_frame=slots.get("animation_frame"),
325
+ line_shape="hv",
326
+ labels=labels,
327
+ **px_kwargs,
328
+ )
329
+
330
+ _style_traces_as_bars(fig)
331
+
332
+ return fig
333
+
334
+
170
335
  def area(
171
336
  darray: DataArray,
172
337
  *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xarray_plotly
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: Interactive Plotly Express plotting accessor for xarray
5
5
  Author: Felix
6
6
  License: MIT
@@ -0,0 +1,12 @@
1
+ xarray_plotly/__init__.py,sha256=vAM2TCeLnSkSjwCDETlBI_1SukhVeth7HFMaFIWl_ps,3384
2
+ xarray_plotly/accessor.py,sha256=eP4_GPBCN-c32M8s1LJv4Zn4PvHHO9HsP5oPQMDDd6I,23150
3
+ xarray_plotly/common.py,sha256=YTiaPLJ0Gh20mHV8-72J8DjWk-XaSYaT_ZiqXp6cecU,7192
4
+ xarray_plotly/config.py,sha256=gS6IqWdx82PQs1yZzl_diGHlnmVTladuWZzk1lYgIsQ,6569
5
+ xarray_plotly/figures.py,sha256=5SxTTPCgaEMRC5Bz7Ew1leJXDNMA1ZchbY_dO82SHY4,15574
6
+ xarray_plotly/plotting.py,sha256=Fl_UklCN8AYtPj0t3m8zXwcz6MsFLdmQrKZC68IH1OM,20668
7
+ xarray_plotly/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ xarray_plotly-0.0.8.dist-info/licenses/LICENSE,sha256=AvVEfNqbhIm9jHvt0acJNjW1JUKa2a70Zb5rJdEXCJI,1064
9
+ xarray_plotly-0.0.8.dist-info/METADATA,sha256=ezKDOXZ7CklqOB5KDBh32Lm1nri63OF_OoeCKhuIt9c,3415
10
+ xarray_plotly-0.0.8.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
11
+ xarray_plotly-0.0.8.dist-info/top_level.txt,sha256=GtMkvuZvLAYTjYXtwoNUa0ag42CJARZJK1CZemYD7pg,14
12
+ xarray_plotly-0.0.8.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- xarray_plotly/__init__.py,sha256=vAM2TCeLnSkSjwCDETlBI_1SukhVeth7HFMaFIWl_ps,3384
2
- xarray_plotly/accessor.py,sha256=rbVqzi29nvQUM8b5jE5ivw-CPoUxQjrl5Fnh-LQvtWo,20758
3
- xarray_plotly/common.py,sha256=YTiaPLJ0Gh20mHV8-72J8DjWk-XaSYaT_ZiqXp6cecU,7192
4
- xarray_plotly/config.py,sha256=-Sp6TNx-8Zk6x0WmJ5j_KGHPRY639ju3_RSemrlmuXY,6492
5
- xarray_plotly/figures.py,sha256=fE6aLqcNuCZ5GA-3x4xTl1Bw9Fg9fGCoyjKdE7jM-eo,15161
6
- xarray_plotly/plotting.py,sha256=i06kz7HSloH_SfkmkjRnafz1auADqFi7X5N3tZKZ1_E,15239
7
- xarray_plotly/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- xarray_plotly-0.0.7.dist-info/licenses/LICENSE,sha256=AvVEfNqbhIm9jHvt0acJNjW1JUKa2a70Zb5rJdEXCJI,1064
9
- xarray_plotly-0.0.7.dist-info/METADATA,sha256=6YUJZEytGX_sjeAXz-jeW9Wpuct2yicNBeo8bauGemc,3415
10
- xarray_plotly-0.0.7.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
11
- xarray_plotly-0.0.7.dist-info/top_level.txt,sha256=GtMkvuZvLAYTjYXtwoNUa0ag42CJARZJK1CZemYD7pg,14
12
- xarray_plotly-0.0.7.dist-info/RECORD,,