dataeval 0.73.0__py3-none-any.whl → 0.73.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.73.0"
1
+ __version__ = "0.73.1"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -12,12 +12,12 @@ from dataeval import detectors, metrics # noqa: E402
12
12
 
13
13
  __all__ = ["detectors", "metrics"]
14
14
 
15
- if _IS_TORCH_AVAILABLE: # pragma: no cover
15
+ if _IS_TORCH_AVAILABLE:
16
16
  from dataeval import workflows
17
17
 
18
18
  __all__ += ["workflows"]
19
19
 
20
- if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
21
21
  from dataeval import utils
22
22
 
23
23
  __all__ += ["utils"]
@@ -7,7 +7,7 @@ from dataeval.detectors import drift, linters
7
7
 
8
8
  __all__ = ["drift", "linters"]
9
9
 
10
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
10
+ if _IS_TENSORFLOW_AVAILABLE:
11
11
  from dataeval.detectors import ood
12
12
 
13
13
  __all__ += ["ood"]
@@ -10,7 +10,7 @@ from dataeval.detectors.drift.ks import DriftKS
10
10
 
11
11
  __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
12
12
 
13
- if _IS_TORCH_AVAILABLE: # pragma: no cover
13
+ if _IS_TORCH_AVAILABLE:
14
14
  from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
15
  from dataeval.detectors.drift.torch import preprocess_drift
16
16
  from dataeval.detectors.drift.uncertainty import DriftUncertainty
@@ -18,7 +18,7 @@ from typing import Any, Callable, Literal, TypeVar
18
18
  import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
- from dataeval.interop import as_numpy, to_numpy
21
+ from dataeval.interop import as_numpy
22
22
  from dataeval.output import OutputMetadata, set_metadata
23
23
 
24
24
  R = TypeVar("R")
@@ -196,7 +196,7 @@ class BaseDrift:
196
196
  if correction not in ["bonferroni", "fdr"]:
197
197
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
198
 
199
- self._x_ref = to_numpy(x_ref)
199
+ self._x_ref = as_numpy(x_ref)
200
200
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
201
 
202
202
  # Other attributes
@@ -480,7 +480,7 @@ class Clusterer:
480
480
  samples = self.clusters[level][cluster_id].samples
481
481
  if len(samples) >= self._min_num_samples_per_cluster:
482
482
  duplicates_std.append(self.clusters[level][cluster_id].dist_std)
483
- diag_mask = np.ones_like(self._sqdmat, dtype=bool)
483
+ diag_mask = np.ones_like(self._sqdmat, dtype=np.bool_)
484
484
  np.fill_diagonal(diag_mask, 0)
485
485
  diag_mask = np.triu(diag_mask)
486
486
 
@@ -4,7 +4,7 @@ Out-of-distribution (OOD)` detectors identify data that is different from the da
4
4
 
5
5
  from dataeval import _IS_TENSORFLOW_AVAILABLE
6
6
 
7
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
7
+ if _IS_TENSORFLOW_AVAILABLE:
8
8
  from dataeval.detectors.ood.ae import OOD_AE
9
9
  from dataeval.detectors.ood.aegmm import OOD_AEGMM
10
10
  from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
@@ -11,7 +11,7 @@ import numpy as np
11
11
  from numpy.typing import ArrayLike, NDArray
12
12
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
13
13
 
14
- from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
14
+ from dataeval.metrics.bias.metadata import CLASS_LABEL, entropy, heatmap, preprocess_metadata
15
15
  from dataeval.output import OutputMetadata, set_metadata
16
16
 
17
17
  with contextlib.suppress(ImportError):
@@ -31,9 +31,9 @@ class BalanceOutput(OutputMetadata):
31
31
  Estimate of inter/intra-factor mutual information
32
32
  classwise : NDArray[np.float64]
33
33
  Estimate of mutual information between metadata factors and individual class labels
34
- class_list: NDArray
34
+ class_list : NDArray
35
35
  Array of the class labels present in the dataset
36
- metadata_names: list[str]
36
+ metadata_names : list[str]
37
37
  Names of each metadata factor
38
38
  """
39
39
 
@@ -54,9 +54,9 @@ class BalanceOutput(OutputMetadata):
54
54
 
55
55
  Parameters
56
56
  ----------
57
- row_labels : ArrayLike | None, default None
57
+ row_labels : ArrayLike or None, default None
58
58
  List/Array containing the labels for rows in the histogram
59
- col_labels : ArrayLike | None, default None
59
+ col_labels : ArrayLike or None, default None
60
60
  List/Array containing the labels for columns in the histogram
61
61
  plot_classwise : bool, default False
62
62
  Whether to plot per-class balance instead of global balance
@@ -116,19 +116,29 @@ def validate_num_neighbors(num_neighbors: int) -> int:
116
116
 
117
117
 
118
118
  @set_metadata("dataeval.metrics")
119
- def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
119
+ def balance(
120
+ class_labels: ArrayLike,
121
+ metadata: Mapping[str, ArrayLike],
122
+ num_neighbors: int = 5,
123
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
124
+ ) -> BalanceOutput:
120
125
  """
121
126
  Mutual information (MI) between factors (class label, metadata, label/image properties)
122
127
 
123
128
  Parameters
124
129
  ----------
125
- class_labels: ArrayLike
130
+ class_labels : ArrayLike
126
131
  List of class labels for each image
127
- metadata: Mapping[str, ArrayLike]
132
+ metadata : Mapping[str, ArrayLike]
128
133
  Dict of lists of metadata factors for each image
129
- num_neighbors: int, default 5
134
+ num_neighbors : int, default 5
130
135
  Number of nearest neighbors to use for computing MI between discrete
131
136
  and continuous variables.
137
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
138
+ The factors in metadata that have continuous values and the array of bin counts to
139
+ discretize values into. All factors are treated as having discrete values unless they
140
+ are specified as keys in this dictionary. Each element of this array must occur as a key
141
+ in metadata.
132
142
 
133
143
  Returns
134
144
  -------
@@ -148,7 +158,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
148
158
  -------
149
159
  Return balance (mutual information) of factors with class_labels
150
160
 
151
- >>> bal = balance(class_labels, metadata)
161
+ >>> bal = balance(class_labels, metadata, continuous_factor_bincounts=continuous_factor_bincounts)
152
162
  >>> bal.balance
153
163
  array([0.99999822, 0.13363788, 0.04505382, 0.02994455])
154
164
 
@@ -165,6 +175,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
165
175
  array([[0.99999822, 0.13363788, 0. , 0. ],
166
176
  [0.99999822, 0.13363788, 0. , 0. ]])
167
177
 
178
+
168
179
  See Also
169
180
  --------
170
181
  sklearn.feature_selection.mutual_info_classif
@@ -178,9 +189,9 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
178
189
  mi[:] = np.nan
179
190
 
180
191
  for idx in range(num_factors):
181
- tgt = data[:, idx].astype(int)
192
+ tgt = data[:, idx].astype(np.intp)
182
193
 
183
- if is_categorical[idx]:
194
+ if continuous_factor_bincounts and names[idx] not in continuous_factor_bincounts:
184
195
  mi[idx, :] = mutual_info_classif(
185
196
  data,
186
197
  tgt,
@@ -197,7 +208,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
197
208
  random_state=0,
198
209
  )
199
210
 
200
- ent_all = entropy(data, names, is_categorical, normalized=False)
211
+ ent_all = entropy(data, names, continuous_factor_bincounts, normalized=False)
201
212
  norm_factor = 0.5 * np.add.outer(ent_all, ent_all) + 1e-6
202
213
  # in principle MI should be symmetric, but it is not in practice.
203
214
  nmi = 0.5 * (mi + mi.T) / norm_factor
@@ -205,7 +216,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
205
216
  factors = nmi[1:, 1:]
206
217
 
207
218
  # unique class labels
208
- class_idx = names.index("class_label")
219
+ class_idx = names.index(CLASS_LABEL)
209
220
  u_cls = np.unique(data[:, class_idx])
210
221
  num_classes = len(u_cls)
211
222
 
@@ -214,12 +225,11 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
214
225
  classwise_mi[:] = np.nan
215
226
 
216
227
  # categorical variables, excluding class label
217
- cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
228
+ cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(np.intp)
218
229
 
219
- tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(int)
220
- ent_tgt_bin = entropy(
221
- tgt_bin, names=[str(idx) for idx in range(num_classes)], is_categorical=[True for idx in range(num_classes)]
222
- )
230
+ tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(np.intp)
231
+ names = [str(idx) for idx in range(num_classes)]
232
+ ent_tgt_bin = entropy(tgt_bin, names, continuous_factor_bincounts)
223
233
 
224
234
  # classification MI for discrete/categorical features
225
235
  for idx in range(num_classes):
@@ -5,7 +5,7 @@ __all__ = ["CoverageOutput", "coverage"]
5
5
  import contextlib
6
6
  import math
7
7
  from dataclasses import dataclass
8
- from typing import Any, Literal
8
+ from typing import Literal
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import ArrayLike, NDArray
@@ -27,9 +27,9 @@ class CoverageOutput(OutputMetadata):
27
27
 
28
28
  Attributes
29
29
  ----------
30
- indices : NDArray
30
+ indices : NDArray[np.intp]
31
31
  Array of uncovered indices
32
- radii : NDArray
32
+ radii : NDArray[np.float64]
33
33
  Array of critical value radii
34
34
  critical_value : float
35
35
  Radius for :term:`coverage<Coverage>`
@@ -39,11 +39,7 @@ class CoverageOutput(OutputMetadata):
39
39
  radii: NDArray[np.float64]
40
40
  critical_value: float
41
41
 
42
- def plot(
43
- self,
44
- images: NDArray[Any],
45
- top_k: int = 6,
46
- ) -> Figure:
42
+ def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
47
43
  """
48
44
  Plot the top k images together for visualization
49
45
 
@@ -53,6 +49,10 @@ class CoverageOutput(OutputMetadata):
53
49
  Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
54
50
  top_k : int, default 6
55
51
  Number of images to plot (plotting assumes groups of 3)
52
+
53
+ Returns
54
+ -------
55
+ matplotlib.figure.Figure
56
56
  """
57
57
  # Determine which images to plot
58
58
  highest_uncovered_indices = self.indices[:top_k]
@@ -82,12 +82,12 @@ def coverage(
82
82
  embeddings : ArrayLike, shape - (N, P)
83
83
  A dataset in an ArrayLike format.
84
84
  Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
85
- radius_type : Literal["adaptive", "naive"], default "adaptive"
85
+ radius_type : {"adaptive", "naive"}, default "adaptive"
86
86
  The function used to determine radius.
87
- k: int, default 20
87
+ k : int, default 20
88
88
  Number of observations required in order to be covered.
89
89
  [1] suggests that a minimum of 20-50 samples is necessary.
90
- percent: float, default 0.01
90
+ percent : float, default 0.01
91
91
  Percent of observations to be considered uncovered. Only applies to adaptive radius.
92
92
 
93
93
  Returns
@@ -10,6 +10,7 @@ import numpy as np
10
10
  from numpy.typing import ArrayLike, NDArray
11
11
 
12
12
  from dataeval.metrics.bias.metadata import (
13
+ CLASS_LABEL,
13
14
  diversity_bar_plot,
14
15
  entropy,
15
16
  get_counts,
@@ -35,9 +36,9 @@ class DiversityOutput(OutputMetadata):
35
36
  :term:`Diversity` index for classes and factors
36
37
  classwise : NDArray[np.float64]
37
38
  Classwise diversity index [n_class x n_factor]
38
- class_list: NDArray[np.int64]
39
+ class_list : NDArray[np.int64]
39
40
  Class labels for each value in the dataset
40
- metadata_names: list[str]
41
+ metadata_names : list[str]
41
42
  Names of each metadata factor
42
43
  """
43
44
 
@@ -45,12 +46,11 @@ class DiversityOutput(OutputMetadata):
45
46
  classwise: NDArray[np.float64]
46
47
  class_list: NDArray[Any]
47
48
  metadata_names: list[str]
48
- method: Literal["shannon", "simpson"]
49
49
 
50
50
  def plot(
51
51
  self,
52
- row_labels: list[Any] | NDArray[Any] | None = None,
53
- col_labels: list[Any] | NDArray[Any] | None = None,
52
+ row_labels: ArrayLike | list[Any] | None = None,
53
+ col_labels: ArrayLike | list[Any] | None = None,
54
54
  plot_classwise: bool = False,
55
55
  ) -> Figure:
56
56
  """
@@ -58,9 +58,9 @@ class DiversityOutput(OutputMetadata):
58
58
 
59
59
  Parameters
60
60
  ----------
61
- row_labels : ArrayLike | None, default None
61
+ row_labels : ArrayLike or None, default None
62
62
  List/Array containing the labels for rows in the histogram
63
- col_labels : ArrayLike | None, default None
63
+ col_labels : ArrayLike or None, default None
64
64
  List/Array containing the labels for columns in the histogram
65
65
  plot_classwise : bool, default False
66
66
  Whether to plot per-class balance instead of global balance
@@ -77,7 +77,7 @@ class DiversityOutput(OutputMetadata):
77
77
  col_labels,
78
78
  xlabel="Factors",
79
79
  ylabel="Class",
80
- cbarlabel=f"Normalized {self.method.title()} Index",
80
+ cbarlabel=f"Normalized {self.meta()['arguments']['method'].title()} Index",
81
81
  )
82
82
 
83
83
  else:
@@ -92,7 +92,7 @@ class DiversityOutput(OutputMetadata):
92
92
  def diversity_shannon(
93
93
  data: NDArray[Any],
94
94
  names: list[str],
95
- is_categorical: list[bool],
95
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
96
96
  subset_mask: NDArray[np.bool_] | None = None,
97
97
  ) -> NDArray[np.float64]:
98
98
  """
@@ -106,14 +106,16 @@ def diversity_shannon(
106
106
 
107
107
  Parameters
108
108
  ----------
109
- data: NDArray
109
+ data : NDArray
110
110
  Array containing numerical values for metadata factors
111
- names: list[str]
111
+ names : list[str]
112
112
  Names of metadata factors -- keys of the metadata dictionary
113
- is_categorical: list[bool]
114
- List of flags to identify whether variables are categorical (True) or
115
- continuous (False)
116
- subset_mask: NDArray[np.bool_] | None
113
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
114
+ The factors in names that have continuous values and the array of bin counts to
115
+ discretize values into. All factors are treated as having discrete values unless they
116
+ are specified as keys in this dictionary. Each element of this array must occur as a key
117
+ in names.
118
+ subset_mask : NDArray[np.bool_] or None, default None
117
119
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
118
120
 
119
121
  Note
@@ -122,18 +124,32 @@ def diversity_shannon(
122
124
 
123
125
  Returns
124
126
  -------
125
- diversity_index: NDArray
127
+ diversity_index : NDArray[np.float64]
126
128
  Diversity index per column of X
127
129
 
128
130
  See Also
129
131
  --------
130
132
  numpy.histogram
131
133
  """
134
+ hist_cache = {}
132
135
 
133
136
  # entropy computed using global auto bins so that we can properly normalize
134
- ent_unnormalized = entropy(data, names, is_categorical, normalized=False, subset_mask=subset_mask)
137
+ ent_unnormalized = entropy(
138
+ data,
139
+ names,
140
+ continuous_factor_bincounts,
141
+ normalized=False,
142
+ subset_mask=subset_mask,
143
+ hist_cache=hist_cache,
144
+ )
135
145
  # normalize by global counts rather than classwise counts
136
- num_bins = get_num_bins(data, names, is_categorical=is_categorical, subset_mask=subset_mask)
146
+ num_bins = get_num_bins(
147
+ data,
148
+ names,
149
+ continuous_factor_bincounts=continuous_factor_bincounts,
150
+ subset_mask=subset_mask,
151
+ hist_cache=hist_cache,
152
+ )
137
153
  ent_norm = np.empty(ent_unnormalized.shape)
138
154
  ent_norm[num_bins != 1] = ent_unnormalized[num_bins != 1] / np.log(num_bins[num_bins != 1])
139
155
  ent_norm[num_bins == 1] = 0
@@ -143,7 +159,7 @@ def diversity_shannon(
143
159
  def diversity_simpson(
144
160
  data: NDArray[Any],
145
161
  names: list[str],
146
- is_categorical: list[bool],
162
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
147
163
  subset_mask: NDArray[np.bool_] | None = None,
148
164
  ) -> NDArray[np.float64]:
149
165
  """
@@ -157,14 +173,16 @@ def diversity_simpson(
157
173
 
158
174
  Parameters
159
175
  ----------
160
- data: NDArray
176
+ data : NDArray
161
177
  Array containing numerical values for metadata factors
162
- names: list[str]
178
+ names : list[str]
163
179
  Names of metadata factors -- keys of the metadata dictionary
164
- is_categorical: list[bool]
165
- List of flags to identify whether variables are categorical (True) or
166
- continuous (False)
167
- subset_mask: NDArray[np.bool_] | None
180
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
181
+ The factors in names that have continuous values and the array of bin counts to
182
+ discretize values into. All factors are treated as having discrete values unless they
183
+ are specified as keys in this dictionary. Each element of this array must occur as a key
184
+ in names.
185
+ subset_mask : NDArray[np.bool_] or None, default None
168
186
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
169
187
 
170
188
  Note
@@ -175,35 +193,39 @@ def diversity_simpson(
175
193
 
176
194
  Returns
177
195
  -------
178
- NDArray
196
+ diversity_index : NDArray[np.float64]
179
197
  Diversity index per column of X
180
198
 
181
199
  See Also
182
200
  --------
183
201
  numpy.histogram
184
202
  """
203
+ hist_cache = {}
185
204
 
186
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
205
+ hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache=hist_cache)
187
206
  # normalize by global counts, not classwise counts
188
- num_bins = get_num_bins(data, names, is_categorical)
207
+ num_bins = get_num_bins(data, names, continuous_factor_bincounts, hist_cache=hist_cache)
189
208
 
190
209
  ev_index = np.empty(len(names))
191
210
  # loop over columns for convenience
192
211
  for col, cnts in enumerate(hist_counts.values()):
193
212
  # relative frequencies
194
- p_i = cnts / cnts.sum()
213
+ p_i = cnts / np.sum(cnts)
195
214
  # inverse Simpson index normalized by (number of bins)
196
- s_0 = 1 / np.sum(p_i**2) / num_bins[col]
215
+ s_0 = 1 / np.sum(p_i**2) # / num_bins[col]
197
216
  if num_bins[col] == 1:
198
217
  ev_index[col] = 0
199
218
  else:
200
- ev_index[col] = (s_0 * num_bins[col] - 1) / (num_bins[col] - 1)
219
+ ev_index[col] = (s_0 - 1) / (num_bins[col] - 1)
201
220
  return ev_index
202
221
 
203
222
 
204
223
  @set_metadata()
205
224
  def diversity(
206
- class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
225
+ class_labels: ArrayLike,
226
+ metadata: Mapping[str, ArrayLike],
227
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
228
+ method: Literal["simpson", "shannon"] = "simpson",
207
229
  ) -> DiversityOutput:
208
230
  """
209
231
  Compute :term:`diversity<Diversity>` and classwise diversity for discrete/categorical variables and,
@@ -216,11 +238,16 @@ def diversity(
216
238
 
217
239
  Parameters
218
240
  ----------
219
- class_labels: ArrayLike
241
+ class_labels : ArrayLike
220
242
  List of class labels for each image
221
- metadata: Mapping[str, ArrayLike]
243
+ metadata : Mapping[str, ArrayLike]
222
244
  Dict of list of metadata factors for each image
223
- method: Literal["shannon", "simpson"], default "simpson"
245
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
246
+ The factors in metadata that have continuous values and the array of bin counts to
247
+ discretize values into. All factors are treated as having discrete values unless they
248
+ are specified as keys in this dictionary. Each element of this array must occur as a key
249
+ in metadata.
250
+ method : {"simpson", "shannon"}, default "simpson"
224
251
  Indicates which diversity index should be computed
225
252
 
226
253
  Note
@@ -239,40 +266,42 @@ def diversity(
239
266
  -------
240
267
  Compute Simpson diversity index of metadata and class labels
241
268
 
242
- >>> div_simp = diversity(class_labels, metadata, method="simpson")
269
+ >>> div_simp = diversity(class_labels, metadata, continuous_factor_bincounts, method="simpson")
243
270
  >>> div_simp.diversity_index
244
- array([0.18103448, 0.18103448, 0.88636364])
271
+ array([0.72413793, 0.72413793, 0.88636364])
245
272
 
246
273
  >>> div_simp.classwise
247
- array([[0.17241379, 0.39473684],
248
- [0.2 , 0.2 ]])
274
+ array([[0.68965517, 0.69230769],
275
+ [0.8 , 1. ]])
249
276
 
250
277
  Compute Shannon diversity index of metadata and class labels
251
278
 
252
- >>> div_shan = diversity(class_labels, metadata, method="shannon")
279
+ >>> div_shan = diversity(class_labels, metadata, continuous_factor_bincounts, method="shannon")
253
280
  >>> div_shan.diversity_index
254
- array([0.37955133, 0.37955133, 0.96748876])
281
+ array([0.8812909 , 0.8812909 , 0.96748876])
255
282
 
256
283
  >>> div_shan.classwise
257
- array([[0.43156028, 0.83224889],
258
- [0.57938016, 0.57938016]])
284
+ array([[0.86312057, 0.91651644],
285
+ [0.91829583, 1. ]])
259
286
 
260
287
  See Also
261
288
  --------
262
289
  numpy.histogram
263
290
  """
264
291
  diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
265
- data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
266
- diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
292
+ data, names, _, unique_labels = preprocess_metadata(class_labels, metadata)
293
+ diversity_index = diversity_fn(data, names, continuous_factor_bincounts)
294
+
295
+ class_idx = names.index(CLASS_LABEL)
296
+ class_lbl = data[:, class_idx]
267
297
 
268
- class_idx = names.index("class_label")
269
- u_classes = np.unique(data[:, class_idx])
298
+ u_classes = np.unique(class_lbl)
270
299
  num_factors = len(names)
271
300
  diversity = np.empty((len(u_classes), num_factors))
272
301
  diversity[:] = np.nan
273
302
  for idx, cls in enumerate(u_classes):
274
- subset_mask = data[:, class_idx] == cls
275
- diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
303
+ subset_mask = class_lbl == cls
304
+ diversity[idx, :] = diversity_fn(data, names, continuous_factor_bincounts, subset_mask)
276
305
  div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
277
306
 
278
- return DiversityOutput(diversity_index, div_no_class, unique_labels, list(metadata.keys()), method)
307
+ return DiversityOutput(diversity_index, div_no_class, unique_labels, list(metadata.keys()))
@@ -18,52 +18,80 @@ CLASS_LABEL = "class_label"
18
18
 
19
19
 
20
20
  def get_counts(
21
- data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
22
- ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
21
+ data: NDArray[Any],
22
+ names: list[str],
23
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
24
+ subset_mask: NDArray[np.bool_] | None = None,
25
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
26
+ ) -> dict[str, NDArray[np.intp]]:
23
27
  """
24
28
  Initialize dictionary of histogram counts --- treat categorical values
25
29
  as histogram bins.
26
30
 
27
31
  Parameters
28
32
  ----------
29
- subset_mask: NDArray[np.bool_] | None
33
+ data : NDArray
34
+ Array containing numerical values for metadata factors
35
+ names : list[str]
36
+ Names of metadata factors -- keys of the metadata dictionary
37
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
38
+ The factors in names that have continuous values and the array of bin counts to
39
+ discretize values into. All factors are treated as having discrete values unless they
40
+ are specified as keys in this dictionary. Each element of this array must occur as a key
41
+ in names.
42
+ Names of metadata factors -- keys of the metadata dictionary
43
+ subset_mask : NDArray[np.bool_] or None, default None
30
44
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
45
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
46
+ Optional cache to store histogram counts
31
47
 
32
48
  Returns
33
49
  -------
34
- counts: Dict
50
+ dict[str, NDArray[np.intp]]
35
51
  histogram counts per metadata factor in `factors`. Each
36
52
  factor will have a different number of bins. Counts get reused
37
53
  across metrics, so hist_counts are cached but only if computed
38
54
  globally, i.e. without masked samples.
39
55
  """
40
56
 
41
- hist_counts, hist_bins = {}, {}
42
- # np.where needed to satisfy linter
43
- mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
57
+ hist_counts = {}
58
+
59
+ mask = subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=np.bool_)
44
60
 
45
61
  for cdx, fn in enumerate(names):
46
- # linter doesn't like double indexing
47
- col_data = data[mask, cdx].squeeze()
48
- if is_categorical[cdx]:
49
- # if discrete, use unique values as bins
50
- bins, cnts = np.unique(col_data, return_counts=True)
62
+ if hist_cache is not None and fn in hist_cache:
63
+ cnts = hist_cache[fn]
51
64
  else:
52
- bins = hist_bins.get(fn, "auto")
53
- cnts, bins = np.histogram(col_data, bins=bins, density=True)
65
+ hist_edges = np.array([-np.inf, np.inf])
66
+ cnts = np.array([len(data[:, cdx].squeeze())])
67
+ # linter doesn't like double indexing
68
+ col_data = np.array(data[mask, cdx].squeeze(), dtype=np.float64)
69
+
70
+ if continuous_factor_bincounts and fn in continuous_factor_bincounts:
71
+ num_bins = continuous_factor_bincounts[fn]
72
+ _, hist_edges = np.histogram(data[:, cdx].squeeze(), bins=num_bins, density=True)
73
+ hist_edges[-1] = np.inf
74
+ hist_edges[0] = -np.inf
75
+ disc_col_data = np.digitize(col_data, np.array(hist_edges))
76
+ _, cnts = np.unique(disc_col_data, return_counts=True)
77
+ else:
78
+ _, cnts = np.unique(col_data, return_counts=True)
79
+
80
+ if hist_cache is not None:
81
+ hist_cache[fn] = cnts
54
82
 
55
83
  hist_counts[fn] = cnts
56
- hist_bins[fn] = bins
57
84
 
58
- return hist_counts, hist_bins
85
+ return hist_counts
59
86
 
60
87
 
61
88
  def entropy(
62
89
  data: NDArray[Any],
63
90
  names: list[str],
64
- is_categorical: list[bool],
91
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
65
92
  normalized: bool = False,
66
93
  subset_mask: NDArray[np.bool_] | None = None,
94
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
67
95
  ) -> NDArray[np.float64]:
68
96
  """
69
97
  Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
@@ -74,19 +102,30 @@ def entropy(
74
102
 
75
103
  Parameters
76
104
  ----------
77
- normalized: bool
105
+ data : NDArray
106
+ Array containing numerical values for metadata factors
107
+ names : list[str]
108
+ Names of metadata factors -- keys of the metadata dictionary
109
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
110
+ The factors in names that have continuous values and the array of bin counts to
111
+ discretize values into. All factors are treated as having discrete values unless they
112
+ are specified as keys in this dictionary. Each element of this array must occur as a key
113
+ in names.
114
+ normalized : bool, default False
78
115
  Flag that determines whether or not to normalize entropy by log(num_bins)
79
- subset_mask: NDArray[np.bool_] | None
116
+ subset_mask : NDArray[np.bool_] or None, default None
80
117
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
118
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
119
+ Optional cache to store histogram counts
81
120
 
82
- Note
83
- ----
121
+ Notes
122
+ -----
84
123
  For continuous variables, histogram bins are chosen automatically. See
85
124
  numpy.histogram for details.
86
125
 
87
126
  Returns
88
127
  -------
89
- ent: NDArray[np.float64]
128
+ NDArray[np.float64]
90
129
  Entropy estimate per column of X
91
130
 
92
131
  See Also
@@ -96,47 +135,64 @@ def entropy(
96
135
  """
97
136
 
98
137
  num_factors = len(names)
99
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
138
+ hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
100
139
 
101
140
  ev_index = np.empty(num_factors)
102
141
  for col, cnts in enumerate(hist_counts.values()):
103
142
  # entropy in nats, normalizes counts
104
143
  ev_index[col] = sp_entropy(cnts)
105
144
  if normalized:
106
- if len(cnts) == 1:
145
+ cnt_len = np.size(cnts, 0)
146
+ if cnt_len == 1:
107
147
  # log(0)
108
148
  ev_index[col] = 0
109
149
  else:
110
- ev_index[col] /= np.log(len(cnts))
150
+ ev_index[col] /= np.log(cnt_len)
111
151
  return ev_index
112
152
 
113
153
 
114
154
  def get_num_bins(
115
- data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
155
+ data: NDArray[Any],
156
+ names: list[str],
157
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
158
+ subset_mask: NDArray[np.bool_] | None = None,
159
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
116
160
  ) -> NDArray[np.float64]:
117
161
  """
118
162
  Number of bins or unique values for each metadata factor, used to
119
- normalize entropy/:term:`diversity<Diversity>`.
163
+ normalize entropy/diversity.
120
164
 
121
165
  Parameters
122
166
  ----------
123
- subset_mask: NDArray[np.bool_] | None
167
+ data : NDArray
168
+ Array containing numerical values for metadata factors
169
+ names : list[str]
170
+ Names of metadata factors -- keys of the metadata dictionary
171
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
172
+ The factors in names that have continuous values and the array of bin counts to
173
+ discretize values into. All factors are treated as having discrete values unless they
174
+ are specified as keys in this dictionary. Each element of this array must occur as a key
175
+ in names.
176
+ subset_mask : NDArray[np.bool_] or None, default None
124
177
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
178
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
179
+ Optional cache to store histogram counts
125
180
 
126
181
  Returns
127
182
  -------
128
183
  NDArray[np.float64]
184
+ Number of bins used in the discretization for each value in names.
129
185
  """
130
186
  # likely cached
131
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
187
+ hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
132
188
  num_bins = np.empty(len(hist_counts))
133
189
  for idx, cnts in enumerate(hist_counts.values()):
134
- num_bins[idx] = len(cnts)
190
+ num_bins[idx] = np.size(cnts, 0)
135
191
 
136
192
  return num_bins
137
193
 
138
194
 
139
- def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
195
+ def infer_categorical(arr: NDArray[np.float64], threshold: float = 0.2) -> NDArray[np.bool_]:
140
196
  """
141
197
  Compute fraction of feature values that are unique --- intended to be used
142
198
  for inferring whether variables are categorical.
@@ -154,12 +210,16 @@ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]
154
210
  def preprocess_metadata(
155
211
  class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
156
212
  ) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
213
+ """
214
+ Formats metadata by organizing factor names, converting labels to numeric values,
215
+ adds class labels to the dataset structure, and marks which factors are categorical.
216
+ """
157
217
  # if class_labels is not numeric
158
218
  class_array = to_numpy(class_labels)
159
- if not np.issubdtype(class_array.dtype, np.number):
219
+ if not np.issubdtype(class_array.dtype, np.integer):
160
220
  unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
161
221
  else:
162
- numerical_labels = np.asarray(class_array, dtype=int)
222
+ numerical_labels = np.asarray(class_array, dtype=np.intp)
163
223
  unique_classes = np.unique(class_array)
164
224
 
165
225
  # convert class_labels and dict of lists to matrix of metadata values
@@ -170,7 +230,7 @@ def preprocess_metadata(
170
230
  # unique string receives a unique integer value.
171
231
  for k, v in metadata.items():
172
232
  if k == CLASS_LABEL:
173
- k = "label_class"
233
+ continue
174
234
  # if not numeric
175
235
  v = to_numpy(v)
176
236
  if not np.issubdtype(v.dtype, np.number):
@@ -181,15 +241,18 @@ def preprocess_metadata(
181
241
 
182
242
  data = np.stack(list(preprocessed_metadata.values()), axis=-1)
183
243
  names = list(preprocessed_metadata.keys())
184
- is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
244
+ is_categorical = [
245
+ var == CLASS_LABEL or infer_categorical(preprocessed_metadata[var].astype(np.float64), cat_thresh)[0]
246
+ for var in names
247
+ ]
185
248
 
186
249
  return data, names, is_categorical, unique_classes
187
250
 
188
251
 
189
252
  def heatmap(
190
- data: NDArray[Any],
191
- row_labels: list[str] | NDArray[Any],
192
- col_labels: list[str] | NDArray[Any],
253
+ data: ArrayLike,
254
+ row_labels: list[str] | ArrayLike,
255
+ col_labels: list[str] | ArrayLike,
193
256
  xlabel: str = "",
194
257
  ylabel: str = "",
195
258
  cbarlabel: str = "",
@@ -211,14 +274,23 @@ def heatmap(
211
274
  Y-axis label
212
275
  cbarlabel : str, default ""
213
276
  Label for the colorbar
277
+
278
+ Returns
279
+ -------
280
+ matplotlib.figure.Figure
281
+ Formatted heatmap
214
282
  """
215
- import matplotlib
216
283
  import matplotlib.pyplot as plt
284
+ from matplotlib.ticker import FuncFormatter
285
+
286
+ np_data = to_numpy(data)
287
+ rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
288
+ cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
217
289
 
218
290
  fig, ax = plt.subplots(figsize=(10, 10))
219
291
 
220
292
  # Plot the heatmap
221
- im = ax.imshow(data, vmin=0, vmax=1.0)
293
+ im = ax.imshow(np_data, vmin=0, vmax=1.0)
222
294
 
223
295
  # Create colorbar
224
296
  cbar = fig.colorbar(im, shrink=0.5)
@@ -227,8 +299,8 @@ def heatmap(
227
299
  cbar.set_label(cbarlabel, loc="center")
228
300
 
229
301
  # Show all ticks and label them with the respective list entries.
230
- ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
231
- ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
302
+ ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
303
+ ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
232
304
 
233
305
  ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
234
306
  # Rotate the tick labels and set their alignment.
@@ -237,8 +309,8 @@ def heatmap(
237
309
  # Turn spines off and create white grid.
238
310
  ax.spines[:].set_visible(False)
239
311
 
240
- ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
241
- ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
312
+ ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
313
+ ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
242
314
  ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
243
315
  ax.tick_params(which="minor", bottom=False, left=False)
244
316
 
@@ -247,7 +319,7 @@ def heatmap(
247
319
  if ylabel:
248
320
  ax.set_ylabel(ylabel)
249
321
 
250
- valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
322
+ valfmt = FuncFormatter(format_text)
251
323
 
252
324
  # Normalize the threshold to the images color range.
253
325
  threshold = im.norm(1.0) / 2.0
@@ -260,10 +332,10 @@ def heatmap(
260
332
  # Change the text's color depending on the data.
261
333
  textcolors = ("white", "black")
262
334
  texts = []
263
- for i in range(data.shape[0]):
264
- for j in range(data.shape[1]):
265
- kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
266
- text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
335
+ for i in range(np_data.shape[0]):
336
+ for j in range(np_data.shape[1]):
337
+ kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
338
+ text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
267
339
  texts.append(text)
268
340
 
269
341
  fig.tight_layout()
@@ -277,7 +349,7 @@ def format_text(*args: str) -> str:
277
349
 
278
350
  Parameters
279
351
  ----------
280
- *args: Tuple (str, str)
352
+ *args : tuple[str, str]
281
353
  Text to be formatted. Second element is ignored, but is a
282
354
  mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
283
355
 
@@ -300,6 +372,11 @@ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figur
300
372
  Array containing the labels for each bar
301
373
  bar_heights : NDArray
302
374
  Array containing the values for each bar
375
+
376
+ Returns
377
+ -------
378
+ matplotlib.figure.Figure
379
+ Bar plot figure
303
380
  """
304
381
  import matplotlib.pyplot as plt
305
382
 
@@ -322,6 +399,11 @@ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
322
399
  ----------
323
400
  images : NDArray
324
401
  Array containing only the desired images to plot
402
+
403
+ Returns
404
+ -------
405
+ matplotlib.figure.Figure
406
+ Plot of all provided images
325
407
  """
326
408
  import matplotlib.pyplot as plt
327
409
 
@@ -336,7 +418,7 @@ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
336
418
  f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
337
419
  )
338
420
 
339
- rows = np.ceil(num_images / 3).astype(int)
421
+ rows = int(np.ceil(num_images / 3))
340
422
  fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
341
423
 
342
424
  if rows == 1:
@@ -28,7 +28,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
28
28
  chi-squared score(s) of the test
29
29
  p_value : np.float64 | NDArray[np.float64]
30
30
  p-value(s) of the test
31
- metadata_names: list[str] | None
31
+ metadata_names : list[str] | None
32
32
  Names of each metadata factor
33
33
  """
34
34
 
@@ -43,16 +43,16 @@ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name
43
43
 
44
44
  Parameters
45
45
  ----------
46
- continuous_values: NDArray
46
+ continuous_values : NDArray
47
47
  The values to be digitized.
48
- bins: int
48
+ bins : int
49
49
  The number of bins for the discrete values that continuous_values will be digitized into.
50
- factor_name: str
50
+ factor_name : str
51
51
  The name of the factor to be digitized.
52
52
 
53
53
  Returns
54
54
  -------
55
- NDArray
55
+ NDArray[np.intp]
56
56
  The digitized values
57
57
  """
58
58
 
@@ -70,17 +70,21 @@ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name
70
70
 
71
71
 
72
72
  def format_discretize_factors(
73
- data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
73
+ data: NDArray[Any],
74
+ names: list[str],
75
+ is_categorical: list[bool],
76
+ continuous_factor_bincounts: Mapping[str, int] | None,
74
77
  ) -> dict[str, NDArray[Any]]:
75
78
  """
76
79
  Sets up the internal list of metadata factors.
77
80
 
78
81
  Parameters
79
82
  ----------
80
- data_factors: Dict[str, NDArray]
83
+ data : NDArray
81
84
  The dataset factors, which are per-image attributes including class label and metadata.
82
- Each key of dataset_factors is a factor, whose value is the per-image factor values.
83
- continuous_factor_bincounts : Dict[str, int]
85
+ names : list[str]
86
+ The class label
87
+ continuous_factor_bincounts : Mapping[str, int] or None
84
88
  The factors in data_factors that have continuous values and the array of bin counts to
85
89
  discretize values into. All factors are treated as having discrete values unless they
86
90
  are specified as keys in this dictionary. Each element of this array must occur as a key
@@ -93,19 +97,20 @@ def format_discretize_factors(
93
97
  Each key is a metadata factor, whose value is the discrete per-image factor values.
94
98
  """
95
99
 
96
- invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
97
- if invalid_keys:
98
- raise KeyError(
99
- f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
100
- "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
101
- )
100
+ if continuous_factor_bincounts:
101
+ invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
102
+ if invalid_keys:
103
+ raise KeyError(
104
+ f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
105
+ "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
106
+ )
102
107
 
103
108
  warn = []
104
109
  metadata_factors = {}
105
110
  for i, name in enumerate(names):
106
111
  if name == CLASS_LABEL:
107
112
  continue
108
- if name in continuous_factor_bincounts:
113
+ if continuous_factor_bincounts and name in continuous_factor_bincounts:
109
114
  metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
110
115
  elif not is_categorical[i]:
111
116
  warn.append(name)
@@ -132,14 +137,14 @@ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[
132
137
 
133
138
  Parameters
134
139
  ----------
135
- expected_dist : np.ndarray
140
+ expected_dist : NDArray
136
141
  The expected label distribution. This array represents the anticipated distribution of labels.
137
- observed_dist : np.ndarray
142
+ observed_dist : NDArray
138
143
  The observed label distribution. This array represents the actual distribution of labels in the dataset.
139
144
 
140
145
  Returns
141
146
  -------
142
- np.ndarray
147
+ NDArray
143
148
  The normalized expected distribution, scaled to have the same sum as the observed distribution.
144
149
 
145
150
  Raises
@@ -179,6 +184,8 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
179
184
  ----------
180
185
  label_dist : NDArray
181
186
  Array representing label distributions
187
+ label_name : str
188
+ String representing label name
182
189
 
183
190
  Raises
184
191
  ------
@@ -219,7 +226,7 @@ def label_parity(
219
226
  List of class labels in the expected dataset
220
227
  observed_labels : ArrayLike
221
228
  List of class labels in the observed dataset
222
- num_classes : int | None, default None
229
+ num_classes : int or None, default None
223
230
  The number of unique classes in the datasets. If not provided, the function will infer it
224
231
  from the set of unique labels in expected_labels and observed_labels
225
232
 
@@ -303,12 +310,12 @@ def parity(
303
310
 
304
311
  Parameters
305
312
  ----------
306
- class_labels: ArrayLike
313
+ class_labels : ArrayLike
307
314
  List of class labels for each image
308
- metadata: Mapping[str, ArrayLike]
315
+ metadata : Mapping[str, ArrayLike]
309
316
  The dataset factors, which are per-image metadata attributes.
310
317
  Each key of dataset_factors is a factor, whose value is the per-image factor values.
311
- continuous_factor_bincounts : Mapping[str, int] | None, default None
318
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
312
319
  A dictionary specifying the number of bins for discretizing the continuous factors.
313
320
  The keys should correspond to the names of continuous factors in `metadata`,
314
321
  and the values should be the number of bins to use for discretization.
@@ -359,7 +366,6 @@ def parity(
359
366
  )
360
367
 
361
368
  data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
362
- continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
363
369
 
364
370
  factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
365
371
 
@@ -10,12 +10,12 @@ from dataeval.utils.split_dataset import split_dataset
10
10
 
11
11
  __all__ = ["split_dataset", "merge_metadata"]
12
12
 
13
- if _IS_TORCH_AVAILABLE: # pragma: no cover
13
+ if _IS_TORCH_AVAILABLE:
14
14
  from dataeval.utils import torch
15
15
 
16
16
  __all__ += ["torch"]
17
17
 
18
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
18
+ if _IS_TENSORFLOW_AVAILABLE:
19
19
  from dataeval.utils import tensorflow
20
20
 
21
21
  __all__ += ["tensorflow"]
dataeval/utils/shared.py CHANGED
@@ -95,7 +95,7 @@ def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
95
95
  M = len(classes)
96
96
  if M < 2:
97
97
  raise ValueError("Label vector contains less than 2 classes!")
98
- N = np.sum(counts).astype(int)
98
+ N = int(np.sum(counts))
99
99
  return M, N
100
100
 
101
101
 
@@ -144,7 +144,7 @@ def check_groups(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
144
144
  ----------
145
145
  group_ids : np.ndarray
146
146
  Identifies the group to which a sample at the same index belongs.
147
- num_partitions: int
147
+ num_partitions : int
148
148
  How many total (train, val) folds will be generated (+1 if also specifying a test fold).
149
149
 
150
150
  Warns
@@ -242,12 +242,12 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
242
242
 
243
243
  Returns
244
244
  -------
245
- group_ids: np.ndarray
245
+ group_ids : np.ndarray
246
246
  group identifiers from metadata
247
247
  """
248
248
  features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
249
249
  if not features2group:
250
- return np.zeros(num_samples, dtype=int)
250
+ return np.zeros(num_samples, dtype=np.int_)
251
251
  for name, feature in features2group.items():
252
252
  if len(feature) != num_samples:
253
253
  raise IndexError(f"""Feature length does not match number of labels.
@@ -300,7 +300,13 @@ def make_splits(
300
300
  splits = splitter.split(index, labels)
301
301
  for train_idx, eval_idx in splits:
302
302
  test_ratio = len(eval_idx) / index.shape[0]
303
- split_defs.append({"train": train_idx.astype(int), "eval": eval_idx.astype(int), "eval_frac": test_ratio})
303
+ split_defs.append(
304
+ {
305
+ "train": train_idx.astype(np.int_),
306
+ "eval": eval_idx.astype(np.int_),
307
+ "eval_frac": test_ratio,
308
+ }
309
+ )
304
310
  return split_defs
305
311
 
306
312
 
@@ -318,9 +324,9 @@ def find_best_split(
318
324
  split_defs : list[dict]
319
325
  List of dictionaries, which specifying train index, validation index, and the ratio of
320
326
  validation to all data.
321
- stratified: bool
327
+ stratified : bool
322
328
  If True, maintain dataset class balance within each train/val split
323
- eval_frac: float
329
+ eval_frac : float
324
330
  Desired fraction of the dataset sequestered for evaluation
325
331
 
326
332
  Returns
@@ -206,7 +206,7 @@ class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
206
206
  Option to select specific classes from dataset.
207
207
  balance : bool, default True
208
208
  If True, returns equal number of samples for each class.
209
- randomize : bool, default False
209
+ randomize : bool, default True
210
210
  If True, shuffles the data prior to selection - uses a set seed for reproducibility.
211
211
  slice_back : bool, default False
212
212
  If True and size has a value greater than 0, then grabs selection starting at the last image.
@@ -251,7 +251,7 @@ class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
251
251
  corruption: CorruptionStringMap | None = None,
252
252
  classes: TClassMap | None = None,
253
253
  balance: bool = True,
254
- randomize: bool = False,
254
+ randomize: bool = True,
255
255
  slice_back: bool = False,
256
256
  verbose: bool = True,
257
257
  ) -> None:
@@ -4,7 +4,7 @@ Workflows perform a sequence of actions to analyze the dataset and make predicti
4
4
 
5
5
  from dataeval import _IS_TORCH_AVAILABLE
6
6
 
7
- if _IS_TORCH_AVAILABLE: # pragma: no cover
7
+ if _IS_TORCH_AVAILABLE:
8
8
  from dataeval.workflows.sufficiency import Sufficiency, SufficiencyOutput
9
9
 
10
10
  __all__ = ["Sufficiency", "SufficiencyOutput"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.73.0
3
+ Version: 0.73.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
- dataeval/__init__.py,sha256=cAgMAbawI3EC6HdLfV_g_mMpH5Y-zy-n2qzrRKBH_6s,641
2
- dataeval/detectors/__init__.py,sha256=xdp8LYOFjV5tVbAwu0Y03KU9EajHkSFy_M3raqbxpDc,383
3
- dataeval/detectors/drift/__init__.py,sha256=MRPWFOaoVoqAHW36nA5F3wk7QXJU4oecND2RbtgG9oY,757
4
- dataeval/detectors/drift/base.py,sha256=0S-0MFpIFaJ4_8IGreFKSmyna2L50FBn7DVaoNWmw8E,14509
1
+ dataeval/__init__.py,sha256=SdXxst_wmjSoQkYzGdR-JXSV-iJmKynWsiwkpmGDDPE,601
2
+ dataeval/detectors/__init__.py,sha256=mwAyY54Hvp6N4D57cde3_besOinK8jVF43k0Mw4XZi8,363
3
+ dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
4
+ dataeval/detectors/drift/base.py,sha256=xwI6C-PEH0ZjpSqP6No6WDZp42DnE16OHi_mXe2JSvI,14499
5
5
  dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
6
6
  dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
7
7
  dataeval/detectors/drift/mmd.py,sha256=TqGOnUNYKwpS0GQPV3dSl-_qRa0g2flmoQ-dxzW_JfY,7586
@@ -9,11 +9,11 @@ dataeval/detectors/drift/torch.py,sha256=D46J72OPW8-PpP3w9ODMBfcDSdailIgVjgHVFpb
9
9
  dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
10
10
  dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
11
11
  dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
12
- dataeval/detectors/linters/clusterer.py,sha256=OtBE5rglAGdTTQRmKUHP6J-uWmnh2E3lZxeqJCnc87U,21014
12
+ dataeval/detectors/linters/clusterer.py,sha256=sau5A9YcQ6VDjbZGOIaCaRHW_63opaA31pqHo5Rm-hQ,21018
13
13
  dataeval/detectors/linters/duplicates.py,sha256=tOD43rJkvheIA3mznbUqHhft2yD3xRZQdCt61daIca4,5665
14
14
  dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
15
15
  dataeval/detectors/linters/outliers.py,sha256=BUVvtbKHo04KnRmrgb84MBr0l1gtcY3-xNCHjetFrEQ,10117
16
- dataeval/detectors/ood/__init__.py,sha256=FVyVuaxVKAOgSTaaBf-j2OXXDarSBFcJ7CTlMV6w88s,661
16
+ dataeval/detectors/ood/__init__.py,sha256=yzvCszJ0KrX9Eu4S_ykC_jwC0uYGPjxY3Vyx9fU3zQk,641
17
17
  dataeval/detectors/ood/ae.py,sha256=XQ_rCsf0VWg_2YXt33XGe6ZgxEud1PfIl7TmBVP1GkM,2347
18
18
  dataeval/detectors/ood/aegmm.py,sha256=6UKv0uJYWAzu1F-cITFGly4w9y_t7wqg3OmVyCN365o,2041
19
19
  dataeval/detectors/ood/base.py,sha256=a_d52pJMWVmduSt8OvUWYwHE8mpCaI6pIAE4_ib_GOs,8841
@@ -26,11 +26,11 @@ dataeval/detectors/ood/vaegmm.py,sha256=_wwmT37URs0MyhbORk91XJExClv-4e15LH_Bj60P
26
26
  dataeval/interop.py,sha256=TZCkZo844DvzHoxuRo-YsBhT6GvKmyQTHtUEQZPly1M,1728
27
27
  dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
28
28
  dataeval/metrics/bias/__init__.py,sha256=puf645-hAO5hFHNHlZ239TPopqWIoN-uLGXFB8-hA_o,599
29
- dataeval/metrics/bias/balance.py,sha256=Uz7RHf3UuiAxfYlZpKMg4jMzXwXcEfYj7BUnUjzgkw0,8579
30
- dataeval/metrics/bias/coverage.py,sha256=eB8PacN_uJ19pMd5SVI3N98NC2KJMgE3tgI-DJFNHYs,4497
31
- dataeval/metrics/bias/diversity.py,sha256=v9fiuySovMajW9Re0EH_FdbuJryAAdVKkvOuNngO5nc,9618
32
- dataeval/metrics/bias/metadata.py,sha256=OZB9BzPW6JMq2kTp_a9ucqRNcPpfqOexINax1jH5vVQ,11318
33
- dataeval/metrics/bias/parity.py,sha256=vfGnt_GoGMjMfWgY1FjqNV-gjqVq13tsTTmVkNtRfDM,17120
29
+ dataeval/metrics/bias/balance.py,sha256=n4SM2Z46dzps_SPgHV8Q69msZ507AP9neebsQ45cNxc,9170
30
+ dataeval/metrics/bias/coverage.py,sha256=7nDufCmQwZ8QG3Me5UiY0N5YoTByjcwK2zOYuMOHkJ0,4540
31
+ dataeval/metrics/bias/diversity.py,sha256=BKGpyJ1K3S5RS_VxXN5DusB2gfRidOksL7r0L3SFa0Y,11018
32
+ dataeval/metrics/bias/metadata.py,sha256=tPvyfFkfqWBFMX6v8i1ZLAA3DZfF6M4O7qXDdKzhQ6g,15040
33
+ dataeval/metrics/bias/parity.py,sha256=_-WdKRWPlKHLNbjq-4mIhVdR1MI3NEabbMWblAmmVRM,17145
34
34
  dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
35
35
  dataeval/metrics/estimators/ber.py,sha256=SVT-BIC_GLs0l2l2NhWu4OpRbgn96w-OwTSoPHTnQbE,5037
36
36
  dataeval/metrics/estimators/divergence.py,sha256=pImaa216-YYTgGWDCSTcpJrC-dfl7150yVrPfW_TyGc,4293
@@ -46,12 +46,12 @@ dataeval/metrics/stats/pixelstats.py,sha256=x90O10IqVjEORtYwueFLvJnVYTxhPBOOx5HM
46
46
  dataeval/metrics/stats/visualstats.py,sha256=y0xIvst7epcajk8vz2jngiAiz0T7DZC-M97Rs1-vV9I,4950
47
47
  dataeval/output.py,sha256=jWXXNxFNBEaY1rN7Z-6LZl6bQT-I7z_wqr91Rhrdt_0,3061
48
48
  dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
- dataeval/utils/__init__.py,sha256=Qr-D0yHnDE8qit0-Wf6xmdMX9Wle2p_mXKgTueTy5GA,753
49
+ dataeval/utils/__init__.py,sha256=FZLWDA7nMbHOcdg3701cVJpQmUp1Wxxk8h_qIrUQQjY,713
50
50
  dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
51
51
  dataeval/utils/lazy.py,sha256=M0iBHuJh4UPrSJPHZ0jhFwRSZhyjHJQx_KEf1OCkHD8,588
52
52
  dataeval/utils/metadata.py,sha256=A6VN7KbdiOA6rUQvUGKwDcvtOyjBer8bRW_wFxNhmW0,8556
53
- dataeval/utils/shared.py,sha256=BvEeYPMNQTmx4LSaImGeC0VkvcbEY3Byqtxa-jQ3xgc,3623
54
- dataeval/utils/split_dataset.py,sha256=IopyxwC3FaZwgVriW4OXze-mDMpOlvRr83OADA5Jydk,19454
53
+ dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
54
+ dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
55
55
  dataeval/utils/tensorflow/__init__.py,sha256=l4OjIA75JJXeNWDCkST1xtDMVYsw97lZ-9JXFBlyuYg,539
56
56
  dataeval/utils/tensorflow/_internal/gmm.py,sha256=RIFx8asEpi2kMf8JVzq9M3aAvNe9fjpJPf3BzWE-aeE,3787
57
57
  dataeval/utils/tensorflow/_internal/loss.py,sha256=TFhoNPgqeJtdpIHYobZPyzMpeWjzlFqzu5LCtthEUi4,4463
@@ -61,13 +61,13 @@ dataeval/utils/tensorflow/_internal/utils.py,sha256=lr5hKkAPbjMCUNIzMUIqbEddwbWQ
61
61
  dataeval/utils/tensorflow/loss/__init__.py,sha256=Q-66vt91Oe1ByYfo28tW32zXDq2MqQ2gngWgmIVmof8,227
62
62
  dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
63
63
  dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
64
- dataeval/utils/torch/datasets.py,sha256=9YV9-Uhq6NCMuu1hPhMnQXjmeI-Ld8ve1z_haxre88o,15023
64
+ dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
65
65
  dataeval/utils/torch/models.py,sha256=0BsXmLK8W1OZ8nnEGb1f9LzIeCgtevQC37dvKS1v1vA,3236
66
66
  dataeval/utils/torch/trainer.py,sha256=EraOKiXxiMNiycStZNMR5yRz3ehgp87d9ewR9a9dV4w,5559
67
67
  dataeval/utils/torch/utils.py,sha256=FI4LJ6DvXFQJVff8fxSCP7LRkp8H9BIUgYX0kk7_Cuo,1537
68
- dataeval/workflows/__init__.py,sha256=x2JnOoKmLUCZOsB6RNPqMdVvxEb6Hpda5GPJnD_k0v0,310
68
+ dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
69
69
  dataeval/workflows/sufficiency.py,sha256=1jSYhH9i4oesmJYs5PZvWS1LGXf8ekOgNhpFtMPLPXk,18552
70
- dataeval-0.73.0.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
71
- dataeval-0.73.0.dist-info/METADATA,sha256=YVw0z5C5BZs-9gCxCmbo4aNIN7Ph3rZsel7FofFrMKY,4714
72
- dataeval-0.73.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
73
- dataeval-0.73.0.dist-info/RECORD,,
70
+ dataeval-0.73.1.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
71
+ dataeval-0.73.1.dist-info/METADATA,sha256=C7xThIWgHNoZEdSiGEZr3VgDLRSzeT3TkFbn4nQgrK0,4714
72
+ dataeval-0.73.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
73
+ dataeval-0.73.1.dist-info/RECORD,,