dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/data/_embeddings.py +345 -0
  4. dataeval/{utils/data → data}/_images.py +2 -2
  5. dataeval/{utils/data → data}/_metadata.py +8 -7
  6. dataeval/{utils/data → data}/_selection.py +22 -9
  7. dataeval/{utils/data → data}/_split.py +1 -1
  8. dataeval/data/selections/__init__.py +19 -0
  9. dataeval/data/selections/_classbalance.py +37 -0
  10. dataeval/data/selections/_classfilter.py +109 -0
  11. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  12. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  13. dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
  14. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  15. dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
  16. dataeval/detectors/drift/__init__.py +2 -2
  17. dataeval/detectors/drift/_base.py +55 -203
  18. dataeval/detectors/drift/_cvm.py +19 -30
  19. dataeval/detectors/drift/_ks.py +18 -30
  20. dataeval/detectors/drift/_mmd.py +189 -53
  21. dataeval/detectors/drift/_uncertainty.py +52 -56
  22. dataeval/detectors/drift/updates.py +13 -12
  23. dataeval/detectors/linters/duplicates.py +6 -4
  24. dataeval/detectors/linters/outliers.py +3 -3
  25. dataeval/detectors/ood/ae.py +1 -1
  26. dataeval/metadata/_distance.py +1 -1
  27. dataeval/metadata/_ood.py +4 -4
  28. dataeval/metrics/bias/_balance.py +1 -1
  29. dataeval/metrics/bias/_diversity.py +1 -1
  30. dataeval/metrics/bias/_parity.py +1 -1
  31. dataeval/metrics/stats/_base.py +7 -7
  32. dataeval/metrics/stats/_dimensionstats.py +2 -2
  33. dataeval/metrics/stats/_hashstats.py +2 -2
  34. dataeval/metrics/stats/_imagestats.py +4 -4
  35. dataeval/metrics/stats/_labelstats.py +2 -2
  36. dataeval/metrics/stats/_pixelstats.py +2 -2
  37. dataeval/metrics/stats/_visualstats.py +2 -2
  38. dataeval/outputs/_bias.py +1 -1
  39. dataeval/typing.py +53 -19
  40. dataeval/utils/__init__.py +2 -2
  41. dataeval/utils/_array.py +18 -7
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/_dataset.py +6 -4
  44. dataeval/utils/data/collate.py +2 -0
  45. dataeval/utils/datasets/__init__.py +17 -0
  46. dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
  47. dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
  48. dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
  49. dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
  50. dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
  51. dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
  52. dataeval/utils/torch/_internal.py +12 -35
  53. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
  54. dataeval-1.0.0.dist-info/RECORD +107 -0
  55. dataeval/detectors/drift/_torch.py +0 -222
  56. dataeval/utils/data/_embeddings.py +0 -186
  57. dataeval/utils/data/datasets/__init__.py +0 -17
  58. dataeval/utils/data/selections/__init__.py +0 -17
  59. dataeval/utils/data/selections/_classfilter.py +0 -59
  60. dataeval-0.84.0.dist-info/RECORD +0 -106
  61. /dataeval/{utils/data → data}/_targets.py +0 -0
  62. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  63. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  64. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  65. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  66. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
  67. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from functools import partial
6
5
  from typing import Any, Callable
7
6
 
8
7
  import numpy as np
@@ -12,16 +11,16 @@ from torch.utils.data import DataLoader, TensorDataset
12
11
  from tqdm import tqdm
13
12
 
14
13
  from dataeval.config import DeviceLike, get_device
14
+ from dataeval.typing import Array
15
15
 
16
16
 
