dataeval 0.86.1__py3-none-any.whl → 0.86.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/metadata/_ood.py CHANGED
@@ -15,95 +15,6 @@ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorO
15
15
  from dataeval.outputs._base import set_metadata
16
16
 
17
17
 
18
- def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
19
- """Combines the discrete and continuous data of a :class:`Metadata` object
20
-
21
- Returns
22
- -------
23
- Tuple[list[str], NDArray]
24
- The combined list of factors names and the combined discrete and continuous data
25
-
26
- Note
27
- ----
28
- Discrete and continuous data must have the same number of samples
29
- """
30
- names = []
31
- data = []
32
-
33
- if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
34
- names.extend(metadata.discrete_factor_names)
35
- data.append(metadata.discrete_data)
36
-
37
- if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
38
- names.extend(metadata.continuous_factor_names)
39
- data.append(metadata.continuous_data)
40
-
41
- return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
42
-
43
-
44
- def _combine_metadata(
45
- metadata_1: Metadata, metadata_2: Metadata
46
- ) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
47
- """
48
- Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
49
- match exactly and data has the same number of columns (factors).
50
-
51
- Parameters
52
- ----------
53
- metadata_1 : Metadata
54
- The set of factor names used as reference to determine the correct factor names and length of data
55
- metadata_2 : Metadata
56
- The compared set of factor names and data that must match metadata_1
57
-
58
- Returns
59
- -------
60
- list[str]
61
- The combined discrete and continuous factor names in that order.
62
- list[NDArray]
63
- Combined discrete and continuous data of metadata_1
64
- list[NDArray]
65
- Combined discrete and continuous data of metadata_2
66
-
67
- Raises
68
- ------
69
- ValueError
70
- If keys do not match in metadata_1 and metadata_2
71
- ValueError
72
- If the length of keys do not match the length of the data
73
- """
74
- factor_names: list[str] = []
75
- m1_data: list[NDArray[np.int64 | np.float64]] = []
76
- m2_data: list[NDArray[np.int64 | np.float64]] = []
77
-
78
- # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
79
- if metadata_1.total_num_factors != metadata_2.total_num_factors:
80
- raise ValueError(
81
- f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
82
- f"and metadata_2 ({metadata_2.total_num_factors})"
83
- )
84
-
85
- # Validate and attach discrete data
86
- if metadata_1.discrete_factor_names:
87
- _compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
88
- _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
89
-
90
- factor_names.extend(metadata_1.discrete_factor_names)
91
- m1_data.append(metadata_1.discrete_data)
92
- m2_data.append(metadata_2.discrete_data)
93
-
94
- # Validate and attach continuous data
95
- if metadata_1.continuous_factor_names:
96
- _compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
97
- _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
98
-
99
- factor_names.extend(metadata_1.continuous_factor_names)
100
- m1_data.append(metadata_1.continuous_data)
101
- m2_data.append(metadata_2.continuous_data)
102
-
103
- # Turns list of discrete and continuous into one array
104
- return factor_names, m1_data, m2_data
105
-
106
-
107
18
  def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
108
19
  """
109
20
  Calculates deviations of the test data from the median of the reference data
@@ -207,16 +118,13 @@ def find_most_deviated_factors(
207
118
  if not any(ood_mask):
208
119
  return MostDeviatedFactorsOutput([])
209
120
 
210
- # Combines reference and test factor names and data if exists and match exactly
211
- # shape -> (samples, factors)
212
- factor_names, md_1, md_2 = _combine_metadata(
213
- metadata_1=metadata_ref,
214
- metadata_2=metadata_tst,
215
- )
121
+ factor_names = metadata_ref.factor_names
122
+ ref_data = metadata_ref.factor_data
123
+ tst_data = metadata_tst.factor_data
216
124
 
217
- # Stack discrete and continuous factors as separate factors. Must have equal sample counts
218
- ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
219
- tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
125
+ _compare_keys(factor_names, metadata_tst.factor_names)
126
+ _validate_factors_and_data(factor_names, ref_data)
127
+ _validate_factors_and_data(factor_names, tst_data)
220
128
 
221
129
  if len(ref_data) < 3:
222
130
  warnings.warn(
@@ -256,6 +164,7 @@ which is what many library functions return, multiply it by _NATS2BITS to get it
256
164
  """
