dataeval 0.72.2__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,258 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["merge_metadata"]
4
+
5
+ import warnings
6
+ from typing import Any, Iterable, Mapping, TypeVar, overload
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ def _try_cast(v: Any, t: type[T]) -> T | None:
15
+ """Casts a value to a type or returns None if unable"""
16
+ try:
17
+ return t(v) # type: ignore
18
+ except (TypeError, ValueError):
19
+ return None
20
+
21
+
22
+ @overload
23
+ def _convert_type(data: list[str]) -> list[int] | list[float] | list[str]: ...
24
+ @overload
25
+ def _convert_type(data: str) -> int | float | str: ...
26
+
27
+
28
+ def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
29
+ """
30
+ Converts a value or a list of values to the simplest form possible, in preferred order of `int`,
31
+ `float`, or `string`.
32
+
33
+ Parameters
34
+ ----------
35
+ data : list[str] | str
36
+ A list of values or a single value
37
+
38
+ Returns
39
+ -------
40
+ list[int | float | str] | int | float | str
41
+ The same values converted to the numerical type if possible
42
+ """
43
+ if not isinstance(data, list):
44
+ value = _try_cast(data, float)
45
+ return str(data) if value is None else int(value) if value.is_integer() else value
46
+
47
+ converted = []
48
+ TYPE_MAP = {int: 0, float: 1, str: 2}
49
+ max_type = 0
50
+ for value in data:
51
+ value = _convert_type(value)
52
+ max_type = max(max_type, TYPE_MAP.get(type(value), 2))
53
+ converted.append(value)
54
+ for i in range(len(converted)):
55
+ converted[i] = list(TYPE_MAP)[max_type](converted[i])
56
+ return converted
57
+
58
+
59
+ def _get_key_indices(keys: Iterable[tuple[str, ...]]) -> dict[tuple[str, ...], int]:
60
+ """
61
+ Finds indices to minimize unique tuple keys
62
+
63
+ Parameters
64
+ ----------
65
+ keys : Iterable[tuple[str, ...]]
66
+ Collection of unique expanded tuple keys
67
+
68
+ Returns
69
+ -------
70
+ dict[tuple[str, ...], int]
71
+ Mapping of tuple keys to starting index
72
+ """
73
+ indices = {k: -1 for k in keys}
74
+ ks = list(keys)
75
+ while len(ks) > 0:
76
+ seen: dict[tuple[str, ...], list[tuple[str, ...]]] = {}
77
+ for k in ks:
78
+ seen.setdefault(k[indices[k] :], []).append(k)
79
+ ks.clear()
80
+ for sk in seen.values():
81
+ if len(sk) > 1:
82
+ ks.extend(sk)
83
+ for k in sk:
84
+ indices[k] -= 1
85
+ return indices
86
+
87
+
88
+ def _flatten_dict_inner(
89
+ d: Mapping[str, Any], parent_keys: tuple[str, ...], size: int | None = None, nested: bool = False
90
+ ) -> tuple[dict[tuple[str, ...], Any], int | None]:
91
+ """
92
+ Recursive internal function for flattening a dictionary.
93
+
94
+ Parameters
95
+ ----------
96
+ d : dict[str, Any]
97
+ Dictionary to flatten
98
+ parent_keys : tuple[str, ...]
99
+ Parent keys to the current dictionary being flattened
100
+ size : int or None, default None
101
+ Tracking int for length of lists
102
+ nested : bool, default False
103
+ Tracking if inside a list
104
+
105
+ Returns
106
+ -------
107
+ tuple[dict[tuple[str, ...], Any], int | None]
108
+ - [0]: Dictionary of flattened values with the keys reformatted as a hierarchical tuple of strings
109
+ - [1]: Size, if any, of the current list of values
110
+ """
111
+ items: dict[tuple[str, ...], Any] = {}
112
+ for k, v in d.items():
113
+ new_keys: tuple[str, ...] = parent_keys + (k,)
114
+ if isinstance(v, dict):
115
+ fd, size = _flatten_dict_inner(v, new_keys, size=size, nested=nested)
116
+ items.update(fd)
117
+ elif isinstance(v, (list, tuple)):
118
+ if not nested and (size is None or size == len(v)):
119
+ size = len(v)
120
+ if all(isinstance(i, dict) for i in v):
121
+ for sub_dict in v:
122
+ fd, size = _flatten_dict_inner(sub_dict, new_keys, size=size, nested=True)
123
+ for fk, fv in fd.items():
124
+ items.setdefault(fk, []).append(fv)
125
+ else:
126
+ items[new_keys] = v
127
+ else:
128
+ warnings.warn(f"Dropping nested list found in '{parent_keys + (k, )}'.")
129
+ else:
130
+ items[new_keys] = v
131
+ return items, size
132
+
133
+
134
+ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> dict[str, Any]:
135
+ """
136
+ Flattens a dictionary and converts values to numeric values when possible.
137
+
138
+ Parameters
139
+ ----------
140
+ d : dict[str, Any]
141
+ Dictionary to flatten
142
+ sep : str
143
+ String separator to use when concatenating key names
144
+ ignore_lists : bool
145
+ Option to skip expanding lists within metadata
146
+ fully_qualified : bool
147
+ Option to return dictionary keys full qualified instead of minimized
148
+
149
+ Returns
150
+ -------
151
+ dict[str, Any]
152
+ A flattened dictionary
153
+ """
154
+ expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
155
+
156
+ output = {}
157
+ if fully_qualified:
158
+ expanded = {sep.join(k): v for k, v in expanded.items()}
159
+ else:
160
+ keys = _get_key_indices(expanded)
161
+ expanded = {sep.join(k[keys[k] :]): v for k, v in expanded.items()}
162
+ for k, v in expanded.items():
163
+ cv = _convert_type(v)
164
+ if isinstance(cv, list) and len(cv) == size:
165
+ output[k] = cv
166
+ elif not isinstance(cv, list):
167
+ output[k] = cv if not size else [cv] * size
168
+ return output
169
+
170
+
171
+ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
172
+ """EXPERIMENTAL: Attempt to detect if metadata is a dict of dicts"""
173
+ # single dict
174
+ if len(metadata) < 2:
175
+ return False
176
+
177
+ # dict of non dicts
178
+ keys = list(metadata)
179
+ if not isinstance(metadata[keys[0]], Mapping):
180
+ return False
181
+
182
+ # dict of dicts with matching keys
183
+ return set(metadata[keys[0]]) == set(metadata[keys[1]])
184
+
185
+
186
+ def merge_metadata(
187
+ metadata: Iterable[Mapping[str, Any]],
188
+ ignore_lists: bool = False,
189
+ fully_qualified: bool = False,
190
+ as_numpy: bool = False,
191
+ ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
192
+ """
193
+ Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
194
+
195
+ Nested dictionaries are flattened, and lists are expanded. Nested lists are dropped as the
196
+ expanding into multiple hierarchical trees is not supported.
197
+
198
+ Parameters
199
+ ----------
200
+ metadata : Iterable[Mapping[str, Any]]
201
+ Iterable collection of metadata dictionaries to flatten and merge
202
+ ignore_lists : bool, default False
203
+ Option to skip expanding lists within metadata
204
+ fully_qualified : bool, default False
205
+ Option to return dictionary keys full qualified instead of minimized
206
+ as_numpy : bool, default False
207
+ Option to return results as lists or NumPy arrays
208
+
209
+ Returns
210
+ -------
211
+ dict[str, list[Any]] | dict[str, NDArray[Any]]
212
+ A single dictionary containing the flattened data as lists or NumPy arrays
213
+
214
+ Note
215
+ ----
216
+ Nested lists of values and inconsistent keys are dropped in the merged metadata dictionary
217
+
218
+ Example
219
+ -------
220
+ >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
221
+ >>> merge_metadata(list_metadata)
222
+ {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
223
+ """
224
+ merged: dict[str, list[Any]] = {}
225
+ isect: set[str] = set()
226
+ union: set[str] = set()
227
+ keys: list[str] | None = None
228
+ dicts: list[Mapping[str, Any]]
229
+
230
+ # EXPERIMENTAL
231
+ if isinstance(metadata, Mapping) and _is_metadata_dict_of_dicts(metadata):
232
+ warnings.warn("Experimental processing for dict of dicts.")
233
+ keys = [str(k) for k in metadata]
234
+ dicts = list(metadata.values())
235
+ ignore_lists = True
236
+ else:
237
+ dicts = list(metadata)
238
+
239
+ for d in dicts:
240
+ flattened = _flatten_dict(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
241
+ isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
242
+ union = union.union(flattened.keys())
243
+ for k, v in flattened.items():
244
+ merged.setdefault(k, []).extend(flattened[k]) if isinstance(v, list) else merged.setdefault(k, []).append(v)
245
+
246
+ if len(union) > len(isect):
247
+ warnings.warn(f"Inconsistent metadata keys found. Dropping {union - isect} from metadata.")
248
+
249
+ output: dict[str, Any] = {}
250
+
251
+ if keys:
252
+ output["keys"] = np.array(keys) if as_numpy else keys
253
+
254
+ for k in (key for key in merged if key in isect):
255
+ cv = _convert_type(merged[k])
256
+ output[k] = np.array(cv) if as_numpy else cv
257
+
258
+ return output
@@ -8,10 +8,16 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from typing import NamedTuple
11
+ from typing import TYPE_CHECKING, NamedTuple
12
12
 
13
13
  import numpy as np
14
- import tensorflow as tf
14
+
15
+ from dataeval.utils.lazy import lazyload
16
+
17
+ if TYPE_CHECKING:
18
+ import tensorflow as tf
19
+ else:
20
+ tf = lazyload("tensorflow")
15
21
 
16
22
 
17
23
  class GaussianMixtureModelParams(NamedTuple):
@@ -8,18 +8,27 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from typing import Literal, cast
11
+ from typing import TYPE_CHECKING, Literal, cast
12
12
 
13
13
  import numpy as np
14
- import tensorflow as tf
15
14
  from numpy.typing import NDArray
16
- from tensorflow_probability.python.distributions.mvn_diag import MultivariateNormalDiag
17
- from tensorflow_probability.python.distributions.mvn_tril import MultivariateNormalTriL
18
- from tensorflow_probability.python.stats import covariance
19
- from tf_keras.layers import Flatten
20
15
 
16
+ from dataeval.utils.lazy import lazyload
21
17
  from dataeval.utils.tensorflow._internal.gmm import gmm_energy, gmm_params
22
18
 
19
+ if TYPE_CHECKING:
20
+ import tensorflow as tf
21
+ import tensorflow_probability.python.distributions.mvn_diag as mvn_diag
22
+ import tensorflow_probability.python.distributions.mvn_tril as mvn_tril
23
+ import tensorflow_probability.python.stats as tfp_stats
24
+ import tf_keras as keras
25
+ else:
26
+ tf = lazyload("tensorflow")
27
+ keras = lazyload("tf_keras")
28
+ mvn_diag = lazyload("tensorflow_probability.python.distributions.mvn_diag")
29
+ mvn_tril = lazyload("tensorflow_probability.python.distributions.mvn_tril")
30
+ tfp_stats = lazyload("tensorflow_probability.python.stats")
31
+
23
32
 
24
33
  class Elbo:
25
34
  """
@@ -46,7 +55,7 @@ class Elbo:
46
55
  self._cov = ("sim", cov_type)
47
56
  elif cov_type in ["cov_full", "cov_diag"]:
48
57
  x_np: NDArray[np.float32] = x.numpy().astype(np.float32) if tf.is_tensor(x) else x # type: ignore
49
- cov = covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
58
+ cov = tfp_stats.covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
50
59
  if cov_type == "cov_diag": # infer standard deviation from covariance matrix
51
60
  cov = tf.math.sqrt(tf.linalg.diag_part(cov))
52
61
  self._cov = (cov_type, cov)
@@ -54,15 +63,15 @@ class Elbo:
54
63
  raise ValueError("Only cov_full, cov_diag or sim value should be specified.")
55
64
 
56
65
  def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
57
- y_pred_flat = cast(tf.Tensor, Flatten()(y_pred))
66
+ y_pred_flat = cast(tf.Tensor, keras.layers.Flatten()(y_pred))
58
67
 
59
68
  if self._cov[0] == "cov_full":
60
- y_mn = MultivariateNormalTriL(y_pred_flat, scale_tril=tf.linalg.cholesky(self._cov[1]))
69
+ y_mn = mvn_tril.MultivariateNormalTriL(y_pred_flat, scale_tril=tf.linalg.cholesky(self._cov[1]))
61
70
  else: # cov_diag and sim
62
71
  cov_diag = self._cov[1] if self._cov[0] == "cov_diag" else self._cov[1] * tf.ones(y_pred_flat.shape[-1])
63
- y_mn = MultivariateNormalDiag(y_pred_flat, scale_diag=cov_diag)
72
+ y_mn = mvn_diag.MultivariateNormalDiag(y_pred_flat, scale_diag=cov_diag)
64
73
 
65
- loss = -tf.reduce_mean(y_mn.log_prob(Flatten()(y_true)))
74
+ loss = -tf.reduce_mean(y_mn.log_prob(keras.layers.Flatten()(y_true)))
66
75
  return loss
67
76
 
68
77