dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
  18. dataeval/detectors/ood/aegmm.py +66 -0
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
  25. dataeval/detectors/ood/vaegmm.py +75 -0
  26. dataeval/interop.py +56 -0
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
  32. dataeval/metrics/bias/metadata.py +358 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/lazy.py +26 -0
  51. dataeval/utils/metadata.py +258 -0
  52. dataeval/utils/shared.py +151 -0
  53. dataeval/{_internal → utils}/split_dataset.py +98 -33
  54. dataeval/utils/tensorflow/__init__.py +7 -6
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
  56. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
  57. dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
  58. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
  59. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
  60. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  61. dataeval/utils/torch/__init__.py +7 -3
  62. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  63. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  64. dataeval/utils/torch/models.py +138 -0
  65. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  66. dataeval/{_internal → utils/torch}/utils.py +3 -1
  67. dataeval/workflows/__init__.py +1 -1
  68. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  69. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
  70. dataeval-0.73.0.dist-info/RECORD +73 -0
  71. dataeval/_internal/detectors/__init__.py +0 -0
  72. dataeval/_internal/detectors/drift/__init__.py +0 -0
  73. dataeval/_internal/detectors/ood/__init__.py +0 -0
  74. dataeval/_internal/detectors/ood/aegmm.py +0 -78
  75. dataeval/_internal/detectors/ood/vaegmm.py +0 -89
  76. dataeval/_internal/interop.py +0 -49
  77. dataeval/_internal/metrics/__init__.py +0 -0
  78. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  79. dataeval/_internal/metrics/utils.py +0 -447
  80. dataeval/_internal/models/__init__.py +0 -0
  81. dataeval/_internal/models/pytorch/__init__.py +0 -0
  82. dataeval/_internal/models/pytorch/utils.py +0 -67
  83. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  84. dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
  85. dataeval/_internal/workflows/__init__.py +0 -0
  86. dataeval/detectors/drift/kernels/__init__.py +0 -10
  87. dataeval/detectors/drift/updates/__init__.py +0 -8
  88. dataeval/utils/tensorflow/models/__init__.py +0 -9
  89. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  90. dataeval/utils/torch/datasets/__init__.py +0 -12
  91. dataeval/utils/torch/models/__init__.py +0 -11
  92. dataeval/utils/torch/trainer/__init__.py +0 -7
  93. dataeval-0.72.1.dist-info/RECORD +0 -81
  94. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
  95. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,258 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["merge_metadata"]
4
+
5
+ import warnings
6
+ from typing import Any, Iterable, Mapping, TypeVar, overload
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ def _try_cast(v: Any, t: type[T]) -> T | None:
15
+ """Casts a value to a type or returns None if unable"""
16
+ try:
17
+ return t(v) # type: ignore
18
+ except (TypeError, ValueError):
19
+ return None
20
+
21
+
22
+ @overload
23
+ def _convert_type(data: list[str]) -> list[int] | list[float] | list[str]: ...
24
+ @overload
25
+ def _convert_type(data: str) -> int | float | str: ...
26
+
27
+
28
+ def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
29
+ """
30
+ Converts a value or a list of values to the simplest form possible, in preferred order of `int`,
31
+ `float`, or `string`.
32
+
33
+ Parameters
34
+ ----------
35
+ data : list[str] | str
36
+ A list of values or a single value
37
+
38
+ Returns
39
+ -------
40
+ list[int | float | str] | int | float | str
41
+ The same values converted to the numerical type if possible
42
+ """
43
+ if not isinstance(data, list):
44
+ value = _try_cast(data, float)
45
+ return str(data) if value is None else int(value) if value.is_integer() else value
46
+
47
+ converted = []
48
+ TYPE_MAP = {int: 0, float: 1, str: 2}
49
+ max_type = 0
50
+ for value in data:
51
+ value = _convert_type(value)
52
+ max_type = max(max_type, TYPE_MAP.get(type(value), 2))
53
+ converted.append(value)
54
+ for i in range(len(converted)):
55
+ converted[i] = list(TYPE_MAP)[max_type](converted[i])
56
+ return converted
57
+
58
+
59
+ def _get_key_indices(keys: Iterable[tuple[str, ...]]) -> dict[tuple[str, ...], int]:
60
+ """
61
+ Finds indices to minimize unique tuple keys
62
+
63
+ Parameters
64
+ ----------
65
+ keys : Iterable[tuple[str, ...]]
66
+ Collection of unique expanded tuple keys
67
+
68
+ Returns
69
+ -------
70
+ dict[tuple[str, ...], int]
71
+ Mapping of tuple keys to starting index
72
+ """
73
+ indices = {k: -1 for k in keys}
74
+ ks = list(keys)
75
+ while len(ks) > 0:
76
+ seen: dict[tuple[str, ...], list[tuple[str, ...]]] = {}
77
+ for k in ks:
78
+ seen.setdefault(k[indices[k] :], []).append(k)
79
+ ks.clear()
80
+ for sk in seen.values():
81
+ if len(sk) > 1:
82
+ ks.extend(sk)
83
+ for k in sk:
84
+ indices[k] -= 1
85
+ return indices
86
+
87
+
88
+ def _flatten_dict_inner(
89
+ d: Mapping[str, Any], parent_keys: tuple[str, ...], size: int | None = None, nested: bool = False
90
+ ) -> tuple[dict[tuple[str, ...], Any], int | None]:
91
+ """
92
+ Recursive internal function for flattening a dictionary.
93
+
94
+ Parameters
95
+ ----------
96
+ d : dict[str, Any]
97
+ Dictionary to flatten
98
+ parent_keys : tuple[str, ...]
99
+ Parent keys to the current dictionary being flattened
100
+ size : int or None, default None
101
+ Tracking int for length of lists
102
+ nested : bool, default False
103
+ Tracking if inside a list
104
+
105
+ Returns
106
+ -------
107
+ tuple[dict[tuple[str, ...], Any], int | None]
108
+ - [0]: Dictionary of flattened values with the keys reformatted as a hierarchical tuple of strings
109
+ - [1]: Size, if any, of the current list of values
110
+ """
111
+ items: dict[tuple[str, ...], Any] = {}
112
+ for k, v in d.items():
113
+ new_keys: tuple[str, ...] = parent_keys + (k,)
114
+ if isinstance(v, dict):
115
+ fd, size = _flatten_dict_inner(v, new_keys, size=size, nested=nested)
116
+ items.update(fd)
117
+ elif isinstance(v, (list, tuple)):
118
+ if not nested and (size is None or size == len(v)):
119
+ size = len(v)
120
+ if all(isinstance(i, dict) for i in v):
121
+ for sub_dict in v:
122
+ fd, size = _flatten_dict_inner(sub_dict, new_keys, size=size, nested=True)
123
+ for fk, fv in fd.items():
124
+ items.setdefault(fk, []).append(fv)
125
+ else:
126
+ items[new_keys] = v
127
+ else:
128
+ warnings.warn(f"Dropping nested list found in '{parent_keys + (k, )}'.")
129
+ else:
130
+ items[new_keys] = v
131
+ return items, size
132
+
133
+
134
+ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> dict[str, Any]:
135
+ """
136
+ Flattens a dictionary and converts values to numeric values when possible.
137
+
138
+ Parameters
139
+ ----------
140
+ d : dict[str, Any]
141
+ Dictionary to flatten
142
+ sep : str
143
+ String separator to use when concatenating key names
144
+ ignore_lists : bool
145
+ Option to skip expanding lists within metadata
146
+ fully_qualified : bool
147
+ Option to return dictionary keys full qualified instead of minimized
148
+
149
+ Returns
150
+ -------
151
+ dict[str, Any]
152
+ A flattened dictionary
153
+ """
154
+ expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
155
+
156
+ output = {}
157
+ if fully_qualified:
158
+ expanded = {sep.join(k): v for k, v in expanded.items()}
159
+ else:
160
+ keys = _get_key_indices(expanded)
161
+ expanded = {sep.join(k[keys[k] :]): v for k, v in expanded.items()}
162
+ for k, v in expanded.items():
163
+ cv = _convert_type(v)
164
+ if isinstance(cv, list) and len(cv) == size:
165
+ output[k] = cv
166
+ elif not isinstance(cv, list):
167
+ output[k] = cv if not size else [cv] * size
168
+ return output
169
+
170
+
171
+ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
172
+ """EXPERIMENTAL: Attempt to detect if metadata is a dict of dicts"""
173
+ # single dict
174
+ if len(metadata) < 2:
175
+ return False
176
+
177
+ # dict of non dicts
178
+ keys = list(metadata)
179
+ if not isinstance(metadata[keys[0]], Mapping):
180
+ return False
181
+
182
+ # dict of dicts with matching keys
183
+ return set(metadata[keys[0]]) == set(metadata[keys[1]])
184
+
185
+
186
+ def merge_metadata(
187
+ metadata: Iterable[Mapping[str, Any]],
188
+ ignore_lists: bool = False,
189
+ fully_qualified: bool = False,
190
+ as_numpy: bool = False,
191
+ ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
192
+ """
193
+ Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
194
+
195
+ Nested dictionaries are flattened, and lists are expanded. Nested lists are dropped as the
196
+ expanding into multiple hierarchical trees is not supported.
197
+
198
+ Parameters
199
+ ----------
200
+ metadata : Iterable[Mapping[str, Any]]
201
+ Iterable collection of metadata dictionaries to flatten and merge
202
+ ignore_lists : bool, default False
203
+ Option to skip expanding lists within metadata
204
+ fully_qualified : bool, default False
205
+ Option to return dictionary keys full qualified instead of minimized
206
+ as_numpy : bool, default False
207
+ Option to return results as lists or NumPy arrays
208
+
209
+ Returns
210
+ -------
211
+ dict[str, list[Any]] | dict[str, NDArray[Any]]
212
+ A single dictionary containing the flattened data as lists or NumPy arrays
213
+
214
+ Note
215
+ ----
216
+ Nested lists of values and inconsistent keys are dropped in the merged metadata dictionary
217
+
218
+ Example
219
+ -------
220
+ >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
221
+ >>> merge_metadata(list_metadata)
222
+ {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
223
+ """
224
+ merged: dict[str, list[Any]] = {}
225
+ isect: set[str] = set()
226
+ union: set[str] = set()
227
+ keys: list[str] | None = None
228
+ dicts: list[Mapping[str, Any]]
229
+
230
+ # EXPERIMENTAL
231
+ if isinstance(metadata, Mapping) and _is_metadata_dict_of_dicts(metadata):
232
+ warnings.warn("Experimental processing for dict of dicts.")
233
+ keys = [str(k) for k in metadata]
234
+ dicts = list(metadata.values())
235
+ ignore_lists = True
236
+ else:
237
+ dicts = list(metadata)
238
+
239
+ for d in dicts:
240
+ flattened = _flatten_dict(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
241
+ isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
242
+ union = union.union(flattened.keys())
243
+ for k, v in flattened.items():
244
+ merged.setdefault(k, []).extend(flattened[k]) if isinstance(v, list) else merged.setdefault(k, []).append(v)
245
+
246
+ if len(union) > len(isect):
247
+ warnings.warn(f"Inconsistent metadata keys found. Dropping {union - isect} from metadata.")
248
+
249
+ output: dict[str, Any] = {}
250
+
251
+ if keys:
252
+ output["keys"] = np.array(keys) if as_numpy else keys
253
+
254
+ for k in (key for key in merged if key in isect):
255
+ cv = _convert_type(merged[k])
256
+ output[k] = np.array(cv) if as_numpy else cv
257
+
258
+ return output
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import sys
6
+ from typing import Any, Callable, Literal, TypeVar
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+ from scipy.sparse import csr_matrix
11
+ from scipy.sparse.csgraph import minimum_spanning_tree as mst
12
+ from scipy.spatial.distance import pdist, squareform
13
+ from sklearn.neighbors import NearestNeighbors
14
+
15
+ if sys.version_info >= (3, 10):
16
+ from typing import ParamSpec
17
+ else:
18
+ from typing_extensions import ParamSpec
19
+
20
+ from dataeval.interop import as_numpy
21
+
22
+ EPSILON = 1e-5
23
+ HASH_SIZE = 8
24
+ MAX_FACTOR = 4
25
+
26
+
27
+ P = ParamSpec("P")
28
+ R = TypeVar("R")
29
+
30
+
31
+ def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
32
+ if method not in method_map:
33
+ raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
34
+ return method_map[method]
35
+
36
+
37
+ def flatten(array: ArrayLike) -> NDArray[Any]:
38
+ """
39
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
40
+
41
+ Parameters
42
+ ----------
43
+ X : NDArray, shape - (N, ... )
44
+ Input array
45
+
46
+ Returns
47
+ -------
48
+ NDArray, shape - (N, -1)
49
+ """
50
+ nparr = as_numpy(array)
51
+ return nparr.reshape((nparr.shape[0], -1))
52
+
53
+
54
+ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
55
+ """
56
+ Returns the minimum spanning tree from a :term:`NumPy` image array.
57
+
58
+ Parameters
59
+ ----------
60
+ X : NDArray
61
+ Numpy image array
62
+
63
+ Returns
64
+ -------
65
+ Data representing the minimum spanning tree
66
+ """
67
+ # All features belong on second dimension
68
+ X = flatten(X)
69
+ # We add a small constant to the distance matrix to ensure scipy interprets
70
+ # the input graph as fully-connected.
71
+ dense_eudist = squareform(pdist(X)) + EPSILON
72
+ eudist_csr = csr_matrix(dense_eudist)
73
+ return mst(eudist_csr)
74
+
75
+
76
+ def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
77
+ """
78
+ Returns the classes and counts of from an array of labels
79
+
80
+ Parameters
81
+ ----------
82
+ label : NDArray
83
+ Numpy labels array
84
+
85
+ Returns
86
+ -------
87
+ Classes and counts
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If the number of unique classes is less than 2
93
+ """
94
+ classes, counts = np.unique(labels, return_counts=True)
95
+ M = len(classes)
96
+ if M < 2:
97
+ raise ValueError("Label vector contains less than 2 classes!")
98
+ N = np.sum(counts).astype(int)
99
+ return M, N
100
+
101
+
102
+ def compute_neighbors(
103
+ A: NDArray[Any],
104
+ B: NDArray[Any],
105
+ k: int = 1,
106
+ algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
107
+ ) -> NDArray[Any]:
108
+ """
109
+ For each sample in A, compute the nearest neighbor in B
110
+
111
+ Parameters
112
+ ----------
113
+ A, B : NDArray
114
+ The n_samples and n_features respectively
115
+ k : int
116
+ The number of neighbors to find
117
+ algorithm : Literal
118
+ Tree method for nearest neighbor (auto, ball_tree or kd_tree)
119
+
120
+ Note
121
+ ----
122
+ Do not use kd_tree if n_features > 20
123
+
124
+ Returns
125
+ -------
126
+ List:
127
+ Closest points to each point in A and B
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If algorithm is not "auto", "ball_tree", or "kd_tree"
133
+
134
+ See Also
135
+ --------
136
+ sklearn.neighbors.NearestNeighbors
137
+ """
138
+
139
+ if k < 1:
140
+ raise ValueError("k must be >= 1")
141
+ if algorithm not in ["auto", "ball_tree", "kd_tree"]:
142
+ raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
143
+
144
+ A = flatten(A)
145
+ B = flatten(B)
146
+
147
+ nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
148
+ nns = nbrs.kneighbors(A)[1]
149
+ nns = nns[:, 1:].squeeze()
150
+
151
+ return nns
@@ -1,20 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["split_dataset"]
4
+
3
5
  import warnings
6
+ from typing import Any
4
7
 
5
8
  import numpy as np
9
+ from numpy.typing import NDArray
6
10
  from sklearn.cluster import KMeans
7
11
  from sklearn.metrics import silhouette_score
8
12
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
9
13
  from sklearn.utils.multiclass import type_of_target
10
14
 
11
15
 
12
- def check_args(num_folds: int = 1, test_frac: float | None = None, val_frac: float | None = None):
13
- """Check input arguments to ensure unambiguous splitting arguments are passed.
16
+ def validate_test_val(num_folds: int, test_frac: float | None, val_frac: float | None) -> tuple[float, float]:
17
+ """Check input fractions to ensure unambiguous splitting arguments are passed return calculated
18
+ test and validation fractions.
19
+
14
20
 
15
21
  Parameters
16
22
  ----------
17
- num_folds : int, default 1
23
+ num_folds : int
18
24
  number of [train, val] cross-validation folds to generate
19
25
  test_frac : float, optional
20
26
  If specified, also generate a test set containing (test_frac*100)% of the data
@@ -36,19 +42,23 @@ def check_args(num_folds: int = 1, test_frac: float | None = None, val_frac: flo
36
42
 
37
43
  Returns
38
44
  -------
39
- None
45
+ tuple[float, float]
46
+ Tuple of the validated and calculated values as appropriate for test and validation fractions
40
47
  """
41
48
  if (num_folds > 1) and (val_frac is not None):
42
49
  raise ValueError("If specifying val_frac, num_folds must be None or 1")
43
50
  if (num_folds == 1) and (val_frac is None):
44
- raise UnboundLocalError("If num_folds is None or 1, must assign a value to val_frac")
51
+ raise ValueError("If num_folds is None or 1, must assign a value to val_frac")
45
52
  t_frac = 0.0 if test_frac is None else test_frac
46
53
  v_frac = 1.0 / num_folds * (1.0 - t_frac) if val_frac is None else val_frac * (1.0 - t_frac)
47
54
  if (t_frac + v_frac) >= 1.0:
48
55
  raise ValueError(f"val_frac + test_frac must be less that 1.0, currently {v_frac+t_frac}")
56
+ return t_frac, v_frac
49
57
 
50
58
 
51
- def check_labels(labels: list | np.ndarray, total_partitions: int):
59
+ def check_labels(
60
+ labels: list[int] | NDArray[np.int_], total_partitions: int
61
+ ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
52
62
  """Check to make sure there are more input data than the total number of partitions requested
53
63
  Also converts labels to a numpy array, if it isn't already
54
64
 
@@ -88,7 +98,7 @@ def check_labels(labels: list | np.ndarray, total_partitions: int):
88
98
  return index, labels
89
99
 
90
100
 
91
- def check_stratifiable(labels: np.ndarray, total_partitions: int):
101
+ def check_stratifiable(labels: NDArray[np.int_], total_partitions: int) -> bool:
92
102
  """
93
103
  Very basic check to see if dataset can be stratified by class label. This is not a
94
104
  comprehensive test, as factors such as grouping also affect the ability to stratify by label
@@ -124,7 +134,7 @@ def check_stratifiable(labels: np.ndarray, total_partitions: int):
124
134
  return stratifiable
125
135
 
126
136
 
127
- def check_groups(group_ids: np.ndarray, num_partitions: int):
137
+ def check_groups(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
128
138
  """
129
139
  Warns user if the number of unique group_ids is incompatible with a grouped partition containing
130
140
  num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
@@ -162,7 +172,7 @@ def check_groups(group_ids: np.ndarray, num_partitions: int):
162
172
  return groupable
163
173
 
164
174
 
165
- def bin_kmeans(array: np.ndarray):
175
+ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
166
176
  """
167
177
  Find bins of continuous data by iteratively applying k-means clustering, and keeping the
168
178
  clustering with the highest silhouette score.
@@ -182,18 +192,18 @@ def bin_kmeans(array: np.ndarray):
182
192
  best_score = 0.60
183
193
  else:
184
194
  best_score = 0.50
185
- bin_index = np.zeros(len(array))
195
+ bin_index = np.zeros(len(array), dtype=np.int_)
186
196
  for k in range(2, 20):
187
197
  clusterer = KMeans(n_clusters=k)
188
198
  cluster_labels = clusterer.fit_predict(array)
189
199
  score = silhouette_score(array, cluster_labels, sample_size=25_000)
190
200
  if score > best_score:
191
201
  best_score = score
192
- bin_index = cluster_labels
202
+ bin_index = cluster_labels.astype(np.int_)
193
203
  return bin_index
194
204
 
195
205
 
196
- def angle2xy(angles: np.ndarray):
206
+ def angle2xy(angles: NDArray[Any]) -> NDArray[Any]:
197
207
  """
198
208
  Converts angle measurements to xy coordinates on the unit circle. Needed for binning angle data.
199
209
 
@@ -213,7 +223,7 @@ def angle2xy(angles: np.ndarray):
213
223
  return xy
214
224
 
215
225
 
216
- def get_group_ids(metadata: dict, groupnames: list, num_samples: int):
226
+ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.int_]:
217
227
  """Returns individual group numbers based on a subset of metadata defined by groupnames
218
228
 
219
229
  Parameters
@@ -235,7 +245,7 @@ def get_group_ids(metadata: dict, groupnames: list, num_samples: int):
235
245
  group_ids: np.ndarray
236
246
  group identifiers from metadata
237
247
  """
238
- features2group = {k: np.array(v) for k, v in metadata.items() if k in groupnames}
248
+ features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
239
249
  if not features2group:
240
250
  return np.zeros(num_samples, dtype=int)
241
251
  for name, feature in features2group.items():
@@ -252,8 +262,12 @@ def get_group_ids(metadata: dict, groupnames: list, num_samples: int):
252
262
 
253
263
 
254
264
  def make_splits(
255
- index: np.ndarray, labels: np.ndarray, n_folds: int, groups: np.ndarray | None = None, stratified: bool = False
256
- ):
265
+ index: NDArray[np.int_],
266
+ labels: NDArray[np.int_],
267
+ n_folds: int,
268
+ groups: NDArray[np.int_] | None = None,
269
+ stratified: bool = False,
270
+ ) -> list[dict[str, NDArray[np.int_]]]:
257
271
  """Split data into n_folds partitions of training and validation data.
258
272
 
259
273
  Parameters
@@ -290,9 +304,59 @@ def make_splits(
290
304
  return split_defs
291
305
 
292
306
 
307
+ def find_best_split(
308
+ labels: NDArray[np.int_], split_defs: list[dict[str, NDArray[np.int_]]], stratified: bool, eval_frac: float
309
+ ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
310
+ """Finds the split that most closely satisfies a criterion determined by the arguments passed.
311
+ If stratified is True, returns the split whose class balance most closely resembles the overall
312
+ class balance. If false, returns the split with the size closest to the desired eval_frac
313
+
314
+ Parameters
315
+ ----------
316
+ labels : np.ndarray
317
+ Labels upon which splits are (optionally) stratified
318
+ split_defs : list[dict]
319
+ List of dictionaries, which specifying train index, validation index, and the ratio of
320
+ validation to all data.
321
+ stratified: bool
322
+ If True, maintain dataset class balance within each train/val split
323
+ eval_frac: float
324
+ Desired fraction of the dataset sequestered for evaluation
325
+
326
+ Returns
327
+ -------
328
+ train_index : np.ndarray
329
+ indices of data partitioned for training
330
+ eval_index : np.ndarray
331
+ indices of data partitioned for evaluation
332
+ """
333
+
334
+ def class_freq_diff(split):
335
+ train_labels = labels[split["train"]]
336
+ _, train_counts = np.unique(train_labels, return_counts=True)
337
+ train_freq = train_counts / train_counts.sum()
338
+ return np.square(train_freq - class_freq).sum()
339
+
340
+ if stratified:
341
+ _, class_counts = np.unique(labels, return_counts=True)
342
+ class_freq = class_counts / class_counts.sum()
343
+ best_split = min(split_defs, key=class_freq_diff)
344
+ return best_split["train"], best_split["eval"]
345
+ elif eval_frac <= 2 / 3:
346
+ best_split = min(split_defs, key=lambda x: abs(eval_frac - x["eval_frac"])) # type: ignore
347
+ return best_split["train"], best_split["eval"]
348
+ else:
349
+ best_split = min(split_defs, key=lambda x: abs(eval_frac - (1 - x["eval_frac"]))) # type: ignore
350
+ return best_split["eval"], best_split["train"]
351
+
352
+
293
353
  def single_split(
294
- index: np.ndarray, labels: np.ndarray, eval_frac: float, groups: np.ndarray | None = None, stratified: bool = False
295
- ):
354
+ index: NDArray[np.int_],
355
+ labels: NDArray[np.int_],
356
+ eval_frac: float,
357
+ groups: NDArray[np.int_] | None = None,
358
+ stratified: bool = False,
359
+ ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
296
360
  """Handles the special case where only 1 partition of the data is desired (such as when
297
361
  generating the test holdout split). In this case, the desired fraction of the data to be
298
362
  partitioned into the test data must be specified, and a single [train, eval] pair are returned.
@@ -317,27 +381,28 @@ def single_split(
317
381
  eval_index : np.ndarray
318
382
  indices of data partitioned for evaluation
319
383
  """
320
- if eval_frac <= 2 / 3:
384
+ if groups is not None:
385
+ n_unique_groups = np.unique(groups).shape[0]
386
+ _, label_counts = np.unique(labels, return_counts=True)
387
+ n_folds = min(n_unique_groups, label_counts.min())
388
+ elif eval_frac <= 2 / 3:
321
389
  n_folds = max(2, int(round(1 / (eval_frac + 1e-6))))
322
- split_candidates = make_splits(index, labels, n_folds, groups, stratified)
323
- best_split = min(split_candidates, key=lambda x: abs(eval_frac - x["eval_frac"]))
324
- return best_split["train"], best_split["eval"]
325
390
  else:
326
- n_folds = max(2, int(round(1 / (1 - eval_frac + 1e-6))))
327
- split_candidates = make_splits(index, labels, n_folds, groups, stratified)
328
- best_split = min(split_candidates, key=lambda x: abs(eval_frac - (1 - x["eval_frac"])))
329
- return best_split["eval"], best_split["train"]
391
+ n_folds = max(2, int(round(1 / (1 - eval_frac - 1e-6))))
392
+ split_candidates = make_splits(index, labels, n_folds, groups, stratified)
393
+ best_train, best_eval = find_best_split(labels, split_candidates, stratified, eval_frac)
394
+ return best_train, best_eval
330
395
 
331
396
 
332
397
  def split_dataset(
333
- labels: list | np.ndarray,
398
+ labels: list[int] | NDArray[np.int_],
334
399
  num_folds: int = 1,
335
400
  stratify: bool = False,
336
- split_on: list | None = None,
337
- metadata: dict | None = None,
401
+ split_on: list[str] | None = None,
402
+ metadata: dict[str, Any] | None = None,
338
403
  test_frac: float | None = None,
339
404
  val_frac: float | None = None,
340
- ):
405
+ ) -> dict[str, dict[str, NDArray[np.int_]] | NDArray[np.int_]]:
341
406
  """Top level splitting function. Returns a dict with each key-value pair containing
342
407
  train and validation indices. Indices for a test holdout may also be optionally included
343
408
 
@@ -386,7 +451,7 @@ def split_dataset(
386
451
  }
387
452
  """
388
453
 
389
- check_args(num_folds, test_frac, val_frac)
454
+ test_frac, val_frac = validate_test_val(num_folds, test_frac, val_frac)
390
455
  total_partitions = num_folds + 1 if test_frac else num_folds
391
456
  index, labels = check_labels(labels, total_partitions)
392
457
  stratify &= check_stratifiable(labels, total_partitions)
@@ -399,7 +464,7 @@ def split_dataset(
399
464
  groups = None
400
465
  else:
401
466
  groups = None
402
- split_defs = {}
467
+ split_defs: dict[str, dict[str, NDArray[np.int_]] | NDArray[np.int_]] = {}
403
468
  if test_frac:
404
469
  tv_idx, test_idx = single_split(index, labels, test_frac, groups, stratify)
405
470
  tv_labels = labels[tv_idx]
@@ -410,7 +475,7 @@ def split_dataset(
410
475
  tv_labels = labels
411
476
  tv_groups = groups
412
477
  if num_folds == 1:
413
- train_idx, val_idx = single_split(tv_idx, tv_labels, val_frac, tv_groups, stratify) # type: ignore
478
+ train_idx, val_idx = single_split(tv_idx, tv_labels, val_frac, tv_groups, stratify)
414
479
  split_defs["fold_0"] = {"train": tv_idx[train_idx].squeeze(), "val": tv_idx[val_idx].squeeze()}
415
480
  else:
416
481
  tv_splits = make_splits(tv_idx, tv_labels, num_folds, tv_groups, stratify)
@@ -2,17 +2,18 @@
2
2
  TensorFlow models are used in :term:`out of distribution<Out-of-distribution (OOD)>` detectors in the
3
3
  :mod:`dataeval.detectors.ood` module.
4
4
 
5
- DataEval provides both basic default models through the utility :func:`dataeval.utils.tensorflow.models.create_model`
6
- as well as constructors which allow for customization of the encoder, decoder and any other applicable
7
- layers used by the model.
5
+ DataEval provides basic default models through the utility :func:`dataeval.utils.tensorflow.create_model`.
8
6
  """
9
7
 
10
8
  from dataeval import _IS_TENSORFLOW_AVAILABLE
11
9
 
12
- from . import loss, models, recon
13
-
14
10
  __all__ = []
15
11
 
16
12
 
17
13
  if _IS_TENSORFLOW_AVAILABLE:
18
- __all__ = ["loss", "models", "recon"]
14
+ import dataeval.utils.tensorflow.loss as loss
15
+ from dataeval.utils.tensorflow._internal.utils import create_model
16
+
17
+ __all__ = ["create_model", "loss"]
18
+
19
+ del _IS_TENSORFLOW_AVAILABLE