dataeval 0.76.1__tar.gz → 0.82.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. {dataeval-0.76.1 → dataeval-0.82.0}/PKG-INFO +5 -2
  2. {dataeval-0.76.1 → dataeval-0.82.0}/README.md +1 -1
  3. {dataeval-0.76.1 → dataeval-0.82.0}/pyproject.toml +17 -5
  4. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/__init__.py +3 -3
  5. dataeval-0.82.0/src/dataeval/config.py +77 -0
  6. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/__init__.py +1 -1
  7. dataeval-0.82.0/src/dataeval/detectors/drift/__init__.py +22 -0
  8. dataeval-0.76.1/src/dataeval/detectors/drift/base.py → dataeval-0.82.0/src/dataeval/detectors/drift/_base.py +40 -85
  9. dataeval-0.76.1/src/dataeval/detectors/drift/cvm.py → dataeval-0.82.0/src/dataeval/detectors/drift/_cvm.py +21 -28
  10. dataeval-0.76.1/src/dataeval/detectors/drift/ks.py → dataeval-0.82.0/src/dataeval/detectors/drift/_ks.py +20 -26
  11. dataeval-0.76.1/src/dataeval/detectors/drift/mmd.py → dataeval-0.82.0/src/dataeval/detectors/drift/_mmd.py +31 -43
  12. dataeval-0.76.1/src/dataeval/detectors/drift/torch.py → dataeval-0.82.0/src/dataeval/detectors/drift/_torch.py +2 -1
  13. dataeval-0.76.1/src/dataeval/detectors/drift/uncertainty.py → dataeval-0.82.0/src/dataeval/detectors/drift/_uncertainty.py +24 -7
  14. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/drift/updates.py +20 -3
  15. dataeval-0.82.0/src/dataeval/detectors/linters/__init__.py +14 -0
  16. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/linters/duplicates.py +13 -36
  17. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/linters/outliers.py +23 -148
  18. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/__init__.py +1 -1
  19. dataeval-0.82.0/src/dataeval/detectors/ood/ae.py +93 -0
  20. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/base.py +5 -4
  21. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/mixin.py +21 -7
  22. dataeval-0.76.1/src/dataeval/detectors/ood/ae.py → dataeval-0.82.0/src/dataeval/detectors/ood/vae.py +14 -13
  23. dataeval-0.82.0/src/dataeval/metadata/__init__.py +6 -0
  24. dataeval-0.82.0/src/dataeval/metadata/_distance.py +167 -0
  25. dataeval-0.82.0/src/dataeval/metadata/_ood.py +217 -0
  26. dataeval-0.82.0/src/dataeval/metadata/_utils.py +44 -0
  27. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/metrics/__init__.py +1 -1
  28. dataeval-0.82.0/src/dataeval/metrics/bias/__init__.py +23 -0
  29. dataeval-0.76.1/src/dataeval/metrics/bias/balance.py → dataeval-0.82.0/src/dataeval/metrics/bias/_balance.py +15 -101
  30. dataeval-0.82.0/src/dataeval/metrics/bias/_coverage.py +98 -0
  31. dataeval-0.76.1/src/dataeval/metrics/bias/diversity.py → dataeval-0.82.0/src/dataeval/metrics/bias/_diversity.py +18 -111
  32. dataeval-0.76.1/src/dataeval/metrics/bias/parity.py → dataeval-0.82.0/src/dataeval/metrics/bias/_parity.py +39 -77
  33. dataeval-0.82.0/src/dataeval/metrics/estimators/__init__.py +20 -0
  34. dataeval-0.76.1/src/dataeval/metrics/estimators/ber.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_ber.py +42 -29
  35. dataeval-0.82.0/src/dataeval/metrics/estimators/_clusterer.py +44 -0
  36. dataeval-0.76.1/src/dataeval/metrics/estimators/divergence.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_divergence.py +18 -30
  37. dataeval-0.76.1/src/dataeval/metrics/estimators/uap.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_uap.py +4 -18
  38. dataeval-0.82.0/src/dataeval/metrics/stats/__init__.py +38 -0
  39. dataeval-0.76.1/src/dataeval/metrics/stats/base.py → dataeval-0.82.0/src/dataeval/metrics/stats/_base.py +82 -133
  40. dataeval-0.76.1/src/dataeval/metrics/stats/boxratiostats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_boxratiostats.py +15 -18
  41. dataeval-0.82.0/src/dataeval/metrics/stats/_dimensionstats.py +75 -0
  42. dataeval-0.76.1/src/dataeval/metrics/stats/hashstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_hashstats.py +21 -37
  43. dataeval-0.82.0/src/dataeval/metrics/stats/_imagestats.py +94 -0
  44. dataeval-0.82.0/src/dataeval/metrics/stats/_labelstats.py +131 -0
  45. dataeval-0.76.1/src/dataeval/metrics/stats/pixelstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_pixelstats.py +19 -50
  46. dataeval-0.76.1/src/dataeval/metrics/stats/visualstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_visualstats.py +23 -54
  47. dataeval-0.82.0/src/dataeval/outputs/__init__.py +53 -0
  48. dataeval-0.76.1/src/dataeval/output.py → dataeval-0.82.0/src/dataeval/outputs/_base.py +55 -25
  49. dataeval-0.82.0/src/dataeval/outputs/_bias.py +381 -0
  50. dataeval-0.82.0/src/dataeval/outputs/_drift.py +83 -0
  51. dataeval-0.82.0/src/dataeval/outputs/_estimators.py +114 -0
  52. dataeval-0.82.0/src/dataeval/outputs/_linters.py +184 -0
  53. dataeval-0.76.1/src/dataeval/detectors/ood/output.py → dataeval-0.82.0/src/dataeval/outputs/_ood.py +22 -22
  54. dataeval-0.82.0/src/dataeval/outputs/_stats.py +387 -0
  55. dataeval-0.82.0/src/dataeval/outputs/_utils.py +44 -0
  56. dataeval-0.76.1/src/dataeval/workflows/sufficiency.py → dataeval-0.82.0/src/dataeval/outputs/_workflows.py +210 -418
  57. dataeval-0.82.0/src/dataeval/typing.py +234 -0
  58. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/__init__.py +2 -2
  59. dataeval-0.82.0/src/dataeval/utils/_array.py +169 -0
  60. dataeval-0.82.0/src/dataeval/utils/_bin.py +199 -0
  61. dataeval-0.82.0/src/dataeval/utils/_clusterer.py +144 -0
  62. dataeval-0.82.0/src/dataeval/utils/_fast_mst.py +189 -0
  63. dataeval-0.76.1/src/dataeval/utils/image.py → dataeval-0.82.0/src/dataeval/utils/_image.py +6 -4
  64. dataeval-0.82.0/src/dataeval/utils/_method.py +14 -0
  65. dataeval-0.76.1/src/dataeval/utils/shared.py → dataeval-0.82.0/src/dataeval/utils/_mst.py +3 -65
  66. dataeval-0.76.1/src/dataeval/utils/plot.py → dataeval-0.82.0/src/dataeval/utils/_plot.py +6 -6
  67. dataeval-0.82.0/src/dataeval/utils/data/__init__.py +26 -0
  68. dataeval-0.82.0/src/dataeval/utils/data/_dataset.py +217 -0
  69. dataeval-0.82.0/src/dataeval/utils/data/_embeddings.py +104 -0
  70. dataeval-0.82.0/src/dataeval/utils/data/_images.py +68 -0
  71. dataeval-0.82.0/src/dataeval/utils/data/_metadata.py +360 -0
  72. dataeval-0.82.0/src/dataeval/utils/data/_selection.py +126 -0
  73. dataeval-0.76.1/src/dataeval/utils/dataset/split.py → dataeval-0.82.0/src/dataeval/utils/data/_split.py +12 -38
  74. dataeval-0.82.0/src/dataeval/utils/data/_targets.py +85 -0
  75. dataeval-0.82.0/src/dataeval/utils/data/collate.py +103 -0
  76. dataeval-0.82.0/src/dataeval/utils/data/datasets/__init__.py +17 -0
  77. dataeval-0.82.0/src/dataeval/utils/data/datasets/_base.py +254 -0
  78. dataeval-0.82.0/src/dataeval/utils/data/datasets/_cifar10.py +134 -0
  79. dataeval-0.82.0/src/dataeval/utils/data/datasets/_fileio.py +168 -0
  80. dataeval-0.82.0/src/dataeval/utils/data/datasets/_milco.py +153 -0
  81. dataeval-0.82.0/src/dataeval/utils/data/datasets/_mixin.py +56 -0
  82. dataeval-0.82.0/src/dataeval/utils/data/datasets/_mnist.py +183 -0
  83. dataeval-0.82.0/src/dataeval/utils/data/datasets/_ships.py +123 -0
  84. dataeval-0.82.0/src/dataeval/utils/data/datasets/_types.py +52 -0
  85. dataeval-0.82.0/src/dataeval/utils/data/datasets/_voc.py +352 -0
  86. dataeval-0.82.0/src/dataeval/utils/data/selections/__init__.py +15 -0
  87. dataeval-0.82.0/src/dataeval/utils/data/selections/_classfilter.py +57 -0
  88. dataeval-0.82.0/src/dataeval/utils/data/selections/_indices.py +26 -0
  89. dataeval-0.82.0/src/dataeval/utils/data/selections/_limit.py +26 -0
  90. dataeval-0.82.0/src/dataeval/utils/data/selections/_reverse.py +18 -0
  91. dataeval-0.82.0/src/dataeval/utils/data/selections/_shuffle.py +29 -0
  92. dataeval-0.82.0/src/dataeval/utils/metadata.py +403 -0
  93. dataeval-0.76.1/src/dataeval/utils/torch/gmm.py → dataeval-0.82.0/src/dataeval/utils/torch/_gmm.py +4 -2
  94. dataeval-0.76.1/src/dataeval/utils/torch/internal.py → dataeval-0.82.0/src/dataeval/utils/torch/_internal.py +21 -51
  95. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/models.py +43 -2
  96. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/workflows/__init__.py +2 -1
  97. dataeval-0.82.0/src/dataeval/workflows/sufficiency.py +237 -0
  98. dataeval-0.76.1/src/dataeval/detectors/drift/__init__.py +0 -22
  99. dataeval-0.76.1/src/dataeval/detectors/linters/__init__.py +0 -16
  100. dataeval-0.76.1/src/dataeval/detectors/linters/clusterer.py +0 -512
  101. dataeval-0.76.1/src/dataeval/detectors/linters/merged_stats.py +0 -49
  102. dataeval-0.76.1/src/dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  103. dataeval-0.76.1/src/dataeval/detectors/ood/metadata_least_likely.py +0 -119
  104. dataeval-0.76.1/src/dataeval/interop.py +0 -69
  105. dataeval-0.76.1/src/dataeval/metrics/bias/__init__.py +0 -21
  106. dataeval-0.76.1/src/dataeval/metrics/bias/coverage.py +0 -194
  107. dataeval-0.76.1/src/dataeval/metrics/estimators/__init__.py +0 -9
  108. dataeval-0.76.1/src/dataeval/metrics/stats/__init__.py +0 -35
  109. dataeval-0.76.1/src/dataeval/metrics/stats/datasetstats.py +0 -202
  110. dataeval-0.76.1/src/dataeval/metrics/stats/dimensionstats.py +0 -115
  111. dataeval-0.76.1/src/dataeval/metrics/stats/labelstats.py +0 -210
  112. dataeval-0.76.1/src/dataeval/utils/dataset/__init__.py +0 -7
  113. dataeval-0.76.1/src/dataeval/utils/dataset/datasets.py +0 -412
  114. dataeval-0.76.1/src/dataeval/utils/dataset/read.py +0 -63
  115. dataeval-0.76.1/src/dataeval/utils/metadata.py +0 -728
  116. {dataeval-0.76.1 → dataeval-0.82.0}/LICENSE.txt +0 -0
  117. /dataeval-0.76.1/src/dataeval/log.py → /dataeval-0.82.0/src/dataeval/_log.py +0 -0
  118. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/metadata_ood_mi.py +0 -0
  119. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/py.typed +0 -0
  120. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/__init__.py +0 -0
  121. /dataeval-0.76.1/src/dataeval/utils/torch/blocks.py → /dataeval-0.82.0/src/dataeval/utils/torch/_blocks.py +0 -0
  122. {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/trainer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.76.1
3
+ Version: 0.82.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -21,7 +21,10 @@ Classifier: Programming Language :: Python :: 3.12
21
21
  Classifier: Programming Language :: Python :: 3 :: Only
22
22
  Classifier: Topic :: Scientific/Engineering
23
23
  Provides-Extra: all
24
+ Requires-Dist: defusedxml (>=0.7.1)
25
+ Requires-Dist: fast_hdbscan (==0.2.0)
24
26
  Requires-Dist: matplotlib (>=3.7.1) ; extra == "all"
27
+ Requires-Dist: numba (>=0.59.1)
25
28
  Requires-Dist: numpy (>=1.24.2)
26
29
  Requires-Dist: pandas (>=2.0) ; extra == "all"
27
30
  Requires-Dist: pillow (>=10.3.0)
@@ -71,7 +74,7 @@ DataEval is easy to install, supports a wide range of Python versions, and is
71
74
  compatible with many of the most popular packages in the scientific and T&E
72
75
  communities.
73
76
 
74
- DataEval also has native interopability between JATIC's suite of tools when
77
+ DataEval also has native interoperability between JATIC's suite of tools when
75
78
  using MAITE-compliant datasets and models.
76
79
  <!-- end JATIC interop -->
77
80
 
@@ -32,7 +32,7 @@ DataEval is easy to install, supports a wide range of Python versions, and is
32
32
  compatible with many of the most popular packages in the scientific and T&E
33
33
  communities.
34
34
 
35
- DataEval also has native interopability between JATIC's suite of tools when
35
+ DataEval also has native interoperability between JATIC's suite of tools when
36
36
  using MAITE-compliant datasets and models.
37
37
  <!-- end JATIC interop -->
38
38
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.76.1" # dynamic
3
+ version = "0.82.0" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -42,6 +42,9 @@ packages = [
42
42
  [tool.poetry.dependencies]
43
43
  # required
44
44
  python = ">=3.9,<3.13"
45
+ defusedxml = {version = ">=0.7.1"}
46
+ fast_hdbscan = {version = "0.2.0"} # 0.2.1 hits a bug in condense_tree comparing float to none
47
+ numba = {version = ">=0.59.1"}
45
48
  numpy = {version = ">=1.24.2"}
46
49
  pillow = {version = ">=10.3.0"}
47
50
  requests = {version = "*"}
@@ -88,7 +91,7 @@ certifi = {version = ">=2024.07.04"}
88
91
  enum_tools = {version = ">=0.12.0", extras = ["sphinx"]}
89
92
  ipykernel = {version = ">=6.26.0"}
90
93
  ipywidgets = {version = ">=8.1.1"}
91
- jinja2 = {version = ">=3.1.5"}
94
+ jinja2 = {version = ">=3.1.6"}
92
95
  jupyter-client = {version = ">=8.6.0"}
93
96
  jupyter-cache = {version = "*"}
94
97
  myst-nb = {version = ">=1.0.0"}
@@ -129,6 +132,11 @@ reportMissingImports = false
129
132
  norecursedirs = ["prototype"]
130
133
  testpaths = ["tests"]
131
134
  addopts = ["--pythonwarnings=ignore::DeprecationWarning", "--verbose", "--durations=20", "--durations-min=1.0"]
135
+ markers = [
136
+ "required: marks tests for required features",
137
+ "optional: marks tests for optional features",
138
+ "requires_all: marks tests that require the all extras",
139
+ ]
132
140
 
133
141
  [tool.coverage.run]
134
142
  source = ["src/dataeval"]
@@ -143,8 +151,9 @@ exclude_also = [
143
151
  ]
144
152
  include = ["*/src/dataeval/*"]
145
153
  omit = [
146
- "*/torch/blocks.py",
147
- "*/torch/utils.py",
154
+ "*/torch/_blocks.py",
155
+ "*/_clusterer.py",
156
+ "*/_fast_mst.py",
148
157
  ]
149
158
  fail_under = 90
150
159
 
@@ -178,6 +187,9 @@ per-file-ignores = { "*.ipynb" = ["E402"] }
178
187
  [tool.ruff.lint.isort]
179
188
  known-first-party = ["dataeval"]
180
189
 
190
+ [tool.ruff.lint.flake8-builtins]
191
+ builtins-strict-checking = false
192
+
181
193
  [tool.ruff.format]
182
194
  quote-style = "double"
183
195
  indent-style = "space"
@@ -187,7 +199,7 @@ docstring-code-format = true
187
199
  docstring-code-line-length = "dynamic"
188
200
 
189
201
  [tool.codespell]
190
- skip = './*env*,./prototype,./output,./docs/build,./docs/source/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html'
202
+ skip = './*env*,./prototype,./output,./docs/build,./docs/source/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html,./docs/source/*/data'
191
203
  ignore-words-list = ["Hart"]
192
204
 
193
205
  [build-system]
@@ -7,12 +7,12 @@ shifts that impact performance of deployed models.
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- __all__ = ["detectors", "log", "metrics", "utils", "workflows"]
11
- __version__ = "0.76.1"
10
+ __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
+ __version__ = "0.82.0"
12
12
 
13
13
  import logging
14
14
 
15
- from dataeval import detectors, metrics, utils, workflows
15
+ from dataeval import config, detectors, metrics, typing, utils, workflows
16
16
 
17
17
  logging.getLogger(__name__).addHandler(logging.NullHandler())
18
18
 
@@ -0,0 +1,77 @@
1
+ """
2
+ Global configuration settings for DataEval.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
8
+
9
+ import torch
10
+ from torch import device
11
+
12
+ _device: device | None = None
13
+ _processes: int | None = None
14
+
15
+
16
+ def set_device(device: str | device | int) -> None:
17
+ """
18
+ Sets the default device to use when executing against a PyTorch backend.
19
+
20
+ Parameters
21
+ ----------
22
+ device : str or int or `torch.device`
23
+ The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
+ documentation for more information.
25
+ """
26
+ global _device
27
+ _device = torch.device(device)
28
+
29
+
30
+ def get_device(override: str | device | int | None = None) -> torch.device:
31
+ """
32
+ Returns the PyTorch device to use.
33
+
34
+ Parameters
35
+ ----------
36
+ override : str or int or `torch.device` or None, default None
37
+ The user specified override if provided, otherwise returns the default device.
38
+
39
+ Returns
40
+ -------
41
+ `torch.device`
42
+ """
43
+ if override is None:
44
+ global _device
45
+ return torch.get_default_device() if _device is None else _device
46
+ else:
47
+ return torch.device(override)
48
+
49
+
50
+ def set_max_processes(processes: int | None) -> None:
51
+ """
52
+ Sets the maximum number of worker processes to use when running tasks that support parallel processing.
53
+
54
+ Parameters
55
+ ----------
56
+ processes : int or None
57
+ The maximum number of worker processes to use, or None to use
58
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
59
+ to determine the number of worker processes.
60
+ """
61
+ global _processes
62
+ _processes = processes
63
+
64
+
65
+ def get_max_processes() -> int | None:
66
+ """
67
+ Returns the maximum number of worker processes to use when running tasks that support parallel processing.
68
+
69
+ Returns
70
+ -------
71
+ int or None
72
+ The maximum number of worker processes to use, or None to use
73
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
74
+ to determine the number of worker processes.
75
+ """
76
+ global _processes
77
+ return _processes
@@ -4,4 +4,4 @@ Detectors can determine if a dataset or individual images in a dataset are indic
4
4
 
5
5
  __all__ = ["drift", "linters", "ood"]
6
6
 
7
- from dataeval.detectors import drift, linters, ood
7
+ from . import drift, linters, ood
@@ -0,0 +1,22 @@
1
+ """
2
+ :term:`Drift` detectors identify if the statistical properties of the data has changed.
3
+ """
4
+
5
+ __all__ = [
6
+ "DriftCVM",
7
+ "DriftKS",
8
+ "DriftMMD",
9
+ "DriftMMDOutput",
10
+ "DriftOutput",
11
+ "DriftUncertainty",
12
+ "preprocess_drift",
13
+ "updates",
14
+ ]
15
+
16
+ from dataeval.detectors.drift import updates
17
+ from dataeval.detectors.drift._cvm import DriftCVM
18
+ from dataeval.detectors.drift._ks import DriftKS
19
+ from dataeval.detectors.drift._mmd import DriftMMD
20
+ from dataeval.detectors.drift._torch import preprocess_drift
21
+ from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
+ from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
@@ -10,86 +10,29 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from abc import ABC, abstractmethod
14
- from dataclasses import dataclass
13
+ import math
14
+ from abc import abstractmethod
15
15
  from functools import wraps
16
- from typing import Any, Callable, Literal, TypeVar
16
+ from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
17
17
 
18
18
  import numpy as np
19
- from numpy.typing import ArrayLike, NDArray
19
+ from numpy.typing import NDArray
20
20
 
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import Output, set_metadata
21
+ from dataeval.outputs import DriftOutput
22
+ from dataeval.outputs._base import set_metadata
23
+ from dataeval.typing import Array, ArrayLike
24
+ from dataeval.utils._array import as_numpy, to_numpy
23
25
 
24
26
  R = TypeVar("R")
25
27
 
26
28
 
27
- class UpdateStrategy(ABC):
29
+ @runtime_checkable
30
+ class UpdateStrategy(Protocol):
28
31
  """
29
- Updates reference dataset for drift detector
30
-
31
- Parameters
32
- ----------
33
- n : int
34
- Update with last n instances seen by the detector.
35
- """
36
-
37
- def __init__(self, n: int) -> None:
38
- self.n = n
39
-
40
- @abstractmethod
41
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
42
- """Abstract implementation of update strategy"""
43
-
44
-
45
- @dataclass(frozen=True)
46
- class DriftBaseOutput(Output):
47
- """
48
- Base output class for Drift Detector classes
49
-
50
- Attributes
51
- ----------
52
- is_drift : bool
53
- Drift prediction for the images
54
- threshold : float
55
- Threshold after multivariate correction if needed
32
+ Protocol for reference dataset update strategy for drift detectors
56
33
  """
57
34
 
58
- is_drift: bool
59
- threshold: float
60
- p_val: float
61
- distance: float
62
-
63
-
64
- @dataclass(frozen=True)
65
- class DriftOutput(DriftBaseOutput):
66
- """
67
- Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
68
-
69
- Attributes
70
- ----------
71
- is_drift : bool
72
- :term:`Drift` prediction for the images
73
- threshold : float
74
- Threshold after multivariate correction if needed
75
- feature_drift : NDArray
76
- Feature-level array of images detected to have drifted
77
- feature_threshold : float
78
- Feature-level threshold to determine drift
79
- p_vals : NDArray
80
- Feature-level p-values
81
- distances : NDArray
82
- Feature-level distances
83
- """
84
-
85
- # is_drift: bool
86
- # threshold: float
87
- # p_val: float
88
- # distance: float
89
- feature_drift: NDArray[np.bool_]
90
- feature_threshold: float
91
- p_vals: NDArray[np.float32]
92
- distances: NDArray[np.float32]
35
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
93
36
 
94
37
 
95
38
  def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
@@ -196,7 +139,7 @@ class BaseDrift:
196
139
  if correction not in ["bonferroni", "fdr"]:
197
140
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
141
 
199
- self._x_ref = as_numpy(x_ref)
142
+ self._x_ref = x_ref
200
143
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
144
 
202
145
  # Other attributes
@@ -204,25 +147,25 @@ class BaseDrift:
204
147
  self.update_x_ref = update_x_ref
205
148
  self.preprocess_fn = preprocess_fn
206
149
  self.correction = correction
207
- self.n: int = len(self._x_ref)
150
+ self.n: int = len(x_ref)
208
151
 
209
152
  # Ref counter for preprocessed x
210
153
  self._x_refcount = 0
211
154
 
212
155
  @property
213
- def x_ref(self) -> NDArray[Any]:
156
+ def x_ref(self) -> ArrayLike:
214
157
  """
215
158
  Retrieve the reference data, applying preprocessing if not already done.
216
159
 
217
160
  Returns
218
161
  -------
219
- NDArray
162
+ ArrayLike
220
163
  The reference dataset (`x_ref`), preprocessed if needed.
221
164
  """
222
165
  if not self.x_ref_preprocessed:
223
166
  self.x_ref_preprocessed = True
224
167
  if self.preprocess_fn is not None:
225
- self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
168
+ self._x_ref = self.preprocess_fn(self._x_ref)
226
169
 
227
170
  return self._x_ref
228
171
 
@@ -323,32 +266,44 @@ class BaseDriftUnivariate(BaseDrift):
323
266
  # lazy process n_features as needed
324
267
  if not isinstance(self._n_features, int):
325
268
  # compute number of features for the univariate tests
326
- if not isinstance(self.preprocess_fn, Callable) or self.x_ref_preprocessed:
327
- # infer features from preprocessed reference data
328
- self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
329
- else:
330
- # infer number of features after applying preprocessing step
331
- x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
332
- self._n_features = x.reshape(x.shape[0], -1).shape[-1]
269
+ x_ref = (
270
+ self.x_ref
271
+ if self.preprocess_fn is None or self.x_ref_preprocessed
272
+ else self.preprocess_fn(self._x_ref[0:1])
273
+ )
274
+ # infer features from preprocessed reference data
275
+ shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
276
+ self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
333
277
 
334
278
  return self._n_features
335
279
 
336
280
  @preprocess_x
337
- @abstractmethod
338
281
  def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
339
282
  """
340
- Abstract method to calculate feature scores after preprocessing.
283
+ Calculates p-values and test statistics per feature.
341
284
 
342
285
  Parameters
343
286
  ----------
344
287
  x : ArrayLike
345
- The batch of data to calculate univariate :term:`drift<Drift>` scores for each feature.
288
+ Batch of instances
346
289
 
347
290
  Returns
348
291
  -------
349
292
  tuple[NDArray, NDArray]
350
- A tuple containing p-values and distance :term:`statistics<Statistics>` for each feature.
293
+ Feature level p-values and test statistics
351
294
  """
295
+ x_np = to_numpy(x)
296
+ x_np = x_np.reshape(x_np.shape[0], -1)
297
+ x_ref_np = as_numpy(self.x_ref)
298
+ x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
299
+ p_val = np.zeros(self.n_features, dtype=np.float32)
300
+ dist = np.zeros_like(p_val)
301
+ for f in range(self.n_features):
302
+ dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
303
+ return p_val, dist
304
+
305
+ @abstractmethod
306
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
352
307
 
353
308
  def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
354
309
  """
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftCVM(BaseDriftUnivariate):
@@ -55,6 +55,21 @@ class DriftCVM(BaseDriftUnivariate):
55
55
  Number of features used in the statistical test. No need to pass it if no
56
56
  preprocessing takes place. In case of a preprocessing step, this can also
57
57
  be inferred automatically but could be more expensive to compute.
58
+
59
+ Example
60
+ -------
61
+ >>> from functools import partial
62
+ >>> from dataeval.detectors.drift import preprocess_drift
63
+
64
+ Use a preprocess function to encode images before testing for drift
65
+
66
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
67
+ >>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
68
+
69
+ Test incoming images for drift
70
+
71
+ >>> drift.predict(test_images).drifted
72
+ True
58
73
  """
59
74
 
60
75
  def __init__(
@@ -77,28 +92,6 @@ class DriftCVM(BaseDriftUnivariate):
77
92
  n_features=n_features,
78
93
  )
79
94
 
80
- @preprocess_x
81
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
82
- """
83
- Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
84
- test statistic per feature.
85
-
86
- Parameters
87
- ----------
88
- x : ArrayLike
89
- Batch of instances.
90
-
91
- Returns
92
- -------
93
- tuple[NDArray, NDArray]
94
- Feature level p-values and CVM statistic
95
- """
96
- x_np = to_numpy(x)
97
- x_np = x_np.reshape(x_np.shape[0], -1)
98
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
99
- p_val = np.zeros(self.n_features, dtype=np.float32)
100
- dist = np.zeros_like(p_val)
101
- for f in range(self.n_features):
102
- result = cramervonmises_2samp(x_ref[:, f], x_np[:, f], method="auto")
103
- p_val[f], dist[f] = result.pvalue, result.statistic
104
- return p_val, dist
95
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
96
+ result = cramervonmises_2samp(x, y, method="auto")
97
+ return np.float32(result.statistic), np.float32(result.pvalue)
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftKS(BaseDriftUnivariate):
@@ -58,6 +58,21 @@ class DriftKS(BaseDriftUnivariate):
58
58
  Number of features used in the statistical test. No need to pass it if no
59
59
  preprocessing takes place. In case of a preprocessing step, this can also
60
60
  be inferred automatically but could be more expensive to compute.
61
+
62
+ Example
63
+ -------
64
+ >>> from functools import partial
65
+ >>> from dataeval.detectors.drift import preprocess_drift
66
+
67
+ Use a preprocess function to encode images before testing for drift
68
+
69
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
+ >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
71
+
72
+ Test incoming images for drift
73
+
74
+ >>> drift.predict(test_images).drifted
75
+ True
61
76
  """
62
77
 
63
78
  def __init__(
@@ -84,26 +99,5 @@ class DriftKS(BaseDriftUnivariate):
84
99
  # Other attributes
85
100
  self.alternative = alternative
86
101
 
87
- @preprocess_x
88
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
89
- """
90
- Compute KS scores and :term:Statistics` per feature.
91
-
92
- Parameters
93
- ----------
94
- x : ArrayLike
95
- Batch of instances.
96
-
97
- Returns
98
- -------
99
- tuple[NDArray, NDArray]
100
- Feature level :term:p-values and KS statistic
101
- """
102
- x = to_numpy(x)
103
- x = x.reshape(x.shape[0], -1)
104
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
105
- p_val = np.zeros(self.n_features, dtype=np.float32)
106
- dist = np.zeros_like(p_val)
107
- for f in range(self.n_features):
108
- dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
109
- return p_val, dist
102
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
103
+ return ks_2samp(x, y, alternative=self.alternative, method="exact")
@@ -10,43 +10,16 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from dataclasses import dataclass
14
13
  from typing import Callable
15
14
 
16
15
  import torch
17
- from numpy.typing import ArrayLike
18
16
 
19
- from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
- from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import set_metadata
23
- from dataeval.utils.torch.internal import get_device
24
-
25
-
26
- @dataclass(frozen=True)
27
- class DriftMMDOutput(DriftBaseOutput):
28
- """
29
- Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
30
-
31
- Attributes
32
- ----------
33
- is_drift : bool
34
- Drift prediction for the images
35
- threshold : float
36
- :term:`P-Value` used for significance of the permutation test
37
- p_val : float
38
- P-value obtained from the permutation test
39
- distance : float
40
- MMD^2 between the reference and test set
41
- distance_threshold : float
42
- MMD^2 threshold above which drift is flagged
43
- """
44
-
45
- # is_drift: bool
46
- # threshold: float
47
- # p_val: float
48
- # distance: float
49
- distance_threshold: float
17
+ from dataeval.config import get_device
18
+ from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
19
+ from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
20
+ from dataeval.outputs import DriftMMDOutput
21
+ from dataeval.outputs._base import set_metadata
22
+ from dataeval.typing import ArrayLike
50
23
 
51
24
 
52
25
  class DriftMMD(BaseDrift):
@@ -84,6 +57,21 @@ class DriftMMD(BaseDrift):
84
57
  device : str | None, default None
85
58
  Device type used. The default None uses the GPU and falls back on CPU.
86
59
  Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
60
+
61
+ Example
62
+ -------
63
+ >>> from functools import partial
64
+ >>> from dataeval.detectors.drift import preprocess_drift
65
+
66
+ Use a preprocess function to encode images before testing for drift
67
+
68
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
69
+ >>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
70
+
71
+ Test incoming images for drift
72
+
73
+ >>> drift.predict(test_images).drifted
74
+ True
87
75
  """
88
76
 
89
77
  def __init__(
@@ -110,12 +98,12 @@ class DriftMMD(BaseDrift):
110
98
  self.device: torch.device = get_device(device)
111
99
 
112
100
  # initialize kernel
113
- sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
101
+ sigma_tensor = torch.as_tensor(sigma, device=self.device) if sigma is not None else None
114
102
  self._kernel = GaussianRBF(sigma_tensor).to(self.device)
115
103
 
116
104
  # compute kernel matrix for the reference data
117
105
  if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
118
- x = torch.from_numpy(self.x_ref).to(self.device)
106
+ x = torch.as_tensor(self.x_ref, device=self.device)
119
107
  self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
120
108
  self._infer_sigma = False
121
109
  else:
@@ -147,21 +135,21 @@ class DriftMMD(BaseDrift):
147
135
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
136
  and MMD^2 threshold above which :term:`drift<Drift>` is flagged
149
137
  """
150
- x = as_numpy(x)
151
- x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
- n = x.shape[0]
153
- kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
138
+ x_ref = torch.as_tensor(self.x_ref, device=self.device)
139
+ x_test = torch.as_tensor(x, device=self.device)
140
+ n = x_test.shape[0]
141
+ kernel_mat = self._kernel_matrix(x_ref, x_test)
154
142
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
155
143
  mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
156
- mmd2_permuted = torch.Tensor(
157
- [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
144
+ mmd2_permuted = torch.tensor(
145
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)] * self.n_permutations,
146
+ device=self.device,
158
147
  )
159
- mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
160
148
  p_val = (mmd2 <= mmd2_permuted).float().mean()
161
149
  # compute distance threshold
162
150
  idx_threshold = int(self.p_val * len(mmd2_permuted))
163
151
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
164
- return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
152
+ return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
165
153
 
166
154
  @set_metadata
167
155
  @preprocess_x