dataeval 0.73.1__tar.gz → 0.74.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {dataeval-0.73.1 → dataeval-0.74.1}/PKG-INFO +3 -9
  2. {dataeval-0.73.1 → dataeval-0.74.1}/pyproject.toml +6 -15
  3. dataeval-0.74.1/src/dataeval/__init__.py +17 -0
  4. dataeval-0.74.1/src/dataeval/detectors/__init__.py +7 -0
  5. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/base.py +3 -3
  6. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/mmd.py +1 -1
  7. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/torch.py +1 -101
  8. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/linters/clusterer.py +3 -3
  9. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/linters/duplicates.py +4 -4
  10. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/linters/outliers.py +4 -4
  11. dataeval-0.74.1/src/dataeval/detectors/ood/__init__.py +15 -0
  12. dataeval-0.73.1/src/dataeval/detectors/ood/ae.py → dataeval-0.74.1/src/dataeval/detectors/ood/ae_torch.py +22 -27
  13. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/ood/base.py +63 -113
  14. dataeval-0.74.1/src/dataeval/detectors/ood/base_torch.py +109 -0
  15. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  16. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/interop.py +1 -1
  17. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/bias/__init__.py +3 -0
  18. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/bias/balance.py +73 -70
  19. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/bias/coverage.py +4 -4
  20. dataeval-0.74.1/src/dataeval/metrics/bias/diversity.py +238 -0
  21. dataeval-0.74.1/src/dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  22. dataeval-0.74.1/src/dataeval/metrics/bias/metadata_utils.py +229 -0
  23. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/bias/parity.py +51 -161
  24. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/estimators/ber.py +3 -3
  25. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/estimators/divergence.py +3 -3
  26. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/estimators/uap.py +3 -3
  27. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/base.py +2 -2
  28. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/boxratiostats.py +1 -1
  29. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/datasetstats.py +6 -6
  30. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/dimensionstats.py +1 -1
  31. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/hashstats.py +1 -1
  32. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/labelstats.py +3 -3
  33. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/pixelstats.py +1 -1
  34. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/visualstats.py +1 -1
  35. dataeval-0.74.1/src/dataeval/output.py +114 -0
  36. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/__init__.py +1 -7
  37. dataeval-0.74.1/src/dataeval/utils/gmm.py +26 -0
  38. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/metadata.py +29 -9
  39. dataeval-0.74.1/src/dataeval/utils/torch/gmm.py +98 -0
  40. dataeval-0.74.1/src/dataeval/utils/torch/models.py +330 -0
  41. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/torch/trainer.py +84 -5
  42. dataeval-0.74.1/src/dataeval/utils/torch/utils.py +169 -0
  43. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/workflows/sufficiency.py +4 -4
  44. dataeval-0.73.1/src/dataeval/__init__.py +0 -23
  45. dataeval-0.73.1/src/dataeval/detectors/__init__.py +0 -15
  46. dataeval-0.73.1/src/dataeval/detectors/ood/__init__.py +0 -15
  47. dataeval-0.73.1/src/dataeval/detectors/ood/aegmm.py +0 -66
  48. dataeval-0.73.1/src/dataeval/detectors/ood/llr.py +0 -302
  49. dataeval-0.73.1/src/dataeval/detectors/ood/vae.py +0 -97
  50. dataeval-0.73.1/src/dataeval/detectors/ood/vaegmm.py +0 -75
  51. dataeval-0.73.1/src/dataeval/metrics/bias/diversity.py +0 -307
  52. dataeval-0.73.1/src/dataeval/metrics/bias/metadata.py +0 -440
  53. dataeval-0.73.1/src/dataeval/output.py +0 -90
  54. dataeval-0.73.1/src/dataeval/utils/lazy.py +0 -26
  55. dataeval-0.73.1/src/dataeval/utils/tensorflow/__init__.py +0 -19
  56. dataeval-0.73.1/src/dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  57. dataeval-0.73.1/src/dataeval/utils/tensorflow/_internal/loss.py +0 -121
  58. dataeval-0.73.1/src/dataeval/utils/tensorflow/_internal/models.py +0 -1394
  59. dataeval-0.73.1/src/dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  60. dataeval-0.73.1/src/dataeval/utils/tensorflow/_internal/utils.py +0 -256
  61. dataeval-0.73.1/src/dataeval/utils/tensorflow/loss/__init__.py +0 -11
  62. dataeval-0.73.1/src/dataeval/utils/torch/models.py +0 -138
  63. dataeval-0.73.1/src/dataeval/utils/torch/utils.py +0 -63
  64. {dataeval-0.73.1 → dataeval-0.74.1}/LICENSE.txt +0 -0
  65. {dataeval-0.73.1 → dataeval-0.74.1}/README.md +0 -0
  66. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/__init__.py +0 -0
  67. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/cvm.py +0 -0
  68. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/ks.py +0 -0
  69. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/uncertainty.py +0 -0
  70. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/drift/updates.py +0 -0
  71. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/linters/__init__.py +0 -0
  72. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/linters/merged_stats.py +0 -0
  73. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/ood/metadata_least_likely.py +0 -0
  74. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/detectors/ood/metadata_ood_mi.py +0 -0
  75. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/__init__.py +0 -0
  76. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/estimators/__init__.py +0 -0
  77. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/metrics/stats/__init__.py +0 -0
  78. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/py.typed +0 -0
  79. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/image.py +0 -0
  80. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/shared.py +0 -0
  81. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/split_dataset.py +0 -0
  82. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/torch/__init__.py +0 -0
  83. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/torch/blocks.py +0 -0
  84. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/utils/torch/datasets.py +0 -0
  85. {dataeval-0.73.1 → dataeval-0.74.1}/src/dataeval/workflows/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.73.1
