dataeval 0.85.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +65 -42
  7. dataeval/data/_selection.py +2 -3
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +6 -8
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/__init__.py +4 -1
  14. dataeval/detectors/drift/_base.py +4 -5
  15. dataeval/detectors/drift/_mmd.py +3 -6
  16. dataeval/detectors/drift/_mvdc.py +92 -0
  17. dataeval/detectors/drift/_nml/__init__.py +6 -0
  18. dataeval/detectors/drift/_nml/_base.py +70 -0
  19. dataeval/detectors/drift/_nml/_chunk.py +396 -0
  20. dataeval/detectors/drift/_nml/_domainclassifier.py +181 -0
  21. dataeval/detectors/drift/_nml/_result.py +97 -0
  22. dataeval/detectors/drift/_nml/_thresholds.py +269 -0
  23. dataeval/detectors/linters/outliers.py +7 -7
  24. dataeval/metrics/bias/_parity.py +10 -13
  25. dataeval/metrics/estimators/_divergence.py +2 -4
  26. dataeval/metrics/stats/_base.py +103 -42
  27. dataeval/metrics/stats/_boxratiostats.py +21 -19
  28. dataeval/metrics/stats/_dimensionstats.py +14 -10
  29. dataeval/metrics/stats/_hashstats.py +1 -1
  30. dataeval/metrics/stats/_pixelstats.py +6 -6
  31. dataeval/metrics/stats/_visualstats.py +3 -3
  32. dataeval/outputs/__init__.py +2 -1
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +27 -31
  35. dataeval/outputs/_drift.py +60 -0
  36. dataeval/outputs/_linters.py +12 -17
  37. dataeval/outputs/_stats.py +83 -29
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +32 -20
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +19 -11
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +3 -2
  62. dataeval-0.86.1.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.85.0.dist-info/RECORD +0 -107
  65. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,97 @@
1
+ """
2
+ Contains the results of the data reconstruction drift calculation and provides plotting functionality.
3
+
4
+ Source code derived from NannyML 0.13.0
5
+ https://github.com/NannyML/nannyml/blob/main/nannyml/base.py
6
+
7
+ Licensed under Apache Software License (Apache 2.0)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import copy
13
+ from abc import ABC, abstractmethod
14
+ from typing import NamedTuple, Sequence
15
+
16
+ import pandas as pd
17
+ from typing_extensions import Self
18
+
19
+ from dataeval.outputs._base import GenericOutput
20
+
21
+
22
+ class Metric(NamedTuple):
23
+ display_name: str
24
+ column_name: str
25
+
26
+
27
+ class AbstractResult(GenericOutput[pd.DataFrame]):
28
+ def __init__(self, results_data: pd.DataFrame) -> None:
29
+ self._data = results_data.copy(deep=True)
30
+
31
+ def data(self) -> pd.DataFrame:
32
+ return self.to_df()
33
+
34
+ @property
35
+ def empty(self) -> bool:
36
+ return self._data is None or self._data.empty
37
+
38
+ def __len__(self) -> int:
39
+ return 0 if self.empty else len(self._data)
40
+
41
+ def to_df(self, multilevel: bool = True) -> pd.DataFrame:
42
+ """Export results to pandas dataframe."""
43
+ if multilevel:
44
+ return self._data
45
+ column_names = [
46
+ "_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
47
+ for col in self._data.columns.values
48
+ ]
49
+ single_level_data = self._data.copy(deep=True)
50
+ single_level_data.columns = column_names
51
+ return single_level_data
52
+
53
+ def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
54
+ """Returns filtered result metric data."""
55
+ if metrics and not isinstance(metrics, (str, Sequence)):
56
+ raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
57
+ if isinstance(metrics, str):
58
+ metrics = [metrics]
59
+ return self._filter(period, metrics)
60
+
61
+ @abstractmethod
62
+ def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self: ...
63
+
64
+
65
+ class Abstract1DResult(AbstractResult, ABC):
66
+ def __init__(self, results_data: pd.DataFrame) -> None:
67
+ super().__init__(results_data)
68
+
69
+ def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self:
70
+ data = self._data
71
+ if period != "all":
72
+ data = self._data.loc[self._data.loc[:, ("chunk", "period")] == period, :] # type: ignore | dataframe loc
73
+ data = data.reset_index(drop=True)
74
+
75
+ res = copy.deepcopy(self)
76
+ res._data = data
77
+ return res
78
+
79
+
80
+ class PerMetricResult(Abstract1DResult):
81
+ def __init__(self, results_data: pd.DataFrame, metrics: Sequence[Metric] = []) -> None:
82
+ super().__init__(results_data)
83
+ self.metrics = metrics
84
+
85
+ def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self:
86
+ if metrics is None:
87
+ metrics = [metric.column_name for metric in self.metrics]
88
+
89
+ res = super()._filter(period)
90
+
91
+ data = pd.concat([res._data.loc[:, (["chunk"])], res._data.loc[:, (metrics,)]], axis=1) # type: ignore | dataframe loc
92
+ data = data.reset_index(drop=True)
93
+
94
+ res._data = data
95
+ res.metrics = [metric for metric in self.metrics if metric.column_name in metrics]
96
+
97
+ return res
@@ -0,0 +1,269 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/blob/main/nannyml/thresholds.py
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any, Callable, ClassVar
13
+
14
+ import numpy as np
15
+
16
+
17
+ class Threshold(ABC):
18
+ """A base class used to calculate lower and upper threshold values given one or multiple arrays.
19
+
20
+ Any subclass should implement the abstract `thresholds` method.
21
+ It takes an array or list of arrays and converts them into lower and upper threshold values, represented
22
+ as a tuple of optional floats.
23
+
24
+ A `None` threshold value is interpreted as if there is no upper or lower threshold.
25
+ One or both values might be `None`.
26
+ """
27
+
28
+ _registry: ClassVar[dict[str, type[Threshold]]] = {}
29
+ """Class registry lookup to get threshold subclass from threshold_type string"""
30
+
31
+ def __str__(self) -> str:
32
+ return f"{self.__class__.__name__}({str(vars(self))})"
33
+
34
+ def __repr__(self) -> str:
35
+ return str(self)
36
+
37
+ def __eq__(self, other: object) -> bool:
38
+ return isinstance(other, self.__class__) and other.__dict__ == self.__dict__
39
+
40
+ def __init_subclass__(cls, threshold_type: str) -> None:
41
+ Threshold._registry[threshold_type] = cls
42
+
43
+ @abstractmethod
44
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
45
+ """Returns lower and upper threshold values when given one or more np.ndarray instances.
46
+
47
+ Parameters:
48
+ data: np.ndarray
49
+ An array of values used to calculate the thresholds on. This will most often represent a metric
50
+ calculated on one or more sets of data, e.g. a list of F1 scores of multiple data chunks.
51
+ kwargs: dict[str, Any]
52
+ Optional keyword arguments passed to the implementing subclass.
53
+
54
+ Returns:
55
+ lower, upper: tuple[Optional[float], Optional[float]]
56
+ The lower and upper threshold values. One or both might be `None`.
57
+ """
58
+
59
+ @classmethod
60
+ def parse_object(cls, obj: dict[str, Any]) -> Threshold:
61
+ """Parse object as :class:`Threshold`"""
62
+ class_name = obj.pop("type", "")
63
+
64
+ try:
65
+ threshold_cls = cls._registry[class_name]
66
+ except KeyError:
67
+ accepted_values = ", ".join(map(repr, cls._registry))
68
+ raise ValueError(f"Expected one of {accepted_values} for threshold type, but received '{class_name}'")
69
+
70
+ return threshold_cls(**obj)
71
+
72
+ def calculate(
73
+ self,
74
+ data: np.ndarray,
75
+ lower_limit: float | None = None,
76
+ upper_limit: float | None = None,
77
+ override_using_none: bool = False,
78
+ logger: logging.Logger | None = None,
79
+ ) -> tuple[float | None, float | None]:
80
+ """
81
+ Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
82
+
83
+ Parameters
84
+ ----------
85
+ data : np.ndarray
86
+ The data used by the Threshold instance to calculate the lower and upper threshold values.
87
+ This will often be the values of a drift detection method or performance metric on chunks of reference
88
+ data.
89
+ lower_limit : float or None, default None
90
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
91
+ values that end up below this limit will be replaced by this limit value.
92
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
93
+ metric.
94
+ upper_threshold_value_limit : float or None, default None
95
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
96
+ values that end up below this limit will be replaced by this limit value.
97
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
98
+ metric.
99
+ override_using_none: bool, default False
100
+ When set to True use None to override threshold values that exceed value limits.
101
+ This will prevent them from being rendered on plots.
102
+ logger: Optional[logging.Logger], default=None
103
+ An optional Logger instance. When provided a warning will be logged when a calculated threshold value
104
+ gets overridden by a threshold value limit.
105
+ """
106
+
107
+ lower_value, upper_value = self._thresholds(data)
108
+
109
+ if lower_limit is not None and lower_value is not None and lower_value <= lower_limit:
110
+ override_value = None if override_using_none else lower_limit
111
+ if logger:
112
+ logger.warning(
113
+ f"lower threshold value {lower_value} overridden by lower threshold value limit {override_value}"
114
+ )
115
+ lower_value = override_value
116
+
117
+ if upper_limit is not None and upper_value is not None and upper_value >= upper_limit:
118
+ override_value = None if override_using_none else upper_limit
119
+ if logger:
120
+ logger.warning(
121
+ f"upper threshold value {upper_value} overridden by upper threshold value limit {override_value}"
122
+ )
123
+ upper_value = override_value
124
+
125
+ return lower_value, upper_value
126
+
127
+
128
+ class ConstantThreshold(Threshold, threshold_type="constant"):
129
+ """A `Thresholder` implementation that returns a constant lower and or upper threshold value.
130
+
131
+ Attributes:
132
+ lower: Optional[float]
133
+ The constant lower threshold value. Defaults to `None`, meaning there is no lower threshold.
134
+ upper: Optional[float]
135
+ The constant upper threshold value. Defaults to `None`, meaning there is no upper threshold.
136
+
137
+ Raises:
138
+ ValueError: raised when an argument was given using an incorrect type or name
139
+ ValueError: raised when the ConstantThreshold could not be created using the given argument values
140
+
141
+ Examples:
142
+ >>> data = np.array(range(10))
143
+ >>> t = ConstantThreshold(lower=None, upper=0.1)
144
+ >>> lower, upper = t.threshold()
145
+ >>> print(lower, upper)
146
+ None 0.1
147
+ """
148
+
149
+ def __init__(self, lower: float | int | None = None, upper: float | int | None = None) -> None:
150
+ """Creates a new ConstantThreshold instance.
151
+
152
+ Args:
153
+ lower: Optional[Union[float, int]], default=None
154
+ The constant lower threshold value. Defaults to `None`, meaning there is no lower threshold.
155
+ upper: Optional[Union[float, int]], default=None
156
+ The constant upper threshold value. Defaults to `None`, meaning there is no upper threshold.
157
+
158
+ Raises:
159
+ ValueError: raised when an argument was given using an incorrect type or name
160
+ ValueError: raised when the ConstantThreshold could not be created using the given argument values
161
+ """
162
+ self._validate_inputs(lower, upper)
163
+
164
+ self.lower = lower
165
+ self.upper = upper
166
+
167
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
168
+ return self.lower, self.upper
169
+
170
+ @staticmethod
171
+ def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
172
+ if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
173
+ raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
174
+
175
+ if upper is not None and not isinstance(upper, (float, int)) or isinstance(upper, bool):
176
+ raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
177
+
178
+ # explicit None check is required due to special interpretation of the value 0.0 as False
179
+ if lower is not None and upper is not None and lower >= upper:
180
+ raise ValueError(f"lower threshold {lower} must be less than upper threshold {upper}")
181
+
182
+
183
+ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation"):
184
+ """A Thresholder that offsets the mean of an array by a multiple of the standard deviation of the array values.
185
+
186
+ This thresholder will take the aggregate of an array of values, the mean by default and add or subtract an offset
187
+ to get the upper and lower threshold values.
188
+ This offset is calculated as a multiplier, by default 3, times the standard deviation of the given array.
189
+
190
+ Attributes:
191
+ std_lower_multiplier: float
192
+ std_upper_multiplier: float
193
+
194
+ Examples:
195
+ >>> data = np.array(range(10))
196
+ >>> t = ConstantThreshold(lower=None, upper=0.1)
197
+ >>> lower, upper = t.threshold()
198
+ >>> print(lower, upper)
199
+ -4.116843969807043 13.116843969807043
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ std_lower_multiplier: float | int | None = 3,
205
+ std_upper_multiplier: float | int | None = 3,
206
+ offset_from: Callable[[np.ndarray], Any] = np.nanmean,
207
+ ) -> None:
208
+ """Creates a new StandardDeviationThreshold instance.
209
+
210
+ Args:
211
+ std_lower_multiplier: float, default=3
212
+ The number the standard deviation of the input array will be multiplied with to form the lower offset.
213
+ This value will be subtracted from the aggregate of the input array.
214
+ Defaults to 3.
215
+ std_upper_multiplier: float, default=3
216
+ The number the standard deviation of the input array will be multiplied with to form the upper offset.
217
+ This value will be added to the aggregate of the input array.
218
+ Defaults to 3.
219
+ offset_from: Callable[[np.ndarray], Any], default=np.nanmean
220
+ A function that will be applied to the input array to aggregate it into a single value.
221
+ Adding the upper offset to this value will yield the upper threshold, subtracting the lower offset
222
+ will yield the lower threshold.
223
+ """
224
+
225
+ self._validate_inputs(std_lower_multiplier, std_upper_multiplier)
226
+
227
+ self.std_lower_multiplier = std_lower_multiplier
228
+ self.std_upper_multiplier = std_upper_multiplier
229
+ self.offset_from = offset_from
230
+
231
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
232
+ aggregate = self.offset_from(data)
233
+ std = np.nanstd(data)
234
+
235
+ lower_threshold = aggregate - std * self.std_lower_multiplier if self.std_lower_multiplier is not None else None
236
+
237
+ upper_threshold = aggregate + std * self.std_upper_multiplier if self.std_upper_multiplier is not None else None
238
+
239
+ return lower_threshold, upper_threshold
240
+
241
+ @staticmethod
242
+ def _validate_inputs(
243
+ std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3
244
+ ) -> None:
245
+ if (
246
+ std_lower_multiplier is not None
247
+ and not isinstance(std_lower_multiplier, (float, int))
248
+ or isinstance(std_lower_multiplier, bool)
249
+ ):
250
+ raise ValueError(
251
+ f"expected type of 'std_lower_multiplier' to be 'float', 'int' or None "
252
+ f"but got '{type(std_lower_multiplier).__name__}'"
253
+ )
254
+
255
+ if std_lower_multiplier and std_lower_multiplier < 0:
256
+ raise ValueError(f"'std_lower_multiplier' should be greater than 0 but got value {std_lower_multiplier}")
257
+
258
+ if (
259
+ std_upper_multiplier is not None
260
+ and not isinstance(std_upper_multiplier, (float, int))
261
+ or isinstance(std_upper_multiplier, bool)
262
+ ):
263
+ raise ValueError(
264
+ f"expected type of 'std_upper_multiplier' to be 'float', 'int' or None "
265
+ f"but got '{type(std_upper_multiplier).__name__}'"
266
+ )
267
+
268
+ if std_upper_multiplier and std_upper_multiplier < 0:
269
+ raise ValueError(f"'std_upper_multiplier' should be greater than 0 but got value {std_upper_multiplier}")
@@ -13,31 +13,31 @@ from dataeval.metrics.stats._imagestats import imagestats
13
13
  from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
14
14
  from dataeval.outputs._base import set_metadata
15
15
  from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
16
- from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
+ from dataeval.outputs._stats import BASE_ATTRS
17
17
  from dataeval.typing import ArrayLike, Dataset
18
18
 
19
19
 
20
20
  def _get_outlier_mask(
21
21
  values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
22
22
  ) -> NDArray:
23
+ values = values.astype(np.float64)
23
24
  if method == "zscore":
24
25
  threshold = threshold if threshold else 3.0
25
26
  std = np.std(values)
26
27
  abs_diff = np.abs(values - np.mean(values))
27
28
  return std != 0 and (abs_diff / std) > threshold
28
- elif method == "modzscore":
29
+ if method == "modzscore":
29
30
  threshold = threshold if threshold else 3.5
30
31
  abs_diff = np.abs(values - np.median(values))
31
32
  med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
32
33
  mod_z_score = 0.6745 * abs_diff / med_abs_diff
33
34
  return mod_z_score > threshold
34
- elif method == "iqr":
35
+ if method == "iqr":
35
36
  threshold = threshold if threshold else 1.5
36
37
  qrt = np.percentile(values, q=(25, 75), method="midpoint")
37
38
  iqr = (qrt[1] - qrt[0]) * threshold
38
39
  return (values < (qrt[0] - iqr)) | (values > (qrt[1] + iqr))
39
- else:
40
- raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
40
+ raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
41
41
 
42
42
 
43
43
  class Outliers:
@@ -103,7 +103,7 @@ class Outliers:
103
103
  use_visual: bool = True,
104
104
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
105
105
  outlier_threshold: float | None = None,
106
- ):
106
+ ) -> None:
107
107
  self.stats: ImageStatsOutput
108
108
  self.use_dimension = use_dimension
109
109
  self.use_pixel = use_pixel
@@ -114,7 +114,7 @@ class Outliers:
114
114
  def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
115
115
  flagged_images: dict[int, dict[str, float]] = {}
116
116
  for stat, values in stats.items():
117
- if stat in (SOURCE_INDEX, BOX_COUNT):
117
+ if stat in BASE_ATTRS:
118
118
  continue
119
119
  if values.ndim == 1:
120
120
  mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
+ from collections import defaultdict
6
7
  from typing import Any
7
8
 
8
9
  import numpy as np
@@ -246,7 +247,7 @@ def parity(metadata: Metadata) -> ParityOutput:
246
247
 
247
248
  chi_scores = np.zeros(metadata.discrete_data.shape[1])
248
249
  p_values = np.zeros_like(chi_scores)
249
- insufficient_data = {}
250
+ insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
250
251
  for i, col_data in enumerate(metadata.discrete_data.T):
251
252
  # Builds a contingency matrix where entry at index (r,c) represents
252
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
@@ -261,26 +262,22 @@ def parity(metadata: Metadata) -> ParityOutput:
261
262
  for int_factor, int_class in zip(counts[0], counts[1]):
262
263
  if contingency_matrix[int_factor, int_class] > 0:
263
264
  factor_category = unique_factor_values[int_factor].item()
264
- if current_factor_name not in insufficient_data:
265
- insufficient_data[current_factor_name] = {}
266
- if factor_category not in insufficient_data[current_factor_name]:
267
- insufficient_data[current_factor_name][factor_category] = {}
268
265
  class_name = metadata.class_names[int_class]
269
266
  class_count = contingency_matrix[int_factor, int_class].item()
270
267
  insufficient_data[current_factor_name][factor_category][class_name] = class_count
271
268
 
272
269
  # This deletes rows containing only zeros,
273
270
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
274
- rowsums = np.sum(contingency_matrix, axis=1)
275
- rowmask = np.nonzero(rowsums)[0]
276
- contingency_matrix = contingency_matrix[rowmask]
271
+ contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
277
272
 
278
- chi2, p, _, _ = chi2_contingency(contingency_matrix)
279
-
280
- chi_scores[i] = chi2
281
- p_values[i] = p
273
+ chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
282
274
 
283
275
  if insufficient_data:
284
276
  warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
285
277
 
286
- return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)
278
+ return ParityOutput(
279
+ score=chi_scores,
280
+ p_value=p_values,
281
+ factor_names=metadata.discrete_factor_names,
282
+ insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
283
+ )
@@ -38,8 +38,7 @@ def divergence_mst(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
38
38
  """
39
39
  mst = minimum_spanning_tree(data).toarray()
40
40
  edgelist = np.transpose(np.nonzero(mst))
41
- errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
42
- return errors
41
+ return np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
43
42
 
44
43
 
45
44
  def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
@@ -59,8 +58,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
59
58
  Number of label errors when finding nearest neighbors
60
59
  """
61
60
  nn_indices = compute_neighbors(data, data)
62
- errors = np.sum(np.abs(labels[nn_indices] - labels))
63
- return errors
61
+ return np.sum(np.abs(labels[nn_indices] - labels))
64
62
 
65
63
 
66
64
  _DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}