dataeval 0.82.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/_mmd.py +9 -9
  4. dataeval/detectors/drift/_torch.py +7 -7
  5. dataeval/detectors/drift/_uncertainty.py +4 -4
  6. dataeval/detectors/linters/duplicates.py +3 -3
  7. dataeval/detectors/linters/outliers.py +3 -3
  8. dataeval/detectors/ood/ae.py +5 -4
  9. dataeval/detectors/ood/base.py +2 -2
  10. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  11. dataeval/detectors/ood/mixin.py +1 -1
  12. dataeval/detectors/ood/vae.py +2 -1
  13. dataeval/metadata/_distance.py +11 -44
  14. dataeval/metadata/_ood.py +9 -7
  15. dataeval/metrics/bias/_balance.py +7 -3
  16. dataeval/metrics/bias/_diversity.py +3 -0
  17. dataeval/metrics/bias/_parity.py +2 -0
  18. dataeval/metrics/stats/_base.py +3 -3
  19. dataeval/metrics/stats/_boxratiostats.py +1 -1
  20. dataeval/metrics/stats/_imagestats.py +4 -4
  21. dataeval/outputs/__init__.py +4 -0
  22. dataeval/outputs/_base.py +50 -21
  23. dataeval/outputs/_bias.py +1 -1
  24. dataeval/outputs/_linters.py +4 -2
  25. dataeval/outputs/_metadata.py +54 -0
  26. dataeval/outputs/_stats.py +12 -6
  27. dataeval/utils/data/_embeddings.py +8 -9
  28. dataeval/utils/data/_metadata.py +16 -7
  29. dataeval/utils/data/_selection.py +4 -8
  30. dataeval/utils/data/_split.py +3 -2
  31. dataeval/utils/data/selections/_classfilter.py +5 -3
  32. dataeval/utils/torch/_internal.py +5 -5
  33. dataeval/utils/torch/trainer.py +8 -8
  34. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +1 -1
  35. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/RECORD +37 -36
  36. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  37. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -248,13 +248,13 @@ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
248
248
  if type(a) is not type(b):
249
249
  raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
250
250
 
251
- sum_dict = deepcopy(a.dict())
251
+ sum_dict = deepcopy(a.data())
252
252
 
253
253
  for k in sum_dict:
254
254
  if isinstance(sum_dict[k], list):
255
- sum_dict[k].extend(b.dict()[k])
255
+ sum_dict[k].extend(b.data()[k])
256
256
  else:
257
- sum_dict[k] = np.concatenate((sum_dict[k], b.dict()[k]))
257
+ sum_dict[k] = np.concatenate((sum_dict[k], b.data()[k]))
258
258
 
259
259
  return type(a)(**sum_dict)
260
260
 
@@ -153,7 +153,7 @@ def boxratiostats(
153
153
  raise ValueError("Input for boxstats and imgstats must have matching channel information.")
154
154
 
155
155
  output_dict = {}
156
- for key in boxstats.dict():
156
+ for key in boxstats.data():
157
157
  output_dict[key] = calculate_ratios(key, boxstats, imgstats)
158
158
 
159
159
  return output_cls(**output_dict)
@@ -42,8 +42,8 @@ def imagestats(
42
42
  Calculates various :term:`statistics<Statistics>` for each image.
43
43
 
44
44
  This function computes dimension, pixel and visual metrics
45
- on the images or individual bounding boxes for each image as
46
- well as label statistics if provided.
45
+ on the images or individual bounding boxes for each image. If
46
+ performing calculations per channel dimension stats are excluded.
47
47
 
48
48
  Parameters
49
49
  ----------
@@ -61,7 +61,7 @@ def imagestats(
61
61
 
62
62
  See Also
63
63
  --------
64
- dimensionstats, labelstats, pixelstats, visualstats, Outliers
64
+ dimensionstats, pixelstats, visualstats
65
65
 
66
66
  Examples
67
67
  --------
@@ -91,4 +91,4 @@ def imagestats(
91
91
  output_cls = ImageStatsOutput
92
92
 
93
93
  outputs = run_stats(dataset, per_box, per_channel, processors)
94
- return output_cls(**{k: v for d in outputs for k, v in d.dict().items()})
94
+ return output_cls(**{k: v for d in outputs for k, v in d.data().items()})
@@ -8,6 +8,7 @@ from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOu
8
8
  from ._drift import DriftMMDOutput, DriftOutput
9
9
  from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
10
10
  from ._linters import DuplicatesOutput, OutliersOutput
11
+ from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
11
12
  from ._ood import OODOutput, OODScoreOutput
12
13
  from ._stats import (
13
14
  ChannelStatsOutput,
@@ -39,6 +40,9 @@ __all__ = [
39
40
  "ImageStatsOutput",
40
41
  "LabelParityOutput",
41
42
  "LabelStatsOutput",
43
+ "MetadataDistanceOutput",
44
+ "MetadataDistanceValues",
45
+ "MostDeviatedFactorsOutput",
42
46
  "OODOutput",
43
47
  "OODScoreOutput",
44
48
  "OutliersOutput",
dataeval/outputs/_base.py CHANGED
@@ -4,11 +4,11 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import logging
7
- from collections.abc import Mapping
7
+ from collections.abc import Collection, Mapping, Sequence
8
8
  from dataclasses import dataclass
9
9
  from datetime import datetime, timezone
10
10
  from functools import partial, wraps
11
- from typing import Any, Callable, Iterator, TypeVar
11
+ from typing import Any, Callable, Generic, Iterator, TypeVar, overload
12
12
 
13
13
  import numpy as np
14
14
  from typing_extensions import ParamSpec
@@ -56,16 +56,13 @@ class ExecutionMetadata:
56
56
  )
57
57
 
58
58
 
59
- class Output:
60
- _meta: ExecutionMetadata | None = None
59
+ T = TypeVar("T", covariant=True)
61
60
 
62
- def __str__(self) -> str:
63
- return f"{self.__class__.__name__}: {str(self.dict())}"
64
61
 
65
- def dict(self) -> dict[str, Any]:
66
- return {k: v for k, v in self.__dict__.items() if k != "_meta"}
62
+ class GenericOutput(Generic[T]):
63
+ _meta: ExecutionMetadata | None = None
67
64
 
68
- @property
65
+ def data(self) -> T: ...
69
66
  def meta(self) -> ExecutionMetadata:
70
67
  """
71
68
  Metadata about the execution of the function or method for the Output class.
@@ -73,34 +70,66 @@ class Output:
73
70
  return self._meta or ExecutionMetadata.empty()
74
71
 
75
72
 
76
- TKey = TypeVar("TKey", str, int, float, set)
77
- TValue = TypeVar("TValue")
73
+ class Output(GenericOutput[dict[str, Any]]):
74
+ def data(self) -> dict[str, Any]:
75
+ return {k: v for k, v in self.__dict__.items() if k != "_meta"}
78
76
 
77
+ def __repr__(self) -> str:
78
+ return str(self)
79
79
 
80
- class MappingOutput(Mapping[TKey, TValue], Output):
80
+ def __str__(self) -> str:
81
+ return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
82
+
83
+
84
+ class BaseCollectionMixin(Collection[Any]):
81
85
  __slots__ = ["_data"]
82
86
 
87
+ def data(self) -> Any:
88
+ return self._data
89
+
90
+ def __len__(self) -> int:
91
+ return len(self._data)
92
+
93
+ def __repr__(self) -> str:
94
+ return f"{self.__class__.__name__}({repr(self._data)})"
95
+
96
+ def __str__(self) -> str:
97
+ return str(self._data)
98
+
99
+
100
+ TKey = TypeVar("TKey", str, int, float, set)
101
+ TValue = TypeVar("TValue")
102
+
103
+
104
+ class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
83
105
  def __init__(self, data: Mapping[TKey, TValue]):
84
106
  self._data = data
85
107
 
86
108
  def __getitem__(self, key: TKey) -> TValue:
87
- return self._data.__getitem__(key)
109
+ return self._data[key]
88
110
 
89
111
  def __iter__(self) -> Iterator[TKey]:
90
- return self._data.__iter__()
112
+ return iter(self._data)
91
113
 
92
- def __len__(self) -> int:
93
- return self._data.__len__()
94
114
 
95
- def dict(self) -> dict[str, TValue]:
96
- return {str(k): v for k, v in self._data.items()}
115
+ class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
116
+ def __init__(self, data: Sequence[TValue]):
117
+ self._data = data
118
+
119
+ @overload
120
+ def __getitem__(self, index: int) -> TValue: ...
121
+ @overload
122
+ def __getitem__(self, index: slice) -> Sequence[TValue]: ...
97
123
 
98
- def __str__(self) -> str:
99
- return str(self.dict())
124
+ def __getitem__(self, index: int | slice) -> TValue | Sequence[TValue]:
125
+ return self._data[index]
126
+
127
+ def __iter__(self) -> Iterator[TValue]:
128
+ return iter(self._data)
100
129
 
101
130
 
102
131
  P = ParamSpec("P")
103
- R = TypeVar("R", bound=Output)
132
+ R = TypeVar("R", bound=GenericOutput)
104
133
 
105
134
 
106
135
  def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
dataeval/outputs/_bias.py CHANGED
@@ -364,7 +364,7 @@ class DiversityOutput(Output):
364
364
  col_labels,
365
365
  xlabel="Factors",
366
366
  ylabel="Class",
367
- cbarlabel=f"Normalized {asdict(self.meta)['arguments']['method'].title()} Index",
367
+ cbarlabel=f"Normalized {asdict(self.meta())['arguments']['method'].title()} Index",
368
368
  )
369
369
 
370
370
  else:
@@ -24,7 +24,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
24
24
 
25
25
 
26
26
  @dataclass(frozen=True)
27
- class DuplicatesOutput(Generic[TIndexCollection], Output):
27
+ class DuplicatesOutput(Output, Generic[TIndexCollection]):
28
28
  """
29
29
  Output class for :class:`.Duplicates` lint detector.
30
30
 
@@ -35,6 +35,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
35
35
  near: list[list[int] | dict[int, list[int]]]
36
36
  Indices of images that are near matches
37
37
 
38
+ Notes
39
+ -----
38
40
  - For a single dataset, indices are returned as a list of index groups.
39
41
  - For multiple datasets, indices are returned as dictionaries where the key is the
40
42
  index of the dataset, and the value is the list index groups from that dataset.
@@ -99,7 +101,7 @@ def _create_pandas_dataframe(class_wise):
99
101
 
100
102
 
101
103
  @dataclass(frozen=True)
102
- class OutliersOutput(Generic[TIndexIssueMap], Output):
104
+ class OutliersOutput(Output, Generic[TIndexIssueMap]):
103
105
  """
104
106
  Output class for :class:`.Outliers` lint detector.
105
107
 
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import NamedTuple
6
+
7
+ from dataeval.outputs._base import MappingOutput, SequenceOutput
8
+
9
+
10
+ class MostDeviatedFactorsOutput(SequenceOutput[tuple[str, float]]):
11
+ """
12
+ Output class for results of :func:`.most_deviated_factors` for OOD samples with metadata.
13
+
14
+ Attributes
15
+ ----------
16
+ value : tuple[str, float]
17
+ A tuple of the factor name and deviation of the highest metadata deviation
18
+ """
19
+
20
+
21
+ class MetadataDistanceValues(NamedTuple):
22
+ """
23
+ Statistics comparing metadata distance.
24
+
25
+ Attributes
26
+ ----------
27
+ statistic : float
28
+ the KS statistic
29
+ location : float
30
+ The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
31
+ to the median of the reference distribution.
32
+ dist : float
33
+ The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
34
+ pvalue : float
35
+ The p-value from the KS two-sample test
36
+ """
37
+
38
+ statistic: float
39
+ location: float
40
+ dist: float
41
+ pvalue: float
42
+
43
+
44
+ class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
45
+ """
46
+ Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
47
+
48
+ Attributes
49
+ ----------
50
+ key : str
51
+ Metadata feature names
52
+ value : :class:`.MetadataDistanceValues`
53
+ Output per feature name containing the statistic, statistic location, distance, and pvalue.
54
+ """
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Iterable, Optional, Union
7
+ from typing import Any, Iterable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -63,7 +63,7 @@ class BaseStatsOutput(Output):
63
63
 
64
64
  def __post_init__(self) -> None:
65
65
  length = len(self.source_index)
66
- bad = {k: len(v) for k, v in self.dict().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
66
+ bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
67
67
  if bad:
68
68
  raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
69
69
 
@@ -105,7 +105,7 @@ class BaseStatsOutput(Output):
105
105
  def _get_channels(
106
106
  self, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
107
107
  ) -> tuple[int, list[bool] | None]:
108
- source_index = self.dict()[SOURCE_INDEX]
108
+ source_index = self.data()[SOURCE_INDEX]
109
109
  raw_channels = int(max([si.channel or 0 for si in source_index])) + 1
110
110
  if isinstance(channel_index, int):
111
111
  max_channels = 1 if channel_index < raw_channels else raw_channels
@@ -127,15 +127,21 @@ class BaseStatsOutput(Output):
127
127
 
128
128
  return max_channels, ch_mask
129
129
 
130
+ def factors(self) -> dict[str, NDArray[Any]]:
131
+ return {
132
+ k: v
133
+ for k, v in self.data().items()
134
+ if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
135
+ }
136
+
130
137
  def plot(
131
138
  self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
132
139
  ) -> None:
133
140
  max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
134
- d = {k: v for k, v in self.dict().items() if isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1}
135
141
  if max_channels == 1:
136
- histogram_plot(d, log)
142
+ histogram_plot(self.factors(), log)
137
143
  else:
138
- channel_histogram_plot(d, log, max_channels, ch_mask)
144
+ channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
139
145
 
140
146
 
141
147
  @dataclass(frozen=True)
@@ -9,7 +9,7 @@ import torch
9
9
  from torch.utils.data import DataLoader, Subset
10
10
  from tqdm import tqdm
11
11
 
12
- from dataeval.config import get_device
12
+ from dataeval.config import DeviceLike, get_device
13
13
  from dataeval.typing import Array, Dataset
14
14
  from dataeval.utils.torch.models import SupportsEncode
15
15
 
@@ -24,13 +24,14 @@ class Embeddings:
24
24
  ----------
25
25
  dataset : ImageClassificationDataset or ObjectDetectionDataset
26
26
  Dataset to access original images from.
27
- batch_size : int, optional
27
+ batch_size : int
28
28
  Batch size to use when encoding images.
29
- model : torch.nn.Module, optional
29
+ model : torch.nn.Module or None, default None
30
30
  Model to use for encoding images.
31
- device : torch.device, optional
32
- Device to use for encoding images.
33
- verbose : bool, optional
31
+ device : DeviceLike or None, default None
32
+ The hardware device to use if specified, otherwise uses the DataEval
33
+ default or torch default.
34
+ verbose : bool, default False
34
35
  Whether to print progress bar when encoding images.
35
36
  """
36
37
 
@@ -42,9 +43,8 @@ class Embeddings:
42
43
  self,
43
44
  dataset: Dataset[tuple[Array, Any, Any]],
44
45
  batch_size: int,
45
- indices: Sequence[int] | None = None,
46
46
  model: torch.nn.Module | None = None,
47
- device: torch.device | str | None = None,
47
+ device: DeviceLike | None = None,
48
48
  verbose: bool = False,
49
49
  ) -> None:
50
50
  self.device = get_device(device)
@@ -52,7 +52,6 @@ class Embeddings:
52
52
  self.verbose = verbose
53
53
 
54
54
  self._dataset = dataset
55
- self._indices = indices if indices is not None else range(len(dataset))
56
55
  model = torch.nn.Flatten() if model is None else model
57
56
  self._model = model.to(self.device).eval()
58
57
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -11,6 +11,7 @@ from numpy.typing import NDArray
11
11
  from dataeval.typing import (
12
12
  AnnotatedDataset,
13
13
  Array,
14
+ ArrayLike,
14
15
  ObjectDetectionTarget,
15
16
  )
16
17
  from dataeval.utils._array import as_numpy, to_numpy
@@ -276,16 +277,12 @@ class Metadata:
276
277
  if self._processed and not force:
277
278
  return
278
279
 
279
- # Trigger collate and merge if not yet done
280
- self._collate()
281
- self._merge()
280
+ # Create image indices from targets
281
+ self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
282
282
 
283
283
  # Validate the metadata dimensions
284
284
  self._validate()
285
285
 
286
- # Create image indices from targets
287
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
288
-
289
286
  # Include specified metadata keys
290
287
  if self.include:
291
288
  metadata = {i: self.merged[i] for i in self.include if i in self.merged}
@@ -358,3 +355,15 @@ class Metadata:
358
355
  )
359
356
  self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
360
357
  self._processed = True
358
+
359
+ def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
360
+ self._merge()
361
+ self._processed = False
362
+ target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
+ if any(len(v) != target_len for v in factors.values()):
364
+ raise ValueError(
365
+ "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
+ )
367
+ merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
368
+ for k, v in factors.items():
369
+ merged[k] = v
@@ -3,11 +3,11 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from enum import IntEnum
6
- from typing import Any, Generic, Iterator, Sequence, TypeVar
6
+ from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
8
  from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
9
 
10
- _TDatum = TypeVar("_TDatum")
10
+ _TDatum = TypeVar("_TDatum", covariant=True)
11
11
 
12
12
 
13
13
  class SelectionStage(IntEnum):
@@ -69,11 +69,11 @@ class Select(AnnotatedDataset[_TDatum]):
69
69
  dataset: AnnotatedDataset[_TDatum],
70
70
  selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
71
71
  ) -> None:
72
+ self.__dict__.update(dataset.__dict__)
72
73
  self._dataset = dataset
73
74
  self._size_limit = len(dataset)
74
75
  self._selection = list(range(self._size_limit))
75
76
  self._selections = self._sort_selections(selections)
76
- self.__dict__.update(dataset.__dict__)
77
77
 
78
78
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
79
  _metadata = getattr(dataset, "metadata", {})
@@ -93,7 +93,7 @@ class Select(AnnotatedDataset[_TDatum]):
93
93
  title = f"{self.__class__.__name__} Dataset"
94
94
  sep = "-" * len(title)
95
95
  selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
96
- return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
96
+ return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
97
97
 
98
98
  def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
99
99
  if not selections:
@@ -111,10 +111,6 @@ class Select(AnnotatedDataset[_TDatum]):
111
111
  selection(self)
112
112
  self._selection = self._selection[: self._size_limit]
113
113
 
114
- def __getattr__(self, name: str, /) -> Any:
115
- selfattr = getattr(self._dataset, name, None)
116
- return selfattr if selfattr is not None else getattr(self._dataset, name)
117
-
118
114
  def __getitem__(self, index: int) -> _TDatum:
119
115
  return self._dataset[self._selection[index]]
120
116
 
@@ -12,6 +12,7 @@ from sklearn.metrics import silhouette_score
12
12
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
13
13
  from sklearn.utils.multiclass import type_of_target
14
14
 
15
+ from dataeval.config import get_seed
15
16
  from dataeval.outputs._base import set_metadata
16
17
  from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
17
18
 
@@ -212,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
212
213
  best_score = 0.50
213
214
  bin_index = np.zeros(len(array), dtype=np.intp)
214
215
  for k in range(2, 20):
215
- clusterer = KMeans(n_clusters=k)
216
+ clusterer = KMeans(n_clusters=k, random_state=get_seed())
216
217
  cluster_labels = clusterer.fit_predict(array)
217
- score = silhouette_score(array, cluster_labels, sample_size=25_000)
218
+ score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
218
219
  if score > best_score:
219
220
  best_score = score
220
221
  bin_index = cluster_labels.astype(np.intp)
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Sequence
5
+ from typing import Sequence, TypeVar
6
6
 
7
7
  import numpy as np
8
8
 
@@ -10,8 +10,10 @@ from dataeval.typing import Array, ImageClassificationDatum
10
10
  from dataeval.utils._array import as_numpy
11
11
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
12
12
 
13
+ TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum, covariant=True)
13
14
 
14
- class ClassFilter(Selection[ImageClassificationDatum]):
15
+
16
+ class ClassFilter(Selection[TImageClassificationDatum]):
15
17
  """
16
18
  Filter and balance the dataset by class.
17
19
 
@@ -34,7 +36,7 @@ class ClassFilter(Selection[ImageClassificationDatum]):
34
36
  self.classes = classes
35
37
  self.balance = balance
36
38
 
37
- def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
39
+ def __call__(self, dataset: Select[TImageClassificationDatum]) -> None:
38
40
  if self.classes is None and not self.balance:
39
41
  return
40
42
 
@@ -11,13 +11,13 @@ from numpy.typing import NDArray
11
11
  from torch.utils.data import DataLoader, TensorDataset
12
12
  from tqdm import tqdm
13
13
 
14
- from dataeval.config import get_device
14
+ from dataeval.config import DeviceLike, get_device
15
15
 
16
16
 
17
17
  def predict_batch(
18
18
  x: NDArray[Any] | torch.Tensor,
19
19
  model: Callable | torch.nn.Module | torch.nn.Sequential,
20
- device: torch.device | None = None,
20
+ device: DeviceLike | None = None,
21
21
  batch_size: int = int(1e10),
22
22
  preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
23
23
  dtype: type[np.generic] | torch.dtype = np.float32,
@@ -31,9 +31,9 @@ def predict_batch(
31
31
  Batch of instances.
32
32
  model : Callable | nn.Module | nn.Sequential
33
33
  PyTorch model.
34
- device : torch.device | None, default None
35
- Device type used. The default None tries to use the GPU and falls back on CPU.
36
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
34
+ device : DeviceLike or None, default None
35
+ The hardware device to use if specified, otherwise uses the DataEval
36
+ default or torch default.
37
37
  batch_size : int, default 1e10
38
38
  Batch size used during prediction.
39
39
  preprocess_fn : Callable | None, default None
@@ -2,6 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from dataeval.config import DeviceLike, get_device
6
+
5
7
  __all__ = ["AETrainer"]
6
8
 
7
9
  from typing import Any
@@ -25,9 +27,9 @@ class AETrainer:
25
27
  ----------
26
28
  model : nn.Module
27
29
  The model to be trained.
28
- device : str or torch.device, default "auto"
29
- The hardware device to use for training.
30
- If "auto", the device will be set to "cuda" if available, otherwise "cpu".
30
+ device : DeviceLike or None, default None
31
+ The hardware device to use if specified, otherwise uses the DataEval
32
+ default or torch default.
31
33
  batch_size : int, default 8
32
34
  The number of images to process in a batch.
33
35
  """
@@ -35,13 +37,11 @@ class AETrainer:
35
37
  def __init__(
36
38
  self,
37
39
  model: nn.Module,
38
- device: str | torch.device = "auto",
40
+ device: DeviceLike | None = None,
39
41
  batch_size: int = 8,
40
42
  ):
41
- if device == "auto":
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
- self.device: torch.device = torch.device(device)
44
- self.model: nn.Module = model.to(device)
43
+ self.device: torch.device = get_device(device)
44
+ self.model: nn.Module = model.to(self.device)
45
45
  self.batch_size = batch_size
46
46
 
47
47
  def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.82.0
3
+ Version: 0.82.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT