dataeval 0.74.2__py3-none-any.whl → 0.75.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +27 -23
  2. dataeval/detectors/__init__.py +2 -2
  3. dataeval/detectors/drift/__init__.py +14 -12
  4. dataeval/detectors/drift/base.py +1 -1
  5. dataeval/detectors/drift/cvm.py +1 -1
  6. dataeval/detectors/drift/ks.py +1 -1
  7. dataeval/detectors/drift/mmd.py +6 -5
  8. dataeval/detectors/drift/torch.py +12 -12
  9. dataeval/detectors/drift/uncertainty.py +3 -2
  10. dataeval/detectors/linters/__init__.py +4 -4
  11. dataeval/detectors/linters/clusterer.py +2 -7
  12. dataeval/detectors/linters/duplicates.py +6 -10
  13. dataeval/detectors/linters/outliers.py +4 -2
  14. dataeval/detectors/ood/__init__.py +3 -10
  15. dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
  16. dataeval/detectors/ood/base.py +64 -161
  17. dataeval/detectors/ood/metadata_ks_compare.py +34 -42
  18. dataeval/detectors/ood/metadata_least_likely.py +3 -3
  19. dataeval/detectors/ood/metadata_ood_mi.py +6 -5
  20. dataeval/detectors/ood/mixin.py +146 -0
  21. dataeval/detectors/ood/output.py +63 -0
  22. dataeval/interop.py +6 -5
  23. dataeval/{logging.py → log.py} +2 -0
  24. dataeval/metrics/__init__.py +2 -2
  25. dataeval/metrics/bias/__init__.py +9 -12
  26. dataeval/metrics/bias/balance.py +10 -8
  27. dataeval/metrics/bias/coverage.py +52 -4
  28. dataeval/metrics/bias/diversity.py +42 -14
  29. dataeval/metrics/bias/parity.py +15 -12
  30. dataeval/metrics/estimators/__init__.py +2 -2
  31. dataeval/metrics/estimators/ber.py +3 -1
  32. dataeval/metrics/estimators/divergence.py +1 -1
  33. dataeval/metrics/estimators/uap.py +1 -1
  34. dataeval/metrics/stats/__init__.py +18 -18
  35. dataeval/metrics/stats/base.py +4 -4
  36. dataeval/metrics/stats/boxratiostats.py +8 -9
  37. dataeval/metrics/stats/datasetstats.py +10 -14
  38. dataeval/metrics/stats/dimensionstats.py +4 -4
  39. dataeval/metrics/stats/hashstats.py +12 -8
  40. dataeval/metrics/stats/labelstats.py +5 -5
  41. dataeval/metrics/stats/pixelstats.py +4 -9
  42. dataeval/metrics/stats/visualstats.py +4 -9
  43. dataeval/utils/__init__.py +4 -13
  44. dataeval/utils/dataset/__init__.py +7 -0
  45. dataeval/utils/{torch → dataset}/datasets.py +2 -0
  46. dataeval/utils/dataset/read.py +63 -0
  47. dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
  48. dataeval/utils/image.py +2 -2
  49. dataeval/utils/metadata.py +310 -5
  50. dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +1 -104
  51. dataeval/utils/torch/__init__.py +2 -17
  52. dataeval/utils/torch/gmm.py +29 -6
  53. dataeval/utils/torch/{utils.py → internal.py} +82 -58
  54. dataeval/utils/torch/models.py +10 -8
  55. dataeval/utils/torch/trainer.py +6 -85
  56. dataeval/workflows/__init__.py +2 -5
  57. dataeval/workflows/sufficiency.py +16 -6
  58. dataeval-0.75.0.dist-info/METADATA +136 -0
  59. dataeval-0.75.0.dist-info/RECORD +67 -0
  60. dataeval/detectors/ood/base_torch.py +0 -109
  61. dataeval/metrics/bias/metadata_preprocessing.py +0 -285
  62. dataeval/utils/gmm.py +0 -26
  63. dataeval-0.74.2.dist-info/METADATA +0 -120
  64. dataeval-0.74.2.dist-info/RECORD +0 -66
  65. {dataeval-0.74.2.dist-info → dataeval-0.75.0.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.74.2.dist-info → dataeval-0.75.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,36 +1,40 @@
1
- __version__ = "0.74.2"
1
+ """
2
+ DataEval provides a simple interface to characterize image data and its impact on model performance
3
+ across classification and object-detection tasks. It also provides capabilities to select and curate
4
+ datasets to test and train performant, robust, unbiased and reliable AI models and monitor for data
5
+ shifts that impact performance of deployed models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ __all__ = ["detectors", "log", "metrics", "utils", "workflows"]
11
+ __version__ = "0.75.0"
2
12
 
3
13
  import logging
4
- from importlib.util import find_spec
14
+
15
+ from dataeval import detectors, metrics, utils, workflows
5
16
 
6
17
  logging.getLogger(__name__).addHandler(logging.NullHandler())
7
18
 
8
19
 
9
- def log_stderr(level: int = logging.DEBUG) -> None:
20
+ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> None:
10
21
  """
11
- Helper for quickly adding a StreamHandler to the logger. Useful for
12
- debugging.
22
+ Helper for quickly adding a StreamHandler to the logger. Useful for debugging.
23
+
24
+ Parameters
25
+ ----------
26
+ level : int, default logging.DEBUG(10)
27
+ Set the logging level for the logger
28
+ handler : logging.Handler, optional
29
+ Sets the logging handler for the logger if provided, otherwise logger will be
30
+ provided with a StreamHandler
13
31
  """
14
32
  import logging
15
33
 
16
34
  logger = logging.getLogger(__name__)
17
- handler = logging.StreamHandler()
18
- handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
35
+ if handler is None:
36
+ handler = logging.StreamHandler() if handler is None else handler
37
+ handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
19
38
  logger.addHandler(handler)
20
39
  logger.setLevel(level)
21
- logger.debug("Added a stderr logging handler to logger: %s", __name__)
22
-
23
-
24
- _IS_TORCH_AVAILABLE = find_spec("torch") is not None
25
- _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
26
-
27
- del find_spec
28
-
29
- from dataeval import detectors, metrics # noqa: E402
30
-
31
- __all__ = ["log_stderr", "detectors", "metrics"]
32
-
33
- if _IS_TORCH_AVAILABLE:
34
- from dataeval import utils, workflows
35
-
36
- __all__ += ["utils", "workflows"]
40
+ logger.debug(f"Added logging handler {handler} to logger: {__name__}")
@@ -2,6 +2,6 @@
2
2
  Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
3
3
  """
4
4
 
5
- from dataeval.detectors import drift, linters, ood
6
-
7
5
  __all__ = ["drift", "linters", "ood"]
6
+
7
+ from dataeval.detectors import drift, linters, ood
@@ -2,19 +2,21 @@
2
2
  :term:`Drift` detectors identify if the statistical properties of the data has changed.
3
3
  """
4
4
 
5
- from dataeval import _IS_TORCH_AVAILABLE
5
+ __all__ = [
6
+ "DriftCVM",
7
+ "DriftKS",
8
+ "DriftMMD",
9
+ "DriftMMDOutput",
10
+ "DriftOutput",
11
+ "DriftUncertainty",
12
+ "preprocess_drift",
13
+ "updates",
14
+ ]
15
+
6
16
  from dataeval.detectors.drift import updates
7
17
  from dataeval.detectors.drift.base import DriftOutput
8
18
  from dataeval.detectors.drift.cvm import DriftCVM
9
19
  from dataeval.detectors.drift.ks import DriftKS
10
-
11
- __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
12
-
13
- if _IS_TORCH_AVAILABLE:
14
- from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
- from dataeval.detectors.drift.torch import preprocess_drift
16
- from dataeval.detectors.drift.uncertainty import DriftUncertainty
17
-
18
- __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
19
-
20
- del _IS_TORCH_AVAILABLE
20
+ from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
21
+ from dataeval.detectors.drift.torch import preprocess_drift
22
+ from dataeval.detectors.drift.uncertainty import DriftUncertainty
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftOutput"]
11
+ __all__ = []
12
12
 
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftCVM"]
11
+ __all__ = []
12
12
 
13
13
  from typing import Callable, Literal
14
14
 
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftKS"]
11
+ __all__ = []
12
12
 
13
13
  from typing import Callable, Literal
14
14
 
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftMMD", "DriftMMDOutput"]
11
+ __all__ = []
12
12
 
13
13
  from dataclasses import dataclass
14
14
  from typing import Callable
@@ -17,9 +17,10 @@ import torch
17
17
  from numpy.typing import ArrayLike
18
18
 
19
19
  from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
- from dataeval.detectors.drift.torch import _GaussianRBF, _mmd2_from_kernel_matrix, get_device
20
+ from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
21
21
  from dataeval.interop import as_numpy
22
22
  from dataeval.output import set_metadata
23
+ from dataeval.utils.torch.internal import get_device
23
24
 
24
25
 
25
26
  @dataclass(frozen=True)
@@ -109,7 +110,7 @@ class DriftMMD(BaseDrift):
109
110
 
110
111
  # initialize kernel
111
112
  sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
112
- self._kernel = _GaussianRBF(sigma_tensor).to(self.device)
113
+ self._kernel = GaussianRBF(sigma_tensor).to(self.device)
113
114
 
114
115
  # compute kernel matrix for the reference data
115
116
  if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
@@ -150,9 +151,9 @@ class DriftMMD(BaseDrift):
150
151
  n = x.shape[0]
151
152
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
152
153
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
153
- mmd2 = _mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
154
+ mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
154
155
  mmd2_permuted = torch.Tensor(
155
- [_mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
156
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
156
157
  )
157
158
  mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
158
159
  p_val = (mmd2 <= mmd2_permuted).float().mean()
@@ -17,10 +17,10 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.utils.torch.utils import get_device, predict_batch
20
+ from dataeval.utils.torch.internal import get_device, predict_batch
21
21
 
22
22
 
23
- def _mmd2_from_kernel_matrix(
23
+ def mmd2_from_kernel_matrix(
24
24
  kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
25
25
  ) -> torch.Tensor:
26
26
  """
@@ -127,7 +127,7 @@ def _squared_pairwise_distance(
127
127
 
128
128
  def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
129
129
  """
130
- Bandwidth estimation using the median heuristic :cite:t:`Gretton2012`.
130
+ Bandwidth estimation using the median heuristic `Gretton2012`
131
131
 
132
132
  Parameters
133
133
  ----------
@@ -151,7 +151,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
151
151
  return sigma
152
152
 
153
153
 
154
- class _GaussianRBF(nn.Module):
154
+ class GaussianRBF(nn.Module):
155
155
  """
156
156
  Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
157
157
 
@@ -179,18 +179,18 @@ class _GaussianRBF(nn.Module):
179
179
  ) -> None:
180
180
  super().__init__()
181
181
  init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
182
- self.config = {
182
+ self.config: dict[str, Any] = {
183
183
  "sigma": sigma,
184
184
  "trainable": trainable,
185
185
  "init_sigma_fn": init_sigma_fn,
186
186
  }
187
187
  if sigma is None:
188
- self.log_sigma = nn.Parameter(torch.empty(1), requires_grad=trainable)
189
- self.init_required = True
188
+ self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
189
+ self.init_required: bool = True
190
190
  else:
191
191
  sigma = sigma.reshape(-1) # [Ns,]
192
- self.log_sigma = nn.Parameter(sigma.log(), requires_grad=trainable)
193
- self.init_required = False
192
+ self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
193
+ self.init_required: bool = False
194
194
  self.init_sigma_fn = init_sigma_fn
195
195
  self.trainable = trainable
196
196
 
@@ -200,8 +200,8 @@ class _GaussianRBF(nn.Module):
200
200
 
201
201
  def forward(
202
202
  self,
203
- x: np.ndarray | torch.Tensor,
204
- y: np.ndarray | torch.Tensor,
203
+ x: np.ndarray[Any, Any] | torch.Tensor,
204
+ y: np.ndarray[Any, Any] | torch.Tensor,
205
205
  infer_sigma: bool = False,
206
206
  ) -> torch.Tensor:
207
207
  x, y = torch.as_tensor(x), torch.as_tensor(y)
@@ -213,7 +213,7 @@ class _GaussianRBF(nn.Module):
213
213
  sigma = self.init_sigma_fn(x, y, dist)
214
214
  with torch.no_grad():
215
215
  self.log_sigma.copy_(sigma.log().clone())
216
- self.init_required = False
216
+ self.init_required: bool = False
217
217
 
218
218
  gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
219
219
  # TODO: do matrix multiplication after all?
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- __all__ = ["DriftUncertainty"]
11
+ __all__ = []
12
12
 
13
13
  from functools import partial
14
14
  from typing import Callable, Literal
@@ -20,7 +20,8 @@ from scipy.stats import entropy
20
20
 
21
21
  from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
22
  from dataeval.detectors.drift.ks import DriftKS
23
- from dataeval.detectors.drift.torch import get_device, preprocess_drift
23
+ from dataeval.detectors.drift.torch import preprocess_drift
24
+ from dataeval.utils.torch.internal import get_device
24
25
 
25
26
 
26
27
  def classifier_uncertainty(
@@ -2,10 +2,6 @@
2
2
  Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
3
3
  """
4
4
 
5
- from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
6
- from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
7
- from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
8
-
9
5
  __all__ = [
10
6
  "Clusterer",
11
7
  "ClustererOutput",
@@ -14,3 +10,7 @@ __all__ = [
14
10
  "Outliers",
15
11
  "OutliersOutput",
16
12
  ]
13
+
14
+ from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
15
+ from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
16
+ from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["ClustererOutput", "Clusterer"]
3
+ __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Any, Iterable, NamedTuple, cast
@@ -147,12 +147,6 @@ class Clusterer:
147
147
  ----
148
148
  The Clusterer works best when the length of the feature dimension, P, is less than 500.
149
149
  If flattening a CxHxW image results in a dimension larger than 500, then it is recommended to reduce the dimensions.
150
-
151
- Example
152
- -------
153
- Initialize the Clusterer class:
154
-
155
- >>> cluster = Clusterer(dataset)
156
150
  """
157
151
 
158
152
  def __init__(self, dataset: ArrayLike) -> None:
@@ -506,6 +500,7 @@ class Clusterer:
506
500
 
507
501
  Example
508
502
  -------
503
+ >>> cluster = Clusterer(clusterer_images)
509
504
  >>> cluster.evaluate()
510
505
  ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
511
506
  """ # noqa: E501
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["DuplicatesOutput", "Duplicates"]
3
+ __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Generic, Iterable, Sequence, TypeVar, overload
@@ -51,13 +51,6 @@ class Duplicates:
51
51
  ----------
52
52
  only_exact : bool, default False
53
53
  Only inspect the dataset for exact image matches
54
-
55
- Example
56
- -------
57
- Initialize the Duplicates class:
58
-
59
- >>> all_dupes = Duplicates()
60
- >>> exact_dupes = Duplicates(only_exact=True)
61
54
  """
62
55
 
63
56
  def __init__(self, only_exact: bool = False) -> None:
@@ -73,7 +66,8 @@ class Duplicates:
73
66
  if not self.only_exact:
74
67
  near_dict: dict[int, list] = {}
75
68
  for i, value in enumerate(stats["pchash"]):
76
- near_dict.setdefault(value, []).append(i)
69
+ if value:
70
+ near_dict.setdefault(value, []).append(i)
77
71
  near = [sorted(v) for v in near_dict.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
78
72
  else:
79
73
  near = []
@@ -112,6 +106,7 @@ class Duplicates:
112
106
 
113
107
  Example
114
108
  -------
109
+ >>> exact_dupes = Duplicates(only_exact=True)
115
110
  >>> exact_dupes.from_stats([hashes1, hashes2])
116
111
  DuplicatesOutput(exact=[{0: [3, 20]}, {0: [16], 1: [12]}], near=[])
117
112
  """
@@ -159,7 +154,8 @@ class Duplicates:
159
154
 
160
155
  Example
161
156
  -------
162
- >>> all_dupes.evaluate(images)
157
+ >>> all_dupes = Duplicates()
158
+ >>> all_dupes.evaluate(duplicate_images)
163
159
  DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
164
160
  """ # noqa: E501
165
161
  self.stats = hashstats(data)
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["OutliersOutput", "Outliers"]
3
+ __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
@@ -188,6 +188,7 @@ class Outliers:
188
188
  -------
189
189
  Evaluate the dataset:
190
190
 
191
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
191
192
  >>> results = outliers.from_stats([stats1, stats2])
192
193
  >>> len(results)
193
194
  2
@@ -248,7 +249,8 @@ class Outliers:
248
249
  -------
249
250
  Evaluate the dataset:
250
251
 
251
- >>> results = outliers.evaluate(images)
252
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
253
+ >>> results = outliers.evaluate(outlier_images)
252
254
  >>> list(results.issues)
253
255
  [10, 12]
254
256
  >>> results.issues[10]
@@ -2,14 +2,7 @@
2
2
  Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- from dataeval import _IS_TORCH_AVAILABLE
6
- from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
5
+ __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
7
6
 
8
- __all__ = ["OODOutput", "OODScoreOutput"]
9
-
10
- if _IS_TORCH_AVAILABLE:
11
- from dataeval.detectors.ood.ae_torch import OOD_AE
12
-
13
- __all__ += ["OOD_AE"]
14
-
15
- del _IS_TORCH_AVAILABLE
7
+ from dataeval.detectors.ood.ae import OOD_AE
8
+ from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
@@ -10,16 +10,18 @@ Licensed under Apache Software License (Apache 2.0)
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
+ __all__ = []
14
+
13
15
  from typing import Callable
14
16
 
15
17
  import numpy as np
16
18
  import torch
17
19
  from numpy.typing import ArrayLike
18
20
 
19
- from dataeval.detectors.ood.base import OODScoreOutput
20
- from dataeval.detectors.ood.base_torch import OODBase
21
+ from dataeval.detectors.ood.base import OODBase
22
+ from dataeval.detectors.ood.output import OODScoreOutput
21
23
  from dataeval.interop import as_numpy
22
- from dataeval.utils.torch.utils import predict_batch
24
+ from dataeval.utils.torch.internal import predict_batch
23
25
 
24
26
 
25
27
  class OOD_AE(OODBase):
@@ -28,7 +30,7 @@ class OOD_AE(OODBase):
28
30
 
29
31
  Parameters
30
32
  ----------
31
- model : AriaAutoencoder
33
+ model : Autoencoder
32
34
  An Autoencoder model.
33
35
  """
34
36