dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +10 -11
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
  16. dataeval/detectors/ood/__init__.py +8 -16
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -4
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/utils/split_dataset.py +486 -0
  52. dataeval/utils/tensorflow/__init__.py +9 -7
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +49 -43
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
  67. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -7
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.0.dist-info/RECORD +0 -80
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -8,8 +8,10 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = []
12
+
11
13
  from functools import partial
12
- from typing import Callable
14
+ from typing import Any, Callable
13
15
 
14
16
  import numpy as np
15
17
  import torch
@@ -42,7 +44,7 @@ def get_device(device: str | torch.device | None = None) -> torch.device:
42
44
  return torch_device
43
45
 
44
46
 
45
- def mmd2_from_kernel_matrix(
47
+ def _mmd2_from_kernel_matrix(
46
48
  kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
47
49
  ) -> torch.Tensor:
48
50
  """
@@ -78,13 +80,13 @@ def mmd2_from_kernel_matrix(
78
80
 
79
81
 
80
82
  def predict_batch(
81
- x: NDArray | torch.Tensor,
83
+ x: NDArray[Any] | torch.Tensor,
82
84
  model: Callable | nn.Module | nn.Sequential,
83
85
  device: torch.device | None = None,
84
86
  batch_size: int = int(1e10),
85
87
  preprocess_fn: Callable | None = None,
86
88
  dtype: type[np.generic] | torch.dtype = np.float32,
87
- ) -> NDArray | torch.Tensor | tuple:
89
+ ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
88
90
  """
89
91
  Make batch predictions on a model.
90
92
 
@@ -102,7 +104,7 @@ def predict_batch(
102
104
  preprocess_fn : Callable | None, default None
103
105
  Optional preprocessing function for each batch.
104
106
  dtype : np.dtype | torch.dtype, default np.float32
105
- Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
107
+ Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
106
108
 
107
109
  Returns
108
110
  -------
@@ -154,13 +156,13 @@ def predict_batch(
154
156
 
155
157
 
156
158
  def preprocess_drift(
157
- x: NDArray,
159
+ x: NDArray[Any],
158
160
  model: nn.Module,
159
- device: torch.device | None = None,
161
+ device: str | torch.device | None = None,
160
162
  preprocess_batch_fn: Callable | None = None,
161
163
  batch_size: int = int(1e10),
162
164
  dtype: type[np.generic] | torch.dtype = np.float32,
163
- ) -> NDArray | torch.Tensor | tuple:
165
+ ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
164
166
  """
165
167
  Prediction function used for preprocessing step of drift detector.
166
168
 
@@ -179,7 +181,7 @@ def preprocess_drift(
179
181
  batch_size : int, default 1e10
180
182
  Batch size used during prediction.
181
183
  dtype : np.dtype | torch.dtype, default np.float32
182
- Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
184
+ Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
183
185
 
184
186
  Returns
185
187
  -------
@@ -189,7 +191,7 @@ def preprocess_drift(
189
191
  return predict_batch(
190
192
  x,
191
193
  model,
192
- device=device,
194
+ device=get_device(device),
193
195
  batch_size=batch_size,
194
196
  preprocess_fn=preprocess_batch_fn,
195
197
  dtype=dtype,
@@ -197,7 +199,7 @@ def preprocess_drift(
197
199
 
198
200
 
199
201
  @torch.jit.script
200
- def squared_pairwise_distance(
202
+ def _squared_pairwise_distance(
201
203
  x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
202
204
  ) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
203
205
  """
@@ -249,7 +251,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
249
251
  return sigma
250
252
 
251
253
 
252
- class GaussianRBF(nn.Module):
254
+ class _GaussianRBF(nn.Module):
253
255
  """
254
256
  Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
255
257
 
@@ -303,7 +305,7 @@ class GaussianRBF(nn.Module):
303
305
  infer_sigma: bool = False,
304
306
  ) -> torch.Tensor:
305
307
  x, y = torch.as_tensor(x), torch.as_tensor(y)
306
- dist = squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
308
+ dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
307
309
 
308
310
  if infer_sigma or self.init_required:
309
311
  if self.trainable and infer_sigma:
@@ -8,6 +8,8 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftUncertainty"]
12
+
11
13
  from functools import partial
12
14
  from typing import Callable, Literal
13
15
 
@@ -16,16 +18,16 @@ from numpy.typing import ArrayLike, NDArray
16
18
  from scipy.special import softmax
17
19
  from scipy.stats import entropy
18
20
 
19
- from .base import DriftOutput, UpdateStrategy
20
- from .ks import DriftKS
21
- from .torch import get_device, preprocess_drift
21
+ from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
+ from dataeval.detectors.drift.ks import DriftKS
23
+ from dataeval.detectors.drift.torch import get_device, preprocess_drift
22
24
 
23
25
 
24
26
  def classifier_uncertainty(
25
- x: NDArray,
27
+ x: NDArray[np.float64],
26
28
  model_fn: Callable,
27
29
  preds_type: Literal["probs", "logits"] = "probs",
28
- ) -> NDArray:
30
+ ) -> NDArray[np.float64]:
29
31
  """
30
32
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
31
33
 
@@ -34,7 +36,7 @@ def classifier_uncertainty(
34
36
  x : np.ndarray
35
37
  Batch of instances.
36
38
  model_fn : Callable
37
- Function that evaluates a classification model on x in a single call (contains
39
+ Function that evaluates a :term:`classification<Classification>` model on x in a single call (contains
38
40
  batching logic if necessary).
39
41
  preds_type : "probs" | "logits", default "probs"
40
42
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
@@ -73,9 +75,9 @@ class DriftUncertainty:
73
75
  x_ref : ArrayLike
74
76
  Data used as reference distribution.
75
77
  model : Callable
76
- Classification model outputting class probabilities (or logits)
78
+ :term:`Classification` model outputting class probabilities (or logits)
77
79
  p_val : float, default 0.05
78
- p-value used for the significance of the test.
80
+ :term:`P-Value` used for the significance of the test.
79
81
  x_ref_preprocessed : bool, default False
80
82
  Whether the given reference data ``x_ref`` has been preprocessed yet.
81
83
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
@@ -145,6 +147,7 @@ class DriftUncertainty:
145
147
  Returns
146
148
  -------
147
149
  DriftUnvariateOutput
148
- Dictionary containing the drift prediction, p-value, and threshold statistics.
150
+ Dictionary containing the drift prediction, :term:`p-value<P-Value>`, and threshold
151
+ statistics.
149
152
  """
150
153
  return self._detector.predict(x)
@@ -0,0 +1,61 @@
1
+ """
2
+ Update strategies inform how the :term:`drift<Drift>` detector classes update the reference data when monitoring
3
+ for drift.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
9
+
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ from dataeval.detectors.drift.base import UpdateStrategy
16
+
17
+
18
+ class LastSeenUpdate(UpdateStrategy):
19
+ """
20
+ Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
21
+
22
+ Parameters
23
+ ----------
24
+ n : int
25
+ Update with last n instances seen by the detector.
26
+ """
27
+
28
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
29
+ x_updated = np.concatenate([x_ref, x], axis=0)
30
+ return x_updated[-self.n :]
31
+
32
+
33
+ class ReservoirSamplingUpdate(UpdateStrategy):
34
+ """
35
+ Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
36
+
37
+ Parameters
38
+ ----------
39
+ n : int
40
+ Update with last n instances seen by the detector.
41
+ """
42
+
43
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
44
+ if x.shape[0] + count <= self.n:
45
+ return np.concatenate([x_ref, x], axis=0)
46
+
47
+ n_ref = x_ref.shape[0]
48
+ output_size = min(self.n, n_ref + x.shape[0])
49
+ shape = (output_size,) + x.shape[1:]
50
+ x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
51
+ x_reservoir[:n_ref] = x_ref
52
+ for item in x:
53
+ count += 1
54
+ if n_ref < self.n:
55
+ x_reservoir[n_ref, :] = item
56
+ n_ref += 1
57
+ else:
58
+ r = np.random.randint(0, count)
59
+ if r < self.n:
60
+ x_reservoir[r, :] = item
61
+ return x_reservoir
@@ -2,9 +2,9 @@
2
2
  Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
3
3
  """
4
4
 
5
- from dataeval._internal.detectors.clusterer import Clusterer, ClustererOutput
6
- from dataeval._internal.detectors.duplicates import Duplicates, DuplicatesOutput
7
- from dataeval._internal.detectors.outliers import Outliers, OutliersOutput
5
+ from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
6
+ from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
7
+ from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
8
8
 
9
9
  __all__ = [
10
10
  "Clusterer",
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["ClustererOutput", "Clusterer"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Iterable, NamedTuple, cast
6
+ from typing import Any, Iterable, NamedTuple, cast
5
7
 
6
8
  import numpy as np
7
9
  from numpy.typing import ArrayLike, NDArray
8
10
  from scipy.cluster.hierarchy import linkage
9
11
  from scipy.spatial.distance import pdist, squareform
10
12
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.metrics.utils import flatten
13
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.interop import to_numpy
14
+ from dataeval.output import OutputMetadata, set_metadata
15
+ from dataeval.utils.shared import flatten
14
16
 
15
17
 
16
18
  @dataclass(frozen=True)
@@ -25,7 +27,7 @@ class ClustererOutput(OutputMetadata):
25
27
  potential_outliers : List[int]
26
28
  Indices which are near the border between belonging in the cluster and being an outlier
27
29
  duplicates : List[List[int]]
28
- Groups of indices that are exact duplicates
30
+ Groups of indices that are exact :term:`duplicates<Duplicates>`
29
31
  potential_duplicates : List[List[int]]
30
32
  Groups of indices which are not exact but closely related data points
31
33
  """
@@ -36,7 +38,7 @@ class ClustererOutput(OutputMetadata):
36
38
  potential_duplicates: list[list[int]]
37
39
 
38
40
 
39
- def extend_linkage(link_arr: NDArray) -> NDArray:
41
+ def _extend_linkage(link_arr: NDArray) -> NDArray:
40
42
  """
41
43
  Adds a column to the linkage matrix link_arr that tracks the new id assigned
42
44
  to each row
@@ -60,10 +62,10 @@ def extend_linkage(link_arr: NDArray) -> NDArray:
60
62
  return arr
61
63
 
62
64
 
63
- class Cluster:
65
+ class _Cluster:
64
66
  __slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
65
67
 
66
- def __init__(self, merged: int, samples: NDArray, sample_dist: float | NDArray, is_copy: bool = False):
68
+ def __init__(self, merged: int, samples: NDArray, sample_dist: float | NDArray, is_copy: bool = False) -> None:
67
69
  self.merged = merged
68
70
  self.samples = np.array(samples, dtype=np.int32)
69
71
  self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
@@ -85,8 +87,8 @@ class Cluster:
85
87
  self.out1 = dist > out1
86
88
  self.out2 = dist > out2
87
89
 
88
- def copy(self) -> Cluster:
89
- return Cluster(False, self.samples, self.sample_dist, True)
90
+ def copy(self) -> _Cluster:
91
+ return _Cluster(False, self.samples, self.sample_dist, True)
90
92
 
91
93
  def __repr__(self) -> str:
92
94
  _params = {
@@ -98,38 +100,38 @@ class Cluster:
98
100
  return f"{self.__class__.__name__}(**{repr(_params)})"
99
101
 
100
102
 
101
- class Clusters(dict[int, dict[int, Cluster]]):
102
- def __init__(self, *args, **kwargs):
103
- super().__init__(*args, **kwargs)
103
+ class _Clusters(dict[int, dict[int, _Cluster]]):
104
+ def __init__(self, *args: dict[int, dict[int, _Cluster]]) -> None:
105
+ super().__init__(*args)
104
106
  self.max_level: int = 1
105
107
 
106
108
 
107
- class ClusterPosition(NamedTuple):
109
+ class _ClusterPosition(NamedTuple):
108
110
  """Keeps track of a cluster's level and ID"""
109
111
 
110
112
  level: int
111
113
  cid: int
112
114
 
113
115
 
114
- class ClusterMergeEntry:
116
+ class _ClusterMergeEntry:
115
117
  __slots__ = "level", "outer_cluster", "inner_cluster", "status"
116
118
 
117
- def __init__(self, level: int, outer_cluster: int, inner_cluster: int, status: int):
119
+ def __init__(self, level: int, outer_cluster: int, inner_cluster: int, status: int) -> None:
118
120
  self.level = level
119
121
  self.outer_cluster = outer_cluster
120
122
  self.inner_cluster = inner_cluster
121
123
  self.status = status
122
124
 
123
- def __lt__(self, value: ClusterMergeEntry) -> bool:
125
+ def __lt__(self, value: _ClusterMergeEntry) -> bool:
124
126
  return self.level.__lt__(value.level)
125
127
 
126
- def __gt__(self, value: ClusterMergeEntry) -> bool:
128
+ def __gt__(self, value: _ClusterMergeEntry) -> bool:
127
129
  return self.level.__gt__(value.level)
128
130
 
129
131
 
130
132
  class Clusterer:
131
133
  """
132
- Uses hierarchical clustering to flag dataset properties of interest like outliers and duplicates
134
+ Uses hierarchical clustering to flag dataset properties of interest like Outliers and :term:`duplicates<Duplicates>`
133
135
 
134
136
  Parameters
135
137
  ----------
@@ -153,36 +155,36 @@ class Clusterer:
153
155
  >>> cluster = Clusterer(dataset)
154
156
  """
155
157
 
156
- def __init__(self, dataset: ArrayLike):
158
+ def __init__(self, dataset: ArrayLike) -> None:
157
159
  # Allows an update to dataset to reset the state rather than instantiate a new class
158
160
  self._on_init(dataset)
159
161
 
160
162
  def _on_init(self, dataset: ArrayLike):
161
- self._data: NDArray = flatten(to_numpy(dataset))
163
+ self._data: NDArray[Any] = flatten(to_numpy(dataset))
162
164
  self._validate_data(self._data)
163
165
  self._num_samples = len(self._data)
164
166
 
165
- self._darr: NDArray = pdist(self._data, metric="euclidean")
166
- self._sqdmat: NDArray = squareform(self._darr)
167
- self._larr: NDArray = extend_linkage(linkage(self._darr))
167
+ self._darr: NDArray[np.floating[Any]] = pdist(self._data, metric="euclidean")
168
+ self._sqdmat: NDArray[np.floating[Any]] = squareform(self._darr)
169
+ self._larr: NDArray[np.floating[Any]] = _extend_linkage(linkage(self._darr))
168
170
  self._max_clusters: int = np.count_nonzero(self._larr[:, 3] == 2)
169
171
 
170
172
  min_num = int(self._num_samples * 0.05)
171
- self._min_num_samples_per_cluster = min(max(2, min_num), 100)
173
+ self._min_num_samples_per_cluster: int = min(max(2, min_num), 100)
172
174
 
173
- self._clusters = None
174
- self._last_good_merge_levels = None
175
+ self._clusters: _Clusters | None = None
176
+ self._last_good_merge_levels: dict[int, int] | None = None
175
177
 
176
178
  @property
177
- def data(self) -> NDArray:
179
+ def data(self) -> NDArray[Any]:
178
180
  return self._data
179
181
 
180
182
  @data.setter
181
- def data(self, x: ArrayLike):
183
+ def data(self, x: ArrayLike) -> None:
182
184
  self._on_init(x)
183
185
 
184
186
  @property
185
- def clusters(self) -> Clusters:
187
+ def clusters(self) -> _Clusters:
186
188
  if self._clusters is None:
187
189
  self._clusters = self._create_clusters()
188
190
  return self._clusters
@@ -209,11 +211,11 @@ class Clusterer:
209
211
  if features < 1:
210
212
  raise ValueError(f"Samples should have at least 1 feature; got {features}")
211
213
 
212
- def _create_clusters(self) -> Clusters:
214
+ def _create_clusters(self) -> _Clusters:
213
215
  """Generates clusters based on linkage matrix"""
214
216
  next_cluster_id = 0
215
- cluster_map: dict[int, ClusterPosition] = {} # Dictionary to associate new cluster ids with actual clusters
216
- clusters: Clusters = Clusters()
217
+ cluster_map: dict[int, _ClusterPosition] = {} # Dictionary to associate new cluster ids with actual clusters
218
+ clusters: _Clusters = _Clusters()
217
219
 
218
220
  # Walking through the linkage array to generate clusters
219
221
  for arr_i in self._larr:
@@ -240,7 +242,7 @@ class Clusterer:
240
242
  # Update clusters to include previously skipped levels
241
243
  clusters = self._fill_levels(clusters, left, right)
242
244
  elif left or right:
243
- child, other_id = cast(tuple[ClusterPosition, int], (left, right_id) if left else (right, left_id))
245
+ child, other_id = cast(tuple[_ClusterPosition, int], (left, right_id) if left else (right, left_id))
244
246
  cc = clusters[child.level][child.cid]
245
247
  samples = np.concatenate([cc.samples, [other_id]])
246
248
  sample_dist = np.concatenate([cc.sample_dist, sample_dist])
@@ -254,12 +256,12 @@ class Clusterer:
254
256
  if level not in clusters:
255
257
  clusters[level] = {}
256
258
 
257
- clusters[level][cid] = Cluster(merged, samples, sample_dist)
258
- cluster_map[int(arr_i[-1])] = ClusterPosition(level, cid)
259
+ clusters[level][cid] = _Cluster(merged, samples, sample_dist)
260
+ cluster_map[int(arr_i[-1])] = _ClusterPosition(level, cid)
259
261
 
260
262
  return clusters
261
263
 
262
- def _fill_levels(self, clusters: Clusters, left: ClusterPosition, right: ClusterPosition) -> Clusters:
264
+ def _fill_levels(self, clusters: _Clusters, left: _ClusterPosition, right: _ClusterPosition) -> _Clusters:
263
265
  # Sets each level's cluster info if it does not exist
264
266
  if left.level != right.level:
265
267
  (level, cid), max_level = (left, right[0]) if left[0] < right[0] else (right, left[0])
@@ -312,7 +314,7 @@ class Clusterer:
312
314
  mask2 = mask2_vals < one_std_check
313
315
  return np.logical_or(desired_merge, mask2)
314
316
 
315
- def _generate_merge_list(self, cluster_matrix: NDArray) -> list[ClusterMergeEntry]:
317
+ def _generate_merge_list(self, cluster_matrix: NDArray) -> list[_ClusterMergeEntry]:
316
318
  """
317
319
  Runs through the clusters dictionary determining when clusters merge,
318
320
  and how close are those clusters when they merge.
@@ -329,7 +331,7 @@ class Clusterer:
329
331
  """
330
332
  intra_max = []
331
333
  merge_mean = []
332
- merge_list: list[ClusterMergeEntry] = []
334
+ merge_list: list[_ClusterMergeEntry] = []
333
335
 
334
336
  for level, cluster_set in self.clusters.items():
335
337
  for outer_cluster, cluster in cluster_set.items():
@@ -356,7 +358,7 @@ class Clusterer:
356
358
  # Calculate the corresponding distance stats
357
359
  distance_stats_arr = aggregate_func(distances)
358
360
  merge_mean.append(distance_stats_arr)
359
- merge_list.append(ClusterMergeEntry(level, outer_cluster, inner_cluster, 0))
361
+ merge_list.append(_ClusterMergeEntry(level, outer_cluster, inner_cluster, 0))
360
362
 
361
363
  all_merge_indices = self._calc_merge_indices(merge_mean=merge_mean, intra_max=intra_max)
362
364
 
@@ -401,7 +403,7 @@ class Clusterer:
401
403
 
402
404
  def find_outliers(self, last_merge_levels: dict[int, int]) -> tuple[list[int], list[int]]:
403
405
  """
404
- Retrieves outliers based on when the sample was added to the cluster
406
+ Retrieves Outliers based on when the sample was added to the cluster
405
407
  and how far it was from the cluster when it was added
406
408
 
407
409
  Parameters
@@ -470,7 +472,7 @@ class Clusterer:
470
472
  Returns
471
473
  -------
472
474
  Tuple[List[List[int]], List[List[int]]]
473
- The exact duplicates and near duplicates as lists of related indices
475
+ The exact :term:`duplicates<Duplicates>` and near duplicates as lists of related indices
474
476
  """
475
477
 
476
478
  duplicates_std = []
@@ -493,14 +495,14 @@ class Clusterer:
493
495
  return exact_dupes, near_dupes
494
496
 
495
497
  # TODO: Move data input to evaluate from class
496
- @set_metadata("dataeval.detectors", ["data"])
498
+ @set_metadata(["data"])
497
499
  def evaluate(self) -> ClustererOutput:
498
- """Finds and flags indices of the data for outliers and duplicates
500
+ """Finds and flags indices of the data for Outliers and :term:`duplicates<Duplicates>`
499
501
 
500
502
  Returns
501
503
  -------
502
504
  ClustererOutput
503
- The outliers and duplicate indices found in the data
505
+ The Outliers and duplicate indices found in the data
504
506
 
505
507
  Example
506
508
  -------
@@ -1,13 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["DuplicatesOutput", "Duplicates"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Generic, Iterable, Sequence, TypeVar
6
+ from typing import Generic, Iterable, Sequence, TypeVar, overload
5
7
 
6
8
  from numpy.typing import ArrayLike
7
9
 
8
- from dataeval._internal.detectors.merged_stats import combine_stats, get_dataset_step_from_idx
9
- from dataeval._internal.metrics.stats.hashstats import HashStatsOutput, hashstats
10
- from dataeval._internal.output import OutputMetadata, set_metadata
10
+ from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
+ from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
+ from dataeval.output import OutputMetadata, set_metadata
11
13
 
12
14
  DuplicateGroup = list[int]
13
15
  DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
@@ -37,7 +39,7 @@ class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
37
39
 
38
40
  class Duplicates:
39
41
  """
40
- Finds the duplicate images in a dataset using xxhash for exact duplicates
42
+ Finds the duplicate images in a dataset using xxhash for exact :term:`duplicates<Duplicates>`
41
43
  and pchash for near duplicates
42
44
 
43
45
  Attributes
@@ -58,7 +60,7 @@ class Duplicates:
58
60
  >>> exact_dupes = Duplicates(only_exact=True)
59
61
  """
60
62
 
61
- def __init__(self, only_exact: bool = False):
63
+ def __init__(self, only_exact: bool = False) -> None:
62
64
  self.stats: HashStatsOutput
63
65
  self.only_exact = only_exact
64
66
 
@@ -81,8 +83,16 @@ class Duplicates:
81
83
  "near": sorted(near),
82
84
  }
83
85
 
84
- @set_metadata("dataeval.detectors", ["only_exact"])
85
- def from_stats(self, hashes: HashStatsOutput | Sequence[HashStatsOutput]) -> DuplicatesOutput:
86
+ @overload
87
+ def from_stats(self, hashes: HashStatsOutput) -> DuplicatesOutput[DuplicateGroup]: ...
88
+
89
+ @overload
90
+ def from_stats(self, hashes: Sequence[HashStatsOutput]) -> DuplicatesOutput[DatasetDuplicateGroupMap]: ...
91
+
92
+ @set_metadata(["only_exact"])
93
+ def from_stats(
94
+ self, hashes: HashStatsOutput | Sequence[HashStatsOutput]
95
+ ) -> DuplicatesOutput[DuplicateGroup] | DuplicatesOutput[DatasetDuplicateGroupMap]:
86
96
  """
87
97
  Returns duplicate image indices for both exact matches and near matches
88
98
 
@@ -128,8 +138,8 @@ class Duplicates:
128
138
 
129
139
  return DuplicatesOutput(**duplicates)
130
140
 
131
- @set_metadata("dataeval.detectors", ["only_exact"])
132
- def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput:
141
+ @set_metadata(["only_exact"])
142
+ def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
133
143
  """
134
144
  Returns duplicate image indices for both exact matches and near matches
135
145
 
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = []
4
+
3
5
  from copy import deepcopy
4
6
  from typing import Sequence, TypeVar
5
7
 
6
8
  import numpy as np
7
9
 
8
- from dataeval._internal.metrics.stats.base import BaseStatsOutput
10
+ from dataeval.metrics.stats.base import BaseStatsOutput
9
11
 
10
12
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput)
11
13