dataeval 0.86.1__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.86.1"
11
+ __version__ = "0.86.2"
12
12
 
13
13
  import logging
14
14
 
@@ -3,19 +3,20 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence, Sized
7
8
 
8
9
  import numpy as np
10
+ import polars as pl
9
11
  from numpy.typing import NDArray
10
12
 
11
13
  from dataeval.typing import (
12
14
  AnnotatedDataset,
13
15
  Array,
14
- ArrayLike,
15
16
  ObjectDetectionTarget,
16
17
  )
17
- from dataeval.utils._array import as_numpy, to_numpy
18
- from dataeval.utils._bin import bin_data, digitize_data, is_continuous
18
+ from dataeval.utils._array import as_numpy
19
+ from dataeval.utils._bin import bin_data, digitize_data
19
20
  from dataeval.utils.data.metadata import merge
20
21
 
21
22
  if TYPE_CHECKING:
@@ -24,31 +25,15 @@ else:
24
25
  from dataeval.data._targets import Targets
25
26
 
26
27
 
28
+ @dataclass
29
+ class FactorInfo:
30
+ factor_type: Literal["categorical", "continuous", "discrete"] | None = None
31
+ discretized_col: str | None = None
32
+
33
+
27
34
  class Metadata:
28
35
  """
29
- Class containing binned metadata.
30
-
31
- Attributes
32
- ----------
33
- discrete_factor_names : list[str]
34
- List containing factor names for the original data that was discrete and
35
- the binned continuous data
36
- discrete_data : NDArray[np.int64]
37
- Array containing values for the original data that was discrete and the
38
- binned continuous data
39
- continuous_factor_names : list[str]
40
- List containing factor names for the original continuous data
41
- continuous_data : NDArray[np.float64] | None
42
- Array containing values for the original continuous data or None if there
43
- was no continuous data
44
- class_labels : NDArray[np.int]
45
- Numerical class labels for the images/objects
46
- class_names : list[str]
47
- List of unique class names
48
- total_num_factors : int
49
- Sum of discrete_factor_names and continuous_factor_names plus 1 for class
50
- image_indices : NDArray[np.intp]
51
- Array of the image index that is mapped by the index of the factor
36
+ Class containing binned metadata using Polars DataFrames.
52
37
 
53
38
  Parameters
54
39
  ----------
@@ -73,13 +58,20 @@ class Metadata:
73
58
  exclude: Sequence[str] | None = None,
74
59
  include: Sequence[str] | None = None,
75
60
  ) -> None:
76
- self._collated = False
77
- self._merged = None
78
- self._processed = False
61
+ self._targets: Targets
62
+ self._class_labels: NDArray[np.intp]
63
+ self._class_names: list[str]
64
+ self._image_indices: NDArray[np.intp]
65
+ self._factors: dict[str, FactorInfo]
66
+ self._dropped_factors: dict[str, list[str]]
67
+ self._dataframe: pl.DataFrame
68
+
69
+ self._is_structured = False
70
+ self._is_binned = False
79
71
 
80
72
  self._dataset = dataset
81
73
  self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
82
- self._auto_bin_method = auto_bin_method
74
+ self._auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = auto_bin_method
83
75
 
84
76
  if exclude is not None and include is not None:
85
77
  raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
@@ -89,16 +81,19 @@ class Metadata:
89
81
 
90
82
  @property
91
83
  def targets(self) -> Targets:
92
- self._collate()
84
+ """Target information for the dataset."""
85
+ self._structure()
93
86
  return self._targets
94
87
 
95
88
  @property
96
89
  def raw(self) -> list[dict[str, Any]]:
97
- self._collate()
90
+ """The raw list of metadata dictionaries for the dataset."""
91
+ self._structure()
98
92
  return self._raw
99
93
 
100
94
  @property
101
95
  def exclude(self) -> set[str]:
96
+ """Factors to exclude from the metadata."""
102
97
  return self._exclude
103
98
 
104
99
  @exclude.setter
@@ -107,10 +102,11 @@ class Metadata:
107
102
  if self._exclude != exclude:
108
103
  self._exclude = exclude
109
104
  self._include = set()
110
- self._processed = False
105
+ self._is_binned = False
111
106
 
112
107
  @property
113
108
  def include(self) -> set[str]:
109
+ """Factors to include from the metadata."""
114
110
  return self._include
115
111
 
116
112
  @include.setter
@@ -119,85 +115,109 @@ class Metadata:
119
115
  if self._include != include:
120
116
  self._include = include
121
117
  self._exclude = set()
122
- self._processed = False
118
+ self._is_binned = False
123
119
 
124
120
  @property
125
121
  def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
122
+ """Map of factor names to bin counts or bin edges."""
126
123
  return self._continuous_factor_bins
127
124
 
128
125
  @continuous_factor_bins.setter
129
126
  def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
130
127
  if self._continuous_factor_bins != bins:
131
128
  self._continuous_factor_bins = dict(bins)
132
- self._processed = False
129
+ self._reset_bins(bins)
133
130
 
134
131
  @property
135
- def auto_bin_method(self) -> str:
132
+ def auto_bin_method(self) -> Literal["uniform_width", "uniform_count", "clusters"]:
133
+ """Binning method to use when continuous_factor_bins is not defined."""
136
134
  return self._auto_bin_method
137
135
 
138
136
  @auto_bin_method.setter
139
137
  def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
140
138
  if self._auto_bin_method != method:
141
139
  self._auto_bin_method = method
142
- self._processed = False
140
+ self._reset_bins()
143
141
 
144
142
  @property
145
- def merged(self) -> dict[str, Any]:
146
- self._merge()
147
- return {} if self._merged is None else self._merged[0]
143
+ def dataframe(self) -> pl.DataFrame:
144
+ """Dataframe containing target information and metadata factors."""
145
+ self._structure()
146
+ return self._dataframe
148
147
 
149
148
  @property
150
149
  def dropped_factors(self) -> dict[str, list[str]]:
151
- self._merge()
152
- return {} if self._merged is None else self._merged[1]
150
+ """Factors that were dropped during preprocessing and the reasons why they were dropped."""
151
+ self._structure()
152
+ return self._dropped_factors
153
153
 
154
154
  @property
155
- def discrete_factor_names(self) -> list[str]:
156
- self._process()
157
- return self._discrete_factor_names
155
+ def discretized_data(self) -> NDArray[np.int64]:
156
+ """Factor data with continuous data discretized."""
157
+ if not self.factor_names:
158
+ return np.array([], dtype=np.int64)
159
+
160
+ self._bin()
161
+ return (
162
+ self.dataframe.select([info.discretized_col or name for name, info in self.factor_info.items()])
163
+ .to_numpy()
164
+ .astype(np.int64)
165
+ )
158
166
 
159
167
  @property
160
- def discrete_data(self) -> NDArray[np.int64]:
161
- self._process()
162
- return self._discrete_data
168
+ def factor_names(self) -> list[str]:
169
+ """Factor names of the metadata."""
170
+ self._structure()
171
+ return list(self._factors)
163
172
 
164
173
  @property
165
- def continuous_factor_names(self) -> list[str]:
166
- self._process()
167
- return self._continuous_factor_names
174
+ def factor_info(self) -> dict[str, FactorInfo]:
175
+ """Factor types of the metadata."""
176
+ self._bin()
177
+ return self._factors
168
178
 
169
179
  @property
170
- def continuous_data(self) -> NDArray[np.float64]:
171
- self._process()
172
- return self._continuous_data
180
+ def factor_data(self) -> NDArray[Any]:
181
+ """Factor data as a NumPy array."""
182
+ if not self.factor_names:
183
+ return np.array([], dtype=np.float64)
184
+
185
+ # Extract continuous columns and convert to NumPy array
186
+ return self.dataframe.select(self.factor_names).to_numpy()
173
187
 
174
188
  @property
175
189
  def class_labels(self) -> NDArray[np.intp]:
176
- self._collate()
190
+ """Class labels as a NumPy array."""
191
+ self._structure()
177
192
  return self._class_labels
178
193
 
179
194
  @property
180
195
  def class_names(self) -> list[str]:
181
- self._collate()
196
+ """Class names as a list of strings."""
197
+ self._structure()
182
198
  return self._class_names
183
199
 
184
- @property
185
- def total_num_factors(self) -> int:
186
- self._process()
187
- return self._total_num_factors
188
-
189
200
  @property
190
201
  def image_indices(self) -> NDArray[np.intp]:
191
- self._process()
202
+ """Indices of images as a NumPy array."""
203
+ self._bin()
192
204
  return self._image_indices
193
205
 
194
206
  @property
195
207
  def image_count(self) -> int:
196
- self._process()
208
+ self._bin()
197
209
  return int(self._image_indices.max() + 1)
198
210
 
199
- def _collate(self, force: bool = False) -> None:
200
- if self._collated and not force:
211
+ def _reset_bins(self, cols: Iterable[str] | None = None) -> None:
212
+ if self._is_binned:
213
+ columns = self._dataframe.columns
214
+ for col in (col for col in cols or columns if f"{col}[|]" in columns):
215
+ self._dataframe.drop_in_place(f"{col}[|]")
216
+ self._factors[col] = FactorInfo()
217
+ self._is_binned = False
218
+
219
+ def _structure(self) -> None:
220
+ if self._is_structured:
201
221
  return
202
222
 
203
223
  raw: list[dict[str, Any]] = []
@@ -235,134 +255,108 @@ class Metadata:
235
255
  bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
236
256
  srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
237
257
 
258
+ target_dict = {
259
+ "image_index": srcidx if srcidx is not None else np.arange(len(labels)),
260
+ "class_label": labels,
261
+ "score": scores,
262
+ "box": bboxes if bboxes is not None else [None] * len(labels),
263
+ }
264
+
238
265
  self._targets = Targets(labels, scores, bboxes, srcidx)
239
266
  self._raw = raw
240
267
 
241
268
  index2label = self._dataset.metadata.get("index2label", {})
242
- self._class_labels = self._targets.labels
269
+ self._class_labels = labels
243
270
  self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
244
- self._collated = True
271
+ self._image_indices = target_dict["image_index"]
272
+
273
+ targets_per_image = None if srcidx is None else np.unique(srcidx, return_counts=True)[1].tolist()
274
+ merged = merge(raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
275
+
276
+ reserved = ["image_index", "class_label", "score", "box"]
277
+ factor_dict = {f"metadata_{k}" if k in reserved else k: v for k, v in merged[0].items() if k != "_image_index"}
245
278
 
246
- def _merge(self, force: bool = False) -> None:
247
- if self._merged is not None and not force:
279
+ self._factors = dict.fromkeys(factor_dict, FactorInfo())
280
+ self._dataframe = pl.DataFrame({**target_dict, **factor_dict})
281
+ self._dropped_factors = merged[1]
282
+ self._is_structured = True
283
+
284
+ def _bin(self) -> None:
285
+ """Populate factor info and bin non-categorical factors."""
286
+ if self._is_binned:
248
287
  return
249
288
 
250
- targets_per_image = (
251
- None if self.targets.source is None else np.unique(self.targets.source, return_counts=True)[1].tolist()
252
- )
253
- self._merged = merge(self.raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
289
+ # Start with an empty set of factor info
290
+ factor_info: dict[str, FactorInfo] = {}
254
291
 
255
- def _validate(self) -> None:
256
- # Check that metadata is a single, flattened dictionary with uniform array lengths
257
- check_length = None
258
- if self._targets.labels.ndim > 1:
259
- raise ValueError(
260
- f"Got class labels with {self._targets.labels.ndim}-dimensional "
261
- f"shape {self._targets.labels.shape}, but expected a 1-dimensional array."
262
- )
263
- for v in self.merged.values():
264
- if not isinstance(v, (list, tuple, np.ndarray)):
265
- raise TypeError(
266
- "Metadata dictionary needs to be a single dictionary whose values "
267
- "are arraylike containing the metadata on a per image or per object basis."
268
- )
269
- check_length = len(v) if check_length is None else check_length
270
- if check_length != len(v):
271
- raise ValueError(
272
- "The lists/arrays in the metadata dict have varying lengths. "
273
- "Metadata requires them to be uniform in length."
274
- )
275
- if len(self._class_labels) != check_length:
276
- raise ValueError(
277
- f"The length of the label array {len(self._class_labels)} is not the same as "
278
- f"the length of the metadata arrays {check_length}."
279
- )
292
+ # Create a mutable DataFrame for updates
293
+ df = self.dataframe.clone()
294
+ factor_bins = self.continuous_factor_bins
280
295
 
281
- def _filter(self, d: Mapping[str, Any]) -> dict[str, Any]:
282
- return (
283
- {k: d[k] for k in self.include if k in d} if self.include else {k: d[k] for k in d if k not in self.exclude}
284
- )
296
+ # Check for invalid keys
297
+ invalid_keys = set(factor_bins.keys()) - set(df.columns)
298
+ if invalid_keys:
299
+ warnings.warn(
300
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
301
+ "but are not columns in the metadata DataFrame. Unknown keys will be ignored."
302
+ )
285
303
 
286
- def _split_continuous_discrete(
287
- self, metadata: dict[str, NDArray[Any]], continuous_factor_bins: dict[str, int | Sequence[float]]
288
- ) -> tuple[dict[str, NDArray[Any]], dict[str, NDArray[np.int64]]]:
289
- # Bin according to user supplied bins
290
- continuous_metadata = {}
291
- discrete_metadata = {}
292
- if continuous_factor_bins:
293
- invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
294
- if invalid_keys:
295
- raise KeyError(
296
- f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
297
- "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
298
- "or add corresponding entries to the `metadata` dictionary."
299
- )
300
- for factor, bins in continuous_factor_bins.items():
301
- discrete_metadata[factor] = digitize_data(metadata[factor], bins)
302
- continuous_metadata[factor] = metadata[factor]
303
-
304
- # Determine category of the rest of the keys
305
- remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
306
- for key in remaining_keys:
307
- data = to_numpy(metadata[key])
308
- if np.issubdtype(data.dtype, np.number):
309
- result = is_continuous(data, self._image_indices)
310
- if result:
311
- continuous_metadata[key] = data
312
- unique_samples, ordinal_data = np.unique(data, return_inverse=True)
313
- if unique_samples.size <= np.max([20, data.size * 0.01]):
314
- discrete_metadata[key] = ordinal_data
315
- else:
304
+ column_set = set(df.columns)
305
+ for col in (col for col in self.factor_names if f"{col}[|]" not in column_set):
306
+ # Get data as numpy array for processing
307
+ data = df[col].to_numpy()
308
+ col_dz = f"{col}[|]"
309
+ if col in factor_bins:
310
+ # User provided binning
311
+ bins = factor_bins[col]
312
+ df = df.with_columns(pl.Series(name=col_dz, values=digitize_data(data, bins).astype(np.int64)))
313
+ factor_info[col] = FactorInfo("continuous", col_dz)
314
+ else:
315
+ # Check if data is numeric
316
+ unique, ordinal = np.unique(data, return_inverse=True)
317
+ if not np.issubdtype(data.dtype, np.number) or unique.size <= max(20, data.size * 0.01):
318
+ # Non-numeric data or small number of unique values - convert to categorical
319
+ df = df.with_columns(pl.Series(name=col_dz, values=ordinal.astype(np.int64)))
320
+ factor_info[col] = FactorInfo("categorical", col_dz)
321
+ elif data.dtype == float:
322
+ # Many unique values - discretize by binning
316
323
  warnings.warn(
317
- f"A user defined binning was not provided for {key}. "
324
+ f"A user defined binning was not provided for {col}. "
318
325
  f"Using the {self.auto_bin_method} method to discretize the data. "
319
326
  "It is recommended that the user rerun and supply the desired "
320
327
  "bins using the continuous_factor_bins parameter.",
321
328
  UserWarning,
322
329
  )
323
- discrete_metadata[key] = bin_data(data, self.auto_bin_method)
324
- else:
325
- _, discrete_metadata[key] = np.unique(data, return_inverse=True)
326
-
327
- return continuous_metadata, discrete_metadata
328
-
329
- def _process(self, force: bool = False) -> None:
330
- if self._processed and not force:
331
- return
332
-
333
- # Create image indices from targets
334
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
335
-
336
- # Validate the metadata dimensions
337
- self._validate()
330
+ # Create binned version
331
+ binned_data = bin_data(data, self.auto_bin_method)
332
+ df = df.with_columns(pl.Series(name=col_dz, values=binned_data.astype(np.int64)))
333
+ factor_info[col] = FactorInfo("continuous", col_dz)
334
+ else:
335
+ factor_info[col] = FactorInfo("discrete", col_dz)
338
336
 
339
- # Filter the merged metadata and continuous factor bins
340
- metadata = self._filter(self.merged)
341
- continuous_factor_bins = self._filter(self.continuous_factor_bins)
337
+ # Store the results
338
+ self._dataframe = df
339
+ self._factors.update(factor_info)
340
+ self._is_binned = True
342
341
 
343
- # Remove generated "_image_index" if present
344
- metadata.pop("_image_index", None)
342
+ def get_factors_by_type(self, factor_type: Literal["categorical", "continuous", "discrete"]) -> list[str]:
343
+ """
344
+ Get the names of factors of a specific type.
345
345
 
