dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +52 -43
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +198 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.0.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
- from typing import Any, Iterator, NamedTuple, Protocol
7
+ from typing import Any, Iterator, Protocol
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -13,10 +13,11 @@ from sklearn.metrics import silhouette_score
13
13
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
14
14
  from sklearn.utils.multiclass import type_of_target
15
15
 
16
- from dataeval.output import Output, set_metadata
16
+ from dataeval._output import Output, set_metadata
17
17
 
18
18
 
19
- class TrainValSplit(NamedTuple):
19
+ @dataclass
20
+ class TrainValSplit:
20
21
  """Tuple containing train and validation indices"""
21
22
 
22
23
  train: NDArray[np.intp]
@@ -274,8 +275,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
274
275
  for name, feature in features2group.items():
275
276
  if len(feature) != num_samples:
276
277
  raise ValueError(
277
- f"Feature length does not match number of labels. "
278
- f"Got {len(feature)} features and {num_samples} samples"
278
+ f"Feature length does not match number of labels. Got {len(feature)} features and {num_samples} samples"
279
279
  )
280
280
 
281
281
  if type_of_target(feature) == "continuous":
@@ -505,23 +505,22 @@ def split_dataset(
505
505
  if is_groupable(possible_groups, group_partitions):
506
506
  groups = possible_groups
507
507
 
508
- test_indices: NDArray[np.intp]
509
508
  index = np.arange(label_length)
510
509
 
511
- tv_indices, test_indices = (
510
+ tvs = (
512
511
  single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
513
512
  if test_frac
514
- else (index, np.array([], dtype=np.intp))
513
+ else TrainValSplit(index, np.array([], dtype=np.intp))
515
514
  )
516
515
 
517
- tv_labels = labels[tv_indices]
518
- tv_groups = groups[tv_indices] if groups is not None else None
516
+ tv_labels = labels[tvs.train]
517
+ tv_groups = groups[tvs.train] if groups is not None else None
519
518
 
520
519
  if num_folds == 1:
521
- tv_splits = [single_split(tv_indices, tv_labels, val_frac, tv_groups, stratify)]
520
+ tv_splits = [single_split(tvs.train, tv_labels, val_frac, tv_groups, stratify)]
522
521
  else:
523
- tv_splits = make_splits(tv_indices, tv_labels, num_folds, tv_groups, stratify)
522
+ tv_splits = make_splits(tvs.train, tv_labels, num_folds, tv_groups, stratify)
524
523
 
525
- folds: list[TrainValSplit] = [TrainValSplit(tv_indices[split.train], tv_indices[split.val]) for split in tv_splits]
524
+ folds: list[TrainValSplit] = [TrainValSplit(tvs.train[split.train], tvs.train[split.val]) for split in tv_splits]
526
525
 
527
- return SplitDatasetOutput(test_indices, folds)
526
+ return SplitDatasetOutput(tvs.val, folds)
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+
11
+ def _len(arr: NDArray, dim: int) -> int:
12
+ return 0 if len(arr) == 0 else len(np.atleast_1d(arr) if dim == 1 else np.atleast_2d(arr))
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class Targets:
17
+ """
18
+ Dataclass defining targets for image classification or object detection.
19
+
20
+ Attributes
21
+ ----------
22
+ labels : NDArray[np.intp]
23
+ Labels (N,) for N images or objects
24
+ scores : NDArray[np.float32]
25
+ Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
26
+ bboxes : NDArray[np.float32] | None
27
+ Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
28
+ source : NDArray[np.intp] | None
29
+ Source image index (N,) for N objects
30
+ """
31
+
32
+ labels: NDArray[np.intp]
33
+ scores: NDArray[np.float32]
34
+ bboxes: NDArray[np.float32] | None
35
+ source: NDArray[np.intp] | None
36
+
37
+ def __post_init__(self) -> None:
38
+ if (self.bboxes is None) != (self.source is None):
39
+ raise ValueError("Either both bboxes and source must be provided or neither.")
40
+
41
+ labels = _len(self.labels, 1)
42
+ scores = _len(self.scores, 2) if self.bboxes is None else _len(self.scores, 1)
43
+ bboxes = labels if self.bboxes is None else _len(self.bboxes, 2)
44
+ source = labels if self.source is None else _len(self.source, 1)
45
+
46
+ if labels != scores or labels != bboxes or labels != source:
47
+ raise ValueError(
48
+ "Labels, scores, bboxes and source must be the same length (if provided).\n"
49
+ + f" labels: {self.labels.shape}\n"
50
+ + f" scores: {self.scores.shape}\n"
51
+ + f" bboxes: {None if self.bboxes is None else self.bboxes.shape}\n"
52
+ + f" source: {None if self.source is None else self.source.shape}\n"
53
+ )
54
+
55
+ def __len__(self) -> int:
56
+ return len(self.labels)
57
+
58
+ def at(self, idx: int) -> Targets:
59
+ if self.source is None or self.bboxes is None:
60
+ return Targets(
61
+ np.atleast_1d(self.labels[idx]),
62
+ np.atleast_2d(self.scores[idx]),
63
+ None,
64
+ None,
65
+ )
66
+ else:
67
+ mask = np.where(self.source == idx, True, False)
68
+ return Targets(
69
+ np.atleast_1d(self.labels[mask]),
70
+ np.atleast_1d(self.scores[mask]),
71
+ np.atleast_2d(self.bboxes[mask]),
72
+ np.atleast_1d(self.source[mask]),
73
+ )
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from typing import Any, Generic, Protocol, TypedDict, TypeVar
8
+
9
+ if sys.version_info >= (3, 11):
10
+ from typing import NotRequired, Required
11
+ else:
12
+ from typing_extensions import NotRequired, Required
13
+
14
+ from torch.utils.data import Dataset as _Dataset
15
+
16
+ _TArray = TypeVar("_TArray")
17
+ _TData = TypeVar("_TData", covariant=True)
18
+ _TTarget = TypeVar("_TTarget", covariant=True)
19
+
20
+
21
+ class DatasetMetadata(TypedDict):
22
+ id: Required[str]
23
+ index2label: NotRequired[dict[int, str]]
24
+ split: NotRequired[str]
25
+
26
+
27
+ class Dataset(_Dataset[tuple[_TData, _TTarget, dict[str, Any]]]):
28
+ metadata: DatasetMetadata
29
+
30
+ def __getitem__(self, index: Any) -> tuple[_TData, _TTarget, dict[str, Any]]: ...
31
+ def __len__(self) -> int: ...
32
+
33
+
34
+ class ImageClassificationDataset(Dataset[_TArray, _TArray]): ...
35
+
36
+
37
+ @dataclass
38
+ class ObjectDetectionTarget(Generic[_TArray]):
39
+ boxes: _TArray
40
+ labels: _TArray
41
+ scores: _TArray
42
+
43
+
44
+ class ObjectDetectionDataset(Dataset[_TArray, ObjectDetectionTarget[_TArray]]): ...
45
+
46
+
47
+ @dataclass
48
+ class SegmentationTarget(Generic[_TArray]):
49
+ mask: _TArray
50
+ labels: _TArray
51
+ scores: _TArray
52
+
53
+
54
+ class SegmentationDataset(Dataset[_TArray, SegmentationTarget[_TArray]]): ...
55
+
56
+
57
+ class Transform(Generic[_TArray], Protocol):
58
+ def __call__(self, data: _TArray, /) -> _TArray: ...
@@ -0,0 +1,103 @@
1
+ """
2
+ Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Any, Iterable, Sequence, TypeVar
8
+
9
+ import numpy as np
10
+ import torch
11
+ from numpy.typing import NDArray
12
+
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._array import as_numpy
15
+
16
+ T_in = TypeVar("T_in")
17
+ T_tgt = TypeVar("T_tgt")
18
+ T_md = TypeVar("T_md")
19
+
20
+
21
+ def list_collate_fn(
22
+ batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
23
+ ) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
24
+ """
25
+ A collate function that takes a batch of individual data points in the format
26
+ (input, target, metadata) and returns three lists: the input batch, the target batch,
27
+ and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
28
+ when the target and metadata are not tensors.
29
+
30
+ Parameters
31
+ ----------
32
+ batch_data_as_singles : An iterable of (input, target, metadata) tuples.
33
+
34
+ Returns
35
+ -------
36
+ tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
37
+ A tuple of three lists: the input batch, the target batch, and the metadata batch.
38
+ """
39
+ input_batch: list[T_in] = []
40
+ target_batch: list[T_tgt] = []
41
+ metadata_batch: list[T_md] = []
42
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
43
+ input_batch.append(input_datum)
44
+ target_batch.append(target_datum)
45
+ metadata_batch.append(metadata_datum)
46
+
47
+ return input_batch, target_batch, metadata_batch
48
+
49
+
50
+ def numpy_collate_fn(
51
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
52
+ ) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
53
+ """
54
+ A collate function that takes a batch of individual data points in the format
55
+ (input, target, metadata) and returns the batched input as a single NumPy array with two
56
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
57
+
58
+ Parameters
59
+ ----------
60
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
61
+
62
+ Returns
63
+ -------
64
+ tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
65
+ A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
66
+ """
67
+ input_batch: list[NDArray[Any]] = []
68
+ target_batch: list[T_tgt] = []
69
+ metadata_batch: list[T_md] = []
70
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
71
+ input_batch.append(as_numpy(input_datum))
72
+ target_batch.append(target_datum)
73
+ metadata_batch.append(metadata_datum)
74
+
75
+ return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
76
+
77
+
78
+ def torch_collate_fn(
79
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
80
+ ) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
81
+ """
82
+ A collate function that takes a batch of individual data points in the format
83
+ (input, target, metadata) and returns the batched input as a single torch Tensor with two
84
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
85
+
86
+ Parameters
87
+ ----------
88
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
89
+
90
+ Returns
91
+ -------
92
+ tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
93
+ A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
94
+ """
95
+ input_batch: list[torch.Tensor] = []
96
+ target_batch: list[T_tgt] = []
97
+ metadata_batch: list[T_md] = []
98
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
99
+ input_batch.append(torch.as_tensor(input_datum))
100
+ target_batch.append(target_datum)
101
+ metadata_batch.append(metadata_datum)
102
+
103
+ return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
@@ -0,0 +1,17 @@
1
+ """Provides access to common Computer Vision datasets."""
2
+
3
+ from dataeval.utils.data.datasets._cifar10 import CIFAR10
4
+ from dataeval.utils.data.datasets._milco import MILCO
5
+ from dataeval.utils.data.datasets._mnist import MNIST
6
+ from dataeval.utils.data.datasets._ships import Ships
7
+ from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
8
+
9
+ __all__ = [
10
+ "MNIST",
11
+ "Ships",
12
+ "CIFAR10",
13
+ "MILCO",
14
+ "VOCDetection",
15
+ "VOCDetectionTorch",
16
+ "VOCSegmentation",
17
+ ]
@@ -0,0 +1,254 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from abc import abstractmethod
6
+ from pathlib import Path
7
+ from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
8
+
9
+ from dataeval.utils.data._types import (
10
+ Dataset,
11
+ DatasetMetadata,
12
+ ImageClassificationDataset,
13
+ ObjectDetectionDataset,
14
+ ObjectDetectionTarget,
15
+ SegmentationDataset,
16
+ SegmentationTarget,
17
+ Transform,
18
+ )
19
+ from dataeval.utils.data.datasets._fileio import _ensure_exists
20
+ from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
21
+
22
+ _TArray = TypeVar("_TArray")
23
+ _TTarget = TypeVar("_TTarget")
24
+ _TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
25
+
26
+
27
+ class DataLocation(NamedTuple):
28
+ url: str
29
+ filename: str
30
+ md5: bool
31
+ checksum: str
32
+
33
+
34
+ class BaseDataset(Dataset[_TArray, _TTarget], Generic[_TArray, _TTarget, _TRawTarget]):
35
+ """
36
+ Base class for internet downloaded datasets.
37
+ """
38
+
39
+ # Each subclass should override the attributes below.
40
+ # Each resource tuple must contain:
41
+ # 'url': str, the URL to download from
42
+ # 'filename': str, the name of the file once downloaded
43
+ # 'md5': boolean, True if it's the checksum value is md5
44
+ # 'checksum': str, the associated checksum for the downloaded file
45
+ _resources: list[DataLocation]
46
+ _resource_index: int = 0
47
+ index2label: dict[int, str]
48
+
49
+ def __init__(
50
+ self,
51
+ root: str | Path,
52
+ download: bool = False,
53
+ image_set: Literal["train", "val", "test", "base"] = "train",
54
+ transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
55
+ verbose: bool = False,
56
+ ) -> None:
57
+ self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
58
+ transforms = transforms if transforms is not None else []
59
+ self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
60
+ self.image_set = image_set
61
+ self._verbose = verbose
62
+
63
+ # Internal Attributes
64
+ self._download = download
65
+ self._filepaths: list[str]
66
+ self._targets: _TRawTarget
67
+ self._datum_metadata: dict[str, list[Any]]
68
+ self._resource: DataLocation = self._resources[self._resource_index]
69
+ self._label2index = {v: k for k, v in self.index2label.items()}
70
+
71
+ self.metadata: DatasetMetadata = DatasetMetadata(
72
+ id=self._unique_id(),
73
+ index2label=self.index2label,
74
+ split=self.image_set,
75
+ )
76
+
77
+ # Load the data
78
+ self.path: Path = self._get_dataset_dir()
79
+ self._filepaths, self._targets, self._datum_metadata = self._load_data()
80
+ self.size: int = len(self._filepaths)
81
+
82
+ def __str__(self) -> str:
83
+ nt = "\n "
84
+ title = f"{self.__class__.__name__} Dataset"
85
+ sep = "-" * len(title)
86
+ attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
87
+ return f"{title}\n{sep}{nt}{nt.join(attrs)}"
88
+
89
+ @property
90
+ def label2index(self) -> dict[str, int]:
91
+ return self._label2index
92
+
93
+ def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, dict[str, Any]]]:
94
+ for i in range(len(self)):
95
+ yield self[i]
96
+
97
+ def _get_dataset_dir(self) -> Path:
98
+ # Create a designated folder for this dataset (named after the class)
99
+ if self._root.stem in [
100
+ self.__class__.__name__.lower(),
101
+ self.__class__.__name__.upper(),
102
+ self.__class__.__name__,
103
+ ]:
104
+ dataset_dir: Path = self._root
105
+ else:
106
+ dataset_dir: Path = self._root / self.__class__.__name__.lower()
107
+ if not dataset_dir.exists():
108
+ dataset_dir.mkdir(parents=True, exist_ok=True)
109
+ return dataset_dir
110
+
111
+ def _unique_id(self) -> str:
112
+ unique_id = f"{self.__class__.__name__}_{self.image_set}"
113
+ return unique_id
114
+
115
+ def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
116
+ """
117
+ Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
118
+ """
119
+ if self._verbose:
120
+ print(f"Determining if {self._resource.filename} needs to be downloaded.")
121
+
122
+ try:
123
+ result = self._load_data_inner()
124
+ if self._verbose:
125
+ print("No download needed, loaded data successfully.")
126
+ except FileNotFoundError:
127
+ _ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
128
+ result = self._load_data_inner()
129
+ return result
130
+
131
+ @abstractmethod
132
+ def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
133
+
134
+ def _transform(self, image: _TArray) -> _TArray:
135
+ """Function to transform the image prior to returning based on parameters passed in."""
136
+ for transform in self.transforms:
137
+ image = transform(image)
138
+ return image
139
+
140
+ def __len__(self) -> int:
141
+ return self.size
142
+
143
+
144
+ class BaseICDataset(
145
+ BaseDataset[_TArray, _TArray, list[int]],
146
+ BaseDatasetMixin[_TArray],
147
+ ImageClassificationDataset[_TArray],
148
+ ):
149
+ """
150
+ Base class for image classification datasets.
151
+ """
152
+
153
+ def __getitem__(self, index: int) -> tuple[_TArray, _TArray, dict[str, Any]]:
154
+ """
155
+ Args
156
+ ----
157
+ index : int
158
+ Value of the desired data point
159
+
160
+ Returns
161
+ -------
162
+ tuple[TArray, TArray, dict[str, Any]]
163
+ Image, target, datum_metadata - where target is one-hot encoding of class.
164
+ """
165
+ # Get the associated label and score
166
+ label = self._targets[index]
167
+ score = self._one_hot_encode(label)
168
+ # Get the image
169
+ img = self._read_file(self._filepaths[index])
170
+ img = self._transform(img)
171
+
172
+ img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
173
+
174
+ return img, score, img_metadata
175
+
176
+
177
+ class BaseODDataset(
178
+ BaseDataset[_TArray, ObjectDetectionTarget[_TArray], list[str]],
179
+ BaseDatasetMixin[_TArray],
180
+ ObjectDetectionDataset[_TArray],
181
+ ):
182
+ """
183
+ Base class for object detection datasets.
184
+ """
185
+
186
+ def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
187
+ """
188
+ Args
189
+ ----
190
+ index : int
191
+ Value of the desired data point
192
+
193
+ Returns
194
+ -------
195
+ tuple[TArray, ObjectDetectionTarget[TArray], dict[str, Any]]
196
+ Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
197
+ """
198
+ # Grab the bounding boxes and labels from the annotations
199
+ boxes, labels, additional_metadata = self._read_annotations(self._targets[index])
200
+ # Get the image
201
+ img = self._read_file(self._filepaths[index])
202
+ img = self._transform(img)
203
+
204
+ target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
205
+
206
+ img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
207
+ img_metadata = img_metadata | additional_metadata
208
+
209
+ return img, target, img_metadata
210
+
211
+ @abstractmethod
212
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
213
+
214
+
215
+ class BaseSegDataset(
216
+ BaseDataset[_TArray, SegmentationTarget[_TArray], list[str]],
217
+ BaseDatasetMixin[_TArray],
218
+ SegmentationDataset[_TArray],
219
+ ):
220
+ """
221
+ Base class for segmentation datasets.
222
+ """
223
+
224
+ _masks: Sequence[str]
225
+
226
+ def __getitem__(self, index: int) -> tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]:
227
+ """
228
+ Args
229
+ ----
230
+ index : int
231
+ Value of the desired data point
232
+
233
+ Returns
234
+ -------
235
+ tuple[TArray, SegmentationTarget[TArray], dict[str, Any]]
236
+ Image, target, datum_metadata - target.mask returns the ground truth mask
237
+ """
238
+ # Grab the labels from the annotations
239
+ _, labels, additional_metadata = self._read_annotations(self._targets[index])
240
+ # Grab the ground truth masks
241
+ mask = self._read_file(self._masks[index])
242
+ # Get the image
243
+ img = self._read_file(self._filepaths[index])
244
+ img = self._transform(img)
245
+
246
+ target = SegmentationTarget(mask, self._as_array(labels), self._one_hot_encode(labels))
247
+
248
+ img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
249
+ img_metadata = img_metadata | additional_metadata
250
+
251
+ return img, target, img_metadata
252
+
253
+ @abstractmethod
254
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...