dataeval 0.75.0__py3-none-any.whl → 0.76.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/drift/base.py +2 -2
  3. dataeval/detectors/drift/ks.py +2 -1
  4. dataeval/detectors/drift/mmd.py +3 -2
  5. dataeval/detectors/drift/uncertainty.py +2 -2
  6. dataeval/detectors/drift/updates.py +1 -1
  7. dataeval/detectors/linters/clusterer.py +3 -2
  8. dataeval/detectors/linters/duplicates.py +4 -4
  9. dataeval/detectors/linters/outliers.py +96 -3
  10. dataeval/detectors/ood/__init__.py +1 -1
  11. dataeval/detectors/ood/base.py +1 -17
  12. dataeval/detectors/ood/output.py +1 -1
  13. dataeval/interop.py +1 -1
  14. dataeval/metrics/__init__.py +1 -1
  15. dataeval/metrics/bias/__init__.py +1 -1
  16. dataeval/metrics/bias/balance.py +3 -3
  17. dataeval/metrics/bias/coverage.py +1 -1
  18. dataeval/metrics/bias/diversity.py +14 -10
  19. dataeval/metrics/bias/parity.py +5 -5
  20. dataeval/metrics/estimators/ber.py +4 -3
  21. dataeval/metrics/estimators/divergence.py +3 -3
  22. dataeval/metrics/estimators/uap.py +3 -3
  23. dataeval/metrics/stats/__init__.py +1 -1
  24. dataeval/metrics/stats/base.py +24 -8
  25. dataeval/metrics/stats/boxratiostats.py +5 -5
  26. dataeval/metrics/stats/datasetstats.py +39 -6
  27. dataeval/metrics/stats/dimensionstats.py +4 -4
  28. dataeval/metrics/stats/hashstats.py +2 -2
  29. dataeval/metrics/stats/labelstats.py +89 -6
  30. dataeval/metrics/stats/pixelstats.py +7 -5
  31. dataeval/metrics/stats/visualstats.py +6 -4
  32. dataeval/output.py +23 -14
  33. dataeval/utils/__init__.py +2 -2
  34. dataeval/utils/dataset/read.py +1 -1
  35. dataeval/utils/dataset/split.py +1 -1
  36. dataeval/utils/metadata.py +42 -44
  37. dataeval/utils/plot.py +129 -6
  38. dataeval/workflows/sufficiency.py +2 -2
  39. {dataeval-0.75.0.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
  40. {dataeval-0.75.0.dist-info → dataeval-0.76.0.dist-info}/METADATA +18 -17
  41. dataeval-0.76.0.dist-info/RECORD +67 -0
  42. dataeval-0.75.0.dist-info/RECORD +0 -67
  43. {dataeval-0.75.0.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
@@ -7,7 +7,7 @@ from typing import Any, Iterable
7
7
 
8
8
  from numpy.typing import ArrayLike
9
9
 
10
- from dataeval.metrics.stats.base import BaseStatsOutput, run_stats
10
+ from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, _is_plottable, run_stats
11
11
  from dataeval.metrics.stats.dimensionstats import (
12
12
  DimensionStatsOutput,
13
13
  DimensionStatsProcessor,
@@ -16,12 +16,13 @@ from dataeval.metrics.stats.labelstats import LabelStatsOutput, labelstats
16
16
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput, PixelStatsProcessor
17
17
  from dataeval.metrics.stats.visualstats import VisualStatsOutput, VisualStatsProcessor
18
18
  from dataeval.output import Output, set_metadata
19
+ from dataeval.utils.plot import channel_histogram_plot
19
20
 
20
21
 
21
22
  @dataclass(frozen=True)
22
- class DatasetStatsOutput(Output):
23
+ class DatasetStatsOutput(Output, HistogramPlotMixin):
23
24
  """
24
- Output class for :func:`datasetstats` stats metric
25
+ Output class for :func:`datasetstats` stats metric.
25
26
 
26
27
  This class represents the outputs of various stats functions against a single
27
28
  dataset, such that each index across all stat outputs are representative of
@@ -41,6 +42,8 @@ class DatasetStatsOutput(Output):
41
42
  visualstats: VisualStatsOutput
42
43
  labelstats: LabelStatsOutput | None = None
43
44
 
45
+ _excluded_keys = ["histogram", "percentiles"]
46
+
44
47
  def _outputs(self) -> list[Output]:
45
48
  return [s for s in (self.dimensionstats, self.pixelstats, self.visualstats, self.labelstats) if s is not None]
46
49
 
@@ -53,10 +56,33 @@ class DatasetStatsOutput(Output):
53
56
  raise ValueError("All StatsOutput classes must contain the same number of image sources.")
54
57
 
55
58
 
59
+ def _get_channels(cls, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None):
60
+ raw_channels = max([si.channel for si in cls.dict()["source_index"]]) + 1
61
+ if isinstance(channel_index, int):
62
+ max_channels = 1 if channel_index < raw_channels else raw_channels
63
+ ch_mask = cls.pixelstats.get_channel_mask(channel_index)
64
+ elif isinstance(channel_index, Iterable) and all(isinstance(val, int) for val in list(channel_index)):
65
+ max_channels = len(list(channel_index))
66
+ ch_mask = cls.pixelstats.get_channel_mask(channel_index)
67
+ elif isinstance(channel_limit, int):
68
+ max_channels = channel_limit
69
+ ch_mask = cls.pixelstats.get_channel_mask(None, channel_limit)
70
+ else:
71
+ max_channels = raw_channels
72
+ ch_mask = None
73
+
74
+ if max_channels > raw_channels:
75
+ max_channels = raw_channels
76
+ if ch_mask is not None and not any(ch_mask):
77
+ ch_mask = None
78
+
79
+ return max_channels, ch_mask
80
+
81
+
56
82
  @dataclass(frozen=True)
57
83
  class ChannelStatsOutput(Output):
58
84
  """
59
- Output class for :func:`channelstats` stats metric
85
+ Output class for :func:`channelstats` stats metric.
60
86
 
61
87
  This class represents the outputs of various per-channel stats functions against
62
88
  a single dataset, such that each index across all stat outputs are representative
@@ -83,6 +109,13 @@ class ChannelStatsOutput(Output):
83
109
  if not all(length == lengths[0] for length in lengths):
84
110
  raise ValueError("All StatsOutput classes must contain the same number of image sources.")
85
111
 
112
+ def plot(
113
+ self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
114
+ ) -> None:
115
+ max_channels, ch_mask = _get_channels(self, channel_limit, channel_index)
116
+ data_dict = {k: v for k, v in self.dict().items() if _is_plottable(k, v, ("histogram", "percentiles"))}
117
+ channel_histogram_plot(data_dict, log, max_channels, ch_mask)
118
+
86
119
 
87
120
  @set_metadata
88
121
  def datasetstats(
@@ -91,7 +124,7 @@ def datasetstats(
91
124
  labels: Iterable[ArrayLike] | None = None,
92
125
  ) -> DatasetStatsOutput:
93
126
  """
94
- Calculates various :term:`statistics<Statistics>` for each image
127
+ Calculates various :term:`statistics<Statistics>` for each image.
95
128
 
96
129
  This function computes dimension, pixel and visual metrics
97
130
  on the images or individual bounding boxes for each image as
@@ -135,7 +168,7 @@ def channelstats(
135
168
  bboxes: Iterable[ArrayLike] | None = None,
136
169
  ) -> ChannelStatsOutput:
137
170
  """
138
- Calculates various per-channel statistics for each image
171
+ Calculates various per-channel :term:`statistics` for each image.
139
172
 
140
173
  This function computes pixel and visual metrics on the images
141
174
  or individual bounding boxes for each image.
@@ -8,15 +8,15 @@ from typing import Any, Callable, Iterable
8
8
  import numpy as np
9
9
  from numpy.typing import ArrayLike, NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
11
+ from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
12
12
  from dataeval.output import set_metadata
13
13
  from dataeval.utils.image import get_bitdepth
14
14
 
15
15
 
16
16
  @dataclass(frozen=True)
17
- class DimensionStatsOutput(BaseStatsOutput):
17
+ class DimensionStatsOutput(BaseStatsOutput, HistogramPlotMixin):
18
18
  """
19
- Output class for :func:`dimensionstats` stats metric
19
+ Output class for :func:`dimensionstats` stats metric.
20
20
 
21
21
  Attributes
22
22
  ----------
@@ -79,7 +79,7 @@ def dimensionstats(
79
79
  bboxes: Iterable[ArrayLike] | None = None,
80
80
  ) -> DimensionStatsOutput:
81
81
  """
82
- Calculates dimension :term:`statistics<Statistics>` for each image
82
+ Calculates dimension :term:`statistics<Statistics>` for each image.
83
83
 
84
84
  This function computes various dimensional metrics (e.g., width, height, channels)
85
85
  on the images or individual bounding boxes for each image.
@@ -25,7 +25,7 @@ MAX_FACTOR = 4
25
25
  @dataclass(frozen=True)
26
26
  class HashStatsOutput(BaseStatsOutput):
27
27
  """
28
- Output class for :func:`hashstats` stats metric
28
+ Output class for :func:`hashstats` stats metric.
29
29
 
30
30
  Attributes
31
31
  ----------
@@ -126,7 +126,7 @@ def hashstats(
126
126
  bboxes: Iterable[ArrayLike] | None = None,
127
127
  ) -> HashStatsOutput:
128
128
  """
129
- Calculates hashes for each image
129
+ Calculates hashes for each image.
130
130
 
131
131
  This function computes hashes from the images including exact hashes and perception-based
132
132
  hashes. These hash values can be used to determine if images are exact or near matches.
@@ -2,20 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ # import contextlib
5
6
  from collections import Counter, defaultdict
6
7
  from dataclasses import dataclass
7
8
  from typing import Any, Iterable, Mapping, TypeVar
8
9
 
10
+ import numpy as np
9
11
  from numpy.typing import ArrayLike
10
12
 
11
- from dataeval.interop import to_numpy
13
+ from dataeval.interop import as_numpy
12
14
  from dataeval.output import Output, set_metadata
13
15
 
16
+ # with contextlib.suppress(ImportError):
17
+ # import pandas as pd
18
+
14
19
 
15
20
  @dataclass(frozen=True)
16
21
  class LabelStatsOutput(Output):
17
22
  """
18
- Output class for :func:`labelstats` stats metric
23
+ Output class for :func:`labelstats` stats metric.
19
24
 
20
25
  Attributes
21
26
  ----------
@@ -46,6 +51,47 @@ class LabelStatsOutput(Output):
46
51
  class_count: int
47
52
  label_count: int
48
53
 
54
+ def to_table(self) -> str:
55
+ max_char = max(len(key) if isinstance(key, str) else key // 10 + 1 for key in self.label_counts_per_class)
56
+ max_char = max(max_char, 5)
57
+ max_label = max(list(self.label_counts_per_class.values()))
58
+ max_img = max(list(self.image_counts_per_label.values()))
59
+ max_num = int(np.ceil(np.log10(max(max_label, max_img))))
60
+ max_num = max(max_num, 11)
61
+
62
+ # Display basic counts
63
+ table_str = f"Class Count: {self.class_count}\n"
64
+ table_str += f"Label Count: {self.label_count}\n"
65
+ table_str += f"Average # Labels per Image: {round(np.mean(self.label_counts_per_image), 2)}\n"
66
+ table_str += "--------------------------------------\n"
67
+
68
+ # Display counts per class
69
+ table_str += f"{'Label':>{max_char}}: Total Count - Image Count\n"
70
+ for cls in self.label_counts_per_class:
71
+ table_str += f"{cls:>{max_char}}: {self.label_counts_per_class[cls]:^{max_num}} "
72
+ table_str += f"- {self.image_counts_per_label[cls]:^{max_num}}\n"
73
+
74
+ return table_str
75
+
76
+ # def to_dataframe(self) -> pd.DataFrame:
77
+ # import pandas as pd
78
+
79
+ # class_list = []
80
+ # total_count = []
81
+ # image_count = []
82
+ # for cls in self.label_counts_per_class:
83
+ # class_list.append(cls)
84
+ # total_count.append(self.label_counts_per_class[cls])
85
+ # image_count.append(self.image_counts_per_label[cls])
86
+
87
+ # return pd.DataFrame(
88
+ # {
89
+ # "Label": class_list,
90
+ # "Total Count": total_count,
91
+ # "Image Count": image_count,
92
+ # }
93
+ # )
94
+
49
95
 
50
96
  TKey = TypeVar("TKey", int, str)
51
97
 
@@ -57,12 +103,47 @@ def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
57
103
  return dict(sorted(d.items(), key=lambda x: x[0]))
58
104
 
59
105
 
106
+ def _ensure_2d(labels: Iterable[ArrayLike]) -> Iterable[ArrayLike]:
107
+ if isinstance(labels, np.ndarray):
108
+ return labels[:, None]
109
+ else:
110
+ return [[lbl] for lbl in labels] # type: ignore
111
+
112
+
113
+ def _get_list_depth(lst):
114
+ if isinstance(lst, list) and lst:
115
+ return 1 + max(_get_list_depth(item) for item in lst)
116
+ return 0
117
+
118
+
119
+ def _check_labels_dimension(labels: Iterable[ArrayLike]) -> Iterable[ArrayLike]:
120
+ # Check for nested lists beyond 2 levels
121
+
122
+ if isinstance(labels, np.ndarray):
123
+ if labels.ndim == 1:
124
+ return _ensure_2d(labels)
125
+ elif labels.ndim == 2:
126
+ return labels
127
+ else:
128
+ raise ValueError("The label array must not have more than 2 dimensions.")
129
+ elif isinstance(labels, list):
130
+ depth = _get_list_depth(labels)
131
+ if depth == 1:
132
+ return _ensure_2d(labels)
133
+ elif depth == 2:
134
+ return labels
135
+ else:
136
+ raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
137
+ else:
138
+ raise TypeError("Labels must be either a NumPy array or a list.")
139
+
140
+
60
141
  @set_metadata
61
142
  def labelstats(
62
143
  labels: Iterable[ArrayLike],
63
144
  ) -> LabelStatsOutput:
64
145
  """
65
- Calculates :term:`statistics<Statistics>` for data labels
146
+ Calculates :term:`statistics<Statistics>` for data labels.
66
147
 
67
148
  This function computes counting metrics (e.g., total per class, total per image)
68
149
  on the labels.
@@ -99,10 +180,12 @@ def labelstats(
99
180
  index_location = defaultdict(list[int])
100
181
  label_per_image: list[int] = []
101
182
 
102
- for i, group in enumerate(labels):
103
- # Count occurrences of each label in all sublists
104
- group = to_numpy(group)
183
+ labels_2d = _check_labels_dimension(labels)
184
+
185
+ for i, group in enumerate(labels_2d):
186
+ group = as_numpy(group)
105
187
 
188
+ # Count occurrences of each label in all sublists
106
189
  label_counts.update(group)
107
190
 
108
191
  # Get the number of labels per image
@@ -9,14 +9,14 @@ import numpy as np
9
9
  from numpy.typing import ArrayLike, NDArray
10
10
  from scipy.stats import entropy, kurtosis, skew
11
11
 
12
- from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
12
+ from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
13
  from dataeval.output import set_metadata
14
14
 
15
15
 
16
16
  @dataclass(frozen=True)
17
- class PixelStatsOutput(BaseStatsOutput):
17
+ class PixelStatsOutput(BaseStatsOutput, HistogramPlotMixin):
18
18
  """
19
- Output class for :func:`pixelstats` stats metric
19
+ Output class for :func:`pixelstats` stats metric.
20
20
 
21
21
  Attributes
22
22
  ----------
@@ -44,11 +44,13 @@ class PixelStatsOutput(BaseStatsOutput):
44
44
  histogram: NDArray[np.uint32]
45
45
  entropy: NDArray[np.float16]
46
46
 
47
+ _excluded_keys = ["histogram"]
48
+
47
49
 
48
50
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
49
51
  output_class: type = PixelStatsOutput
50
52
  image_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
51
- "mean": lambda self: np.mean(self.scaled),
53
+ "mean": lambda x: np.mean(x.scaled),
52
54
  "std": lambda x: np.std(x.scaled),
53
55
  "var": lambda x: np.var(x.scaled),
54
56
  "skew": lambda x: np.nan_to_num(skew(x.scaled.ravel())),
@@ -74,7 +76,7 @@ def pixelstats(
74
76
  per_channel: bool = False,
75
77
  ) -> PixelStatsOutput:
76
78
  """
77
- Calculates pixel :term:`statistics<Statistics>` for each image
79
+ Calculates pixel :term:`statistics<Statistics>` for each image.
78
80
 
79
81
  This function computes various statistical metrics (e.g., mean, standard deviation, entropy)
80
82
  on the images as a whole.
@@ -8,7 +8,7 @@ from typing import Any, Callable, Iterable
8
8
  import numpy as np
9
9
  from numpy.typing import ArrayLike, NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
11
+ from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
12
12
  from dataeval.output import set_metadata
13
13
  from dataeval.utils.image import edge_filter
14
14
 
@@ -16,9 +16,9 @@ QUARTILES = (0, 25, 50, 75, 100)
16
16
 
17
17
 
18
18
  @dataclass(frozen=True)
19
- class VisualStatsOutput(BaseStatsOutput):
19
+ class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
20
20
  """
21
- Output class for :func:`visualstats` stats metric
21
+ Output class for :func:`visualstats` stats metric.
22
22
 
23
23
  Attributes
24
24
  ----------
@@ -46,6 +46,8 @@ class VisualStatsOutput(BaseStatsOutput):
46
46
  zeros: NDArray[np.float16]
47
47
  percentiles: NDArray[np.float16]
48
48
 
49
+ _excluded_keys = ["percentiles"]
50
+
49
51
 
50
52
  class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
51
53
  output_class: type = VisualStatsOutput
@@ -81,7 +83,7 @@ def visualstats(
81
83
  per_channel: bool = False,
82
84
  ) -> VisualStatsOutput:
83
85
  """
84
- Calculates visual statistics for each image
86
+ Calculates visual :term:`statistics` for each image.
85
87
 
86
88
  This function computes various visual metrics (e.g., :term:`brightness<Brightness>`, darkness, contrast, blurriness)
87
89
  on the images as a whole.
dataeval/output.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import inspect
6
+ import logging
6
7
  import sys
7
8
  from collections.abc import Mapping
8
9
  from datetime import datetime, timezone
@@ -81,29 +82,37 @@ def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None =
81
82
  return f"{v.__class__.__name__}: len={len(v)}"
82
83
  return f"{v.__class__.__name__}"
83
84
 
84
- time = datetime.now(timezone.utc)
85
- result = fn(*args, **kwargs)
86
- duration = (datetime.now(timezone.utc) - time).total_seconds()
87
- fn_params = inspect.signature(fn).parameters
88
-
85
+ # Collect function metadata
89
86
  # set all params with defaults then update params with mapped arguments and explicit keyword args
87
+ fn_params = inspect.signature(fn).parameters
90
88
  arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
91
89
  arguments.update(zip(fn_params, args))
92
90
  arguments.update(kwargs)
93
91
  arguments = {k: fmt(v) for k, v in arguments.items()}
94
- state_attrs = (
95
- {k: fmt(getattr(args[0], k)) for k in state if "self" in arguments} if "self" in arguments and state else {}
96
- )
97
- name = (
98
- f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
99
- if "self" in arguments
100
- else f"{fn.__module__}.{fn.__qualname__}"
101
- )
92
+ is_method = "self" in arguments
93
+ state_attrs = {k: fmt(getattr(args[0], k)) for k in state or []} if is_method else {}
94
+ module = args[0].__class__.__module__ if is_method else fn.__module__.removeprefix("src.")
95
+ class_prefix = f".{args[0].__class__.__name__}." if is_method else "."
96
+ name = f"{module}{class_prefix}{fn.__name__}"
97
+ arguments = {k: v for k, v in arguments.items() if k != "self"}
98
+
99
+ _logger = logging.getLogger(module)
100
+ time = datetime.now(timezone.utc)
101
+ _logger.log(logging.INFO, f">>> Executing '{name}': args={arguments} state={state} <<<")
102
+
103
+ ##### EXECUTE FUNCTION #####
104
+ result = fn(*args, **kwargs)
105
+ ############################
106
+
107
+ duration = (datetime.now(timezone.utc) - time).total_seconds()
108
+ _logger.log(logging.INFO, f">>> Completed '{name}': args={arguments} state={state} duration={duration} <<<")
109
+
110
+ # Update output with recorded metadata
102
111
  metadata = {
103
112
  "_name": name,
104
113
  "_execution_time": time,
105
114
  "_execution_duration": duration,
106
- "_arguments": {k: v for k, v in arguments.items() if k != "self"},
115
+ "_arguments": arguments,
107
116
  "_state": state_attrs,
108
117
  "_version": __version__,
109
118
  }
@@ -1,6 +1,6 @@
1
1
  """
2
- The utility classes and functions are provided by DataEval to assist users
3
- in setting up data and architectures that are guaranteed to work with applicable
2
+ The utility classes and functions are provided by DataEval to assist users \
3
+ in setting up data and architectures that are guaranteed to work with applicable \
4
4
  DataEval metrics.
5
5
  """
6
6
 
@@ -10,7 +10,7 @@ from torch.utils.data import Dataset
10
10
 
11
11
  def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
12
12
  """
13
- Extract information from a dataset at each index into individual lists of each information position
13
+ Extract information from a dataset at each index into individual lists of each information position.
14
14
 
15
15
  Parameters
16
16
  ----------
@@ -26,7 +26,7 @@ class TrainValSplit(NamedTuple):
26
26
  @dataclass(frozen=True)
27
27
  class SplitDatasetOutput(Output):
28
28
  """
29
- Output class containing test indices and a list of TrainValSplits
29
+ Output class containing test indices and a list of TrainValSplits.
30
30
 
31
31
  Attributes
32
32
  ----------
@@ -1,11 +1,11 @@
1
1
  """
2
- Metadata related utility functions that help organize raw metadata into :class:`Metadata` objects
3
- for use within `DataEval`.
2
+ Metadata related utility functions that help organize raw metadata into \
3
+ :class:`Metadata` objects for use within `DataEval`.
4
4
  """
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- __all__ = ["Metadata", "preprocess", "merge"]
8
+ __all__ = ["Metadata", "preprocess", "merge", "flatten"]
9
9
 
10
10
  import warnings
11
11
  from dataclasses import dataclass
@@ -18,7 +18,6 @@ from scipy.stats import wasserstein_distance as wd
18
18
  from dataeval.interop import as_numpy, to_numpy
19
19
  from dataeval.output import Output, set_metadata
20
20
 
21
- TNum = TypeVar("TNum", int, float)
22
21
  DISCRETE_MIN_WD = 0.054
23
22
  CONTINUOUS_MIN_SAMPLE_SIZE = 20
24
23
 
@@ -146,9 +145,7 @@ def _flatten_dict_inner(
146
145
  return items, size
147
146
 
148
147
 
149
- def _flatten_dict(
150
- d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
151
- ) -> tuple[dict[str, Any], int]:
148
+ def flatten(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> tuple[dict[str, Any], int]:
152
149
  """
153
150
  Flattens a dictionary and converts values to numeric values when possible.
154
151
 
@@ -161,12 +158,12 @@ def _flatten_dict(
161
158
  ignore_lists : bool
162
159
  Option to skip expanding lists within metadata
163
160
  fully_qualified : bool
164
- Option to return dictionary keys full qualified instead of minimized
161
+ Option to return dictionary keys full qualified instead of reduced
165
162
 
166
163
  Returns
167
164
  -------
168
- dict[str, Any]
169
- A flattened dictionary
165
+ tuple[dict[str, Any], int]
166
+ A tuple of the flattened dictionary and the length of detected lists in metadata
170
167
  """
171
168
  expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
172
169
 
@@ -260,9 +257,7 @@ def merge(
260
257
 
261
258
  image_repeats = np.zeros(len(dicts))
262
259
  for i, d in enumerate(dicts):
263
- flattened, image_repeats[i] = _flatten_dict(
264
- d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
265
- )
260
+ flattened, image_repeats[i] = flatten(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
266
261
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
267
262
  union = union.union(flattened.keys())
268
263
  for k, v in flattened.items():
@@ -296,7 +291,7 @@ def merge(
296
291
  @dataclass(frozen=True)
297
292
  class Metadata(Output):
298
293
  """
299
- Dataclass containing binned metadata from the :func:`preprocess` function
294
+ Dataclass containing binned metadata from the :func:`preprocess` function.
300
295
 
301
296
  Attributes
302
297
  ----------
@@ -329,7 +324,7 @@ class Metadata(Output):
329
324
  def preprocess(
330
325
  raw_metadata: Iterable[Mapping[str, Any]],
331
326
  class_labels: ArrayLike | str,
332
- continuous_factor_bins: Mapping[str, int | list[tuple[TNum, TNum]]] | None = None,
327
+ continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
333
328
  auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
334
329
  exclude: Iterable[str] | None = None,
335
330
  ) -> Metadata:
@@ -348,8 +343,9 @@ def preprocess(
348
343
  class_labels : ArrayLike or string
349
344
  If arraylike, expects the labels for each image (image classification) or each object (object detection).
350
345
  If the labels are included in the metadata dictionary, pass in the key value.
351
- continuous_factor_bins : Mapping[str, int] or Mapping[str, list[tuple[TNum, TNum]]] or None, default None
352
- User provided dictionary specifying how to bin the continuous metadata factors
346
+ continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
347
+ User provided dictionary specifying how to bin the continuous metadata factors where the value is either
348
+ an int to represent the number of bins, or a list of floats representing the edges for each bin.
353
349
  auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
354
350
  Method by which the function will automatically bin continuous metadata factors. It is recommended
355
351
  that the user provide the bins through the `continuous_factor_bins`.
@@ -364,11 +360,13 @@ def preprocess(
364
360
  # Transform metadata into single, flattened dictionary
365
361
  metadata, image_repeats = merge(raw_metadata)
366
362
 
363
+ continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
364
+
367
365
  # Drop any excluded metadata keys
368
- if exclude:
369
- for k in list(metadata):
370
- if k in exclude:
371
- metadata.pop(k)
366
+ for k in exclude or ():
367
+ metadata.pop(k, None)
368
+ if continuous_factor_bins:
369
+ continuous_factor_bins.pop(k, None)
372
370
 
373
371
  # Get the class label array in numeric form
374
372
  class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
@@ -394,8 +392,8 @@ def preprocess(
394
392
  "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
395
393
  "or add corresponding entries to the `metadata` dictionary."
396
394
  )
397
- for factor, grouping in continuous_factor_bins.items():
398
- discrete_metadata[factor] = _user_defined_bin(metadata[factor], grouping)
395
+ for factor, bins in continuous_factor_bins.items():
396
+ discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
399
397
  continuous_metadata[factor] = metadata[factor]
400
398
 
401
399
  # Determine category of the rest of the keys
@@ -417,7 +415,7 @@ def preprocess(
417
415
  "bins using the continuous_factor_bins parameter.",
418
416
  UserWarning,
419
417
  )
420
- discrete_metadata[key] = _binning_function(data, auto_bin_method)
418
+ discrete_metadata[key] = _bin_data(data, auto_bin_method)
421
419
  else:
422
420
  _, discrete_metadata[key] = np.unique(data, return_inverse=True)
423
421
 
@@ -439,7 +437,7 @@ def preprocess(
439
437
  )
440
438
 
441
439
 
442
- def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[TNum, TNum]]) -> NDArray[np.intp]:
440
+ def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
443
441
  """
444
442
  Digitizes a list of values into a given number of bins.
445
443
 
@@ -447,8 +445,8 @@ def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[
447
445
  ----------
448
446
  data : list | NDArray
449
447
  The values to be digitized.
450
- binning : int | list[tuple[TNum, TNum]]
451
- The number of bins for the discrete values that data will be digitized into.
448
+ bins : int | Iterable[float]
449
+ The number of bins or list of bin edges for the discrete values that data will be digitized into.
452
450
 
453
451
  Returns
454
452
  -------
@@ -461,16 +459,16 @@ def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[
461
459
  "Encountered a data value with non-numeric type when digitizing a factor. "
462
460
  "Ensure all occurrences of continuous factors are numeric types."
463
461
  )
464
- if type(binning) is int:
465
- _, bin_edges = np.histogram(data, bins=binning)
462
+ if isinstance(bins, int):
463
+ _, bin_edges = np.histogram(data, bins=bins)
466
464
  bin_edges[-1] = np.inf
467
465
  bin_edges[0] = -np.inf
468
466
  else:
469
- bin_edges = binning
467
+ bin_edges = list(bins)
470
468
  return np.digitize(data, bin_edges)
471
469
 
472
470
 
473
- def _binning_function(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
471
+ def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
474
472
  """
475
473
  Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
476
474
  """
@@ -482,19 +480,19 @@ def _binning_function(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
482
480
  )
483
481
  bin_method = "uniform_width"
484
482
 
485
- if bin_method != "clusters":
486
- counts, bin_edges = np.histogram(data, bins="auto")
487
- n_bins = counts.size
488
- if counts[counts > 0].min() < 10:
489
- for _ in range(20):
490
- n_bins -= 1
491
- counts, bin_edges = np.histogram(data, bins=n_bins)
492
- if counts[counts > 0].min() >= 10 or n_bins < 2:
493
- break
494
-
495
- if bin_method == "uniform_count":
496
- quantiles = np.linspace(0, 100, n_bins + 1)
497
- bin_edges = np.asarray(np.percentile(data, quantiles))
483
+ # if bin_method != "clusters": # restore this when clusters bin_method is available
484
+ counts, bin_edges = np.histogram(data, bins="auto")
485
+ n_bins = counts.size
486
+ if counts[counts > 0].min() < 10:
487
+ counter = 20
488
+ while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
489
+ counter -= 1
490
+ n_bins -= 1
491
+ counts, bin_edges = np.histogram(data, bins=n_bins)
492
+
493
+ if bin_method == "uniform_count":
494
+ quantiles = np.linspace(0, 100, n_bins + 1)
495
+ bin_edges = np.asarray(np.percentile(data, quantiles))
498
496
 
499
497
  bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
500
498
  bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented