scitex 2.3.0__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. scitex/ai/classification/reporters/reporter_utils/_Plotter.py +1 -1
  2. scitex/ai/plt/__init__.py +2 -2
  3. scitex/ai/plt/{_plot_conf_mat.py → _stx_conf_mat.py} +3 -3
  4. scitex/config/PriorityConfig.py +195 -0
  5. scitex/config/__init__.py +24 -0
  6. scitex/io/_save.py +125 -34
  7. scitex/io/_save_modules/_image.py +37 -20
  8. scitex/plt/__init__.py +470 -17
  9. scitex/plt/_subplots/_AxisWrapper.py +98 -50
  10. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +559 -124
  11. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +49 -8
  12. scitex/plt/_subplots/_SubplotsWrapper.py +76 -91
  13. scitex/plt/_subplots/_export_as_csv.py +127 -58
  14. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +25 -16
  15. scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py +54 -0
  16. scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py +41 -0
  17. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py +41 -0
  18. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -47
  19. scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py +42 -0
  20. scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py +42 -0
  21. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +72 -35
  22. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +1 -1
  23. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +2 -2
  24. scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py +53 -0
  25. scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py +42 -0
  26. scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py +42 -0
  27. scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py +48 -0
  28. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_conf_mat.py → _format_stx_conf_mat.py} +2 -2
  29. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_ecdf.py → _format_stx_ecdf.py} +2 -2
  30. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_fillv.py → _format_stx_fillv.py} +2 -2
  31. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_heatmap.py → _format_stx_heatmap.py} +2 -2
  32. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_image.py → _format_stx_image.py} +2 -2
  33. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_joyplot.py → _format_stx_joyplot.py} +2 -2
  34. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_line.py → _format_stx_line.py} +3 -3
  35. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_ci.py → _format_stx_mean_ci.py} +2 -2
  36. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_std.py → _format_stx_mean_std.py} +2 -2
  37. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_median_iqr.py → _format_stx_median_iqr.py} +2 -2
  38. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_raster.py → _format_stx_raster.py} +2 -2
  39. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_rectangle.py → _format_stx_rectangle.py} +1 -1
  40. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_scatter_hist.py → _format_stx_scatter_hist.py} +2 -2
  41. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_shaded_line.py → _format_stx_shaded_line.py} +2 -2
  42. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_violin.py → _format_stx_violin.py} +2 -2
  43. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +23 -23
  44. scitex/plt/ax/__init__.py +16 -15
  45. scitex/plt/ax/_plot/__init__.py +30 -30
  46. scitex/plt/ax/_plot/_add_fitted_line.py +65 -11
  47. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +104 -76
  48. scitex/plt/ax/_plot/{_plot_conf_mat.py → _stx_conf_mat.py} +10 -10
  49. scitex/plt/ax/_plot/_stx_ecdf.py +109 -0
  50. scitex/plt/ax/_plot/{_plot_fillv.py → _stx_fillv.py} +7 -7
  51. scitex/plt/ax/_plot/_stx_heatmap.py +366 -0
  52. scitex/plt/ax/_plot/{_plot_image.py → _stx_image.py} +1 -1
  53. scitex/plt/ax/_plot/_stx_joyplot.py +113 -0
  54. scitex/plt/ax/_plot/{_plot_raster.py → _stx_raster.py} +37 -25
  55. scitex/plt/ax/_plot/{_plot_rectangle.py → _stx_rectangle.py} +10 -9
  56. scitex/plt/ax/_plot/{_plot_scatter_hist.py → _stx_scatter_hist.py} +1 -1
  57. scitex/plt/ax/_plot/_stx_shaded_line.py +215 -0
  58. scitex/plt/ax/_plot/{_plot_violin.py → _stx_violin.py} +13 -6
  59. scitex/plt/ax/_style/__init__.py +3 -0
  60. scitex/plt/ax/_style/_style_barplot.py +13 -2
  61. scitex/plt/ax/_style/_style_boxplot.py +78 -32
  62. scitex/plt/ax/_style/_style_errorbar.py +17 -3
  63. scitex/plt/ax/_style/_style_scatter.py +17 -3
  64. scitex/plt/ax/_style/_style_violinplot.py +109 -0
  65. scitex/plt/color/_vizualize_colors.py +3 -3
  66. scitex/plt/styles/SCITEX_STYLE.yaml +104 -0
  67. scitex/plt/styles/__init__.py +57 -0
  68. scitex/plt/styles/_plot_defaults.py +209 -0
  69. scitex/plt/styles/_plot_postprocess.py +518 -0
  70. scitex/plt/styles/_style_loader.py +268 -0
  71. scitex/plt/styles/presets.py +208 -0
  72. scitex/plt/utils/_collect_figure_metadata.py +160 -18
  73. scitex/plt/utils/_colorbar.py +72 -10
  74. scitex/plt/utils/_configure_mpl.py +108 -52
  75. scitex/plt/utils/_crop.py +21 -7
  76. scitex/plt/utils/_figure_mm.py +21 -7
  77. scitex/stats/__init__.py +13 -1
  78. scitex/stats/_schema.py +578 -0
  79. scitex/stats/tests/__init__.py +13 -0
  80. scitex/stats/tests/correlation/__init__.py +13 -0
  81. scitex/stats/tests/correlation/_test_pearson.py +262 -0
  82. scitex/vis/__init__.py +6 -0
  83. scitex/vis/editor/__init__.py +23 -0
  84. scitex/vis/editor/_defaults.py +205 -0
  85. scitex/vis/editor/_edit.py +342 -0
  86. scitex/vis/editor/_mpl_editor.py +231 -0
  87. scitex/vis/editor/_tkinter_editor.py +466 -0
  88. scitex/vis/editor/_web_editor.py +1440 -0
  89. scitex/vis/model/plot_types.py +15 -15
  90. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/METADATA +2 -1
  91. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/RECORD +94 -67
  92. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/WHEEL +1 -1
  93. scitex/plt/ax/_plot/_plot_ecdf.py +0 -84
  94. scitex/plt/ax/_plot/_plot_heatmap.py +0 -277
  95. scitex/plt/ax/_plot/_plot_joyplot.py +0 -77
  96. scitex/plt/ax/_plot/_plot_shaded_line.py +0 -142
  97. scitex/plt/presets.py +0 -224
  98. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/entry_points.txt +0 -0
  99. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,366 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-12-01 13:00:00 (ywatanabe)"
