dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/_version.py +2 -2
  4. dataeval/config.py +4 -19
  5. dataeval/data/_embeddings.py +78 -35
  6. dataeval/data/_images.py +41 -8
  7. dataeval/data/_metadata.py +348 -66
  8. dataeval/data/_selection.py +22 -7
  9. dataeval/data/_split.py +3 -2
  10. dataeval/data/selections/_classbalance.py +4 -3
  11. dataeval/data/selections/_classfilter.py +9 -8
  12. dataeval/data/selections/_indices.py +4 -3
  13. dataeval/data/selections/_prioritize.py +249 -29
  14. dataeval/data/selections/_reverse.py +1 -1
  15. dataeval/data/selections/_shuffle.py +5 -4
  16. dataeval/detectors/drift/_base.py +2 -1
  17. dataeval/detectors/drift/_mmd.py +2 -1
  18. dataeval/detectors/drift/_nml/_base.py +1 -1
  19. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  20. dataeval/detectors/drift/_nml/_result.py +3 -2
  21. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  22. dataeval/detectors/drift/_uncertainty.py +2 -1
  23. dataeval/detectors/linters/duplicates.py +2 -1
  24. dataeval/detectors/linters/outliers.py +4 -3
  25. dataeval/detectors/ood/__init__.py +2 -1
  26. dataeval/detectors/ood/ae.py +1 -1
  27. dataeval/detectors/ood/base.py +39 -1
  28. dataeval/detectors/ood/knn.py +95 -0
  29. dataeval/detectors/ood/mixin.py +2 -1
  30. dataeval/metadata/_utils.py +1 -1
  31. dataeval/metrics/bias/_balance.py +29 -22
  32. dataeval/metrics/bias/_diversity.py +4 -4
  33. dataeval/metrics/bias/_parity.py +2 -2
  34. dataeval/metrics/stats/_base.py +3 -29
  35. dataeval/metrics/stats/_boxratiostats.py +2 -1
  36. dataeval/metrics/stats/_dimensionstats.py +2 -1
  37. dataeval/metrics/stats/_hashstats.py +21 -3
  38. dataeval/metrics/stats/_pixelstats.py +2 -1
  39. dataeval/metrics/stats/_visualstats.py +2 -1
  40. dataeval/outputs/_base.py +2 -3
  41. dataeval/outputs/_bias.py +2 -1
  42. dataeval/outputs/_estimators.py +1 -1
  43. dataeval/outputs/_linters.py +3 -3
  44. dataeval/outputs/_stats.py +3 -3
  45. dataeval/outputs/_utils.py +1 -1
  46. dataeval/outputs/_workflows.py +49 -31
  47. dataeval/typing.py +23 -9
  48. dataeval/utils/__init__.py +2 -2
  49. dataeval/utils/_array.py +3 -2
  50. dataeval/utils/_bin.py +9 -7
  51. dataeval/utils/_method.py +2 -3
  52. dataeval/utils/_multiprocessing.py +34 -0
  53. dataeval/utils/_plot.py +2 -1
  54. dataeval/utils/data/__init__.py +6 -5
  55. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  56. dataeval/utils/data/_validate.py +170 -0
  57. dataeval/utils/data/collate.py +2 -1
  58. dataeval/utils/torch/_internal.py +2 -1
  59. dataeval/utils/torch/trainer.py +1 -1
  60. dataeval/workflows/sufficiency.py +13 -9
  61. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
  62. dataeval-0.88.0.dist-info/RECORD +105 -0
  63. dataeval/utils/data/_dataset.py +0 -246
  64. dataeval/utils/datasets/__init__.py +0 -21
  65. dataeval/utils/datasets/_antiuav.py +0 -189
  66. dataeval/utils/datasets/_base.py +0 -266
  67. dataeval/utils/datasets/_cifar10.py +0 -201
  68. dataeval/utils/datasets/_fileio.py +0 -142
  69. dataeval/utils/datasets/_milco.py +0 -197
  70. dataeval/utils/datasets/_mixin.py +0 -54
  71. dataeval/utils/datasets/_mnist.py +0 -202
  72. dataeval/utils/datasets/_seadrone.py +0 -512
  73. dataeval/utils/datasets/_ships.py +0 -144
  74. dataeval/utils/datasets/_types.py +0 -48
  75. dataeval/utils/datasets/_voc.py +0 -583
  76. dataeval-0.86.9.dist-info/RECORD +0 -115
  77. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  78. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -10,7 +10,8 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Any, Callable
13
+ from collections.abc import Callable
14
+ from typing import Any
14
15
 
15
16
  import torch
16
17
 
@@ -9,8 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
+ from collections.abc import Sequence
12
13
  from logging import Logger
13
- from typing import Sequence
14
14
 
15
15
  import pandas as pd
16
16
  from typing_extensions import Self
@@ -13,7 +13,8 @@ import copy
13
13
  import logging
14
14
  import warnings
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, Generic, Literal, Sequence, TypeVar, cast
16
+ from collections.abc import Sequence
17
+ from typing import Any, Generic, Literal, TypeVar, cast
17
18
 
18
19
  import pandas as pd
19
20
  from pandas import Index, Period
@@ -11,7 +11,8 @@ from __future__ import annotations
11
11
 
12
12
  import copy
13
13
  from abc import ABC, abstractmethod
14
- from typing import NamedTuple, Sequence
14
+ from collections.abc import Sequence
15
+ from typing import NamedTuple
15
16
 
16
17
  import pandas as pd
17
18
  from typing_extensions import Self
@@ -52,7 +53,7 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
52
53
 
53
54
  def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
54
55
  """Returns filtered result metric data."""
55
- if metrics and not isinstance(metrics, (str, Sequence)):
56
+ if metrics and not isinstance(metrics, str | Sequence):
56
57
  raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
57
58
  if isinstance(metrics, str):
58
59
  metrics = [metrics]
@@ -9,7 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
- from typing import Any, Callable, ClassVar
12
+ from collections.abc import Callable
13
+ from typing import Any, ClassVar
13
14
 
14
15
  import numpy as np
15
16
 
@@ -169,10 +170,10 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
169
170
 
170
171
  @staticmethod
171
172
  def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
172
- if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
173
+ if lower is not None and not isinstance(lower, float | int) or isinstance(lower, bool):
173
174
  raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
174
175
 
175
- if upper is not None and not isinstance(upper, (float, int)) or isinstance(upper, bool):
176
+ if upper is not None and not isinstance(upper, float | int) or isinstance(upper, bool):
176
177
  raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
177
178
 
178
179
  # explicit None check is required due to special interpretation of the value 0.0 as False
@@ -244,7 +245,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
244
245
  ) -> None:
245
246
  if (
246
247
  std_lower_multiplier is not None
247
- and not isinstance(std_lower_multiplier, (float, int))
248
+ and not isinstance(std_lower_multiplier, float | int)
248
249
  or isinstance(std_lower_multiplier, bool)
249
250
  ):
250
251
  raise ValueError(
@@ -257,7 +258,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
257
258
 
258
259
  if (
259
260
  std_upper_multiplier is not None
260
- and not isinstance(std_upper_multiplier, (float, int))
261
+ and not isinstance(std_upper_multiplier, float | int)
261
262
  or isinstance(std_upper_multiplier, bool)
262
263
  ):
263
264
  raise ValueError(
@@ -10,7 +10,8 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Literal, Sequence, cast
13
+ from collections.abc import Sequence
14
+ from typing import Literal, cast
14
15
 
15
16
  import numpy as np
16
17
  import torch
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, overload
6
7
 
7
8
  from dataeval.data._images import Images
8
9
  from dataeval.metrics.stats import hashstats
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Literal, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, Literal, overload
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -201,7 +202,7 @@ class Outliers:
201
202
  >>> results.issues[1]
202
203
  {}
203
204
  """
204
- if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
205
+ if isinstance(stats, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput):
205
206
  return OutliersOutput(self._get_outliers(stats.data()))
206
207
 
207
208
  if not isinstance(stats, Sequence):
@@ -212,7 +213,7 @@ class Outliers:
212
213
  stats_map: dict[type, list[int]] = {}
213
214
  for i, stats_output in enumerate(stats):
214
215
  if not isinstance(
215
- stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
216
+ stats_output, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
216
217
  ):
217
218
  raise TypeError(
218
219
  "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
@@ -2,7 +2,8 @@
2
2
  Out-of-distribution (OOD) detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
5
+ __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE", "OOD_KNN"]
6
6
 
7
7
  from dataeval.detectors.ood.ae import OOD_AE
8
+ from dataeval.detectors.ood.knn import OOD_KNN
8
9
  from dataeval.outputs._ood import OODOutput, OODScoreOutput
@@ -12,7 +12,7 @@ from __future__ import annotations
12
12
 
13
13
  __all__ = []
14
14
 
15
- from typing import Callable
15
+ from collections.abc import Callable
16
16
 
17
17
  import numpy as np
18
18
  import torch
@@ -10,11 +10,16 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, cast
13
+ from abc import ABC, abstractmethod
14
+ from collections.abc import Callable
15
+ from typing import Any, cast
14
16
 
17
+ import numpy as np
15
18
  import torch
19
+ from numpy.typing import NDArray
16
20
 
17
21
  from dataeval.config import DeviceLike, get_device
22
+ from dataeval.data import Embeddings
18
23
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
24
  from dataeval.typing import ArrayLike
20
25
  from dataeval.utils._array import to_numpy
@@ -93,3 +98,36 @@ class OODBaseGMM(OODBase, OODGMMMixin[GaussianMixtureModelParams]):
93
98
  # Calculate the GMM parameters
94
99
  _, z, gamma = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self.model(x_ref))
95
100
  self._gmm_params = gmm_params(z, gamma)
101
+
102
+
103
+ class EmbeddingBasedOODBase(OODBaseMixin[Callable[[Any], Any]], ABC):
104
+ """
105
+ Base class for embedding-based OOD detection methods.
106
+
107
+ These methods work directly on embedding representations,
108
+ using distance metrics or density estimation in embedding space.
109
+ Inherits from OODBaseMixin to get automatic thresholding.
110
+ """
111
+
112
+ def __init__(self) -> None:
113
+ """Initialize embedding-based OOD detector."""
114
+ # Pass a dummy callable as model since we don't use it
115
+ super().__init__(lambda x: x)
116
+
117
+ def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
118
+ """Override to skip [0-1] validation for embeddings."""
119
+ if not isinstance(X, np.ndarray):
120
+ raise TypeError("Dataset should of type: `NDArray`.")
121
+ # Skip the [0-1] range check for embeddings
122
+ return X.shape[1:], X.dtype.type
123
+
124
+ @abstractmethod
125
+ def fit_embeddings(self, embeddings: Embeddings, threshold_perc: float = 95.0) -> None:
126
+ """
127
+ Fit using reference embeddings.
128
+
129
+ Args:
130
+ embeddings: Reference (in-distribution) embeddings
131
+ threshold_perc: Percentage of reference data considered normal
132
+ """
133
+ pass
@@ -0,0 +1,95 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ from sklearn.neighbors import NearestNeighbors
5
+
6
+ from dataeval.data import Embeddings
7
+ from dataeval.detectors.ood.base import EmbeddingBasedOODBase
8
+ from dataeval.outputs._ood import OODScoreOutput
9
+ from dataeval.typing import ArrayLike
10
+
11
+
12
+ class OOD_KNN(EmbeddingBasedOODBase):
13
+ """
14
+ K-Nearest Neighbors Out-of-Distribution detector.
15
+
16
+ Uses average cosine distance to k nearest neighbors in embedding space to detect OOD samples.
17
+ Samples with larger average distances to their k nearest neighbors in the
18
+ reference (in-distribution) set are considered more likely to be OOD.
19
+
20
+ Based on the methodology from:
21
+ "Back to the Basics: Revisiting Out-of-Distribution Detection Baselines"
22
+ (Kuan & Mueller, 2022)
23
+
24
+ As referenced in:
25
+ "Safe AI for coral reefs: Benchmarking out-of-distribution detection
26
+ algorithms for coral reef image surveys"
27
+ """
28
+
29
+ def __init__(self, k: int = 10, distance_metric: Literal["cosine", "euclidean"] = "cosine") -> None:
30
+ """
31
+ Initialize KNN OOD detector.
32
+
33
+ Args:
34
+ k: Number of nearest neighbors to consider (default: 10)
35
+ distance_metric: Distance metric to use ('cosine' or 'euclidean')
36
+ """
37
+ super().__init__()
38
+ self.k = k
39
+ self.distance_metric = distance_metric
40
+ self._nn_model: NearestNeighbors
41
+ self.reference_embeddings: ArrayLike
42
+
43
+ def fit_embeddings(self, embeddings: Embeddings, threshold_perc: float = 95.0) -> None:
44
+ """
45
+ Fit the detector using reference (in-distribution) embeddings.
46
+
47
+ Builds a k-NN index for efficient nearest neighbor search and
48
+ computes reference scores for automatic thresholding.
49
+
50
+ Args:
51
+ embeddings: Reference embeddings from in-distribution data
52
+ threshold_perc: Percentage of reference data considered normal
53
+ """
54
+ self.reference_embeddings = embeddings.to_numpy()
55
+
56
+ if self.k >= len(self.reference_embeddings):
57
+ raise ValueError(
58
+ f"k ({self.k}) must be less than number of reference embeddings ({len(self.reference_embeddings)})"
59
+ )
60
+
61
+ # Build k-NN index using sklearn
62
+ self._nn_model = NearestNeighbors(
63
+ n_neighbors=self.k,
64
+ metric=self.distance_metric,
65
+ algorithm="auto", # Let sklearn choose the best algorithm
66
+ )
67
+ self._nn_model.fit(self.reference_embeddings)
68
+
69
+ # efficiently compute reference scores for automatic thresholding
70
+ ref_scores = self._compute_reference_scores()
71
+ self._ref_score = OODScoreOutput(instance_score=ref_scores)
72
+ self._threshold_perc = threshold_perc
73
+ self._data_info = self._get_data_info(self.reference_embeddings)
74
+
75
+ def _compute_reference_scores(self) -> np.ndarray:
76
+ """Efficiently compute reference scores by excluding self-matches."""
77
+ # Find k+1 neighbors (including self) for reference points
78
+ distances, _ = self._nn_model.kneighbors(self.reference_embeddings, n_neighbors=self.k + 1)
79
+ # Skip first neighbor (self with distance 0) and average the rest
80
+ return np.mean(distances[:, 1:], axis=1)
81
+
82
+ def _score(self, X: np.ndarray, batch_size: int = int(1e10)) -> OODScoreOutput:
83
+ """
84
+ Compute OOD scores for input embeddings.
85
+
86
+ Args:
87
+ X: Input embeddings to score
88
+ batch_size: Batch size (not used, kept for interface compatibility)
89
+
90
+ Returns:
91
+ OODScoreOutput containing instance-level scores
92
+ """
93
+ # Compute OOD scores using sklearn's efficient k-NN search
94
+ distances, _ = self._nn_model.kneighbors(X)
95
+ return OODScoreOutput(instance_score=np.mean(distances, axis=1))
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from abc import ABC, abstractmethod
6
- from typing import Callable, Generic, Literal, TypeVar
6
+ from collections.abc import Callable
7
+ from typing import Generic, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -1,6 +1,6 @@
1
1
  __all__ = []
2
2
 
3
- from typing import Sequence
3
+ from collections.abc import Sequence
4
4
 
5
5
  from numpy.typing import NDArray
6
6
 
@@ -16,7 +16,7 @@ from dataeval.utils._bin import get_counts
16
16
 
17
17
 
18
18
  def _validate_num_neighbors(num_neighbors: int) -> int:
19
- if not isinstance(num_neighbors, (int, float)):
19
+ if not isinstance(num_neighbors, int | float):
20
20
  raise TypeError(
21
21
  f"Variable {num_neighbors} is not real-valued numeric type."
22
22
  "num_neighbors should be an int, greater than 0 and less than"
@@ -73,9 +73,9 @@ def balance(
73
73
  Return intra/interfactor balance (mutual information)
74
74
 
75
75
  >>> bal.factors
76
- array([[1. , 0.017, 0.015],
77
- [0.017, 0.445, 0.245],
78
- [0.015, 0.245, 1.063]])
76
+ array([[1. , 0. , 0.015],
77
+ [0. , 0.08 , 0.011],
78
+ [0.015, 0.011, 1.063]])
79
79
 
80
80
  Return classwise balance (mutual information) of factors with individual class_labels
81
81
 
@@ -95,32 +95,39 @@ def balance(
95
95
 
96
96
  num_neighbors = _validate_num_neighbors(num_neighbors)
97
97
 
98
- data = metadata.discretized_data
99
98
  factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
100
99
  is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
101
100
  num_factors = len(factor_types)
102
101
  class_labels = metadata.class_labels
103
102
 
104
103
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
105
- data = np.hstack((class_labels[:, np.newaxis], data))
104
+
105
+ # Use numeric data for MI
106
+ data = np.hstack((class_labels[:, np.newaxis], metadata.digitized_data))
107
+
108
+ # Present discrete features composed of distinct values as continuous for `mutual_info_classif`
109
+ for i, factor_type in enumerate(factor_types):
110
+ if len(data) == len(np.unique(data[:, i])):
111
+ is_discrete[i] = False
112
+ factor_types[factor_type] = "continuous"
113
+
114
+ mutual_info_fn_map = {
115
+ "categorical": mutual_info_classif,
116
+ "discrete": mutual_info_classif,
117
+ "continuous": mutual_info_regression,
118
+ }
106
119
 
107
120
  for idx, factor_type in enumerate(factor_types.values()):
108
- if factor_type != "continuous":
109
- mi[idx, :] = mutual_info_classif(
110
- data,
111
- data[:, idx],
112
- discrete_features=is_discrete, # type: ignore - sklearn function not typed
113
- n_neighbors=num_neighbors,
114
- random_state=get_seed(),
115
- )
116
- else:
117
- mi[idx, :] = mutual_info_regression(
118
- data,
119
- data[:, idx],
120
- discrete_features=is_discrete, # type: ignore - sklearn function not typed
121
- n_neighbors=num_neighbors,
122
- random_state=get_seed(),
123
- )
121
+ mi[idx, :] = mutual_info_fn_map[factor_type](
122
+ data,
123
+ data[:, idx],
124
+ discrete_features=is_discrete,
125
+ n_neighbors=num_neighbors,
126
+ random_state=get_seed(),
127
+ )
128
+
129
+ # Use binned data for classwise MI
130
+ data = np.hstack((class_labels[:, np.newaxis], metadata.binned_data))
124
131
 
125
132
  # Normalization via entropy
126
133
  bin_cnts = get_counts(data)
@@ -162,12 +162,12 @@ def diversity(
162
162
  raise ValueError("No factors found in provided metadata.")
163
163
 
164
164
  diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
165
- discretized_data = metadata.discretized_data
165
+ binned_data = metadata.binned_data
166
166
  factor_names = metadata.factor_names
167
167
  class_lbl = metadata.class_labels
168
168
 
169
- class_labels_with_discretized_data = np.hstack((class_lbl[:, np.newaxis], discretized_data))
170
- cnts = get_counts(class_labels_with_discretized_data)
169
+ class_labels_with_binned_data = np.hstack((class_lbl[:, np.newaxis], binned_data))
170
+ cnts = get_counts(class_labels_with_binned_data)
171
171
  num_bins = np.bincount(np.nonzero(cnts)[1])
172
172
  diversity_index = diversity_fn(cnts, num_bins)
173
173
 
@@ -176,7 +176,7 @@ def diversity(
176
176
  classwise_div = np.full((len(u_classes), num_factors), np.nan)
177
177
  for idx, cls in enumerate(u_classes):
178
178
  subset_mask = class_lbl == cls
179
- cls_cnts = get_counts(discretized_data[subset_mask], min_num_bins=cnts.shape[0])
179
+ cls_cnts = get_counts(binned_data[subset_mask], min_num_bins=cnts.shape[0])
180
180
  classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
181
181
 
182
182
  return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
@@ -245,10 +245,10 @@ def parity(metadata: Metadata) -> ParityOutput:
245
245
  if not metadata.factor_names:
246
246
  raise ValueError("No factors found in provided metadata.")
247
247
 
248
- chi_scores = np.zeros(metadata.discretized_data.shape[1])
248
+ chi_scores = np.zeros(metadata.binned_data.shape[1])
249
249
  p_values = np.zeros_like(chi_scores)
250
250
  insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
251
- for i, col_data in enumerate(metadata.discretized_data.T):
251
+ for i, col_data in enumerate(metadata.binned_data.T):
252
252
  # Builds a contingency matrix where entry at index (r,c) represents
253
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
254
254
  # at a data point with class c.
@@ -6,11 +6,11 @@ import math
6
6
  import re
7
7
  import warnings
8
8
  from collections import ChainMap
9
+ from collections.abc import Callable, Iterable, Iterator, Sequence
9
10
  from copy import deepcopy
10
11
  from dataclasses import dataclass
11
12
  from functools import partial
12
- from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
13
+ from typing import Any, Generic, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
@@ -21,14 +21,12 @@ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
22
  from dataeval.utils._array import as_numpy, to_numpy
23
23
  from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
24
+ from dataeval.utils._multiprocessing import PoolWrapper
24
25
 
25
26
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
26
27
 
27
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
28
29
 
29
- _S = TypeVar("_S")
30
- _T = TypeVar("_T")
31
-
32
30
 
33
31
  @dataclass
34
32
  class BoundingBox:
@@ -67,30 +65,6 @@ class BoundingBox:
67
65
  return x0_int, y0_int, x1_int, y1_int
68
66
 
69
67
 
70
- class PoolWrapper:
71
- """
72
- Wraps `multiprocessing.Pool` to allow for easy switching between
73
- multiprocessing and single-threaded execution.
74
-
75
- This helps with debugging and profiling, as well as usage with Jupyter notebooks
76
- in VS Code, which does not support subprocess debugging.
77
- """
78
-
79
- def __init__(self, processes: int | None) -> None:
80
- self.pool = Pool(processes) if processes is None or processes > 1 else None
81
-
82
- def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
83
- return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
84
-
85
- def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
86
- return self
87
-
88
- def __exit__(self, *args: Any) -> None:
89
- if self.pool is not None:
90
- self.pool.close()
91
- self.pool.join()
92
-
93
-
94
68
  class StatsProcessor(Generic[TStatsOutput]):
95
69
  output_class: type[TStatsOutput]
96
70
  cache_keys: set[str] = set()
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import copy
6
- from typing import Any, Callable, Generic, TypeVar, cast
6
+ from collections.abc import Callable
7
+ from typing import Any, Generic, TypeVar, cast
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
@@ -4,12 +4,14 @@ import warnings
4
4
 
5
5
  __all__ = []
6
6
 
7
- from typing import Any, Callable
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import numpy as np
10
11
  import xxhash as xxh
11
- from PIL import Image
12
+ from numpy.typing import NDArray
12
13
  from scipy.fftpack import dct
14
+ from scipy.ndimage import zoom
13
15
 
14
16
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
15
17
  from dataeval.outputs import HashStatsOutput
@@ -18,10 +20,26 @@ from dataeval.typing import ArrayLike, Dataset
18
20
  from dataeval.utils._array import as_numpy
19
21
  from dataeval.utils._image import normalize_image_shape, rescale
20
22
 
23
+ try:
24
+ from PIL import Image
25
+ except ImportError:
26
+ Image = None
27
+
21
28
  HASH_SIZE = 8
22
29
  MAX_FACTOR = 4
23
30
 
24
31
 
32
+ def _resize(image: NDArray[np.uint8], resize_dim: int, use_pil: bool = True) -> NDArray[np.uint8]:
33
+ """Resizes a grayscale (HxW) 8-bit image using PIL or scipy.ndimage.zoom."""
34
+
35
+ # Use PIL if available, otherwise resize and resample with scipy.ndimage.zoom
36
+ if use_pil and Image is not None:
37
+ return np.array(Image.fromarray(image).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
38
+
39
+ zoom_factors = (resize_dim / image.shape[0], resize_dim / image.shape[1])
40
+ return np.clip(zoom(image, zoom_factors, order=5, mode="reflect"), 0, 255, dtype=np.uint8)
41
+
42
+
25
43
  def pchash(image: ArrayLike) -> str:
26
44
  """
27
45
  Performs a perceptual hash on an image by resizing to a square NxN image
@@ -59,7 +77,7 @@ def pchash(image: ArrayLike) -> str:
59
77
  rescaled = rescale(normalized, 8).astype(np.uint8)
60
78
 
61
79
  # Resizes the image using the Lanczos algorithm to a square image
62
- im = np.array(Image.fromarray(rescaled).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
80
+ im = _resize(rescaled, resize_dim)
63
81
 
64
82
  # Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
65
83
  transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
  from scipy.stats import entropy, kurtosis, skew
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
dataeval/outputs/_base.py CHANGED
@@ -4,14 +4,13 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import logging
7
- from collections.abc import Collection, Mapping, Sequence
7
+ from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
8
8
  from dataclasses import dataclass
9
9
  from datetime import datetime, timezone
10
10
  from functools import partial, wraps
11
- from typing import Any, Callable, Generic, Iterator, TypeVar, overload
11
+ from typing import Any, Generic, ParamSpec, TypeVar, overload
12
12
 
13
13
  import numpy as np
14
- from typing_extensions import ParamSpec
15
14
 
16
15
  from dataeval import __version__
17
16
 
dataeval/outputs/_bias.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import contextlib
6
+ from collections.abc import Mapping, Sequence
6
7
  from dataclasses import asdict, dataclass
7
- from typing import Any, Mapping, Sequence, TypeVar
8
+ from typing import Any, TypeVar
8
9
 
9
10
  import numpy as np
10
11
  import pandas as pd
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray