dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +9 -9
- dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
- dataeval/detectors/ood/base.py +63 -113
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/metadata_ks_compare.py +52 -14
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +73 -70
- dataeval/metrics/bias/coverage.py +4 -4
- dataeval/metrics/bias/diversity.py +67 -136
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +51 -161
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/aegmm.py +0 -66
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -97
- dataeval/detectors/ood/vaegmm.py +0 -75
- dataeval/metrics/bias/metadata.py +0 -440
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -123
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.73.1.dist-info/RECORD +0 -73
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -9,11 +9,11 @@ from typing import Any, Iterable, Mapping, TypeVar
|
|
9
9
|
from numpy.typing import ArrayLike
|
10
10
|
|
11
11
|
from dataeval.interop import to_numpy
|
12
|
-
from dataeval.output import
|
12
|
+
from dataeval.output import Output, set_metadata
|
13
13
|
|
14
14
|
|
15
15
|
@dataclass(frozen=True)
|
16
|
-
class LabelStatsOutput(
|
16
|
+
class LabelStatsOutput(Output):
|
17
17
|
"""
|
18
18
|
Output class for :func:`labelstats` stats metric
|
19
19
|
|
@@ -57,7 +57,7 @@ def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
|
|
57
57
|
return dict(sorted(d.items(), key=lambda x: x[0]))
|
58
58
|
|
59
59
|
|
60
|
-
@set_metadata
|
60
|
+
@set_metadata
|
61
61
|
def labelstats(
|
62
62
|
labels: Iterable[ArrayLike],
|
63
63
|
) -> LabelStatsOutput:
|
dataeval/output.py
CHANGED
@@ -4,9 +4,10 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import inspect
|
6
6
|
import sys
|
7
|
+
from collections.abc import Mapping
|
7
8
|
from datetime import datetime, timezone
|
8
|
-
from functools import wraps
|
9
|
-
from typing import Any, Callable,
|
9
|
+
from functools import partial, wraps
|
10
|
+
from typing import Any, Callable, Iterator, TypeVar
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
|
@@ -18,7 +19,7 @@ else:
|
|
18
19
|
from dataeval import __version__
|
19
20
|
|
20
21
|
|
21
|
-
class
|
22
|
+
class Output:
|
22
23
|
_name: str
|
23
24
|
_execution_time: datetime
|
24
25
|
_execution_duration: float
|
@@ -26,6 +27,9 @@ class OutputMetadata:
|
|
26
27
|
_state: dict[str, str]
|
27
28
|
_version: str
|
28
29
|
|
30
|
+
def __str__(self) -> str:
|
31
|
+
return f"{self.__class__.__name__}: {str(self.dict())}"
|
32
|
+
|
29
33
|
def dict(self) -> dict[str, Any]:
|
30
34
|
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
31
35
|
|
@@ -33,58 +37,78 @@ class OutputMetadata:
|
|
33
37
|
return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
|
34
38
|
|
35
39
|
|
40
|
+
TKey = TypeVar("TKey", str, int, float, set)
|
41
|
+
TValue = TypeVar("TValue")
|
42
|
+
|
43
|
+
|
44
|
+
class MappingOutput(Mapping[TKey, TValue], Output):
|
45
|
+
__slots__ = ["_data"]
|
46
|
+
|
47
|
+
def __init__(self, data: Mapping[TKey, TValue]):
|
48
|
+
self._data = data
|
49
|
+
|
50
|
+
def __getitem__(self, key: TKey) -> TValue:
|
51
|
+
return self._data.__getitem__(key)
|
52
|
+
|
53
|
+
def __iter__(self) -> Iterator[TKey]:
|
54
|
+
return self._data.__iter__()
|
55
|
+
|
56
|
+
def __len__(self) -> int:
|
57
|
+
return self._data.__len__()
|
58
|
+
|
59
|
+
def dict(self) -> dict[str, TValue]:
|
60
|
+
return {str(k): v for k, v in self._data.items()}
|
61
|
+
|
62
|
+
|
36
63
|
P = ParamSpec("P")
|
37
|
-
R = TypeVar("R", bound=
|
64
|
+
R = TypeVar("R", bound=Output)
|
38
65
|
|
39
66
|
|
40
|
-
def set_metadata(
|
41
|
-
state_attr: Iterable[str] | None = None,
|
42
|
-
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
67
|
+
def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
|
43
68
|
"""Decorator to stamp OutputMetadata classes with runtime metadata"""
|
44
69
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
)
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
return decorator
|
70
|
+
if fn is None:
|
71
|
+
return partial(set_metadata, state=state) # type: ignore
|
72
|
+
|
73
|
+
@wraps(fn)
|
74
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
75
|
+
def fmt(v):
|
76
|
+
if np.isscalar(v):
|
77
|
+
return v
|
78
|
+
if hasattr(v, "shape"):
|
79
|
+
return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
|
80
|
+
if hasattr(v, "__len__"):
|
81
|
+
return f"{v.__class__.__name__}: len={len(v)}"
|
82
|
+
return f"{v.__class__.__name__}"
|
83
|
+
|
84
|
+
time = datetime.now(timezone.utc)
|
85
|
+
result = fn(*args, **kwargs)
|
86
|
+
duration = (datetime.now(timezone.utc) - time).total_seconds()
|
87
|
+
fn_params = inspect.signature(fn).parameters
|
88
|
+
|
89
|
+
# set all params with defaults then update params with mapped arguments and explicit keyword args
|
90
|
+
arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
|
91
|
+
arguments.update(zip(fn_params, args))
|
92
|
+
arguments.update(kwargs)
|
93
|
+
arguments = {k: fmt(v) for k, v in arguments.items()}
|
94
|
+
state_attrs = (
|
95
|
+
{k: fmt(getattr(args[0], k)) for k in state if "self" in arguments} if "self" in arguments and state else {}
|
96
|
+
)
|
97
|
+
name = (
|
98
|
+
f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
|
99
|
+
if "self" in arguments
|
100
|
+
else f"{fn.__module__}.{fn.__qualname__}"
|
101
|
+
)
|
102
|
+
metadata = {
|
103
|
+
"_name": name,
|
104
|
+
"_execution_time": time,
|
105
|
+
"_execution_duration": duration,
|
106
|
+
"_arguments": {k: v for k, v in arguments.items() if k != "self"},
|
107
|
+
"_state": state_attrs,
|
108
|
+
"_version": __version__,
|
109
|
+
}
|
110
|
+
for k, v in metadata.items():
|
111
|
+
object.__setattr__(result, k, v)
|
112
|
+
return result
|
113
|
+
|
114
|
+
return wrapper
|
dataeval/utils/__init__.py
CHANGED
@@ -4,7 +4,7 @@ in setting up architectures that are guaranteed to work with applicable DataEval
|
|
4
4
|
metrics. Currently DataEval supports both :term:`TensorFlow` and PyTorch backends.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from dataeval import
|
7
|
+
from dataeval import _IS_TORCH_AVAILABLE
|
8
8
|
from dataeval.utils.metadata import merge_metadata
|
9
9
|
from dataeval.utils.split_dataset import split_dataset
|
10
10
|
|
@@ -15,10 +15,4 @@ if _IS_TORCH_AVAILABLE:
|
|
15
15
|
|
16
16
|
__all__ += ["torch"]
|
17
17
|
|
18
|
-
if _IS_TENSORFLOW_AVAILABLE:
|
19
|
-
from dataeval.utils import tensorflow
|
20
|
-
|
21
|
-
__all__ += ["tensorflow"]
|
22
|
-
|
23
|
-
del _IS_TENSORFLOW_AVAILABLE
|
24
18
|
del _IS_TORCH_AVAILABLE
|
dataeval/utils/gmm.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Generic, TypeVar
|
3
|
+
|
4
|
+
TGMMData = TypeVar("TGMMData")
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class GaussianMixtureModelParams(Generic[TGMMData]):
|
9
|
+
"""
|
10
|
+
phi : TGMMData
|
11
|
+
Mixture component distribution weights.
|
12
|
+
mu : TGMMData
|
13
|
+
Mixture means.
|
14
|
+
cov : TGMMData
|
15
|
+
Mixture covariance.
|
16
|
+
L : TGMMData
|
17
|
+
Cholesky decomposition of `cov`.
|
18
|
+
log_det_cov : TGMMData
|
19
|
+
Log of the determinant of `cov`.
|
20
|
+
"""
|
21
|
+
|
22
|
+
phi: TGMMData
|
23
|
+
mu: TGMMData
|
24
|
+
cov: TGMMData
|
25
|
+
L: TGMMData
|
26
|
+
log_det_cov: TGMMData
|
dataeval/utils/metadata.py
CHANGED
@@ -131,7 +131,9 @@ def _flatten_dict_inner(
|
|
131
131
|
return items, size
|
132
132
|
|
133
133
|
|
134
|
-
def _flatten_dict(
|
134
|
+
def _flatten_dict(
|
135
|
+
d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
|
136
|
+
) -> tuple[dict[str, Any], int]:
|
135
137
|
"""
|
136
138
|
Flattens a dictionary and converts values to numeric values when possible.
|
137
139
|
|
@@ -165,7 +167,7 @@ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qual
|
|
165
167
|
output[k] = cv
|
166
168
|
elif not isinstance(cv, list):
|
167
169
|
output[k] = cv if not size else [cv] * size
|
168
|
-
return output
|
170
|
+
return output, size if size is not None else 1
|
169
171
|
|
170
172
|
|
171
173
|
def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
|
@@ -188,7 +190,7 @@ def merge_metadata(
|
|
188
190
|
ignore_lists: bool = False,
|
189
191
|
fully_qualified: bool = False,
|
190
192
|
as_numpy: bool = False,
|
191
|
-
) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
|
193
|
+
) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
|
192
194
|
"""
|
193
195
|
Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
|
194
196
|
|
@@ -208,8 +210,10 @@ def merge_metadata(
|
|
208
210
|
|
209
211
|
Returns
|
210
212
|
-------
|
211
|
-
dict[str, list[Any]]
|
213
|
+
dict[str, list[Any]] or dict[str, NDArray[Any]]
|
212
214
|
A single dictionary containing the flattened data as lists or NumPy arrays
|
215
|
+
NDArray[np.int_]
|
216
|
+
Array defining where individual images start, helpful when working with object detection metadata
|
213
217
|
|
214
218
|
Note
|
215
219
|
----
|
@@ -217,9 +221,12 @@ def merge_metadata(
|
|
217
221
|
|
218
222
|
Example
|
219
223
|
-------
|
220
|
-
>>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
|
221
|
-
>>> merge_metadata(list_metadata)
|
224
|
+
>>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
|
225
|
+
>>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
|
226
|
+
>>> reorganized_metadata
|
222
227
|
{'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
|
228
|
+
>>> image_indicies
|
229
|
+
array([0])
|
223
230
|
"""
|
224
231
|
merged: dict[str, list[Any]] = {}
|
225
232
|
isect: set[str] = set()
|
@@ -236,8 +243,11 @@ def merge_metadata(
|
|
236
243
|
else:
|
237
244
|
dicts = list(metadata)
|
238
245
|
|
239
|
-
|
240
|
-
|
246
|
+
image_repeats = np.zeros(len(dicts))
|
247
|
+
for i, d in enumerate(dicts):
|
248
|
+
flattened, image_repeats[i] = _flatten_dict(
|
249
|
+
d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
|
250
|
+
)
|
241
251
|
isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
|
242
252
|
union = union.union(flattened.keys())
|
243
253
|
for k, v in flattened.items():
|
@@ -248,6 +258,16 @@ def merge_metadata(
|
|
248
258
|
|
249
259
|
output: dict[str, Any] = {}
|
250
260
|
|
261
|
+
if image_repeats.sum() == image_repeats.size:
|
262
|
+
image_indicies = np.arange(image_repeats.size)
|
263
|
+
else:
|
264
|
+
image_ids = np.arange(image_repeats.size)
|
265
|
+
image_data = np.concatenate(
|
266
|
+
[np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
|
267
|
+
)
|
268
|
+
_, image_unsorted = np.unique(image_data, return_index=True)
|
269
|
+
image_indicies = np.sort(image_unsorted)
|
270
|
+
|
251
271
|
if keys:
|
252
272
|
output["keys"] = np.array(keys) if as_numpy else keys
|
253
273
|
|
@@ -255,4 +275,4 @@ def merge_metadata(
|
|
255
275
|
cv = _convert_type(merged[k])
|
256
276
|
output[k] = np.array(cv) if as_numpy else cv
|
257
277
|
|
258
|
-
return output
|
278
|
+
return output, image_indicies
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Adapted for Pytorch from:
|
3
|
+
|
4
|
+
Source code derived from Alibi-Detect 0.11.4
|
5
|
+
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
|
+
|
7
|
+
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
8
|
+
Licensed under Apache Software License (Apache 2.0)
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from dataeval.utils.gmm import GaussianMixtureModelParams
|
17
|
+
|
18
|
+
|
19
|
+
def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams[torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Compute parameters of Gaussian Mixture Model.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
z : torch.Tensor
|
26
|
+
Observations.
|
27
|
+
gamma : torch.Tensor
|
28
|
+
Mixture probabilities to derive mixture distribution weights from.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
33
|
+
The parameters used to calculate energy.
|
34
|
+
"""
|
35
|
+
|
36
|
+
# compute gmm parameters phi, mu and cov
|
37
|
+
N = gamma.shape[0] # nb of samples in batch
|
38
|
+
sum_gamma = torch.sum(gamma, 0) # K
|
39
|
+
phi = sum_gamma / N # K
|
40
|
+
# K x D (D = latent_dim)
|
41
|
+
mu = torch.sum(torch.unsqueeze(gamma, -1) * torch.unsqueeze(z, 1), 0) / torch.unsqueeze(sum_gamma, -1)
|
42
|
+
z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(mu, 0) # N x K x D
|
43
|
+
z_mu_outer = torch.unsqueeze(z_mu, -1) * torch.unsqueeze(z_mu, -2) # N x K x D x D
|
44
|
+
|
45
|
+
# K x D x D
|
46
|
+
cov = torch.sum(torch.unsqueeze(torch.unsqueeze(gamma, -1), -1) * z_mu_outer, 0) / torch.unsqueeze(
|
47
|
+
torch.unsqueeze(sum_gamma, -1), -1
|
48
|
+
)
|
49
|
+
|
50
|
+
# cholesky decomposition of covariance and determinant derivation
|
51
|
+
D = cov.shape[1]
|
52
|
+
eps = 1e-6
|
53
|
+
L = torch.linalg.cholesky(cov + torch.eye(D) * eps) # K x D x D
|
54
|
+
log_det_cov = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), 1) # K
|
55
|
+
|
56
|
+
return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
57
|
+
|
58
|
+
|
59
|
+
def gmm_energy(
|
60
|
+
z: torch.Tensor,
|
61
|
+
params: GaussianMixtureModelParams[torch.Tensor],
|
62
|
+
return_mean: bool = True,
|
63
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
64
|
+
"""
|
65
|
+
Compute sample energy from Gaussian Mixture Model.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
params : GaussianMixtureModelParams
|
70
|
+
The gaussian mixture model parameters.
|
71
|
+
return_mean : bool, default True
|
72
|
+
Take mean across all sample energies in a batch.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
sample_energy
|
77
|
+
The sample energy of the GMM.
|
78
|
+
cov_diag
|
79
|
+
The inverse sum of the diagonal components of the covariance matrix.
|
80
|
+
"""
|
81
|
+
D = params.cov.shape[1]
|
82
|
+
z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(params.mu, 0) # N x K x D
|
83
|
+
z_mu_T = torch.permute(z_mu, dims=[1, 2, 0]) # K x D x N
|
84
|
+
v = torch.linalg.solve_triangular(params.L, z_mu_T, upper=False) # K x D x D
|
85
|
+
|
86
|
+
# rewrite sample energy in logsumexp format for numerical stability
|
87
|
+
logits = torch.log(torch.unsqueeze(params.phi, -1)) - 0.5 * (
|
88
|
+
torch.sum(torch.square(v), 1) + float(D) * np.log(2.0 * np.pi) + torch.unsqueeze(params.log_det_cov, -1)
|
89
|
+
) # K x N
|
90
|
+
sample_energy = -torch.logsumexp(logits, 0) # N
|
91
|
+
|
92
|
+
if return_mean:
|
93
|
+
sample_energy = torch.mean(sample_energy)
|
94
|
+
|
95
|
+
# inverse sum of variances
|
96
|
+
cov_diag = torch.sum(torch.divide(torch.tensor(1), torch.diagonal(params.cov, dim1=-2, dim2=-1)))
|
97
|
+
|
98
|
+
return sample_energy, cov_diag
|
dataeval/utils/torch/models.py
CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
|
4
4
|
|
5
|
+
import math
|
5
6
|
from typing import Any
|
6
7
|
|
8
|
+
import torch
|
7
9
|
import torch.nn as nn
|
8
10
|
|
9
11
|
|
@@ -136,3 +138,193 @@ class Decoder(nn.Module):
|
|
136
138
|
The reconstructed output tensor.
|
137
139
|
"""
|
138
140
|
return self.decoder(x)
|
141
|
+
|
142
|
+
|
143
|
+
class AE(nn.Module):
|
144
|
+
"""
|
145
|
+
An autoencoder model with a separate encoder and decoder. Meant to replace the TensorFlow model called AE, which we
|
146
|
+
used as the core of an autoencoder-based OOD detector, i.e. as an argument to OOD_AE().
|
147
|
+
|
148
|
+
Parameters
|
149
|
+
----------
|
150
|
+
input_shape : tuple[int, int, int]
|
151
|
+
Number of input channels, number of rows, number of columns.() Number of examples per batch will be inferred
|
152
|
+
at runtime.)
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(self, input_shape: tuple[int, int, int]) -> None:
|
156
|
+
super().__init__()
|
157
|
+
|
158
|
+
input_dim = math.prod(input_shape)
|
159
|
+
|
160
|
+
# following is lifted from src/dataeval/utils/tensorflow/_internal/utils.py. It makes an odd staircase that is
|
161
|
+
# basically proportional to the number of numbers in the image to the 0.8 power. '
|
162
|
+
encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)))
|
163
|
+
|
164
|
+
self.encoder: Encoder_AE = Encoder_AE(input_shape, encoding_dim)
|
165
|
+
|
166
|
+
self.decoder: Decoder_AE = Decoder_AE(input_shape, encoding_dim, self.encoder.post_op_shape)
|
167
|
+
|
168
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
169
|
+
"""
|
170
|
+
Perform a forward pass through the encoder and decoder.
|
171
|
+
|
172
|
+
Parameters
|
173
|
+
----------
|
174
|
+
x : torch.Tensor
|
175
|
+
Input tensor
|
176
|
+
|
177
|
+
Returns
|
178
|
+
-------
|
179
|
+
torch.Tensor
|
180
|
+
The reconstructed output tensor.
|
181
|
+
"""
|
182
|
+
x = self.encoder(x)
|
183
|
+
x = self.decoder(x)
|
184
|
+
return x
|
185
|
+
|
186
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
187
|
+
"""
|
188
|
+
Encode the input tensor using the encoder.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
x : torch.Tensor
|
193
|
+
Input tensor
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
torch.Tensor
|
198
|
+
The encoded representation of the input tensor.
|
199
|
+
"""
|
200
|
+
return self.encoder(x)
|
201
|
+
|
202
|
+
|
203
|
+
class Encoder_AE(nn.Module):
|
204
|
+
"""
|
205
|
+
A simple encoder to be used in an autoencoder model.
|
206
|
+
|
207
|
+
This is the encoder used to replicate AE, which was a TF function. It consists of a CNN followed by a fully
|
208
|
+
connected layer.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
channels : int
|
213
|
+
Number of input channels
|
214
|
+
|
215
|
+
input_shape : tuple[int, int, int]
|
216
|
+
number of channels, number of rows, number of columns in input images.
|
217
|
+
|
218
|
+
encoding_dim : the size of the 1D array that emerges from the fully connected layer.
|
219
|
+
|
220
|
+
"""
|
221
|
+
|
222
|
+
def __init__(
|
223
|
+
self,
|
224
|
+
input_shape: tuple[int, int, int],
|
225
|
+
encoding_dim: int,
|
226
|
+
) -> None:
|
227
|
+
super().__init__()
|
228
|
+
|
229
|
+
channels = input_shape[0]
|
230
|
+
nc_in, nc_mid, nc_done = 256, 128, 64
|
231
|
+
|
232
|
+
conv_in = nn.Conv2d(channels, nc_in, 2, stride=1, padding=1)
|
233
|
+
conv_mid = nn.Conv2d(nc_in, nc_mid, 2, stride=1, padding=1)
|
234
|
+
conv_done = nn.Conv2d(nc_mid, nc_done, 2, stride=1)
|
235
|
+
|
236
|
+
self.encoding_ops: nn.Sequential = nn.Sequential(
|
237
|
+
conv_in,
|
238
|
+
nn.LeakyReLU(),
|
239
|
+
nn.MaxPool2d(2),
|
240
|
+
conv_mid,
|
241
|
+
nn.LeakyReLU(),
|
242
|
+
nn.MaxPool2d(2),
|
243
|
+
conv_done,
|
244
|
+
)
|
245
|
+
|
246
|
+
ny, nx = input_shape[1:]
|
247
|
+
self.post_op_shape: tuple[int, int, int] = (nc_done, ny // 4 - 1, nx // 4 - 1)
|
248
|
+
self.flatcon: int = math.prod(self.post_op_shape)
|
249
|
+
self.flatten: nn.Sequential = nn.Sequential(
|
250
|
+
nn.Flatten(),
|
251
|
+
nn.Linear(
|
252
|
+
self.flatcon,
|
253
|
+
encoding_dim,
|
254
|
+
),
|
255
|
+
)
|
256
|
+
|
257
|
+
def forward(self, x: Any) -> Any:
|
258
|
+
"""
|
259
|
+
Perform a forward pass through the AE_torch encoder.
|
260
|
+
|
261
|
+
Parameters
|
262
|
+
----------
|
263
|
+
x : torch.Tensor
|
264
|
+
Input tensor
|
265
|
+
|
266
|
+
Returns
|
267
|
+
-------
|
268
|
+
torch.Tensor
|
269
|
+
The encoded representation of the input tensor.
|
270
|
+
"""
|
271
|
+
x = self.encoding_ops(x)
|
272
|
+
|
273
|
+
x = self.flatten(x)
|
274
|
+
|
275
|
+
return x
|
276
|
+
|
277
|
+
|
278
|
+
class Decoder_AE(nn.Module):
|
279
|
+
"""
|
280
|
+
A simple decoder to be used in an autoencoder model.
|
281
|
+
|
282
|
+
This is the decoder used by the AriaAutoencoder model.
|
283
|
+
|
284
|
+
Parameters
|
285
|
+
----------
|
286
|
+
channels : int
|
287
|
+
Number of output channels
|
288
|
+
"""
|
289
|
+
|
290
|
+
def __init__(
|
291
|
+
self,
|
292
|
+
input_shape: tuple[int, int, int],
|
293
|
+
encoding_dim: int,
|
294
|
+
post_op_shape: tuple[int, int, int],
|
295
|
+
) -> None:
|
296
|
+
super().__init__()
|
297
|
+
|
298
|
+
self.post_op_shape = post_op_shape
|
299
|
+
self.input_shape = input_shape # need to store this for use in forward().
|
300
|
+
channels = input_shape[0]
|
301
|
+
|
302
|
+
self.input: nn.Linear = nn.Linear(encoding_dim, math.prod(post_op_shape))
|
303
|
+
|
304
|
+
self.decoder: nn.Sequential = nn.Sequential(
|
305
|
+
nn.ConvTranspose2d(64, 128, 2, stride=1),
|
306
|
+
nn.LeakyReLU(),
|
307
|
+
nn.ConvTranspose2d(128, 256, 2, stride=2),
|
308
|
+
nn.LeakyReLU(),
|
309
|
+
nn.ConvTranspose2d(256, channels, 2, stride=2),
|
310
|
+
)
|
311
|
+
|
312
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
313
|
+
"""
|
314
|
+
Perform a forward pass through the decoder.
|
315
|
+
|
316
|
+
Parameters
|
317
|
+
----------
|
318
|
+
x : torch.Tensor
|
319
|
+
The encoded tensor.
|
320
|
+
|
321
|
+
Returns
|
322
|
+
-------
|
323
|
+
torch.Tensor
|
324
|
+
The reconstructed output tensor.
|
325
|
+
"""
|
326
|
+
x = self.input(x)
|
327
|
+
x = x.reshape((-1, *self.post_op_shape))
|
328
|
+
x = self.decoder(x)
|
329
|
+
x = x.reshape((-1, *self.input_shape))
|
330
|
+
return x
|