dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.81.0"
11
+ __version__ = "0.82.1"
12
12
 
13
13
  import logging
14
14
 
dataeval/config.py CHANGED
@@ -4,36 +4,61 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
8
8
 
9
+ import sys
10
+ from typing import Union
11
+
12
+ if sys.version_info >= (3, 10):
13
+ from typing import TypeAlias
14
+ else:
15
+ from typing_extensions import TypeAlias
16
+
17
+ import numpy as np
9
18
  import torch
10
- from torch import device
11
19
 
12
- _device: device | None = None
20
+ _device: torch.device | None = None
13
21
  _processes: int | None = None
22
+ _seed: int | None = None
23
+
24
+ DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
25
+ """
26
+ Type alias for types that are acceptable for specifying a torch.device.
27
+
28
+ See Also
29
+ --------
30
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
31
+ """
32
+
14
33
 
34
+ def _todevice(device: DeviceLike) -> torch.device:
35
+ return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
15
36
 
16
- def set_device(device: str | device | int) -> None:
37
+
38
+ def set_device(device: DeviceLike) -> None:
17
39
  """
18
40
  Sets the default device to use when executing against a PyTorch backend.
19
41
 
20
42
  Parameters
21
43
  ----------
22
- device : str or int or `torch.device`
23
- The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
- documentation for more information.
44
+ device : DeviceLike
45
+ The default device to use. See documentation for more information.
46
+
47
+ See Also
48
+ --------
49
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
25
50
  """
26
51
  global _device
27
- _device = torch.device(device)
52
+ _device = _todevice(device)
28
53
 
29
54
 
30
- def get_device(override: str | device | int | None = None) -> torch.device:
55
+ def get_device(override: DeviceLike | None = None) -> torch.device:
31
56
  """
32
57
  Returns the PyTorch device to use.
33
58
 
34
59
  Parameters
35
60
  ----------
36
- override : str or int or `torch.device` or None, default None
61
+ override : DeviceLike or None, default None
37
62
  The user specified override if provided, otherwise returns the default device.
38
63
 
39
64
  Returns
@@ -44,7 +69,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
44
69
  global _device
45
70
  return torch.get_default_device() if _device is None else _device
46
71
  else:
47
- return torch.device(override)
72
+ return _todevice(override)
48
73
 
49
74
 
50
75
  def set_max_processes(processes: int | None) -> None:
@@ -75,3 +100,35 @@ def get_max_processes() -> int | None:
75
100
  """
76
101
  global _processes
77
102
  return _processes
103
+
104
+
105
+ def set_seed(seed: int | None, all_generators: bool = False) -> None:
106
+ """
107
+ Sets the seed for use by classes that allow for a random state or seed.
108
+
109
+ Parameters
110
+ ----------
111
+ seed : int or None
112
+ The seed to use.
113
+ all_generators : bool, default False
114
+ Whether to set the seed for all generators, including NumPy and PyTorch.
115
+ """
116
+ global _seed
117
+ _seed = seed
118
+
119
+ if all_generators:
120
+ np.random.seed(seed)
121
+ torch.manual_seed(seed)
122
+
123
+
124
+ def get_seed() -> int | None:
125
+ """
126
+ Returns the seed for random state or seed.
127
+
128
+ Returns
129
+ -------
130
+ int or None
131
+ The seed to use.
132
+ """
133
+ global _seed
134
+ return _seed
@@ -14,9 +14,9 @@ __all__ = [
14
14
  ]
15
15
 
16
16
  from dataeval.detectors.drift import updates
17
- from dataeval.detectors.drift._base import DriftOutput
18
17
  from dataeval.detectors.drift._cvm import DriftCVM
19
18
  from dataeval.detectors.drift._ks import DriftKS
20
- from dataeval.detectors.drift._mmd import DriftMMD, DriftMMDOutput
19
+ from dataeval.detectors.drift._mmd import DriftMMD
21
20
  from dataeval.detectors.drift._torch import preprocess_drift
