dataeval 0.73.0__tar.gz → 0.74.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {dataeval-0.73.0 → dataeval-0.74.0}/PKG-INFO +1 -2
  2. {dataeval-0.73.0 → dataeval-0.74.0}/pyproject.toml +6 -5
  3. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/__init__.py +3 -3
  4. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/__init__.py +1 -1
  5. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/__init__.py +1 -1
  6. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/base.py +2 -2
  7. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/torch.py +1 -101
  8. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/linters/clusterer.py +1 -1
  9. dataeval-0.74.0/src/dataeval/detectors/ood/__init__.py +22 -0
  10. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/ae.py +2 -1
  11. dataeval-0.74.0/src/dataeval/detectors/ood/ae_torch.py +70 -0
  12. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/aegmm.py +4 -3
  13. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/base.py +58 -108
  14. dataeval-0.74.0/src/dataeval/detectors/ood/base_tf.py +109 -0
  15. dataeval-0.74.0/src/dataeval/detectors/ood/base_torch.py +109 -0
  16. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/llr.py +2 -2
  17. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/metadata_ks_compare.py +53 -14
  18. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/vae.py +3 -2
  19. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/vaegmm.py +5 -4
  20. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/bias/__init__.py +3 -0
  21. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/bias/balance.py +77 -64
  22. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/bias/coverage.py +12 -12
  23. dataeval-0.74.0/src/dataeval/metrics/bias/diversity.py +238 -0
  24. dataeval-0.74.0/src/dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  25. dataeval-0.74.0/src/dataeval/metrics/bias/metadata_utils.py +229 -0
  26. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/bias/parity.py +54 -158
  27. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/__init__.py +2 -2
  28. dataeval-0.74.0/src/dataeval/utils/gmm.py +26 -0
  29. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/metadata.py +29 -9
  30. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/shared.py +1 -1
  31. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/split_dataset.py +12 -6
  32. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/_internal/gmm.py +4 -24
  33. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/torch/datasets.py +2 -2
  34. dataeval-0.74.0/src/dataeval/utils/torch/gmm.py +98 -0
  35. dataeval-0.74.0/src/dataeval/utils/torch/models.py +330 -0
  36. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/torch/trainer.py +84 -5
  37. dataeval-0.74.0/src/dataeval/utils/torch/utils.py +169 -0
  38. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/workflows/__init__.py +1 -1
  39. dataeval-0.73.0/src/dataeval/detectors/ood/__init__.py +0 -15
  40. dataeval-0.73.0/src/dataeval/metrics/bias/diversity.py +0 -278
  41. dataeval-0.73.0/src/dataeval/metrics/bias/metadata.py +0 -358
  42. dataeval-0.73.0/src/dataeval/utils/torch/models.py +0 -138
  43. dataeval-0.73.0/src/dataeval/utils/torch/utils.py +0 -63
  44. {dataeval-0.73.0 → dataeval-0.74.0}/LICENSE.txt +0 -0
  45. {dataeval-0.73.0 → dataeval-0.74.0}/README.md +0 -0
  46. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/cvm.py +0 -0
  47. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/ks.py +0 -0
  48. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/mmd.py +0 -0
  49. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/uncertainty.py +0 -0
  50. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/drift/updates.py +0 -0
  51. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/linters/__init__.py +0 -0
  52. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/linters/duplicates.py +0 -0
  53. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/linters/merged_stats.py +0 -0
  54. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/linters/outliers.py +0 -0
  55. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/metadata_least_likely.py +0 -0
  56. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/detectors/ood/metadata_ood_mi.py +0 -0
  57. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/interop.py +0 -0
  58. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/__init__.py +0 -0
  59. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/estimators/__init__.py +0 -0
  60. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/estimators/ber.py +0 -0
  61. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/estimators/divergence.py +0 -0
  62. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/estimators/uap.py +0 -0
  63. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/__init__.py +0 -0
  64. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/base.py +0 -0
  65. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/boxratiostats.py +0 -0
  66. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/datasetstats.py +0 -0
  67. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/dimensionstats.py +0 -0
  68. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/hashstats.py +0 -0
  69. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/labelstats.py +0 -0
  70. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/pixelstats.py +0 -0
  71. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/metrics/stats/visualstats.py +0 -0
  72. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/output.py +0 -0
  73. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/py.typed +0 -0
  74. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/image.py +0 -0
  75. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/lazy.py +0 -0
  76. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/__init__.py +0 -0
  77. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/_internal/loss.py +0 -0
  78. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/_internal/models.py +0 -0
  79. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/_internal/trainer.py +0 -0
  80. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/_internal/utils.py +0 -0
  81. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/tensorflow/loss/__init__.py +0 -0
  82. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/torch/__init__.py +0 -0
  83. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/utils/torch/blocks.py +0 -0
  84. {dataeval-0.73.0 → dataeval-0.74.0}/src/dataeval/workflows/sufficiency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.73.0
3
+ Version: 0.74.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -23,7 +23,6 @@ Classifier: Topic :: Scientific/Engineering
23
23
  Provides-Extra: all
24
24
  Provides-Extra: tensorflow
25
25
  Provides-Extra: torch
26
- Requires-Dist: hdbscan (>=0.8.36)
27
26
  Requires-Dist: markupsafe (<3.0.2) ; extra == "tensorflow" or extra == "all"
28
27
  Requires-Dist: matplotlib ; extra == "torch" or extra == "all"
29
28
  Requires-Dist: numpy (>1.24.3)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.73.0" # dynamic
3
+ version = "0.74.0" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -42,7 +42,6 @@ packages = [
42
42
  [tool.poetry.dependencies]
43
43
  # required
44
44
  python = ">=3.9,<3.13"
45
- hdbscan = {version = ">=0.8.36"}
46
45
  numpy = {version = ">1.24.3"}
47
46
  pillow = {version = ">=10.3.0"}
48
47
  scipy = {version = ">=1.10"}
@@ -69,8 +68,7 @@ all = ["matplotlib", "markupsafe", "tensorflow", "tensorflow_probability", "tf-k
69
68
  optional = true
70
69
 
71
70
  [tool.poetry.group.dev.dependencies]
72
- tox = {version = "*"}
73
- tox-uv = {version = "*"}
71
+ nox = {version = "*", extras = ["uv"]}
74
72
  uv = {version = "*"}
75
73
  poetry = {version = "*"}
76
74
  poetry-lock-groups-plugin = {version = "*"}
@@ -122,7 +120,6 @@ files = ["src/dataeval/__init__.py"]
122
120
  name = "dataeval"
123
121
 
124
122
  [tool.poetry2conda.dependencies]
125
- nvidia-cudnn-cu11 = { name = "cudnn" }
126
123
  tensorflow_probability = { name = "tensorflow-probability" }
127
124
  torch = { name = "pytorch" }
128
125
  xxhash = { name = "python-xxhash" }
@@ -145,6 +142,9 @@ parallel = true
145
142
  exclude_also = [
146
143
  "raise NotImplementedError",
147
144
  "if TYPE_CHECKING:",
145
+ "if _IS_TENSORFLOW_AVAILABLE",
146
+ "if _IS_TORCH_AVAILABLE",
147
+ "if _IS_TORCHVISION_AVAILABLE",
148
148
  ]
149
149
  include = ["*/src/dataeval/*"]
150
150
  omit = [
@@ -164,6 +164,7 @@ exclude = [
164
164
  "*env*",
165
165
  "output",
166
166
  "_build",
167
+ ".nox",
167
168
  ".tox",
168
169
  "prototype",
169
170
  ]
@@ -1,4 +1,4 @@
1
- __version__ = "0.73.0"
1
+ __version__ = "0.74.0"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -12,12 +12,12 @@ from dataeval import detectors, metrics # noqa: E402
12
12
 
13
13
  __all__ = ["detectors", "metrics"]
14
14
 
15
- if _IS_TORCH_AVAILABLE: # pragma: no cover
15
+ if _IS_TORCH_AVAILABLE:
16
16
  from dataeval import workflows
17
17
 
18
18
  __all__ += ["workflows"]
19
19
 
20
- if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
21
21
  from dataeval import utils
22
22
 
23
23
  __all__ += ["utils"]
@@ -7,7 +7,7 @@ from dataeval.detectors import drift, linters
7
7
 
8
8
  __all__ = ["drift", "linters"]
9
9
 
10
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
10
+ if _IS_TENSORFLOW_AVAILABLE:
11
11
  from dataeval.detectors import ood
12
12
 
13
13
  __all__ += ["ood"]
@@ -10,7 +10,7 @@ from dataeval.detectors.drift.ks import DriftKS
10
10
 
11
11
  __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
12
12
 
13
- if _IS_TORCH_AVAILABLE: # pragma: no cover
13
+ if _IS_TORCH_AVAILABLE:
14
14
  from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
15
  from dataeval.detectors.drift.torch import preprocess_drift
16
16
  from dataeval.detectors.drift.uncertainty import DriftUncertainty
@@ -18,7 +18,7 @@ from typing import Any, Callable, Literal, TypeVar
18
18
  import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
- from dataeval.interop import as_numpy, to_numpy
21
+ from dataeval.interop import as_numpy
22
22
  from dataeval.output import OutputMetadata, set_metadata
23
23
 
24
24
  R = TypeVar("R")
@@ -196,7 +196,7 @@ class BaseDrift:
196
196
  if correction not in ["bonferroni", "fdr"]:
197
197
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
198
 
199
- self._x_ref = to_numpy(x_ref)
199
+ self._x_ref = as_numpy(x_ref)
200
200
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
201
 
202
202
  # Other attributes
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from functools import partial
14
13
  from typing import Any, Callable
15
14
 
16
15
  import numpy as np
@@ -18,30 +17,7 @@ import torch
18
17
  import torch.nn as nn
19
18
  from numpy.typing import NDArray
20
19
 
21
-
22
- def get_device(device: str | torch.device | None = None) -> torch.device:
23
- """
24
- Instantiates a PyTorch device object.
25
-
26
- Parameters
27
- ----------
28
- device : str | torch.device | None, default None
29
- Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
30
- already instantiated device object. If ``None``, the GPU is selected if it is
31
- detected, otherwise the CPU is used as a fallback.
32
-
33
- Returns
34
- -------
35
- The instantiated device object.
36
- """
37
- if isinstance(device, torch.device): # Already a torch device
38
- return device
39
- else: # Instantiate device
40
- if device is None or device.lower() in ["gpu", "cuda"]:
41
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- else:
43
- torch_device = torch.device("cpu")
44
- return torch_device
20
+ from dataeval.utils.torch.utils import get_device, predict_batch
45
21
 
46
22
 
47
23
  def _mmd2_from_kernel_matrix(
@@ -79,82 +55,6 @@ def _mmd2_from_kernel_matrix(
79
55
  return mmd2
80
56
 
81
57
 
82
- def predict_batch(
83
- x: NDArray[Any] | torch.Tensor,
84
- model: Callable | nn.Module | nn.Sequential,
85
- device: torch.device | None = None,
86
- batch_size: int = int(1e10),
87
- preprocess_fn: Callable | None = None,
88
- dtype: type[np.generic] | torch.dtype = np.float32,
89
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
90
- """
91
- Make batch predictions on a model.
92
-
93
- Parameters
94
- ----------
95
- x : np.ndarray | torch.Tensor
96
- Batch of instances.
97
- model : Callable | nn.Module | nn.Sequential
98
- PyTorch model.
99
- device : torch.device | None, default None
100
- Device type used. The default None tries to use the GPU and falls back on CPU.
101
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
102
- batch_size : int, default 1e10
103
- Batch size used during prediction.
104
- preprocess_fn : Callable | None, default None
105
- Optional preprocessing function for each batch.
106
- dtype : np.dtype | torch.dtype, default np.float32
107
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
108
-
109
- Returns
110
- -------
111
- NDArray | torch.Tensor | tuple
112
- Numpy array, torch tensor or tuples of those with model outputs.
113
- """
114
- device = get_device(device)
115
- if isinstance(x, np.ndarray):
116
- x = torch.from_numpy(x)
117
- n = len(x)
118
- n_minibatch = int(np.ceil(n / batch_size))
119
- return_np = not isinstance(dtype, torch.dtype)
120
- preds = []
121
- with torch.no_grad():
122
- for i in range(n_minibatch):
123
- istart, istop = i * batch_size, min((i + 1) * batch_size, n)
124
- x_batch = x[istart:istop]
125
- if isinstance(preprocess_fn, Callable):
126
- x_batch = preprocess_fn(x_batch)
127
- preds_tmp = model(x_batch.to(device))
128
- if isinstance(preds_tmp, (list, tuple)):
129
- if len(preds) == 0: # init tuple with lists to store predictions
130
- preds = tuple([] for _ in range(len(preds_tmp)))
131
- for j, p in enumerate(preds_tmp):
132
- if isinstance(p, torch.Tensor):
133
- p = p.cpu()
134
- preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
135
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
136
- if isinstance(preds_tmp, torch.Tensor):
137
- preds_tmp = preds_tmp.cpu()
138
- if isinstance(preds, tuple):
139
- preds = list(preds)
140
- preds.append(
141
- preds_tmp
142
- if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
143
- else preds_tmp.numpy()
144
- )
145
- else:
146
- raise TypeError(
147
- f"Model output type {type(preds_tmp)} not supported. The model \
148
- output type needs to be one of list, tuple, NDArray or \
149
- torch.Tensor."
150
- )
151
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
152
- out: tuple | np.ndarray | torch.Tensor = (
153
- tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
154
- )
155
- return out
156
-
157
-
158
58
  def preprocess_drift(
159
59
  x: NDArray[Any],
160
60
  model: nn.Module,
@@ -480,7 +480,7 @@ class Clusterer:
480
480
  samples = self.clusters[level][cluster_id].samples
481
481
  if len(samples) >= self._min_num_samples_per_cluster:
482
482
  duplicates_std.append(self.clusters[level][cluster_id].dist_std)
483
- diag_mask = np.ones_like(self._sqdmat, dtype=bool)
483
+ diag_mask = np.ones_like(self._sqdmat, dtype=np.bool_)
484
484
  np.fill_diagonal(diag_mask, 0)
485
485
  diag_mask = np.triu(diag_mask)
486
486
 
@@ -0,0 +1,22 @@
1
+ """
2
+ Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
+ """
4
+
5
+ from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
6
+ from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
7
+
8
+ __all__ = ["OODOutput", "OODScoreOutput"]
9
+
10
+ if _IS_TENSORFLOW_AVAILABLE:
11
+ from dataeval.detectors.ood.ae import OOD_AE
12
+ from dataeval.detectors.ood.aegmm import OOD_AEGMM
13
+ from dataeval.detectors.ood.llr import OOD_LLR
14
+ from dataeval.detectors.ood.vae import OOD_VAE
15
+ from dataeval.detectors.ood.vaegmm import OOD_VAEGMM
16
+
17
+ __all__ += ["OOD_AE", "OOD_AEGMM", "OOD_LLR", "OOD_VAE", "OOD_VAEGMM"]
18
+
19
+ elif _IS_TORCH_AVAILABLE:
20
+ from dataeval.detectors.ood.ae_torch import OOD_AE
21
+
22
+ __all__ += ["OOD_AE", "OODOutput"]
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
15
15
  import numpy as np
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
18
+ from dataeval.detectors.ood.base import OODScoreOutput
19
+ from dataeval.detectors.ood.base_tf import OODBase
19
20
  from dataeval.interop import as_numpy
20
21
  from dataeval.utils.lazy import lazyload
21
22
  from dataeval.utils.tensorflow._internal.utils import predict_batch
@@ -0,0 +1,70 @@
1
+ """
2
+ Adapted for Pytorch from
3
+
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Callable
14
+
15
+ import numpy as np
16
+ import torch
17
+ from numpy.typing import ArrayLike
18
+
19
+ from dataeval.detectors.ood.base import OODScoreOutput
20
+ from dataeval.detectors.ood.base_torch import OODBase
21
+ from dataeval.interop import as_numpy
22
+ from dataeval.utils.torch.utils import predict_batch
23
+
24
+
25
+ class OOD_AE(OODBase):
26
+ """
27
+ Autoencoder based out-of-distribution detector.
28
+
29
+ Parameters
30
+ ----------
31
+ model : AriaAutoencoder
32
+ An Autoencoder model.
33
+ """
34
+
35
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
36
+ super().__init__(model, device)
37
+
38
+ def fit(
39
+ self,
40
+ x_ref: ArrayLike,
41
+ threshold_perc: float,
42
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
43
+ optimizer: torch.optim.Optimizer | None = None,
44
+ epochs: int = 20,
45
+ batch_size: int = 64,
46
+ verbose: bool = False,
47
+ ) -> None:
48
+ if loss_fn is None:
49
+ loss_fn = torch.nn.MSELoss()
50
+
51
+ if optimizer is None:
52
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
53
+
54
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
55
+
56
+ def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
57
+ self._validate(X := as_numpy(X))
58
+
59
+ # reconstruct instances
60
+ X_recon = predict_batch(X, self.model, batch_size=batch_size)
61
+
62
+ # compute feature and instance level scores
63
+ fscore = np.power(X - X_recon, 2)
64
+ fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
65
+ n_score_features = int(np.ceil(fscore_flat.shape[1]))
66
+ sorted_fscore = np.sort(fscore_flat, axis=1)
67
+ sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
68
+ iscore = np.mean(sorted_fscore_perc, axis=1)
69
+
70
+ return OODScoreOutput(iscore, fscore)
@@ -14,7 +14,8 @@ from typing import TYPE_CHECKING, Callable
14
14
 
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval.detectors.ood.base import OODGMMBase, OODScoreOutput
17
+ from dataeval.detectors.ood.base import OODScoreOutput
18
+ from dataeval.detectors.ood.base_tf import OODBaseGMM
18
19
  from dataeval.interop import to_numpy
19
20
  from dataeval.utils.lazy import lazyload
20
21
  from dataeval.utils.tensorflow._internal.gmm import gmm_energy
@@ -32,7 +33,7 @@ else:
32
33
  tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
33
34
 
34
35
 
35
- class OOD_AEGMM(OODGMMBase):
36
+ class OOD_AEGMM(OODBaseGMM):
36
37
  """
37
38
  AE with Gaussian Mixture Model based outlier detector.
38
39
 
@@ -62,5 +63,5 @@ class OOD_AEGMM(OODGMMBase):
62
63
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
63
64
  self._validate(X := to_numpy(X))
64
65
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
65
- energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
66
+ energy, _ = gmm_energy(z, self._gmm_params, return_mean=False)
66
67
  return OODScoreOutput(energy.numpy()) # type: ignore
@@ -12,23 +12,14 @@ __all__ = ["OODOutput", "OODScoreOutput"]
12
12
 
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass
15
- from typing import TYPE_CHECKING, Callable, Literal, cast
15
+ from typing import Callable, Generic, Literal, TypeVar
16
16
 
17
17
  import numpy as np
18
18
  from numpy.typing import ArrayLike, NDArray
19
19
 
20
20
  from dataeval.interop import to_numpy
21
21
  from dataeval.output import OutputMetadata, set_metadata
22
- from dataeval.utils.lazy import lazyload
23
- from dataeval.utils.tensorflow._internal.gmm import GaussianMixtureModelParams, gmm_params
24
- from dataeval.utils.tensorflow._internal.trainer import trainer
25
-
26
- if TYPE_CHECKING:
27
- import tensorflow as tf
28
- import tf_keras as keras
29
- else:
30
- tf = lazyload("tensorflow")
31
- keras = lazyload("tf_keras")
22
+ from dataeval.utils.gmm import GaussianMixtureModelParams
32
23
 
33
24
 
34
25
  @dataclass(frozen=True)
@@ -85,16 +76,62 @@ class OODScoreOutput(OutputMetadata):
85
76
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
86
77
 
87
78
 
88
- class OODBase(ABC):
89
- def __init__(self, model: keras.Model) -> None:
90
- self.model = model
79
+ TGMMData = TypeVar("TGMMData")
80
+
81
+
82
+ class OODGMMMixin(Generic[TGMMData]):
83
+ _gmm_params: GaussianMixtureModelParams[TGMMData]
84
+
85
+
86
+ TModel = TypeVar("TModel", bound=Callable)
87
+ TLossFn = TypeVar("TLossFn", bound=Callable)
88
+ TOptimizer = TypeVar("TOptimizer")
89
+
90
+
91
+ class OODFitMixin(Generic[TLossFn, TOptimizer], ABC):
92
+ @abstractmethod
93
+ def fit(
94
+ self,
95
+ x_ref: ArrayLike,
96
+ threshold_perc: float,
97
+ loss_fn: TLossFn | None,
98
+ optimizer: TOptimizer | None,
99
+ epochs: int,
100
+ batch_size: int,
101
+ verbose: bool,
102
+ ) -> None:
103
+ """
104
+ Train the model and infer the threshold value.
105
+
106
+ Parameters
107
+ ----------
108
+ x_ref : ArrayLike
109
+ Training data.
110
+ threshold_perc : float, default 100.0
111
+ Percentage of reference data that is normal.
112
+ loss_fn : TLossFn
113
+ Loss function used for training.
114
+ optimizer : TOptimizer
115
+ Optimizer used for training.
116
+ epochs : int, default 20
117
+ Number of training epochs.
118
+ batch_size : int, default 64
119
+ Batch size used for training.
120
+ verbose : bool, default True
121
+ Whether to print training progress.
122
+ """
91
123
 
92
- self._ref_score: OODScoreOutput
93
- self._threshold_perc: float
94
- self._data_info: tuple[tuple, type] | None = None
95
124
 
96
- if not isinstance(model, keras.Model):
97
- raise TypeError("Model should be of type 'keras.Model'.")
125
+ class OODBaseMixin(Generic[TModel], ABC):
126
+ _ref_score: OODScoreOutput
127
+ _threshold_perc: float
128
+ _data_info: tuple[tuple, type] | None = None
129
+
130
+ def __init__(
131
+ self,
132
+ model: TModel,
133
+ ) -> None:
134
+ self.model = model
98
135
 
99
136
  def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
100
137
  if not isinstance(X, np.ndarray):
@@ -107,9 +144,8 @@ class OODBase(ABC):
107
144
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
108
145
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
109
146
 
110
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
111
- attrs = ["_data_info", "_threshold_perc", "_ref_score"]
112
- attrs = attrs if additional_attrs is None else attrs + additional_attrs
147
+ def _validate_state(self, X: NDArray) -> None:
148
+ attrs = [k for c in self.__class__.mro()[:-1][::-1] if hasattr(c, "__annotations__") for k in c.__annotations__]
113
149
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
114
150
  raise RuntimeError("Metric needs to be `fit` before method call.")
115
151
  self._validate(X)
@@ -140,52 +176,6 @@ class OODBase(ABC):
140
176
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
141
177
  return np.percentile(self._ref_score.get(ood_type), self._threshold_perc)
142
178
 
143
- def fit(
144
- self,
145
- x_ref: ArrayLike,
146
- threshold_perc: float,
147
- loss_fn: Callable[..., tf.Tensor],
148
- optimizer: keras.optimizers.Optimizer,
149
- epochs: int,
150
- batch_size: int,
151
- verbose: bool,
152
- ) -> None:
153
- """
154
- Train the model and infer the threshold value.
155
-
156
- Parameters
157
- ----------
158
- x_ref : ArrayLike
159
- Training data.
160
- threshold_perc : float, default 100.0
161
- Percentage of reference data that is normal.
162
- loss_fn : Callable | None, default None
163
- Loss function used for training.
164
- optimizer : Optimizer, default keras.optimizers.Adam
165
- Optimizer used for training.
166
- epochs : int, default 20
167
- Number of training epochs.
168
- batch_size : int, default 64
169
- Batch size used for training.
170
- verbose : bool, default True
171
- Whether to print training progress.
172
- """
173
-
174
- # Train the model
175
- trainer(
176
- model=self.model,
177
- loss_fn=loss_fn,
178
- x_train=to_numpy(x_ref),
179
- optimizer=optimizer,
180
- epochs=epochs,
181
- batch_size=batch_size,
182
- verbose=verbose,
183
- )
184
-
185
- # Infer the threshold values
186
- self._ref_score = self.score(x_ref, batch_size)
187
- self._threshold_perc = threshold_perc
188
-
189
179
  @set_metadata()
190
180
  def predict(
191
181
  self,
@@ -215,43 +205,3 @@ class OODBase(ABC):
215
205
  score = self.score(X, batch_size=batch_size)
216
206
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
217
207
  return OODOutput(is_ood=ood_pred, **score.dict())
218
-
219
-
220
- class OODGMMBase(OODBase):
221
- def __init__(self, model: keras.Model) -> None:
222
- super().__init__(model)
223
- self.gmm_params: GaussianMixtureModelParams
224
-
225
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
226
- if additional_attrs is None:
227
- additional_attrs = ["gmm_params"]
228
- super()._validate_state(X, additional_attrs)
229
-
230
- def fit(
231
- self,
232
- x_ref: ArrayLike,
233
- threshold_perc: float,
234
- loss_fn: Callable[..., tf.Tensor],
235
- optimizer: keras.optimizers.Optimizer,
236
- epochs: int,
237
- batch_size: int,
238
- verbose: bool,
239
- ) -> None:
240
- # Train the model
241
- trainer(
242
- model=self.model,
243
- loss_fn=loss_fn,
244
- x_train=to_numpy(x_ref),
245
- optimizer=optimizer,
246
- epochs=epochs,
247
- batch_size=batch_size,
248
- verbose=verbose,
249
- )
250
-
251
- # Calculate the GMM parameters
252
- _, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
253
- self.gmm_params = gmm_params(z, gamma)
254
-
255
- # Infer the threshold values
256
- self._ref_score = self.score(x_ref, batch_size)
257
- self._threshold_perc = threshold_perc