dataeval 0.65.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +24 -22
  3. dataeval/_internal/detectors/drift/base.py +206 -26
  4. dataeval/_internal/detectors/drift/cvm.py +25 -23
  5. dataeval/_internal/detectors/drift/ks.py +28 -25
  6. dataeval/_internal/detectors/drift/mmd.py +30 -29
  7. dataeval/_internal/detectors/drift/torch.py +66 -58
  8. dataeval/_internal/detectors/drift/uncertainty.py +28 -28
  9. dataeval/_internal/detectors/duplicates.py +28 -18
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +61 -43
  13. dataeval/_internal/detectors/ood/llr.py +27 -24
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
  17. dataeval/_internal/flags.py +5 -3
  18. dataeval/_internal/interop.py +4 -2
  19. dataeval/_internal/metrics/balance.py +33 -4
  20. dataeval/_internal/metrics/ber.py +6 -4
  21. dataeval/_internal/metrics/diversity.py +45 -12
  22. dataeval/_internal/metrics/parity.py +114 -26
  23. dataeval/_internal/metrics/stats.py +154 -16
  24. dataeval/_internal/metrics/uap.py +28 -2
  25. dataeval/_internal/metrics/utils.py +20 -18
  26. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  27. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  28. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  29. dataeval/_internal/models/tensorflow/losses.py +15 -11
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  31. dataeval/_internal/models/tensorflow/trainer.py +8 -6
  32. dataeval/_internal/models/tensorflow/utils.py +21 -19
  33. dataeval/_internal/output.py +13 -10
  34. dataeval/_internal/utils.py +5 -3
  35. dataeval/_internal/workflows/sufficiency.py +42 -30
  36. dataeval/detectors/__init__.py +6 -25
  37. dataeval/detectors/drift/__init__.py +16 -0
  38. dataeval/detectors/drift/kernels/__init__.py +6 -0
  39. dataeval/detectors/drift/updates/__init__.py +3 -0
  40. dataeval/detectors/linters/__init__.py +5 -0
  41. dataeval/detectors/ood/__init__.py +11 -0
  42. dataeval/metrics/__init__.py +2 -26
  43. dataeval/metrics/bias/__init__.py +14 -0
  44. dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval/metrics/stats/__init__.py +6 -0
  46. dataeval/tensorflow/__init__.py +3 -0
  47. dataeval/tensorflow/loss/__init__.py +3 -0
  48. dataeval/tensorflow/models/__init__.py +5 -0
  49. dataeval/tensorflow/recon/__init__.py +3 -0
  50. dataeval/torch/__init__.py +3 -0
  51. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  52. dataeval/torch/trainer/__init__.py +3 -0
  53. dataeval/utils/__init__.py +3 -6
  54. dataeval/workflows/__init__.py +2 -4
  55. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  56. dataeval-0.66.0.dist-info/RECORD +72 -0
  57. dataeval/models/__init__.py +0 -15
  58. dataeval/models/tensorflow/__init__.py +0 -6
  59. dataeval-0.65.0.dist-info/RECORD +0 -60
  60. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  61. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import warnings
2
4
  from dataclasses import dataclass
3
- from typing import Dict, Generic, Mapping, Optional, Tuple, TypeVar
5
+ from typing import Generic, Mapping, TypeVar
4
6
 
5
7
  import numpy as np
6
8
  from numpy.typing import ArrayLike, NDArray
@@ -44,8 +46,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
44
46
  -------
45
47
  NDArray
46
48
  The digitized values
47
-
48
49
  """
50
+
49
51
  if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
50
52
  raise TypeError(
51
53
  f"Encountered a non-numeric value for factor {factor_name}, but the factor"
@@ -60,8 +62,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
60
62
 
61
63
 
62
64
  def format_discretize_factors(
63
- data_factors: Dict[str, NDArray], continuous_factor_bincounts: Dict[str, int]
64
- ) -> Tuple[Dict[str, NDArray], NDArray]:
65
+ data_factors: dict[str, NDArray], continuous_factor_bincounts: dict[str, int]
66
+ ) -> tuple[dict[str, NDArray], NDArray]:
65
67
  """
66
68
  Sets up the internal list of metadata factors.
67
69
 
@@ -83,6 +85,7 @@ def format_discretize_factors(
83
85
  Each key is a metadata factor, whose value is the discrete per-image factor values.
84
86
  - Per-image labels, whose ith element is the label for the ith element of the dataset.
85
87
  """
88
+
86
89
  invalid_keys = set(continuous_factor_bincounts.keys()) - set(data_factors.keys())
87
90
  if invalid_keys:
88
91
  raise KeyError(
@@ -114,6 +117,35 @@ def format_discretize_factors(
114
117
 
115
118
 
116
119
  def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
120
+ """
121
+ Normalize the expected label distribution to match the total number of labels in the observed distribution.
122
+
123
+ This function adjusts the expected distribution so that its sum equals the sum of the observed distribution.
124
+ If the expected distribution is all zeros, an error is raised.
125
+
126
+ Parameters
127
+ ----------
128
+ expected_dist : np.ndarray
129
+ The expected label distribution. This array represents the anticipated distribution of labels.
130
+ observed_dist : np.ndarray
131
+ The observed label distribution. This array represents the actual distribution of labels in the dataset.
132
+
133
+ Returns
134
+ -------
135
+ np.ndarray
136
+ The normalized expected distribution, scaled to have the same sum as the observed distribution.
137
+
138
+ Raises
139
+ ------
140
+ ValueError
141
+ If the expected distribution is all zeros.
142
+
143
+ Notes
144
+ -----
145
+ The function ensures that the total number of labels in the expected distribution matches the total
146
+ number of labels in the observed distribution by scaling the expected distribution.
147
+ """
148
+
117
149
  exp_sum = np.sum(expected_dist)
118
150
  obs_sum = np.sum(observed_dist)
119
151
 
@@ -148,6 +180,7 @@ def validate_dist(label_dist: NDArray, label_name: str):
148
180
  Warning
149
181
  If any elements of label_dist are less than 5
150
182
  """
183
+
151
184
  if not len(label_dist):
152
185
  raise ValueError(f"No labels found in the {label_name} dataset")
153
186
  if np.any(label_dist < 5):
@@ -159,17 +192,17 @@ def validate_dist(label_dist: NDArray, label_name: str):
159
192
 
160
193
 
161
194
  @set_metadata("dataeval.metrics")
162
- def parity(
195
+ def label_parity(
163
196
  expected_labels: ArrayLike,
164
197
  observed_labels: ArrayLike,
165
- num_classes: Optional[int] = None,
198
+ num_classes: int | None = None,
166
199
  ) -> ParityOutput[np.float64]:
167
200
  """
168
- Perform a one-way chi-squared test between observation frequencies and expected frequencies that
169
- tests the null hypothesis that the observed data has the expected frequencies.
201
+ Calculate the chi-square statistic to assess the parity between expected and observed label distributions.
170
202
 
171
- This function acts as an interface to the scipy.stats.chisquare method, which is documented at
172
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
203
+ This function computes the frequency distribution of classes in both expected and observed labels, normalizes
204
+ the expected distribution to match the total number of observed labels, and then calculates the chi-square
205
+ statistic to determine if there is a significant difference between the two distributions.
173
206
 
174
207
  Parameters
175
208
  ----------
@@ -177,9 +210,9 @@ def parity(
177
210
  List of class labels in the expected dataset
178
211
  observed_labels : ArrayLike
179
212
  List of class labels in the observed dataset
180
- num_classes : Optional[int]
181
- The number of unique classes in the datasets. If this is not specified, it will
182
- be inferred from the set of unique labels in expected_labels and observed_labels
213
+ num_classes : int | None, default None
214
+ The number of unique classes in the datasets. If not provided, the function will infer it
215
+ from the set of unique labels in expected_labels and observed_labels
183
216
 
184
217
  Returns
185
218
  -------
@@ -189,8 +222,31 @@ def parity(
189
222
  Raises
190
223
  ------
191
224
  ValueError
192
- If x is empty
225
+ If expected label distribution is empty, is all zeros, or if there is a mismatch in the number
226
+ of unique classes between the observed and expected distributions.
227
+
228
+
229
+ Notes
230
+ -----
231
+ - Providing ``num_classes`` can be helpful if there are classes with zero instances in one of the distributions.
232
+ - The function first validates the observed distribution and normalizes the expected distribution so that it
233
+ has the same total number of labels as the observed distribution.
234
+ - It then performs a chi-square test to determine if there is a statistically significant difference between
235
+ the observed and expected label distributions.
236
+ - This function acts as an interface to the scipy.stats.chisquare method, which is documented at
237
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
238
+
239
+
240
+ Examples
241
+ --------
242
+ Randomly creating some label distributions using ``np.random.default_rng``
243
+
244
+ >>> expected_labels = np_random_gen.choice([0, 1, 2, 3, 4], (100))
245
+ >>> observed_labels = np_random_gen.choice([2, 3, 0, 4, 1], (100))
246
+ >>> label_parity(expected_labels, observed_labels)
247
+ ParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
193
248
  """
249
+
194
250
  # Calculate
195
251
  if not num_classes:
196
252
  num_classes = 0
@@ -223,27 +279,27 @@ def parity(
223
279
 
224
280
 
225
281
  @set_metadata("dataeval.metrics")
226
- def parity_metadata(
282
+ def parity(
227
283
  data_factors: Mapping[str, ArrayLike],
228
- continuous_factor_bincounts: Optional[Dict[str, int]] = None,
284
+ continuous_factor_bincounts: dict[str, int] | None = None,
229
285
  ) -> ParityOutput[NDArray[np.float64]]:
230
286
  """
231
- Evaluates the statistical independence of metadata factors from class labels.
232
- This performs a chi-square test, which provides a score and a p-value for
233
- statistical independence between each pair of a metadata factor and a class label.
234
- A high score with a low p-value suggests that a metadata factor is strongly
235
- correlated with a class label.
287
+ Calculate chi-square statistics to assess the relationship between multiple factors and class labels.
288
+
289
+ This function computes the chi-square statistic for each metadata factor to determine if there is
290
+ a significant relationship between the factor values and class labels. The function handles both categorical
291
+ and discretized continuous factors.
236
292
 
237
293
  Parameters
238
294
  ----------
239
295
  data_factors: Mapping[str, ArrayLike]
240
296
  The dataset factors, which are per-image attributes including class label and metadata.
241
297
  Each key of dataset_factors is a factor, whose value is the per-image factor values.
242
- continuous_factor_bincounts : Optional[Dict[str, int]], default None
243
- The factors in data_factors that have continuous values and the array of bin counts to
244
- discretize values into. All factors are treated as having discrete values unless they
245
- are specified as keys in this dictionary. Each element of this array must occur as a key
246
- in data_factors.
298
+ continuous_factor_bincounts : Dict[str, int] | None, default None
299
+ A dictionary specifying the number of bins for discretizing the continuous factors.
300
+ The keys should correspond to the names of continuous factors in `data_factors`,
301
+ and the values should be the number of bins to use for discretization.
302
+ If not provided, no discretization is applied.
247
303
 
248
304
  Returns
249
305
  -------
@@ -251,7 +307,39 @@ def parity_metadata(
251
307
  Arrays of length (num_factors) whose (i)th element corresponds to the
252
308
  chi-square score and p-value for the relationship between factor i and
253
309
  the class labels in the dataset.
310
+
311
+ Raises
312
+ ------
313
+ Warning
314
+ If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
315
+ lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
316
+ factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
317
+ into fewer bins.
318
+
319
+ Notes
320
+ -----
321
+ - Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
322
+ - A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
323
+ - The function creates a contingency matrix for each factor, where each entry represents the frequency of a
324
+ specific factor value co-occurring with a particular class label.
325
+ - Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
326
+ to prevent errors in the calculation.
327
+
328
+ Examples
329
+ --------
330
+ Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
331
+
332
+ >>> data_factors = {
333
+ ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
334
+ ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
335
+ ... "gender": np_random_gen.choice(["M", "F"], (100)),
336
+ ... "class": np_random_gen.choice([0, 1, 2], (100)),
337
+ ... }
338
+ >>> continuous_factor_bincounts = {"age": 4, "income": 3}
339
+ >>> parity(data_factors, continuous_factor_bincounts)
340
+ ParityOutput(score=array([2.82329785, 1.60625584, 1.38377236]), p_value=array([0.83067563, 0.80766733, 0.5006309 ]))
254
341
  """
342
+
255
343
  data_factors_np = {k: to_numpy(v) for k, v in data_factors.items()}
256
344
  continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
257
345
 
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Any, Callable, Dict, Iterable, List
4
+ from typing import Any, Callable, Iterable
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import ArrayLike, NDArray
@@ -62,8 +64,8 @@ class StatsOutput(OutputMetadata):
62
64
  Per-channel mapping of indices for each metric
63
65
  """
64
66
 
65
- xxhash: List[str]
66
- pchash: List[str]
67
+ xxhash: list[str]
68
+ pchash: list[str]
67
69
  width: NDArray[np.uint16]
68
70
  height: NDArray[np.uint16]
69
71
  channels: NDArray[np.uint8]
@@ -82,7 +84,7 @@ class StatsOutput(OutputMetadata):
82
84
  percentiles: NDArray[np.float16]
83
85
  histogram: NDArray[np.uint32]
84
86
  entropy: NDArray[np.float16]
85
- ch_idx_map: Dict[int, List[int]]
87
+ ch_idx_map: dict[int, list[int]]
86
88
 
87
89
  def dict(self):
88
90
  return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and len(v) > 0}
@@ -90,7 +92,7 @@ class StatsOutput(OutputMetadata):
90
92
 
91
93
  QUARTILES = (0, 25, 50, 75, 100)
92
94
 
93
- IMAGESTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
95
+ IMAGESTATS_FN_MAP: dict[ImageStat, Callable[[NDArray], Any]] = {
94
96
  ImageStat.XXHASH: lambda x: xxhash(x),
95
97
  ImageStat.PCHASH: lambda x: pchash(x),
96
98
  ImageStat.WIDTH: lambda x: np.uint16(x.shape[-1]),
@@ -113,7 +115,7 @@ IMAGESTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
113
115
  ImageStat.ENTROPY: lambda x: np.float16(entropy(x)),
114
116
  }
115
117
 
116
- CHANNELSTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
118
+ CHANNELSTATS_FN_MAP: dict[ImageStat, Callable[[NDArray], Any]] = {
117
119
  ImageStat.MEAN: lambda x: np.float16(np.mean(x, axis=1)),
118
120
  ImageStat.STD: lambda x: np.float16(np.std(x, axis=1)),
119
121
  ImageStat.VAR: lambda x: np.float16(np.var(x, axis=1)),
@@ -128,18 +130,62 @@ CHANNELSTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
128
130
  def run_stats(
129
131
  images: Iterable[ArrayLike],
130
132
  flags: ImageStat,
131
- fn_map: Dict[ImageStat, Callable[[NDArray], Any]],
133
+ fn_map: dict[ImageStat, Callable[[NDArray], Any]],
132
134
  flatten: bool,
133
135
  ):
136
+ """
137
+ Compute specified statistics on a set of images.
138
+
139
+ This function applies a set of statistical operations to each image in the input iterable,
140
+ based on the specified flags. The function dynamically determines which statistics to apply
141
+ using a flag system and a corresponding function map. It also supports optional image
142
+ flattening for pixel-wise calculations.
143
+
144
+ Parameters
145
+ ----------
146
+ images : ArrayLike
147
+ An iterable of images (e.g., list of arrays), where each image is represented as an
148
+ array-like structure (e.g., NumPy arrays).
149
+ flags : ImageStat
150
+ A bitwise flag or set of flags specifying the statistics to compute for each image.
151
+ These flags determine which functions in `fn_map` to apply.
152
+ fn_map : dict[ImageStat, Callable]
153
+ A dictionary mapping `ImageStat` flags to functions that compute the corresponding statistics.
154
+ Each function accepts a NumPy array (representing an image or rescaled pixel data) and returns a result.
155
+ flatten : bool
156
+ If True, the image is flattened into a 2D array for pixel-wise operations. Otherwise, the
157
+ original image dimensions are preserved.
158
+
159
+ Returns
160
+ -------
161
+ list[dict[str, NDArray]]
162
+ A list of dictionaries, where each dictionary contains the computed statistics for an image.
163
+ The dictionary keys correspond to the names of the statistics, and the values are NumPy arrays
164
+ with the results of the computations.
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If unsupported flags are provided that are not present in `fn_map`.
170
+
171
+ Notes
172
+ -----
173
+ - The function performs image normalization (rescaling the image values)
174
+ before applying some of the statistics.
175
+ - Pixel-level statistics (e.g., brightness, entropy) are computed after
176
+ rescaling and, optionally, flattening the images.
177
+ - For statistics like histograms and entropy, intermediate results may
178
+ be reused to avoid redundant computation.
179
+ """
134
180
  verify_supported(flags, fn_map)
135
181
  flag_dict = to_distinct(flags)
136
182
 
137
- results_list: List[Dict[str, NDArray]] = []
183
+ results_list: list[dict[str, NDArray]] = []
138
184
  for image in to_numpy_iter(images):
139
185
  normalized = normalize_image_shape(image)
140
186
  scaled = None
141
187
  hist = None
142
- output: Dict[str, NDArray] = {}
188
+ output: dict[str, NDArray] = {}
143
189
  for flag, stat in flag_dict.items():
144
190
  if flag & (ImageStat.ALL_PIXELSTATS | ImageStat.BRIGHTNESS):
145
191
  if scaled is None:
@@ -161,16 +207,53 @@ def imagestats(images: Iterable[ArrayLike], flags: ImageStat = ImageStat.ALL_STA
161
207
  """
162
208
  Calculates image and pixel statistics for each image
163
209
 
210
+ This function computes various statistical metrics (e.g., mean, standard deviation, entropy)
211
+ on the images as a whole, based on the specified flags. It supports multiple types of statistics
212
+ that can be selected using the `flags` argument.
213
+
164
214
  Parameters
165
215
  ----------
166
- images : Iterable[ArrayLike]
216
+ images : ArrayLike
167
217
  Images to run statistical tests on
168
218
  flags : ImageStat, default ImageStat.ALL_STATS
169
- Metric(s) to calculate for each image
219
+ Metric(s) to calculate for each image. The default flag ``ImageStat.ALL_STATS``
220
+ computes all available statistics.
170
221
 
171
222
  Returns
172
223
  -------
173
- Dict[str, Any]
224
+ StatsOutput
225
+ A dictionary-like object containing the computed statistics for each image. The keys correspond
226
+ to the names of the statistics (e.g., 'mean', 'std'), and the values are lists of results for
227
+ each image or numpy arrays when the results are multi-dimensional.
228
+
229
+ Notes
230
+ -----
231
+ - All metrics in the ImageStat.ALL_PIXELSTATS flag are scaled based on the perceived bit depth
232
+ (which is derived from the largest pixel value) to allow for better comparison
233
+ between images stored in different formats and different resolutions.
234
+ - ImageStat.ZERO and ImageStat.MISSING are presented as a percentage of total pixel counts
235
+
236
+ Examples
237
+ --------
238
+ Calculating the statistics on the images, whose shape is (C, H, W)
239
+
240
+ >>> results = imagestats(images, flags=ImageStat.MEAN | ImageStat.ALL_VISUALS)
241
+ >>> print(results.mean)
242
+ [0.16650391 0.52050781 0.05471802 0.07702637 0.09875488 0.12188721
243
+ 0.14440918 0.16711426 0.18859863 0.21264648 0.2355957 0.25854492
244
+ 0.27978516 0.3046875 0.32788086 0.35131836 0.37255859 0.39819336
245
+ 0.42163086 0.4453125 0.46630859 0.49267578 0.51660156 0.54052734
246
+ 0.56152344 0.58837891 0.61230469 0.63671875 0.65771484 0.68505859
247
+ 0.70947266 0.73388672 0.75488281 0.78271484 0.80712891 0.83203125
248
+ 0.85302734 0.88134766 0.90625 0.93115234]
249
+ >>> print(results.zero)
250
+ [0.12561035 0. 0. 0. 0.11730957 0.
251
+ 0. 0. 0.10986328 0. 0. 0.
252
+ 0.10266113 0. 0. 0. 0.09570312 0.
253
+ 0. 0. 0.08898926 0. 0. 0.
254
+ 0.08251953 0. 0. 0. 0.07629395 0.
255
+ 0. 0. 0.0703125 0. 0. 0.
256
+ 0.0645752 0. 0. 0. ]
174
257
  """
175
258
  stats = run_stats(images, flags, IMAGESTATS_FN_MAP, False)
176
259
  output = {}
@@ -190,17 +273,72 @@ def channelstats(images: Iterable[ArrayLike], flags=ImageStat.ALL_PIXELSTATS) ->
190
273
  """
191
274
  Calculates pixel statistics for each image per channel
192
275
 
276
+ This function computes pixel-level statistics (e.g., mean, variance, etc.) on a per-channel basis
277
+ for each image. The statistics can be selected using the `flags` argument, and the results will
278
+ be grouped by the number of channels (e.g., RGB channels) in each image.
279
+
193
280
  Parameters
194
281
  ----------
195
- images : Iterable[ArrayLike]
282
+ images : ArrayLike
196
283
  Images to run statistical tests on
197
284
  flags: ImageStat, default ImageStat.ALL_PIXELSTATS
198
- Statistic(s) to calculate for each image per channel
199
- Only flags in the ImageStat.ALL_PIXELSTATS category are supported
285
+ Metric(s) to calculate for each image per channel.
286
+ Only flags within the ``ImageStat.ALL_PIXELSTATS`` category are supported.
200
287
 
201
288
  Returns
202
289
  -------
203
- Dict[str, Any]
290
+ StatsOutput
291
+ A dictionary-like object containing the computed statistics for each image per channel. The keys
292
+ correspond to the names of the statistics (e.g., 'mean', 'variance'), and the values are numpy arrays
293
+ with results for each channel of each image.
294
+
295
+ Notes
296
+ -----
297
+ - All metrics in the ImageStat.ALL_PIXELSTATS flag are scaled based on the perceived bit depth
298
+ (which is derived from the largest pixel value) to allow for better comparison
299
+ between images stored in different formats and different resolutions.
300
+
301
+ Examples
302
+ --------
303
+ Calculating the statistics on a per channel basis for images, whose shape is (N, C, H, W)
304
+
305
+ >>> results = channelstats(images, flags=ImageStat.MEAN | ImageStat.VAR)
306
+ >>> print(results.mean)
307
+ {3: array([[0.01617, 0.5303 , 0.06525, 0.09735, 0.1295 , 0.1616 , 0.1937 ,
308
+ 0.2258 , 0.2578 , 0.29 , 0.322 , 0.3542 , 0.3865 , 0.4185 ,
309
+ 0.4507 , 0.4827 , 0.5146 , 0.547 , 0.579 , 0.6113 , 0.643 ,
310
+ 0.6753 , 0.7075 , 0.7397 , 0.7715 , 0.8037 , 0.836 , 0.868 ,
311
+ 0.9004 , 0.932 ],
312
+ [0.04828, 0.562 , 0.06726, 0.09937, 0.1315 , 0.1636 , 0.1957 ,
313
+ 0.2278 , 0.26 , 0.292 , 0.3242 , 0.3562 , 0.3884 , 0.4204 ,
314
+ 0.4526 , 0.4846 , 0.5166 , 0.549 , 0.581 , 0.6133 , 0.6455 ,
315
+ 0.6772 , 0.7095 , 0.7417 , 0.774 , 0.8057 , 0.838 , 0.87 ,
316
+ 0.9023 , 0.934 ],
317
+ [0.0804 , 0.594 , 0.0693 , 0.1014 , 0.1334 , 0.1656 , 0.1978 ,
318
+ 0.2299 , 0.262 , 0.294 , 0.3262 , 0.3584 , 0.3904 , 0.4226 ,
319
+ 0.4546 , 0.4868 , 0.519 , 0.551 , 0.583 , 0.615 , 0.6475 ,
320
+ 0.679 , 0.7114 , 0.7437 , 0.776 , 0.808 , 0.84 , 0.872 ,
321
+ 0.9043 , 0.9365 ]], dtype=float16)}
322
+ >>> print(results.var)
323
+ {3: array([[0.00010103, 0.01077 , 0.0001621 , 0.0003605 , 0.0006375 ,
324
+ 0.000993 , 0.001427 , 0.001939 , 0.00253 , 0.003199 ,
325
+ 0.003944 , 0.004772 , 0.005676 , 0.006657 , 0.007717 ,
326
+ 0.00886 , 0.01008 , 0.01137 , 0.01275 , 0.0142 ,
327
+ 0.01573 , 0.01733 , 0.01903 , 0.0208 , 0.02264 ,
328
+ 0.02457 , 0.02657 , 0.02864 , 0.0308 , 0.03305 ],
329
+ [0.0001798 , 0.0121 , 0.0001721 , 0.0003753 , 0.0006566 ,
330
+ 0.001017 , 0.001455 , 0.001972 , 0.002565 , 0.003239 ,
331
+ 0.00399 , 0.00482 , 0.00573 , 0.006714 , 0.007782 ,
332
+ 0.00893 , 0.01015 , 0.011444 , 0.012825 , 0.01428 ,
333
+ 0.01581 , 0.01743 , 0.01912 , 0.02089 , 0.02274 ,
334
+ 0.02466 , 0.02667 , 0.02875 , 0.03091 , 0.03314 ],
335
+ [0.000337 , 0.0135 , 0.0001824 , 0.0003903 , 0.0006766 ,
336
+ 0.00104 , 0.001484 , 0.002005 , 0.002604 , 0.00328 ,
337
+ 0.004036 , 0.00487 , 0.005783 , 0.006775 , 0.00784 ,
338
+ 0.00899 , 0.010216 , 0.01152 , 0.0129 , 0.01436 ,
339
+ 0.0159 , 0.01752 , 0.01921 , 0.02098 , 0.02283 ,
340
+ 0.02477 , 0.02676 , 0.02885 , 0.03102 , 0.03326 ]],
341
+ dtype=float16)}
204
342
  """
205
343
  stats = run_stats(images, flags, CHANNELSTATS_FN_MAP, True)
206
344
 
@@ -40,13 +40,39 @@ def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
40
40
 
41
41
  Returns
42
42
  -------
43
- Dict[str, float]
44
- uap : The empirical mean precision estimate
43
+ UAPOutput
44
+ The empirical mean precision estimate, float
45
45
 
46
46
  Raises
47
47
  ------
48
48
  ValueError
49
49
  If unique classes M < 2
50
+
51
+ Notes
52
+ -----
53
+ This function calculates the empirical mean precision using the
54
+ ``average_precision_score`` from scikit-learn, weighted by the class distribution.
55
+
56
+ Examples
57
+ --------
58
+ >>> y_true = np.array([0, 0, 1, 1])
59
+ >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
60
+ >>> uap(y_true, y_scores)
61
+ UAPOutput(uap=0.8333333333333333)
62
+
63
+ >>> y_true = np.array([0, 0, 1, 1, 2, 2])
64
+ >>> y_scores = np.array(
65
+ ... [
66
+ ... [0.7, 0.2, 0.1],
67
+ ... [0.4, 0.3, 0.3],
68
+ ... [0.1, 0.8, 0.1],
69
+ ... [0.2, 0.3, 0.5],
70
+ ... [0.4, 0.4, 0.2],
71
+ ... [0.1, 0.2, 0.7],
72
+ ... ]
73
+ ... )
74
+ >>> uap(y_true, y_scores)
75
+ UAPOutput(uap=0.7777777777777777)
50
76
  """
51
77
 
52
78
  precision = float(average_precision_score(to_numpy(labels), to_numpy(scores), average="weighted"))
@@ -1,4 +1,6 @@
1
- from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Literal, NamedTuple, Sequence
2
4
 
3
5
  import numpy as np
4
6
  import xxhash as xxh
@@ -19,22 +21,22 @@ HASH_SIZE = 8
19
21
  MAX_FACTOR = 4
20
22
 
21
23
 
22
- def get_method(method_map: Dict[str, Callable], method: str) -> Callable:
24
+ def get_method(method_map: dict[str, Callable], method: str) -> Callable:
23
25
  if method not in method_map:
24
26
  raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
25
27
  return method_map[method]
26
28
 
27
29
 
28
30
  def get_counts(
29
- data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
30
- ) -> tuple[Dict, Dict]:
31
+ data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
32
+ ) -> tuple[dict, dict]:
31
33
  """
32
34
  Initialize dictionary of histogram counts --- treat categorical values
33
35
  as histogram bins.
34
36
 
35
37
  Parameters
36
38
  ----------
37
- subset_mask: Optional[NDArray[np.bool_]]
39
+ subset_mask: NDArray[np.bool_] | None
38
40
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
39
41
 
40
42
  Returns
@@ -68,10 +70,10 @@ def get_counts(
68
70
 
69
71
  def entropy(
70
72
  data: NDArray,
71
- names: List[str],
72
- is_categorical: List[bool],
73
+ names: list[str],
74
+ is_categorical: list[bool],
73
75
  normalized: bool = False,
74
- subset_mask: Optional[NDArray[np.bool_]] = None,
76
+ subset_mask: NDArray[np.bool_] | None = None,
75
77
  ) -> NDArray[np.float64]:
76
78
  """
77
79
  Meant for use with Bias metrics, Balance, Diversity, ClasswiseBalance,
@@ -84,7 +86,7 @@ def entropy(
84
86
  ----------
85
87
  normalized: bool
86
88
  Flag that determines whether or not to normalize entropy by log(num_bins)
87
- subset_mask: Optional[NDArray[np.bool_]]
89
+ subset_mask: NDArray[np.bool_] | None
88
90
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
89
91
 
90
92
  Notes
@@ -120,7 +122,7 @@ def entropy(
120
122
 
121
123
 
122
124
  def get_num_bins(
123
- data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
125
+ data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
124
126
  ) -> NDArray[np.float64]:
125
127
  """
126
128
  Number of bins or unique values for each metadata factor, used to
@@ -128,7 +130,7 @@ def get_num_bins(
128
130
 
129
131
  Parameters
130
132
  ----------
131
- subset_mask: Optional[NDArray[np.bool_]]
133
+ subset_mask: NDArray[np.bool_] | None
132
134
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
133
135
 
134
136
  Returns
@@ -144,7 +146,7 @@ def get_num_bins(
144
146
  return num_bins
145
147
 
146
148
 
147
- def infer_categorical(X: NDArray, threshold: float = 0.5) -> NDArray:
149
+ def infer_categorical(X: NDArray, threshold: float = 0.2) -> NDArray:
148
150
  """
149
151
  Compute fraction of feature values that are unique --- intended to be used
150
152
  for inferring whether variables are categorical.
@@ -160,10 +162,10 @@ def infer_categorical(X: NDArray, threshold: float = 0.5) -> NDArray:
160
162
 
161
163
 
162
164
  def preprocess_metadata(
163
- class_labels: Sequence[int], metadata: List[Dict], cat_thresh: float = 0.2
164
- ) -> Tuple[NDArray, List[str], List[bool]]:
165
+ class_labels: Sequence[int], metadata: list[dict], cat_thresh: float = 0.2
166
+ ) -> tuple[NDArray, list[str], list[bool]]:
165
167
  # convert class_labels and list of metadata dicts to dict of ndarrays
166
- metadata_dict: Dict[str, NDArray] = {
168
+ metadata_dict: dict[str, NDArray] = {
167
169
  "class_label": np.asarray(class_labels, dtype=int),
168
170
  **{k: np.array([d[k] for d in metadata]) for k in metadata[0]},
169
171
  }
@@ -223,7 +225,7 @@ def minimum_spanning_tree(X: NDArray) -> Any:
223
225
  return mst(eudist_csr)
224
226
 
225
227
 
226
- def get_classes_counts(labels: NDArray) -> Tuple[int, int]:
228
+ def get_classes_counts(labels: NDArray) -> tuple[int, int]:
227
229
  """
228
230
  Returns the classes and counts of from an array of labels
229
231
 
@@ -303,8 +305,8 @@ def compute_neighbors(
303
305
 
304
306
  class BitDepth(NamedTuple):
305
307
  depth: int
306
- pmin: Union[float, int]
307
- pmax: Union[float, int]
308
+ pmin: float | int
309
+ pmax: float | int
308
310
 
309
311
 
310
312
  def get_bitdepth(image: NDArray) -> BitDepth: