dataeval 0.86.1__py3-none-any.whl → 0.86.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.86.1"
11
+ __version__ = "0.86.3"
12
12
 
13
13
  import logging
14
14
 
dataeval/data/__init__.py CHANGED
@@ -6,7 +6,6 @@ __all__ = [
6
6
  "Metadata",
7
7
  "Select",
8
8
  "SplitDatasetOutput",
9
- "Targets",
10
9
  "split_dataset",
11
10
  ]
12
11
 
@@ -15,5 +14,4 @@ from dataeval.data._images import Images
15
14
  from dataeval.data._metadata import Metadata
16
15
  from dataeval.data._selection import Select
17
16
  from dataeval.data._split import split_dataset
18
- from dataeval.data._targets import Targets
19
17
  from dataeval.outputs._utils import SplitDatasetOutput
@@ -3,52 +3,32 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
6
+ from dataclasses import dataclass
7
+ from typing import Any, Iterable, Literal, Mapping, Sequence
7
8
 
8
9
  import numpy as np
10
+ import polars as pl
9
11
  from numpy.typing import NDArray
10
12
 
11
13
  from dataeval.typing import (
12
14
  AnnotatedDataset,
13
15
  Array,
14
- ArrayLike,
15
16
  ObjectDetectionTarget,
16
17
  )
17
- from dataeval.utils._array import as_numpy, to_numpy
18
- from dataeval.utils._bin import bin_data, digitize_data, is_continuous
18
+ from dataeval.utils._array import as_numpy
19
+ from dataeval.utils._bin import bin_data, digitize_data
19
20
  from dataeval.utils.data.metadata import merge
20
21
 
21
- if TYPE_CHECKING:
22
- from dataeval.data import Targets
23
- else:
24
- from dataeval.data._targets import Targets
22
+
23
+ @dataclass
24
+ class FactorInfo:
25
+ factor_type: Literal["categorical", "continuous", "discrete"] | None = None
26
+ discretized_col: str | None = None
25
27
 
26
28
 
27
29
  class Metadata:
28
30
  """
29
- Class containing binned metadata.
30
-
31
- Attributes
32
- ----------
33
- discrete_factor_names : list[str]
34
- List containing factor names for the original data that was discrete and
35
- the binned continuous data
36
- discrete_data : NDArray[np.int64]
37
- Array containing values for the original data that was discrete and the
38
- binned continuous data
39
- continuous_factor_names : list[str]
40
- List containing factor names for the original continuous data
41
- continuous_data : NDArray[np.float64] | None
42
- Array containing values for the original continuous data or None if there
43
- was no continuous data
44
- class_labels : NDArray[np.int]
45
- Numerical class labels for the images/objects
46
- class_names : list[str]
47
- List of unique class names
48
- total_num_factors : int
49
- Sum of discrete_factor_names and continuous_factor_names plus 1 for class
50
- image_indices : NDArray[np.intp]
51
- Array of the image index that is mapped by the index of the factor
31
+ Class containing binned metadata using Polars DataFrames.
52
32
 
53
33
  Parameters
54
34
  ----------
@@ -66,20 +46,27 @@ class Metadata:
66
46
 
67
47
  def __init__(
68
48
  self,
69
- dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
49
+ dataset: AnnotatedDataset[tuple[Any, Any, Mapping[str, Any]]],
70
50
  *,
71
51
  continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
72
52
  auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
73
53
  exclude: Sequence[str] | None = None,
74
54
  include: Sequence[str] | None = None,
75
55
  ) -> None:
76
- self._collated = False
77
- self._merged = None
78
- self._processed = False
56
+ self._class_labels: NDArray[np.intp]
57
+ self._class_names: list[str]
58
+ self._image_indices: NDArray[np.intp]
59
+ self._factors: dict[str, FactorInfo]
60
+ self._dropped_factors: dict[str, list[str]]
61
+ self._dataframe: pl.DataFrame
62
+ self._raw: Sequence[Mapping[str, Any]]
63
+
64
+ self._is_structured = False
65
+ self._is_binned = False
79
66
 
80
67
  self._dataset = dataset
81
68
  self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
82
- self._auto_bin_method = auto_bin_method
69
+ self._auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = auto_bin_method
83
70
 
84
71
  if exclude is not None and include is not None:
85
72
  raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
@@ -88,17 +75,14 @@ class Metadata:
88
75
  self._include = set(include or ())
89
76
 
90
77
  @property
91
- def targets(self) -> Targets:
92
- self._collate()
93
- return self._targets
94
-
95
- @property
96
- def raw(self) -> list[dict[str, Any]]:
97
- self._collate()
78
+ def raw(self) -> Sequence[Mapping[str, Any]]:
79
+ """The raw list of metadata dictionaries for the dataset."""
80
+ self._structure()
98
81
  return self._raw
99
82
 
100
83
  @property
101
84
  def exclude(self) -> set[str]:
85
+ """Factors to exclude from the metadata."""
102
86
  return self._exclude
103
87
 
104
88
  @exclude.setter
@@ -107,10 +91,11 @@ class Metadata:
107
91
  if self._exclude != exclude:
108
92
  self._exclude = exclude
109
93
  self._include = set()
110
- self._processed = False
94
+ self._is_binned = False
111
95
 
112
96
  @property
113
97
  def include(self) -> set[str]:
98
+ """Factors to include from the metadata."""
114
99
  return self._include
115
100
 
116
101
  @include.setter
@@ -119,88 +104,112 @@ class Metadata:
119
104
  if self._include != include:
120
105
  self._include = include
121
106
  self._exclude = set()
122
- self._processed = False
107
+ self._is_binned = False
123
108
 
124
109
  @property
125
110
  def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
111
+ """Map of factor names to bin counts or bin edges."""
126
112
  return self._continuous_factor_bins
127
113
 
128
114
  @continuous_factor_bins.setter
129
115
  def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
130
116
  if self._continuous_factor_bins != bins:
131
117
  self._continuous_factor_bins = dict(bins)
132
- self._processed = False
118
+ self._reset_bins(bins)
133
119
 
134
120
  @property
135
- def auto_bin_method(self) -> str:
121
+ def auto_bin_method(self) -> Literal["uniform_width", "uniform_count", "clusters"]:
122
+ """Binning method to use when continuous_factor_bins is not defined."""
136
123
  return self._auto_bin_method
137
124
 
138
125
  @auto_bin_method.setter
139
126
  def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
140
127
  if self._auto_bin_method != method:
141
128
  self._auto_bin_method = method
142
- self._processed = False
129
+ self._reset_bins()
143
130
 
144
131
  @property
145
- def merged(self) -> dict[str, Any]:
146
- self._merge()
147
- return {} if self._merged is None else self._merged[0]
132
+ def dataframe(self) -> pl.DataFrame:
133
+ """Dataframe containing target information and metadata factors."""
134
+ self._structure()
135
+ return self._dataframe
148
136
 
149
137
  @property
150
- def dropped_factors(self) -> dict[str, list[str]]:
151
- self._merge()
152
- return {} if self._merged is None else self._merged[1]
138
+ def dropped_factors(self) -> Mapping[str, Sequence[str]]:
139
+ """Factors that were dropped during preprocessing and the reasons why they were dropped."""
140
+ self._structure()
141
+ return self._dropped_factors
153
142
 
154
143
  @property
155
- def discrete_factor_names(self) -> list[str]:
156
- self._process()
157
- return self._discrete_factor_names
144
+ def discretized_data(self) -> NDArray[np.int64]:
145
+ """Factor data with continuous data discretized."""
146
+ if not self.factor_names:
147
+ return np.array([], dtype=np.int64)
148
+
149
+ self._bin()
150
+ return (
151
+ self.dataframe.select([info.discretized_col or name for name, info in self.factor_info.items()])
152
+ .to_numpy()
153
+ .astype(np.int64)
154
+ )
158
155
 
159
156
  @property
160
- def discrete_data(self) -> NDArray[np.int64]:
161
- self._process()
162
- return self._discrete_data
157
+ def factor_names(self) -> Sequence[str]:
158
+ """Factor names of the metadata."""
159
+ self._structure()
160
+ return list(self._factors)
163
161
 
164
162
  @property
165
- def continuous_factor_names(self) -> list[str]:
166
- self._process()
167
- return self._continuous_factor_names
163
+ def factor_info(self) -> Mapping[str, FactorInfo]:
164
+ """Factor types of the metadata."""
165
+ self._bin()
166
+ return self._factors
168
167
 
169
168
  @property
170
- def continuous_data(self) -> NDArray[np.float64]:
171
- self._process()
172
- return self._continuous_data
169
+ def factor_data(self) -> NDArray[Any]:
170
+ """Factor data as a NumPy array."""
171
+ if not self.factor_names:
172
+ return np.array([], dtype=np.float64)
173
+
174
+ # Extract continuous columns and convert to NumPy array
175
+ return self.dataframe.select(self.factor_names).to_numpy()
173
176
 
174
177
  @property
175
178
  def class_labels(self) -> NDArray[np.intp]:
176
- self._collate()
179
+ """Class labels as a NumPy array."""
180
+ self._structure()
177
181
  return self._class_labels
178
182
 
179
183
  @property
180
- def class_names(self) -> list[str]:
181
- self._collate()
184
+ def class_names(self) -> Sequence[str]:
185
+ """Class names as a list of strings."""
186
+ self._structure()
182
187
  return self._class_names
183
188
 
184
- @property
185
- def total_num_factors(self) -> int:
186
- self._process()
187
- return self._total_num_factors
188
-
189
189
  @property
190
190
  def image_indices(self) -> NDArray[np.intp]:
191
- self._process()
191
+ """Indices of images as a NumPy array."""
192
+ self._bin()
192
193
  return self._image_indices
193
194
 
194
195
  @property
195
196
  def image_count(self) -> int:
196
- self._process()
197
+ self._bin()
197
198
  return int(self._image_indices.max() + 1)
198
199
 
199
- def _collate(self, force: bool = False) -> None:
200
- if self._collated and not force:
200
+ def _reset_bins(self, cols: Iterable[str] | None = None) -> None:
201
+ if self._is_binned:
202
+ columns = self._dataframe.columns
203
+ for col in (col for col in cols or columns if f"{col}[|]" in columns):
204
+ self._dataframe.drop_in_place(f"{col}[|]")
205
+ self._factors[col] = FactorInfo()
206
+ self._is_binned = False
207
+
208
+ def _structure(self) -> None:
209
+ if self._is_structured:
201
210
  return
202
211
 
203
- raw: list[dict[str, Any]] = []
212
+ raw: Sequence[Mapping[str, Any]] = []
204
213
 
205
214
  labels = []
206
215
  bboxes = []
@@ -235,134 +244,106 @@ class Metadata:
235
244
  bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
236
245
  srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
237
246
 
238
- self._targets = Targets(labels, scores, bboxes, srcidx)
239
- self._raw = raw
240
-
241
247
  index2label = self._dataset.metadata.get("index2label", {})
242
- self._class_labels = self._targets.labels
243
- self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
244
- self._collated = True
245
248
 
246
- def _merge(self, force: bool = False) -> None:
247
- if self._merged is not None and not force:
249
+ targets_per_image = None if srcidx is None else np.unique(srcidx, return_counts=True)[1].tolist()
250
+ merged = merge(raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
251
+
252
+ reserved = ["image_index", "class_label", "score", "box"]
253
+ factor_dict = {f"metadata_{k}" if k in reserved else k: v for k, v in merged[0].items() if k != "_image_index"}
254
+
255
+ target_dict = {
256
+ "image_index": srcidx if srcidx is not None else np.arange(len(labels)),
257
+ "class_label": labels,
258
+ "score": scores,
259
+ "box": bboxes if bboxes is not None else [None] * len(labels),
260
+ }
261
+
262
+ self._raw = raw
263
+ self._class_labels = labels
264
+ self._class_names = [index2label.get(i, str(i)) for i in np.unique(labels)]
265
+ self._image_indices = target_dict["image_index"]
266
+ self._factors = dict.fromkeys(factor_dict, FactorInfo())
267
+ self._dataframe = pl.DataFrame({**target_dict, **factor_dict})
268
+ self._dropped_factors = merged[1]
269
+ self._is_structured = True
270
+
271
+ def _bin(self) -> None:
272
+ """Populate factor info and bin non-categorical factors."""
273
+ if self._is_binned:
248
274
  return
249
275
 
250
- targets_per_image = (
251
- None if self.targets.source is None else np.unique(self.targets.source, return_counts=True)[1].tolist()
252
- )
253
- self._merged = merge(self.raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
276
+ # Start with an empty set of factor info
277
+ factor_info: dict[str, FactorInfo] = {}
254
278
 
255
- def _validate(self) -> None:
256
- # Check that metadata is a single, flattened dictionary with uniform array lengths
257
- check_length = None
258
- if self._targets.labels.ndim > 1:
259
- raise ValueError(
260
- f"Got class labels with {self._targets.labels.ndim}-dimensional "
261
- f"shape {self._targets.labels.shape}, but expected a 1-dimensional array."
262
- )
263
- for v in self.merged.values():
264
- if not isinstance(v, (list, tuple, np.ndarray)):
265
- raise TypeError(
266
- "Metadata dictionary needs to be a single dictionary whose values "
267
- "are arraylike containing the metadata on a per image or per object basis."
268
- )
269
- check_length = len(v) if check_length is None else check_length
270
- if check_length != len(v):
271
- raise ValueError(
272
- "The lists/arrays in the metadata dict have varying lengths. "
273
- "Metadata requires them to be uniform in length."
274
- )
275
- if len(self._class_labels) != check_length:
276
- raise ValueError(
277
- f"The length of the label array {len(self._class_labels)} is not the same as "
278
- f"the length of the metadata arrays {check_length}."
279
- )
279
+ # Create a mutable DataFrame for updates
280
+ df = self.dataframe.clone()
281
+ factor_bins = self.continuous_factor_bins
280
282
 
281
- def _filter(self, d: Mapping[str, Any]) -> dict[str, Any]:
282
- return (
283
- {k: d[k] for k in self.include if k in d} if self.include else {k: d[k] for k in d if k not in self.exclude}
284
- )
283
+ # Check for invalid keys
284
+ invalid_keys = set(factor_bins.keys()) - set(df.columns)
285
+ if invalid_keys:
286
+ warnings.warn(
287
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
288
+ "but are not columns in the metadata DataFrame. Unknown keys will be ignored."
289
+ )
285
290
 
286
- def _split_continuous_discrete(
287
- self, metadata: dict[str, NDArray[Any]], continuous_factor_bins: dict[str, int | Sequence[float]]
288
- ) -> tuple[dict[str, NDArray[Any]], dict[str, NDArray[np.int64]]]:
289
- # Bin according to user supplied bins
290
- continuous_metadata = {}
291
- discrete_metadata = {}
292
- if continuous_factor_bins:
293
- invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
294
- if invalid_keys:
295
- raise KeyError(
296
- f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
297
- "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
298
- "or add corresponding entries to the `metadata` dictionary."
299
- )
300
- for factor, bins in continuous_factor_bins.items():
301
- discrete_metadata[factor] = digitize_data(metadata[factor], bins)
302
- continuous_metadata[factor] = metadata[factor]
303
-
304
- # Determine category of the rest of the keys
305
- remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
306
- for key in remaining_keys:
307
- data = to_numpy(metadata[key])
308
- if np.issubdtype(data.dtype, np.number):
309
- result = is_continuous(data, self._image_indices)
310
- if result:
311
- continuous_metadata[key] = data
312
- unique_samples, ordinal_data = np.unique(data, return_inverse=True)
313
- if unique_samples.size <= np.max([20, data.size * 0.01]):
314
- discrete_metadata[key] = ordinal_data
315
- else:
291
+ column_set = set(df.columns)
292
+ for col in (col for col in self.factor_names if f"{col}[|]" not in column_set):
293
+ # Get data as numpy array for processing
294
+ data = df[col].to_numpy()
295
+ col_dz = f"{col}[|]"
296
+ if col in factor_bins:
297
+ # User provided binning
298
+ bins = factor_bins[col]
299
+ df = df.with_columns(pl.Series(name=col_dz, values=digitize_data(data, bins).astype(np.int64)))
300
+ factor_info[col] = FactorInfo("continuous", col_dz)
301
+ else:
302
+ # Check if data is numeric
303
+ unique, ordinal = np.unique(data, return_inverse=True)
304
+ if not np.issubdtype(data.dtype, np.number) or unique.size <= max(20, data.size * 0.01):
305
+ # Non-numeric data or small number of unique values - convert to categorical
306
+ df = df.with_columns(pl.Series(name=col_dz, values=ordinal.astype(np.int64)))
307
+ factor_info[col] = FactorInfo("categorical", col_dz)
308
+ elif data.dtype == float:
309
+ # Many unique values - discretize by binning
316
310
  warnings.warn(
317
- f"A user defined binning was not provided for {key}. "
311
+ f"A user defined binning was not provided for {col}. "
318
312
  f"Using the {self.auto_bin_method} method to discretize the data. "
319
313
  "It is recommended that the user rerun and supply the desired "
320
314
  "bins using the continuous_factor_bins parameter.",
321
315
  UserWarning,
322
316
  )
323
- discrete_metadata[key] = bin_data(data, self.auto_bin_method)
324
- else:
325
- _, discrete_metadata[key] = np.unique(data, return_inverse=True)
326
-
327
- return continuous_metadata, discrete_metadata
328
-
329
- def _process(self, force: bool = False) -> None:
330
- if self._processed and not force:
331
- return
332
-
333
- # Create image indices from targets
334
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
335
-
336
- # Validate the metadata dimensions
337
- self._validate()
317
+ # Create binned version
318
+ binned_data = bin_data(data, self.auto_bin_method)
319
+ df = df.with_columns(pl.Series(name=col_dz, values=binned_data.astype(np.int64)))
320
+ factor_info[col] = FactorInfo("continuous", col_dz)
321
+ else:
322
+ factor_info[col] = FactorInfo("discrete", col)
338
323
 
339
- # Filter the merged metadata and continuous factor bins
340
- metadata = self._filter(self.merged)
341
- continuous_factor_bins = self._filter(self.continuous_factor_bins)
324
+ # Store the results
325
+ self._dataframe = df
326
+ self._factors.update(factor_info)
327
+ self._is_binned = True
342
328
 
343
- # Remove generated "_image_index" if present
344
- metadata.pop("_image_index", None)
329
+ def get_factors_by_type(self, factor_type: Literal["categorical", "continuous", "discrete"]) -> Sequence[str]:
330
+ """
331
+ Get the names of factors of a specific type.
345
332
 
346
- # Split the metadata into continuous and discrete
347
- continuous_metadata, discrete_metadata = self._split_continuous_discrete(metadata, continuous_factor_bins)
333
+ Parameters
334
+ ----------
335
+ factor_type : Literal["categorical", "continuous", "discrete"]
336
+ The type of factors to retrieve.
348
337
 
349
- # Split out the dictionaries into the keys and values
350
- self._discrete_factor_names = list(discrete_metadata.keys())
351
- self._discrete_data = (
352
- np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
353
- if discrete_metadata
354
- else np.array([], dtype=np.int64)
355
- )
356
- self._continuous_factor_names = list(continuous_metadata.keys())
357
- self._continuous_data = (
358
- np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
359
- if continuous_metadata
360
- else np.array([], dtype=np.float64)
361
- )
362
- self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
363
- self._processed = True
338
+ Returns
339
+ -------
340
+ list[str]
341
+ List of factor names of the specified type.
342
+ """
343
+ self._bin()
344
+ return [name for name, info in self.factor_info.items() if info.factor_type == factor_type]
364
345
 
365
- def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
346
+ def add_factors(self, factors: Mapping[str, Array | Sequence[Any]]) -> None:
366
347
  """
367
348
  Add additional factors to the metadata.
368
349
 
@@ -371,23 +352,26 @@ class Metadata:
371
352
 
