dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/_internal/detectors/clusterer.py +47 -34
  3. dataeval/_internal/detectors/drift/base.py +53 -35
  4. dataeval/_internal/detectors/drift/cvm.py +5 -4
  5. dataeval/_internal/detectors/drift/ks.py +7 -6
  6. dataeval/_internal/detectors/drift/mmd.py +39 -19
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -8
  9. dataeval/_internal/detectors/duplicates.py +57 -30
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/ae.py +2 -1
  12. dataeval/_internal/detectors/ood/aegmm.py +2 -1
  13. dataeval/_internal/detectors/ood/base.py +37 -15
  14. dataeval/_internal/detectors/ood/llr.py +9 -8
  15. dataeval/_internal/detectors/ood/vae.py +2 -1
  16. dataeval/_internal/detectors/ood/vaegmm.py +2 -1
  17. dataeval/_internal/flags.py +42 -21
  18. dataeval/_internal/interop.py +3 -12
  19. dataeval/_internal/metrics/balance.py +188 -0
  20. dataeval/_internal/metrics/ber.py +123 -48
  21. dataeval/_internal/metrics/coverage.py +90 -74
  22. dataeval/_internal/metrics/divergence.py +101 -67
  23. dataeval/_internal/metrics/diversity.py +211 -0
  24. dataeval/_internal/metrics/parity.py +287 -155
  25. dataeval/_internal/metrics/stats.py +198 -317
  26. dataeval/_internal/metrics/uap.py +40 -29
  27. dataeval/_internal/metrics/utils.py +430 -0
  28. dataeval/_internal/models/tensorflow/losses.py +3 -3
  29. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  30. dataeval/_internal/models/tensorflow/utils.py +4 -3
  31. dataeval/_internal/output.py +82 -0
  32. dataeval/_internal/utils.py +64 -0
  33. dataeval/_internal/workflows/sufficiency.py +96 -107
  34. dataeval/flags/__init__.py +2 -2
  35. dataeval/metrics/__init__.py +26 -7
  36. dataeval/utils/__init__.py +9 -0
  37. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  38. dataeval-0.65.0.dist-info/RECORD +60 -0
  39. dataeval/_internal/functional/__init__.py +0 -0
  40. dataeval/_internal/functional/ber.py +0 -63
  41. dataeval/_internal/functional/coverage.py +0 -75
  42. dataeval/_internal/functional/divergence.py +0 -16
  43. dataeval/_internal/functional/hash.py +0 -79
  44. dataeval/_internal/functional/metadata.py +0 -136
  45. dataeval/_internal/functional/metadataparity.py +0 -190
  46. dataeval/_internal/functional/uap.py +0 -6
  47. dataeval/_internal/functional/utils.py +0 -158
  48. dataeval/_internal/maite/__init__.py +0 -0
  49. dataeval/_internal/maite/utils.py +0 -30
  50. dataeval/_internal/metrics/base.py +0 -92
  51. dataeval/_internal/metrics/metadata.py +0 -610
  52. dataeval/_internal/metrics/metadataparity.py +0 -67
  53. dataeval-0.63.0.dist-info/RECORD +0 -68
  54. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  55. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -1,8 +1,27 @@
1
- from typing import Dict, Iterable, List, Literal
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Iterable, List
2
3
 
3
- from dataeval._internal.flags import ImageHash
4
- from dataeval._internal.interop import ArrayLike
5
- from dataeval._internal.metrics.stats import ImageStats
4
+ from numpy.typing import ArrayLike
5
+
6
+ from dataeval._internal.metrics.stats import StatsOutput
7
+ from dataeval._internal.output import OutputMetadata, set_metadata
8
+ from dataeval.flags import ImageStat
9
+ from dataeval.metrics import imagestats
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class DuplicatesOutput(OutputMetadata):
14
+ """
15
+ Attributes
16
+ ----------
17
+ exact : List[List[int]]
18
+ Indices of images that are exact matches
19
+ near: List[List[int]]
20
+ Indices of images that are near matches
21
+ """
22
+
23
+ exact: List[List[int]]
24
+ near: List[List[int]]
6
25
 
7
26
 
8
27
  class Duplicates:
@@ -12,8 +31,8 @@ class Duplicates:
12
31
 
13
32
  Attributes
14
33
  ----------
15
- stats : ImageStats(flags=ImageHash.ALL)
16
- Base stats class with the flags for checking duplicates
34
+ stats : StatsOutput
35
+ Output class of stats
17
36
 
18
37
  Example
19
38
  -------
@@ -22,25 +41,36 @@ class Duplicates:
22
41
  >>> dups = Duplicates()
23
42
  """
24
43
 
25
- def __init__(self):
26
- self.stats = ImageStats(ImageHash.ALL)
44
+ def __init__(self, find_exact: bool = True, find_near: bool = True):
45
+ self.stats: StatsOutput
46
+ self.find_exact = find_exact
47
+ self.find_near = find_near
27
48
 
28
- def _get_duplicates(self) -> dict:
29
- exact = {}
30
- near = {}
31
- for i, value in enumerate(self.results["xxhash"]):
32
- exact.setdefault(value, []).append(i)
33
- for i, value in enumerate(self.results["pchash"]):
34
- near.setdefault(value, []).append(i)
35
- exact = [v for v in exact.values() if len(v) > 1]
36
- near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
49
+ def _get_duplicates(self) -> Dict[str, List[List[int]]]:
50
+ stats_dict = self.stats.dict()
51
+ if "xxhash" in stats_dict:
52
+ exact = {}
53
+ for i, value in enumerate(stats_dict["xxhash"]):
54
+ exact.setdefault(value, []).append(i)
55
+ exact = [v for v in exact.values() if len(v) > 1]
56
+ else:
57
+ exact = []
58
+
59
+ if "pchash" in stats_dict:
60
+ near = {}
61
+ for i, value in enumerate(stats_dict["pchash"]):
62
+ near.setdefault(value, []).append(i)
63
+ near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
64
+ else:
65
+ near = []
37
66
 
38
67
  return {
39
68
  "exact": sorted(exact),
40
69
  "near": sorted(near),
41
70
  }
42
71
 
43
- def evaluate(self, images: Iterable[ArrayLike]) -> Dict[Literal["exact", "near"], List[int]]:
72
+ @set_metadata("dataeval.detectors", ["find_exact", "find_near"])
73
+ def evaluate(self, images: Iterable[ArrayLike]) -> DuplicatesOutput:
44
74
  """
45
75
  Returns duplicate image indices for both exact matches and near matches
46
76
 
@@ -51,22 +81,19 @@ class Duplicates:
51
81
 
52
82
  Returns
53
83
  -------
54
- Dict[str, List[int]]
55
- exact :
56
- List of groups of indices that are exact matches
57
- near :
58
- List of groups of indices that are near matches
84
+ DuplicatesOutput
85
+ List of groups of indices that are exact and near matches
59
86
 
60
87
  See Also
61
88
  --------
62
- ImageStats
89
+ imagestats
63
90
 
64
91
  Example
65
92
  -------
66
93
  >>> dups.evaluate(images)
67
- {'exact': [[3, 20], [16, 37]], 'near': [[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]]}
68
- """
69
- self.stats.reset()
70
- self.stats.update(images)
71
- self.results = self.stats.compute()
72
- return self._get_duplicates()
94
+ DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
95
+ """ # noqa: E501
96
+ flag_exact = ImageStat.XXHASH if self.find_exact else ImageStat(0)
97
+ flag_near = ImageStat.PCHASH if self.find_near else ImageStat(0)
98
+ self.stats = imagestats(images, flag_exact | flag_near)
99
+ return DuplicatesOutput(**self._get_duplicates())
@@ -1,15 +1,31 @@
1
- from typing import Iterable, Literal, Optional, Sequence, Union
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Iterable, Literal, Optional
2
3
 
3
4
  import numpy as np
5
+ from numpy.typing import ArrayLike, NDArray
4
6
 
5
- from dataeval._internal.flags import ImageProperty, ImageVisuals, LinterFlags
6
- from dataeval._internal.interop import ArrayLike
7
- from dataeval._internal.metrics.stats import ImageStats
7
+ from dataeval._internal.flags import verify_supported
8
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
+ from dataeval.flags import ImageStat
10
+ from dataeval.metrics import imagestats
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class LinterOutput(OutputMetadata):
15
+ """
16
+ Attributes
17
+ ----------
18
+ issues : Dict[int, Dict[str, float]]
19
+ Dictionary containing the indices of outliers and a dictionary showing
20
+ the issues and calculated values for the given index.
21
+ """
22
+
23
+ issues: Dict[int, Dict[str, float]]
8
24
 
9
25
 
10
26
  def _get_outlier_mask(
11
- values: np.ndarray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
12
- ) -> np.ndarray:
27
+ values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
28
+ ) -> NDArray:
13
29
  if method == "zscore":
14
30
  threshold = threshold if threshold else 3.0
15
31
  std = np.std(values)
@@ -18,7 +34,7 @@ def _get_outlier_mask(
18
34
  elif method == "modzscore":
19
35
  threshold = threshold if threshold else 3.5
20
36
  abs_diff = np.abs(values - np.median(values))
21
- med_abs_diff = np.median(abs_diff)
37
+ med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
22
38
  mod_z_score = 0.6745 * abs_diff / med_abs_diff
23
39
  return mod_z_score > threshold
24
40
  elif method == "iqr":
@@ -36,8 +52,9 @@ class Linter:
36
52
 
37
53
  Parameters
38
54
  ----------
39
- flags : [ImageProperty | ImageStatistics | ImageVisuals], default None
55
+ flags : ImageStat, default ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS
40
56
  Metric(s) to calculate for each image - calculates all metrics if None
57
+ Only supports ImageStat.ALL_STATS
41
58
  outlier_method : ["modzscore" | "zscore" | "iqr"], optional - default "modzscore"
42
59
  Statistical method used to identify outliers
43
60
  outlier_threshold : float, optional - default None
@@ -46,8 +63,8 @@ class Linter:
46
63
 
47
64
  Attributes
48
65
  ----------
49
- stats : ImageStats
50
- Class to hold the value of each metric for each image
66
+ stats : Dict[str, Any]
67
+ Dictionary to hold the value of each metric for each image
51
68
 
52
69
  See Also
53
70
  --------
@@ -81,7 +98,7 @@ class Linter:
81
98
 
82
99
  Specifying specific metrics to analyze:
83
100
 
84
- >>> lint = Linter(flags=[ImageProperty.SIZE, ImageVisuals.ALL])
101
+ >>> lint = Linter(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
85
102
 
86
103
  Specifying an outlier method:
87
104
 
@@ -94,19 +111,19 @@ class Linter:
94
111
 
95
112
  def __init__(
96
113
  self,
97
- flags: Optional[Union[LinterFlags, Sequence[LinterFlags]]] = None,
114
+ flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
98
115
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
99
116
  outlier_threshold: Optional[float] = None,
100
117
  ):
101
- flags = flags if flags is not None else (ImageProperty.ALL, ImageVisuals.ALL)
102
- self.stats = ImageStats(flags)
118
+ verify_supported(flags, ImageStat.ALL_STATS)
119
+ self.flags = flags
103
120
  self.outlier_method: Literal["zscore", "modzscore", "iqr"] = outlier_method
104
121
  self.outlier_threshold = outlier_threshold
105
122
 
106
123
  def _get_outliers(self) -> dict:
107
124
  flagged_images = {}
108
-
109
- for stat, values in self.results.items():
125
+ stats_dict = self.stats.dict()
126
+ for stat, values in stats_dict.items():
110
127
  if not isinstance(values, np.ndarray):
111
128
  continue
112
129
 
@@ -118,7 +135,8 @@ class Linter:
118
135
 
119
136
  return dict(sorted(flagged_images.items()))
120
137
 
121
- def evaluate(self, images: Iterable[ArrayLike]) -> dict:
138
+ @set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
139
+ def evaluate(self, images: Iterable[ArrayLike]) -> LinterOutput:
122
140
  """
123
141
  Returns indices of outliers with the issues identified for each
124
142
 
@@ -130,8 +148,8 @@ class Linter:
130
148
 
131
149
  Returns
132
150
  -------
133
- Dict[int, Dict[str, float]]
134
- Dictionary containing the indices of outliers and a dictionary showing
151
+ LinterOutput
152
+ Output class containing the indices of outliers and a dictionary showing
135
153
  the issues and calculated values for the given index.
136
154
 
137
155
  Example
@@ -139,9 +157,7 @@ class Linter:
139
157
  Evaluate the dataset:
140
158
 
141
159
  >>> lint.evaluate(images)
142
- {18: {'brightness': 0.78}, 25: {'brightness': 0.98}}
160
+ LinterOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
143
161
  """
144
- self.stats.reset()
145
- self.stats.update(images)
146
- self.results = self.stats.compute()
147
- return self._get_outliers()
162
+ self.stats = imagestats(images, self.flags)
163
+ return LinterOutput(self._get_outliers())
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import AE
17
18
  from dataeval._internal.models.tensorflow.utils import predict_batch
18
19
 
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable
10
10
 
11
11
  import keras
12
+ from numpy.typing import ArrayLike
12
13
 
13
14
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
  from dataeval._internal.models.tensorflow.autoencoder import AEGMM
16
17
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
17
18
  from dataeval._internal.models.tensorflow.losses import LossGMM
@@ -7,15 +7,36 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
10
+ from dataclasses import dataclass
11
+ from typing import Callable, List, Literal, NamedTuple, Optional, Tuple, cast
11
12
 
12
13
  import keras
13
14
  import numpy as np
14
15
  import tensorflow as tf
16
+ from numpy.typing import ArrayLike, NDArray
15
17
 
16
- from dataeval._internal.interop import ArrayLike, to_numpy
18
+ from dataeval._internal.interop import to_numpy
17
19
  from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
18
20
  from dataeval._internal.models.tensorflow.trainer import trainer
21
+ from dataeval._internal.output import OutputMetadata, set_metadata
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class OODOutput(OutputMetadata):
26
+ """
27
+ Attributes
28
+ ----------
29
+ is_ood : NDArray[np.bool_]
30
+ Array of images that are detected as out of distribution
31
+ instance_score : NDArray[np.float32]
32
+ Instance score of the evaluated dataset
33
+ feature_score : Optional[NDArray[np.float32]]
34
+ Feature score, if available, of the evaluated dataset
35
+ """
36
+
37
+ is_ood: NDArray[np.bool_]
38
+ instance_score: NDArray[np.float32]
39
+ feature_score: Optional[NDArray[np.float32]]
19
40
 
20
41
 
21
42
  class OODScore(NamedTuple):
@@ -24,16 +45,16 @@ class OODScore(NamedTuple):
24
45
 
25
46
  Parameters
26
47
  ----------
27
- instance_score : np.ndarray
48
+ instance_score : NDArray[np.float32]
28
49
  Instance score of the evaluated dataset.
29
- feature_score : Optional[np.ndarray], default None
50
+ feature_score : Optional[NDArray[np.float32]], default None
30
51
  Feature score, if available, of the evaluated dataset.
31
52
  """
32
53
 
33
- instance_score: np.ndarray
34
- feature_score: Optional[np.ndarray] = None
54
+ instance_score: NDArray[np.float32]
55
+ feature_score: Optional[NDArray[np.float32]] = None
35
56
 
36
- def get(self, ood_type: Literal["instance", "feature"]) -> np.ndarray:
57
+ def get(self, ood_type: Literal["instance", "feature"]) -> NDArray:
37
58
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
38
59
 
39
60
 
@@ -48,18 +69,18 @@ class OODBase(ABC):
48
69
  if not isinstance(model, keras.Model):
49
70
  raise TypeError("Model should be of type 'keras.Model'.")
50
71
 
51
- def _get_data_info(self, X: np.ndarray) -> Tuple[tuple, type]:
72
+ def _get_data_info(self, X: NDArray) -> Tuple[tuple, type]:
52
73
  if not isinstance(X, np.ndarray):
53
- raise TypeError("Dataset should of type: `np.ndarray`.")
74
+ raise TypeError("Dataset should of type: `NDArray`.")
54
75
  return X.shape[1:], X.dtype.type
55
76
 
56
- def _validate(self, X: np.ndarray) -> None:
77
+ def _validate(self, X: NDArray) -> None:
57
78
  check_data_info = self._get_data_info(X)
58
79
  if self._data_info is not None and check_data_info != self._data_info:
59
80
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
60
81
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
61
82
 
62
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
83
+ def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
63
84
  attrs = ["_data_info", "_threshold_perc", "_ref_score"]
64
85
  attrs = attrs if additional_attrs is None else attrs + additional_attrs
65
86
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
@@ -131,12 +152,13 @@ class OODBase(ABC):
131
152
  self._ref_score = self.score(x_ref, batch_size)
132
153
  self._threshold_perc = threshold_perc
133
154
 
155
+ @set_metadata("dataeval.detectors")
134
156
  def predict(
135
157
  self,
136
158
  X: ArrayLike,
137
159
  batch_size: int = int(1e10),
138
160
  ood_type: Literal["feature", "instance"] = "instance",
139
- ) -> Dict[str, np.ndarray]:
161
+ ) -> OODOutput:
140
162
  """
141
163
  Predict whether instances are out-of-distribution or not.
142
164
 
@@ -156,8 +178,8 @@ class OODBase(ABC):
156
178
  self._validate_state(X := to_numpy(X))
157
179
  # compute outlier scores
158
180
  score = self.score(X, batch_size=batch_size)
159
- ood_pred = (score.get(ood_type) > self._threshold_score(ood_type)).astype(int)
160
- return {**{"is_ood": ood_pred}, **score._asdict()}
181
+ ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
182
+ return OODOutput(is_ood=ood_pred, **score._asdict())
161
183
 
162
184
 
163
185
  class OODGMMBase(OODBase):
@@ -165,7 +187,7 @@ class OODGMMBase(OODBase):
165
187
  super().__init__(model)
166
188
  self.gmm_params: GaussianMixtureModelParams
167
189
 
168
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
190
+ def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
169
191
  if additional_attrs is None:
170
192
  additional_attrs = ["gmm_params"]
171
193
  super()._validate_state(X, additional_attrs)
@@ -14,9 +14,10 @@ import numpy as np
14
14
  import tensorflow as tf
15
15
  from keras.layers import Input
16
16
  from keras.models import Model
17
+ from numpy.typing import ArrayLike, NDArray
17
18
 
18
19
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
19
- from dataeval._internal.interop import ArrayLike, to_numpy
20
+ from dataeval._internal.interop import to_numpy
20
21
  from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
21
22
  from dataeval._internal.models.tensorflow.trainer import trainer
22
23
  from dataeval._internal.models.tensorflow.utils import predict_batch
@@ -51,7 +52,7 @@ def build_model(
51
52
 
52
53
 
53
54
  def mutate_categorical(
54
- X: np.ndarray,
55
+ X: NDArray,
55
56
  rate: float,
56
57
  seed: int = 0,
57
58
  feature_range: tuple = (0, 255),
@@ -180,7 +181,7 @@ class OOD_LLR(OODBase):
180
181
 
181
182
  # create background data
182
183
  mutate_fn = partial(mutate_fn, **mutate_fn_kwargs)
183
- X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype)
184
+ X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype) # type: ignore
184
185
 
185
186
  # prepare sequential data
186
187
  if self.sequential and not self.has_log_prob:
@@ -220,10 +221,10 @@ class OOD_LLR(OODBase):
220
221
  def _logp(
221
222
  self,
222
223
  dist,
223
- X: np.ndarray,
224
+ X: NDArray,
224
225
  return_per_feature: bool = False,
225
226
  batch_size: int = int(1e10),
226
- ) -> np.ndarray:
227
+ ) -> NDArray:
227
228
  """
228
229
  Compute log probability of a batch of instances under the generative model.
229
230
  """
@@ -234,10 +235,10 @@ class OOD_LLR(OODBase):
234
235
  def _logp_alt(
235
236
  self,
236
237
  model: keras.Model,
237
- X: np.ndarray,
238
+ X: NDArray,
238
239
  return_per_feature: bool = False,
239
240
  batch_size: int = int(1e10),
240
- ) -> np.ndarray:
241
+ ) -> NDArray:
241
242
  """
242
243
  Compute log probability of a batch of instances with the user defined log_prob function.
243
244
  """
@@ -253,7 +254,7 @@ class OOD_LLR(OODBase):
253
254
  axis = tuple(np.arange(len(logp.shape))[1:])
254
255
  return np.mean(logp, axis=axis)
255
256
 
256
- def _llr(self, X: np.ndarray, return_per_feature: bool, batch_size: int = int(1e10)) -> np.ndarray:
257
+ def _llr(self, X: NDArray, return_per_feature: bool, batch_size: int = int(1e10)) -> NDArray:
257
258
  """
258
259
  Compute likelihood ratios.
259
260
 
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import VAE
17
18
  from dataeval._internal.models.tensorflow.losses import Elbo
18
19
  from dataeval._internal.models.tensorflow.utils import predict_batch
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
17
18
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
18
19
  from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
@@ -1,37 +1,31 @@
1
- from enum import Flag, auto
2
- from typing import Union
1
+ from enum import IntFlag, auto
2
+ from functools import reduce
3
+ from typing import Dict, Iterable, TypeVar, Union, cast
3
4
 
5
+ TFlag = TypeVar("TFlag", bound=IntFlag)
4
6
 
5
- class auto_all:
6
- def __get__(self, _, cls):
7
- return ~cls(0)
8
7
 
8
+ class ImageStat(IntFlag):
9
+ """
10
+ Flags for calculating image and channel statistics
11
+ """
9
12
 
10
- class ImageHash(Flag):
13
+ # HASHES
11
14
  XXHASH = auto()
12
15
  PCHASH = auto()
13
- ALL = auto_all()
14
-
15
-
16
- class ImageProperty(Flag):
16
+ # PROPERTIES
17
17
  WIDTH = auto()
18
18
  HEIGHT = auto()
19
19
  SIZE = auto()
20
20
  ASPECT_RATIO = auto()
21
21
  CHANNELS = auto()
22
22
  DEPTH = auto()
23
- ALL = auto_all()
24
-
25
-
26
- class ImageVisuals(Flag):
23
+ # VISUALS
27
24
  BRIGHTNESS = auto()
28
25
  BLURRINESS = auto()
29
26
  MISSING = auto()
30
27
  ZERO = auto()
31
- ALL = auto_all()
32
-
33
-
34
- class ImageStatistics(Flag):
28
+ # PIXEL STATS
35
29
  MEAN = auto()
36
30
  STD = auto()
37
31
  VAR = auto()
@@ -40,8 +34,35 @@ class ImageStatistics(Flag):
40
34
  ENTROPY = auto()
41
35
  PERCENTILES = auto()
42
36
  HISTOGRAM = auto()
43
- ALL = auto_all()
37
+ # JOINT FLAGS
38
+ ALL_HASHES = XXHASH | PCHASH
39
+ ALL_PROPERTIES = WIDTH | HEIGHT | SIZE | ASPECT_RATIO | CHANNELS | DEPTH
40
+ ALL_VISUALS = BRIGHTNESS | BLURRINESS | MISSING | ZERO
41
+ ALL_PIXELSTATS = MEAN | STD | VAR | SKEW | KURTOSIS | ENTROPY | PERCENTILES | HISTOGRAM
42
+ ALL_STATS = ALL_PROPERTIES | ALL_VISUALS | ALL_PIXELSTATS
43
+ ALL = ALL_HASHES | ALL_STATS
44
+
45
+
46
+ def is_distinct(flag: IntFlag) -> bool:
47
+ return (flag & (flag - 1) == 0) and flag != 0
48
+
49
+
50
+ def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
51
+ """
52
+ Returns a distinct set of all flags set on the input flag and their names
53
+
54
+ NOTE: this is supported natively in Python 3.11, but for earlier versions we need
55
+ to use a combination of list comprehension and bit fiddling to determine distinct
56
+ flag values from joint aliases.
57
+ """
58
+ if isinstance(flag, Iterable): # >= py311
59
+ return {f: f.name.lower() for f in flag if f.name}
60
+ else: # < py311
61
+ return {f: f.name.lower() for f in list(flag.__class__) if f & flag and is_distinct(f) and f.name}
44
62
 
45
63
 
46
- ImageStatsFlags = Union[ImageHash, ImageProperty, ImageVisuals, ImageStatistics]
47
- LinterFlags = Union[ImageProperty, ImageVisuals, ImageStatistics]
64
+ def verify_supported(flag: TFlag, flags: Union[TFlag, Iterable[TFlag]]):
65
+ supported = flags if isinstance(flags, flag.__class__) else cast(TFlag, reduce(lambda a, b: a | b, flags)) # type: ignore
66
+ unsupported = flag & ~supported
67
+ if unsupported:
68
+ raise ValueError(f"Unsupported flags {unsupported} called. Only {supported} flags are supported.")
@@ -1,7 +1,8 @@
1
1
  from importlib import import_module
2
- from typing import Any, Iterable, Optional, runtime_checkable
2
+ from typing import Iterable, Optional
3
3
 
4
4
  import numpy as np
5
+ from numpy.typing import ArrayLike, NDArray
5
6
 
6
7
  module_cache = {}
7
8
 
@@ -19,17 +20,7 @@ def try_import(module_name):
19
20
  return module
20
21
 
21
22
 
22
- try:
23
- from maite.protocols import ArrayLike # type: ignore
24
- except ImportError: # pragma: no cover - covered by test_mindeps.py
25
- from typing import Protocol
26
-
27
- @runtime_checkable
28
- class ArrayLike(Protocol):
29
- def __array__(self) -> Any: ...
30
-
31
-
32
- def to_numpy(array: Optional[ArrayLike]) -> np.ndarray:
23
+ def to_numpy(array: Optional[ArrayLike]) -> NDArray:
33
24
  if array is None:
34
25
  return np.ndarray([])
35
26