dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/data/_embeddings.py +345 -0
- dataeval/{utils/data → data}/_images.py +2 -2
- dataeval/{utils/data → data}/_metadata.py +8 -7
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/data/selections/_classbalance.py +37 -0
- dataeval/data/selections/_classfilter.py +109 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +6 -4
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_base.py +7 -7
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/_bias.py +1 -1
- dataeval/typing.py +53 -19
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +18 -7
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/_dataset.py +6 -4
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
- dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
- dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
- dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
- dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
- dataeval-1.0.0.dist-info/RECORD +107 -0
- dataeval/detectors/drift/_torch.py +0 -222
- dataeval/utils/data/_embeddings.py +0 -186
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -17
- dataeval/utils/data/selections/_classfilter.py +0 -59
- dataeval-0.84.0.dist-info/RECORD +0 -106
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from functools import partial
|
6
5
|
from typing import Any, Callable
|
7
6
|
|
8
7
|
import numpy as np
|
@@ -12,16 +11,16 @@ from torch.utils.data import DataLoader, TensorDataset
|
|
12
11
|
from tqdm import tqdm
|
13
12
|
|
14
13
|
from dataeval.config import DeviceLike, get_device
|
14
|
+
from dataeval.typing import Array
|
15
15
|
|
16
16
|
|
17
17
|
def predict_batch(
|
18
|
-
x:
|
19
|
-
model:
|
18
|
+
x: Array,
|
19
|
+
model: torch.nn.Module,
|
20
20
|
device: DeviceLike | None = None,
|
21
21
|
batch_size: int = int(1e10),
|
22
22
|
preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
23
|
-
|
24
|
-
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
23
|
+
) -> torch.Tensor:
|
25
24
|
"""
|
26
25
|
Make batch predictions on a model.
|
27
26
|
|
@@ -29,7 +28,7 @@ def predict_batch(
|
|
29
28
|
----------
|
30
29
|
x : np.ndarray | torch.Tensor
|
31
30
|
Batch of instances.
|
32
|
-
model :
|
31
|
+
model : nn.Module
|
33
32
|
PyTorch model.
|
34
33
|
device : DeviceLike or None, default None
|
35
34
|
The hardware device to use if specified, otherwise uses the DataEval
|
@@ -38,21 +37,18 @@ def predict_batch(
|
|
38
37
|
Batch size used during prediction.
|
39
38
|
preprocess_fn : Callable | None, default None
|
40
39
|
Optional preprocessing function for each batch.
|
41
|
-
dtype : np.dtype | torch.dtype, default np.float32
|
42
|
-
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
43
40
|
|
44
41
|
Returns
|
45
42
|
-------
|
46
|
-
|
47
|
-
|
43
|
+
torch.Tensor
|
44
|
+
PyTorch tensor with model outputs.
|
48
45
|
"""
|
49
46
|
device = get_device(device)
|
50
|
-
if isinstance(
|
51
|
-
|
47
|
+
if isinstance(model, torch.nn.Module):
|
48
|
+
model = model.to(device).eval()
|
49
|
+
x = torch.tensor(x, device=device)
|
52
50
|
n = len(x)
|
53
51
|
n_minibatch = int(np.ceil(n / batch_size))
|
54
|
-
return_np = not isinstance(dtype, torch.dtype)
|
55
|
-
preds_tuple = None
|
56
52
|
preds_array = []
|
57
53
|
with torch.no_grad():
|
58
54
|
for i in range(n_minibatch):
|
@@ -60,28 +56,9 @@ def predict_batch(
|
|
60
56
|
x_batch = x[istart:istop]
|
61
57
|
if isinstance(preprocess_fn, Callable):
|
62
58
|
x_batch = preprocess_fn(x_batch)
|
59
|
+
preds_array.append(model(x_batch.to(dtype=torch.float32)).cpu())
|
63
60
|
|
64
|
-
|
65
|
-
if isinstance(preds_tmp, (list, tuple)):
|
66
|
-
if preds_tuple is None: # init tuple with lists to store predictions
|
67
|
-
preds_tuple = tuple([] for _ in range(len(preds_tmp)))
|
68
|
-
for j, p in enumerate(preds_tmp):
|
69
|
-
p = p.cpu() if isinstance(p, torch.Tensor) else p
|
70
|
-
preds_tuple[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
|
71
|
-
elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
|
72
|
-
preds_tmp = preds_tmp.cpu() if isinstance(preds_tmp, torch.Tensor) else preds_tmp
|
73
|
-
preds_array.append(
|
74
|
-
preds_tmp if not return_np or isinstance(preds_tmp, np.ndarray) else preds_tmp.numpy()
|
75
|
-
)
|
76
|
-
else:
|
77
|
-
raise TypeError(
|
78
|
-
f"Model output type {type(preds_tmp)} not supported. The model \
|
79
|
-
output type needs to be one of list, tuple, NDArray or \
|
80
|
-
torch.Tensor."
|
81
|
-
)
|
82
|
-
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
83
|
-
out = tuple(concat(p) for p in preds_tuple) if preds_tuple is not None else concat(preds_array)
|
84
|
-
return out
|
61
|
+
return torch.cat(preds_array, dim=0)
|
85
62
|
|
86
63
|
|
87
64
|
def trainer(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 1.0.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -82,8 +82,7 @@ using MAITE-compliant datasets and models.
|
|
82
82
|
|
83
83
|
**Python versions:** 3.9 - 3.12
|
84
84
|
|
85
|
-
**Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK
|
86
|
-
*Gradient*
|
85
|
+
**Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*
|
87
86
|
|
88
87
|
Choose your preferred method of installation below or follow our
|
89
88
|
[installation guide](https://dataeval.readthedocs.io/en/v0.74.2/installation.html).
|
@@ -0,0 +1,107 @@
|
|
1
|
+
dataeval/__init__.py,sha256=xd1GfD7QmzBG-WN7K6BMJSzV9_UZlX5OiKICdQ5xGfU,1635
|
2
|
+
dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
|
3
|
+
dataeval/config.py,sha256=lD1YDH8HosFeRU5rQEYRBcmXMZy-csWaMlJTRZGd9iU,3582
|
4
|
+
dataeval/data/__init__.py,sha256=qNnRRiVP_sLthkkHpUrMgI_r8dQK-cC-xoGrrjQeRKc,544
|
5
|
+
dataeval/data/_embeddings.py,sha256=6Medqj_JCQt1iwZwWGSs1OeX-bHB8bg5BJqADY1N2s8,12883
|
6
|
+
dataeval/data/_images.py,sha256=WF9XJRka8ohUdyI2IKBMAy3JoJhOm1iC-8tbYl8woRM,2642
|
7
|
+
dataeval/data/_metadata.py,sha256=hNgsCEN8EyfDDX7zLKcQnsaDl-9xvvs5tUzqMjVLvI4,14457
|
8
|
+
dataeval/data/_selection.py,sha256=V61_pTFj0hSzmltA6CV5t51Znqw2dIQZ71Iu46bLm44,4486
|
9
|
+
dataeval/data/_split.py,sha256=6Jtm_i__CcPtNE3eSeBdPxc7gn7Cp-GM7g9wJWFlVus,16761
|
10
|
+
dataeval/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
|
11
|
+
dataeval/data/selections/__init__.py,sha256=2m8ZB53wXzqLcqmc6p5atO6graB6ZyiRSNJFxf11X_g,613
|
12
|
+
dataeval/data/selections/_classbalance.py,sha256=7v8ApoL3X8eCZ6fGDNTehE_bZ1loaP3TlhsJLaICVWg,1458
|
13
|
+
dataeval/data/selections/_classfilter.py,sha256=rEeq959p_SLl_etS7pcM8ZxK4yzEYlYZAQ3FlcLV0R8,4330
|
14
|
+
dataeval/data/selections/_indices.py,sha256=RFsR9z10aM3N0gJSfKrukFpi-LkiQGXoOwXhmOQ5cpg,630
|
15
|
+
dataeval/data/selections/_limit.py,sha256=JG4GmEiNKt3sk4PbOUbBnGGzNlyz72H-kQrt8COMm4Y,512
|
16
|
+
dataeval/data/selections/_prioritize.py,sha256=yw51ZQk6FPvyC38M4_pS_Se2Dq0LDFcdDhfbsELzTZc,11306
|
17
|
+
dataeval/data/selections/_reverse.py,sha256=b67kNC43A5KpQOic5gifjo9HpJ7FMh4LFCrfovPiJ-M,368
|
18
|
+
dataeval/data/selections/_shuffle.py,sha256=gVz_2T4rlucq8Ytqz5jvmmZdTrZDaIv43jJbq97tLjQ,1173
|
19
|
+
dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
|
20
|
+
dataeval/detectors/drift/__init__.py,sha256=gD8aY5PotS-S2ot7iB_z_zzSOjIbQLw5znFBNj0jtHE,646
|
21
|
+
dataeval/detectors/drift/_base.py,sha256=amGqzUAe8fU5qwM5lq1p8PCuhjGh9MHkdW1zeBF1LEE,7574
|
22
|
+
dataeval/detectors/drift/_cvm.py,sha256=cS33zWJmFY1fft1XcANcP2jSD5ou7TxvIU2AldhTynM,3004
|
23
|
+
dataeval/detectors/drift/_ks.py,sha256=uMc5-NA-lSV1IODrY8uJe87ll3uRJT_oXLJFXy95M1w,3186
|
24
|
+
dataeval/detectors/drift/_mmd.py,sha256=wHUy_vUafCikrZ_WX8qQXpxFwzw07-5zVutloR6hl1k,11589
|
25
|
+
dataeval/detectors/drift/_uncertainty.py,sha256=BHlykJ-r7TGLJxdPfoazXnoAJ1qVDzbk5HjAMdsnHz8,5847
|
26
|
+
dataeval/detectors/drift/updates.py,sha256=L1PnrPlIE1x6ujCc5mCwjcAZwadVTn-Zjb6MnTDvzJQ,2251
|
27
|
+
dataeval/detectors/linters/__init__.py,sha256=xn2zPwUcmsuf-Jd9uw6AVI11C9z1b1Y9fYtuFnXenZ0,404
|
28
|
+
dataeval/detectors/linters/duplicates.py,sha256=X5WSEvI_BHkLoXjkaHK6wTnSkx4IjpO_exMRjSlhc70,4963
|
29
|
+
dataeval/detectors/linters/outliers.py,sha256=D8A-Fov5iUrlU9xMX5Ht33FqUY8Lk5ulC6BlHbUoLwU,9048
|
30
|
+
dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
|
31
|
+
dataeval/detectors/ood/ae.py,sha256=fTrUfFxv6xUqzKpwMC8rW3JrizA16M_bgzqLuBKMrS0,2944
|
32
|
+
dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
|
33
|
+
dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
|
34
|
+
dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
|
35
|
+
dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
|
36
|
+
dataeval/metadata/_distance.py,sha256=T1Umju_QwBiLmn1iUbxZagzBS2VnHaDIdp6j-NpaZuk,4076
|
37
|
+
dataeval/metadata/_ood.py,sha256=lnKtKModArnUrAhH_XswEtUAhUkh1U_oNsLt1UmNP44,12748
|
38
|
+
dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
|
39
|
+
dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
|
40
|
+
dataeval/metrics/bias/__init__.py,sha256=329S1_3WnWqeU4-qVcbe0fMy4lDrj9uKslWHIQf93yg,839
|
41
|
+
dataeval/metrics/bias/_balance.py,sha256=l1hTVkVwD85bP20MTthA-I5BkvbytylQkJu3Q6iTuPA,6152
|
42
|
+
dataeval/metrics/bias/_completeness.py,sha256=BysXU2Jpw33n5dl3acJFEqF3mFGiJLsfG4n5Q2fkTaY,4608
|
43
|
+
dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
|
44
|
+
dataeval/metrics/bias/_diversity.py,sha256=B_qWVDMZfh818U0qVm8yidquB0H0XvW8N75OWVWXy2g,5814
|
45
|
+
dataeval/metrics/bias/_parity.py,sha256=ea1D-eJh6cJxQ11XD6VbDXBKecE0jJJwptGD7LQJmBw,11529
|
46
|
+
dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
|
47
|
+
dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
|
48
|
+
dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
|
49
|
+
dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
|
50
|
+
dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
|
51
|
+
dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
|
52
|
+
dataeval/metrics/stats/_base.py,sha256=YIfOVGd7E19B4dpAnzDYRQkaikvRRyJIpznJNfVtPdw,10750
|
53
|
+
dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
|
54
|
+
dataeval/metrics/stats/_dimensionstats.py,sha256=73mFP-Myxne0peFliwvTntc0kk4cpq0krzMvSLDSIMM,2702
|
55
|
+
dataeval/metrics/stats/_hashstats.py,sha256=gp9X_pnTT3mPH9YNrWLdn2LQPK_epJ3dQRoyOCwmKlg,4758
|
56
|
+
dataeval/metrics/stats/_imagestats.py,sha256=gUPNgN5Zwzdr7WnSwbve1NXNsyxd5dy3cSnlR_7guCg,3007
|
57
|
+
dataeval/metrics/stats/_labelstats.py,sha256=lz8I6eSd8tFkmQqy5cOG8hn9yxs0mP-Ic9ratFHiuoU,2813
|
58
|
+
dataeval/metrics/stats/_pixelstats.py,sha256=SfergRbjNJE4h0xqe-0c8RnKtZmEkZ9MwExdipLSGvg,3247
|
59
|
+
dataeval/metrics/stats/_visualstats.py,sha256=cq4AbF2B50Ihbzb86FphcnKQ1TSwNnP3PsnbpiPQZWw,3698
|
60
|
+
dataeval/outputs/__init__.py,sha256=ciK-RdXgtn_s7MSCUW1UXvrXltMbltqbpfe9_V7xGrI,1701
|
61
|
+
dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
|
62
|
+
dataeval/outputs/_bias.py,sha256=7L-d3DUWY6Vud7iX_VoQT0HG0KaV1U35gvmRApqzyB0,12401
|
63
|
+
dataeval/outputs/_drift.py,sha256=gOiu2C-ERTWiRqlP0auMYxPBGdm9HecWPqWfg7I4tZg,2015
|
64
|
+
dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
|
65
|
+
dataeval/outputs/_linters.py,sha256=YOdjrfm8ypdRrqYOaPM9nc6wVJI3-ita3Haj7LHDNaw,6416
|
66
|
+
dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI,1710
|
67
|
+
dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
|
68
|
+
dataeval/outputs/_stats.py,sha256=c73Yc3Kkrl-MN6BGKe1V0Yr6Ix2Yp_DZZfFSp8fZMZ0,13180
|
69
|
+
dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
|
70
|
+
dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
|
71
|
+
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
72
|
+
dataeval/typing.py,sha256=GDMuef-oFFukNtsiKFmsExHdNvYR_j-tQcsCwZ9reow,7198
|
73
|
+
dataeval/utils/__init__.py,sha256=hRvyUK7b3d6JBEV5u47rFcOHEcmDYqAvZQw_T5pDAWw,264
|
74
|
+
dataeval/utils/_array.py,sha256=KqAdXEMjcXYvdWdYEEoEbigwQJ4S9VYxQS3sRFeY5XY,5929
|
75
|
+
dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
|
76
|
+
dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,5436
|
77
|
+
dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
|
78
|
+
dataeval/utils/_image.py,sha256=capzF_X5H0jy0PmTP3Hf52GFgLqrnfU6gS4tiwck9jo,1939
|
79
|
+
dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
|
80
|
+
dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
|
81
|
+
dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
|
82
|
+
dataeval/utils/data/__init__.py,sha256=xGzrjrOxOP2DP1tU84AWMKPnSxFvSjM81CTlDg4rNM8,331
|
83
|
+
dataeval/utils/data/_dataset.py,sha256=MHY582yRm4FxQkkLWUhKZBb7ZyvWypM6ldUG89vd3uE,7936
|
84
|
+
dataeval/utils/data/collate.py,sha256=5egEEKhNNCGeNLChO1p6dZ4Wg6x51VEaMNHz7hEZUxI,3936
|
85
|
+
dataeval/utils/data/metadata.py,sha256=1XeGYj_e97-nJ_IrWEHPhWICmouYU5qbXWbp7uhZrIE,14171
|
86
|
+
dataeval/utils/datasets/__init__.py,sha256=Jfe7XI_9U5S4wuI_2QCoeuWNOxz4j0nAQvxc5wG5mWY,486
|
87
|
+
dataeval/utils/datasets/_base.py,sha256=TpmgPzF3EShCLAF5S4Zf9lFN78q17bTZF6AUE1qKdlk,8857
|
88
|
+
dataeval/utils/datasets/_cifar10.py,sha256=oSX5JEzbBM4zGC9kC7-hVTOglms3rYaUuYiA00_DUJ4,5439
|
89
|
+
dataeval/utils/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
|
90
|
+
dataeval/utils/datasets/_milco.py,sha256=BF2XvyzuOop1mg5pFZcRfYmZcezlbpZWHyd_TtEHFF4,7573
|
91
|
+
dataeval/utils/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
|
92
|
+
dataeval/utils/datasets/_mnist.py,sha256=4WOkQTORYMs6KEeyyJgChTnH03797y4ezgaZtYqplh4,8102
|
93
|
+
dataeval/utils/datasets/_ships.py,sha256=RMdX2KlnXJYOTzBb6euA5TAqxs-S8b56pAGiyQhNMuo,4870
|
94
|
+
dataeval/utils/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
|
95
|
+
dataeval/utils/datasets/_voc.py,sha256=kif6ms_romK6VElP4pf2SK4cJ5dEHDOkxSaSaeP3c5k,15565
|
96
|
+
dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
|
97
|
+
dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
98
|
+
dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
|
99
|
+
dataeval/utils/torch/_internal.py,sha256=vHy-DzPhmvE8h3wmWc3aciBJ8nDGzQ1z1jTZgGjmDyM,4154
|
100
|
+
dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
|
101
|
+
dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
|
102
|
+
dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
|
103
|
+
dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
|
104
|
+
dataeval-1.0.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
|
105
|
+
dataeval-1.0.0.dist-info/METADATA,sha256=ma_TquWQQl0QETiK4-wH1jfAe2my33Cl37GswNe0ZM8,5307
|
106
|
+
dataeval-1.0.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
107
|
+
dataeval-1.0.0.dist-info/RECORD,,
|
@@ -1,222 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
__all__ = []
|
12
|
-
|
13
|
-
from typing import Any, Callable
|
14
|
-
|
15
|
-
import numpy as np
|
16
|
-
import torch
|
17
|
-
import torch.nn as nn
|
18
|
-
from numpy.typing import NDArray
|
19
|
-
|
20
|
-
from dataeval.config import DeviceLike, get_device
|
21
|
-
from dataeval.utils.torch._internal import predict_batch
|
22
|
-
|
23
|
-
|
24
|
-
def mmd2_from_kernel_matrix(
|
25
|
-
kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
|
26
|
-
) -> torch.Tensor:
|
27
|
-
"""
|
28
|
-
Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
|
29
|
-
full kernel matrix between the samples.
|
30
|
-
|
31
|
-
Parameters
|
32
|
-
----------
|
33
|
-
kernel_mat : torch.Tensor
|
34
|
-
Kernel matrix between samples x and y.
|
35
|
-
m : int
|
36
|
-
Number of instances in y.
|
37
|
-
permute : bool, default False
|
38
|
-
Whether to permute the row indices. Used for permutation tests.
|
39
|
-
zero_diag : bool, default True
|
40
|
-
Whether to zero out the diagonal of the kernel matrix.
|
41
|
-
|
42
|
-
Returns
|
43
|
-
-------
|
44
|
-
torch.Tensor
|
45
|
-
MMD^2 between the samples from the kernel matrix.
|
46
|
-
"""
|
47
|
-
n = kernel_mat.shape[0] - m
|
48
|
-
if zero_diag:
|
49
|
-
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
|
50
|
-
if permute:
|
51
|
-
idx = torch.randperm(kernel_mat.shape[0])
|
52
|
-
kernel_mat = kernel_mat[idx][:, idx]
|
53
|
-
k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
|
54
|
-
c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
|
55
|
-
mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
|
56
|
-
return mmd2
|
57
|
-
|
58
|
-
|
59
|
-
def preprocess_drift(
|
60
|
-
x: NDArray[Any],
|
61
|
-
model: nn.Module,
|
62
|
-
device: DeviceLike | None = None,
|
63
|
-
preprocess_batch_fn: Callable | None = None,
|
64
|
-
batch_size: int = int(1e10),
|
65
|
-
dtype: type[np.generic] | torch.dtype = np.float32,
|
66
|
-
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
67
|
-
"""
|
68
|
-
Prediction function used for preprocessing step of drift detector.
|
69
|
-
|
70
|
-
Parameters
|
71
|
-
----------
|
72
|
-
x : NDArray
|
73
|
-
Batch of instances.
|
74
|
-
model : nn.Module
|
75
|
-
Model used for preprocessing.
|
76
|
-
device : DeviceLike or None, default None
|
77
|
-
The hardware device to use if specified, otherwise uses the DataEval
|
78
|
-
default or torch default.
|
79
|
-
preprocess_batch_fn : Callable or None, default None
|
80
|
-
Optional batch preprocessing function. For example to convert a list of objects
|
81
|
-
to a batch which can be processed by the PyTorch model.
|
82
|
-
batch_size : int, default 1e10
|
83
|
-
Batch size used during prediction.
|
84
|
-
dtype : np.dtype or torch.dtype, default np.float32
|
85
|
-
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
86
|
-
|
87
|
-
Returns
|
88
|
-
-------
|
89
|
-
NDArray | torch.Tensor | tuple
|
90
|
-
Numpy array, torch tensor or tuples of those with model outputs.
|
91
|
-
"""
|
92
|
-
return predict_batch(
|
93
|
-
x,
|
94
|
-
model,
|
95
|
-
device=get_device(device),
|
96
|
-
batch_size=batch_size,
|
97
|
-
preprocess_fn=preprocess_batch_fn,
|
98
|
-
dtype=dtype,
|
99
|
-
)
|
100
|
-
|
101
|
-
|
102
|
-
@torch.jit.script
|
103
|
-
def _squared_pairwise_distance(
|
104
|
-
x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
|
105
|
-
) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
|
106
|
-
"""
|
107
|
-
PyTorch pairwise squared Euclidean distance between samples x and y.
|
108
|
-
|
109
|
-
Parameters
|
110
|
-
----------
|
111
|
-
x : torch.Tensor
|
112
|
-
Batch of instances of shape [Nx, features].
|
113
|
-
y : torch.Tensor
|
114
|
-
Batch of instances of shape [Ny, features].
|
115
|
-
a_min : float
|
116
|
-
Lower bound to clip distance values.
|
117
|
-
|
118
|
-
Returns
|
119
|
-
-------
|
120
|
-
torch.Tensor
|
121
|
-
Pairwise squared Euclidean distance [Nx, Ny].
|
122
|
-
"""
|
123
|
-
x2 = x.pow(2).sum(dim=-1, keepdim=True)
|
124
|
-
y2 = y.pow(2).sum(dim=-1, keepdim=True)
|
125
|
-
dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
|
126
|
-
return dist.clamp_min_(a_min)
|
127
|
-
|
128
|
-
|
129
|
-
def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
|
130
|
-
"""
|
131
|
-
Bandwidth estimation using the median heuristic `Gretton2012`
|
132
|
-
|
133
|
-
Parameters
|
134
|
-
----------
|
135
|
-
x : torch.Tensor
|
136
|
-
Tensor of instances with dimension [Nx, features].
|
137
|
-
y : torch.Tensor
|
138
|
-
Tensor of instances with dimension [Ny, features].
|
139
|
-
dist : torch.Tensor
|
140
|
-
Tensor with dimensions [Nx, Ny], containing the pairwise distances
|
141
|
-
between `x` and `y`.
|
142
|
-
|
143
|
-
Returns
|
144
|
-
-------
|
145
|
-
torch.Tensor
|
146
|
-
The computed bandwidth, `sigma`.
|
147
|
-
"""
|
148
|
-
n = min(x.shape[0], y.shape[0])
|
149
|
-
n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
|
150
|
-
n_median = n + (np.prod(dist.shape) - n) // 2 - 1
|
151
|
-
sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
|
152
|
-
return sigma
|
153
|
-
|
154
|
-
|
155
|
-
class GaussianRBF(nn.Module):
|
156
|
-
"""
|
157
|
-
Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
|
158
|
-
|
159
|
-
A forward pass takes a batch of instances x [Nx, features] and
|
160
|
-
y [Ny, features] and returns the kernel matrix [Nx, Ny].
|
161
|
-
|
162
|
-
Parameters
|
163
|
-
----------
|
164
|
-
sigma : torch.Tensor | None, default None
|
165
|
-
Bandwidth used for the kernel. Needn't be specified if being inferred or
|
166
|
-
trained. Can pass multiple values to eval kernel with and then average.
|
167
|
-
init_sigma_fn : Callable | None, default None
|
168
|
-
Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
|
169
|
-
inferred. The function's signature should take in the tensors ``x``, ``y`` and
|
170
|
-
``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
|
171
|
-
trainable : bool, default False
|
172
|
-
Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
|
173
|
-
"""
|
174
|
-
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
sigma: torch.Tensor | None = None,
|
178
|
-
init_sigma_fn: Callable | None = None,
|
179
|
-
trainable: bool = False,
|
180
|
-
) -> None:
|
181
|
-
super().__init__()
|
182
|
-
init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
|
183
|
-
self.config: dict[str, Any] = {
|
184
|
-
"sigma": sigma,
|
185
|
-
"trainable": trainable,
|
186
|
-
"init_sigma_fn": init_sigma_fn,
|
187
|
-
}
|
188
|
-
if sigma is None:
|
189
|
-
self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
|
190
|
-
self.init_required: bool = True
|
191
|
-
else:
|
192
|
-
sigma = sigma.reshape(-1) # [Ns,]
|
193
|
-
self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
|
194
|
-
self.init_required: bool = False
|
195
|
-
self.init_sigma_fn = init_sigma_fn
|
196
|
-
self.trainable = trainable
|
197
|
-
|
198
|
-
@property
|
199
|
-
def sigma(self) -> torch.Tensor:
|
200
|
-
return self.log_sigma.exp()
|
201
|
-
|
202
|
-
def forward(
|
203
|
-
self,
|
204
|
-
x: np.ndarray[Any, Any] | torch.Tensor,
|
205
|
-
y: np.ndarray[Any, Any] | torch.Tensor,
|
206
|
-
infer_sigma: bool = False,
|
207
|
-
) -> torch.Tensor:
|
208
|
-
x, y = torch.as_tensor(x), torch.as_tensor(y)
|
209
|
-
dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
|
210
|
-
|
211
|
-
if infer_sigma or self.init_required:
|
212
|
-
if self.trainable and infer_sigma:
|
213
|
-
raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
|
214
|
-
sigma = self.init_sigma_fn(x, y, dist)
|
215
|
-
with torch.no_grad():
|
216
|
-
self.log_sigma.copy_(sigma.log().clone())
|
217
|
-
self.init_required: bool = False
|
218
|
-
|
219
|
-
gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
|
220
|
-
# TODO: do matrix multiplication after all?
|
221
|
-
kernel_mat = torch.exp(-torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
|
222
|
-
return kernel_mat.mean(dim=0) # [Nx, Ny]
|
@@ -1,186 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import math
|
6
|
-
from typing import Any, Iterator, Sequence, cast
|
7
|
-
|
8
|
-
import torch
|
9
|
-
from torch.utils.data import DataLoader, Subset
|
10
|
-
from tqdm import tqdm
|
11
|
-
|
12
|
-
from dataeval.config import DeviceLike, get_device
|
13
|
-
from dataeval.typing import Array, Dataset, Transform
|
14
|
-
from dataeval.utils.torch.models import SupportsEncode
|
15
|
-
|
16
|
-
|
17
|
-
class Embeddings:
|
18
|
-
"""
|
19
|
-
Collection of image embeddings from a dataset.
|
20
|
-
|
21
|
-
Embeddings are accessed by index or slice and are only loaded on-demand.
|
22
|
-
|
23
|
-
Parameters
|
24
|
-
----------
|
25
|
-
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
26
|
-
Dataset to access original images from.
|
27
|
-
batch_size : int
|
28
|
-
Batch size to use when encoding images.
|
29
|
-
transforms : Transform or Sequence[Transform] or None, default None
|
30
|
-
Transforms to apply to images before encoding.
|
31
|
-
model : torch.nn.Module or None, default None
|
32
|
-
Model to use for encoding images.
|
33
|
-
device : DeviceLike or None, default None
|
34
|
-
The hardware device to use if specified, otherwise uses the DataEval
|
35
|
-
default or torch default.
|
36
|
-
cache : bool, default False
|
37
|
-
Whether to cache the embeddings in memory.
|
38
|
-
verbose : bool, default False
|
39
|
-
Whether to print progress bar when encoding images.
|
40
|
-
"""
|
41
|
-
|
42
|
-
device: torch.device
|
43
|
-
batch_size: int
|
44
|
-
verbose: bool
|
45
|
-
|
46
|
-
def __init__(
|
47
|
-
self,
|
48
|
-
dataset: Dataset[tuple[Array, Any, Any]],
|
49
|
-
batch_size: int,
|
50
|
-
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
51
|
-
model: torch.nn.Module | None = None,
|
52
|
-
device: DeviceLike | None = None,
|
53
|
-
cache: bool = False,
|
54
|
-
verbose: bool = False,
|
55
|
-
) -> None:
|
56
|
-
self.device = get_device(device)
|
57
|
-
self.cache = cache
|
58
|
-
self.batch_size = batch_size if batch_size > 0 else 1
|
59
|
-
self.verbose = verbose
|
60
|
-
|
61
|
-
self._dataset = dataset
|
62
|
-
self._length = len(dataset)
|
63
|
-
model = torch.nn.Flatten() if model is None else model
|
64
|
-
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
65
|
-
self._model = model.to(self.device).eval()
|
66
|
-
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
67
|
-
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
68
|
-
self._cached_idx = set()
|
69
|
-
self._embeddings: torch.Tensor = torch.empty(())
|
70
|
-
self._shallow: bool = False
|
71
|
-
|
72
|
-
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
73
|
-
"""
|
74
|
-
Converts dataset to embeddings.
|
75
|
-
|
76
|
-
Parameters
|
77
|
-
----------
|
78
|
-
indices : Sequence[int] or None, default None
|
79
|
-
The indices to convert to embeddings
|
80
|
-
|
81
|
-
Returns
|
82
|
-
-------
|
83
|
-
torch.Tensor
|
84
|
-
|
85
|
-
Warning
|
86
|
-
-------
|
87
|
-
Processing large quantities of data can be resource intensive.
|
88
|
-
"""
|
89
|
-
if indices is not None:
|
90
|
-
return torch.vstack(list(self._batch(indices))).to(self.device)
|
91
|
-
else:
|
92
|
-
return self[:]
|
93
|
-
|
94
|
-
@classmethod
|
95
|
-
def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
|
96
|
-
"""
|
97
|
-
Instantiates a shallow Embeddings object using an array.
|
98
|
-
|
99
|
-
Parameters
|
100
|
-
----------
|
101
|
-
array : Array
|
102
|
-
The array to convert to embeddings.
|
103
|
-
device : DeviceLike or None, default None
|
104
|
-
The hardware device to use if specified, otherwise uses the DataEval
|
105
|
-
default or torch default.
|
106
|
-
|
107
|
-
Returns
|
108
|
-
-------
|
109
|
-
Embeddings
|
110
|
-
|
111
|
-
Example
|
112
|
-
-------
|
113
|
-
>>> import numpy as np
|
114
|
-
>>> from dataeval.utils.data._embeddings import Embeddings
|
115
|
-
>>> array = np.random.randn(100, 3, 224, 224)
|
116
|
-
>>> embeddings = Embeddings.from_array(array)
|
117
|
-
>>> print(embeddings.to_tensor().shape)
|
118
|
-
torch.Size([100, 3, 224, 224])
|
119
|
-
"""
|
120
|
-
embeddings = Embeddings([], 0, None, None, device, True, False)
|
121
|
-
embeddings._length = len(array)
|
122
|
-
embeddings._cached_idx = set(range(len(array)))
|
123
|
-
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
124
|
-
embeddings._shallow = True
|
125
|
-
return embeddings
|
126
|
-
|
127
|
-
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
128
|
-
if self._transforms:
|
129
|
-
images = [transform(image) for transform in self._transforms for image in images]
|
130
|
-
return self._encoder(torch.stack(images).to(self.device))
|
131
|
-
|
132
|
-
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
133
|
-
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
134
|
-
dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
|
135
|
-
total_batches = math.ceil(len(indices) / self.batch_size)
|
136
|
-
|
137
|
-
# If not caching, process all indices normally
|
138
|
-
if not self.cache:
|
139
|
-
for images in tqdm(
|
140
|
-
DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
|
141
|
-
total=total_batches,
|
142
|
-
desc="Batch embedding",
|
143
|
-
disable=not self.verbose,
|
144
|
-
):
|
145
|
-
yield self._encode(images)
|
146
|
-
return
|
147
|
-
|
148
|
-
# If caching, process each batch of indices at a time, preserving original order
|
149
|
-
for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
|
150
|
-
batch = indices[i : i + self.batch_size]
|
151
|
-
uncached = [idx for idx in batch if idx not in self._cached_idx]
|
152
|
-
|
153
|
-
if uncached:
|
154
|
-
# Process uncached indices as as single batch
|
155
|
-
for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
|
156
|
-
embeddings = self._encode(images)
|
157
|
-
|
158
|
-
if not self._embeddings.shape:
|
159
|
-
full_shape = (len(self._dataset), *embeddings.shape[1:])
|
160
|
-
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
161
|
-
|
162
|
-
self._embeddings[uncached] = embeddings
|
163
|
-
self._cached_idx.update(uncached)
|
164
|
-
|
165
|
-
yield self._embeddings[batch]
|
166
|
-
|
167
|
-
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
168
|
-
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
169
|
-
raise TypeError("Invalid argument type.")
|
170
|
-
|
171
|
-
if self._shallow:
|
172
|
-
if not self._embeddings.shape:
|
173
|
-
raise ValueError("Embeddings not initialized.")
|
174
|
-
return self._embeddings[key]
|
175
|
-
|
176
|
-
indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
|
177
|
-
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
178
|
-
return result.squeeze(0) if len(indices) == 1 else result
|
179
|
-
|
180
|
-
def __iter__(self) -> Iterator[torch.Tensor]:
|
181
|
-
# process in batches while yielding individual embeddings
|
182
|
-
for batch in self._batch(range(self._length)):
|
183
|
-
yield from batch
|
184
|
-
|
185
|
-
def __len__(self) -> int:
|
186
|
-
return self._length
|
@@ -1,17 +0,0 @@
|
|
1
|
-
"""Provides access to common Computer Vision datasets."""
|
2
|
-
|
3
|
-
from dataeval.utils.data.datasets._cifar10 import CIFAR10
|
4
|
-
from dataeval.utils.data.datasets._milco import MILCO
|
5
|
-
from dataeval.utils.data.datasets._mnist import MNIST
|
6
|
-
from dataeval.utils.data.datasets._ships import Ships
|
7
|
-
from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
8
|
-
|
9
|
-
__all__ = [
|
10
|
-
"MNIST",
|
11
|
-
"Ships",
|
12
|
-
"CIFAR10",
|
13
|
-
"MILCO",
|
14
|
-
"VOCDetection",
|
15
|
-
"VOCDetectionTorch",
|
16
|
-
"VOCSegmentation",
|
17
|
-
]
|
@@ -1,17 +0,0 @@
|
|
1
|
-
"""Provides selection classes for selecting subsets of Computer Vision datasets."""
|
2
|
-
|
3
|
-
__all__ = [
|
4
|
-
"ClassFilter",
|
5
|
-
"Indices",
|
6
|
-
"Limit",
|
7
|
-
"Prioritize",
|
8
|
-
"Reverse",
|
9
|
-
"Shuffle",
|
10
|
-
]
|
11
|
-
|
12
|
-
from dataeval.utils.data.selections._classfilter import ClassFilter
|
13
|
-
from dataeval.utils.data.selections._indices import Indices
|
14
|
-
from dataeval.utils.data.selections._limit import Limit
|
15
|
-
from dataeval.utils.data.selections._prioritize import Prioritize
|
16
|
-
from dataeval.utils.data.selections._reverse import Reverse
|
17
|
-
from dataeval.utils.data.selections._shuffle import Shuffle
|