dataeval 0.74.2__py3-none-any.whl → 0.76.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. dataeval/__init__.py +27 -23
  2. dataeval/detectors/__init__.py +2 -2
  3. dataeval/detectors/drift/__init__.py +14 -12
  4. dataeval/detectors/drift/base.py +3 -3
  5. dataeval/detectors/drift/cvm.py +1 -1
  6. dataeval/detectors/drift/ks.py +3 -2
  7. dataeval/detectors/drift/mmd.py +9 -7
  8. dataeval/detectors/drift/torch.py +12 -12
  9. dataeval/detectors/drift/uncertainty.py +5 -4
  10. dataeval/detectors/drift/updates.py +1 -1
  11. dataeval/detectors/linters/__init__.py +4 -4
  12. dataeval/detectors/linters/clusterer.py +5 -9
  13. dataeval/detectors/linters/duplicates.py +10 -14
  14. dataeval/detectors/linters/outliers.py +100 -5
  15. dataeval/detectors/ood/__init__.py +4 -11
  16. dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
  17. dataeval/detectors/ood/base.py +47 -160
  18. dataeval/detectors/ood/metadata_ks_compare.py +34 -42
  19. dataeval/detectors/ood/metadata_least_likely.py +3 -3
  20. dataeval/detectors/ood/metadata_ood_mi.py +6 -5
  21. dataeval/detectors/ood/mixin.py +146 -0
  22. dataeval/detectors/ood/output.py +63 -0
  23. dataeval/interop.py +7 -6
  24. dataeval/{logging.py → log.py} +2 -0
  25. dataeval/metrics/__init__.py +3 -3
  26. dataeval/metrics/bias/__init__.py +10 -13
  27. dataeval/metrics/bias/balance.py +13 -11
  28. dataeval/metrics/bias/coverage.py +53 -5
  29. dataeval/metrics/bias/diversity.py +56 -24
  30. dataeval/metrics/bias/parity.py +20 -17
  31. dataeval/metrics/estimators/__init__.py +2 -2
  32. dataeval/metrics/estimators/ber.py +7 -4
  33. dataeval/metrics/estimators/divergence.py +4 -4
  34. dataeval/metrics/estimators/uap.py +4 -4
  35. dataeval/metrics/stats/__init__.py +19 -19
  36. dataeval/metrics/stats/base.py +28 -12
  37. dataeval/metrics/stats/boxratiostats.py +13 -14
  38. dataeval/metrics/stats/datasetstats.py +49 -20
  39. dataeval/metrics/stats/dimensionstats.py +8 -8
  40. dataeval/metrics/stats/hashstats.py +14 -10
  41. dataeval/metrics/stats/labelstats.py +94 -11
  42. dataeval/metrics/stats/pixelstats.py +11 -14
  43. dataeval/metrics/stats/visualstats.py +10 -13
  44. dataeval/output.py +23 -14
  45. dataeval/utils/__init__.py +5 -14
  46. dataeval/utils/dataset/__init__.py +7 -0
  47. dataeval/utils/{torch → dataset}/datasets.py +2 -0
  48. dataeval/utils/dataset/read.py +63 -0
  49. dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
  50. dataeval/utils/image.py +2 -2
  51. dataeval/utils/metadata.py +317 -14
  52. dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +91 -71
  53. dataeval/utils/torch/__init__.py +2 -17
  54. dataeval/utils/torch/gmm.py +29 -6
  55. dataeval/utils/torch/{utils.py → internal.py} +82 -58
  56. dataeval/utils/torch/models.py +10 -8
  57. dataeval/utils/torch/trainer.py +6 -85
  58. dataeval/workflows/__init__.py +2 -5
  59. dataeval/workflows/sufficiency.py +18 -8
  60. {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
  61. dataeval-0.76.0.dist-info/METADATA +137 -0
  62. dataeval-0.76.0.dist-info/RECORD +67 -0
  63. dataeval/detectors/ood/base_torch.py +0 -109
  64. dataeval/metrics/bias/metadata_preprocessing.py +0 -285
  65. dataeval/utils/gmm.py +0 -26
  66. dataeval-0.74.2.dist-info/METADATA +0 -120
  67. dataeval-0.74.2.dist-info/RECORD +0 -66
  68. {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,36 +1,40 @@
1
- __version__ = "0.74.2"
1
+ """
2
+ DataEval provides a simple interface to characterize image data and its impact on model performance
3
+ across classification and object-detection tasks. It also provides capabilities to select and curate
4
+ datasets to test and train performant, robust, unbiased and reliable AI models and monitor for data
5
+ shifts that impact performance of deployed models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ __all__ = ["detectors", "log", "metrics", "utils", "workflows"]
11
+ __version__ = "0.76.0"
2
12
 
3
13
  import logging
4
- from importlib.util import find_spec
14
+
15
+ from dataeval import detectors, metrics, utils, workflows
5
16
 
6
17
  logging.getLogger(__name__).addHandler(logging.NullHandler())
7
18
 
8
19
 
9
- def log_stderr(level: int = logging.DEBUG) -> None:
20
+ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> None:
10
21
  """
11
- Helper for quickly adding a StreamHandler to the logger. Useful for
12
- debugging.
22
+ Helper for quickly adding a StreamHandler to the logger. Useful for debugging.
23
+
24
+ Parameters
25
+ ----------
26
+ level : int, default logging.DEBUG(10)
27
+ Set the logging level for the logger.
28
+ handler : logging.Handler, optional
29
+ Sets the logging handler for the logger if provided, otherwise logger will be
30
+ provided with a StreamHandler.
13
31
  """
14
32
  import logging
15
33
 
16
34
  logger = logging.getLogger(__name__)
17
- handler = logging.StreamHandler()
18
- handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
35
+ if handler is None:
36
+ handler = logging.StreamHandler() if handler is None else handler
37
+ handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
19
38
  logger.addHandler(handler)
20
39
  logger.setLevel(level)
21
- logger.debug("Added a stderr logging handler to logger: %s", __name__)
22
-
23
-
24
- _IS_TORCH_AVAILABLE = find_spec("torch") is not None
25
- _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
26
-
27
- del find_spec
28
-
29
- from dataeval import detectors, metrics # noqa: E402
30
-
31
- __all__ = ["log_stderr", "detectors", "metrics"]
32
-
33
- if _IS_TORCH_AVAILABLE:
34
- from dataeval import utils, workflows
35
-
36
- __all__ += ["utils", "workflows"]
40
+ logger.debug(f"Added logging handler {handler} to logger: {__name__}")
@@ -2,6 +2,6 @@
2
2
  Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
3
3
  """
4
4
 
5
- from dataeval.detectors import drift, linters, ood
6
-
7
5
  __all__ = ["drift", "linters", "ood"]
6
+
7
+ from dataeval.detectors import drift, linters, ood
@@ -2,19 +2,21 @@
2
2
  :term:`Drift` detectors identify if the statistical properties of the data has changed.
3
3
  """
4
4
 
5
- from dataeval import _IS_TORCH_AVAILABLE
5
+ __all__ = [
6
+ "DriftCVM",
7
+ "DriftKS",
8
+ "DriftMMD",
9
+ "DriftMMDOutput",
10
+ "DriftOutput",
11
+ "DriftUncertainty",
12
+ "preprocess_drift",
13
+ "updates",
14
+ ]
15
+
6
16
  from dataeval.detectors.drift import updates
7
17
  from dataeval.detectors.drift.base import DriftOutput
8
18
  from dataeval.detectors.drift.cvm import DriftCVM
9
19
  from dataeval.detectors.drift.ks import DriftKS
10
-
11
- __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
12
-
13
- if _IS_TORCH_AVAILABLE:
14
- from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
- from dataeval.detectors.drift.torch import preprocess_drift
16
- from dataeval.detectors.drift.uncertainty import DriftUncertainty
17
-
18
- __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
19
-
20
- del _IS_TORCH_AVAILABLE
20
+ from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
21
+ from dataeval.detectors.drift.torch import preprocess_drift
22
+ from dataeval.detectors.drift.uncertainty import DriftUncertainty
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftOutput"]
11
+ __all__ = []
12
12
 
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass
@@ -45,7 +45,7 @@ class UpdateStrategy(ABC):
45
45
  @dataclass(frozen=True)
46
46
  class DriftBaseOutput(Output):
47
47
  """
48
- Base output class for Drift detector classes
48
+ Base output class for Drift Detector classes
49
49
 
50
50
  Attributes
51
51
  ----------
@@ -64,7 +64,7 @@ class DriftBaseOutput(Output):
64
64
  @dataclass(frozen=True)
65
65
  class DriftOutput(DriftBaseOutput):
66
66
  """
67
- Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
67
+ Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
68
68
 
69
69
  Attributes
70
70
  ----------
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftCVM"]
11
+ __all__ = []
12
12
 
13
13
  from typing import Callable, Literal
14
14
 
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftKS"]
11
+ __all__ = []
12
12
 
13
13
  from typing import Callable, Literal
14
14
 
@@ -22,7 +22,8 @@ from dataeval.interop import to_numpy
22
22
 
23
23
  class DriftKS(BaseDriftUnivariate):
24
24
  """
25
- :term:`Drift` detector employing the Kolmogorov-Smirnov (KS) distribution test.
25
+ :term:`Drift` detector employing the :term:`Kolmogorov-Smirnov (KS) \
26
+ distribution<Kolmogorov-Smirnov (K-S) test>` test.
26
27
 
27
28
  The KS test detects changes in the maximum distance between two data
28
29
  distributions with Bonferroni or :term:`False Discovery Rate (FDR)` correction
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftMMD", "DriftMMDOutput"]
11
+ __all__ = []
12
12
 
13
13
  from dataclasses import dataclass
14
14
  from typing import Callable
@@ -17,15 +17,16 @@ import torch
17
17
  from numpy.typing import ArrayLike
18
18
 
19
19
  from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
- from dataeval.detectors.drift.torch import _GaussianRBF, _mmd2_from_kernel_matrix, get_device
20
+ from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
21
21
  from dataeval.interop import as_numpy
22
22
  from dataeval.output import set_metadata
23
+ from dataeval.utils.torch.internal import get_device
23
24
 
24
25
 
25
26
  @dataclass(frozen=True)
26
27
  class DriftMMDOutput(DriftBaseOutput):
27
28
  """
28
- Output class for :class:`DriftMMD` :term:`drift<Drift>` detector
29
+ Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
29
30
 
30
31
  Attributes
31
32
  ----------
@@ -50,7 +51,8 @@ class DriftMMDOutput(DriftBaseOutput):
50
51
 
51
52
  class DriftMMD(BaseDrift):
52
53
  """
53
- :term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm using a permutation test.
54
+ :term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm \
55
+ using a permutation test.
54
56
 
55
57
  Parameters
56
58
  ----------
@@ -109,7 +111,7 @@ class DriftMMD(BaseDrift):
109
111
 
110
112
  # initialize kernel
111
113
  sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
112
- self._kernel = _GaussianRBF(sigma_tensor).to(self.device)
114
+ self._kernel = GaussianRBF(sigma_tensor).to(self.device)
113
115
 
114
116
  # compute kernel matrix for the reference data
115
117
  if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
@@ -150,9 +152,9 @@ class DriftMMD(BaseDrift):
150
152
  n = x.shape[0]
151
153
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
152
154
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
153
- mmd2 = _mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
155
+ mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
154
156
  mmd2_permuted = torch.Tensor(
155
- [_mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
157
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
156
158
  )
157
159
  mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
158
160
  p_val = (mmd2 <= mmd2_permuted).float().mean()
@@ -17,10 +17,10 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.utils.torch.utils import get_device, predict_batch
20
+ from dataeval.utils.torch.internal import get_device, predict_batch
21
21
 
22
22
 
23
- def _mmd2_from_kernel_matrix(
23
+ def mmd2_from_kernel_matrix(
24
24
  kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
25
25
  ) -> torch.Tensor:
26
26
  """
@@ -127,7 +127,7 @@ def _squared_pairwise_distance(
127
127
 
128
128
  def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
129
129
  """
130
- Bandwidth estimation using the median heuristic :cite:t:`Gretton2012`.
130
+ Bandwidth estimation using the median heuristic `Gretton2012`
131
131
 
132
132
  Parameters
133
133
  ----------
@@ -151,7 +151,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
151
151
  return sigma
152
152
 
153
153
 
154
- class _GaussianRBF(nn.Module):
154
+ class GaussianRBF(nn.Module):
155
155
  """
156
156
  Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
157
157
 
@@ -179,18 +179,18 @@ class _GaussianRBF(nn.Module):
179
179
  ) -> None:
180
180
  super().__init__()
181
181
  init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
182
- self.config = {
182
+ self.config: dict[str, Any] = {
183
183
  "sigma": sigma,
184
184
  "trainable": trainable,
185
185
  "init_sigma_fn": init_sigma_fn,
186
186
  }
187
187
  if sigma is None:
188
- self.log_sigma = nn.Parameter(torch.empty(1), requires_grad=trainable)
189
- self.init_required = True
188
+ self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
189
+ self.init_required: bool = True
190
190
  else:
191
191
  sigma = sigma.reshape(-1) # [Ns,]
192
- self.log_sigma = nn.Parameter(sigma.log(), requires_grad=trainable)
193
- self.init_required = False
192
+ self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
193
+ self.init_required: bool = False
194
194
  self.init_sigma_fn = init_sigma_fn
195
195
  self.trainable = trainable
196
196
 
@@ -200,8 +200,8 @@ class _GaussianRBF(nn.Module):
200
200
 
201
201
  def forward(
202
202
  self,
203
- x: np.ndarray | torch.Tensor,
204
- y: np.ndarray | torch.Tensor,
203
+ x: np.ndarray[Any, Any] | torch.Tensor,
204
+ y: np.ndarray[Any, Any] | torch.Tensor,
205
205
  infer_sigma: bool = False,
206
206
  ) -> torch.Tensor:
207
207
  x, y = torch.as_tensor(x), torch.as_tensor(y)
@@ -213,7 +213,7 @@ class _GaussianRBF(nn.Module):
213
213
  sigma = self.init_sigma_fn(x, y, dist)
214
214
  with torch.no_grad():
215
215
  self.log_sigma.copy_(sigma.log().clone())
216
- self.init_required = False
216
+ self.init_required: bool = False
217
217
 
218
218
  gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
219
219
  # TODO: do matrix multiplication after all?
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftUncertainty"]
11
+ __all__ = []
12
12
 
13
13
  from functools import partial
14
14
  from typing import Callable, Literal
@@ -20,7 +20,8 @@ from scipy.stats import entropy
20
20
 
21
21
  from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
22
  from dataeval.detectors.drift.ks import DriftKS
23
- from dataeval.detectors.drift.torch import get_device, preprocess_drift
23
+ from dataeval.detectors.drift.torch import preprocess_drift
24
+ from dataeval.utils.torch.internal import get_device
24
25
 
25
26
 
26
27
  def classifier_uncertainty(
@@ -65,8 +66,8 @@ def classifier_uncertainty(
65
66
 
66
67
  class DriftUncertainty:
67
68
  """
68
- Test for a change in the number of instances falling into regions on which the
69
- model is uncertain.
69
+ Test for a change in the number of instances falling into regions on which \
70
+ the model is uncertain.
70
71
 
71
72
  Performs a K-S test on prediction entropies.
72
73
 
@@ -1,5 +1,5 @@
1
1
  """
2
- Update strategies inform how the :term:`drift<Drift>` detector classes update the reference data when monitoring
2
+ Update strategies inform how the :term:`drift<Drift>` detector classes update the reference data when monitoring.
3
3
  for drift.
4
4
  """
5
5
 
@@ -2,10 +2,6 @@
2
2
  Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
3
3
  """
4
4
 
5
- from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
6
- from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
7
- from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
8
-
9
5
  __all__ = [
10
6
  "Clusterer",
11
7
  "ClustererOutput",
@@ -14,3 +10,7 @@ __all__ = [
14
10
  "Outliers",
15
11
  "OutliersOutput",
16
12
  ]
13
+
14
+ from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
15
+ from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
16
+ from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["ClustererOutput", "Clusterer"]
3
+ __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Any, Iterable, NamedTuple, cast
@@ -18,7 +18,7 @@ from dataeval.utils.shared import flatten
18
18
  @dataclass(frozen=True)
19
19
  class ClustererOutput(Output):
20
20
  """
21
- Output class for :class:`Clusterer` lint detector
21
+ Output class for :class:`Clusterer` lint detector.
22
22
 
23
23
  Attributes
24
24
  ----------
@@ -131,7 +131,8 @@ class _ClusterMergeEntry:
131
131
 
132
132
  class Clusterer:
133
133
  """
134
- Uses hierarchical clustering to flag dataset properties of interest like Outliers and :term:`duplicates<Duplicates>`
134
+ Uses hierarchical clustering to flag dataset properties of interest like outliers \
135
+ and :term:`duplicates<Duplicates>`.
135
136
 
136
137
  Parameters
137
138
  ----------
@@ -147,12 +148,6 @@ class Clusterer:
147
148
  ----
148
149
  The Clusterer works best when the length of the feature dimension, P, is less than 500.
149
150
  If flattening a CxHxW image results in a dimension larger than 500, then it is recommended to reduce the dimensions.
150
-
151
- Example
152
- -------
153
- Initialize the Clusterer class:
154
-
155
- >>> cluster = Clusterer(dataset)
156
151
  """
157
152
 
158
153
  def __init__(self, dataset: ArrayLike) -> None:
@@ -506,6 +501,7 @@ class Clusterer:
506
501
 
507
502
  Example
508
503
  -------
504
+ >>> cluster = Clusterer(clusterer_images)
509
505
  >>> cluster.evaluate()
510
506
  ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
511
507
  """ # noqa: E501
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["DuplicatesOutput", "Duplicates"]
3
+ __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Generic, Iterable, Sequence, TypeVar, overload
@@ -19,7 +19,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
19
19
  @dataclass(frozen=True)
20
20
  class DuplicatesOutput(Generic[TIndexCollection], Output):
21
21
  """
22
- Output class for :class:`Duplicates` lint detector
22
+ Output class for :class:`Duplicates` lint detector.
23
23
 
24
24
  Attributes
25
25
  ----------
@@ -39,8 +39,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
39
39
 
40
40
  class Duplicates:
41
41
  """
42
- Finds the duplicate images in a dataset using xxhash for exact :term:`duplicates<Duplicates>`
43
- and pchash for near duplicates
42
+ Finds the duplicate images in a dataset using xxhash for exact \
43
+ :term:`duplicates<Duplicates>` and pchash for near duplicates.
44
44
 
45
45
  Attributes
46
46
  ----------
@@ -51,13 +51,6 @@ class Duplicates:
51
51
  ----------
52
52
  only_exact : bool, default False
53
53
  Only inspect the dataset for exact image matches
54
-
55
- Example
56
- -------
57
- Initialize the Duplicates class:
58
-
59
- >>> all_dupes = Duplicates()
60
- >>> exact_dupes = Duplicates(only_exact=True)
61
54
  """
62
55
 
63
56
  def __init__(self, only_exact: bool = False) -> None:
@@ -73,7 +66,8 @@ class Duplicates:
73
66
  if not self.only_exact:
74
67
  near_dict: dict[int, list] = {}
75
68
  for i, value in enumerate(stats["pchash"]):
76
- near_dict.setdefault(value, []).append(i)
69
+ if value:
70
+ near_dict.setdefault(value, []).append(i)
77
71
  near = [sorted(v) for v in near_dict.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
78
72
  else:
79
73
  near = []
@@ -98,7 +92,7 @@ class Duplicates:
98
92
 
99
93
  Parameters
100
94
  ----------
101
- data : HashStatsOutput | Sequence[HashStatsOutput]
95
+ hashes : HashStatsOutput | Sequence[HashStatsOutput]
102
96
  The output(s) from a hashstats analysis
103
97
 
104
98
  Returns
@@ -112,6 +106,7 @@ class Duplicates:
112
106
 
113
107
  Example
114
108
  -------
109
+ >>> exact_dupes = Duplicates(only_exact=True)
115
110
  >>> exact_dupes.from_stats([hashes1, hashes2])
116
111
  DuplicatesOutput(exact=[{0: [3, 20]}, {0: [16], 1: [12]}], near=[])
117
112
  """
@@ -159,7 +154,8 @@ class Duplicates:
159
154
 
160
155
  Example
161
156
  -------
162
- >>> all_dupes.evaluate(images)
157
+ >>> all_dupes = Duplicates()
158
+ >>> all_dupes.evaluate(duplicate_images)
163
159
  DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
164
160
  """ # noqa: E501
165
161
  self.stats = hashstats(data)
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["OutliersOutput", "Outliers"]
3
+ __all__ = []
4
4
 
5
+ # import contextlib
5
6
  from dataclasses import dataclass
6
7
  from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
7
8
 
@@ -12,19 +13,78 @@ from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_s
12
13
  from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
13
14
  from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
14
15
  from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
16
+ from dataeval.metrics.stats.labelstats import LabelStatsOutput
15
17
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput
16
18
  from dataeval.metrics.stats.visualstats import VisualStatsOutput
17
19
  from dataeval.output import Output, set_metadata
18
20
 
21
+ # with contextlib.suppress(ImportError):
22
+ # import pandas as pd
23
+
24
+
19
25
  IndexIssueMap = dict[int, dict[str, float]]
20
26
  OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
21
27
  TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
22
28
 
23
29
 
30
+ def _reorganize_by_class_and_metric(result, lstats):
31
+ """Flip result from grouping by image to grouping by class and metric"""
32
+ metrics = {}
33
+ class_wise = {label: {} for label in lstats.image_indices_per_label}
34
+
35
+ # Group metrics and calculate class-wise counts
36
+ for img, group in result.items():
37
+ for extreme in group:
38
+ metrics.setdefault(extreme, []).append(img)
39
+ for label, images in lstats.image_indices_per_label.items():
40
+ if img in images:
41
+ class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
42
+
43
+ return metrics, class_wise
44
+
45
+
46
+ def _create_table(metrics, class_wise):
47
+ """Create table for displaying the results"""
48
+ max_class_length = max(len(str(label)) for label in class_wise) + 2
49
+ max_total = max(len(metrics[group]) for group in metrics) + 2
50
+
51
+ table_header = " | ".join(
52
+ [f"{'Class':>{max_class_length}}"]
53
+ + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
54
+ + [f"{'Total':<{max_total}}"]
55
+ )
56
+ table_rows = []
57
+
58
+ for class_cat, results in class_wise.items():
59
+ table_value = [f"{class_cat:>{max_class_length}}"]
60
+ total = 0
61
+ for group in sorted(metrics.keys()):
62
+ count = results.get(group, 0)
63
+ table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
64
+ total += count
65
+ table_value.append(f"{total:^{max_total}}")
66
+ table_rows.append(" | ".join(table_value))
67
+
68
+ table = [table_header] + table_rows
69
+ return table
70
+
71
+
72
+ # def _create_pandas_dataframe(class_wise):
73
+ # """Create data for pandas dataframe"""
74
+ # data = []
75
+ # for label, metrics_dict in class_wise.items():
76
+ # row = {"Class": label}
77
+ # total = sum(metrics_dict.values())
78
+ # row.update(metrics_dict) # Add metric counts
79
+ # row["Total"] = total
80
+ # data.append(row)
81
+ # return data
82
+
83
+
24
84
  @dataclass(frozen=True)
25
85
  class OutliersOutput(Generic[TIndexIssueMap], Output):
26
86
  """
27
- Output class for :class:`Outliers` lint detector
87
+ Output class for :class:`Outliers` lint detector.
28
88
 
29
89
  Attributes
30
90
  ----------
@@ -45,6 +105,39 @@ class OutliersOutput(Generic[TIndexIssueMap], Output):
45
105
  else:
46
106
  return sum(len(d) for d in self.issues)
47
107
 
108
+ def to_table(self, labelstats: LabelStatsOutput) -> str:
109
+ if isinstance(self.issues, dict):
110
+ metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
111
+ listed_table = _create_table(metrics, classwise)
112
+ table = "\n".join(listed_table)
113
+ else:
114
+ outertable = []
115
+ for d in self.issues:
116
+ metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
117
+ listed_table = _create_table(metrics, classwise)
118
+ str_table = "\n".join(listed_table)
119
+ outertable.append(str_table)
120
+ table = "\n\n".join(outertable)
121
+ return table
122
+
123
+ # def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
124
+ # import pandas as pd
125
+
126
+ # if isinstance(self.issues, dict):
127
+ # _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
128
+ # data = _create_pandas_dataframe(classwise)
129
+ # df = pd.DataFrame(data)
130
+ # else:
131
+ # df_list = []
132
+ # for i, d in enumerate(self.issues):
133
+ # _, classwise = _reorganize_by_class_and_metric(d, labelstats)
134
+ # data = _create_pandas_dataframe(classwise)
135
+ # single_df = pd.DataFrame(data)
136
+ # single_df["Dataset"] = i
137
+ # df_list.append(single_df)
138
+ # df = pd.concat(df_list)
139
+ # return df
140
+
48
141
 
49
142
  def _get_outlier_mask(
50
143
  values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
@@ -71,7 +164,7 @@ def _get_outlier_mask(
71
164
 
72
165
  class Outliers:
73
166
  r"""
74
- Calculates statistical Outliers of a dataset using various statistical tests applied to each image
167
+ Calculates statistical outliers of a dataset using various statistical tests applied to each image.
75
168
 
76
169
  Parameters
77
170
  ----------
@@ -164,7 +257,7 @@ class Outliers:
164
257
  self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
165
258
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
166
259
  """
167
- Returns indices of Outliers with the issues identified for each
260
+ Returns indices of Outliers with the issues identified for each.
168
261
 
169
262
  Parameters
170
263
  ----------
@@ -188,6 +281,7 @@ class Outliers:
188
281
  -------
189
282
  Evaluate the dataset:
190
283
 
284
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
191
285
  >>> results = outliers.from_stats([stats1, stats2])
192
286
  >>> len(results)
193
287
  2
@@ -248,7 +342,8 @@ class Outliers:
248
342
  -------
249
343
  Evaluate the dataset:
250
344
 
251
- >>> results = outliers.evaluate(images)
345
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
346
+ >>> results = outliers.evaluate(outlier_images)
252
347
  >>> list(results.issues)
253
348
  [10, 12]
254
349
  >>> results.issues[10]
@@ -1,15 +1,8 @@
1
1
  """
2
- Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
2
+ Out-of-distribution (OOD) detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- from dataeval import _IS_TORCH_AVAILABLE
6
- from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
5
+ __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
7
6
 
8
- __all__ = ["OODOutput", "OODScoreOutput"]
9
-
10
- if _IS_TORCH_AVAILABLE:
11
- from dataeval.detectors.ood.ae_torch import OOD_AE
12
-
13
- __all__ += ["OOD_AE"]
14
-
15
- del _IS_TORCH_AVAILABLE
7
+ from dataeval.detectors.ood.ae import OOD_AE
8
+ from dataeval.detectors.ood.output import OODOutput, OODScoreOutput