4
+ # File: ./src/scitex/plt/ax/_plot/_plot_heatmap.py
5
+
6
+ """Heatmap plotting with automatic annotation color switching."""
7
+
8
+ from typing import Any, List, Optional, Tuple, Union
9
+
10
+ import matplotlib
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from matplotlib.axes import Axes
14
+ from matplotlib.colorbar import Colorbar
15
+ from matplotlib.image import AxesImage
16
+
17
+
18
+ def stx_heatmap(
19
+ ax: Union[Axes, "AxisWrapper"],
20
+ values_2d: np.ndarray,
21
+ x_labels: Optional[List[str]] = None,
22
+ y_labels: Optional[List[str]] = None,
23
+ cmap: str = "viridis",
24
+ cbar_label: str = "ColorBar Label",
25
+ annot_format: str = "{x:.1f}",
26
+ show_annot: bool = True,
27
+ annot_color_lighter: str = "black",
28
+ annot_color_darker: str = "white",
29
+ **kwargs: Any,
30
+ ) -> Tuple[Union[Axes, "AxisWrapper"], AxesImage, Colorbar]:
31
+ """Plot a heatmap on the given axes with automatic annotation colors.
32
+
33
+ Creates a heatmap visualization with optional cell annotations. Annotation
34
+ text colors are automatically switched based on background brightness for
35
+ optimal readability.
36
+
37
+ Parameters
38
+ ----------
39
+ ax : matplotlib.axes.Axes or AxisWrapper
40
+ The axes to plot on.
41
+ values_2d : np.ndarray, shape (n_rows, n_cols)
42
+ 2D array of data to display as heatmap.
43
+ x_labels : list of str, optional
44
+ Labels for the x-axis (columns).
45
+ y_labels : list of str, optional
46
+ Labels for the y-axis (rows).
47
+ cmap : str, default "viridis"
48
+ Colormap name to use.
49
+ cbar_label : str, default "ColorBar Label"
50
+ Label for the colorbar.
51
+ annot_format : str, default "{x:.1f}"
52
+ Format string for cell annotations.
53
+ show_annot : bool, default True
54
+ Whether to annotate the heatmap with values.
55
+ annot_color_lighter : str, default "black"
56
+ Text color for annotations on lighter backgrounds.
57
+ annot_color_darker : str, default "white"
58
+ Text color for annotations on darker backgrounds.
59
+ **kwargs : dict
60
+ Additional keyword arguments passed to imshow().
61
+
62
+ Returns
63
+ -------
64
+ ax : matplotlib.axes.Axes or AxisWrapper
65
+ The axes with the heatmap.
66
+ im : matplotlib.image.AxesImage
67
+ The image object created by imshow.
68
+ cbar : matplotlib.colorbar.Colorbar
69
+ The colorbar object.
70
+
71
+ Examples
72
+ --------
73
+ >>> import numpy as np
74
+ >>> import scitex as stx
75
+ >>> data = np.random.rand(5, 10)
76
+ >>> fig, ax = stx.plt.subplots()
77
+ >>> ax, im, cbar = stx.plt.ax.stx_heatmap(
78
+ ... ax, data,
79
+ ... x_labels=[f"X{i}" for i in range(10)],
80
+ ... y_labels=[f"Y{i}" for i in range(5)],
81
+ ... cmap="Blues"
82
+ ... )
83
+ """
84
+
85
+ im, cbar = _mpl_heatmap(
86
+ values_2d,
87
+ x_labels,
88
+ y_labels,
89
+ ax=ax,
90
+ cmap=cmap,
91
+ cbarlabel=cbar_label,
92
+ )
93
+
94
+ if show_annot:
95
+ textcolors = _switch_annot_colors(cmap, annot_color_lighter, annot_color_darker)
96
+ texts = _mpl_annotate_heatmap(
97
+ im,
98
+ valfmt=annot_format,
99
+ textcolors=textcolors,
100
+ )
101
+
102
+ return ax, im, cbar
103
+
104
+
105
+ def _switch_annot_colors(
106
+ cmap: str,
107
+ annot_color_lighter: str,
108
+ annot_color_darker: str,
109
+ ) -> Tuple[str, str]:
110
+ """Determine annotation text colors based on colormap brightness.
111
+
112
+ Uses perceived brightness (ITU-R BT.709) to select appropriate text
113
+ colors for light vs dark backgrounds in the colormap.
114
+
115
+ Parameters
116
+ ----------
117
+ cmap : str
118
+ Colormap name.
119
+ annot_color_lighter : str
120
+ Color to use on lighter backgrounds.
121
+ annot_color_darker : str
122
+ Color to use on darker backgrounds.
123
+
124
+ Returns
125
+ -------
126
+ tuple of str
127
+ (color_for_dark_bg, color_for_light_bg) text colors.
128
+ """
129
+ cmap_obj = plt.cm.get_cmap(cmap)
130
+
131
+ # Sample colormap at extremes (avoiding edge effects)
132
+ dark_color = cmap_obj(0.1)
133
+ light_color = cmap_obj(0.9)
134
+
135
+ # Calculate perceived brightness using ITU-R BT.709 coefficients
136
+ dark_brightness = (
137
+ 0.2126 * dark_color[0] + 0.7152 * dark_color[1] + 0.0722 * dark_color[2]
138
+ )
139
+
140
+ # Choose text colors based on background brightness
141
+ if dark_brightness < 0.5:
142
+ return (annot_color_lighter, annot_color_darker)
143
+ else:
144
+ return (annot_color_darker, annot_color_lighter)
145
+
146
+
147
+ def _mpl_heatmap(
148
+ data: np.ndarray,
149
+ row_labels: Optional[List[str]],
150
+ col_labels: Optional[List[str]],
151
+ ax: Optional[Axes] = None,
152
+ cbar_kw: Optional[dict] = None,
153
+ cbarlabel: str = "",
154
+ **kwargs: Any,
155
+ ) -> Tuple[AxesImage, Colorbar]:
156
+ """Create a heatmap with imshow and add a colorbar.
157
+
158
+ Parameters
159
+ ----------
160
+ data : np.ndarray
161
+ 2D array of data to display.
162
+ row_labels : list of str or None
163
+ Labels for the rows (y-axis).
164
+ col_labels : list of str or None
165
+ Labels for the columns (x-axis).
166
+ ax : matplotlib.axes.Axes, optional
167
+ Axes to plot on. If None, uses current axes.
168
+ cbar_kw : dict, optional
169
+ Keyword arguments for colorbar creation.
170
+ cbarlabel : str, default ""
171
+ Label for the colorbar.
172
+ **kwargs : dict
173
+ Additional keyword arguments passed to imshow().
174
+
175
+ Returns
176
+ -------
177
+ im : matplotlib.image.AxesImage
178
+ The image object.
179
+ cbar : matplotlib.colorbar.Colorbar
180
+ The colorbar object.
181
+ """
182
+
183
+ if ax is None:
184
+ ax = plt.gca()
185
+
186
+ if cbar_kw is None:
187
+ cbar_kw = {}
188
+
189
+ # Plot the heatmap
190
+ im = ax.imshow(data, **kwargs)
191
+
192
+ # Create colorbar with proper formatting
193
+ cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
194
+ cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
195
+
196
+ # Set colorbar border width to match axes spines
197
+ cbar.outline.set_linewidth(0.2 * 2.83465) # 0.2mm in points
198
+
199
+ # Format colorbar ticks
200
+ from matplotlib.ticker import MaxNLocator
201
+ cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3))
202
+ cbar.ax.tick_params(width=0.2 * 2.83465, length=0.8 * 2.83465) # Match tick styling
203
+
204
+ # Show all ticks and label them with the respective list entries.
205
+ ax.set_xticks(
206
+ range(data.shape[1]),
207
+ labels=col_labels,
208
+ # rotation=45,
209
+ # ha="right",
210
+ # rotation_mode="anchor",
211
+ )
212
+ ax.set_yticks(range(data.shape[0]), labels=row_labels)
213
+
214
+ # Let the horizontal axes labeling appear on top.
215
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
216
+
217
+ # Show all 4 spines for heatmap
218
+ ax.spines[:].set_visible(True)
219
+
220
+ # Set aspect ratio to 'equal' for square cells (1:1)
221
+ ax.set_aspect('equal', adjustable='box')
222
+
223
+ ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
224
+ ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
225
+ ax.tick_params(which="minor", bottom=False, left=False)
226
+
227
+ return im, cbar
228
+
229
+
230
+ def _calc_annot_fontsize(n_rows: int, n_cols: int) -> float:
231
+ """Calculate dynamic annotation font size based on cell count.
232
+
233
+ Uses a base size of 6pt for small heatmaps and scales down for larger ones.
234
+
235
+ Parameters
236
+ ----------
237
+ n_rows : int
238
+ Number of rows in the heatmap.
239
+ n_cols : int
240
+ Number of columns in the heatmap.
241
+
242
+ Returns
243
+ -------
244
+ float
245
+ Font size in points.
246
+ """
247
+ # Base font size for small heatmaps (e.g., 5x5)
248
+ BASE_FONTSIZE = 6.0
249
+ BASE_CELLS = 5 # Reference dimension
250
+
251
+ # Use the larger dimension to scale
252
+ max_dim = max(n_rows, n_cols)
253
+
254
+ if max_dim <= BASE_CELLS:
255
+ return BASE_FONTSIZE
256
+ elif max_dim <= 10:
257
+ # Linear interpolation: 6pt at 5 cells, 5pt at 10 cells
258
+ return BASE_FONTSIZE - (max_dim - BASE_CELLS) * 0.2
259
+ elif max_dim <= 20:
260
+ # 5pt at 10 cells, 4pt at 20 cells
261
+ return 5.0 - (max_dim - 10) * 0.1
262
+ else:
263
+ # Minimum 3pt for very large heatmaps
264
+ return max(3.0, 4.0 - (max_dim - 20) * 0.05)
265
+
266
+
267
+ def _mpl_annotate_heatmap(
268
+ im: AxesImage,
269
+ data: Optional[np.ndarray] = None,
270
+ valfmt: str = "{x:.2f}",
271
+ textcolors: Tuple[str, str] = ("lightgray", "black"),
272
+ threshold: Optional[float] = None,
273
+ fontsize: Optional[float] = None,
274
+ **textkw: Any,
275
+ ) -> List:
276
+ """Annotate a heatmap with cell values.
277
+
278
+ Parameters
279
+ ----------
280
+ im : matplotlib.image.AxesImage
281
+ The image to be annotated.
282
+ data : np.ndarray, optional
283
+ Data used to annotate. If None, uses the image's array.
284
+ valfmt : str, default "{x:.2f}"
285
+ Format string for the annotations.
286
+ textcolors : tuple of str, default ("lightgray", "black")
287
+ Colors for annotations. First color for values below threshold,
288
+ second for values above.
289
+ threshold : float, optional
290
+ Value in normalized colormap space (0 to 1) above which the
291
+ second color is used. If None, uses 0.7 * max(data).
292
+ fontsize : float, optional
293
+ Font size in points. If None, dynamically calculated based on
294
+ cell count (6pt base, scaling down for larger heatmaps).
295
+ **textkw : dict
296
+ Additional keyword arguments passed to ax.text().
297
+
298
+ Returns
299
+ -------
300
+ texts : list of matplotlib.text.Text
301
+ The annotation text objects.
302
+ """
303
+
304
+ if not isinstance(data, (list, np.ndarray)):
305
+ data = im.get_array()
306
+
307
+ # Calculate dynamic font size if not specified
308
+ if fontsize is None:
309
+ fontsize = _calc_annot_fontsize(data.shape[0], data.shape[1])
310
+
311
+ # Normalize the threshold to the images color range.
312
+ if threshold is not None:
313
+ threshold = im.norm(threshold)
314
+ else:
315
+ # Use 0.7 instead of 0.5 for better visibility with most colormaps
316
+ threshold = im.norm(data.max()) * 0.7
317
+
318
+ # Set default alignment to center, but allow it to be
319
+ # overwritten by textkw.
320
+ kw = dict(horizontalalignment="center", verticalalignment="center", fontsize=fontsize)
321
+ kw.update(textkw)
322
+
323
+ # Get the formatter in case a string is supplied
324
+ if isinstance(valfmt, str):
325
+ valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
326
+
327
+ # Loop over the data and create a `Text` for each "pixel".
328
+ # Change the text's color depending on the data.
329
+ texts = []
330
+ for ii in range(data.shape[0]):
331
+ for jj in range(data.shape[1]):
332
+ kw.update(color=textcolors[int(im.norm(data[ii, jj]) > threshold)])
333
+ text = im.axes.text(jj, ii, valfmt(data[ii, jj], None), **kw)
334
+ texts.append(text)
335
+
336
+ return texts
337
+
338
+
339
+ if __name__ == "__main__":
340
+ import matplotlib
341
+ import matplotlib as mpl
342
+ import matplotlib.pyplot as plt
343
+ import numpy as np
344
+
345
+ data = np.random.rand(5, 10)
346
+ x_labels = [f"X{ii+1}" for ii in range(5)]
347
+ y_labels = [f"Y{ii+1}" for ii in range(10)]
348
+
349
+ fig, ax = plt.subplots()
350
+
351
+ im, cbar = stx_heatmap(
352
+ ax,
353
+ data,
354
+ x_labels=x_labels,
355
+ y_labels=y_labels,
356
+ show_annot=True,
357
+ annot_color_lighter="white",
358
+ annot_color_darker="black",
359
+ cmap="Blues",
360
+ )
361
+
362
+ fig.tight_layout()
363
+ plt.show()
364
+ # EOF
365
+
366
+ # EOF
@@ -13,7 +13,7 @@ import matplotlib
13
13
  from scitex.plt.utils import assert_valid_axis
