dataeval 0.85.0__tar.gz → 0.86.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {dataeval-0.85.0 → dataeval-0.86.0}/PKG-INFO +3 -2
  2. {dataeval-0.85.0 → dataeval-0.86.0}/pyproject.toml +5 -3
  3. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/__init__.py +1 -1
  4. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_metadata.py +17 -5
  5. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_selection.py +1 -1
  6. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_classfilter.py +4 -3
  7. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/__init__.py +4 -1
  8. dataeval-0.86.0/src/dataeval/detectors/drift/_mvdc.py +92 -0
  9. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/__init__.py +6 -0
  10. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_base.py +68 -0
  11. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_chunk.py +404 -0
  12. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_domainclassifier.py +192 -0
  13. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_result.py +98 -0
  14. dataeval-0.86.0/src/dataeval/detectors/drift/_nml/_thresholds.py +280 -0
  15. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/__init__.py +2 -1
  16. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_bias.py +1 -3
  17. dataeval-0.86.0/src/dataeval/outputs/_drift.py +151 -0
  18. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_linters.py +1 -6
  19. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_stats.py +1 -6
  20. dataeval-0.85.0/src/dataeval/outputs/_drift.py +0 -83
  21. {dataeval-0.85.0 → dataeval-0.86.0}/LICENSE.txt +0 -0
  22. {dataeval-0.85.0 → dataeval-0.86.0}/README.md +0 -0
  23. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/_log.py +0 -0
  24. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/config.py +0 -0
  25. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/__init__.py +0 -0
  26. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_embeddings.py +0 -0
  27. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_images.py +0 -0
  28. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_split.py +0 -0
  29. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/_targets.py +0 -0
  30. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/__init__.py +0 -0
  31. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_classbalance.py +0 -0
  32. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_indices.py +0 -0
  33. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_limit.py +0 -0
  34. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_prioritize.py +0 -0
  35. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_reverse.py +0 -0
  36. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/data/selections/_shuffle.py +0 -0
  37. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/__init__.py +0 -0
  38. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_base.py +0 -0
  39. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_cvm.py +0 -0
  40. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_ks.py +0 -0
  41. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_mmd.py +0 -0
  42. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/_uncertainty.py +0 -0
  43. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/drift/updates.py +0 -0
  44. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/__init__.py +0 -0
  45. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/duplicates.py +0 -0
  46. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/linters/outliers.py +0 -0
  47. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/__init__.py +0 -0
  48. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/ae.py +0 -0
  49. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/base.py +0 -0
  50. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/mixin.py +0 -0
  51. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/detectors/ood/vae.py +0 -0
  52. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/__init__.py +0 -0
  53. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_distance.py +0 -0
  54. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_ood.py +0 -0
  55. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metadata/_utils.py +0 -0
  56. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/__init__.py +0 -0
  57. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/__init__.py +0 -0
  58. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_balance.py +0 -0
  59. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_completeness.py +0 -0
  60. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_coverage.py +0 -0
  61. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_diversity.py +0 -0
  62. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/bias/_parity.py +0 -0
  63. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/__init__.py +0 -0
  64. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_ber.py +0 -0
  65. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_clusterer.py +0 -0
  66. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_divergence.py +0 -0
  67. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/estimators/_uap.py +0 -0
  68. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/__init__.py +0 -0
  69. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_base.py +0 -0
  70. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_boxratiostats.py +0 -0
  71. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_dimensionstats.py +0 -0
  72. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_hashstats.py +0 -0
  73. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_imagestats.py +0 -0
  74. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_labelstats.py +0 -0
  75. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_pixelstats.py +0 -0
  76. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/metrics/stats/_visualstats.py +0 -0
  77. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_base.py +0 -0
  78. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_estimators.py +0 -0
  79. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_metadata.py +0 -0
  80. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_ood.py +0 -0
  81. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_utils.py +0 -0
  82. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/outputs/_workflows.py +0 -0
  83. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/py.typed +0 -0
  84. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/typing.py +0 -0
  85. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/__init__.py +0 -0
  86. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_array.py +0 -0
  87. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_bin.py +0 -0
  88. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_clusterer.py +0 -0
  89. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_fast_mst.py +0 -0
  90. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_image.py +0 -0
  91. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_method.py +0 -0
  92. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_mst.py +0 -0
  93. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/_plot.py +0 -0
  94. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/__init__.py +0 -0
  95. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/_dataset.py +0 -0
  96. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/collate.py +0 -0
  97. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/data/metadata.py +0 -0
  98. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/__init__.py +0 -0
  99. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_base.py +0 -0
  100. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_cifar10.py +0 -0
  101. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_fileio.py +0 -0
  102. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_milco.py +0 -0
  103. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_mixin.py +0 -0
  104. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_mnist.py +0 -0
  105. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_ships.py +0 -0
  106. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_types.py +0 -0
  107. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/datasets/_voc.py +0 -0
  108. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/__init__.py +0 -0
  109. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_blocks.py +0 -0
  110. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_gmm.py +0 -0
  111. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/_internal.py +0 -0
  112. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/models.py +0 -0
  113. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/utils/torch/trainer.py +0 -0
  114. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/workflows/__init__.py +0 -0
  115. {dataeval-0.85.0 → dataeval-0.86.0}/src/dataeval/workflows/sufficiency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.85.0
