dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/__init__.py +1 -1
  3. dataeval/detectors/drift/__init__.py +1 -1
  4. dataeval/detectors/drift/base.py +2 -2
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +1 -1
  7. dataeval/detectors/ood/__init__.py +11 -4
  8. dataeval/detectors/ood/ae.py +2 -1
  9. dataeval/detectors/ood/ae_torch.py +70 -0
  10. dataeval/detectors/ood/aegmm.py +4 -3
  11. dataeval/detectors/ood/base.py +58 -108
  12. dataeval/detectors/ood/base_tf.py +109 -0
  13. dataeval/detectors/ood/base_torch.py +109 -0
  14. dataeval/detectors/ood/llr.py +2 -2
  15. dataeval/detectors/ood/metadata_ks_compare.py +53 -14
  16. dataeval/detectors/ood/vae.py +3 -2
  17. dataeval/detectors/ood/vaegmm.py +5 -4
  18. dataeval/metrics/bias/__init__.py +3 -0
  19. dataeval/metrics/bias/balance.py +77 -64
  20. dataeval/metrics/bias/coverage.py +12 -12
  21. dataeval/metrics/bias/diversity.py +74 -114
  22. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  23. dataeval/metrics/bias/metadata_utils.py +229 -0
  24. dataeval/metrics/bias/parity.py +54 -158
  25. dataeval/utils/__init__.py +2 -2
  26. dataeval/utils/gmm.py +26 -0
  27. dataeval/utils/metadata.py +29 -9
  28. dataeval/utils/shared.py +1 -1
  29. dataeval/utils/split_dataset.py +12 -6
  30. dataeval/utils/tensorflow/_internal/gmm.py +4 -24
  31. dataeval/utils/torch/datasets.py +2 -2
  32. dataeval/utils/torch/gmm.py +98 -0
  33. dataeval/utils/torch/models.py +192 -0
  34. dataeval/utils/torch/trainer.py +84 -5
  35. dataeval/utils/torch/utils.py +107 -1
  36. dataeval/workflows/__init__.py +1 -1
  37. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
  38. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
  39. dataeval/metrics/bias/metadata.py +0 -358
  40. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
  41. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.73.0"
1
+ __version__ = "0.74.0"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -12,12 +12,12 @@ from dataeval import detectors, metrics # noqa: E402
12
12
 
13
13
  __all__ = ["detectors", "metrics"]
14
14
 
15
- if _IS_TORCH_AVAILABLE: # pragma: no cover
15
+ if _IS_TORCH_AVAILABLE:
16
16
  from dataeval import workflows
17
17
 
18
18
  __all__ += ["workflows"]
19
19
 
20
- if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
21
21
  from dataeval import utils
22
22
 
23
23
  __all__ += ["utils"]
@@ -7,7 +7,7 @@ from dataeval.detectors import drift, linters
7
7
 
8
8
  __all__ = ["drift", "linters"]
9
9
 
10
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
10
+ if _IS_TENSORFLOW_AVAILABLE:
11
11
  from dataeval.detectors import ood
12
12
 
13
13
  __all__ += ["ood"]
@@ -10,7 +10,7 @@ from dataeval.detectors.drift.ks import DriftKS
10
10
 
11
11
  __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
12
12
 
13
- if _IS_TORCH_AVAILABLE: # pragma: no cover
13
+ if _IS_TORCH_AVAILABLE:
14
14
  from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
15
  from dataeval.detectors.drift.torch import preprocess_drift
16
16
  from dataeval.detectors.drift.uncertainty import DriftUncertainty
@@ -18,7 +18,7 @@ from typing import Any, Callable, Literal, TypeVar
18
18
  import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
- from dataeval.interop import as_numpy, to_numpy
21
+ from dataeval.interop import as_numpy
22
22
  from dataeval.output import OutputMetadata, set_metadata
23
23
 
24
24
  R = TypeVar("R")
@@ -196,7 +196,7 @@ class BaseDrift:
196
196
  if correction not in ["bonferroni", "fdr"]:
197
197
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
198
 
199
- self._x_ref = to_numpy(x_ref)
199
+ self._x_ref = as_numpy(x_ref)
200
200
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
201
 
