dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +188 -178
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +4 -5
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metadata/_distance.py +10 -7
  22. dataeval/metadata/_ood.py +11 -103
  23. dataeval/metrics/bias/_balance.py +23 -33
  24. dataeval/metrics/bias/_diversity.py +16 -14
  25. dataeval/metrics/bias/_parity.py +18 -18
  26. dataeval/metrics/estimators/_divergence.py +2 -4
  27. dataeval/metrics/stats/_base.py +103 -42
  28. dataeval/metrics/stats/_boxratiostats.py +21 -19
  29. dataeval/metrics/stats/_dimensionstats.py +14 -10
  30. dataeval/metrics/stats/_hashstats.py +1 -1
  31. dataeval/metrics/stats/_pixelstats.py +6 -6
  32. dataeval/metrics/stats/_visualstats.py +3 -3
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +24 -70
  35. dataeval/outputs/_drift.py +1 -9
  36. dataeval/outputs/_linters.py +11 -11
  37. dataeval/outputs/_stats.py +82 -23
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +54 -28
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +22 -12
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
  62. dataeval-0.86.2.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.86.0.dist-info/RECORD +0 -114
  65. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.86.0"
11
+ __version__ = "0.86.2"
12
12
 
13
13
  import logging
14
14
 
dataeval/_log.py CHANGED
@@ -8,7 +8,7 @@ class LogMessage:
8
8
  Deferred message callback for logging expensive messages.
9
9
  """
10
10
 
11
- def __init__(self, fn: Callable[..., str]):
11
+ def __init__(self, fn: Callable[..., str]) -> None:
12
12
  self._fn = fn
13
13
  self._str = None
14
14
 
dataeval/config.py CHANGED
@@ -4,10 +4,10 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "use_max_processes", "DeviceLike"]
8
8
 
9
9
  import sys
10
- from typing import Union
10
+ from typing import Any, Union
11
11
 
12
12
  if sys.version_info >= (3, 10):
13
13
  from typing import TypeAlias
@@ -78,8 +78,7 @@ def get_device(override: DeviceLike | None = None) -> torch.device:
78
78
  if override is None:
79
79
  global _device
80
80
  return torch.get_default_device() if _device is None else _device
81
- else:
82
- return _todevice(override)
81
+ return _todevice(override)
83
82
 
84
83
 
85
84
  def set_max_processes(processes: int | None) -> None:
@@ -112,6 +111,24 @@ def get_max_processes() -> int | None:
112
111
  return _processes
113
112
 
114
113
 
114
+ class MaxProcessesContextManager:
115
+ def __init__(self, processes: int) -> None:
116
+ self._processes = processes
117
+
118
+ def __enter__(self) -> None:
119
+ global _processes
120
+ self._old = _processes
121
+ set_max_processes(self._processes)
122
+
123
+ def __exit__(self, *args: tuple[Any, ...]) -> None:
124
+ global _processes
125
+ _processes = self._old
126
+
127
+
128
+ def use_max_processes(processes: int) -> MaxProcessesContextManager:
129
+ return MaxProcessesContextManager(processes)
130
+
131
+
115
132
  def set_seed(seed: int | None, all_generators: bool = False) -> None:
116
133
  """
117
134
  Sets the seed for use by classes that allow for a random state or seed.
@@ -144,8 +144,7 @@ class Embeddings:
144
144
  """
145
145
  if indices is not None:
146
146
  return torch.vstack(list(self._batch(indices))).to(self.device)
147
- else:
148
- return self[:]
147
+ return self[:]
149
148
 
150
149
  def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
151
150
  """
@@ -248,6 +247,7 @@ class Embeddings:
248
247
  _logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
249
248
  except Exception as e:
250
249
  _logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
250
+ raise e
251
251
 
252
252
  @classmethod
253
253
  def load(cls, path: Path | str) -> Embeddings:
dataeval/data/_images.py CHANGED
@@ -73,15 +73,14 @@ class Images(Generic[T]):
73
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
74
  if isinstance(key, slice):
75
75
  return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
- elif hasattr(key, "__int__"):
76
+ if hasattr(key, "__int__"):
77
77
  return self._get_image(int(key))
78
78
  raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
79
 
80
80
  def _get_image(self, index: int) -> T:
81
81
  if self._is_tuple_datum:
82
82
  return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
83
- else:
84
- return cast(Dataset[T], self._dataset)[index]
83
+ return cast(Dataset[T], self._dataset)[index]
85
84
 
86
85
  def __iter__(self) -> Iterator[T]:
87
86
  for i in range(len(self._dataset)):
@@ -3,19 +3,20 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence, Sized
7
8
 
8
9
  import numpy as np
10
+ import polars as pl
9
11
  from numpy.typing import NDArray
10
12
 
11
13
  from dataeval.typing import (
12
14
  AnnotatedDataset,
13
15
  Array,
14
- ArrayLike,
15
16
  ObjectDetectionTarget,
16
17
  )
17
- from dataeval.utils._array import as_numpy, to_numpy
18
- from dataeval.utils._bin import bin_data, digitize_data, is_continuous
18
+ from dataeval.utils._array import as_numpy
19
+ from dataeval.utils._bin import bin_data, digitize_data
19
20
  from dataeval.utils.data.metadata import merge
20
21
 
21
22
  if TYPE_CHECKING:
@@ -24,31 +25,15 @@ else:
24
25
  from dataeval.data._targets import Targets
25
26
 
26
27
 
28
+ @dataclass
29
+ class FactorInfo:
30
+ factor_type: Literal["categorical", "continuous", "discrete"] | None = None
31
+ discretized_col: str | None = None
32
+
33
+
27
34
  class Metadata:
28
35
  """
29
- Class containing binned metadata.
30
-
31
- Attributes
32
- ----------
33
- discrete_factor_names : list[str]
34
- List containing factor names for the original data that was discrete and
35
- the binned continuous data
36
- discrete_data : NDArray[np.int64]
37
- Array containing values for the original data that was discrete and the
38
- binned continuous data
39
- continuous_factor_names : list[str]
40
- List containing factor names for the original continuous data
41
- continuous_data : NDArray[np.float64] | None
42
- Array containing values for the original continuous data or None if there
43
- was no continuous data
44
- class_labels : NDArray[np.int]
45
- Numerical class labels for the images/objects
46
- class_names : list[str]
47
- List of unique class names
48
- total_num_factors : int
49
- Sum of discrete_factor_names and continuous_factor_names plus 1 for class
50
- image_indices : NDArray[np.intp]
51
- Array of the image index that is mapped by the index of the factor
36
+ Class containing binned metadata using Polars DataFrames.
52
37
 
53
38
  Parameters
54
39
  ----------
@@ -73,13 +58,20 @@ class Metadata:
73
58
  exclude: Sequence[str] | None = None,
74
59
  include: Sequence[str] | None = None,
75
60
  ) -> None:
76
- self._collated = False
77
- self._merged = None
78
- self._processed = False
61
+ self._targets: Targets
62
+ self._class_labels: NDArray[np.intp]
63
+ self._class_names: list[str]
64
+ self._image_indices: NDArray[np.intp]
65
+ self._factors: dict[str, FactorInfo]
66
+ self._dropped_factors: dict[str, list[str]]
67
+ self._dataframe: pl.DataFrame
68
+
69
+ self._is_structured = False
70
+ self._is_binned = False
79
71
 
80
72
  self._dataset = dataset
81
73
  self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
82
- self._auto_bin_method = auto_bin_method
74
+ self._auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = auto_bin_method
83
75
 
84
76
  if exclude is not None and include is not None:
85
77
  raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
@@ -89,16 +81,19 @@ class Metadata:
89
81
 
90
82
  @property
91
83
  def targets(self) -> Targets:
92
- self._collate()
84
+ """Target information for the dataset."""
85
+ self._structure()
93
86
  return self._targets
94
87
 
95
88
  @property
96
89
  def raw(self) -> list[dict[str, Any]]:
97
- self._collate()
90
+ """The raw list of metadata dictionaries for the dataset."""
91
+ self._structure()
98
92
  return self._raw
99
93
 
100
94
  @property
101
95
  def exclude(self) -> set[str]:
96
+ """Factors to exclude from the metadata."""
102
97
  return self._exclude
103
98
 
104
99
  @exclude.setter
@@ -107,10 +102,11 @@ class Metadata:
107
102
  if self._exclude != exclude:
108
103
  self._exclude = exclude
109
104
  self._include = set()
110
- self._processed = False
105
+ self._is_binned = False
111
106
 
112
107
  @property
113
108
  def include(self) -> set[str]:
109
+ """Factors to include from the metadata."""
114
110
  return self._include
115
111
 
116
112
  @include.setter
@@ -119,85 +115,109 @@ class Metadata:
119
115
  if self._include != include:
120
116
  self._include = include
121
117
  self._exclude = set()
122
- self._processed = False
118
+ self._is_binned = False
123
119
 
124
120
  @property
125
121
  def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
122
+ """Map of factor names to bin counts or bin edges."""
126
123
  return self._continuous_factor_bins
127
124
 
128
125
  @continuous_factor_bins.setter
129
126
  def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
130
127
  if self._continuous_factor_bins != bins:
131
128
  self._continuous_factor_bins = dict(bins)
132
- self._processed = False
129
+ self._reset_bins(bins)
133
130
 
134
131
  @property
135
- def auto_bin_method(self) -> str:
132
+ def auto_bin_method(self) -> Literal["uniform_width", "uniform_count", "clusters"]:
133
+ """Binning method to use when continuous_factor_bins is not defined."""
136
134
  return self._auto_bin_method
137
135
 
138
136
  @auto_bin_method.setter
139
137
  def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
140
138
  if self._auto_bin_method != method:
141
139
  self._auto_bin_method = method
142
- self._processed = False
140
+ self._reset_bins()
143
141
 
144
142
  @property
145
- def merged(self) -> dict[str, Any]:
146
- self._merge()
147
- return {} if self._merged is None else self._merged[0]
143
+ def dataframe(self) -> pl.DataFrame:
144
+ """Dataframe containing target information and metadata factors."""
145
+ self._structure()
146
+ return self._dataframe
148
147
 
149
148
  @property
150
149
  def dropped_factors(self) -> dict[str, list[str]]:
151
- self._merge()
152
- return {} if self._merged is None else self._merged[1]
150
+ """Factors that were dropped during preprocessing and the reasons why they were dropped."""
151
+ self._structure()
152
+ return self._dropped_factors
153
153
 
154
154
  @property
155
- def discrete_factor_names(self) -> list[str]:
156
- self._process()
157
- return self._discrete_factor_names
155
+ def discretized_data(self) -> NDArray[np.int64]:
156
+ """Factor data with continuous data discretized."""
157
+ if not self.factor_names:
158
+ return np.array([], dtype=np.int64)
159
+
160
+ self._bin()
161
+ return (
162
+ self.dataframe.select([info.discretized_col or name for name, info in self.factor_info.items()])
163
+ .to_numpy()
164
+ .astype(np.int64)
165
+ )
158
166
 
159
167
  @property
160
- def discrete_data(self) -> NDArray[np.int64]:
161
- self._process()
162
- return self._discrete_data
168
+ def factor_names(self) -> list[str]:
169
+ """Factor names of the metadata."""
170
+ self._structure()
171
+ return list(self._factors)
163
172
 
164
173
  @property
165
- def continuous_factor_names(self) -> list[str]:
166
- self._process()
167
- return self._continuous_factor_names
174
+ def factor_info(self) -> dict[str, FactorInfo]:
175
+ """Factor types of the metadata."""
176
+ self._bin()
177
+ return self._factors
168
178
 
169
179
  @property
170
- def continuous_data(self) -> NDArray[np.float64]:
171
- self._process()
172
- return self._continuous_data
180
+ def factor_data(self) -> NDArray[Any]:
181
+ """Factor data as a NumPy array."""
182
+ if not self.factor_names:
183
+ return np.array([], dtype=np.float64)
184
+
185
+ # Extract continuous columns and convert to NumPy array
186
+ return self.dataframe.select(self.factor_names).to_numpy()
173
187
 
174
188
  @property
175
189
  def class_labels(self) -> NDArray[np.intp]:
176
- self._collate()
190
+ """Class labels as a NumPy array."""
191
+ self._structure()
177
192
  return self._class_labels
178
193
 
179
194
  @property
180
195
  def class_names(self) -> list[str]:
181
- self._collate()
196
+ """Class names as a list of strings."""
197
+ self._structure()
182
198
  return self._class_names
183
199
 
184
- @property
185
- def total_num_factors(self) -> int:
186
- self._process()
187
- return self._total_num_factors
188
-
189
200
  @property
190
201
  def image_indices(self) -> NDArray[np.intp]:
191
- self._process()
202
+ """Indices of images as a NumPy array."""
203
+ self._bin()
192
204
  return self._image_indices
193
205
 
194
206
  @property
195
207
  def image_count(self) -> int:
196
- self._process()
208
+ self._bin()
197
209
  return int(self._image_indices.max() + 1)
198
210
 
199
- def _collate(self, force: bool = False):
200
- if self._collated and not force:
211
+ def _reset_bins(self, cols: Iterable[str] | None = None) -> None:
212
+ if self._is_binned:
213
+ columns = self._dataframe.columns
214
+ for col in (col for col in cols or columns if f"{col}[|]" in columns):
215
+ self._dataframe.drop_in_place(f"{col}[|]")
216
+ self._factors[col] = FactorInfo()
217
+ self._is_binned = False
218
+
219
+ def _structure(self) -> None:
220
+ if self._is_structured:
201
221
  return
202
222
 
203
223
  raw: list[dict[str, Any]] = []
@@ -235,135 +255,120 @@ class Metadata:
235
255
  bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
236
256
  srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
237
257
 
258
+ target_dict = {
259
+ "image_index": srcidx if srcidx is not None else np.arange(len(labels)),
260
+ "class_label": labels,
261
+ "score": scores,
262
+ "box": bboxes if bboxes is not None else [None] * len(labels),
263
+ }
264
+
238
265
  self._targets = Targets(labels, scores, bboxes, srcidx)
239
266
  self._raw = raw
240
267
 
241
268
  index2label = self._dataset.metadata.get("index2label", {})
242
- self._class_labels = self._targets.labels
269
+ self._class_labels = labels
243
270
  self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
244
- self._collated = True
271
+ self._image_indices = target_dict["image_index"]
245
272
 
246
- def _merge(self, force: bool = False):
247
- if self._merged is not None and not force:
248
- return
273
+ targets_per_image = None if srcidx is None else np.unique(srcidx, return_counts=True)[1].tolist()
274
+ merged = merge(raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
249
275
 
250
- targets_per_image = (
251
- None if self.targets.source is None else np.unique(self.targets.source, return_counts=True)[1].tolist()
252
- )
253
- self._merged = merge(self.raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
276
+ reserved = ["image_index", "class_label", "score", "box"]
277
+ factor_dict = {f"metadata_{k}" if k in reserved else k: v for k, v in merged[0].items() if k != "_image_index"}
254
278
 
255
- def _validate(self) -> None:
256
- # Check that metadata is a single, flattened dictionary with uniform array lengths
257
- check_length = None
258
- if self._targets.labels.ndim > 1:
259
- raise ValueError(
260
- f"Got class labels with {self._targets.labels.ndim}-dimensional "
261
- f"shape {self._targets.labels.shape}, but expected a 1-dimensional array."
262
- )
263
- for v in self.merged.values():
264
- if not isinstance(v, (list, tuple, np.ndarray)):
265
- raise TypeError(
266
- "Metadata dictionary needs to be a single dictionary whose values "
267
- "are arraylike containing the metadata on a per image or per object basis."
268
- )
269
- else:
270
- check_length = len(v) if check_length is None else check_length
271
- if check_length != len(v):
272
- raise ValueError(
273
- "The lists/arrays in the metadata dict have varying lengths. "
274
- "Metadata requires them to be uniform in length."
275
- )
276
- if len(self._class_labels) != check_length:
277
- raise ValueError(
278
- f"The length of the label array {len(self._class_labels)} is not the same as "
279
- f"the length of the metadata arrays {check_length}."
280
- )
279
+ self._factors = dict.fromkeys(factor_dict, FactorInfo())
280
+ self._dataframe = pl.DataFrame({**target_dict, **factor_dict})
281
+ self._dropped_factors = merged[1]
282
+ self._is_structured = True
281
283
 
282
- def _process(self, force: bool = False) -> None:
283
- if self._processed and not force:
284
+ def _bin(self) -> None:
285
+ """Populate factor info and bin non-categorical factors."""
286
+ if self._is_binned:
284
287
  return
285
288
 
286
- # Create image indices from targets
287
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
289
+ # Start with an empty set of factor info
290
+ factor_info: dict[str, FactorInfo] = {}
288
291
 
289
- # Validate the metadata dimensions
290
- self._validate()
292
+ # Create a mutable DataFrame for updates
293
+ df = self.dataframe.clone()
294
+ factor_bins = self.continuous_factor_bins
291
295
 
292
- # Include specified metadata keys
293
- if self.include:
294
- metadata = {i: self.merged[i] for i in self.include if i in self.merged}
295
- continuous_factor_bins = (
296
- {i: self.continuous_factor_bins[i] for i in self.include if i in self.continuous_factor_bins}
297
- if self.continuous_factor_bins
298
- else {}
296
+ # Check for invalid keys
297
+ invalid_keys = set(factor_bins.keys()) - set(df.columns)
298
+ if invalid_keys:
299
+ warnings.warn(
300
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
301
+ "but are not columns in the metadata DataFrame. Unknown keys will be ignored."
299
302
  )
300
- else:
301
- metadata = self.merged
302
- continuous_factor_bins = dict(self.continuous_factor_bins) if self.continuous_factor_bins else {}
303
- for k in self.exclude:
304
- metadata.pop(k, None)
305
- continuous_factor_bins.pop(k, None)
306
-
307
- # Remove generated "_image_index" if present
308
- if "_image_index" in metadata:
309
- metadata.pop("_image_index", None)
310
-
311
- # Bin according to user supplied bins
312
- continuous_metadata = {}
313
- discrete_metadata = {}
314
- if continuous_factor_bins:
315
- invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
316
- if invalid_keys:
317
- raise KeyError(
318
- f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
319
- "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
320
- "or add corresponding entries to the `metadata` dictionary."
321
- )
322
- for factor, bins in continuous_factor_bins.items():
323
- discrete_metadata[factor] = digitize_data(metadata[factor], bins)
324
- continuous_metadata[factor] = metadata[factor]
325
-
326
- # Determine category of the rest of the keys
327
- remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
328
- for key in remaining_keys:
329
- data = to_numpy(metadata[key])
330
- if np.issubdtype(data.dtype, np.number):
331
- result = is_continuous(data, self._image_indices)
332
- if result:
333
- continuous_metadata[key] = data
334
- unique_samples, ordinal_data = np.unique(data, return_inverse=True)
335
- if unique_samples.size <= np.max([20, data.size * 0.01]):
336
- discrete_metadata[key] = ordinal_data
337
- else:
303
+
304
+ column_set = set(df.columns)
305
+ for col in (col for col in self.factor_names if f"{col}[|]" not in column_set):
306
+ # Get data as numpy array for processing
307
+ data = df[col].to_numpy()
308
+ col_dz = f"{col}[|]"
309
+ if col in factor_bins:
310
+ # User provided binning
311
+ bins = factor_bins[col]
312
+ df = df.with_columns(pl.Series(name=col_dz, values=digitize_data(data, bins).astype(np.int64)))
313
+ factor_info[col] = FactorInfo("continuous", col_dz)
314
+ else:
315
+ # Check if data is numeric
316
+ unique, ordinal = np.unique(data, return_inverse=True)
317
+ if not np.issubdtype(data.dtype, np.number) or unique.size <= max(20, data.size * 0.01):
318
+ # Non-numeric data or small number of unique values - convert to categorical
319
+ df = df.with_columns(pl.Series(name=col_dz, values=ordinal.astype(np.int64)))
320
+ factor_info[col] = FactorInfo("categorical", col_dz)
321
+ elif data.dtype == float:
322
+ # Many unique values - discretize by binning
338
323
  warnings.warn(
339
- f"A user defined binning was not provided for {key}. "
324
+ f"A user defined binning was not provided for {col}. "
340
325
  f"Using the {self.auto_bin_method} method to discretize the data. "
341
326
  "It is recommended that the user rerun and supply the desired "
342
327
  "bins using the continuous_factor_bins parameter.",
343
328
  UserWarning,
344
329
  )
345
- discrete_metadata[key] = bin_data(data, self.auto_bin_method)
346
- else:
347
- _, discrete_metadata[key] = np.unique(data, return_inverse=True)
348
-
349
- # Split out the dictionaries into the keys and values
350
- self._discrete_factor_names = list(discrete_metadata.keys())
351
- self._discrete_data = (
352
- np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
353
- if discrete_metadata
354
- else np.array([], dtype=np.int64)
355
- )
356
- self._continuous_factor_names = list(continuous_metadata.keys())
357
- self._continuous_data = (
358
- np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
359
- if continuous_metadata
360
- else np.array([], dtype=np.float64)
361
- )
362
- self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
363
- self._processed = True
364
-
365
- def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
366
- self._merge()
330
+ # Create binned version
331
+ binned_data = bin_data(data, self.auto_bin_method)
332
+ df = df.with_columns(pl.Series(name=col_dz, values=binned_data.astype(np.int64)))
333
+ factor_info[col] = FactorInfo("continuous", col_dz)
334
+ else:
335
+ factor_info[col] = FactorInfo("discrete", col_dz)
336
+
337
+ # Store the results
338
+ self._dataframe = df
339
+ self._factors.update(factor_info)
340
+ self._is_binned = True
341
+
342
+ def get_factors_by_type(self, factor_type: Literal["categorical", "continuous", "discrete"]) -> list[str]:
343
+ """
344
+ Get the names of factors of a specific type.
345
+
346
+ Parameters
347
+ ----------
348
+ factor_type : Literal["categorical", "continuous", "discrete"]
349
+ The type of factors to retrieve.
350
+
351
+ Returns
352
+ -------
353
+ list[str]
354
+ List of factor names of the specified type.
355
+ """
356
+ self._bin()
357
+ return [name for name, info in self.factor_info.items() if info.factor_type == factor_type]
358
+
359
+ def add_factors(self, factors: Mapping[str, Any]) -> None:
360
+ """
361
+ Add additional factors to the metadata.
362
+
363
+ The number of measures per factor must match the number of images
364
+ in the dataset or the number of detections in the dataset.
365
+
366
+ Parameters
367
+ ----------
368
+ factors : Mapping[str, ArrayLike]
369
+ Dictionary of factors to add to the metadata.
370
+ """
371
+ self._structure()
367
372
 
368
373
  targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
369
374
  images = self.image_count
@@ -374,9 +379,14 @@ class Metadata:
374
379
  raise ValueError(
375
380
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
376
381
  )
377
- merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
382
+
383
+ new_columns = []
378
384
  for k, v in factors.items():
379
385
  v = as_numpy(v)
380
- merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
386
+ data = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
387
+ new_columns.append(pl.Series(name=k, values=data))
388
+ self._factors[k] = FactorInfo()
381
389
 
382
- self._processed = False
390
+ if new_columns:
391
+ self._dataframe = self.dataframe.with_columns(new_columns)
392
+ self._is_binned = False
@@ -110,8 +110,7 @@ class Select(AnnotatedDataset[_TDatum]):
110
110
  grouped: dict[int, list[Selection[_TDatum]]] = {}
111
111
  for selection in selections_list:
112
112
  grouped.setdefault(selection.stage, []).append(selection)
113
- selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
114
- return selection_list
113
+ return [selection for category in sorted(grouped) for selection in grouped[category]]
115
114
 
116
115
  def _apply_selections(self) -> None:
117
116
  for selection in self._selections:
dataeval/data/_split.py CHANGED
@@ -23,7 +23,7 @@ _logger = logging.getLogger(__name__)
23
23
  class KFoldSplitter(Protocol):
24
24
  """Protocol covering sklearn KFold variant splitters"""
25
25
 
26
- def __init__(self, n_splits: int): ...
26
+ def __init__(self, n_splits: int) -> None: ...
27
27
  def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
28
28
 
29
29
 
@@ -207,10 +207,9 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
207
207
  return None
208
208
 
209
209
  split_set = set(split_on)
210
- indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
- binned_features = metadata.discrete_data[:, indices]
212
- group_ids = np.unique(binned_features, axis=0, return_inverse=True)[1]
213
- return group_ids
210
+ indices = [i for i, name in enumerate(metadata.factor_names) if name in split_set]
211
+ binned_features = metadata.discretized_data[:, indices]
212
+ return np.unique(binned_features, axis=0, return_inverse=True)[1]
214
213
 
215
214
 
216
215
  def make_splits(