17
17
  def predict_batch(
18
- x: NDArray[Any] | torch.Tensor,
19
- model: Callable | torch.nn.Module | torch.nn.Sequential,
18
+ x: Array,
19
+ model: torch.nn.Module,
20
20
  device: DeviceLike | None = None,
21
21
  batch_size: int = int(1e10),
22
22
  preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
23
- dtype: type[np.generic] | torch.dtype = np.float32,
24
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
23
+ ) -> torch.Tensor:
25
24
  """
26
25
  Make batch predictions on a model.
27
26
 
@@ -29,7 +28,7 @@ def predict_batch(
29
28
  ----------
30
29
  x : np.ndarray | torch.Tensor
31
30
  Batch of instances.
32
- model : Callable | nn.Module | nn.Sequential
31
+ model : nn.Module
33
32
  PyTorch model.
34
33
  device : DeviceLike or None, default None
35
34
  The hardware device to use if specified, otherwise uses the DataEval
@@ -38,21 +37,18 @@ def predict_batch(
38
37
  Batch size used during prediction.
39
38
  preprocess_fn : Callable | None, default None
40
39
  Optional preprocessing function for each batch.
41
- dtype : np.dtype | torch.dtype, default np.float32
42
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
43
40
 
44
41
  Returns
45
42
  -------
46
- NDArray | torch.Tensor | tuple
47
- Numpy array, torch tensor or tuples of those with model outputs.
43
+ torch.Tensor
44
+ PyTorch tensor with model outputs.
48
45
  """
49
46
  device = get_device(device)
50
- if isinstance(x, np.ndarray):
51
- x = torch.tensor(x, device=device)
47
+ if isinstance(model, torch.nn.Module):
48
+ model = model.to(device).eval()
49
+ x = torch.tensor(x, device=device)
52
50
  n = len(x)
53
51
  n_minibatch = int(np.ceil(n / batch_size))
54
- return_np = not isinstance(dtype, torch.dtype)
55
- preds_tuple = None
56
52
  preds_array = []
57
53
  with torch.no_grad():
58
54
  for i in range(n_minibatch):
@@ -60,28 +56,9 @@ def predict_batch(
60
56
  x_batch = x[istart:istop]
61
57
  if isinstance(preprocess_fn, Callable):
62
58
  x_batch = preprocess_fn(x_batch)
59
+ preds_array.append(model(x_batch.to(dtype=torch.float32)).cpu())
63
60
 
64
- preds_tmp = model(x_batch.to(dtype=torch.float32))
65
- if isinstance(preds_tmp, (list, tuple)):
66
- if preds_tuple is None: # init tuple with lists to store predictions
67
- preds_tuple = tuple([] for _ in range(len(preds_tmp)))
68
- for j, p in enumerate(preds_tmp):
69
- p = p.cpu() if isinstance(p, torch.Tensor) else p
70
- preds_tuple[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
71
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
72
- preds_tmp = preds_tmp.cpu() if isinstance(preds_tmp, torch.Tensor) else preds_tmp
73
- preds_array.append(
74
- preds_tmp if not return_np or isinstance(preds_tmp, np.ndarray) else preds_tmp.numpy()
75
- )
76
- else:
77
- raise TypeError(
78
- f"Model output type {type(preds_tmp)} not supported. The model \
79
- output type needs to be one of list, tuple, NDArray or \
80
- torch.Tensor."
81
- )
82
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
83
- out = tuple(concat(p) for p in preds_tuple) if preds_tuple is not None else concat(preds_array)
84
- return out
61
+ return torch.cat(preds_array, dim=0)
85
62
 
86
63
 
87
64
  def trainer(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.84.0
3
+ Version: 1.0.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -82,8 +82,7 @@ using MAITE-compliant datasets and models.
82
82
 
83
83
  **Python versions:** 3.9 - 3.12
84
84
 
85
- **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*,
86
- *Gradient*
85
+ **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*
87
86
 
88
87
  Choose your preferred method of installation below or follow our
89
88
  [installation guide](https://dataeval.readthedocs.io/en/v0.74.2/installation.html).
@@ -0,0 +1,107 @@
1
+ dataeval/__init__.py,sha256=xd1GfD7QmzBG-WN7K6BMJSzV9_UZlX5OiKICdQ5xGfU,1635
2
+ dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
3
+ dataeval/config.py,sha256=lD1YDH8HosFeRU5rQEYRBcmXMZy-csWaMlJTRZGd9iU,3582
4
+ dataeval/data/__init__.py,sha256=qNnRRiVP_sLthkkHpUrMgI_r8dQK-cC-xoGrrjQeRKc,544
5
+ dataeval/data/_embeddings.py,sha256=6Medqj_JCQt1iwZwWGSs1OeX-bHB8bg5BJqADY1N2s8,12883
6
+ dataeval/data/_images.py,sha256=WF9XJRka8ohUdyI2IKBMAy3JoJhOm1iC-8tbYl8woRM,2642
7
+ dataeval/data/_metadata.py,sha256=hNgsCEN8EyfDDX7zLKcQnsaDl-9xvvs5tUzqMjVLvI4,14457
8
+ dataeval/data/_selection.py,sha256=V61_pTFj0hSzmltA6CV5t51Znqw2dIQZ71Iu46bLm44,4486
9
+ dataeval/data/_split.py,sha256=6Jtm_i__CcPtNE3eSeBdPxc7gn7Cp-GM7g9wJWFlVus,16761
10
+ dataeval/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
11
+ dataeval/data/selections/__init__.py,sha256=2m8ZB53wXzqLcqmc6p5atO6graB6ZyiRSNJFxf11X_g,613
12
+ dataeval/data/selections/_classbalance.py,sha256=7v8ApoL3X8eCZ6fGDNTehE_bZ1loaP3TlhsJLaICVWg,1458
13
+ dataeval/data/selections/_classfilter.py,sha256=rEeq959p_SLl_etS7pcM8ZxK4yzEYlYZAQ3FlcLV0R8,4330
14
+ dataeval/data/selections/_indices.py,sha256=RFsR9z10aM3N0gJSfKrukFpi-LkiQGXoOwXhmOQ5cpg,630
15
+ dataeval/data/selections/_limit.py,sha256=JG4GmEiNKt3sk4PbOUbBnGGzNlyz72H-kQrt8COMm4Y,512
16
+ dataeval/data/selections/_prioritize.py,sha256=yw51ZQk6FPvyC38M4_pS_Se2Dq0LDFcdDhfbsELzTZc,11306
17
+ dataeval/data/selections/_reverse.py,sha256=b67kNC43A5KpQOic5gifjo9HpJ7FMh4LFCrfovPiJ-M,368
18
+ dataeval/data/selections/_shuffle.py,sha256=gVz_2T4rlucq8Ytqz5jvmmZdTrZDaIv43jJbq97tLjQ,1173
19
+ dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
20
+ dataeval/detectors/drift/__init__.py,sha256=gD8aY5PotS-S2ot7iB_z_zzSOjIbQLw5znFBNj0jtHE,646
21
+ dataeval/detectors/drift/_base.py,sha256=amGqzUAe8fU5qwM5lq1p8PCuhjGh9MHkdW1zeBF1LEE,7574
22
+ dataeval/detectors/drift/_cvm.py,sha256=cS33zWJmFY1fft1XcANcP2jSD5ou7TxvIU2AldhTynM,3004
23
+ dataeval/detectors/drift/_ks.py,sha256=uMc5-NA-lSV1IODrY8uJe87ll3uRJT_oXLJFXy95M1w,3186
24
+ dataeval/detectors/drift/_mmd.py,sha256=wHUy_vUafCikrZ_WX8qQXpxFwzw07-5zVutloR6hl1k,11589
25
+ dataeval/detectors/drift/_uncertainty.py,sha256=BHlykJ-r7TGLJxdPfoazXnoAJ1qVDzbk5HjAMdsnHz8,5847
26
+ dataeval/detectors/drift/updates.py,sha256=L1PnrPlIE1x6ujCc5mCwjcAZwadVTn-Zjb6MnTDvzJQ,2251
27
+ dataeval/detectors/linters/__init__.py,sha256=xn2zPwUcmsuf-Jd9uw6AVI11C9z1b1Y9fYtuFnXenZ0,404
28
+ dataeval/detectors/linters/duplicates.py,sha256=X5WSEvI_BHkLoXjkaHK6wTnSkx4IjpO_exMRjSlhc70,4963
29
+ dataeval/detectors/linters/outliers.py,sha256=D8A-Fov5iUrlU9xMX5Ht33FqUY8Lk5ulC6BlHbUoLwU,9048
30
+ dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
31
+ dataeval/detectors/ood/ae.py,sha256=fTrUfFxv6xUqzKpwMC8rW3JrizA16M_bgzqLuBKMrS0,2944
32
+ dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
33
+ dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
34
+ dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
35
+ dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
36
+ dataeval/metadata/_distance.py,sha256=T1Umju_QwBiLmn1iUbxZagzBS2VnHaDIdp6j-NpaZuk,4076
37
+ dataeval/metadata/_ood.py,sha256=lnKtKModArnUrAhH_XswEtUAhUkh1U_oNsLt1UmNP44,12748
38
+ dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
39
+ dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
40
+ dataeval/metrics/bias/__init__.py,sha256=329S1_3WnWqeU4-qVcbe0fMy4lDrj9uKslWHIQf93yg,839
41
+ dataeval/metrics/bias/_balance.py,sha256=l1hTVkVwD85bP20MTthA-I5BkvbytylQkJu3Q6iTuPA,6152
42
+ dataeval/metrics/bias/_completeness.py,sha256=BysXU2Jpw33n5dl3acJFEqF3mFGiJLsfG4n5Q2fkTaY,4608
43
+ dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
44
+ dataeval/metrics/bias/_diversity.py,sha256=B_qWVDMZfh818U0qVm8yidquB0H0XvW8N75OWVWXy2g,5814
45
+ dataeval/metrics/bias/_parity.py,sha256=ea1D-eJh6cJxQ11XD6VbDXBKecE0jJJwptGD7LQJmBw,11529
46
+ dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
47
+ dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
48
+ dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
49
+ dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
50
+ dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
51
+ dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
52
+ dataeval/metrics/stats/_base.py,sha256=YIfOVGd7E19B4dpAnzDYRQkaikvRRyJIpznJNfVtPdw,10750
53
+ dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
54
+ dataeval/metrics/stats/_dimensionstats.py,sha256=73mFP-Myxne0peFliwvTntc0kk4cpq0krzMvSLDSIMM,2702
55
+ dataeval/metrics/stats/_hashstats.py,sha256=gp9X_pnTT3mPH9YNrWLdn2LQPK_epJ3dQRoyOCwmKlg,4758
56
+ dataeval/metrics/stats/_imagestats.py,sha256=gUPNgN5Zwzdr7WnSwbve1NXNsyxd5dy3cSnlR_7guCg,3007
57
+ dataeval/metrics/stats/_labelstats.py,sha256=lz8I6eSd8tFkmQqy5cOG8hn9yxs0mP-Ic9ratFHiuoU,2813
58
+ dataeval/metrics/stats/_pixelstats.py,sha256=SfergRbjNJE4h0xqe-0c8RnKtZmEkZ9MwExdipLSGvg,3247
59
+ dataeval/metrics/stats/_visualstats.py,sha256=cq4AbF2B50Ihbzb86FphcnKQ1TSwNnP3PsnbpiPQZWw,3698
60
+ dataeval/outputs/__init__.py,sha256=ciK-RdXgtn_s7MSCUW1UXvrXltMbltqbpfe9_V7xGrI,1701
61
+ dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
62
+ dataeval/outputs/_bias.py,sha256=7L-d3DUWY6Vud7iX_VoQT0HG0KaV1U35gvmRApqzyB0,12401
63
+ dataeval/outputs/_drift.py,sha256=gOiu2C-ERTWiRqlP0auMYxPBGdm9HecWPqWfg7I4tZg,2015
64
+ dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
65
+ dataeval/outputs/_linters.py,sha256=YOdjrfm8ypdRrqYOaPM9nc6wVJI3-ita3Haj7LHDNaw,6416
66
+ dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI,1710
67
+ dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
68
+ dataeval/outputs/_stats.py,sha256=c73Yc3Kkrl-MN6BGKe1V0Yr6Ix2Yp_DZZfFSp8fZMZ0,13180
69
+ dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
70
+ dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
71
+ dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
+ dataeval/typing.py,sha256=GDMuef-oFFukNtsiKFmsExHdNvYR_j-tQcsCwZ9reow,7198
73
+ dataeval/utils/__init__.py,sha256=hRvyUK7b3d6JBEV5u47rFcOHEcmDYqAvZQw_T5pDAWw,264
74
+ dataeval/utils/_array.py,sha256=KqAdXEMjcXYvdWdYEEoEbigwQJ4S9VYxQS3sRFeY5XY,5929
75
+ dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
76
+ dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,5436
77
+ dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
78
+ dataeval/utils/_image.py,sha256=capzF_X5H0jy0PmTP3Hf52GFgLqrnfU6gS4tiwck9jo,1939
79
+ dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
80
+ dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
81
+ dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
82
+ dataeval/utils/data/__init__.py,sha256=xGzrjrOxOP2DP1tU84AWMKPnSxFvSjM81CTlDg4rNM8,331
83
+ dataeval/utils/data/_dataset.py,sha256=MHY582yRm4FxQkkLWUhKZBb7ZyvWypM6ldUG89vd3uE,7936
84
+ dataeval/utils/data/collate.py,sha256=5egEEKhNNCGeNLChO1p6dZ4Wg6x51VEaMNHz7hEZUxI,3936
85
+ dataeval/utils/data/metadata.py,sha256=1XeGYj_e97-nJ_IrWEHPhWICmouYU5qbXWbp7uhZrIE,14171
86
+ dataeval/utils/datasets/__init__.py,sha256=Jfe7XI_9U5S4wuI_2QCoeuWNOxz4j0nAQvxc5wG5mWY,486
87
+ dataeval/utils/datasets/_base.py,sha256=TpmgPzF3EShCLAF5S4Zf9lFN78q17bTZF6AUE1qKdlk,8857
88
+ dataeval/utils/datasets/_cifar10.py,sha256=oSX5JEzbBM4zGC9kC7-hVTOglms3rYaUuYiA00_DUJ4,5439
89
+ dataeval/utils/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
90
+ dataeval/utils/datasets/_milco.py,sha256=BF2XvyzuOop1mg5pFZcRfYmZcezlbpZWHyd_TtEHFF4,7573
91
+ dataeval/utils/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
92
+ dataeval/utils/datasets/_mnist.py,sha256=4WOkQTORYMs6KEeyyJgChTnH03797y4ezgaZtYqplh4,8102
93
+ dataeval/utils/datasets/_ships.py,sha256=RMdX2KlnXJYOTzBb6euA5TAqxs-S8b56pAGiyQhNMuo,4870
94
+ dataeval/utils/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
95
+ dataeval/utils/datasets/_voc.py,sha256=kif6ms_romK6VElP4pf2SK4cJ5dEHDOkxSaSaeP3c5k,15565
96
+ dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
97
+ dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
98
+ dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
99
+ dataeval/utils/torch/_internal.py,sha256=vHy-DzPhmvE8h3wmWc3aciBJ8nDGzQ1z1jTZgGjmDyM,4154
100
+ dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
101
+ dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
102
+ dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
103
+ dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
104
+ dataeval-1.0.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
105
+ dataeval-1.0.0.dist-info/METADATA,sha256=ma_TquWQQl0QETiK4-wH1jfAe2my33Cl37GswNe0ZM8,5307
106
+ dataeval-1.0.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
107
+ dataeval-1.0.0.dist-info/RECORD,,
@@ -1,222 +0,0 @@
1
- """
2
- Source code derived from Alibi-Detect 0.11.4
3
- https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
-
5
- Original code Copyright (c) 2023 Seldon Technologies Ltd
6
- Licensed under Apache Software License (Apache 2.0)
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- __all__ = []
12
-
13
- from typing import Any, Callable
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- from numpy.typing import NDArray
19
-
20
- from dataeval.config import DeviceLike, get_device
21
- from dataeval.utils.torch._internal import predict_batch
22
-
23
-
24
- def mmd2_from_kernel_matrix(
25
- kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
26
- ) -> torch.Tensor:
27
- """
28
- Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
29
- full kernel matrix between the samples.
30
-
31
- Parameters
32
- ----------
33
- kernel_mat : torch.Tensor
34
- Kernel matrix between samples x and y.
35
- m : int
36
- Number of instances in y.
37
- permute : bool, default False
38
- Whether to permute the row indices. Used for permutation tests.
39
- zero_diag : bool, default True
40
- Whether to zero out the diagonal of the kernel matrix.
41
-
42
- Returns
43
- -------
44
- torch.Tensor
45
- MMD^2 between the samples from the kernel matrix.
46
- """
47
- n = kernel_mat.shape[0] - m
48
- if zero_diag:
49
- kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
50
- if permute:
51
- idx = torch.randperm(kernel_mat.shape[0])
52
- kernel_mat = kernel_mat[idx][:, idx]
53
- k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
54
- c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
55
- mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
56
- return mmd2
57
-
58
-
59
- def preprocess_drift(
60
- x: NDArray[Any],
61
- model: nn.Module,
62
- device: DeviceLike | None = None,
63
- preprocess_batch_fn: Callable | None = None,
64
- batch_size: int = int(1e10),
65
- dtype: type[np.generic] | torch.dtype = np.float32,
66
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
67
- """
68
- Prediction function used for preprocessing step of drift detector.
69
-
70
- Parameters
71
- ----------
72
- x : NDArray
73
- Batch of instances.
74
- model : nn.Module
75
- Model used for preprocessing.
76
- device : DeviceLike or None, default None
77
- The hardware device to use if specified, otherwise uses the DataEval
78
- default or torch default.
79
- preprocess_batch_fn : Callable or None, default None
80
- Optional batch preprocessing function. For example to convert a list of objects
81
- to a batch which can be processed by the PyTorch model.
82
- batch_size : int, default 1e10
83
- Batch size used during prediction.
84
- dtype : np.dtype or torch.dtype, default np.float32
85
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
86
-
87
- Returns
88
- -------
89
- NDArray | torch.Tensor | tuple
90
- Numpy array, torch tensor or tuples of those with model outputs.
91
- """
92
- return predict_batch(
93
- x,
94
- model,
95
- device=get_device(device),
96
- batch_size=batch_size,
97
- preprocess_fn=preprocess_batch_fn,
98
- dtype=dtype,
99
- )
100
-
101
-
102
- @torch.jit.script
103
- def _squared_pairwise_distance(
104
- x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
105
- ) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
106
- """
107
- PyTorch pairwise squared Euclidean distance between samples x and y.
108
-
109
- Parameters
110
- ----------
111
- x : torch.Tensor
112
- Batch of instances of shape [Nx, features].
113
- y : torch.Tensor
114
- Batch of instances of shape [Ny, features].
115
- a_min : float
116
- Lower bound to clip distance values.
117
-
118
- Returns
119
- -------
120
- torch.Tensor
121
- Pairwise squared Euclidean distance [Nx, Ny].
122
- """
123
- x2 = x.pow(2).sum(dim=-1, keepdim=True)
124
- y2 = y.pow(2).sum(dim=-1, keepdim=True)
125
- dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
126
- return dist.clamp_min_(a_min)
127
-
128
-
129
- def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
130
- """
131
- Bandwidth estimation using the median heuristic `Gretton2012`
132
-
133
- Parameters
134
- ----------
135
- x : torch.Tensor
136
- Tensor of instances with dimension [Nx, features].
137
- y : torch.Tensor
138
- Tensor of instances with dimension [Ny, features].
139
- dist : torch.Tensor
140
- Tensor with dimensions [Nx, Ny], containing the pairwise distances
141
- between `x` and `y`.
142
-
143
- Returns
144
- -------
145
- torch.Tensor
146
- The computed bandwidth, `sigma`.
147
- """
148
- n = min(x.shape[0], y.shape[0])
149
- n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
150
- n_median = n + (np.prod(dist.shape) - n) // 2 - 1
151
- sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
152
- return sigma
153
-
154
-
155
- class GaussianRBF(nn.Module):
156
- """
157
- Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
158
-
159
- A forward pass takes a batch of instances x [Nx, features] and
160
- y [Ny, features] and returns the kernel matrix [Nx, Ny].
161
-
162
- Parameters
163
- ----------
164
- sigma : torch.Tensor | None, default None
165
- Bandwidth used for the kernel. Needn't be specified if being inferred or
166
- trained. Can pass multiple values to eval kernel with and then average.
167
- init_sigma_fn : Callable | None, default None
168
- Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
169
- inferred. The function's signature should take in the tensors ``x``, ``y`` and
170
- ``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
171
- trainable : bool, default False
172
- Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
173
- """
174
-
175
- def __init__(
176
- self,
177
- sigma: torch.Tensor | None = None,
178
- init_sigma_fn: Callable | None = None,
179
- trainable: bool = False,
180
- ) -> None:
181
- super().__init__()
182
- init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
183
- self.config: dict[str, Any] = {
184
- "sigma": sigma,
185
- "trainable": trainable,
186
- "init_sigma_fn": init_sigma_fn,
187
- }
188
- if sigma is None:
189
- self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
190
- self.init_required: bool = True
191
- else:
192
- sigma = sigma.reshape(-1) # [Ns,]
193
- self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
194
- self.init_required: bool = False
195
- self.init_sigma_fn = init_sigma_fn
196
- self.trainable = trainable
197
-
198
- @property
199
- def sigma(self) -> torch.Tensor:
200
- return self.log_sigma.exp()
201
-
202
- def forward(
203
- self,
204
- x: np.ndarray[Any, Any] | torch.Tensor,
205
- y: np.ndarray[Any, Any] | torch.Tensor,
206
- infer_sigma: bool = False,
207
- ) -> torch.Tensor:
208
- x, y = torch.as_tensor(x), torch.as_tensor(y)
209
- dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
210
-
211
- if infer_sigma or self.init_required:
212
- if self.trainable and infer_sigma:
213
- raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
214
- sigma = self.init_sigma_fn(x, y, dist)
215
- with torch.no_grad():
216
- self.log_sigma.copy_(sigma.log().clone())
217
- self.init_required: bool = False
218
-
219
- gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
220
- # TODO: do matrix multiplication after all?
221
- kernel_mat = torch.exp(-torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
222
- return kernel_mat.mean(dim=0) # [Nx, Ny]
@@ -1,186 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import math
6
- from typing import Any, Iterator, Sequence, cast
7
-
8
- import torch
9
- from torch.utils.data import DataLoader, Subset
10
- from tqdm import tqdm
11
-
12
- from dataeval.config import DeviceLike, get_device
13
- from dataeval.typing import Array, Dataset, Transform
14
- from dataeval.utils.torch.models import SupportsEncode
15
-
16
-
17
- class Embeddings:
18
- """
19
- Collection of image embeddings from a dataset.
20
-
21
- Embeddings are accessed by index or slice and are only loaded on-demand.
22
-
23
- Parameters
24
- ----------
25
- dataset : ImageClassificationDataset or ObjectDetectionDataset
26
- Dataset to access original images from.
27
- batch_size : int
28
- Batch size to use when encoding images.
29
- transforms : Transform or Sequence[Transform] or None, default None
30
- Transforms to apply to images before encoding.
31
- model : torch.nn.Module or None, default None
32
- Model to use for encoding images.
33
- device : DeviceLike or None, default None
34
- The hardware device to use if specified, otherwise uses the DataEval
35
- default or torch default.
36
- cache : bool, default False
37
- Whether to cache the embeddings in memory.
38
- verbose : bool, default False
39
- Whether to print progress bar when encoding images.
40
- """
41
-
42
- device: torch.device
43
- batch_size: int
44
- verbose: bool
45
-
46
- def __init__(
47
- self,
48
- dataset: Dataset[tuple[Array, Any, Any]],
49
- batch_size: int,
50
- transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
51
- model: torch.nn.Module | None = None,
52
- device: DeviceLike | None = None,
53
- cache: bool = False,
54
- verbose: bool = False,
55
- ) -> None:
56
- self.device = get_device(device)
57
- self.cache = cache
58
- self.batch_size = batch_size if batch_size > 0 else 1
59
- self.verbose = verbose
60
-
61
- self._dataset = dataset
62
- self._length = len(dataset)
63
- model = torch.nn.Flatten() if model is None else model
64
- self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
65
- self._model = model.to(self.device).eval()
66
- self._encoder = model.encode if isinstance(model, SupportsEncode) else model
67
- self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
68
- self._cached_idx = set()
69
- self._embeddings: torch.Tensor = torch.empty(())
70
- self._shallow: bool = False
71
-
72
- def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
73
- """
74
- Converts dataset to embeddings.
75
-
76
- Parameters
77
- ----------
78
- indices : Sequence[int] or None, default None
79
- The indices to convert to embeddings
80
-
81
- Returns
82
- -------
83
- torch.Tensor
84
-
85
- Warning
86
- -------
87
- Processing large quantities of data can be resource intensive.
88
- """
89
- if indices is not None:
90
- return torch.vstack(list(self._batch(indices))).to(self.device)
91
- else:
92
- return self[:]
93
-
94
- @classmethod
95
- def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
96
- """
97
- Instantiates a shallow Embeddings object using an array.
98
-
99
- Parameters
100
- ----------
101
- array : Array
102
- The array to convert to embeddings.
103
- device : DeviceLike or None, default None
104
- The hardware device to use if specified, otherwise uses the DataEval
105
- default or torch default.
106
-
107
- Returns
108
- -------
109
- Embeddings
110
-
111
- Example
112
- -------
113
- >>> import numpy as np
114
- >>> from dataeval.utils.data._embeddings import Embeddings
115
- >>> array = np.random.randn(100, 3, 224, 224)
116
- >>> embeddings = Embeddings.from_array(array)
117
- >>> print(embeddings.to_tensor().shape)
118
- torch.Size([100, 3, 224, 224])
119
- """
120
- embeddings = Embeddings([], 0, None, None, device, True, False)
121
- embeddings._length = len(array)
122
- embeddings._cached_idx = set(range(len(array)))
123
- embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
124
- embeddings._shallow = True
125
- return embeddings
126
-
127
- def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
128
- if self._transforms:
129
- images = [transform(image) for transform in self._transforms for image in images]
130
- return self._encoder(torch.stack(images).to(self.device))
131
-
132
- @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
133
- def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
134
- dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
135
- total_batches = math.ceil(len(indices) / self.batch_size)
136
-
137
- # If not caching, process all indices normally
138
- if not self.cache:
139
- for images in tqdm(
140
- DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
141
- total=total_batches,
142
- desc="Batch embedding",
143
- disable=not self.verbose,
144
- ):
145
- yield self._encode(images)
146
- return
147
-
148
- # If caching, process each batch of indices at a time, preserving original order
149
- for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
150
- batch = indices[i : i + self.batch_size]
151
- uncached = [idx for idx in batch if idx not in self._cached_idx]
152
-
153
- if uncached:
154
- # Process uncached indices as as single batch
155
- for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
156
- embeddings = self._encode(images)
157
-
158
- if not self._embeddings.shape:
159
- full_shape = (len(self._dataset), *embeddings.shape[1:])
160
- self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
161
-
162
- self._embeddings[uncached] = embeddings
163
- self._cached_idx.update(uncached)
164
-
165
- yield self._embeddings[batch]
166
-
167
- def __getitem__(self, key: int | slice, /) -> torch.Tensor:
168
- if not isinstance(key, slice) and not hasattr(key, "__int__"):
169
- raise TypeError("Invalid argument type.")
170
-
171
- if self._shallow:
172
- if not self._embeddings.shape:
173
- raise ValueError("Embeddings not initialized.")
174
- return self._embeddings[key]
175
-
176
- indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
177
- result = torch.vstack(list(self._batch(indices))).to(self.device)
178
- return result.squeeze(0) if len(indices) == 1 else result
179
-
180
- def __iter__(self) -> Iterator[torch.Tensor]:
181
- # process in batches while yielding individual embeddings
182
- for batch in self._batch(range(self._length)):
183
- yield from batch
184
-
185
- def __len__(self) -> int:
186
- return self._length
@@ -1,17 +0,0 @@
1
- """Provides access to common Computer Vision datasets."""
2
-
3
- from dataeval.utils.data.datasets._cifar10 import CIFAR10
4
- from dataeval.utils.data.datasets._milco import MILCO
5
- from dataeval.utils.data.datasets._mnist import MNIST
6
- from dataeval.utils.data.datasets._ships import Ships
7
- from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
8
-
9
- __all__ = [
10
- "MNIST",
11
- "Ships",
12
- "CIFAR10",
13
- "MILCO",
14
- "VOCDetection",
15
- "VOCDetectionTorch",
16
- "VOCSegmentation",
17
- ]
@@ -1,17 +0,0 @@
1
- """Provides selection classes for selecting subsets of Computer Vision datasets."""
2
-
3
- __all__ = [
4
- "ClassFilter",
5
- "Indices",
6
- "Limit",
7
- "Prioritize",
8
- "Reverse",
9
- "Shuffle",
10
- ]
11
-
12
- from dataeval.utils.data.selections._classfilter import ClassFilter
13
- from dataeval.utils.data.selections._indices import Indices
14
- from dataeval.utils.data.selections._limit import Limit
15
- from dataeval.utils.data.selections._prioritize import Prioritize
16
- from dataeval.utils.data.selections._reverse import Reverse
17
- from dataeval.utils.data.selections._shuffle import Shuffle