dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +3 -3
  7. dataeval/detectors/linters/duplicates.py +4 -4
  8. dataeval/detectors/linters/outliers.py +4 -4
  9. dataeval/detectors/ood/__init__.py +9 -9
  10. dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
  11. dataeval/detectors/ood/base.py +63 -113
  12. dataeval/detectors/ood/base_torch.py +109 -0
  13. dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  14. dataeval/interop.py +1 -1
  15. dataeval/metrics/bias/__init__.py +3 -0
  16. dataeval/metrics/bias/balance.py +73 -70
  17. dataeval/metrics/bias/coverage.py +4 -4
  18. dataeval/metrics/bias/diversity.py +67 -136
  19. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  20. dataeval/metrics/bias/metadata_utils.py +229 -0
  21. dataeval/metrics/bias/parity.py +51 -161
  22. dataeval/metrics/estimators/ber.py +3 -3
  23. dataeval/metrics/estimators/divergence.py +3 -3
  24. dataeval/metrics/estimators/uap.py +3 -3
  25. dataeval/metrics/stats/base.py +2 -2
  26. dataeval/metrics/stats/boxratiostats.py +1 -1
  27. dataeval/metrics/stats/datasetstats.py +6 -6
  28. dataeval/metrics/stats/dimensionstats.py +1 -1
  29. dataeval/metrics/stats/hashstats.py +1 -1
  30. dataeval/metrics/stats/labelstats.py +3 -3
  31. dataeval/metrics/stats/pixelstats.py +1 -1
  32. dataeval/metrics/stats/visualstats.py +1 -1
  33. dataeval/output.py +77 -53
  34. dataeval/utils/__init__.py +1 -7
  35. dataeval/utils/gmm.py +26 -0
  36. dataeval/utils/metadata.py +29 -9
  37. dataeval/utils/torch/gmm.py +98 -0
  38. dataeval/utils/torch/models.py +192 -0
  39. dataeval/utils/torch/trainer.py +84 -5
  40. dataeval/utils/torch/utils.py +107 -1
  41. dataeval/workflows/sufficiency.py +4 -4
  42. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
  43. dataeval-0.74.1.dist-info/RECORD +65 -0
  44. dataeval/detectors/ood/aegmm.py +0 -66
  45. dataeval/detectors/ood/llr.py +0 -302
  46. dataeval/detectors/ood/vae.py +0 -97
  47. dataeval/detectors/ood/vaegmm.py +0 -75
  48. dataeval/metrics/bias/metadata.py +0 -440
  49. dataeval/utils/lazy.py +0 -26
  50. dataeval/utils/tensorflow/__init__.py +0 -19
  51. dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  52. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  53. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  54. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  55. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  56. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  57. dataeval-0.73.1.dist-info/RECORD +0 -73
  58. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,10 +1,9 @@
1
- __version__ = "0.73.1"
1
+ __version__ = "0.74.1"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
5
5
  _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
6
  _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
7
- _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
8
7
 
9
8
  del find_spec
10
9
 
@@ -13,11 +12,6 @@ from dataeval import detectors, metrics # noqa: E402
13
12
  __all__ = ["detectors", "metrics"]
14
13
 
15
14
  if _IS_TORCH_AVAILABLE:
16
- from dataeval import workflows
15
+ from dataeval import utils, workflows
17
16
 
18
- __all__ += ["workflows"]
19
-
20
- if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
21
- from dataeval import utils
22
-
23
- __all__ += ["utils"]
17
+ __all__ += ["utils", "workflows"]
@@ -2,14 +2,6 @@
2
2
  Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
3
3
  """
4
4
 
5
- from dataeval import _IS_TENSORFLOW_AVAILABLE
6
- from dataeval.detectors import drift, linters
5
+ from dataeval.detectors import drift, linters, ood
7
6
 
8
- __all__ = ["drift", "linters"]
9
-
10
- if _IS_TENSORFLOW_AVAILABLE:
11
- from dataeval.detectors import ood
12
-
13
- __all__ += ["ood"]
14
-
15
- del _IS_TENSORFLOW_AVAILABLE
7
+ __all__ = ["drift", "linters", "ood"]
@@ -19,7 +19,7 @@ import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
21
  from dataeval.interop import as_numpy
22
- from dataeval.output import OutputMetadata, set_metadata
22
+ from dataeval.output import Output, set_metadata
23
23
 
24
24
  R = TypeVar("R")
25
25
 
@@ -43,7 +43,7 @@ class UpdateStrategy(ABC):
43
43
 
44
44
 
45
45
  @dataclass(frozen=True)
46
- class DriftBaseOutput(OutputMetadata):
46
+ class DriftBaseOutput(Output):
47
47
  """
