dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +188 -178
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +4 -5
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metadata/_distance.py +10 -7
  22. dataeval/metadata/_ood.py +11 -103
  23. dataeval/metrics/bias/_balance.py +23 -33
  24. dataeval/metrics/bias/_diversity.py +16 -14
  25. dataeval/metrics/bias/_parity.py +18 -18
  26. dataeval/metrics/estimators/_divergence.py +2 -4
  27. dataeval/metrics/stats/_base.py +103 -42
  28. dataeval/metrics/stats/_boxratiostats.py +21 -19
  29. dataeval/metrics/stats/_dimensionstats.py +14 -10
  30. dataeval/metrics/stats/_hashstats.py +1 -1
  31. dataeval/metrics/stats/_pixelstats.py +6 -6
  32. dataeval/metrics/stats/_visualstats.py +3 -3
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +24 -70
  35. dataeval/outputs/_drift.py +1 -9
  36. dataeval/outputs/_linters.py +11 -11
  37. dataeval/outputs/_stats.py +82 -23
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +54 -28
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +22 -12
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
  62. dataeval-0.86.2.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.86.0.dist-info/RECORD +0 -114
  65. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
@@ -13,31 +13,31 @@ from dataeval.metrics.stats._imagestats import imagestats
13
13
  from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
14
14
  from dataeval.outputs._base import set_metadata
15
15
  from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
16
- from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
+ from dataeval.outputs._stats import BASE_ATTRS
17
17
  from dataeval.typing import ArrayLike, Dataset
18
18
 
19
19
 
20
20
  def _get_outlier_mask(
21
21
  values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
22
22
  ) -> NDArray:
23
+ values = values.astype(np.float64)
23
24
  if method == "zscore":
24
25
  threshold = threshold if threshold else 3.0
25
26
  std = np.std(values)
26
27
  abs_diff = np.abs(values - np.mean(values))
27
28
  return std != 0 and (abs_diff / std) > threshold
28
- elif method == "modzscore":
29
+ if method == "modzscore":
29
30
  threshold = threshold if threshold else 3.5
30
31
  abs_diff = np.abs(values - np.median(values))
31
32
  med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
32
33
  mod_z_score = 0.6745 * abs_diff / med_abs_diff
33
34
  return mod_z_score > threshold
34
- elif method == "iqr":
35
+ if method == "iqr":
35
36
  threshold = threshold if threshold else 1.5
36
37
  qrt = np.percentile(values, q=(25, 75), method="midpoint")
37
38
  iqr = (qrt[1] - qrt[0]) * threshold
38
39
  return (values < (qrt[0] - iqr)) | (values > (qrt[1] + iqr))
39
- else:
40
- raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
40
+ raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
41
41
 
42
42
 
43
43
  class Outliers:
@@ -103,7 +103,7 @@ class Outliers:
103
103
  use_visual: bool = True,
104
104
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
105
105
  outlier_threshold: float | None = None,
106
- ):
106
+ ) -> None:
107
107
  self.stats: ImageStatsOutput
108
108
  self.use_dimension = use_dimension
109
109
  self.use_pixel = use_pixel
@@ -114,7 +114,7 @@ class Outliers:
114
114
  def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
115
115
  flagged_images: dict[int, dict[str, float]] = {}
116
116
  for stat, values in stats.items():
117
- if stat in (SOURCE_INDEX, BOX_COUNT):
117
+ if stat in BASE_ATTRS:
118
118
  continue
119
119
  if values.ndim == 1:
120
120
  mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
@@ -80,14 +80,17 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
80
80
  MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
81
81
  """
82
82
 
83
- _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
84
- fnames = metadata1.continuous_factor_names
83
+ _compare_keys(metadata1.factor_names, metadata2.factor_names)
84
+ cont_fnames = metadata1.get_factors_by_type("continuous")
85
85
 
86
- cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
87
- cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
86
+ if not cont_fnames:
87
+ return MetadataDistanceOutput({})
88
88
 
89
- _validate_factors_and_data(fnames, cont1)
90
- _validate_factors_and_data(fnames, cont2)
89
+ cont1 = np.atleast_2d(metadata1.dataframe[cont_fnames].to_numpy()) # (S, F)
90
+ cont2 = np.atleast_2d(metadata2.dataframe[cont_fnames].to_numpy()) # (S, F)
91
+
92
+ _validate_factors_and_data(cont_fnames, cont1)
93
+ _validate_factors_and_data(cont_fnames, cont2)
91
94
 
92
95
  N = len(cont1)
93
96
  M = len(cont2)
@@ -104,7 +107,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
104
107
  results: dict[str, MetadataDistanceValues] = {}
105
108
 
106
109
  # Per factor
107
- for i, fname in enumerate(fnames):
110
+ for i, fname in enumerate(cont_fnames):
108
111
  fdata1 = cont1[:, i] # (S, 1)
109
112
  fdata2 = cont2[:, i] # (S, 1)
110
113
 
dataeval/metadata/_ood.py CHANGED
@@ -15,95 +15,6 @@ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorO
15
15
  from dataeval.outputs._base import set_metadata
16
16
 
17
17
 
18
- def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
19
- """Combines the discrete and continuous data of a :class:`Metadata` object
20
-
21
- Returns
22
- -------
23
- Tuple[list[str], NDArray]
24
- The combined list of factors names and the combined discrete and continuous data
25
-
26
- Note
27
- ----
28
- Discrete and continuous data must have the same number of samples
29
- """
30
- names = []
31
- data = []
32
-
33
- if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
34
- names.extend(metadata.discrete_factor_names)
35
- data.append(metadata.discrete_data)
36
-
37
- if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
38
- names.extend(metadata.continuous_factor_names)
39
- data.append(metadata.continuous_data)
40
-
41
- return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
42
-
43
-
44
- def _combine_metadata(
45
- metadata_1: Metadata, metadata_2: Metadata
46
- ) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
47
- """
48
- Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
49
- match exactly and data has the same number of columns (factors).
50
-
51
- Parameters
52
- ----------
53
- metadata_1 : Metadata
54
- The set of factor names used as reference to determine the correct factor names and length of data
55
- metadata_2 : Metadata
56
- The compared set of factor names and data that must match metadata_1
57
-
58
- Returns
59
- -------
60
- list[str]
61
- The combined discrete and continuous factor names in that order.
62
- list[NDArray]
63
- Combined discrete and continuous data of metadata_1
64
- list[NDArray]
65
- Combined discrete and continuous data of metadata_2
66
-
67
- Raises
68
- ------
69
- ValueError
70
- If keys do not match in metadata_1 and metadata_2
71
- ValueError
72
- If the length of keys do not match the length of the data
73
- """
74
- factor_names: list[str] = []
75
- m1_data: list[NDArray[np.int64 | np.float64]] = []
76
- m2_data: list[NDArray[np.int64 | np.float64]] = []
77
-
78
- # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
79
- if metadata_1.total_num_factors != metadata_2.total_num_factors:
80
- raise ValueError(
81
- f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
82
- f"and metadata_2 ({metadata_2.total_num_factors})"
83
- )
84
-
85
- # Validate and attach discrete data
86
- if metadata_1.discrete_factor_names:
87
- _compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
88
- _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
89
-
90
- factor_names.extend(metadata_1.discrete_factor_names)
91
- m1_data.append(metadata_1.discrete_data)
92
- m2_data.append(metadata_2.discrete_data)
93
-
94
- # Validate and attach continuous data
95
- if metadata_1.continuous_factor_names:
96
- _compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
97
- _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
98
-
99
- factor_names.extend(metadata_1.continuous_factor_names)
100
- m1_data.append(metadata_1.continuous_data)
101
- m2_data.append(metadata_2.continuous_data)
102
-
103
- # Turns list of discrete and continuous into one array
104
- return factor_names, m1_data, m2_data
105
-
106
-
107
18
  def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
108
19
  """
109
20
  Calculates deviations of the test data from the median of the reference data
