dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +7 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/{_internal → utils}/split_dataset.py +98 -33
  52. dataeval/utils/tensorflow/__init__.py +7 -6
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  67. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -8
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.1.dist-info/RECORD +0 -81
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
4
+
5
+ from typing import Any
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ class AriaAutoencoder(nn.Module):
11
+ """
12
+ An autoencoder model with a separate encoder and decoder.
13
+
14
+ Parameters
15
+ ----------
16
+ channels : int, default 3
17
+ Number of input channels
18
+ """
19
+
20
+ def __init__(self, channels: int = 3) -> None:
21
+ super().__init__()
22
+ self.encoder: Encoder = Encoder(channels)
23
+ self.decoder: Decoder = Decoder(channels)
24
+
25
+ def forward(self, x: Any) -> Any:
26
+ """
27
+ Perform a forward pass through the encoder and decoder.
28
+
29
+ Parameters
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ The reconstructed output tensor.
38
+ """
39
+ x = self.encoder(x)
40
+ x = self.decoder(x)
41
+ return x
42
+
43
+ def encode(self, x: Any) -> Any:
44
+ """
45
+ Encode the input tensor using the encoder.
46
+
47
+ Parameters
48
+ ----------
49
+ x : torch.Tensor
50
+ Input tensor
51
+
52
+ Returns
53
+ -------
54
+ torch.Tensor
55
+ The encoded representation of the input tensor.
56
+ """
57
+ return self.encoder(x)
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ """
62
+ A simple encoder to be used in an autoencoder model.
63
+
64
+ This is the encoder used by the AriaAutoencoder model.
65
+
66
+ Parameters
67
+ ----------
68
+ channels : int, default 3
69
+ Number of input channels
70
+ """
71
+
72
+ def __init__(self, channels: int = 3) -> None:
73
+ super().__init__()
74
+ self.encoder: nn.Sequential = nn.Sequential(
75
+ nn.Conv2d(channels, 256, 2, stride=1, padding=1),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(2),
78
+ nn.Conv2d(256, 128, 2, stride=1, padding=1),
79
+ nn.ReLU(),
80
+ nn.MaxPool2d(2),
81
+ nn.Conv2d(128, 64, 2, stride=1),
82
+ )
83
+
84
+ def forward(self, x: Any) -> Any:
85
+ """
86
+ Perform a forward pass through the encoder.
87
+
88
+ Parameters
89
+ ----------
90
+ x : torch.Tensor
91
+ Input tensor
92
+
93
+ Returns
94
+ -------
95
+ torch.Tensor
96
+ The encoded representation of the input tensor.
97
+ """
98
+ return self.encoder(x)
99
+
100
+
101
+ class Decoder(nn.Module):
102
+ """
103
+ A simple decoder to be used in an autoencoder model.
104
+
105
+ This is the decoder used by the AriaAutoencoder model.
106
+
107
+ Parameters
108
+ ----------
109
+ channels : int
110
+ Number of output channels
111
+ """
112
+
113
+ def __init__(self, channels: int) -> None:
114
+ super().__init__()
115
+ self.decoder: nn.Sequential = nn.Sequential(
116
+ nn.ConvTranspose2d(64, 128, 2, stride=1),
117
+ nn.ReLU(),
118
+ nn.ConvTranspose2d(128, 256, 2, stride=2),
119
+ nn.ReLU(),
120
+ nn.ConvTranspose2d(256, channels, 2, stride=2),
121
+ nn.Sigmoid(),
122
+ )
123
+
124
+ def forward(self, x: Any) -> Any:
125
+ """
126
+ Perform a forward pass through the decoder.
127
+
128
+ Parameters
129
+ ----------
130
+ x : torch.Tensor
131
+ The encoded tensor.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The reconstructed output tensor.
137
+ """
138
+ return self.decoder(x)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["AETrainer"]
4
+
3
5
  from typing import Any
4
6
 
5
7
  import torch
@@ -38,11 +40,11 @@ class AETrainer:
38
40
  ):
