dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -1,9 +1,28 @@
1
- from typing import Dict, Iterable, List, Literal
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable
2
5
 
3
6
  from numpy.typing import ArrayLike
4
7
 
5
- from dataeval._internal.flags import ImageHash
6
- from dataeval._internal.metrics.stats import ImageStats
8
+ from dataeval._internal.flags import ImageStat
9
+ from dataeval._internal.metrics.stats import StatsOutput, imagestats
10
+ from dataeval._internal.output import OutputMetadata, set_metadata
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class DuplicatesOutput(OutputMetadata):
15
+ """
16
+ Attributes
17
+ ----------
18
+ exact : List[List[int]]
19
+ Indices of images that are exact matches
20
+ near: List[List[int]]
21
+ Indices of images that are near matches
22
+ """
23
+
24
+ exact: list[list[int]]
25
+ near: list[list[int]]
7
26
 
8
27
 
9
28
  class Duplicates:
@@ -13,8 +32,13 @@ class Duplicates:
13
32
 
14
33
  Attributes
15
34
  ----------
16
- stats : ImageStats(flags=ImageHash.ALL)
17
- Base stats class with the flags for checking duplicates
35
+ stats : StatsOutput
36
+ Output class of stats
37
+
38
+ Parameters
39
+ ----------
40
+ only_exact : bool, default False
41
+ Only inspect the dataset for exact image matches
18
42
 
19
43
  Example
20
44
  -------
@@ -23,51 +47,63 @@ class Duplicates:
23
47
  >>> dups = Duplicates()
24
48
  """
25
49
 
26
- def __init__(self):
27
- self.stats = ImageStats(ImageHash.ALL)
50
+ def __init__(self, only_exact: bool = False):
51
+ self.stats: StatsOutput
52
+ self.only_exact = only_exact
53
+
54
+ def _get_duplicates(self) -> dict[str, list[list[int]]]:
55
+ stats_dict = self.stats.dict()
56
+ if "xxhash" in stats_dict:
57
+ exact = {}
58
+ for i, value in enumerate(stats_dict["xxhash"]):
59
+ exact.setdefault(value, []).append(i)
60
+ exact = [v for v in exact.values() if len(v) > 1]
61
+ else:
62
+ exact = []
28
63
 
29
- def _get_duplicates(self) -> dict:
30
- exact = {}
31
- near = {}
32
- for i, value in enumerate(self.results["xxhash"]):
33
- exact.setdefault(value, []).append(i)
34
- for i, value in enumerate(self.results["pchash"]):
35
- near.setdefault(value, []).append(i)
36
- exact = [v for v in exact.values() if len(v) > 1]
37
- near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
64
+ if "pchash" in stats_dict and not self.only_exact:
65
+ near = {}
66
+ for i, value in enumerate(stats_dict["pchash"]):
67
+ near.setdefault(value, []).append(i)
68
+ near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
69
+ else:
70
+ near = []
38
71
 
39
72
  return {
40
73
  "exact": sorted(exact),
41
74
  "near": sorted(near),
42
75
  }
43
76
 
44
- def evaluate(self, images: Iterable[ArrayLike]) -> Dict[Literal["exact", "near"], List[int]]:
77
+ @set_metadata("dataeval.detectors", ["only_exact"])
78
+ def evaluate(self, data: Iterable[ArrayLike] | StatsOutput) -> DuplicatesOutput:
45
79
  """
46
80
  Returns duplicate image indices for both exact matches and near matches
47
81
 
48
82
  Parameters
49
83
  ----------
50
- images : Iterable[ArrayLike], shape - (N, C, H, W)
51
- A set of images in an ArrayLike format
84
+ data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput
85
+ A dataset of images in an ArrayLike format or the output from an imagestats metric analysis
52
86
 
53
87
  Returns
54
88
  -------
55
- Dict[str, List[int]]
56
- exact :
57
- List of groups of indices that are exact matches
58
- near :
59
- List of groups of indices that are near matches
89
+ DuplicatesOutput
90
+ List of groups of indices that are exact and near matches
60
91
 
61
92
  See Also
62
93
  --------
63
- ImageStats
94
+ imagestats
64
95
 
65
96
  Example
66
97
  -------
67
98
  >>> dups.evaluate(images)
68
- {'exact': [[3, 20], [16, 37]], 'near': [[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]]}
69
- """
70
- self.stats.reset()
71
- self.stats.update(images)
72
- self.results = self.stats.compute()
73
- return self._get_duplicates()
99
+ DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
100
+ """ # noqa: E501
101
+ if isinstance(data, StatsOutput):
102
+ if not data.xxhash:
103
+ raise ValueError("StatsOutput must include xxhash information of the images.")
104
+ if not self.only_exact and not data.pchash:
105
+ raise ValueError("StatsOutput must include pchash information of the images for near matches.")
106
+ self.stats = data
107
+ else:
108
+ self.stats = imagestats(data, ImageStat.XXHASH | (ImageStat(0) if self.only_exact else ImageStat.PCHASH))
109
+ return DuplicatesOutput(**self._get_duplicates())
@@ -6,10 +6,13 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from typing import Callable
10
12
 
11
13
  import keras
12
14
  import numpy as np
15
+ import tensorflow as tf
13
16
  from numpy.typing import ArrayLike
14
17
 
15
18
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
@@ -19,47 +22,30 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
19
22
 
20
23
 
21
24
  class OOD_AE(OODBase):
22
- def __init__(self, model: AE) -> None:
23
- """
24
- Autoencoder based out-of-distribution detector.
25
+ """
26
+ Autoencoder based out-of-distribution detector.
27
+
28
+ Parameters
29
+ ----------
30
+ model : AE
31
+ An Autoencoder model.
32
+ """
25
33
 
26
- Parameters
27
- ----------
28
- model : AE
29
- An Autoencoder model.
30
- """
34
+ def __init__(self, model: AE) -> None:
31
35
  super().__init__(model)
32
36
 
33
37
  def fit(
34
38
  self,
35
39
  x_ref: ArrayLike,
36
40
  threshold_perc: float = 100.0,
37
- loss_fn: Callable = keras.losses.MeanSquaredError(),
41
+ loss_fn: Callable[..., tf.Tensor] | None = None,
38
42
  optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
39
43
  epochs: int = 20,
40
44
  batch_size: int = 64,
41
45
  verbose: bool = True,
42
46
  ) -> None:
