dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +3 -3
  7. dataeval/detectors/linters/duplicates.py +4 -4
  8. dataeval/detectors/linters/outliers.py +4 -4
  9. dataeval/detectors/ood/__init__.py +9 -9
  10. dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
  11. dataeval/detectors/ood/base.py +63 -113
  12. dataeval/detectors/ood/base_torch.py +109 -0
  13. dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  14. dataeval/interop.py +1 -1
  15. dataeval/metrics/bias/__init__.py +3 -0
  16. dataeval/metrics/bias/balance.py +73 -70
  17. dataeval/metrics/bias/coverage.py +4 -4
  18. dataeval/metrics/bias/diversity.py +67 -136
  19. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  20. dataeval/metrics/bias/metadata_utils.py +229 -0
  21. dataeval/metrics/bias/parity.py +51 -161
  22. dataeval/metrics/estimators/ber.py +3 -3
  23. dataeval/metrics/estimators/divergence.py +3 -3
  24. dataeval/metrics/estimators/uap.py +3 -3
  25. dataeval/metrics/stats/base.py +2 -2
  26. dataeval/metrics/stats/boxratiostats.py +1 -1
  27. dataeval/metrics/stats/datasetstats.py +6 -6
  28. dataeval/metrics/stats/dimensionstats.py +1 -1
  29. dataeval/metrics/stats/hashstats.py +1 -1
  30. dataeval/metrics/stats/labelstats.py +3 -3
  31. dataeval/metrics/stats/pixelstats.py +1 -1
  32. dataeval/metrics/stats/visualstats.py +1 -1
  33. dataeval/output.py +77 -53
  34. dataeval/utils/__init__.py +1 -7
  35. dataeval/utils/gmm.py +26 -0
  36. dataeval/utils/metadata.py +29 -9
  37. dataeval/utils/torch/gmm.py +98 -0
  38. dataeval/utils/torch/models.py +192 -0
  39. dataeval/utils/torch/trainer.py +84 -5
  40. dataeval/utils/torch/utils.py +107 -1
  41. dataeval/workflows/sufficiency.py +4 -4
  42. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
  43. dataeval-0.74.1.dist-info/RECORD +65 -0
  44. dataeval/detectors/ood/aegmm.py +0 -66
  45. dataeval/detectors/ood/llr.py +0 -302
  46. dataeval/detectors/ood/vae.py +0 -97
  47. dataeval/detectors/ood/vaegmm.py +0 -75
  48. dataeval/metrics/bias/metadata.py +0 -440
  49. dataeval/utils/lazy.py +0 -26
  50. dataeval/utils/tensorflow/__init__.py +0 -19
  51. dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  52. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  53. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  54. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  55. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  56. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  57. dataeval-0.73.1.dist-info/RECORD +0 -73
  58. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["AETrainer"]
4
-
5
- from typing import Any
3
+ from typing import Any, Callable
6
4
 
7
5
  import torch
8
6
  import torch.nn as nn
7
+ from numpy.typing import NDArray
9
8
  from torch.optim import Adam
10
- from torch.utils.data import DataLoader, Dataset
9
+ from torch.utils.data import DataLoader, Dataset, TensorDataset
10
+ from tqdm import tqdm
11
11
 
12
- torch.manual_seed(0)
12
+ __all__ = ["AETrainer", "trainer"]
13
13
 
14
14
 
15
15
  def get_images_from_batch(batch: Any) -> Any:
@@ -176,3 +176,82 @@ class AETrainer:
176
176
  encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
177
177
 
178
178
  return encodings
179
+
180
+
181
+ def trainer(
182
+ model: torch.nn.Module,
183
+ x_train: NDArray[Any],
184
+ y_train: NDArray[Any] | None,
185
+ loss_fn: Callable[..., torch.Tensor | torch.nn.Module] | None,
186
+ optimizer: torch.optim.Optimizer | None,
187
+ preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None,
188
+ epochs: int,
189
+ batch_size: int,
190
+ device: torch.device,
191
+ verbose: bool,
192
+ ) -> None:
193
+ """
194
+ Train Pytorch model.
195
+
196
+ Parameters
197
+ ----------
198
+ model
199
+ Model to train.
200
+ loss_fn
201
+ Loss function used for training.
202
+ x_train
203
+ Training data.
204
+ y_train
205
+ Training labels.
206
+ optimizer
207
+ Optimizer used for training.
208
+ preprocess_fn
209
+ Preprocessing function applied to each training batch.
210
+ epochs
211
+ Number of training epochs.
212
+ reg_loss_fn
213
+ Allows an additional regularisation term to be defined as reg_loss_fn(model)
214
+ batch_size
215
+ Batch size used for training.
216
+ buffer_size
217
+ Maximum number of elements that will be buffered when prefetching.
218
+ verbose
219
+ Whether to print training progress.
220
+ """
221
+ if optimizer is None:
222
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
223
+
224
+ if y_train is None:
225
+ dataset = TensorDataset(torch.from_numpy(x_train).to(torch.float32))
226
+
227
+ else:
228
+ dataset = TensorDataset(
229
+ torch.from_numpy(x_train).to(torch.float32), torch.from_numpy(y_train).to(torch.float32)
230
+ )
231
+
232
+ loader = DataLoader(dataset=dataset)
233
+
234
+ model = model.to(device)
235
+
236
+ # iterate over epochs
237
+ loss = torch.nan
238
+ disable_tqdm = not verbose
239
+ for epoch in (pbar := tqdm(range(epochs), disable=disable_tqdm)):
240
+ epoch_loss = loss
241
+ for step, data in enumerate(loader):
242
+ if step % 250 == 0:
243
+ pbar.set_description(f"Epoch: {epoch} ({epoch_loss:.3f}), loss: {loss:.3f}")
244
+
245
+ x, y = [d.to(device) for d in data] if len(data) > 1 else (data[0].to(device), None)
246
+
247
+ if isinstance(preprocess_fn, Callable):
248
+ x = preprocess_fn(x)
249
+
250
+ y_hat = model(x)
251
+ y = x if y is None else y
252
+
253
+ loss = loss_fn(y, y_hat) # type: ignore
254
+
255
+ optimizer.zero_grad()
256
+ loss.backward()
257
+ optimizer.step()
@@ -3,8 +3,12 @@ from __future__ import annotations
3
3
  __all__ = ["read_dataset"]
4
4
 
5
5
  from collections import defaultdict
6
- from typing import Any
6
+ from functools import partial
7
+ from typing import Any, Callable
7
8
 
9
+ import numpy as np
10
+ import torch
11
+ from numpy.typing import NDArray
8
12
  from torch.utils.data import Dataset
9
13
 
10
14
 
@@ -61,3 +65,105 @@ def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
61
65
  ddict[i].append(d)
62
66
 
63
67
  return list(ddict.values())
68
+
69
+
70
+ def get_device(device: str | torch.device | None = None) -> torch.device:
71
+ """
72
+ Instantiates a PyTorch device object.
73
+
74
+ Parameters
75
+ ----------
76
+ device : str | torch.device | None, default None
77
+ Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
78
+ already instantiated device object. If ``None``, the GPU is selected if it is
79
+ detected, otherwise the CPU is used as a fallback.
80
+
81
+ Returns
82
+ -------
83
+ The instantiated device object.
84
+ """
85
+ if isinstance(device, torch.device): # Already a torch device
86
+ return device
87
+ else: # Instantiate device
88
+ if device is None or device.lower() in ["gpu", "cuda"]:
89
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+ else:
91
+ torch_device = torch.device("cpu")
92
+ return torch_device
93
+
94
+
95
+ def predict_batch(
96
+ x: NDArray[Any] | torch.Tensor,
97
+ model: Callable | torch.nn.Module | torch.nn.Sequential,
98
+ device: torch.device | None = None,
99
+ batch_size: int = int(1e10),
100
+ preprocess_fn: Callable | None = None,
101
+ dtype: type[np.generic] | torch.dtype = np.float32,
102
+ ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
103
+ """
104
+ Make batch predictions on a model.
105
+
106
+ Parameters
107
+ ----------
108
+ x : np.ndarray | torch.Tensor
109
+ Batch of instances.
110
+ model : Callable | nn.Module | nn.Sequential
111
+ PyTorch model.
112
+ device : torch.device | None, default None
113
+ Device type used. The default None tries to use the GPU and falls back on CPU.
114
+ Can be specified by passing either torch.device('cuda') or torch.device('cpu').
115
+ batch_size : int, default 1e10
116
+ Batch size used during prediction.
117
+ preprocess_fn : Callable | None, default None
118
+ Optional preprocessing function for each batch.
119
+ dtype : np.dtype | torch.dtype, default np.float32
120
+ Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
121
+
122
+ Returns
123
+ -------
124
+ NDArray | torch.Tensor | tuple
125
+ Numpy array, torch tensor or tuples of those with model outputs.
126
+ """
127
+ device = get_device(device)
128
+ if isinstance(x, np.ndarray):
129
+ x = torch.from_numpy(x).to(device)
130
+ n = len(x)
131
+ n_minibatch = int(np.ceil(n / batch_size))
132
+ return_np = not isinstance(dtype, torch.dtype)
133
+ preds = []
134
+ with torch.no_grad():
135
+ for i in range(n_minibatch):
136
+ istart, istop = i * batch_size, min((i + 1) * batch_size, n)
137
+ x_batch = x[istart:istop]
138
+ if isinstance(preprocess_fn, Callable):
139
+ x_batch = preprocess_fn(x_batch)
140
+
141
+ preds_tmp = model(x_batch.to(torch.float32).to(device))
142
+ if isinstance(preds_tmp, (list, tuple)):
143
+ if len(preds) == 0: # init tuple with lists to store predictions
144
+ preds = tuple([] for _ in range(len(preds_tmp)))
145
+ for j, p in enumerate(preds_tmp):
146
+ if isinstance(p, torch.Tensor):
147
+ p = p.cpu()
148
+ preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
149
+ elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
150
+ if isinstance(preds_tmp, torch.Tensor):
151
+ preds_tmp = preds_tmp.cpu()
152
+ if isinstance(preds, tuple):
153
+ preds = list(preds)
154
+ preds.append(
155
+ preds_tmp
156
+ if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
157
+ else preds_tmp.numpy()
158
+ )
159
+ else:
160
+ raise TypeError(
161
+ f"Model output type {type(preds_tmp)} not supported. The model \
162
+ output type needs to be one of list, tuple, NDArray or \
163
+ torch.Tensor."
164
+ )
165
+ concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
166
+ out: tuple | np.ndarray | torch.Tensor = (
167
+ tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
168
+ )
169
+ return out
@@ -16,11 +16,11 @@ from scipy.optimize import basinhopping
16
16
  from torch.utils.data import Dataset
17
17
 
18
18
  from dataeval.interop import as_numpy
19
- from dataeval.output import OutputMetadata, set_metadata
19
+ from dataeval.output import Output, set_metadata
20
20
 
21
21
 
22
22
  @dataclass(frozen=True)
23
- class SufficiencyOutput(OutputMetadata):
23
+ class SufficiencyOutput(Output):
24
24
  """
25
25
  Output class for :class:`Sufficiency` workflow
26
26
 
@@ -47,7 +47,7 @@ class SufficiencyOutput(OutputMetadata):
47
47
  if c != c_v:
48
48
  raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
49
49
 
50
- @set_metadata()
50
+ @set_metadata
51
51
  def project(
52
52
  self,
53
53
  projection: int | Iterable[int],
@@ -484,7 +484,7 @@ class Sufficiency(Generic[T]):
484
484
  def eval_kwargs(self, value: Mapping[str, Any] | None) -> None:
485
485
  self._eval_kwargs = {} if value is None else value
486
486
 
487
- @set_metadata(["runs", "substeps"])
487
+ @set_metadata(state=["runs", "substeps"])
488
488
  def evaluate(self, eval_at: int | Iterable[int] | None = None, niter: int = 1000) -> SufficiencyOutput:
489
489
  """
490
490
  Creates data indices, trains models, and returns plotting data
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.73.1
3
+ Version: 0.74.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -21,18 +21,12 @@ Classifier: Programming Language :: Python :: 3.12
21
21
  Classifier: Programming Language :: Python :: 3 :: Only
22
22
  Classifier: Topic :: Scientific/Engineering
23
23
  Provides-Extra: all
24
- Provides-Extra: tensorflow
25
24
  Provides-Extra: torch
26
- Requires-Dist: hdbscan (>=0.8.36)
27
- Requires-Dist: markupsafe (<3.0.2) ; extra == "tensorflow" or extra == "all"
28
- Requires-Dist: matplotlib ; extra == "torch" or extra == "all"
29
- Requires-Dist: numpy (>1.24.3)
25
+ Requires-Dist: matplotlib ; extra == "all"
26
+ Requires-Dist: numpy (>=1.24.3)
30
27
  Requires-Dist: pillow (>=10.3.0)
31
28
  Requires-Dist: scikit-learn (>=1.5.0)
32
29
  Requires-Dist: scipy (>=1.10)
33
- Requires-Dist: tensorflow (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
34
- Requires-Dist: tensorflow_probability (>=0.24,<0.25) ; extra == "tensorflow" or extra == "all"
35
- Requires-Dist: tf-keras (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
36
30
  Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
37
31
  Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
38
32
  Requires-Dist: tqdm
@@ -0,0 +1,65 @@
1
+ dataeval/__init__.py,sha256=HNOjwnFIQCD7vwBBo0xMexlnNG3xRZ3s3VUMsA4Qozw,392
2
+ dataeval/detectors/__init__.py,sha256=Y-0bbyWyuMvZU80bCx6WPt3IV_r2hu9ymzpA8uzMqoI,206
3
+ dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
4
+ dataeval/detectors/drift/base.py,sha256=QDGHMu1WADD-38MEIOwjQMEQM3DE7B0yFHO3hsMbV-E,14481
5
+ dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
6
+ dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
7
+ dataeval/detectors/drift/mmd.py,sha256=C0FX5v9ZJzmKNYEcYUaC7sDtMpJ2dZpwikNDu-AEWiI,7584
8
+ dataeval/detectors/drift/torch.py,sha256=igEQ2DV9JmcpTdUKCOHBi5LxtoNeCAslJS2Ldulg1hw,7585
9
+ dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
10
+ dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
11
+ dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
12
+ dataeval/detectors/linters/clusterer.py,sha256=hK-ak02GaxwWuufesZMKDsvoE5fMdXO7UWsLiK8hfY0,21008
13
+ dataeval/detectors/linters/duplicates.py,sha256=2bmPTFqoefeiAQV9y4CGlHV_mJNrysJSEFLXLd2DO4I,5661
14
+ dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
15
+ dataeval/detectors/linters/outliers.py,sha256=X48bzTfTr1LqC6WKVKBRfvpjcQRgmb93cNLT7Oipe3M,10113
16
+ dataeval/detectors/ood/__init__.py,sha256=-D4Fq-ysFylNNMqjHG1ALbB9qBCm_UinkCAgsK9HGg0,408
17
+ dataeval/detectors/ood/ae_torch.py,sha256=pO9w5221bXR9lEBkE7oakXeE7PXUUR--xcTpmHvOCSk,2142
18
+ dataeval/detectors/ood/base.py,sha256=UzcDbXl8Gv43VFzjrOegTnKSIoEYmfDP7fAySeWyWPw,6955
19
+ dataeval/detectors/ood/base_torch.py,sha256=yFbSfQsBMwZeVf8mrixmkZYBGChhV5oAHtkgzWnMzsA,3405
20
+ dataeval/detectors/ood/metadata_ks_compare.py,sha256=LNDNWGEDKTW8_-djgmK53sn9EZzzXq1Sgwc47k0QI-Y,5380
21
+ dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
22
+ dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
23
+ dataeval/interop.py,sha256=SB5Nca12rluZeXrpmmlfY7LFJbN5opYM7jmAb2c29hM,1748
24
+ dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
25
+ dataeval/metrics/bias/__init__.py,sha256=dYiPHenS8J7pgRMMW2jNkTBmTbPoYTxT04fZu9PFats,747
26
+ dataeval/metrics/bias/balance.py,sha256=_TZEe17AT-qOvPp-QFrQfTqNwh8uVVCYjC4Sv6JBx9o,9118
27
+ dataeval/metrics/bias/coverage.py,sha256=o65_IgrWSlGnYeYZFABjwKaxq09uqyy5esHJM67PJ-k,4528
28
+ dataeval/metrics/bias/diversity.py,sha256=WL1NbZiRrv0SIq97FY3womZNCSl_EBMVlBWQZAUtjk8,7701
29
+ dataeval/metrics/bias/metadata_preprocessing.py,sha256=ekUFiirkmaHDiH7nJjkNpiUQD7OolAPhHorjLxpXv_Y,12248
30
+ dataeval/metrics/bias/metadata_utils.py,sha256=HmTjlRRTdM9566oKUDDdVMJ8luss4DYykFOiS2FQzhM,6558
31
+ dataeval/metrics/bias/parity.py,sha256=hnA7qQH4Uy3tl_krluZ9BPD5zYjjagUxZt2fEiIa2yE,12745
32
+ dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
33
+ dataeval/metrics/estimators/ber.py,sha256=fs3_e9pgu7I50QIALWtF2aidkBZhTCKVE2pA7PyB5Go,5019
34
+ dataeval/metrics/estimators/divergence.py,sha256=r_SKSurf1TdI5E1ivENqDnz8cQ3_sxVGKAqmF9cqcT4,4275
35
+ dataeval/metrics/estimators/uap.py,sha256=Aw5ReoWNK73Tq96r__qN_-cvHrELauqtDX3Af_QxX4s,2157
36
+ dataeval/metrics/stats/__init__.py,sha256=igLRaAt1nX6yRwC4xI0zNPBADi3u7EsSxWP3OZ8AqcU,1086
37
+ dataeval/metrics/stats/base.py,sha256=_C05KUAuDrfX3N-19o25V3vmXr0-45A5fc57cXyV8qs,12161
38
+ dataeval/metrics/stats/boxratiostats.py,sha256=bZunY-b8Y2IQqHlTusQN77ujLOHftogEQIARDpdVv6A,6463
39
+ dataeval/metrics/stats/datasetstats.py,sha256=rZUDiciHwEpnXmkI8-uJNiYwUuTL9ssZMKMx73hVX-Y,6219
40
+ dataeval/metrics/stats/dimensionstats.py,sha256=xITgQF_oomb6Ty_dJcbT3ARGGNp4QRcYSgnkjB4f-YE,4054
41
+ dataeval/metrics/stats/hashstats.py,sha256=vxw_K74EJM9CZy-EV617vdrysFO8nEspVWqIYsIHC-c,4958
42
+ dataeval/metrics/stats/labelstats.py,sha256=K0hJTphMe7htSjyss8GPtKDiHepTuU60_hX0xRA-uAg,4096
43
+ dataeval/metrics/stats/pixelstats.py,sha256=2zr9i3GLNx1i_SCtbfdtZNxXBEc_9wCe4qDpmXLVbKY,4576
44
+ dataeval/metrics/stats/visualstats.py,sha256=vLIC4sMo796axWl-4e4RzT33ll-_6ki54Dirn3V-EL8,4948
45
+ dataeval/output.py,sha256=SmzH9W9yewdL9SBKVBkUUvOo45oA5lHphE2DYvJJMu0,3573
46
+ dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
+ dataeval/utils/__init__.py,sha256=z7HxSijjycey-rGdQkgVOdpvT0oO2pKAuT4uYyxYGMs,555
48
+ dataeval/utils/gmm.py,sha256=YuLsJKsVWgH_wHr1u_hSRH5Yeexdj8exht8h99L7bLo,561
49
+ dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
50
+ dataeval/utils/metadata.py,sha256=0A--iru0zEmi044mKz5P35q69KrI30yoiRSlvs7TSdQ,9418
51
+ dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
52
+ dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
53
+ dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
54
+ dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
55
+ dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
56
+ dataeval/utils/torch/gmm.py,sha256=VbLlUQohwToApT493_tjQBWy2UM5R-3ppS9Dp-eP7BA,3240
57
+ dataeval/utils/torch/models.py,sha256=sdGeo7a8vshCTGA4lYyVxxb_aDWUlxdtIVxrddS-_ls,8542
58
+ dataeval/utils/torch/trainer.py,sha256=8BEXr6xtk-CHJTcNxOBnWgkFWfJUAiBy28cEdBhLMRU,7883
59
+ dataeval/utils/torch/utils.py,sha256=nWRcT6z6DbFVrL1RyxCOX3DPoCrv9G0B-VI_9LdGCQQ,5784
60
+ dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
61
+ dataeval/workflows/sufficiency.py,sha256=v9AV3BZT0NW-zD2VNIL_5aWspvoscrxRIUKcUdpy7HI,18540
62
+ dataeval-0.74.1.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
63
+ dataeval-0.74.1.dist-info/METADATA,sha256=nd7os3kaLfp-A5HWH0QYVxe-gQdj5q3dIn9d0fPf-Lk,4298
64
+ dataeval-0.74.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
65
+ dataeval-0.74.1.dist-info/RECORD,,
@@ -1,66 +0,0 @@
1
- """
2
- Source code derived from Alibi-Detect 0.11.4
3
- https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
-
5
- Original code Copyright (c) 2023 Seldon Technologies Ltd
6
- Licensed under Apache Software License (Apache 2.0)
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- __all__ = ["OOD_AEGMM"]
12
-
13
- from typing import TYPE_CHECKING, Callable
14
-
15
- from numpy.typing import ArrayLike
16
-
17
- from dataeval.detectors.ood.base import OODGMMBase, OODScoreOutput
18
- from dataeval.interop import to_numpy
19
- from dataeval.utils.lazy import lazyload
20
- from dataeval.utils.tensorflow._internal.gmm import gmm_energy
21
- from dataeval.utils.tensorflow._internal.loss import LossGMM
22
- from dataeval.utils.tensorflow._internal.utils import predict_batch
23
-
24
- if TYPE_CHECKING:
25
- import tensorflow as tf
26
- import tf_keras as keras
27
-
28
- import dataeval.utils.tensorflow._internal.models as tf_models
29
- else:
30
- tf = lazyload("tensorflow")
31
- keras = lazyload("tf_keras")
32
- tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
33
-
34
-
35
- class OOD_AEGMM(OODGMMBase):
36
- """
37
- AE with Gaussian Mixture Model based outlier detector.
38
-
39
- Parameters
40
- ----------
41
- model : AEGMM
42
- An AEGMM model.
43
- """
44
-
45
- def __init__(self, model: tf_models.AEGMM) -> None:
46
- super().__init__(model)
47
-
48
- def fit(
49
- self,
50
- x_ref: ArrayLike,
51
- threshold_perc: float = 100.0,
52
- loss_fn: Callable[..., tf.Tensor] | None = None,
53
- optimizer: keras.optimizers.Optimizer | None = None,
54
- epochs: int = 20,
55
- batch_size: int = 64,
56
- verbose: bool = True,
57
- ) -> None:
58
- if loss_fn is None:
59
- loss_fn = LossGMM()
60
- super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
61
-
62
- def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
63
- self._validate(X := to_numpy(X))
64
- _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
65
- energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
66
- return OODScoreOutput(energy.numpy()) # type: ignore