scitex 2.3.0__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. scitex/ai/classification/reporters/reporter_utils/_Plotter.py +1 -1
  2. scitex/ai/plt/__init__.py +2 -2
  3. scitex/ai/plt/{_plot_conf_mat.py → _stx_conf_mat.py} +3 -3
  4. scitex/config/PriorityConfig.py +195 -0
  5. scitex/config/__init__.py +24 -0
  6. scitex/io/_save.py +125 -34
  7. scitex/io/_save_modules/_image.py +37 -20
  8. scitex/plt/__init__.py +470 -17
  9. scitex/plt/_subplots/_AxisWrapper.py +98 -50
  10. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +559 -124
  11. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +49 -8
  12. scitex/plt/_subplots/_SubplotsWrapper.py +76 -91
  13. scitex/plt/_subplots/_export_as_csv.py +127 -58
  14. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +25 -16
  15. scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py +54 -0
  16. scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py +41 -0
  17. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py +41 -0
  18. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -47
  19. scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py +42 -0
  20. scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py +42 -0
  21. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +72 -35
  22. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +1 -1
  23. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +2 -2
  24. scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py +53 -0
  25. scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py +42 -0
  26. scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py +42 -0
  27. scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py +48 -0
  28. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_conf_mat.py → _format_stx_conf_mat.py} +2 -2
  29. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_ecdf.py → _format_stx_ecdf.py} +2 -2
  30. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_fillv.py → _format_stx_fillv.py} +2 -2
  31. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_heatmap.py → _format_stx_heatmap.py} +2 -2
  32. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_image.py → _format_stx_image.py} +2 -2
  33. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_joyplot.py → _format_stx_joyplot.py} +2 -2
  34. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_line.py → _format_stx_line.py} +3 -3
  35. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_ci.py → _format_stx_mean_ci.py} +2 -2
  36. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_mean_std.py → _format_stx_mean_std.py} +2 -2
  37. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_median_iqr.py → _format_stx_median_iqr.py} +2 -2
  38. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_raster.py → _format_stx_raster.py} +2 -2
  39. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_rectangle.py → _format_stx_rectangle.py} +1 -1
  40. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_scatter_hist.py → _format_stx_scatter_hist.py} +2 -2
  41. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_shaded_line.py → _format_stx_shaded_line.py} +2 -2
  42. scitex/plt/_subplots/_export_as_csv_formatters/{_format_plot_violin.py → _format_stx_violin.py} +2 -2
  43. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +23 -23
  44. scitex/plt/ax/__init__.py +16 -15
  45. scitex/plt/ax/_plot/__init__.py +30 -30
  46. scitex/plt/ax/_plot/_add_fitted_line.py +65 -11
  47. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +104 -76
  48. scitex/plt/ax/_plot/{_plot_conf_mat.py → _stx_conf_mat.py} +10 -10
  49. scitex/plt/ax/_plot/_stx_ecdf.py +109 -0
  50. scitex/plt/ax/_plot/{_plot_fillv.py → _stx_fillv.py} +7 -7
  51. scitex/plt/ax/_plot/_stx_heatmap.py +366 -0
  52. scitex/plt/ax/_plot/{_plot_image.py → _stx_image.py} +1 -1
  53. scitex/plt/ax/_plot/_stx_joyplot.py +113 -0
  54. scitex/plt/ax/_plot/{_plot_raster.py → _stx_raster.py} +37 -25
  55. scitex/plt/ax/_plot/{_plot_rectangle.py → _stx_rectangle.py} +10 -9
  56. scitex/plt/ax/_plot/{_plot_scatter_hist.py → _stx_scatter_hist.py} +1 -1
  57. scitex/plt/ax/_plot/_stx_shaded_line.py +215 -0
  58. scitex/plt/ax/_plot/{_plot_violin.py → _stx_violin.py} +13 -6
  59. scitex/plt/ax/_style/__init__.py +3 -0
  60. scitex/plt/ax/_style/_style_barplot.py +13 -2
  61. scitex/plt/ax/_style/_style_boxplot.py +78 -32
  62. scitex/plt/ax/_style/_style_errorbar.py +17 -3
  63. scitex/plt/ax/_style/_style_scatter.py +17 -3
  64. scitex/plt/ax/_style/_style_violinplot.py +109 -0
  65. scitex/plt/color/_vizualize_colors.py +3 -3
  66. scitex/plt/styles/SCITEX_STYLE.yaml +104 -0
  67. scitex/plt/styles/__init__.py +57 -0
  68. scitex/plt/styles/_plot_defaults.py +209 -0
  69. scitex/plt/styles/_plot_postprocess.py +518 -0
  70. scitex/plt/styles/_style_loader.py +268 -0
  71. scitex/plt/styles/presets.py +208 -0
  72. scitex/plt/utils/_collect_figure_metadata.py +160 -18
  73. scitex/plt/utils/_colorbar.py +72 -10
  74. scitex/plt/utils/_configure_mpl.py +108 -52
  75. scitex/plt/utils/_crop.py +21 -7
  76. scitex/plt/utils/_figure_mm.py +21 -7
  77. scitex/stats/__init__.py +13 -1
  78. scitex/stats/_schema.py +578 -0
  79. scitex/stats/tests/__init__.py +13 -0
  80. scitex/stats/tests/correlation/__init__.py +13 -0
  81. scitex/stats/tests/correlation/_test_pearson.py +262 -0
  82. scitex/vis/__init__.py +6 -0
  83. scitex/vis/editor/__init__.py +23 -0
  84. scitex/vis/editor/_defaults.py +205 -0
  85. scitex/vis/editor/_edit.py +342 -0
  86. scitex/vis/editor/_mpl_editor.py +231 -0
  87. scitex/vis/editor/_tkinter_editor.py +466 -0
  88. scitex/vis/editor/_web_editor.py +1440 -0
  89. scitex/vis/model/plot_types.py +15 -15
  90. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/METADATA +2 -1
  91. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/RECORD +94 -67
  92. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/WHEEL +1 -1
  93. scitex/plt/ax/_plot/_plot_ecdf.py +0 -84
  94. scitex/plt/ax/_plot/_plot_heatmap.py +0 -277
  95. scitex/plt/ax/_plot/_plot_joyplot.py +0 -77
  96. scitex/plt/ax/_plot/_plot_shaded_line.py +0 -142
  97. scitex/plt/presets.py +0 -224
  98. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/entry_points.txt +0 -0
  99. {scitex-2.3.0.dist-info → scitex-2.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -14,24 +14,50 @@ import numpy as np
14
14
  import pandas as pd
15
15
  from ....plt.utils import assert_valid_axis
16
16
 
17
- from ._plot_shaded_line import plot_shaded_line as scitex_plt_plot_shaded_line
17
+ from ._stx_shaded_line import stx_shaded_line as scitex_plt_plot_shaded_line
18
18
 
19
19
 
20
- def plot_line(axis, data, xx=None, **kwargs):
20
+ def _format_sample_size(values_2d):
21
+ """Format sample size string, showing range if variable due to NaN.
22
+
23
+ Parameters
24
+ ----------
25
+ values_2d : np.ndarray, shape (n_samples, n_points)
26
+ 2D array where sample count may vary per column due to NaN.
27
+
28
+ Returns
29
+ -------
30
+ str
31
+ Formatted sample size string, e.g., "20" or "18-20".
32
+ """
33
+ if values_2d.ndim == 1:
34
+ return "1"
35
+
36
+ # Count non-NaN values per column (timepoint)
37
+ n_per_point = np.sum(~np.isnan(values_2d), axis=0)
38
+ n_min, n_max = int(n_per_point.min()), int(n_per_point.max())
39
+
40
+ if n_min == n_max:
41
+ return str(n_min)
42
+ else:
43
+ return f"{n_min}-{n_max}"
44
+
45
+
46
+ def stx_line(axis, values_1d, xx=None, **kwargs):
21
47
  """
22
48
  Plot a simple line.
23
-
49
+
24
50
  Parameters
25
51
  ----------
26
52
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
27
53
  The axis to plot on
28
- data : array-like
29
- Data to plot
30
- xx : array-like, optional
31
- X coordinates for the data. If None, will use np.arange(len(data))
54
+ values_1d : array-like, shape (n_points,)
55
+ 1D array of y-values to plot
56
+ xx : array-like, shape (n_points,), optional
57
+ X coordinates for the data. If None, will use np.arange(len(values_1d))
32
58
  **kwargs
33
59
  Additional keyword arguments passed to axis.plot()
34
-
60
+
35
61
  Returns
36
62
  -------
37
63
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
@@ -40,36 +66,38 @@ def plot_line(axis, data, xx=None, **kwargs):
40
66
  DataFrame with x and y values
41
67
  """
42
68
  assert_valid_axis(axis, "First argument must be a matplotlib axis or scitex axis wrapper")
43
- data = np.asarray(data)
44
- assert data.ndim <= 2, f"Data must be 1D or 2D, got {data.ndim}D"
69
+ values_1d = np.asarray(values_1d)
70
+ assert values_1d.ndim <= 2, f"Data must be 1D or 2D, got {values_1d.ndim}D"
45
71
  if xx is None:
46
- xx = np.arange(len(data))
72
+ xx = np.arange(len(values_1d))
47
73
  else:
48
74
  xx = np.asarray(xx)
49
75
  assert len(xx) == len(
50
- data
51
- ), f"xx length ({len(xx)}) must match data length ({len(data)})"
52
- axis.plot(xx, data, **kwargs)
53
- return axis, pd.DataFrame({"x": xx, "y": data})
76
+ values_1d
77
+ ), f"xx length ({len(xx)}) must match values_1d length ({len(values_1d)})"
78
+
79
+ axis.plot(xx, values_1d, **kwargs)
80
+ return axis, pd.DataFrame({"x": xx, "y": values_1d})
54
81
 
55
82
 
56
- def plot_mean_std(axis, data, xx=None, sd=1, **kwargs):
83
+ def stx_mean_std(axis, values_2d, xx=None, sd=1, **kwargs):
57
84
  """
58
85
  Plot mean line with standard deviation shading.
59
-
86
+
60
87
  Parameters
61
88
  ----------
62
89
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
63
90
  The axis to plot on
64
- data : array-like
65
- Data to plot, can be 1D or 2D. If 2D, mean and std are calculated across the first dimension
66
- xx : array-like, optional
67
- X coordinates for the data. If None, will use np.arange(len(data))
91
+ values_2d : array-like, shape (n_samples, n_points) or (n_points,)
92
+ 2D array where mean and std are calculated across axis=0 (samples).
93
+ Can also be 1D for a single line without shading.
94
+ xx : array-like, shape (n_points,), optional
95
+ X coordinates for the data. If None, will use np.arange(n_points)
68
96
  sd : float, optional
69
97
  Number of standard deviations for the shaded region. Default is 1
70
98
  **kwargs
71
- Additional keyword arguments passed to plot_shaded_line()
72
-
99
+ Additional keyword arguments passed to stx_shaded_line()
100
+
73
101
  Returns
74
102
  -------
75
103
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
@@ -78,51 +106,52 @@ def plot_mean_std(axis, data, xx=None, sd=1, **kwargs):
78
106
  assert_valid_axis(axis, "First argument must be a matplotlib axis or scitex axis wrapper")
79
107
  assert isinstance(sd, (int, float)), f"sd must be a number, got {type(sd)}"
80
108
  assert sd >= 0, f"sd must be non-negative, got {sd}"
81
- data = np.asarray(data)
82
- assert data.ndim <= 2, f"Data must be 1D or 2D, got {data.ndim}D"
109
+ values_2d = np.asarray(values_2d)
110
+ assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D"
83
111
  if xx is None:
84
- xx = np.arange(data.shape[1] if data.ndim > 1 else len(data))
112
+ xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d))
85
113
  else:
86
114
  xx = np.asarray(xx)
87
- expected_len = data.shape[1] if data.ndim > 1 else len(data)
115
+ expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)
88
116
  assert (
89
117
  len(xx) == expected_len
90
- ), f"xx length ({len(xx)}) must match data length ({expected_len})"
118
+ ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})"
91
119
 
92
- if data.ndim == 1:
93
- central = data
120
+ if values_2d.ndim == 1:
121
+ central = values_2d
94
122
  error = np.zeros_like(central)
95
123
  else:
96
- central = np.nanmean(data, axis=0)
97
- error = np.nanstd(data, axis=0) * sd
124
+ central = np.nanmean(values_2d, axis=0)
125
+ error = np.nanstd(values_2d, axis=0) * sd
98
126
 
99
127
  y_lower = central - error
100
128
  y_upper = central + error
101
- n_samples = data.shape[0] if data.ndim > 1 else 1
102
129
 
103
130
  if "label" in kwargs and kwargs["label"]:
104
- kwargs["label"] = f"{kwargs['label']} (n={n_samples})"
131
+ n_str = _format_sample_size(values_2d)
132
+ kwargs["label"] = f"{kwargs['label']} ($n$={n_str})"
105
133
 
106
134
  return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs)
107
135
 
108
136
 
109
- def plot_mean_ci(axis, data, xx=None, perc=95, **kwargs):
137
+ def stx_mean_ci(axis, values_2d, xx=None, perc=95, **kwargs):
110
138
  """
