dataeval 0.76.1__tar.gz → 0.82.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dataeval-0.76.1 → dataeval-0.82.0}/PKG-INFO +5 -2
- {dataeval-0.76.1 → dataeval-0.82.0}/README.md +1 -1
- {dataeval-0.76.1 → dataeval-0.82.0}/pyproject.toml +17 -5
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/__init__.py +3 -3
- dataeval-0.82.0/src/dataeval/config.py +77 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/__init__.py +1 -1
- dataeval-0.82.0/src/dataeval/detectors/drift/__init__.py +22 -0
- dataeval-0.76.1/src/dataeval/detectors/drift/base.py → dataeval-0.82.0/src/dataeval/detectors/drift/_base.py +40 -85
- dataeval-0.76.1/src/dataeval/detectors/drift/cvm.py → dataeval-0.82.0/src/dataeval/detectors/drift/_cvm.py +21 -28
- dataeval-0.76.1/src/dataeval/detectors/drift/ks.py → dataeval-0.82.0/src/dataeval/detectors/drift/_ks.py +20 -26
- dataeval-0.76.1/src/dataeval/detectors/drift/mmd.py → dataeval-0.82.0/src/dataeval/detectors/drift/_mmd.py +31 -43
- dataeval-0.76.1/src/dataeval/detectors/drift/torch.py → dataeval-0.82.0/src/dataeval/detectors/drift/_torch.py +2 -1
- dataeval-0.76.1/src/dataeval/detectors/drift/uncertainty.py → dataeval-0.82.0/src/dataeval/detectors/drift/_uncertainty.py +24 -7
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/drift/updates.py +20 -3
- dataeval-0.82.0/src/dataeval/detectors/linters/__init__.py +14 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/linters/duplicates.py +13 -36
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/linters/outliers.py +23 -148
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/__init__.py +1 -1
- dataeval-0.82.0/src/dataeval/detectors/ood/ae.py +93 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/base.py +5 -4
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/mixin.py +21 -7
- dataeval-0.76.1/src/dataeval/detectors/ood/ae.py → dataeval-0.82.0/src/dataeval/detectors/ood/vae.py +14 -13
- dataeval-0.82.0/src/dataeval/metadata/__init__.py +6 -0
- dataeval-0.82.0/src/dataeval/metadata/_distance.py +167 -0
- dataeval-0.82.0/src/dataeval/metadata/_ood.py +217 -0
- dataeval-0.82.0/src/dataeval/metadata/_utils.py +44 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/metrics/__init__.py +1 -1
- dataeval-0.82.0/src/dataeval/metrics/bias/__init__.py +23 -0
- dataeval-0.76.1/src/dataeval/metrics/bias/balance.py → dataeval-0.82.0/src/dataeval/metrics/bias/_balance.py +15 -101
- dataeval-0.82.0/src/dataeval/metrics/bias/_coverage.py +98 -0
- dataeval-0.76.1/src/dataeval/metrics/bias/diversity.py → dataeval-0.82.0/src/dataeval/metrics/bias/_diversity.py +18 -111
- dataeval-0.76.1/src/dataeval/metrics/bias/parity.py → dataeval-0.82.0/src/dataeval/metrics/bias/_parity.py +39 -77
- dataeval-0.82.0/src/dataeval/metrics/estimators/__init__.py +20 -0
- dataeval-0.76.1/src/dataeval/metrics/estimators/ber.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_ber.py +42 -29
- dataeval-0.82.0/src/dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval-0.76.1/src/dataeval/metrics/estimators/divergence.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_divergence.py +18 -30
- dataeval-0.76.1/src/dataeval/metrics/estimators/uap.py → dataeval-0.82.0/src/dataeval/metrics/estimators/_uap.py +4 -18
- dataeval-0.82.0/src/dataeval/metrics/stats/__init__.py +38 -0
- dataeval-0.76.1/src/dataeval/metrics/stats/base.py → dataeval-0.82.0/src/dataeval/metrics/stats/_base.py +82 -133
- dataeval-0.76.1/src/dataeval/metrics/stats/boxratiostats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_boxratiostats.py +15 -18
- dataeval-0.82.0/src/dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval-0.76.1/src/dataeval/metrics/stats/hashstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_hashstats.py +21 -37
- dataeval-0.82.0/src/dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval-0.82.0/src/dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval-0.76.1/src/dataeval/metrics/stats/pixelstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_pixelstats.py +19 -50
- dataeval-0.76.1/src/dataeval/metrics/stats/visualstats.py → dataeval-0.82.0/src/dataeval/metrics/stats/_visualstats.py +23 -54
- dataeval-0.82.0/src/dataeval/outputs/__init__.py +53 -0
- dataeval-0.76.1/src/dataeval/output.py → dataeval-0.82.0/src/dataeval/outputs/_base.py +55 -25
- dataeval-0.82.0/src/dataeval/outputs/_bias.py +381 -0
- dataeval-0.82.0/src/dataeval/outputs/_drift.py +83 -0
- dataeval-0.82.0/src/dataeval/outputs/_estimators.py +114 -0
- dataeval-0.82.0/src/dataeval/outputs/_linters.py +184 -0
- dataeval-0.76.1/src/dataeval/detectors/ood/output.py → dataeval-0.82.0/src/dataeval/outputs/_ood.py +22 -22
- dataeval-0.82.0/src/dataeval/outputs/_stats.py +387 -0
- dataeval-0.82.0/src/dataeval/outputs/_utils.py +44 -0
- dataeval-0.76.1/src/dataeval/workflows/sufficiency.py → dataeval-0.82.0/src/dataeval/outputs/_workflows.py +210 -418
- dataeval-0.82.0/src/dataeval/typing.py +234 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/__init__.py +2 -2
- dataeval-0.82.0/src/dataeval/utils/_array.py +169 -0
- dataeval-0.82.0/src/dataeval/utils/_bin.py +199 -0
- dataeval-0.82.0/src/dataeval/utils/_clusterer.py +144 -0
- dataeval-0.82.0/src/dataeval/utils/_fast_mst.py +189 -0
- dataeval-0.76.1/src/dataeval/utils/image.py → dataeval-0.82.0/src/dataeval/utils/_image.py +6 -4
- dataeval-0.82.0/src/dataeval/utils/_method.py +14 -0
- dataeval-0.76.1/src/dataeval/utils/shared.py → dataeval-0.82.0/src/dataeval/utils/_mst.py +3 -65
- dataeval-0.76.1/src/dataeval/utils/plot.py → dataeval-0.82.0/src/dataeval/utils/_plot.py +6 -6
- dataeval-0.82.0/src/dataeval/utils/data/__init__.py +26 -0
- dataeval-0.82.0/src/dataeval/utils/data/_dataset.py +217 -0
- dataeval-0.82.0/src/dataeval/utils/data/_embeddings.py +104 -0
- dataeval-0.82.0/src/dataeval/utils/data/_images.py +68 -0
- dataeval-0.82.0/src/dataeval/utils/data/_metadata.py +360 -0
- dataeval-0.82.0/src/dataeval/utils/data/_selection.py +126 -0
- dataeval-0.76.1/src/dataeval/utils/dataset/split.py → dataeval-0.82.0/src/dataeval/utils/data/_split.py +12 -38
- dataeval-0.82.0/src/dataeval/utils/data/_targets.py +85 -0
- dataeval-0.82.0/src/dataeval/utils/data/collate.py +103 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_base.py +254 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_types.py +52 -0
- dataeval-0.82.0/src/dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/__init__.py +15 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/_indices.py +26 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/_limit.py +26 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval-0.82.0/src/dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval-0.82.0/src/dataeval/utils/metadata.py +403 -0
- dataeval-0.76.1/src/dataeval/utils/torch/gmm.py → dataeval-0.82.0/src/dataeval/utils/torch/_gmm.py +4 -2
- dataeval-0.76.1/src/dataeval/utils/torch/internal.py → dataeval-0.82.0/src/dataeval/utils/torch/_internal.py +21 -51
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/models.py +43 -2
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/workflows/__init__.py +2 -1
- dataeval-0.82.0/src/dataeval/workflows/sufficiency.py +237 -0
- dataeval-0.76.1/src/dataeval/detectors/drift/__init__.py +0 -22
- dataeval-0.76.1/src/dataeval/detectors/linters/__init__.py +0 -16
- dataeval-0.76.1/src/dataeval/detectors/linters/clusterer.py +0 -512
- dataeval-0.76.1/src/dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval-0.76.1/src/dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval-0.76.1/src/dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval-0.76.1/src/dataeval/interop.py +0 -69
- dataeval-0.76.1/src/dataeval/metrics/bias/__init__.py +0 -21
- dataeval-0.76.1/src/dataeval/metrics/bias/coverage.py +0 -194
- dataeval-0.76.1/src/dataeval/metrics/estimators/__init__.py +0 -9
- dataeval-0.76.1/src/dataeval/metrics/stats/__init__.py +0 -35
- dataeval-0.76.1/src/dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval-0.76.1/src/dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval-0.76.1/src/dataeval/metrics/stats/labelstats.py +0 -210
- dataeval-0.76.1/src/dataeval/utils/dataset/__init__.py +0 -7
- dataeval-0.76.1/src/dataeval/utils/dataset/datasets.py +0 -412
- dataeval-0.76.1/src/dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1/src/dataeval/utils/metadata.py +0 -728
- {dataeval-0.76.1 → dataeval-0.82.0}/LICENSE.txt +0 -0
- /dataeval-0.76.1/src/dataeval/log.py → /dataeval-0.82.0/src/dataeval/_log.py +0 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/detectors/ood/metadata_ood_mi.py +0 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/py.typed +0 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/__init__.py +0 -0
- /dataeval-0.76.1/src/dataeval/utils/torch/blocks.py → /dataeval-0.82.0/src/dataeval/utils/torch/_blocks.py +0 -0
- {dataeval-0.76.1 → dataeval-0.82.0}/src/dataeval/utils/torch/trainer.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.82.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -21,7 +21,10 @@ Classifier: Programming Language :: Python :: 3.12
|
|
21
21
|
Classifier: Programming Language :: Python :: 3 :: Only
|
22
22
|
Classifier: Topic :: Scientific/Engineering
|
23
23
|
Provides-Extra: all
|
24
|
+
Requires-Dist: defusedxml (>=0.7.1)
|
25
|
+
Requires-Dist: fast_hdbscan (==0.2.0)
|
24
26
|
Requires-Dist: matplotlib (>=3.7.1) ; extra == "all"
|
27
|
+
Requires-Dist: numba (>=0.59.1)
|
25
28
|
Requires-Dist: numpy (>=1.24.2)
|
26
29
|
Requires-Dist: pandas (>=2.0) ; extra == "all"
|
27
30
|
Requires-Dist: pillow (>=10.3.0)
|
@@ -71,7 +74,7 @@ DataEval is easy to install, supports a wide range of Python versions, and is
|
|
71
74
|
compatible with many of the most popular packages in the scientific and T&E
|
72
75
|
communities.
|
73
76
|
|
74
|
-
DataEval also has native
|
77
|
+
DataEval also has native interoperability between JATIC's suite of tools when
|
75
78
|
using MAITE-compliant datasets and models.
|
76
79
|
<!-- end JATIC interop -->
|
77
80
|
|
@@ -32,7 +32,7 @@ DataEval is easy to install, supports a wide range of Python versions, and is
|
|
32
32
|
compatible with many of the most popular packages in the scientific and T&E
|
33
33
|
communities.
|
34
34
|
|
35
|
-
DataEval also has native
|
35
|
+
DataEval also has native interoperability between JATIC's suite of tools when
|
36
36
|
using MAITE-compliant datasets and models.
|
37
37
|
<!-- end JATIC interop -->
|
38
38
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "dataeval"
|
3
|
-
version = "0.
|
3
|
+
version = "0.82.0" # dynamic
|
4
4
|
description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
|
5
5
|
license = "MIT"
|
6
6
|
readme = "README.md"
|
@@ -42,6 +42,9 @@ packages = [
|
|
42
42
|
[tool.poetry.dependencies]
|
43
43
|
# required
|
44
44
|
python = ">=3.9,<3.13"
|
45
|
+
defusedxml = {version = ">=0.7.1"}
|
46
|
+
fast_hdbscan = {version = "0.2.0"} # 0.2.1 hits a bug in condense_tree comparing float to none
|
47
|
+
numba = {version = ">=0.59.1"}
|
45
48
|
numpy = {version = ">=1.24.2"}
|
46
49
|
pillow = {version = ">=10.3.0"}
|
47
50
|
requests = {version = "*"}
|
@@ -88,7 +91,7 @@ certifi = {version = ">=2024.07.04"}
|
|
88
91
|
enum_tools = {version = ">=0.12.0", extras = ["sphinx"]}
|
89
92
|
ipykernel = {version = ">=6.26.0"}
|
90
93
|
ipywidgets = {version = ">=8.1.1"}
|
91
|
-
jinja2 = {version = ">=3.1.
|
94
|
+
jinja2 = {version = ">=3.1.6"}
|
92
95
|
jupyter-client = {version = ">=8.6.0"}
|
93
96
|
jupyter-cache = {version = "*"}
|
94
97
|
myst-nb = {version = ">=1.0.0"}
|
@@ -129,6 +132,11 @@ reportMissingImports = false
|
|
129
132
|
norecursedirs = ["prototype"]
|
130
133
|
testpaths = ["tests"]
|
131
134
|
addopts = ["--pythonwarnings=ignore::DeprecationWarning", "--verbose", "--durations=20", "--durations-min=1.0"]
|
135
|
+
markers = [
|
136
|
+
"required: marks tests for required features",
|
137
|
+
"optional: marks tests for optional features",
|
138
|
+
"requires_all: marks tests that require the all extras",
|
139
|
+
]
|
132
140
|
|
133
141
|
[tool.coverage.run]
|
134
142
|
source = ["src/dataeval"]
|
@@ -143,8 +151,9 @@ exclude_also = [
|
|
143
151
|
]
|
144
152
|
include = ["*/src/dataeval/*"]
|
145
153
|
omit = [
|
146
|
-
"*/torch/
|
147
|
-
"*/
|
154
|
+
"*/torch/_blocks.py",
|
155
|
+
"*/_clusterer.py",
|
156
|
+
"*/_fast_mst.py",
|
148
157
|
]
|
149
158
|
fail_under = 90
|
150
159
|
|
@@ -178,6 +187,9 @@ per-file-ignores = { "*.ipynb" = ["E402"] }
|
|
178
187
|
[tool.ruff.lint.isort]
|
179
188
|
known-first-party = ["dataeval"]
|
180
189
|
|
190
|
+
[tool.ruff.lint.flake8-builtins]
|
191
|
+
builtins-strict-checking = false
|
192
|
+
|
181
193
|
[tool.ruff.format]
|
182
194
|
quote-style = "double"
|
183
195
|
indent-style = "space"
|
@@ -187,7 +199,7 @@ docstring-code-format = true
|
|
187
199
|
docstring-code-line-length = "dynamic"
|
188
200
|
|
189
201
|
[tool.codespell]
|
190
|
-
skip = './*env*,./prototype,./output,./docs/build,./docs/source/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html'
|
202
|
+
skip = './*env*,./prototype,./output,./docs/build,./docs/source/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html,./docs/source/*/data'
|
191
203
|
ignore-words-list = ["Hart"]
|
192
204
|
|
193
205
|
[build-system]
|
@@ -7,12 +7,12 @@ shifts that impact performance of deployed models.
|
|
7
7
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
|
-
__all__ = ["detectors", "log", "metrics", "utils", "workflows"]
|
11
|
-
__version__ = "0.
|
10
|
+
__all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
11
|
+
__version__ = "0.82.0"
|
12
12
|
|
13
13
|
import logging
|
14
14
|
|
15
|
-
from dataeval import detectors, metrics, utils, workflows
|
15
|
+
from dataeval import config, detectors, metrics, typing, utils, workflows
|
16
16
|
|
17
17
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
18
18
|
|
@@ -0,0 +1,77 @@
|
|
1
|
+
"""
|
2
|
+
Global configuration settings for DataEval.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch import device
|
11
|
+
|
12
|
+
_device: device | None = None
|
13
|
+
_processes: int | None = None
|
14
|
+
|
15
|
+
|
16
|
+
def set_device(device: str | device | int) -> None:
|
17
|
+
"""
|
18
|
+
Sets the default device to use when executing against a PyTorch backend.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
device : str or int or `torch.device`
|
23
|
+
The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
24
|
+
documentation for more information.
|
25
|
+
"""
|
26
|
+
global _device
|
27
|
+
_device = torch.device(device)
|
28
|
+
|
29
|
+
|
30
|
+
def get_device(override: str | device | int | None = None) -> torch.device:
|
31
|
+
"""
|
32
|
+
Returns the PyTorch device to use.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
override : str or int or `torch.device` or None, default None
|
37
|
+
The user specified override if provided, otherwise returns the default device.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
`torch.device`
|
42
|
+
"""
|
43
|
+
if override is None:
|
44
|
+
global _device
|
45
|
+
return torch.get_default_device() if _device is None else _device
|
46
|
+
else:
|
47
|
+
return torch.device(override)
|
48
|
+
|
49
|
+
|
50
|
+
def set_max_processes(processes: int | None) -> None:
|
51
|
+
"""
|
52
|
+
Sets the maximum number of worker processes to use when running tasks that support parallel processing.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
processes : int or None
|
57
|
+
The maximum number of worker processes to use, or None to use
|
58
|
+
`os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
|
59
|
+
to determine the number of worker processes.
|
60
|
+
"""
|
61
|
+
global _processes
|
62
|
+
_processes = processes
|
63
|
+
|
64
|
+
|
65
|
+
def get_max_processes() -> int | None:
|
66
|
+
"""
|
67
|
+
Returns the maximum number of worker processes to use when running tasks that support parallel processing.
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
int or None
|
72
|
+
The maximum number of worker processes to use, or None to use
|
73
|
+
`os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
|
74
|
+
to determine the number of worker processes.
|
75
|
+
"""
|
76
|
+
global _processes
|
77
|
+
return _processes
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
:term:`Drift` detectors identify if the statistical properties of the data has changed.
|
3
|
+
"""
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"DriftCVM",
|
7
|
+
"DriftKS",
|
8
|
+
"DriftMMD",
|
9
|
+
"DriftMMDOutput",
|
10
|
+
"DriftOutput",
|
11
|
+
"DriftUncertainty",
|
12
|
+
"preprocess_drift",
|
13
|
+
"updates",
|
14
|
+
]
|
15
|
+
|
16
|
+
from dataeval.detectors.drift import updates
|
17
|
+
from dataeval.detectors.drift._cvm import DriftCVM
|
18
|
+
from dataeval.detectors.drift._ks import DriftKS
|
19
|
+
from dataeval.detectors.drift._mmd import DriftMMD
|
20
|
+
from dataeval.detectors.drift._torch import preprocess_drift
|
21
|
+
from dataeval.detectors.drift._uncertainty import DriftUncertainty
|
22
|
+
from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
|
@@ -10,86 +10,29 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
|
14
|
-
from
|
13
|
+
import math
|
14
|
+
from abc import abstractmethod
|
15
15
|
from functools import wraps
|
16
|
-
from typing import Any, Callable, Literal, TypeVar
|
16
|
+
from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
|
17
17
|
|
18
18
|
import numpy as np
|
19
|
-
from numpy.typing import
|
19
|
+
from numpy.typing import NDArray
|
20
20
|
|
21
|
-
from dataeval.
|
22
|
-
from dataeval.
|
21
|
+
from dataeval.outputs import DriftOutput
|
22
|
+
from dataeval.outputs._base import set_metadata
|
23
|
+
from dataeval.typing import Array, ArrayLike
|
24
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
23
25
|
|
24
26
|
R = TypeVar("R")
|
25
27
|
|
26
28
|
|
27
|
-
|
29
|
+
@runtime_checkable
|
30
|
+
class UpdateStrategy(Protocol):
|
28
31
|
"""
|
29
|
-
|
30
|
-
|
31
|
-
Parameters
|
32
|
-
----------
|
33
|
-
n : int
|
34
|
-
Update with last n instances seen by the detector.
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, n: int) -> None:
|
38
|
-
self.n = n
|
39
|
-
|
40
|
-
@abstractmethod
|
41
|
-
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
42
|
-
"""Abstract implementation of update strategy"""
|
43
|
-
|
44
|
-
|
45
|
-
@dataclass(frozen=True)
|
46
|
-
class DriftBaseOutput(Output):
|
47
|
-
"""
|
48
|
-
Base output class for Drift Detector classes
|
49
|
-
|
50
|
-
Attributes
|
51
|
-
----------
|
52
|
-
is_drift : bool
|
53
|
-
Drift prediction for the images
|
54
|
-
threshold : float
|
55
|
-
Threshold after multivariate correction if needed
|
32
|
+
Protocol for reference dataset update strategy for drift detectors
|
56
33
|
"""
|
57
34
|
|
58
|
-
|
59
|
-
threshold: float
|
60
|
-
p_val: float
|
61
|
-
distance: float
|
62
|
-
|
63
|
-
|
64
|
-
@dataclass(frozen=True)
|
65
|
-
class DriftOutput(DriftBaseOutput):
|
66
|
-
"""
|
67
|
-
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
|
68
|
-
|
69
|
-
Attributes
|
70
|
-
----------
|
71
|
-
is_drift : bool
|
72
|
-
:term:`Drift` prediction for the images
|
73
|
-
threshold : float
|
74
|
-
Threshold after multivariate correction if needed
|
75
|
-
feature_drift : NDArray
|
76
|
-
Feature-level array of images detected to have drifted
|
77
|
-
feature_threshold : float
|
78
|
-
Feature-level threshold to determine drift
|
79
|
-
p_vals : NDArray
|
80
|
-
Feature-level p-values
|
81
|
-
distances : NDArray
|
82
|
-
Feature-level distances
|
83
|
-
"""
|
84
|
-
|
85
|
-
# is_drift: bool
|
86
|
-
# threshold: float
|
87
|
-
# p_val: float
|
88
|
-
# distance: float
|
89
|
-
feature_drift: NDArray[np.bool_]
|
90
|
-
feature_threshold: float
|
91
|
-
p_vals: NDArray[np.float32]
|
92
|
-
distances: NDArray[np.float32]
|
35
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
|
93
36
|
|
94
37
|
|
95
38
|
def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
|
@@ -196,7 +139,7 @@ class BaseDrift:
|
|
196
139
|
if correction not in ["bonferroni", "fdr"]:
|
197
140
|
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
198
141
|
|
199
|
-
self._x_ref =
|
142
|
+
self._x_ref = x_ref
|
200
143
|
self.x_ref_preprocessed: bool = x_ref_preprocessed
|
201
144
|
|
202
145
|
# Other attributes
|
@@ -204,25 +147,25 @@ class BaseDrift:
|
|
204
147
|
self.update_x_ref = update_x_ref
|
205
148
|
self.preprocess_fn = preprocess_fn
|
206
149
|
self.correction = correction
|
207
|
-
self.n: int = len(
|
150
|
+
self.n: int = len(x_ref)
|
208
151
|
|
209
152
|
# Ref counter for preprocessed x
|
210
153
|
self._x_refcount = 0
|
211
154
|
|
212
155
|
@property
|
213
|
-
def x_ref(self) ->
|
156
|
+
def x_ref(self) -> ArrayLike:
|
214
157
|
"""
|
215
158
|
Retrieve the reference data, applying preprocessing if not already done.
|
216
159
|
|
217
160
|
Returns
|
218
161
|
-------
|
219
|
-
|
162
|
+
ArrayLike
|
220
163
|
The reference dataset (`x_ref`), preprocessed if needed.
|
221
164
|
"""
|
222
165
|
if not self.x_ref_preprocessed:
|
223
166
|
self.x_ref_preprocessed = True
|
224
167
|
if self.preprocess_fn is not None:
|
225
|
-
self._x_ref =
|
168
|
+
self._x_ref = self.preprocess_fn(self._x_ref)
|
226
169
|
|
227
170
|
return self._x_ref
|
228
171
|
|
@@ -323,32 +266,44 @@ class BaseDriftUnivariate(BaseDrift):
|
|
323
266
|
# lazy process n_features as needed
|
324
267
|
if not isinstance(self._n_features, int):
|
325
268
|
# compute number of features for the univariate tests
|
326
|
-
|
327
|
-
|
328
|
-
self.
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
269
|
+
x_ref = (
|
270
|
+
self.x_ref
|
271
|
+
if self.preprocess_fn is None or self.x_ref_preprocessed
|
272
|
+
else self.preprocess_fn(self._x_ref[0:1])
|
273
|
+
)
|
274
|
+
# infer features from preprocessed reference data
|
275
|
+
shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
|
276
|
+
self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
|
333
277
|
|
334
278
|
return self._n_features
|
335
279
|
|
336
280
|
@preprocess_x
|
337
|
-
@abstractmethod
|
338
281
|
def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
339
282
|
"""
|
340
|
-
|
283
|
+
Calculates p-values and test statistics per feature.
|
341
284
|
|
342
285
|
Parameters
|
343
286
|
----------
|
344
287
|
x : ArrayLike
|
345
|
-
|
288
|
+
Batch of instances
|
346
289
|
|
347
290
|
Returns
|
348
291
|
-------
|
349
292
|
tuple[NDArray, NDArray]
|
350
|
-
|
293
|
+
Feature level p-values and test statistics
|
351
294
|
"""
|
295
|
+
x_np = to_numpy(x)
|
296
|
+
x_np = x_np.reshape(x_np.shape[0], -1)
|
297
|
+
x_ref_np = as_numpy(self.x_ref)
|
298
|
+
x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
|
299
|
+
p_val = np.zeros(self.n_features, dtype=np.float32)
|
300
|
+
dist = np.zeros_like(p_val)
|
301
|
+
for f in range(self.n_features):
|
302
|
+
dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
|
303
|
+
return p_val, dist
|
304
|
+
|
305
|
+
@abstractmethod
|
306
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
|
352
307
|
|
353
308
|
def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
|
354
309
|
"""
|
@@ -13,11 +13,11 @@ __all__ = []
|
|
13
13
|
from typing import Callable, Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
|
-
from numpy.typing import
|
16
|
+
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import cramervonmises_2samp
|
18
18
|
|
19
|
-
from dataeval.detectors.drift.
|
20
|
-
from dataeval.
|
19
|
+
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
+
from dataeval.typing import ArrayLike
|
21
21
|
|
22
22
|
|
23
23
|
class DriftCVM(BaseDriftUnivariate):
|
@@ -55,6 +55,21 @@ class DriftCVM(BaseDriftUnivariate):
|
|
55
55
|
Number of features used in the statistical test. No need to pass it if no
|
56
56
|
preprocessing takes place. In case of a preprocessing step, this can also
|
57
57
|
be inferred automatically but could be more expensive to compute.
|
58
|
+
|
59
|
+
Example
|
60
|
+
-------
|
61
|
+
>>> from functools import partial
|
62
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
63
|
+
|
64
|
+
Use a preprocess function to encode images before testing for drift
|
65
|
+
|
66
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
67
|
+
>>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
|
68
|
+
|
69
|
+
Test incoming images for drift
|
70
|
+
|
71
|
+
>>> drift.predict(test_images).drifted
|
72
|
+
True
|
58
73
|
"""
|
59
74
|
|
60
75
|
def __init__(
|
@@ -77,28 +92,6 @@ class DriftCVM(BaseDriftUnivariate):
|
|
77
92
|
n_features=n_features,
|
78
93
|
)
|
79
94
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
|
84
|
-
test statistic per feature.
|
85
|
-
|
86
|
-
Parameters
|
87
|
-
----------
|
88
|
-
x : ArrayLike
|
89
|
-
Batch of instances.
|
90
|
-
|
91
|
-
Returns
|
92
|
-
-------
|
93
|
-
tuple[NDArray, NDArray]
|
94
|
-
Feature level p-values and CVM statistic
|
95
|
-
"""
|
96
|
-
x_np = to_numpy(x)
|
97
|
-
x_np = x_np.reshape(x_np.shape[0], -1)
|
98
|
-
x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
|
99
|
-
p_val = np.zeros(self.n_features, dtype=np.float32)
|
100
|
-
dist = np.zeros_like(p_val)
|
101
|
-
for f in range(self.n_features):
|
102
|
-
result = cramervonmises_2samp(x_ref[:, f], x_np[:, f], method="auto")
|
103
|
-
p_val[f], dist[f] = result.pvalue, result.statistic
|
104
|
-
return p_val, dist
|
95
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
|
96
|
+
result = cramervonmises_2samp(x, y, method="auto")
|
97
|
+
return np.float32(result.statistic), np.float32(result.pvalue)
|
@@ -13,11 +13,11 @@ __all__ = []
|
|
13
13
|
from typing import Callable, Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
|
-
from numpy.typing import
|
16
|
+
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import ks_2samp
|
18
18
|
|
19
|
-
from dataeval.detectors.drift.
|
20
|
-
from dataeval.
|
19
|
+
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
+
from dataeval.typing import ArrayLike
|
21
21
|
|
22
22
|
|
23
23
|
class DriftKS(BaseDriftUnivariate):
|
@@ -58,6 +58,21 @@ class DriftKS(BaseDriftUnivariate):
|
|
58
58
|
Number of features used in the statistical test. No need to pass it if no
|
59
59
|
preprocessing takes place. In case of a preprocessing step, this can also
|
60
60
|
be inferred automatically but could be more expensive to compute.
|
61
|
+
|
62
|
+
Example
|
63
|
+
-------
|
64
|
+
>>> from functools import partial
|
65
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
66
|
+
|
67
|
+
Use a preprocess function to encode images before testing for drift
|
68
|
+
|
69
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
70
|
+
>>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
|
71
|
+
|
72
|
+
Test incoming images for drift
|
73
|
+
|
74
|
+
>>> drift.predict(test_images).drifted
|
75
|
+
True
|
61
76
|
"""
|
62
77
|
|
63
78
|
def __init__(
|
@@ -84,26 +99,5 @@ class DriftKS(BaseDriftUnivariate):
|
|
84
99
|
# Other attributes
|
85
100
|
self.alternative = alternative
|
86
101
|
|
87
|
-
|
88
|
-
|
89
|
-
"""
|
90
|
-
Compute KS scores and :term:Statistics` per feature.
|
91
|
-
|
92
|
-
Parameters
|
93
|
-
----------
|
94
|
-
x : ArrayLike
|
95
|
-
Batch of instances.
|
96
|
-
|
97
|
-
Returns
|
98
|
-
-------
|
99
|
-
tuple[NDArray, NDArray]
|
100
|
-
Feature level :term:p-values and KS statistic
|
101
|
-
"""
|
102
|
-
x = to_numpy(x)
|
103
|
-
x = x.reshape(x.shape[0], -1)
|
104
|
-
x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
|
105
|
-
p_val = np.zeros(self.n_features, dtype=np.float32)
|
106
|
-
dist = np.zeros_like(p_val)
|
107
|
-
for f in range(self.n_features):
|
108
|
-
dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
|
109
|
-
return p_val, dist
|
102
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
|
103
|
+
return ks_2samp(x, y, alternative=self.alternative, method="exact")
|
@@ -10,43 +10,16 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from dataclasses import dataclass
|
14
13
|
from typing import Callable
|
15
14
|
|
16
15
|
import torch
|
17
|
-
from numpy.typing import ArrayLike
|
18
16
|
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.detectors.drift.
|
21
|
-
from dataeval.
|
22
|
-
from dataeval.
|
23
|
-
from dataeval.
|
24
|
-
|
25
|
-
|
26
|
-
@dataclass(frozen=True)
|
27
|
-
class DriftMMDOutput(DriftBaseOutput):
|
28
|
-
"""
|
29
|
-
Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
|
30
|
-
|
31
|
-
Attributes
|
32
|
-
----------
|
33
|
-
is_drift : bool
|
34
|
-
Drift prediction for the images
|
35
|
-
threshold : float
|
36
|
-
:term:`P-Value` used for significance of the permutation test
|
37
|
-
p_val : float
|
38
|
-
P-value obtained from the permutation test
|
39
|
-
distance : float
|
40
|
-
MMD^2 between the reference and test set
|
41
|
-
distance_threshold : float
|
42
|
-
MMD^2 threshold above which drift is flagged
|
43
|
-
"""
|
44
|
-
|
45
|
-
# is_drift: bool
|
46
|
-
# threshold: float
|
47
|
-
# p_val: float
|
48
|
-
# distance: float
|
49
|
-
distance_threshold: float
|
17
|
+
from dataeval.config import get_device
|
18
|
+
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
19
|
+
from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
|
20
|
+
from dataeval.outputs import DriftMMDOutput
|
21
|
+
from dataeval.outputs._base import set_metadata
|
22
|
+
from dataeval.typing import ArrayLike
|
50
23
|
|
51
24
|
|
52
25
|
class DriftMMD(BaseDrift):
|
@@ -84,6 +57,21 @@ class DriftMMD(BaseDrift):
|
|
84
57
|
device : str | None, default None
|
85
58
|
Device type used. The default None uses the GPU and falls back on CPU.
|
86
59
|
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
|
60
|
+
|
61
|
+
Example
|
62
|
+
-------
|
63
|
+
>>> from functools import partial
|
64
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
65
|
+
|
66
|
+
Use a preprocess function to encode images before testing for drift
|
67
|
+
|
68
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
69
|
+
>>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
|
70
|
+
|
71
|
+
Test incoming images for drift
|
72
|
+
|
73
|
+
>>> drift.predict(test_images).drifted
|
74
|
+
True
|
87
75
|
"""
|
88
76
|
|
89
77
|
def __init__(
|
@@ -110,12 +98,12 @@ class DriftMMD(BaseDrift):
|
|
110
98
|
self.device: torch.device = get_device(device)
|
111
99
|
|
112
100
|
# initialize kernel
|
113
|
-
sigma_tensor = torch.
|
101
|
+
sigma_tensor = torch.as_tensor(sigma, device=self.device) if sigma is not None else None
|
114
102
|
self._kernel = GaussianRBF(sigma_tensor).to(self.device)
|
115
103
|
|
116
104
|
# compute kernel matrix for the reference data
|
117
105
|
if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
|
118
|
-
x = torch.
|
106
|
+
x = torch.as_tensor(self.x_ref, device=self.device)
|
119
107
|
self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
|
120
108
|
self._infer_sigma = False
|
121
109
|
else:
|
@@ -147,21 +135,21 @@ class DriftMMD(BaseDrift):
|
|
147
135
|
p-value obtained from the permutation test, MMD^2 between the reference and test set,
|
148
136
|
and MMD^2 threshold above which :term:`drift<Drift>` is flagged
|
149
137
|
"""
|
150
|
-
|
151
|
-
|
152
|
-
n =
|
153
|
-
kernel_mat = self._kernel_matrix(x_ref,
|
138
|
+
x_ref = torch.as_tensor(self.x_ref, device=self.device)
|
139
|
+
x_test = torch.as_tensor(x, device=self.device)
|
140
|
+
n = x_test.shape[0]
|
141
|
+
kernel_mat = self._kernel_matrix(x_ref, x_test)
|
154
142
|
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
|
155
143
|
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
|
156
|
-
mmd2_permuted = torch.
|
157
|
-
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)
|
144
|
+
mmd2_permuted = torch.tensor(
|
145
|
+
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)] * self.n_permutations,
|
146
|
+
device=self.device,
|
158
147
|
)
|
159
|
-
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
160
148
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
161
149
|
# compute distance threshold
|
162
150
|
idx_threshold = int(self.p_val * len(mmd2_permuted))
|
163
151
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
164
|
-
return p_val.
|
152
|
+
return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
|
165
153
|
|
166
154
|
@set_metadata
|
167
155
|
@preprocess_x
|