dataeval 0.86.0__tar.gz → 0.86.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. {dataeval-0.86.0 → dataeval-0.86.2}/PKG-INFO +2 -1
  2. {dataeval-0.86.0 → dataeval-0.86.2}/pyproject.toml +6 -4
  3. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/__init__.py +1 -1
  4. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/_log.py +1 -1
  5. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/config.py +21 -4
  6. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/_embeddings.py +2 -2
  7. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/_images.py +2 -3
  8. dataeval-0.86.2/src/dataeval/data/_metadata.py +392 -0
  9. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/_selection.py +1 -2
  10. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/_split.py +4 -5
  11. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/_targets.py +17 -13
  12. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_classfilter.py +2 -5
  13. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_prioritize.py +6 -9
  14. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_shuffle.py +3 -1
  15. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_base.py +4 -5
  16. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_mmd.py +3 -6
  17. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/_base.py +4 -2
  18. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/_chunk.py +11 -19
  19. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  20. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/_result.py +8 -9
  21. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  22. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/linters/outliers.py +7 -7
  23. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metadata/_distance.py +10 -7
  24. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metadata/_ood.py +11 -103
  25. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/_balance.py +23 -33
  26. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/_diversity.py +16 -14
  27. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/_parity.py +18 -18
  28. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/estimators/_divergence.py +2 -4
  29. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_base.py +103 -42
  30. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_boxratiostats.py +21 -19
  31. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_dimensionstats.py +14 -10
  32. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_hashstats.py +1 -1
  33. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_pixelstats.py +6 -6
  34. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_visualstats.py +3 -3
  35. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_base.py +22 -7
  36. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_bias.py +24 -70
  37. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_drift.py +1 -9
  38. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_linters.py +11 -11
  39. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_stats.py +82 -23
  40. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_workflows.py +2 -2
  41. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_array.py +6 -9
  42. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_bin.py +1 -2
  43. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_clusterer.py +7 -4
  44. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_fast_mst.py +27 -13
  45. dataeval-0.86.2/src/dataeval/utils/_image.py +127 -0
  46. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_mst.py +1 -3
  47. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_plot.py +15 -10
  48. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/data/_dataset.py +54 -28
  49. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/data/metadata.py +104 -82
  50. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/__init__.py +2 -0
  51. dataeval-0.86.2/src/dataeval/utils/datasets/_antiuav.py +189 -0
  52. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_base.py +11 -8
  53. dataeval-0.86.2/src/dataeval/utils/datasets/_cifar10.py +201 -0
  54. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_fileio.py +21 -47
  55. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_milco.py +22 -12
  56. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_mixin.py +2 -4
  57. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_mnist.py +3 -4
  58. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_ships.py +14 -7
  59. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_voc.py +229 -42
  60. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/models.py +5 -10
  61. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/trainer.py +3 -3
  62. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/workflows/sufficiency.py +2 -2
  63. dataeval-0.86.0/src/dataeval/data/_metadata.py +0 -382
  64. dataeval-0.86.0/src/dataeval/detectors/ood/vae.py +0 -74
  65. dataeval-0.86.0/src/dataeval/utils/_image.py +0 -73
  66. dataeval-0.86.0/src/dataeval/utils/datasets/_cifar10.py +0 -142
  67. {dataeval-0.86.0 → dataeval-0.86.2}/LICENSE.txt +0 -0
  68. {dataeval-0.86.0 → dataeval-0.86.2}/README.md +0 -0
  69. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/__init__.py +0 -0
  70. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/__init__.py +0 -0
  71. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_classbalance.py +0 -0
  72. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_indices.py +0 -0
  73. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_limit.py +0 -0
  74. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/data/selections/_reverse.py +0 -0
  75. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/__init__.py +0 -0
  76. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/__init__.py +0 -0
  77. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_cvm.py +0 -0
  78. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_ks.py +0 -0
  79. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_mvdc.py +0 -0
  80. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_nml/__init__.py +0 -0
  81. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/_uncertainty.py +0 -0
  82. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/drift/updates.py +0 -0
  83. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/linters/__init__.py +0 -0
  84. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/linters/duplicates.py +0 -0
  85. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/ood/__init__.py +0 -0
  86. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/ood/ae.py +0 -0
  87. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/ood/base.py +0 -0
  88. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/detectors/ood/mixin.py +0 -0
  89. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metadata/__init__.py +0 -0
  90. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metadata/_utils.py +0 -0
  91. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/__init__.py +0 -0
  92. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/__init__.py +0 -0
  93. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/_completeness.py +0 -0
  94. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/bias/_coverage.py +0 -0
  95. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/estimators/__init__.py +0 -0
  96. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/estimators/_ber.py +0 -0
  97. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/estimators/_clusterer.py +0 -0
  98. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/estimators/_uap.py +0 -0
  99. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/__init__.py +0 -0
  100. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_imagestats.py +0 -0
  101. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/metrics/stats/_labelstats.py +0 -0
  102. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/__init__.py +0 -0
  103. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_estimators.py +0 -0
  104. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_metadata.py +0 -0
  105. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_ood.py +0 -0
  106. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/outputs/_utils.py +0 -0
  107. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/py.typed +0 -0
  108. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/typing.py +0 -0
  109. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/__init__.py +0 -0
  110. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/_method.py +0 -0
  111. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/data/__init__.py +0 -0
  112. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/data/collate.py +0 -0
  113. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/datasets/_types.py +0 -0
  114. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/__init__.py +0 -0
  115. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/_blocks.py +0 -0
  116. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/_gmm.py +0 -0
  117. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/utils/torch/_internal.py +0 -0
  118. {dataeval-0.86.0 → dataeval-0.86.2}/src/dataeval/workflows/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.86.0
