dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +188 -178
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +4 -5
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metadata/_distance.py +10 -7
  22. dataeval/metadata/_ood.py +11 -103
  23. dataeval/metrics/bias/_balance.py +23 -33
  24. dataeval/metrics/bias/_diversity.py +16 -14
  25. dataeval/metrics/bias/_parity.py +18 -18
  26. dataeval/metrics/estimators/_divergence.py +2 -4
  27. dataeval/metrics/stats/_base.py +103 -42
  28. dataeval/metrics/stats/_boxratiostats.py +21 -19
  29. dataeval/metrics/stats/_dimensionstats.py +14 -10
  30. dataeval/metrics/stats/_hashstats.py +1 -1
  31. dataeval/metrics/stats/_pixelstats.py +6 -6
  32. dataeval/metrics/stats/_visualstats.py +3 -3
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +24 -70
  35. dataeval/outputs/_drift.py +1 -9
  36. dataeval/outputs/_linters.py +11 -11
  37. dataeval/outputs/_stats.py +82 -23
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +54 -28
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +22 -12
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
  62. dataeval-0.86.2.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.86.0.dist-info/RECORD +0 -114
  65. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
dataeval/data/_targets.py CHANGED
@@ -24,11 +24,13 @@ class Targets:
24
24
  labels : NDArray[np.intp]
25
25
  Labels (N,) for N images or objects
26
26
  scores : NDArray[np.float32]
27
- Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
27
+ Probability scores (N, M) for N images of M classes or confidence score (N,) of objects
28
28
  bboxes : NDArray[np.float32] | None
29
- Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
29
+ Bounding boxes (N, 4) for N objects in (x0, y0, x1, y1) format
30
30
  source : NDArray[np.intp] | None
31
31
  Source image index (N,) for N objects
