dataeval 0.69.4__py3-none-any.whl → 0.70.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +8 -8
  2. dataeval/_internal/datasets.py +235 -131
  3. dataeval/_internal/detectors/clusterer.py +2 -0
  4. dataeval/_internal/detectors/drift/base.py +7 -8
  5. dataeval/_internal/detectors/drift/mmd.py +4 -4
  6. dataeval/_internal/detectors/duplicates.py +64 -45
  7. dataeval/_internal/detectors/merged_stats.py +23 -54
  8. dataeval/_internal/detectors/ood/ae.py +8 -6
  9. dataeval/_internal/detectors/ood/aegmm.py +6 -4
  10. dataeval/_internal/detectors/ood/base.py +12 -7
  11. dataeval/_internal/detectors/ood/llr.py +6 -4
  12. dataeval/_internal/detectors/ood/vae.py +5 -3
  13. dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  14. dataeval/_internal/detectors/outliers.py +137 -63
  15. dataeval/_internal/interop.py +11 -7
  16. dataeval/_internal/metrics/balance.py +13 -11
  17. dataeval/_internal/metrics/ber.py +5 -3
  18. dataeval/_internal/metrics/coverage.py +4 -0
  19. dataeval/_internal/metrics/divergence.py +9 -5
  20. dataeval/_internal/metrics/diversity.py +14 -12
  21. dataeval/_internal/metrics/parity.py +32 -22
  22. dataeval/_internal/metrics/stats/base.py +231 -0
  23. dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
  24. dataeval/_internal/metrics/stats/datasetstats.py +99 -0
  25. dataeval/_internal/metrics/stats/dimensionstats.py +113 -0
  26. dataeval/_internal/metrics/stats/hashstats.py +75 -0
  27. dataeval/_internal/metrics/stats/labelstats.py +125 -0
  28. dataeval/_internal/metrics/stats/pixelstats.py +119 -0
  29. dataeval/_internal/metrics/stats/visualstats.py +124 -0
  30. dataeval/_internal/metrics/uap.py +8 -4
  31. dataeval/_internal/metrics/utils.py +30 -15
  32. dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  33. dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  34. dataeval/_internal/output.py +3 -18
  35. dataeval/_internal/utils.py +11 -16
  36. dataeval/_internal/workflows/sufficiency.py +152 -151
  37. dataeval/detectors/__init__.py +4 -0
  38. dataeval/detectors/drift/__init__.py +8 -3
  39. dataeval/detectors/drift/kernels/__init__.py +4 -0
  40. dataeval/detectors/drift/updates/__init__.py +4 -0
  41. dataeval/detectors/linters/__init__.py +15 -4
  42. dataeval/detectors/ood/__init__.py +14 -2
  43. dataeval/metrics/__init__.py +5 -0
  44. dataeval/metrics/bias/__init__.py +13 -4
  45. dataeval/metrics/estimators/__init__.py +8 -8
  46. dataeval/metrics/stats/__init__.py +25 -3
  47. dataeval/utils/__init__.py +16 -3
  48. dataeval/utils/tensorflow/__init__.py +11 -0
  49. dataeval/utils/torch/__init__.py +12 -0
  50. dataeval/utils/torch/datasets/__init__.py +7 -0
  51. dataeval/workflows/__init__.py +6 -2
  52. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/METADATA +12 -4
  53. dataeval-0.70.1.dist-info/RECORD +80 -0
  54. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/WHEEL +1 -1
  55. dataeval/_internal/flags.py +0 -77
  56. dataeval/_internal/metrics/stats.py +0 -397
  57. dataeval/flags/__init__.py +0 -3
  58. dataeval/tensorflow/__init__.py +0 -3
  59. dataeval/torch/__init__.py +0 -3
  60. dataeval-0.69.4.dist-info/RECORD +0 -74
  61. /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
  62. /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
  63. /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
  64. /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
  65. /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
  66. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/LICENSE.txt +0 -0
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Generic, Iterable, Sequence, TypeVar, cast
4
+ from typing import Generic, Iterable, Sequence, TypeVar
5
5
 
