dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +13 -3
  3. dataeval/metadata/__init__.py +2 -2
  4. dataeval/metadata/_ood.py +144 -27
  5. dataeval/metrics/bias/__init__.py +11 -1
  6. dataeval/metrics/bias/_balance.py +3 -3
  7. dataeval/metrics/bias/_completeness.py +130 -0
  8. dataeval/metrics/estimators/_ber.py +2 -1
  9. dataeval/metrics/stats/_base.py +31 -36
  10. dataeval/metrics/stats/_dimensionstats.py +2 -2
  11. dataeval/metrics/stats/_hashstats.py +2 -2
  12. dataeval/metrics/stats/_imagestats.py +4 -4
  13. dataeval/metrics/stats/_labelstats.py +4 -45
  14. dataeval/metrics/stats/_pixelstats.py +2 -2
  15. dataeval/metrics/stats/_visualstats.py +2 -2
  16. dataeval/outputs/__init__.py +4 -2
  17. dataeval/outputs/_bias.py +31 -22
  18. dataeval/outputs/_metadata.py +7 -0
  19. dataeval/outputs/_stats.py +2 -3
  20. dataeval/typing.py +43 -12
  21. dataeval/utils/_array.py +26 -1
  22. dataeval/utils/_mst.py +1 -2
  23. dataeval/utils/data/_dataset.py +2 -0
  24. dataeval/utils/data/_embeddings.py +115 -32
  25. dataeval/utils/data/_images.py +38 -15
  26. dataeval/utils/data/_selection.py +7 -8
  27. dataeval/utils/data/_split.py +76 -129
  28. dataeval/utils/data/datasets/_base.py +4 -2
  29. dataeval/utils/data/datasets/_cifar10.py +17 -9
  30. dataeval/utils/data/datasets/_milco.py +18 -12
  31. dataeval/utils/data/datasets/_mnist.py +24 -8
  32. dataeval/utils/data/datasets/_ships.py +18 -8
  33. dataeval/utils/data/datasets/_types.py +1 -5
  34. dataeval/utils/data/datasets/_voc.py +47 -24
  35. dataeval/utils/data/selections/__init__.py +2 -0
  36. dataeval/utils/data/selections/_classfilter.py +1 -1
  37. dataeval/utils/data/selections/_prioritize.py +296 -0
  38. dataeval/utils/data/selections/_shuffle.py +13 -4
  39. dataeval/utils/metadata.py +1 -1
  40. dataeval/utils/torch/_gmm.py +3 -2
  41. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
  42. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
  43. dataeval/detectors/ood/metadata_ood_mi.py +0 -91
  44. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
  45. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.82.1"
11
+ __version__ = "0.84.0"
12
12
 
13
13
  import logging
14
14
 
@@ -34,7 +34,12 @@ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> N
34
34
  logger = logging.getLogger(__name__)
35
35
  if handler is None:
36
36
  handler = logging.StreamHandler() if handler is None else handler
37
- handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
37
+ handler.setFormatter(
38
+ logging.Formatter(
39
+ "%(asctime)s %(levelname)-8s %(name)s.%(filename)s:%(lineno)s - %(funcName)10s() | %(message)s"
40
+ )
41
+ )
38
42
  logger.addHandler(handler)
39
43
  logger.setLevel(level)
44
+ logging.DEBUG
40
45
  logger.debug(f"Added logging handler {handler} to logger: {__name__}")
dataeval/config.py CHANGED
@@ -17,10 +17,18 @@ else:
17
17
  import numpy as np
18
18
  import torch
19
19
 
20
+ ### GLOBALS ###
21
+
20
22
  _device: torch.device | None = None
21
23
  _processes: int | None = None
22
24
  _seed: int | None = None
23
25
 
26
+ ### CONSTS ###
27
+
28
+ EPSILON = 1e-10
29
+
30
+ ### TYPES ###
31
+
24
32
  DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
25
33
  """
26
34
  Type alias for types that are acceptable for specifying a torch.device.
@@ -30,18 +38,20 @@ See Also
30
38
  `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
31
39
  """
32
40
 
41
+ ### FUNCS ###
42
+
33
43
 
34
44
  def _todevice(device: DeviceLike) -> torch.device:
35
45
  return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
36
46
 
37
47
 
38
- def set_device(device: DeviceLike) -> None:
48
+ def set_device(device: DeviceLike | None) -> None:
39
49
  """
40
50
  Sets the default device to use when executing against a PyTorch backend.
41
51
 
42
52
  Parameters
43
53
  ----------
44
- device : DeviceLike
54
+ device : DeviceLike or None
45
55
  The default device to use. See documentation for more information.
46
56
 
47
57
  See Also
@@ -49,7 +59,7 @@ def set_device(device: DeviceLike) -> None:
49
59
  `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
50
60
  """
51
61
  global _device
52
- _device = _todevice(device)
62
+ _device = None if device is None else _todevice(device)
53
63
 
54
64
 
55
65
  def get_device(override: DeviceLike | None = None) -> torch.device:
@@ -1,6 +1,6 @@
1
1
  """Explanatory functions using metadata and additional features such as ood or drift"""
2
2
 
3
- __all__ = ["most_deviated_factors", "metadata_distance"]
3
+ __all__ = ["find_ood_predictors", "metadata_distance", "find_most_deviated_factors"]
4
4
 
5
5
  from dataeval.metadata._distance import metadata_distance
6
- from dataeval.metadata._ood import most_deviated_factors
6
+ from dataeval.metadata._ood import find_most_deviated_factors, find_ood_predictors
dataeval/metadata/_ood.py CHANGED
@@ -6,14 +6,44 @@ import warnings
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
9
+ from sklearn.feature_selection import mutual_info_classif
9
10
 
11
+ from dataeval.config import get_seed
10
12
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
11
- from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
13
+ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorOutput
12
14
  from dataeval.outputs._base import set_metadata
13
15
  from dataeval.utils.data import Metadata
14
16
 
15
17
 
16
- def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
18
+ def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
19
+ """Combines the discrete and continuous data of a :class:`Metadata` object
20
+
21
+ Returns
22
+ -------
23
+ Tuple[list[str], NDArray]
24
+ The combined list of factors names and the combined discrete and continuous data
25
+
26
+ Note
27
+ ----
28
+ Discrete and continuous data must have the same number of samples
29
+ """
30
+ names = []
31
+ data = []
32
+
33
+ if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
34
+ names.extend(metadata.discrete_factor_names)
35
+ data.append(metadata.discrete_data)
36
+
37
+ if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
38
+ names.extend(metadata.continuous_factor_names)
39
+ data.append(metadata.continuous_data)
40
+
41
+ return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
42
+
43
+
44
+ def _combine_metadata(
45
+ metadata_1: Metadata, metadata_2: Metadata
46
+ ) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
17
47
  """
18
48
  Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
19
49
  match exactly and data has the same number of columns (factors).
@@ -42,8 +72,8 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
42
72
  If the length of keys do not match the length of the data
43
73
  """
44
74
  factor_names: list[str] = []
45
- m1_data: list[NDArray] = []
46
- m2_data: list[NDArray] = []
75
+ m1_data: list[NDArray[np.int64 | np.float64]] = []
76
+ m2_data: list[NDArray[np.int64 | np.float64]] = []
47
77
 
48
78
  # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
49
79
  if metadata_1.total_num_factors != metadata_2.total_num_factors:
@@ -121,36 +151,37 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
121
151
 
122
152
 
123
153
  @set_metadata