32
+ size : int
33
+ Count of objects
32
34
  """
33
35
 
34
36
  labels: NDArray[np.intp]
@@ -55,13 +57,16 @@ class Targets:
55
57
  )
56
58
 
57
59
  if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
58
- raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
60
+ raise ValueError("Bounding boxes must be in (x0, y0, x1, y1) format.")
61
+
62
+ @property
63
+ def size(self) -> int:
64
+ return len(self.labels)
59
65
 
60
66
  def __len__(self) -> int:
61
67
  if self.source is None:
62
68
  return len(self.labels)
63
- else:
64
- return len(np.unique(self.source))
69
+ return len(np.unique(self.source))
65
70
 
66
71
  def __getitem__(self, idx: int, /) -> Targets:
67
72
  if self.source is None or self.bboxes is None:
@@ -71,14 +76,13 @@ class Targets:
71
76
  None,
72
77
  None,
73
78
  )
74
- else:
75
- mask = np.where(self.source == idx, True, False)
76
- return Targets(
77
- np.atleast_1d(self.labels[mask]),
78
- np.atleast_1d(self.scores[mask]),
79
- np.atleast_2d(self.bboxes[mask]),
80
- np.atleast_1d(self.source[mask]),
81
- )
79
+ mask = np.where(self.source == idx, True, False)
80
+ return Targets(
81
+ np.atleast_1d(self.labels[mask]),
82
+ np.atleast_1d(self.scores[mask]),
83
+ np.atleast_2d(self.bboxes[mask]),
84
+ np.atleast_1d(self.source[mask]),
85
+ )
82
86
 
83
87
  def __iter__(self) -> Iterator[Targets]:
84
88
  for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
@@ -68,11 +68,8 @@ _TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
68
68
 
69
69
 
70
70
  def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
71
- if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
72
- if isinstance(obj, Array):
73
- return obj[mask]
74
- elif isinstance(obj, Sequence):
75
- return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
71
+ if not isinstance(obj, (str, bytes, bytearray)) and isinstance(obj, (Sequence, Array)) and len(obj) == len(mask):
72
+ return obj[mask] if isinstance(obj, Array) else cast(_T, [item for i, item in enumerate(obj) if mask[i]])
76
73
  return obj
77
74
 
78
75
 
@@ -99,8 +99,7 @@ class _KNNSorter(_Sorter):
99
99
  np.fill_diagonal(dists, np.inf)
100
100
  else:
101
101
  dists = pairwise_distances(embeddings, reference)
102
- inds = np.argsort(np.sort(dists, axis=1)[:, self._k])
103
- return inds
102
+ return np.argsort(np.sort(dists, axis=1)[:, self._k])
104
103
 
105
104
 
106
105
  class _KMeansSorter(_Sorter):
@@ -124,15 +123,13 @@ class _KMeansSorter(_Sorter):
124
123
  class _KMeansDistanceSorter(_KMeansSorter):
125
124
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
126
125
  clst = self._get_clusters(embeddings if reference is None else reference)
127
- inds = np.argsort(clst._dist2center(embeddings))
128
- return inds
126
+ return np.argsort(clst._dist2center(embeddings))
129
127
 
130
128
 
131
129
  class _KMeansComplexitySorter(_KMeansSorter):
132
130
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
133
131
  clst = self._get_clusters(embeddings if reference is None else reference)
134
- inds = clst._sort_by_weights(embeddings)
135
- return inds
132
+ return clst._sort_by_weights(embeddings)
136
133
 
137
134
 
138
135
  class Prioritize(Selection[Any]):
@@ -266,10 +263,10 @@ class Prioritize(Selection[Any]):
266
263
  def _get_sorter(self, samples: int) -> _Sorter:
267
264
  if self._method == "knn":
268
265
  return _KNNSorter(samples, self._k)
269
- elif self._method == "kmeans_distance":
266
+ if self._method == "kmeans_distance":
270
267
  return _KMeansDistanceSorter(samples, self._c)
271
- else: # self._method == "kmeans_complexity"
272
- return _KMeansComplexitySorter(samples, self._c)
268
+ # self._method == "kmeans_complexity"
269
+ return _KMeansComplexitySorter(samples, self._c)
273
270
 
274
271
  def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
275
272
  emb: NDArray[Any] = embeddings.to_numpy(selection)
@@ -30,7 +30,9 @@ class Shuffle(Selection[Any]):
30
30
  seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
31
31
  stage = SelectionStage.ORDER
32
32
 
33
- def __init__(self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None):
33
+ def __init__(
34
+ self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None
35
+ ) -> None:
34
36
  self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
35
37
 
36
38
  def __call__(self, dataset: Select[Any]) -> None:
@@ -13,7 +13,7 @@ __all__ = []
13
13
  import math
14
14
  from abc import abstractmethod
15
15
  from functools import wraps
16
- from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
16
+ from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
17
17
 
18
18
  import numpy as np
19
19
  from numpy.typing import NDArray
@@ -40,7 +40,7 @@ def update_strategy(fn: Callable[..., R]) -> Callable[..., R]:
40
40
  """Decorator to update x_ref with x using selected update methodology"""
41
41
 
42
42
  @wraps(fn)
43
- def _(self: BaseDrift, data: Embeddings | Array, *args, **kwargs) -> R:
43
+ def _(self: BaseDrift, data: Embeddings | Array, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> R:
44
44
  output = fn(self, data, *args, **kwargs)
45
45
 
46
46
  # update reference dataset
@@ -184,7 +184,7 @@ class BaseDriftUnivariate(BaseDrift):
184
184
  threshold = self.p_val / self.n_features
185
185
  drift_pred = bool((p_vals < threshold).any())
186
186
  return drift_pred, threshold
187
- elif self.correction == "fdr":
187
+ if self.correction == "fdr":
188
188
  n = p_vals.shape[0]
189
189
  i = np.arange(n) + np.int_(1)
190
190
  p_sorted = np.sort(p_vals)
@@ -195,8 +195,7 @@ class BaseDriftUnivariate(BaseDrift):
195
195
  except ValueError: # sorted p-values not below thresholds
196
196
  return bool(below_threshold.any()), q_threshold.min()
197
197
  return bool(below_threshold.any()), q_threshold[idx_threshold]
198
- else:
199
- raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
198
+ raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
200
199
 
201
200
  @set_metadata
202
201
  @update_strategy
@@ -95,8 +95,7 @@ class DriftMMD(BaseDrift):
95
95
  k_xy = self._kernel(x, y)
96
96
  k_xx = self._k_xx if self._k_xx is not None and self.update_strategy is None else self._kernel(x, x)
97
97
  k_yy = self._kernel(y, y)
98
- kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
99
- return kernel_mat
98
+ return torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
100
99
 
101
100
  def score(self, data: Embeddings | Array) -> tuple[float, float, float]:
102
101
  """
@@ -205,8 +204,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
205
204
  n = min(x.shape[0], y.shape[0])
206
205
  n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
207
206
  n_median = n + (torch.prod(torch.as_tensor(dist.shape)) - n) // 2 - 1
208
- sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
209
- return sigma
207
+ return (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
210
208
 
211
209
 
212
210
  class GaussianRBF(torch.nn.Module):
@@ -310,5 +308,4 @@ def mmd2_from_kernel_matrix(
310
308
  kernel_mat = kernel_mat[idx][:, idx]
311
309
  k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
312
310
  c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
313
- mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
314
- return mmd2
311
+ return c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
@@ -27,7 +27,9 @@ def _validate(data: pd.DataFrame, expected_features: int | None = None) -> int:
27
27
  return data.shape[-1]
28
28
 
29
29
 
30
- def _create_multilevel_index(chunks: Sequence[Chunk], result_group_name: str, result_column_names: Sequence[str]):
30
+ def _create_multilevel_index(
31
+ chunks: Sequence[Chunk], result_group_name: str, result_column_names: Sequence[str]
32
+ ) -> pd.MultiIndex:
31
33
  chunk_column_names = (*chunks[0].KEYS, "period")
32
34
  chunk_tuples = [("chunk", chunk_column_name) for chunk_column_name in chunk_column_names]
33
35
  result_tuples = [(result_group_name, column_name) for column_name in result_column_names]
@@ -37,7 +39,7 @@ def _create_multilevel_index(chunks: Sequence[Chunk], result_group_name: str, re
37
39
  class AbstractCalculator(ABC):
38
40
  """Base class for drift calculation."""
39
41
 
40
- def __init__(self, chunker: Chunker | None = None, logger: Logger | None = None):
42
+ def __init__(self, chunker: Chunker | None = None, logger: Logger | None = None) -> None:
41
43
  self.chunker = chunker if isinstance(chunker, Chunker) else CountBasedChunker(10)
42
44
  self.result: DriftMVDCOutput | None = None
43
45
  self.n_features: int | None = None
@@ -16,7 +16,6 @@ from abc import ABC, abstractmethod
16
16
  from typing import Any, Generic, Literal, Sequence, TypeVar, cast
17
17
 
18
18
  import pandas as pd
19
- from dateutil.parser import ParserError
20
19
  from pandas import Index, Period
21
20
  from typing_extensions import Self
22
21
 
@@ -31,7 +30,7 @@ class Chunk(ABC):
31
30
  def __init__(
32
31
  self,
33
32
  data: pd.DataFrame,
34
- ):
33
+ ) -> None:
35
34
  self.key: str
36
35
  self.data = data
37
36
 
@@ -39,11 +38,11 @@ class Chunk(ABC):
39
38
  self.end_index: int = -1
40
39
  self.chunk_index: int = -1
41
40
 
42
- def __repr__(self):
41
+ def __repr__(self) -> str:
43
42
  attr_str = ", ".join([f"{k}={v}" for k, v in self.dict().items()])
44
43
  return f"{self.__class__.__name__}(data=pd.DataFrame(shape={self.data.shape}), {attr_str})"
45
44
 
46
- def __len__(self):
45
+ def __len__(self) -> int:
47
46
  return self.data.shape[0]
48
47
 
49
48
  @abstractmethod
@@ -76,7 +75,7 @@ class IndexChunk(Chunk):
76
75
  data: pd.DataFrame,
77
76
  start_index: int,
78
77
  end_index: int,
79
- ):
78
+ ) -> None:
80
79
  super().__init__(data)
81
80
  self.key = f"[{start_index}:{end_index}]"
82
81
  self.start_index: int = start_index
@@ -113,7 +112,7 @@ class PeriodChunk(Chunk):
113
112
 
114
113
  KEYS = ("key", "chunk_index", "start_date", "end_date", "chunk_size")
115
114
 
