dataeval 0.73.1__py3-none-any.whl → 0.74.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/ood/__init__.py +10 -3
- dataeval/detectors/ood/ae.py +2 -1
- dataeval/detectors/ood/ae_torch.py +70 -0
- dataeval/detectors/ood/aegmm.py +4 -3
- dataeval/detectors/ood/base.py +58 -108
- dataeval/detectors/ood/base_tf.py +109 -0
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/llr.py +2 -2
- dataeval/detectors/ood/metadata_ks_compare.py +53 -14
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/detectors/ood/vaegmm.py +5 -4
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +70 -67
- dataeval/metrics/bias/coverage.py +1 -1
- dataeval/metrics/bias/diversity.py +64 -133
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +47 -157
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/tensorflow/_internal/gmm.py +4 -24
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/RECORD +31 -25
- dataeval/metrics/bias/metadata.py +0 -440
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Adapted for Pytorch from:
|
3
|
+
|
4
|
+
Source code derived from Alibi-Detect 0.11.4
|
5
|
+
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
|
+
|
7
|
+
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
8
|
+
Licensed under Apache Software License (Apache 2.0)
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from dataeval.utils.gmm import GaussianMixtureModelParams
|
17
|
+
|
18
|
+
|
19
|
+
def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams[torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Compute parameters of Gaussian Mixture Model.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
z : torch.Tensor
|
26
|
+
Observations.
|
27
|
+
gamma : torch.Tensor
|
28
|
+
Mixture probabilities to derive mixture distribution weights from.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
33
|
+
The parameters used to calculate energy.
|
34
|
+
"""
|
35
|
+
|
36
|
+
# compute gmm parameters phi, mu and cov
|
37
|
+
N = gamma.shape[0] # nb of samples in batch
|
38
|
+
sum_gamma = torch.sum(gamma, 0) # K
|
39
|
+
phi = sum_gamma / N # K
|
40
|
+
# K x D (D = latent_dim)
|
41
|
+
mu = torch.sum(torch.unsqueeze(gamma, -1) * torch.unsqueeze(z, 1), 0) / torch.unsqueeze(sum_gamma, -1)
|
42
|
+
z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(mu, 0) # N x K x D
|
43
|
+
z_mu_outer = torch.unsqueeze(z_mu, -1) * torch.unsqueeze(z_mu, -2) # N x K x D x D
|
44
|
+
|
45
|
+
# K x D x D
|
46
|
+
cov = torch.sum(torch.unsqueeze(torch.unsqueeze(gamma, -1), -1) * z_mu_outer, 0) / torch.unsqueeze(
|
47
|
+
torch.unsqueeze(sum_gamma, -1), -1
|
48
|
+
)
|
49
|
+
|
50
|
+
# cholesky decomposition of covariance and determinant derivation
|
51
|
+
D = cov.shape[1]
|
52
|
+
eps = 1e-6
|
53
|
+
L = torch.linalg.cholesky(cov + torch.eye(D) * eps) # K x D x D
|
54
|
+
log_det_cov = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), 1) # K
|
55
|
+
|
56
|
+
return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
57
|
+
|
58
|
+
|
59
|
+
def gmm_energy(
|
60
|
+
z: torch.Tensor,
|
61
|
+
params: GaussianMixtureModelParams[torch.Tensor],
|
62
|
+
return_mean: bool = True,
|
63
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
64
|
+
"""
|
65
|
+
Compute sample energy from Gaussian Mixture Model.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
params : GaussianMixtureModelParams
|
70
|
+
The gaussian mixture model parameters.
|
71
|
+
return_mean : bool, default True
|
72
|
+
Take mean across all sample energies in a batch.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
sample_energy
|
77
|
+
The sample energy of the GMM.
|
78
|
+
cov_diag
|
79
|
+
The inverse sum of the diagonal components of the covariance matrix.
|
80
|
+
"""
|
81
|
+
D = params.cov.shape[1]
|
82
|
+
z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(params.mu, 0) # N x K x D
|
83
|
+
z_mu_T = torch.permute(z_mu, dims=[1, 2, 0]) # K x D x N
|
84
|
+
v = torch.linalg.solve_triangular(params.L, z_mu_T, upper=False) # K x D x D
|
85
|
+
|
86
|
+
# rewrite sample energy in logsumexp format for numerical stability
|
87
|
+
logits = torch.log(torch.unsqueeze(params.phi, -1)) - 0.5 * (
|
88
|
+
torch.sum(torch.square(v), 1) + float(D) * np.log(2.0 * np.pi) + torch.unsqueeze(params.log_det_cov, -1)
|
89
|
+
) # K x N
|
90
|
+
sample_energy = -torch.logsumexp(logits, 0) # N
|
91
|
+
|
92
|
+
if return_mean:
|
93
|
+
sample_energy = torch.mean(sample_energy)
|
94
|
+
|
95
|
+
# inverse sum of variances
|
96
|
+
cov_diag = torch.sum(torch.divide(torch.tensor(1), torch.diagonal(params.cov, dim1=-2, dim2=-1)))
|
97
|
+
|
98
|
+
return sample_energy, cov_diag
|
dataeval/utils/torch/models.py
CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
|
4
4
|
|
5
|
+
import math
|
5
6
|
from typing import Any
|
6
7
|
|
8
|
+
import torch
|
7
9
|
import torch.nn as nn
|
8
10
|
|
9
11
|
|
@@ -136,3 +138,193 @@ class Decoder(nn.Module):
|
|
136
138
|
The reconstructed output tensor.
|
137
139
|
"""
|
138
140
|
return self.decoder(x)
|
141
|
+
|
142
|
+
|
143
|
+
class AE(nn.Module):
|
144
|
+
"""
|
145
|
+
An autoencoder model with a separate encoder and decoder. Meant to replace the TensorFlow model called AE, which we
|
146
|
+
used as the core of an autoencoder-based OOD detector, i.e. as an argument to OOD_AE().
|
147
|
+
|
148
|
+
Parameters
|
149
|
+
----------
|
150
|
+
input_shape : tuple[int, int, int]
|
151
|
+
Number of input channels, number of rows, number of columns.() Number of examples per batch will be inferred
|
152
|
+
at runtime.)
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(self, input_shape: tuple[int, int, int]) -> None:
|
156
|
+
super().__init__()
|
157
|
+
|
158
|
+
input_dim = math.prod(input_shape)
|
159
|
+
|
160
|
+
# following is lifted from src/dataeval/utils/tensorflow/_internal/utils.py. It makes an odd staircase that is
|
161
|
+
# basically proportional to the number of numbers in the image to the 0.8 power. '
|
162
|
+
encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)))
|
163
|
+
|
164
|
+
self.encoder: Encoder_AE = Encoder_AE(input_shape, encoding_dim)
|
165
|
+
|
166
|
+
self.decoder: Decoder_AE = Decoder_AE(input_shape, encoding_dim, self.encoder.post_op_shape)
|
167
|
+
|
168
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
169
|
+
"""
|
170
|
+
Perform a forward pass through the encoder and decoder.
|
171
|
+
|
172
|
+
Parameters
|
173
|
+
----------
|
174
|
+
x : torch.Tensor
|
175
|
+
Input tensor
|
176
|
+
|
177
|
+
Returns
|
178
|
+
-------
|
179
|
+
torch.Tensor
|
180
|
+
The reconstructed output tensor.
|
181
|
+
"""
|
182
|
+
x = self.encoder(x)
|
183
|
+
x = self.decoder(x)
|
184
|
+
return x
|
185
|
+
|
186
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
187
|
+
"""
|
188
|
+
Encode the input tensor using the encoder.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
x : torch.Tensor
|
193
|
+
Input tensor
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
torch.Tensor
|
198
|
+
The encoded representation of the input tensor.
|
199
|
+
"""
|
200
|
+
return self.encoder(x)
|
201
|
+
|
202
|
+
|
203
|
+
class Encoder_AE(nn.Module):
|
204
|
+
"""
|
205
|
+
A simple encoder to be used in an autoencoder model.
|
206
|
+
|
207
|
+
This is the encoder used to replicate AE, which was a TF function. It consists of a CNN followed by a fully
|
208
|
+
connected layer.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
channels : int
|
213
|
+
Number of input channels
|
214
|
+
|
215
|
+
input_shape : tuple[int, int, int]
|
216
|
+
number of channels, number of rows, number of columns in input images.
|
217
|
+
|
218
|
+
encoding_dim : the size of the 1D array that emerges from the fully connected layer.
|
219
|
+
|
220
|
+
"""
|
221
|
+
|
222
|
+
def __init__(
|
223
|
+
self,
|
224
|
+
input_shape: tuple[int, int, int],
|
225
|
+
encoding_dim: int,
|
226
|
+
) -> None:
|
227
|
+
super().__init__()
|
228
|
+
|
229
|
+
channels = input_shape[0]
|
230
|
+
nc_in, nc_mid, nc_done = 256, 128, 64
|
231
|
+
|
232
|
+
conv_in = nn.Conv2d(channels, nc_in, 2, stride=1, padding=1)
|
233
|
+
conv_mid = nn.Conv2d(nc_in, nc_mid, 2, stride=1, padding=1)
|
234
|
+
conv_done = nn.Conv2d(nc_mid, nc_done, 2, stride=1)
|
235
|
+
|
236
|
+
self.encoding_ops: nn.Sequential = nn.Sequential(
|
237
|
+
conv_in,
|
238
|
+
nn.LeakyReLU(),
|
239
|
+
nn.MaxPool2d(2),
|
240
|
+
conv_mid,
|
241
|
+
nn.LeakyReLU(),
|
242
|
+
nn.MaxPool2d(2),
|
243
|
+
conv_done,
|
244
|
+
)
|
245
|
+
|
246
|
+
ny, nx = input_shape[1:]
|
247
|
+
self.post_op_shape: tuple[int, int, int] = (nc_done, ny // 4 - 1, nx // 4 - 1)
|
248
|
+
self.flatcon: int = math.prod(self.post_op_shape)
|
249
|
+
self.flatten: nn.Sequential = nn.Sequential(
|
250
|
+
nn.Flatten(),
|
251
|
+
nn.Linear(
|
252
|
+
self.flatcon,
|
253
|
+
encoding_dim,
|
254
|
+
),
|
255
|
+
)
|
256
|
+
|
257
|
+
def forward(self, x: Any) -> Any:
|
258
|
+
"""
|
259
|
+
Perform a forward pass through the AE_torch encoder.
|
260
|
+
|
261
|
+
Parameters
|
262
|
+
----------
|
263
|
+
x : torch.Tensor
|
264
|
+
Input tensor
|
265
|
+
|
266
|
+
Returns
|
267
|
+
-------
|
268
|
+
torch.Tensor
|
269
|
+
The encoded representation of the input tensor.
|
270
|
+
"""
|
271
|
+
x = self.encoding_ops(x)
|
272
|
+
|
273
|
+
x = self.flatten(x)
|
274
|
+
|
275
|
+
return x
|
276
|
+
|
277
|
+
|
278
|
+
class Decoder_AE(nn.Module):
|
279
|
+
"""
|
280
|
+
A simple decoder to be used in an autoencoder model.
|
281
|
+
|
282
|
+
This is the decoder used by the AriaAutoencoder model.
|
283
|
+
|
284
|
+
Parameters
|
285
|
+
----------
|
286
|
+
channels : int
|
287
|
+
Number of output channels
|
288
|
+
"""
|
289
|
+
|
290
|
+
def __init__(
|
291
|
+
self,
|
292
|
+
input_shape: tuple[int, int, int],
|
293
|
+
encoding_dim: int,
|
294
|
+
post_op_shape: tuple[int, int, int],
|
295
|
+
) -> None:
|
296
|
+
super().__init__()
|
297
|
+
|
298
|
+
self.post_op_shape = post_op_shape
|
299
|
+
self.input_shape = input_shape # need to store this for use in forward().
|
300
|
+
channels = input_shape[0]
|
301
|
+
|
302
|
+
self.input: nn.Linear = nn.Linear(encoding_dim, math.prod(post_op_shape))
|
303
|
+
|
304
|
+
self.decoder: nn.Sequential = nn.Sequential(
|
305
|
+
nn.ConvTranspose2d(64, 128, 2, stride=1),
|
306
|
+
nn.LeakyReLU(),
|
307
|
+
nn.ConvTranspose2d(128, 256, 2, stride=2),
|
308
|
+
nn.LeakyReLU(),
|
309
|
+
nn.ConvTranspose2d(256, channels, 2, stride=2),
|
310
|
+
)
|
311
|
+
|
312
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
313
|
+
"""
|
314
|
+
Perform a forward pass through the decoder.
|
315
|
+
|
316
|
+
Parameters
|
317
|
+
----------
|
318
|
+
x : torch.Tensor
|
319
|
+
The encoded tensor.
|
320
|
+
|
321
|
+
Returns
|
322
|
+
-------
|
323
|
+
torch.Tensor
|
324
|
+
The reconstructed output tensor.
|
325
|
+
"""
|
326
|
+
x = self.input(x)
|
327
|
+
x = x.reshape((-1, *self.post_op_shape))
|
328
|
+
x = self.decoder(x)
|
329
|
+
x = x.reshape((-1, *self.input_shape))
|
330
|
+
return x
|
dataeval/utils/torch/trainer.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
from typing import Any
|
3
|
+
from typing import Any, Callable
|
6
4
|
|
7
5
|
import torch
|
8
6
|
import torch.nn as nn
|
7
|
+
from numpy.typing import NDArray
|
9
8
|
from torch.optim import Adam
|
10
|
-
from torch.utils.data import DataLoader, Dataset
|
9
|
+
from torch.utils.data import DataLoader, Dataset, TensorDataset
|
10
|
+
from tqdm import tqdm
|
11
11
|
|
12
|
-
|
12
|
+
__all__ = ["AETrainer", "trainer"]
|
13
13
|
|
14
14
|
|
15
15
|
def get_images_from_batch(batch: Any) -> Any:
|
@@ -176,3 +176,82 @@ class AETrainer:
|
|
176
176
|
encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
|
177
177
|
|
178
178
|
return encodings
|
179
|
+
|
180
|
+
|
181
|
+
def trainer(
|
182
|
+
model: torch.nn.Module,
|
183
|
+
x_train: NDArray[Any],
|
184
|
+
y_train: NDArray[Any] | None,
|
185
|
+
loss_fn: Callable[..., torch.Tensor | torch.nn.Module] | None,
|
186
|
+
optimizer: torch.optim.Optimizer | None,
|
187
|
+
preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None,
|
188
|
+
epochs: int,
|
189
|
+
batch_size: int,
|
190
|
+
device: torch.device,
|
191
|
+
verbose: bool,
|
192
|
+
) -> None:
|
193
|
+
"""
|
194
|
+
Train Pytorch model.
|
195
|
+
|
196
|
+
Parameters
|
197
|
+
----------
|
198
|
+
model
|
199
|
+
Model to train.
|
200
|
+
loss_fn
|
201
|
+
Loss function used for training.
|
202
|
+
x_train
|
203
|
+
Training data.
|
204
|
+
y_train
|
205
|
+
Training labels.
|
206
|
+
optimizer
|
207
|
+
Optimizer used for training.
|
208
|
+
preprocess_fn
|
209
|
+
Preprocessing function applied to each training batch.
|
210
|
+
epochs
|
211
|
+
Number of training epochs.
|
212
|
+
reg_loss_fn
|
213
|
+
Allows an additional regularisation term to be defined as reg_loss_fn(model)
|
214
|
+
batch_size
|
215
|
+
Batch size used for training.
|
216
|
+
buffer_size
|
217
|
+
Maximum number of elements that will be buffered when prefetching.
|
218
|
+
verbose
|
219
|
+
Whether to print training progress.
|
220
|
+
"""
|
221
|
+
if optimizer is None:
|
222
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
223
|
+
|
224
|
+
if y_train is None:
|
225
|
+
dataset = TensorDataset(torch.from_numpy(x_train).to(torch.float32))
|
226
|
+
|
227
|
+
else:
|
228
|
+
dataset = TensorDataset(
|
229
|
+
torch.from_numpy(x_train).to(torch.float32), torch.from_numpy(y_train).to(torch.float32)
|
230
|
+
)
|
231
|
+
|
232
|
+
loader = DataLoader(dataset=dataset)
|
233
|
+
|
234
|
+
model = model.to(device)
|
235
|
+
|
236
|
+
# iterate over epochs
|
237
|
+
loss = torch.nan
|
238
|
+
disable_tqdm = not verbose
|
239
|
+
for epoch in (pbar := tqdm(range(epochs), disable=disable_tqdm)):
|
240
|
+
epoch_loss = loss
|
241
|
+
for step, data in enumerate(loader):
|
242
|
+
if step % 250 == 0:
|
243
|
+
pbar.set_description(f"Epoch: {epoch} ({epoch_loss:.3f}), loss: {loss:.3f}")
|
244
|
+
|
245
|
+
x, y = [d.to(device) for d in data] if len(data) > 1 else (data[0].to(device), None)
|
246
|
+
|
247
|
+
if isinstance(preprocess_fn, Callable):
|
248
|
+
x = preprocess_fn(x)
|
249
|
+
|
250
|
+
y_hat = model(x)
|
251
|
+
y = x if y is None else y
|
252
|
+
|
253
|
+
loss = loss_fn(y, y_hat) # type: ignore
|
254
|
+
|
255
|
+
optimizer.zero_grad()
|
256
|
+
loss.backward()
|
257
|
+
optimizer.step()
|
dataeval/utils/torch/utils.py
CHANGED
@@ -3,8 +3,12 @@ from __future__ import annotations
|
|
3
3
|
__all__ = ["read_dataset"]
|
4
4
|
|
5
5
|
from collections import defaultdict
|
6
|
-
from
|
6
|
+
from functools import partial
|
7
|
+
from typing import Any, Callable
|
7
8
|
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
from numpy.typing import NDArray
|
8
12
|
from torch.utils.data import Dataset
|
9
13
|
|
10
14
|
|
@@ -61,3 +65,105 @@ def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
|
|
61
65
|
ddict[i].append(d)
|
62
66
|
|
63
67
|
return list(ddict.values())
|
68
|
+
|
69
|
+
|
70
|
+
def get_device(device: str | torch.device | None = None) -> torch.device:
|
71
|
+
"""
|
72
|
+
Instantiates a PyTorch device object.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
device : str | torch.device | None, default None
|
77
|
+
Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
|
78
|
+
already instantiated device object. If ``None``, the GPU is selected if it is
|
79
|
+
detected, otherwise the CPU is used as a fallback.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
The instantiated device object.
|
84
|
+
"""
|
85
|
+
if isinstance(device, torch.device): # Already a torch device
|
86
|
+
return device
|
87
|
+
else: # Instantiate device
|
88
|
+
if device is None or device.lower() in ["gpu", "cuda"]:
|
89
|
+
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
90
|
+
else:
|
91
|
+
torch_device = torch.device("cpu")
|
92
|
+
return torch_device
|
93
|
+
|
94
|
+
|
95
|
+
def predict_batch(
|
96
|
+
x: NDArray[Any] | torch.Tensor,
|
97
|
+
model: Callable | torch.nn.Module | torch.nn.Sequential,
|
98
|
+
device: torch.device | None = None,
|
99
|
+
batch_size: int = int(1e10),
|
100
|
+
preprocess_fn: Callable | None = None,
|
101
|
+
dtype: type[np.generic] | torch.dtype = np.float32,
|
102
|
+
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
103
|
+
"""
|
104
|
+
Make batch predictions on a model.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
x : np.ndarray | torch.Tensor
|
109
|
+
Batch of instances.
|
110
|
+
model : Callable | nn.Module | nn.Sequential
|
111
|
+
PyTorch model.
|
112
|
+
device : torch.device | None, default None
|
113
|
+
Device type used. The default None tries to use the GPU and falls back on CPU.
|
114
|
+
Can be specified by passing either torch.device('cuda') or torch.device('cpu').
|
115
|
+
batch_size : int, default 1e10
|
116
|
+
Batch size used during prediction.
|
117
|
+
preprocess_fn : Callable | None, default None
|
118
|
+
Optional preprocessing function for each batch.
|
119
|
+
dtype : np.dtype | torch.dtype, default np.float32
|
120
|
+
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
NDArray | torch.Tensor | tuple
|
125
|
+
Numpy array, torch tensor or tuples of those with model outputs.
|
126
|
+
"""
|
127
|
+
device = get_device(device)
|
128
|
+
if isinstance(x, np.ndarray):
|
129
|
+
x = torch.from_numpy(x).to(device)
|
130
|
+
n = len(x)
|
131
|
+
n_minibatch = int(np.ceil(n / batch_size))
|
132
|
+
return_np = not isinstance(dtype, torch.dtype)
|
133
|
+
preds = []
|
134
|
+
with torch.no_grad():
|
135
|
+
for i in range(n_minibatch):
|
136
|
+
istart, istop = i * batch_size, min((i + 1) * batch_size, n)
|
137
|
+
x_batch = x[istart:istop]
|
138
|
+
if isinstance(preprocess_fn, Callable):
|
139
|
+
x_batch = preprocess_fn(x_batch)
|
140
|
+
|
141
|
+
preds_tmp = model(x_batch.to(torch.float32).to(device))
|
142
|
+
if isinstance(preds_tmp, (list, tuple)):
|
143
|
+
if len(preds) == 0: # init tuple with lists to store predictions
|
144
|
+
preds = tuple([] for _ in range(len(preds_tmp)))
|
145
|
+
for j, p in enumerate(preds_tmp):
|
146
|
+
if isinstance(p, torch.Tensor):
|
147
|
+
p = p.cpu()
|
148
|
+
preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
|
149
|
+
elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
|
150
|
+
if isinstance(preds_tmp, torch.Tensor):
|
151
|
+
preds_tmp = preds_tmp.cpu()
|
152
|
+
if isinstance(preds, tuple):
|
153
|
+
preds = list(preds)
|
154
|
+
preds.append(
|
155
|
+
preds_tmp
|
156
|
+
if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
|
157
|
+
else preds_tmp.numpy()
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
raise TypeError(
|
161
|
+
f"Model output type {type(preds_tmp)} not supported. The model \
|
162
|
+
output type needs to be one of list, tuple, NDArray or \
|
163
|
+
torch.Tensor."
|
164
|
+
)
|
165
|
+
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
166
|
+
out: tuple | np.ndarray | torch.Tensor = (
|
167
|
+
tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
|
168
|
+
)
|
169
|
+
return out
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.74.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -23,7 +23,6 @@ Classifier: Topic :: Scientific/Engineering
|
|
23
23
|
Provides-Extra: all
|
24
24
|
Provides-Extra: tensorflow
|
25
25
|
Provides-Extra: torch
|
26
|
-
Requires-Dist: hdbscan (>=0.8.36)
|
27
26
|
Requires-Dist: markupsafe (<3.0.2) ; extra == "tensorflow" or extra == "all"
|
28
27
|
Requires-Dist: matplotlib ; extra == "torch" or extra == "all"
|
29
28
|
Requires-Dist: numpy (>1.24.3)
|
@@ -1,11 +1,11 @@
|
|
1
|
-
dataeval/__init__.py,sha256=
|
1
|
+
dataeval/__init__.py,sha256=bwKFegCsdGFydqDvza_wSvJgRGr-0pQ59UpcePQ1mNs,601
|
2
2
|
dataeval/detectors/__init__.py,sha256=mwAyY54Hvp6N4D57cde3_besOinK8jVF43k0Mw4XZi8,363
|
3
3
|
dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
|
4
4
|
dataeval/detectors/drift/base.py,sha256=xwI6C-PEH0ZjpSqP6No6WDZp42DnE16OHi_mXe2JSvI,14499
|
5
5
|
dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
|
6
6
|
dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
|
7
7
|
dataeval/detectors/drift/mmd.py,sha256=TqGOnUNYKwpS0GQPV3dSl-_qRa0g2flmoQ-dxzW_JfY,7586
|
8
|
-
dataeval/detectors/drift/torch.py,sha256=
|
8
|
+
dataeval/detectors/drift/torch.py,sha256=igEQ2DV9JmcpTdUKCOHBi5LxtoNeCAslJS2Ldulg1hw,7585
|
9
9
|
dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
|
10
10
|
dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
|
11
11
|
dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
|
@@ -13,24 +13,28 @@ dataeval/detectors/linters/clusterer.py,sha256=sau5A9YcQ6VDjbZGOIaCaRHW_63opaA31
|
|
13
13
|
dataeval/detectors/linters/duplicates.py,sha256=tOD43rJkvheIA3mznbUqHhft2yD3xRZQdCt61daIca4,5665
|
14
14
|
dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
|
15
15
|
dataeval/detectors/linters/outliers.py,sha256=BUVvtbKHo04KnRmrgb84MBr0l1gtcY3-xNCHjetFrEQ,10117
|
16
|
-
dataeval/detectors/ood/__init__.py,sha256=
|
17
|
-
dataeval/detectors/ood/ae.py,sha256=
|
18
|
-
dataeval/detectors/ood/
|
19
|
-
dataeval/detectors/ood/
|
20
|
-
dataeval/detectors/ood/
|
21
|
-
dataeval/detectors/ood/
|
16
|
+
dataeval/detectors/ood/__init__.py,sha256=XckkWVhYbbg9iWVsCPEQN-t7FFSt2a4jmCwAAempkM4,793
|
17
|
+
dataeval/detectors/ood/ae.py,sha256=km7buF8LbMmwsyfu1xMOI5CJDnQX1x8_-c04zTGMXRI,2389
|
18
|
+
dataeval/detectors/ood/ae_torch.py,sha256=pO9w5221bXR9lEBkE7oakXeE7PXUUR--xcTpmHvOCSk,2142
|
19
|
+
dataeval/detectors/ood/aegmm.py,sha256=CI2HEkRMJSEFTVLZEhz4CStkaS7i66yTPtnbkbCqTes,2084
|
20
|
+
dataeval/detectors/ood/base.py,sha256=u9S7z7zJ8wuPqrtn63ePdAa8DdI579EbCy8Tn0M3XI8,6983
|
21
|
+
dataeval/detectors/ood/base_tf.py,sha256=ppj8rAjXjHEab2oGfQO2olXyN4aGZH8_QHIEghOoeFQ,3297
|
22
|
+
dataeval/detectors/ood/base_torch.py,sha256=yFbSfQsBMwZeVf8mrixmkZYBGChhV5oAHtkgzWnMzsA,3405
|
23
|
+
dataeval/detectors/ood/llr.py,sha256=IrOam-kqUU4bftolR3MvhcEq-NNj2euyI-lYvMuXYn8,10645
|
24
|
+
dataeval/detectors/ood/metadata_ks_compare.py,sha256=Ka6MABdJH5ZlHF66mENpSOLCE8H9xdQ_wWNwMYVO_Q0,5352
|
22
25
|
dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
|
23
26
|
dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
|
24
|
-
dataeval/detectors/ood/vae.py,sha256=
|
25
|
-
dataeval/detectors/ood/vaegmm.py,sha256=
|
27
|
+
dataeval/detectors/ood/vae.py,sha256=yjK4p-XYhnH3wWPiwAclb3eyZE0wpTazLLuKhzurcWY,3203
|
28
|
+
dataeval/detectors/ood/vaegmm.py,sha256=FhPJBzs7wyEPQUUMxOMsdPpCdAZwN82vztjt05cSrds,2459
|
26
29
|
dataeval/interop.py,sha256=TZCkZo844DvzHoxuRo-YsBhT6GvKmyQTHtUEQZPly1M,1728
|
27
30
|
dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
|
28
|
-
dataeval/metrics/bias/__init__.py,sha256=
|
29
|
-
dataeval/metrics/bias/balance.py,sha256=
|
30
|
-
dataeval/metrics/bias/coverage.py,sha256=
|
31
|
-
dataeval/metrics/bias/diversity.py,sha256=
|
32
|
-
dataeval/metrics/bias/
|
33
|
-
dataeval/metrics/bias/
|
31
|
+
dataeval/metrics/bias/__init__.py,sha256=dYiPHenS8J7pgRMMW2jNkTBmTbPoYTxT04fZu9PFats,747
|
32
|
+
dataeval/metrics/bias/balance.py,sha256=BH644D_xN7rRUdJMNgVcGHWq3TTnehYjSBhSMhmAFyY,9154
|
33
|
+
dataeval/metrics/bias/coverage.py,sha256=LBrNG6GIrvMJjZckr72heyCTMCke_p5BT8NJWi-noEY,4546
|
34
|
+
dataeval/metrics/bias/diversity.py,sha256=__7I934sVoymXqgHoneXglJhIU5iHRIuklFwC2ks84w,7719
|
35
|
+
dataeval/metrics/bias/metadata_preprocessing.py,sha256=DbtzsiHjkCxs411okb6s2B_H2TqfvwJ4xyt9m_OsqJo,12266
|
36
|
+
dataeval/metrics/bias/metadata_utils.py,sha256=HmTjlRRTdM9566oKUDDdVMJ8luss4DYykFOiS2FQzhM,6558
|
37
|
+
dataeval/metrics/bias/parity.py,sha256=lLa2zN0AK-zWzlXmvLCbMxTZFodAKLs8wSGl_YZdNFo,12765
|
34
38
|
dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
|
35
39
|
dataeval/metrics/estimators/ber.py,sha256=SVT-BIC_GLs0l2l2NhWu4OpRbgn96w-OwTSoPHTnQbE,5037
|
36
40
|
dataeval/metrics/estimators/divergence.py,sha256=pImaa216-YYTgGWDCSTcpJrC-dfl7150yVrPfW_TyGc,4293
|
@@ -47,13 +51,14 @@ dataeval/metrics/stats/visualstats.py,sha256=y0xIvst7epcajk8vz2jngiAiz0T7DZC-M97
|
|
47
51
|
dataeval/output.py,sha256=jWXXNxFNBEaY1rN7Z-6LZl6bQT-I7z_wqr91Rhrdt_0,3061
|
48
52
|
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
53
|
dataeval/utils/__init__.py,sha256=FZLWDA7nMbHOcdg3701cVJpQmUp1Wxxk8h_qIrUQQjY,713
|
54
|
+
dataeval/utils/gmm.py,sha256=YuLsJKsVWgH_wHr1u_hSRH5Yeexdj8exht8h99L7bLo,561
|
50
55
|
dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
|
51
56
|
dataeval/utils/lazy.py,sha256=M0iBHuJh4UPrSJPHZ0jhFwRSZhyjHJQx_KEf1OCkHD8,588
|
52
|
-
dataeval/utils/metadata.py,sha256=
|
57
|
+
dataeval/utils/metadata.py,sha256=0A--iru0zEmi044mKz5P35q69KrI30yoiRSlvs7TSdQ,9418
|
53
58
|
dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
|
54
59
|
dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
|
55
60
|
dataeval/utils/tensorflow/__init__.py,sha256=l4OjIA75JJXeNWDCkST1xtDMVYsw97lZ-9JXFBlyuYg,539
|
56
|
-
dataeval/utils/tensorflow/_internal/gmm.py,sha256=
|
61
|
+
dataeval/utils/tensorflow/_internal/gmm.py,sha256=XvjhWM3ppP-R9nCZGs80WphmQR3u7wb-VtoCQYeXZlQ,3404
|
57
62
|
dataeval/utils/tensorflow/_internal/loss.py,sha256=TFhoNPgqeJtdpIHYobZPyzMpeWjzlFqzu5LCtthEUi4,4463
|
58
63
|
dataeval/utils/tensorflow/_internal/models.py,sha256=TzQYRrFe5XomhnPw05v-HBODQdFIqWg21WH1xS0XBlg,59868
|
59
64
|
dataeval/utils/tensorflow/_internal/trainer.py,sha256=uBFTnAy9o2T_FoT3RSX-AA7T-2FScyOdYEg9_7Dpd28,4314
|
@@ -62,12 +67,13 @@ dataeval/utils/tensorflow/loss/__init__.py,sha256=Q-66vt91Oe1ByYfo28tW32zXDq2MqQ
|
|
62
67
|
dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
|
63
68
|
dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
64
69
|
dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
|
65
|
-
dataeval/utils/torch/
|
66
|
-
dataeval/utils/torch/
|
67
|
-
dataeval/utils/torch/
|
70
|
+
dataeval/utils/torch/gmm.py,sha256=VbLlUQohwToApT493_tjQBWy2UM5R-3ppS9Dp-eP7BA,3240
|
71
|
+
dataeval/utils/torch/models.py,sha256=sdGeo7a8vshCTGA4lYyVxxb_aDWUlxdtIVxrddS-_ls,8542
|
72
|
+
dataeval/utils/torch/trainer.py,sha256=8BEXr6xtk-CHJTcNxOBnWgkFWfJUAiBy28cEdBhLMRU,7883
|
73
|
+
dataeval/utils/torch/utils.py,sha256=nWRcT6z6DbFVrL1RyxCOX3DPoCrv9G0B-VI_9LdGCQQ,5784
|
68
74
|
dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
|
69
75
|
dataeval/workflows/sufficiency.py,sha256=1jSYhH9i4oesmJYs5PZvWS1LGXf8ekOgNhpFtMPLPXk,18552
|
70
|
-
dataeval-0.
|
71
|
-
dataeval-0.
|
72
|
-
dataeval-0.
|
73
|
-
dataeval-0.
|
76
|
+
dataeval-0.74.0.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
|
77
|
+
dataeval-0.74.0.dist-info/METADATA,sha256=OPnkHZTm8R1LHqLxcSnOHjqj5GuHmjUVI3dddTVsBAc,4680
|
78
|
+
dataeval-0.74.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
79
|
+
dataeval-0.74.0.dist-info/RECORD,,
|