14
14
 
15
15
 
16
- def plot_image(
16
+ def stx_image(
17
17
  ax,
18
18
  arr_2d,
19
19
  cbar=True,
@@ -0,0 +1,113 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-05-02 09:03:23 (ywatanabe)"
4
+ # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_joyplot.py
5
+ # ----------------------------------------
6
+ import os
7
+
8
+ __FILE__ = "./src/scitex/plt/ax/_plot/_plot_joyplot.py"
9
+ __DIR__ = os.path.dirname(__FILE__)
10
+ # ----------------------------------------
11
+
12
+ import numpy as np
13
+ from scipy import stats
14
+
15
+ from ....plt.utils import assert_valid_axis
16
+
17
+
18
+ def stx_joyplot(ax, arrays, overlap=0.5, fill_alpha=0.7, line_alpha=1.0,
19
+ colors=None, **kwargs):
20
+ """
21
+ Create a joyplot (ridgeline plot) on the provided axes.
22
+
23
+ Parameters
24
+ ----------
25
+ ax : matplotlib.axes.Axes
26
+ The axes to plot on
27
+ arrays : list of array-like
28
+ List of 1D arrays for each ridge
29
+ overlap : float, default 0.5
30
+ Amount of overlap between ridges (0 = no overlap, 1 = full overlap)
31
+ fill_alpha : float, default 0.7
32
+ Alpha for the filled KDE area
33
+ line_alpha : float, default 1.0
34
+ Alpha for the KDE line
35
+ colors : list, optional
36
+ Colors for each ridge. If None, uses scitex palette.
37
+ **kwargs
38
+ Additional keyword arguments
39
+
40
+ Returns
41
+ -------
42
+ matplotlib.axes.Axes
43
+ The axes with the joyplot
44
+ """
45
+ assert_valid_axis(ax, "First argument must be a matplotlib axis or scitex axis wrapper")
46
+
47
+ # Add sample size per distribution to label if provided (show range if variable)
48
+ if kwargs.get("label"):
49
+ n_per_dist = [len(arr) for arr in arrays]
50
+ n_min, n_max = min(n_per_dist), max(n_per_dist)
51
+ n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}"
52
+ kwargs["label"] = f"{kwargs['label']} ($n$={n_str})"
53
+
54
+ # Import scitex colors
55
+ from scitex.plt.color._PARAMS import HEX
56
+
57
+ # Default colors from scitex palette
58
+ if colors is None:
59
+ colors = [
60
+ HEX["blue"], HEX["red"], HEX["green"], HEX["yellow"],
61
+ HEX["purple"], HEX["orange"], HEX["lightblue"], HEX["pink"],
62
+ ]
63
+
64
+ n_ridges = len(arrays)
65
+
66
+ # Calculate global x range
67
+ all_data = np.concatenate([np.asarray(arr) for arr in arrays])
68
+ x_min, x_max = np.min(all_data), np.max(all_data)
69
+ x_range = x_max - x_min
70
+ x_padding = x_range * 0.1
71
+ x = np.linspace(x_min - x_padding, x_max + x_padding, 200)
72
+
73
+ # Calculate KDEs and find max density for scaling
74
+ kdes = []
75
+ max_density = 0
76
+ for arr in arrays:
77
+ arr = np.asarray(arr)
78
+ if len(arr) > 1:
79
+ kde = stats.gaussian_kde(arr)
80
+ density = kde(x)
81
+ kdes.append(density)
82
+ max_density = max(max_density, np.max(density))
83
+ else:
84
+ kdes.append(np.zeros_like(x))
85
+
86
+ # Scale factor for ridge height
87
+ ridge_height = 1.0 / (1.0 - overlap * 0.5) if overlap < 1 else 2.0
88
+
89
+ # Plot each ridge from back to front
90
+ for i in range(n_ridges - 1, -1, -1):
91
+ color = colors[i % len(colors)]
92
+ baseline = i * (1.0 - overlap)
93
+
94
+ # Scale density to fit nicely
95
+ scaled_density = kdes[i] / max_density * ridge_height if max_density > 0 else kdes[i]
96
+
97
+ # Fill
98
+ ax.fill_between(x, baseline, baseline + scaled_density,
99
+ facecolor=color, edgecolor='none', alpha=fill_alpha)
100
+ # Line on top
101
+ ax.plot(x, baseline + scaled_density, color=color, alpha=line_alpha,
102
+ linewidth=1.0)
103
+
104
+ # Set y limits
105
+ ax.set_ylim(-0.1, n_ridges * (1.0 - overlap) + ridge_height)
106
+
107
+ # Hide y-axis ticks for cleaner look (joyplots typically don't show y values)
108
+ ax.set_yticks([])
109
+
110
+ return ax
111
+
112
+
113
+ # EOF
@@ -18,15 +18,16 @@ import pandas as pd
18
18
  from ....plt.utils import assert_valid_axis
19
19
 
20
20
 
21
- def plot_raster(
21
+ def stx_raster(
22
22
  ax,
23
- event_times,
23
+ spike_times_list,
24
24
  time=None,
25
25
  labels=None,
26
26
  colors=None,
27
27
  orientation="horizontal",
28
28
  y_offset=None,
29
29
  lineoffsets=None,
30
+ linelengths=None,
30
31
  apply_set_n_ticks=True,
31
32
  n_xticks=4,
32
33
  n_yticks=None,
@@ -39,8 +40,8 @@ def plot_raster(
39
40
  ----------
40
41
  ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
41
42
  The axes on which to draw the raster plot.
42
- event_times : Array-like or list of lists
43
- Time points of events by channels/trials
43
+ spike_times_list : list of array-like, shape (n_trials,) where each element is (n_spikes,)
44
+ List of spike/event time arrays, one per trial/channel
44
45
  time : array-like, optional
45
46
  The time indices for the events (default: np.linspace(0, max(event_times))).
46
47
  labels : list, optional
@@ -50,9 +51,11 @@ def plot_raster(
50
51
  orientation: str, optional
51
52
  Orientation of raster plot (default: horizontal).
52
53
  y_offset : float, optional
53
- Vertical spacing between trials/channels.
54
+ Vertical spacing between trials/channels (default: 1.0).
54
55
  lineoffsets : array-like, optional
55
56
  Y-positions for each trial/channel (overrides automatic positioning).
57
+ linelengths : float, optional
58
+ Height of each spike mark (default: 0.8, slightly less than y_offset to prevent overlap).
56
59
  apply_set_n_ticks : bool, optional
57
60
  Whether to apply set_n_ticks for cleaner axis (default: True).
58
61
  n_xticks : int, optional
@@ -71,37 +74,46 @@ def plot_raster(
71
74
  """
72
75
  assert_valid_axis(ax, "First argument must be a matplotlib axis or scitex axis wrapper")
73
76
 
74
- # Format event_times data
75
- event_times_list = _ensure_list(event_times)
77
+ # Format spike_times_list data
78
+ spike_times_list = _ensure_list(spike_times_list)
79
+
80
+ # Add sample size (number of trials) to label if provided
81
+ if kwargs.get("label"):
82
+ n_trials = len(spike_times_list)
83
+ kwargs["label"] = f"{kwargs['label']} ($n$={n_trials})"
76
84
 
77
85
  # Handle colors and labels
78
- colors = _handle_colors(colors, event_times_list)
79
-
86
+ colors = _handle_colors(colors, spike_times_list)
87
+
80
88
  # Handle lineoffsets for positioning between trials/channels
89
+ if y_offset is None:
90
+ y_offset = 1.0 # Default spacing
81
91
  if lineoffsets is None:
82
- if y_offset is None:
83
- y_offset = 1.0 # Default spacing
84
- lineoffsets = np.arange(len(event_times_list)) * y_offset
85
-
86
- # Ensure lineoffsets is iterable and matches event_times_list length
92
+ lineoffsets = np.arange(len(spike_times_list)) * y_offset
93
+
94
+ # Set linelengths to prevent overlap (80% of y_offset by default)
95
+ if linelengths is None:
96
+ linelengths = y_offset * 0.8
97
+
98
+ # Ensure lineoffsets is iterable and matches spike_times_list length
87
99
  if np.isscalar(lineoffsets):
88
100
  lineoffsets = [lineoffsets]
89
- if len(lineoffsets) < len(event_times_list):
90
- lineoffsets = list(lineoffsets) + list(range(len(lineoffsets), len(event_times_list)))
101
+ if len(lineoffsets) < len(spike_times_list):
102
+ lineoffsets = list(lineoffsets) + list(range(len(lineoffsets), len(spike_times_list)))
91
103
 
92
- # Plotting as eventplot using event_times_list with proper positioning
93
- for ii, (pos, color, offset) in enumerate(zip(event_times_list, colors, lineoffsets)):
104
+ # Plotting as eventplot using spike_times_list with proper positioning
105
+ for ii, (pos, color, offset) in enumerate(zip(spike_times_list, colors, lineoffsets)):
94
106
  label = _define_label(labels, ii)
95
- ax.eventplot(pos, lineoffsets=offset, orientation=orientation,
96
- colors=color, label=label, **kwargs)
107
+ ax.eventplot(pos, lineoffsets=offset, linelengths=linelengths,
108
+ orientation=orientation, colors=color, label=label, **kwargs)
97
109
 
98
110
  # Apply set_n_ticks for cleaner axes if requested
99
111
  if apply_set_n_ticks:
100
112
  from scitex.plt.ax._style._set_n_ticks import set_n_ticks
101
-
113
+
102
114
  # For categorical y-axis (trials/channels), use appropriate tick count
103
115
  if n_yticks is None:
104
- n_yticks = min(len(event_times_list), 8) # Max 8 ticks for readability
116
+ n_yticks = min(len(spike_times_list), 8) # Max 8 ticks for readability
105
117
 
106
118
  # Only apply if we have reasonable numeric ranges
107
119
  try:
@@ -124,10 +136,10 @@ def plot_raster(
124
136
  if labels is not None:
125
137
  ax.legend()
126
138
 
127
- # Return event_times in a useful format
128
- event_times_digital_df = _event_times_to_digital_df(event_times_list, time, lineoffsets)
139
+ # Return spike_times in a useful format
140
+ spike_times_digital_df = _event_times_to_digital_df(spike_times_list, time, lineoffsets)
129
141
 
130
- return ax, event_times_digital_df
142
+ return ax, spike_times_digital_df
131
143
 
132
144
 
133
145
  def _ensure_list(event_times):
@@ -12,11 +12,12 @@ __DIR__ = os.path.dirname(__FILE__)
12
12
  from matplotlib.patches import Rectangle
13
13
 
14
14
 
15
- def plot_rectangle(ax, xx, yy, ww, hh, **kwargs):
15
+ def stx_rectangle(ax, xx, yy, ww, hh, **kwargs):
16
16
  """Add a rectangle patch to an axes.
17
17
 
18
18
  Convenience function for adding rectangular patches to plots, useful for
19
19
  highlighting regions, creating box annotations, or drawing geometric shapes.
20
+ By default, rectangles have no edge (border) for cleaner publication figures.
20
21
 
21
22
  Parameters
22
23
  ----------
@@ -34,7 +35,7 @@ def plot_rectangle(ax, xx, yy, ww, hh, **kwargs):
34
35
  Additional keyword arguments passed to matplotlib.patches.Rectangle.
35
36
  Common options include:
36
37
  - facecolor/fc : fill color
37
- - edgecolor/ec : edge color
38
+ - edgecolor/ec : edge color (default: 'none')
38
39
  - linewidth/lw : edge line width
39
40
  - alpha : transparency (0-1)
40
41
  - linestyle/ls : edge line style
@@ -48,20 +49,20 @@ def plot_rectangle(ax, xx, yy, ww, hh, **kwargs):
48
49
  --------
49
50
  >>> fig, ax = plt.subplots()
50
51
  >>> ax.plot([0, 10], [0, 10])
51
- >>> # Highlight a region
52
- >>> plot_rectangle(ax, 2, 3, 4, 3, facecolor='yellow', alpha=0.3)
52
+ >>> # Highlight a region (no border by default)
53
+ >>> stx_rectangle(ax, 2, 3, 4, 3, facecolor='yellow', alpha=0.3)
53
54
 
54
- >>> # Draw a box annotation
55
- >>> plot_rectangle(ax, 5, 5, 2, 2, facecolor='none', edgecolor='red', linewidth=2)
56
-
57
- >>> # Create a filled rectangle
58
- >>> plot_rectangle(ax, 0, 0, 1, 1, facecolor='blue', edgecolor='black')
55
+ >>> # Draw a box with explicit edge
56
+ >>> stx_rectangle(ax, 5, 5, 2, 2, facecolor='none', edgecolor='red', linewidth=2)
59
57
 
60
58
  See Also
61
59
  --------
62
60
  matplotlib.patches.Rectangle : The underlying Rectangle class
63
61
  matplotlib.axes.Axes.add_patch : Method used to add the patch
64
62
  """
63
+ # Default to no edge for cleaner publication figures
64
+ if 'edgecolor' not in kwargs and 'ec' not in kwargs:
65
+ kwargs['edgecolor'] = 'none'
65
66
  ax.add_patch(Rectangle((xx, yy), ww, hh, **kwargs))
66
67
  return ax
67
68
 
@@ -12,7 +12,7 @@ __DIR__ = os.path.dirname(__FILE__)
12
12
  import numpy as np
13
13
 
14
14
 
15
- def plot_scatter_hist(
15
+ def stx_scatter_hist(
16
16
  ax,
17
17
  x,
18
18
  y,