116
- def __init__(self, data: pd.DataFrame, period: Period, chunk_size: int):
115
+ def __init__(self, data: pd.DataFrame, period: Period, chunk_size: int) -> None:
117
116
  super().__init__(data)
118
117
  self.key = str(period)
119
118
  self.start_datetime = period.start_time
@@ -127,6 +126,7 @@ class PeriodChunk(Chunk):
127
126
  a, b = (self, other) if self < other else (other, self)
128
127
  result = copy.deepcopy(a)
129
128
  result.data = pd.concat([a.data, b.data])
129
+ result.end_index = b.end_index
130
130
  result.end_datetime = b.end_datetime
131
131
  result.chunk_size += b.chunk_size
132
132
  return result
@@ -237,13 +237,7 @@ class PeriodBasedChunker(Chunker[PeriodChunk]):
237
237
  if self.timestamp_column_name not in data:
238
238
  raise ValueError(f"timestamp column '{self.timestamp_column_name}' not in columns")
239
239
 
240
- try:
241
- grouped = data.groupby(pd.to_datetime(data[self.timestamp_column_name]).dt.to_period(self.offset))
242
- except ParserError:
243
- raise ValueError(
244
- f"could not parse date_column '{self.timestamp_column_name}' values as dates."
245
- f"Please verify if you've specified the correct date column."
246
- )
240
+ grouped = data.groupby(pd.to_datetime(data[self.timestamp_column_name]).dt.to_period(self.offset))
247
241
 
248
242
  for k, v in grouped.groups.items():
249
243
  period, index = cast(Period, k), cast(Index, v)
@@ -281,7 +275,7 @@ class SizeBasedChunker(Chunker[IndexChunk]):
281
275
  self,
282
276
  chunk_size: int,
283
277
  incomplete: Literal["append", "drop", "keep"] = "keep",
284
- ):
278
+ ) -> None:
285
279
  """Create a new SizeBasedChunker.
286
280
 
287
281
  Parameters
@@ -314,12 +308,11 @@ class SizeBasedChunker(Chunker[IndexChunk]):
314
308
  def _split(self, data: pd.DataFrame) -> list[IndexChunk]:
315
309
  def _create_chunk(index: int, data: pd.DataFrame, chunk_size: int) -> IndexChunk:
316
310
  chunk_data = data.iloc[index : index + chunk_size]
317
- chunk = IndexChunk(
311
+ return IndexChunk(
318
312
  data=chunk_data,
319
313
  start_index=index,
320
314
  end_index=index + chunk_size - 1,
321
315
  )
322
- return chunk
323
316
 
324
317
  chunks = [
325
318
  _create_chunk(index=i, data=data, chunk_size=self.chunk_size)
@@ -364,7 +357,7 @@ class CountBasedChunker(Chunker[IndexChunk]):
364
357
  self,
365
358
  chunk_number: int,
366
359
  incomplete: Literal["append", "drop", "keep"] = "keep",
367
- ):
360
+ ) -> None:
368
361
  """Creates a new CountBasedChunker.
369
362
 
370
363
  It will calculate the amount of observations per chunk based on the given chunk count.
@@ -400,5 +393,4 @@ class CountBasedChunker(Chunker[IndexChunk]):
400
393
  def _split(self, data: pd.DataFrame) -> list[IndexChunk]:
401
394
  chunk_size = data.shape[0] // self.chunk_number
402
395
  chunker = SizeBasedChunker(chunk_size, self.incomplete)
403
- chunks = chunker.split(data=data)
404
- return chunks
396
+ return chunker.split(data=data)
@@ -20,7 +20,7 @@ from sklearn.model_selection import StratifiedKFold
20
20
  from dataeval.config import get_max_processes, get_seed
21
21
  from dataeval.detectors.drift._nml._base import AbstractCalculator, _create_multilevel_index
22
22
  from dataeval.detectors.drift._nml._chunk import Chunk, Chunker
23
- from dataeval.detectors.drift._nml._thresholds import ConstantThreshold, Threshold, calculate_threshold_values
23
+ from dataeval.detectors.drift._nml._thresholds import ConstantThreshold, Threshold
24
24
  from dataeval.outputs._base import set_metadata
25
25
  from dataeval.outputs._drift import DriftMVDCOutput
26
26
 
@@ -38,10 +38,8 @@ DEFAULT_LGBM_HYPERPARAMS = {
38
38
  "min_child_weight": 0.001,
39
39
  "min_split_gain": 0.0,
40
40
  "n_estimators": 100,
41
- "n_jobs": get_max_processes() or 0,
42
41
  "num_leaves": 31,
43
42
  "objective": None,
44
- "random_state": get_seed(),
45
43
  "reg_alpha": 0.0,
46
44
  "reg_lambda": 0.0,
47
45
  "subsample": 1.0,
@@ -126,7 +124,7 @@ class DomainClassifierCalculator(AbstractCalculator):
126
124
  self.result._data = pd.concat([self.result._data, res], ignore_index=True)
127
125
  return self.result
128
126
 
129
- def _calculate_chunk(self, chunk: Chunk):
127
+ def _calculate_chunk(self, chunk: Chunk) -> float:
130
128
  if self.result is None:
131
129
  # Use information from chunk indices to identify reference chunk's location. This is possible because
132
130
  # both the internal reference data copy and the chunk data were sorted by timestamp, so these
@@ -151,7 +149,7 @@ class DomainClassifierCalculator(AbstractCalculator):
151
149
  _try = y[train_index]
152
150
  _tsx = df_X.iloc[test_index]
153
151
  _tsy = y[test_index]
154
- model = LGBMClassifier(**self.hyperparameters)
152
+ model = LGBMClassifier(**self.hyperparameters, n_jobs=get_max_processes(), random_state=get_seed())
155
153
  model.fit(_trx, _try)
156
154
  preds = np.asarray(model.predict_proba(_tsx), dtype=np.float32)[:, 1]
157
155
  all_preds.append(preds)
@@ -159,24 +157,15 @@ class DomainClassifierCalculator(AbstractCalculator):
159
157
 
160
158
  np_all_preds = np.concatenate(all_preds, axis=0)
161
159
  np_all_tgts = np.concatenate(all_tgts, axis=0)
162
- try:
163
- # catch case where all rows are duplicates
164
- result = roc_auc_score(np_all_tgts, np_all_preds)
165
- except ValueError as err:
166
- if str(err) != "Only one class present in y_true. ROC AUC score is not defined in that case.":
167
- raise
168
- else:
169
- # by definition if reference and chunk exactly match we can't discriminate
170
- result = 0.5
171
- return result
160
+ result = roc_auc_score(np_all_tgts, np_all_preds)
161
+ return 0.5 if result == np.nan else float(result)
172
162
 
173
163
  def _populate_alert_thresholds(self, result_data: pd.DataFrame) -> pd.DataFrame:
174
164
  if self.result is None:
175
- self._threshold_values = calculate_threshold_values(
176
- threshold=self.threshold,
165
+ self._threshold_values = self.threshold.calculate(
177
166
  data=result_data.loc[:, ("domain_classifier_auroc", "value")], # type: ignore | dataframe loc
178
- lower_threshold_value_limit=0.0,
179
- upper_threshold_value_limit=1.0,
167
+ lower_limit=0.0,
168
+ upper_limit=1.0,
180
169
  logger=self._logger,
181
170
  )
182
171
 
@@ -42,14 +42,13 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
42
42
  """Export results to pandas dataframe."""
43
43
  if multilevel:
44
44
  return self._data
45
- else:
46
- column_names = [
47
- "_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
48
- for col in self._data.columns.values
49
- ]
50
- single_level_data = self._data.copy(deep=True)
51
- single_level_data.columns = column_names
52
- return single_level_data
45
+ column_names = [
46
+ "_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
47
+ for col in self._data.columns.values
48
+ ]
49
+ single_level_data = self._data.copy(deep=True)
50
+ single_level_data.columns = column_names
51
+ return single_level_data
53
52
 
54
53
  def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
55
54
  """Returns filtered result metric data."""
@@ -67,7 +66,7 @@ class Abstract1DResult(AbstractResult, ABC):
67
66
  def __init__(self, results_data: pd.DataFrame) -> None:
68
67
  super().__init__(results_data)
69
68
 
70
- def _filter(self, period: str, metrics=None) -> Self:
69
+ def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self:
71
70
  data = self._data
72
71
  if period != "all":
73
72
  data = self._data.loc[self._data.loc[:, ("chunk", "period")] == period, :] # type: ignore | dataframe loc
@@ -29,10 +29,10 @@ class Threshold(ABC):
29
29
  """Class registry lookup to get threshold subclass from threshold_type string"""
30
30
 
31
31
  def __str__(self) -> str:
32
- return self.__str__()
32
+ return f"{self.__class__.__name__}({str(vars(self))})"
33
33
 
34
34
  def __repr__(self) -> str:
35
- return self.__class__.__name__ + str(vars(self))
35
+ return str(self)
36
36
 
37
37
  def __eq__(self, other: object) -> bool:
38
38
  return isinstance(other, self.__class__) and other.__dict__ == self.__dict__
@@ -41,7 +41,7 @@ class Threshold(ABC):
41
41
  Threshold._registry[threshold_type] = cls
42
42
 
43
43
  @abstractmethod
44
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
44
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
45
45
  """Returns lower and upper threshold values when given one or more np.ndarray instances.
46
46
 
47
47
  Parameters:
@@ -69,6 +69,61 @@ class Threshold(ABC):
69
69
 
70
70
  return threshold_cls(**obj)
71
71
 
72
+ def calculate(
73
+ self,
74
+ data: np.ndarray,
75
+ lower_limit: float | None = None,
76
+ upper_limit: float | None = None,
77
+ override_using_none: bool = False,
78
+ logger: logging.Logger | None = None,
79
+ ) -> tuple[float | None, float | None]:
80
+ """
81
+ Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
82
+
83
+ Parameters
84
+ ----------
85
+ data : np.ndarray
86
+ The data used by the Threshold instance to calculate the lower and upper threshold values.
87
+ This will often be the values of a drift detection method or performance metric on chunks of reference
88
+ data.
89
+ lower_limit : float or None, default None
90
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
91
+ values that end up below this limit will be replaced by this limit value.
92
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
93
+ metric.
94
+ upper_threshold_value_limit : float or None, default None
95
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
96
+ values that end up below this limit will be replaced by this limit value.
97
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
98
+ metric.
99
+ override_using_none: bool, default False
100
+ When set to True use None to override threshold values that exceed value limits.
101
+ This will prevent them from being rendered on plots.
102
+ logger: Optional[logging.Logger], default=None
103
+ An optional Logger instance. When provided a warning will be logged when a calculated threshold value
104
+ gets overridden by a threshold value limit.
105
+ """
106
+
107
+ lower_value, upper_value = self._thresholds(data)
108
+
109
+ if lower_limit is not None and lower_value is not None and lower_value <= lower_limit:
110
+ override_value = None if override_using_none else lower_limit
111
+ if logger:
112
+ logger.warning(
113
+ f"lower threshold value {lower_value} overridden by lower threshold value limit {override_value}"
114
+ )
115
+ lower_value = override_value
116
+
117
+ if upper_limit is not None and upper_value is not None and upper_value >= upper_limit:
118
+ override_value = None if override_using_none else upper_limit
119
+ if logger:
120
+ logger.warning(
121
+ f"upper threshold value {upper_value} overridden by upper threshold value limit {override_value}"
122
+ )
123
+ upper_value = override_value
124
+
125
+ return lower_value, upper_value
126
+
72
127
 
73
128
  class ConstantThreshold(Threshold, threshold_type="constant"):
74
129
  """A `Thresholder` implementation that returns a constant lower and or upper threshold value.
@@ -91,7 +146,7 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
91
146
  None 0.1
92
147
  """
93
148
 
94
- def __init__(self, lower: float | int | None = None, upper: float | int | None = None):
149
+ def __init__(self, lower: float | int | None = None, upper: float | int | None = None) -> None:
95
150
  """Creates a new ConstantThreshold instance.
96
151
 
97
152
  Args:
@@ -109,11 +164,11 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
109
164
  self.lower = lower
110
165
  self.upper = upper
111
166
 
112
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
167
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
113
168
  return self.lower, self.upper
114
169
 
115
170
  @staticmethod
116
- def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None):
171
+ def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
117
172
  if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
118
173
  raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
119
174
 
@@ -149,7 +204,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
149
204
  std_lower_multiplier: float | int | None = 3,
150
205
  std_upper_multiplier: float | int | None = 3,
151
206
  offset_from: Callable[[np.ndarray], Any] = np.nanmean,
152
- ):
207
+ ) -> None:
153
208
  """Creates a new StandardDeviationThreshold instance.
154
209
 
155
210
  Args:
@@ -173,7 +228,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
173
228
  self.std_upper_multiplier = std_upper_multiplier
174
229
  self.offset_from = offset_from
175
230
 
176
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
231
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
177
232
  aggregate = self.offset_from(data)
178
233
  std = np.nanstd(data)
179
234
 
@@ -184,7 +239,9 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
184
239
  return lower_threshold, upper_threshold
185
240
 
186
241
  @staticmethod
187
- def _validate_inputs(std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3):
242
+ def _validate_inputs(
243
+ std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3
244
+ ) -> None:
188
245
  if (
189
246
  std_lower_multiplier is not None
190
247
  and not isinstance(std_lower_multiplier, (float, int))
@@ -210,71 +267,3 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
210
267
 
211
268
  if std_upper_multiplier and std_upper_multiplier < 0:
212
269
  raise ValueError(f"'std_upper_multiplier' should be greater than 0 but got value {std_upper_multiplier}")
213
-
214
-
215
- def calculate_threshold_values(
216
- threshold: Threshold,
217
- data: np.ndarray,
218
- lower_threshold_value_limit: float | None = None,
219
- upper_threshold_value_limit: float | None = None,
220
- override_using_none: bool = False,
221
- logger: logging.Logger | None = None,
222
- metric_name: str | None = None,
223
- ) -> tuple[float | None, float | None]:
224
- """Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
225
-
226
- Parameters:
227
- threshold: Threshold
228
- The Threshold instance that determines how the lower and upper threshold values will be calculated.
229
- data: np.ndarray
230
- The data used by the Threshold instance to calculate the lower and upper threshold values.
231
- This will often be the values of a drift detection method or performance metric on chunks of reference data.
232
- lower_threshold_value_limit: Optional[float], default=None
233
- An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
234
- values that end up below this limit will be replaced by this limit value.
235
- The limit is often a theoretical constraint enforced by a specific drift detection method or performance
236
- metric.
237
- upper_threshold_value_limit: Optional[float], default=None
238
- An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
239
- values that end up below this limit will be replaced by this limit value.
240
- The limit is often a theoretical constraint enforced by a specific drift detection method or performance
241
- metric.
242
- override_using_none: bool, default=False
243
- When set to True use None to override threshold values that exceed value limits.
244
- This will prevent them from being rendered on plots.
245
- logger: Optional[logging.Logger], default=None
246
- An optional Logger instance. When provided a warning will be logged when a calculated threshold value
247
- gets overridden by a threshold value limit.
248
- metric_name: Optional[str], default=None
249
- When provided the metric name will be included within any log messages for additional clarity.
250
- """
251
-
252
- lower_threshold_value, upper_threshold_value = threshold.thresholds(data)
253
-
254
- if (
255
- lower_threshold_value_limit is not None
256
- and lower_threshold_value is not None
257
- and lower_threshold_value <= lower_threshold_value_limit
258
- ):
259
- override_value = None if override_using_none else lower_threshold_value_limit
260
- if logger:
261
- logger.warning(
262
- f"{metric_name + ' ' if metric_name else ''}lower threshold value {lower_threshold_value} "
263
- f"overridden by lower threshold value limit {override_value}"
264
- )
265
- lower_threshold_value = override_value
266
-
267
- if (
268
- upper_threshold_value_limit is not None
269
- and upper_threshold_value is not None
270
- and upper_threshold_value >= upper_threshold_value_limit
271
- ):
272
- override_value = None if override_using_none else upper_threshold_value_limit
273
- if logger:
274
- logger.warning(
275
- f"{metric_name + ' ' if metric_name else ''}upper threshold value {upper_threshold_value} "
276
- f"overridden by upper threshold value limit {override_value}"
277
- )
278
- upper_threshold_value = override_value
279
-
280
- return lower_threshold_value, upper_threshold_value