dataeval 0.74.2__py3-none-any.whl → 0.75.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +27 -23
  2. dataeval/detectors/__init__.py +2 -2
  3. dataeval/detectors/drift/__init__.py +14 -12
  4. dataeval/detectors/drift/base.py +1 -1
  5. dataeval/detectors/drift/cvm.py +1 -1
  6. dataeval/detectors/drift/ks.py +1 -1
  7. dataeval/detectors/drift/mmd.py +6 -5
  8. dataeval/detectors/drift/torch.py +12 -12
  9. dataeval/detectors/drift/uncertainty.py +3 -2
  10. dataeval/detectors/linters/__init__.py +4 -4
  11. dataeval/detectors/linters/clusterer.py +2 -7
  12. dataeval/detectors/linters/duplicates.py +6 -10
  13. dataeval/detectors/linters/outliers.py +4 -2
  14. dataeval/detectors/ood/__init__.py +3 -10
  15. dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
  16. dataeval/detectors/ood/base.py +64 -161
  17. dataeval/detectors/ood/metadata_ks_compare.py +34 -42
  18. dataeval/detectors/ood/metadata_least_likely.py +3 -3
  19. dataeval/detectors/ood/metadata_ood_mi.py +6 -5
  20. dataeval/detectors/ood/mixin.py +146 -0
  21. dataeval/detectors/ood/output.py +63 -0
  22. dataeval/interop.py +6 -5
  23. dataeval/{logging.py → log.py} +2 -0
  24. dataeval/metrics/__init__.py +2 -2
  25. dataeval/metrics/bias/__init__.py +9 -12
  26. dataeval/metrics/bias/balance.py +10 -8
  27. dataeval/metrics/bias/coverage.py +52 -4
  28. dataeval/metrics/bias/diversity.py +42 -14
  29. dataeval/metrics/bias/parity.py +15 -12
  30. dataeval/metrics/estimators/__init__.py +2 -2
  31. dataeval/metrics/estimators/ber.py +3 -1
  32. dataeval/metrics/estimators/divergence.py +1 -1
  33. dataeval/metrics/estimators/uap.py +1 -1
  34. dataeval/metrics/stats/__init__.py +18 -18
  35. dataeval/metrics/stats/base.py +4 -4
  36. dataeval/metrics/stats/boxratiostats.py +8 -9
  37. dataeval/metrics/stats/datasetstats.py +10 -14
  38. dataeval/metrics/stats/dimensionstats.py +4 -4
  39. dataeval/metrics/stats/hashstats.py +12 -8
  40. dataeval/metrics/stats/labelstats.py +5 -5
  41. dataeval/metrics/stats/pixelstats.py +4 -9
  42. dataeval/metrics/stats/visualstats.py +4 -9
  43. dataeval/utils/__init__.py +4 -13
  44. dataeval/utils/dataset/__init__.py +7 -0
  45. dataeval/utils/{torch → dataset}/datasets.py +2 -0
  46. dataeval/utils/dataset/read.py +63 -0
  47. dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
  48. dataeval/utils/image.py +2 -2
  49. dataeval/utils/metadata.py +310 -5
  50. dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +1 -104
  51. dataeval/utils/torch/__init__.py +2 -17
  52. dataeval/utils/torch/gmm.py +29 -6
  53. dataeval/utils/torch/{utils.py → internal.py} +82 -58
  54. dataeval/utils/torch/models.py +10 -8
  55. dataeval/utils/torch/trainer.py +6 -85
  56. dataeval/workflows/__init__.py +2 -5
  57. dataeval/workflows/sufficiency.py +16 -6
  58. dataeval-0.75.0.dist-info/METADATA +136 -0
  59. dataeval-0.75.0.dist-info/RECORD +67 -0
  60. dataeval/detectors/ood/base_torch.py +0 -109
  61. dataeval/metrics/bias/metadata_preprocessing.py +0 -285
  62. dataeval/utils/gmm.py +0 -26
  63. dataeval-0.74.2.dist-info/METADATA +0 -120
  64. dataeval-0.74.2.dist-info/RECORD +0 -66
  65. {dataeval-0.74.2.dist-info → dataeval-0.75.0.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.74.2.dist-info → dataeval-0.75.0.dist-info}/WHEEL +0 -0
@@ -1,12 +1,27 @@
1
+ """
2
+ Metadata related utility functions that help organize raw metadata into :class:`Metadata` objects
3
+ for use within `DataEval`.
4
+ """
5
+
1
6
  from __future__ import annotations
2
7
 
3
- __all__ = ["merge_metadata"]
8
+ __all__ = ["Metadata", "preprocess", "merge"]
4
9
 
5
10
  import warnings
6
- from typing import Any, Iterable, Mapping, TypeVar, overload
11
+ from dataclasses import dataclass
12
+ from typing import Any, Iterable, Literal, Mapping, TypeVar, overload
7
13
 
8
14
  import numpy as np
9
- from numpy.typing import NDArray
15
+ from numpy.typing import ArrayLike, NDArray
16
+ from scipy.stats import wasserstein_distance as wd
17
+
18
+ from dataeval.interop import as_numpy, to_numpy
19
+ from dataeval.output import Output, set_metadata
20
+
21
+ TNum = TypeVar("TNum", int, float)
22
+ DISCRETE_MIN_WD = 0.054
23
+ CONTINUOUS_MIN_SAMPLE_SIZE = 20
24
+
10
25
 
11
26
  T = TypeVar("T")
12
27
 
@@ -185,7 +200,7 @@ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
185
200
  return set(metadata[keys[0]]) == set(metadata[keys[1]])
186
201
 
187
202
 
188
- def merge_metadata(
203
+ def merge(
189
204
  metadata: Iterable[Mapping[str, Any]],
190
205
  ignore_lists: bool = False,
191
206
  fully_qualified: bool = False,
@@ -222,7 +237,7 @@ def merge_metadata(
222
237
  Example
223
238
  -------
224
239
  >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
225
- >>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
240
+ >>> reorganized_metadata, image_indicies = merge(list_metadata)
226
241
  >>> reorganized_metadata
227
242
  {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
228
243
  >>> image_indicies
@@ -276,3 +291,293 @@ def merge_metadata(
276
291
  output[k] = np.array(cv) if as_numpy else cv
277
292
 
278
293
  return output, image_indicies
294
+
295
+
296
+ @dataclass(frozen=True)
297
+ class Metadata(Output):
298
+ """
299
+ Dataclass containing binned metadata from the :func:`preprocess` function
300
+
301
+ Attributes
302
+ ----------
303
+ discrete_factor_names : list[str]
304
+ List containing factor names for the original data that was discrete and the binned continuous data
305
+ discrete_data : NDArray[np.int]
306
+ Array containing values for the original data that was discrete and the binned continuous data
307
+ continuous_factor_names : list[str]
308
+ List containing factor names for the original continuous data
309
+ continuous_data : NDArray[np.int or np.double] | None
310
+ Array containing values for the original continuous data or None if there was no continuous data
311
+ class_labels : NDArray[np.int]
312
+ Numerical class labels for the images/objects
313
+ class_names : NDArray[Any]
314
+ Array of unique class names (for use with plotting)
315
+ total_num_factors : int
316
+ Sum of discrete_factor_names and continuous_factor_names plus 1 for class
317
+ """
318
+
319
+ discrete_factor_names: list[str]
320
+ discrete_data: NDArray[np.int_]
321
+ continuous_factor_names: list[str]
322
+ continuous_data: NDArray[np.int_ | np.double] | None
323
+ class_labels: NDArray[np.int_]
324
+ class_names: NDArray[Any]
325
+ total_num_factors: int
326
+
327
+
328
+ @set_metadata
329
+ def preprocess(
330
+ raw_metadata: Iterable[Mapping[str, Any]],
331
+ class_labels: ArrayLike | str,
332
+ continuous_factor_bins: Mapping[str, int | list[tuple[TNum, TNum]]] | None = None,
333
+ auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
334
+ exclude: Iterable[str] | None = None,
335
+ ) -> Metadata:
336
+ """
337
+ Restructures the metadata to be in the correct format for the bias functions.
338
+
339
+ This identifies whether the incoming metadata is discrete or continuous,
340
+ and whether the data is already binned or still needs binning.
341
+ It accepts a list of dictionaries containing the per image metadata and
342
+ automatically adjusts for multiple targets in an image.
343
+
344
+ Parameters
345
+ ----------
346
+ raw_metadata : Iterable[Mapping[str, Any]]
347
+ Iterable collection of metadata dictionaries to flatten and merge.
348
+ class_labels : ArrayLike or string
349
+ If arraylike, expects the labels for each image (image classification) or each object (object detection).
350
+ If the labels are included in the metadata dictionary, pass in the key value.
351
+ continuous_factor_bins : Mapping[str, int] or Mapping[str, list[tuple[TNum, TNum]]] or None, default None
352
+ User provided dictionary specifying how to bin the continuous metadata factors
353
+ auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
354
+ Method by which the function will automatically bin continuous metadata factors. It is recommended
355
+ that the user provide the bins through the `continuous_factor_bins`.
356
+ exclude : Iterable[str] or None, default None
357
+ User provided collection of metadata keys to exclude when processing metadata.
358
+
359
+ Returns
360
+ -------
361
+ Metadata
362
+ Output class containing the binned metadata
363
+ """
364
+ # Transform metadata into single, flattened dictionary
365
+ metadata, image_repeats = merge(raw_metadata)
366
+
367
+ # Drop any excluded metadata keys
368
+ if exclude:
369
+ for k in list(metadata):
370
+ if k in exclude:
371
+ metadata.pop(k)
372
+
373
+ # Get the class label array in numeric form
374
+ class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
375
+ if class_array.ndim > 1:
376
+ raise ValueError(
377
+ f"Got class labels with {class_array.ndim}-dimensional "
378
+ f"shape {class_array.shape}, but expected a 1-dimensional array."
379
+ )
380
+ if not np.issubdtype(class_array.dtype, np.int_):
381
+ unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
382
+ else:
383
+ numerical_labels = class_array
384
+ unique_classes = np.unique(class_array)
385
+
386
+ # Bin according to user supplied bins
387
+ continuous_metadata = {}
388
+ discrete_metadata = {}
389
+ if continuous_factor_bins is not None and continuous_factor_bins != {}:
390
+ invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
391
+ if invalid_keys:
392
+ raise KeyError(
393
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
394
+ "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
395
+ "or add corresponding entries to the `metadata` dictionary."
396
+ )
397
+ for factor, grouping in continuous_factor_bins.items():
398
+ discrete_metadata[factor] = _user_defined_bin(metadata[factor], grouping)
399
+ continuous_metadata[factor] = metadata[factor]
400
+
401
+ # Determine category of the rest of the keys
402
+ remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
403
+ for key in remaining_keys:
404
+ data = to_numpy(metadata[key])
405
+ if np.issubdtype(data.dtype, np.number):
406
+ result = _is_continuous(data, image_repeats)
407
+ if result:
408
+ continuous_metadata[key] = data
409
+ unique_samples, ordinal_data = np.unique(data, return_inverse=True)
410
+ if unique_samples.size <= np.max([20, data.size * 0.01]):
411
+ discrete_metadata[key] = ordinal_data
412
+ else:
413
+ warnings.warn(
414
+ f"A user defined binning was not provided for {key}. "
415
+ f"Using the {auto_bin_method} method to discretize the data. "
416
+ "It is recommended that the user rerun and supply the desired "
417
+ "bins using the continuous_factor_bins parameter.",
418
+ UserWarning,
419
+ )
420
+ discrete_metadata[key] = _binning_function(data, auto_bin_method)
421
+ else:
422
+ _, discrete_metadata[key] = np.unique(data, return_inverse=True)
423
+
424
+ # splitting out the dictionaries into the keys and values
425
+ discrete_factor_names = list(discrete_metadata.keys())
426
+ discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
427
+ continuous_factor_names = list(continuous_metadata.keys())
428
+ continuous_data = np.stack(list(continuous_metadata.values()), axis=-1) if continuous_metadata else None
429
+ total_num_factors = len(discrete_factor_names + continuous_factor_names) + 1
430
+
431
+ return Metadata(
432
+ discrete_factor_names,
433
+ discrete_data,
434
+ continuous_factor_names,
435
+ continuous_data,
436
+ numerical_labels,
437
+ unique_classes,
438
+ total_num_factors,
439
+ )
440
+
441
+
442
+ def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[TNum, TNum]]) -> NDArray[np.intp]:
443
+ """
444
+ Digitizes a list of values into a given number of bins.
445
+
446
+ Parameters
447
+ ----------
448
+ data : list | NDArray
449
+ The values to be digitized.
450
+ binning : int | list[tuple[TNum, TNum]]
451
+ The number of bins for the discrete values that data will be digitized into.
452
+
453
+ Returns
454
+ -------
455
+ NDArray[np.intp]
456
+ The digitized values
457
+ """
458
+
459
+ if not np.all([np.issubdtype(type(n), np.number) for n in data]):
460
+ raise TypeError(
461
+ "Encountered a data value with non-numeric type when digitizing a factor. "
462
+ "Ensure all occurrences of continuous factors are numeric types."
463
+ )
464
+ if type(binning) is int:
465
+ _, bin_edges = np.histogram(data, bins=binning)
466
+ bin_edges[-1] = np.inf
467
+ bin_edges[0] = -np.inf
468
+ else:
469
+ bin_edges = binning
470
+ return np.digitize(data, bin_edges)
471
+
472
+
473
+ def _binning_function(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
474
+ """
475
+ Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
476
+ """
477
+ if bin_method == "clusters":
478
+ # bin_edges = _binning_by_clusters(data)
479
+ warnings.warn(
480
+ "Binning by clusters is currently unavailable until changes to the clustering function go through.",
481
+ UserWarning,
482
+ )
483
+ bin_method = "uniform_width"
484
+
485
+ if bin_method != "clusters":
486
+ counts, bin_edges = np.histogram(data, bins="auto")
487
+ n_bins = counts.size
488
+ if counts[counts > 0].min() < 10:
489
+ for _ in range(20):
490
+ n_bins -= 1
491
+ counts, bin_edges = np.histogram(data, bins=n_bins)
492
+ if counts[counts > 0].min() >= 10 or n_bins < 2:
493
+ break
494
+
495
+ if bin_method == "uniform_count":
496
+ quantiles = np.linspace(0, 100, n_bins + 1)
497
+ bin_edges = np.asarray(np.percentile(data, quantiles))
498
+
499
+ bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
500
+ bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
501
+ return np.digitize(data, bin_edges) # type: ignore
502
+
503
+
504
+ def _is_continuous(data: NDArray[np.number], image_indicies: NDArray[np.number]) -> bool:
505
+ """
506
+ Determines whether the data is continuous or discrete using the Wasserstein distance.
507
+
508
+ Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
509
+ a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
510
+ we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
511
+ neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
512
+ ignored, depending on context). These normalized locations will be much more uniformly distributed
513
+ for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
514
+ Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
515
+
516
+ The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
517
+ how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
518
+ least 20 points, and furthermore at least half as many points as there are discrete values, we can
519
+ reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
520
+ measured from a uniform distribution is greater or less than 0.054, respectively.
521
+ """
522
+ # Check if the metadata is image specific
523
+ _, data_indicies_unsorted = np.unique(data, return_index=True)
524
+ if data_indicies_unsorted.size == image_indicies.size:
525
+ data_indicies = np.sort(data_indicies_unsorted)
526
+ if (data_indicies == image_indicies).all():
527
+ data = data[data_indicies]
528
+
529
+ # OLD METHOD
530
+ # uvals = np.unique(data)
531
+ # pct_unique = uvals.size / data.size
532
+ # return pct_unique < threshold
533
+
534
+ n_examples = len(data)
535
+
536
+ if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
537
+ warnings.warn(
538
+ f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
539
+ )
540
+ return False
541
+
542
+ # Require at least 3 unique values before bothering with NNN
543
+ xu = np.unique(data, axis=None)
544
+ if xu.size < 3:
545
+ return False
546
+
547
+ Xs = np.sort(data)
548
+
549
+ X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
550
+
551
+ dx = np.zeros(n_examples - 2) # no dx at end points
552
+ gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
553
+ dx[np.logical_not(gtz)] = 0.0
554
+
555
+ dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
556
+
557
+ shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
558
+
559
+ return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
560
+
561
+
562
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
563
+ """
564
+ Returns columnwise unique counts for discrete data.
565
+
566
+ Parameters
567
+ ----------
568
+ data : NDArray
569
+ Array containing integer values for metadata factors
570
+ min_num_bins : int | None, default None
571
+ Minimum number of bins for bincount, helps force consistency across runs
572
+
573
+ Returns
574
+ -------
575
+ NDArray[np.int_]
576
+ Bin counts per column of data.
577
+ """
578
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
579
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
580
+ for idx in range(data.shape[1]):
581
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
582
+
583
+ return cnt_array
@@ -3,10 +3,9 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import contextlib
6
- from typing import Any
7
6
 
8
7
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
8
+ from numpy.typing import ArrayLike
10
9
 
11
10
  from dataeval.interop import to_numpy
12
11
 
@@ -14,30 +13,6 @@ with contextlib.suppress(ImportError):
14
13
  from matplotlib.figure import Figure
15
14
 
16
15
 
17
- def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
18
- """
19
- Returns columnwise unique counts for discrete data.
20
-
21
- Parameters
22
- ----------
23
- data : NDArray
24
- Array containing integer values for metadata factors
25
- min_num_bins : int | None, default None
26
- Minimum number of bins for bincount, helps force consistency across runs
27
-
28
- Returns
29
- -------
30
- NDArray[np.int_]
31
- Bin counts per column of data.
32
- """
33
- max_value = data.max() + 1 if min_num_bins is None else min_num_bins
34
- cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
35
- for idx in range(data.shape[1]):
36
- cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
37
-
38
- return cnt_array
39
-
40
-
41
16
  def heatmap(
42
17
  data: ArrayLike,
43
18
  row_labels: list[str] | ArrayLike,
@@ -149,81 +124,3 @@ def format_text(*args: str) -> str:
149
124
  """
150
125
  x = args[0]
151
126
  return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
152
-
153
-
154
- def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
155
- """
156
- Plots a formatted bar plot
157
-
158
- Parameters
159
- ----------
160
- labels : NDArray
161
- Array containing the labels for each bar
162
- bar_heights : NDArray
163
- Array containing the values for each bar
164
-
165
- Returns
166
- -------
167
- matplotlib.figure.Figure
168
- Bar plot figure
169
- """
170
- import matplotlib.pyplot as plt
171
-
172
- fig, ax = plt.subplots(figsize=(10, 10))
173
-
174
- ax.bar(labels, bar_heights)
175
- ax.set_xlabel("Factors")
176
-
177
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
178
-
179
- fig.tight_layout()
180
- return fig
181
-
182
-
183
- def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
184
- """
185
- Creates a single plot of all of the provided images
186
-
187
- Parameters
188
- ----------
189
- images : NDArray
190
- Array containing only the desired images to plot
191
-
192
- Returns
193
- -------
194
- matplotlib.figure.Figure
195
- Plot of all provided images
196
- """
197
- import matplotlib.pyplot as plt
198
-
199
- num_images = min(num_images, len(images))
200
-
201
- if images.ndim == 4:
202
- images = np.moveaxis(images, 1, -1)
203
- elif images.ndim == 3:
204
- images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
205
- else:
206
- raise ValueError(
207
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
208
- )
209
-
210
- rows = int(np.ceil(num_images / 3))
211
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
212
-
213
- if rows == 1:
214
- for j in range(3):
215
- if j >= len(images):
216
- continue
217
- axs[j].imshow(images[j])
218
- axs[j].axis("off")
219
- else:
220
- for i in range(rows):
221
- for j in range(3):
222
- i_j = i * 3 + j
223
- if i_j >= len(images):
224
- continue
225
- axs[i, j].imshow(images[i_j])
226
- axs[i, j].axis("off")
227
-
228
- fig.tight_layout()
229
- return fig
@@ -5,21 +5,6 @@ While these metrics can take in custom models, DataEval provides utility classes
5
5
  to create a seamless integration between custom models and DataEval's metrics.
6
6
  """
7
7
 
8
- from dataeval import _IS_TORCH_AVAILABLE, _IS_TORCHVISION_AVAILABLE
8
+ __all__ = ["models", "trainer"]
9
9
 
10
- __all__ = []
11
-
12
- if _IS_TORCH_AVAILABLE:
13
- from dataeval.utils.torch import models, trainer
14
- from dataeval.utils.torch.utils import read_dataset
15
-
16
- __all__ += ["read_dataset", "models", "trainer"]
17
-
18
- if _IS_TORCHVISION_AVAILABLE:
19
- from dataeval.utils.torch import datasets
20
-
21
- __all__ += ["datasets"]
22
-
23
-
24
- del _IS_TORCH_AVAILABLE
25
- del _IS_TORCHVISION_AVAILABLE
10
+ from dataeval.utils.torch import models, trainer
@@ -1,7 +1,5 @@
1
1
  """
2
- Adapted for Pytorch from:
3
-
4
- Source code derived from Alibi-Detect 0.11.4
2
+ Adapted for Pytorch from Alibi-Detect 0.11.4
5
3
  https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
4
 
7
5
  Original code Copyright (c) 2023 Seldon Technologies Ltd
@@ -10,13 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
10
8
 
11
9
  from __future__ import annotations
12
10
 
11
+ __all__ = []
12
+
13
+ from typing import NamedTuple, TypeVar
14
+
13
15
  import numpy as np
14
16
  import torch
15
17
 
16
- from dataeval.utils.gmm import GaussianMixtureModelParams
18
+ TGMMData = TypeVar("TGMMData")
19
+
20
+
21
+ class GaussianMixtureModelParams(NamedTuple):
22
+ """
23
+ phi : torch.Tensor
24
+ Mixture component distribution weights.
25
+ mu : torch.Tensor
26
+ Mixture means.
27
+ cov : torch.Tensor
28
+ Mixture covariance.
29
+ L : torch.Tensor
30
+ Cholesky decomposition of `cov`.
31
+ log_det_cov : torch.Tensor
32
+ Log of the determinant of `cov`.
33
+ """
34
+
35
+ phi: torch.Tensor
36
+ mu: torch.Tensor
37
+ cov: torch.Tensor
38
+ L: torch.Tensor
39
+ log_det_cov: torch.Tensor
17
40
 
18
41
 
19
- def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams[torch.Tensor]:
42
+ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams:
20
43
  """
21
44
  Compute parameters of Gaussian Mixture Model.
22
45
 
@@ -58,7 +81,7 @@ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelPara
58
81
 
59
82
  def gmm_energy(
60
83
  z: torch.Tensor,
61
- params: GaussianMixtureModelParams[torch.Tensor],
84
+ params: GaussianMixtureModelParams,
62
85
  return_mean: bool = True,
63
86
  ) -> tuple[torch.Tensor, torch.Tensor]:
64
87
  """
@@ -1,70 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["read_dataset"]
3
+ __all__ = []
4
4
 
5
- from collections import defaultdict
6
5
  from functools import partial
7
6
  from typing import Any, Callable
8
7
 
9
8
  import numpy as np
10
9
  import torch
11
10
  from numpy.typing import NDArray
12
- from torch.utils.data import Dataset
13
-
14
-
15
- def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
16
- """
17
- Extract information from a dataset at each index into individual lists of each information position
18
-
19
- Parameters
20
- ----------
21
- dataset : torch.utils.data.Dataset
22
- Input dataset
23
-
24
- Returns
25
- -------
26
- List[List[Any]]
27
- All objects in individual lists based on return position from dataset
28
-
29
- Warning
30
- -------
31
- No type checking is done between lists or data inside lists
32
-
33
- See Also
34
- --------
35
- torch.utils.data.Dataset
36
-
37
- Examples
38
- --------
39
- >>> import numpy as np
40
- >>> data = np.ones((10, 1, 3, 3))
41
- >>> labels = np.ones((10,))
42
- >>> class ICDataset:
43
- ... def __init__(self, data, labels):
44
- ... self.data = data
45
- ... self.labels = labels
46
- ...
47
- ... def __getitem__(self, idx):
48
- ... return self.data[idx], self.labels[idx]
49
-
50
- >>> ds = ICDataset(data, labels)
51
-
52
- >>> result = read_dataset(ds)
53
- >>> len(result) # images and labels
54
- 2
55
- >>> np.asarray(result[0]).shape # images
56
- (10, 1, 3, 3)
57
- >>> np.asarray(result[1]).shape # labels
58
- (10,)
59
- """
60
-
61
- ddict: dict[int, list[Any]] = defaultdict(list[Any])
62
-
63
- for data in dataset:
64
- for i, d in enumerate(data if isinstance(data, tuple) else (data,)):
65
- ddict[i].append(d)
66
-
67
- return list(ddict.values())
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ from tqdm import tqdm
68
13
 
69
14
 
70
15
  def get_device(device: str | torch.device | None = None) -> torch.device:
@@ -167,3 +112,82 @@ def predict_batch(
167
112
  tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
168
113
  )
169
114
  return out
115
+
116
+
117
+ def trainer(
118
+ model: torch.nn.Module,
119
+ x_train: NDArray[Any],
120
+ y_train: NDArray[Any] | None,
121
+ loss_fn: Callable[..., torch.Tensor | torch.nn.Module] | None,
122
+ optimizer: torch.optim.Optimizer | None,
123
+ preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None,
124
+ epochs: int,
125
+ batch_size: int,
126
+ device: torch.device,
127
+ verbose: bool,
128
+ ) -> None:
129
+ """
130
+ Train Pytorch model.
131
+
132
+ Parameters
133
+ ----------
134
+ model
135
+ Model to train.
136
+ loss_fn
137
+ Loss function used for training.
138
+ x_train
139
+ Training data.
140
+ y_train
141
+ Training labels.
142
+ optimizer
143
+ Optimizer used for training.
144
+ preprocess_fn
145
+ Preprocessing function applied to each training batch.
146
+ epochs
147
+ Number of training epochs.
148
+ reg_loss_fn
149
+ Allows an additional regularisation term to be defined as reg_loss_fn(model)
150
+ batch_size
151
+ Batch size used for training.
152
+ buffer_size
153
+ Maximum number of elements that will be buffered when prefetching.
154
+ verbose
155
+ Whether to print training progress.
156
+ """
157
+ if optimizer is None:
158
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
159
+
160
+ if y_train is None:
161
+ dataset = TensorDataset(torch.from_numpy(x_train).to(torch.float32))
162
+
163
+ else:
164
+ dataset = TensorDataset(
165
+ torch.from_numpy(x_train).to(torch.float32), torch.from_numpy(y_train).to(torch.float32)
166
+ )
167
+
168
+ loader = DataLoader(dataset=dataset)
169
+
170
+ model = model.to(device)
171
+
172
+ # iterate over epochs
173
+ loss = torch.nan
174
+ disable_tqdm = not verbose
175
+ for epoch in (pbar := tqdm(range(epochs), disable=disable_tqdm)):
176
+ epoch_loss = loss
177
+ for step, data in enumerate(loader):
178
+ if step % 250 == 0:
179
+ pbar.set_description(f"Epoch: {epoch} ({epoch_loss:.3f}), loss: {loss:.3f}")
180
+
181
+ x, y = [d.to(device) for d in data] if len(data) > 1 else (data[0].to(device), None)
182
+
183
+ if isinstance(preprocess_fn, Callable):
184
+ x = preprocess_fn(x)
185
+
186
+ y_hat = model(x)
187
+ y = x if y is None else y
188
+
189
+ loss = loss_fn(y, y_hat) # type: ignore
190
+
191
+ optimizer.zero_grad()
192
+ loss.backward()
193
+ optimizer.step()