dataeval 0.69.3__py3-none-any.whl → 0.70.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/_internal/datasets.py +300 -0
  3. dataeval/_internal/detectors/drift/base.py +5 -6
  4. dataeval/_internal/detectors/drift/mmd.py +3 -3
  5. dataeval/_internal/detectors/duplicates.py +62 -45
  6. dataeval/_internal/detectors/merged_stats.py +23 -54
  7. dataeval/_internal/detectors/ood/ae.py +3 -3
  8. dataeval/_internal/detectors/outliers.py +133 -61
  9. dataeval/_internal/interop.py +11 -7
  10. dataeval/_internal/metrics/balance.py +9 -9
  11. dataeval/_internal/metrics/ber.py +3 -3
  12. dataeval/_internal/metrics/divergence.py +3 -3
  13. dataeval/_internal/metrics/diversity.py +6 -6
  14. dataeval/_internal/metrics/parity.py +24 -16
  15. dataeval/_internal/metrics/stats/base.py +231 -0
  16. dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
  17. dataeval/_internal/metrics/stats/datasetstats.py +97 -0
  18. dataeval/_internal/metrics/stats/dimensionstats.py +111 -0
  19. dataeval/_internal/metrics/stats/hashstats.py +73 -0
  20. dataeval/_internal/metrics/stats/labelstats.py +125 -0
  21. dataeval/_internal/metrics/stats/pixelstats.py +117 -0
  22. dataeval/_internal/metrics/stats/visualstats.py +122 -0
  23. dataeval/_internal/metrics/uap.py +2 -2
  24. dataeval/_internal/metrics/utils.py +28 -13
  25. dataeval/_internal/output.py +3 -18
  26. dataeval/_internal/workflows/sufficiency.py +123 -133
  27. dataeval/metrics/stats/__init__.py +14 -3
  28. dataeval/workflows/__init__.py +2 -2
  29. {dataeval-0.69.3.dist-info → dataeval-0.70.0.dist-info}/METADATA +3 -2
  30. {dataeval-0.69.3.dist-info → dataeval-0.70.0.dist-info}/RECORD +32 -26
  31. {dataeval-0.69.3.dist-info → dataeval-0.70.0.dist-info}/WHEEL +1 -1
  32. dataeval/_internal/flags.py +0 -77
  33. dataeval/_internal/metrics/stats.py +0 -397
  34. dataeval/flags/__init__.py +0 -3
  35. {dataeval-0.69.3.dist-info → dataeval-0.70.0.dist-info}/LICENSE.txt +0 -0
@@ -1,39 +1,45 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Iterable, Literal, Sequence, cast
5
- from warnings import warn
4
+ from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
6
5
 
7
6
  import numpy as np
8
7
  from numpy.typing import ArrayLike, NDArray
9
8
 
10
9
  from dataeval._internal.detectors.merged_stats import combine_stats, get_dataset_step_from_idx
11
- from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
12
- from dataeval._internal.metrics.stats import StatsOutput, imagestats
10
+ from dataeval._internal.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
11
+ from dataeval._internal.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
12
+ from dataeval._internal.metrics.stats.dimensionstats import DimensionStatsOutput
13
+ from dataeval._internal.metrics.stats.pixelstats import PixelStatsOutput
14
+ from dataeval._internal.metrics.stats.visualstats import VisualStatsOutput
13
15
  from dataeval._internal.output import OutputMetadata, set_metadata
14
16
 
15
17
  IndexIssueMap = dict[int, dict[str, float]]
16
- DatasetIndexIssueMap = dict[int, IndexIssueMap]
17
- """
18
- Mapping of image indices to a dictionary of issue types and calculated values
19
- """
18
+ OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
19
+ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
20
20
 
21
21
 
22
22
  @dataclass(frozen=True)
23
- class OutliersOutput(OutputMetadata):
23
+ class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
24
24
  """
25
25
  Attributes
26
26
  ----------
27
- issues : dict[int, dict[str, float]] | dict[int, dict[int, dict[str, float]]]
27
+ issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
28
28
  Indices of image outliers with their associated issue type and calculated values.
29
29
 
30
30
  - For a single dataset, a dictionary containing the indices of outliers and
31
31
  a dictionary showing the issues and calculated values for the given index.
32
- - For multiple datasets, a map of dataset indices to the indices of outliers
33
- and their associated issues and calculated values.
32
+ - For multiple stats outputs, a list of dictionaries containing the indices of
33
+ outliers and their associated issues and calculated values.
34
34
  """
35
35
 
36
- issues: IndexIssueMap | DatasetIndexIssueMap
36
+ issues: TIndexIssueMap
37
+
38
+ def __len__(self):
39
+ if isinstance(self.issues, dict):
40
+ return len(self.issues)
41
+ else:
42
+ return sum(len(d) for d in self.issues)
37
43
 
38
44
 
39
45
  def _get_outlier_mask(
@@ -43,7 +49,7 @@ def _get_outlier_mask(
43
49
  threshold = threshold if threshold else 3.0
44
50
  std = np.std(values)
45
51
  abs_diff = np.abs(values - np.mean(values))
46
- return (abs_diff / std) > threshold
52
+ return std != 0 and (abs_diff / std) > threshold
47
53
  elif method == "modzscore":
48
54
  threshold = threshold if threshold else 3.5
49
55
  abs_diff = np.abs(values - np.median(values))
@@ -65,9 +71,6 @@ class Outliers:
65
71
 
66
72
  Parameters
67
73
  ----------
68
- flags : ImageStat, default ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS
69
- Metric(s) to calculate for each image - calculates all metrics if None
70
- Only supports ImageStat.ALL_STATS
71
74
  outlier_method : ["modzscore" | "zscore" | "iqr"], optional - default "modzscore"
72
75
  Statistical method used to identify outliers
73
76
  outlier_threshold : float, optional - default None
@@ -76,8 +79,8 @@ class Outliers:
76
79
 
77
80
  Attributes
78
81
  ----------
79
- stats : dict[str, Any]
80
- Dictionary to hold the value of each metric for each image
82
+ stats : tuple[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
83
+ Various stats output classes that hold the value of each metric for each image
81
84
 
82
85
  See Also
83
86
  --------
@@ -109,52 +112,61 @@ class Outliers:
109
112
 
110
113
  >>> outliers = Outliers()
111
114
 
112
- Specifying specific metrics to analyze:
113
-
114
- >>> outliers = Outliers(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
115
-
116
115
  Specifying an outlier method:
117
116
 
118
117
  >>> outliers = Outliers(outlier_method="iqr")
119
118
 
120
119
  Specifying an outlier method and threshold:
121
120
 
122
- >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=2.75)
121
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
123
122
  """
124
123
 
125
124
  def __init__(
126
125
  self,
127
- flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
126
+ use_dimension: bool = True,
127
+ use_pixel: bool = True,
128
+ use_visual: bool = True,
128
129
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
129
130
  outlier_threshold: float | None = None,
130
131
  ):
131
- verify_supported(flags, ImageStat.ALL_STATS)
132
- self.flags = flags
132
+ self.stats: DatasetStatsOutput
133
+ self.use_dimension = use_dimension
134
+ self.use_pixel = use_pixel
135
+ self.use_visual = use_visual
133
136
  self.outlier_method: Literal["zscore", "modzscore", "iqr"] = outlier_method
134
137
  self.outlier_threshold = outlier_threshold
135
138
 
136
- def _get_outliers(self) -> dict:
137
- flagged_images = {}
138
- stats_dict = self.stats.dict()
139
- supported = to_distinct(ImageStat.ALL_STATS)
140
- for stat, values in stats_dict.items():
141
- if stat in supported.values() and values.ndim == 1 and np.std(values) != 0:
142
- mask = _get_outlier_mask(values, self.outlier_method, self.outlier_threshold)
139
+ def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
140
+ flagged_images: dict[int, dict[str, float]] = {}
141
+ for stat, values in stats.items():
142
+ if stat in (SOURCE_INDEX, BOX_COUNT):
143
+ continue
144
+ if values.ndim == 1:
145
+ mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
143
146
  indices = np.flatnonzero(mask)
144
147
  for i, value in zip(indices, values[mask]):
145
- flagged_images.setdefault(i, {}).update({stat: np.round(value, 2)})
148
+ flagged_images.setdefault(i, {}).update({stat: value})
146
149
 
147
150
  return dict(sorted(flagged_images.items()))
148
151
 
149
- @set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
150
- def evaluate(self, data: Iterable[ArrayLike] | StatsOutput | Sequence[StatsOutput]) -> OutliersOutput:
152
+ @overload
153
+ def from_stats(self, stats: OutlierStatsOutput | DatasetStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
154
+
155
+ @overload
156
+ def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
157
+
158
+ @set_metadata("dataeval.detectors", ["outlier_method", "outlier_threshold"])
159
+ def from_stats(
160
+ self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
161
+ ) -> OutliersOutput:
151
162
  """
152
163
  Returns indices of outliers with the issues identified for each
153
164
 
154
165
  Parameters
155
166
  ----------
156
- data : Iterable[ArrayLike], shape - (C, H, W) | StatsOutput | Sequence[StatsOutput]
157
- A dataset of images in an ArrayLike format or the output(s) from an imagestats metric analysis
167
+ stats : OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
168
+ The output(s) from a dimensionstats, pixelstats, or visualstats metric
169
+ analysis or an aggregate DatasetStatsOutput
158
170
 
159
171
  Returns
160
172
  -------
@@ -162,36 +174,96 @@ class Outliers:
162
174
  Output class containing the indices of outliers and a dictionary showing
163
175
  the issues and calculated values for the given index.
164
176
 
177
+ See Also
178
+ --------
179
+ dimensionstats
180
+ pixelstats
181
+ visualstats
182
+
165
183
  Example
166
184
  -------
167
185
  Evaluate the dataset:
168
186
 
169
- >>> outliers.evaluate(images)
170
- OutliersOutput(issues={10: {'blurriness': 1.26, 'contrast': 1.06, 'zeros': 0.05}, 12: {'blurriness': 1.51, 'contrast': 1.06, 'zeros': 0.05}})
187
+ >>> results = outliers.from_stats([stats1, stats2])
188
+ >>> len(results)
189
+ 2
190
+ >>> results.issues[0]
191
+ {10: {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}, 12: {'std': 0.00536, 'var': 2.87e-05, 'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}}
192
+ >>> results.issues[1]
193
+ {}
171
194
  """ # noqa: E501
172
- stats, dataset_steps = combine_stats(data)
173
-
174
- if isinstance(stats, StatsOutput):
175
- selected_flags = set(to_distinct(self.flags).values())
176
- provided = set(stats.dict())
177
- missing = selected_flags - provided
178
- if missing:
179
- warn(
180
- f"StatsOutput provided {provided} and is missing {missing} \
181
- from the selected stat flags: {selected_flags}."
195
+ if isinstance(stats, DatasetStatsOutput):
196
+ outliers = self._get_outliers({k: v for o in stats.outputs() for k, v in o.dict().items()})
197
+ return OutliersOutput(outliers)
198
+
199
+ if isinstance(stats, (DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
200
+ return OutliersOutput(self._get_outliers(stats.dict()))
201
+
202
+ if not isinstance(stats, Sequence):
203
+ raise TypeError(
204
+ "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
205
+ )
206
+
207
+ stats_map: dict[type, list[int]] = {}
208
+ for i, stats_output in enumerate(stats):
209
+ if not isinstance(
210
+ stats_output, (DatasetStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
211
+ ):
212
+ raise TypeError(
213
+ "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
182
214
  )
183
- self.stats = stats
184
- else:
185
- self.stats = imagestats(cast(Iterable[ArrayLike], data), self.flags)
186
-
187
- outliers = self._get_outliers()
215
+ stats_map.setdefault(type(stats_output), []).append(i)
188
216
 
189
- # split up results from combined dataset into individual dataset buckets
190
- if dataset_steps:
191
- out_dict = {}
217
+ output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
218
+ for _, indices in stats_map.items():
219
+ substats, dataset_steps = combine_stats([stats[i] for i in indices])
220
+ outliers = self._get_outliers(substats.dict())
192
221
  for idx, issue in outliers.items():
193
222
  k, v = get_dataset_step_from_idx(idx, dataset_steps)
194
- out_dict.setdefault(k, {})[v] = issue
195
- outliers = out_dict
223
+ output_list[indices[k]][v] = issue
224
+
225
+ return OutliersOutput(output_list)
226
+
227
+ @set_metadata(
228
+ "dataeval.detectors",
229
+ [
230
+ "use_dimension",
231
+ "use_pixel",
232
+ "use_visual",
233
+ "outlier_method",
234
+ "outlier_threshold",
235
+ ],
236
+ )
237
+ def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
238
+ """
239
+ Returns indices of outliers with the issues identified for each
196
240
 
241
+ Parameters
242
+ ----------
243
+ data : Iterable[ArrayLike], shape - (C, H, W)
244
+ A dataset of images in an ArrayLike format
245
+
246
+ Returns
247
+ -------
248
+ OutliersOutput
249
+ Output class containing the indices of outliers and a dictionary showing
250
+ the issues and calculated values for the given index.
251
+
252
+ Example
253
+ -------
254
+ Evaluate the dataset:
255
+
256
+ >>> results = outliers.evaluate(images)
257
+ >>> list(results.issues)
258
+ [10, 12]
259
+ >>> results.issues[10]
260
+ {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
261
+ """
262
+ self.stats = datasetstats(
263
+ images=data,
264
+ use_dimension=self.use_dimension,
265
+ use_pixel=self.use_pixel,
266
+ use_visual=self.use_visual,
267
+ )
268
+ outliers = self._get_outliers({k: v for o in self.stats.outputs() for k, v in o.dict().items()})
197
269
  return OutliersOutput(outliers)
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from importlib import import_module
4
- from typing import Iterable
4
+ from typing import Any, Iterable, Iterator
5
5
 
6
6
  import numpy as np
7
7
  from numpy.typing import ArrayLike, NDArray
@@ -22,24 +22,28 @@ def try_import(module_name):
22
22
  return module
23
23
 
24
24
 
25
- def to_numpy(array: ArrayLike | None) -> NDArray:
25
+ def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
26
+ return to_numpy(array, copy=False)
27
+
28
+
29
+ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
26
30
  if array is None:
27
31
  return np.ndarray([])
28
32
 
29
33
  if isinstance(array, np.ndarray):
30
- return array
34
+ return array.copy() if copy else array
31
35
 
32
36
  tf = try_import("tensorflow")
33
37
  if tf and tf.is_tensor(array):
34
- return array.numpy() # type: ignore
38
+ return array.numpy().copy() if copy else array.numpy() # type: ignore
35
39
 
36
40
  torch = try_import("torch")
37
41
  if torch and isinstance(array, torch.Tensor):
38
- return array.detach().cpu().numpy() # type: ignore
42
+ return array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
39
43
 
40
- return np.asarray(array)
44
+ return np.array(array, copy=copy)
41
45
 
42
46
 
43
- def to_numpy_iter(iterable: Iterable[ArrayLike]):
47
+ def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
44
48
  for array in iterable:
45
49
  yield to_numpy(array)
@@ -2,10 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import warnings
4
4
  from dataclasses import dataclass
5
- from typing import Sequence
5
+ from typing import Mapping
6
6
 
7
7
  import numpy as np
8
- from numpy.typing import NDArray
8
+ from numpy.typing import ArrayLike, NDArray
9
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
10
 
11
11
  from dataeval._internal.metrics.utils import entropy, preprocess_metadata
@@ -51,16 +51,16 @@ def validate_num_neighbors(num_neighbors: int) -> int:
51
51
 
52
52
 
53
53
  @set_metadata("dataeval.metrics")
54
- def balance(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
54
+ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
55
55
  """
56
56
  Mutual information (MI) between factors (class label, metadata, label/image properties)
57
57
 
58
58
  Parameters
59
59
  ----------
60
- class_labels: Sequence[int]
60
+ class_labels: ArrayLike
61
61
  List of class labels for each image
62
- metadata: List[Dict]
63
- List of metadata factors for each image
62
+ metadata: Mapping[str, ArrayLike]
63
+ Dict of lists of metadata factors for each image
64
64
  num_neighbors: int, default 5
65
65
  Number of nearest neighbors to use for computing MI between discrete
66
66
  and continuous variables.
@@ -90,9 +90,9 @@ def balance(class_labels: Sequence[int], metadata: list[dict], num_neighbors: in
90
90
  Return intra/interfactor balance (mutual information)
91
91
 
92
92
  >>> bal.factors
93
- array([[0.99999843, 0.03510422, 0.09725766],
94
- [0.03510422, 0.08433558, 0.15621459],
95
- [0.09725766, 0.15621459, 0.99999856]])
93
+ array([[0.99999843, 0.04133555, 0.09725766],
94
+ [0.04133555, 0.08433558, 0.1301489 ],
95
+ [0.09725766, 0.1301489 , 0.99999856]])
96
96
 
97
97
  Return classwise balance (mutual information) of factors with individual class_labels
98
98
 
@@ -17,7 +17,7 @@ from numpy.typing import ArrayLike, NDArray
17
17
  from scipy.sparse import coo_matrix
18
18
  from scipy.stats import mode
19
19
 
20
- from dataeval._internal.interop import to_numpy
20
+ from dataeval._internal.interop import as_numpy
21
21
  from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
22
22
  from dataeval._internal.output import OutputMetadata, set_metadata
23
23
 
@@ -145,7 +145,7 @@ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN",
145
145
  BEROutput(ber=0.04, ber_lower=0.020416847668728033)
146
146
  """
147
147
  ber_fn = get_method(BER_FN_MAP, method)
148
- X = to_numpy(images)
149
- y = to_numpy(labels)
148
+ X = as_numpy(images)
149
+ y = as_numpy(labels)
150
150
  upper, lower = ber_fn(X, y, k) if method == "KNN" else ber_fn(X, y)
151
151
  return BEROutput(upper, lower)
@@ -9,7 +9,7 @@ from typing import Literal
9
9
  import numpy as np
10
10
  from numpy.typing import ArrayLike, NDArray
11
11
 
12
- from dataeval._internal.interop import to_numpy
12
+ from dataeval._internal.interop import as_numpy
13
13
  from dataeval._internal.metrics.utils import compute_neighbors, get_method, minimum_spanning_tree
14
14
  from dataeval._internal.output import OutputMetadata, set_metadata
15
15
 
@@ -123,8 +123,8 @@ def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST
123
123
  DivergenceOutput(divergence=0.28, errors=36.0)
124
124
  """
125
125
  div_fn = get_method(DIVERGENCE_FN_MAP, method)
126
- a = to_numpy(data_a)
127
- b = to_numpy(data_b)
126
+ a = as_numpy(data_a)
127
+ b = as_numpy(data_b)
128
128
  N = a.shape[0]
129
129
  M = b.shape[0]
130
130
 
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Literal, Sequence
4
+ from typing import Literal, Mapping
5
5
 
6
6
  import numpy as np
7
- from numpy.typing import NDArray
7
+ from numpy.typing import ArrayLike, NDArray
8
8
 
9
9
  from dataeval._internal.metrics.utils import entropy, get_counts, get_method, get_num_bins, preprocess_metadata
10
10
  from dataeval._internal.output import OutputMetadata, set_metadata
@@ -142,7 +142,7 @@ DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
142
142
 
143
143
  @set_metadata("dataeval.metrics")
144
144
  def diversity(
145
- class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
145
+ class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
146
146
  ) -> DiversityOutput:
147
147
  """
148
148
  Compute diversity and classwise diversity for discrete/categorical variables and, through standard
@@ -155,10 +155,10 @@ def diversity(
155
155
 
156
156
  Parameters
157
157
  ----------
158
- class_labels: Sequence[int]
158
+ class_labels: ArrayLike
159
159
  List of class labels for each image
160
- metadata: List[Dict]
161
- List of metadata factors for each image
160
+ metadata: Mapping[str, ArrayLike]
161
+ Dict of list of metadata factors for each image
162
162
  method: Literal["shannon", "simpson"], default "simpson"
163
163
  Indicates which diversity index should be computed
164
164
 
@@ -62,8 +62,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
62
62
 
63
63
 
64
64
  def format_discretize_factors(
65
- data_factors: dict[str, NDArray], continuous_factor_bincounts: dict[str, int]
66
- ) -> tuple[dict[str, NDArray], NDArray]:
65
+ data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
66
+ ) -> dict[str, NDArray]:
67
67
  """
68
68
  Sets up the internal list of metadata factors.
69
69
 
@@ -80,10 +80,9 @@ def format_discretize_factors(
80
80
 
81
81
  Returns
82
82
  -------
83
- Tuple[Dict[str, NDArray], NDArray]
83
+ Dict[str, NDArray]
84
84
  - Intrinsic per-image metadata information with the formatting that input data_factors uses.
85
85
  Each key is a metadata factor, whose value is the discrete per-image factor values.
86
- - Per-image labels, whose ith element is the label for the ith element of the dataset.
87
86
  """
88
87
 
89
88
  invalid_keys = set(continuous_factor_bincounts.keys()) - set(data_factors.keys())
@@ -103,8 +102,6 @@ def format_discretize_factors(
103
102
  if lengths[1:] != lengths[:-1]:
104
103
  raise ValueError("The lengths of each entry in the dictionary are not equal." f" Found lengths {lengths}")
105
104
 
106
- labels = data_factors["class"]
107
-
108
105
  metadata_factors = {
109
106
  name: val
110
107
  if name not in continuous_factor_bincounts
@@ -113,7 +110,7 @@ def format_discretize_factors(
113
110
  if name != "class"
114
111
  }
115
112
 
116
- return metadata_factors, labels
113
+ return metadata_factors
117
114
 
118
115
 
119
116
  def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
@@ -187,7 +184,8 @@ def validate_dist(label_dist: NDArray, label_name: str):
187
184
  warnings.warn(
188
185
  f"Labels {np.where(label_dist<5)[0]} in {label_name}"
189
186
  " dataset have frequencies less than 5. This may lead"
190
- " to invalid chi-squared evaluation."
187
+ " to invalid chi-squared evaluation.",
188
+ UserWarning,
191
189
  )
192
190
 
193
191
 
@@ -280,8 +278,9 @@ def label_parity(
280
278
 
281
279
  @set_metadata("dataeval.metrics")
282
280
  def parity(
281
+ class_labels: ArrayLike,
283
282
  data_factors: Mapping[str, ArrayLike],
284
- continuous_factor_bincounts: dict[str, int] | None = None,
283
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
285
284
  ) -> ParityOutput[NDArray[np.float64]]:
286
285
  """
287
286
  Calculate chi-square statistics to assess the relationship between multiple factors and class labels.
@@ -292,10 +291,12 @@ def parity(
292
291
 
293
292
  Parameters
294
293
  ----------
294
+ class_labels: ArrayLike
295
+ List of class labels for each image
295
296
  data_factors: Mapping[str, ArrayLike]
296
- The dataset factors, which are per-image attributes including class label and metadata.
297
+ The dataset factors, which are per-image metadata attributes.
297
298
  Each key of dataset_factors is a factor, whose value is the per-image factor values.
298
- continuous_factor_bincounts : Dict[str, int] | None, default None
299
+ continuous_factor_bincounts : Mapping[str, int] | None, default None
299
300
  A dictionary specifying the number of bins for discretizing the continuous factors.
300
301
  The keys should correspond to the names of continuous factors in `data_factors`,
301
302
  and the values should be the number of bins to use for discretization.
@@ -329,21 +330,27 @@ def parity(
329
330
  --------
330
331
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
331
332
 
333
+ >>> labels = np_random_gen.choice([0, 1, 2], (100))
332
334
  >>> data_factors = {
333
335
  ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
334
336
  ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
335
337
  ... "gender": np_random_gen.choice(["M", "F"], (100)),
336
- ... "class": np_random_gen.choice([0, 1, 2], (100)),
337
338
  ... }
338
339
  >>> continuous_factor_bincounts = {"age": 4, "income": 3}
339
- >>> parity(data_factors, continuous_factor_bincounts)
340
- ParityOutput(score=array([2.82329785, 1.60625584, 1.38377236]), p_value=array([0.83067563, 0.80766733, 0.5006309 ]))
340
+ >>> parity(labels, data_factors, continuous_factor_bincounts)
341
+ ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]))
341
342
  """
343
+ if len(np.shape(class_labels)) > 1:
344
+ raise ValueError(
345
+ f"Got class labels with {len(np.shape(class_labels))}-dimensional",
346
+ f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
347
+ )
342
348
 
343
349
  data_factors_np = {k: to_numpy(v) for k, v in data_factors.items()}
344
350
  continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
345
351
 
346
- factors, labels = format_discretize_factors(data_factors_np, continuous_factor_bincounts)
352
+ labels = to_numpy(class_labels)
353
+ factors = format_discretize_factors(data_factors_np, continuous_factor_bincounts)
347
354
 
348
355
  chi_scores = np.zeros(len(factors))
349
356
  p_values = np.zeros(len(factors))
@@ -396,7 +403,8 @@ def parity(
396
403
  message = "\n".join(factor_msg)
397
404
 
398
405
  warnings.warn(
399
- f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \nRecommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}", # noqa: E501
406
+ f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \n\
407
+ Recommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}",
400
408
  UserWarning,
401
409
  )
402
410