111
139
  Plot mean line with confidence interval shading.
112
-
140
+
113
141
  Parameters
114
142
  ----------
115
143
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
116
144
  The axis to plot on
117
- data : array-like
118
- Data to plot, can be 1D or 2D. If 2D, mean and percentiles are calculated across the first dimension
119
- xx : array-like, optional
120
- X coordinates for the data. If None, will use np.arange(len(data))
145
+ values_2d : array-like, shape (n_samples, n_points) or (n_points,)
146
+ 2D array where mean and percentiles are calculated across axis=0 (samples).
147
+ Can also be 1D for a single line without shading.
148
+ xx : array-like, shape (n_points,), optional
149
+ X coordinates for the data. If None, will use np.arange(n_points)
121
150
  perc : float, optional
122
151
  Confidence interval percentage (0-100). Default is 95
123
152
  **kwargs
124
- Additional keyword arguments passed to plot_shaded_line()
125
-
153
+ Additional keyword arguments passed to stx_shaded_line()
154
+
126
155
  Returns
127
156
  -------
128
157
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
@@ -133,87 +162,86 @@ def plot_mean_ci(axis, data, xx=None, perc=95, **kwargs):
133
162
  perc, (int, float)
134
163
  ), f"perc must be a number, got {type(perc)}"
135
164
  assert 0 <= perc <= 100, f"perc must be between 0 and 100, got {perc}"
136
- data = np.asarray(data)
137
- assert data.ndim <= 2, f"Data must be 1D or 2D, got {data.ndim}D"
165
+ values_2d = np.asarray(values_2d)
166
+ assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D"
138
167
 
139
168
  if xx is None:
140
- xx = np.arange(data.shape[1] if data.ndim > 1 else len(data))
169
+ xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d))
141
170
  else:
142
171
  xx = np.asarray(xx)
143
172
 
144
- expected_len = data.shape[1] if data.ndim > 1 else len(data)
173
+ expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)
145
174
  assert (
146
175
  len(xx) == expected_len
147
- ), f"xx length ({len(xx)}) must match data length ({expected_len})"
176
+ ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})"
148
177
 
149
- if data.ndim == 1:
150
- central = data
178
+ if values_2d.ndim == 1:
179
+ central = values_2d
151
180
  y_lower = central
152
181
  y_upper = central
153
182
  else:
154
- central = np.nanmean(data, axis=0)
183
+ central = np.nanmean(values_2d, axis=0)
155
184
  # Calculate CI bounds
156
185
  alpha = 1 - perc / 100
157
186
  y_lower_perc = alpha / 2 * 100
158
187
  y_upper_perc = (1 - alpha / 2) * 100
159
- y_lower = np.nanpercentile(data, y_lower_perc, axis=0)
160
- y_upper = np.nanpercentile(data, y_upper_perc, axis=0)
161
-
162
- n_samples = data.shape[0] if data.ndim > 1 else 1
188
+ y_lower = np.nanpercentile(values_2d, y_lower_perc, axis=0)
189
+ y_upper = np.nanpercentile(values_2d, y_upper_perc, axis=0)
163
190
 
164
191
  if "label" in kwargs and kwargs["label"]:
165
- kwargs["label"] = f"{kwargs['label']} (n={n_samples}, CI={perc}%)"
192
+ n_str = _format_sample_size(values_2d)
193
+ kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, CI={perc}%)"
166
194
 
167
195
  return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs)
168
196
 
169
197
 
170
- def plot_median_iqr(axis, data, xx=None, **kwargs):
198
+ def stx_median_iqr(axis, values_2d, xx=None, **kwargs):
171
199
  """
172
200
  Plot median line with interquartile range shading.
173
-
201
+
174
202
  Parameters
175
203
  ----------
176
204
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
177
205
  The axis to plot on
178
- data : array-like
179
- Data to plot, can be 1D or 2D. If 2D, median and IQR are calculated across the first dimension
180
- xx : array-like, optional
181
- X coordinates for the data. If None, will use np.arange(len(data))
206
+ values_2d : array-like, shape (n_samples, n_points) or (n_points,)
207
+ 2D array where median and IQR are calculated across axis=0 (samples).
208
+ Can also be 1D for a single line without shading.
209
+ xx : array-like, shape (n_points,), optional
210
+ X coordinates for the data. If None, will use np.arange(n_points)
182
211
  **kwargs
183
- Additional keyword arguments passed to plot_shaded_line()
184
-
212
+ Additional keyword arguments passed to stx_shaded_line()
213
+
185
214
  Returns
186
215
  -------
187
216
  axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper
188
217
  The axis with the plot
189
218
  """
190
219
  assert_valid_axis(axis, "First argument must be a matplotlib axis or scitex axis wrapper")
191
- data = np.asarray(data)
192
- assert data.ndim <= 2, f"Data must be 1D or 2D, got {data.ndim}D"
220
+ values_2d = np.asarray(values_2d)
221
+ assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D"
193
222
 
194
223
  if xx is None:
195
- xx = np.arange(data.shape[1] if data.ndim > 1 else len(data))
224
+ xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d))
196
225
  else:
197
226
  xx = np.asarray(xx)
198
227
 
199
- expected_len = data.shape[1] if data.ndim > 1 else len(data)
228
+ expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)
200
229
  assert (
201
230
  len(xx) == expected_len
202
- ), f"xx length ({len(xx)}) must match data length ({expected_len})"
231
+ ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})"
203
232
 
204
- if data.ndim == 1:
205
- central = data
233
+ if values_2d.ndim == 1:
234
+ central = values_2d
206
235
  y_lower = central
207
236
  y_upper = central
208
237
  else:
209
- central = np.nanmedian(data, axis=0)
210
- y_lower = np.nanpercentile(data, 25, axis=0)
211
- y_upper = np.nanpercentile(data, 75, axis=0)
212
-
213
- n_samples = data.shape[0] if data.ndim > 1 else 1
238
+ central = np.nanmedian(values_2d, axis=0)
239
+ y_lower = np.nanpercentile(values_2d, 25, axis=0)
240
+ y_upper = np.nanpercentile(values_2d, 75, axis=0)
214
241
 
215
242
  if "label" in kwargs and kwargs["label"]:
216
- kwargs["label"] = f"{kwargs['label']} (n={n_samples}, IQR)"
243
+ n_str = _format_sample_size(values_2d)
244
+ kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, IQR)"
217
245
 
218
246
  return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs)
219
247
 
@@ -22,9 +22,9 @@ from scitex.plt.utils import assert_valid_axis
22
22
  from .._style._extend import extend as scitex_plt_extend
23
23
 
24
24
 
