dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +3 -3
  7. dataeval/detectors/linters/duplicates.py +4 -4
  8. dataeval/detectors/linters/outliers.py +4 -4
  9. dataeval/detectors/ood/__init__.py +9 -9
  10. dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
  11. dataeval/detectors/ood/base.py +63 -113
  12. dataeval/detectors/ood/base_torch.py +109 -0
  13. dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  14. dataeval/interop.py +1 -1
  15. dataeval/metrics/bias/__init__.py +3 -0
  16. dataeval/metrics/bias/balance.py +73 -70
  17. dataeval/metrics/bias/coverage.py +4 -4
  18. dataeval/metrics/bias/diversity.py +67 -136
  19. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  20. dataeval/metrics/bias/metadata_utils.py +229 -0
  21. dataeval/metrics/bias/parity.py +51 -161
  22. dataeval/metrics/estimators/ber.py +3 -3
  23. dataeval/metrics/estimators/divergence.py +3 -3
  24. dataeval/metrics/estimators/uap.py +3 -3
  25. dataeval/metrics/stats/base.py +2 -2
  26. dataeval/metrics/stats/boxratiostats.py +1 -1
  27. dataeval/metrics/stats/datasetstats.py +6 -6
  28. dataeval/metrics/stats/dimensionstats.py +1 -1
  29. dataeval/metrics/stats/hashstats.py +1 -1
  30. dataeval/metrics/stats/labelstats.py +3 -3
  31. dataeval/metrics/stats/pixelstats.py +1 -1
  32. dataeval/metrics/stats/visualstats.py +1 -1
  33. dataeval/output.py +77 -53
  34. dataeval/utils/__init__.py +1 -7
  35. dataeval/utils/gmm.py +26 -0
  36. dataeval/utils/metadata.py +29 -9
  37. dataeval/utils/torch/gmm.py +98 -0
  38. dataeval/utils/torch/models.py +192 -0
  39. dataeval/utils/torch/trainer.py +84 -5
  40. dataeval/utils/torch/utils.py +107 -1
  41. dataeval/workflows/sufficiency.py +4 -4
  42. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
  43. dataeval-0.74.1.dist-info/RECORD +65 -0
  44. dataeval/detectors/ood/aegmm.py +0 -66
  45. dataeval/detectors/ood/llr.py +0 -302
  46. dataeval/detectors/ood/vae.py +0 -97
  47. dataeval/detectors/ood/vaegmm.py +0 -75
  48. dataeval/metrics/bias/metadata.py +0 -440
  49. dataeval/utils/lazy.py +0 -26
  50. dataeval/utils/tensorflow/__init__.py +0 -19
  51. dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  52. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  53. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  54. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  55. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  56. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  57. dataeval-0.73.1.dist-info/RECORD +0 -73
  58. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -116,7 +116,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
116
116
  }
117
117
 
118
118
 