202
202
  # Other attributes
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from functools import partial
14
13
  from typing import Any, Callable
15
14
 
16
15
  import numpy as np
@@ -18,30 +17,7 @@ import torch
18
17
  import torch.nn as nn
19
18
  from numpy.typing import NDArray
20
19
 
21
-
22
- def get_device(device: str | torch.device | None = None) -> torch.device:
23
- """
24
- Instantiates a PyTorch device object.
25
-
26
- Parameters
27
- ----------
28
- device : str | torch.device | None, default None
29
- Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
30
- already instantiated device object. If ``None``, the GPU is selected if it is
31
- detected, otherwise the CPU is used as a fallback.
32
-
33
- Returns
34
- -------
35
- The instantiated device object.
36
- """
37
- if isinstance(device, torch.device): # Already a torch device
38
- return device
39
- else: # Instantiate device
40
- if device is None or device.lower() in ["gpu", "cuda"]:
41
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- else:
43
- torch_device = torch.device("cpu")
44
- return torch_device
20
+ from dataeval.utils.torch.utils import get_device, predict_batch
45
21
 
46
22
 
47
23
  def _mmd2_from_kernel_matrix(
@@ -79,82 +55,6 @@ def _mmd2_from_kernel_matrix(
79
55
  return mmd2
80
56
 
81
57
 
82
- def predict_batch(
83
- x: NDArray[Any] | torch.Tensor,
84
- model: Callable | nn.Module | nn.Sequential,
85
- device: torch.device | None = None,
86
- batch_size: int = int(1e10),
87
- preprocess_fn: Callable | None = None,
88
- dtype: type[np.generic] | torch.dtype = np.float32,
89
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
90
- """
91
- Make batch predictions on a model.
92
-
93
- Parameters
94
- ----------
95
- x : np.ndarray | torch.Tensor
96
- Batch of instances.
97
- model : Callable | nn.Module | nn.Sequential
98
- PyTorch model.
99
- device : torch.device | None, default None
100
- Device type used. The default None tries to use the GPU and falls back on CPU.
101
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
102
- batch_size : int, default 1e10
103
- Batch size used during prediction.
104
- preprocess_fn : Callable | None, default None
105
- Optional preprocessing function for each batch.
106
- dtype : np.dtype | torch.dtype, default np.float32
107
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
108
-
109
- Returns
110
- -------
111
- NDArray | torch.Tensor | tuple
112
- Numpy array, torch tensor or tuples of those with model outputs.
113
- """
114
- device = get_device(device)
115
- if isinstance(x, np.ndarray):
116
- x = torch.from_numpy(x)
117
- n = len(x)
118
- n_minibatch = int(np.ceil(n / batch_size))
119
- return_np = not isinstance(dtype, torch.dtype)
120
- preds = []
121
- with torch.no_grad():
122
- for i in range(n_minibatch):
123
- istart, istop = i * batch_size, min((i + 1) * batch_size, n)
124
- x_batch = x[istart:istop]
125
- if isinstance(preprocess_fn, Callable):
126
- x_batch = preprocess_fn(x_batch)
127
- preds_tmp = model(x_batch.to(device))
128
- if isinstance(preds_tmp, (list, tuple)):
129
- if len(preds) == 0: # init tuple with lists to store predictions
130
- preds = tuple([] for _ in range(len(preds_tmp)))
131
- for j, p in enumerate(preds_tmp):
132
- if isinstance(p, torch.Tensor):
133
- p = p.cpu()
134
- preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
135
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
136
- if isinstance(preds_tmp, torch.Tensor):
137
- preds_tmp = preds_tmp.cpu()
138
- if isinstance(preds, tuple):
139
- preds = list(preds)
140
- preds.append(
141
- preds_tmp
142
- if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
143
- else preds_tmp.numpy()
144
- )
145
- else:
146
- raise TypeError(
147
- f"Model output type {type(preds_tmp)} not supported. The model \
148
- output type needs to be one of list, tuple, NDArray or \
149
- torch.Tensor."
150
- )
151
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
152
- out: tuple | np.ndarray | torch.Tensor = (
153
- tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
154
- )
155
- return out
156
-
157
-
158
58
  def preprocess_drift(
159
59
  x: NDArray[Any],
160
60
  model: nn.Module,
@@ -480,7 +480,7 @@ class Clusterer:
480
480
  samples = self.clusters[level][cluster_id].samples
481
481
  if len(samples) >= self._min_num_samples_per_cluster:
482
482
  duplicates_std.append(self.clusters[level][cluster_id].dist_std)
483
- diag_mask = np.ones_like(self._sqdmat, dtype=bool)
483
+ diag_mask = np.ones_like(self._sqdmat, dtype=np.bool_)
484
484
  np.fill_diagonal(diag_mask, 0)
485
485
  diag_mask = np.triu(diag_mask)
486
486
 
@@ -2,14 +2,21 @@
2
2
  Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- from dataeval import _IS_TENSORFLOW_AVAILABLE
5
+ from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
6
+ from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
6
7
 
7
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
8
+ __all__ = ["OODOutput", "OODScoreOutput"]
9
+
10
+ if _IS_TENSORFLOW_AVAILABLE:
8
11
  from dataeval.detectors.ood.ae import OOD_AE
9
12
  from dataeval.detectors.ood.aegmm import OOD_AEGMM
10
- from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
11
13
  from dataeval.detectors.ood.llr import OOD_LLR
12
14
  from dataeval.detectors.ood.vae import OOD_VAE
13
15
  from dataeval.detectors.ood.vaegmm import OOD_VAEGMM
14
16
 
15
- __all__ = ["OOD_AE", "OOD_AEGMM", "OOD_LLR", "OOD_VAE", "OOD_VAEGMM", "OODOutput", "OODScoreOutput"]
17
+ __all__ += ["OOD_AE", "OOD_AEGMM", "OOD_LLR", "OOD_VAE", "OOD_VAEGMM"]
18
+
19
+ elif _IS_TORCH_AVAILABLE:
20
+ from dataeval.detectors.ood.ae_torch import OOD_AE
21
+
22
+ __all__ += ["OOD_AE", "OODOutput"]
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
15
15
  import numpy as np
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
18
+ from dataeval.detectors.ood.base import OODScoreOutput
19
+ from dataeval.detectors.ood.base_tf import OODBase
19
20
  from dataeval.interop import as_numpy
20
21
  from dataeval.utils.lazy import lazyload
21
22
  from dataeval.utils.tensorflow._internal.utils import predict_batch
@@ -0,0 +1,70 @@
1
+ """
2
+ Adapted for Pytorch from
3
+
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Callable
14
+
15
+ import numpy as np
16
+ import torch
17
+ from numpy.typing import ArrayLike
18
+
19
+ from dataeval.detectors.ood.base import OODScoreOutput
20
+ from dataeval.detectors.ood.base_torch import OODBase
21
+ from dataeval.interop import as_numpy
22
+ from dataeval.utils.torch.utils import predict_batch
23
+
24
+
25
+ class OOD_AE(OODBase):
26
+ """
27
+ Autoencoder based out-of-distribution detector.
28
+
29
+ Parameters
30
+ ----------
31
+ model : AriaAutoencoder
32
+ An Autoencoder model.
33
+ """
34
+
35
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
36
+ super().__init__(model, device)
37
+
38
+ def fit(
39
+ self,
40
+ x_ref: ArrayLike,
41
+ threshold_perc: float,
42
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
43
+ optimizer: torch.optim.Optimizer | None = None,
44
+ epochs: int = 20,
45
+ batch_size: int = 64,
46
+ verbose: bool = False,
47
+ ) -> None:
48
+ if loss_fn is None:
49
+ loss_fn = torch.nn.MSELoss()
50
+
51
+ if optimizer is None:
52
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
53
+
54
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
55
+
56
+ def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
57
+ self._validate(X := as_numpy(X))
58
+
59
+ # reconstruct instances
60
+ X_recon = predict_batch(X, self.model, batch_size=batch_size)
61
+
62
+ # compute feature and instance level scores
63
+ fscore = np.power(X - X_recon, 2)
64
+ fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
65
+ n_score_features = int(np.ceil(fscore_flat.shape[1]))
66
+ sorted_fscore = np.sort(fscore_flat, axis=1)
67
+ sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
68
+ iscore = np.mean(sorted_fscore_perc, axis=1)
69
+
70
+ return OODScoreOutput(iscore, fscore)
@@ -14,7 +14,8 @@ from typing import TYPE_CHECKING, Callable
14
14
 
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval.detectors.ood.base import OODGMMBase, OODScoreOutput
17
+ from dataeval.detectors.ood.base import OODScoreOutput
18
+ from dataeval.detectors.ood.base_tf import OODBaseGMM
18
19
  from dataeval.interop import to_numpy
19
20
  from dataeval.utils.lazy import lazyload
20
21
  from dataeval.utils.tensorflow._internal.gmm import gmm_energy
@@ -32,7 +33,7 @@ else:
32
33
  tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
33
34
 
34
35
 
35
- class OOD_AEGMM(OODGMMBase):
36
+ class OOD_AEGMM(OODBaseGMM):
36
37
  """
37
38
  AE with Gaussian Mixture Model based outlier detector.
38
39
 
@@ -62,5 +63,5 @@ class OOD_AEGMM(OODGMMBase):
62
63
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
63
64
  self._validate(X := to_numpy(X))
64
65
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
65
- energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
66
+ energy, _ = gmm_energy(z, self._gmm_params, return_mean=False)
66
67
  return OODScoreOutput(energy.numpy()) # type: ignore
@@ -12,23 +12,14 @@ __all__ = ["OODOutput", "OODScoreOutput"]
12
12
 
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass
15
- from typing import TYPE_CHECKING, Callable, Literal, cast
15
+ from typing import Callable, Generic, Literal, TypeVar
16
16
 
17
17
  import numpy as np
18
18
  from numpy.typing import ArrayLike, NDArray
19
19
 
20
20
  from dataeval.interop import to_numpy
21
21
  from dataeval.output import OutputMetadata, set_metadata
22
- from dataeval.utils.lazy import lazyload
23
- from dataeval.utils.tensorflow._internal.gmm import GaussianMixtureModelParams, gmm_params
24
- from dataeval.utils.tensorflow._internal.trainer import trainer
25
-
26
- if TYPE_CHECKING:
27
- import tensorflow as tf
28
- import tf_keras as keras
29
- else:
30
- tf = lazyload("tensorflow")
31
- keras = lazyload("tf_keras")
22
+ from dataeval.utils.gmm import GaussianMixtureModelParams
32
23
 
33
24
 
34
25
  @dataclass(frozen=True)
@@ -85,16 +76,62 @@ class OODScoreOutput(OutputMetadata):
85
76
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
86
77
 
87
78
 
88
- class OODBase(ABC):
89
- def __init__(self, model: keras.Model) -> None:
90
- self.model = model
79
+ TGMMData = TypeVar("TGMMData")
80
+
81
+
82
+ class OODGMMMixin(Generic[TGMMData]):
83
+ _gmm_params: GaussianMixtureModelParams[TGMMData]
84
+
85
+
86
+ TModel = TypeVar("TModel", bound=Callable)
87
+ TLossFn = TypeVar("TLossFn", bound=Callable)
88
+ TOptimizer = TypeVar("TOptimizer")
89
+
90
+
91
+ class OODFitMixin(Generic[TLossFn, TOptimizer], ABC):
92
+ @abstractmethod
93
+ def fit(
94
+ self,
95
+ x_ref: ArrayLike,
96
+ threshold_perc: float,
97
+ loss_fn: TLossFn | None,
98
+ optimizer: TOptimizer | None,
99
+ epochs: int,
100
+ batch_size: int,
101
+ verbose: bool,
102
+ ) -> None:
103
+ """
104
+ Train the model and infer the threshold value.
105
+
106
+ Parameters
107
+ ----------
108
+ x_ref : ArrayLike
109
+ Training data.
110
+ threshold_perc : float, default 100.0
111
+ Percentage of reference data that is normal.
112
+ loss_fn : TLossFn
113
+ Loss function used for training.
114
+ optimizer : TOptimizer
115
+ Optimizer used for training.
116
+ epochs : int, default 20
117
+ Number of training epochs.
118
+ batch_size : int, default 64
119
+ Batch size used for training.
120
+ verbose : bool, default True
121
+ Whether to print training progress.
122
+ """
91
123
 
92
- self._ref_score: OODScoreOutput
93
- self._threshold_perc: float
94
- self._data_info: tuple[tuple, type] | None = None
95
124
 
96
- if not isinstance(model, keras.Model):
97
- raise TypeError("Model should be of type 'keras.Model'.")
125
+ class OODBaseMixin(Generic[TModel], ABC):
126
+ _ref_score: OODScoreOutput
127
+ _threshold_perc: float
128
+ _data_info: tuple[tuple, type] | None = None
129
+
130
+ def __init__(
131
+ self,
132
+ model: TModel,
133
+ ) -> None:
134
+ self.model = model
98
135
 
99
136
  def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
100
137
  if not isinstance(X, np.ndarray):
@@ -107,9 +144,8 @@ class OODBase(ABC):
107
144
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
108
145
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
109
146
 
110
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
111
- attrs = ["_data_info", "_threshold_perc", "_ref_score"]
112
- attrs = attrs if additional_attrs is None else attrs + additional_attrs
147
+ def _validate_state(self, X: NDArray) -> None:
148
+ attrs = [k for c in self.__class__.mro()[:-1][::-1] if hasattr(c, "__annotations__") for k in c.__annotations__]
113
149
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
114
150
  raise RuntimeError("Metric needs to be `fit` before method call.")
115
151
  self._validate(X)
@@ -140,52 +176,6 @@ class OODBase(ABC):
140
176
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
141
177
  return np.percentile(self._ref_score.get(ood_type), self._threshold_perc)
142
178
 
143
- def fit(
144
- self,
145
- x_ref: ArrayLike,
146
- threshold_perc: float,
147
- loss_fn: Callable[..., tf.Tensor],
148
- optimizer: keras.optimizers.Optimizer,
149
- epochs: int,
150
- batch_size: int,
151
- verbose: bool,
152
- ) -> None:
153
- """
154
- Train the model and infer the threshold value.
155
-
156
- Parameters
157
- ----------
158
- x_ref : ArrayLike
159
- Training data.
160
- threshold_perc : float, default 100.0
161
- Percentage of reference data that is normal.
162
- loss_fn : Callable | None, default None
163
- Loss function used for training.
164
- optimizer : Optimizer, default keras.optimizers.Adam
165
- Optimizer used for training.
166
- epochs : int, default 20
167
- Number of training epochs.
168
- batch_size : int, default 64
169
- Batch size used for training.
170
- verbose : bool, default True
171
- Whether to print training progress.
172
- """
173
-
174
- # Train the model
175
- trainer(
176
- model=self.model,
177
- loss_fn=loss_fn,
178
- x_train=to_numpy(x_ref),
179
- optimizer=optimizer,
180
- epochs=epochs,
181
- batch_size=batch_size,
182
- verbose=verbose,
183
- )
184
-
185
- # Infer the threshold values
186
- self._ref_score = self.score(x_ref, batch_size)
187
- self._threshold_perc = threshold_perc
188
-
189
179
  @set_metadata()
190
180
  def predict(
191
181
  self,
@@ -215,43 +205,3 @@ class OODBase(ABC):
215
205
  score = self.score(X, batch_size=batch_size)
216
206
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
217
207
  return OODOutput(is_ood=ood_pred, **score.dict())
218
-
219
-
220
- class OODGMMBase(OODBase):
221
- def __init__(self, model: keras.Model) -> None:
222
- super().__init__(model)
223
- self.gmm_params: GaussianMixtureModelParams
224
-
225
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
226
- if additional_attrs is None:
227
- additional_attrs = ["gmm_params"]
228
- super()._validate_state(X, additional_attrs)
229
-
230
- def fit(
231
- self,
232
- x_ref: ArrayLike,
233
- threshold_perc: float,
234
- loss_fn: Callable[..., tf.Tensor],
235
- optimizer: keras.optimizers.Optimizer,
236
- epochs: int,
237
- batch_size: int,
238
- verbose: bool,
239
- ) -> None:
240
- # Train the model
241
- trainer(
242
- model=self.model,
243
- loss_fn=loss_fn,
244
- x_train=to_numpy(x_ref),
245
- optimizer=optimizer,
246
- epochs=epochs,
247
- batch_size=batch_size,
248
- verbose=verbose,
249
- )
250
-
251
- # Calculate the GMM parameters
252
- _, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
253
- self.gmm_params = gmm_params(z, gamma)
254
-
255
- # Infer the threshold values
256
- self._ref_score = self.score(x_ref, batch_size)
257
- self._threshold_perc = threshold_perc
@@ -0,0 +1,109 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING, Callable, cast
12
+
13
+ from numpy.typing import ArrayLike
14
+
15
+ from dataeval.detectors.ood.base import OODBaseMixin, OODFitMixin, OODGMMMixin
16
+ from dataeval.interop import to_numpy
17
+ from dataeval.utils.lazy import lazyload
18
+ from dataeval.utils.tensorflow._internal.gmm import gmm_params
19
+ from dataeval.utils.tensorflow._internal.trainer import trainer
20
+
21
+ if TYPE_CHECKING:
22
+ import tensorflow as tf
23
+ import tf_keras as keras
24
+ else:
25
+ tf = lazyload("tensorflow")
26
+ keras = lazyload("tf_keras")
27
+
28
+
29
+ class OODBase(OODBaseMixin[keras.Model], OODFitMixin[Callable[..., tf.Tensor], keras.optimizers.Optimizer]):
30
+ def __init__(self, model: keras.Model) -> None:
31
+ super().__init__(model)
32
+
33
+ def fit(
34
+ self,
35
+ x_ref: ArrayLike,
36
+ threshold_perc: float,
37
+ loss_fn: Callable[..., tf.Tensor] | None,
38
+ optimizer: keras.optimizers.Optimizer | None,
39
+ epochs: int,
40
+ batch_size: int,
41
+ verbose: bool,
42
+ ) -> None:
43
+ """
44
+ Train the model and infer the threshold value.
45
+
46
+ Parameters
47
+ ----------
48
+ x_ref : ArrayLike
49
+ Training data.
50
+ threshold_perc : float, default 100.0
51
+ Percentage of reference data that is normal.
52
+ loss_fn : Callable | None, default None
53
+ Loss function used for training.
54
+ optimizer : Optimizer, default keras.optimizers.Adam
55
+ Optimizer used for training.
56
+ epochs : int, default 20
57
+ Number of training epochs.
58
+ batch_size : int, default 64
59
+ Batch size used for training.
60
+ verbose : bool, default True
61
+ Whether to print training progress.
62
+ """
63
+
64
+ # Train the model
65
+ trainer(
66
+ model=self.model,
67
+ loss_fn=loss_fn,
68
+ x_train=to_numpy(x_ref),
69
+ y_train=None,
70
+ optimizer=optimizer,
71
+ epochs=epochs,
72
+ batch_size=batch_size,
73
+ verbose=verbose,
74
+ )
75
+
76
+ # Infer the threshold values
77
+ self._ref_score = self.score(x_ref, batch_size)
78
+ self._threshold_perc = threshold_perc
79
+
80
+
81
+ class OODBaseGMM(OODBase, OODGMMMixin[tf.Tensor]):
82
+ def fit(
83
+ self,
84
+ x_ref: ArrayLike,
85
+ threshold_perc: float,
86
+ loss_fn: Callable[..., tf.Tensor] | None,
87
+ optimizer: keras.optimizers.Optimizer | None,
88
+ epochs: int,
89
+ batch_size: int,
90
+ verbose: bool,
91
+ ) -> None:
92
+ # Train the model
93
+ trainer(
94
+ model=self.model,
95
+ loss_fn=loss_fn,
96
+ x_train=to_numpy(x_ref),
97
+ optimizer=optimizer,
98
+ epochs=epochs,
99
+ batch_size=batch_size,
100
+ verbose=verbose,
101
+ )
102
+
103
+ # Calculate the GMM parameters
104
+ _, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
105
+ self._gmm_params = gmm_params(z, gamma)
106
+
107
+ # Infer the threshold values
108
+ self._ref_score = self.score(x_ref, batch_size)
109
+ self._threshold_perc = threshold_perc