dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +13 -9
- dataeval/_internal/detectors/clusterer.py +63 -49
- dataeval/_internal/detectors/drift/base.py +248 -51
- dataeval/_internal/detectors/drift/cvm.py +28 -26
- dataeval/_internal/detectors/drift/ks.py +31 -28
- dataeval/_internal/detectors/drift/mmd.py +62 -42
- dataeval/_internal/detectors/drift/torch.py +69 -60
- dataeval/_internal/detectors/drift/uncertainty.py +32 -32
- dataeval/_internal/detectors/duplicates.py +67 -31
- dataeval/_internal/detectors/ood/ae.py +15 -29
- dataeval/_internal/detectors/ood/aegmm.py +33 -27
- dataeval/_internal/detectors/ood/base.py +86 -47
- dataeval/_internal/detectors/ood/llr.py +34 -31
- dataeval/_internal/detectors/ood/vae.py +32 -31
- dataeval/_internal/detectors/ood/vaegmm.py +34 -28
- dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
- dataeval/_internal/flags.py +44 -21
- dataeval/_internal/interop.py +5 -3
- dataeval/_internal/metrics/balance.py +42 -5
- dataeval/_internal/metrics/ber.py +11 -8
- dataeval/_internal/metrics/coverage.py +15 -8
- dataeval/_internal/metrics/divergence.py +41 -7
- dataeval/_internal/metrics/diversity.py +57 -19
- dataeval/_internal/metrics/parity.py +141 -66
- dataeval/_internal/metrics/stats.py +330 -313
- dataeval/_internal/metrics/uap.py +33 -4
- dataeval/_internal/metrics/utils.py +79 -40
- dataeval/_internal/models/pytorch/autoencoder.py +127 -22
- dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
- dataeval/_internal/models/tensorflow/gmm.py +4 -2
- dataeval/_internal/models/tensorflow/losses.py +17 -13
- dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
- dataeval/_internal/models/tensorflow/trainer.py +10 -7
- dataeval/_internal/models/tensorflow/utils.py +23 -20
- dataeval/_internal/output.py +85 -0
- dataeval/_internal/utils.py +5 -3
- dataeval/_internal/workflows/sufficiency.py +122 -121
- dataeval/detectors/__init__.py +6 -25
- dataeval/detectors/drift/__init__.py +16 -0
- dataeval/detectors/drift/kernels/__init__.py +6 -0
- dataeval/detectors/drift/updates/__init__.py +3 -0
- dataeval/detectors/linters/__init__.py +5 -0
- dataeval/detectors/ood/__init__.py +11 -0
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +2 -26
- dataeval/metrics/bias/__init__.py +14 -0
- dataeval/metrics/estimators/__init__.py +9 -0
- dataeval/metrics/stats/__init__.py +6 -0
- dataeval/tensorflow/__init__.py +3 -0
- dataeval/tensorflow/loss/__init__.py +3 -0
- dataeval/tensorflow/models/__init__.py +5 -0
- dataeval/tensorflow/recon/__init__.py +3 -0
- dataeval/torch/__init__.py +3 -0
- dataeval/{models/torch → torch/models}/__init__.py +1 -2
- dataeval/torch/trainer/__init__.py +3 -0
- dataeval/utils/__init__.py +3 -6
- dataeval/workflows/__init__.py +2 -4
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
- dataeval-0.66.0.dist-info/RECORD +72 -0
- dataeval/_internal/metrics/base.py +0 -10
- dataeval/models/__init__.py +0 -15
- dataeval/models/tensorflow/__init__.py +0 -6
- dataeval-0.64.0.dist-info/RECORD +0 -60
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -4,15 +4,17 @@ FR Test Statistic based estimate for the upperbound
|
|
4
4
|
average precision using empirical mean precision
|
5
5
|
"""
|
6
6
|
|
7
|
-
from
|
7
|
+
from dataclasses import dataclass
|
8
8
|
|
9
9
|
from numpy.typing import ArrayLike
|
10
10
|
from sklearn.metrics import average_precision_score
|
11
11
|
|
12
12
|
from dataeval._internal.interop import to_numpy
|
13
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
13
14
|
|
14
15
|
|
15
|
-
|
16
|
+
@dataclass(frozen=True)
|
17
|
+
class UAPOutput(OutputMetadata):
|
16
18
|
"""
|
17
19
|
Attributes
|
18
20
|
----------
|
@@ -23,6 +25,7 @@ class UAPOutput(NamedTuple):
|
|
23
25
|
uap: float
|
24
26
|
|
25
27
|
|
28
|
+
@set_metadata("dataeval.metrics")
|
26
29
|
def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
|
27
30
|
"""
|
28
31
|
FR Test Statistic based estimate of the empirical mean precision for
|
@@ -37,13 +40,39 @@ def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
|
|
37
40
|
|
38
41
|
Returns
|
39
42
|
-------
|
40
|
-
|
41
|
-
|
43
|
+
UAPOutput
|
44
|
+
The empirical mean precision estimate, float
|
42
45
|
|
43
46
|
Raises
|
44
47
|
------
|
45
48
|
ValueError
|
46
49
|
If unique classes M < 2
|
50
|
+
|
51
|
+
Notes
|
52
|
+
-----
|
53
|
+
This function calculates the empirical mean precision using the
|
54
|
+
``average_precision_score`` from scikit-learn, weighted by the class distribution.
|
55
|
+
|
56
|
+
Examples
|
57
|
+
--------
|
58
|
+
>>> y_true = np.array([0, 0, 1, 1])
|
59
|
+
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
|
60
|
+
>>> uap(y_true, y_scores)
|
61
|
+
UAPOutput(uap=0.8333333333333333)
|
62
|
+
|
63
|
+
>>> y_true = np.array([0, 0, 1, 1, 2, 2])
|
64
|
+
>>> y_scores = np.array(
|
65
|
+
... [
|
66
|
+
... [0.7, 0.2, 0.1],
|
67
|
+
... [0.4, 0.3, 0.3],
|
68
|
+
... [0.1, 0.8, 0.1],
|
69
|
+
... [0.2, 0.3, 0.5],
|
70
|
+
... [0.4, 0.4, 0.2],
|
71
|
+
... [0.1, 0.2, 0.7],
|
72
|
+
... ]
|
73
|
+
... )
|
74
|
+
>>> uap(y_true, y_scores)
|
75
|
+
UAPOutput(uap=0.7777777777777777)
|
47
76
|
"""
|
48
77
|
|
49
78
|
precision = float(average_precision_score(to_numpy(labels), to_numpy(scores), average="weighted"))
|
@@ -1,7 +1,10 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Literal, NamedTuple, Sequence
|
2
4
|
|
3
5
|
import numpy as np
|
4
6
|
import xxhash as xxh
|
7
|
+
from numpy.typing import NDArray
|
5
8
|
from PIL import Image
|
6
9
|
from scipy.fftpack import dct
|
7
10
|
from scipy.signal import convolve2d
|
@@ -18,22 +21,22 @@ HASH_SIZE = 8
|
|
18
21
|
MAX_FACTOR = 4
|
19
22
|
|
20
23
|
|
21
|
-
def get_method(method_map:
|
24
|
+
def get_method(method_map: dict[str, Callable], method: str) -> Callable:
|
22
25
|
if method not in method_map:
|
23
26
|
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
24
27
|
return method_map[method]
|
25
28
|
|
26
29
|
|
27
30
|
def get_counts(
|
28
|
-
data:
|
29
|
-
) -> tuple[
|
31
|
+
data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
32
|
+
) -> tuple[dict, dict]:
|
30
33
|
"""
|
31
34
|
Initialize dictionary of histogram counts --- treat categorical values
|
32
35
|
as histogram bins.
|
33
36
|
|
34
37
|
Parameters
|
35
38
|
----------
|
36
|
-
subset_mask:
|
39
|
+
subset_mask: NDArray[np.bool_] | None
|
37
40
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
38
41
|
|
39
42
|
Returns
|
@@ -66,24 +69,24 @@ def get_counts(
|
|
66
69
|
|
67
70
|
|
68
71
|
def entropy(
|
69
|
-
data:
|
70
|
-
names:
|
71
|
-
is_categorical:
|
72
|
+
data: NDArray,
|
73
|
+
names: list[str],
|
74
|
+
is_categorical: list[bool],
|
72
75
|
normalized: bool = False,
|
73
|
-
subset_mask:
|
74
|
-
) -> np.
|
76
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
77
|
+
) -> NDArray[np.float64]:
|
75
78
|
"""
|
76
79
|
Meant for use with Bias metrics, Balance, Diversity, ClasswiseBalance,
|
77
80
|
and Classwise Diversity.
|
78
81
|
|
79
|
-
Compute entropy for discrete/categorical variables and
|
80
|
-
histogram binning
|
82
|
+
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
83
|
+
histogram binning.
|
81
84
|
|
82
85
|
Parameters
|
83
86
|
----------
|
84
87
|
normalized: bool
|
85
88
|
Flag that determines whether or not to normalize entropy by log(num_bins)
|
86
|
-
subset_mask:
|
89
|
+
subset_mask: NDArray[np.bool_] | None
|
87
90
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
88
91
|
|
89
92
|
Notes
|
@@ -93,7 +96,7 @@ def entropy(
|
|
93
96
|
|
94
97
|
Returns
|
95
98
|
-------
|
96
|
-
ent: np.
|
99
|
+
ent: NDArray[np.float64]
|
97
100
|
Entropy estimate per column of X
|
98
101
|
|
99
102
|
See Also
|
@@ -119,16 +122,20 @@ def entropy(
|
|
119
122
|
|
120
123
|
|
121
124
|
def get_num_bins(
|
122
|
-
data:
|
123
|
-
) -> np.
|
125
|
+
data: NDArray, names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
126
|
+
) -> NDArray[np.float64]:
|
124
127
|
"""
|
125
128
|
Number of bins or unique values for each metadata factor, used to
|
126
129
|
normalize entropy/diversity.
|
127
130
|
|
128
131
|
Parameters
|
129
132
|
----------
|
130
|
-
subset_mask:
|
133
|
+
subset_mask: NDArray[np.bool_] | None
|
131
134
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
135
|
+
|
136
|
+
Returns
|
137
|
+
-------
|
138
|
+
NDArray[np.float64]
|
132
139
|
"""
|
133
140
|
# likely cached
|
134
141
|
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
@@ -139,7 +146,7 @@ def get_num_bins(
|
|
139
146
|
return num_bins
|
140
147
|
|
141
148
|
|
142
|
-
def infer_categorical(X:
|
149
|
+
def infer_categorical(X: NDArray, threshold: float = 0.2) -> NDArray:
|
143
150
|
"""
|
144
151
|
Compute fraction of feature values that are unique --- intended to be used
|
145
152
|
for inferring whether variables are categorical.
|
@@ -154,9 +161,11 @@ def infer_categorical(X: np.ndarray, threshold: float = 0.5) -> np.ndarray:
|
|
154
161
|
return pct_unique < threshold
|
155
162
|
|
156
163
|
|
157
|
-
def preprocess_metadata(
|
164
|
+
def preprocess_metadata(
|
165
|
+
class_labels: Sequence[int], metadata: list[dict], cat_thresh: float = 0.2
|
166
|
+
) -> tuple[NDArray, list[str], list[bool]]:
|
158
167
|
# convert class_labels and list of metadata dicts to dict of ndarrays
|
159
|
-
metadata_dict:
|
168
|
+
metadata_dict: dict[str, NDArray] = {
|
160
169
|
"class_label": np.asarray(class_labels, dtype=int),
|
161
170
|
**{k: np.array([d[k] for d in metadata]) for k in metadata[0]},
|
162
171
|
}
|
@@ -172,18 +181,35 @@ def preprocess_metadata(class_labels: Sequence[int], metadata: List[Dict]) -> Tu
|
|
172
181
|
|
173
182
|
data = np.stack(list(metadata_dict.values()), axis=-1)
|
174
183
|
names = list(metadata_dict.keys())
|
175
|
-
is_categorical = [infer_categorical(metadata_dict[var],
|
184
|
+
is_categorical = [infer_categorical(metadata_dict[var], cat_thresh)[0] for var in names]
|
176
185
|
|
177
186
|
return data, names, is_categorical
|
178
187
|
|
179
188
|
|
180
|
-
def
|
189
|
+
def flatten(X: NDArray):
|
190
|
+
"""
|
191
|
+
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
X : NDArray, shape - (N, ... )
|
196
|
+
Input array
|
197
|
+
|
198
|
+
Returns
|
199
|
+
-------
|
200
|
+
NDArray, shape - (N, -1)
|
201
|
+
"""
|
202
|
+
|
203
|
+
return X.reshape((X.shape[0], -1))
|
204
|
+
|
205
|
+
|
206
|
+
def minimum_spanning_tree(X: NDArray) -> Any:
|
181
207
|
"""
|
182
208
|
Returns the minimum spanning tree from a NumPy image array.
|
183
209
|
|
184
210
|
Parameters
|
185
211
|
----------
|
186
|
-
X:
|
212
|
+
X : NDArray
|
187
213
|
Numpy image array
|
188
214
|
|
189
215
|
Returns
|
@@ -191,7 +217,7 @@ def minimum_spanning_tree(X: np.ndarray) -> Any:
|
|
191
217
|
Data representing the minimum spanning tree
|
192
218
|
"""
|
193
219
|
# All features belong on second dimension
|
194
|
-
X =
|
220
|
+
X = flatten(X)
|
195
221
|
# We add a small constant to the distance matrix to ensure scipy interprets
|
196
222
|
# the input graph as fully-connected.
|
197
223
|
dense_eudist = squareform(pdist(X)) + EPSILON
|
@@ -199,13 +225,13 @@ def minimum_spanning_tree(X: np.ndarray) -> Any:
|
|
199
225
|
return mst(eudist_csr)
|
200
226
|
|
201
227
|
|
202
|
-
def get_classes_counts(labels:
|
228
|
+
def get_classes_counts(labels: NDArray) -> tuple[int, int]:
|
203
229
|
"""
|
204
230
|
Returns the classes and counts of from an array of labels
|
205
231
|
|
206
232
|
Parameters
|
207
233
|
----------
|
208
|
-
label:
|
234
|
+
label : NDArray
|
209
235
|
Numpy labels array
|
210
236
|
|
211
237
|
Returns
|
@@ -226,17 +252,17 @@ def get_classes_counts(labels: np.ndarray) -> Tuple[int, int]:
|
|
226
252
|
|
227
253
|
|
228
254
|
def compute_neighbors(
|
229
|
-
A:
|
230
|
-
B:
|
255
|
+
A: NDArray,
|
256
|
+
B: NDArray,
|
231
257
|
k: int = 1,
|
232
258
|
algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
|
233
|
-
) ->
|
259
|
+
) -> NDArray:
|
234
260
|
"""
|
235
261
|
For each sample in A, compute the nearest neighbor in B
|
236
262
|
|
237
263
|
Parameters
|
238
264
|
----------
|
239
|
-
A, B :
|
265
|
+
A, B : NDArray
|
240
266
|
The n_samples and n_features respectively
|
241
267
|
k : int
|
242
268
|
The number of neighbors to find
|
@@ -252,11 +278,24 @@ def compute_neighbors(
|
|
252
278
|
List:
|
253
279
|
Closest points to each point in A and B
|
254
280
|
|
281
|
+
Raises
|
282
|
+
------
|
283
|
+
ValueError
|
284
|
+
If algorithm is not "auto", "ball_tree", or "kd_tree"
|
285
|
+
|
255
286
|
See Also
|
256
287
|
--------
|
257
288
|
sklearn.neighbors.NearestNeighbors
|
258
289
|
"""
|
259
290
|
|
291
|
+
if k < 1:
|
292
|
+
raise ValueError("k must be >= 1")
|
293
|
+
if algorithm not in ["auto", "ball_tree", "kd_tree"]:
|
294
|
+
raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
|
295
|
+
|
296
|
+
A = flatten(A)
|
297
|
+
B = flatten(B)
|
298
|
+
|
260
299
|
nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
|
261
300
|
nns = nbrs.kneighbors(A)[1]
|
262
301
|
nns = nns[:, 1:].squeeze()
|
@@ -266,11 +305,11 @@ def compute_neighbors(
|
|
266
305
|
|
267
306
|
class BitDepth(NamedTuple):
|
268
307
|
depth: int
|
269
|
-
pmin:
|
270
|
-
pmax:
|
308
|
+
pmin: float | int
|
309
|
+
pmax: float | int
|
271
310
|
|
272
311
|
|
273
|
-
def get_bitdepth(image:
|
312
|
+
def get_bitdepth(image: NDArray) -> BitDepth:
|
274
313
|
"""
|
275
314
|
Approximates the bit depth of the image using the
|
276
315
|
min and max pixel values.
|
@@ -283,7 +322,7 @@ def get_bitdepth(image: np.ndarray) -> BitDepth:
|
|
283
322
|
return BitDepth(depth, 0, 2**depth - 1)
|
284
323
|
|
285
324
|
|
286
|
-
def rescale(image:
|
325
|
+
def rescale(image: NDArray, depth: int = 1) -> NDArray:
|
287
326
|
"""
|
288
327
|
Rescales the image using the bit depth provided.
|
289
328
|
"""
|
@@ -295,7 +334,7 @@ def rescale(image: np.ndarray, depth: int = 1) -> np.ndarray:
|
|
295
334
|
return normalized * (2**depth - 1)
|
296
335
|
|
297
336
|
|
298
|
-
def normalize_image_shape(image:
|
337
|
+
def normalize_image_shape(image: NDArray) -> NDArray:
|
299
338
|
"""
|
300
339
|
Normalizes the image shape into (C,H,W).
|
301
340
|
"""
|
@@ -311,7 +350,7 @@ def normalize_image_shape(image: np.ndarray) -> np.ndarray:
|
|
311
350
|
raise ValueError("Images must have 2 or more dimensions.")
|
312
351
|
|
313
352
|
|
314
|
-
def edge_filter(image:
|
353
|
+
def edge_filter(image: NDArray, offset: float = 0.5) -> NDArray:
|
315
354
|
"""
|
316
355
|
Returns the image filtered using a 3x3 edge detection kernel:
|
317
356
|
[[ -1, -1, -1 ],
|
@@ -323,7 +362,7 @@ def edge_filter(image: np.ndarray, offset: float = 0.5) -> np.ndarray:
|
|
323
362
|
return edges
|
324
363
|
|
325
364
|
|
326
|
-
def pchash(image:
|
365
|
+
def pchash(image: NDArray) -> str:
|
327
366
|
"""
|
328
367
|
Performs a perceptual hash on an image by resizing to a square NxN image
|
329
368
|
using the Lanczos algorithm where N is 32x32 or the largest multiple of
|
@@ -334,7 +373,7 @@ def pchash(image: np.ndarray) -> str:
|
|
334
373
|
|
335
374
|
Parameters
|
336
375
|
----------
|
337
|
-
image :
|
376
|
+
image : NDArray
|
338
377
|
An image as a numpy array in CxHxW format
|
339
378
|
|
340
379
|
Returns
|
@@ -374,7 +413,7 @@ def pchash(image: np.ndarray) -> str:
|
|
374
413
|
return hash_hex if hash_hex else "0"
|
375
414
|
|
376
415
|
|
377
|
-
def xxhash(image:
|
416
|
+
def xxhash(image: NDArray) -> str:
|
378
417
|
"""
|
379
418
|
Performs a fast non-cryptographic hash using the xxhash algorithm
|
380
419
|
(xxhash.com) against the image as a flattened bytearray. The hash
|
@@ -382,7 +421,7 @@ def xxhash(image: np.ndarray) -> str:
|
|
382
421
|
|
383
422
|
Parameters
|
384
423
|
----------
|
385
|
-
image :
|
424
|
+
image : NDArray
|
386
425
|
An image as a numpy array
|
387
426
|
|
388
427
|
Returns
|
@@ -1,4 +1,6 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any
|
2
4
|
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
@@ -14,40 +16,52 @@ def get_images_from_batch(batch: Any) -> Any:
|
|
14
16
|
|
15
17
|
|
16
18
|
class AETrainer:
|
19
|
+
"""
|
20
|
+
A class to train and evaluate an autoencoder model.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
model : nn.Module
|
25
|
+
The model to be trained.
|
26
|
+
device : str or torch.device, default "auto"
|
27
|
+
The hardware device to use for training.
|
28
|
+
If "auto", the device will be set to "cuda" if available, otherwise "cpu".
|
29
|
+
batch_size : int, default 8
|
30
|
+
The number of images to process in a batch.
|
31
|
+
"""
|
32
|
+
|
17
33
|
def __init__(
|
18
34
|
self,
|
19
35
|
model: nn.Module,
|
20
|
-
device:
|
36
|
+
device: str | torch.device = "auto",
|
21
37
|
batch_size: int = 8,
|
22
38
|
):
|
23
|
-
"""
|
24
|
-
model : nn.Module
|
25
|
-
Model to be trained
|
26
|
-
device : str | torch.device, default "cpu"
|
27
|
-
Hardware device for model, optimizer, and data to run on
|
28
|
-
batch_size : int, default 8
|
29
|
-
Number of images to group together in `torch.utils.data.DataLoader`
|
30
|
-
"""
|
31
39
|
if device == "auto":
|
32
40
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33
41
|
self.device = device
|
34
42
|
self.model = model.to(device)
|
35
43
|
self.batch_size = batch_size
|
36
44
|
|
37
|
-
def train(self, dataset: Dataset, epochs: int = 25) ->
|
45
|
+
def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
|
38
46
|
"""
|
39
|
-
Basic training function for Autoencoder models
|
47
|
+
Basic image reconstruction training function for Autoencoder models
|
40
48
|
|
41
49
|
Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
|
42
50
|
|
43
51
|
Parameters
|
44
52
|
----------
|
45
53
|
dataset : Dataset
|
46
|
-
|
54
|
+
The dataset to train on.
|
55
|
+
Torch Dataset containing images in the first return position.
|
47
56
|
epochs : int, default 25
|
48
57
|
Number of full training loops
|
49
58
|
|
50
|
-
|
59
|
+
Returns
|
60
|
+
-------
|
61
|
+
List[float]
|
62
|
+
A list of average loss values for each epoch.
|
63
|
+
|
64
|
+
Notes
|
51
65
|
----
|
52
66
|
To replace this function with a custom function, do
|
53
67
|
AETrainer.train = custom_function
|
@@ -58,7 +72,7 @@ class AETrainer:
|
|
58
72
|
opt = Adam(self.model.parameters(), lr=0.001)
|
59
73
|
criterion = nn.MSELoss().to(self.device)
|
60
74
|
# Record loss
|
61
|
-
loss_history:
|
75
|
+
loss_history: list[float] = []
|
62
76
|
|
63
77
|
for _ in range(epochs):
|
64
78
|
epoch_loss: float = 0
|
@@ -89,19 +103,20 @@ class AETrainer:
|
|
89
103
|
@torch.no_grad
|
90
104
|
def eval(self, dataset: Dataset) -> float:
|
91
105
|
"""
|
92
|
-
Basic evaluation function for Autoencoder models
|
106
|
+
Basic image reconstruction evaluation function for Autoencoder models
|
93
107
|
|
94
|
-
Uses `torch.
|
108
|
+
Uses `torch.nn.MSELoss` as default loss function.
|
95
109
|
|
96
110
|
Parameters
|
97
111
|
----------
|
98
112
|
dataset : Dataset
|
99
|
-
|
113
|
+
The dataset to evaluate on.
|
114
|
+
Torch Dataset containing images in the first return position.
|
100
115
|
|
101
116
|
Returns
|
102
117
|
-------
|
103
118
|
float
|
104
|
-
Total reconstruction loss over
|
119
|
+
Total reconstruction loss over the entire dataset
|
105
120
|
|
106
121
|
Note
|
107
122
|
----
|
@@ -124,18 +139,25 @@ class AETrainer:
|
|
124
139
|
@torch.no_grad
|
125
140
|
def encode(self, dataset: Dataset) -> torch.Tensor:
|
126
141
|
"""
|
127
|
-
|
128
|
-
|
142
|
+
Create image embeddings for the dataset using the model's encoder.
|
143
|
+
|
144
|
+
If the model has an `encode` method, it will be used; otherwise,
|
145
|
+
`model.forward` will be used.
|
129
146
|
|
130
147
|
Parameters
|
131
148
|
----------
|
132
149
|
dataset: Dataset
|
133
|
-
|
150
|
+
The dataset to encode.
|
151
|
+
Torch Dataset containing images in the first return position.
|
134
152
|
|
135
153
|
Returns
|
136
154
|
-------
|
137
155
|
torch.Tensor
|
138
156
|
Data encoded by the model
|
157
|
+
|
158
|
+
Notes
|
159
|
+
-----
|
160
|
+
This function should be run after the model has been trained and evaluated.
|
139
161
|
"""
|
140
162
|
self.model.eval()
|
141
163
|
dl = DataLoader(dataset, batch_size=self.batch_size)
|
@@ -155,21 +177,67 @@ class AETrainer:
|
|
155
177
|
|
156
178
|
|
157
179
|
class AriaAutoencoder(nn.Module):
|
180
|
+
"""
|
181
|
+
An autoencoder model with a separate encoder and decoder.
|
182
|
+
|
183
|
+
Parameters
|
184
|
+
----------
|
185
|
+
channels : int, default 3
|
186
|
+
Number of input channels
|
187
|
+
"""
|
188
|
+
|
158
189
|
def __init__(self, channels=3):
|
159
190
|
super().__init__()
|
160
191
|
self.encoder = Encoder(channels)
|
161
192
|
self.decoder = Decoder(channels)
|
162
193
|
|
163
194
|
def forward(self, x):
|
195
|
+
"""
|
196
|
+
Perform a forward pass through the encoder and decoder.
|
197
|
+
|
198
|
+
Parameters
|
199
|
+
----------
|
200
|
+
x : torch.Tensor
|
201
|
+
Input tensor
|
202
|
+
|
203
|
+
Returns
|
204
|
+
-------
|
205
|
+
torch.Tensor
|
206
|
+
The reconstructed output tensor.
|
207
|
+
"""
|
164
208
|
x = self.encoder(x)
|
165
209
|
x = self.decoder(x)
|
166
210
|
return x
|
167
211
|
|
168
212
|
def encode(self, x):
|
213
|
+
"""
|
214
|
+
Encode the input tensor using the encoder.
|
215
|
+
|
216
|
+
Parameters
|
217
|
+
----------
|
218
|
+
x : torch.Tensor
|
219
|
+
Input tensor
|
220
|
+
|
221
|
+
Returns
|
222
|
+
-------
|
223
|
+
torch.Tensor
|
224
|
+
The encoded representation of the input tensor.
|
225
|
+
"""
|
169
226
|
return self.encoder(x)
|
170
227
|
|
171
228
|
|
172
229
|
class Encoder(nn.Module):
|
230
|
+
"""
|
231
|
+
A simple encoder to be used in an autoencoder model.
|
232
|
+
|
233
|
+
This is the encoder used by the AriaAutoencoder model.
|
234
|
+
|
235
|
+
Parameters
|
236
|
+
----------
|
237
|
+
channels : int, default 3
|
238
|
+
Number of input channels
|
239
|
+
"""
|
240
|
+
|
173
241
|
def __init__(self, channels=3):
|
174
242
|
super().__init__()
|
175
243
|
self.encoder = nn.Sequential(
|
@@ -183,10 +251,34 @@ class Encoder(nn.Module):
|
|
183
251
|
)
|
184
252
|
|
185
253
|
def forward(self, x):
|
254
|
+
"""
|
255
|
+
Perform a forward pass through the encoder.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
x : torch.Tensor
|
260
|
+
Input tensor
|
261
|
+
|
262
|
+
Returns
|
263
|
+
-------
|
264
|
+
torch.Tensor
|
265
|
+
The encoded representation of the input tensor.
|
266
|
+
"""
|
186
267
|
return self.encoder(x)
|
187
268
|
|
188
269
|
|
189
270
|
class Decoder(nn.Module):
|
271
|
+
"""
|
272
|
+
A simple decoder to be used in an autoencoder model.
|
273
|
+
|
274
|
+
This is the decoder used by the AriaAutoencoder model.
|
275
|
+
|
276
|
+
Parameters
|
277
|
+
----------
|
278
|
+
channels : int
|
279
|
+
Number of output channels
|
280
|
+
"""
|
281
|
+
|
190
282
|
def __init__(self, channels):
|
191
283
|
super().__init__()
|
192
284
|
self.decoder = nn.Sequential(
|
@@ -199,4 +291,17 @@ class Decoder(nn.Module):
|
|
199
291
|
)
|
200
292
|
|
201
293
|
def forward(self, x):
|
294
|
+
"""
|
295
|
+
Perform a forward pass through the decoder.
|
296
|
+
|
297
|
+
Parameters
|
298
|
+
----------
|
299
|
+
x : torch.Tensor
|
300
|
+
The encoded tensor.
|
301
|
+
|
302
|
+
Returns
|
303
|
+
-------
|
304
|
+
torch.Tensor
|
305
|
+
The reconstructed output tensor.
|
306
|
+
"""
|
202
307
|
return self.decoder(x)
|