plopp 25.7.1__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,30 +15,61 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, LogNorm, Normal
15
15
  from ..backends.matplotlib.utils import fig_to_bytes
16
16
  from ..core.limits import find_limits, fix_empty_range
17
17
  from ..core.utils import maybe_variable_to_number, merge_masks
18
+ from ..utils import parse_mutually_exclusive
18
19
 
19
20
 
20
- def _get_cmap(name: str, nan_color: str | None = None) -> Colormap:
21
+ def _shift_color(color: float, delta: float) -> float:
22
+ """
23
+ Shift a color (number from 0 to 1) by delta. If the result is out of bounds,
24
+ the color is shifted in the opposite direction.
25
+ """
26
+ shifted = color + delta
27
+ if shifted > 1.0 or shifted < 0.0:
28
+ return color - delta
29
+ return shifted
30
+
31
+
32
+ def _get_cmap(colormap: str | Colormap, nan_color: str | None = None) -> Colormap:
21
33
  """
22
34
  Get a colormap object from a colormap name.
35
+ We also set the 'over', 'under' and 'bad' colors. The 'bad' color is set to
36
+ ``nan_color`` if it is not None. The 'over' and 'under' colors are set to be
37
+ slightly lighter or darker than the first and last colors in the colormap.
23
38
 
24
39
  Parameters
25
40
  ----------
26
- name:
41
+ colormap:
27
42
  Name of the colormap. If the name is just a single html color, this will
28
- create a colormap with that single color.
43
+ create a colormap with that single color. If ``cmap`` is already a Colormap,
44
+ it will be used as is.
29
45
  nan_color:
30
46
  The color to use for NAN values.
31
47
  """
32
48
 
49
+ if isinstance(colormap, Colormap):
50
+ return colormap
51
+
33
52
  try:
34
53
  if hasattr(mpl, 'colormaps'):
35
- cmap = copy(mpl.colormaps[name])
54
+ cmap = copy(mpl.colormaps[colormap])
36
55
  else:
37
- cmap = mpl.cm.get_cmap(name)
56
+ cmap = mpl.cm.get_cmap(colormap)
38
57
  except (KeyError, ValueError):
39
- cmap = LinearSegmentedColormap.from_list('tmp', [name, name])
40
- # TODO: we need to set under and over values for the cmap with
41
- # `cmap.set_under` and `cmap.set_over`. Ideally these should come from a config?
58
+ cmap = LinearSegmentedColormap.from_list('tmp', [colormap, colormap])
59
+
60
+ # Add under and over values to the cmap
61
+ delta = 0.15
62
+ cmap = cmap.copy()
63
+ over = cmap.get_over()
64
+ under = cmap.get_under()
65
+ # Note that we only shift the first 3 RGB values, leaving alpha unchanged.
66
+ cmap.set_over(
67
+ [_shift_color(c, delta * (-1 + 2 * (np.mean(over) > 0.5))) for c in over[:3]]
68
+ )
69
+ cmap.set_under(
70
+ [_shift_color(c, delta * (-1 + 2 * (np.mean(under) > 0.5))) for c in under[:3]]
71
+ )
72
+
42
73
  if nan_color is not None:
43
74
  cmap.set_bad(color=nan_color)
44
75
  return cmap
@@ -72,46 +103,75 @@ class ColorMapper:
72
103
  a colormap with that single color will be used.
73
104
  mask_cmap:
74
105
  The name of the colormap for masked data.
75
- norm:
76
- The colorscale normalization.
77
- vmin:
106
+ cmin:
78
107
  The minimum value for the colorscale range. If a number (without a unit) is
79
108
  supplied, it is assumed that the unit is the same as the data unit.
80
- vmax:
109
+ cmax:
81
110
  The maximum value for the colorscale range. If a number (without a unit) is
82
111
  supplied, it is assumed that the unit is the same as the data unit.
112
+ logc:
113
+ If ``True``, use a logarithmic colorscale.
83
114
  nan_color:
84
115
  The color used for representing NAN values.
85
116
  figsize:
86
117
  The size of the figure next to which the colorbar will be displayed.
118
+ norm:
119
+ The colorscale normalization. This is an old parameter name.
120
+ Prefer using ``logc`` instead.
121
+ vmin:
122
+ The minimum value for the colorscale range. If a number (without a unit) is
123
+ supplied, it is assumed that the unit is the same as the data unit.
124
+ This is an old parameter name. Prefer using ``cmin`` instead.
125
+ vmax:
126
+ The maximum value for the colorscale range. If a number (without a unit) is
127
+ supplied, it is assumed that the unit is the same as the data unit.
128
+ This is an old parameter name. Prefer using ``cmax`` instead.
87
129
  """
88
130
 
89
131
  def __init__(
90
132
  self,
91
133
  canvas: Any | None = None,
92
134
  cbar: bool = True,
93
- cmap: str = 'viridis',
94
- mask_cmap: str = 'gray',
95
- norm: Literal['linear', 'log'] = 'linear',
96
- vmin: sc.Variable | float | None = None,
97
- vmax: sc.Variable | float | None = None,
135
+ cmap: str | Colormap = 'viridis',
136
+ mask_cmap: str | Colormap = 'gray',
137
+ mask_color: str | None = None,
138
+ cmin: sc.Variable | float | None = None,
139
+ cmax: sc.Variable | float | None = None,
140
+ logc: bool | None = None,
141
+ clabel: str | None = None,
98
142
  nan_color: str | None = None,
99
143
  figsize: tuple[float, float] | None = None,
144
+ norm: Literal['linear', 'log', None] = None,
145
+ vmin: sc.Variable | float | None = None,
146
+ vmax: sc.Variable | float | None = None,
100
147
  ):
148
+ cmin = parse_mutually_exclusive(vmin=vmin, cmin=cmin)
149
+ cmax = parse_mutually_exclusive(vmax=vmax, cmax=cmax)
150
+ logc = parse_mutually_exclusive(norm=norm, logc=logc)
151
+
101
152
  self._canvas = canvas
102
153
  self.cax = self._canvas.cax if hasattr(self._canvas, 'cax') else None
103
154
  self.cmap = _get_cmap(cmap, nan_color=nan_color)
104
- self.mask_cmap = _get_cmap(mask_cmap, nan_color=nan_color)
105
- self.user_vmin = vmin
106
- self.user_vmax = vmax
107
- self._vmin = np.inf
108
- self._vmax = -np.inf
109
- self.norm = norm
155
+ self.mask_cmap = _get_cmap(
156
+ mask_cmap if mask_color is None else mask_color, nan_color=nan_color
157
+ )
158
+
159
+ # Inside the autoscale, we need to distinguish between a min value that was set
160
+ # by the user and one that was found by looping over all the data.
161
+ # So we basically always need to keep a backup of the user-set value.
162
+ self._cmin = {"data": np.inf}
163
+ self._cmax = {"data": -np.inf}
164
+ if cmin is not None:
165
+ self._cmin["user"] = cmin
166
+ if cmax is not None:
167
+ self._cmax["user"] = cmax
168
+ self._clabel = clabel
169
+ self._logc = False if logc is None else logc
110
170
  self.set_colors_on_update = True
111
171
 
112
- # Note that we need to set vmin/vmax for the LogNorm, if not an error is
172
+ # Note that we need to set cmin/cmax for the LogNorm, if not an error is
113
173
  # raised when making the colorbar before any call to update is made.
114
- self.normalizer = _get_normalizer(self.norm)
174
+ self.normalizer = _get_normalizer('log' if self._logc else 'linear')
115
175
  self.colorbar = None
116
176
  self._unit = None
117
177
  self.empty = True
@@ -127,6 +187,8 @@ class ColorMapper:
127
187
  self.cax = fig.add_axes([0.05, 0.02, 0.2, 0.98])
128
188
  self.colorbar = ColorbarBase(self.cax, cmap=self.cmap, norm=self.normalizer)
129
189
  self.cax.yaxis.set_label_coords(-0.9, 0.5)
190
+ if self._clabel is not None:
191
+ self.cax.set_ylabel(self._clabel)
130
192
 
131
193
  def add_artist(self, key: str, artist: Any):
132
194
  self.artists[key] = artist
@@ -171,41 +233,47 @@ class ColorMapper:
171
233
  Re-compute the global min and max range of values by iterating over all the
172
234
  artists and adjust the limits.
173
235
  """
