scitex 2.3.0__py3-none-any.whl → 2.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/ai/classification/reporters/reporter_utils/_Plotter.py +1 -1
- scitex/ai/plt/__init__.py +2 -2
- scitex/ai/plt/{_plot_conf_mat.py → _stx_conf_mat.py} +3 -3
- scitex/config/PriorityConfig.py +195 -0
- scitex/config/__init__.py +24 -0
- scitex/io/_save.py +125 -34
- scitex/io/_save_modules/_image.py +37 -20
- scitex/plt/__init__.py +470 -17
- scitex/plt/_subplots/_AxisWrapper.py +98 -50
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +559 -124
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +49 -8
- scitex/plt/_subplots/_SubplotsWrapper.py +76 -91
- scitex/plt/_subplots/_export_as_csv.py +127 -58
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +25 -16
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py +54 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py +41 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py +41 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -47
- scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py +42 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py +42 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +72 -35
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +1 -1
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py +53 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py +42 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py +42 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py +48 -0
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_conf_mat.py → _format_stx_conf_mat.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_ecdf.py → _format_stx_ecdf.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_fillv.py → _format_stx_fillv.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_heatmap.py → _format_stx_heatmap.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_image.py → _format_stx_image.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_joyplot.py → _format_stx_joyplot.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_line.py → _format_stx_line.py} +3 -3
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_ci.py → _format_stx_mean_ci.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_std.py → _format_stx_mean_std.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_median_iqr.py → _format_stx_median_iqr.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_raster.py → _format_stx_raster.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_rectangle.py → _format_stx_rectangle.py} +1 -1
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_scatter_hist.py → _format_stx_scatter_hist.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_shaded_line.py → _format_stx_shaded_line.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_violin.py → _format_stx_violin.py} +2 -2
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +23 -23
- scitex/plt/ax/__init__.py +16 -15
- scitex/plt/ax/_plot/__init__.py +30 -30
- scitex/plt/ax/_plot/_add_fitted_line.py +65 -11
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +104 -76
- scitex/plt/ax/_plot/{_plot_conf_mat.py → _stx_conf_mat.py} +10 -10
- scitex/plt/ax/_plot/_stx_ecdf.py +109 -0
- scitex/plt/ax/_plot/{_plot_fillv.py → _stx_fillv.py} +7 -7
- scitex/plt/ax/_plot/_stx_heatmap.py +366 -0
- scitex/plt/ax/_plot/{_plot_image.py → _stx_image.py} +1 -1
- scitex/plt/ax/_plot/_stx_joyplot.py +113 -0
- scitex/plt/ax/_plot/{_plot_raster.py → _stx_raster.py} +37 -25
- scitex/plt/ax/_plot/{_plot_rectangle.py → _stx_rectangle.py} +10 -9
- scitex/plt/ax/_plot/{_plot_scatter_hist.py → _stx_scatter_hist.py} +1 -1
- scitex/plt/ax/_plot/_stx_shaded_line.py +215 -0
- scitex/plt/ax/_plot/{_plot_violin.py → _stx_violin.py} +13 -6
- scitex/plt/ax/_style/__init__.py +3 -0
- scitex/plt/ax/_style/_style_barplot.py +13 -2
- scitex/plt/ax/_style/_style_boxplot.py +78 -32
- scitex/plt/ax/_style/_style_errorbar.py +17 -3
- scitex/plt/ax/_style/_style_scatter.py +17 -3
- scitex/plt/ax/_style/_style_violinplot.py +109 -0
- scitex/plt/color/_vizualize_colors.py +3 -3
- scitex/plt/styles/SCITEX_STYLE.yaml +104 -0
- scitex/plt/styles/__init__.py +57 -0
- scitex/plt/styles/_plot_defaults.py +209 -0
- scitex/plt/styles/_plot_postprocess.py +518 -0
- scitex/plt/styles/_style_loader.py +268 -0
- scitex/plt/styles/presets.py +208 -0
- scitex/plt/utils/_collect_figure_metadata.py +160 -18
- scitex/plt/utils/_colorbar.py +72 -10
- scitex/plt/utils/_configure_mpl.py +108 -52
- scitex/plt/utils/_crop.py +21 -7
- scitex/plt/utils/_figure_mm.py +21 -7
- scitex/stats/__init__.py +13 -1
- scitex/stats/_schema.py +578 -0
- scitex/stats/tests/__init__.py +13 -0
- scitex/stats/tests/correlation/__init__.py +13 -0
- scitex/stats/tests/correlation/_test_pearson.py +262 -0
- scitex/vis/__init__.py +6 -0
- scitex/vis/editor/__init__.py +23 -0
- scitex/vis/editor/_defaults.py +205 -0
- scitex/vis/editor/_edit.py +342 -0
- scitex/vis/editor/_mpl_editor.py +231 -0
- scitex/vis/editor/_tkinter_editor.py +466 -0
- scitex/vis/editor/_web_editor.py +1440 -0
- scitex/vis/model/plot_types.py +15 -15
- {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/METADATA +2 -1
- {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/RECORD +94 -67
- {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/WHEEL +1 -1
- scitex/plt/ax/_plot/_plot_ecdf.py +0 -84
- scitex/plt/ax/_plot/_plot_heatmap.py +0 -277
- scitex/plt/ax/_plot/_plot_joyplot.py +0 -77
- scitex/plt/ax/_plot/_plot_shaded_line.py +0 -142
- scitex/plt/presets.py +0 -224
- {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/entry_points.txt +0 -0
- {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Timestamp: "2025-12-01 13:00:00 (ywatanabe)"
|
|
4
|
+
# File: ./src/scitex/plt/ax/_plot/_plot_heatmap.py
|
|
5
|
+
|
|
6
|
+
"""Heatmap plotting with automatic annotation color switching."""
|
|
7
|
+
|
|
8
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import matplotlib
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
from matplotlib.axes import Axes
|
|
14
|
+
from matplotlib.colorbar import Colorbar
|
|
15
|
+
from matplotlib.image import AxesImage
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def stx_heatmap(
|
|
19
|
+
ax: Union[Axes, "AxisWrapper"],
|
|
20
|
+
values_2d: np.ndarray,
|
|
21
|
+
x_labels: Optional[List[str]] = None,
|
|
22
|
+
y_labels: Optional[List[str]] = None,
|
|
23
|
+
cmap: str = "viridis",
|
|
24
|
+
cbar_label: str = "ColorBar Label",
|
|
25
|
+
annot_format: str = "{x:.1f}",
|
|
26
|
+
show_annot: bool = True,
|
|
27
|
+
annot_color_lighter: str = "black",
|
|
28
|
+
annot_color_darker: str = "white",
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> Tuple[Union[Axes, "AxisWrapper"], AxesImage, Colorbar]:
|
|
31
|
+
"""Plot a heatmap on the given axes with automatic annotation colors.
|
|
32
|
+
|
|
33
|
+
Creates a heatmap visualization with optional cell annotations. Annotation
|
|
34
|
+
text colors are automatically switched based on background brightness for
|
|
35
|
+
optimal readability.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
ax : matplotlib.axes.Axes or AxisWrapper
|
|
40
|
+
The axes to plot on.
|
|
41
|
+
values_2d : np.ndarray, shape (n_rows, n_cols)
|
|
42
|
+
2D array of data to display as heatmap.
|
|
43
|
+
x_labels : list of str, optional
|
|
44
|
+
Labels for the x-axis (columns).
|
|
45
|
+
y_labels : list of str, optional
|
|
46
|
+
Labels for the y-axis (rows).
|
|
47
|
+
cmap : str, default "viridis"
|
|
48
|
+
Colormap name to use.
|
|
49
|
+
cbar_label : str, default "ColorBar Label"
|
|
50
|
+
Label for the colorbar.
|
|
51
|
+
annot_format : str, default "{x:.1f}"
|
|
52
|
+
Format string for cell annotations.
|
|
53
|
+
show_annot : bool, default True
|
|
54
|
+
Whether to annotate the heatmap with values.
|
|
55
|
+
annot_color_lighter : str, default "black"
|
|
56
|
+
Text color for annotations on lighter backgrounds.
|
|
57
|
+
annot_color_darker : str, default "white"
|
|
58
|
+
Text color for annotations on darker backgrounds.
|
|
59
|
+
**kwargs : dict
|
|
60
|
+
Additional keyword arguments passed to imshow().
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
ax : matplotlib.axes.Axes or AxisWrapper
|
|
65
|
+
The axes with the heatmap.
|
|
66
|
+
im : matplotlib.image.AxesImage
|
|
67
|
+
The image object created by imshow.
|
|
68
|
+
cbar : matplotlib.colorbar.Colorbar
|
|
69
|
+
The colorbar object.
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> import numpy as np
|
|
74
|
+
>>> import scitex as stx
|
|
75
|
+
>>> data = np.random.rand(5, 10)
|
|
76
|
+
>>> fig, ax = stx.plt.subplots()
|
|
77
|
+
>>> ax, im, cbar = stx.plt.ax.stx_heatmap(
|
|
78
|
+
... ax, data,
|
|
79
|
+
... x_labels=[f"X{i}" for i in range(10)],
|
|
80
|
+
... y_labels=[f"Y{i}" for i in range(5)],
|
|
81
|
+
... cmap="Blues"
|
|
82
|
+
... )
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
im, cbar = _mpl_heatmap(
|
|
86
|
+
values_2d,
|
|
87
|
+
x_labels,
|
|
88
|
+
y_labels,
|
|
89
|
+
ax=ax,
|
|
90
|
+
cmap=cmap,
|
|
91
|
+
cbarlabel=cbar_label,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if show_annot:
|
|
95
|
+
textcolors = _switch_annot_colors(cmap, annot_color_lighter, annot_color_darker)
|
|
96
|
+
texts = _mpl_annotate_heatmap(
|
|
97
|
+
im,
|
|
98
|
+
valfmt=annot_format,
|
|
99
|
+
textcolors=textcolors,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return ax, im, cbar
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _switch_annot_colors(
|
|
106
|
+
cmap: str,
|
|
107
|
+
annot_color_lighter: str,
|
|
108
|
+
annot_color_darker: str,
|
|
109
|
+
) -> Tuple[str, str]:
|
|
110
|
+
"""Determine annotation text colors based on colormap brightness.
|
|
111
|
+
|
|
112
|
+
Uses perceived brightness (ITU-R BT.709) to select appropriate text
|
|
113
|
+
colors for light vs dark backgrounds in the colormap.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
cmap : str
|
|
118
|
+
Colormap name.
|
|
119
|
+
annot_color_lighter : str
|
|
120
|
+
Color to use on lighter backgrounds.
|
|
121
|
+
annot_color_darker : str
|
|
122
|
+
Color to use on darker backgrounds.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
tuple of str
|
|
127
|
+
(color_for_dark_bg, color_for_light_bg) text colors.
|
|
128
|
+
"""
|
|
129
|
+
cmap_obj = plt.cm.get_cmap(cmap)
|
|
130
|
+
|
|
131
|
+
# Sample colormap at extremes (avoiding edge effects)
|
|
132
|
+
dark_color = cmap_obj(0.1)
|
|
133
|
+
light_color = cmap_obj(0.9)
|
|
134
|
+
|
|
135
|
+
# Calculate perceived brightness using ITU-R BT.709 coefficients
|
|
136
|
+
dark_brightness = (
|
|
137
|
+
0.2126 * dark_color[0] + 0.7152 * dark_color[1] + 0.0722 * dark_color[2]
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Choose text colors based on background brightness
|
|
141
|
+
if dark_brightness < 0.5:
|
|
142
|
+
return (annot_color_lighter, annot_color_darker)
|
|
143
|
+
else:
|
|
144
|
+
return (annot_color_darker, annot_color_lighter)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _mpl_heatmap(
|
|
148
|
+
data: np.ndarray,
|
|
149
|
+
row_labels: Optional[List[str]],
|
|
150
|
+
col_labels: Optional[List[str]],
|
|
151
|
+
ax: Optional[Axes] = None,
|
|
152
|
+
cbar_kw: Optional[dict] = None,
|
|
153
|
+
cbarlabel: str = "",
|
|
154
|
+
**kwargs: Any,
|
|
155
|
+
) -> Tuple[AxesImage, Colorbar]:
|
|
156
|
+
"""Create a heatmap with imshow and add a colorbar.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
data : np.ndarray
|
|
161
|
+
2D array of data to display.
|
|
162
|
+
row_labels : list of str or None
|
|
163
|
+
Labels for the rows (y-axis).
|
|
164
|
+
col_labels : list of str or None
|
|
165
|
+
Labels for the columns (x-axis).
|
|
166
|
+
ax : matplotlib.axes.Axes, optional
|
|
167
|
+
Axes to plot on. If None, uses current axes.
|
|
168
|
+
cbar_kw : dict, optional
|
|
169
|
+
Keyword arguments for colorbar creation.
|
|
170
|
+
cbarlabel : str, default ""
|
|
171
|
+
Label for the colorbar.
|
|
172
|
+
**kwargs : dict
|
|
173
|
+
Additional keyword arguments passed to imshow().
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
im : matplotlib.image.AxesImage
|
|
178
|
+
The image object.
|
|
179
|
+
cbar : matplotlib.colorbar.Colorbar
|
|
180
|
+
The colorbar object.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if ax is None:
|
|
184
|
+
ax = plt.gca()
|
|
185
|
+
|
|
186
|
+
if cbar_kw is None:
|
|
187
|
+
cbar_kw = {}
|
|
188
|
+
|
|
189
|
+
# Plot the heatmap
|
|
190
|
+
im = ax.imshow(data, **kwargs)
|
|
191
|
+
|
|
192
|
+
# Create colorbar with proper formatting
|
|
193
|
+
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
|
|
194
|
+
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
|
|
195
|
+
|
|
196
|
+
# Set colorbar border width to match axes spines
|
|
197
|
+
cbar.outline.set_linewidth(0.2 * 2.83465) # 0.2mm in points
|
|
198
|
+
|
|
199
|
+
# Format colorbar ticks
|
|
200
|
+
from matplotlib.ticker import MaxNLocator
|
|
201
|
+
cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3))
|
|
202
|
+
cbar.ax.tick_params(width=0.2 * 2.83465, length=0.8 * 2.83465) # Match tick styling
|
|
203
|
+
|
|
204
|
+
# Show all ticks and label them with the respective list entries.
|
|
205
|
+
ax.set_xticks(
|
|
206
|
+
range(data.shape[1]),
|
|
207
|
+
labels=col_labels,
|
|
208
|
+
# rotation=45,
|
|
209
|
+
# ha="right",
|
|
210
|
+
# rotation_mode="anchor",
|
|
211
|
+
)
|
|
212
|
+
ax.set_yticks(range(data.shape[0]), labels=row_labels)
|
|
213
|
+
|
|
214
|
+
# Let the horizontal axes labeling appear on top.
|
|
215
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
|
216
|
+
|
|
217
|
+
# Show all 4 spines for heatmap
|
|
218
|
+
ax.spines[:].set_visible(True)
|
|
219
|
+
|
|
220
|
+
# Set aspect ratio to 'equal' for square cells (1:1)
|
|
221
|
+
ax.set_aspect('equal', adjustable='box')
|
|
222
|
+
|
|
223
|
+
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
|
|
224
|
+
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
|
|
225
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
|
226
|
+
|
|
227
|
+
return im, cbar
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _calc_annot_fontsize(n_rows: int, n_cols: int) -> float:
|
|
231
|
+
"""Calculate dynamic annotation font size based on cell count.
|
|
232
|
+
|
|
233
|
+
Uses a base size of 6pt for small heatmaps and scales down for larger ones.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
n_rows : int
|
|
238
|
+
Number of rows in the heatmap.
|
|
239
|
+
n_cols : int
|
|
240
|
+
Number of columns in the heatmap.
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
float
|
|
245
|
+
Font size in points.
|
|
246
|
+
"""
|
|
247
|
+
# Base font size for small heatmaps (e.g., 5x5)
|
|
248
|
+
BASE_FONTSIZE = 6.0
|
|
249
|
+
BASE_CELLS = 5 # Reference dimension
|
|
250
|
+
|
|
251
|
+
# Use the larger dimension to scale
|
|
252
|
+
max_dim = max(n_rows, n_cols)
|
|
253
|
+
|
|
254
|
+
if max_dim <= BASE_CELLS:
|
|
255
|
+
return BASE_FONTSIZE
|
|
256
|
+
elif max_dim <= 10:
|
|
257
|
+
# Linear interpolation: 6pt at 5 cells, 5pt at 10 cells
|
|
258
|
+
return BASE_FONTSIZE - (max_dim - BASE_CELLS) * 0.2
|
|
259
|
+
elif max_dim <= 20:
|
|
260
|
+
# 5pt at 10 cells, 4pt at 20 cells
|
|
261
|
+
return 5.0 - (max_dim - 10) * 0.1
|
|
262
|
+
else:
|
|
263
|
+
# Minimum 3pt for very large heatmaps
|
|
264
|
+
return max(3.0, 4.0 - (max_dim - 20) * 0.05)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _mpl_annotate_heatmap(
|
|
268
|
+
im: AxesImage,
|
|
269
|
+
data: Optional[np.ndarray] = None,
|
|
270
|
+
valfmt: str = "{x:.2f}",
|
|
271
|
+
textcolors: Tuple[str, str] = ("lightgray", "black"),
|
|
272
|
+
threshold: Optional[float] = None,
|
|
273
|
+
fontsize: Optional[float] = None,
|
|
274
|
+
**textkw: Any,
|
|
275
|
+
) -> List:
|
|
276
|
+
"""Annotate a heatmap with cell values.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
im : matplotlib.image.AxesImage
|
|
281
|
+
The image to be annotated.
|
|
282
|
+
data : np.ndarray, optional
|
|
283
|
+
Data used to annotate. If None, uses the image's array.
|
|
284
|
+
valfmt : str, default "{x:.2f}"
|
|
285
|
+
Format string for the annotations.
|
|
286
|
+
textcolors : tuple of str, default ("lightgray", "black")
|
|
287
|
+
Colors for annotations. First color for values below threshold,
|
|
288
|
+
second for values above.
|
|
289
|
+
threshold : float, optional
|
|
290
|
+
Value in normalized colormap space (0 to 1) above which the
|
|
291
|
+
second color is used. If None, uses 0.7 * max(data).
|
|
292
|
+
fontsize : float, optional
|
|
293
|
+
Font size in points. If None, dynamically calculated based on
|
|
294
|
+
cell count (6pt base, scaling down for larger heatmaps).
|
|
295
|
+
**textkw : dict
|
|
296
|
+
Additional keyword arguments passed to ax.text().
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
texts : list of matplotlib.text.Text
|
|
301
|
+
The annotation text objects.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
if not isinstance(data, (list, np.ndarray)):
|
|
305
|
+
data = im.get_array()
|
|
306
|
+
|
|
307
|
+
# Calculate dynamic font size if not specified
|
|
308
|
+
if fontsize is None:
|
|
309
|
+
fontsize = _calc_annot_fontsize(data.shape[0], data.shape[1])
|
|
310
|
+
|
|
311
|
+
# Normalize the threshold to the images color range.
|
|
312
|
+
if threshold is not None:
|
|
313
|
+
threshold = im.norm(threshold)
|
|
314
|
+
else:
|
|
315
|
+
# Use 0.7 instead of 0.5 for better visibility with most colormaps
|
|
316
|
+
threshold = im.norm(data.max()) * 0.7
|
|
317
|
+
|
|
318
|
+
# Set default alignment to center, but allow it to be
|
|
319
|
+
# overwritten by textkw.
|
|
320
|
+
kw = dict(horizontalalignment="center", verticalalignment="center", fontsize=fontsize)
|
|
321
|
+
kw.update(textkw)
|
|
322
|
+
|
|
323
|
+
# Get the formatter in case a string is supplied
|
|
324
|
+
if isinstance(valfmt, str):
|
|
325
|
+
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
|
|
326
|
+
|
|
327
|
+
# Loop over the data and create a `Text` for each "pixel".
|
|
328
|
+
# Change the text's color depending on the data.
|
|
329
|
+
texts = []
|
|
330
|
+
for ii in range(data.shape[0]):
|
|
331
|
+
for jj in range(data.shape[1]):
|
|
332
|
+
kw.update(color=textcolors[int(im.norm(data[ii, jj]) > threshold)])
|
|
333
|
+
text = im.axes.text(jj, ii, valfmt(data[ii, jj], None), **kw)
|
|
334
|
+
texts.append(text)
|
|
335
|
+
|
|
336
|
+
return texts
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
if __name__ == "__main__":
|
|
340
|
+
import matplotlib
|
|
341
|
+
import matplotlib as mpl
|
|
342
|
+
import matplotlib.pyplot as plt
|
|
343
|
+
import numpy as np
|
|
344
|
+
|
|
345
|
+
data = np.random.rand(5, 10)
|
|
346
|
+
x_labels = [f"X{ii+1}" for ii in range(5)]
|
|
347
|
+
y_labels = [f"Y{ii+1}" for ii in range(10)]
|
|
348
|
+
|
|
349
|
+
fig, ax = plt.subplots()
|
|
350
|
+
|
|
351
|
+
im, cbar = stx_heatmap(
|
|
352
|
+
ax,
|
|
353
|
+
data,
|
|
354
|
+
x_labels=x_labels,
|
|
355
|
+
y_labels=y_labels,
|
|
356
|
+
show_annot=True,
|
|
357
|
+
annot_color_lighter="white",
|
|
358
|
+
annot_color_darker="black",
|
|
359
|
+
cmap="Blues",
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
fig.tight_layout()
|
|
363
|
+
plt.show()
|
|
364
|
+
# EOF
|
|
365
|
+
|
|
366
|
+
# EOF
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Timestamp: "2025-05-02 09:03:23 (ywatanabe)"
|
|
4
|
+
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_joyplot.py
|
|
5
|
+
# ----------------------------------------
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
__FILE__ = "./src/scitex/plt/ax/_plot/_plot_joyplot.py"
|
|
9
|
+
__DIR__ = os.path.dirname(__FILE__)
|
|
10
|
+
# ----------------------------------------
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from scipy import stats
|
|
14
|
+
|
|
15
|
+
from ....plt.utils import assert_valid_axis
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def stx_joyplot(ax, arrays, overlap=0.5, fill_alpha=0.7, line_alpha=1.0,
|
|
19
|
+
colors=None, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Create a joyplot (ridgeline plot) on the provided axes.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
ax : matplotlib.axes.Axes
|
|
26
|
+
The axes to plot on
|
|
27
|
+
arrays : list of array-like
|
|
28
|
+
List of 1D arrays for each ridge
|
|
29
|
+
overlap : float, default 0.5
|
|
30
|
+
Amount of overlap between ridges (0 = no overlap, 1 = full overlap)
|
|
31
|
+
fill_alpha : float, default 0.7
|
|
32
|
+
Alpha for the filled KDE area
|
|
33
|
+
line_alpha : float, default 1.0
|
|
34
|
+
Alpha for the KDE line
|
|
35
|
+
colors : list, optional
|
|
36
|
+
Colors for each ridge. If None, uses scitex palette.
|
|
37
|
+
**kwargs
|
|
38
|
+
Additional keyword arguments
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
matplotlib.axes.Axes
|
|
43
|
+
The axes with the joyplot
|
|
44
|
+
"""
|
|
45
|
+
assert_valid_axis(ax, "First argument must be a matplotlib axis or scitex axis wrapper")
|
|
46
|
+
|
|
47
|
+
# Add sample size per distribution to label if provided (show range if variable)
|
|
48
|
+
if kwargs.get("label"):
|
|
49
|
+
n_per_dist = [len(arr) for arr in arrays]
|
|
50
|
+
n_min, n_max = min(n_per_dist), max(n_per_dist)
|
|
51
|
+
n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}"
|
|
52
|
+
kwargs["label"] = f"{kwargs['label']} ($n$={n_str})"
|
|
53
|
+
|
|
54
|
+
# Import scitex colors
|
|
55
|
+
from scitex.plt.color._PARAMS import HEX
|
|
56
|
+
|
|
57
|
+
# Default colors from scitex palette
|
|
58
|
+
if colors is None:
|
|
59
|
+
colors = [
|
|
60
|
+
HEX["blue"], HEX["red"], HEX["green"], HEX["yellow"],
|
|
61
|
+
HEX["purple"], HEX["orange"], HEX["lightblue"], HEX["pink"],
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
n_ridges = len(arrays)
|
|
65
|
+
|
|
66
|
+
# Calculate global x range
|
|
67
|
+
all_data = np.concatenate([np.asarray(arr) for arr in arrays])
|
|
68
|
+
x_min, x_max = np.min(all_data), np.max(all_data)
|
|
69
|
+
x_range = x_max - x_min
|
|
70
|
+
x_padding = x_range * 0.1
|
|
71
|
+
x = np.linspace(x_min - x_padding, x_max + x_padding, 200)
|
|
72
|
+
|
|
73
|
+
# Calculate KDEs and find max density for scaling
|
|
74
|
+
kdes = []
|
|
75
|
+
max_density = 0
|
|
76
|
+
for arr in arrays:
|
|
77
|
+
arr = np.asarray(arr)
|
|
78
|
+
if len(arr) > 1:
|
|
79
|
+
kde = stats.gaussian_kde(arr)
|
|
80
|
+
density = kde(x)
|
|
81
|
+
kdes.append(density)
|
|
82
|
+
max_density = max(max_density, np.max(density))
|
|
83
|
+
else:
|
|
84
|
+
kdes.append(np.zeros_like(x))
|
|
85
|
+
|
|
86
|
+
# Scale factor for ridge height
|
|
87
|
+
ridge_height = 1.0 / (1.0 - overlap * 0.5) if overlap < 1 else 2.0
|
|
88
|
+
|
|
89
|
+
# Plot each ridge from back to front
|
|
90
|
+
for i in range(n_ridges - 1, -1, -1):
|
|
91
|
+
color = colors[i % len(colors)]
|
|
92
|
+
baseline = i * (1.0 - overlap)
|
|
93
|
+
|
|
94
|
+
# Scale density to fit nicely
|
|
95
|
+
scaled_density = kdes[i] / max_density * ridge_height if max_density > 0 else kdes[i]
|
|
96
|
+
|
|
97
|
+
# Fill
|
|
98
|
+
ax.fill_between(x, baseline, baseline + scaled_density,
|
|
99
|
+
facecolor=color, edgecolor='none', alpha=fill_alpha)
|
|
100
|
+
# Line on top
|
|
101
|
+
ax.plot(x, baseline + scaled_density, color=color, alpha=line_alpha,
|
|
102
|
+
linewidth=1.0)
|
|
103
|
+
|
|
104
|
+
# Set y limits
|
|
105
|
+
ax.set_ylim(-0.1, n_ridges * (1.0 - overlap) + ridge_height)
|
|
106
|
+
|
|
107
|
+
# Hide y-axis ticks for cleaner look (joyplots typically don't show y values)
|
|
108
|
+
ax.set_yticks([])
|
|
109
|
+
|
|
110
|
+
return ax
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# EOF
|
|
@@ -18,15 +18,16 @@ import pandas as pd
|
|
|
18
18
|
from ....plt.utils import assert_valid_axis
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
21
|
+
def stx_raster(
|
|
22
22
|
ax,
|
|
23
|
-
|
|
23
|
+
spike_times_list,
|
|
24
24
|
time=None,
|
|
25
25
|
labels=None,
|
|
26
26
|
colors=None,
|
|
27
27
|
orientation="horizontal",
|
|
28
28
|
y_offset=None,
|
|
29
29
|
lineoffsets=None,
|
|
30
|
+
linelengths=None,
|
|
30
31
|
apply_set_n_ticks=True,
|
|
31
32
|
n_xticks=4,
|
|
32
33
|
n_yticks=None,
|
|
@@ -39,8 +40,8 @@ def plot_raster(
|
|
|
39
40
|
----------
|
|
40
41
|
ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
|
|
41
42
|
The axes on which to draw the raster plot.
|
|
42
|
-
|
|
43
|
-
|
|
43
|
+
spike_times_list : list of array-like, shape (n_trials,) where each element is (n_spikes,)
|
|
44
|
+
List of spike/event time arrays, one per trial/channel
|
|
44
45
|
time : array-like, optional
|
|
45
46
|
The time indices for the events (default: np.linspace(0, max(event_times))).
|
|
46
47
|
labels : list, optional
|
|
@@ -50,9 +51,11 @@ def plot_raster(
|
|
|
50
51
|
orientation: str, optional
|
|
51
52
|
Orientation of raster plot (default: horizontal).
|
|
52
53
|
y_offset : float, optional
|
|
53
|
-
Vertical spacing between trials/channels.
|
|
54
|
+
Vertical spacing between trials/channels (default: 1.0).
|
|
54
55
|
lineoffsets : array-like, optional
|
|
55
56
|
Y-positions for each trial/channel (overrides automatic positioning).
|
|
57
|
+
linelengths : float, optional
|
|
58
|
+
Height of each spike mark (default: 0.8, slightly less than y_offset to prevent overlap).
|
|
56
59
|
apply_set_n_ticks : bool, optional
|
|
57
60
|
Whether to apply set_n_ticks for cleaner axis (default: True).
|
|
58
61
|
n_xticks : int, optional
|
|
@@ -71,37 +74,46 @@ def plot_raster(
|
|
|
71
74
|
"""
|
|
72
75
|
assert_valid_axis(ax, "First argument must be a matplotlib axis or scitex axis wrapper")
|
|
73
76
|
|
|
74
|
-
# Format
|
|
75
|
-
|
|
77
|
+
# Format spike_times_list data
|
|
78
|
+
spike_times_list = _ensure_list(spike_times_list)
|
|
79
|
+
|
|
80
|
+
# Add sample size (number of trials) to label if provided
|
|
81
|
+
if kwargs.get("label"):
|
|
82
|
+
n_trials = len(spike_times_list)
|
|
83
|
+
kwargs["label"] = f"{kwargs['label']} ($n$={n_trials})"
|
|
76
84
|
|
|
77
85
|
# Handle colors and labels
|
|
78
|
-
colors = _handle_colors(colors,
|
|
79
|
-
|
|
86
|
+
colors = _handle_colors(colors, spike_times_list)
|
|
87
|
+
|
|
80
88
|
# Handle lineoffsets for positioning between trials/channels
|
|
89
|
+
if y_offset is None:
|
|
90
|
+
y_offset = 1.0 # Default spacing
|
|
81
91
|
if lineoffsets is None:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
92
|
+
lineoffsets = np.arange(len(spike_times_list)) * y_offset
|
|
93
|
+
|
|
94
|
+
# Set linelengths to prevent overlap (80% of y_offset by default)
|
|
95
|
+
if linelengths is None:
|
|
96
|
+
linelengths = y_offset * 0.8
|
|
97
|
+
|
|
98
|
+
# Ensure lineoffsets is iterable and matches spike_times_list length
|
|
87
99
|
if np.isscalar(lineoffsets):
|
|
88
100
|
lineoffsets = [lineoffsets]
|
|
89
|
-
if len(lineoffsets) < len(
|
|
90
|
-
lineoffsets = list(lineoffsets) + list(range(len(lineoffsets), len(
|
|
101
|
+
if len(lineoffsets) < len(spike_times_list):
|
|
102
|
+
lineoffsets = list(lineoffsets) + list(range(len(lineoffsets), len(spike_times_list)))
|
|
91
103
|
|
|
92
|
-
# Plotting as eventplot using
|
|
93
|
-
for ii, (pos, color, offset) in enumerate(zip(
|
|
104
|
+
# Plotting as eventplot using spike_times_list with proper positioning
|
|
105
|
+
for ii, (pos, color, offset) in enumerate(zip(spike_times_list, colors, lineoffsets)):
|
|
94
106
|
label = _define_label(labels, ii)
|
|
95
|
-
ax.eventplot(pos, lineoffsets=offset,
|
|
96
|
-
colors=color, label=label, **kwargs)
|
|
107
|
+
ax.eventplot(pos, lineoffsets=offset, linelengths=linelengths,
|
|
108
|
+
orientation=orientation, colors=color, label=label, **kwargs)
|
|
97
109
|
|
|
98
110
|
# Apply set_n_ticks for cleaner axes if requested
|
|
99
111
|
if apply_set_n_ticks:
|
|
100
112
|
from scitex.plt.ax._style._set_n_ticks import set_n_ticks
|
|
101
|
-
|
|
113
|
+
|
|
102
114
|
# For categorical y-axis (trials/channels), use appropriate tick count
|
|
103
115
|
if n_yticks is None:
|
|
104
|
-
n_yticks = min(len(
|
|
116
|
+
n_yticks = min(len(spike_times_list), 8) # Max 8 ticks for readability
|
|
105
117
|
|
|
106
118
|
# Only apply if we have reasonable numeric ranges
|
|
107
119
|
try:
|
|
@@ -124,10 +136,10 @@ def plot_raster(
|
|
|
124
136
|
if labels is not None:
|
|
125
137
|
ax.legend()
|
|
126
138
|
|
|
127
|
-
# Return
|
|
128
|
-
|
|
139
|
+
# Return spike_times in a useful format
|
|
140
|
+
spike_times_digital_df = _event_times_to_digital_df(spike_times_list, time, lineoffsets)
|
|
129
141
|
|
|
130
|
-
return ax,
|
|
142
|
+
return ax, spike_times_digital_df
|
|
131
143
|
|
|
132
144
|
|
|
133
145
|
def _ensure_list(event_times):
|
|
@@ -12,11 +12,12 @@ __DIR__ = os.path.dirname(__FILE__)
|
|
|
12
12
|
from matplotlib.patches import Rectangle
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def
|
|
15
|
+
def stx_rectangle(ax, xx, yy, ww, hh, **kwargs):
|
|
16
16
|
"""Add a rectangle patch to an axes.
|
|
17
17
|
|
|
18
18
|
Convenience function for adding rectangular patches to plots, useful for
|
|
19
19
|
highlighting regions, creating box annotations, or drawing geometric shapes.
|
|
20
|
+
By default, rectangles have no edge (border) for cleaner publication figures.
|
|
20
21
|
|
|
21
22
|
Parameters
|
|
22
23
|
----------
|
|
@@ -34,7 +35,7 @@ def plot_rectangle(ax, xx, yy, ww, hh, **kwargs):
|
|
|
34
35
|
Additional keyword arguments passed to matplotlib.patches.Rectangle.
|
|
35
36
|
Common options include:
|
|
36
37
|
- facecolor/fc : fill color
|
|
37
|
-
- edgecolor/ec : edge color
|
|
38
|
+
- edgecolor/ec : edge color (default: 'none')
|
|
38
39
|
- linewidth/lw : edge line width
|
|
39
40
|
- alpha : transparency (0-1)
|
|
40
41
|
- linestyle/ls : edge line style
|
|
@@ -48,20 +49,20 @@ def plot_rectangle(ax, xx, yy, ww, hh, **kwargs):
|
|
|
48
49
|
--------
|
|
49
50
|
>>> fig, ax = plt.subplots()
|
|
50
51
|
>>> ax.plot([0, 10], [0, 10])
|
|
51
|
-
>>> # Highlight a region
|
|
52
|
-
>>>
|
|
52
|
+
>>> # Highlight a region (no border by default)
|
|
53
|
+
>>> stx_rectangle(ax, 2, 3, 4, 3, facecolor='yellow', alpha=0.3)
|
|
53
54
|
|
|
54
|
-
>>> # Draw a box
|
|
55
|
-
>>>
|
|
56
|
-
|
|
57
|
-
>>> # Create a filled rectangle
|
|
58
|
-
>>> plot_rectangle(ax, 0, 0, 1, 1, facecolor='blue', edgecolor='black')
|
|
55
|
+
>>> # Draw a box with explicit edge
|
|
56
|
+
>>> stx_rectangle(ax, 5, 5, 2, 2, facecolor='none', edgecolor='red', linewidth=2)
|
|
59
57
|
|
|
60
58
|
See Also
|
|
61
59
|
--------
|
|
62
60
|
matplotlib.patches.Rectangle : The underlying Rectangle class
|
|
63
61
|
matplotlib.axes.Axes.add_patch : Method used to add the patch
|
|
64
62
|
"""
|
|
63
|
+
# Default to no edge for cleaner publication figures
|
|
64
|
+
if 'edgecolor' not in kwargs and 'ec' not in kwargs:
|
|
65
|
+
kwargs['edgecolor'] = 'none'
|
|
65
66
|
ax.add_patch(Rectangle((xx, yy), ww, hh, **kwargs))
|
|
66
67
|
return ax
|
|
67
68
|
|