3
+ Version: 0.86.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -23,10 +23,11 @@ Classifier: Topic :: Scientific/Engineering
23
23
  Provides-Extra: all
24
24
  Requires-Dist: defusedxml (>=0.7.1)
25
25
  Requires-Dist: fast_hdbscan (==0.2.0)
26
+ Requires-Dist: lightgbm (>=4)
26
27
  Requires-Dist: matplotlib (>=3.7.1) ; extra == "all"
27
28
  Requires-Dist: numba (>=0.59.1)
28
29
  Requires-Dist: numpy (>=1.24.2)
29
- Requires-Dist: pandas (>=2.0) ; extra == "all"
30
+ Requires-Dist: pandas (>=2.0)
30
31
  Requires-Dist: pillow (>=10.3.0)
31
32
  Requires-Dist: requests
32
33
  Requires-Dist: scikit-learn (>=1.5.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.85.0" # dynamic
3
+ version = "0.86.0" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -44,8 +44,10 @@ packages = [
44
44
  python = ">=3.9,<3.13"
45
45
  defusedxml = {version = ">=0.7.1"}
46
46
  fast_hdbscan = {version = "0.2.0"} # 0.2.1 hits a bug in condense_tree comparing float to none
47
+ lightgbm = {version = ">=4"}
47
48
  numba = {version = ">=0.59.1"}
48
49
  numpy = {version = ">=1.24.2"}
50
+ pandas = {version = ">=2.0"}
49
51
  pillow = {version = ">=10.3.0"}
50
52
  requests = {version = "*"}
51
53
  scipy = {version = ">=1.10"}
@@ -58,10 +60,9 @@ xxhash = {version = ">=3.3"}
58
60
 
59
61
  # optional
60
62
  matplotlib = {version = ">=3.7.1", optional = true}
61
- pandas = {version = ">=2.0", optional = true}
62
63
 
63
64
  [tool.poetry.extras]
64
- all = ["matplotlib", "pandas"]
65
+ all = ["matplotlib"]
65
66
 
66
67
  [tool.poetry.group.dev]
67
68
  optional = true
@@ -132,6 +133,7 @@ markers = [
132
133
  "required: marks tests for required features",
133
134
  "optional: marks tests for optional features",
134
135
  "requires_all: marks tests that require the all extras",
136
+ "cuda: marks tests that require cuda",
135
137
  ]
136
138
 
137
139
  [tool.coverage.run]
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.85.0"
11
+ __version__ = "0.86.0"
12
12
 
13
13
  import logging
14
14
 
@@ -191,6 +191,11 @@ class Metadata:
191
191
  self._process()
192
192
  return self._image_indices
193
193
 
194
+ @property
195
+ def image_count(self) -> int:
196
+ self._process()
197
+ return int(self._image_indices.max() + 1)
198
+
194
199
  def _collate(self, force: bool = False):
195
200
  if self._collated and not force:
196
201
  return
@@ -359,12 +364,19 @@ class Metadata:
359
364
 
360
365
  def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
361
366
  self._merge()
362
- self._processed = False
363
- target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
364
- if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
367
+
368
+ targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
369
+ images = self.image_count
370
+ lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
371
+ targets_match = all(f == targets for f in lengths.values())
372
+ images_match = targets_match if images == targets else all(f == images for f in lengths.values())
373
+ if not targets_match and not images_match:
365
374
  raise ValueError(
366
375
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
367
376
  )
368
- merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
377
+ merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
369
378
  for k, v in factors.items():
370
- merged[k] = v
379
+ v = as_numpy(v)
380
+ merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
381
+
382
+ self._processed = False
@@ -120,7 +120,7 @@ class Select(AnnotatedDataset[_TDatum]):
120
120
 
121
121
  def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
122
122
  for subselection, indices in self._subselections:
123
- datum = subselection(datum) if index in indices else datum
123
+ datum = subselection(datum) if self._selection[index] in indices else datum
124
124
  return datum
125
125
 
126
126
  def __getitem__(self, index: int) -> _TDatum:
@@ -10,7 +10,6 @@ from numpy.typing import NDArray
10
10
  from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
11
11
  from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
12
12
  from dataeval.utils._array import as_numpy
13
- from dataeval.utils.data.metadata import flatten
14
13
 
15
14
 
16
15
  class ClassFilter(Selection[Any]):
@@ -96,13 +95,15 @@ class ClassFilterSubSelection(Subselection[Any]):
96
95
  def __init__(self, classes: Sequence[int]) -> None:
97
96
  self.classes = classes
98
97
 
98
+ def _filter(self, d: dict[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
99
+ return {k: self._filter(v, mask) if isinstance(v, dict) else _try_mask_object(v, mask) for k, v in d.items()}
100
+
99
101
  def __call__(self, datum: _TDatum) -> _TDatum:
100
102
  # build a mask for any arrays
101
103
  image, target, metadata = datum
102
104
 
103
105
  mask = np.isin(as_numpy(target.labels), self.classes)
104
- flattened_metadata = flatten(metadata)[0]
105
- filtered_metadata = {k: _try_mask_object(v, mask) for k, v in flattened_metadata.items()}
106
+ filtered_metadata = self._filter(metadata, mask)
106
107
 
107
108
  # return a masked datum
108
109
  filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
@@ -7,6 +7,8 @@ __all__ = [
7
7
  "DriftKS",
8
8
  "DriftMMD",
9
9
  "DriftMMDOutput",
10
+ "DriftMVDC",
11
+ "DriftMVDCOutput",
10
12
  "DriftOutput",
11
13
  "DriftUncertainty",
12
14
  "UpdateStrategy",
@@ -18,5 +20,6 @@ from dataeval.detectors.drift._base import UpdateStrategy
18
20
  from dataeval.detectors.drift._cvm import DriftCVM
19
21
  from dataeval.detectors.drift._ks import DriftKS
20
22
  from dataeval.detectors.drift._mmd import DriftMMD
23
+ from dataeval.detectors.drift._mvdc import DriftMVDC
21
24
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
- from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
25
+ from dataeval.outputs._drift import DriftMMDOutput, DriftMVDCOutput, DriftOutput
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from numpy.typing import ArrayLike
8
+
9
+ if TYPE_CHECKING:
10
+ from typing import Self
11
+ else:
12
+ from typing_extensions import Self
13
+
14
+ from dataeval.detectors.drift._nml._chunk import CountBasedChunker, SizeBasedChunker
15
+ from dataeval.detectors.drift._nml._domainclassifier import DomainClassifierCalculator
16
+ from dataeval.detectors.drift._nml._thresholds import ConstantThreshold
17
+ from dataeval.outputs._drift import DriftMVDCOutput
18
+ from dataeval.utils._array import flatten
19
+
20
+
21
+ class DriftMVDC:
22
+ """Multivariant Domain Classifier
23
+
24
+ Parameters
25
+ ----------
26
+ n_folds : int, default 5
27
+ Number of cross-validation (CV) folds.
28
+ chunk_size : int or None, default None
29
+ Number of samples in a chunk used in CV, will get one metric & prediction per chunk.
30
+ chunk_count : int or None, default None
31
+ Number of total chunks used in CV, will get one metric & prediction per chunk.
32
+ threshold : Tuple[float, float], default (0.45, 0.65)
33
+ (lower, upper) metric bounds on roc_auc for identifying :term:`drift<Drift>`.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ n_folds: int = 5,
39
+ chunk_size: int | None = None,
40
+ chunk_count: int | None = None,
41
+ threshold: tuple[float, float] = (0.45, 0.65),
42
+ ) -> None:
43
+ self.threshold: tuple[float, float] = max(0.0, min(threshold)), min(1.0, max(threshold))
44
+ chunker = (
45
+ CountBasedChunker(10 if chunk_count is None else chunk_count)
46
+ if chunk_size is None
47
+ else SizeBasedChunker(chunk_size)
48
+ )
49
+ self._calc = DomainClassifierCalculator(
50
+ cv_folds_num=n_folds,
51
+ chunker=chunker,
52
+ threshold=ConstantThreshold(lower=self.threshold[0], upper=self.threshold[1]),
53
+ )
54
+
55
+ def fit(self, x_ref: ArrayLike) -> Self:
56
+ """
57
+ Fit the domain classifier on the training dataframe
58
+
59
+ Parameters
60
+ ----------
61
+ x_ref : ArrayLike
62
+ Reference data with dim[n_samples, n_features].
63
+
64
+ Returns
65
+ -------
66
+ Self
67
+
68
+ """
69
+ # for 1D input, assume that is 1 sample: dim[1,n_features]
70
+ self.x_ref: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x_ref))))
71
+ self.n_features: int = self.x_ref.shape[-1]
72
+ self._calc.fit(self.x_ref)
73
+ return self
74
+
75
+ def predict(self, x: ArrayLike) -> DriftMVDCOutput:
76
+ """
77
+ Perform :term:`inference<Inference>` on the test dataframe
78
+
79
+ Parameters
80
+ ----------
81
+ x : ArrayLike
82
+ Test (analysis) data with dim[n_samples, n_features].
83
+
84
+ Returns
85
+ -------
86
+ DomainClassifierDriftResult
87
+ """
88
+ self.x_test: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x))))
89
+ if self.x_test.shape[-1] != self.n_features:
90
+ raise ValueError("Reference and test embeddings have different number of features")
91
+
92
+ return self._calc.calculate(self.x_test)
@@ -0,0 +1,6 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """
@@ -0,0 +1,68 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/blob/main/nannyml/base.py
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from logging import Logger
13
+ from typing import Sequence
14
+
15
+ import pandas as pd
16
+ from typing_extensions import Self
17
+
18
+ from dataeval.detectors.drift._nml._chunk import Chunk, Chunker, CountBasedChunker
19
+ from dataeval.outputs._drift import DriftMVDCOutput
20
+
21
+
22
+ def _validate(data: pd.DataFrame, expected_features: int | None = None) -> int:
23
+ if data.empty:
24
+ raise ValueError("data contains no rows. Please provide a valid data set.")
25
+ if expected_features is not None and data.shape[-1] != expected_features:
26
+ raise ValueError(f"expected '{expected_features}' features in data set:\n\t{data}")
27
+ return data.shape[-1]
28
+
29
+
30
+ def _create_multilevel_index(chunks: Sequence[Chunk], result_group_name: str, result_column_names: Sequence[str]):
31
+ chunk_column_names = (*chunks[0].KEYS, "period")
32
+ chunk_tuples = [("chunk", chunk_column_name) for chunk_column_name in chunk_column_names]
33
+ result_tuples = [(result_group_name, column_name) for column_name in result_column_names]
34
+ return pd.MultiIndex.from_tuples(chunk_tuples + result_tuples)
35
+
36
+
37
+ class AbstractCalculator(ABC):
38
+ """Base class for drift calculation."""
39
+
40
+ def __init__(self, chunker: Chunker | None = None, logger: Logger | None = None):
41
+ self.chunker = chunker if isinstance(chunker, Chunker) else CountBasedChunker(10)
42
+ self.result: DriftMVDCOutput | None = None
43
+ self.n_features: int | None = None
44
+ self._logger = logger if isinstance(logger, Logger) else logging.getLogger(__name__)
45
+
46
+ def fit(self, reference_data: pd.DataFrame) -> Self:
47
+ """Trains the calculator using reference data."""
48
+ self.n_features = _validate(reference_data)
49
+
50
+ self._logger.debug(f"fitting {str(self)}")
51
+ self.result = self._fit(reference_data)
52
+ return self
53
+
54
+ def calculate(self, data: pd.DataFrame) -> DriftMVDCOutput:
55
+ """Performs a calculation on the provided data."""
56
+ if self.result is None:
57
+ raise RuntimeError("must run fit with reference data before running calculate")
58
+ _validate(data, self.n_features)
59
+
60
+ self._logger.debug(f"calculating {str(self)}")
61
+ self.result = self._calculate(data)
62
+ return self.result
63
+
64
+ @abstractmethod
65
+ def _fit(self, reference_data: pd.DataFrame) -> DriftMVDCOutput: ...
66
+
67
+ @abstractmethod
68
+ def _calculate(self, data: pd.DataFrame) -> DriftMVDCOutput: ...