119
- @set_metadata()
119
+ @set_metadata
120
120
  def hashstats(
121
121
  images: Iterable[ArrayLike],
122
122
  bboxes: Iterable[ArrayLike] | None = None,
@@ -9,11 +9,11 @@ from typing import Any, Iterable, Mapping, TypeVar
9
9
  from numpy.typing import ArrayLike
10
10
 
11
11
  from dataeval.interop import to_numpy
12
- from dataeval.output import OutputMetadata, set_metadata
12
+ from dataeval.output import Output, set_metadata
13
13
 
14
14
 
15
15
  @dataclass(frozen=True)
16
- class LabelStatsOutput(OutputMetadata):
16
+ class LabelStatsOutput(Output):
17
17
  """
18
18
  Output class for :func:`labelstats` stats metric
19
19
 
@@ -57,7 +57,7 @@ def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
57
57
  return dict(sorted(d.items(), key=lambda x: x[0]))
58
58
 
59
59
 
60
- @set_metadata()
60
+ @set_metadata
61
61
  def labelstats(
62
62
  labels: Iterable[ArrayLike],
63
63
  ) -> LabelStatsOutput:
@@ -67,7 +67,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
67
67
  }
68
68
 
69
69
 
70
- @set_metadata()
70
+ @set_metadata
71
71
  def pixelstats(
72
72
  images: Iterable[ArrayLike],
73
73
  bboxes: Iterable[ArrayLike] | None = None,
@@ -74,7 +74,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
74
74
  }
75
75
 
76
76
 
77
- @set_metadata()
77
+ @set_metadata
78
78
  def visualstats(
79
79
  images: Iterable[ArrayLike],
80
80
  bboxes: Iterable[ArrayLike] | None = None,
dataeval/output.py CHANGED
@@ -4,9 +4,10 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import sys
7
+ from collections.abc import Mapping
7
8
  from datetime import datetime, timezone
8
- from functools import wraps
9
- from typing import Any, Callable, Iterable, TypeVar
9
+ from functools import partial, wraps
10
+ from typing import Any, Callable, Iterator, TypeVar
10
11
 
11
12
  import numpy as np
12
13
 
@@ -18,7 +19,7 @@ else:
18
19
  from dataeval import __version__
19
20
 
20
21
 
21
- class OutputMetadata:
22
+ class Output:
22
23
  _name: str
23
24
  _execution_time: datetime
24
25
  _execution_duration: float
@@ -26,6 +27,9 @@ class OutputMetadata:
26
27
  _state: dict[str, str]
27
28
  _version: str
28
29
 
30
+ def __str__(self) -> str:
31
+ return f"{self.__class__.__name__}: {str(self.dict())}"
32
+
29
33
  def dict(self) -> dict[str, Any]:
30
34
  return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
31
35
 
@@ -33,58 +37,78 @@ class OutputMetadata:
33
37
  return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
34
38
 
35
39
 
40
+ TKey = TypeVar("TKey", str, int, float, set)
41
+ TValue = TypeVar("TValue")
42
+
43
+
44
+ class MappingOutput(Mapping[TKey, TValue], Output):
45
+ __slots__ = ["_data"]
46
+
47
+ def __init__(self, data: Mapping[TKey, TValue]):
48
+ self._data = data
49
+
50
+ def __getitem__(self, key: TKey) -> TValue:
51
+ return self._data.__getitem__(key)
52
+
53
+ def __iter__(self) -> Iterator[TKey]:
54
+ return self._data.__iter__()
55
+
56
+ def __len__(self) -> int:
57
+ return self._data.__len__()
58
+
59
+ def dict(self) -> dict[str, TValue]:
60
+ return {str(k): v for k, v in self._data.items()}
61
+
62
+
36
63
  P = ParamSpec("P")
37
- R = TypeVar("R", bound=OutputMetadata)
64
+ R = TypeVar("R", bound=Output)
38
65
 
39
66
 
40
- def set_metadata(
41
- state_attr: Iterable[str] | None = None,
42
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
67
+ def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
43
68
  """Decorator to stamp OutputMetadata classes with runtime metadata"""
44
69
 
45
- def decorator(fn: Callable[P, R]) -> Callable[P, R]:
46
- @wraps(fn)
47
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
48
- def fmt(v):
49
- if np.isscalar(v):
50
- return v
51
- if hasattr(v, "shape"):
52
- return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
53
- if hasattr(v, "__len__"):
54
- return f"{v.__class__.__name__}: len={len(v)}"
55
- return f"{v.__class__.__name__}"
56
-
57
- time = datetime.now(timezone.utc)
58
- result = fn(*args, **kwargs)
59
- duration = (datetime.now(timezone.utc) - time).total_seconds()
60
- fn_params = inspect.signature(fn).parameters
61
- # set all params with defaults then update params with mapped arguments and explicit keyword args
62
- arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
63
- arguments.update(zip(fn_params, args))
64
- arguments.update(kwargs)
65
- arguments = {k: fmt(v) for k, v in arguments.items()}
66
- state = (
67
- {k: fmt(getattr(args[0], k)) for k in state_attr if "self" in arguments}
68
- if "self" in arguments and state_attr
69
- else {}
70
- )
71
- name = (
72
- f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
73
- if "self" in arguments
74
- else f"{fn.__module__}.{fn.__qualname__}"
75
- )
76
- metadata = {
77
- "_name": name,
78
- "_execution_time": time,
79
- "_execution_duration": duration,
80
- "_arguments": {k: v for k, v in arguments.items() if k != "self"},
81
- "_state": state,
82
- "_version": __version__,
83
- }
84
- for k, v in metadata.items():
85
- object.__setattr__(result, k, v)
86
- return result
87
-
88
- return wrapper
89
-
90
- return decorator
70
+ if fn is None:
71
+ return partial(set_metadata, state=state) # type: ignore
72
+
73
+ @wraps(fn)
74
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
75
+ def fmt(v):
76
+ if np.isscalar(v):
77
+ return v
78
+ if hasattr(v, "shape"):
79
+ return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
80
+ if hasattr(v, "__len__"):
81
+ return f"{v.__class__.__name__}: len={len(v)}"
82
+ return f"{v.__class__.__name__}"
83
+
84
+ time = datetime.now(timezone.utc)
85
+ result = fn(*args, **kwargs)
86
+ duration = (datetime.now(timezone.utc) - time).total_seconds()
87
+ fn_params = inspect.signature(fn).parameters
88
+
89
+ # set all params with defaults then update params with mapped arguments and explicit keyword args
90
+ arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
91
+ arguments.update(zip(fn_params, args))
92
+ arguments.update(kwargs)
93
+ arguments = {k: fmt(v) for k, v in arguments.items()}
94
+ state_attrs = (
95
+ {k: fmt(getattr(args[0], k)) for k in state if "self" in arguments} if "self" in arguments and state else {}
96
+ )
97
+ name = (
98
+ f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
99
+ if "self" in arguments
100
+ else f"{fn.__module__}.{fn.__qualname__}"
101
+ )
102
+ metadata = {
103
+ "_name": name,
104
+ "_execution_time": time,
105
+ "_execution_duration": duration,
106
+ "_arguments": {k: v for k, v in arguments.items() if k != "self"},
107
+ "_state": state_attrs,
108
+ "_version": __version__,
109
+ }
110
+ for k, v in metadata.items():
111
+ object.__setattr__(result, k, v)
112
+ return result
113
+
114
+ return wrapper
@@ -4,7 +4,7 @@ in setting up architectures that are guaranteed to work with applicable DataEval
4
4
  metrics. Currently DataEval supports both :term:`TensorFlow` and PyTorch backends.
5
5
  """
6
6
 
7
- from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
7
+ from dataeval import _IS_TORCH_AVAILABLE
8
8
  from dataeval.utils.metadata import merge_metadata
9
9
  from dataeval.utils.split_dataset import split_dataset
10
10
 
@@ -15,10 +15,4 @@ if _IS_TORCH_AVAILABLE:
15
15
 
16
16
  __all__ += ["torch"]
17
17
 
18
- if _IS_TENSORFLOW_AVAILABLE:
19
- from dataeval.utils import tensorflow
20
-
21
- __all__ += ["tensorflow"]
22
-
23
- del _IS_TENSORFLOW_AVAILABLE
24
18
  del _IS_TORCH_AVAILABLE
dataeval/utils/gmm.py ADDED
@@ -0,0 +1,26 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar
3
+
4
+ TGMMData = TypeVar("TGMMData")
5
+
6
+
7
+ @dataclass
8
+ class GaussianMixtureModelParams(Generic[TGMMData]):
9
+ """
10
+ phi : TGMMData
11
+ Mixture component distribution weights.
12
+ mu : TGMMData
13
+ Mixture means.
14
+ cov : TGMMData
15
+ Mixture covariance.
16
+ L : TGMMData
17
+ Cholesky decomposition of `cov`.
18
+ log_det_cov : TGMMData
19
+ Log of the determinant of `cov`.
20
+ """
21
+
22
+ phi: TGMMData
23
+ mu: TGMMData
24
+ cov: TGMMData
25
+ L: TGMMData
26
+ log_det_cov: TGMMData
@@ -131,7 +131,9 @@ def _flatten_dict_inner(
131
131
  return items, size
132
132
 
133
133
 
134
- def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> dict[str, Any]:
134
+ def _flatten_dict(
135
+ d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
136
+ ) -> tuple[dict[str, Any], int]:
135
137
  """
136
138
  Flattens a dictionary and converts values to numeric values when possible.
137
139
 
@@ -165,7 +167,7 @@ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qual
165
167
  output[k] = cv
166
168
  elif not isinstance(cv, list):
167
169
  output[k] = cv if not size else [cv] * size
168
- return output
170
+ return output, size if size is not None else 1
169
171
 
170
172
 
171
173
  def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
@@ -188,7 +190,7 @@ def merge_metadata(
188
190
  ignore_lists: bool = False,
189
191
  fully_qualified: bool = False,
190
192
  as_numpy: bool = False,
191
- ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
193
+ ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
192
194
  """
193
195
  Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
194
196
 
@@ -208,8 +210,10 @@ def merge_metadata(
208
210
 
209
211
  Returns
210
212
  -------
211
- dict[str, list[Any]] | dict[str, NDArray[Any]]
213
+ dict[str, list[Any]] or dict[str, NDArray[Any]]
212
214
  A single dictionary containing the flattened data as lists or NumPy arrays
215
+ NDArray[np.int_]
216
+ Array defining where individual images start, helpful when working with object detection metadata
213
217
 
214
218
  Note
215
219
  ----
@@ -217,9 +221,12 @@ def merge_metadata(
217
221
 
218
222
  Example
219
223
  -------
220
- >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
221
- >>> merge_metadata(list_metadata)
224
+ >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
225
+ >>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
226
+ >>> reorganized_metadata
222
227
  {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
228
+ >>> image_indicies
229
+ array([0])
223
230
  """
224
231
  merged: dict[str, list[Any]] = {}
225
232
  isect: set[str] = set()
@@ -236,8 +243,11 @@ def merge_metadata(
236
243
  else:
237
244
  dicts = list(metadata)
238
245
 
239
- for d in dicts:
240
- flattened = _flatten_dict(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
246
+ image_repeats = np.zeros(len(dicts))
247
+ for i, d in enumerate(dicts):
248
+ flattened, image_repeats[i] = _flatten_dict(
249
+ d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
250
+ )
241
251
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
242
252
  union = union.union(flattened.keys())
243
253
  for k, v in flattened.items():
@@ -248,6 +258,16 @@ def merge_metadata(
248
258
 
249
259
  output: dict[str, Any] = {}
250
260
 
261
+ if image_repeats.sum() == image_repeats.size:
262
+ image_indicies = np.arange(image_repeats.size)
263
+ else:
264
+ image_ids = np.arange(image_repeats.size)
265
+ image_data = np.concatenate(
266
+ [np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
267
+ )
268
+ _, image_unsorted = np.unique(image_data, return_index=True)
269
+ image_indicies = np.sort(image_unsorted)
270
+
251
271
  if keys:
252
272
  output["keys"] = np.array(keys) if as_numpy else keys
253
273
 
@@ -255,4 +275,4 @@ def merge_metadata(
255
275
  cv = _convert_type(merged[k])
256
276
  output[k] = np.array(cv) if as_numpy else cv
257
277
 
258
- return output
278
+ return output, image_indicies
@@ -0,0 +1,98 @@
1
+ """
2
+ Adapted for Pytorch from:
3
+
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from dataeval.utils.gmm import GaussianMixtureModelParams
17
+
18
+
19
+ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams[torch.Tensor]:
20
+ """
21
+ Compute parameters of Gaussian Mixture Model.
22
+
23
+ Parameters
24
+ ----------
25
+ z : torch.Tensor
26
+ Observations.
27
+ gamma : torch.Tensor
28
+ Mixture probabilities to derive mixture distribution weights from.
29
+
30
+ Returns
31
+ -------
32
+ GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
33
+ The parameters used to calculate energy.
34
+ """
35
+
36
+ # compute gmm parameters phi, mu and cov
37
+ N = gamma.shape[0] # nb of samples in batch
38
+ sum_gamma = torch.sum(gamma, 0) # K
39
+ phi = sum_gamma / N # K
40
+ # K x D (D = latent_dim)
41
+ mu = torch.sum(torch.unsqueeze(gamma, -1) * torch.unsqueeze(z, 1), 0) / torch.unsqueeze(sum_gamma, -1)
42
+ z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(mu, 0) # N x K x D
43
+ z_mu_outer = torch.unsqueeze(z_mu, -1) * torch.unsqueeze(z_mu, -2) # N x K x D x D
44
+
45
+ # K x D x D
46
+ cov = torch.sum(torch.unsqueeze(torch.unsqueeze(gamma, -1), -1) * z_mu_outer, 0) / torch.unsqueeze(
47
+ torch.unsqueeze(sum_gamma, -1), -1
48
+ )
49
+
50
+ # cholesky decomposition of covariance and determinant derivation
51
+ D = cov.shape[1]
52
+ eps = 1e-6
53
+ L = torch.linalg.cholesky(cov + torch.eye(D) * eps) # K x D x D
54
+ log_det_cov = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), 1) # K
55
+
56
+ return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
57
+
58
+
59
+ def gmm_energy(
60
+ z: torch.Tensor,
61
+ params: GaussianMixtureModelParams[torch.Tensor],
62
+ return_mean: bool = True,
63
+ ) -> tuple[torch.Tensor, torch.Tensor]:
64
+ """
65
+ Compute sample energy from Gaussian Mixture Model.
66
+
67
+ Parameters
68
+ ----------
69
+ params : GaussianMixtureModelParams
70
+ The gaussian mixture model parameters.
71
+ return_mean : bool, default True
72
+ Take mean across all sample energies in a batch.
73
+
74
+ Returns
75
+ -------
76
+ sample_energy
77
+ The sample energy of the GMM.
78
+ cov_diag
79
+ The inverse sum of the diagonal components of the covariance matrix.
80
+ """
81
+ D = params.cov.shape[1]
82
+ z_mu = torch.unsqueeze(z, 1) - torch.unsqueeze(params.mu, 0) # N x K x D
83
+ z_mu_T = torch.permute(z_mu, dims=[1, 2, 0]) # K x D x N
84
+ v = torch.linalg.solve_triangular(params.L, z_mu_T, upper=False) # K x D x D
85
+
86
+ # rewrite sample energy in logsumexp format for numerical stability
87
+ logits = torch.log(torch.unsqueeze(params.phi, -1)) - 0.5 * (
88
+ torch.sum(torch.square(v), 1) + float(D) * np.log(2.0 * np.pi) + torch.unsqueeze(params.log_det_cov, -1)
89
+ ) # K x N
90
+ sample_energy = -torch.logsumexp(logits, 0) # N
91
+
92
+ if return_mean:
93
+ sample_energy = torch.mean(sample_energy)
94
+
95
+ # inverse sum of variances
96
+ cov_diag = torch.sum(torch.divide(torch.tensor(1), torch.diagonal(params.cov, dim1=-2, dim2=-1)))
97
+
98
+ return sample_energy, cov_diag
@@ -2,8 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
8
+ import torch
7
9
  import torch.nn as nn
8
10
 
9
11
 
@@ -136,3 +138,193 @@ class Decoder(nn.Module):
136
138
  The reconstructed output tensor.
137
139
  """
138
140
  return self.decoder(x)
141
+
142
+
143
+ class AE(nn.Module):
144
+ """
145
+ An autoencoder model with a separate encoder and decoder. Meant to replace the TensorFlow model called AE, which we
146
+ used as the core of an autoencoder-based OOD detector, i.e. as an argument to OOD_AE().
147
+
148
+ Parameters
149
+ ----------
150
+ input_shape : tuple[int, int, int]
151
+ Number of input channels, number of rows, number of columns.() Number of examples per batch will be inferred
152
+ at runtime.)
153
+ """
154
+
155
+ def __init__(self, input_shape: tuple[int, int, int]) -> None:
156
+ super().__init__()
157
+
158
+ input_dim = math.prod(input_shape)
159
+
160
+ # following is lifted from src/dataeval/utils/tensorflow/_internal/utils.py. It makes an odd staircase that is
161
+ # basically proportional to the number of numbers in the image to the 0.8 power. '
162
+ encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)))
163
+
164
+ self.encoder: Encoder_AE = Encoder_AE(input_shape, encoding_dim)
165
+
166
+ self.decoder: Decoder_AE = Decoder_AE(input_shape, encoding_dim, self.encoder.post_op_shape)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Perform a forward pass through the encoder and decoder.
171
+
172
+ Parameters
173
+ ----------
174
+ x : torch.Tensor
175
+ Input tensor
176
+
177
+ Returns
178
+ -------
179
+ torch.Tensor
180
+ The reconstructed output tensor.
181
+ """
182
+ x = self.encoder(x)
183
+ x = self.decoder(x)
184
+ return x
185
+
186
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Encode the input tensor using the encoder.
189
+
190
+ Parameters
191
+ ----------
192
+ x : torch.Tensor
193
+ Input tensor
194
+
195
+ Returns
196
+ -------
197
+ torch.Tensor
198
+ The encoded representation of the input tensor.
199
+ """
200
+ return self.encoder(x)
201
+
202
+
203
+ class Encoder_AE(nn.Module):
204
+ """
205
+ A simple encoder to be used in an autoencoder model.
206
+
207
+ This is the encoder used to replicate AE, which was a TF function. It consists of a CNN followed by a fully
208
+ connected layer.
209
+
210
+ Parameters
211
+ ----------
212
+ channels : int
213
+ Number of input channels
214
+
215
+ input_shape : tuple[int, int, int]
216
+ number of channels, number of rows, number of columns in input images.
217
+
218
+ encoding_dim : the size of the 1D array that emerges from the fully connected layer.
219
+
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ input_shape: tuple[int, int, int],
225
+ encoding_dim: int,
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ channels = input_shape[0]
230
+ nc_in, nc_mid, nc_done = 256, 128, 64
231
+
232
+ conv_in = nn.Conv2d(channels, nc_in, 2, stride=1, padding=1)
233
+ conv_mid = nn.Conv2d(nc_in, nc_mid, 2, stride=1, padding=1)
234
+ conv_done = nn.Conv2d(nc_mid, nc_done, 2, stride=1)
235
+
236
+ self.encoding_ops: nn.Sequential = nn.Sequential(
237
+ conv_in,
238
+ nn.LeakyReLU(),
239
+ nn.MaxPool2d(2),
240
+ conv_mid,
241
+ nn.LeakyReLU(),
242
+ nn.MaxPool2d(2),
243
+ conv_done,
244
+ )
245
+
246
+ ny, nx = input_shape[1:]
247
+ self.post_op_shape: tuple[int, int, int] = (nc_done, ny // 4 - 1, nx // 4 - 1)
248
+ self.flatcon: int = math.prod(self.post_op_shape)
249
+ self.flatten: nn.Sequential = nn.Sequential(
250
+ nn.Flatten(),
251
+ nn.Linear(
252
+ self.flatcon,
253
+ encoding_dim,
254
+ ),
255
+ )
256
+
257
+ def forward(self, x: Any) -> Any:
258
+ """
259
+ Perform a forward pass through the AE_torch encoder.
260
+
261
+ Parameters
262
+ ----------
263
+ x : torch.Tensor
264
+ Input tensor
265
+
266
+ Returns
267
+ -------
268
+ torch.Tensor
269
+ The encoded representation of the input tensor.
270
+ """
271
+ x = self.encoding_ops(x)
272
+
273
+ x = self.flatten(x)
274
+
275
+ return x
276
+
277
+
278
+ class Decoder_AE(nn.Module):
279
+ """
280
+ A simple decoder to be used in an autoencoder model.
281
+
282
+ This is the decoder used by the AriaAutoencoder model.
283
+
284
+ Parameters
285
+ ----------
286
+ channels : int
287
+ Number of output channels
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ input_shape: tuple[int, int, int],
293
+ encoding_dim: int,
294
+ post_op_shape: tuple[int, int, int],
295
+ ) -> None:
296
+ super().__init__()
297
+
298
+ self.post_op_shape = post_op_shape
299
+ self.input_shape = input_shape # need to store this for use in forward().
300
+ channels = input_shape[0]
301
+
302
+ self.input: nn.Linear = nn.Linear(encoding_dim, math.prod(post_op_shape))
303
+
304
+ self.decoder: nn.Sequential = nn.Sequential(
305
+ nn.ConvTranspose2d(64, 128, 2, stride=1),
306
+ nn.LeakyReLU(),
307
+ nn.ConvTranspose2d(128, 256, 2, stride=2),
308
+ nn.LeakyReLU(),
309
+ nn.ConvTranspose2d(256, channels, 2, stride=2),
310
+ )
311
+
312
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
313
+ """
314
+ Perform a forward pass through the decoder.
315
+
316
+ Parameters
317
+ ----------
318
+ x : torch.Tensor
319
+ The encoded tensor.
320
+
321
+ Returns
322
+ -------
323
+ torch.Tensor
324
+ The reconstructed output tensor.
325
+ """
326
+ x = self.input(x)
327
+ x = x.reshape((-1, *self.post_op_shape))
328
+ x = self.decoder(x)
329
+ x = x.reshape((-1, *self.input_shape))
330
+ return x