236
+ if ("user" in self._cmin) and ("user" in self._cmax):
237
+ if self._cmin["user"] >= self._cmax["user"]:
238
+ raise ValueError('User-set limits: cmin must be smaller than cmax.')
239
+ self._cmin["data"] = self._cmin["user"]
240
+ self._cmax["data"] = self._cmax["user"]
241
+ self.apply_limits()
242
+ return
243
+
174
244
  limits = [
175
- fix_empty_range(find_limits(artist._data, scale=self.norm))
245
+ fix_empty_range(
246
+ find_limits(artist._data, scale='log' if self._logc else 'linear')
247
+ )
176
248
  for artist in self.artists.values()
177
249
  ]
178
- vmin = reduce(min, [v[0] for v in limits])
179
- vmax = reduce(max, [v[1] for v in limits])
180
- if self.user_vmin is not None:
181
- self._vmin = self.user_vmin
250
+ if "user" not in self._cmin:
251
+ self._cmin["data"] = reduce(min, [v[0] for v in limits]).value
182
252
  else:
183
- self._vmin = vmin.value
184
- if self.user_vmax is not None:
185
- self._vmax = self.user_vmax
253
+ self._cmin["data"] = self._cmin["user"]
254
+ if "user" not in self._cmax:
255
+ self._cmax["data"] = reduce(max, [v[1] for v in limits]).value
186
256
  else:
187
- self._vmax = vmax.value
257
+ self._cmax["data"] = self._cmax["user"]
188
258
 
189
- if self._vmin >= self._vmax:
190
- if self.user_vmax is not None:
191
- self._vmax = self.user_vmax
192
- self._vmin = self.user_vmax - abs(self.user_vmax) * 0.1
259
+ if self._cmin["data"] >= self._cmax["data"]:
260
+ if "user" in self._cmax:
261
+ self._cmin["data"] = self._cmax["data"] - abs(self._cmax["data"]) * 0.1
193
262
  else:
194
- self._vmin = self.user_vmin
195
- self._vmax = self.user_vmin + abs(self.user_vmin) * 0.1
263
+ self._cmax["data"] = self._cmin["data"] + abs(self._cmin["data"]) * 0.1
196
264
 
197
265
  self.apply_limits()
198
266
 
199
267
  def apply_limits(self):
200
268
  # Synchronize the underlying normalizer limits to the current state.
201
- # Note that the order matters here, as for a normalizer vmin cannot be set above
202
- # the current vmax.
203
- if self._vmin >= self.normalizer.vmax:
204
- self.normalizer.vmax = self._vmax
205
- self.normalizer.vmin = self._vmin
269
+ # Note that the order matters here, as for a normalizer cmin cannot be set above
270
+ # the current cmax.
271
+ if self._cmin["data"] >= self.normalizer.vmax:
272
+ self.normalizer.vmax = self._cmax["data"]
273
+ self.normalizer.vmin = self._cmin["data"]
206
274
  else:
207
- self.normalizer.vmin = self._vmin
208
- self.normalizer.vmax = self._vmax
275
+ self.normalizer.vmin = self._cmin["data"]
276
+ self.normalizer.vmax = self._cmax["data"]
209
277
 
210
278
  if self.colorbar is not None:
211
279
  self._update_colorbar_widget()
@@ -222,24 +290,48 @@ class ColorMapper:
222
290
  def vmin(self) -> float:
223
291
  """
224
292
  Get or set the minimum value of the colorbar.
293
+ This is an old property name. Prefer using ``cmin`` instead.
225
294
  """
226
- return self._vmin
295
+ return self.cmin
227
296
 
228
297
  @vmin.setter
229
- def vmin(self, vmin: sc.Variable | float):
230
- self._vmin = maybe_variable_to_number(vmin, unit=self._unit)
231
- self.apply_limits()
298
+ def vmin(self, value: sc.Variable | float):
299
+ self.cmin = value
232
300
 
233
301
  @property
234
302
  def vmax(self) -> float:
235
303
  """
236
304
  Get or set the maximum value of the colorbar.
305
+ This is an old property name. Prefer using ``cmax`` instead.
237
306
  """
238
- return self._vmax
307
+ return self.cmax
239
308
 
240
309
  @vmax.setter
241
- def vmax(self, vmax: sc.Variable | float):
242
- self._vmax = maybe_variable_to_number(vmax, unit=self._unit)
310
+ def vmax(self, value: sc.Variable | float):
311
+ self.cmax = value
312
+
313
+ @property
314
+ def cmin(self) -> float:
315
+ """
316
+ Get or set the minimum value of the colorbar.
317
+ """
318
+ return self._cmin.get("user", self._cmin["data"])
319
+
320
+ @cmin.setter
321
+ def cmin(self, value: sc.Variable | float):
322
+ self._cmin["user"] = maybe_variable_to_number(value, unit=self._unit)
323
+ self.apply_limits()
324
+
325
+ @property
326
+ def cmax(self) -> float:
327
+ """
328
+ Get or set the maximum value of the colorbar.
329
+ """
330
+ return self._cmax.get("user", self._cmax["data"])
331
+
332
+ @cmax.setter
333
+ def cmax(self, value: sc.Variable | float):
334
+ self._cmax["user"] = maybe_variable_to_number(value, unit=self._unit)
243
335
  self.apply_limits()
244
336
 
245
337
  @property
@@ -252,34 +344,70 @@ class ColorMapper:
252
344
  @unit.setter
253
345
  def unit(self, unit: str | None):
254
346
  self._unit = unit
255
- if self.user_vmin is not None:
256
- self.user_vmin = maybe_variable_to_number(self.user_vmin, unit=self._unit)
257
- if self.user_vmax is not None:
258
- self.user_vmax = maybe_variable_to_number(self.user_vmax, unit=self._unit)
347
+ if "user" in self._cmin:
348
+ self._cmin["user"] = maybe_variable_to_number(
349
+ self._cmin["user"], unit=self._unit
350
+ )
351
+ if "user" in self._cmax:
352
+ self._cmax["user"] = maybe_variable_to_number(
353
+ self._cmax["user"], unit=self._unit
354
+ )
259
355
 
260
356
  @property
261
- def ylabel(self) -> str | None:
357
+ def clabel(self) -> str | None:
262
358
  """
263
359
  Get or set the label of the colorbar axis.
264
360
  """
265
361
  if self.cax is not None:
266
362
  return self.cax.get_ylabel()
267
363
 
268
- @ylabel.setter
269
- def ylabel(self, lab: str):
364
+ @clabel.setter
365
+ def clabel(self, lab: str):
270
366
  if self.cax is not None:
271
367
  self.cax.set_ylabel(lab)
272
368
 
369
+ @property
370
+ def ylabel(self) -> str | None:
371
+ """
372
+ Get or set the label of the colorbar axis.
373
+ This is an old property name. Prefer using ``clabel`` instead.
374
+ """
375
+ return self.clabel
376
+
377
+ @ylabel.setter
378
+ def ylabel(self, lab: str):
379
+ self.clabel = lab
380
+
273
381
  def toggle_norm(self):
274
382
  """
275
383
  Toggle the norm flag, between `linear` and `log`.
276
384
  """
277
- self.norm = "log" if self.norm == 'linear' else 'linear'
278
- self.normalizer = _get_normalizer(self.norm)
279
- self._vmin = np.inf
280
- self._vmax = -np.inf
385
+ self._logc = not self._logc
386
+ self.normalizer = _get_normalizer('log' if self._logc else 'linear')
387
+ self._cmin["data"] = np.inf
388
+ self._cmax["data"] = -np.inf
281
389
  if self.colorbar is not None:
282
390
  self.colorbar.mappable.norm = self.normalizer
283
391
  self.autoscale()
284
392
  if self._canvas is not None:
285
393
  self._canvas.draw()
394
+
395
+ @property
396
+ def norm(self) -> Literal['linear', 'log']:
397
+ """
398
+ Get or set the colorscale normalization.
399
+ """
400
+ return 'log' if self._logc else 'linear'
401
+
402
+ @norm.setter
403
+ def norm(self, value: Literal['linear', 'log']):
404
+ if value not in ['linear', 'log']:
405
+ raise ValueError('norm must be either "linear" or "log".')
406
+ if value != self.norm:
407
+ self.toggle_norm()
408
+
409
+ def has_user_clabel(self) -> bool:
410
+ """
411
+ Return ``True`` if the user has set a colorbar label.
412
+ """
413
+ return self._clabel is not None
@@ -49,8 +49,9 @@ class GraphicalView(View):
49
49
  colormapper: bool = False,
50
50
  cmap: str = 'viridis',
51
51
  mask_cmap: str = 'gray',
52
+ mask_color: str | None = None,
52
53
  cbar: bool = False,
53
- norm: Literal['linear', 'log'] = 'linear',
54
+ norm: Literal['linear', 'log', None] = None,
54
55
  vmin: sc.Variable | float | None = None,
55
56
  vmax: sc.Variable | float | None = None,
56
57
  scale: dict[str, str] | None = None,
@@ -64,9 +65,27 @@ class GraphicalView(View):
64
65
  autoscale: bool = True,
65
66
  ax: Any = None,
66
67
  cax: Any = None,
68
+ xmin: sc.Variable | float | None = None,
69
+ xmax: sc.Variable | float | None = None,
70
+ ymin: sc.Variable | float | None = None,
71
+ ymax: sc.Variable | float | None = None,
72
+ zmin: sc.Variable | float | None = None,
73
+ zmax: sc.Variable | float | None = None,
74
+ cmin: sc.Variable | float | None = None,
75
+ cmax: sc.Variable | float | None = None,
76
+ logx: bool | None = None,
77
+ logy: bool | None = None,
78
+ logz: bool | None = None,
79
+ logc: bool | None = None,
80
+ xlabel: str | None = None,
81
+ ylabel: str | None = None,
82
+ zlabel: str | None = None,
83
+ clabel: str | None = None,
84
+ nan_color: str | None = None,
67
85
  **kwargs,
68
86
  ):
69
87
  super().__init__(*nodes)
88
+
70
89
  self._dims = dims
71
90
  self._scale = {} if scale is None else scale
72
91
  self.artists = {}
@@ -88,6 +107,19 @@ class GraphicalView(View):
88
107
  camera=camera,
89
108
  ax=ax,
90
109
  cax=cax,
110
+ xmin=xmin,
111
+ xmax=xmax,
112
+ ymin=ymin,
113
+ ymax=ymax,
114
+ zmin=zmin,
115
+ zmax=zmax,
116
+ logx=logx,
117
+ logy=logy,
118
+ logz=logz,
119
+ xlabel=xlabel,
120
+ ylabel=ylabel,
121
+ zlabel=zlabel,
122
+ norm=norm if len(dims) == 1 else None,
91
123
  )
92
124
 
93
125
  if colormapper:
@@ -95,11 +127,17 @@ class GraphicalView(View):
95
127
  cmap=cmap,
96
128
  cbar=cbar,
97
129
  mask_cmap=mask_cmap,
130
+ mask_color=mask_color,
98
131
  norm=norm,
99
132
  vmin=vmin,
100
133
  vmax=vmax,
134
+ cmin=cmin,
135
+ cmax=cmax,
136
+ clabel=clabel,
137
+ logc=logc,
101
138
  canvas=self.canvas,
102
139
  figsize=getattr(self.canvas, "figsize", None),
140
+ nan_color=nan_color,
103
141
  )
104
142
  self._kwargs['colormapper'] = self.colormapper
105
143
  if self._autoscale:
@@ -108,9 +146,10 @@ class GraphicalView(View):
108
146
  self.colormapper.set_colors_on_update = False
109
147
  else:
110
148
  self.colormapper = None
149
+ # If there is not colormapper, we need to forward the mask_color to the
150
+ # artist maker.
151
+ self._kwargs['mask_color'] = mask_color
111
152
 
112
- if len(self._dims) == 1:
113
- self.canvas.yscale = norm
114
153
  self.render()
115
154
 
116
155
  def autoscale(self):
@@ -145,19 +184,20 @@ class GraphicalView(View):
145
184
  for i, direction in enumerate(self._dims):
146
185
  if self._dims[direction] is None:
147
186
  self._dims[direction] = new_values.dims[i]
148
- try:
149
- coords[direction] = new_values.coords[self._dims[direction]]
150
- except KeyError as e:
187
+ if self._dims[direction] not in new_values.coords:
151
188
  raise KeyError(
152
189
  "Supplied data is incompatible with this view: "
153
190
  f"coordinate '{self._dims[direction]}' was not found in data."
154
- ) from e
191
+ )
192
+ coords[direction] = new_values.coords[self._dims[direction]]
155
193
 
156
194
  if self.canvas.empty:
157
195
  self._data_name = new_values.name
158
196
  axes_units = {k: coord.unit for k, coord in coords.items()}
159
197
  axes_dtypes = {k: coord.dtype for k, coord in coords.items()}
160
198
 
199
+ data_label = name_with_unit(var=new_values.data, name=self._data_name)
200
+
161
201
  if set(self._dims) == {'x'}:
162
202
  axes_units['data'] = new_values.unit
163
203
  axes_dtypes['data'] = new_values.dtype
@@ -165,28 +205,31 @@ class GraphicalView(View):
165
205
  self.colormapper.unit = new_values.unit
166
206
  axes_units['data'] = new_values.unit
167
207
  axes_dtypes['data'] = new_values.dtype
168
- self._data_axis = self.colormapper
208
+ if not self.colormapper.has_user_clabel():
209
+ self.colormapper.clabel = data_label
169
210
  else:
170
- self._data_axis = self.canvas
211
+ if not self.canvas.has_user_ylabel():
212
+ self.canvas.ylabel = data_label
171
213
 
172
214
  self.canvas.set_axes(
173
215
  dims=self._dims, units=axes_units, dtypes=axes_dtypes
174
216
  )
175
217
 
176
218
  for xyz, dim in self._dims.items():
177
- setattr(
178
- self.canvas,
179
- f'{xyz}label',
180
- name_with_unit(var=coords[xyz], name=dim),
181
- )
219
+ if not getattr(self.canvas, f'has_user_{xyz}label')():
220
+ setattr(
221
+ self.canvas,
222
+ f'{xyz}label',
223
+ name_with_unit(var=coords[xyz], name=dim),
224
+ )
225
+ # Note that setting the scale is handled here as well as in the
226
+ # canvas for historical purposes. We kept the scale argument for
227
+ # backward compatibility, but it is now also possible to set the
228
+ # axes scales with logx, logy, logz in the constructor of the
229
+ # canvas.
182
230
  if dim in self._scale:
183
231
  setattr(self.canvas, f'{xyz}scale', self._scale[dim])
184
232
 
185
- if self._data_axis is not None:
186
- self._data_axis.ylabel = name_with_unit(
187
- var=new_values.data, name=self._data_name
188
- )
189
-
190
233
  else:
191
234
  for xy, dim in self._dims.items():
192
235
  new_values.coords[dim] = make_compatible(
@@ -198,9 +241,15 @@ class GraphicalView(View):
198
241
  )
199
242
  if self._data_name and (new_values.name != self._data_name):
200
243
  self._data_name = None
201
- self._data_axis.ylabel = name_with_unit(
202
- var=sc.scalar(0.0, unit=self.canvas.units['data']), name=''
244
+ data_label = name_with_unit(
245
+ var=sc.scalar(0.0, unit=self.canvas.units['data']),
246
+ name='',
203
247
  )
248
+ if self.colormapper is not None:
249
+ if not self.colormapper.has_user_clabel():
250
+ self.colormapper.clabel = data_label
251
+ elif not self.canvas.has_user_ylabel():
252
+ self.canvas.ylabel = data_label
204
253
 
205
254
  if key not in self.artists:
206
255
  self.artists[key] = self._artist_maker(
plopp/plotting/common.py CHANGED
@@ -3,7 +3,7 @@
3
3
 
4
4
  import warnings
5
5
  from collections.abc import Callable, Iterable
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import numpy as np
9
9
  import scipp as sc
@@ -97,7 +97,7 @@ def to_data_array(obj: Plottable | list) -> sc.DataArray:
97
97
  return out
98
98
 
99
99
 
100
- def _check_size(da: sc.DataArray) -> None:
100
+ def check_size(da: sc.DataArray) -> None:
101
101
  """
102
102
  Prevent slow figure rendering by raising an error if the data array exceeds a
103
103
  default size.
@@ -133,11 +133,14 @@ def check_not_binned(da: sc.DataArray) -> None:
133
133
  )
134
134
 
135
135
 
136
- def check_allowed_dtypes(da: sc.DataArray) -> None:
136
+ def to_allowed_dtypes(da: sc.DataArray) -> sc.DataArray:
137
137
  """
138
138
  Currently, Plopp cannot plot data that contains vector and matrix dtypes.
139
139
  This function will raise an error if the input data type is not supported.
140
140
 
141
+ We also convert boolean data to integers, as some operations downstream
142
+ may not support boolean data types.
143
+
141
144
  Parameters
142
145
  ----------
143
146
  da:
@@ -147,6 +150,9 @@ def check_allowed_dtypes(da: sc.DataArray) -> None:
147
150
  raise TypeError(
148
151
  f'The input has dtype {da.dtype} which is not supported by Plopp.'
149
152
  )
153
+ if da.dtype == bool:
154
+ da = da.to(dtype='int32')
155
+ return da
150
156
 
151
157
 
152
158
  def _all_dims_sorted(var, order='ascending') -> bool:
@@ -283,11 +289,11 @@ def preprocess(
283
289
 
284
290
  out = to_data_array(obj)
285
291
  check_not_binned(out)
286
- check_allowed_dtypes(out)
292
+ out = to_allowed_dtypes(out)
287
293
  if name is not None:
288
294
  out.name = str(name)
289
295
  if not ignore_size:
290
- _check_size(out)
296
+ check_size(out)
291
297
  if coords is not None:
292
298
  out = _rename_dims_from_coords(out, coords)
293
299
  out = _add_missing_dimension_coords(out)
@@ -338,3 +344,71 @@ def raise_multiple_inputs_for_2d_plot_error(origin):
338
344
  'want to plot two images onto the same axes, use the lower-level '
339
345
  'plopp.imagefigure function.'
340
346
  )
347
+
348
+
349
+ def categorize_args(
350
+ aspect: Literal['auto', 'equal', None] = None,
351
+ autoscale: bool = True,
352
+ cbar: bool = True,
353
+ clabel: str | None = None,
354
+ cmap: str = 'viridis',
355
+ cmax: sc.Variable | float | None = None,
356
+ cmin: sc.Variable | float | None = None,
357
+ errorbars: bool = True,
358
+ figsize: tuple[float, float] | None = None,
359
+ grid: bool = False,
360
+ legend: bool | tuple[float, float] = True,
361
+ logc: bool | None = None,
362
+ logx: bool | None = None,
363
+ logy: bool | None = None,
364
+ mask_cmap: str = 'gray',
365
+ mask_color: str = 'black',
366
+ nan_color: str | None = None,
367
+ norm: Literal['linear', 'log', None] = None,
368
+ scale: dict[str, str] | None = None,
369
+ title: str | None = None,
370
+ vmax: sc.Variable | float | None = None,
371
+ vmin: sc.Variable | float | None = None,
372
+ xlabel: str | None = None,
373
+ xmax: sc.Variable | float | None = None,
374
+ xmin: sc.Variable | float | None = None,
375
+ ylabel: str | None = None,
376
+ ymax: sc.Variable | float | None = None,
377
+ ymin: sc.Variable | float | None = None,
378
+ **kwargs,
379
+ ) -> dict:
380
+ common_args = {
381
+ 'aspect': aspect,
382
+ 'autoscale': autoscale,
383
+ 'figsize': figsize,
384
+ 'grid': grid,
385
+ 'logx': logx,
386
+ 'logy': logy,
387
+ 'mask_color': mask_color,
388
+ 'norm': norm,
389
+ 'scale': scale,
390
+ 'title': title,
391
+ 'vmax': vmax,
392
+ 'vmin': vmin,
393
+ 'xlabel': xlabel,
394
+ 'xmax': xmax,
395
+ 'xmin': xmin,
396
+ 'ylabel': ylabel,
397
+ 'ymax': ymax,
398
+ 'ymin': ymin,
399
+ **kwargs,
400
+ }
401
+ return {
402
+ "1d": {'errorbars': errorbars, 'legend': legend, **common_args},
403
+ "2d": {
404
+ 'cbar': cbar,
405
+ 'cmap': cmap,
406
+ 'cmin': cmin,
407
+ 'cmax': cmax,
408
+ 'clabel': clabel,
409
+ 'logc': logc,
410
+ 'nan_color': nan_color,
411
+ 'mask_cmap': mask_cmap,
412
+ **common_args,
413
+ },
414
+ }