dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +9 -9
- dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
- dataeval/detectors/ood/base.py +63 -113
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/metadata_ks_compare.py +52 -14
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +73 -70
- dataeval/metrics/bias/coverage.py +4 -4
- dataeval/metrics/bias/diversity.py +67 -136
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +51 -161
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/aegmm.py +0 -66
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -97
- dataeval/detectors/ood/vaegmm.py +0 -75
- dataeval/metrics/bias/metadata.py +0 -440
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -123
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.73.1.dist-info/RECORD +0 -73
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
dataeval/utils/torch/trainer.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
from typing import Any
|
3
|
+
from typing import Any, Callable
|
6
4
|
|
7
5
|
import torch
|
8
6
|
import torch.nn as nn
|
7
|
+
from numpy.typing import NDArray
|
9
8
|
from torch.optim import Adam
|
10
|
-
from torch.utils.data import DataLoader, Dataset
|
9
|
+
from torch.utils.data import DataLoader, Dataset, TensorDataset
|
10
|
+
from tqdm import tqdm
|
11
11
|
|
12
|
-
|
12
|
+
__all__ = ["AETrainer", "trainer"]
|
13
13
|
|
14
14
|
|
15
15
|
def get_images_from_batch(batch: Any) -> Any:
|
@@ -176,3 +176,82 @@ class AETrainer:
|
|
176
176
|
encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
|
177
177
|
|
178
178
|
return encodings
|
179
|
+
|
180
|
+
|
181
|
+
def trainer(
|
182
|
+
model: torch.nn.Module,
|
183
|
+
x_train: NDArray[Any],
|
184
|
+
y_train: NDArray[Any] | None,
|
185
|
+
loss_fn: Callable[..., torch.Tensor | torch.nn.Module] | None,
|
186
|
+
optimizer: torch.optim.Optimizer | None,
|
187
|
+
preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None,
|
188
|
+
epochs: int,
|
189
|
+
batch_size: int,
|
190
|
+
device: torch.device,
|
191
|
+
verbose: bool,
|
192
|
+
) -> None:
|
193
|
+
"""
|
194
|
+
Train Pytorch model.
|
195
|
+
|
196
|
+
Parameters
|
197
|
+
----------
|
198
|
+
model
|
199
|
+
Model to train.
|
200
|
+
loss_fn
|
201
|
+
Loss function used for training.
|
202
|
+
x_train
|
203
|
+
Training data.
|
204
|
+
y_train
|
205
|
+
Training labels.
|
206
|
+
optimizer
|
207
|
+
Optimizer used for training.
|
208
|
+
preprocess_fn
|
209
|
+
Preprocessing function applied to each training batch.
|
210
|
+
epochs
|
211
|
+
Number of training epochs.
|
212
|
+
reg_loss_fn
|
213
|
+
Allows an additional regularisation term to be defined as reg_loss_fn(model)
|
214
|
+
batch_size
|
215
|
+
Batch size used for training.
|
216
|
+
buffer_size
|
217
|
+
Maximum number of elements that will be buffered when prefetching.
|
218
|
+
verbose
|
219
|
+
Whether to print training progress.
|
220
|
+
"""
|
221
|
+
if optimizer is None:
|
222
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
223
|
+
|
224
|
+
if y_train is None:
|
225
|
+
dataset = TensorDataset(torch.from_numpy(x_train).to(torch.float32))
|
226
|
+
|
227
|
+
else:
|
228
|
+
dataset = TensorDataset(
|
229
|
+
torch.from_numpy(x_train).to(torch.float32), torch.from_numpy(y_train).to(torch.float32)
|
230
|
+
)
|
231
|
+
|
232
|
+
loader = DataLoader(dataset=dataset)
|
233
|
+
|
234
|
+
model = model.to(device)
|
235
|
+
|
236
|
+
# iterate over epochs
|
237
|
+
loss = torch.nan
|
238
|
+
disable_tqdm = not verbose
|
239
|
+
for epoch in (pbar := tqdm(range(epochs), disable=disable_tqdm)):
|
240
|
+
epoch_loss = loss
|
241
|
+
for step, data in enumerate(loader):
|
242
|
+
if step % 250 == 0:
|
243
|
+
pbar.set_description(f"Epoch: {epoch} ({epoch_loss:.3f}), loss: {loss:.3f}")
|
244
|
+
|
245
|
+
x, y = [d.to(device) for d in data] if len(data) > 1 else (data[0].to(device), None)
|
246
|
+
|
247
|
+
if isinstance(preprocess_fn, Callable):
|
248
|
+
x = preprocess_fn(x)
|
249
|
+
|
250
|
+
y_hat = model(x)
|
251
|
+
y = x if y is None else y
|
252
|
+
|
253
|
+
loss = loss_fn(y, y_hat) # type: ignore
|
254
|
+
|
255
|
+
optimizer.zero_grad()
|
256
|
+
loss.backward()
|
257
|
+
optimizer.step()
|
dataeval/utils/torch/utils.py
CHANGED
@@ -3,8 +3,12 @@ from __future__ import annotations
|
|
3
3
|
__all__ = ["read_dataset"]
|
4
4
|
|
5
5
|
from collections import defaultdict
|
6
|
-
from
|
6
|
+
from functools import partial
|
7
|
+
from typing import Any, Callable
|
7
8
|
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
from numpy.typing import NDArray
|
8
12
|
from torch.utils.data import Dataset
|
9
13
|
|
10
14
|
|
@@ -61,3 +65,105 @@ def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
|
|
61
65
|
ddict[i].append(d)
|
62
66
|
|
63
67
|
return list(ddict.values())
|
68
|
+
|
69
|
+
|
70
|
+
def get_device(device: str | torch.device | None = None) -> torch.device:
|
71
|
+
"""
|
72
|
+
Instantiates a PyTorch device object.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
device : str | torch.device | None, default None
|
77
|
+
Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
|
78
|
+
already instantiated device object. If ``None``, the GPU is selected if it is
|
79
|
+
detected, otherwise the CPU is used as a fallback.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
The instantiated device object.
|
84
|
+
"""
|
85
|
+
if isinstance(device, torch.device): # Already a torch device
|
86
|
+
return device
|
87
|
+
else: # Instantiate device
|
88
|
+
if device is None or device.lower() in ["gpu", "cuda"]:
|
89
|
+
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
90
|
+
else:
|
91
|
+
torch_device = torch.device("cpu")
|
92
|
+
return torch_device
|
93
|
+
|
94
|
+
|
95
|
+
def predict_batch(
|
96
|
+
x: NDArray[Any] | torch.Tensor,
|
97
|
+
model: Callable | torch.nn.Module | torch.nn.Sequential,
|
98
|
+
device: torch.device | None = None,
|
99
|
+
batch_size: int = int(1e10),
|
100
|
+
preprocess_fn: Callable | None = None,
|
101
|
+
dtype: type[np.generic] | torch.dtype = np.float32,
|
102
|
+
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
103
|
+
"""
|
104
|
+
Make batch predictions on a model.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
x : np.ndarray | torch.Tensor
|
109
|
+
Batch of instances.
|
110
|
+
model : Callable | nn.Module | nn.Sequential
|
111
|
+
PyTorch model.
|
112
|
+
device : torch.device | None, default None
|
113
|
+
Device type used. The default None tries to use the GPU and falls back on CPU.
|
114
|
+
Can be specified by passing either torch.device('cuda') or torch.device('cpu').
|
115
|
+
batch_size : int, default 1e10
|
116
|
+
Batch size used during prediction.
|
117
|
+
preprocess_fn : Callable | None, default None
|
118
|
+
Optional preprocessing function for each batch.
|
119
|
+
dtype : np.dtype | torch.dtype, default np.float32
|
120
|
+
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
NDArray | torch.Tensor | tuple
|
125
|
+
Numpy array, torch tensor or tuples of those with model outputs.
|
126
|
+
"""
|
127
|
+
device = get_device(device)
|
128
|
+
if isinstance(x, np.ndarray):
|
129
|
+
x = torch.from_numpy(x).to(device)
|
130
|
+
n = len(x)
|
131
|
+
n_minibatch = int(np.ceil(n / batch_size))
|
132
|
+
return_np = not isinstance(dtype, torch.dtype)
|
133
|
+
preds = []
|
134
|
+
with torch.no_grad():
|
135
|
+
for i in range(n_minibatch):
|
136
|
+
istart, istop = i * batch_size, min((i + 1) * batch_size, n)
|
137
|
+
x_batch = x[istart:istop]
|
138
|
+
if isinstance(preprocess_fn, Callable):
|
139
|
+
x_batch = preprocess_fn(x_batch)
|
140
|
+
|
141
|
+
preds_tmp = model(x_batch.to(torch.float32).to(device))
|
142
|
+
if isinstance(preds_tmp, (list, tuple)):
|
143
|
+
if len(preds) == 0: # init tuple with lists to store predictions
|
144
|
+
preds = tuple([] for _ in range(len(preds_tmp)))
|
145
|
+
for j, p in enumerate(preds_tmp):
|
146
|
+
if isinstance(p, torch.Tensor):
|
147
|
+
p = p.cpu()
|
148
|
+
preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
|
149
|
+
elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
|
150
|
+
if isinstance(preds_tmp, torch.Tensor):
|
151
|
+
preds_tmp = preds_tmp.cpu()
|
152
|
+
if isinstance(preds, tuple):
|
153
|
+
preds = list(preds)
|
154
|
+
preds.append(
|
155
|
+
preds_tmp
|
156
|
+
if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
|
157
|
+
else preds_tmp.numpy()
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
raise TypeError(
|
161
|
+
f"Model output type {type(preds_tmp)} not supported. The model \
|
162
|
+
output type needs to be one of list, tuple, NDArray or \
|
163
|
+
torch.Tensor."
|
164
|
+
)
|
165
|
+
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
166
|
+
out: tuple | np.ndarray | torch.Tensor = (
|
167
|
+
tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
|
168
|
+
)
|
169
|
+
return out
|
@@ -16,11 +16,11 @@ from scipy.optimize import basinhopping
|
|
16
16
|
from torch.utils.data import Dataset
|
17
17
|
|
18
18
|
from dataeval.interop import as_numpy
|
19
|
-
from dataeval.output import
|
19
|
+
from dataeval.output import Output, set_metadata
|
20
20
|
|
21
21
|
|
22
22
|
@dataclass(frozen=True)
|
23
|
-
class SufficiencyOutput(
|
23
|
+
class SufficiencyOutput(Output):
|
24
24
|
"""
|
25
25
|
Output class for :class:`Sufficiency` workflow
|
26
26
|
|
@@ -47,7 +47,7 @@ class SufficiencyOutput(OutputMetadata):
|
|
47
47
|
if c != c_v:
|
48
48
|
raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
|
49
49
|
|
50
|
-
@set_metadata
|
50
|
+
@set_metadata
|
51
51
|
def project(
|
52
52
|
self,
|
53
53
|
projection: int | Iterable[int],
|
@@ -484,7 +484,7 @@ class Sufficiency(Generic[T]):
|
|
484
484
|
def eval_kwargs(self, value: Mapping[str, Any] | None) -> None:
|
485
485
|
self._eval_kwargs = {} if value is None else value
|
486
486
|
|
487
|
-
@set_metadata(["runs", "substeps"])
|
487
|
+
@set_metadata(state=["runs", "substeps"])
|
488
488
|
def evaluate(self, eval_at: int | Iterable[int] | None = None, niter: int = 1000) -> SufficiencyOutput:
|
489
489
|
"""
|
490
490
|
Creates data indices, trains models, and returns plotting data
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.74.1
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -21,18 +21,12 @@ Classifier: Programming Language :: Python :: 3.12
|
|
21
21
|
Classifier: Programming Language :: Python :: 3 :: Only
|
22
22
|
Classifier: Topic :: Scientific/Engineering
|
23
23
|
Provides-Extra: all
|
24
|
-
Provides-Extra: tensorflow
|
25
24
|
Provides-Extra: torch
|
26
|
-
Requires-Dist:
|
27
|
-
Requires-Dist:
|
28
|
-
Requires-Dist: matplotlib ; extra == "torch" or extra == "all"
|
29
|
-
Requires-Dist: numpy (>1.24.3)
|
25
|
+
Requires-Dist: matplotlib ; extra == "all"
|
26
|
+
Requires-Dist: numpy (>=1.24.3)
|
30
27
|
Requires-Dist: pillow (>=10.3.0)
|
31
28
|
Requires-Dist: scikit-learn (>=1.5.0)
|
32
29
|
Requires-Dist: scipy (>=1.10)
|
33
|
-
Requires-Dist: tensorflow (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
|
34
|
-
Requires-Dist: tensorflow_probability (>=0.24,<0.25) ; extra == "tensorflow" or extra == "all"
|
35
|
-
Requires-Dist: tf-keras (>=2.16,<2.18) ; extra == "tensorflow" or extra == "all"
|
36
30
|
Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
|
37
31
|
Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
|
38
32
|
Requires-Dist: tqdm
|
@@ -0,0 +1,65 @@
|
|
1
|
+
dataeval/__init__.py,sha256=HNOjwnFIQCD7vwBBo0xMexlnNG3xRZ3s3VUMsA4Qozw,392
|
2
|
+
dataeval/detectors/__init__.py,sha256=Y-0bbyWyuMvZU80bCx6WPt3IV_r2hu9ymzpA8uzMqoI,206
|
3
|
+
dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
|
4
|
+
dataeval/detectors/drift/base.py,sha256=QDGHMu1WADD-38MEIOwjQMEQM3DE7B0yFHO3hsMbV-E,14481
|
5
|
+
dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
|
6
|
+
dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
|
7
|
+
dataeval/detectors/drift/mmd.py,sha256=C0FX5v9ZJzmKNYEcYUaC7sDtMpJ2dZpwikNDu-AEWiI,7584
|
8
|
+
dataeval/detectors/drift/torch.py,sha256=igEQ2DV9JmcpTdUKCOHBi5LxtoNeCAslJS2Ldulg1hw,7585
|
9
|
+
dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
|
10
|
+
dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
|
11
|
+
dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
|
12
|
+
dataeval/detectors/linters/clusterer.py,sha256=hK-ak02GaxwWuufesZMKDsvoE5fMdXO7UWsLiK8hfY0,21008
|
13
|
+
dataeval/detectors/linters/duplicates.py,sha256=2bmPTFqoefeiAQV9y4CGlHV_mJNrysJSEFLXLd2DO4I,5661
|
14
|
+
dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
|
15
|
+
dataeval/detectors/linters/outliers.py,sha256=X48bzTfTr1LqC6WKVKBRfvpjcQRgmb93cNLT7Oipe3M,10113
|
16
|
+
dataeval/detectors/ood/__init__.py,sha256=-D4Fq-ysFylNNMqjHG1ALbB9qBCm_UinkCAgsK9HGg0,408
|
17
|
+
dataeval/detectors/ood/ae_torch.py,sha256=pO9w5221bXR9lEBkE7oakXeE7PXUUR--xcTpmHvOCSk,2142
|
18
|
+
dataeval/detectors/ood/base.py,sha256=UzcDbXl8Gv43VFzjrOegTnKSIoEYmfDP7fAySeWyWPw,6955
|
19
|
+
dataeval/detectors/ood/base_torch.py,sha256=yFbSfQsBMwZeVf8mrixmkZYBGChhV5oAHtkgzWnMzsA,3405
|
20
|
+
dataeval/detectors/ood/metadata_ks_compare.py,sha256=LNDNWGEDKTW8_-djgmK53sn9EZzzXq1Sgwc47k0QI-Y,5380
|
21
|
+
dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
|
22
|
+
dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
|
23
|
+
dataeval/interop.py,sha256=SB5Nca12rluZeXrpmmlfY7LFJbN5opYM7jmAb2c29hM,1748
|
24
|
+
dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
|
25
|
+
dataeval/metrics/bias/__init__.py,sha256=dYiPHenS8J7pgRMMW2jNkTBmTbPoYTxT04fZu9PFats,747
|
26
|
+
dataeval/metrics/bias/balance.py,sha256=_TZEe17AT-qOvPp-QFrQfTqNwh8uVVCYjC4Sv6JBx9o,9118
|
27
|
+
dataeval/metrics/bias/coverage.py,sha256=o65_IgrWSlGnYeYZFABjwKaxq09uqyy5esHJM67PJ-k,4528
|
28
|
+
dataeval/metrics/bias/diversity.py,sha256=WL1NbZiRrv0SIq97FY3womZNCSl_EBMVlBWQZAUtjk8,7701
|
29
|
+
dataeval/metrics/bias/metadata_preprocessing.py,sha256=ekUFiirkmaHDiH7nJjkNpiUQD7OolAPhHorjLxpXv_Y,12248
|
30
|
+
dataeval/metrics/bias/metadata_utils.py,sha256=HmTjlRRTdM9566oKUDDdVMJ8luss4DYykFOiS2FQzhM,6558
|
31
|
+
dataeval/metrics/bias/parity.py,sha256=hnA7qQH4Uy3tl_krluZ9BPD5zYjjagUxZt2fEiIa2yE,12745
|
32
|
+
dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
|
33
|
+
dataeval/metrics/estimators/ber.py,sha256=fs3_e9pgu7I50QIALWtF2aidkBZhTCKVE2pA7PyB5Go,5019
|
34
|
+
dataeval/metrics/estimators/divergence.py,sha256=r_SKSurf1TdI5E1ivENqDnz8cQ3_sxVGKAqmF9cqcT4,4275
|
35
|
+
dataeval/metrics/estimators/uap.py,sha256=Aw5ReoWNK73Tq96r__qN_-cvHrELauqtDX3Af_QxX4s,2157
|
36
|
+
dataeval/metrics/stats/__init__.py,sha256=igLRaAt1nX6yRwC4xI0zNPBADi3u7EsSxWP3OZ8AqcU,1086
|
37
|
+
dataeval/metrics/stats/base.py,sha256=_C05KUAuDrfX3N-19o25V3vmXr0-45A5fc57cXyV8qs,12161
|
38
|
+
dataeval/metrics/stats/boxratiostats.py,sha256=bZunY-b8Y2IQqHlTusQN77ujLOHftogEQIARDpdVv6A,6463
|
39
|
+
dataeval/metrics/stats/datasetstats.py,sha256=rZUDiciHwEpnXmkI8-uJNiYwUuTL9ssZMKMx73hVX-Y,6219
|
40
|
+
dataeval/metrics/stats/dimensionstats.py,sha256=xITgQF_oomb6Ty_dJcbT3ARGGNp4QRcYSgnkjB4f-YE,4054
|
41
|
+
dataeval/metrics/stats/hashstats.py,sha256=vxw_K74EJM9CZy-EV617vdrysFO8nEspVWqIYsIHC-c,4958
|
42
|
+
dataeval/metrics/stats/labelstats.py,sha256=K0hJTphMe7htSjyss8GPtKDiHepTuU60_hX0xRA-uAg,4096
|
43
|
+
dataeval/metrics/stats/pixelstats.py,sha256=2zr9i3GLNx1i_SCtbfdtZNxXBEc_9wCe4qDpmXLVbKY,4576
|
44
|
+
dataeval/metrics/stats/visualstats.py,sha256=vLIC4sMo796axWl-4e4RzT33ll-_6ki54Dirn3V-EL8,4948
|
45
|
+
dataeval/output.py,sha256=SmzH9W9yewdL9SBKVBkUUvOo45oA5lHphE2DYvJJMu0,3573
|
46
|
+
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
+
dataeval/utils/__init__.py,sha256=z7HxSijjycey-rGdQkgVOdpvT0oO2pKAuT4uYyxYGMs,555
|
48
|
+
dataeval/utils/gmm.py,sha256=YuLsJKsVWgH_wHr1u_hSRH5Yeexdj8exht8h99L7bLo,561
|
49
|
+
dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
|
50
|
+
dataeval/utils/metadata.py,sha256=0A--iru0zEmi044mKz5P35q69KrI30yoiRSlvs7TSdQ,9418
|
51
|
+
dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
|
52
|
+
dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
|
53
|
+
dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
|
54
|
+
dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
55
|
+
dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
|
56
|
+
dataeval/utils/torch/gmm.py,sha256=VbLlUQohwToApT493_tjQBWy2UM5R-3ppS9Dp-eP7BA,3240
|
57
|
+
dataeval/utils/torch/models.py,sha256=sdGeo7a8vshCTGA4lYyVxxb_aDWUlxdtIVxrddS-_ls,8542
|
58
|
+
dataeval/utils/torch/trainer.py,sha256=8BEXr6xtk-CHJTcNxOBnWgkFWfJUAiBy28cEdBhLMRU,7883
|
59
|
+
dataeval/utils/torch/utils.py,sha256=nWRcT6z6DbFVrL1RyxCOX3DPoCrv9G0B-VI_9LdGCQQ,5784
|
60
|
+
dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
|
61
|
+
dataeval/workflows/sufficiency.py,sha256=v9AV3BZT0NW-zD2VNIL_5aWspvoscrxRIUKcUdpy7HI,18540
|
62
|
+
dataeval-0.74.1.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
|
63
|
+
dataeval-0.74.1.dist-info/METADATA,sha256=nd7os3kaLfp-A5HWH0QYVxe-gQdj5q3dIn9d0fPf-Lk,4298
|
64
|
+
dataeval-0.74.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
65
|
+
dataeval-0.74.1.dist-info/RECORD,,
|
dataeval/detectors/ood/aegmm.py
DELETED
@@ -1,66 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
__all__ = ["OOD_AEGMM"]
|
12
|
-
|
13
|
-
from typing import TYPE_CHECKING, Callable
|
14
|
-
|
15
|
-
from numpy.typing import ArrayLike
|
16
|
-
|
17
|
-
from dataeval.detectors.ood.base import OODGMMBase, OODScoreOutput
|
18
|
-
from dataeval.interop import to_numpy
|
19
|
-
from dataeval.utils.lazy import lazyload
|
20
|
-
from dataeval.utils.tensorflow._internal.gmm import gmm_energy
|
21
|
-
from dataeval.utils.tensorflow._internal.loss import LossGMM
|
22
|
-
from dataeval.utils.tensorflow._internal.utils import predict_batch
|
23
|
-
|
24
|
-
if TYPE_CHECKING:
|
25
|
-
import tensorflow as tf
|
26
|
-
import tf_keras as keras
|
27
|
-
|
28
|
-
import dataeval.utils.tensorflow._internal.models as tf_models
|
29
|
-
else:
|
30
|
-
tf = lazyload("tensorflow")
|
31
|
-
keras = lazyload("tf_keras")
|
32
|
-
tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
|
33
|
-
|
34
|
-
|
35
|
-
class OOD_AEGMM(OODGMMBase):
|
36
|
-
"""
|
37
|
-
AE with Gaussian Mixture Model based outlier detector.
|
38
|
-
|
39
|
-
Parameters
|
40
|
-
----------
|
41
|
-
model : AEGMM
|
42
|
-
An AEGMM model.
|
43
|
-
"""
|
44
|
-
|
45
|
-
def __init__(self, model: tf_models.AEGMM) -> None:
|
46
|
-
super().__init__(model)
|
47
|
-
|
48
|
-
def fit(
|
49
|
-
self,
|
50
|
-
x_ref: ArrayLike,
|
51
|
-
threshold_perc: float = 100.0,
|
52
|
-
loss_fn: Callable[..., tf.Tensor] | None = None,
|
53
|
-
optimizer: keras.optimizers.Optimizer | None = None,
|
54
|
-
epochs: int = 20,
|
55
|
-
batch_size: int = 64,
|
56
|
-
verbose: bool = True,
|
57
|
-
) -> None:
|
58
|
-
if loss_fn is None:
|
59
|
-
loss_fn = LossGMM()
|
60
|
-
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
61
|
-
|
62
|
-
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
63
|
-
self._validate(X := to_numpy(X))
|
64
|
-
_, z, _ = predict_batch(X, self.model, batch_size=batch_size)
|
65
|
-
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
66
|
-
return OODScoreOutput(energy.numpy()) # type: ignore
|