6
6
  from numpy.typing import ArrayLike
7
7
 
8
8
  from dataeval._internal.detectors.merged_stats import combine_stats, get_dataset_step_from_idx
9
- from dataeval._internal.flags import ImageStat
10
- from dataeval._internal.metrics.stats import StatsOutput, imagestats
9
+ from dataeval._internal.metrics.stats.hashstats import HashStatsOutput, hashstats
11
10
  from dataeval._internal.output import OutputMetadata, set_metadata
12
11
 
13
12
  DuplicateGroup = list[int]
@@ -18,6 +17,8 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
18
17
  @dataclass(frozen=True)
19
18
  class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
20
19
  """
20
+ Output class for :class:`Duplicates` lint detector
21
+
21
22
  Attributes
22
23
  ----------
23
24
  exact : list[list[int] | dict[int, list[int]]]
@@ -53,26 +54,23 @@ class Duplicates:
53
54
  -------
54
55
  Initialize the Duplicates class:
55
56
 
56
- >>> dups = Duplicates()
57
+ >>> all_dupes = Duplicates()
58
+ >>> exact_dupes = Duplicates(only_exact=True)
57
59
  """
58
60
 
59
61
  def __init__(self, only_exact: bool = False):
60
- self.stats: StatsOutput
62
+ self.stats: HashStatsOutput
61
63
  self.only_exact = only_exact
62
64
 
63
- def _get_duplicates(self) -> dict[str, list[list[int]]]:
64
- stats_dict = self.stats.dict()
65
- if "xxhash" in stats_dict:
66
- exact_dict: dict[int, list] = {}
67
- for i, value in enumerate(stats_dict["xxhash"]):
68
- exact_dict.setdefault(value, []).append(i)
69
- exact = [sorted(v) for v in exact_dict.values() if len(v) > 1]
70
- else:
71
- exact = []
65
+ def _get_duplicates(self, stats: dict) -> dict[str, list[list[int]]]:
66
+ exact_dict: dict[int, list] = {}
67
+ for i, value in enumerate(stats["xxhash"]):
68
+ exact_dict.setdefault(value, []).append(i)
69
+ exact = [sorted(v) for v in exact_dict.values() if len(v) > 1]
72
70
 
73
- if "pchash" in stats_dict and not self.only_exact:
71
+ if not self.only_exact:
74
72
  near_dict: dict[int, list] = {}
75
- for i, value in enumerate(stats_dict["pchash"]):
73
+ for i, value in enumerate(stats["pchash"]):
76
74
  near_dict.setdefault(value, []).append(i)
77
75
  near = [sorted(v) for v in near_dict.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
78
76
  else:
@@ -84,14 +82,14 @@ class Duplicates:
84
82
  }
85
83
 
86
84
  @set_metadata("dataeval.detectors", ["only_exact"])
87
- def evaluate(self, data: Iterable[ArrayLike] | StatsOutput | Sequence[StatsOutput]) -> DuplicatesOutput:
85
+ def from_stats(self, hashes: HashStatsOutput | Sequence[HashStatsOutput]) -> DuplicatesOutput:
88
86
  """
89
87
  Returns duplicate image indices for both exact matches and near matches
90
88
 
91
89
  Parameters
92
90
  ----------
93
- data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput | Sequence[StatsOutput]
94
- A dataset of images in an ArrayLike format or the output(s) from an imagestats metric analysis
91
+ data : HashStatsOutput | Sequence[HashStatsOutput]
92
+ The output(s) from a hashstats analysis
95
93
 
96
94
  Returns
97
95
  -------
@@ -100,39 +98,60 @@ class Duplicates:
100
98
 
101
99
  See Also
102
100
  --------
103
- imagestats
101
+ hashstats
104
102
 
105
103
  Example
106
104
  -------
107
- >>> dups.evaluate(images)
108
- DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
109
- """ # noqa: E501
105
+ >>> exact_dupes.from_stats([hashes1, hashes2])
106
+ DuplicatesOutput(exact=[{0: [3, 20]}, {0: [16], 1: [12]}], near=[])
107
+ """
110
108
 
111
- stats, dataset_steps = combine_stats(data)
109
+ if isinstance(hashes, HashStatsOutput):
110
+ return DuplicatesOutput(**self._get_duplicates(hashes.dict()))
112
111
 
113
- if isinstance(stats, StatsOutput):
114
- if not stats.xxhash:
115
- raise ValueError("StatsOutput must include xxhash information of the images.")
116
- if not self.only_exact and not stats.pchash:
117
- raise ValueError("StatsOutput must include pchash information of the images for near matches.")
118
- self.stats = stats
119
- else:
120
- flags = ImageStat.XXHASH | (ImageStat(0) if self.only_exact else ImageStat.PCHASH)
121
- self.stats = imagestats(cast(Iterable[ArrayLike], data), flags)
112
+ if not isinstance(hashes, Sequence):
113
+ raise TypeError("Invalid stats output type; only use output from hashstats.")
122
114
 
123
- duplicates = self._get_duplicates()
115
+ combined, dataset_steps = combine_stats(hashes)
116
+ duplicates = self._get_duplicates(combined.dict())
124
117
 
125
118
  # split up results from combined dataset into individual dataset buckets
126
- if dataset_steps:
127
- dup_list: list[list[int]]
128
- for dup_type, dup_list in duplicates.items():
129
- dup_list_dict = []
130
- for idxs in dup_list:
131
- dup_dict = {}
132
- for idx in idxs:
133
- k, v = get_dataset_step_from_idx(idx, dataset_steps)
134
- dup_dict.setdefault(k, []).append(v)
135
- dup_list_dict.append(dup_dict)
136
- duplicates[dup_type] = dup_list_dict
119
+ for dup_type, dup_list in duplicates.items():
120
+ dup_list_dict = []
121
+ for idxs in dup_list:
122
+ dup_dict = {}
123
+ for idx in idxs:
124
+ k, v = get_dataset_step_from_idx(idx, dataset_steps)
125
+ dup_dict.setdefault(k, []).append(v)
126
+ dup_list_dict.append(dup_dict)
127
+ duplicates[dup_type] = dup_list_dict
137
128
 
138
129
  return DuplicatesOutput(**duplicates)
130
+
131
+ @set_metadata("dataeval.detectors", ["only_exact"])
132
+ def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput:
133
+ """
134
+ Returns duplicate image indices for both exact matches and near matches
135
+
136
+ Parameters
137
+ ----------
138
+ data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput | Sequence[StatsOutput]
139
+ A dataset of images in an ArrayLike format or the output(s) from a hashstats analysis
140
+
141
+ Returns
142
+ -------
143
+ DuplicatesOutput
144
+ List of groups of indices that are exact and near matches
145
+
146
+ See Also
147
+ --------
148
+ hashstats
149
+
150
+ Example
151
+ -------
152
+ >>> all_dupes.evaluate(images)
153
+ DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
154
+ """ # noqa: E501
155
+ self.stats = hashstats(data)
156
+ duplicates = self._get_duplicates(self.stats.dict())
157
+ return DuplicatesOutput(**duplicates)
@@ -1,71 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Sequence, cast
4
- from warnings import warn
3
+ from copy import deepcopy
4
+ from typing import Sequence, TypeVar
5
5
 
6
6
  import numpy as np
7
7
 
8
- from dataeval._internal.metrics.stats import StatsOutput
9
- from dataeval._internal.output import populate_defaults
8
+ from dataeval._internal.metrics.stats.base import BaseStatsOutput
10
9
 
10
+ TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput)
11
11
 
12
- def add_stats(a: StatsOutput, b: StatsOutput) -> StatsOutput:
13
- if not isinstance(a, StatsOutput) or not isinstance(b, StatsOutput):
14
- raise TypeError(f"Cannot add object of type {type(a)} and type {type(b)}.")
15
12
 
16
- a_dict = a.dict()
17
- b_dict = b.dict()
18
- a_keys = set(a_dict)
19
- b_keys = set(b_dict)
13
+ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
14
+ if type(a) is not type(b):
15
+ raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
20
16
 
21
- missing_keys = a_keys - b_keys
22
- if missing_keys:
23
- raise ValueError(f"Required keys are missing: {missing_keys}.")
17
+ sum_dict = deepcopy(a.dict())
24
18
 
25
- extra_keys = b_keys - a_keys
26
- if extra_keys:
27
- warn(f"Extraneous keys will be dropped: {extra_keys}.")
19
+ for k in sum_dict:
20
+ if isinstance(sum_dict[k], list):
21
+ sum_dict[k].extend(b.dict()[k])
22
+ else:
23
+ sum_dict[k] = np.concatenate((sum_dict[k], b.dict()[k]))
28
24
 
29
- # perform add of multi-channel stats
30
- if "ch_idx_map" in a_dict:
31
- for k, v in a_dict.items():
32
- if k == "ch_idx_map":
33
- offset = sum([len(idxs) for idxs in v.values()])
34
- for ch_k, ch_v in b_dict[k].items():
35
- if ch_k not in v:
36
- v[ch_k] = []
37
- a_dict[k][ch_k].extend([idx + offset for idx in ch_v])
38
- else:
39
- for ch_k in b_dict[k]:
40
- if ch_k not in v:
41
- v[ch_k] = b_dict[k][ch_k]
42
- else:
43
- v[ch_k] = np.concatenate((v[ch_k], b_dict[k][ch_k]), axis=1)
44
- else:
45
- for k in a_dict:
46
- if isinstance(a_dict[k], list):
47
- a_dict[k].extend(b_dict[k])
48
- else:
49
- a_dict[k] = np.concatenate((a_dict[k], b_dict[k]))
25
+ return type(a)(**sum_dict)
50
26
 
51
- return StatsOutput(**populate_defaults(a_dict, StatsOutput))
52
-
53
-
54
- def combine_stats(stats) -> tuple[StatsOutput | None, list[int]]:
55
- dataset_steps = []
56
-
57
- if isinstance(stats, StatsOutput):
58
- return stats, dataset_steps
59
27
 
28
+ def combine_stats(stats: Sequence[TStatsOutput]) -> tuple[TStatsOutput, list[int]]:
60
29
  output = None
61
- if isinstance(stats, Sequence) and isinstance(stats[0], StatsOutput):
62
- stats = cast(Sequence[StatsOutput], stats)
63
- cur_len = 0
64
- for s in stats:
65
- output = s if output is None else add_stats(output, s)
66
- cur_len += len(s)
67
- dataset_steps.append(cur_len)
68
-
30
+ dataset_steps = []
31
+ cur_len = 0
32
+ for s in stats:
33
+ output = s if output is None else add_stats(output, s)
34
+ cur_len += len(s)
35
+ dataset_steps.append(cur_len)
36
+ if output is None:
37
+ raise TypeError("Cannot combine empty sequence of stats.")
69
38
  return output, dataset_steps
70
39
 
71
40
 
@@ -15,10 +15,11 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
19
- from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
19
+ from dataeval._internal.interop import as_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import AE
21
21
  from dataeval._internal.models.tensorflow.utils import predict_batch
22
+ from dataeval._internal.output import set_metadata
22
23
 
23
24
 
24
25
  class OOD_AE(OODBase):
@@ -46,10 +47,11 @@ class OOD_AE(OODBase):
46
47
  ) -> None:
47
48
  if loss_fn is None:
48
49
  loss_fn = keras.losses.MeanSquaredError()
49
- super().fit(to_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
50
+ super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
50
51
 
51
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
52
- self._validate(X := to_numpy(X))
52
+ @set_metadata("dataeval.detectors")
53
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
54
+ self._validate(X := as_numpy(X))
53
55
 
54
56
  # reconstruct instances
55
57
  X_recon = predict_batch(X, self.model, batch_size=batch_size)
@@ -62,4 +64,4 @@ class OOD_AE(OODBase):
62
64
  sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
63
65
  iscore = np.mean(sorted_fscore_perc, axis=1)
64
66
 
65
- return OODScore(iscore, fscore)
67
+ return OODScoreOutput(iscore, fscore)
@@ -14,12 +14,13 @@ import keras
14
14
  import tensorflow as tf
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
17
+ from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
18
18
  from dataeval._internal.interop import to_numpy
19
19
  from dataeval._internal.models.tensorflow.autoencoder import AEGMM
20
20
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
21
21
  from dataeval._internal.models.tensorflow.losses import LossGMM
22
22
  from dataeval._internal.models.tensorflow.utils import predict_batch
23
+ from dataeval._internal.output import set_metadata
23
24
 
24
25
 
25
26
  class OOD_AEGMM(OODGMMBase):
@@ -49,7 +50,8 @@ class OOD_AEGMM(OODGMMBase):
49
50
  loss_fn = LossGMM()
50
51
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
51
52
 
52
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
53
+ @set_metadata("dataeval.detectors")
54
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
53
55
  """
54
56
  Compute the out-of-distribution (OOD) score for a given dataset.
55
57
 
@@ -63,7 +65,7 @@ class OOD_AEGMM(OODGMMBase):
63
65
 
64
66
  Returns
65
67
  -------
66
- OODScore
68
+ OODScoreOutput
67
69
  An object containing the instance-level OOD score.
68
70
 
69
71
  Note
@@ -73,4 +75,4 @@ class OOD_AEGMM(OODGMMBase):
73
75
  self._validate(X := to_numpy(X))
74
76
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
75
77
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
76
- return OODScore(energy.numpy()) # type: ignore
78
+ return OODScoreOutput(energy.numpy()) # type: ignore
@@ -10,7 +10,7 @@ from __future__ import annotations
10
10
 
11
11
  from abc import ABC, abstractmethod
12
12
  from dataclasses import dataclass
13
- from typing import Callable, Literal, NamedTuple, cast
13
+ from typing import Callable, Literal, cast
14
14
 
15
15
  import keras
16
16
  import numpy as np
@@ -26,6 +26,9 @@ from dataeval._internal.output import OutputMetadata, set_metadata
26
26
  @dataclass(frozen=True)
27
27
  class OODOutput(OutputMetadata):
28
28
  """
29
+ Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
30
+ :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
31
+
29
32
  Attributes
30
33
  ----------
31
34
  is_ood : NDArray
@@ -41,9 +44,11 @@ class OODOutput(OutputMetadata):
41
44
  feature_score: NDArray[np.float32] | None
42
45
 
43
46
 
44
- class OODScore(NamedTuple):
47
+ @dataclass(frozen=True)
48
+ class OODScoreOutput(OutputMetadata):
45
49
  """