22
21
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
+ from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
@@ -11,84 +11,28 @@ from __future__ import annotations
11
11
  __all__ = []
12
12
 
13
13
  import math
14
- from abc import ABC, abstractmethod
15
- from dataclasses import dataclass
14
+ from abc import abstractmethod
16
15
  from functools import wraps
17
- from typing import Any, Callable, Literal, TypeVar
16
+ from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
18
17
 
19
18
  import numpy as np
20
19
  from numpy.typing import NDArray
21
20
 
22
- from dataeval._output import Output, set_metadata
21
+ from dataeval.outputs import DriftOutput
22
+ from dataeval.outputs._base import set_metadata
23
23
  from dataeval.typing import Array, ArrayLike
24
24
  from dataeval.utils._array import as_numpy, to_numpy
25
25
 
26
26
  R = TypeVar("R")
27
27
 
28
28
 
29
- class UpdateStrategy(ABC):
29
+ @runtime_checkable
30
+ class UpdateStrategy(Protocol):
30
31
  """
31
- Updates reference dataset for drift detector
32
-
33
- Parameters
34
- ----------
35
- n : int
36
- Update with last n instances seen by the detector.
37
- """
38
-
39
- def __init__(self, n: int) -> None:
40
- self.n = n
41
-
42
- @abstractmethod
43
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
44
- """Abstract implementation of update strategy"""
45
-
46
-
47
- @dataclass(frozen=True)
48
- class DriftBaseOutput(Output):
49
- """
50
- Base output class for Drift Detector classes
51
- """
52
-
53
- drifted: bool
54
- threshold: float
55
- p_val: float
56
- distance: float
57
-
58
-
59
- @dataclass(frozen=True)
60
- class DriftOutput(DriftBaseOutput):
61
- """
62
- Output class for :class:`.DriftCVM`, :class:`.DriftKS`, and :class:`.DriftUncertainty` drift detectors.
63
-
64
- Attributes
65
- ----------
66
- drifted : bool
67
- :term:`Drift` prediction for the images
68
- threshold : float
69
- Threshold after multivariate correction if needed
70
- p_val : float
71
- Instance-level p-value
72
- distance : float
73
- Instance-level distance
74
- feature_drift : NDArray
75
- Feature-level array of images detected to have drifted
76
- feature_threshold : float
77
- Feature-level threshold to determine drift
78
- p_vals : NDArray
79
- Feature-level p-values
80
- distances : NDArray
81
- Feature-level distances
32
+ Protocol for reference dataset update strategy for drift detectors
82
33
  """
83
34
 
84
- # drifted: bool
85
- # threshold: float
86
- # p_val: float
87
- # distance: float
88
- feature_drift: NDArray[np.bool_]
89
- feature_threshold: float
90
- p_vals: NDArray[np.float32]
91
- distances: NDArray[np.float32]
35
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
92
36
 
93
37
 
94
38
  def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
@@ -10,44 +10,18 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from dataclasses import dataclass
14
13
  from typing import Callable
15
14
 
16
15
  import torch
17
16
 
18
- from dataeval._output import set_metadata
19
- from dataeval.config import get_device
20
- from dataeval.detectors.drift._base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
17
+ from dataeval.config import DeviceLike, get_device
18
+ from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
21
19
  from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
20
+ from dataeval.outputs import DriftMMDOutput
21
+ from dataeval.outputs._base import set_metadata
22
22
  from dataeval.typing import ArrayLike
23
23
 
24
24
 
25
- @dataclass(frozen=True)
26
- class DriftMMDOutput(DriftBaseOutput):
27
- """
28
- Output class for :class:`.DriftMMD` :term:`drift<Drift>` detector.
29
-
30
- Attributes
31
- ----------
32
- drifted : bool
33
- Drift prediction for the images
34
- threshold : float
35
- :term:`P-Value` used for significance of the permutation test
36
- p_val : float
37
- P-value obtained from the permutation test
38
- distance : float
39
- MMD^2 between the reference and test set
40
- distance_threshold : float
41
- MMD^2 threshold above which drift is flagged
42
- """
43
-
44
- # drifted: bool
45
- # threshold: float
46
- # p_val: float
47
- # distance: float
48
- distance_threshold: float
49
-
50
-
51
25
  class DriftMMD(BaseDrift):