25
- def plot_conf_mat(
25
+ def stx_conf_mat(
26
26
  axis: plt.Axes,
27
- data: Union[np.ndarray, pd.DataFrame],
27
+ conf_mat_2d: Union[np.ndarray, pd.DataFrame],
28
28
  x_labels: Optional[List[str]] = None,
29
29
  y_labels: Optional[List[str]] = None,
30
30
  title: str = "Confusion Matrix",
@@ -43,8 +43,8 @@ def plot_conf_mat(
43
43
  ----------
44
44
  axis : plt.Axes or scitex.plt._subplots._AxisWrapper.AxisWrapper
45
45
  Matplotlib axes or scitex axis wrapper to plot on
46
- data : Union[np.ndarray, pd.DataFrame]
47
- Confusion matrix data
46
+ conf_mat_2d : Union[np.ndarray, pd.DataFrame], shape (n_classes, n_classes)
47
+ 2D confusion matrix data (true labels × predicted labels)
48
48
  x_labels : Optional[List[str]], optional
49
49
  Labels for predicted classes
50
50
  y_labels : Optional[List[str]], optional
@@ -75,7 +75,7 @@ def plot_conf_mat(
75
75
  -------
76
76
  >>> data = np.array([[10, 2, 0], [1, 15, 3], [0, 2, 20]])
77
77
  >>> fig, ax = plt.subplots()
78
- >>> ax, bacc = plot_conf_mat(ax, data, x_labels=['A','B','C'],
78
+ >>> ax, bacc = stx_conf_mat(ax, data, x_labels=['A','B','C'],
79
79
  ... y_labels=['X','Y','Z'], calc_bacc=True)
80
80
  >>> print(f"Balanced Accuracy: {bacc:.3f}")
81
81
  Balanced Accuracy: 0.889
@@ -83,14 +83,14 @@ def plot_conf_mat(
83
83
 
84
84
  assert_valid_axis(axis, "First argument must be a matplotlib axis or scitex axis wrapper")
85
85
 
86
- if not isinstance(data, pd.DataFrame):
87
- data = pd.DataFrame(data)
86
+ if not isinstance(conf_mat_2d, pd.DataFrame):
87
+ conf_mat_2d = pd.DataFrame(conf_mat_2d)
88
88
 
89
- bacc_val = calc_bacc_from_conf_mat(data.values)
89
+ bacc_val = calc_bacc_from_conf_mat(conf_mat_2d.values)
90
90
  title = f"{title} (bACC = {bacc_val:.3f})"
91
91
 
92
92
  res = sns.heatmap(
93
- data,
93
+ conf_mat_2d,
94
94
  ax=axis,
95
95
  cmap=cmap,
96
96
  annot=True,
@@ -115,7 +115,7 @@ def plot_conf_mat(
115
115
  axis.set_yticklabels(y_labels)
116
116
 
117
117
  axis = scitex_plt_extend(axis, x_extend_ratio, y_extend_ratio)
118
- if data.shape[0] == data.shape[1]:
118
+ if conf_mat_2d.shape[0] == conf_mat_2d.shape[1]:
119
119
  axis.set_box_aspect(1)
120
120
  axis.set_xticklabels(
121
121
  axis.get_xticklabels(),
@@ -0,0 +1,109 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-12-01 14:00:00 (ywatanabe)"
4
+ # File: ./src/scitex/plt/ax/_plot/_plot_ecdf.py
5
+
6
+ """Empirical Cumulative Distribution Function (ECDF) plotting."""
7
+
8
+ import warnings
9
+ from typing import Any, Tuple, Union
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from matplotlib.axes import Axes
14
+
15
+ from scitex.pd._force_df import force_df as scitex_pd_force_df
16
+ from ....plt.utils import assert_valid_axis, mm_to_pt
17
+
18
+
19
+ # Default line width (0.2mm for publication)
20
+ DEFAULT_LINE_WIDTH_MM = 0.2
21
+
22
+
23
+ def stx_ecdf(
24
+ axis: Union[Axes, "AxisWrapper"],
25
+ values_1d: np.ndarray,
26
+ **kwargs: Any,
27
+ ) -> Tuple[Union[Axes, "AxisWrapper"], pd.DataFrame]:
28
+ """Plot Empirical Cumulative Distribution Function (ECDF).
29
+
30
+ The ECDF shows the proportion of data points less than or equal to each
31
+ value, representing the empirical estimate of the cumulative distribution
32
+ function.
33
+
34
+ Parameters
35
+ ----------
36
+ axis : matplotlib.axes.Axes or AxisWrapper
37
+ Matplotlib axis or scitex axis wrapper to plot on.
38
+ values_1d : array-like, shape (n_samples,)
39
+ 1D array of values to compute and plot ECDF for. NaN values are automatically ignored.
40
+ **kwargs : dict
41
+ Additional arguments passed to plot function.
42
+
43
+ Returns
44
+ -------
45
+ axis : matplotlib.axes.Axes or AxisWrapper
46
+ The axes with the ECDF plot.
47
+ df : pd.DataFrame
48
+ DataFrame containing ECDF data with columns:
49
+ - x: sorted data values
50
+ - y: cumulative percentages (0-100)
51
+ - n: total number of data points
52
+ - x_step, y_step: step plot coordinates
53
+
54
+ Examples
55
+ --------
56
+ >>> import numpy as np
57
+ >>> import scitex as stx
58
+ >>> data = np.random.randn(100)
59
+ >>> fig, ax = stx.plt.subplots()
60
+ >>> ax, df = stx.plt.ax.stx_ecdf(ax, data)
61
+ """
62
+ assert_valid_axis(axis, "First argument must be a matplotlib axis or scitex axis wrapper")
63
+
64
+ # Flatten and remove NaN values
65
+ values_1d = np.hstack(values_1d)
66
+
67
+ # Warnings
68
+ if np.isnan(values_1d).any():
69
+ warnings.warn("NaN value are ignored for ECDF plot.")
70
+ values_1d = values_1d[~np.isnan(values_1d)]
71
+ nn = len(values_1d)
72
+
73
+ # Sort the data and compute the ECDF values
74
+ data_sorted = np.sort(values_1d)
75
+ ecdf_perc = 100 * np.arange(1, len(data_sorted) + 1) / len(data_sorted)
76
+
77
+ # Create the pseudo x-axis for step plotting
78
+ x_step = np.repeat(data_sorted, 2)[1:]
79
+ y_step = np.repeat(ecdf_perc, 2)[:-1]
80
+
81
+ # Apply default linewidth if not specified
82
+ if 'linewidth' not in kwargs and 'lw' not in kwargs:
83
+ kwargs['linewidth'] = mm_to_pt(DEFAULT_LINE_WIDTH_MM)
84
+
85
+ # Add sample size to label if provided
86
+ if "label" in kwargs and kwargs["label"]:
87
+ kwargs["label"] = f"{kwargs['label']} ($n$={nn})"
88
+
89
+ # Plot the ECDF using steps (no markers - clean line only)
90
+ axis.plot(x_step, y_step, drawstyle="steps-post", **kwargs)
91
+
92
+ # Set ylim (xlim is auto-scaled based on data)
93
+ axis.set_ylim(0, 100)
94
+
95
+ # Create a DataFrame to hold the ECDF data
96
+ df = scitex_pd_force_df(
97
+ {
98
+ "x": data_sorted,
99
+ "y": ecdf_perc,
100
+ "n": nn,
101
+ "x_step": x_step,
102
+ "y_step": y_step,
103
+ }
104
+ )
105
+
106
+ return axis, df
107
+
108
+
109
+ # EOF
@@ -14,7 +14,7 @@ import numpy as np
14
14
  from ....plt.utils import assert_valid_axis
15
15
 
16
16
 
17
- def plot_fillv(axes, starts, ends, color="red", alpha=0.2):
17
+ def stx_fillv(axes, starts_1d, ends_1d, color="red", alpha=0.2):
18
18
  """
19
19
  Fill between specified start and end intervals on an axis or array of axes.
20
20
 
@@ -22,10 +22,10 @@ def plot_fillv(axes, starts, ends, color="red", alpha=0.2):
22
22
  ----------
23
23
  axes : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper or numpy.ndarray of axes
24
24
  The axis object(s) to fill intervals on.
25
- starts : array-like
26
- Array-like of start positions.
27
- ends : array-like
28
- Array-like of end positions.
25
+ starts_1d : array-like, shape (n_regions,)
26
+ 1D array of start x-positions for vertical fill regions.
27
+ ends_1d : array-like, shape (n_regions,)
28
+ 1D array of end x-positions for vertical fill regions.
29
29
  color : str, optional
30
30
  The color to use for the filled regions. Default is "red".
31
31
  alpha : float, optional
@@ -43,8 +43,8 @@ def plot_fillv(axes, starts, ends, color="red", alpha=0.2):
43
43
 
44
44
  for ax in axes:
45
45
  assert_valid_axis(ax, "First argument must be a matplotlib axis or scitex axis wrapper")
46
- for start, end in zip(starts, ends):
47
- ax.axvspan(start, end, color=color, alpha=alpha)
46
+ for start, end in zip(starts_1d, ends_1d):
47
+ ax.axvspan(start, end, facecolor=color, edgecolor='none', alpha=alpha)
48
48
 
49
49
  if not is_axes:
50
50
  return axes[0]