46
- NamedTuple containing the instance and (optionally) feature score.
50
+ Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
51
+ :class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
47
52
 
48
53
  Parameters
49
54
  ----------
@@ -76,7 +81,7 @@ class OODBase(ABC):
76
81
  def __init__(self, model: keras.Model) -> None:
77
82
  self.model = model
78
83
 
79
- self._ref_score: OODScore
84
+ self._ref_score: OODScoreOutput
80
85
  self._threshold_perc: float
81
86
  self._data_info: tuple[tuple, type] | None = None
82
87
 
@@ -102,7 +107,7 @@ class OODBase(ABC):
102
107
  self._validate(X)
103
108
 
104
109
  @abstractmethod
105
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
110
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
106
111
  """
107
112
  Compute the out-of-distribution (OOD) scores for a given dataset.
108
113
 
@@ -116,7 +121,7 @@ class OODBase(ABC):
116
121
 
117
122
  Returns
118
123
  -------
119
- OODScore
124
+ OODScoreOutput
120
125
  An object containing the instance-level and feature-level OOD scores.
121
126
  """
122
127
 
@@ -197,7 +202,7 @@ class OODBase(ABC):
197
202
  # compute outlier scores
198
203
  score = self.score(X, batch_size=batch_size)
199
204
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
200
- return OODOutput(is_ood=ood_pred, **score._asdict())
205
+ return OODOutput(is_ood=ood_pred, **score.dict())
201
206
 
202
207
 
203
208
  class OODGMMBase(OODBase):
@@ -18,11 +18,12 @@ from keras.layers import Input
18
18
  from keras.models import Model
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
21
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
22
22
  from dataeval._internal.interop import to_numpy
23
23
  from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
24
24
  from dataeval._internal.models.tensorflow.trainer import trainer
25
25
  from dataeval._internal.models.tensorflow.utils import predict_batch
26
+ from dataeval._internal.output import set_metadata
26
27
 
27
28
 
28
29
  def build_model(
@@ -124,7 +125,7 @@ class OOD_LLR(OODBase):
124
125
  self.sequential = sequential
125
126
  self.log_prob = log_prob
126
127
 
127
- self._ref_score: OODScore
128
+ self._ref_score: OODScoreOutput
128
129
  self._threshold_perc: float
129
130
  self._data_info: tuple[tuple, type] | None = None
130
131
 
@@ -279,12 +280,13 @@ class OOD_LLR(OODBase):
279
280
  logp_b = logp_fn(self.dist_b, X, return_per_feature=return_per_feature, batch_size=batch_size)
280
281
  return logp_s - logp_b
281
282
 
283
+ @set_metadata("dataeval.detectors")
282
284
  def score(
283
285
  self,
284
286
  X: ArrayLike,
285
287
  batch_size: int = int(1e10),
286
- ) -> OODScore:
288
+ ) -> OODScoreOutput:
287
289
  self._validate(X := to_numpy(X))
288
290
  fscore = -self._llr(X, True, batch_size=batch_size)
289
291
  iscore = -self._llr(X, False, batch_size=batch_size)
290
- return OODScore(iscore, fscore)
292
+ return OODScoreOutput(iscore, fscore)
@@ -15,11 +15,12 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
19
19
  from dataeval._internal.interop import to_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import VAE
21
21
  from dataeval._internal.models.tensorflow.losses import Elbo
22
22
  from dataeval._internal.models.tensorflow.utils import predict_batch
23
+ from dataeval._internal.output import set_metadata
23
24
 
24
25
 
25
26
  class OOD_VAE(OODBase):
@@ -67,7 +68,8 @@ class OOD_VAE(OODBase):
67
68
  loss_fn = Elbo(0.05)
68
69
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
69
70
 
70
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
71
+ @set_metadata("dataeval.detectors")
72
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
71
73
  self._validate(X := to_numpy(X))
72
74
 
73
75
  # sample reconstructed instances
@@ -86,4 +88,4 @@ class OOD_VAE(OODBase):
86
88
  sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
87
89
  iscore = np.mean(sorted_fscore_perc, axis=1)
88
90
 
89
- return OODScore(iscore, fscore)
91
+ return OODScoreOutput(iscore, fscore)
@@ -15,12 +15,13 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
19
19
  from dataeval._internal.interop import to_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
21
21
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
22
22
  from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
23
23
  from dataeval._internal.models.tensorflow.utils import predict_batch
24
+ from dataeval._internal.output import set_metadata
24
25
 
25
26
 
26
27
  class OOD_VAEGMM(OODGMMBase):
@@ -53,7 +54,8 @@ class OOD_VAEGMM(OODGMMBase):
53
54
  loss_fn = LossGMM(elbo=Elbo(0.05))
54
55
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
55
56
 
56
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
57
+ @set_metadata("dataeval.detectors")
58
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
57
59
  """
58
60
  Compute the out-of-distribution (OOD) score for a given dataset.
59
61
 
@@ -67,7 +69,7 @@ class OOD_VAEGMM(OODGMMBase):
67
69
 
68
70
  Returns
69
71
  -------
70
- OODScore
72
+ OODScoreOutput
71
73
  An object containing the instance-level OOD score.
72
74
 
73
75
  Note
@@ -84,4 +86,4 @@ class OOD_VAEGMM(OODGMMBase):
84
86
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
85
87
  energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
86
88
  iscore = np.mean(energy_samples, axis=-1)
87
- return OODScore(iscore)
89
+ return OODScoreOutput(iscore)