@@ -207,16 +118,13 @@ def find_most_deviated_factors(
207
118
  if not any(ood_mask):
208
119
  return MostDeviatedFactorsOutput([])
209
120
 
210
- # Combines reference and test factor names and data if exists and match exactly
211
- # shape -> (samples, factors)
212
- factor_names, md_1, md_2 = _combine_metadata(
213
- metadata_1=metadata_ref,
214
- metadata_2=metadata_tst,
215
- )
121
+ factor_names = metadata_ref.factor_names
122
+ ref_data = metadata_ref.factor_data
123
+ tst_data = metadata_tst.factor_data
216
124
 
217
- # Stack discrete and continuous factors as separate factors. Must have equal sample counts
218
- ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
219
- tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
125
+ _compare_keys(factor_names, metadata_tst.factor_names)
126
+ _validate_factors_and_data(factor_names, ref_data)
127
+ _validate_factors_and_data(factor_names, tst_data)
220
128
 
221
129
  if len(ref_data) < 3:
222
130
  warnings.warn(
@@ -256,6 +164,7 @@ which is what many library functions return, multiply it by _NATS2BITS to get it
256
164
  """
257
165
 
258
166
 
167
+ @set_metadata
259
168
  def find_ood_predictors(
260
169
  metadata: Metadata,
261
170
  ood: OODOutput,
@@ -305,8 +214,8 @@ def find_ood_predictors(
305
214
 
306
215
  ood_mask: NDArray[np.bool_] = ood.is_ood
307
216
 
308
- discrete_features_count = len(metadata.discrete_factor_names)
309
- factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
217
+ factors = metadata.factor_names
218
+ data = metadata.factor_data
310
219
 
311
220
  # No metadata correlated with out of distribution data, return 0.0 for all factors
312
221
  if not any(ood_mask):
@@ -320,14 +229,13 @@ def find_ood_predictors(
320
229
  # Calculate mean, std of each factor over all samples
321
230
  scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
322
231
 
323
- discrete_features = np.zeros_like(factors, dtype=np.bool_)
324
- discrete_features[:discrete_features_count] = True
232
+ discrete_features = [info.factor_type != "continuous" for info in metadata.factor_info.values()]
325
233
 
326
234
  mutual_info_values = (
327
235
  mutual_info_classif(
328
236
  X=scaled_data,
329
237
  y=ood_mask,
330
- discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
238
+ discrete_features=discrete_features, # type: ignore - sklearn function not typed
331
239
  random_state=get_seed(),
332
240
  )
333
241
  * _NATS2BITS
@@ -68,22 +68,20 @@ def balance(
68
68
 
69
69
  >>> bal = balance(metadata)
70
70
  >>> bal.balance
71
- array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
71
+ array([1. , 0.134, 0. , 0. ])
72
72
 
73
73
  Return intra/interfactor balance (mutual information)
74
74
 
75
75
  >>> bal.factors
76
- array([[1. , 0.314, 0.269, 0.852, 0.367],
77
- [0.314, 1. , 0.097, 0.158, 1.98 ],
78
- [0.269, 0.097, 1. , 0.037, 0.015],
79
- [0.852, 0.158, 0.037, 0.475, 0.255],
80
- [0.367, 1.98 , 0.015, 0.255, 1.063]])
76
+ array([[1. , 0.017, 0.015],
77
+ [0.017, 0.445, 0.245],
78
+ [0.015, 0.245, 1.063]])
81
79
 
82
80
  Return classwise balance (mutual information) of factors with individual class_labels
83
81
 
84
82
  >>> bal.classwise
85
- array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
86
- [1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
83
+ array([[1. , 0.134, 0. , 0. ],
84
+ [1. , 0.134, 0. , 0. ]])
87
85
 
88
86
 
89
87
  See Also
@@ -92,41 +90,39 @@ def balance(
92
90
  sklearn.feature_selection.mutual_info_regression
93
91
  sklearn.metrics.mutual_info_score
94
92
  """
95
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
93
+ if not metadata.factor_names:
96
94
  raise ValueError("No factors found in provided metadata.")
97
95
 
98
96
  num_neighbors = _validate_num_neighbors(num_neighbors)
99
97
 
100
- num_factors = metadata.total_num_factors
101
- is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
98
+ data = metadata.discretized_data
99
+ factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
100
+ is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
101
+ num_factors = len(factor_types)
102
+
102
103
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
103
- data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
104
- discretized_data = data
105
- if len(metadata.continuous_data):
106
- data = np.hstack((data, metadata.continuous_data))
107
- discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
108
- discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
109
-
110
- for idx in range(num_factors):
111
- if idx >= len(metadata.discrete_factor_names) + 1:
112
- mi[idx, :] = mutual_info_regression(
104
+ data = np.hstack((metadata.class_labels[:, np.newaxis], data))
105
+
106
+ for idx, factor_type in enumerate(factor_types.values()):
107
+ if factor_type != "continuous":
108
+ mi[idx, :] = mutual_info_classif(
113
109
  data,
114
110
  data[:, idx],
115
- discrete_features=is_discrete, # type: ignore
111
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
116
112
  n_neighbors=num_neighbors,
117
113
  random_state=get_seed(),
118
114
  )
119
115
  else:
120
- mi[idx, :] = mutual_info_classif(
116
+ mi[idx, :] = mutual_info_regression(
121
117
  data,
122
118
  data[:, idx],
123
- discrete_features=is_discrete, # type: ignore
119
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
124
120
  n_neighbors=num_neighbors,
125
121
  random_state=get_seed(),
126
122
  )
127
123
 
128
124
  # Normalization via entropy
129
- bin_cnts = get_counts(discretized_data)
125
+ bin_cnts = get_counts(data)
130
126
  ent_factor = sp.stats.entropy(bin_cnts, axis=0)
131
127
  norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
132
128
 
@@ -149,7 +145,7 @@ def balance(
149
145
  classwise_mi[idx, :] = mutual_info_classif(
150
146
  data,
151
147
  tgt_bin[:, idx],
152
- discrete_features=is_discrete, # type: ignore
148
+ discrete_features=is_discrete, # type: ignore - sklearn function not typed
153
149
  n_neighbors=num_neighbors,
154
150
  random_state=get_seed(),
155
151
  )
@@ -161,12 +157,6 @@ def balance(
161
157
  classwise = classwise_mi / norm_factor
162
158
 
163
159
  # Grabbing factor names for plotting function
164
- factor_names = ["class"]
165
- for name in metadata.discrete_factor_names:
166
- if name in metadata.continuous_factor_names:
167
- name = name + "-discrete"
168
- factor_names.append(name)
169
- for name in metadata.continuous_factor_names:
170
- factor_names.append(name + "-continuous")
160
+ factor_names = ["class_label"] + metadata.factor_names
171
161
 
172
162
  return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
@@ -138,43 +138,45 @@ def diversity(
138
138
 
139
139
  >>> div_simp = diversity(metadata, method="simpson")
140
140
  >>> div_simp.diversity_index
141
- array([0.6 , 0.809, 1. , 0.8 ])
141
+ array([0.6 , 0.8 , 0.809, 1. ])
142
142
 
143
143
  >>> div_simp.classwise
144
- array([[0.5 , 0.8 , 0.8 ],
145
- [0.63 , 0.976, 0.528]])
144
+ array([[0.8 , 0.5 , 0.8 ],
145
+ [0.528, 0.63 , 0.976]])
146
146
 
147
147
  Compute Shannon diversity index of metadata and class labels
148
148
 
149
149
  >>> div_shan = diversity(metadata, method="shannon")
150
150
  >>> div_shan.diversity_index
151
- array([0.811, 0.943, 1. , 0.918])
151
+ array([0.811, 0.918, 0.943, 1. ])
152
152
 
153
153
  >>> div_shan.classwise
154
- array([[0.683, 0.918, 0.918],
155
- [0.814, 0.991, 0.764]])
154
+ array([[0.918, 0.683, 0.918],
155
+ [0.764, 0.814, 0.991]])
156
156
 
157
157
  See Also
158
158
  --------
159
159
  scipy.stats.entropy
160
160
  """
161
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
161
+ if not metadata.factor_names:
162
162
  raise ValueError("No factors found in provided metadata.")
163
163
 
164
164
  diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
165
- discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
166
- cnts = get_counts(discretized_data)
165
+ discretized_data = metadata.discretized_data
166
+ factor_names = metadata.factor_names
167
+ class_lbl = metadata.class_labels
168
+
169
+ class_labels_with_discretized_data = np.hstack((class_lbl[:, np.newaxis], discretized_data))
170
+ cnts = get_counts(class_labels_with_discretized_data)
167
171
  num_bins = np.bincount(np.nonzero(cnts)[1])
168
172
  diversity_index = diversity_fn(cnts, num_bins)
169
173
 
170
- class_lbl = metadata.class_labels
171
-
172
174
  u_classes = np.unique(class_lbl)
173
- num_factors = len(metadata.discrete_factor_names)
175
+ num_factors = len(factor_names)
174
176
  classwise_div = np.full((len(u_classes), num_factors), np.nan)
175
177
  for idx, cls in enumerate(u_classes):
176
178
  subset_mask = class_lbl == cls
177
- cls_cnts = get_counts(metadata.discrete_data[subset_mask], min_num_bins=cnts.shape[0])
179
+ cls_cnts = get_counts(discretized_data[subset_mask], min_num_bins=cnts.shape[0])
178
180
  classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
179
181
 
180
- return DiversityOutput(diversity_index, classwise_div, metadata.discrete_factor_names, metadata.class_names)
182
+ return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
+ from collections import defaultdict
6
7
  from typing import Any
7
8
 
8
9
  import numpy as np
@@ -241,13 +242,13 @@ def parity(metadata: Metadata) -> ParityOutput:
241
242
  >>> parity(metadata)
242
243
  ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
243
244
  """ # noqa: E501
244
- if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
245
+ if not metadata.factor_names:
245
246
  raise ValueError("No factors found in provided metadata.")
246
247
 
247
- chi_scores = np.zeros(metadata.discrete_data.shape[1])
248
+ chi_scores = np.zeros(metadata.discretized_data.shape[1])
248
249
  p_values = np.zeros_like(chi_scores)
249
- insufficient_data = {}
250
- for i, col_data in enumerate(metadata.discrete_data.T):
250
+ insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
251
+ for i, col_data in enumerate(metadata.discretized_data.T):
251
252
  # Builds a contingency matrix where entry at index (r,c) represents
252
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
253
254
  # at a data point with class c.
@@ -257,30 +258,29 @@ def parity(metadata: Metadata) -> ParityOutput:
257
258
  # Determines if any frequencies are too low
258
259
  counts = np.nonzero(contingency_matrix < 5)
259
260
  unique_factor_values = np.unique(col_data)
260
- current_factor_name = metadata.discrete_factor_names[i]
261
+ current_factor_name = metadata.factor_names[i]
261
262
  for int_factor, int_class in zip(counts[0], counts[1]):
262
263
  if contingency_matrix[int_factor, int_class] > 0:
263
264
  factor_category = unique_factor_values[int_factor].item()
264
- if current_factor_name not in insufficient_data:
265
- insufficient_data[current_factor_name] = {}
266
- if factor_category not in insufficient_data[current_factor_name]:
267
- insufficient_data[current_factor_name][factor_category] = {}
268
265
  class_name = metadata.class_names[int_class]
269
266
  class_count = contingency_matrix[int_factor, int_class].item()
270
267
  insufficient_data[current_factor_name][factor_category][class_name] = class_count
271
268
 
272
269
  # This deletes rows containing only zeros,
273
270
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
274
- rowsums = np.sum(contingency_matrix, axis=1)
275
- rowmask = np.nonzero(rowsums)[0]
276
- contingency_matrix = contingency_matrix[rowmask]
271
+ contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
277
272
 
278
- chi2, p, _, _ = chi2_contingency(contingency_matrix)
279
-
280
- chi_scores[i] = chi2
281
- p_values[i] = p
273
+ chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
282
274
 
283
275
  if insufficient_data:
284
- warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
276
+ warnings.warn(
277
+ f"Factors {list(insufficient_data)} did not meet the recommended "
278
+ "5 occurrences for each value-label combination."
279
+ )
285
280
 
286
- return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)
281
+ return ParityOutput(
282
+ score=chi_scores,
283
+ p_value=p_values,
284
+ factor_names=metadata.factor_names,
285
+ insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
286
+ )
@@ -38,8 +38,7 @@ def divergence_mst(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
38
38
  """
39
39
  mst = minimum_spanning_tree(data).toarray()
40
40
  edgelist = np.transpose(np.nonzero(mst))
41
- errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
42
- return errors
41
+ return np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
43
42
 
44
43
 
45
44
  def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
@@ -59,8 +58,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
59
58
  Number of label errors when finding nearest neighbors
60
59
  """
61
60
  nn_indices = compute_neighbors(data, data)
62
- errors = np.sum(np.abs(labels[nn_indices] - labels))
63
- return errors
61
+ return np.sum(np.abs(labels[nn_indices] - labels))
64
62
 
65
63
 
66
64
  _DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}