346
- # Split the metadata into continuous and discrete
347
- continuous_metadata, discrete_metadata = self._split_continuous_discrete(metadata, continuous_factor_bins)
346
+ Parameters
347
+ ----------
348
+ factor_type : Literal["categorical", "continuous", "discrete"]
349
+ The type of factors to retrieve.
348
350
 
349
- # Split out the dictionaries into the keys and values
350
- self._discrete_factor_names = list(discrete_metadata.keys())
351
- self._discrete_data = (
352
- np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
353
- if discrete_metadata
354
- else np.array([], dtype=np.int64)
355
- )
356
- self._continuous_factor_names = list(continuous_metadata.keys())
357
- self._continuous_data = (
358
- np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
359
- if continuous_metadata
360
- else np.array([], dtype=np.float64)
361
- )
362
- self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
363
- self._processed = True
351
+ Returns
352
+ -------
353
+ list[str]
354
+ List of factor names of the specified type.
355
+ """
356
+ self._bin()
357
+ return [name for name, info in self.factor_info.items() if info.factor_type == factor_type]
364
358
 
365
- def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
359
+ def add_factors(self, factors: Mapping[str, Any]) -> None:
366
360
  """
367
361
  Add additional factors to the metadata.
368
362
 
@@ -374,7 +368,7 @@ class Metadata:
374
368
  factors : Mapping[str, ArrayLike]
375
369
  Dictionary of factors to add to the metadata.
376
370
  """
377
- self._merge()
371
+ self._structure()
378
372
 
379
373
  targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
380
374
  images = self.image_count
@@ -385,9 +379,14 @@ class Metadata:
385
379
  raise ValueError(
386
380
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
387
381
  )
388
- merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
382
+
383
+ new_columns = []
389
384
  for k, v in factors.items():
390
385
  v = as_numpy(v)
391
- merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
386
+ data = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
387
+ new_columns.append(pl.Series(name=k, values=data))
388
+ self._factors[k] = FactorInfo()
392
389
 
393
- self._processed = False
390
+ if new_columns:
391
+ self._dataframe = self.dataframe.with_columns(new_columns)
392
+ self._is_binned = False
dataeval/data/_split.py CHANGED
@@ -207,8 +207,8 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
207
207
  return None
208
208
 
209
209
  split_set = set(split_on)
210
- indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
- binned_features = metadata.discrete_data[:, indices]
210
+ indices = [i for i, name in enumerate(metadata.factor_names) if name in split_set]
211
+ binned_features = metadata.discretized_data[:, indices]
212
212
  return np.unique(binned_features, axis=0, return_inverse=True)[1]
213
213
 
214
214
 
@@ -80,14 +80,17 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
80
80
  MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
81
81
  """
82
82
 
83
- _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
84
- fnames = metadata1.continuous_factor_names
83
+ _compare_keys(metadata1.factor_names, metadata2.factor_names)
84
+ cont_fnames = metadata1.get_factors_by_type("continuous")
85
85
 
86
- cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
87
- cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
86
+ if not cont_fnames:
87
+ return MetadataDistanceOutput({})
88
88
 
89
- _validate_factors_and_data(fnames, cont1)
90
- _validate_factors_and_data(fnames, cont2)
89
+ cont1 = np.atleast_2d(metadata1.dataframe[cont_fnames].to_numpy()) # (S, F)
90
+ cont2 = np.atleast_2d(metadata2.dataframe[cont_fnames].to_numpy()) # (S, F)
91
+
92
+ _validate_factors_and_data(cont_fnames, cont1)
93
+ _validate_factors_and_data(cont_fnames, cont2)
91
94
 
92
95
  N = len(cont1)
93
96
  M = len(cont2)
@@ -104,7 +107,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
104
107
  results: dict[str, MetadataDistanceValues] = {}
105
108
 
106
109
  # Per factor
107
- for i, fname in enumerate(fnames):
110
+ for i, fname in enumerate(cont_fnames):
108
111
  fdata1 = cont1[:, i] # (S, 1)
109
112
  fdata2 = cont2[:, i] # (S, 1)
110
113
 
dataeval/metadata/_ood.py CHANGED
@@ -15,95 +15,6 @@ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorO
15
15
  from dataeval.outputs._base import set_metadata
16
16
 
17
17
 
18
- def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
19
- """Combines the discrete and continuous data of a :class:`Metadata` object
20
-
21
- Returns
22
- -------
23
- Tuple[list[str], NDArray]
24
- The combined list of factors names and the combined discrete and continuous data
25
-
26
- Note
27
- ----
28
- Discrete and continuous data must have the same number of samples
29
- """
30
- names = []
31
- data = []
32
-
33
- if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
34
- names.extend(metadata.discrete_factor_names)
35
- data.append(metadata.discrete_data)
36
-
37
- if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
38
- names.extend(metadata.continuous_factor_names)
39
- data.append(metadata.continuous_data)
40
-
41
- return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
42
-
43
-
44
- def _combine_metadata(
45
- metadata_1: Metadata, metadata_2: Metadata
46
- ) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
47
- """
48
- Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
49
- match exactly and data has the same number of columns (factors).
50
-
51
- Parameters
52
- ----------
53
- metadata_1 : Metadata
54
- The set of factor names used as reference to determine the correct factor names and length of data
55
- metadata_2 : Metadata
56
- The compared set of factor names and data that must match metadata_1
57
-
58
- Returns
59
- -------
60
- list[str]
61
- The combined discrete and continuous factor names in that order.
62
- list[NDArray]
63
- Combined discrete and continuous data of metadata_1
64
- list[NDArray]
65
- Combined discrete and continuous data of metadata_2
66
-
67
- Raises
68
- ------
69
- ValueError
70
- If keys do not match in metadata_1 and metadata_2
71
- ValueError
72
- If the length of keys do not match the length of the data
73
- """
74
- factor_names: list[str] = []
75
- m1_data: list[NDArray[np.int64 | np.float64]] = []
76
- m2_data: list[NDArray[np.int64 | np.float64]] = []
77
-
78
- # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
79
- if metadata_1.total_num_factors != metadata_2.total_num_factors:
80
- raise ValueError(
81
- f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
82
- f"and metadata_2 ({metadata_2.total_num_factors})"
83
- )
84
-
85
- # Validate and attach discrete data
86
- if metadata_1.discrete_factor_names:
87
- _compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
88
- _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
89
-
90
- factor_names.extend(metadata_1.discrete_factor_names)
91
- m1_data.append(metadata_1.discrete_data)
92
- m2_data.append(metadata_2.discrete_data)
93
-
94
- # Validate and attach continuous data
95
- if metadata_1.continuous_factor_names:
96
- _compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
97
- _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
98
-
99
- factor_names.extend(metadata_1.continuous_factor_names)
100
- m1_data.append(metadata_1.continuous_data)
101
- m2_data.append(metadata_2.continuous_data)
102
-
103
- # Turns list of discrete and continuous into one array
104
- return factor_names, m1_data, m2_data
105
-
106
-
107
18
  def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
108
19
  """
109
20
  Calculates deviations of the test data from the median of the reference data
@@ -207,16 +118,13 @@ def find_most_deviated_factors(
207
118
  if not any(ood_mask):
208
119
  return MostDeviatedFactorsOutput([])
209
120
 
210
- # Combines reference and test factor names and data if exists and match exactly
211
- # shape -> (samples, factors)
212
- factor_names, md_1, md_2 = _combine_metadata(
213
- metadata_1=metadata_ref,
214
- metadata_2=metadata_tst,
215
- )
121
+ factor_names = metadata_ref.factor_names
122
+ ref_data = metadata_ref.factor_data
123
+ tst_data = metadata_tst.factor_data
216
124
 
217
- # Stack discrete and continuous factors as separate factors. Must have equal sample counts
218
- ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
219
- tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
125
+ _compare_keys(factor_names, metadata_tst.factor_names)
126
+ _validate_factors_and_data(factor_names, ref_data)
127
+ _validate_factors_and_data(factor_names, tst_data)
220
128
 
221
129
  if len(ref_data) < 3:
222
130
  warnings.warn(
@@ -256,6 +164,7 @@ which is what many library functions return, multiply it by _NATS2BITS to get it
256
164
  """
257
165
 
258
166
 
167
+ @set_metadata
259
168
  def find_ood_predictors(
260
169
  metadata: Metadata,
261
170
  ood: OODOutput,
@@ -305,8 +214,8 @@ def find_ood_predictors(
305
214
 
306
215
  ood_mask: NDArray[np.bool_] = ood.is_ood
307
216
 
308
- discrete_features_count = len(metadata.discrete_factor_names)
309
- factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
217
+ factors = metadata.factor_names
218
+ data = metadata.factor_data
310
219
 
311
220
  # No metadata correlated with out of distribution data, return 0.0 for all factors
312
221
  if not any(ood_mask):
@@ -320,14 +229,13 @@ def find_ood_predictors(
320
229
  # Calculate mean, std of each factor over all samples
321
230
  scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
322
231
 
323
- discrete_features = np.zeros_like(factors, dtype=np.bool_)
324
- discrete_features[:discrete_features_count] = True
232
+ discrete_features = [info.factor_type != "continuous" for info in metadata.factor_info.values()]
325
233
 
326
234
  mutual_info_values = (
327
235
  mutual_info_classif(
328
236
  X=scaled_data,
329
237
  y=ood_mask,
330
- discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
238
+ discrete_features=discrete_features, # type: ignore - sklearn function not typed
331
239
  random_state=get_seed(),
332
240
  )
333
241
  * _NATS2BITS
@@ -68,22 +68,20 @@ def balance(
68
68
 
69
69
  >>> bal = balance(metadata)
70
70
  >>> bal.balance
71
- array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
71
+ array([1. , 0.134, 0. , 0. ])
72
72
 
73
73
  Return intra/interfactor balance (mutual information)
74
74
 
75
75
  >>> bal.factors
76
- array([[1. , 0.314, 0.269, 0.852, 0.367],
77
- [0.314, 1. , 0.097, 0.158, 1.98 ],
78
- [0.269, 0.097, 1. , 0.037, 0.015],
79
- [0.852, 0.158, 0.037, 0.475, 0.255],
80
- [0.367, 1.98 , 0.015, 0.255, 1.063]])
76
+ array([[1. , 0.017, 0.015],
77
+ [0.017, 0.445, 0.245],
78
+ [0.015, 0.245, 1.063]])
81
79
 
82
80
  Return classwise balance (mutual information) of factors with individual class_labels
83
81
 
84
82
  >>> bal.classwise
85
- array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
86
- [1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
83
+ array([[1. , 0.134, 0. , 0. ],
84
+ [1. , 0.134, 0. , 0. ]])
87
85
 
88
86
 
89
87
  See Also
@@ -92,41 +90,39 @@ def balance(
92
90
  sklearn.feature_selection.mutual_info_regression
93
91
  sklearn.metrics.mutual_info_score
94
92
  """
95
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
93
+ if not metadata.factor_names:
96
94
  raise ValueError("No factors found in provided metadata.")
97
95
 
98
96
  num_neighbors = _validate_num_neighbors(num_neighbors)
99
97
 
100
- num_factors = metadata.total_num_factors
101
- is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
98
+ data = metadata.discretized_data
99
+ factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
100
+ is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
101
+ num_factors = len(factor_types)
102
+
102
103
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
103
- data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
104
- discretized_data = data
105
- if len(metadata.continuous_data):
106
- data = np.hstack((data, metadata.continuous_data))
107
- discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
108
- discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
109
-
110
- for idx in range(num_factors):
111
- if idx >= len(metadata.discrete_factor_names) + 1:
112
- mi[idx, :] = mutual_info_regression(
104
+ data = np.hstack((metadata.class_labels[:, np.newaxis], data))
105
+
106
+ for idx, factor_type in enumerate(factor_types.values()):
107
+ if factor_type != "continuous":
108
+ mi[idx, :] = mutual_info_classif(
113
109
  data,
114
110
  data[:, idx],
115
- discrete_features=is_discrete, # type: ignore
111
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
116
112
  n_neighbors=num_neighbors,
117
113
  random_state=get_seed(),
118
114
  )
119
115
  else:
120
- mi[idx, :] = mutual_info_classif(
116
+ mi[idx, :] = mutual_info_regression(
121
117
  data,
122
118
  data[:, idx],
123
- discrete_features=is_discrete, # type: ignore
119
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
124
120
  n_neighbors=num_neighbors,
125
121
  random_state=get_seed(),
126
122
  )
127
123
 
128
124
  # Normalization via entropy
129
- bin_cnts = get_counts(discretized_data)
125
+ bin_cnts = get_counts(data)
130
126
  ent_factor = sp.stats.entropy(bin_cnts, axis=0)
131
127
  norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
132
128
 
@@ -149,7 +145,7 @@ def balance(
149
145
  classwise_mi[idx, :] = mutual_info_classif(
150
146
  data,
151
147
  tgt_bin[:, idx],
152
- discrete_features=is_discrete, # type: ignore
148
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
153
149
  n_neighbors=num_neighbors,
154
150
  random_state=get_seed(),
155
151
  )
@@ -161,12 +157,6 @@ def balance(
161
157
  classwise = classwise_mi / norm_factor
162
158
 
163
159
  # Grabbing factor names for plotting function
164
- factor_names = ["class"]
165
- for name in metadata.discrete_factor_names:
166
- if name in metadata.continuous_factor_names:
167
- name = name + "-discrete"
168
- factor_names.append(name)
169
- for name in metadata.continuous_factor_names:
170
- factor_names.append(name + "-continuous")
160
+ factor_names = ["class_label"] + metadata.factor_names
171
161
 
172
162
  return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
@@ -138,43 +138,45 @@ def diversity(
138
138
 
139
139
  >>> div_simp = diversity(metadata, method="simpson")
140
140
  >>> div_simp.diversity_index
141
- array([0.6 , 0.809, 1. , 0.8 ])
141
+ array([0.6 , 0.8 , 0.809, 1. ])
142
142
 
143
143
  >>> div_simp.classwise
144
- array([[0.5 , 0.8 , 0.8 ],
145
- [0.63 , 0.976, 0.528]])
144
+ array([[0.8 , 0.5 , 0.8 ],
145
+ [0.528, 0.63 , 0.976]])
146
146
 
147
147
  Compute Shannon diversity index of metadata and class labels
148
148
 
149
149
  >>> div_shan = diversity(metadata, method="shannon")
150
150
  >>> div_shan.diversity_index
151
- array([0.811, 0.943, 1. , 0.918])
151
+ array([0.811, 0.918, 0.943, 1. ])
152
152
 
153
153
  >>> div_shan.classwise
154
- array([[0.683, 0.918, 0.918],
155
- [0.814, 0.991, 0.764]])
154
+ array([[0.918, 0.683, 0.918],
155
+ [0.764, 0.814, 0.991]])
156
156
 
157
157
  See Also
158
158
  --------
159
159
  scipy.stats.entropy
160
160
  """
161
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
161
+ if not metadata.factor_names:
162
162
  raise ValueError("No factors found in provided metadata.")
163
163
 
164
164
  diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
165
- discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
166
- cnts = get_counts(discretized_data)
165
+ discretized_data = metadata.discretized_data
166
+ factor_names = metadata.factor_names
167
+ class_lbl = metadata.class_labels
168
+
169
+ class_labels_with_discretized_data = np.hstack((class_lbl[:, np.newaxis], discretized_data))
170
+ cnts = get_counts(class_labels_with_discretized_data)
167
171
  num_bins = np.bincount(np.nonzero(cnts)[1])
168
172
  diversity_index = diversity_fn(cnts, num_bins)
169
173
 
170
- class_lbl = metadata.class_labels
171
-
172
174
  u_classes = np.unique(class_lbl)
173
- num_factors = len(metadata.discrete_factor_names)
175
+ num_factors = len(factor_names)
174
176
  classwise_div = np.full((len(u_classes), num_factors), np.nan)
175
177
  for idx, cls in enumerate(u_classes):
176
178
  subset_mask = class_lbl == cls
177
- cls_cnts = get_counts(metadata.discrete_data[subset_mask], min_num_bins=cnts.shape[0])
179
+ cls_cnts = get_counts(discretized_data[subset_mask], min_num_bins=cnts.shape[0])
178
180
  classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
179
181
 
180
- return DiversityOutput(diversity_index, classwise_div, metadata.discrete_factor_names, metadata.class_names)
182
+ return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
@@ -242,13 +242,13 @@ def parity(metadata: Metadata) -> ParityOutput:
242
242
  >>> parity(metadata)
243
243
  ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
244
244
  """ # noqa: E501
245
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
245
+ if not metadata.factor_names:
246
246
  raise ValueError("No factors found in provided metadata.")
247
247
 
248
- chi_scores = np.zeros(metadata.discrete_data.shape[1])
248
+ chi_scores = np.zeros(metadata.discretized_data.shape[1])
249
249
  p_values = np.zeros_like(chi_scores)
250
250
  insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
251
- for i, col_data in enumerate(metadata.discrete_data.T):
251
+ for i, col_data in enumerate(metadata.discretized_data.T):
252
252
  # Builds a contingency matrix where entry at index (r,c) represents
253
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
254
254
  # at a data point with class c.
@@ -258,7 +258,7 @@ def parity(metadata: Metadata) -> ParityOutput:
258
258
  # Determines if any frequencies are too low
259
259
  counts = np.nonzero(contingency_matrix < 5)
260
260
  unique_factor_values = np.unique(col_data)
261
- current_factor_name = metadata.discrete_factor_names[i]
261
+ current_factor_name = metadata.factor_names[i]
262
262
  for int_factor, int_class in zip(counts[0], counts[1]):
263
263
  if contingency_matrix[int_factor, int_class] > 0:
264
264
  factor_category = unique_factor_values[int_factor].item()
@@ -273,11 +273,14 @@ def parity(metadata: Metadata) -> ParityOutput:
273
273
  chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
274
274
 
275
275
  if insufficient_data:
276
- warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
276
+ warnings.warn(
277
+ f"Factors {list(insufficient_data)} did not meet the recommended "
278
+ "5 occurrences for each value-label combination."
279
+ )
277
280
 
278
281
  return ParityOutput(
279
282
  score=chi_scores,
280
283
  p_value=p_values,
281
- factor_names=metadata.discrete_factor_names,
284
+ factor_names=metadata.factor_names,
282
285
  insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
283
286
  )
dataeval/outputs/_bias.py CHANGED
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import asdict, dataclass
7
- from typing import Any, Literal, TypeVar, overload
7
+ from typing import Any, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
@@ -199,53 +199,11 @@ class BalanceOutput(Output):
199
199
  factor_names: list[str]
200
200
  class_names: list[str]
201
201
 
202
- @overload
203
- def _by_factor_type(
204
- self,
205
- attr: Literal["factor_names"],
206
- factor_type: Literal["discrete", "continuous", "both"],
207
- ) -> list[str]: ...
208
-
209
- @overload
210
- def _by_factor_type(
211
- self,
212
- attr: Literal["balance", "factors", "classwise"],
213
- factor_type: Literal["discrete", "continuous", "both"],
214
- ) -> NDArray[np.float64]: ...
215
-
216
- def _by_factor_type(
217
- self,
218
- attr: Literal["balance", "factors", "classwise", "factor_names"],
219
- factor_type: Literal["discrete", "continuous", "both"],
220
- ) -> NDArray[np.float64] | list[str]:
221
- # if not filtering by factor_type then just return the requested attribute without mask
222
- if factor_type == "both":
223
- return getattr(self, attr)
224
-
225
- # create the mask for the selected factor_type
226
- mask_lambda = (
227
- (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
228
- )
229
-
230
- # return the masked attribute
231
- if attr == "factor_names":
232
- return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
233
- factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
234
- if attr == "factors":
235
- return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
236
- if attr == "balance":
237
- return self.balance[factor_type_mask]
238
- if attr == "classwise":
239
- return self.classwise[:, factor_type_mask]
240
-
241
- raise ValueError(f"Unknown attr {attr} specified.")
242
-
243
202
  def plot(
244
203
  self,
245
204
  row_labels: list[Any] | NDArray[Any] | None = None,
246
205
  col_labels: list[Any] | NDArray[Any] | None = None,
247
206
  plot_classwise: bool = False,
248
- factor_type: Literal["discrete", "continuous", "both"] = "discrete",
249
207
  ) -> Figure:
250
208
  """
251
209
  Plot a heatmap of balance information.
@@ -258,8 +216,6 @@ class BalanceOutput(Output):
258
216
  List/Array containing the labels for columns in the histogram
259
217
  plot_classwise : bool, default False
260
218
  Whether to plot per-class balance instead of global balance
261
- factor_type : "discrete", "continuous", or "both", default "discrete"
262
- Whether to plot discretized values, continuous values, or to include both
263
219
 
264
220
  Returns
265
221
  -------
@@ -273,10 +229,10 @@ class BalanceOutput(Output):
273
229
  if row_labels is None:
274
230
  row_labels = self.class_names
275
231
  if col_labels is None:
276
- col_labels = self._by_factor_type("factor_names", factor_type)
232
+ col_labels = self.factor_names
277
233
 
278
234
  fig = heatmap(
279
- self._by_factor_type("classwise", factor_type),
235
+ self.classwise,
280
236
  row_labels,
281
237
  col_labels,
282
238
  xlabel="Factors",
@@ -287,8 +243,8 @@ class BalanceOutput(Output):
287
243
  # Combine balance and factors results
288
244
  data = np.concatenate(
289
245
  [
290
- self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
291
- self._by_factor_type("factors", factor_type),
246
+ self.balance[np.newaxis, 1:],
247
+ self.factors,
292
248
  ],
293
249
  axis=0,
294
250
  )
@@ -297,7 +253,7 @@ class BalanceOutput(Output):
297
253
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
298
254
  heat_data = np.where(mask, np.nan, data)[:-1]
299
255
  # Creating label array for heat map axes
300
- heat_labels = self._by_factor_type("factor_names", factor_type)
256
+ heat_labels = self.factor_names
301
257
 
302
258
  if row_labels is None:
303
259
  row_labels = heat_labels[:-1]
@@ -377,7 +333,7 @@ class DiversityOutput(Output):
377
333
  import matplotlib.pyplot as plt
378
334
 
379
335
  fig, ax = plt.subplots(figsize=(8, 8))
380
- heat_labels = np.concatenate((["class"], self.factor_names))
336
+ heat_labels = ["class_labels"] + self.factor_names
381
337
  ax.bar(heat_labels, self.diversity_index)
382
338
  ax.set_xlabel("Factors")
383
339
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
@@ -19,7 +19,7 @@ def _validate_data(
19
19
  images: Array | Sequence[Array],
20
20
  labels: Array | Sequence[int] | Sequence[Array] | Sequence[Sequence[int]],
21
21
  bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]] | None,
22
- metadata: Sequence[dict[str, Any]] | None,
22
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
23
23
  ) -> None:
24
24
  # Validate inputs
25
25
  dataset_len = len(images)
@@ -30,7 +30,13 @@ def _validate_data(
30
30
  raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
31
31
  if bboxes is not None and len(bboxes) != dataset_len:
32
32
  raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
33
- if metadata is not None and len(metadata) != dataset_len:
33
+ if metadata is not None and (
34
+ len(metadata) != dataset_len
35
+ if isinstance(metadata, Sequence)
36
+ else any(
37
+ not isinstance(metadatum, Sequence) or len(metadatum) != dataset_len for metadatum in metadata.values()
38
+ )
39
+ ):
34
40
  raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
35
41
 
36
42
  if datum_type == "ic":
@@ -56,6 +62,14 @@ def _validate_data(
56
62
  raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
57
63
 
58
64
 
65
+ def _listify_metadata(
66
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
67
+ ) -> Sequence[dict[str, Any]] | None:
68
+ if isinstance(metadata, dict):
69
+ return [{k: v[i] for k, v in metadata.items()} for i in range(len(next(iter(metadata.values()))))]
70
+ return metadata
71
+
72
+
59
73
  def _find_max(arr: ArrayLike) -> Any:
60
74
  if not isinstance(arr, (bytes, str)) and isinstance(arr, (Iterable, Sequence, Array)):
61
75
  if isinstance(arr[0], (Iterable, Sequence, Array)):
@@ -175,7 +189,7 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
175
189
  def to_image_classification_dataset(
176
190
  images: Array | Sequence[Array],
177
191
  labels: Array | Sequence[int],
178
- metadata: Sequence[dict[str, Any]] | None,
192
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
179
193
  classes: Sequence[str] | None,
180
194
  name: str | None = None,
181
195
  ) -> ImageClassificationDataset:
@@ -188,7 +202,7 @@ def to_image_classification_dataset(
188
202
  The images to use in the dataset.
189
203
  labels : Array | Sequence[int]
190
204
  The labels to use in the dataset.
191
- metadata : Sequence[dict[str, Any]] | None
205
+ metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
192
206
  The metadata to use in the dataset.
193
207
  classes : Sequence[str] | None
194
208
  The classes to use in the dataset.
@@ -198,14 +212,14 @@ def to_image_classification_dataset(
198
212
  ImageClassificationDataset
199
213
  """
200
214
  _validate_data("ic", images, labels, None, metadata)
201
- return CustomImageClassificationDataset(images, labels, metadata, classes, name)
215
+ return CustomImageClassificationDataset(images, labels, _listify_metadata(metadata), classes, name)
202
216
 
203
217
 
204
218
  def to_object_detection_dataset(
205
219
  images: Array | Sequence[Array],
206
220
  labels: Array | Sequence[Array] | Sequence[Sequence[int]],
207
221
  bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
208
- metadata: Sequence[dict[str, Any]] | None,
222
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
209
223
  classes: Sequence[str] | None,
210
224
  name: str | None = None,
211
225
  ) -> ObjectDetectionDataset:
@@ -220,7 +234,7 @@ def to_object_detection_dataset(
220
234
  The labels to use in the dataset.
221
235
  bboxes : Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]]
222
236
  The bounding boxes (x0,y0,x1,y0) to use in the dataset.
223
- metadata : Sequence[dict[str, Any]] | None
237
+ metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
224
238
  The metadata to use in the dataset.
225
239
  classes : Sequence[str] | None
226
240
  The classes to use in the dataset.
@@ -230,4 +244,4 @@ def to_object_detection_dataset(
230
244
  ObjectDetectionDataset
231
245
  """
232
246
  _validate_data("od", images, labels, bboxes, metadata)
233
- return CustomObjectDetectionDataset(images, labels, bboxes, metadata, classes, name)
247
+ return CustomObjectDetectionDataset(images, labels, bboxes, _listify_metadata(metadata), classes, name)
@@ -183,9 +183,11 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
183
183
  boxes: list[list[float]] = []
184
184
  with open(annotation) as f:
185
185
  for line in f.readlines():
186
- out = line.strip().split(" ")
186
+ out = line.strip().split()
187
187
  labels.append(int(out[0]))
188
+
188
189
  xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
190
+
189
191
  x0 = xcenter - width / 2
190
192
  x1 = x0 + width
191
193
  y0 = ycenter - height / 2
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.86.1
3
+ Version: 0.86.2
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -29,6 +29,7 @@ Requires-Dist: numba (>=0.59.1)
29
29
  Requires-Dist: numpy (>=1.24.2)
30
30
  Requires-Dist: pandas (>=2.0)
31
31
  Requires-Dist: pillow (>=10.3.0)
32
+ Requires-Dist: polars (>=1.0.0)
32
33
  Requires-Dist: requests
33
34
  Requires-Dist: scikit-learn (>=1.5.0)
34
35
  Requires-Dist: scipy (>=1.10)
@@ -1,12 +1,12 @@
1
- dataeval/__init__.py,sha256=oC55_G8B7aR_QRKVy5fQtolW71aKDzMSixWge3cHn3M,1636
1
+ dataeval/__init__.py,sha256=7Q_nGiQN6g8Le7VtOsemNgn5mC_6gR3NhazolD_arSQ,1636
2
2
  dataeval/_log.py,sha256=C7AGkIRzymvYJ0LQXtnShiy3i5Xrp8T58JzIHHguk_Q,365
3
3
  dataeval/config.py,sha256=hjad0TK1UmaKQlUuxqxt64_OAUqZkHjicBf06cvTyrQ,4082
4
4
  dataeval/data/__init__.py,sha256=qNnRRiVP_sLthkkHpUrMgI_r8dQK-cC-xoGrrjQeRKc,544
5
5
  dataeval/data/_embeddings.py,sha256=PFjpdV9bfusCB4taTIYSzx1hP8nJb_KCkZTN8kMw-Hs,12885
6
6
  dataeval/data/_images.py,sha256=3d4Cv-xg5z6_LVtw1eL_QdFwzbDI1cwvPNQblkrMEMk,2622
7
- dataeval/data/_metadata.py,sha256=va5coOR1rRVzFB9SGzmuTj-Oaexs9LobGMA7u2An_eY,15420
7
+ dataeval/data/_metadata.py,sha256=GzXtecy7EvrB3ZJJbaCQjmpsdHXRL5788ckKbzeI54w,14994
8
8
  dataeval/data/_selection.py,sha256=r06xeiyK8nTWPLyItkoPQRWZI1i6LATSue_cuEbCdc4,4463
9
- dataeval/data/_split.py,sha256=pSyeJVW2sDoTU9wyi0d7UWqDuPhYvDyEgA0BUldS9Vg,16743
9
+ dataeval/data/_split.py,sha256=nQABR05vxil2Qx7-uX4Fm0_DWpibskBGDJOYj_b1u3I,16737
10
10
  dataeval/data/_targets.py,sha256=pXrHBwT4Pi8DauaOxDVnIMwowWWlXuvSb07ShW7O2zk,3119
11
11
  dataeval/data/selections/__init__.py,sha256=2m8ZB53wXzqLcqmc6p5atO6graB6ZyiRSNJFxf11X_g,613
12
12
  dataeval/data/selections/_classbalance.py,sha256=7v8ApoL3X8eCZ6fGDNTehE_bZ1loaP3TlhsJLaICVWg,1458
@@ -39,16 +39,16 @@ dataeval/detectors/ood/ae.py,sha256=fTrUfFxv6xUqzKpwMC8rW3JrizA16M_bgzqLuBKMrS0,
39
39
  dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
40
40
  dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
41
41
  dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
42
- dataeval/metadata/_distance.py,sha256=T1Umju_QwBiLmn1iUbxZagzBS2VnHaDIdp6j-NpaZuk,4076
43
- dataeval/metadata/_ood.py,sha256=lnKtKModArnUrAhH_XswEtUAhUkh1U_oNsLt1UmNP44,12748
42
+ dataeval/metadata/_distance.py,sha256=AABrGoQyD13z9Fqlz3NyfX0Iow_vjBwAugIv6OSRTTE,4187
43
+ dataeval/metadata/_ood.py,sha256=lNPHouj_9WfM_uTtsaiRaPn46RcVy3YebD1c32vDj-c,8981
44
44
  dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
45
45
  dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
46
46
  dataeval/metrics/bias/__init__.py,sha256=329S1_3WnWqeU4-qVcbe0fMy4lDrj9uKslWHIQf93yg,839
47
- dataeval/metrics/bias/_balance.py,sha256=l1hTVkVwD85bP20MTthA-I5BkvbytylQkJu3Q6iTuPA,6152
47
+ dataeval/metrics/bias/_balance.py,sha256=FcMOA3ge-sQ-0Id2E0K_6hTjNAV3ejJhlB5r4lxlJWI,5519
48
48
  dataeval/metrics/bias/_completeness.py,sha256=BysXU2Jpw33n5dl3acJFEqF3mFGiJLsfG4n5Q2fkTaY,4608
49
49
  dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
50
- dataeval/metrics/bias/_diversity.py,sha256=B_qWVDMZfh818U0qVm8yidquB0H0XvW8N75OWVWXy2g,5814
51
- dataeval/metrics/bias/_parity.py,sha256=PkU3wa77Iyif3McjA510fifTBaph7eJ8iAlI2jQngEM,11374
50
+ dataeval/metrics/bias/_diversity.py,sha256=25udDKmel9IjeVT5nM4dOa1apda66QdRxBc922yuUvI,5830
51
+ dataeval/metrics/bias/_parity.py,sha256=OHUSHPOeC8e1I3acALHbQv5bK4V7SqAT7ds9gNVNzSU,11371
52
52
  dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
53
53
  dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
54
54
  dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
@@ -65,7 +65,7 @@ dataeval/metrics/stats/_pixelstats.py,sha256=5RCQh0OQkHiCkn3DgCPVxKoFfifX_FOtwsn
65
65
  dataeval/metrics/stats/_visualstats.py,sha256=0k6bvAL_d66nQMfG7bydCOFJb7B0dhgG7fqCjVTp1sg,3707
66
66
  dataeval/outputs/__init__.py,sha256=geHB5M3QOiFFaQGV4ZwDTTKpqZPvPePbqG7lzaPhaXQ,1741
67
67
  dataeval/outputs/_base.py,sha256=7KRWFIEw0UHdhb1em92bPE1YqbMYumAW1QD0QfPwVLc,5900
68
- dataeval/outputs/_bias.py,sha256=EjJ6jrxDEJYgUj11EyUhdQvdCUSNeefMe5uD3E73GIo,12261
68
+ dataeval/outputs/_bias.py,sha256=W5QWjtZzMfCaztw6lf0VTZsuSDrNgCcdAvNx6P4fIAo,10254
69
69
  dataeval/outputs/_drift.py,sha256=rKn5vqMR6XNujgSqfHsH76oFkoGsUusquZL2Qy4Ae6Y,4581
70
70
  dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
71
71
  dataeval/outputs/_linters.py,sha256=3vI8zsSF-JecQut500A629sICidQLWqhEZcj7o7_cfs,6554
@@ -86,7 +86,7 @@ dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
86
86
  dataeval/utils/_mst.py,sha256=bLmJmu_1Dtj3hC5gQp3oAiJ_7TKtEjahTqusVRRU4eI,2168
87
87
  dataeval/utils/_plot.py,sha256=zP0bEvtrLdws7r1Jte8Camq-q5K5F6T8iuv3bStnEJc,7116
88
88
  dataeval/utils/data/__init__.py,sha256=xGzrjrOxOP2DP1tU84AWMKPnSxFvSjM81CTlDg4rNM8,331
89
- dataeval/utils/data/_dataset.py,sha256=5Yt7PzNeeUgm3qy71B_IOW7mKyCfvv8AIqs7Xzv7B9Q,8853
89
+ dataeval/utils/data/_dataset.py,sha256=CFK9h-XPN7J-iF2nXol6keMDbGm6VIweFAMAjXRUlhg,9527
90
90
  dataeval/utils/data/collate.py,sha256=5egEEKhNNCGeNLChO1p6dZ4Wg6x51VEaMNHz7hEZUxI,3936
91
91
  dataeval/utils/data/metadata.py,sha256=L1c2bCiMj0aR0QCoKkjwBujIftJDEMgW_3ZbgeS8WHo,14703
92
92
  dataeval/utils/datasets/__init__.py,sha256=pAXqHX76yAoBI8XB3m6zGuW-u3s3PCoIXG5GDzxH7Zs,572
@@ -94,7 +94,7 @@ dataeval/utils/datasets/_antiuav.py,sha256=kA_ia1fYNcJiz9SpCvh-Z8iSc7iJrdogjBI3s
94
94
  dataeval/utils/datasets/_base.py,sha256=pyfpJda3ku469M3TFRsJn9S2oAiQODOGTlLcdcoEW9U,9031
95
95
  dataeval/utils/datasets/_cifar10.py,sha256=hZc_A30yKYBbv2kvVdEkZ9egyEe6XBUnmksoIAoJ-5Y,8265
96
96
  dataeval/utils/datasets/_fileio.py,sha256=OASFA9uX3KgfyPb5vza12BlZyAi9Y8Al9lUR_IYPcsM,5449
97
- dataeval/utils/datasets/_milco.py,sha256=O4w4Z97tdGU-_us09lPrMNpcPLsXXbKkyPYAWzzvPc4,7870
97
+ dataeval/utils/datasets/_milco.py,sha256=iXf4C1I3Eg_3gHKUe4XPi21yFMBO51zxTIqAkGf9bYg,7869
98
98
  dataeval/utils/datasets/_mixin.py,sha256=S8iii-SoYUsFFYNXjw2thlZkpBvRLnZ4XI8wTqOKXgU,1729
99
99
  dataeval/utils/datasets/_mnist.py,sha256=uz46sE1Go3TgGjG6x2cXckSVQ0mSg2mhgk8BUvLWjb0,8149
100
100
  dataeval/utils/datasets/_ships.py,sha256=6U04HAoM3jgLl1qv-NnxjZeSsBipcqWJBMhBMn5iIUY,5115
@@ -108,7 +108,7 @@ dataeval/utils/torch/models.py,sha256=1idpXyjrYcCBSsbxxRUOto8xr4MJNjDEqQHiIXVU5Z
108
108
  dataeval/utils/torch/trainer.py,sha256=Oc2lK13uPGhmLYbmAqlPWyKxgG4YJFlnSXCqFHUZbdA,5528
109
109
  dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
110
110
  dataeval/workflows/sufficiency.py,sha256=j-R8dg4XE6a66p_oTXG2GNzgg3vGk85CTblxhFXaxog,8513
111
- dataeval-0.86.1.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
112
- dataeval-0.86.1.dist-info/METADATA,sha256=k9tNiWEDBXit4KU6le2vb1CrArZNxssiW5LHXtVXo0A,5321
113
- dataeval-0.86.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
114
- dataeval-0.86.1.dist-info/RECORD,,
111
+ dataeval-0.86.2.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
112
+ dataeval-0.86.2.dist-info/METADATA,sha256=6y6bI8GBv_VjBs1mpjAZJ9R5UBTKT7RHQRRUGJdyPCk,5353
113
+ dataeval-0.86.2.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
114
+ dataeval-0.86.2.dist-info/RECORD,,