43
- """
44
- Train the AE model with recommended loss function and optimizer.
45
-
46
- Parameters
47
- ----------
48
- x_ref : ArrayLike
49
- Training batch.
50
- threshold_perc : float, default 100.0
51
- Percentage of reference data that is normal.
52
- loss_fn : Callable, default keras.losses.MeanSquaredError()
53
- Loss function used for training.
54
- optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
55
- Optimizer used for training.
56
- epochs : int, default 20
57
- Number of training epochs.
58
- batch_size : int, default 64
59
- Batch size used for training.
60
- verbose : bool, default True
61
- Whether to print training progress.
62
- """
47
+ if loss_fn is None:
48
+ loss_fn = keras.losses.MeanSquaredError()
63
49
  super().fit(to_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
64
50
 
65
51
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
@@ -6,9 +6,12 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from typing import Callable
10
12
 
11
13
  import keras
14
+ import tensorflow as tf
12
15
  from numpy.typing import ArrayLike
13
16
 
14
17
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
@@ -20,50 +23,53 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
20
23
 
21
24
 
22
25
  class OOD_AEGMM(OODGMMBase):
23
- def __init__(self, model: AEGMM) -> None:
24
- """
25
- AE with Gaussian Mixture Model based outlier detector.
26
+ """
27
+ AE with Gaussian Mixture Model based outlier detector.
26
28
 
27
- Parameters
28
- ----------
29
- model : AEGMM
30
- An AEGMM model.
31
- """
29
+ Parameters
30
+ ----------
31
+ model : AEGMM
32
+ An AEGMM model.
33
+ """
34
+
35
+ def __init__(self, model: AEGMM) -> None:
32
36
  super().__init__(model)
33
37
 
34
38
  def fit(
35
39
  self,
36
40
  x_ref: ArrayLike,
37
41
  threshold_perc: float = 100.0,
38
- loss_fn: Callable = LossGMM(),
42
+ loss_fn: Callable[..., tf.Tensor] | None = None,
39
43
  optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
40
44
  epochs: int = 20,
41
45
  batch_size: int = 64,
42
46
  verbose: bool = True,
43
47
  ) -> None:
48
+ if loss_fn is None:
49
+ loss_fn = LossGMM()
50
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
51
+
52
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
44
53
  """
45
- Train the AEGMM model with recommended loss function and optimizer.
54
+ Compute the out-of-distribution (OOD) score for a given dataset.
46
55
 
47
56
  Parameters
48
57
  ----------
49
- x_ref : ArrayLike
50
- Training batch.
51
- threshold_perc : float, default 100.0
52
- Percentage of reference data that is normal.
53
- loss_fn : Callable, default LossGMM()
54
- Loss function used for training.
55
- optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
56
- Optimizer used for training.
57
- epochs : int, default 20
58
- Number of training epochs.
59
- batch_size : int, default 64
60
- Batch size used for training.
61
- verbose : bool, default True
62
- Whether to print training progress.
63
- """
64
- super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
58
+ X : ArrayLike
59
+ Input data to score.
60
+ batch_size : int, default 1e10
61
+ Number of instances to process in each batch.
62
+ Use a smaller batch size if your dataset is large or if you encounter memory issues.
65
63
 
66
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
64
+ Returns
65
+ -------
66
+ OODScore
67
+ An object containing the instance-level OOD score.
68
+
69
+ Note
70
+ ----
71
+ This model does not produce a feature level score like the OOD_AE or OOD_VAE models.
72
+ """
67
73
  self._validate(X := to_numpy(X))
68
74
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
69
75
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
@@ -6,17 +6,39 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from abc import ABC, abstractmethod
10
- from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
12
+ from dataclasses import dataclass
13
+ from typing import Callable, Literal, NamedTuple, cast
11
14
 
12
15
  import keras
13
16
  import numpy as np
14
17
  import tensorflow as tf
15
- from numpy.typing import ArrayLike
18
+ from numpy.typing import ArrayLike, NDArray
16
19
 
17
20
  from dataeval._internal.interop import to_numpy
18
21
  from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
19
22
  from dataeval._internal.models.tensorflow.trainer import trainer
23
+ from dataeval._internal.output import OutputMetadata, set_metadata
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class OODOutput(OutputMetadata):
28
+ """
29
+ Attributes
30
+ ----------
31
+ is_ood : NDArray
32
+ Array of images that are detected as out of distribution
33
+ instance_score : NDArray
34
+ Instance score of the evaluated dataset
35
+ feature_score : NDArray | None
36
+ Feature score, if available, of the evaluated dataset
37
+ """
38
+
39
+ is_ood: NDArray[np.bool_]
40
+ instance_score: NDArray[np.float32]
41
+ feature_score: NDArray[np.float32] | None
20
42
 
21
43
 
22
44
  class OODScore(NamedTuple):
@@ -25,16 +47,28 @@ class OODScore(NamedTuple):
25
47
 
26
48
  Parameters
27
49
  ----------
28
- instance_score : np.ndarray
50
+ instance_score : NDArray
29
51
  Instance score of the evaluated dataset.
30
- feature_score : Optional[np.ndarray], default None
52
+ feature_score : NDArray | None, default None
31
53
  Feature score, if available, of the evaluated dataset.
32
54
  """
33
55
 
34
- instance_score: np.ndarray
35
- feature_score: Optional[np.ndarray] = None
56
+ instance_score: NDArray[np.float32]
57
+ feature_score: NDArray[np.float32] | None = None
58
+
59
+ def get(self, ood_type: Literal["instance", "feature"]) -> NDArray:
60
+ """
61
+ Returns either the instance or feature score
62
+
63
+ Parameters
64
+ ----------
65
+ ood_type : "instance" | "feature"
36
66
 
37
- def get(self, ood_type: Literal["instance", "feature"]) -> np.ndarray:
67
+ Returns
68
+ -------
69
+ NDArray
70
+ Either the instance or feature score based on input selection
71
+ """
38
72
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
39
73
 
40
74
 
@@ -44,23 +78,23 @@ class OODBase(ABC):
44
78
 
45
79
  self._ref_score: OODScore
46
80
  self._threshold_perc: float
47
- self._data_info: Optional[Tuple[tuple, type]] = None
81
+ self._data_info: tuple[tuple, type] | None = None
48
82
 
49
83
  if not isinstance(model, keras.Model):
50
84
  raise TypeError("Model should be of type 'keras.Model'.")
51
85
 
52
- def _get_data_info(self, X: np.ndarray) -> Tuple[tuple, type]:
86
+ def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
53
87
  if not isinstance(X, np.ndarray):
54
- raise TypeError("Dataset should of type: `np.ndarray`.")
88
+ raise TypeError("Dataset should of type: `NDArray`.")
55
89
  return X.shape[1:], X.dtype.type
56
90
 
57
- def _validate(self, X: np.ndarray) -> None:
91
+ def _validate(self, X: NDArray) -> None:
58
92
  check_data_info = self._get_data_info(X)
59
93
  if self._data_info is not None and check_data_info != self._data_info:
60
94
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
61
95
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
62
96
 
63
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
97
+ def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
64
98
  attrs = ["_data_info", "_threshold_perc", "_ref_score"]
65
99
  attrs = attrs if additional_attrs is None else attrs + additional_attrs
66
100
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
@@ -70,18 +104,20 @@ class OODBase(ABC):
70
104
  @abstractmethod
71
105
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
72
106
  """
73
- Compute instance and (optionally) feature level outlier scores.
107
+ Compute the out-of-distribution (OOD) scores for a given dataset.
74
108
 
75
109
  Parameters
76
110
  ----------
77
111
  X : ArrayLike
78
- Batch of instances.
79
- batch_size : int, default int(1e10)
80
- Batch size used when making predictions with the autoencoder.
112
+ Input data to score.
113
+ batch_size : int, default 1e10
114
+ Number of instances to process in each batch.
115
+ Use a smaller batch size if your dataset is large or if you encounter memory issues.
81
116
 
82
117
  Returns
83
118
  -------
84
- Instance and feature level outlier scores.
119
+ OODScore
120
+ An object containing the instance-level and feature-level OOD scores.
85
121
  """
86
122
 
87
123
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
@@ -90,33 +126,34 @@ class OODBase(ABC):
90
126
  def fit(
91
127
  self,
92
128
  x_ref: ArrayLike,
93
- threshold_perc: float,
94
- loss_fn: Callable,
95
- optimizer: keras.optimizers.Optimizer,
96
- epochs: int,
97
- batch_size: int,
98
- verbose: bool,
129
+ threshold_perc: float = 100.0,
130
+ loss_fn: Callable[..., tf.Tensor] | None = None,
131
+ optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
132
+ epochs: int = 20,
133
+ batch_size: int = 64,
134
+ verbose: bool = True,
99
135
  ) -> None:
100
136
  """
101
137
  Train the model and infer the threshold value.
102
138
 
103
139
  Parameters
104
140
  ----------
105
- x_ref: : ArrayLike
106
- Training batch.
107
- threshold_perc : float
141
+ x_ref : ArrayLike
142
+ Training data.
143
+ threshold_perc : float, default 100.0
108
144
  Percentage of reference data that is normal.
109
- loss_fn : Callable
145
+ loss_fn : Callable | None, default None
110
146
  Loss function used for training.
111
- optimizer : keras.optimizers.Optimizer
147
+ optimizer : Optimizer, default keras.optimizers.Adam
112
148
  Optimizer used for training.
113
- epochs : int
149
+ epochs : int, default 20
114
150
  Number of training epochs.
115
- batch_size : int
151
+ batch_size : int, default 64
116
152
  Batch size used for training.
117
- verbose : bool
153
+ verbose : bool, default True
118
154
  Whether to print training progress.
119
155
  """
156
+
120
157
  # Train the model
121
158
  trainer(
122
159
  model=self.model,
@@ -132,33 +169,35 @@ class OODBase(ABC):
132
169
  self._ref_score = self.score(x_ref, batch_size)
133
170
  self._threshold_perc = threshold_perc
134
171
 
172
+ @set_metadata("dataeval.detectors")
135
173
  def predict(
136
174
  self,
137
175
  X: ArrayLike,
138
176
  batch_size: int = int(1e10),
139
177
  ood_type: Literal["feature", "instance"] = "instance",
140
- ) -> Dict[str, np.ndarray]:
178
+ ) -> OODOutput:
141
179
  """
142
180
  Predict whether instances are out-of-distribution or not.
143
181
 
144
182
  Parameters
145
183
  ----------
146
184
  X : ArrayLike
147
- Batch of instances.
148
- batch_size : int, default int(1e10)
149
- Batch size used when making predictions with the autoencoder.
150
- ood_type : Literal["feature", "instance"], default "instance"
185
+ Input data for out-of-distribution prediction.
186
+ batch_size : int, default 1e10
187
+ Number of instances to process in each batch.
188
+ ood_type : "feature" | "instance", default "instance"
151
189
  Predict out-of-distribution at the 'feature' or 'instance' level.
152
190
 
153
191
  Returns
154
192
  -------
155
- Dictionary containing the outlier predictions and both feature and instance level outlier scores.
193
+ Dictionary containing the outlier predictions for the selected level,
194
+ and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
156
195
  """
157
196
  self._validate_state(X := to_numpy(X))
158
197
  # compute outlier scores
159
198
  score = self.score(X, batch_size=batch_size)
160
- ood_pred = (score.get(ood_type) > self._threshold_score(ood_type)).astype(int)
161
- return {**{"is_ood": ood_pred}, **score._asdict()}
199
+ ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
200
+ return OODOutput(is_ood=ood_pred, **score._asdict())
162
201
 
163
202
 
164
203
  class OODGMMBase(OODBase):
@@ -166,7 +205,7 @@ class OODGMMBase(OODBase):
166
205
  super().__init__(model)
167
206
  self.gmm_params: GaussianMixtureModelParams
168
207
 
169
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
208
+ def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
170
209
  if additional_attrs is None:
171
210
  additional_attrs = ["gmm_params"]
172
211
  super()._validate_state(X, additional_attrs)
@@ -174,12 +213,12 @@ class OODGMMBase(OODBase):
174
213
  def fit(
175
214
  self,
176
215
  x_ref: ArrayLike,
177
- threshold_perc: float,
178
- loss_fn: Callable[[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor],
179
- optimizer: keras.optimizers.Optimizer,
180
- epochs: int,
181
- batch_size: int,
182
- verbose: bool,
216
+ threshold_perc: float = 100.0,
217
+ loss_fn: Callable[..., tf.Tensor] | None = None,
218
+ optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
219
+ epochs: int = 20,
220
+ batch_size: int = 64,
221
+ verbose: bool = True,
183
222
  ) -> None:
184
223
  # Train the model
185
224
  trainer(
@@ -193,7 +232,7 @@ class OODGMMBase(OODBase):
193
232
  )
194
233
 
195
234
  # Calculate the GMM parameters
196
- _, z, gamma = cast(Tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
235
+ _, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
197
236
  self.gmm_params = gmm_params(z, gamma)
198
237
 
199
238
  # Infer the threshold values