dataeval 0.72.2__py3-none-any.whl → 0.73.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +1 -1
- dataeval/detectors/drift/base.py +2 -2
- dataeval/detectors/linters/clusterer.py +1 -1
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +14 -6
- dataeval/detectors/ood/aegmm.py +14 -6
- dataeval/detectors/ood/base.py +9 -3
- dataeval/detectors/ood/llr.py +22 -16
- dataeval/detectors/ood/vae.py +14 -6
- dataeval/detectors/ood/vaegmm.py +14 -6
- dataeval/interop.py +9 -7
- dataeval/metrics/bias/balance.py +50 -44
- dataeval/metrics/bias/coverage.py +38 -6
- dataeval/metrics/bias/diversity.py +117 -65
- dataeval/metrics/bias/metadata.py +225 -60
- dataeval/metrics/bias/parity.py +68 -54
- dataeval/utils/__init__.py +4 -3
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +1 -1
- dataeval/utils/split_dataset.py +12 -6
- dataeval/utils/tensorflow/_internal/gmm.py +8 -2
- dataeval/utils/tensorflow/_internal/loss.py +20 -11
- dataeval/utils/tensorflow/_internal/{pixelcnn.py → models.py} +371 -77
- dataeval/utils/tensorflow/_internal/trainer.py +12 -5
- dataeval/utils/tensorflow/_internal/utils.py +70 -71
- dataeval/utils/torch/datasets.py +2 -2
- dataeval/workflows/__init__.py +1 -1
- {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/METADATA +3 -3
- {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/RECORD +34 -33
- dataeval/utils/tensorflow/_internal/autoencoder.py +0 -316
- {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
from typing import Any, Mapping
|
6
7
|
|
7
8
|
import numpy as np
|
@@ -10,54 +11,87 @@ from scipy.stats import entropy as sp_entropy
|
|
10
11
|
|
11
12
|
from dataeval.interop import to_numpy
|
12
13
|
|
14
|
+
with contextlib.suppress(ImportError):
|
15
|
+
from matplotlib.figure import Figure
|
16
|
+
|
17
|
+
CLASS_LABEL = "class_label"
|
18
|
+
|
13
19
|
|
14
20
|
def get_counts(
|
15
|
-
data: NDArray[
|
16
|
-
|
21
|
+
data: NDArray[Any],
|
22
|
+
names: list[str],
|
23
|
+
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
24
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
25
|
+
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
26
|
+
) -> dict[str, NDArray[np.intp]]:
|
17
27
|
"""
|
18
28
|
Initialize dictionary of histogram counts --- treat categorical values
|
19
29
|
as histogram bins.
|
20
30
|
|
21
31
|
Parameters
|
22
32
|
----------
|
23
|
-
|
33
|
+
data : NDArray
|
34
|
+
Array containing numerical values for metadata factors
|
35
|
+
names : list[str]
|
36
|
+
Names of metadata factors -- keys of the metadata dictionary
|
37
|
+
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
38
|
+
The factors in names that have continuous values and the array of bin counts to
|
39
|
+
discretize values into. All factors are treated as having discrete values unless they
|
40
|
+
are specified as keys in this dictionary. Each element of this array must occur as a key
|
41
|
+
in names.
|
42
|
+
Names of metadata factors -- keys of the metadata dictionary
|
43
|
+
subset_mask : NDArray[np.bool_] or None, default None
|
24
44
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
45
|
+
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
46
|
+
Optional cache to store histogram counts
|
25
47
|
|
26
48
|
Returns
|
27
49
|
-------
|
28
|
-
|
50
|
+
dict[str, NDArray[np.intp]]
|
29
51
|
histogram counts per metadata factor in `factors`. Each
|
30
52
|
factor will have a different number of bins. Counts get reused
|
31
53
|
across metrics, so hist_counts are cached but only if computed
|
32
54
|
globally, i.e. without masked samples.
|
33
55
|
"""
|
34
56
|
|
35
|
-
hist_counts
|
36
|
-
|
37
|
-
mask =
|
57
|
+
hist_counts = {}
|
58
|
+
|
59
|
+
mask = subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=np.bool_)
|
38
60
|
|
39
61
|
for cdx, fn in enumerate(names):
|
40
|
-
|
41
|
-
|
42
|
-
if is_categorical[cdx]:
|
43
|
-
# if discrete, use unique values as bins
|
44
|
-
bins, cnts = np.unique(col_data, return_counts=True)
|
62
|
+
if hist_cache is not None and fn in hist_cache:
|
63
|
+
cnts = hist_cache[fn]
|
45
64
|
else:
|
46
|
-
|
47
|
-
cnts
|
65
|
+
hist_edges = np.array([-np.inf, np.inf])
|
66
|
+
cnts = np.array([len(data[:, cdx].squeeze())])
|
67
|
+
# linter doesn't like double indexing
|
68
|
+
col_data = np.array(data[mask, cdx].squeeze(), dtype=np.float64)
|
69
|
+
|
70
|
+
if continuous_factor_bincounts and fn in continuous_factor_bincounts:
|
71
|
+
num_bins = continuous_factor_bincounts[fn]
|
72
|
+
_, hist_edges = np.histogram(data[:, cdx].squeeze(), bins=num_bins, density=True)
|
73
|
+
hist_edges[-1] = np.inf
|
74
|
+
hist_edges[0] = -np.inf
|
75
|
+
disc_col_data = np.digitize(col_data, np.array(hist_edges))
|
76
|
+
_, cnts = np.unique(disc_col_data, return_counts=True)
|
77
|
+
else:
|
78
|
+
_, cnts = np.unique(col_data, return_counts=True)
|
79
|
+
|
80
|
+
if hist_cache is not None:
|
81
|
+
hist_cache[fn] = cnts
|
48
82
|
|
49
83
|
hist_counts[fn] = cnts
|
50
|
-
hist_bins[fn] = bins
|
51
84
|
|
52
|
-
return hist_counts
|
85
|
+
return hist_counts
|
53
86
|
|
54
87
|
|
55
88
|
def entropy(
|
56
89
|
data: NDArray[Any],
|
57
90
|
names: list[str],
|
58
|
-
|
91
|
+
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
59
92
|
normalized: bool = False,
|
60
93
|
subset_mask: NDArray[np.bool_] | None = None,
|
94
|
+
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
61
95
|
) -> NDArray[np.float64]:
|
62
96
|
"""
|
63
97
|
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
@@ -68,19 +102,30 @@ def entropy(
|
|
68
102
|
|
69
103
|
Parameters
|
70
104
|
----------
|
71
|
-
|
105
|
+
data : NDArray
|
106
|
+
Array containing numerical values for metadata factors
|
107
|
+
names : list[str]
|
108
|
+
Names of metadata factors -- keys of the metadata dictionary
|
109
|
+
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
110
|
+
The factors in names that have continuous values and the array of bin counts to
|
111
|
+
discretize values into. All factors are treated as having discrete values unless they
|
112
|
+
are specified as keys in this dictionary. Each element of this array must occur as a key
|
113
|
+
in names.
|
114
|
+
normalized : bool, default False
|
72
115
|
Flag that determines whether or not to normalize entropy by log(num_bins)
|
73
|
-
subset_mask: NDArray[np.bool_]
|
116
|
+
subset_mask : NDArray[np.bool_] or None, default None
|
74
117
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
118
|
+
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
119
|
+
Optional cache to store histogram counts
|
75
120
|
|
76
|
-
|
77
|
-
|
121
|
+
Notes
|
122
|
+
-----
|
78
123
|
For continuous variables, histogram bins are chosen automatically. See
|
79
124
|
numpy.histogram for details.
|
80
125
|
|
81
126
|
Returns
|
82
127
|
-------
|
83
|
-
|
128
|
+
NDArray[np.float64]
|
84
129
|
Entropy estimate per column of X
|
85
130
|
|
86
131
|
See Also
|
@@ -90,47 +135,64 @@ def entropy(
|
|
90
135
|
"""
|
91
136
|
|
92
137
|
num_factors = len(names)
|
93
|
-
hist_counts
|
138
|
+
hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
|
94
139
|
|
95
140
|
ev_index = np.empty(num_factors)
|
96
141
|
for col, cnts in enumerate(hist_counts.values()):
|
97
142
|
# entropy in nats, normalizes counts
|
98
143
|
ev_index[col] = sp_entropy(cnts)
|
99
144
|
if normalized:
|
100
|
-
|
145
|
+
cnt_len = np.size(cnts, 0)
|
146
|
+
if cnt_len == 1:
|
101
147
|
# log(0)
|
102
148
|
ev_index[col] = 0
|
103
149
|
else:
|
104
|
-
ev_index[col] /= np.log(
|
150
|
+
ev_index[col] /= np.log(cnt_len)
|
105
151
|
return ev_index
|
106
152
|
|
107
153
|
|
108
154
|
def get_num_bins(
|
109
|
-
data: NDArray[Any],
|
155
|
+
data: NDArray[Any],
|
156
|
+
names: list[str],
|
157
|
+
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
158
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
159
|
+
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
110
160
|
) -> NDArray[np.float64]:
|
111
161
|
"""
|
112
162
|
Number of bins or unique values for each metadata factor, used to
|
113
|
-
normalize entropy
|
163
|
+
normalize entropy/diversity.
|
114
164
|
|
115
165
|
Parameters
|
116
166
|
----------
|
117
|
-
|
167
|
+
data : NDArray
|
168
|
+
Array containing numerical values for metadata factors
|
169
|
+
names : list[str]
|
170
|
+
Names of metadata factors -- keys of the metadata dictionary
|
171
|
+
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
172
|
+
The factors in names that have continuous values and the array of bin counts to
|
173
|
+
discretize values into. All factors are treated as having discrete values unless they
|
174
|
+
are specified as keys in this dictionary. Each element of this array must occur as a key
|
175
|
+
in names.
|
176
|
+
subset_mask : NDArray[np.bool_] or None, default None
|
118
177
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
178
|
+
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
179
|
+
Optional cache to store histogram counts
|
119
180
|
|
120
181
|
Returns
|
121
182
|
-------
|
122
183
|
NDArray[np.float64]
|
184
|
+
Number of bins used in the discretization for each value in names.
|
123
185
|
"""
|
124
186
|
# likely cached
|
125
|
-
hist_counts
|
187
|
+
hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
|
126
188
|
num_bins = np.empty(len(hist_counts))
|
127
189
|
for idx, cnts in enumerate(hist_counts.values()):
|
128
|
-
num_bins[idx] =
|
190
|
+
num_bins[idx] = np.size(cnts, 0)
|
129
191
|
|
130
192
|
return num_bins
|
131
193
|
|
132
194
|
|
133
|
-
def infer_categorical(arr: NDArray[
|
195
|
+
def infer_categorical(arr: NDArray[np.float64], threshold: float = 0.2) -> NDArray[np.bool_]:
|
134
196
|
"""
|
135
197
|
Compute fraction of feature values that are unique --- intended to be used
|
136
198
|
for inferring whether variables are categorical.
|
@@ -147,14 +209,28 @@ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]
|
|
147
209
|
|
148
210
|
def preprocess_metadata(
|
149
211
|
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
150
|
-
) -> tuple[NDArray[Any], list[str], list[bool]]:
|
212
|
+
) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
|
213
|
+
"""
|
214
|
+
Formats metadata by organizing factor names, converting labels to numeric values,
|
215
|
+
adds class labels to the dataset structure, and marks which factors are categorical.
|
216
|
+
"""
|
217
|
+
# if class_labels is not numeric
|
218
|
+
class_array = to_numpy(class_labels)
|
219
|
+
if not np.issubdtype(class_array.dtype, np.integer):
|
220
|
+
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
221
|
+
else:
|
222
|
+
numerical_labels = np.asarray(class_array, dtype=np.intp)
|
223
|
+
unique_classes = np.unique(class_array)
|
224
|
+
|
151
225
|
# convert class_labels and dict of lists to matrix of metadata values
|
152
|
-
preprocessed_metadata = {
|
226
|
+
preprocessed_metadata = {CLASS_LABEL: numerical_labels}
|
153
227
|
|
154
228
|
# map columns of dict that are not numeric (e.g. string) to numeric values
|
155
229
|
# that mutual information and diversity functions can accommodate. Each
|
156
230
|
# unique string receives a unique integer value.
|
157
231
|
for k, v in metadata.items():
|
232
|
+
if k == CLASS_LABEL:
|
233
|
+
continue
|
158
234
|
# if not numeric
|
159
235
|
v = to_numpy(v)
|
160
236
|
if not np.issubdtype(v.dtype, np.number):
|
@@ -165,45 +241,56 @@ def preprocess_metadata(
|
|
165
241
|
|
166
242
|
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
167
243
|
names = list(preprocessed_metadata.keys())
|
168
|
-
is_categorical = [
|
244
|
+
is_categorical = [
|
245
|
+
var == CLASS_LABEL or infer_categorical(preprocessed_metadata[var].astype(np.float64), cat_thresh)[0]
|
246
|
+
for var in names
|
247
|
+
]
|
169
248
|
|
170
|
-
return data, names, is_categorical
|
249
|
+
return data, names, is_categorical, unique_classes
|
171
250
|
|
172
251
|
|
173
252
|
def heatmap(
|
174
|
-
data:
|
175
|
-
row_labels:
|
176
|
-
col_labels:
|
253
|
+
data: ArrayLike,
|
254
|
+
row_labels: list[str] | ArrayLike,
|
255
|
+
col_labels: list[str] | ArrayLike,
|
177
256
|
xlabel: str = "",
|
178
257
|
ylabel: str = "",
|
179
258
|
cbarlabel: str = "",
|
180
|
-
) ->
|
259
|
+
) -> Figure:
|
181
260
|
"""
|
182
261
|
Plots a formatted heatmap
|
183
262
|
|
184
263
|
Parameters
|
185
264
|
----------
|
186
|
-
data: NDArray
|
265
|
+
data : NDArray
|
187
266
|
Array containing numerical values for factors to plot
|
188
|
-
row_labels:
|
189
|
-
Array containing the labels for rows in the histogram
|
190
|
-
col_labels:
|
191
|
-
Array containing the labels for columns in the histogram
|
192
|
-
xlabel: str, default ""
|
267
|
+
row_labels : ArrayLike
|
268
|
+
List/Array containing the labels for rows in the histogram
|
269
|
+
col_labels : ArrayLike
|
270
|
+
List/Array containing the labels for columns in the histogram
|
271
|
+
xlabel : str, default ""
|
193
272
|
X-axis label
|
194
|
-
ylabel: str, default ""
|
273
|
+
ylabel : str, default ""
|
195
274
|
Y-axis label
|
196
|
-
cbarlabel: str, default ""
|
275
|
+
cbarlabel : str, default ""
|
197
276
|
Label for the colorbar
|
198
277
|
|
278
|
+
Returns
|
279
|
+
-------
|
280
|
+
matplotlib.figure.Figure
|
281
|
+
Formatted heatmap
|
199
282
|
"""
|
200
|
-
import matplotlib
|
201
283
|
import matplotlib.pyplot as plt
|
284
|
+
from matplotlib.ticker import FuncFormatter
|
285
|
+
|
286
|
+
np_data = to_numpy(data)
|
287
|
+
rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
|
288
|
+
cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
|
202
289
|
|
203
290
|
fig, ax = plt.subplots(figsize=(10, 10))
|
204
291
|
|
205
292
|
# Plot the heatmap
|
206
|
-
im = ax.imshow(
|
293
|
+
im = ax.imshow(np_data, vmin=0, vmax=1.0)
|
207
294
|
|
208
295
|
# Create colorbar
|
209
296
|
cbar = fig.colorbar(im, shrink=0.5)
|
@@ -212,8 +299,8 @@ def heatmap(
|
|
212
299
|
cbar.set_label(cbarlabel, loc="center")
|
213
300
|
|
214
301
|
# Show all ticks and label them with the respective list entries.
|
215
|
-
ax.set_xticks(np.arange(
|
216
|
-
ax.set_yticks(np.arange(
|
302
|
+
ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
|
303
|
+
ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
|
217
304
|
|
218
305
|
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
219
306
|
# Rotate the tick labels and set their alignment.
|
@@ -222,8 +309,8 @@ def heatmap(
|
|
222
309
|
# Turn spines off and create white grid.
|
223
310
|
ax.spines[:].set_visible(False)
|
224
311
|
|
225
|
-
ax.set_xticks(np.arange(
|
226
|
-
ax.set_yticks(np.arange(
|
312
|
+
ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
|
313
|
+
ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
|
227
314
|
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
228
315
|
ax.tick_params(which="minor", bottom=False, left=False)
|
229
316
|
|
@@ -232,7 +319,7 @@ def heatmap(
|
|
232
319
|
if ylabel:
|
233
320
|
ax.set_ylabel(ylabel)
|
234
321
|
|
235
|
-
valfmt =
|
322
|
+
valfmt = FuncFormatter(format_text)
|
236
323
|
|
237
324
|
# Normalize the threshold to the images color range.
|
238
325
|
threshold = im.norm(1.0) / 2.0
|
@@ -245,14 +332,14 @@ def heatmap(
|
|
245
332
|
# Change the text's color depending on the data.
|
246
333
|
textcolors = ("white", "black")
|
247
334
|
texts = []
|
248
|
-
for i in range(
|
249
|
-
for j in range(
|
250
|
-
kw.update(color=textcolors[int(im.norm(
|
251
|
-
text = im.axes.text(j, i, valfmt(
|
335
|
+
for i in range(np_data.shape[0]):
|
336
|
+
for j in range(np_data.shape[1]):
|
337
|
+
kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
|
338
|
+
text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
|
252
339
|
texts.append(text)
|
253
340
|
|
254
341
|
fig.tight_layout()
|
255
|
-
|
342
|
+
return fig
|
256
343
|
|
257
344
|
|
258
345
|
# Function to define how the text is displayed in the heatmap
|
@@ -262,7 +349,7 @@ def format_text(*args: str) -> str:
|
|
262
349
|
|
263
350
|
Parameters
|
264
351
|
----------
|
265
|
-
*args:
|
352
|
+
*args : tuple[str, str]
|
266
353
|
Text to be formatted. Second element is ignored, but is a
|
267
354
|
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
268
355
|
|
@@ -273,3 +360,81 @@ def format_text(*args: str) -> str:
|
|
273
360
|
"""
|
274
361
|
x = args[0]
|
275
362
|
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
363
|
+
|
364
|
+
|
365
|
+
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
366
|
+
"""
|
367
|
+
Plots a formatted bar plot
|
368
|
+
|
369
|
+
Parameters
|
370
|
+
----------
|
371
|
+
labels : NDArray
|
372
|
+
Array containing the labels for each bar
|
373
|
+
bar_heights : NDArray
|
374
|
+
Array containing the values for each bar
|
375
|
+
|
376
|
+
Returns
|
377
|
+
-------
|
378
|
+
matplotlib.figure.Figure
|
379
|
+
Bar plot figure
|
380
|
+
"""
|
381
|
+
import matplotlib.pyplot as plt
|
382
|
+
|
383
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
384
|
+
|
385
|
+
ax.bar(labels, bar_heights)
|
386
|
+
ax.set_xlabel("Factors")
|
387
|
+
|
388
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
389
|
+
|
390
|
+
fig.tight_layout()
|
391
|
+
return fig
|
392
|
+
|
393
|
+
|
394
|
+
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
395
|
+
"""
|
396
|
+
Creates a single plot of all of the provided images
|
397
|
+
|
398
|
+
Parameters
|
399
|
+
----------
|
400
|
+
images : NDArray
|
401
|
+
Array containing only the desired images to plot
|
402
|
+
|
403
|
+
Returns
|
404
|
+
-------
|
405
|
+
matplotlib.figure.Figure
|
406
|
+
Plot of all provided images
|
407
|
+
"""
|
408
|
+
import matplotlib.pyplot as plt
|
409
|
+
|
410
|
+
num_images = min(num_images, len(images))
|
411
|
+
|
412
|
+
if images.ndim == 4:
|
413
|
+
images = np.moveaxis(images, 1, -1)
|
414
|
+
elif images.ndim == 3:
|
415
|
+
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
416
|
+
else:
|
417
|
+
raise ValueError(
|
418
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
419
|
+
)
|
420
|
+
|
421
|
+
rows = int(np.ceil(num_images / 3))
|
422
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
423
|
+
|
424
|
+
if rows == 1:
|
425
|
+
for j in range(3):
|
426
|
+
if j >= len(images):
|
427
|
+
continue
|
428
|
+
axs[j].imshow(images[j])
|
429
|
+
axs[j].axis("off")
|
430
|
+
else:
|
431
|
+
for i in range(rows):
|
432
|
+
for j in range(3):
|
433
|
+
i_j = i * 3 + j
|
434
|
+
if i_j >= len(images):
|
435
|
+
continue
|
436
|
+
axs[i, j].imshow(images[i_j])
|
437
|
+
axs[i, j].axis("off")
|
438
|
+
|
439
|
+
fig.tight_layout()
|
440
|
+
return fig
|