52
26
  """
53
27
  :term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm \
@@ -57,7 +31,7 @@ class DriftMMD(BaseDrift):
57
31
  ----------
58
32
  x_ref : ArrayLike
59
33
  Data used as reference distribution.
60
- p_val : float | None, default 0.05
34
+ p_val : float or None, default 0.05
61
35
  :term:`P-value` used for significance of the statistical test for each feature.
62
36
  If the FDR correction method is used, this corresponds to the acceptable
63
37
  q-value.
@@ -65,14 +39,14 @@ class DriftMMD(BaseDrift):
65
39
  Whether the given reference data ``x_ref`` has been preprocessed yet.
66
40
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
67
41
  If ``False``, the reference data will also be preprocessed.
68
- update_x_ref : UpdateStrategy | None, default None
42
+ update_x_ref : UpdateStrategy or None, default None
69
43
  Reference data can optionally be updated using an UpdateStrategy class. Update
70
44
  using the last n instances seen by the detector with LastSeenUpdateStrategy
71
45
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
72
- preprocess_fn : Callable | None, default None
46
+ preprocess_fn : Callable or None, default None
73
47
  Function to preprocess the data before computing the data drift metrics.
74
48
  Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
75
- sigma : ArrayLike | None, default None
49
+ sigma : ArrayLike or None, default None
76
50
  Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
77
51
  bandwidth values as an array. The kernel evaluation is then averaged over
78
52
  those bandwidths.
@@ -80,9 +54,9 @@ class DriftMMD(BaseDrift):
80
54
  Whether to already configure the kernel bandwidth from the reference data.
81
55
  n_permutations : int, default 100
82
56
  Number of permutations used in the permutation test.
83
- device : str | None, default None
84
- Device type used. The default None uses the GPU and falls back on CPU.
85
- Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
57
+ device : DeviceLike or None, default None
58
+ The hardware device to use if specified, otherwise uses the DataEval
59
+ default or torch default.
86
60
 
87
61
  Example
88
62
  -------
@@ -110,7 +84,7 @@ class DriftMMD(BaseDrift):
110
84
  sigma: ArrayLike | None = None,
111
85
  configure_kernel_from_x_ref: bool = True,
112
86
  n_permutations: int = 100,
113
- device: str | torch.device | None = None,
87
+ device: DeviceLike | None = None,
114
88
  ) -> None:
115
89
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
116
90
 
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.config import get_device
20
+ from dataeval.config import DeviceLike, get_device
21
21
  from dataeval.utils.torch._internal import predict_batch
22
22
 
23
23
 
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
59
59
  def preprocess_drift(
60
60
  x: NDArray[Any],
61
61
  model: nn.Module,
62
- device: str | torch.device | None = None,
62
+ device: DeviceLike | None = None,
63
63
  preprocess_batch_fn: Callable | None = None,
64
64
  batch_size: int = int(1e10),
65
65
  dtype: type[np.generic] | torch.dtype = np.float32,
@@ -73,15 +73,15 @@ def preprocess_drift(
73
73
  Batch of instances.
74
74
  model : nn.Module
75
75
  Model used for preprocessing.
76
- device : torch.device | None, default None
77
- Device type used. The default None tries to use the GPU and falls back on CPU.
78
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
79
- preprocess_batch_fn : Callable | None, default None
76
+ device : DeviceLike or None, default None
77
+ The hardware device to use if specified, otherwise uses the DataEval
78
+ default or torch default.
79
+ preprocess_batch_fn : Callable or None, default None
80
80
  Optional batch preprocessing function. For example to convert a list of objects
81
81
  to a batch which can be processed by the PyTorch model.
82
82
  batch_size : int, default 1e10
83
83
  Batch size used during prediction.
84
- dtype : np.dtype | torch.dtype, default np.float32
84
+ dtype : np.dtype or torch.dtype, default np.float32
85
85
  Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
86
86
 
87
87
  Returns
@@ -19,9 +19,10 @@ from scipy.special import softmax
19
19
  from scipy.stats import entropy
20
20
 
21
21
  from dataeval.config import get_device
22
- from dataeval.detectors.drift._base import DriftOutput, UpdateStrategy
22
+ from dataeval.detectors.drift._base import UpdateStrategy
23
23
  from dataeval.detectors.drift._ks import DriftKS
24
24
  from dataeval.detectors.drift._torch import preprocess_drift
25
+ from dataeval.outputs import DriftOutput
25
26
  from dataeval.typing import ArrayLike
26
27
 
27
28
 
@@ -84,20 +85,20 @@ class DriftUncertainty:
84
85
  Whether the given reference data ``x_ref`` has been preprocessed yet.
85
86
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
86
87
  If ``False``, the reference data will also be preprocessed.
87
- update_x_ref : UpdateStrategy | None, default None
88
+ update_x_ref : UpdateStrategy or None, default None
88
89
  Reference data can optionally be updated using an UpdateStrategy class. Update
89
90
  using the last n instances seen by the detector with LastSeenUpdateStrategy
90
91
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
91
- preds_type : "probs" | "logits", default "probs"
92
+ preds_type : "probs" or "logits", default "probs"
92
93
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
93
94
  'logits' (in [-inf,inf]).
94
95
  batch_size : int, default 32
95
96
  Batch size used to evaluate model. Only relevant when backend has been
96
97
  specified for batch prediction.
97
- preprocess_batch_fn : Callable | None, default None
98
+ preprocess_batch_fn : Callable or None, default None
98
99
  Optional batch preprocessing function. For example to convert a list of
99
100
  objects to a batch which can be processed by the model.
100
- device : str | None, default None
101
+ device : DeviceLike or None, default None
101
102
  Device type used. The default None tries to use the GPU and falls back on
102
103
  CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
103
104
 
@@ -7,15 +7,32 @@ from __future__ import annotations
7
7
 
8
8
  __all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
9
9
 
10
+ from abc import ABC, abstractmethod
10
11
  from typing import Any
11
12
 
12
13
  import numpy as np
13
14
  from numpy.typing import NDArray
14
15
 
15
- from dataeval.detectors.drift._base import UpdateStrategy
16
+
17
+ class BaseUpdateStrategy(ABC):
18
+ """
19
+ Updates reference dataset for drift detector
20
+
21
+ Parameters
22
+ ----------
23
+ n : int
24
+ Update with last n instances seen by the detector.
25
+ """
26
+
27
+ def __init__(self, n: int) -> None:
28
+ self.n = n
29
+
30
+ @abstractmethod
31
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
32
+ """Abstract implementation of update strategy"""
16
33
 
17
34
 
18
- class LastSeenUpdate(UpdateStrategy):
35
+ class LastSeenUpdate(BaseUpdateStrategy):
19
36
  """
20
37
  Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
21
38
 
@@ -30,7 +47,7 @@ class LastSeenUpdate(UpdateStrategy):
30
47
  return x_updated[-self.n :]
31
48
 
32
49
 
33
- class ReservoirSamplingUpdate(UpdateStrategy):
50
+ class ReservoirSamplingUpdate(BaseUpdateStrategy):
34
51
  """
35
52
  Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
36
53
 
@@ -9,5 +9,6 @@ __all__ = [
9
9
  "OutliersOutput",
10
10
  ]
11
11
 
12
- from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
13
- from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
12
+ from dataeval.detectors.linters.duplicates import Duplicates
13
+ from dataeval.detectors.linters.outliers import Outliers
14
+ from dataeval.outputs._linters import DuplicatesOutput, OutliersOutput
@@ -2,40 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from dataclasses import dataclass
6
- from typing import Any, Generic, Iterable, Sequence, TypeVar, overload
5
+ from typing import Any, Sequence, overload
7
6
 
8
- from torch.utils.data import Dataset
9
-
10
- from dataeval._output import Output, set_metadata
7
+ from dataeval.metrics.stats import hashstats
11
8
  from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
12
- from dataeval.metrics.stats._hashstats import HashStatsOutput, hashstats
13
- from dataeval.typing import ArrayLike
14
-
15
- DuplicateGroup = list[int]
16
- DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
17
- TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
18
-
19
-
20
- @dataclass(frozen=True)
21
- class DuplicatesOutput(Generic[TIndexCollection], Output):
22
- """
23
- Output class for :class:`.Duplicates` lint detector.
24
-
25
- Attributes
26
- ----------
27
- exact : list[list[int] | dict[int, list[int]]]
28
- Indices of images that are exact matches
29
- near: list[list[int] | dict[int, list[int]]]
30
- Indices of images that are near matches
31
-
32
- - For a single dataset, indices are returned as a list of index groups.
33
- - For multiple datasets, indices are returned as dictionaries where the key is the
34
- index of the dataset, and the value is the list index groups from that dataset.
35
- """
36
-
37
- exact: list[TIndexCollection]
38
- near: list[TIndexCollection]
9
+ from dataeval.outputs import DuplicatesOutput, HashStatsOutput
10
+ from dataeval.outputs._base import set_metadata
11
+ from dataeval.outputs._linters import DatasetDuplicateGroupMap, DuplicateGroup
12
+ from dataeval.typing import Array, Dataset
13
+ from dataeval.utils.data._images import Images
39
14
 
40
15
 
41
16
  class Duplicates:
@@ -113,13 +88,13 @@ class Duplicates:
113
88
  """