124
- def most_deviated_factors(
125
- metadata_1: Metadata,
126
- metadata_2: Metadata,
154
+ def find_most_deviated_factors(
155
+ metadata_ref: Metadata,
156
+ metadata_tst: Metadata,
127
157
  ood: OODOutput,
128
158
  ) -> MostDeviatedFactorsOutput:
129
159
  """
130
- Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
160
+ Determine greatest deviation in metadata features per out of distribution sample in test metadata.
131
161
 
132
162
  Parameters
133
163
  ----------
134
- metadata_1 : Metadata
164
+ metadata_ref : Metadata
135
165
  A reference set of Metadata containing factor names and samples
136
166
  with discrete and/or continuous values per factor
137
- metadata_2 : Metadata
167
+ metadata_tst : Metadata
138
168
  The set of Metadata that is tested against the reference metadata.
139
169
  This set must have the same number of features but does not require the same number of samples.
140
170
  ood : OODOutput
141
- A class output by the DataEval's OOD functions that contains which examples are OOD.
171
+ A class output by DataEval's OOD functions that contains which examples are OOD.
142
172
 
143
173
  Returns
144
174
  -------
145
- list[tuple[str, float]]
146
- An array of the factor name and deviation of the highest metadata deviation for each OOD example in metadata_2.
175
+ MostDeviatedFactorsOutput
176
+ An output class containing the factor name and deviation of the highest metadata deviations for each
177
+ OOD example in the test metadata.
147
178
 
148
179
  Notes
149
180
  -----
150
181
  1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
151
182
  and have equivalent factor names and lengths
152
183
  2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
153
- directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
184
+ directly to sample `i` of `metadata_tst` being out-of-distribution from `metadata_ref`
154
185
 
155
186
  Examples
156
187
  --------
@@ -160,13 +191,13 @@ def most_deviated_factors(
160
191
  All samples are out-of-distribution
161
192
 
162
193
  >>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
163
- >>> most_deviated_factors(metadata1, metadata2, is_ood)
194
+ >>> find_most_deviated_factors(metadata1, metadata2, is_ood)
164
195
  MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
165
196
 
166
- If there are no out-of-distribution samples, a list is returned
197
+ No samples are out-of-distribution
167
198
 
168
199
  >>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
169
- >>> most_deviated_factors(metadata1, metadata2, is_ood)
200
+ >>> find_most_deviated_factors(metadata1, metadata2, is_ood)
170
201
  MostDeviatedFactorsOutput([])
171
202
  """
172
203
 
@@ -179,31 +210,30 @@ def most_deviated_factors(
179
210
  # Combines reference and test factor names and data if exists and match exactly
180
211
  # shape -> (samples, factors)
181
212
  factor_names, md_1, md_2 = _combine_metadata(
182
- metadata_1=metadata_1,
183
- metadata_2=metadata_2,
213
+ metadata_1=metadata_ref,
214
+ metadata_2=metadata_tst,
184
215
  )
185
216
 
186
217
  # Stack discrete and continuous factors as separate factors. Must have equal sample counts
187
- metadata_ref = np.hstack(md_1) if md_1 else np.array([])
188
- metadata_tst = np.hstack(md_2) if md_2 else np.array([])
218
+ ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
219
+ tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
189
220
 
190
- if len(metadata_ref) < 3:
221
+ if len(ref_data) < 3:
191
222
  warnings.warn(
192
- f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
223
+ f"At least 3 reference metadata samples are needed, got {len(ref_data)}",
193
224
  UserWarning,
194
225
  )
195
226
  return MostDeviatedFactorsOutput([])
196
227
 
197
- if len(metadata_tst) != len(ood_mask):
228
+ if len(tst_data) != len(ood_mask):
198
229
  raise ValueError(
199
- f"ood and test metadata must have the same length, "
200
- f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
230
+ f"ood and test metadata must have the same length, got {len(ood_mask)} and {len(tst_data)} respectively."
201
231
  )
202
232
 
203
233
  # Calculates deviations of all samples in m2_data
204
234
  # from the median values of the corresponding index in m1_data
205
235
  # Guaranteed for inputs to not be empty
206
- deviations = _calc_median_deviations(metadata_ref, metadata_tst)
236
+ deviations = _calc_median_deviations(ref_data, tst_data)
207
237
 
208
238
  # Get most impactful factor deviation of each sample for ood samples only
209
239
  deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
@@ -217,3 +247,90 @@ def most_deviated_factors(
217
247
  # List of tuples matching the factor name with its deviation
218
248
 
219
249
  return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
250
+
251
+
252
+ _NATS2BITS = 1.442695
253
+ """
254
+ _NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
255
+ which is what many library functions return, multiply it by _NATS2BITS to get it in bits.
256
+ """
257
+
258
+
259
+ def find_ood_predictors(
260
+ metadata: Metadata,
261
+ ood: OODOutput,
262
+ ) -> OODPredictorOutput:
263
+ """Computes mutual information between a set of metadata features and per sample out-of-distribution flags.
264
+
265
+ Given a set of metadata features per sample and a corresponding OODOutput that indicates whether a sample was
266
+ determined to be out of distribution, this function calculates the mutual information between each factor and being
267
+ out of distribution. In other words, it finds which metadata factors most likely correlate to an
268
+ out of distribution sample.
269
+
270
+ Note
271
+ ----
272
+ A high mutual information between a factor and ood samples is an indication of correlation, but not causation.
273
+ Additional analysis should be done to determine how to handle factors with a high mutual information.
274
+
275
+
276
+ Parameters
277
+ ----------
278
+ metadata : Metadata
279
+ A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
280
+ ood : OODOutput
281
+ A class output by DataEval's OOD functions that contains which examples are OOD.
282
+
283
+ Returns
284
+ -------
285
+ OODPredictorOutput
286
+ A dictionary with keys corresponding to metadata feature names, and values indicating the strength of
287
+ association between each named feature and the OOD flag, as mutual information measured in bits.
288
+
289
+ Examples
290
+ --------
291
+ >>> from dataeval.outputs import OODOutput
292
+
293
+ All samples are out-of-distribution
294
+
295
+ >>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
296
+ >>> find_ood_predictors(metadata1, is_ood)
297
+ OODPredictorOutput({'time': 8.008566032557951e-17, 'altitude': 8.008566032557951e-17})
298
+
299
+ No out-of-distribution samples
300
+
301
+ >> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
302
+ >> find_ood_predictors(metadata1, is_ood)
303
+ OODPredictorOutput({})
304
+ """
305
+
306
+ ood_mask: NDArray[np.bool] = ood.is_ood
307
+
308
+ discrete_features_count = len(metadata.discrete_factor_names)
309
+ factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
310
+
311
+ # No metadata correlated with out of distribution data, return 0.0 for all factors
312
+ if not any(ood_mask):
313
+ return OODPredictorOutput(dict.fromkeys(factors, 0.0))
314
+
315
+ if len(data) != len(ood_mask):
316
+ raise ValueError(
317
+ f"ood and metadata must have the same length, got {len(ood_mask)} and {len(data)} respectively."
318
+ )
319
+
320
+ # Calculate mean, std of each factor over all samples
321
+ scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
322
+
323
+ discrete_features = np.zeros_like(factors, dtype=np.bool)
324
+ discrete_features[:discrete_features_count] = True
325
+
326
+ mutual_info_values = (
327
+ mutual_info_classif(
328
+ X=scaled_data,
329
+ y=ood_mask,
330
+ discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
331
+ random_state=get_seed(),
332
+ )
333
+ * _NATS2BITS
334
+ )
335
+
336
+ return OODPredictorOutput({k: mutual_info_values[i] for i, k in enumerate(factors)})
@@ -6,10 +6,12 @@ representation which may impact model performance.
6
6
  __all__ = [
7
7
  "BalanceOutput",
8
8
  "CoverageOutput",
9
+ "CompletenessOutput",
9
10
  "DiversityOutput",
10
11
  "LabelParityOutput",
11
12
  "ParityOutput",
12
13
  "balance",
14
+ "completeness",
13
15
  "coverage",
14
16
  "diversity",
15
17
  "label_parity",
@@ -17,7 +19,15 @@ __all__ = [
17
19
  ]
18
20
 
19
21
  from dataeval.metrics.bias._balance import balance
22
+ from dataeval.metrics.bias._completeness import completeness
20
23
  from dataeval.metrics.bias._coverage import coverage
21
24
  from dataeval.metrics.bias._diversity import diversity
22
25
  from dataeval.metrics.bias._parity import label_parity, parity
23
- from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
26
+ from dataeval.outputs._bias import (
27
+ BalanceOutput,
28
+ CompletenessOutput,
29
+ CoverageOutput,
30
+ DiversityOutput,
31
+ LabelParityOutput,
32
+ ParityOutput,
33
+ )
@@ -8,7 +8,7 @@ import numpy as np
8
8
  import scipy as sp
9
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
10
 
11
- from dataeval.config import get_seed
11
+ from dataeval.config import EPSILON, get_seed
12
12
  from dataeval.outputs import BalanceOutput
13
13
  from dataeval.outputs._base import set_metadata
14
14
  from dataeval.utils._bin import get_counts
@@ -128,7 +128,7 @@ def balance(
128
128
  # Normalization via entropy
129
129
  bin_cnts = get_counts(discretized_data)
130
130
  ent_factor = sp.stats.entropy(bin_cnts, axis=0)
131
- norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + 1e-6
131
+ norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
132
132
 
133
133
  # in principle MI should be symmetric, but it is not in practice.
134
134
  nmi = 0.5 * (mi + mi.T) / norm_factor
@@ -157,7 +157,7 @@ def balance(
157
157
  # Classwise normalization via entropy
158
158
  classwise_bin_cnts = get_counts(tgt_bin)
159
159
  ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
160
- norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + 1e-6
160
+ norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + EPSILON
161
161
  classwise = classwise_mi / norm_factor
162
162
 
163
163
  # Grabbing factor names for plotting function
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+
5
+ __all__ = []
6
+
7
+
8
+ import numpy as np
9
+
10
+ from dataeval.config import EPSILON
11
+ from dataeval.outputs import CompletenessOutput
12
+ from dataeval.typing import ArrayLike
13
+ from dataeval.utils._array import ensure_embeddings
14
+
15
+
16
+ def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
17
+ """
18
+ Calculate the fraction of boxes in a grid defined by quantiles that
19
+ contain at least one data point.
20
+ Also returns the center coordinates of each empty box.
21
+
22
+ Parameters
23
+ ----------
24
+ embeddings : ArrayLike
25
+ Embedded dataset (or other low-dimensional data) (nxp)
26
+ quantiles : int
27
+ number of quantile values to use for partitioning each dimension
28
+ e.g., 1 would create a grid of 2^p boxes, 2, 3^p etc..
29
+
30
+ Returns
31
+ -------
32
+ CompletenessOutput
33
+ - fraction_filled: float - Fraction of boxes that contain at least one
34
+ data point
35
+ - empty_box_centers: List[np.ndarray] - List of coordinates for centers of empty
36
+ boxes
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If embeddings are too high-dimensional (>10)
42
+ ValueError
43
+ If there are too many quantiles (>2)
44
+ ValueError
45
+ If embedding is invalid shape
46
+
47
+ Example
48
+ -------
49
+ >>> embs = np.array([[1, 0], [0, 1], [1, 1]])
50
+ >>> quantiles = 1
51
+ >>> result = completeness(embs, quantiles)
52
+ >>> result.fraction_filled
53
+ 0.75
54
+
55
+ Reference
56
+ ---------
57
+ This implementation is based on https://arxiv.org/abs/2002.03147.
58
+
59
+ [1] Byun, Taejoon, and Sanjai Rayadurgam. “Manifold for Machine Learning Assurance.”
60
+ Proceedings of the ACM/IEEE 42nd International Conference on Software Engineering
61
+ """
62
+ # Ensure proper data format
63
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=False)
64
+
65
+ # Get data dimensions
66
+ n, p = embeddings.shape
67
+ if quantiles > 2 or quantiles <= 0:
68
+ raise ValueError(
69
+ f"Number of quantiles ({quantiles}) is greater than 2 or is nonpositive. \
70
+ The metric scales exponentially in this value. Please 1 or 2 quantiles."
71
+ )
72
+ if p > 10:
73
+ raise ValueError(
74
+ f"Dimension of embeddings ({p}) is greater than 10. \
75
+ The metric scales exponentially in this value. Please reduce the embedding dimension."
76
+ )
77
+ if n == 0 or p == 0:
78
+ raise ValueError("Your provided embeddings do not contain any data!")
79
+ # n+2 edges partition the embedding dimension (e.g. [0,0.5,1] for quantiles = 1)
80
+ quantile_vec = np.linspace(0, 1, quantiles + 2)
81
+
82
+ # Calculate the bin edges for each dimension based on quantiles
83
+ bin_edges = []
84
+ for dim in range(p):
85
+ # Calculate the quantile values for this feature
86
+ edges = np.array(np.quantile(embeddings[:, dim], quantile_vec))
87
+ # Make sure the last bin contains all the remaining points
88
+ edges[-1] += EPSILON
89
+ bin_edges.append(edges)
90
+ # Convert each data point into its corresponding grid cell indices
91
+ grid_indices = []
92
+ for dim in range(p):
93
+ # For each dimension, find which bin each data point belongs to
94
+ # Digitize is 1 indexed so we subtract 1
95
+ indices = np.digitize(embeddings[:, dim], bin_edges[dim]) - 1
96
+ grid_indices.append(indices)
97
+
98
+ # Make the rows the data point and the column the grid index
99
+ grid_coords = np.array(grid_indices).T
100
+
101
+ # Use set to find unique tuple of grid coordinates
102
+ occupied_cells = set(map(tuple, grid_coords))
103
+
104
+ # For the fraction
105
+ num_occupied_cells = len(occupied_cells)
106
+
107
+ # Calculate total possible cells in the grid
108
+ num_bins_per_dim = [len(edges) - 1 for edges in bin_edges]
109
+ total_possible_cells = np.prod(num_bins_per_dim)
110
+
111
+ # Generate all possible grid cells
112
+ all_cells = set(itertools.product(*[range(bins) for bins in num_bins_per_dim]))
113
+
114
+ # Find the empty cells (cells with no data points)
115
+ empty_cells = all_cells - occupied_cells
116
+
117
+ # Calculate center points of empty boxes
118
+ empty_box_centers = []
119
+ for cell in empty_cells:
120
+ center_coords = []
121
+ for dim, idx in enumerate(cell):
122
+ # Calculate center of the bin as midpoint between edges
123
+ center = (bin_edges[dim][idx] + bin_edges[dim][idx + 1]) / 2
124
+ center_coords.append(center)
125
+ empty_box_centers.append(np.array(center_coords))
126
+
127
+ # Calculate the fraction
128
+ fraction = float(num_occupied_cells / total_possible_cells)
129
+ empty_box_centers = np.array(empty_box_centers)
130
+ return CompletenessOutput(fraction, empty_box_centers)
@@ -19,6 +19,7 @@ from numpy.typing import NDArray
19
19
  from scipy.sparse import coo_matrix
20
20
  from scipy.stats import mode
21
21
 
22
+ from dataeval.config import EPSILON
22
23
  from dataeval.outputs import BEROutput
23
24
  from dataeval.outputs._base import set_metadata
24
25
  from dataeval.typing import ArrayLike
@@ -82,7 +83,7 @@ def ber_knn(images: NDArray[np.float64], labels: NDArray[np.int_], k: int) -> tu
82
83
 
83
84
  def knn_lowerbound(value: float, classes: int, k: int) -> float:
84
85
  """Several cases for computing the BER lower bound"""
85
- if value <= 1e-10:
86
+ if value <= EPSILON:
86
87
  return 0.0
87
88
 
88
89
  if classes == 2 and k != 1: