dataeval 0.85.0__tar.gz → 0.86.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dataeval-0.85.0 → dataeval-0.86.0}/PKG-INFO +3 -2
- {dataeval-0.85.0 → dataeval-0.86.0}/pyproject.toml +5 -3
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/__init__.py +1 -1
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_metadata.py +17 -5
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_selection.py +1 -1
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_classfilter.py +4 -3
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/__init__.py +4 -1
- dataeval-0.86.0/src/dataeval/detectors/drift/_mvdc.py +92 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/__init__.py +6 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_base.py +68 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_chunk.py +404 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_domainclassifier.py +192 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_result.py +98 -0
- dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_thresholds.py +280 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/__init__.py +2 -1
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_bias.py +1 -3
- dataeval-0.86.0/src/dataeval/outputs/_drift.py +151 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_linters.py +1 -6
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_stats.py +1 -6
- dataeval-0.85.0/src/dataeval/outputs/_drift.py +0 -83
- {dataeval-0.85.0 → dataeval-0.86.0}/LICENSE.txt +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/README.md +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/_log.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/config.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_embeddings.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_images.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_split.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_targets.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_classbalance.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_indices.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_limit.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_prioritize.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_reverse.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_shuffle.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_base.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_cvm.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_ks.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_mmd.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_uncertainty.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/updates.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/duplicates.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/outliers.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/ae.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/base.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/mixin.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/vae.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_distance.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_ood.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_utils.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_balance.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_completeness.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_coverage.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_diversity.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_parity.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_ber.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_clusterer.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_divergence.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_uap.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_base.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_boxratiostats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_dimensionstats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_hashstats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_imagestats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_labelstats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_pixelstats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_visualstats.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_base.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_estimators.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_metadata.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_ood.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_utils.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_workflows.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/py.typed +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/typing.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_array.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_bin.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_clusterer.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_fast_mst.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_image.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_method.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_mst.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_plot.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/_dataset.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/collate.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/metadata.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_base.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_cifar10.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_fileio.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_milco.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_mixin.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_mnist.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_ships.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_types.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_voc.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_blocks.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_gmm.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_internal.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/models.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/trainer.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/workflows/__init__.py +0 -0
- {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/workflows/sufficiency.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.86.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -23,10 +23,11 @@ Classifier: Topic :: Scientific/Engineering
|
|
23
23
|
Provides-Extra: all
|
24
24
|
Requires-Dist: defusedxml (>=0.7.1)
|
25
25
|
Requires-Dist: fast_hdbscan (==0.2.0)
|
26
|
+
Requires-Dist: lightgbm (>=4)
|
26
27
|
Requires-Dist: matplotlib (>=3.7.1) ; extra == "all"
|
27
28
|
Requires-Dist: numba (>=0.59.1)
|
28
29
|
Requires-Dist: numpy (>=1.24.2)
|
29
|
-
Requires-Dist: pandas (>=2.0)
|
30
|
+
Requires-Dist: pandas (>=2.0)
|
30
31
|
Requires-Dist: pillow (>=10.3.0)
|
31
32
|
Requires-Dist: requests
|
32
33
|
Requires-Dist: scikit-learn (>=1.5.0)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "dataeval"
|
3
|
-
version = "0.
|
3
|
+
version = "0.86.0" # dynamic
|
4
4
|
description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
|
5
5
|
license = "MIT"
|
6
6
|
readme = "README.md"
|
@@ -44,8 +44,10 @@ packages = [
|
|
44
44
|
python = ">=3.9,<3.13"
|
45
45
|
defusedxml = {version = ">=0.7.1"}
|
46
46
|
fast_hdbscan = {version = "0.2.0"} # 0.2.1 hits a bug in condense_tree comparing float to none
|
47
|
+
lightgbm = {version = ">=4"}
|
47
48
|
numba = {version = ">=0.59.1"}
|
48
49
|
numpy = {version = ">=1.24.2"}
|
50
|
+
pandas = {version = ">=2.0"}
|
49
51
|
pillow = {version = ">=10.3.0"}
|
50
52
|
requests = {version = "*"}
|
51
53
|
scipy = {version = ">=1.10"}
|
@@ -58,10 +60,9 @@ xxhash = {version = ">=3.3"}
|
|
58
60
|
|
59
61
|
# optional
|
60
62
|
matplotlib = {version = ">=3.7.1", optional = true}
|
61
|
-
pandas = {version = ">=2.0", optional = true}
|
62
63
|
|
63
64
|
[tool.poetry.extras]
|
64
|
-
all = ["matplotlib"
|
65
|
+
all = ["matplotlib"]
|
65
66
|
|
66
67
|
[tool.poetry.group.dev]
|
67
68
|
optional = true
|
@@ -132,6 +133,7 @@ markers = [
|
|
132
133
|
"required: marks tests for required features",
|
133
134
|
"optional: marks tests for optional features",
|
134
135
|
"requires_all: marks tests that require the all extras",
|
136
|
+
"cuda: marks tests that require cuda",
|
135
137
|
]
|
136
138
|
|
137
139
|
[tool.coverage.run]
|
@@ -191,6 +191,11 @@ class Metadata:
|
|
191
191
|
self._process()
|
192
192
|
return self._image_indices
|
193
193
|
|
194
|
+
@property
|
195
|
+
def image_count(self) -> int:
|
196
|
+
self._process()
|
197
|
+
return int(self._image_indices.max() + 1)
|
198
|
+
|
194
199
|
def _collate(self, force: bool = False):
|
195
200
|
if self._collated and not force:
|
196
201
|
return
|
@@ -359,12 +364,19 @@ class Metadata:
|
|
359
364
|
|
360
365
|
def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
|
361
366
|
self._merge()
|
362
|
-
|
363
|
-
|
364
|
-
|
367
|
+
|
368
|
+
targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
369
|
+
images = self.image_count
|
370
|
+
lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
|
371
|
+
targets_match = all(f == targets for f in lengths.values())
|
372
|
+
images_match = targets_match if images == targets else all(f == images for f in lengths.values())
|
373
|
+
if not targets_match and not images_match:
|
365
374
|
raise ValueError(
|
366
375
|
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
367
376
|
)
|
368
|
-
merged = cast(
|
377
|
+
merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
|
369
378
|
for k, v in factors.items():
|
370
|
-
|
379
|
+
v = as_numpy(v)
|
380
|
+
merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
|
381
|
+
|
382
|
+
self._processed = False
|
@@ -120,7 +120,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
120
120
|
|
121
121
|
def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
|
122
122
|
for subselection, indices in self._subselections:
|
123
|
-
datum = subselection(datum) if index in indices else datum
|
123
|
+
datum = subselection(datum) if self._selection[index] in indices else datum
|
124
124
|
return datum
|
125
125
|
|
126
126
|
def __getitem__(self, index: int) -> _TDatum:
|
@@ -10,7 +10,6 @@ from numpy.typing import NDArray
|
|
10
10
|
from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
|
11
11
|
from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
|
12
12
|
from dataeval.utils._array import as_numpy
|
13
|
-
from dataeval.utils.data.metadata import flatten
|
14
13
|
|
15
14
|
|
16
15
|
class ClassFilter(Selection[Any]):
|
@@ -96,13 +95,15 @@ class ClassFilterSubSelection(Subselection[Any]):
|
|
96
95
|
def __init__(self, classes: Sequence[int]) -> None:
|
97
96
|
self.classes = classes
|
98
97
|
|
98
|
+
def _filter(self, d: dict[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
|
99
|
+
return {k: self._filter(v, mask) if isinstance(v, dict) else _try_mask_object(v, mask) for k, v in d.items()}
|
100
|
+
|
99
101
|
def __call__(self, datum: _TDatum) -> _TDatum:
|
100
102
|
# build a mask for any arrays
|
101
103
|
image, target, metadata = datum
|
102
104
|
|
103
105
|
mask = np.isin(as_numpy(target.labels), self.classes)
|
104
|
-
|
105
|
-
filtered_metadata = {k: _try_mask_object(v, mask) for k, v in flattened_metadata.items()}
|
106
|
+
filtered_metadata = self._filter(metadata, mask)
|
106
107
|
|
107
108
|
# return a masked datum
|
108
109
|
filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
|
@@ -7,6 +7,8 @@ __all__ = [
|
|
7
7
|
"DriftKS",
|
8
8
|
"DriftMMD",
|
9
9
|
"DriftMMDOutput",
|
10
|
+
"DriftMVDC",
|
11
|
+
"DriftMVDCOutput",
|
10
12
|
"DriftOutput",
|
11
13
|
"DriftUncertainty",
|
12
14
|
"UpdateStrategy",
|
@@ -18,5 +20,6 @@ from dataeval.detectors.drift._base import UpdateStrategy
|
|
18
20
|
from dataeval.detectors.drift._cvm import DriftCVM
|
19
21
|
from dataeval.detectors.drift._ks import DriftKS
|
20
22
|
from dataeval.detectors.drift._mmd import DriftMMD
|
23
|
+
from dataeval.detectors.drift._mvdc import DriftMVDC
|
21
24
|
from dataeval.detectors.drift._uncertainty import DriftUncertainty
|
22
|
-
from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
|
25
|
+
from dataeval.outputs._drift import DriftMMDOutput, DriftMVDCOutput, DriftOutput
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
from numpy.typing import ArrayLike
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from typing import Self
|
11
|
+
else:
|
12
|
+
from typing_extensions import Self
|
13
|
+
|
14
|
+
from dataeval.detectors.drift._nml._chunk import CountBasedChunker, SizeBasedChunker
|
15
|
+
from dataeval.detectors.drift._nml._domainclassifier import DomainClassifierCalculator
|
16
|
+
from dataeval.detectors.drift._nml._thresholds import ConstantThreshold
|
17
|
+
from dataeval.outputs._drift import DriftMVDCOutput
|
18
|
+
from dataeval.utils._array import flatten
|
19
|
+
|
20
|
+
|
21
|
+
class DriftMVDC:
|
22
|
+
"""Multivariant Domain Classifier
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
n_folds : int, default 5
|
27
|
+
Number of cross-validation (CV) folds.
|
28
|
+
chunk_size : int or None, default None
|
29
|
+
Number of samples in a chunk used in CV, will get one metric & prediction per chunk.
|
30
|
+
chunk_count : int or None, default None
|
31
|
+
Number of total chunks used in CV, will get one metric & prediction per chunk.
|
32
|
+
threshold : Tuple[float, float], default (0.45, 0.65)
|
33
|
+
(lower, upper) metric bounds on roc_auc for identifying :term:`drift<Drift>`.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
n_folds: int = 5,
|
39
|
+
chunk_size: int | None = None,
|
40
|
+
chunk_count: int | None = None,
|
41
|
+
threshold: tuple[float, float] = (0.45, 0.65),
|
42
|
+
) -> None:
|
43
|
+
self.threshold: tuple[float, float] = max(0.0, min(threshold)), min(1.0, max(threshold))
|
44
|
+
chunker = (
|
45
|
+
CountBasedChunker(10 if chunk_count is None else chunk_count)
|
46
|
+
if chunk_size is None
|
47
|
+
else SizeBasedChunker(chunk_size)
|
48
|
+
)
|
49
|
+
self._calc = DomainClassifierCalculator(
|
50
|
+
cv_folds_num=n_folds,
|
51
|
+
chunker=chunker,
|
52
|
+
threshold=ConstantThreshold(lower=self.threshold[0], upper=self.threshold[1]),
|
53
|
+
)
|
54
|
+
|
55
|
+
def fit(self, x_ref: ArrayLike) -> Self:
|
56
|
+
"""
|
57
|
+
Fit the domain classifier on the training dataframe
|
58
|
+
|
59
|
+
Parameters
|
60
|
+
----------
|
61
|
+
x_ref : ArrayLike
|
62
|
+
Reference data with dim[n_samples, n_features].
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
Self
|
67
|
+
|
68
|
+
"""
|
69
|
+
# for 1D input, assume that is 1 sample: dim[1,n_features]
|
70
|
+
self.x_ref: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x_ref))))
|
71
|
+
self.n_features: int = self.x_ref.shape[-1]
|
72
|
+
self._calc.fit(self.x_ref)
|
73
|
+
return self
|
74
|
+
|
75
|
+
def predict(self, x: ArrayLike) -> DriftMVDCOutput:
|
76
|
+
"""
|
77
|
+
Perform :term:`inference<Inference>` on the test dataframe
|
78
|
+
|
79
|
+
Parameters
|
80
|
+
----------
|
81
|
+
x : ArrayLike
|
82
|
+
Test (analysis) data with dim[n_samples, n_features].
|
83
|
+
|
84
|
+
Returns
|
85
|
+
-------
|
86
|
+
DomainClassifierDriftResult
|
87
|
+
"""
|
88
|
+
self.x_test: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x))))
|
89
|
+
if self.x_test.shape[-1] != self.n_features:
|
90
|
+
raise ValueError("Reference and test embeddings have different number of features")
|
91
|
+
|
92
|
+
return self._calc.calculate(self.x_test)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
"""
|
2
|
+
Source code derived from NannyML 0.13.0
|
3
|
+
https://github.com/NannyML/nannyml/blob/main/nannyml/base.py
|
4
|
+
|
5
|
+
Licensed under Apache Software License (Apache 2.0)
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
import logging
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
from logging import Logger
|
13
|
+
from typing import Sequence
|
14
|
+
|
15
|
+
import pandas as pd
|
16
|
+
from typing_extensions import Self
|
17
|
+
|
18
|
+
from dataeval.detectors.drift._nml._chunk import Chunk, Chunker, CountBasedChunker
|
19
|
+
from dataeval.outputs._drift import DriftMVDCOutput
|
20
|
+
|
21
|
+
|
22
|
+
def _validate(data: pd.DataFrame, expected_features: int | None = None) -> int:
|
23
|
+
if data.empty:
|
24
|
+
raise ValueError("data contains no rows. Please provide a valid data set.")
|
25
|
+
if expected_features is not None and data.shape[-1] != expected_features:
|
26
|
+
raise ValueError(f"expected '{expected_features}' features in data set:\n\t{data}")
|
27
|
+
return data.shape[-1]
|
28
|
+
|
29
|
+
|
30
|
+
def _create_multilevel_index(chunks: Sequence[Chunk], result_group_name: str, result_column_names: Sequence[str]):
|
31
|
+
chunk_column_names = (*chunks[0].KEYS, "period")
|
32
|
+
chunk_tuples = [("chunk", chunk_column_name) for chunk_column_name in chunk_column_names]
|
33
|
+
result_tuples = [(result_group_name, column_name) for column_name in result_column_names]
|
34
|
+
return pd.MultiIndex.from_tuples(chunk_tuples + result_tuples)
|
35
|
+
|
36
|
+
|
37
|
+
class AbstractCalculator(ABC):
|
38
|
+
"""Base class for drift calculation."""
|
39
|
+
|
40
|
+
def __init__(self, chunker: Chunker | None = None, logger: Logger | None = None):
|
41
|
+
self.chunker = chunker if isinstance(chunker, Chunker) else CountBasedChunker(10)
|
42
|
+
self.result: DriftMVDCOutput | None = None
|
43
|
+
self.n_features: int | None = None
|
44
|
+
self._logger = logger if isinstance(logger, Logger) else logging.getLogger(__name__)
|
45
|
+
|
46
|
+
def fit(self, reference_data: pd.DataFrame) -> Self:
|
47
|
+
"""Trains the calculator using reference data."""
|
48
|
+
self.n_features = _validate(reference_data)
|
49
|
+
|
50
|
+
self._logger.debug(f"fitting {str(self)}")
|
51
|
+
self.result = self._fit(reference_data)
|
52
|
+
return self
|
53
|
+
|
54
|
+
def calculate(self, data: pd.DataFrame) -> DriftMVDCOutput:
|
55
|
+
"""Performs a calculation on the provided data."""
|
56
|
+
if self.result is None:
|
57
|
+
raise RuntimeError("must run fit with reference data before running calculate")
|
58
|
+
_validate(data, self.n_features)
|
59
|
+
|
60
|
+
self._logger.debug(f"calculating {str(self)}")
|
61
|
+
self.result = self._calculate(data)
|
62
|
+
return self.result
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def _fit(self, reference_data: pd.DataFrame) -> DriftMVDCOutput: ...
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
def _calculate(self, data: pd.DataFrame) -> DriftMVDCOutput: ...
|