dataeval 0.74.2__py3-none-any.whl → 0.76.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +27 -23
- dataeval/detectors/__init__.py +2 -2
- dataeval/detectors/drift/__init__.py +14 -12
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/cvm.py +1 -1
- dataeval/detectors/drift/ks.py +3 -2
- dataeval/detectors/drift/mmd.py +9 -7
- dataeval/detectors/drift/torch.py +12 -12
- dataeval/detectors/drift/uncertainty.py +5 -4
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +4 -4
- dataeval/detectors/linters/clusterer.py +5 -9
- dataeval/detectors/linters/duplicates.py +10 -14
- dataeval/detectors/linters/outliers.py +100 -5
- dataeval/detectors/ood/__init__.py +4 -11
- dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
- dataeval/detectors/ood/base.py +47 -160
- dataeval/detectors/ood/metadata_ks_compare.py +34 -42
- dataeval/detectors/ood/metadata_least_likely.py +3 -3
- dataeval/detectors/ood/metadata_ood_mi.py +6 -5
- dataeval/detectors/ood/mixin.py +146 -0
- dataeval/detectors/ood/output.py +63 -0
- dataeval/interop.py +7 -6
- dataeval/{logging.py → log.py} +2 -0
- dataeval/metrics/__init__.py +3 -3
- dataeval/metrics/bias/__init__.py +10 -13
- dataeval/metrics/bias/balance.py +13 -11
- dataeval/metrics/bias/coverage.py +53 -5
- dataeval/metrics/bias/diversity.py +56 -24
- dataeval/metrics/bias/parity.py +20 -17
- dataeval/metrics/estimators/__init__.py +2 -2
- dataeval/metrics/estimators/ber.py +7 -4
- dataeval/metrics/estimators/divergence.py +4 -4
- dataeval/metrics/estimators/uap.py +4 -4
- dataeval/metrics/stats/__init__.py +19 -19
- dataeval/metrics/stats/base.py +28 -12
- dataeval/metrics/stats/boxratiostats.py +13 -14
- dataeval/metrics/stats/datasetstats.py +49 -20
- dataeval/metrics/stats/dimensionstats.py +8 -8
- dataeval/metrics/stats/hashstats.py +14 -10
- dataeval/metrics/stats/labelstats.py +94 -11
- dataeval/metrics/stats/pixelstats.py +11 -14
- dataeval/metrics/stats/visualstats.py +10 -13
- dataeval/output.py +23 -14
- dataeval/utils/__init__.py +5 -14
- dataeval/utils/dataset/__init__.py +7 -0
- dataeval/utils/{torch → dataset}/datasets.py +2 -0
- dataeval/utils/dataset/read.py +63 -0
- dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
- dataeval/utils/image.py +2 -2
- dataeval/utils/metadata.py +317 -14
- dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +91 -71
- dataeval/utils/torch/__init__.py +2 -17
- dataeval/utils/torch/gmm.py +29 -6
- dataeval/utils/torch/{utils.py → internal.py} +82 -58
- dataeval/utils/torch/models.py +10 -8
- dataeval/utils/torch/trainer.py +6 -85
- dataeval/workflows/__init__.py +2 -5
- dataeval/workflows/sufficiency.py +18 -8
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
- dataeval-0.76.0.dist-info/METADATA +137 -0
- dataeval-0.76.0.dist-info/RECORD +67 -0
- dataeval/detectors/ood/base_torch.py +0 -109
- dataeval/metrics/bias/metadata_preprocessing.py +0 -285
- dataeval/utils/gmm.py +0 -26
- dataeval-0.74.2.dist-info/METADATA +0 -120
- dataeval-0.74.2.dist-info/RECORD +0 -66
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
dataeval/utils/metadata.py
CHANGED
@@ -1,12 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Metadata related utility functions that help organize raw metadata into \
|
3
|
+
:class:`Metadata` objects for use within `DataEval`.
|
4
|
+
"""
|
5
|
+
|
1
6
|
from __future__ import annotations
|
2
7
|
|
3
|
-
__all__ = ["
|
8
|
+
__all__ = ["Metadata", "preprocess", "merge", "flatten"]
|
4
9
|
|
5
10
|
import warnings
|
6
|
-
from
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import Any, Iterable, Literal, Mapping, TypeVar, overload
|
7
13
|
|
8
14
|
import numpy as np
|
9
|
-
from numpy.typing import NDArray
|
15
|
+
from numpy.typing import ArrayLike, NDArray
|
16
|
+
from scipy.stats import wasserstein_distance as wd
|
17
|
+
|
18
|
+
from dataeval.interop import as_numpy, to_numpy
|
19
|
+
from dataeval.output import Output, set_metadata
|
20
|
+
|
21
|
+
DISCRETE_MIN_WD = 0.054
|
22
|
+
CONTINUOUS_MIN_SAMPLE_SIZE = 20
|
23
|
+
|
10
24
|
|
11
25
|
T = TypeVar("T")
|
12
26
|
|
@@ -131,9 +145,7 @@ def _flatten_dict_inner(
|
|
131
145
|
return items, size
|
132
146
|
|
133
147
|
|
134
|
-
def
|
135
|
-
d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
|
136
|
-
) -> tuple[dict[str, Any], int]:
|
148
|
+
def flatten(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> tuple[dict[str, Any], int]:
|
137
149
|
"""
|
138
150
|
Flattens a dictionary and converts values to numeric values when possible.
|
139
151
|
|
@@ -146,12 +158,12 @@ def _flatten_dict(
|
|
146
158
|
ignore_lists : bool
|
147
159
|
Option to skip expanding lists within metadata
|
148
160
|
fully_qualified : bool
|
149
|
-
Option to return dictionary keys full qualified instead of
|
161
|
+
Option to return dictionary keys full qualified instead of reduced
|
150
162
|
|
151
163
|
Returns
|
152
164
|
-------
|
153
|
-
dict[str, Any]
|
154
|
-
A flattened dictionary
|
165
|
+
tuple[dict[str, Any], int]
|
166
|
+
A tuple of the flattened dictionary and the length of detected lists in metadata
|
155
167
|
"""
|
156
168
|
expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
|
157
169
|
|
@@ -185,7 +197,7 @@ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
|
|
185
197
|
return set(metadata[keys[0]]) == set(metadata[keys[1]])
|
186
198
|
|
187
199
|
|
188
|
-
def
|
200
|
+
def merge(
|
189
201
|
metadata: Iterable[Mapping[str, Any]],
|
190
202
|
ignore_lists: bool = False,
|
191
203
|
fully_qualified: bool = False,
|
@@ -222,7 +234,7 @@ def merge_metadata(
|
|
222
234
|
Example
|
223
235
|
-------
|
224
236
|
>>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
|
225
|
-
>>> reorganized_metadata, image_indicies =
|
237
|
+
>>> reorganized_metadata, image_indicies = merge(list_metadata)
|
226
238
|
>>> reorganized_metadata
|
227
239
|
{'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
|
228
240
|
>>> image_indicies
|
@@ -245,9 +257,7 @@ def merge_metadata(
|
|
245
257
|
|
246
258
|
image_repeats = np.zeros(len(dicts))
|
247
259
|
for i, d in enumerate(dicts):
|
248
|
-
flattened, image_repeats[i] =
|
249
|
-
d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
|
250
|
-
)
|
260
|
+
flattened, image_repeats[i] = flatten(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
|
251
261
|
isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
|
252
262
|
union = union.union(flattened.keys())
|
253
263
|
for k, v in flattened.items():
|
@@ -276,3 +286,296 @@ def merge_metadata(
|
|
276
286
|
output[k] = np.array(cv) if as_numpy else cv
|
277
287
|
|
278
288
|
return output, image_indicies
|
289
|
+
|
290
|
+
|
291
|
+
@dataclass(frozen=True)
|
292
|
+
class Metadata(Output):
|
293
|
+
"""
|
294
|
+
Dataclass containing binned metadata from the :func:`preprocess` function.
|
295
|
+
|
296
|
+
Attributes
|
297
|
+
----------
|
298
|
+
discrete_factor_names : list[str]
|
299
|
+
List containing factor names for the original data that was discrete and the binned continuous data
|
300
|
+
discrete_data : NDArray[np.int]
|
301
|
+
Array containing values for the original data that was discrete and the binned continuous data
|
302
|
+
continuous_factor_names : list[str]
|
303
|
+
List containing factor names for the original continuous data
|
304
|
+
continuous_data : NDArray[np.int or np.double] | None
|
305
|
+
Array containing values for the original continuous data or None if there was no continuous data
|
306
|
+
class_labels : NDArray[np.int]
|
307
|
+
Numerical class labels for the images/objects
|
308
|
+
class_names : NDArray[Any]
|
309
|
+
Array of unique class names (for use with plotting)
|
310
|
+
total_num_factors : int
|
311
|
+
Sum of discrete_factor_names and continuous_factor_names plus 1 for class
|
312
|
+
"""
|
313
|
+
|
314
|
+
discrete_factor_names: list[str]
|
315
|
+
discrete_data: NDArray[np.int_]
|
316
|
+
continuous_factor_names: list[str]
|
317
|
+
continuous_data: NDArray[np.int_ | np.double] | None
|
318
|
+
class_labels: NDArray[np.int_]
|
319
|
+
class_names: NDArray[Any]
|
320
|
+
total_num_factors: int
|
321
|
+
|
322
|
+
|
323
|
+
@set_metadata
|
324
|
+
def preprocess(
|
325
|
+
raw_metadata: Iterable[Mapping[str, Any]],
|
326
|
+
class_labels: ArrayLike | str,
|
327
|
+
continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
|
328
|
+
auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
|
329
|
+
exclude: Iterable[str] | None = None,
|
330
|
+
) -> Metadata:
|
331
|
+
"""
|
332
|
+
Restructures the metadata to be in the correct format for the bias functions.
|
333
|
+
|
334
|
+
This identifies whether the incoming metadata is discrete or continuous,
|
335
|
+
and whether the data is already binned or still needs binning.
|
336
|
+
It accepts a list of dictionaries containing the per image metadata and
|
337
|
+
automatically adjusts for multiple targets in an image.
|
338
|
+
|
339
|
+
Parameters
|
340
|
+
----------
|
341
|
+
raw_metadata : Iterable[Mapping[str, Any]]
|
342
|
+
Iterable collection of metadata dictionaries to flatten and merge.
|
343
|
+
class_labels : ArrayLike or string
|
344
|
+
If arraylike, expects the labels for each image (image classification) or each object (object detection).
|
345
|
+
If the labels are included in the metadata dictionary, pass in the key value.
|
346
|
+
continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
|
347
|
+
User provided dictionary specifying how to bin the continuous metadata factors where the value is either
|
348
|
+
an int to represent the number of bins, or a list of floats representing the edges for each bin.
|
349
|
+
auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
|
350
|
+
Method by which the function will automatically bin continuous metadata factors. It is recommended
|
351
|
+
that the user provide the bins through the `continuous_factor_bins`.
|
352
|
+
exclude : Iterable[str] or None, default None
|
353
|
+
User provided collection of metadata keys to exclude when processing metadata.
|
354
|
+
|
355
|
+
Returns
|
356
|
+
-------
|
357
|
+
Metadata
|
358
|
+
Output class containing the binned metadata
|
359
|
+
"""
|
360
|
+
# Transform metadata into single, flattened dictionary
|
361
|
+
metadata, image_repeats = merge(raw_metadata)
|
362
|
+
|
363
|
+
continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
|
364
|
+
|
365
|
+
# Drop any excluded metadata keys
|
366
|
+
for k in exclude or ():
|
367
|
+
metadata.pop(k, None)
|
368
|
+
if continuous_factor_bins:
|
369
|
+
continuous_factor_bins.pop(k, None)
|
370
|
+
|
371
|
+
# Get the class label array in numeric form
|
372
|
+
class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
|
373
|
+
if class_array.ndim > 1:
|
374
|
+
raise ValueError(
|
375
|
+
f"Got class labels with {class_array.ndim}-dimensional "
|
376
|
+
f"shape {class_array.shape}, but expected a 1-dimensional array."
|
377
|
+
)
|
378
|
+
if not np.issubdtype(class_array.dtype, np.int_):
|
379
|
+
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
380
|
+
else:
|
381
|
+
numerical_labels = class_array
|
382
|
+
unique_classes = np.unique(class_array)
|
383
|
+
|
384
|
+
# Bin according to user supplied bins
|
385
|
+
continuous_metadata = {}
|
386
|
+
discrete_metadata = {}
|
387
|
+
if continuous_factor_bins is not None and continuous_factor_bins != {}:
|
388
|
+
invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
|
389
|
+
if invalid_keys:
|
390
|
+
raise KeyError(
|
391
|
+
f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
|
392
|
+
"but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
|
393
|
+
"or add corresponding entries to the `metadata` dictionary."
|
394
|
+
)
|
395
|
+
for factor, bins in continuous_factor_bins.items():
|
396
|
+
discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
|
397
|
+
continuous_metadata[factor] = metadata[factor]
|
398
|
+
|
399
|
+
# Determine category of the rest of the keys
|
400
|
+
remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
|
401
|
+
for key in remaining_keys:
|
402
|
+
data = to_numpy(metadata[key])
|
403
|
+
if np.issubdtype(data.dtype, np.number):
|
404
|
+
result = _is_continuous(data, image_repeats)
|
405
|
+
if result:
|
406
|
+
continuous_metadata[key] = data
|
407
|
+
unique_samples, ordinal_data = np.unique(data, return_inverse=True)
|
408
|
+
if unique_samples.size <= np.max([20, data.size * 0.01]):
|
409
|
+
discrete_metadata[key] = ordinal_data
|
410
|
+
else:
|
411
|
+
warnings.warn(
|
412
|
+
f"A user defined binning was not provided for {key}. "
|
413
|
+
f"Using the {auto_bin_method} method to discretize the data. "
|
414
|
+
"It is recommended that the user rerun and supply the desired "
|
415
|
+
"bins using the continuous_factor_bins parameter.",
|
416
|
+
UserWarning,
|
417
|
+
)
|
418
|
+
discrete_metadata[key] = _bin_data(data, auto_bin_method)
|
419
|
+
else:
|
420
|
+
_, discrete_metadata[key] = np.unique(data, return_inverse=True)
|
421
|
+
|
422
|
+
# splitting out the dictionaries into the keys and values
|
423
|
+
discrete_factor_names = list(discrete_metadata.keys())
|
424
|
+
discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
|
425
|
+
continuous_factor_names = list(continuous_metadata.keys())
|
426
|
+
continuous_data = np.stack(list(continuous_metadata.values()), axis=-1) if continuous_metadata else None
|
427
|
+
total_num_factors = len(discrete_factor_names + continuous_factor_names) + 1
|
428
|
+
|
429
|
+
return Metadata(
|
430
|
+
discrete_factor_names,
|
431
|
+
discrete_data,
|
432
|
+
continuous_factor_names,
|
433
|
+
continuous_data,
|
434
|
+
numerical_labels,
|
435
|
+
unique_classes,
|
436
|
+
total_num_factors,
|
437
|
+
)
|
438
|
+
|
439
|
+
|
440
|
+
def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
|
441
|
+
"""
|
442
|
+
Digitizes a list of values into a given number of bins.
|
443
|
+
|
444
|
+
Parameters
|
445
|
+
----------
|
446
|
+
data : list | NDArray
|
447
|
+
The values to be digitized.
|
448
|
+
bins : int | Iterable[float]
|
449
|
+
The number of bins or list of bin edges for the discrete values that data will be digitized into.
|
450
|
+
|
451
|
+
Returns
|
452
|
+
-------
|
453
|
+
NDArray[np.intp]
|
454
|
+
The digitized values
|
455
|
+
"""
|
456
|
+
|
457
|
+
if not np.all([np.issubdtype(type(n), np.number) for n in data]):
|
458
|
+
raise TypeError(
|
459
|
+
"Encountered a data value with non-numeric type when digitizing a factor. "
|
460
|
+
"Ensure all occurrences of continuous factors are numeric types."
|
461
|
+
)
|
462
|
+
if isinstance(bins, int):
|
463
|
+
_, bin_edges = np.histogram(data, bins=bins)
|
464
|
+
bin_edges[-1] = np.inf
|
465
|
+
bin_edges[0] = -np.inf
|
466
|
+
else:
|
467
|
+
bin_edges = list(bins)
|
468
|
+
return np.digitize(data, bin_edges)
|
469
|
+
|
470
|
+
|
471
|
+
def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
|
472
|
+
"""
|
473
|
+
Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
|
474
|
+
"""
|
475
|
+
if bin_method == "clusters":
|
476
|
+
# bin_edges = _binning_by_clusters(data)
|
477
|
+
warnings.warn(
|
478
|
+
"Binning by clusters is currently unavailable until changes to the clustering function go through.",
|
479
|
+
UserWarning,
|
480
|
+
)
|
481
|
+
bin_method = "uniform_width"
|
482
|
+
|
483
|
+
# if bin_method != "clusters": # restore this when clusters bin_method is available
|
484
|
+
counts, bin_edges = np.histogram(data, bins="auto")
|
485
|
+
n_bins = counts.size
|
486
|
+
if counts[counts > 0].min() < 10:
|
487
|
+
counter = 20
|
488
|
+
while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
|
489
|
+
counter -= 1
|
490
|
+
n_bins -= 1
|
491
|
+
counts, bin_edges = np.histogram(data, bins=n_bins)
|
492
|
+
|
493
|
+
if bin_method == "uniform_count":
|
494
|
+
quantiles = np.linspace(0, 100, n_bins + 1)
|
495
|
+
bin_edges = np.asarray(np.percentile(data, quantiles))
|
496
|
+
|
497
|
+
bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
|
498
|
+
bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
|
499
|
+
return np.digitize(data, bin_edges) # type: ignore
|
500
|
+
|
501
|
+
|
502
|
+
def _is_continuous(data: NDArray[np.number], image_indicies: NDArray[np.number]) -> bool:
|
503
|
+
"""
|
504
|
+
Determines whether the data is continuous or discrete using the Wasserstein distance.
|
505
|
+
|
506
|
+
Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
|
507
|
+
a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
|
508
|
+
we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
|
509
|
+
neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
|
510
|
+
ignored, depending on context). These normalized locations will be much more uniformly distributed
|
511
|
+
for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
|
512
|
+
Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
|
513
|
+
|
514
|
+
The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
|
515
|
+
how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
|
516
|
+
least 20 points, and furthermore at least half as many points as there are discrete values, we can
|
517
|
+
reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
|
518
|
+
measured from a uniform distribution is greater or less than 0.054, respectively.
|
519
|
+
"""
|
520
|
+
# Check if the metadata is image specific
|
521
|
+
_, data_indicies_unsorted = np.unique(data, return_index=True)
|
522
|
+
if data_indicies_unsorted.size == image_indicies.size:
|
523
|
+
data_indicies = np.sort(data_indicies_unsorted)
|
524
|
+
if (data_indicies == image_indicies).all():
|
525
|
+
data = data[data_indicies]
|
526
|
+
|
527
|
+
# OLD METHOD
|
528
|
+
# uvals = np.unique(data)
|
529
|
+
# pct_unique = uvals.size / data.size
|
530
|
+
# return pct_unique < threshold
|
531
|
+
|
532
|
+
n_examples = len(data)
|
533
|
+
|
534
|
+
if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
|
535
|
+
warnings.warn(
|
536
|
+
f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
|
537
|
+
)
|
538
|
+
return False
|
539
|
+
|
540
|
+
# Require at least 3 unique values before bothering with NNN
|
541
|
+
xu = np.unique(data, axis=None)
|
542
|
+
if xu.size < 3:
|
543
|
+
return False
|
544
|
+
|
545
|
+
Xs = np.sort(data)
|
546
|
+
|
547
|
+
X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
|
548
|
+
|
549
|
+
dx = np.zeros(n_examples - 2) # no dx at end points
|
550
|
+
gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
|
551
|
+
dx[np.logical_not(gtz)] = 0.0
|
552
|
+
|
553
|
+
dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
|
554
|
+
|
555
|
+
shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
|
556
|
+
|
557
|
+
return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
|
558
|
+
|
559
|
+
|
560
|
+
def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
|
561
|
+
"""
|
562
|
+
Returns columnwise unique counts for discrete data.
|
563
|
+
|
564
|
+
Parameters
|
565
|
+
----------
|
566
|
+
data : NDArray
|
567
|
+
Array containing integer values for metadata factors
|
568
|
+
min_num_bins : int | None, default None
|
569
|
+
Minimum number of bins for bincount, helps force consistency across runs
|
570
|
+
|
571
|
+
Returns
|
572
|
+
-------
|
573
|
+
NDArray[np.int_]
|
574
|
+
Bin counts per column of data.
|
575
|
+
"""
|
576
|
+
max_value = data.max() + 1 if min_num_bins is None else min_num_bins
|
577
|
+
cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
|
578
|
+
for idx in range(data.shape[1]):
|
579
|
+
cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
|
580
|
+
|
581
|
+
return cnt_array
|
@@ -6,7 +6,7 @@ import contextlib
|
|
6
6
|
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import ArrayLike
|
9
|
+
from numpy.typing import ArrayLike
|
10
10
|
|
11
11
|
from dataeval.interop import to_numpy
|
12
12
|
|
@@ -14,30 +14,6 @@ with contextlib.suppress(ImportError):
|
|
14
14
|
from matplotlib.figure import Figure
|
15
15
|
|
16
16
|
|
17
|
-
def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
|
18
|
-
"""
|
19
|
-
Returns columnwise unique counts for discrete data.
|
20
|
-
|
21
|
-
Parameters
|
22
|
-
----------
|
23
|
-
data : NDArray
|
24
|
-
Array containing integer values for metadata factors
|
25
|
-
min_num_bins : int | None, default None
|
26
|
-
Minimum number of bins for bincount, helps force consistency across runs
|
27
|
-
|
28
|
-
Returns
|
29
|
-
-------
|
30
|
-
NDArray[np.int_]
|
31
|
-
Bin counts per column of data.
|
32
|
-
"""
|
33
|
-
max_value = data.max() + 1 if min_num_bins is None else min_num_bins
|
34
|
-
cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
|
35
|
-
for idx in range(data.shape[1]):
|
36
|
-
cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
|
37
|
-
|
38
|
-
return cnt_array
|
39
|
-
|
40
|
-
|
41
17
|
def heatmap(
|
42
18
|
data: ArrayLike,
|
43
19
|
row_labels: list[str] | ArrayLike,
|
@@ -95,12 +71,17 @@ def heatmap(
|
|
95
71
|
# Rotate the tick labels and set their alignment.
|
96
72
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
97
73
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
74
|
+
light_gray = "0.9"
|
75
|
+
# Turn spines on and create light gray easily visible grid.
|
76
|
+
for spine in ax.spines.values():
|
77
|
+
spine.set_visible(True)
|
78
|
+
spine.set_color(light_gray)
|
79
|
+
|
80
|
+
xticks = np.arange(np_data.shape[1] + 1) - 0.5
|
81
|
+
yticks = np.arange(np_data.shape[0] + 1) - 0.5
|
82
|
+
ax.set_xticks(xticks, minor=True)
|
83
|
+
ax.set_yticks(yticks, minor=True)
|
84
|
+
ax.grid(which="minor", color=light_gray, linestyle="-", linewidth=3)
|
104
85
|
ax.tick_params(which="minor", bottom=False, left=False)
|
105
86
|
|
106
87
|
if xlabel:
|
@@ -151,79 +132,118 @@ def format_text(*args: str) -> str:
|
|
151
132
|
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
152
133
|
|
153
134
|
|
154
|
-
def
|
135
|
+
def histogram_plot(
|
136
|
+
data_dict: dict[str, Any],
|
137
|
+
log: bool = True,
|
138
|
+
xlabel: str = "values",
|
139
|
+
ylabel: str = "counts",
|
140
|
+
) -> Figure:
|
155
141
|
"""
|
156
|
-
Plots a formatted
|
142
|
+
Plots a formatted histogram
|
157
143
|
|
158
144
|
Parameters
|
159
145
|
----------
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
146
|
+
data_dict : dict
|
147
|
+
Dictionary containing the metrics and their value arrays
|
148
|
+
log : bool, default True
|
149
|
+
If True, plots the histogram on a semi-log scale (y axis)
|
150
|
+
xlabel : str, default "values"
|
151
|
+
X-axis label
|
152
|
+
ylabel : str, default "counts"
|
153
|
+
Y-axis label
|
164
154
|
|
165
155
|
Returns
|
166
156
|
-------
|
167
157
|
matplotlib.figure.Figure
|
168
|
-
|
158
|
+
Formatted plot of histograms
|
169
159
|
"""
|
170
160
|
import matplotlib.pyplot as plt
|
171
161
|
|
172
|
-
|
162
|
+
num_metrics = len(data_dict)
|
163
|
+
if num_metrics > 2:
|
164
|
+
rows = int(len(data_dict) / 3)
|
165
|
+
fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
|
166
|
+
else:
|
167
|
+
fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
|
173
168
|
|
174
|
-
ax
|
175
|
-
|
169
|
+
for ax, metric in zip(
|
170
|
+
axs.flat,
|
171
|
+
data_dict,
|
172
|
+
):
|
173
|
+
# Plot the histogram for the chosen metric
|
174
|
+
ax.hist(data_dict[metric], bins=20, log=log)
|
176
175
|
|
177
|
-
|
176
|
+
# Add labels to the histogram
|
177
|
+
ax.set_title(metric)
|
178
|
+
ax.set_ylabel(ylabel)
|
179
|
+
ax.set_xlabel(xlabel)
|
178
180
|
|
179
181
|
fig.tight_layout()
|
180
182
|
return fig
|
181
183
|
|
182
184
|
|
183
|
-
def
|
185
|
+
def channel_histogram_plot(
|
186
|
+
data_dict: dict[str, Any],
|
187
|
+
log: bool = True,
|
188
|
+
max_channels: int = 3,
|
189
|
+
ch_mask: list[bool] | None = None,
|
190
|
+
xlabel: str = "values",
|
191
|
+
ylabel: str = "counts",
|
192
|
+
) -> Figure:
|
184
193
|
"""
|
185
|
-
|
194
|
+
Plots a formatted heatmap
|
186
195
|
|
187
196
|
Parameters
|
188
197
|
----------
|
189
|
-
|
190
|
-
|
198
|
+
data_dict : dict
|
199
|
+
Dictionary containing the metrics and their value arrays
|
200
|
+
log : bool, default True
|
201
|
+
If True, plots the histogram on a semi-log scale (y axis)
|
202
|
+
xlabel : str, default "values"
|
203
|
+
X-axis label
|
204
|
+
ylabel : str, default "counts"
|
205
|
+
Y-axis label
|
191
206
|
|
192
207
|
Returns
|
193
208
|
-------
|
194
209
|
matplotlib.figure.Figure
|
195
|
-
|
210
|
+
Formatted plot of histograms
|
196
211
|
"""
|
197
212
|
import matplotlib.pyplot as plt
|
198
213
|
|
199
|
-
|
214
|
+
channelwise_metrics = ["mean", "std", "var", "skew", "zeros", "brightness", "contrast", "darkness", "entropy"]
|
215
|
+
data_keys = [key for key in data_dict if key in channelwise_metrics]
|
216
|
+
label_kwargs = {"label": [f"Channel {i}" for i in range(max_channels)]}
|
200
217
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
218
|
+
num_metrics = len(data_keys)
|
219
|
+
if num_metrics > 2:
|
220
|
+
rows = int(len(data_keys) / 3)
|
221
|
+
fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
|
205
222
|
else:
|
206
|
-
|
207
|
-
|
223
|
+
fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
|
224
|
+
|
225
|
+
for ax, metric in zip(
|
226
|
+
axs.flat,
|
227
|
+
data_keys,
|
228
|
+
):
|
229
|
+
# Plot the histogram for the chosen metric
|
230
|
+
data = data_dict[metric][ch_mask].reshape(-1, max_channels)
|
231
|
+
ax.hist(
|
232
|
+
data,
|
233
|
+
bins=20,
|
234
|
+
density=True,
|
235
|
+
log=log,
|
236
|
+
**label_kwargs,
|
208
237
|
)
|
238
|
+
# Only plot the labels once for channels
|
239
|
+
if label_kwargs:
|
240
|
+
ax.legend()
|
241
|
+
label_kwargs = {}
|
209
242
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
for j in range(3):
|
215
|
-
if j >= len(images):
|
216
|
-
continue
|
217
|
-
axs[j].imshow(images[j])
|
218
|
-
axs[j].axis("off")
|
219
|
-
else:
|
220
|
-
for i in range(rows):
|
221
|
-
for j in range(3):
|
222
|
-
i_j = i * 3 + j
|
223
|
-
if i_j >= len(images):
|
224
|
-
continue
|
225
|
-
axs[i, j].imshow(images[i_j])
|
226
|
-
axs[i, j].axis("off")
|
243
|
+
# Add labels to the histogram
|
244
|
+
ax.set_title(metric)
|
245
|
+
ax.set_ylabel(ylabel)
|
246
|
+
ax.set_xlabel(xlabel)
|
227
247
|
|
228
248
|
fig.tight_layout()
|
229
249
|
return fig
|
dataeval/utils/torch/__init__.py
CHANGED
@@ -5,21 +5,6 @@ While these metrics can take in custom models, DataEval provides utility classes
|
|
5
5
|
to create a seamless integration between custom models and DataEval's metrics.
|
6
6
|
"""
|
7
7
|
|
8
|
-
|
8
|
+
__all__ = ["models", "trainer"]
|
9
9
|
|
10
|
-
|
11
|
-
|
12
|
-
if _IS_TORCH_AVAILABLE:
|
13
|
-
from dataeval.utils.torch import models, trainer
|
14
|
-
from dataeval.utils.torch.utils import read_dataset
|
15
|
-
|
16
|
-
__all__ += ["read_dataset", "models", "trainer"]
|
17
|
-
|
18
|
-
if _IS_TORCHVISION_AVAILABLE:
|
19
|
-
from dataeval.utils.torch import datasets
|
20
|
-
|
21
|
-
__all__ += ["datasets"]
|
22
|
-
|
23
|
-
|
24
|
-
del _IS_TORCH_AVAILABLE
|
25
|
-
del _IS_TORCHVISION_AVAILABLE
|
10
|
+
from dataeval.utils.torch import models, trainer
|
dataeval/utils/torch/gmm.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
Adapted for Pytorch from
|
3
|
-
|
4
|
-
Source code derived from Alibi-Detect 0.11.4
|
2
|
+
Adapted for Pytorch from Alibi-Detect 0.11.4
|
5
3
|
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
4
|
|
7
5
|
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
@@ -10,13 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
|
|
10
8
|
|
11
9
|
from __future__ import annotations
|
12
10
|
|
11
|
+
__all__ = []
|
12
|
+
|
13
|
+
from typing import NamedTuple, TypeVar
|
14
|
+
|
13
15
|
import numpy as np
|
14
16
|
import torch
|
15
17
|
|
16
|
-
|
18
|
+
TGMMData = TypeVar("TGMMData")
|
19
|
+
|
20
|
+
|
21
|
+
class GaussianMixtureModelParams(NamedTuple):
|
22
|
+
"""
|
23
|
+
phi : torch.Tensor
|
24
|
+
Mixture component distribution weights.
|
25
|
+
mu : torch.Tensor
|
26
|
+
Mixture means.
|
27
|
+
cov : torch.Tensor
|
28
|
+
Mixture covariance.
|
29
|
+
L : torch.Tensor
|
30
|
+
Cholesky decomposition of `cov`.
|
31
|
+
log_det_cov : torch.Tensor
|
32
|
+
Log of the determinant of `cov`.
|
33
|
+
"""
|
34
|
+
|
35
|
+
phi: torch.Tensor
|
36
|
+
mu: torch.Tensor
|
37
|
+
cov: torch.Tensor
|
38
|
+
L: torch.Tensor
|
39
|
+
log_det_cov: torch.Tensor
|
17
40
|
|
18
41
|
|
19
|
-
def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams
|
42
|
+
def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams:
|
20
43
|
"""
|
21
44
|
Compute parameters of Gaussian Mixture Model.
|
22
45
|
|
@@ -58,7 +81,7 @@ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelPara
|
|
58
81
|
|
59
82
|
def gmm_energy(
|
60
83
|
z: torch.Tensor,
|
61
|
-
params: GaussianMixtureModelParams
|
84
|
+
params: GaussianMixtureModelParams,
|
62
85
|
return_mean: bool = True,
|
63
86
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
64
87
|
"""
|