372
353
  Parameters
373
354
  ----------
374
- factors : Mapping[str, ArrayLike]
355
+ factors : Mapping[str, Array | Sequence[Any]]
375
356
  Dictionary of factors to add to the metadata.
376
357
  """
377
- self._merge()
358
+ self._structure()
378
359
 
379
- targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
360
+ targets = len(self.dataframe)
380
361
  images = self.image_count
381
- lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
382
- targets_match = all(f == targets for f in lengths.values())
383
- images_match = targets_match if images == targets else all(f == images for f in lengths.values())
362
+ targets_match = all(len(v) == targets for v in factors.values())
363
+ images_match = targets_match if images == targets else all(len(v) == images for v in factors.values())
384
364
  if not targets_match and not images_match:
385
365
  raise ValueError(
386
366
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
387
367
  )
388
- merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
368
+
369
+ new_columns = []
389
370
  for k, v in factors.items():
390
- v = as_numpy(v)
391
- merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
371
+ data = as_numpy(v)[self.image_indices]
372
+ new_columns.append(pl.Series(name=k, values=data))
373
+ self._factors[k] = FactorInfo()
392
374
 
393
- self._processed = False
375
+ if new_columns:
376
+ self._dataframe = self.dataframe.with_columns(new_columns)
377
+ self._is_binned = False
dataeval/data/_split.py CHANGED
@@ -207,8 +207,8 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
207
207
  return None
208
208
 
209
209
  split_set = set(split_on)
210
- indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
- binned_features = metadata.discrete_data[:, indices]
210
+ indices = [i for i, name in enumerate(metadata.factor_names) if name in split_set]
211
+ binned_features = metadata.discretized_data[:, indices]
212
212
  return np.unique(binned_features, axis=0, return_inverse=True)[1]
213
213
 
214
214
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterable, Sequence, Sized, TypeVar, cast
5
+ from typing import Any, Generic, Iterable, Mapping, Sequence, Sized, TypeVar, cast
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
@@ -92,7 +92,7 @@ class ClassFilterSubSelection(Subselection[Any]):
92
92
  def __init__(self, classes: Sequence[int]) -> None:
93
93
  self.classes = classes
94
94
 
95
- def _filter(self, d: dict[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
95
+ def _filter(self, d: Mapping[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
96
96
  return {k: self._filter(v, mask) if isinstance(v, dict) else _try_mask_object(v, mask) for k, v in d.items()}
97
97
 
98
98
  def __call__(self, datum: _TDatum) -> _TDatum:
@@ -80,14 +80,17 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
80
80
  MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
81
81
  """
82
82
 
83
- _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
84
- fnames = metadata1.continuous_factor_names
83
+ _compare_keys(metadata1.factor_names, metadata2.factor_names)
84
+ cont_fnames = metadata1.get_factors_by_type("continuous")
85
85
 
86
- cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
87
- cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
86
+ if not cont_fnames:
87
+ return MetadataDistanceOutput({})
88
88
 
89
- _validate_factors_and_data(fnames, cont1)
90
- _validate_factors_and_data(fnames, cont2)
89
+ cont1 = np.atleast_2d(metadata1.dataframe[cont_fnames].to_numpy()) # (S, F)
90
+ cont2 = np.atleast_2d(metadata2.dataframe[cont_fnames].to_numpy()) # (S, F)
91
+
92
+ _validate_factors_and_data(cont_fnames, cont1)
93
+ _validate_factors_and_data(cont_fnames, cont2)
91
94
 
92
95
  N = len(cont1)
93
96
  M = len(cont2)
@@ -104,7 +107,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
104
107
  results: dict[str, MetadataDistanceValues] = {}
105
108
 
106
109
  # Per factor
107
- for i, fname in enumerate(fnames):
110
+ for i, fname in enumerate(cont_fnames):
108
111
  fdata1 = cont1[:, i] # (S, 1)
109
112
  fdata2 = cont2[:, i] # (S, 1)
110
113