257
165
 
258
166
 
167
+ @set_metadata
259
168
  def find_ood_predictors(
260
169
  metadata: Metadata,
261
170
  ood: OODOutput,
@@ -305,8 +214,8 @@ def find_ood_predictors(
305
214
 
306
215
  ood_mask: NDArray[np.bool_] = ood.is_ood
307
216
 
308
- discrete_features_count = len(metadata.discrete_factor_names)
309
- factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
217
+ factors = metadata.factor_names
218
+ data = metadata.factor_data
310
219
 
311
220
  # No metadata correlated with out of distribution data, return 0.0 for all factors
312
221
  if not any(ood_mask):
@@ -320,14 +229,13 @@ def find_ood_predictors(
320
229
  # Calculate mean, std of each factor over all samples
321
230
  scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
322
231
 
323
- discrete_features = np.zeros_like(factors, dtype=np.bool_)
324
- discrete_features[:discrete_features_count] = True
232
+ discrete_features = [info.factor_type != "continuous" for info in metadata.factor_info.values()]
325
233
 
326
234
  mutual_info_values = (
327
235
  mutual_info_classif(
328
236
  X=scaled_data,
329
237
  y=ood_mask,
330
- discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
238
+ discrete_features=discrete_features, # type: ignore - sklearn function not typed
331
239
  random_state=get_seed(),
332
240
  )
333
241
  * _NATS2BITS
@@ -1,9 +1,11 @@
1
1
  __all__ = []
2
2
 
3
+ from typing import Sequence
4
+
3
5
  from numpy.typing import NDArray
4
6
 
5
7
 
6
- def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
8
+ def _compare_keys(keys1: Sequence[str], keys2: Sequence[str]) -> None:
7
9
  """
8
10
  Raises error when two lists are not equivalent including ordering
9
11
 
@@ -24,7 +26,7 @@ def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
24
26
  raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
25
27
 
26
28
 
27
- def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
29
+ def _validate_factors_and_data(factors: Sequence[str], data: NDArray) -> None:
28
30
  """
29
31
  Raises error when the number of factors and number of rows do not match
30
32
 
@@ -68,22 +68,20 @@ def balance(
68
68
 
69
69
  >>> bal = balance(metadata)
70
70
  >>> bal.balance
71
- array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
71
+ array([1. , 0.134, 0. , 0. ])
72
72
 
73
73
  Return intra/interfactor balance (mutual information)
74
74
 
75
75
  >>> bal.factors
76
- array([[1. , 0.314, 0.269, 0.852, 0.367],
77
- [0.314, 1. , 0.097, 0.158, 1.98 ],
78
- [0.269, 0.097, 1. , 0.037, 0.015],
79
- [0.852, 0.158, 0.037, 0.475, 0.255],
80
- [0.367, 1.98 , 0.015, 0.255, 1.063]])
76
+ array([[1. , 0.017, 0.015],
77
+ [0.017, 0.445, 0.245],
78
+ [0.015, 0.245, 1.063]])
81
79
 
82
80
  Return classwise balance (mutual information) of factors with individual class_labels
83
81
 
84
82
  >>> bal.classwise
85
- array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
86
- [1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
83
+ array([[1. , 0.134, 0. , 0. ],
84
+ [1. , 0.134, 0. , 0. ]])
87
85
 
88
86
 
89
87
  See Also
@@ -92,41 +90,39 @@ def balance(
92
90
  sklearn.feature_selection.mutual_info_regression
93
91
  sklearn.metrics.mutual_info_score
94
92
  """
95
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
93
+ if not metadata.factor_names:
96
94
  raise ValueError("No factors found in provided metadata.")
97
95
 
98
96
  num_neighbors = _validate_num_neighbors(num_neighbors)
99
97
 
100
- num_factors = metadata.total_num_factors
101
- is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
98
+ data = metadata.discretized_data
99
+ factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
100
+ is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
101
+ num_factors = len(factor_types)
102
+
102
103
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
103
- data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
104
- discretized_data = data
105
- if len(metadata.continuous_data):
106
- data = np.hstack((data, metadata.continuous_data))
107
- discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
108
- discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
109
-
110
- for idx in range(num_factors):
111
- if idx >= len(metadata.discrete_factor_names) + 1:
112
- mi[idx, :] = mutual_info_regression(
104
+ data = np.hstack((metadata.class_labels[:, np.newaxis], data))
105
+
106
+ for idx, factor_type in enumerate(factor_types.values()):
107
+ if factor_type != "continuous":
108
+ mi[idx, :] = mutual_info_classif(
113
109
  data,
114
110
  data[:, idx],
115
- discrete_features=is_discrete, # type: ignore
111
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
116
112
  n_neighbors=num_neighbors,
117
113
  random_state=get_seed(),
118
114
  )
119
115
  else:
120
- mi[idx, :] = mutual_info_classif(
116
+ mi[idx, :] = mutual_info_regression(
121
117
  data,
122
118
  data[:, idx],
123
- discrete_features=is_discrete, # type: ignore
119
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
124
120
  n_neighbors=num_neighbors,
125
121
  random_state=get_seed(),
126
122
  )
127
123
 
128
124
  # Normalization via entropy
129
- bin_cnts = get_counts(discretized_data)
125
+ bin_cnts = get_counts(data)
130
126
  ent_factor = sp.stats.entropy(bin_cnts, axis=0)
131
127
  norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
132
128
 
@@ -149,7 +145,7 @@ def balance(
149
145
  classwise_mi[idx, :] = mutual_info_classif(
150
146
  data,
151
147
  tgt_bin[:, idx],
152
- discrete_features=is_discrete, # type: ignore
148
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
153
149
  n_neighbors=num_neighbors,
154
150
  random_state=get_seed(),
155
151
  )
@@ -161,12 +157,6 @@ def balance(
161
157
  classwise = classwise_mi / norm_factor
162
158
 
163
159
  # Grabbing factor names for plotting function
164
- factor_names = ["class"]
165
- for name in metadata.discrete_factor_names:
166
- if name in metadata.continuous_factor_names:
167
- name = name + "-discrete"
168
- factor_names.append(name)
169
- for name in metadata.continuous_factor_names:
170
- factor_names.append(name + "-continuous")
160
+ factor_names = ["class_label"] + list(metadata.factor_names)
171
161
 
172
162
  return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
@@ -138,43 +138,45 @@ def diversity(
138
138
 
139
139
  >>> div_simp = diversity(metadata, method="simpson")
140
140
  >>> div_simp.diversity_index
141
- array([0.6 , 0.809, 1. , 0.8 ])
141
+ array([0.6 , 0.8 , 0.809, 1. ])
142
142
 
143
143
  >>> div_simp.classwise
144
- array([[0.5 , 0.8 , 0.8 ],
145
- [0.63 , 0.976, 0.528]])
144
+ array([[0.8 , 0.5 , 0.8 ],
145
+ [0.528, 0.63 , 0.976]])
146
146
 
147
147
  Compute Shannon diversity index of metadata and class labels
148
148
 
149
149
  >>> div_shan = diversity(metadata, method="shannon")
150
150
  >>> div_shan.diversity_index
151
- array([0.811, 0.943, 1. , 0.918])
151
+ array([0.811, 0.918, 0.943, 1. ])
152
152
 
153
153
  >>> div_shan.classwise
154
- array([[0.683, 0.918, 0.918],
155
- [0.814, 0.991, 0.764]])
154
+ array([[0.918, 0.683, 0.918],
155
+ [0.764, 0.814, 0.991]])
156
156
 
157
157
  See Also
158
158
  --------
159
159
  scipy.stats.entropy
160
160
  """
161
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
161
+ if not metadata.factor_names:
162
162
  raise ValueError("No factors found in provided metadata.")
163
163
 
164
164
  diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
165
- discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
166
- cnts = get_counts(discretized_data)
165
+ discretized_data = metadata.discretized_data
166
+ factor_names = metadata.factor_names
167
+ class_lbl = metadata.class_labels
168
+
169
+ class_labels_with_discretized_data = np.hstack((class_lbl[:, np.newaxis], discretized_data))
170
+ cnts = get_counts(class_labels_with_discretized_data)
167
171
  num_bins = np.bincount(np.nonzero(cnts)[1])
168
172
  diversity_index = diversity_fn(cnts, num_bins)
169
173
 
170
- class_lbl = metadata.class_labels
171
-
172
174
  u_classes = np.unique(class_lbl)
173
- num_factors = len(metadata.discrete_factor_names)
175
+ num_factors = len(factor_names)
174
176
  classwise_div = np.full((len(u_classes), num_factors), np.nan)
175
177
  for idx, cls in enumerate(u_classes):
176
178
  subset_mask = class_lbl == cls
177
- cls_cnts = get_counts(metadata.discrete_data[subset_mask], min_num_bins=cnts.shape[0])
179
+ cls_cnts = get_counts(discretized_data[subset_mask], min_num_bins=cnts.shape[0])
178
180
  classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
179
181
 
180
- return DiversityOutput(diversity_index, classwise_div, metadata.discrete_factor_names, metadata.class_names)
182
+ return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
@@ -242,13 +242,13 @@ def parity(metadata: Metadata) -> ParityOutput:
242
242
  >>> parity(metadata)
243
243
  ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
244
244
  """ # noqa: E501
245
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
245
+ if not metadata.factor_names:
246
246
  raise ValueError("No factors found in provided metadata.")
247
247
 
248
- chi_scores = np.zeros(metadata.discrete_data.shape[1])
248
+ chi_scores = np.zeros(metadata.discretized_data.shape[1])
249
249
  p_values = np.zeros_like(chi_scores)
250
250
  insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
251
- for i, col_data in enumerate(metadata.discrete_data.T):
251
+ for i, col_data in enumerate(metadata.discretized_data.T):
252
252
  # Builds a contingency matrix where entry at index (r,c) represents
253
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
254
254
  # at a data point with class c.
@@ -258,8 +258,9 @@ def parity(metadata: Metadata) -> ParityOutput:
258
258
  # Determines if any frequencies are too low
259
259
  counts = np.nonzero(contingency_matrix < 5)
260
260
  unique_factor_values = np.unique(col_data)
261
- current_factor_name = metadata.discrete_factor_names[i]
262
- for int_factor, int_class in zip(counts[0], counts[1]):
261
+ current_factor_name = metadata.factor_names[i]
262
+ for _factor, _class in zip(counts[0], counts[1]):
263
+ int_factor, int_class = int(_factor), int(_class)
263
264
  if contingency_matrix[int_factor, int_class] > 0:
264
265
  factor_category = unique_factor_values[int_factor].item()
265
266
  class_name = metadata.class_names[int_class]
@@ -273,11 +274,14 @@ def parity(metadata: Metadata) -> ParityOutput:
273
274
  chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
274
275
 
275
276
  if insufficient_data:
276
- warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
277
+ warnings.warn(
278
+ f"Factors {list(insufficient_data)} did not meet the recommended "
279
+ "5 occurrences for each value-label combination."
280
+ )
277
281
 
278
282
  return ParityOutput(
279
283
  score=chi_scores,
280
284
  p_value=p_values,
281
- factor_names=metadata.discrete_factor_names,
285
+ factor_names=metadata.factor_names,
282
286
  insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
283
287
  )
@@ -2,9 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from collections import Counter, defaultdict
6
5
  from typing import Any, Mapping, TypeVar
7
6
 
7
+ import polars as pl
8
+
8
9
  from dataeval.data._metadata import Metadata
9
10
  from dataeval.outputs import LabelStatsOutput
10
11
  from dataeval.outputs._base import set_metadata
@@ -52,39 +53,34 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
52
53
  pig: 2 - 2
53
54
  chicken: 5 - 5
54
55
  """
55
- dataset = Metadata(dataset) if isinstance(dataset, AnnotatedDataset) else dataset
56
-
57
- label_counts: Counter[int] = Counter()
58
- image_counts: Counter[int] = Counter()
59
- index_location = defaultdict(list[int])
60
- label_per_image: list[int] = []
61
-
62
- index2label = dict(enumerate(dataset.class_names))
63
-
64
- for i, target in enumerate(dataset.targets):
65
- group = target.labels.tolist()
56
+ metadata = Metadata(dataset) if isinstance(dataset, AnnotatedDataset) else dataset
57
+ metadata_df = metadata.dataframe
66
58
 
67
- # Count occurrences of each label in all sublists
68
- label_counts.update(group)
59
+ # Count occurrences of each label across all images
60
+ label_counts_df = metadata_df.group_by("class_label").len()
61
+ label_counts = label_counts_df.sort("class_label")["len"].to_list()
69
62
 
70
- # Get the number of labels per image
71
- label_per_image.append(len(group))
63
+ # Count unique images per label (how many images contain each label)
64
+ image_counts_df = metadata_df.select(["image_index", "class_label"]).unique().group_by("class_label").len()
65
+ image_counts = image_counts_df.sort("class_label")["len"].to_list()
72
66
 
73
- # Create a set of unique items in the current sublist
74
- unique_items: set[int] = set(group)
67
+ # Create index_location mapping (which images contain each label)
68
+ index_location: list[list[int]] = [[] for _ in range(len(metadata.class_names))]
69
+ for row in metadata_df.group_by("class_label").agg(pl.col("image_index")).to_dicts():
70
+ indices = row["image_index"]
71
+ index_location[row["class_label"]] = sorted(dict.fromkeys(indices)) if isinstance(indices, list) else [indices]
75
72
 
76
- # Update image counts and index locations
77
- image_counts.update(unique_items)
78
- for item in unique_items:
79
- index_location[item].append(i)
73
+ # Count labels per image
74
+ label_per_image_df = metadata_df.group_by("image_index").agg(pl.count().alias("label_count"))
75
+ label_per_image = label_per_image_df.sort("image_index")["label_count"].to_list()
80
76
 
81
77
  return LabelStatsOutput(
82
- label_counts_per_class=_sort_to_list(label_counts),
78
+ label_counts_per_class=label_counts,
83
79
  label_counts_per_image=label_per_image,
84
- image_counts_per_class=_sort_to_list(image_counts),
85
- image_indices_per_class=_sort_to_list(index_location),
80
+ image_counts_per_class=image_counts,
81
+ image_indices_per_class=index_location,
86
82
  image_count=len(label_per_image),
87
- class_count=len(label_counts),
88
- label_count=sum(label_counts.values()),
89
- class_names=list(index2label.values()),
83
+ class_count=len(metadata.class_names),
84
+ label_count=sum(label_counts),
85
+ class_names=metadata.class_names,
90
86
  )
dataeval/outputs/_base.py CHANGED
@@ -147,7 +147,7 @@ P = ParamSpec("P")
147
147
  R = TypeVar("R", bound=GenericOutput)
148
148
 
149
149
 
150
- def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
150
+ def set_metadata(fn: Callable[P, R] | None = None, *, state: Sequence[str] | None = None) -> Callable[P, R]:
151
151
  """Decorator to stamp Output classes with runtime metadata"""
152
152
 
153
153
  if fn is None: