dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -4,15 +4,17 @@ FR Test Statistic based estimate for the upperbound
4
4
  average precision using empirical mean precision
5
5
  """
6
6
 
7
- from typing import NamedTuple
7
+ from dataclasses import dataclass
8
8
 
9
9
  from numpy.typing import ArrayLike
10
10
  from sklearn.metrics import average_precision_score
11
11
 
12
12
  from dataeval._internal.interop import to_numpy
13
+ from dataeval._internal.output import OutputMetadata, set_metadata
13
14
 
14
15
 
15
- class UAPOutput(NamedTuple):
16
+ @dataclass(frozen=True)
17
+ class UAPOutput(OutputMetadata):
16
18
  """
17
19
  Attributes
18
20
  ----------
@@ -23,6 +25,7 @@ class UAPOutput(NamedTuple):
23
25
  uap: float
24
26
 
25
27
 
28
+ @set_metadata("dataeval.metrics")
26
29
  def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
27
30
  """
28
31
  FR Test Statistic based estimate of the empirical mean precision for
@@ -37,13 +40,39 @@ def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
37
40
 
38
41
  Returns
39
42
  -------
40
- Dict[str, float]
41
- uap : The empirical mean precision estimate
43
+ UAPOutput
44
+ The empirical mean precision estimate, float
42
45
 
43
46
  Raises
44
47
  ------
45
48
  ValueError
46
49
  If unique classes M < 2
50
+
51
+ Notes
52
+ -----
53
+ This function calculates the empirical mean precision using the
54
+ ``average_precision_score`` from scikit-learn, weighted by the class distribution.
55
+
56
+ Examples
57
+ --------
58
+ >>> y_true = np.array([0, 0, 1, 1])
59
+ >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
60
+ >>> uap(y_true, y_scores)
61
+ UAPOutput(uap=0.8333333333333333)
62
+
63
+ >>> y_true = np.array([0, 0, 1, 1, 2, 2])
64
+ >>> y_scores = np.array(
65
+ ... [
66
+ ... [0.7, 0.2, 0.1],
67
+ ... [0.4, 0.3, 0.3],
68
+ ... [0.1, 0.8, 0.1],
69
+ ... [0.2, 0.3, 0.5],
70
+ ... [0.4, 0.4, 0.2],
71
+ ... [0.1, 0.2, 0.7],
72
+ ... ]
73
+ ... )
74
+ >>> uap(y_true, y_scores)
75
+ UAPOutput(uap=0.7777777777777777)
47
76
  """
48
77
 
49
78
  precision = float(average_precision_score(to_numpy(labels), to_numpy(scores), average="weighted"))
@@ -1,7 +1,10 @@
1
- from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Literal, NamedTuple, Sequence
2
4
 
3
5
  import numpy as np
4
6
  import xxhash as xxh
7
+ from numpy.typing import NDArray
5
8
  from PIL import Image
6
9
  from scipy.fftpack import dct
7
10
  from scipy.signal import convolve2d
@@ -18,22 +21,22 @@ HASH_SIZE = 8
18
21
  MAX_FACTOR = 4
19
22
 
20
23
 
21
- def get_method(method_map: Dict[str, Callable], method: str) -> Callable:
24
+ def get_method(method_map: dict[str, Callable], method: str) -> Callable:
22
25
  if method not in method_map:
23
26
  raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
24
27
  return method_map[method]
25
28
 
26
29
 
27
30
  def get_counts(
28
- data: np.ndarray, names: List[str], is_categorical: List[bool], subset_mask: Optional[np.ndarray] = None
29
- ) -> tuple[Dict, Dict]:
31
+ data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
32
+ ) -> tuple[dict, dict]:
30
33
  """
31
34
  Initialize dictionary of histogram counts --- treat categorical values
32
35
  as histogram bins.
33
36
 
34
37
  Parameters
35
38
  ----------
36
- subset_mask: Optional[np.ndarray[bool]]
39
+ subset_mask: NDArray[np.bool_] | None
37
40
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
38
41
 
39
42
  Returns
@@ -66,24 +69,24 @@ def get_counts(
66
69
 
67
70
 
68
71
  def entropy(
69
- data: np.ndarray,
70
- names: List[str],
71
- is_categorical: List[bool],
72
+ data: NDArray,
73
+ names: list[str],
74
+ is_categorical: list[bool],
72
75
  normalized: bool = False,
73
- subset_mask: Optional[np.ndarray] = None,
74
- ) -> np.ndarray:
76
+ subset_mask: NDArray[np.bool_] | None = None,
77
+ ) -> NDArray[np.float64]:
75
78
  """
76
79
  Meant for use with Bias metrics, Balance, Diversity, ClasswiseBalance,
77
80
  and Classwise Diversity.
78
81
 
79
- Compute entropy for discrete/categorical variables and, through standard
80
- histogram binning, for continuous variables.
82
+ Compute entropy for discrete/categorical variables and for continuous variables through standard
83
+ histogram binning.
81
84
 
82
85
  Parameters
83
86
  ----------
84
87
  normalized: bool
85
88
  Flag that determines whether or not to normalize entropy by log(num_bins)
86
- subset_mask: Optional[np.ndarray[bool]]
89
+ subset_mask: NDArray[np.bool_] | None
87
90
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
88
91
 
89
92
  Notes
@@ -93,7 +96,7 @@ def entropy(
93
96
 
94
97
  Returns
95
98
  -------
96
- ent: np.ndarray[float]
99
+ ent: NDArray[np.float64]
97
100
  Entropy estimate per column of X
98
101
 
99
102
  See Also
@@ -119,16 +122,20 @@ def entropy(
119
122
 
120
123
 
121
124
  def get_num_bins(
122
- data: np.ndarray, names: List[str], is_categorical: List[bool], subset_mask: Optional[np.ndarray] = None
123
- ) -> np.ndarray:
125
+ data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
126
+ ) -> NDArray[np.float64]:
124
127
  """
125
128
  Number of bins or unique values for each metadata factor, used to
126
129
  normalize entropy/diversity.
127
130
 
128
131
  Parameters
129
132
  ----------
130
- subset_mask: Optional[np.ndarray[bool]]
133
+ subset_mask: NDArray[np.bool_] | None
131
134
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
135
+
136
+ Returns
137
+ -------
138
+ NDArray[np.float64]
132
139
  """
133
140
  # likely cached
134
141
  hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
@@ -139,7 +146,7 @@ def get_num_bins(
139
146
  return num_bins
140
147
 
141
148
 
142
- def infer_categorical(X: np.ndarray, threshold: float = 0.5) -> np.ndarray:
149
+ def infer_categorical(X: NDArray, threshold: float = 0.2) -> NDArray:
143
150
  """
144
151
  Compute fraction of feature values that are unique --- intended to be used
145
152
  for inferring whether variables are categorical.
@@ -154,9 +161,11 @@ def infer_categorical(X: np.ndarray, threshold: float = 0.5) -> np.ndarray:
154
161
  return pct_unique < threshold
155
162
 
156
163
 
157
- def preprocess_metadata(class_labels: Sequence[int], metadata: List[Dict]) -> Tuple[np.ndarray, List[str], List[bool]]:
164
+ def preprocess_metadata(
165
+ class_labels: Sequence[int], metadata: list[dict], cat_thresh: float = 0.2
166
+ ) -> tuple[NDArray, list[str], list[bool]]:
158
167
  # convert class_labels and list of metadata dicts to dict of ndarrays
159
- metadata_dict: Dict[str, np.ndarray] = {
168
+ metadata_dict: dict[str, NDArray] = {
160
169
  "class_label": np.asarray(class_labels, dtype=int),
161
170
  **{k: np.array([d[k] for d in metadata]) for k in metadata[0]},
162
171
  }
@@ -172,18 +181,35 @@ def preprocess_metadata(class_labels: Sequence[int], metadata: List[Dict]) -> Tu
172
181
 
173
182
  data = np.stack(list(metadata_dict.values()), axis=-1)
174
183
  names = list(metadata_dict.keys())
175
- is_categorical = [infer_categorical(metadata_dict[var], 0.25)[0] for var in names]
184
+ is_categorical = [infer_categorical(metadata_dict[var], cat_thresh)[0] for var in names]
176
185
 
177
186
  return data, names, is_categorical
178
187
 
179
188
 
180
- def minimum_spanning_tree(X: np.ndarray) -> Any:
189
+ def flatten(X: NDArray):
190
+ """
191
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
192
+
193
+ Parameters
194
+ ----------
195
+ X : NDArray, shape - (N, ... )
196
+ Input array
197
+
198
+ Returns
199
+ -------
200
+ NDArray, shape - (N, -1)
201
+ """
202
+
203
+ return X.reshape((X.shape[0], -1))
204
+
205
+
206
+ def minimum_spanning_tree(X: NDArray) -> Any:
181
207
  """
182
208
  Returns the minimum spanning tree from a NumPy image array.
183
209
 
184
210
  Parameters
185
211
  ----------
186
- X: np.ndarray
212
+ X : NDArray
187
213
  Numpy image array
188
214
 
189
215
  Returns
@@ -191,7 +217,7 @@ def minimum_spanning_tree(X: np.ndarray) -> Any:
191
217
  Data representing the minimum spanning tree
192
218
  """
193
219
  # All features belong on second dimension
194
- X = X.reshape((X.shape[0], -1))
220
+ X = flatten(X)
195
221
  # We add a small constant to the distance matrix to ensure scipy interprets
196
222
  # the input graph as fully-connected.
197
223
  dense_eudist = squareform(pdist(X)) + EPSILON
@@ -199,13 +225,13 @@ def minimum_spanning_tree(X: np.ndarray) -> Any:
199
225
  return mst(eudist_csr)
200
226
 
201
227
 
202
- def get_classes_counts(labels: np.ndarray) -> Tuple[int, int]:
228
+ def get_classes_counts(labels: NDArray) -> tuple[int, int]:
203
229
  """
204
230
  Returns the classes and counts of from an array of labels
205
231
 
206
232
  Parameters
207
233
  ----------
208
- label: np.ndarray
234
+ label : NDArray
209
235
  Numpy labels array
210
236
 
211
237
  Returns
@@ -226,17 +252,17 @@ def get_classes_counts(labels: np.ndarray) -> Tuple[int, int]:
226
252
 
227
253
 
228
254
  def compute_neighbors(
229
- A: np.ndarray,
230
- B: np.ndarray,
255
+ A: NDArray,
256
+ B: NDArray,
231
257
  k: int = 1,
232
258
  algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
233
- ) -> np.ndarray:
259
+ ) -> NDArray:
234
260
  """
235
261
  For each sample in A, compute the nearest neighbor in B
236
262
 
237
263
  Parameters
238
264
  ----------
239
- A, B : np.ndarray
265
+ A, B : NDArray
240
266
  The n_samples and n_features respectively
241
267
  k : int
242
268
  The number of neighbors to find
@@ -252,11 +278,24 @@ def compute_neighbors(
252
278
  List:
253
279
  Closest points to each point in A and B
254
280
 
281
+ Raises
282
+ ------
283
+ ValueError
284
+ If algorithm is not "auto", "ball_tree", or "kd_tree"
285
+
255
286
  See Also
256
287
  --------
257
288
  sklearn.neighbors.NearestNeighbors
258
289
  """
259
290
 
291
+ if k < 1:
292
+ raise ValueError("k must be >= 1")
293
+ if algorithm not in ["auto", "ball_tree", "kd_tree"]:
294
+ raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
295
+
296
+ A = flatten(A)
297
+ B = flatten(B)
298
+
260
299
  nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
261
300
  nns = nbrs.kneighbors(A)[1]
262
301
  nns = nns[:, 1:].squeeze()
@@ -266,11 +305,11 @@ def compute_neighbors(
266
305
 
267
306
  class BitDepth(NamedTuple):
268
307
  depth: int
269
- pmin: Union[float, int]
270
- pmax: Union[float, int]
308
+ pmin: float | int
309
+ pmax: float | int
271
310
 
272
311
 
273
- def get_bitdepth(image: np.ndarray) -> BitDepth:
312
+ def get_bitdepth(image: NDArray) -> BitDepth:
274
313
  """
275
314
  Approximates the bit depth of the image using the
276
315
  min and max pixel values.
@@ -283,7 +322,7 @@ def get_bitdepth(image: np.ndarray) -> BitDepth:
283
322
  return BitDepth(depth, 0, 2**depth - 1)
284
323
 
285
324
 
286
- def rescale(image: np.ndarray, depth: int = 1) -> np.ndarray:
325
+ def rescale(image: NDArray, depth: int = 1) -> NDArray:
287
326
  """
288
327
  Rescales the image using the bit depth provided.
289
328
  """
@@ -295,7 +334,7 @@ def rescale(image: np.ndarray, depth: int = 1) -> np.ndarray:
295
334
  return normalized * (2**depth - 1)
296
335
 
297
336
 
298
- def normalize_image_shape(image: np.ndarray) -> np.ndarray:
337
+ def normalize_image_shape(image: NDArray) -> NDArray:
299
338
  """
300
339
  Normalizes the image shape into (C,H,W).
301
340
  """
@@ -311,7 +350,7 @@ def normalize_image_shape(image: np.ndarray) -> np.ndarray:
311
350
  raise ValueError("Images must have 2 or more dimensions.")
312
351
 
313
352
 
314
- def edge_filter(image: np.ndarray, offset: float = 0.5) -> np.ndarray:
353
+ def edge_filter(image: NDArray, offset: float = 0.5) -> NDArray:
315
354
  """
316
355
  Returns the image filtered using a 3x3 edge detection kernel:
317
356
  [[ -1, -1, -1 ],
@@ -323,7 +362,7 @@ def edge_filter(image: np.ndarray, offset: float = 0.5) -> np.ndarray:
323
362
  return edges
324
363
 
325
364
 
326
- def pchash(image: np.ndarray) -> str:
365
+ def pchash(image: NDArray) -> str:
327
366
  """
328
367
  Performs a perceptual hash on an image by resizing to a square NxN image
329
368
  using the Lanczos algorithm where N is 32x32 or the largest multiple of
@@ -334,7 +373,7 @@ def pchash(image: np.ndarray) -> str:
334
373
 
335
374
  Parameters
336
375
  ----------
337
- image : np.ndarray
376
+ image : NDArray
338
377
  An image as a numpy array in CxHxW format
339
378
 
340
379
  Returns
@@ -374,7 +413,7 @@ def pchash(image: np.ndarray) -> str:
374
413
  return hash_hex if hash_hex else "0"
375
414
 
376
415
 
377
- def xxhash(image: np.ndarray) -> str:
416
+ def xxhash(image: NDArray) -> str:
378
417
  """
379
418
  Performs a fast non-cryptographic hash using the xxhash algorithm
380
419
  (xxhash.com) against the image as a flattened bytearray. The hash
@@ -382,7 +421,7 @@ def xxhash(image: np.ndarray) -> str:
382
421
 
383
422
  Parameters
384
423
  ----------
385
- image : np.ndarray
424
+ image : NDArray
386
425
  An image as a numpy array
387
426
 
388
427
  Returns
@@ -1,4 +1,6 @@
1
- from typing import Any, List, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
2
4
 
3
5
  import torch
4
6
  import torch.nn as nn
@@ -14,40 +16,52 @@ def get_images_from_batch(batch: Any) -> Any:
14
16
 
15
17
 
16
18
  class AETrainer:
19
+ """
20
+ A class to train and evaluate an autoencoder model.
21
+
22
+ Parameters
23
+ ----------
24
+ model : nn.Module
25
+ The model to be trained.
26
+ device : str or torch.device, default "auto"
27
+ The hardware device to use for training.
28
+ If "auto", the device will be set to "cuda" if available, otherwise "cpu".
29
+ batch_size : int, default 8
30
+ The number of images to process in a batch.
31
+ """
32
+
17
33
  def __init__(
18
34
  self,
19
35
  model: nn.Module,
20
- device: Union[str, torch.device] = "auto",
36
+ device: str | torch.device = "auto",
21
37
  batch_size: int = 8,
22
38
  ):
23
- """
24
- model : nn.Module
25
- Model to be trained
26
- device : str | torch.device, default "cpu"
27
- Hardware device for model, optimizer, and data to run on
28
- batch_size : int, default 8
29
- Number of images to group together in `torch.utils.data.DataLoader`
30
- """
31
39
  if device == "auto":
32
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
41
  self.device = device
34
42
  self.model = model.to(device)
35
43
  self.batch_size = batch_size
36
44
 
37
- def train(self, dataset: Dataset, epochs: int = 25) -> List[float]:
45
+ def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
38
46
  """
39
- Basic training function for Autoencoder models for reconstruction tasks
47
+ Basic image reconstruction training function for Autoencoder models
40
48
 
41
49
  Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
42
50
 
43
51
  Parameters
44
52
  ----------
45
53
  dataset : Dataset
46
- Torch Dataset containing images in the first return position
54
+ The dataset to train on.
55
+ Torch Dataset containing images in the first return position.
47
56
  epochs : int, default 25
48
57
  Number of full training loops
49
58
 
50
- Note
59
+ Returns
60
+ -------
61
+ List[float]
62
+ A list of average loss values for each epoch.
63
+
64
+ Notes
51
65
  ----
52
66
  To replace this function with a custom function, do
53
67
  AETrainer.train = custom_function
@@ -58,7 +72,7 @@ class AETrainer:
58
72
  opt = Adam(self.model.parameters(), lr=0.001)
59
73
  criterion = nn.MSELoss().to(self.device)
60
74
  # Record loss
61
- loss_history: List[float] = []
75
+ loss_history: list[float] = []
62
76
 
63
77
  for _ in range(epochs):
64
78
  epoch_loss: float = 0
@@ -89,19 +103,20 @@ class AETrainer:
89
103
  @torch.no_grad
90
104
  def eval(self, dataset: Dataset) -> float:
91
105
  """
92
- Basic evaluation function for Autoencoder models for reconstruction tasks
106
+ Basic image reconstruction evaluation function for Autoencoder models
93
107
 
94
- Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
108
+ Uses `torch.nn.MSELoss` as default loss function.
95
109
 
96
110
  Parameters
97
111
  ----------
98
112
  dataset : Dataset
99
- Torch Dataset containing images in the first return position
113
+ The dataset to evaluate on.
114
+ Torch Dataset containing images in the first return position.
100
115
 
101
116
  Returns
102
117
  -------
103
118
  float
104
- Total reconstruction loss over all data
119
+ Total reconstruction loss over the entire dataset
105
120
 
106
121
  Note
107
122
  ----
@@ -124,18 +139,25 @@ class AETrainer:
124
139
  @torch.no_grad
125
140
  def encode(self, dataset: Dataset) -> torch.Tensor:
126
141
  """
127
- Encode data through model if it has an encode attribute,
128
- otherwise passes data through model.forward
142
+ Create image embeddings for the dataset using the model's encoder.
143
+
144
+ If the model has an `encode` method, it will be used; otherwise,
145
+ `model.forward` will be used.
129
146
 
130
147
  Parameters
131
148
  ----------
132
149
  dataset: Dataset
133
- Dataset containing images to be encoded by the model
150
+ The dataset to encode.
151
+ Torch Dataset containing images in the first return position.
134
152
 
135
153
  Returns
136
154
  -------
137
155
  torch.Tensor
138
156
  Data encoded by the model
157
+
158
+ Notes
159
+ -----
160
+ This function should be run after the model has been trained and evaluated.
139
161
  """
140
162
  self.model.eval()
141
163
  dl = DataLoader(dataset, batch_size=self.batch_size)
@@ -155,21 +177,67 @@ class AETrainer:
155
177
 
156
178
 
157
179
  class AriaAutoencoder(nn.Module):
180
+ """
181
+ An autoencoder model with a separate encoder and decoder.
182
+
183
+ Parameters
184
+ ----------
185
+ channels : int, default 3
186
+ Number of input channels
187
+ """
188
+
158
189
  def __init__(self, channels=3):
159
190
  super().__init__()
160
191
  self.encoder = Encoder(channels)
161
192
  self.decoder = Decoder(channels)
162
193
 
163
194
  def forward(self, x):
195
+ """
196
+ Perform a forward pass through the encoder and decoder.
197
+
198
+ Parameters
199
+ ----------
200
+ x : torch.Tensor
201
+ Input tensor
202
+
203
+ Returns
204
+ -------
205
+ torch.Tensor
206
+ The reconstructed output tensor.
207
+ """
164
208
  x = self.encoder(x)
165
209
  x = self.decoder(x)
166
210
  return x
167
211
 
168
212
  def encode(self, x):
213
+ """
214
+ Encode the input tensor using the encoder.
215
+
216
+ Parameters
217
+ ----------
218
+ x : torch.Tensor
219
+ Input tensor
220
+
221
+ Returns
222
+ -------
223
+ torch.Tensor
224
+ The encoded representation of the input tensor.
225
+ """
169
226
  return self.encoder(x)
170
227
 
171
228
 
172
229
  class Encoder(nn.Module):
230
+ """
231
+ A simple encoder to be used in an autoencoder model.
232
+
233
+ This is the encoder used by the AriaAutoencoder model.
234
+
235
+ Parameters
236
+ ----------
237
+ channels : int, default 3
238
+ Number of input channels
239
+ """
240
+
173
241
  def __init__(self, channels=3):
174
242
  super().__init__()
175
243
  self.encoder = nn.Sequential(
@@ -183,10 +251,34 @@ class Encoder(nn.Module):
183
251
  )
184
252
 
185
253
  def forward(self, x):
254
+ """
255
+ Perform a forward pass through the encoder.
256
+
257
+ Parameters
258
+ ----------
259
+ x : torch.Tensor
260
+ Input tensor
261
+
262
+ Returns
263
+ -------
264
+ torch.Tensor
265
+ The encoded representation of the input tensor.
266
+ """
186
267
  return self.encoder(x)
187
268
 
188
269
 
189
270
  class Decoder(nn.Module):
271
+ """
272
+ A simple decoder to be used in an autoencoder model.
273
+
274
+ This is the decoder used by the AriaAutoencoder model.
275
+
276
+ Parameters
277
+ ----------
278
+ channels : int
279
+ Number of output channels
280
+ """
281
+
190
282
  def __init__(self, channels):
191
283
  super().__init__()
192
284
  self.decoder = nn.Sequential(
@@ -199,4 +291,17 @@ class Decoder(nn.Module):
199
291
  )
200
292
 
201
293
  def forward(self, x):
294
+ """
295
+ Perform a forward pass through the decoder.
296
+
297
+ Parameters
298
+ ----------
299
+ x : torch.Tensor
300
+ The encoded tensor.
301
+
302
+ Returns
303
+ -------
304
+ torch.Tensor
305
+ The reconstructed output tensor.
306
+ """
202
307
  return self.decoder(x)