39
41
  if device == "auto":
40
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
- self.device = device
42
- self.model = model.to(device)
43
+ self.device: torch.device = torch.device(device)
44
+ self.model: nn.Module = model.to(device)
43
45
  self.batch_size = batch_size
44
46
 
45
- def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
47
+ def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
46
48
  """
47
49
  Basic image reconstruction training function for :term:`Autoencoder` models
48
50
 
@@ -101,7 +103,7 @@ class AETrainer:
101
103
  return loss_history
102
104
 
103
105
  @torch.no_grad
104
- def eval(self, dataset: Dataset) -> float:
106
+ def eval(self, dataset: Dataset[Any]) -> float:
105
107
  """
106
108
  Basic image reconstruction evaluation function for :term:`autoencoder<Autoencoder>` models
107
109
 
@@ -137,7 +139,7 @@ class AETrainer:
137
139
  return total_loss / len(dataloader)
138
140
 
139
141
  @torch.no_grad
140
- def encode(self, dataset: Dataset) -> torch.Tensor:
142
+ def encode(self, dataset: Dataset[Any]) -> torch.Tensor:
141
143
  """
142
144
  Create image :term:`embeddings<Embeddings>` for the dataset using the model's encoder.
143
145
 
@@ -174,134 +176,3 @@ class AETrainer:
174
176
  encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
175
177
 
176
178
  return encodings
177
-
178
-
179
- class AriaAutoencoder(nn.Module):
180
- """
181
- An :term:`autoencoder<Autoencoder>` model with a separate encoder and decoder.
182
-
183
- Parameters
184
- ----------
185
- channels : int, default 3
186
- Number of input channels
187
- """
188
-
189
- def __init__(self, channels=3):
190
- super().__init__()
191
- self.encoder = Encoder(channels)
192
- self.decoder = Decoder(channels)
193
-
194
- def forward(self, x):
195
- """
196
- Perform a forward pass through the encoder and decoder.
197
-
198
- Parameters
199
- ----------
200
- x : torch.Tensor
201
- Input tensor
202
-
203
- Returns
204
- -------
205
- torch.Tensor
206
- The reconstructed output tensor.
207
- """
208
- x = self.encoder(x)
209
- x = self.decoder(x)
210
- return x
211
-
212
- def encode(self, x):
213
- """
214
- Encode the input tensor using the encoder.
215
-
216
- Parameters
217
- ----------
218
- x : torch.Tensor
219
- Input tensor
220
-
221
- Returns
222
- -------
223
- torch.Tensor
224
- The encoded representation of the input tensor.
225
- """
226
- return self.encoder(x)
227
-
228
-
229
- class Encoder(nn.Module):
230
- """
231
- A simple encoder to be used in an :term:`autoencoder<Autoencoder>` model.
232
-
233
- This is the encoder used by the AriaAutoencoder model.
234
-
235
- Parameters
236
- ----------
237
- channels : int, default 3
238
- Number of input channels
239
- """
240
-
241
- def __init__(self, channels=3):
242
- super().__init__()
243
- self.encoder = nn.Sequential(
244
- nn.Conv2d(channels, 256, 2, stride=1, padding=1),
245
- nn.ReLU(),
246
- nn.MaxPool2d(2),
247
- nn.Conv2d(256, 128, 2, stride=1, padding=1),
248
- nn.ReLU(),
249
- nn.MaxPool2d(2),
250
- nn.Conv2d(128, 64, 2, stride=1),
251
- )
252
-
253
- def forward(self, x):
254
- """
255
- Perform a forward pass through the encoder.
256
-
257
- Parameters
258
- ----------
259
- x : torch.Tensor
260
- Input tensor
261
-
262
- Returns
263
- -------
264
- torch.Tensor
265
- The encoded representation of the input tensor.
266
- """
267
- return self.encoder(x)
268
-
269
-
270
- class Decoder(nn.Module):
271
- """
272
- A simple decoder to be used in an :term:`autoencoder<Autoencoder>` model.
273
-
274
- This is the decoder used by the AriaAutoencoder model.
275
-
276
- Parameters
277
- ----------
278
- channels : int
279
- Number of output channels
280
- """
281
-
282
- def __init__(self, channels):
283
- super().__init__()
284
- self.decoder = nn.Sequential(
285
- nn.ConvTranspose2d(64, 128, 2, stride=1),
286
- nn.ReLU(),
287
- nn.ConvTranspose2d(128, 256, 2, stride=2),
288
- nn.ReLU(),
289
- nn.ConvTranspose2d(256, channels, 2, stride=2),
290
- nn.Sigmoid(),
291
- )
292
-
293
- def forward(self, x):
294
- """
295
- Perform a forward pass through the decoder.
296
-
297
- Parameters
298
- ----------
299
- x : torch.Tensor
300
- The encoded tensor.
301
-
302
- Returns
303
- -------
304
- torch.Tensor
305
- The reconstructed output tensor.
306
- """
307
- return self.decoder(x)
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["read_dataset"]
4
+
3
5
  from collections import defaultdict
4
6
  from typing import Any
5
7
 
6
8
  from torch.utils.data import Dataset
7
9
 
8
10
 
9
- def read_dataset(dataset: Dataset) -> list[list[Any]]:
11
+ def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
10
12
  """