48
48
  Base output class for Drift detector classes
49
49
 
@@ -387,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
387
387
  else:
388
388
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
389
389
 
390
- @set_metadata()
390
+ @set_metadata
391
391
  @preprocess_x
392
392
  @update_x_ref
393
393
  def predict(
@@ -161,7 +161,7 @@ class DriftMMD(BaseDrift):
161
161
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
162
162
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
163
163
 
164
- @set_metadata()
164
+ @set_metadata
165
165
  @preprocess_x
166
166
  @update_x_ref
167
167
  def predict(self, x: ArrayLike) -> DriftMMDOutput:
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from functools import partial
14
13
  from typing import Any, Callable
15
14
 
16
15
  import numpy as np
@@ -18,30 +17,7 @@ import torch
18
17
  import torch.nn as nn
19
18
  from numpy.typing import NDArray
20
19
 
21
-
22
- def get_device(device: str | torch.device | None = None) -> torch.device:
23
- """
24
- Instantiates a PyTorch device object.
25
-
26
- Parameters
27
- ----------
28
- device : str | torch.device | None, default None
29
- Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
30
- already instantiated device object. If ``None``, the GPU is selected if it is
31
- detected, otherwise the CPU is used as a fallback.
32
-
33
- Returns
34
- -------
35
- The instantiated device object.
36
- """
37
- if isinstance(device, torch.device): # Already a torch device
38
- return device
39
- else: # Instantiate device
40
- if device is None or device.lower() in ["gpu", "cuda"]:
41
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- else:
43
- torch_device = torch.device("cpu")
44
- return torch_device
20
+ from dataeval.utils.torch.utils import get_device, predict_batch
45
21
 
46
22
 
47
23
  def _mmd2_from_kernel_matrix(
@@ -79,82 +55,6 @@ def _mmd2_from_kernel_matrix(
79
55
  return mmd2
80
56
 
81
57
 
82
- def predict_batch(
83
- x: NDArray[Any] | torch.Tensor,
84
- model: Callable | nn.Module | nn.Sequential,
85
- device: torch.device | None = None,
86
- batch_size: int = int(1e10),
87
- preprocess_fn: Callable | None = None,
88
- dtype: type[np.generic] | torch.dtype = np.float32,
89
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
90
- """
91
- Make batch predictions on a model.
92
-
93
- Parameters
94
- ----------
95
- x : np.ndarray | torch.Tensor
96
- Batch of instances.
97
- model : Callable | nn.Module | nn.Sequential
98
- PyTorch model.
99
- device : torch.device | None, default None
100
- Device type used. The default None tries to use the GPU and falls back on CPU.
101
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
102
- batch_size : int, default 1e10
103
- Batch size used during prediction.
104
- preprocess_fn : Callable | None, default None
105
- Optional preprocessing function for each batch.
106
- dtype : np.dtype | torch.dtype, default np.float32
107
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
108
-
109
- Returns
110
- -------
111
- NDArray | torch.Tensor | tuple
112
- Numpy array, torch tensor or tuples of those with model outputs.
113
- """
114
- device = get_device(device)
115
- if isinstance(x, np.ndarray):
116
- x = torch.from_numpy(x)
117
- n = len(x)
118
- n_minibatch = int(np.ceil(n / batch_size))
119
- return_np = not isinstance(dtype, torch.dtype)
120
- preds = []
121
- with torch.no_grad():
122
- for i in range(n_minibatch):
123
- istart, istop = i * batch_size, min((i + 1) * batch_size, n)
124
- x_batch = x[istart:istop]
125
- if isinstance(preprocess_fn, Callable):
126
- x_batch = preprocess_fn(x_batch)
127
- preds_tmp = model(x_batch.to(device))
128
- if isinstance(preds_tmp, (list, tuple)):
129
- if len(preds) == 0: # init tuple with lists to store predictions
130
- preds = tuple([] for _ in range(len(preds_tmp)))
131
- for j, p in enumerate(preds_tmp):
132
- if isinstance(p, torch.Tensor):
133
- p = p.cpu()
134
- preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
135
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
136
- if isinstance(preds_tmp, torch.Tensor):
137
- preds_tmp = preds_tmp.cpu()
138
- if isinstance(preds, tuple):
139
- preds = list(preds)
140
- preds.append(
141
- preds_tmp
142
- if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
143
- else preds_tmp.numpy()
144
- )
145
- else:
146
- raise TypeError(
147
- f"Model output type {type(preds_tmp)} not supported. The model \
148
- output type needs to be one of list, tuple, NDArray or \
149
- torch.Tensor."
150
- )
151
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
152
- out: tuple | np.ndarray | torch.Tensor = (
153
- tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
154
- )
155
- return out
156
-
157
-
158
58
  def preprocess_drift(
159
59
  x: NDArray[Any],
160
60
  model: nn.Module,
@@ -11,12 +11,12 @@ from scipy.cluster.hierarchy import linkage
11
11
  from scipy.spatial.distance import pdist, squareform
12
12
 
13
13
  from dataeval.interop import to_numpy
14
- from dataeval.output import OutputMetadata, set_metadata
14
+ from dataeval.output import Output, set_metadata
15
15
  from dataeval.utils.shared import flatten
16
16
 
17
17
 
18
18
  @dataclass(frozen=True)
19
- class ClustererOutput(OutputMetadata):
19
+ class ClustererOutput(Output):
20
20
  """
21
21
  Output class for :class:`Clusterer` lint detector
22
22
 
@@ -495,7 +495,7 @@ class Clusterer:
495
495
  return exact_dupes, near_dupes
496
496
 
497
497
  # TODO: Move data input to evaluate from class
498
- @set_metadata(["data"])
498
+ @set_metadata(state=["data"])
499
499
  def evaluate(self) -> ClustererOutput:
500
500
  """Finds and flags indices of the data for Outliers and :term:`duplicates<Duplicates>`
501
501
 
@@ -9,7 +9,7 @@ from numpy.typing import ArrayLike
9
9
 
10
10
  from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
11
  from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
- from dataeval.output import OutputMetadata, set_metadata
12
+ from dataeval.output import Output, set_metadata
13
13
 
14
14
  DuplicateGroup = list[int]
15
15
  DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
@@ -17,7 +17,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
17
17
 
18
18
 
19
19
  @dataclass(frozen=True)
20
- class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
20
+ class DuplicatesOutput(Generic[TIndexCollection], Output):
21
21
  """
22
22
  Output class for :class:`Duplicates` lint detector
23
23
 
@@ -89,7 +89,7 @@ class Duplicates:
89
89
  @overload
90
90
  def from_stats(self, hashes: Sequence[HashStatsOutput]) -> DuplicatesOutput[DatasetDuplicateGroupMap]: ...
91
91
 
92
- @set_metadata(["only_exact"])
92
+ @set_metadata(state=["only_exact"])
93
93
  def from_stats(
94
94
  self, hashes: HashStatsOutput | Sequence[HashStatsOutput]
95
95
  ) -> DuplicatesOutput[DuplicateGroup] | DuplicatesOutput[DatasetDuplicateGroupMap]:
@@ -138,7 +138,7 @@ class Duplicates:
138
138
 
139
139
  return DuplicatesOutput(**duplicates)
140
140
 
141
- @set_metadata(["only_exact"])
141
+ @set_metadata(state=["only_exact"])
142
142
  def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
143
143
  """
144
144
  Returns duplicate image indices for both exact matches and near matches
@@ -14,7 +14,7 @@ from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
14
14
  from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
15
15
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput
16
16
  from dataeval.metrics.stats.visualstats import VisualStatsOutput
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
  IndexIssueMap = dict[int, dict[str, float]]
20
20
  OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
@@ -22,7 +22,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
22
22
 
23
23
 
24
24
  @dataclass(frozen=True)
25
- class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
25
+ class OutliersOutput(Generic[TIndexIssueMap], Output):
26
26
  """
27
27
  Output class for :class:`Outliers` lint detector
28
28
 
@@ -159,7 +159,7 @@ class Outliers:
159
159
  @overload
160
160
  def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
161
161
 
162
- @set_metadata(["outlier_method", "outlier_threshold"])
162
+ @set_metadata(state=["outlier_method", "outlier_threshold"])
163
163
  def from_stats(
164
164
  self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
165
165
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
@@ -228,7 +228,7 @@ class Outliers:
228
228
 
229
229
  return OutliersOutput(output_list)
230
230
 
231
- @set_metadata(["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
231
+ @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
232
232
  def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
233
233
  """
234
234
  Returns indices of Outliers with the issues identified for each
@@ -2,14 +2,14 @@
2
2
  Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- from dataeval import _IS_TENSORFLOW_AVAILABLE
5
+ from dataeval import _IS_TORCH_AVAILABLE
6
+ from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
6
7
 
7
- if _IS_TENSORFLOW_AVAILABLE:
8
- from dataeval.detectors.ood.ae import OOD_AE
9
- from dataeval.detectors.ood.aegmm import OOD_AEGMM
10
- from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
11
- from dataeval.detectors.ood.llr import OOD_LLR
12
- from dataeval.detectors.ood.vae import OOD_VAE
13
- from dataeval.detectors.ood.vaegmm import OOD_VAEGMM
8
+ __all__ = ["OODOutput", "OODScoreOutput"]
14
9
 
15
- __all__ = ["OOD_AE", "OOD_AEGMM", "OOD_LLR", "OOD_VAE", "OOD_VAEGMM", "OODOutput", "OODScoreOutput"]
10
+ if _IS_TORCH_AVAILABLE:
11
+ from dataeval.detectors.ood.ae_torch import OOD_AE
12
+
13
+ __all__ += ["OOD_AE"]
14
+
15
+ del _IS_TORCH_AVAILABLE
@@ -1,4 +1,6 @@
1
1
  """
2
+ Adapted for Pytorch from
3
+
2
4
  Source code derived from Alibi-Detect 0.11.4
3
5
  https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
6
 
@@ -8,55 +10,48 @@ Licensed under Apache Software License (Apache 2.0)
8
10
 
9
11
  from __future__ import annotations
10
12
 
11
- __all__ = ["OOD_AE"]
12
-
13
- from typing import TYPE_CHECKING, Callable
13
+ from typing import Callable
14
14
 
15
15
  import numpy as np
16
+ import torch
16
17
  from numpy.typing import ArrayLike
17
18
 
18
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
19
+ from dataeval.detectors.ood.base import OODScoreOutput
20
+ from dataeval.detectors.ood.base_torch import OODBase
19
21
  from dataeval.interop import as_numpy
20
- from dataeval.utils.lazy import lazyload
21
- from dataeval.utils.tensorflow._internal.utils import predict_batch
22
-
23
- if TYPE_CHECKING:
24
- import tensorflow as tf
25
- import tf_keras as keras
26
-
27
- import dataeval.utils.tensorflow._internal.models as tf_models
28
- else:
29
- tf = lazyload("tensorflow")
30
- keras = lazyload("tf_keras")
31
- tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
22
+ from dataeval.utils.torch.utils import predict_batch
32
23
 
33
24
 
34
25
  class OOD_AE(OODBase):
35
26
  """
36
- Autoencoder-based :term:`out of distribution<Out-of-distribution (OOD)>` detector.
27
+ Autoencoder based out-of-distribution detector.
37
28
 
38
29
  Parameters
39
30
  ----------
40
- model : AE
41
- An :term:`autoencoder<Autoencoder>` model.
31
+ model : AriaAutoencoder
32
+ An Autoencoder model.
42
33
  """
43
34
 
44
- def __init__(self, model: tf_models.AE) -> None:
45
- super().__init__(model)
35
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
36
+ super().__init__(model, device)
46
37
 
47
38
  def fit(
48
39
  self,
49
40
  x_ref: ArrayLike,
50
- threshold_perc: float = 100.0,
51
- loss_fn: Callable[..., tf.Tensor] | None = None,
52
- optimizer: keras.optimizers.Optimizer | None = None,
41
+ threshold_perc: float,
42
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
43
+ optimizer: torch.optim.Optimizer | None = None,
53
44
  epochs: int = 20,
54
45
  batch_size: int = 64,
55
- verbose: bool = True,
46
+ verbose: bool = False,
56
47
  ) -> None:
57
48
  if loss_fn is None:
58
- loss_fn = keras.losses.MeanSquaredError()
59
- super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
49
+ loss_fn = torch.nn.MSELoss()
50
+
51
+ if optimizer is None:
52
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
53
+
54
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
60
55
 
61
56
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
62
57
  self._validate(X := as_numpy(X))
@@ -12,27 +12,18 @@ __all__ = ["OODOutput", "OODScoreOutput"]
12
12
 
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass
15
- from typing import TYPE_CHECKING, Callable, Literal, cast
15
+ from typing import Callable, Generic, Literal, TypeVar
16
16
 
17
17
  import numpy as np
18
18
  from numpy.typing import ArrayLike, NDArray
19
19
 
20
20
  from dataeval.interop import to_numpy
21
- from dataeval.output import OutputMetadata, set_metadata
22
- from dataeval.utils.lazy import lazyload
23
- from dataeval.utils.tensorflow._internal.gmm import GaussianMixtureModelParams, gmm_params
24
- from dataeval.utils.tensorflow._internal.trainer import trainer
25
-
26
- if TYPE_CHECKING:
27
- import tensorflow as tf
28
- import tf_keras as keras
29
- else:
30
- tf = lazyload("tensorflow")
31
- keras = lazyload("tf_keras")
21
+ from dataeval.output import Output, set_metadata
22
+ from dataeval.utils.gmm import GaussianMixtureModelParams
32
23
 
33
24
 
34
25
  @dataclass(frozen=True)
35
- class OODOutput(OutputMetadata):
26
+ class OODOutput(Output):
36
27
  """
37
28
  Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
38
29
  :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
@@ -53,7 +44,7 @@ class OODOutput(OutputMetadata):
53
44
 
54
45
 
55
46
  @dataclass(frozen=True)
56
- class OODScoreOutput(OutputMetadata):
47
+ class OODScoreOutput(Output):
57
48
  """
58
49
  Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
59
50
  :class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
@@ -85,16 +76,62 @@ class OODScoreOutput(OutputMetadata):
85
76
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
86
77
 
87
78
 
88
- class OODBase(ABC):
89
- def __init__(self, model: keras.Model) -> None:
90
- self.model = model
79
+ TGMMData = TypeVar("TGMMData")
80
+
81
+
82
+ class OODGMMMixin(Generic[TGMMData]):
83
+ _gmm_params: GaussianMixtureModelParams[TGMMData]
84
+
91
85
 
92
- self._ref_score: OODScoreOutput
93
- self._threshold_perc: float
94
- self._data_info: tuple[tuple, type] | None = None
86
+ TModel = TypeVar("TModel", bound=Callable)
87
+ TLossFn = TypeVar("TLossFn", bound=Callable)
88
+ TOptimizer = TypeVar("TOptimizer")
89
+
90
+
91
+ class OODFitMixin(Generic[TLossFn, TOptimizer], ABC):
92
+ @abstractmethod
93
+ def fit(
94
+ self,
95
+ x_ref: ArrayLike,
96
+ threshold_perc: float,
97
+ loss_fn: TLossFn | None,
98
+ optimizer: TOptimizer | None,
99
+ epochs: int,
100
+ batch_size: int,
101
+ verbose: bool,
102
+ ) -> None:
103
+ """
104
+ Train the model and infer the threshold value.
95
105
 
96
- if not isinstance(model, keras.Model):
97
- raise TypeError("Model should be of type 'keras.Model'.")
106
+ Parameters
107
+ ----------
108
+ x_ref : ArrayLike
109
+ Training data.
110
+ threshold_perc : float, default 100.0
111
+ Percentage of reference data that is normal.
112
+ loss_fn : TLossFn
113
+ Loss function used for training.
114
+ optimizer : TOptimizer
115
+ Optimizer used for training.
116
+ epochs : int, default 20
117
+ Number of training epochs.
118
+ batch_size : int, default 64
119
+ Batch size used for training.
120
+ verbose : bool, default True
121
+ Whether to print training progress.
122
+ """
123
+
124
+
125
+ class OODBaseMixin(Generic[TModel], ABC):
126
+ _ref_score: OODScoreOutput
127
+ _threshold_perc: float
128
+ _data_info: tuple[tuple, type] | None = None
129
+
130
+ def __init__(
131
+ self,
132
+ model: TModel,
133
+ ) -> None:
134
+ self.model = model
98
135
 
99
136
  def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
100
137
  if not isinstance(X, np.ndarray):
@@ -107,9 +144,8 @@ class OODBase(ABC):
107
144
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
108
145
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
109
146
 
110
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
111
- attrs = ["_data_info", "_threshold_perc", "_ref_score"]
112
- attrs = attrs if additional_attrs is None else attrs + additional_attrs
147
+ def _validate_state(self, X: NDArray) -> None:
148
+ attrs = [k for c in self.__class__.mro()[:-1][::-1] if hasattr(c, "__annotations__") for k in c.__annotations__]
113
149
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
114
150
  raise RuntimeError("Metric needs to be `fit` before method call.")
115
151
  self._validate(X)
@@ -117,7 +153,7 @@ class OODBase(ABC):
117
153
  @abstractmethod
118
154
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput: ...
119
155
 
120
- @set_metadata()
156
+ @set_metadata
121
157
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
122
158
  """
123
159
  Compute the :term:`out of distribution<Out-of-distribution (OOD)>` scores for a given dataset.
@@ -140,53 +176,7 @@ class OODBase(ABC):
140
176
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
141
177
  return np.percentile(self._ref_score.get(ood_type), self._threshold_perc)
142
178
 
143
- def fit(
144
- self,
145
- x_ref: ArrayLike,
146
- threshold_perc: float,
147
- loss_fn: Callable[..., tf.Tensor],
148
- optimizer: keras.optimizers.Optimizer,
149
- epochs: int,
150
- batch_size: int,
151
- verbose: bool,
152
- ) -> None:
153
- """
154
- Train the model and infer the threshold value.
155
-
156
- Parameters
157
- ----------
158
- x_ref : ArrayLike
159
- Training data.
160
- threshold_perc : float, default 100.0
161
- Percentage of reference data that is normal.
162
- loss_fn : Callable | None, default None
163
- Loss function used for training.
164
- optimizer : Optimizer, default keras.optimizers.Adam
165
- Optimizer used for training.
166
- epochs : int, default 20
167
- Number of training epochs.
168
- batch_size : int, default 64
169
- Batch size used for training.
170
- verbose : bool, default True
171
- Whether to print training progress.
172
- """
173
-
174
- # Train the model
175
- trainer(
176
- model=self.model,
177
- loss_fn=loss_fn,
178
- x_train=to_numpy(x_ref),
179
- optimizer=optimizer,
180
- epochs=epochs,
181
- batch_size=batch_size,
182
- verbose=verbose,
183
- )
184
-
185
- # Infer the threshold values
186
- self._ref_score = self.score(x_ref, batch_size)
187
- self._threshold_perc = threshold_perc
188
-
189
- @set_metadata()
179
+ @set_metadata
190
180
  def predict(
191
181
  self,
192
182
  X: ArrayLike,
@@ -215,43 +205,3 @@ class OODBase(ABC):
215
205
  score = self.score(X, batch_size=batch_size)
216
206
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
217
207
  return OODOutput(is_ood=ood_pred, **score.dict())
218
-
219
-
220
- class OODGMMBase(OODBase):
221
- def __init__(self, model: keras.Model) -> None:
222
- super().__init__(model)
223
- self.gmm_params: GaussianMixtureModelParams
224
-
225
- def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
226
- if additional_attrs is None:
227
- additional_attrs = ["gmm_params"]
228
- super()._validate_state(X, additional_attrs)
229
-
230
- def fit(
231
- self,
232
- x_ref: ArrayLike,
233
- threshold_perc: float,
234
- loss_fn: Callable[..., tf.Tensor],
235
- optimizer: keras.optimizers.Optimizer,
236
- epochs: int,
237
- batch_size: int,
238
- verbose: bool,
239
- ) -> None:
240
- # Train the model
241
- trainer(
242
- model=self.model,
243
- loss_fn=loss_fn,
244
- x_train=to_numpy(x_ref),
245
- optimizer=optimizer,
246
- epochs=epochs,
247
- batch_size=batch_size,
248
- verbose=verbose,
249
- )
250
-
251
- # Calculate the GMM parameters
252
- _, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
253
- self.gmm_params = gmm_params(z, gamma)
254
-
255
- # Infer the threshold values
256
- self._ref_score = self.score(x_ref, batch_size)
257
- self._threshold_perc = threshold_perc