114
89
 
115
90
  if isinstance(hashes, HashStatsOutput):
116
- return DuplicatesOutput(**self._get_duplicates(hashes.dict()))
91
+ return DuplicatesOutput(**self._get_duplicates(hashes.data()))
117
92
 
118
93
  if not isinstance(hashes, Sequence):
119
94
  raise TypeError("Invalid stats output type; only use output from hashstats.")
120
95
 
121
96
  combined, dataset_steps = combine_stats(hashes)
122
- duplicates = self._get_duplicates(combined.dict())
97
+ duplicates = self._get_duplicates(combined.data())
123
98
 
124
99
  # split up results from combined dataset into individual dataset buckets
125
100
  for dup_type, dup_list in duplicates.items():
@@ -134,22 +109,15 @@ class Duplicates:
134
109
 
135
110
  return DuplicatesOutput(**duplicates)
136
111
 
137
- @overload
138
- def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]: ...
139
- @overload
140
- def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> DuplicatesOutput[DuplicateGroup]: ...
141
-
142
112
  @set_metadata(state=["only_exact"])
143
- def evaluate(
144
- self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
145
- ) -> DuplicatesOutput[DuplicateGroup]:
113
+ def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> DuplicatesOutput[DuplicateGroup]:
146
114
  """
147
115
  Returns duplicate image indices for both exact matches and near matches
148
116
 
149
117
  Parameters
150
118
  ----------
151
- data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput | Sequence[StatsOutput]
152
- A dataset of images in an ArrayLike format or the output(s) from a hashstats analysis
119
+ data : Iterable[Array], shape - (N, C, H, W) | Dataset[tuple[Array, Any, Any]]
120
+ A dataset of images in an Array format or the output(s) from a hashstats analysis
153
121
 
154
122
  Returns
155
123
  -------
@@ -166,7 +134,7 @@ class Duplicates:
166
134
  >>> all_dupes.evaluate(duplicate_images)
167
135
  DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
168
136
  """ # noqa: E501
169
- images = (d[0] for d in data) if isinstance(data, Dataset) else data
137
+ images = Images(data) if isinstance(data, Dataset) else data
170
138
  self.stats = hashstats(images)
171
- duplicates = self._get_duplicates(self.stats.dict())
139
+ duplicates = self._get_duplicates(self.stats.data())
172
140
  return DuplicatesOutput(**duplicates)