3
+ Version: 0.74.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -21,18 +21,12 @@ Classifier: Programming Language :: Python :: 3.12
21
21
  Classifier: Programming Language :: Python :: 3 :: Only
22
22
  Classifier: Topic :: Scientific/Engineering
23
23
  Provides-Extra: all
24
- Provides-Extra: tensorflow
25
24
  Provides-Extra: torch
26
- Requires-Dist: hdbscan (>=0.8.36)
27
- Requires-Dist: markupsafe (<3.0.2) ; extra == "tensorflow" or extra == "all"
28
- Requires-Dist: matplotlib ; extra == "torch" or extra == "all"
29
- Requires-Dist: numpy (>1.24.3)
25
+ Requires-Dist: matplotlib ; extra == "all"
26
+ Requires-Dist: numpy (>=1.24.3)
30
27
  Requires-Dist: pillow (>=10.3.0)
31
28
  Requires-Dist: scikit-learn (>=1.5.0)
32
29
  Requires-Dist: scipy (>=1.10)
33
- Requires-Dist: tensorflow (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
34
- Requires-Dist: tensorflow_probability (>=0.24,<0.25) ; extra == "tensorflow" or extra == "all"
35
- Requires-Dist: tf-keras (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
36
30
  Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
37
31
  Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
38
32
  Requires-Dist: tqdm
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.73.1" # dynamic
3
+ version = "0.74.1" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -42,8 +42,7 @@ packages = [
42
42
  [tool.poetry.dependencies]
43
43
  # required
44
44
  python = ">=3.9,<3.13"
45
- hdbscan = {version = ">=0.8.36"}
46
- numpy = {version = ">1.24.3"}
45
+ numpy = {version = ">=1.24.3"}
47
46
  pillow = {version = ">=10.3.0"}
48
47
  scipy = {version = ">=1.10"}
49
48
  scikit-learn = {version = ">=1.5.0"}
@@ -53,17 +52,12 @@ xxhash = {version = ">=3.3"}
53
52
 
54
53
  # optional
55
54
  matplotlib = {version = "*", optional = true}
56
- markupsafe = {version = "<3.0.2", optional = true}
57
- tensorflow = {version = ">=2.16,<2.18", optional = true}
58
- tensorflow_probability = {version = ">=0.24,<0.25", optional = true}
59
- tf-keras = {version = ">=2.16,<2.18", optional = true}
60
55
  torch = {version = ">=2.2.0", source = "pytorch", optional = true}
61
56
  torchvision = {version = ">=0.17.0", source = "pytorch", optional = true}
62
57
 
63
58
  [tool.poetry.extras]
64
- tensorflow = ["markupsafe", "tensorflow", "tensorflow_probability", "tf-keras"]
65
- torch = ["torch", "torchvision", "matplotlib"]
66
- all = ["matplotlib", "markupsafe", "tensorflow", "tensorflow_probability", "tf-keras", "torch", "torchvision"]
59
+ torch = ["torch", "torchvision"]
60
+ all = ["matplotlib", "torch", "torchvision"]
67
61
 
68
62
  [tool.poetry.group.dev]
69
63
  optional = true
@@ -89,6 +83,7 @@ pyright = {version = "*", extras = ["nodejs"]}
89
83
  maite = {version = "*"}
90
84
  pandas = {version = "*"}
91
85
  seaborn = {version = "*"}
86
+ numpy = {version = ">=2.0.2"}
92
87
  # docs
93
88
  certifi = {version = ">=2024.07.04"}
94
89
  enum_tools = {version = ">=0.12.0", extras = ["sphinx"]}
@@ -105,7 +100,7 @@ markupsafe = {version = "<3.0.2", optional = true}
105
100
 
106
101
  [[tool.poetry.source]]
107
102
  name = "pytorch"
108
- url = "https://download.pytorch.org/whl/cu124"
103
+ url = "https://download.pytorch.org/whl/cu118"
109
104
  priority = "explicit"
110
105
 
111
106
  [tool.poetry-dynamic-versioning]
@@ -121,7 +116,6 @@ files = ["src/dataeval/__init__.py"]
121
116
  name = "dataeval"
122
117
 
123
118
  [tool.poetry2conda.dependencies]
124
- tensorflow_probability = { name = "tensorflow-probability" }
125
119
  torch = { name = "pytorch" }
126
120
  xxhash = { name = "python-xxhash" }
127
121
 
@@ -142,8 +136,6 @@ parallel = true
142
136
  [tool.coverage.report]
143
137
  exclude_also = [
144
138
  "raise NotImplementedError",
145
- "if TYPE_CHECKING:",
146
- "if _IS_TENSORFLOW_AVAILABLE",
147
139
  "if _IS_TORCH_AVAILABLE",
148
140
  "if _IS_TORCHVISION_AVAILABLE",
149
141
  ]
@@ -151,7 +143,6 @@ include = ["*/src/dataeval/*"]
151
143
  omit = [
152
144
  "*/torch/blocks.py",
153
145
  "*/torch/utils.py",
154
- "*/tensorflow/_internal/models.py",
155
146
  ]
156
147
  fail_under = 90
157
148
 
@@ -0,0 +1,17 @@
1
+ __version__ = "0.74.1"
2
+
3
+ from importlib.util import find_spec
4
+
5
+ _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
+ _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
7
+
8
+ del find_spec
9
+
10
+ from dataeval import detectors, metrics # noqa: E402
11
+
12
+ __all__ = ["detectors", "metrics"]
13
+
14
+ if _IS_TORCH_AVAILABLE:
15
+ from dataeval import utils, workflows
16
+
17
+ __all__ += ["utils", "workflows"]
@@ -0,0 +1,7 @@
1
+ """
2
+ Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
3
+ """
4
+
5
+ from dataeval.detectors import drift, linters, ood
6
+
7
+ __all__ = ["drift", "linters", "ood"]
@@ -19,7 +19,7 @@ import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
21
  from dataeval.interop import as_numpy
22
- from dataeval.output import OutputMetadata, set_metadata
22
+ from dataeval.output import Output, set_metadata
23
23
 
24
24
  R = TypeVar("R")
25
25
 
@@ -43,7 +43,7 @@ class UpdateStrategy(ABC):
43
43
 
44
44
 
45
45
  @dataclass(frozen=True)
46
- class DriftBaseOutput(OutputMetadata):
46
+ class DriftBaseOutput(Output):
47
47
  """
48
48
  Base output class for Drift detector classes
49
49
 
@@ -387,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
387
387
  else:
388
388
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
389
389
 
390
- @set_metadata()
390
+ @set_metadata
391
391
  @preprocess_x
392
392
  @update_x_ref
393
393
  def predict(
@@ -161,7 +161,7 @@ class DriftMMD(BaseDrift):
161
161
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
162
162
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
163
163
 
164
- @set_metadata()
164
+ @set_metadata
165
165
  @preprocess_x
166
166
  @update_x_ref
167
167
  def predict(self, x: ArrayLike) -> DriftMMDOutput:
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from functools import partial
14
13
  from typing import Any, Callable
15
14
 
16
15
  import numpy as np
@@ -18,30 +17,7 @@ import torch
18
17
  import torch.nn as nn
19
18
  from numpy.typing import NDArray
20
19
 
21
-
22
- def get_device(device: str | torch.device | None = None) -> torch.device:
23
- """
24
- Instantiates a PyTorch device object.
25
-
26
- Parameters
27
- ----------
28
- device : str | torch.device | None, default None
29
- Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
30
- already instantiated device object. If ``None``, the GPU is selected if it is
31
- detected, otherwise the CPU is used as a fallback.
32
-
33
- Returns
34
- -------
35
- The instantiated device object.
36
- """
37
- if isinstance(device, torch.device): # Already a torch device
38
- return device
39
- else: # Instantiate device
40
- if device is None or device.lower() in ["gpu", "cuda"]:
41
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- else:
43
- torch_device = torch.device("cpu")
44
- return torch_device
20
+ from dataeval.utils.torch.utils import get_device, predict_batch
45
21
 
46
22
 
47
23
  def _mmd2_from_kernel_matrix(
@@ -79,82 +55,6 @@ def _mmd2_from_kernel_matrix(
79
55
  return mmd2
80
56
 
81
57
 
82
- def predict_batch(
83
- x: NDArray[Any] | torch.Tensor,
84
- model: Callable | nn.Module | nn.Sequential,
85
- device: torch.device | None = None,
86
- batch_size: int = int(1e10),
87
- preprocess_fn: Callable | None = None,
88
- dtype: type[np.generic] | torch.dtype = np.float32,
89
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
90
- """
91
- Make batch predictions on a model.
92
-
93
- Parameters
94
- ----------
95
- x : np.ndarray | torch.Tensor
96
- Batch of instances.
97
- model : Callable | nn.Module | nn.Sequential
98
- PyTorch model.
99
- device : torch.device | None, default None
100
- Device type used. The default None tries to use the GPU and falls back on CPU.
101
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
102
- batch_size : int, default 1e10
103
- Batch size used during prediction.
104
- preprocess_fn : Callable | None, default None
105
- Optional preprocessing function for each batch.
106
- dtype : np.dtype | torch.dtype, default np.float32
107
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
108
-
109
- Returns
110
- -------
111
- NDArray | torch.Tensor | tuple
112
- Numpy array, torch tensor or tuples of those with model outputs.
113
- """
114
- device = get_device(device)
115
- if isinstance(x, np.ndarray):
116
- x = torch.from_numpy(x)
117
- n = len(x)
118
- n_minibatch = int(np.ceil(n / batch_size))
119
- return_np = not isinstance(dtype, torch.dtype)
120
- preds = []
121
- with torch.no_grad():
122
- for i in range(n_minibatch):
123
- istart, istop = i * batch_size, min((i + 1) * batch_size, n)
124
- x_batch = x[istart:istop]
125
- if isinstance(preprocess_fn, Callable):
126
- x_batch = preprocess_fn(x_batch)
127
- preds_tmp = model(x_batch.to(device))
128
- if isinstance(preds_tmp, (list, tuple)):
129
- if len(preds) == 0: # init tuple with lists to store predictions
130
- preds = tuple([] for _ in range(len(preds_tmp)))
131
- for j, p in enumerate(preds_tmp):
132
- if isinstance(p, torch.Tensor):
133
- p = p.cpu()
134
- preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
135
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
136
- if isinstance(preds_tmp, torch.Tensor):
137
- preds_tmp = preds_tmp.cpu()
138
- if isinstance(preds, tuple):
139
- preds = list(preds)
140
- preds.append(
141
- preds_tmp
142
- if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
143
- else preds_tmp.numpy()
144
- )
145
- else:
146
- raise TypeError(
147
- f"Model output type {type(preds_tmp)} not supported. The model \
148
- output type needs to be one of list, tuple, NDArray or \
149
- torch.Tensor."
150
- )
151
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
152
- out: tuple | np.ndarray | torch.Tensor = (
153
- tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
154
- )
155
- return out
156
-
157
-
158
58
  def preprocess_drift(
159
59
  x: NDArray[Any],
160
60
  model: nn.Module,
@@ -11,12 +11,12 @@ from scipy.cluster.hierarchy import linkage
11
11
  from scipy.spatial.distance import pdist, squareform
12
12
 
13
13
  from dataeval.interop import to_numpy
14
- from dataeval.output import OutputMetadata, set_metadata
14
+ from dataeval.output import Output, set_metadata
15
15
  from dataeval.utils.shared import flatten
16
16
 
17
17
 
18
18
  @dataclass(frozen=True)
19
- class ClustererOutput(OutputMetadata):
19
+ class ClustererOutput(Output):
20
20
  """
21
21
  Output class for :class:`Clusterer` lint detector
22
22
 
@@ -495,7 +495,7 @@ class Clusterer:
495
495
  return exact_dupes, near_dupes
496
496
 
497
497
  # TODO: Move data input to evaluate from class
498
- @set_metadata(["data"])
498
+ @set_metadata(state=["data"])
499
499
  def evaluate(self) -> ClustererOutput:
500
500
  """Finds and flags indices of the data for Outliers and :term:`duplicates<Duplicates>`
501
501
 
@@ -9,7 +9,7 @@ from numpy.typing import ArrayLike
9
9
 
10
10
  from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
11
  from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
- from dataeval.output import OutputMetadata, set_metadata
12
+ from dataeval.output import Output, set_metadata
13
13
 
14
14
  DuplicateGroup = list[int]
15
15
  DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
@@ -17,7 +17,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
17
17
 
18
18
 
19
19
  @dataclass(frozen=True)
20
- class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
20
+ class DuplicatesOutput(Generic[TIndexCollection], Output):
21
21
  """
22
22
  Output class for :class:`Duplicates` lint detector
23
23
 
@@ -89,7 +89,7 @@ class Duplicates:
89
89
  @overload
90
90
  def from_stats(self, hashes: Sequence[HashStatsOutput]) -> DuplicatesOutput[DatasetDuplicateGroupMap]: ...
91
91
 
92
- @set_metadata(["only_exact"])
92
+ @set_metadata(state=["only_exact"])
93
93
  def from_stats(
94
94
  self, hashes: HashStatsOutput | Sequence[HashStatsOutput]
95
95
  ) -> DuplicatesOutput[DuplicateGroup] | DuplicatesOutput[DatasetDuplicateGroupMap]:
@@ -138,7 +138,7 @@ class Duplicates:
138
138
 
139
139
  return DuplicatesOutput(**duplicates)
140
140
 
141
- @set_metadata(["only_exact"])
141
+ @set_metadata(state=["only_exact"])
142
142
  def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
143
143
  """
144
144
  Returns duplicate image indices for both exact matches and near matches
@@ -14,7 +14,7 @@ from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
14
14
  from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
15
15
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput
16
16
  from dataeval.metrics.stats.visualstats import VisualStatsOutput
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
  IndexIssueMap = dict[int, dict[str, float]]
20
20
  OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
@@ -22,7 +22,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
22
22
 
23
23
 
24
24
  @dataclass(frozen=True)
25
- class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
25
+ class OutliersOutput(Generic[TIndexIssueMap], Output):
26
26
  """
27
27
  Output class for :class:`Outliers` lint detector
28
28
 
@@ -159,7 +159,7 @@ class Outliers:
159
159
  @overload
160
160
  def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
161
161
 
162
- @set_metadata(["outlier_method", "outlier_threshold"])
162
+ @set_metadata(state=["outlier_method", "outlier_threshold"])
163
163
  def from_stats(
164
164
  self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
165
165
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
@@ -228,7 +228,7 @@ class Outliers:
228
228
 
229
229
  return OutliersOutput(output_list)
230
230
 
231
- @set_metadata(["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
231
+ @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
232
232
  def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
233
233
  """
234
234
  Returns indices of Outliers with the issues identified for each
@@ -0,0 +1,15 @@
1
+ """
2
+ Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
+ """
4
+
5
+ from dataeval import _IS_TORCH_AVAILABLE
6
+ from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
7
+
8
+ __all__ = ["OODOutput", "OODScoreOutput"]
9
+
10
+ if _IS_TORCH_AVAILABLE:
11
+ from dataeval.detectors.ood.ae_torch import OOD_AE
12
+
13
+ __all__ += ["OOD_AE"]
14
+
15
+ del _IS_TORCH_AVAILABLE
@@ -1,4 +1,6 @@
1
1
  """
2
+ Adapted for Pytorch from
3
+
2
4
  Source code derived from Alibi-Detect 0.11.4
3
5
  https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
6
 
@@ -8,55 +10,48 @@ Licensed under Apache Software License (Apache 2.0)
8
10
 
9
11
  from __future__ import annotations
10
12
 
11
- __all__ = ["OOD_AE"]
12
-
13
- from typing import TYPE_CHECKING, Callable
13
+ from typing import Callable
14
14
 
15
15
  import numpy as np
16
+ import torch
16
17
  from numpy.typing import ArrayLike
17
18
 
18
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
19
+ from dataeval.detectors.ood.base import OODScoreOutput
20
+ from dataeval.detectors.ood.base_torch import OODBase
19
21
  from dataeval.interop import as_numpy
20
- from dataeval.utils.lazy import lazyload
21
- from dataeval.utils.tensorflow._internal.utils import predict_batch
22
-
23
- if TYPE_CHECKING:
24
- import tensorflow as tf
25
- import tf_keras as keras
26
-
27
- import dataeval.utils.tensorflow._internal.models as tf_models
28
- else:
29
- tf = lazyload("tensorflow")
30
- keras = lazyload("tf_keras")
31
- tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
22
+ from dataeval.utils.torch.utils import predict_batch
32
23
 
33
24
 
34
25
  class OOD_AE(OODBase):
35
26
  """
36
- Autoencoder-based :term:`out of distribution<Out-of-distribution (OOD)>` detector.
27
+ Autoencoder based out-of-distribution detector.
37
28
 
38
29
  Parameters
39
30
  ----------
40
- model : AE
41
- An :term:`autoencoder<Autoencoder>` model.
31
+ model : AriaAutoencoder
32
+ An Autoencoder model.
42
33
  """
43
34
 
44
- def __init__(self, model: tf_models.AE) -> None:
45
- super().__init__(model)
35
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
36
+ super().__init__(model, device)
46
37
 
47
38
  def fit(
48
39
  self,
49
40
  x_ref: ArrayLike,
50
- threshold_perc: float = 100.0,
51
- loss_fn: Callable[..., tf.Tensor] | None = None,
52
- optimizer: keras.optimizers.Optimizer | None = None,
41
+ threshold_perc: float,
42
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
43
+ optimizer: torch.optim.Optimizer | None = None,
53
44
  epochs: int = 20,
54
45
  batch_size: int = 64,
55
- verbose: bool = True,
46
+ verbose: bool = False,
56
47
  ) -> None:
57
48
  if loss_fn is None:
58
- loss_fn = keras.losses.MeanSquaredError()
59
- super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
49
+ loss_fn = torch.nn.MSELoss()
50
+
51
+ if optimizer is None:
52
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
53
+
54
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
60
55
 
61
56
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
62
57
  self._validate(X := as_numpy(X))