11
13
  Extract information from a dataset at each index into individual lists of each information position
12
14
 
@@ -5,6 +5,6 @@ Workflows perform a sequence of actions to analyze the dataset and make predicti
5
5
  from dataeval import _IS_TORCH_AVAILABLE
6
6
 
7
7
  if _IS_TORCH_AVAILABLE: # pragma: no cover
8
- from dataeval._internal.workflows.sufficiency import Sufficiency, SufficiencyOutput
8
+ from dataeval.workflows.sufficiency import Sufficiency, SufficiencyOutput
9
9
 
10
10
  __all__ = ["Sufficiency", "SufficiencyOutput"]
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["SufficiencyOutput", "Sufficiency"]
4
+
3
5
  import warnings
4
6
  from dataclasses import dataclass
5
- from typing import Any, Callable, Iterable, Mapping, Sequence, cast
7
+ from typing import Any, Callable, Generic, Iterable, Mapping, Sequence, TypeVar, cast
6
8
 
7
9
  import matplotlib.pyplot as plt
8
10
  import numpy as np
@@ -13,8 +15,8 @@ from numpy.typing import ArrayLike, NDArray
13
15
  from scipy.optimize import basinhopping
14
16
  from torch.utils.data import Dataset
15
17
 
16
- from dataeval._internal.interop import as_numpy
17
- from dataeval._internal.output import OutputMetadata, set_metadata
18
+ from dataeval.interop import as_numpy
19
+ from dataeval.output import OutputMetadata, set_metadata
18
20
 
19
21
 
20
22
  @dataclass(frozen=True)
@@ -36,7 +38,7 @@ class SufficiencyOutput(OutputMetadata):
36
38
  params: dict[str, NDArray[np.float64]]
37
39
  measures: dict[str, NDArray[np.float64]]
38
40
 
39
- def __post_init__(self):
41
+ def __post_init__(self) -> None:
40
42
  c = len(self.steps)
41
43
  if set(self.params) != set(self.measures):
42
44
  raise ValueError("params and measures have a key mismatch")
@@ -45,7 +47,7 @@ class SufficiencyOutput(OutputMetadata):
45
47
  if c != c_v:
46
48
  raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
47
49
 
48
- @set_metadata("dataeval.workflows.SufficiencyOutput")
50
+ @set_metadata()
49
51
  def project(
50
52
  self,
51
53
  projection: int | Iterable[int],
@@ -170,7 +172,7 @@ class SufficiencyOutput(OutputMetadata):
170
172
  return projection
171
173
 
172
174
 
173
- def f_out(n_i: NDArray, x: NDArray) -> NDArray:
175
+ def f_out(n_i: NDArray[Any], x: NDArray[Any]) -> NDArray[Any]:
174
176
  """
175
177
  Calculates the line of best fit based on its free parameters
176
178
 
@@ -189,7 +191,7 @@ def f_out(n_i: NDArray, x: NDArray) -> NDArray:
189
191
  return x[0] * n_i ** (-x[1]) + x[2]
190
192
 
191
193
 
192
- def f_inv_out(y_i: NDArray, x: NDArray) -> NDArray[np.uint64]:
194
+ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
193
195
  """
194
196
  Inverse function for f_out()
195
197
 
@@ -209,7 +211,7 @@ def f_inv_out(y_i: NDArray, x: NDArray) -> NDArray[np.uint64]:
209
211
  return np.asarray(n_i, dtype=np.uint64)
210
212
 
211
213
 
212
- def calc_params(p_i: NDArray, n_i: NDArray, niter: int) -> NDArray:
214
+ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
213
215
  """
214
216
  Retrieves the inverse power curve coefficients for the line of best fit.
215
217
  Global minimization is done via basin hopping. More info on this algorithm
@@ -254,7 +256,7 @@ def calc_params(p_i: NDArray, n_i: NDArray, niter: int) -> NDArray:
254
256
  return res.x
255
257
 
256
258
 
257
- def reset_parameters(model: nn.Module):
259
+ def reset_parameters(model: nn.Module) -> nn.Module:
258
260
  """
259
261
  Re-initializes each layer in the model using
260
262
  the layer's defined weight_init function
@@ -272,7 +274,7 @@ def reset_parameters(model: nn.Module):
272
274
  return model.apply(fn=weight_reset)
273
275
 
274
276
 
275
- def validate_dataset_len(dataset: Dataset) -> int:
277
+ def validate_dataset_len(dataset: Dataset[Any]) -> int:
276
278
  if not hasattr(dataset, "__len__"):
277
279
  raise TypeError("Must provide a dataset with a length attribute")
278
280
  length: int = dataset.__len__() # type: ignore
@@ -281,7 +283,7 @@ def validate_dataset_len(dataset: Dataset) -> int:
281
283
  return length
282
284
 
283
285
 
284
- def project_steps(params: NDArray, projection: NDArray) -> NDArray:
286
+ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any]:
285
287
  """Projects the measures for each value of X
286
288
 
287
289
  Parameters
@@ -300,7 +302,7 @@ def project_steps(params: NDArray, projection: NDArray) -> NDArray:
300
302
  return 1 - f_out(projection, params)
301
303
 
302
304
 
303
- def inv_project_steps(params: NDArray, targets: NDArray) -> NDArray[np.uint64]:
305
+ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.uint64]:
304
306
  """Inverse function for project_steps()
305
307
 
306
308
  Parameters
@@ -320,7 +322,7 @@ def inv_project_steps(params: NDArray, targets: NDArray) -> NDArray[np.uint64]:
320
322
  return np.ceil(steps)
321
323
 
322
324
 
323
- def get_curve_params(measures: dict[str, NDArray], ranges: NDArray, niter: int) -> dict[str, NDArray]:
325
+ def get_curve_params(measures: dict[str, NDArray[Any]], ranges: NDArray[Any], niter: int) -> dict[str, NDArray[Any]]:
324
326
  """Calculates and aggregates parameters for both single and multi-class metrics"""
325
327
  output = {}
326
328
  for name, measure in measures.items():
@@ -337,10 +339,10 @@ def get_curve_params(measures: dict[str, NDArray], ranges: NDArray, niter: int)
337
339
 
338
340
  def plot_measure(
339
341
  name: str,
340
- steps: NDArray,
341
- measure: NDArray,
342
- params: NDArray,
343
- projection: NDArray,
342
+ steps: NDArray[Any],
343
+ measure: NDArray[Any],
344
+ params: NDArray[Any],
345
+ projection: NDArray[Any],
344
346
  ) -> Figure:
345
347
  fig = plt.figure()
346
348
  fig = cast(Figure, fig)
@@ -367,7 +369,10 @@ def plot_measure(
367
369
  return fig
368
370
 
369
371
 
370
- class Sufficiency:
372
+ T = TypeVar("T")
373
+
374
+
375
+ class Sufficiency(Generic[T]):
371
376
  """
372
377
  Project dataset :term:`sufficiency<Sufficiency>` using given a model and evaluation criteria
373
378
 
@@ -401,10 +406,10 @@ class Sufficiency:
401
406
  def __init__(
402
407
  self,
403
408
  model: nn.Module,
404
- train_ds: Dataset,
405
- test_ds: Dataset,
406
- train_fn: Callable[[nn.Module, Dataset, Sequence[int]], None],
407
- eval_fn: Callable[[nn.Module, Dataset], Mapping[str, float] | Mapping[str, ArrayLike]],
409
+ train_ds: Dataset[T],
410
+ test_ds: Dataset[T],
411
+ train_fn: Callable[[nn.Module, Dataset[T], Sequence[int]], None],
412
+ eval_fn: Callable[[nn.Module, Dataset[T]], Mapping[str, float] | Mapping[str, ArrayLike]],
408
413
  runs: int = 1,
409
414
  substeps: int = 5,
410
415
  train_kwargs: Mapping[str, Any] | None = None,
@@ -421,29 +426,29 @@ class Sufficiency:
421
426
  self.eval_kwargs = eval_kwargs
422
427
 
423
428
  @property
424
- def train_ds(self):
429
+ def train_ds(self) -> Dataset[T]:
425
430
  return self._train_ds
426
431
 
427
432
  @train_ds.setter
428
- def train_ds(self, value: Dataset):
433
+ def train_ds(self, value: Dataset[T]) -> None:
429
434
  self._train_ds = value
430
435
  self._length = validate_dataset_len(value)
431
436
 
432
437
  @property
433
- def test_ds(self):
438
+ def test_ds(self) -> Dataset[T]:
434
439
  return self._test_ds
435
440
 
436
441
  @test_ds.setter
437
- def test_ds(self, value: Dataset):
442
+ def test_ds(self, value: Dataset[T]) -> None:
438
443
  validate_dataset_len(value)
439
444
  self._test_ds = value
440
445
 
441
446
  @property
442
- def train_fn(self) -> Callable[[nn.Module, Dataset, Sequence[int]], None]:
447
+ def train_fn(self) -> Callable[[nn.Module, Dataset[T], Sequence[int]], None]:
443
448
  return self._train_fn
444
449
 
445
450
  @train_fn.setter
446
- def train_fn(self, value: Callable[[nn.Module, Dataset, Sequence[int]], None]):
451
+ def train_fn(self, value: Callable[[nn.Module, Dataset[T], Sequence[int]], None]) -> None:
447
452
  if not callable(value):
448
453
  raise TypeError("Must provide a callable for train_fn.")
449
454
  self._train_fn = value
@@ -451,14 +456,14 @@ class Sufficiency:
451
456
  @property
452
457
  def eval_fn(
453
458
  self,
454
- ) -> Callable[[nn.Module, Dataset], dict[str, float] | Mapping[str, ArrayLike]]:
459
+ ) -> Callable[[nn.Module, Dataset[T]], dict[str, float] | Mapping[str, ArrayLike]]:
455
460
  return self._eval_fn
456
461
 
457
462
  @eval_fn.setter
458
463
  def eval_fn(
459
464
  self,
460
- value: Callable[[nn.Module, Dataset], dict[str, float] | Mapping[str, ArrayLike]],
461
- ):
465
+ value: Callable[[nn.Module, Dataset[T]], dict[str, float] | Mapping[str, ArrayLike]],
466
+ ) -> None:
462
467
  if not callable(value):
463
468
  raise TypeError("Must provide a callable for eval_fn.")
464
469
  self._eval_fn = value
@@ -468,7 +473,7 @@ class Sufficiency:
468
473
  return self._train_kwargs
469
474
 
470
475
  @train_kwargs.setter
471
- def train_kwargs(self, value: Mapping[str, Any] | None):
476
+ def train_kwargs(self, value: Mapping[str, Any] | None) -> None:
472
477
  self._train_kwargs = {} if value is None else value
473
478
 
474
479
  @property
@@ -476,10 +481,10 @@ class Sufficiency:
476
481
  return self._eval_kwargs
477
482
 
478
483
  @eval_kwargs.setter
479
- def eval_kwargs(self, value: Mapping[str, Any] | None):
484
+ def eval_kwargs(self, value: Mapping[str, Any] | None) -> None:
480
485
  self._eval_kwargs = {} if value is None else value
481
486
 
482
- @set_metadata("dataeval.workflows", ["runs", "substeps"])
487
+ @set_metadata(["runs", "substeps"])
483
488
  def evaluate(self, eval_at: int | Iterable[int] | None = None, niter: int = 1000) -> SufficiencyOutput:
484
489
  """
485
490
  Creates data indices, trains models, and returns plotting data
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.72.1
3
+ Version: 0.72.2
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -36,6 +36,7 @@ Requires-Dist: tf-keras (>=2.16) ; extra == "tensorflow" or extra == "all"
36
36
  Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
37
37
  Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
38
38
  Requires-Dist: tqdm
39
+ Requires-Dist: typing-extensions (>=4.12) ; python_version >= "3.9" and python_version < "3.10"
39
40
  Requires-Dist: xxhash (>=3.3)
40
41
  Project-URL: Documentation, https://dataeval.readthedocs.io/
41
42
  Project-URL: Repository, https://github.com/aria-ml/dataeval/
@@ -0,0 +1,72 @@
1
+ dataeval/__init__.py,sha256=UYhkwned7TR5hiU_c8I_qUaKogO1EODTBgT-9_t0ofI,641
2
+ dataeval/detectors/__init__.py,sha256=xdp8LYOFjV5tVbAwu0Y03KU9EajHkSFy_M3raqbxpDc,383
3
+ dataeval/detectors/drift/__init__.py,sha256=MRPWFOaoVoqAHW36nA5F3wk7QXJU4oecND2RbtgG9oY,757
4
+ dataeval/detectors/drift/base.py,sha256=0S-0MFpIFaJ4_8IGreFKSmyna2L50FBn7DVaoNWmw8E,14509
5
+ dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
6
+ dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
7
+ dataeval/detectors/drift/mmd.py,sha256=TqGOnUNYKwpS0GQPV3dSl-_qRa0g2flmoQ-dxzW_JfY,7586
8
+ dataeval/detectors/drift/torch.py,sha256=D46J72OPW8-PpP3w9ODMBfcDSdailIgVjgHVFpbYfws,11649
9
+ dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
10
+ dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
11
+ dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
12
+ dataeval/detectors/linters/clusterer.py,sha256=OtBE5rglAGdTTQRmKUHP6J-uWmnh2E3lZxeqJCnc87U,21014
13
+ dataeval/detectors/linters/duplicates.py,sha256=tOD43rJkvheIA3mznbUqHhft2yD3xRZQdCt61daIca4,5665
14
+ dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
15
+ dataeval/detectors/linters/outliers.py,sha256=BUVvtbKHo04KnRmrgb84MBr0l1gtcY3-xNCHjetFrEQ,10117
16
+ dataeval/detectors/ood/__init__.py,sha256=FVyVuaxVKAOgSTaaBf-j2OXXDarSBFcJ7CTlMV6w88s,661
17
+ dataeval/detectors/ood/ae.py,sha256=cdwrgCpQkueK_HQoQbeXw7s0oTE-6FKVtXe9vETDe5M,2117
18
+ dataeval/detectors/ood/aegmm.py,sha256=jK5aN1UjwwZaSLB3BpzH25eLp5wBqzlgylsfphaoZaE,1814
19
+ dataeval/detectors/ood/base.py,sha256=S9jl4xH2zB_-ixalysQJZEvRCGOqMQSruacvfd4Dnfc,8687
20
+ dataeval/detectors/ood/llr.py,sha256=HUNsro-cV7RR5Mht6pJ4NWCRR7aWeVdjwkBNurs5LbM,10378
21
+ dataeval/detectors/ood/metadata_ks_compare.py,sha256=jH7uDwyyBIIcTrRhQEdnLAdrwf7LfNczKBw0CpJyF5c,4282
22
+ dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
23
+ dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
24
+ dataeval/detectors/ood/vae.py,sha256=O1jpGkpavtJAqn4WrmocPRMtkX4iSdkpiCDUPBF1Ano,2925
25
+ dataeval/detectors/ood/vaegmm.py,sha256=37epPiQKeicy6SZD0D7O7hCFQSajZ-8wvga1pmJiq2s,2183
26
+ dataeval/interop.py,sha256=CFtGyVTwTqkJFkNfhHYhnBRVwxKIQ9f-9Zuuz_uQDqo,1589
27
+ dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
28
+ dataeval/metrics/bias/__init__.py,sha256=puf645-hAO5hFHNHlZ239TPopqWIoN-uLGXFB8-hA_o,599
29
+ dataeval/metrics/bias/balance.py,sha256=pgxaIqFvRcygYlAUbM_BKrbi45WU7fRV08HBrI7Z5q4,8569
30
+ dataeval/metrics/bias/coverage.py,sha256=Ku9l-qvc6YrRiQ0PRzkpfjInyOhkAKKSO_bf_LnOwNg,3623
31
+ dataeval/metrics/bias/diversity.py,sha256=-cmh-vyAUrn4rbn6-ZXvLuaO43Ncj28GKyeTmhWRzfE,8973
32
+ dataeval/metrics/bias/metadata.py,sha256=nUZRwhcKaJM0GVwXn5k11Fa1s56_OtOBF7tmXjMDpsM,8919
33
+ dataeval/metrics/bias/parity.py,sha256=uJ3p8m6id5mZpDNnS1NmxCThb5V6v75lJv_0TGAhCRA,16668
34
+ dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
35
+ dataeval/metrics/estimators/ber.py,sha256=SVT-BIC_GLs0l2l2NhWu4OpRbgn96w-OwTSoPHTnQbE,5037
36
+ dataeval/metrics/estimators/divergence.py,sha256=pImaa216-YYTgGWDCSTcpJrC-dfl7150yVrPfW_TyGc,4293
37
+ dataeval/metrics/estimators/uap.py,sha256=Tz1VZOyUa68HlTh94Rl-wnXCWdTAVjTQc3LtSPEWVu4,2175
38
+ dataeval/metrics/stats/__init__.py,sha256=igLRaAt1nX6yRwC4xI0zNPBADi3u7EsSxWP3OZ8AqcU,1086
39
+ dataeval/metrics/stats/base.py,sha256=9M5g2FAWvd50HT-T2h-MCmYLpvk--em_yWro1qWGHFs,12177
40
+ dataeval/metrics/stats/boxratiostats.py,sha256=iNr-FdppiJ7XAeeLY-o7gL_PSxvT8j86iwRijKca2Eg,6465
41
+ dataeval/metrics/stats/datasetstats.py,sha256=LAMFCIS9v0RjLrdKUFuo8nY-3HLVvRlqQIXGMKtsHEw,6255
42
+ dataeval/metrics/stats/dimensionstats.py,sha256=xdTp2AbGH3xefUUsB4sDjgSKiojJ73DCHyuCOPKsErc,4056
43
+ dataeval/metrics/stats/hashstats.py,sha256=X6aSouaMhDcGZMLuCTje3G4QOr2i-Td6H3SyBFDF6mA,4960
44
+ dataeval/metrics/stats/labelstats.py,sha256=BKwSmyxCr2wYq8IMraCUS-b5wqacfT_BukJUYNfqeCo,4114
45
+ dataeval/metrics/stats/pixelstats.py,sha256=x90O10IqVjEORtYwueFLvJnVYTxhPBOOx5HMweBQnJY,4578
46
+ dataeval/metrics/stats/visualstats.py,sha256=y0xIvst7epcajk8vz2jngiAiz0T7DZC-M97Rs1-vV9I,4950
47
+ dataeval/output.py,sha256=jWXXNxFNBEaY1rN7Z-6LZl6bQT-I7z_wqr91Rhrdt_0,3061
48
+ dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
+ dataeval/utils/__init__.py,sha256=zTgPsmloPy0qZMzb4xipNNdIWpaHtseGph68pIAD-hQ,684
50
+ dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
51
+ dataeval/utils/shared.py,sha256=BvEeYPMNQTmx4LSaImGeC0VkvcbEY3Byqtxa-jQ3xgc,3623
52
+ dataeval/utils/split_dataset.py,sha256=IopyxwC3FaZwgVriW4OXze-mDMpOlvRr83OADA5Jydk,19454
53
+ dataeval/utils/tensorflow/__init__.py,sha256=l4OjIA75JJXeNWDCkST1xtDMVYsw97lZ-9JXFBlyuYg,539
54
+ dataeval/utils/tensorflow/_internal/autoencoder.py,sha256=-pm4VqMEjHcrgre-K8uhMvaEVHyeqZsZbejrnlM6OtY,10430
55
+ dataeval/utils/tensorflow/_internal/gmm.py,sha256=QoEgbeax1GETqRmUF7A2ih9uFOZfFAjGzgH2ljExlAc,3669
56
+ dataeval/utils/tensorflow/_internal/loss.py,sha256=IXW_kxovLaTLd6UkMOIQLPEAGrOMILHDKagvRYgj-DE,4065
57
+ dataeval/utils/tensorflow/_internal/pixelcnn.py,sha256=Aa7koa7YxqhHmFequpsfMw2-61KO03evWWcvvFTuaco,48518
58
+ dataeval/utils/tensorflow/_internal/trainer.py,sha256=ld7pisl4ZXjEA6nxBStRNDEuNJme0IPo08oWqal6bYc,4167
59
+ dataeval/utils/tensorflow/_internal/utils.py,sha256=k1mjy44oE63SIkckvU8BTlqtWsCnGynJF4eYyw1pebQ,8799
60
+ dataeval/utils/tensorflow/loss/__init__.py,sha256=Q-66vt91Oe1ByYfo28tW32zXDq2MqQ2gngWgmIVmof8,227
61
+ dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
62
+ dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
63
+ dataeval/utils/torch/datasets.py,sha256=9YV9-Uhq6NCMuu1hPhMnQXjmeI-Ld8ve1z_haxre88o,15023
64
+ dataeval/utils/torch/models.py,sha256=0BsXmLK8W1OZ8nnEGb1f9LzIeCgtevQC37dvKS1v1vA,3236
65
+ dataeval/utils/torch/trainer.py,sha256=EraOKiXxiMNiycStZNMR5yRz3ehgp87d9ewR9a9dV4w,5559
66
+ dataeval/utils/torch/utils.py,sha256=FI4LJ6DvXFQJVff8fxSCP7LRkp8H9BIUgYX0kk7_Cuo,1537
67
+ dataeval/workflows/__init__.py,sha256=x2JnOoKmLUCZOsB6RNPqMdVvxEb6Hpda5GPJnD_k0v0,310
68
+ dataeval/workflows/sufficiency.py,sha256=1jSYhH9i4oesmJYs5PZvWS1LGXf8ekOgNhpFtMPLPXk,18552
69
+ dataeval-0.72.2.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
70
+ dataeval-0.72.2.dist-info/METADATA,sha256=ddOmTZA6nX7VceQhOmyQ-cQ1aBv2VU9Za32vnmjP-VE,4702
71
+ dataeval-0.72.2.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
72
+ dataeval-0.72.2.dist-info/RECORD,,
File without changes
File without changes
File without changes
File without changes