dataeval 0.64.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dataeval/__init__.py +2 -2
  2. dataeval/_internal/detectors/clusterer.py +46 -34
  3. dataeval/_internal/detectors/drift/base.py +52 -35
  4. dataeval/_internal/detectors/drift/cvm.py +4 -4
  5. dataeval/_internal/detectors/drift/ks.py +6 -6
  6. dataeval/_internal/detectors/drift/mmd.py +35 -16
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -7
  9. dataeval/_internal/detectors/duplicates.py +55 -29
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/base.py +36 -15
  12. dataeval/_internal/detectors/ood/llr.py +7 -7
  13. dataeval/_internal/flags.py +42 -21
  14. dataeval/_internal/interop.py +2 -2
  15. dataeval/_internal/metrics/balance.py +10 -2
  16. dataeval/_internal/metrics/ber.py +6 -5
  17. dataeval/_internal/metrics/coverage.py +15 -8
  18. dataeval/_internal/metrics/divergence.py +41 -7
  19. dataeval/_internal/metrics/diversity.py +17 -12
  20. dataeval/_internal/metrics/parity.py +30 -43
  21. dataeval/_internal/metrics/stats.py +196 -317
  22. dataeval/_internal/metrics/uap.py +5 -2
  23. dataeval/_internal/metrics/utils.py +70 -33
  24. dataeval/_internal/models/tensorflow/losses.py +3 -3
  25. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  26. dataeval/_internal/models/tensorflow/utils.py +4 -3
  27. dataeval/_internal/output.py +82 -0
  28. dataeval/_internal/workflows/sufficiency.py +96 -107
  29. dataeval/flags/__init__.py +2 -2
  30. dataeval/metrics/__init__.py +3 -3
  31. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  32. dataeval-0.65.0.dist-info/RECORD +60 -0
  33. dataeval/_internal/metrics/base.py +0 -10
  34. dataeval-0.64.0.dist-info/RECORD +0 -60
  35. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  36. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -1,9 +1,27 @@
1
- from typing import Dict, Iterable, List, Literal
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Iterable, List
2
3
 
3
4
  from numpy.typing import ArrayLike
4
5
 
5
- from dataeval._internal.flags import ImageHash
6
- from dataeval._internal.metrics.stats import ImageStats
6
+ from dataeval._internal.metrics.stats import StatsOutput
7
+ from dataeval._internal.output import OutputMetadata, set_metadata
8
+ from dataeval.flags import ImageStat
9
+ from dataeval.metrics import imagestats
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class DuplicatesOutput(OutputMetadata):
14
+ """
15
+ Attributes
16
+ ----------
17
+ exact : List[List[int]]
18
+ Indices of images that are exact matches
19
+ near: List[List[int]]
20
+ Indices of images that are near matches
21
+ """
22
+
23
+ exact: List[List[int]]
24
+ near: List[List[int]]
7
25
 
8
26
 
9
27
  class Duplicates:
@@ -13,8 +31,8 @@ class Duplicates:
13
31
 
14
32
  Attributes
15
33
  ----------
16
- stats : ImageStats(flags=ImageHash.ALL)
17
- Base stats class with the flags for checking duplicates
34
+ stats : StatsOutput
35
+ Output class of stats
18
36
 
19
37
  Example
20
38
  -------
@@ -23,25 +41,36 @@ class Duplicates:
23
41
  >>> dups = Duplicates()
24
42
  """
25
43
 
26
- def __init__(self):
27
- self.stats = ImageStats(ImageHash.ALL)
44
+ def __init__(self, find_exact: bool = True, find_near: bool = True):
45
+ self.stats: StatsOutput
46
+ self.find_exact = find_exact
47
+ self.find_near = find_near
28
48
 
29
- def _get_duplicates(self) -> dict:
30
- exact = {}
31
- near = {}
32
- for i, value in enumerate(self.results["xxhash"]):
33
- exact.setdefault(value, []).append(i)
34
- for i, value in enumerate(self.results["pchash"]):
35
- near.setdefault(value, []).append(i)
36
- exact = [v for v in exact.values() if len(v) > 1]
37
- near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
49
+ def _get_duplicates(self) -> Dict[str, List[List[int]]]:
50
+ stats_dict = self.stats.dict()
51
+ if "xxhash" in stats_dict:
52
+ exact = {}
53
+ for i, value in enumerate(stats_dict["xxhash"]):
54
+ exact.setdefault(value, []).append(i)
55
+ exact = [v for v in exact.values() if len(v) > 1]
56
+ else:
57
+ exact = []
58
+
59
+ if "pchash" in stats_dict:
60
+ near = {}
61
+ for i, value in enumerate(stats_dict["pchash"]):
62
+ near.setdefault(value, []).append(i)
63
+ near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
64
+ else:
65
+ near = []
38
66
 
39
67
  return {
40
68
  "exact": sorted(exact),
41
69
  "near": sorted(near),
42
70
  }
43
71
 
44
- def evaluate(self, images: Iterable[ArrayLike]) -> Dict[Literal["exact", "near"], List[int]]:
72
+ @set_metadata("dataeval.detectors", ["find_exact", "find_near"])
73
+ def evaluate(self, images: Iterable[ArrayLike]) -> DuplicatesOutput:
45
74
  """
46
75
  Returns duplicate image indices for both exact matches and near matches
47
76
 
@@ -52,22 +81,19 @@ class Duplicates:
52
81
 
53
82
  Returns
54
83
  -------
55
- Dict[str, List[int]]
56
- exact :
57
- List of groups of indices that are exact matches
58
- near :
59
- List of groups of indices that are near matches
84
+ DuplicatesOutput
85
+ List of groups of indices that are exact and near matches
60
86
 
61
87
  See Also
62
88
  --------
63
- ImageStats
89
+ imagestats
64
90
 
65
91
  Example
66
92
  -------
67
93
  >>> dups.evaluate(images)
68
- {'exact': [[3, 20], [16, 37]], 'near': [[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]]}
69
- """
70
- self.stats.reset()
71
- self.stats.update(images)
72
- self.results = self.stats.compute()
73
- return self._get_duplicates()
94
+ DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
95
+ """ # noqa: E501
96
+ flag_exact = ImageStat.XXHASH if self.find_exact else ImageStat(0)
97
+ flag_near = ImageStat.PCHASH if self.find_near else ImageStat(0)
98
+ self.stats = imagestats(images, flag_exact | flag_near)
99
+ return DuplicatesOutput(**self._get_duplicates())
@@ -1,15 +1,31 @@
1
- from typing import Iterable, Literal, Optional, Sequence, Union
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Iterable, Literal, Optional
2
3
 
3
4
  import numpy as np
4
- from numpy.typing import ArrayLike
5
+ from numpy.typing import ArrayLike, NDArray
5
6
 
6
- from dataeval._internal.flags import ImageProperty, ImageVisuals, LinterFlags
7
- from dataeval._internal.metrics.stats import ImageStats
7
+ from dataeval._internal.flags import verify_supported
8
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
+ from dataeval.flags import ImageStat
10
+ from dataeval.metrics import imagestats
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class LinterOutput(OutputMetadata):
15
+ """
16
+ Attributes
17
+ ----------
18
+ issues : Dict[int, Dict[str, float]]
19
+ Dictionary containing the indices of outliers and a dictionary showing
20
+ the issues and calculated values for the given index.
21
+ """
22
+
23
+ issues: Dict[int, Dict[str, float]]
8
24
 
9
25
 
10
26
  def _get_outlier_mask(
11
- values: np.ndarray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
12
- ) -> np.ndarray:
27
+ values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
28
+ ) -> NDArray:
13
29
  if method == "zscore":
14
30
  threshold = threshold if threshold else 3.0
15
31
  std = np.std(values)
@@ -18,7 +34,7 @@ def _get_outlier_mask(
18
34
  elif method == "modzscore":
19
35
  threshold = threshold if threshold else 3.5
20
36
  abs_diff = np.abs(values - np.median(values))
21
- med_abs_diff = np.median(abs_diff)
37
+ med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
22
38
  mod_z_score = 0.6745 * abs_diff / med_abs_diff
23
39
  return mod_z_score > threshold
24
40
  elif method == "iqr":
@@ -36,8 +52,9 @@ class Linter:
36
52
 
37
53
  Parameters
38
54
  ----------
39
- flags : [ImageProperty | ImageStatistics | ImageVisuals], default None
55
+ flags : ImageStat, default ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS
40
56
  Metric(s) to calculate for each image - calculates all metrics if None
57
+ Only supports ImageStat.ALL_STATS
41
58
  outlier_method : ["modzscore" | "zscore" | "iqr"], optional - default "modzscore"
42
59
  Statistical method used to identify outliers
43
60
  outlier_threshold : float, optional - default None
@@ -46,8 +63,8 @@ class Linter:
46
63
 
47
64
  Attributes
48
65
  ----------
49
- stats : ImageStats
50
- Class to hold the value of each metric for each image
66
+ stats : Dict[str, Any]
67
+ Dictionary to hold the value of each metric for each image
51
68
 
52
69
  See Also
53
70
  --------
@@ -81,7 +98,7 @@ class Linter:
81
98
 
82
99
  Specifying specific metrics to analyze:
83
100
 
84
- >>> lint = Linter(flags=[ImageProperty.SIZE, ImageVisuals.ALL])
101
+ >>> lint = Linter(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
85
102
 
86
103
  Specifying an outlier method:
87
104
 
@@ -94,19 +111,19 @@ class Linter:
94
111
 
95
112
  def __init__(
96
113
  self,
97
- flags: Optional[Union[LinterFlags, Sequence[LinterFlags]]] = None,
114
+ flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
98
115
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
99
116
  outlier_threshold: Optional[float] = None,
100
117
  ):
101
- flags = flags if flags is not None else (ImageProperty.ALL, ImageVisuals.ALL)
102
- self.stats = ImageStats(flags)
118
+ verify_supported(flags, ImageStat.ALL_STATS)
119
+ self.flags = flags
103
120
  self.outlier_method: Literal["zscore", "modzscore", "iqr"] = outlier_method
104
121
  self.outlier_threshold = outlier_threshold
105
122
 
106
123
  def _get_outliers(self) -> dict:
107
124
  flagged_images = {}
108
-
109
- for stat, values in self.results.items():
125
+ stats_dict = self.stats.dict()
126
+ for stat, values in stats_dict.items():
110
127
  if not isinstance(values, np.ndarray):
111
128
  continue
112
129
 
@@ -118,7 +135,8 @@ class Linter:
118
135
 
119
136
  return dict(sorted(flagged_images.items()))
120
137
 
121
- def evaluate(self, images: Iterable[ArrayLike]) -> dict:
138
+ @set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
139
+ def evaluate(self, images: Iterable[ArrayLike]) -> LinterOutput:
122
140
  """
123
141
  Returns indices of outliers with the issues identified for each
124
142
 
@@ -130,8 +148,8 @@ class Linter:
130
148
 
131
149
  Returns
132
150
  -------
133
- Dict[int, Dict[str, float]]
134
- Dictionary containing the indices of outliers and a dictionary showing
151
+ LinterOutput
152
+ Output class containing the indices of outliers and a dictionary showing
135
153
  the issues and calculated values for the given index.
136
154
 
137
155
  Example
@@ -139,9 +157,7 @@ class Linter:
139
157
  Evaluate the dataset:
140
158
 
141
159
  >>> lint.evaluate(images)
142
- {18: {'brightness': 0.78}, 25: {'brightness': 0.98}}
160
+ LinterOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
143
161
  """
144
- self.stats.reset()
145
- self.stats.update(images)
146
- self.results = self.stats.compute()
147
- return self._get_outliers()
162
+ self.stats = imagestats(images, self.flags)
163
+ return LinterOutput(self._get_outliers())
@@ -7,16 +7,36 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
10
+ from dataclasses import dataclass
11
+ from typing import Callable, List, Literal, NamedTuple, Optional, Tuple, cast
11
12
 
12
13
  import keras
13
14
  import numpy as np
14
15
  import tensorflow as tf
15
- from numpy.typing import ArrayLike
16
+ from numpy.typing import ArrayLike, NDArray
16
17
 
17
18
  from dataeval._internal.interop import to_numpy
18
19
  from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
19
20
  from dataeval._internal.models.tensorflow.trainer import trainer
21
+ from dataeval._internal.output import OutputMetadata, set_metadata
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class OODOutput(OutputMetadata):
26
+ """
27
+ Attributes
28
+ ----------
29
+ is_ood : NDArray[np.bool_]
30
+ Array of images that are detected as out of distribution
31
+ instance_score : NDArray[np.float32]
32
+ Instance score of the evaluated dataset
33
+ feature_score : Optional[NDArray[np.float32]]
34
+ Feature score, if available, of the evaluated dataset
35
+ """
36
+
37
+ is_ood: NDArray[np.bool_]
38
+ instance_score: NDArray[np.float32]
39
+ feature_score: Optional[NDArray[np.float32]]
20
40
 
21
41
 
22
42
  class OODScore(NamedTuple):
@@ -25,16 +45,16 @@ class OODScore(NamedTuple):
25
45
 
26
46
  Parameters
27
47
  ----------
28
- instance_score : np.ndarray
48
+ instance_score : NDArray[np.float32]
29
49
  Instance score of the evaluated dataset.
30
- feature_score : Optional[np.ndarray], default None
50
+ feature_score : Optional[NDArray[np.float32]], default None
31
51
  Feature score, if available, of the evaluated dataset.
32
52
  """
33
53
 
34
- instance_score: np.ndarray
35
- feature_score: Optional[np.ndarray] = None
54
+ instance_score: NDArray[np.float32]
55
+ feature_score: Optional[NDArray[np.float32]] = None
36
56
 
37
- def get(self, ood_type: Literal["instance", "feature"]) -> np.ndarray:
57
+ def get(self, ood_type: Literal["instance", "feature"]) -> NDArray:
38
58
  return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
39
59
 
40
60
 
@@ -49,18 +69,18 @@ class OODBase(ABC):
49
69
  if not isinstance(model, keras.Model):
50
70
  raise TypeError("Model should be of type 'keras.Model'.")
51
71
 
52
- def _get_data_info(self, X: np.ndarray) -> Tuple[tuple, type]:
72
+ def _get_data_info(self, X: NDArray) -> Tuple[tuple, type]:
53
73
  if not isinstance(X, np.ndarray):
54
- raise TypeError("Dataset should of type: `np.ndarray`.")
74
+ raise TypeError("Dataset should of type: `NDArray`.")
55
75
  return X.shape[1:], X.dtype.type
56
76
 
57
- def _validate(self, X: np.ndarray) -> None:
77
+ def _validate(self, X: NDArray) -> None:
58
78
  check_data_info = self._get_data_info(X)
59
79
  if self._data_info is not None and check_data_info != self._data_info:
60
80
  raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
61
81
  Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
62
82
 
63
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
83
+ def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
64
84
  attrs = ["_data_info", "_threshold_perc", "_ref_score"]
65
85
  attrs = attrs if additional_attrs is None else attrs + additional_attrs
66
86
  if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
@@ -132,12 +152,13 @@ class OODBase(ABC):
132
152
  self._ref_score = self.score(x_ref, batch_size)
133
153
  self._threshold_perc = threshold_perc
134
154
 
155
+ @set_metadata("dataeval.detectors")
135
156
  def predict(
136
157
  self,
137
158
  X: ArrayLike,
138
159
  batch_size: int = int(1e10),
139
160
  ood_type: Literal["feature", "instance"] = "instance",
140
- ) -> Dict[str, np.ndarray]:
161
+ ) -> OODOutput:
141
162
  """
142
163
  Predict whether instances are out-of-distribution or not.
143
164
 
@@ -157,8 +178,8 @@ class OODBase(ABC):
157
178
  self._validate_state(X := to_numpy(X))
158
179
  # compute outlier scores
159
180
  score = self.score(X, batch_size=batch_size)
160
- ood_pred = (score.get(ood_type) > self._threshold_score(ood_type)).astype(int)
161
- return {**{"is_ood": ood_pred}, **score._asdict()}
181
+ ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
182
+ return OODOutput(is_ood=ood_pred, **score._asdict())
162
183
 
163
184
 
164
185
  class OODGMMBase(OODBase):
@@ -166,7 +187,7 @@ class OODGMMBase(OODBase):
166
187
  super().__init__(model)
167
188
  self.gmm_params: GaussianMixtureModelParams
168
189
 
169
- def _validate_state(self, X: np.ndarray, additional_attrs: Optional[List[str]] = None) -> None:
190
+ def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
170
191
  if additional_attrs is None:
171
192
  additional_attrs = ["gmm_params"]
172
193
  super()._validate_state(X, additional_attrs)
@@ -14,7 +14,7 @@ import numpy as np
14
14
  import tensorflow as tf
15
15
  from keras.layers import Input
16
16
  from keras.models import Model
17
- from numpy.typing import ArrayLike
17
+ from numpy.typing import ArrayLike, NDArray
18
18
 
19
19
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
20
20
  from dataeval._internal.interop import to_numpy
@@ -52,7 +52,7 @@ def build_model(
52
52
 
53
53
 
54
54
  def mutate_categorical(
55
- X: np.ndarray,
55
+ X: NDArray,
56
56
  rate: float,
57
57
  seed: int = 0,
58
58
  feature_range: tuple = (0, 255),
@@ -221,10 +221,10 @@ class OOD_LLR(OODBase):
221
221
  def _logp(
222
222
  self,
223
223
  dist,
224
- X: np.ndarray,
224
+ X: NDArray,
225
225
  return_per_feature: bool = False,
226
226
  batch_size: int = int(1e10),
227
- ) -> np.ndarray:
227
+ ) -> NDArray:
228
228
  """
229
229
  Compute log probability of a batch of instances under the generative model.
230
230
  """
@@ -235,10 +235,10 @@ class OOD_LLR(OODBase):
235
235
  def _logp_alt(
236
236
  self,
237
237
  model: keras.Model,
238
- X: np.ndarray,
238
+ X: NDArray,
239
239
  return_per_feature: bool = False,
240
240
  batch_size: int = int(1e10),
241
- ) -> np.ndarray:
241
+ ) -> NDArray:
242
242
  """
243
243
  Compute log probability of a batch of instances with the user defined log_prob function.
244
244
  """
@@ -254,7 +254,7 @@ class OOD_LLR(OODBase):
254
254
  axis = tuple(np.arange(len(logp.shape))[1:])
255
255
  return np.mean(logp, axis=axis)
256
256
 
257
- def _llr(self, X: np.ndarray, return_per_feature: bool, batch_size: int = int(1e10)) -> np.ndarray:
257
+ def _llr(self, X: NDArray, return_per_feature: bool, batch_size: int = int(1e10)) -> NDArray:
258
258
  """
259
259
  Compute likelihood ratios.
260
260
 
@@ -1,37 +1,31 @@
1
- from enum import Flag, auto
2
- from typing import Union
1
+ from enum import IntFlag, auto
2
+ from functools import reduce
3
+ from typing import Dict, Iterable, TypeVar, Union, cast
3
4
 
5
+ TFlag = TypeVar("TFlag", bound=IntFlag)
4
6
 
5
- class auto_all:
6
- def __get__(self, _, cls):
7
- return ~cls(0)
8
7
 
8
+ class ImageStat(IntFlag):
9
+ """
10
+ Flags for calculating image and channel statistics
11
+ """
9
12
 
10
- class ImageHash(Flag):
13
+ # HASHES
11
14
  XXHASH = auto()
12
15
  PCHASH = auto()
13
- ALL = auto_all()
14
-
15
-
16
- class ImageProperty(Flag):
16
+ # PROPERTIES
17
17
  WIDTH = auto()
18
18
  HEIGHT = auto()
19
19
  SIZE = auto()
20
20
  ASPECT_RATIO = auto()
21
21
  CHANNELS = auto()
22
22
  DEPTH = auto()
23
- ALL = auto_all()
24
-
25
-
26
- class ImageVisuals(Flag):
23
+ # VISUALS
27
24
  BRIGHTNESS = auto()
28
25
  BLURRINESS = auto()
29
26
  MISSING = auto()
30
27
  ZERO = auto()
31
- ALL = auto_all()
32
-
33
-
34
- class ImageStatistics(Flag):
28
+ # PIXEL STATS
35
29
  MEAN = auto()
36
30
  STD = auto()
37
31
  VAR = auto()
@@ -40,8 +34,35 @@ class ImageStatistics(Flag):
40
34
  ENTROPY = auto()
41
35
  PERCENTILES = auto()
42
36
  HISTOGRAM = auto()
43
- ALL = auto_all()
37
+ # JOINT FLAGS
38
+ ALL_HASHES = XXHASH | PCHASH
39
+ ALL_PROPERTIES = WIDTH | HEIGHT | SIZE | ASPECT_RATIO | CHANNELS | DEPTH
40
+ ALL_VISUALS = BRIGHTNESS | BLURRINESS | MISSING | ZERO
41
+ ALL_PIXELSTATS = MEAN | STD | VAR | SKEW | KURTOSIS | ENTROPY | PERCENTILES | HISTOGRAM
42
+ ALL_STATS = ALL_PROPERTIES | ALL_VISUALS | ALL_PIXELSTATS
43
+ ALL = ALL_HASHES | ALL_STATS
44
+
45
+
46
+ def is_distinct(flag: IntFlag) -> bool:
47
+ return (flag & (flag - 1) == 0) and flag != 0
48
+
49
+
50
+ def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
51
+ """
52
+ Returns a distinct set of all flags set on the input flag and their names
53
+
54
+ NOTE: this is supported natively in Python 3.11, but for earlier versions we need
55
+ to use a combination of list comprehension and bit fiddling to determine distinct
56
+ flag values from joint aliases.
57
+ """
58
+ if isinstance(flag, Iterable): # >= py311
59
+ return {f: f.name.lower() for f in flag if f.name}
60
+ else: # < py311
61
+ return {f: f.name.lower() for f in list(flag.__class__) if f & flag and is_distinct(f) and f.name}
44
62
 
45
63
 
46
- ImageStatsFlags = Union[ImageHash, ImageProperty, ImageVisuals, ImageStatistics]
47
- LinterFlags = Union[ImageProperty, ImageVisuals, ImageStatistics]
64
+ def verify_supported(flag: TFlag, flags: Union[TFlag, Iterable[TFlag]]):
65
+ supported = flags if isinstance(flags, flag.__class__) else cast(TFlag, reduce(lambda a, b: a | b, flags)) # type: ignore
66
+ unsupported = flag & ~supported
67
+ if unsupported:
68
+ raise ValueError(f"Unsupported flags {unsupported} called. Only {supported} flags are supported.")
@@ -2,7 +2,7 @@ from importlib import import_module
2
2
  from typing import Iterable, Optional
3
3
 
4
4
  import numpy as np
5
- from numpy.typing import ArrayLike
5
+ from numpy.typing import ArrayLike, NDArray
6
6
 
7
7
  module_cache = {}
8
8
 
@@ -20,7 +20,7 @@ def try_import(module_name):
20
20
  return module
21
21
 
22
22
 
23
- def to_numpy(array: Optional[ArrayLike]) -> np.ndarray:
23
+ def to_numpy(array: Optional[ArrayLike]) -> NDArray:
24
24
  if array is None:
25
25
  return np.ndarray([])
26
26
 
@@ -1,14 +1,17 @@
1
1
  import warnings
2
- from typing import Dict, List, NamedTuple, Sequence
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Sequence
3
4
 
4
5
  import numpy as np
5
6
  from numpy.typing import NDArray
6
7
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
7
8
 
8
9
  from dataeval._internal.metrics.utils import entropy, preprocess_metadata
10
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
11
 
10
12
 
11
- class BalanceOutput(NamedTuple):
13
+ @dataclass(frozen=True)
14
+ class BalanceOutput(OutputMetadata):
12
15
  """
13
16
  Attributes
14
17
  ----------
@@ -39,6 +42,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
39
42
  return num_neighbors
40
43
 
41
44
 
45
+ @set_metadata("dataeval.metrics")
42
46
  def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
43
47
  """
44
48
  Mutual information (MI) between factors (class label, metadata, label/image properties)
@@ -83,6 +87,9 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
83
87
  tgt = data[:, idx]
84
88
 
85
89
  if is_categorical[idx]:
90
+ if tgt.dtype == float:
91
+ # map to unique integers if categorical
92
+ _, tgt = np.unique(tgt, return_inverse=True)
86
93
  # categorical target
87
94
  mi[idx, :] = mutual_info_classif(
88
95
  data,
@@ -107,6 +114,7 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
107
114
  return BalanceOutput(nmi)
108
115
 
109
116
 
117
+ @set_metadata("dataeval.metrics")
110
118
  def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
111
119
  """
112
120
  Compute mutual information (analogous to correlation) between metadata factors
@@ -7,7 +7,8 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
7
7
  https://arxiv.org/abs/1811.06419
8
8
  """
9
9
 
10
- from typing import Literal, NamedTuple, Tuple
10
+ from dataclasses import dataclass
11
+ from typing import Literal, Tuple
11
12
 
12
13
  import numpy as np
13
14
  from numpy.typing import ArrayLike, NDArray
@@ -16,9 +17,11 @@ from scipy.stats import mode
16
17
 
17
18
  from dataeval._internal.interop import to_numpy
18
19
  from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
20
+ from dataeval._internal.output import OutputMetadata, set_metadata
19
21
 
20
22
 
21
- class BEROutput(NamedTuple):
23
+ @dataclass(frozen=True)
24
+ class BEROutput(OutputMetadata):
22
25
  """
23
26
  Attributes
24
27
  ----------
@@ -73,9 +76,6 @@ def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
73
76
  The upper and lower bounds of the bayes error rate
74
77
  """
75
78
  M, N = get_classes_counts(y)
76
-
77
- # All features belong on second dimension
78
- X = X.reshape((X.shape[0], -1))
79
79
  nn_indices = compute_neighbors(X, X, k=k)
80
80
  nn_indices = np.expand_dims(nn_indices, axis=1) if nn_indices.ndim == 1 else nn_indices
81
81
  modal_class = mode(y[nn_indices], axis=1, keepdims=True).mode.squeeze()
@@ -107,6 +107,7 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
107
107
  BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
108
108
 
109
109
 
110
+ @set_metadata("dataeval.metrics")
110
111
  def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
111
112
  """
112
113
  An estimator for Multi-class Bayes Error Rate using FR or KNN test statistic basis