3
+ Version: 0.86.2
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -29,6 +29,7 @@ Requires-Dist: numba (>=0.59.1)
29
29
  Requires-Dist: numpy (>=1.24.2)
30
30
  Requires-Dist: pandas (>=2.0)
31
31
  Requires-Dist: pillow (>=10.3.0)
32
+ Requires-Dist: polars (>=1.0.0)
32
33
  Requires-Dist: requests
33
34
  Requires-Dist: scikit-learn (>=1.5.0)
34
35
  Requires-Dist: scipy (>=1.10)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.86.0" # dynamic
3
+ version = "0.86.2" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -49,6 +49,7 @@ numba = {version = ">=0.59.1"}
49
49
  numpy = {version = ">=1.24.2"}
50
50
  pandas = {version = ">=2.0"}
51
51
  pillow = {version = ">=10.3.0"}
52
+ polars = {version = ">=1.0.0"}
52
53
  requests = {version = "*"}
53
54
  scipy = {version = ">=1.10"}
54
55
  scikit-learn = {version = ">=1.5.0"}
@@ -134,6 +135,7 @@ markers = [
134
135
  "optional: marks tests for optional features",
135
136
  "requires_all: marks tests that require the all extras",
136
137
  "cuda: marks tests that require cuda",
138
+ "year: marks tests that need a specified dataset year",
137
139
  ]
138
140
 
139
141
  [tool.coverage.run]
@@ -175,12 +177,12 @@ target-version = "py38"
175
177
  extend-include = ["*.ipynb"]
176
178
 
177
179
  [tool.ruff.lint]
178
- select = ["A", "E", "F", "C4", "I", "UP", "NPY", "SIM", "RUF100"]
179
- ignore = ["NPY002"]
180
+ select = ["A", "ANN", "C4", "C90", "E", "F", "I", "NPY", "S", "SIM", "RET", "RUF100", "UP"]
181
+ ignore = ["ANN401", "NPY002"]
180
182
  fixable = ["ALL"]
181
183
  unfixable = []
182
184
  dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
183
- per-file-ignores = { "*.ipynb" = ["E402"] }
185
+ per-file-ignores = { "*.ipynb" = ["E402"], "!src/*" = ["ANN", "S", "RET"]}
184
186
 
185
187
  [tool.ruff.lint.isort]
186
188
  known-first-party = ["dataeval"]
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.86.0"
11
+ __version__ = "0.86.2"
12
12
 
13
13
  import logging
14
14
 
@@ -8,7 +8,7 @@ class LogMessage:
8
8
  Deferred message callback for logging expensive messages.
9
9
  """
10
10
 
11
- def __init__(self, fn: Callable[..., str]):
11
+ def __init__(self, fn: Callable[..., str]) -> None:
12
12
  self._fn = fn
13
13
  self._str = None
14
14
 
@@ -4,10 +4,10 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "use_max_processes", "DeviceLike"]
8
8
 
9
9
  import sys
10
- from typing import Union
10
+ from typing import Any, Union
11
11
 
12
12
  if sys.version_info >= (3, 10):
13
13
  from typing import TypeAlias
@@ -78,8 +78,7 @@ def get_device(override: DeviceLike | None = None) -> torch.device:
78
78
  if override is None:
79
79
  global _device
80
80
  return torch.get_default_device() if _device is None else _device
81
- else:
82
- return _todevice(override)
81
+ return _todevice(override)
83
82
 
84
83
 
85
84
  def set_max_processes(processes: int | None) -> None:
@@ -112,6 +111,24 @@ def get_max_processes() -> int | None:
112
111
  return _processes
113
112
 
114
113
 
114
+ class MaxProcessesContextManager:
115
+ def __init__(self, processes: int) -> None:
116
+ self._processes = processes
117
+
118
+ def __enter__(self) -> None:
119
+ global _processes
120
+ self._old = _processes
121
+ set_max_processes(self._processes)
122
+
123
+ def __exit__(self, *args: tuple[Any, ...]) -> None:
124
+ global _processes
125
+ _processes = self._old
126
+
127
+
128
+ def use_max_processes(processes: int) -> MaxProcessesContextManager:
129
+ return MaxProcessesContextManager(processes)
130
+
131
+
115
132
  def set_seed(seed: int | None, all_generators: bool = False) -> None:
116
133
  """
117
134
  Sets the seed for use by classes that allow for a random state or seed.
@@ -144,8 +144,7 @@ class Embeddings:
144
144
  """
145
145
  if indices is not None:
146
146
  return torch.vstack(list(self._batch(indices))).to(self.device)
147
- else:
148
- return self[:]
147
+ return self[:]
149
148
 
150
149
  def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
151
150
  """
@@ -248,6 +247,7 @@ class Embeddings:
248
247
  _logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
249
248
  except Exception as e:
250
249
  _logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
250
+ raise e
251
251
 
252
252
  @classmethod
253
253
  def load(cls, path: Path | str) -> Embeddings:
@@ -73,15 +73,14 @@ class Images(Generic[T]):
73
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
74
  if isinstance(key, slice):
75
75
  return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
- elif hasattr(key, "__int__"):
76
+ if hasattr(key, "__int__"):
77
77
  return self._get_image(int(key))
78
78
  raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
79
 
80
80
  def _get_image(self, index: int) -> T:
81
81
  if self._is_tuple_datum:
82
82
  return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
83
- else:
84
- return cast(Dataset[T], self._dataset)[index]
83
+ return cast(Dataset[T], self._dataset)[index]
85
84
 
86
85
  def __iter__(self) -> Iterator[T]:
87
86
  for i in range(len(self._dataset)):
@@ -0,0 +1,392 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence, Sized
8
+
9
+ import numpy as np
10
+ import polars as pl
11
+ from numpy.typing import NDArray
12
+
13
+ from dataeval.typing import (
14
+ AnnotatedDataset,
15
+ Array,
16
+ ObjectDetectionTarget,
17
+ )
18
+ from dataeval.utils._array import as_numpy
19
+ from dataeval.utils._bin import bin_data, digitize_data
20
+ from dataeval.utils.data.metadata import merge
21
+
22
+ if TYPE_CHECKING:
23
+ from dataeval.data import Targets
24
+ else:
25
+ from dataeval.data._targets import Targets
26
+
27
+
28
+ @dataclass
29
+ class FactorInfo:
30
+ factor_type: Literal["categorical", "continuous", "discrete"] | None = None
31
+ discretized_col: str | None = None
32
+
33
+
34
+ class Metadata:
35
+ """
36
+ Class containing binned metadata using Polars DataFrames.
37
+
38
+ Parameters
39
+ ----------
40
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
41
+ Dataset to access original targets and metadata from.
42
+ continuous_factor_bins : Mapping[str, int | Sequence[float]] | None, default None
43
+ Mapping from continuous factor name to the number of bins or bin edges
44
+ auto_bin_method : Literal["uniform_width", "uniform_count", "clusters"], default "uniform_width"
45
+ Method for automatically determining the number of bins for continuous factors
46
+ exclude : Sequence[str] | None, default None
47
+ Filter metadata factors to exclude the specified factors, cannot be set with `include`
48
+ include : Sequence[str] | None, default None
49
+ Filter metadata factors to include the specified factors, cannot be set with `exclude`
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
55
+ *,
56
+ continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
57
+ auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
58
+ exclude: Sequence[str] | None = None,
59
+ include: Sequence[str] | None = None,
60
+ ) -> None:
61
+ self._targets: Targets
62
+ self._class_labels: NDArray[np.intp]
63
+ self._class_names: list[str]
64
+ self._image_indices: NDArray[np.intp]
65
+ self._factors: dict[str, FactorInfo]
66
+ self._dropped_factors: dict[str, list[str]]
67
+ self._dataframe: pl.DataFrame
68
+
69
+ self._is_structured = False
70
+ self._is_binned = False
71
+
72
+ self._dataset = dataset
73
+ self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
74
+ self._auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = auto_bin_method
75
+
76
+ if exclude is not None and include is not None:
77
+ raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
78
+
79
+ self._exclude = set(exclude or ())
80
+ self._include = set(include or ())
81
+
82
+ @property
83
+ def targets(self) -> Targets:
84
+ """Target information for the dataset."""
85
+ self._structure()
86
+ return self._targets
87
+
88
+ @property
89
+ def raw(self) -> list[dict[str, Any]]:
90
+ """The raw list of metadata dictionaries for the dataset."""
91
+ self._structure()
92
+ return self._raw
93
+
94
+ @property
95
+ def exclude(self) -> set[str]:
96
+ """Factors to exclude from the metadata."""
97
+ return self._exclude
98
+
99
+ @exclude.setter
100
+ def exclude(self, value: Sequence[str]) -> None:
101
+ exclude = set(value)
102
+ if self._exclude != exclude:
103
+ self._exclude = exclude
104
+ self._include = set()
105
+ self._is_binned = False
106
+
107
+ @property
108
+ def include(self) -> set[str]:
109
+ """Factors to include from the metadata."""
110
+ return self._include
111
+
112
+ @include.setter
113
+ def include(self, value: Sequence[str]) -> None:
114
+ include = set(value)
115
+ if self._include != include:
116
+ self._include = include
117
+ self._exclude = set()
118
+ self._is_binned = False
119
+
120
+ @property
121
+ def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
122
+ """Map of factor names to bin counts or bin edges."""
123
+ return self._continuous_factor_bins
124
+
125
+ @continuous_factor_bins.setter
126
+ def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
127
+ if self._continuous_factor_bins != bins:
128
+ self._continuous_factor_bins = dict(bins)
129
+ self._reset_bins(bins)
130
+
131
+ @property
132
+ def auto_bin_method(self) -> Literal["uniform_width", "uniform_count", "clusters"]:
133
+ """Binning method to use when continuous_factor_bins is not defined."""
134
+ return self._auto_bin_method
135
+
136
+ @auto_bin_method.setter
137
+ def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
138
+ if self._auto_bin_method != method:
139
+ self._auto_bin_method = method
140
+ self._reset_bins()
141
+
142
+ @property
143
+ def dataframe(self) -> pl.DataFrame:
144
+ """Dataframe containing target information and metadata factors."""
145
+ self._structure()
146
+ return self._dataframe
147
+
148
+ @property
149
+ def dropped_factors(self) -> dict[str, list[str]]:
150
+ """Factors that were dropped during preprocessing and the reasons why they were dropped."""
151
+ self._structure()
152
+ return self._dropped_factors
153
+
154
+ @property
155
+ def discretized_data(self) -> NDArray[np.int64]:
156
+ """Factor data with continuous data discretized."""
157
+ if not self.factor_names:
158
+ return np.array([], dtype=np.int64)
159
+
160
+ self._bin()
161
+ return (
162
+ self.dataframe.select([info.discretized_col or name for name, info in self.factor_info.items()])
163
+ .to_numpy()
164
+ .astype(np.int64)
165
+ )
166
+
167
+ @property
168
+ def factor_names(self) -> list[str]:
169
+ """Factor names of the metadata."""
170
+ self._structure()
171
+ return list(self._factors)
172
+
173
+ @property
174
+ def factor_info(self) -> dict[str, FactorInfo]:
175
+ """Factor types of the metadata."""
176
+ self._bin()
177
+ return self._factors
178
+
179
+ @property
180
+ def factor_data(self) -> NDArray[Any]:
181
+ """Factor data as a NumPy array."""
182
+ if not self.factor_names:
183
+ return np.array([], dtype=np.float64)
184
+
185
+ # Extract continuous columns and convert to NumPy array
186
+ return self.dataframe.select(self.factor_names).to_numpy()
187
+
188
+ @property
189
+ def class_labels(self) -> NDArray[np.intp]:
190
+ """Class labels as a NumPy array."""
191
+ self._structure()
192
+ return self._class_labels
193
+
194
+ @property
195
+ def class_names(self) -> list[str]:
196
+ """Class names as a list of strings."""
197
+ self._structure()
198
+ return self._class_names
199
+
200
+ @property
201
+ def image_indices(self) -> NDArray[np.intp]:
202
+ """Indices of images as a NumPy array."""
203
+ self._bin()
204
+ return self._image_indices
205
+
206
+ @property
207
+ def image_count(self) -> int:
208
+ self._bin()
209
+ return int(self._image_indices.max() + 1)
210
+
211
+ def _reset_bins(self, cols: Iterable[str] | None = None) -> None:
212
+ if self._is_binned:
213
+ columns = self._dataframe.columns
214
+ for col in (col for col in cols or columns if f"{col}[|]" in columns):
215
+ self._dataframe.drop_in_place(f"{col}[|]")
216
+ self._factors[col] = FactorInfo()
217
+ self._is_binned = False
218
+
219
+ def _structure(self) -> None:
220
+ if self._is_structured:
221
+ return
222
+
223
+ raw: list[dict[str, Any]] = []
224
+
225
+ labels = []
226
+ bboxes = []
227
+ scores = []
228
+ srcidx = []
229
+ is_od = None
230
+ for i in range(len(self._dataset)):
231
+ _, target, metadata = self._dataset[i]
232
+
233
+ raw.append(metadata)
234
+
235
+ if is_od_target := isinstance(target, ObjectDetectionTarget):
236
+ target_labels = as_numpy(target.labels)
237
+ target_len = len(target_labels)
238
+ labels.extend(target_labels.tolist())
239
+ bboxes.extend(as_numpy(target.boxes).tolist())
240
+ scores.extend(as_numpy(target.scores).tolist())
241
+ srcidx.extend([i] * target_len)
242
+ elif isinstance(target, Array):
243
+ target_len = 1
244
+ labels.append(int(np.argmax(as_numpy(target))))
245
+ scores.append(target)
246
+ else:
247
+ raise TypeError("Encountered unsupported target type in dataset")
248
+
249
+ is_od = is_od_target if is_od is None else is_od
250
+ if is_od != is_od_target:
251
+ raise ValueError("Encountered unexpected target type in dataset")
252
+
253
+ labels = as_numpy(labels).astype(np.intp)
254
+ scores = as_numpy(scores).astype(np.float32)
255
+ bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
256
+ srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
257
+
258
+ target_dict = {
259
+ "image_index": srcidx if srcidx is not None else np.arange(len(labels)),
260
+ "class_label": labels,
261
+ "score": scores,
262
+ "box": bboxes if bboxes is not None else [None] * len(labels),
263
+ }
264
+
265
+ self._targets = Targets(labels, scores, bboxes, srcidx)
266
+ self._raw = raw
267
+
268
+ index2label = self._dataset.metadata.get("index2label", {})
269
+ self._class_labels = labels
270
+ self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
271
+ self._image_indices = target_dict["image_index"]
272
+
273
+ targets_per_image = None if srcidx is None else np.unique(srcidx, return_counts=True)[1].tolist()
274
+ merged = merge(raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
275
+
276
+ reserved = ["image_index", "class_label", "score", "box"]
277
+ factor_dict = {f"metadata_{k}" if k in reserved else k: v for k, v in merged[0].items() if k != "_image_index"}
278
+
279
+ self._factors = dict.fromkeys(factor_dict, FactorInfo())
280
+ self._dataframe = pl.DataFrame({**target_dict, **factor_dict})
281
+ self._dropped_factors = merged[1]
282
+ self._is_structured = True
283
+
284
+ def _bin(self) -> None:
285
+ """Populate factor info and bin non-categorical factors."""
286
+ if self._is_binned:
287
+ return
288
+
289
+ # Start with an empty set of factor info
290
+ factor_info: dict[str, FactorInfo] = {}
291
+
292
+ # Create a mutable DataFrame for updates
293
+ df = self.dataframe.clone()
294
+ factor_bins = self.continuous_factor_bins
295
+
296
+ # Check for invalid keys
297
+ invalid_keys = set(factor_bins.keys()) - set(df.columns)
298
+ if invalid_keys:
299
+ warnings.warn(
300
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
301
+ "but are not columns in the metadata DataFrame. Unknown keys will be ignored."
302
+ )
303
+
304
+ column_set = set(df.columns)
305
+ for col in (col for col in self.factor_names if f"{col}[|]" not in column_set):
306
+ # Get data as numpy array for processing
307
+ data = df[col].to_numpy()
308
+ col_dz = f"{col}[|]"
309
+ if col in factor_bins:
310
+ # User provided binning
311
+ bins = factor_bins[col]
312
+ df = df.with_columns(pl.Series(name=col_dz, values=digitize_data(data, bins).astype(np.int64)))
313
+ factor_info[col] = FactorInfo("continuous", col_dz)
314
+ else:
315
+ # Check if data is numeric
316
+ unique, ordinal = np.unique(data, return_inverse=True)
317
+ if not np.issubdtype(data.dtype, np.number) or unique.size <= max(20, data.size * 0.01):
318
+ # Non-numeric data or small number of unique values - convert to categorical
319
+ df = df.with_columns(pl.Series(name=col_dz, values=ordinal.astype(np.int64)))
320
+ factor_info[col] = FactorInfo("categorical", col_dz)
321
+ elif data.dtype == float:
322
+ # Many unique values - discretize by binning
323
+ warnings.warn(
324
+ f"A user defined binning was not provided for {col}. "
325
+ f"Using the {self.auto_bin_method} method to discretize the data. "
326
+ "It is recommended that the user rerun and supply the desired "
327
+ "bins using the continuous_factor_bins parameter.",
328
+ UserWarning,
329
+ )
330
+ # Create binned version
331
+ binned_data = bin_data(data, self.auto_bin_method)
332
+ df = df.with_columns(pl.Series(name=col_dz, values=binned_data.astype(np.int64)))
333
+ factor_info[col] = FactorInfo("continuous", col_dz)
334
+ else:
335
+ factor_info[col] = FactorInfo("discrete", col_dz)
336
+
337
+ # Store the results
338
+ self._dataframe = df
339
+ self._factors.update(factor_info)
340
+ self._is_binned = True
341
+
342
+ def get_factors_by_type(self, factor_type: Literal["categorical", "continuous", "discrete"]) -> list[str]:
343
+ """
344
+ Get the names of factors of a specific type.
345
+
346
+ Parameters
347
+ ----------
348
+ factor_type : Literal["categorical", "continuous", "discrete"]
349
+ The type of factors to retrieve.
350
+
351
+ Returns
352
+ -------
353
+ list[str]
354
+ List of factor names of the specified type.
355
+ """
356
+ self._bin()
357
+ return [name for name, info in self.factor_info.items() if info.factor_type == factor_type]
358
+
359
+ def add_factors(self, factors: Mapping[str, Any]) -> None:
360
+ """
361
+ Add additional factors to the metadata.
362
+
363
+ The number of measures per factor must match the number of images
364
+ in the dataset or the number of detections in the dataset.
365
+
366
+ Parameters
367
+ ----------
368
+ factors : Mapping[str, ArrayLike]
369
+ Dictionary of factors to add to the metadata.
370
+ """
371
+ self._structure()
372
+
373
+ targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
374
+ images = self.image_count
375
+ lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
376
+ targets_match = all(f == targets for f in lengths.values())
377
+ images_match = targets_match if images == targets else all(f == images for f in lengths.values())
378
+ if not targets_match and not images_match:
379
+ raise ValueError(
380
+ "The lists/arrays in the provided factors have a different length than the current metadata factors."
381
+ )
382
+
383
+ new_columns = []
384
+ for k, v in factors.items():
385
+ v = as_numpy(v)
386
+ data = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
387
+ new_columns.append(pl.Series(name=k, values=data))
388
+ self._factors[k] = FactorInfo()
389
+
390
+ if new_columns:
391
+ self._dataframe = self.dataframe.with_columns(new_columns)
392
+ self._is_binned = False
@@ -110,8 +110,7 @@ class Select(AnnotatedDataset[_TDatum]):
110
110
  grouped: dict[int, list[Selection[_TDatum]]] = {}
111
111
  for selection in selections_list:
112
112
  grouped.setdefault(selection.stage, []).append(selection)
113
- selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
114
- return selection_list
113
+ return [selection for category in sorted(grouped) for selection in grouped[category]]
115
114
 
116
115
  def _apply_selections(self) -> None:
117
116
  for selection in self._selections:
@@ -23,7 +23,7 @@ _logger = logging.getLogger(__name__)
23
23
  class KFoldSplitter(Protocol):
24
24
  """Protocol covering sklearn KFold variant splitters"""
25
25
 
26
- def __init__(self, n_splits: int): ...
26
+ def __init__(self, n_splits: int) -> None: ...
27
27
  def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
28
28
 
29
29
 
@@ -207,10 +207,9 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
207
207
  return None
208
208
 
209
209
  split_set = set(split_on)
210
- indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
- binned_features = metadata.discrete_data[:, indices]
212
- group_ids = np.unique(binned_features, axis=0, return_inverse=True)[1]
213
- return group_ids
210
+ indices = [i for i, name in enumerate(metadata.factor_names) if name in split_set]
211
+ binned_features = metadata.discretized_data[:, indices]
212
+ return np.unique(binned_features, axis=0, return_inverse=True)[1]
214
213
 
215
214
 
216
215
  def make_splits(
@@ -24,11 +24,13 @@ class Targets:
24
24
  labels : NDArray[np.intp]
25
25
  Labels (N,) for N images or objects
26
26
  scores : NDArray[np.float32]
27
- Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
27
+ Probability scores (N, M) for N images of M classes or confidence score (N,) of objects
28
28
  bboxes : NDArray[np.float32] | None
29
- Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
29
+ Bounding boxes (N, 4) for N objects in (x0, y0, x1, y1) format
30
30
  source : NDArray[np.intp] | None
31
31
  Source image index (N,) for N objects
32
+ size : int
33
+ Count of objects
32
34
  """
33
35
 
34
36
  labels: NDArray[np.intp]
@@ -55,13 +57,16 @@ class Targets:
55
57
  )
56
58
 
57
59
  if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
58
- raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
60
+ raise ValueError("Bounding boxes must be in (x0, y0, x1, y1) format.")
61
+
62
+ @property
63
+ def size(self) -> int:
64
+ return len(self.labels)
59
65
 
60
66
  def __len__(self) -> int:
61
67
  if self.source is None:
62
68
  return len(self.labels)
63
- else:
64
- return len(np.unique(self.source))
69
+ return len(np.unique(self.source))
65
70
 
66
71
  def __getitem__(self, idx: int, /) -> Targets:
67
72
  if self.source is None or self.bboxes is None:
@@ -71,14 +76,13 @@ class Targets:
71
76
  None,
72
77
  None,
73
78
  )
74
- else:
75
- mask = np.where(self.source == idx, True, False)
76
- return Targets(
77
- np.atleast_1d(self.labels[mask]),
78
- np.atleast_1d(self.scores[mask]),
79
- np.atleast_2d(self.bboxes[mask]),
80
- np.atleast_1d(self.source[mask]),
81
- )
79
+ mask = np.where(self.source == idx, True, False)
80
+ return Targets(
81
+ np.atleast_1d(self.labels[mask]),
82
+ np.atleast_1d(self.scores[mask]),
83
+ np.atleast_2d(self.bboxes[mask]),
84
+ np.atleast_1d(self.source[mask]),
85
+ )
82
86
 
83
87
  def __iter__(self) -> Iterator[Targets]:
84
88
  for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
@@ -68,11 +68,8 @@ _TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
68
68
 
69
69
 
70
70
  def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
71
- if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
72
- if isinstance(obj, Array):
73
- return obj[mask]
74
- elif isinstance(obj, Sequence):
75
- return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
71
+ if not isinstance(obj, (str, bytes, bytearray)) and isinstance(obj, (Sequence, Array)) and len(obj) == len(mask):
72
+ return obj[mask] if isinstance(obj, Array) else cast(_T, [item for i, item in enumerate(obj) if mask[i]])
76
73
  return obj
77
74
 
78
75