dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -2,108 +2,21 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  from collections import Counter, defaultdict
7
- from dataclasses import dataclass
8
- from typing import Any, Iterable, Mapping, TypeVar
6
+ from typing import Any, Mapping, TypeVar
9
7
 
10
8
  import numpy as np
11
9
 
12
- from dataeval._output import Output, set_metadata
13
- from dataeval.typing import ArrayLike
10
+ from dataeval.outputs import LabelStatsOutput
11
+ from dataeval.outputs._base import set_metadata
12
+ from dataeval.typing import AnnotatedDataset, ArrayLike
14
13
  from dataeval.utils._array import as_numpy
14
+ from dataeval.utils.data._metadata import Metadata
15
15
 
16
- with contextlib.suppress(ImportError):
17
- import pandas as pd
16
+ TValue = TypeVar("TValue")
18
17
 
19
18
 
20
- @dataclass(frozen=True)
21
- class LabelStatsOutput(Output):
22
- """
23
- Output class for :func:`.labelstats` stats metric.
24
-
25
- Attributes
26
- ----------
27
- label_counts_per_class : dict[str | int, int]
28
- Dictionary whose keys are the different label classes and
29
- values are total counts of each class
30
- label_counts_per_image : list[int]
31
- Number of labels per image
32
- image_counts_per_label : dict[str | int, int]
33
- Dictionary whose keys are the different label classes and
34
- values are total counts of each image the class is present in
35
- image_indices_per_label : dict[str | int, list]
36
- Dictionary whose keys are the different label classes and
37
- values are lists containing the images that have that label
38
- image_count : int
39
- Total number of images present
40
- class_count : int
41
- Total number of classes present
42
- label_count : int
43
- Total number of labels present
44
- """
45
-
46
- label_counts_per_class: dict[str | int, int]
47
- label_counts_per_image: list[int]
48
- image_counts_per_label: dict[str | int, int]
49
- image_indices_per_label: dict[str | int, list[int]]
50
- image_count: int
51
- class_count: int
52
- label_count: int
53
-
54
- def to_table(self) -> str:
55
- max_char = max(len(key) if isinstance(key, str) else key // 10 + 1 for key in self.label_counts_per_class)
56
- max_char = max(max_char, 5)
57
- max_label = max(list(self.label_counts_per_class.values()))
58
- max_img = max(list(self.image_counts_per_label.values()))
59
- max_num = int(np.ceil(np.log10(max(max_label, max_img))))
60
- max_num = max(max_num, 11)
61
-
62
- # Display basic counts
63
- table_str = f"Class Count: {self.class_count}\n"
64
- table_str += f"Label Count: {self.label_count}\n"
65
- table_str += f"Average # Labels per Image: {round(np.mean(self.label_counts_per_image), 2)}\n"
66
- table_str += "--------------------------------------\n"
67
-
68
- # Display counts per class
69
- table_str += f"{'Label':>{max_char}}: Total Count - Image Count\n"
70
- for cls in self.label_counts_per_class:
71
- table_str += f"{cls:>{max_char}}: {self.label_counts_per_class[cls]:^{max_num}} "
72
- table_str += f"- {self.image_counts_per_label[cls]:^{max_num}}\n"
73
-
74
- return table_str
75
-
76
- def to_dataframe(self) -> pd.DataFrame:
77
- import pandas as pd
78
-
79
- class_list = []
80
- total_count = []
81
- image_count = []
82
- for cls in self.label_counts_per_class:
83
- class_list.append(cls)
84
- total_count.append(self.label_counts_per_class[cls])
85
- image_count.append(self.image_counts_per_label[cls])
86
-
87
- return pd.DataFrame(
88
- {
89
- "Label": class_list,
90
- "Total Count": total_count,
91
- "Image Count": image_count,
92
- }
93
- )
94
-
95
-
96
- TKey = TypeVar("TKey", int, str)
97
-
98
-
99
- def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
100
- """
101
- Sort mappings by key in increasing order
102
- """
103
- return dict(sorted(d.items(), key=lambda x: x[0]))
104
-
105
-
106
- def _ensure_2d(labels: Iterable[ArrayLike]) -> Iterable[ArrayLike]:
19
+ def _ensure_2d(labels: ArrayLike) -> ArrayLike:
107
20
  if isinstance(labels, np.ndarray):
108
21
  return labels[:, None]
109
22
  else:
@@ -116,7 +29,7 @@ def _get_list_depth(lst):
116
29
  return 0
117
30
 
118
31
 
119
- def _check_labels_dimension(labels: Iterable[ArrayLike]) -> Iterable[ArrayLike]:
32
+ def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
120
33
  # Check for nested lists beyond 2 levels
121
34
 
122
35
  if isinstance(labels, np.ndarray):
@@ -138,10 +51,12 @@ def _check_labels_dimension(labels: Iterable[ArrayLike]) -> Iterable[ArrayLike]:
138
51
  raise TypeError("Labels must be either a NumPy array or a list.")
139
52
 
140
53
 
54
+ def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
55
+ return [v for _, v in sorted(d.items())]
56
+
57
+
141
58
  @set_metadata
142
- def labelstats(
143
- labels: Iterable[ArrayLike],
144
- ) -> LabelStatsOutput:
59
+ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
145
60
  """
146
61
  Calculates :term:`statistics<Statistics>` for data labels.
147
62
 
@@ -150,40 +65,45 @@ def labelstats(
150
65
 
151
66
  Parameters
152
67
  ----------
153
- labels : ArrayLike, shape - [label] | [[label]] or (N,M) | (N,)
154
- Lists or :term:`NumPy` array of labels.
155
- A set of lists where each list contains all labels per image -
156
- (e.g. [[label1, label2], [label2], [label1, label3]] or [label1, label2, label1, label3]).
157
- If a numpy array, N is the number of images, M is the number of labels per image.
68
+ dataset : Metadata or ImageClassificationDataset or ObjectDetect
158
69
 
159
70
  Returns
160
71
  -------
161
72
  LabelStatsOutput
162
- A dictionary-like object containing the computed counting metrics for the labels.
73
+ A dataclass containing the computed counting metrics for the labels.
163
74
 
164
75
  Examples
165
76
  --------
166
- Calculating the :term:`statistics<Statistics>` on labels for a set of data
167
-
168
- >>> stats = labelstats(labels)
169
- >>> stats.label_counts_per_class
170
- {'chicken': 12, 'cow': 5, 'horse': 4, 'pig': 7, 'sheep': 4}
171
- >>> stats.label_counts_per_image
172
- [3, 3, 5, 3, 2, 5, 5, 2, 2, 2]
173
- >>> stats.image_counts_per_label
174
- {'chicken': 8, 'cow': 4, 'horse': 4, 'pig': 7, 'sheep': 4}
175
- >>> (stats.image_count, stats.class_count, stats.label_count)
176
- (10, 5, 32)
77
+ Calculate basic :term:`statistics<Statistics>` on labels for a dataset.
78
+
79
+ >>> from dataeval.utils.data import Metadata
80
+ >>> stats = labelstats(Metadata(dataset))
81
+ >>> print(stats.to_table())
82
+ Class Count: 5
83
+ Label Count: 15
84
+ Average # Labels per Image: 1.88
85
+ --------------------------------------
86
+ Label: Total Count - Image Count
87
+ horse: 2 - 2
88
+ cow: 4 - 3
89
+ sheep: 2 - 2
90
+ pig: 2 - 2
91
+ chicken: 5 - 5
177
92
  """
178
- label_counts = Counter()
179
- image_counts = Counter()
93
+ dataset = Metadata(dataset) if isinstance(dataset, AnnotatedDataset) else dataset
94
+
95
+ label_counts: Counter[int] = Counter()
96
+ image_counts: Counter[int] = Counter()
180
97
  index_location = defaultdict(list[int])
181
98
  label_per_image: list[int] = []
182
99
 
100
+ index2label = dict(enumerate(dataset.class_names))
101
+ labels = [target.labels.tolist() for target in dataset.targets]
102
+
183
103
  labels_2d = _check_labels_dimension(labels)
184
104
 
185
105
  for i, group in enumerate(labels_2d):
186
- group = as_numpy(group)
106
+ group = as_numpy(group).tolist()
187
107
 
188
108
  # Count occurrences of each label in all sublists
189
109
  label_counts.update(group)
@@ -200,11 +120,12 @@ def labelstats(
200
120
  index_location[item].append(i)
201
121
 
202
122
  return LabelStatsOutput(
203
- label_counts_per_class=sort(label_counts),
123
+ label_counts_per_class=_sort_to_list(label_counts),
204
124
  label_counts_per_image=label_per_image,
205
- image_counts_per_label=sort(image_counts),
206
- image_indices_per_label=sort(index_location),
125
+ image_counts_per_class=_sort_to_list(image_counts),
126
+ image_indices_per_class=_sort_to_list(index_location),
207
127
  image_count=len(label_per_image),
208
128
  class_count=len(label_counts),
209
129
  label_count=sum(label_counts.values()),
130
+ class_names=list(index2label.values()),
210
131
  )
@@ -2,50 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from dataclasses import dataclass
6
- from typing import Any, Callable, Iterable
5
+ from typing import Any, Callable
7
6
 
8
7
  import numpy as np
9
- from numpy.typing import NDArray
10
8
  from scipy.stats import entropy, kurtosis, skew
11
9
 
12
- from dataeval._output import set_metadata
13
- from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
14
- from dataeval.typing import ArrayLike
15
-
16
-
17
- @dataclass(frozen=True)
18
- class PixelStatsOutput(BaseStatsOutput, HistogramPlotMixin):
19
- """
20
- Output class for :func:`.pixelstats` stats metric.
21
-
22
- Attributes
23
- ----------
24
- mean : NDArray[np.float16]
25
- Mean of the pixel values of the images
26
- std : NDArray[np.float16]
27
- Standard deviation of the pixel values of the images
28
- var : NDArray[np.float16]
29
- :term:`Variance` of the pixel values of the images
30
- skew : NDArray[np.float16]
31
- Skew of the pixel values of the images
32
- kurtosis : NDArray[np.float16]
33
- Kurtosis of the pixel values of the images
34
- histogram : NDArray[np.uint32]
35
- Histogram of the pixel values of the images across 256 bins scaled between 0 and 1
36
- entropy : NDArray[np.float16]
37
- Entropy of the pixel values of the images
38
- """
39
-
40
- mean: NDArray[np.float16]
41
- std: NDArray[np.float16]
42
- var: NDArray[np.float16]
43
- skew: NDArray[np.float16]
44
- kurtosis: NDArray[np.float16]
45
- histogram: NDArray[np.uint32]
46
- entropy: NDArray[np.float16]
47
-
48
- _excluded_keys = ["histogram"]
10
+ from dataeval.metrics.stats._base import StatsProcessor, run_stats
11
+ from dataeval.outputs import PixelStatsOutput
12
+ from dataeval.outputs._base import set_metadata
13
+ from dataeval.typing import ArrayLike, Dataset
49
14
 
50
15
 
51
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
@@ -72,8 +37,9 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
72
37
 
73
38
  @set_metadata
74
39
  def pixelstats(
75
- images: Iterable[ArrayLike],
76
- bboxes: Iterable[ArrayLike] | None = None,
40
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
41
+ *,
42
+ per_box: bool = False,
77
43
  per_channel: bool = False,
78
44
  ) -> PixelStatsOutput:
79
45
  """
@@ -84,10 +50,12 @@ def pixelstats(
84
50
 
85
51
  Parameters
86
52
  ----------
87
- images : Iterable[ArrayLike]
88
- Images to perform calculations on
89
- bboxes : Iterable[ArrayLike] or None
90
- Bounding boxes in `xyxy` format for each image to perform calculations
53
+ dataset : Dataset
54
+ Dataset to perform calculations on.
55
+ per_box : bool, default False
56
+ If True, perform calculations on each bounding box.
57
+ per_channel : bool, default False
58
+ If True, perform calculations on each channel.
91
59
 
92
60
  Returns
93
61
  -------
@@ -107,12 +75,12 @@ def pixelstats(
107
75
 
108
76
  Examples
109
77
  --------
110
- Calculating the statistics on the images, whose shape is (C, H, W)
78
+ Calculate the pixel statistics of a dataset of 8 images, whose shape is (C, H, W).
111
79
 
112
- >>> results = pixelstats(stats_images)
80
+ >>> results = pixelstats(dataset)
113
81
  >>> print(results.mean)
114
- [0.2903 0.2108 0.397 0.596 0.743 ]
82
+ [0.181 0.132 0.248 0.373 0.464 0.613 0.734 0.854]
115
83
  >>> print(results.entropy)
116
- [4.99 2.371 1.179 2.406 0.668]
84
+ [4.527 1.883 0.811 1.883 0.298 1.883 1.883 1.883]
117
85
  """
118
- return run_stats(images, bboxes, per_channel, [PixelStatsProcessor])[0]
86
+ return run_stats(dataset, per_box, per_channel, [PixelStatsProcessor])[0]
@@ -2,54 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from dataclasses import dataclass
6
- from typing import Any, Callable, Iterable
5
+ from typing import Any, Callable
7
6
 
8
7
  import numpy as np
9
- from numpy.typing import NDArray
10
8
 
11
- from dataeval._output import set_metadata
12
- from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
- from dataeval.typing import ArrayLike
9
+ from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
+ from dataeval.outputs import VisualStatsOutput
11
+ from dataeval.outputs._base import set_metadata
12
+ from dataeval.typing import ArrayLike, Dataset
14
13
  from dataeval.utils._image import edge_filter
15
14
 
16
15
  QUARTILES = (0, 25, 50, 75, 100)
17
16
 
18
17
 
19
- @dataclass(frozen=True)
20
- class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
21
- """
22
- Output class for :func:`.visualstats` stats metric.
23
-
24
- Attributes
25
- ----------
26
- brightness : NDArray[np.float16]
27
- Brightness of the images
28
- contrast : NDArray[np.float16]
29
- Image contrast ratio
30
- darkness : NDArray[np.float16]
31
- Darkness of the images
32
- missing : NDArray[np.float16]
33
- Percentage of the images with missing pixels
34
- sharpness : NDArray[np.float16]
35
- Sharpness of the images
36
- zeros : NDArray[np.float16]
37
- Percentage of the images with zero value pixels
38
- percentiles : NDArray[np.float16]
39
- Percentiles of the pixel values of the images with quartiles of (0, 25, 50, 75, 100)
40
- """
41
-
42
- brightness: NDArray[np.float16]
43
- contrast: NDArray[np.float16]
44
- darkness: NDArray[np.float16]
45
- missing: NDArray[np.float16]
46
- sharpness: NDArray[np.float16]
47
- zeros: NDArray[np.float16]
48
- percentiles: NDArray[np.float16]
49
-
50
- _excluded_keys = ["percentiles"]
51
-
52
-
53
18
  class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
54
19
  output_class: type = VisualStatsOutput
55
20
  image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
@@ -79,8 +44,9 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
79
44
 
80
45
  @set_metadata
81
46
  def visualstats(
82
- images: Iterable[ArrayLike],
83
- bboxes: Iterable[ArrayLike] | None = None,
47
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
48
+ *,
49
+ per_box: bool = False,
84
50
  per_channel: bool = False,
85
51
  ) -> VisualStatsOutput:
86
52
  """
@@ -91,10 +57,12 @@ def visualstats(
91
57
 
92
58
  Parameters
93
59
  ----------
94
- images : Iterable[ArrayLike]
95
- Images to perform calculations on
96
- bboxes : Iterable[ArrayLike] or None
97
- Bounding boxes in `xyxy` format for each image to perform calculations on
60
+ dataset : Dataset
61
+ Dataset to perform calculations on.
62
+ per_box : bool, default False
63
+ If True, perform calculations on each bounding box.
64
+ per_channel : bool, default False
65
+ If True, perform calculations on each channel.
98
66
 
99
67
  Returns
100
68
  -------
@@ -113,12 +81,12 @@ def visualstats(
113
81
 
114
82
  Examples
115
83
  --------
116
- Calculating the :term:`statistics<Statistics>` on the images, whose shape is (C, H, W)
84
+ Calculate the visual statistics of a dataset of 8 images, whose shape is (C, H, W).
117
85
 
118
- >>> results = visualstats(stats_images)
86
+ >>> results = visualstats(dataset)
119
87
  >>> print(results.brightness)
120
- [0.1353 0.2085 0.4143 0.6084 0.8135]
88
+ [0.084 0.13 0.259 0.38 0.508 0.63 0.755 0.88 ]
121
89
  >>> print(results.contrast)
122
- [2.04 1.331 1.261 1.279 1.253]
90
+ [2.04 1.331 1.261 1.279 1.253 1.268 1.265 1.263]
123
91
  """
124
- return run_stats(images, bboxes, per_channel, [VisualStatsProcessor])[0]
92
+ return run_stats(dataset, per_box, per_channel, [VisualStatsProcessor])[0]
@@ -0,0 +1,57 @@
1
+ """
2
+ Output classes for DataEval to store function and method outputs
3
+ as well as runtime metadata for reproducibility and logging.
4
+ """
5
+
6
+ from ._base import ExecutionMetadata
7
+ from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
8
+ from ._drift import DriftMMDOutput, DriftOutput
9
+ from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
10
+ from ._linters import DuplicatesOutput, OutliersOutput
11
+ from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
12
+ from ._ood import OODOutput, OODScoreOutput
13
+ from ._stats import (
14
+ ChannelStatsOutput,
15
+ DimensionStatsOutput,
16
+ HashStatsOutput,
17
+ ImageStatsOutput,
18
+ LabelStatsOutput,
19
+ PixelStatsOutput,
20
+ SourceIndex,
21
+ VisualStatsOutput,
22
+ )
23
+ from ._utils import SplitDatasetOutput, TrainValSplit
24
+ from ._workflows import SufficiencyOutput
25
+
26
+ __all__ = [
27
+ "BEROutput",
28
+ "BalanceOutput",
29
+ "ChannelStatsOutput",
30
+ "ClustererOutput",
31
+ "CoverageOutput",
32
+ "DimensionStatsOutput",
33
+ "DivergenceOutput",
34
+ "DiversityOutput",
35
+ "DriftMMDOutput",
36
+ "DriftOutput",
37
+ "DuplicatesOutput",
38
+ "ExecutionMetadata",
39
+ "HashStatsOutput",
40
+ "ImageStatsOutput",
41
+ "LabelParityOutput",
42
+ "LabelStatsOutput",
43
+ "MetadataDistanceOutput",
44
+ "MetadataDistanceValues",
45
+ "MostDeviatedFactorsOutput",
46
+ "OODOutput",
47
+ "OODScoreOutput",
48
+ "OutliersOutput",
49
+ "ParityOutput",
50
+ "PixelStatsOutput",
51
+ "SourceIndex",
52
+ "SplitDatasetOutput",
53
+ "SufficiencyOutput",
54
+ "TrainValSplit",
55
+ "UAPOutput",
56
+ "VisualStatsOutput",
57
+ ]
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import inspect
6
+ import logging
7
+ from collections.abc import Collection, Mapping, Sequence
8
+ from dataclasses import dataclass
9
+ from datetime import datetime, timezone
10
+ from functools import partial, wraps
11
+ from typing import Any, Callable, Generic, Iterator, TypeVar, overload
12
+
13
+ import numpy as np
14
+ from typing_extensions import ParamSpec
15
+
16
+ from dataeval import __version__
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ExecutionMetadata:
21
+ """
22
+ Metadata about the execution of the function or method for the Output class.
23
+
24
+ Attributes
25
+ ----------
26
+ name: str
27
+ Name of the function or method
28
+ execution_time: datetime
29
+ Time of execution
30
+ execution_duration: float
31
+ Duration of execution in seconds
32
+ arguments: dict[str, Any]
33
+ Arguments passed to the function or method
34
+ state: dict[str, Any]
35
+ State attributes of the executing class
36
+ version: str
37
+ Version of DataEval
38
+ """
39
+
40
+ name: str
41
+ execution_time: datetime
42
+ execution_duration: float
43
+ arguments: dict[str, Any]
44
+ state: dict[str, Any]
45
+ version: str
46
+
47
+ @classmethod
48
+ def empty(cls) -> ExecutionMetadata:
49
+ return ExecutionMetadata(
50
+ name="",
51
+ execution_time=datetime.min,
52
+ execution_duration=0.0,
53
+ arguments={},
54
+ state={},
55
+ version=__version__,
56
+ )
57
+
58
+
59
+ T = TypeVar("T", covariant=True)
60
+
61
+
62
+ class GenericOutput(Generic[T]):
63
+ _meta: ExecutionMetadata | None = None
64
+
65
+ def data(self) -> T: ...
66
+ def meta(self) -> ExecutionMetadata:
67
+ """
68
+ Metadata about the execution of the function or method for the Output class.
69
+ """
70
+ return self._meta or ExecutionMetadata.empty()
71
+
72
+
73
+ class Output(GenericOutput[dict[str, Any]]):
74
+ def data(self) -> dict[str, Any]:
75
+ return {k: v for k, v in self.__dict__.items() if k != "_meta"}
76
+
77
+ def __repr__(self) -> str:
78
+ return str(self)
79
+
80
+ def __str__(self) -> str:
81
+ return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
82
+
83
+
84
+ class BaseCollectionMixin(Collection[Any]):
85
+ __slots__ = ["_data"]
86
+
87
+ def data(self) -> Any:
88
+ return self._data
89
+
90
+ def __len__(self) -> int:
91
+ return len(self._data)
92
+
93
+ def __repr__(self) -> str:
94
+ return f"{self.__class__.__name__}({repr(self._data)})"
95
+
96
+ def __str__(self) -> str:
97
+ return str(self._data)
98
+
99
+
100
+ TKey = TypeVar("TKey", str, int, float, set)
101
+ TValue = TypeVar("TValue")
102
+
103
+
104
+ class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
105
+ def __init__(self, data: Mapping[TKey, TValue]):
106
+ self._data = data
107
+
108
+ def __getitem__(self, key: TKey) -> TValue:
109
+ return self._data[key]
110
+
111
+ def __iter__(self) -> Iterator[TKey]:
112
+ return iter(self._data)
113
+
114
+
115
+ class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
116
+ def __init__(self, data: Sequence[TValue]):
117
+ self._data = data
118
+
119
+ @overload
120
+ def __getitem__(self, index: int) -> TValue: ...
121
+ @overload
122
+ def __getitem__(self, index: slice) -> Sequence[TValue]: ...
123
+
124
+ def __getitem__(self, index: int | slice) -> TValue | Sequence[TValue]:
125
+ return self._data[index]
126
+
127
+ def __iter__(self) -> Iterator[TValue]:
128
+ return iter(self._data)
129
+
130
+
131
+ P = ParamSpec("P")
132
+ R = TypeVar("R", bound=GenericOutput)
133
+
134
+
135
+ def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
136
+ """Decorator to stamp Output classes with runtime metadata"""
137
+
138
+ if fn is None:
139
+ return partial(set_metadata, state=state) # type: ignore
140
+
141
+ @wraps(fn)
142
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
143
+ def fmt(v):
144
+ if np.isscalar(v):
145
+ return v
146
+ if hasattr(v, "shape"):
147
+ return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
148
+ if hasattr(v, "__len__"):
149
+ return f"{v.__class__.__name__}: len={len(v)}"
150
+ return f"{v.__class__.__name__}"
151
+
152
+ # Collect function metadata
153
+ # set all params with defaults then update params with mapped arguments and explicit keyword args
154
+ fn_params = inspect.signature(fn).parameters
155
+ arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
156
+ arguments.update(zip(fn_params, args))
157
+ arguments.update(kwargs)
158
+ arguments = {k: fmt(v) for k, v in arguments.items()}
159
+ is_method = "self" in arguments
160
+ state_attrs = {k: fmt(getattr(args[0], k)) for k in state or []} if is_method else {}
161
+ module = args[0].__class__.__module__ if is_method else fn.__module__.removeprefix("src.")
162
+ class_prefix = f".{args[0].__class__.__name__}." if is_method else "."
163
+ name = f"{module}{class_prefix}{fn.__name__}"
164
+ arguments = {k: v for k, v in arguments.items() if k != "self"}
165
+
166
+ _logger = logging.getLogger(module)
167
+ time = datetime.now(timezone.utc)
168
+ _logger.log(logging.INFO, f">>> Executing '{name}': args={arguments} state={state} <<<")
169
+
170
+ ##### EXECUTE FUNCTION #####
171
+ result = fn(*args, **kwargs)
172
+ ############################
173
+
174
+ duration = (datetime.now(timezone.utc) - time).total_seconds()
175
+ _logger.log(logging.INFO, f">>> Completed '{name}': args={arguments} state={state} duration={duration} <<<")
176
+
177
+ # Update output with recorded metadata
178
+ metadata = ExecutionMetadata(name, time, duration, arguments, state_attrs, __version__)
179
+ object.__setattr__(result, "_meta", metadata)
180
+ return result
181
+
182
+ return wrapper