myplotlib 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. myplotlib/__init__.py +79 -0
  2. myplotlib/assets/classic.dark.mplstyle +69 -0
  3. myplotlib/assets/classic.light.mplstyle +60 -0
  4. myplotlib/assets/colormaps/bipolar.csv +1023 -0
  5. myplotlib/assets/colormaps/colt.csv +1024 -0
  6. myplotlib/assets/colormaps/fire.csv +256 -0
  7. myplotlib/assets/colormaps/idl.csv +256 -0
  8. myplotlib/assets/colormaps/sunrise.csv +512 -0
  9. myplotlib/assets/colormaps/thermal.csv +512 -0
  10. myplotlib/assets/colormaps/vanilla.csv +512 -0
  11. myplotlib/assets/fancy.dark.mplstyle +43 -0
  12. myplotlib/assets/fancy.light.mplstyle +32 -0
  13. myplotlib/assets/fonts/AppleChancery.ttf +0 -0
  14. myplotlib/assets/fonts/EBGaramond/EBGaramond-Bold.ttf +0 -0
  15. myplotlib/assets/fonts/EBGaramond/EBGaramond-BoldItalic.ttf +0 -0
  16. myplotlib/assets/fonts/EBGaramond/EBGaramond-ExtraBold.ttf +0 -0
  17. myplotlib/assets/fonts/EBGaramond/EBGaramond-ExtraBoldItalic.ttf +0 -0
  18. myplotlib/assets/fonts/EBGaramond/EBGaramond-Italic.ttf +0 -0
  19. myplotlib/assets/fonts/EBGaramond/EBGaramond-Medium.ttf +0 -0
  20. myplotlib/assets/fonts/EBGaramond/EBGaramond-MediumItalic.ttf +0 -0
  21. myplotlib/assets/fonts/EBGaramond/EBGaramond-Regular.ttf +0 -0
  22. myplotlib/assets/fonts/EBGaramond/EBGaramond-SemiBold.ttf +0 -0
  23. myplotlib/assets/fonts/EBGaramond/EBGaramond-SemiBoldItalic.ttf +0 -0
  24. myplotlib/assets/fonts/Hershey/AVHersheyComplexHeavy.ttf +0 -0
  25. myplotlib/assets/fonts/Hershey/AVHersheyComplexHeavyItalic.ttf +0 -0
  26. myplotlib/assets/fonts/Hershey/AVHersheyComplexLight.ttf +0 -0
  27. myplotlib/assets/fonts/Hershey/AVHersheyComplexLightItalic.ttf +0 -0
  28. myplotlib/assets/fonts/Hershey/AVHersheyComplexMedium.ttf +0 -0
  29. myplotlib/assets/fonts/Hershey/AVHersheyComplexMediumItalic.ttf +0 -0
  30. myplotlib/assets/fonts/Hershey/AVHersheyDuplexHeavy.ttf +0 -0
  31. myplotlib/assets/fonts/Hershey/AVHersheyDuplexHeavyItalic.ttf +0 -0
  32. myplotlib/assets/fonts/Hershey/AVHersheyDuplexLight.ttf +0 -0
  33. myplotlib/assets/fonts/Hershey/AVHersheyDuplexLightItalic.ttf +0 -0
  34. myplotlib/assets/fonts/Hershey/AVHersheyDuplexMedium.ttf +0 -0
  35. myplotlib/assets/fonts/Hershey/AVHersheyDuplexMediumItalic.ttf +0 -0
  36. myplotlib/assets/fonts/Hershey/AVHersheySimplexHeavy.ttf +0 -0
  37. myplotlib/assets/fonts/Hershey/AVHersheySimplexHeavyItalic.ttf +0 -0
  38. myplotlib/assets/fonts/Hershey/AVHersheySimplexLight.ttf +0 -0
  39. myplotlib/assets/fonts/Hershey/AVHersheySimplexLightItalic.ttf +0 -0
  40. myplotlib/assets/fonts/Hershey/AVHersheySimplexMedium.ttf +0 -0
  41. myplotlib/assets/fonts/Hershey/AVHersheySimplexMediumItalic.ttf +0 -0
  42. myplotlib/assets/latex.mplstyle +2 -0
  43. myplotlib/assets/mono.dark.mplstyle +21 -0
  44. myplotlib/assets/mono.light.mplstyle +11 -0
  45. myplotlib/plots.py +658 -0
  46. myplotlib/tests.py +246 -0
  47. myplotlib/tools/__init__.py +0 -0
  48. myplotlib/tools/lic.py +99 -0
  49. myplotlib-1.7.0.dist-info/METADATA +137 -0
  50. myplotlib-1.7.0.dist-info/RECORD +52 -0
  51. myplotlib-1.7.0.dist-info/WHEEL +4 -0
  52. myplotlib-1.7.0.dist-info/licenses/LICENSE +28 -0
myplotlib/plots.py ADDED
@@ -0,0 +1,658 @@
1
+ """
2
+ `myplotlib.plots`
3
+
4
+ a collection of handy plotting functions bound around `matplotlib` with lots of nice perks.
5
+
6
+ * dataPlot .................. : plot generic x & y 1d data (pass an `ax` method)
7
+ * scatter ................... : scatter plot (`dataPlot` with `ax.scatter`)
8
+ * plot ...................... : regular plot (`dataPlot` with `ax.plot`)
9
+ * plot2d .................... : 2d plot using `imshow`
10
+ * plotVectorField ........... : 2d plot with vector field
11
+
12
+ docstrings are available for all of the functions. type, e.g., `dataPlot?` to read about the arguments passed.
13
+ """
14
+
15
+ from typing import TypeAlias, Any, TypedDict, Callable
16
+ import numpy as np
17
+ import matplotlib.colors as mcolors
18
+ from matplotlib.axes._axes import Axes as pltAxes
19
+
20
+ LimTypeWithNone: TypeAlias = tuple[float | None, float | None] | None
21
+ LimType: TypeAlias = tuple[float, float]
22
+
23
+
24
+ def __stretch(
25
+ left: float,
26
+ right: float,
27
+ pad: float,
28
+ ) -> LimType:
29
+ """stretch the limits by a padding factor"""
30
+ c = 0.5 * (left + right)
31
+ d = 0.5 * (right - left)
32
+ return (c - d * pad, c + d * pad)
33
+
34
+
35
+ def __setMinMax(
36
+ lims: LimTypeWithNone,
37
+ data: np.ndarray,
38
+ ) -> LimType:
39
+ """set the limits of the axis according to the data and the passed limits"""
40
+ if lims is None:
41
+ return (np.nanmin(data), np.nanmax(data))
42
+ assert (
43
+ isinstance(lims, tuple) and len(lims) == 2
44
+ ), "lims must be a tuple of length 2"
45
+ if lims[0] is None and lims[1] is None:
46
+ return (np.nanmin(data), np.nanmax(data))
47
+ if lims[0] is None and lims[1] is not None:
48
+ return (np.nanmin(data), lims[1])
49
+ if lims[1] is None and lims[0] is not None:
50
+ return (lims[0], np.nanmax(data))
51
+ assert lims[0] is not None and lims[1] is not None, "lims must not be None"
52
+ return (lims[0], lims[1])
53
+
54
+
55
+ def __setAxLims(
56
+ ax: pltAxes,
57
+ coords,
58
+ log: bool,
59
+ pad: float,
60
+ lims: LimTypeWithNone,
61
+ spines: str,
62
+ ):
63
+ """set the limits of the axis according to the data and the passed limits"""
64
+ lim = __setMinMax(lims, coords)
65
+ # TODO: fix negative when log specified
66
+ if pad > 0:
67
+ ax.spines[spines].set_bounds(*lim)
68
+ if spines == "bottom":
69
+ func_setscale = ax.set_xscale
70
+ func_setlim = ax.set_xlim
71
+ elif spines == "left":
72
+ func_setscale = ax.set_yscale
73
+ func_setlim = ax.set_ylim
74
+ else:
75
+ raise ValueError(f"invalid `spines` value: {spines}")
76
+ if log:
77
+ func_setscale("log")
78
+ p1, p2 = lim
79
+ func_setlim(
80
+ *list(
81
+ map(lambda p: 10**p, __stretch(np.log10(p1), np.log10(p2), 1.0 + pad))
82
+ )
83
+ )
84
+ else:
85
+ p1, p2 = lim
86
+ func_setlim(*__stretch(p1, p2, 1.0 + pad))
87
+
88
+
89
+ def __checkDimensions2d(
90
+ x: np.ndarray,
91
+ y: np.ndarray,
92
+ zz: np.ndarray,
93
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
94
+ """check the dimensions of the passed 2d arrays and return them as 1d arrays"""
95
+ x, y, zz = (
96
+ np.array(np.squeeze(x)),
97
+ np.array(np.squeeze(y)),
98
+ np.array(np.squeeze(zz)),
99
+ )
100
+ readShapes = f"`x.shape={x.shape}`, `y.shape={y.shape}`, `zz.shape={zz.shape}`"
101
+ assert len(x.shape) == len(
102
+ y.shape
103
+ ), f"Shapes of `x` and `y` must be of the same dimension: {readShapes}."
104
+ if len(x.shape) > 1:
105
+ x = x[0, ...]
106
+ if len(y.shape) > 1:
107
+ y = y[..., 0]
108
+ assert (len(zz.shape) == 2) or (
109
+ len(zz.shape) == 3 and ((zz.shape[-1] == 3) or (zz.shape[-1] == 4))
110
+ ), f"`zz` must have exactly 2 non-trivial axes: {readShapes}."
111
+ assert (
112
+ zz.shape[1] == x.shape[0]
113
+ ), f"incompatible dimensions between `x` and `zz`: {readShapes}."
114
+ assert (
115
+ zz.shape[0] == y.shape[0]
116
+ ), f"incompatible dimensions between `y` and `zz`: {readShapes}."
117
+ return (x, y, zz)
118
+
119
+
120
+ def __findExtent(
121
+ x: np.ndarray,
122
+ y: np.ndarray,
123
+ centering: str,
124
+ ) -> tuple[float, float, float, float]:
125
+ if centering == "edge":
126
+ dx = x[1] - x[0]
127
+ dy = y[1] - y[0]
128
+ extent = (x.min(), x.max() + dx, y.min(), y.max() + dy)
129
+ elif centering == "center":
130
+ dx = x[1] - x[0]
131
+ dy = y[1] - y[0]
132
+ extent = (
133
+ x.min() - dx * 0.5,
134
+ x.max() + dx * 0.5,
135
+ y.min() - dy * 0.5,
136
+ y.max() + dy * 0.5,
137
+ )
138
+ else:
139
+ raise ValueError
140
+ return extent
141
+
142
+
143
+ def dataPlot(
144
+ function: Callable,
145
+ ax: pltAxes,
146
+ x: np.ndarray,
147
+ y: np.ndarray,
148
+ xlog: bool = False,
149
+ ylog: bool = False,
150
+ xlim: LimTypeWithNone = None,
151
+ ylim: LimTypeWithNone = None,
152
+ padx: float = 0.0,
153
+ pady: float = 0.0,
154
+ **kwargs,
155
+ ):
156
+ """Add a plot according to a passed function
157
+
158
+ Args
159
+ ----
160
+ function : Callable
161
+ The function to call on the axis (e.g., `ax.plot`, `ax.scatter`).
162
+ ax : pltAxes
163
+ The matplotlib axis object.
164
+ x, y : np.ndarray
165
+ The data to plot.
166
+ xlog : bool, optional
167
+ Use logarithmic scale for x-axis (default is False).
168
+ ylog : bool, optional
169
+ Use logarithmic scale for y-axis (default is False).
170
+ xlim : tuple[float | None, float | None] | None, optional
171
+ Tuple of x limits (None = determine from data) (default is None).
172
+ ylim : tuple[float | None, float | None] | None, optional
173
+ Tuple of y limits (None = determine from data) (default is None).
174
+ padx : float, optional
175
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
176
+ pady : float, optional
177
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
178
+ **kwargs : dict, optional
179
+ Standard matplotlib kwargs passed to `function`.
180
+ """
181
+ if padx != 0:
182
+ ax.spines["top"].set_visible(False)
183
+ if pady != 0:
184
+ ax.spines["right"].set_visible(False)
185
+ function(x, y, **kwargs)
186
+ __setAxLims(ax, x, xlog, padx, xlim, "bottom")
187
+ __setAxLims(ax, y, ylog, pady, ylim, "left")
188
+ return None
189
+
190
+
191
+ def scatter(
192
+ ax: pltAxes,
193
+ x: np.ndarray,
194
+ y: np.ndarray,
195
+ xlog: bool = False,
196
+ ylog: bool = False,
197
+ xlim: LimTypeWithNone = None,
198
+ ylim: LimTypeWithNone = None,
199
+ padx: float = 0.0,
200
+ pady: float = 0.0,
201
+ **kwargs,
202
+ ):
203
+ """Add a scatter plot to a given axis
204
+
205
+ Args
206
+ ----
207
+ ax : pltAxes
208
+ The matplotlib axis object.
209
+ x, y : np.ndarray
210
+ The data to plot.
211
+ xlog : bool, optional
212
+ Use logarithmic scale for x-axis (default is False).
213
+ ylog : bool, optional
214
+ Use logarithmic scale for y-axis (default is False).
215
+ xlim : LimTypeWithNone, optional
216
+ Tuple of x limits (None = determine from data) (default is None).
217
+ ylim : LimTypeWithNone, optional
218
+ Tuple of y limits (None = determine from data) (default is None).
219
+ padx : float, optional
220
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
221
+ pady : float, optional
222
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
223
+ **kwargs : dict, optional
224
+ Standard matplotlib kwargs passed to `ax.scatter`.
225
+ """
226
+ return dataPlot(ax.scatter, ax, x, y, xlog, ylog, xlim, ylim, padx, pady, **kwargs)
227
+
228
+
229
+ def plot(
230
+ ax: pltAxes,
231
+ x: np.ndarray,
232
+ y: np.ndarray,
233
+ xlog: bool = False,
234
+ ylog: bool = False,
235
+ xlim: LimTypeWithNone = None,
236
+ ylim: LimTypeWithNone = None,
237
+ padx: float = 0.0,
238
+ pady: float = 0.0,
239
+ **kwargs,
240
+ ):
241
+ """Add a plot to a given axis (same as `dataPlot(ax.plot, ...)`)
242
+
243
+ Args
244
+ ----
245
+ ax : pltAxes
246
+ The matplotlib axis object.
247
+ x, y : np.ndarray
248
+ The data to plot.
249
+ xlog : bool, optional
250
+ Use logarithmic scale for x-axis (default is False).
251
+ ylog : bool, optional
252
+ Use logarithmic scale for y-axis (default is False).
253
+ xlim : LimTypeWithNone, optional
254
+ Tuple of x limits (None = determine from data) (default is None).
255
+ ylim : LimTypeWithNone, optional
256
+ Tuple of y limits (None = determine from data) (default is None).
257
+ padx : float, optional
258
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
259
+ pady : float, optional
260
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
261
+ **kwargs : dict, optional
262
+ Standard matplotlib kwargs passed to `function`.
263
+ """
264
+ dataPlot(ax.plot, ax, x, y, xlog, ylog, xlim, ylim, padx, pady, **kwargs)
265
+
266
+
267
+ def plot2d(
268
+ ax: pltAxes,
269
+ x: np.ndarray,
270
+ y: np.ndarray,
271
+ zz: np.ndarray,
272
+ force_aspect: bool = True,
273
+ centering: str = "edge",
274
+ xlim: LimTypeWithNone = None,
275
+ ylim: LimTypeWithNone = None,
276
+ zlog: bool = False,
277
+ zlim: LimTypeWithNone = None,
278
+ padx: float = 0.0,
279
+ pady: float = 0.0,
280
+ cbar: str | None = "5%",
281
+ cbar_pad: float = 0.05,
282
+ cbar_pos: str = "right",
283
+ **kwargs,
284
+ ):
285
+ """Add a 2d plot to a given axis
286
+
287
+ Args
288
+ ----
289
+ ax : matplotlib axis object
290
+ The axis to plot on.
291
+ x, y : 1d or 2d arrays of coordinates
292
+ The coordinates of the data to plot.
293
+ force_aspect : bool, optional
294
+ Force equal aspect ratio according to axes (default is True).
295
+ centering : str, optional
296
+ Centering of x & y nodes for the data ('edge', 'center') (default is 'edge').
297
+ xlim : tuple of float, optional
298
+ Tuple of x limits (None = determine from x) (default is None).
299
+ ylim : tuple of float, optional
300
+ Tuple of y limits (None = determine from y) (default is None).
301
+ zlog : bool, optional
302
+ Use log in z ('True', 'False') (default is False).
303
+ zlim : tuple of float, optional
304
+ Tuple of z limits (None = determine from z) (default is None).
305
+ padx : float, optional
306
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
307
+ pady : float, optional
308
+ Add whitespace to axes in each direction (0 = no additional space) (default is 0.0).
309
+ cbar : str or None, optional
310
+ Size of the colorbar in percent of x-axis (None = no colorbar) (default is '5%').
311
+ cbar_pad : float, optional
312
+ Padding of the colorbar (default is 0.05).
313
+ cbar_pos : str, optional
314
+ Position of the colorbar ('left', 'right', 'top', 'bottom') (default is 'right').
315
+ **kwargs : dict, optional
316
+ Standard matplotlib kwargs passed to `ax.imshow`.
317
+
318
+ Returns
319
+ -------
320
+ None or colorbar handle
321
+ Returns `None` if `cbar` is `None`, otherwise returns the colorbar handle.
322
+
323
+ Raises
324
+ ------
325
+ AssertionError
326
+ If `centering` is not 'edge' or 'center', or if `cbar_pos` is not one of 'left', 'right', 'top', or 'bottom'.
327
+
328
+ """
329
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
330
+ import matplotlib.pyplot as plt
331
+
332
+ assert centering in ["edge", "center"], "invalid `centering`"
333
+ assert cbar_pos in ["left", "right", "top", "bottom"], "invalid `cbar_pos`"
334
+
335
+ x, y, zz = __checkDimensions2d(x, y, zz)
336
+ ax.grid(False)
337
+ extent = __findExtent(x, y, centering)
338
+ aspect = "auto" if not force_aspect else None
339
+ if not ("norm" in kwargs):
340
+ zminQ = np.quantile(zz[~np.isnan(zz) & ~np.isinf(zz)], 0.05)
341
+ zmaxQ = np.quantile(zz[~np.isnan(zz) & ~np.isinf(zz)], 0.95)
342
+ if zlim is not None:
343
+ if zlim[0] is None:
344
+ vmin = zminQ
345
+ else:
346
+ vmin = zlim[0]
347
+ if zlim[1] is None:
348
+ vmax = zmaxQ
349
+ else:
350
+ vmax = zlim[1]
351
+ else:
352
+ vmin, vmax = zminQ, zmaxQ
353
+ if zlog:
354
+ norm = mcolors.LogNorm(vmin=float(vmin), vmax=float(vmax))
355
+ else:
356
+ norm = mcolors.Normalize(vmin=float(vmin), vmax=float(vmax))
357
+ else:
358
+ norm = kwargs.get("norm")
359
+ kwargs.pop("norm")
360
+ ax.imshow(zz, origin="lower", extent=extent, aspect=aspect, norm=norm, **kwargs)
361
+ __setAxLims(ax, np.linspace(extent[0], extent[1]), False, padx, xlim, "bottom")
362
+ __setAxLims(ax, np.linspace(extent[2], extent[3]), False, pady, ylim, "left")
363
+ if cbar is not None:
364
+ divider = make_axes_locatable(ax)
365
+ cax = divider.append_axes(cbar_pos, size=cbar, pad=cbar_pad)
366
+ colorbar = plt.colorbar(
367
+ ax.get_images()[0],
368
+ cax=cax,
369
+ orientation="vertical" if cbar_pos in ["left", "right"] else "horizontal",
370
+ )
371
+ if cbar_pos == "left":
372
+ cax.yaxis.set_ticks_position("left")
373
+ cax.yaxis.set_label_position("left")
374
+ ax.yaxis.set_ticks_position("right")
375
+ ax.yaxis.set_label_position("right")
376
+ if cbar_pos == "top":
377
+ cax.xaxis.set_ticks_position("top")
378
+ cax.xaxis.set_label_position("top")
379
+ if cbar_pos == "bottom":
380
+ ax.xaxis.set_ticks_position("top")
381
+ ax.xaxis.set_label_position("top")
382
+ return colorbar
383
+ else:
384
+ return None
385
+
386
+
387
+ def plotVectorField(
388
+ ax: pltAxes,
389
+ x: np.ndarray,
390
+ y: np.ndarray,
391
+ fx: np.ndarray,
392
+ fy: np.ndarray,
393
+ background: np.ndarray | None = None,
394
+ texture_seed: int | None = None,
395
+ kernel_len: int = 31,
396
+ kernel_pow: int = 1,
397
+ lic_alphamin: float = 0.5,
398
+ lic_alphamax: float = 0.75,
399
+ lic_contrast: float = 0.33,
400
+ lic_opacity: float = 0.75,
401
+ lic_cmap: str = "binary_r",
402
+ force_aspect: bool = True,
403
+ centering: str = "edge",
404
+ xlim: tuple[float, float] | None = None,
405
+ ylim: tuple[float, float] | None = None,
406
+ padx: float = 0.0,
407
+ pady: float = 0.0,
408
+ cbar: str | None = "5%",
409
+ cbar_pad: float = 0.05,
410
+ **kwargs,
411
+ ):
412
+ """Add a 2D plot with a vector-field overplotted
413
+
414
+ Args
415
+ ----
416
+ ax : pltAxes
417
+ The matplotlib axis object.
418
+ x, y : np.ndarray
419
+ 1D or 2D arrays of coordinates.
420
+ fx, fy : np.ndarray
421
+ 2D arrays of the vector field components.
422
+ background : np.ndarray | None, optional
423
+ 2D array of the image background (None = `sqrt(fx^2 + fy^2)`).
424
+
425
+ texture_seed : int | None, optional
426
+ Specify a random seed to generate textures, useful when rendering movies (None = random).
427
+ kernel_len : int, optional
428
+ Kernel resolution for the LIC algorithm (default is 31).
429
+ kernel_pow : int, optional
430
+ Kernel sharpness for the LIC algorithm (default is 1).
431
+ lic_alphamin : float, optional
432
+
433
+ """
434
+
435
+ """
436
+ add a 2d plot with a vector-field overplotted
437
+
438
+ args
439
+ ----------
440
+ ax .......................... : matplotlib axis object
441
+ x, y ........................ : 1d or 2d arrays of coordinates
442
+ fx, fy ...................... : 2d arrays of the vector field components
443
+ background [None] ........... : 2d array of the image background (None = `sqrt(fx^2 + fy^2)`)
444
+
445
+ line integral convolution (lic) parameters
446
+ ---------
447
+ texture_seed [None] ......... : specify a random seed to generate textures, useful when rendering movies (None = random)
448
+ kernel_len [31] ............. : kernel resolution for the lic algorithm
449
+ kernel_pow [1] .............. : kernel sharpness for the lic algorithm
450
+ lic_alphamin [0.5] .......... : lic parameter for min transparency
451
+ lic_alphamax [0.75] ......... : lic parameter for max transparency
452
+ lic_contrast [0.33] ......... : lic parameter for the contrast
453
+ lic_opacity [0.75] .......... : lic parameter for the absolute opacity of the field plot
454
+ lic_cmap ['binary_r'] ....... : colormap used for the lic texture
455
+
456
+ the rest of the args are the same as for the `plot2d`
457
+ ----------
458
+ force_aspect [True] ......... : force equal aspect ratio according to axes
459
+ centering ['edge'] .......... : centering of x & y nodes for the data ('edge', 'center')
460
+ xlim [None], ylim [None] .... : tuples of x and y limits (None = determine from x & y)
461
+ padx [0], pady [0] .......... : add whitespace to axes in each direction (0 = no additional space)
462
+ cbar ['5%'] ................. : size of the colorbar in percent of x-axis (None = no colorbar)
463
+ cbar_pad [0.05] ............. : padding of the colorbar
464
+ **kwargs .................... : standard matplotlib kwargs passed to `ax.imshow`
465
+ """
466
+ import myplotlib.tools.lic as lic
467
+ import matplotlib
468
+ import matplotlib.pyplot as plt
469
+
470
+ kernel = lic.generate_kernel(kernel_len) ** kernel_pow
471
+ x, y, fx = __checkDimensions2d(x, y, fx)
472
+ x, y, fy = __checkDimensions2d(x, y, fy)
473
+ if background is None:
474
+ background = np.sqrt(fx**2 + fy**2)
475
+ # line integral convolution doesn't like zeros
476
+ fmin = (np.abs(fx).min() + np.abs(fy).min()) / 1e10
477
+ fx = (1.0 * (fx >= 0) - 1.0 * (fx < 0)) * (np.abs(fx) + fmin)
478
+ fy = (1.0 * (fy >= 0) - 1.0 * (fy < 0)) * (np.abs(fy) + fmin)
479
+
480
+ x, y, background = __checkDimensions2d(x, y, background)
481
+ texture = lic.generate_texture(background.shape, texture_seed)
482
+ img1 = lic.line_integral_convolution(fx, fy, texture, kernel)
483
+ img2 = lic.line_integral_convolution(-fx, -fy, texture, kernel)
484
+ img = 0.5 * (img1 + img2)
485
+
486
+ weights = img
487
+ _ = np.sign(weights - np.average(weights)) * np.sqrt(
488
+ np.abs(weights - np.average(weights))
489
+ )
490
+ alphas = mcolors.Normalize(None, None, clip=True)(_)
491
+ alphas[alphas < lic_alphamin] = 0
492
+ alphas[alphas > lic_alphamax] = 1
493
+ _ = (
494
+ np.sign(weights - np.average(weights))
495
+ * np.abs(weights - np.average(weights)) ** lic_contrast
496
+ )
497
+ colors = mcolors.Normalize(None, None)(_)
498
+ colors = matplotlib.colormaps[lic_cmap](colors)
499
+ colors[..., -1] = alphas
500
+
501
+ colorbar = plot2d(
502
+ ax,
503
+ x,
504
+ y,
505
+ background,
506
+ force_aspect=force_aspect,
507
+ centering=centering,
508
+ xlim=xlim,
509
+ ylim=ylim,
510
+ padx=padx,
511
+ pady=pady,
512
+ cbar=cbar,
513
+ cbar_pad=cbar_pad,
514
+ **kwargs,
515
+ )
516
+ plot2d(
517
+ ax,
518
+ x,
519
+ y,
520
+ colors,
521
+ force_aspect=force_aspect,
522
+ centering=centering,
523
+ xlim=xlim,
524
+ ylim=ylim,
525
+ padx=padx,
526
+ pady=pady,
527
+ cbar=cbar,
528
+ cbar_pad=cbar_pad,
529
+ alpha=lic_opacity,
530
+ )
531
+ return colorbar
532
+
533
+
534
+ class PanelDict(TypedDict):
535
+ label: str | None
536
+ field: Callable | None
537
+ cmap: str | None
538
+ norm: mcolors.Normalize | None
539
+
540
+
541
+ def plot2dGrid(
542
+ x: np.ndarray,
543
+ y: np.ndarray,
544
+ fields: dict[str, np.ndarray],
545
+ panels: list[list[PanelDict]],
546
+ label_pos: str = "title",
547
+ label_args: dict[str, Any] = {},
548
+ width: float = 10,
549
+ dpi: int = 150,
550
+ wspace: float = 0.05,
551
+ hspace: float = 0.05,
552
+ **kwargs,
553
+ ):
554
+ """
555
+ add a grid of 2d plots with shared axes
556
+
557
+ args
558
+ ----------
559
+ x, y ........................ : 1d or 2d arrays of coordinates
560
+ fields ...................... : dictionary of all the fields
561
+ panels ...................... : array of array of dictionaries indicating the panels to plot (see note below)
562
+ label_pos ['title'] ......... : position of the label ('title', 'cbar', 'text', None)
563
+ label_args [{}] ............. : arguments for the label (color, fontsize, etc; passed to `ax.set_title`, )
564
+
565
+ arguments for the figure
566
+ ----------
567
+ width [10] .................. : width of the figure in inches
568
+ dpi [150] ................... : resolution of the figure [dots per inch]
569
+ wspace [0.05] ............... : width space between the panels (as fraction of the panel width)
570
+ hspace [0.05] ............... : height space between the panels (as fraction of the panel height)
571
+
572
+ the rest of the args are the same as for the `plot2d`
573
+ ----------
574
+ force_aspect [True] ......... : force equal aspect ratio according to axes
575
+ centering ['edge'] .......... : centering of x & y nodes for the data ('edge', 'center')
576
+ xlim [None], ylim [None] .... : tuples of x and y limits (None = determine from x & y)
577
+ padx [0], pady [0] .......... : add whitespace to axes in each direction (0 = no additional space)
578
+ cbar ['5%'] ................. : size of the colorbar in percent of x-axis (None = no colorbar)
579
+ cbar_pad [0.05] ............. : padding of the colorbar
580
+ **kwargs .................... : standard matplotlib kwargs passed to `ax.imshow`
581
+
582
+ note
583
+ ----------
584
+ the `panels` is an `n x m` array, where `n` is the number of rows and `m` is the number of columns.
585
+ each element of the array is a dictionary with the following keys:
586
+ - 'label' ............... : label for the field
587
+ - 'field' ............... : lambda function which takes the `fields` dictionary and returns the quantity to plot
588
+ - 'cmap' ................ : colormap of the panel
589
+ - 'norm' ................ : normalization object
590
+ """
591
+ import matplotlib.pyplot as plt
592
+
593
+ assert len(panels) > 0, "no panels to plot"
594
+ assert len(panels[0]) > 0, "no panels to plot"
595
+ assert all(
596
+ [len(row) == len(panels[0]) for row in panels]
597
+ ), "all rows must have the same number of panels"
598
+ assert label_pos in ["title", "cbar", "text", None], "invalid label position"
599
+
600
+ ncols = len(panels[0])
601
+ nrows = len(panels)
602
+
603
+ xlims = kwargs.get("xlim", (x.min(), x.max()))
604
+ ylims = kwargs.get("ylim", (y.min(), y.max()))
605
+ aspect = (xlims[1] - xlims[0]) / (ylims[1] - ylims[0])
606
+ height = (
607
+ width
608
+ * ((nrows + hspace * (nrows - 1)) / (ncols + wspace * (ncols - 1)))
609
+ / aspect
610
+ )
611
+
612
+ fig = plt.figure(figsize=(width, height), dpi=dpi)
613
+
614
+ gs = fig.add_gridspec(nrows, ncols, wspace=wspace, hspace=hspace)
615
+ axs = [[fig.add_subplot(gs[i, j]) for j in range(ncols)] for i in range(nrows)]
616
+
617
+ label_coords = label_args.pop("position", (0.05, 0.95))
618
+
619
+ for i in range(nrows):
620
+ for j in range(ncols):
621
+ ax = axs[i][j]
622
+ panel = panels[i][j]
623
+ assert "field" in panel, "panel must have a 'field' key"
624
+ field_func = panel["field"]
625
+ assert field_func is not None, "field must be a callable function"
626
+ cbar = plot2d(
627
+ ax,
628
+ x,
629
+ y,
630
+ field_func(fields),
631
+ norm=panel["norm"],
632
+ cmap=panel["cmap"],
633
+ **kwargs,
634
+ )
635
+
636
+ if j != 0:
637
+ ax.set(ylabel=None, yticklabels=[])
638
+ if i != nrows - 1:
639
+ ax.set(xlabel=None, xticklabels=[])
640
+
641
+ if label_pos == "title":
642
+ assert "label" in panel, "panel must have a 'label' key"
643
+ if panel["label"] is not None:
644
+ ax.set_title(panel["label"], **label_args)
645
+ elif label_pos == "cbar":
646
+ if cbar is not None:
647
+ assert "label" in panel, "panel must have a 'label' key"
648
+ if panel["label"] is not None:
649
+ cbar.set_label(panel["label"], **label_args)
650
+ elif label_pos == "text":
651
+ assert "label" in panel, "panel must have a 'label' key"
652
+ if panel["label"] is not None:
653
+ ax.text(
654
+ *label_coords,
655
+ s=panel["label"],
656
+ transform=ax.transAxes,
657
+ **label_args,
658
+ )