pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- pg_sui-1.6.16a3.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
- pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +922 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1436 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1121 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
- pgsui/impute/unsupervised/imputers/vae.py +1316 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/METADATA +0 -322
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -1,645 +1,293 @@
|
|
|
1
|
-
import
|
|
2
|
-
import os
|
|
3
|
-
import sys
|
|
4
|
-
import warnings
|
|
1
|
+
from typing import List, Literal
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from snpio.utils.logging import LoggerManager
|
|
10
7
|
|
|
11
|
-
|
|
8
|
+
from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
|
|
9
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
12
10
|
|
|
13
|
-
# Disable can't find cuda .dll errors. Also turns of GPU support.
|
|
14
|
-
tf.config.set_visible_devices([], "GPU")
|
|
15
11
|
|
|
16
|
-
|
|
12
|
+
class Encoder(nn.Module):
|
|
13
|
+
"""The Encoder module of a standard Autoencoder.
|
|
17
14
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
tf.get_logger().setLevel(logging.ERROR)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Monkey patching deprecation utils to supress warnings.
|
|
24
|
-
# noinspection PyUnusedLocal
|
|
25
|
-
def deprecated(
|
|
26
|
-
date, instructions, warn_once=True
|
|
27
|
-
): # pylint: disable=unused-argument
|
|
28
|
-
def deprecated_wrapper(func):
|
|
29
|
-
return func
|
|
30
|
-
|
|
31
|
-
return deprecated_wrapper
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
deprecation.deprecated = deprecated
|
|
35
|
-
|
|
36
|
-
from tensorflow.keras.layers import (
|
|
37
|
-
Dropout,
|
|
38
|
-
Dense,
|
|
39
|
-
Reshape,
|
|
40
|
-
Flatten,
|
|
41
|
-
LeakyReLU,
|
|
42
|
-
PReLU,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
from tensorflow.keras.regularizers import l1_l2
|
|
46
|
-
from tensorflow.keras import backend as K
|
|
47
|
-
|
|
48
|
-
# Custom Modules
|
|
49
|
-
try:
|
|
50
|
-
from ..neural_network_methods import NeuralNetworkMethods
|
|
51
|
-
except (ModuleNotFoundError, ValueError, ImportError):
|
|
52
|
-
from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class Encoder(tf.keras.layers.Layer):
|
|
56
|
-
"""VAE encoder to Encode genotypes to (z_mean, z_log_var, z)."""
|
|
15
|
+
This module defines the encoder network, which takes high-dimensional input data and maps it to a deterministic, low-dimensional latent representation. The architecture consists of a series of fully-connected hidden layers that progressively compress the flattened input data into a single latent vector, `z`.
|
|
16
|
+
"""
|
|
57
17
|
|
|
58
18
|
def __init__(
|
|
59
19
|
self,
|
|
60
|
-
n_features,
|
|
61
|
-
num_classes,
|
|
62
|
-
latent_dim,
|
|
63
|
-
hidden_layer_sizes,
|
|
64
|
-
dropout_rate,
|
|
65
|
-
activation,
|
|
66
|
-
kernel_initializer,
|
|
67
|
-
kernel_regularizer,
|
|
68
|
-
beta=K.variable(0.0),
|
|
69
|
-
name="Encoder",
|
|
70
|
-
**kwargs,
|
|
20
|
+
n_features: int,
|
|
21
|
+
num_classes: int,
|
|
22
|
+
latent_dim: int,
|
|
23
|
+
hidden_layer_sizes: List[int],
|
|
24
|
+
dropout_rate: float,
|
|
25
|
+
activation: torch.nn.Module,
|
|
71
26
|
):
|
|
72
|
-
|
|
27
|
+
"""Initializes the Encoder module.
|
|
73
28
|
|
|
74
|
-
|
|
29
|
+
This class defines the encoder network, which takes high-dimensional input data and maps it to a deterministic, low-dimensional latent representation. The architecture consists of a series of fully-connected hidden layers that progressively compress the flattened input data into a single latent vector, `z`.
|
|
75
30
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
31
|
+
Args:
|
|
32
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
33
|
+
num_classes (int): Number of genotype states per locus (2 for haploid, 3 for diploid in practice).
|
|
34
|
+
latent_dim (int): The dimensionality of the output latent space.
|
|
35
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
36
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
37
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
38
|
+
"""
|
|
39
|
+
super(Encoder, self).__init__()
|
|
40
|
+
self.flatten = nn.Flatten()
|
|
80
41
|
|
|
81
|
-
|
|
82
|
-
|
|
42
|
+
layers = []
|
|
43
|
+
input_dim = n_features * num_classes
|
|
44
|
+
for hidden_size in hidden_layer_sizes:
|
|
45
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
46
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
47
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
48
|
+
layers.append(activation)
|
|
49
|
+
input_dim = hidden_size
|
|
83
50
|
|
|
84
|
-
self.
|
|
85
|
-
|
|
86
|
-
input_shape=(n_features * num_classes,),
|
|
87
|
-
activation=activation,
|
|
88
|
-
kernel_initializer=kernel_initializer,
|
|
89
|
-
kernel_regularizer=kernel_regularizer,
|
|
90
|
-
name="Encoder1",
|
|
91
|
-
)
|
|
51
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
52
|
+
self.dense_z = nn.Linear(input_dim, latent_dim)
|
|
92
53
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
hidden_layer_sizes[1],
|
|
96
|
-
activation=activation,
|
|
97
|
-
kernel_initializer=kernel_initializer,
|
|
98
|
-
kernel_regularizer=kernel_regularizer,
|
|
99
|
-
name="Encoder2",
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
if len(hidden_layer_sizes) >= 3:
|
|
103
|
-
self.dense3 = Dense(
|
|
104
|
-
hidden_layer_sizes[2],
|
|
105
|
-
activation=activation,
|
|
106
|
-
kernel_initializer=kernel_initializer,
|
|
107
|
-
kernel_regularizer=kernel_regularizer,
|
|
108
|
-
name="Encoder3",
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
if len(hidden_layer_sizes) >= 4:
|
|
112
|
-
self.dense4 = Dense(
|
|
113
|
-
hidden_layer_sizes[3],
|
|
114
|
-
activation=activation,
|
|
115
|
-
kernel_initializer=kernel_initializer,
|
|
116
|
-
kernel_regularizer=kernel_regularizer,
|
|
117
|
-
name="Encoder4",
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
if len(hidden_layer_sizes) == 5:
|
|
121
|
-
self.dense5 = Dense(
|
|
122
|
-
hidden_layer_sizes[4],
|
|
123
|
-
activation=activation,
|
|
124
|
-
kernel_initializer=kernel_initializer,
|
|
125
|
-
kernel_regularizer=kernel_regularizer,
|
|
126
|
-
name="Encoder5",
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
self.dense_latent = Dense(
|
|
130
|
-
latent_dim,
|
|
131
|
-
activation=activation,
|
|
132
|
-
kernel_initializer=kernel_initializer,
|
|
133
|
-
kernel_regularizer=kernel_regularizer,
|
|
134
|
-
name="Encoder5",
|
|
135
|
-
)
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
"""Performs the forward pass through the encoder.
|
|
136
56
|
|
|
137
|
-
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
138
59
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
x = self.
|
|
143
|
-
x = self.
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
x = self.dropout_layer(x, training=training)
|
|
147
|
-
if self.dense3 is not None:
|
|
148
|
-
x = self.dense3(x)
|
|
149
|
-
x = self.dropout_layer(x, training=training)
|
|
150
|
-
if self.dense4 is not None:
|
|
151
|
-
x = self.dense4(x)
|
|
152
|
-
x = self.dropout_layer(x, training=training)
|
|
153
|
-
if self.dense5 is not None:
|
|
154
|
-
x = self.dense5(x)
|
|
155
|
-
x = self.dropout_layer(x, training=training)
|
|
60
|
+
Returns:
|
|
61
|
+
torch.Tensor: The latent representation `z` of shape `(batch_size, latent_dim)`.
|
|
62
|
+
"""
|
|
63
|
+
x = self.flatten(x)
|
|
64
|
+
x = self.hidden_layers(x)
|
|
65
|
+
z = self.dense_z(x)
|
|
66
|
+
return z
|
|
156
67
|
|
|
157
|
-
return self.dense_latent(x)
|
|
158
68
|
|
|
69
|
+
class Decoder(nn.Module):
|
|
70
|
+
"""The Decoder module of a standard Autoencoder.
|
|
159
71
|
|
|
160
|
-
|
|
161
|
-
"""
|
|
72
|
+
This module defines the decoder network, which takes a deterministic latent vector and maps it back to the high-dimensional data space, aiming to reconstruct the original input. The architecture typically mirrors the encoder, consisting of a series of fully-connected hidden layers that progressively expand the representation, followed by a final linear layer to produce the reconstructed data.
|
|
73
|
+
"""
|
|
162
74
|
|
|
163
75
|
def __init__(
|
|
164
76
|
self,
|
|
165
|
-
n_features,
|
|
166
|
-
num_classes,
|
|
167
|
-
latent_dim,
|
|
168
|
-
hidden_layer_sizes,
|
|
169
|
-
dropout_rate,
|
|
170
|
-
activation,
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
name="Decoder",
|
|
174
|
-
**kwargs,
|
|
175
|
-
):
|
|
176
|
-
super(Decoder, self).__init__(name=name, **kwargs)
|
|
177
|
-
|
|
178
|
-
self.dense2 = None
|
|
179
|
-
self.dense3 = None
|
|
180
|
-
self.dense4 = None
|
|
181
|
-
self.dense5 = None
|
|
182
|
-
|
|
183
|
-
self.dense1 = Dense(
|
|
184
|
-
hidden_layer_sizes[0],
|
|
185
|
-
input_shape=(latent_dim,),
|
|
186
|
-
activation=activation,
|
|
187
|
-
kernel_initializer=kernel_initializer,
|
|
188
|
-
kernel_regularizer=kernel_regularizer,
|
|
189
|
-
name="Decoder1",
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
if len(hidden_layer_sizes) >= 2:
|
|
193
|
-
self.dense2 = Dense(
|
|
194
|
-
hidden_layer_sizes[1],
|
|
195
|
-
activation=activation,
|
|
196
|
-
kernel_initializer=kernel_initializer,
|
|
197
|
-
kernel_regularizer=kernel_regularizer,
|
|
198
|
-
name="Decoder2",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
if len(hidden_layer_sizes) >= 3:
|
|
202
|
-
self.dense3 = Dense(
|
|
203
|
-
hidden_layer_sizes[2],
|
|
204
|
-
activation=activation,
|
|
205
|
-
kernel_initializer=kernel_initializer,
|
|
206
|
-
kernel_regularizer=kernel_regularizer,
|
|
207
|
-
name="Decoder3",
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
if len(hidden_layer_sizes) >= 4:
|
|
211
|
-
self.dense4 = Dense(
|
|
212
|
-
hidden_layer_sizes[3],
|
|
213
|
-
activation=activation,
|
|
214
|
-
kernel_initializer=kernel_initializer,
|
|
215
|
-
kernel_regularizer=kernel_regularizer,
|
|
216
|
-
name="Decoder4",
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
if len(hidden_layer_sizes) == 5:
|
|
220
|
-
self.dense5 = Dense(
|
|
221
|
-
hidden_layer_sizes[4],
|
|
222
|
-
activation=activation,
|
|
223
|
-
kernel_initializer=kernel_initializer,
|
|
224
|
-
kernel_regularizer=kernel_regularizer,
|
|
225
|
-
name="Decoder5",
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
# No activation for final layer.
|
|
229
|
-
self.dense_output = Dense(
|
|
230
|
-
n_features * num_classes,
|
|
231
|
-
kernel_initializer=kernel_initializer,
|
|
232
|
-
kernel_regularizer=kernel_regularizer,
|
|
233
|
-
activation=None,
|
|
234
|
-
name="Decoder6",
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
self.rshp = Reshape((n_features, num_classes))
|
|
238
|
-
self.dropout_layer = Dropout(dropout_rate)
|
|
239
|
-
|
|
240
|
-
def call(self, inputs, training=None):
|
|
241
|
-
"""Forward pass through model."""
|
|
242
|
-
x = self.dense1(inputs)
|
|
243
|
-
x = self.dropout_layer(x, training=training)
|
|
244
|
-
if self.dense2 is not None:
|
|
245
|
-
x = self.dense2(x)
|
|
246
|
-
x = self.dropout_layer(x, training=training)
|
|
247
|
-
if self.dense3 is not None:
|
|
248
|
-
x = self.dense3(x)
|
|
249
|
-
x = self.dropout_layer(x, training=training)
|
|
250
|
-
if self.dense4 is not None:
|
|
251
|
-
x = self.dense4(x)
|
|
252
|
-
x = self.dropout_layer(x, training=training)
|
|
253
|
-
if self.dense5 is not None:
|
|
254
|
-
x = self.dense5(x)
|
|
255
|
-
x = self.dropout_layer(x, training=training)
|
|
256
|
-
|
|
257
|
-
x = self.dense_output(x)
|
|
258
|
-
return self.rshp(x)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
class AutoEncoderModel(tf.keras.Model):
|
|
262
|
-
"""Standard AutoEncoder model to impute missing data.
|
|
77
|
+
n_features: int,
|
|
78
|
+
num_classes: int,
|
|
79
|
+
latent_dim: int,
|
|
80
|
+
hidden_layer_sizes: List[int],
|
|
81
|
+
dropout_rate: float,
|
|
82
|
+
activation: torch.nn.Module,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Initializes the Decoder module.
|
|
263
85
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
86
|
+
Args:
|
|
87
|
+
n_features (int): The number of features in the output data (e.g., SNPs).
|
|
88
|
+
num_classes (int): Number of genotype states per locus (2 or 3 in practice).
|
|
89
|
+
latent_dim (int): The dimensionality of the input latent space.
|
|
90
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer (typically the reverse of the encoder's).
|
|
91
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
92
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
93
|
+
"""
|
|
94
|
+
super(Decoder, self).__init__()
|
|
270
95
|
|
|
271
|
-
|
|
96
|
+
layers = []
|
|
97
|
+
input_dim = latent_dim
|
|
98
|
+
for hidden_size in hidden_layer_sizes:
|
|
99
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
100
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
101
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
102
|
+
layers.append(activation)
|
|
103
|
+
input_dim = hidden_size
|
|
272
104
|
|
|
273
|
-
|
|
105
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
106
|
+
output_dim = n_features * num_classes
|
|
107
|
+
self.dense_output = nn.Linear(input_dim, output_dim)
|
|
108
|
+
self.reshape = (n_features, num_classes)
|
|
274
109
|
|
|
275
|
-
|
|
110
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
"""Performs the forward pass through the decoder.
|
|
276
112
|
|
|
277
|
-
|
|
113
|
+
Args:
|
|
114
|
+
x (torch.Tensor): The input latent tensor of shape `(batch_size, latent_dim)`.
|
|
278
115
|
|
|
279
|
-
|
|
116
|
+
Returns:
|
|
117
|
+
torch.Tensor: The reconstructed output data of shape `(batch_size, n_features, num_classes)`.
|
|
118
|
+
"""
|
|
119
|
+
x = self.hidden_layers(x)
|
|
120
|
+
x = self.dense_output(x)
|
|
121
|
+
return x.view(-1, *self.reshape)
|
|
280
122
|
|
|
281
|
-
l1_penalty (float, optional): l1_penalty to use for regularization. Defaults to 1e-6.
|
|
282
123
|
|
|
283
|
-
|
|
124
|
+
class AutoencoderModel(nn.Module):
|
|
125
|
+
"""A standard Autoencoder (AE) model for imputation.
|
|
284
126
|
|
|
285
|
-
|
|
127
|
+
This class combines an `Encoder` and a `Decoder` to form a standard autoencoder. The model is trained to learn a compressed, low-dimensional representation of the input data and then reconstruct it as accurately as possible. It is particularly useful for unsupervised dimensionality reduction and data imputation.
|
|
286
128
|
|
|
287
|
-
|
|
129
|
+
**Model Architecture and Objective:**
|
|
288
130
|
|
|
289
|
-
|
|
131
|
+
The autoencoder consists of two parts: an encoder, $f_{\theta}$, and a decoder, $g_{\phi}$.
|
|
132
|
+
1. The **encoder** maps the input data $x$ to a latent representation $z$:
|
|
133
|
+
$$
|
|
134
|
+
z = f_{\theta}(x)
|
|
135
|
+
$$
|
|
136
|
+
2. The **decoder** reconstructs the data $\hat{x}$ from the latent representation:
|
|
137
|
+
$$
|
|
138
|
+
\hat{x} = g_{\phi}(z)
|
|
139
|
+
$$
|
|
290
140
|
|
|
291
|
-
|
|
292
|
-
ValueError: Maximum number of hidden layers (5) was exceeded.
|
|
141
|
+
The model is trained by minimizing a reconstruction loss, $L(x, \hat{x})$, which measures the dissimilarity between the original input and the reconstructed output. This implementation uses a `MaskedFocalLoss` to handle missing values and class imbalance effectively.
|
|
293
142
|
"""
|
|
294
143
|
|
|
295
144
|
def __init__(
|
|
296
145
|
self,
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
missing_mask=None,
|
|
310
|
-
num_classes=3,
|
|
146
|
+
n_features: int,
|
|
147
|
+
prefix: str,
|
|
148
|
+
*,
|
|
149
|
+
num_classes: int = 4,
|
|
150
|
+
hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
|
|
151
|
+
latent_dim: int = 2,
|
|
152
|
+
dropout_rate: float = 0.2,
|
|
153
|
+
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
|
|
154
|
+
gamma: float = 2.0,
|
|
155
|
+
device: Literal["cpu", "gpu", "mps"] = "cpu",
|
|
156
|
+
verbose: bool = False,
|
|
157
|
+
debug: bool = False,
|
|
311
158
|
):
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
self.nn_ = NeuralNetworkMethods()
|
|
315
|
-
self.categorical_accuracy = self.nn_.make_masked_categorical_accuracy()
|
|
159
|
+
"""Initializes the AutoencoderModel.
|
|
316
160
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
self.
|
|
331
|
-
n_features = self.n_features
|
|
332
|
-
|
|
333
|
-
self.n_components = n_components
|
|
334
|
-
self.weights_initializer = weights_initializer
|
|
335
|
-
self.hidden_layer_sizes = hidden_layer_sizes
|
|
336
|
-
self.num_hidden_layers = num_hidden_layers
|
|
337
|
-
self.hidden_activation = hidden_activation
|
|
338
|
-
self.l1_penalty = l1_penalty
|
|
339
|
-
self.l2_penalty = l2_penalty
|
|
340
|
-
self.dropout_rate = dropout_rate
|
|
341
|
-
self.sample_weight = sample_weight
|
|
161
|
+
Args:
|
|
162
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
163
|
+
prefix (str): A prefix used for logging.
|
|
164
|
+
num_classes (int): Number of genotype states per locus. Defaults to 4 for backward compatibility, but the genotype imputers pass 2 (haploid) or 3 (diploid).
|
|
165
|
+
hidden_layer_sizes (List[int] | np.ndarray): A list of integers specifying the size of each hidden layer in the encoder. The decoder will use the reverse of this structure. Defaults to [128, 64].
|
|
166
|
+
latent_dim (int): The dimensionality of the latent space (bottleneck). Defaults to 2.
|
|
167
|
+
dropout_rate (float): The dropout rate for regularization in hidden layers. Defaults to 0.2.
|
|
168
|
+
activation (Literal["relu", "elu", "selu", "leaky_relu"]): The name of the activation function for hidden layers. Defaults to "relu".
|
|
169
|
+
gamma (float): The focusing parameter for the focal loss function. Defaults to 2.0.
|
|
170
|
+
device (Literal["cpu", "gpu", "mps"]): The device to run the model on.
|
|
171
|
+
verbose (bool): If True, enables detailed logging.
|
|
172
|
+
debug (bool): If True, enables debug mode.
|
|
173
|
+
"""
|
|
174
|
+
super(AutoencoderModel, self).__init__()
|
|
342
175
|
self.num_classes = num_classes
|
|
176
|
+
self.gamma = gamma
|
|
177
|
+
self.device = device
|
|
343
178
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
hidden_layer_sizes = nn.validate_hidden_layers(
|
|
347
|
-
self.hidden_layer_sizes, self.num_hidden_layers
|
|
179
|
+
logman = LoggerManager(
|
|
180
|
+
name=__name__, prefix=prefix, verbose=verbose, debug=debug
|
|
348
181
|
)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
n_features, self.n_components, hidden_layer_sizes, vae=True
|
|
182
|
+
self.logger = configure_logger(
|
|
183
|
+
logman.get_logger(), verbose=verbose, debug=debug
|
|
352
184
|
)
|
|
353
185
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
if self.l1_penalty == 0.0 and self.l2_penalty == 0.0:
|
|
357
|
-
kernel_regularizer = None
|
|
358
|
-
else:
|
|
359
|
-
kernel_regularizer = l1_l2(self.l1_penalty, self.l2_penalty)
|
|
360
|
-
|
|
361
|
-
kernel_initializer = self.weights_initializer
|
|
362
|
-
|
|
363
|
-
if self.hidden_activation.lower() == "leaky_relu":
|
|
364
|
-
activation = LeakyReLU(alpha=0.01)
|
|
365
|
-
|
|
366
|
-
elif self.hidden_activation.lower() == "prelu":
|
|
367
|
-
activation = PReLU()
|
|
368
|
-
|
|
369
|
-
elif self.hidden_activation.lower() == "selu":
|
|
370
|
-
activation = "selu"
|
|
371
|
-
kernel_initializer = "lecun_normal"
|
|
186
|
+
activation_module = self._resolve_activation(activation)
|
|
372
187
|
|
|
188
|
+
if isinstance(hidden_layer_sizes, np.ndarray):
|
|
189
|
+
hls = hidden_layer_sizes.tolist()
|
|
373
190
|
else:
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
if num_hidden_layers > 5:
|
|
377
|
-
raise ValueError(
|
|
378
|
-
f"The maximum number of hidden layers is 5, but got "
|
|
379
|
-
f"{num_hidden_layers}"
|
|
380
|
-
)
|
|
191
|
+
hls = hidden_layer_sizes
|
|
381
192
|
|
|
382
193
|
self.encoder = Encoder(
|
|
383
194
|
n_features,
|
|
384
195
|
self.num_classes,
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
kernel_initializer,
|
|
390
|
-
kernel_regularizer,
|
|
196
|
+
latent_dim,
|
|
197
|
+
hls,
|
|
198
|
+
dropout_rate,
|
|
199
|
+
activation_module,
|
|
391
200
|
)
|
|
392
201
|
|
|
393
|
-
|
|
394
|
-
|
|
202
|
+
decoder_layer_sizes = list(reversed(hls))
|
|
395
203
|
self.decoder = Decoder(
|
|
396
204
|
n_features,
|
|
397
205
|
self.num_classes,
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
kernel_initializer,
|
|
403
|
-
kernel_regularizer,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
def call(self, inputs, training=None):
|
|
407
|
-
"""Forward pass through model."""
|
|
408
|
-
x = self.encoder(inputs)
|
|
409
|
-
return self.decoder(x)
|
|
410
|
-
|
|
411
|
-
def model(self):
|
|
412
|
-
"""To allow model.summary().summar() to be called."""
|
|
413
|
-
x = tf.keras.Input(shape=(self.n_features, self.num_classes))
|
|
414
|
-
return tf.keras.Model(inputs=[x], outputs=self.call(x))
|
|
415
|
-
|
|
416
|
-
def set_model_outputs(self):
|
|
417
|
-
"""Set expected model outputs."""
|
|
418
|
-
x = tf.keras.Input(shape=(self.n_features, self.num_classes))
|
|
419
|
-
model = tf.keras.Model(inputs=[x], outputs=self.call(x))
|
|
420
|
-
self.outputs = model.outputs
|
|
421
|
-
|
|
422
|
-
@property
|
|
423
|
-
def metrics(self):
|
|
424
|
-
return [
|
|
425
|
-
self.total_loss_tracker,
|
|
426
|
-
self.reconstruction_loss_tracker,
|
|
427
|
-
self.accuracy_tracker,
|
|
428
|
-
]
|
|
429
|
-
|
|
430
|
-
@tf.function
|
|
431
|
-
def train_step(self, data):
|
|
432
|
-
y = self._y
|
|
433
|
-
|
|
434
|
-
(
|
|
435
|
-
y_true,
|
|
436
|
-
sample_weight,
|
|
437
|
-
missing_mask,
|
|
438
|
-
) = self.nn_.prepare_training_batches(
|
|
439
|
-
y,
|
|
440
|
-
y,
|
|
441
|
-
self._batch_size,
|
|
442
|
-
self._batch_idx,
|
|
443
|
-
True,
|
|
444
|
-
self.n_components,
|
|
445
|
-
self._sample_weight,
|
|
446
|
-
self._missing_mask,
|
|
447
|
-
ubp=False,
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
if sample_weight is not None:
|
|
451
|
-
sample_weight_masked = tf.convert_to_tensor(
|
|
452
|
-
sample_weight[~missing_mask], dtype=tf.float32
|
|
453
|
-
)
|
|
454
|
-
else:
|
|
455
|
-
sample_weight_masked = None
|
|
456
|
-
|
|
457
|
-
y_true_masked = tf.boolean_mask(
|
|
458
|
-
tf.convert_to_tensor(y_true, dtype=tf.float32),
|
|
459
|
-
tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
with tf.GradientTape() as tape:
|
|
463
|
-
reconstruction = self(y_true, training=True)
|
|
464
|
-
|
|
465
|
-
y_pred_masked = tf.boolean_mask(
|
|
466
|
-
reconstruction, tf.reduce_any(tf.not_equal(y_true, -1), axis=2)
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
# Returns binary crossentropy loss.
|
|
470
|
-
reconstruction_loss = self.compiled_loss(
|
|
471
|
-
y_true_masked,
|
|
472
|
-
y_pred_masked,
|
|
473
|
-
sample_weight=sample_weight_masked,
|
|
474
|
-
)
|
|
475
|
-
|
|
476
|
-
regularization_loss = sum(self.losses)
|
|
477
|
-
|
|
478
|
-
total_loss = reconstruction_loss + regularization_loss
|
|
479
|
-
|
|
480
|
-
grads = tape.gradient(total_loss, self.trainable_weights)
|
|
481
|
-
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
|
|
482
|
-
self.total_loss_tracker.update_state(total_loss)
|
|
483
|
-
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
|
484
|
-
|
|
485
|
-
### NOTE: If you get the error, "'tuple' object has no attribute
|
|
486
|
-
### 'rank', then convert y_true to a tensor object."
|
|
487
|
-
# self.compiled_metrics.update_state(
|
|
488
|
-
self.accuracy_tracker.update_state(
|
|
489
|
-
self.categorical_accuracy(
|
|
490
|
-
y_true_masked,
|
|
491
|
-
y_pred_masked,
|
|
492
|
-
sample_weight=sample_weight_masked,
|
|
493
|
-
)
|
|
206
|
+
latent_dim,
|
|
207
|
+
decoder_layer_sizes,
|
|
208
|
+
dropout_rate,
|
|
209
|
+
activation_module,
|
|
494
210
|
)
|
|
495
211
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
|
499
|
-
"accuracy": self.accuracy_tracker.result(),
|
|
500
|
-
}
|
|
501
|
-
|
|
502
|
-
@tf.function
|
|
503
|
-
def test_step(self, data):
|
|
504
|
-
"""Custom evaluation loop for one step (=batch) in a single epoch.
|
|
505
|
-
|
|
506
|
-
This function will evaluate on a batch of samples (rows), which can be adjusted with the ``batch_size`` parameter from the estimator.
|
|
212
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
213
|
+
"""Performs the forward pass through the full Autoencoder model.
|
|
507
214
|
|
|
508
215
|
Args:
|
|
509
|
-
|
|
216
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
510
217
|
|
|
511
218
|
Returns:
|
|
512
|
-
|
|
219
|
+
torch.Tensor: The reconstructed data tensor.
|
|
513
220
|
"""
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
y_true,
|
|
518
|
-
sample_weight,
|
|
519
|
-
missing_mask,
|
|
520
|
-
) = self.nn_.prepare_training_batches(
|
|
521
|
-
y,
|
|
522
|
-
y,
|
|
523
|
-
self._batch_size,
|
|
524
|
-
self._batch_idx,
|
|
525
|
-
True,
|
|
526
|
-
self.n_components,
|
|
527
|
-
self._sample_weight,
|
|
528
|
-
self._missing_mask,
|
|
529
|
-
ubp=False,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
if sample_weight is not None:
|
|
533
|
-
sample_weight_masked = tf.convert_to_tensor(
|
|
534
|
-
sample_weight[~missing_mask], dtype=tf.float32
|
|
535
|
-
)
|
|
536
|
-
else:
|
|
537
|
-
sample_weight_masked = None
|
|
538
|
-
|
|
539
|
-
y_true_masked = tf.boolean_mask(
|
|
540
|
-
tf.convert_to_tensor(y_true, dtype=tf.float32),
|
|
541
|
-
tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
|
|
542
|
-
)
|
|
543
|
-
|
|
544
|
-
reconstruction = self(y_true, training=False)
|
|
221
|
+
z = self.encoder(x)
|
|
222
|
+
reconstruction = self.decoder(z)
|
|
223
|
+
return reconstruction
|
|
545
224
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
)
|
|
555
|
-
|
|
556
|
-
regularization_loss = sum(self.losses)
|
|
557
|
-
|
|
558
|
-
total_loss = reconstruction_loss + regularization_loss
|
|
559
|
-
|
|
560
|
-
self.accuracy_tracker.update_state(
|
|
561
|
-
self.categorical_accuracy(
|
|
562
|
-
y_true_masked,
|
|
563
|
-
y_pred_masked,
|
|
564
|
-
sample_weight=sample_weight_masked,
|
|
565
|
-
)
|
|
566
|
-
)
|
|
225
|
+
def compute_loss(
|
|
226
|
+
self,
|
|
227
|
+
reconstruction: torch.Tensor,
|
|
228
|
+
y: torch.Tensor,
|
|
229
|
+
mask: torch.Tensor | None = None,
|
|
230
|
+
class_weights: torch.Tensor | None = None,
|
|
231
|
+
) -> torch.Tensor:
|
|
232
|
+
"""Computes the reconstruction loss for the Autoencoder model.
|
|
567
233
|
|
|
568
|
-
|
|
569
|
-
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
|
234
|
+
This method calculates the reconstruction loss using a masked focal loss, which is suitable for categorical data with missing values and class imbalance.
|
|
570
235
|
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
236
|
+
Args:
|
|
237
|
+
reconstruction (torch.Tensor): The reconstructed output (logits) from the model's forward pass.
|
|
238
|
+
y (torch.Tensor): The target data tensor, expected to be one-hot encoded. It is converted to class indices internally for the loss calculation.
|
|
239
|
+
mask (torch.Tensor | None): A boolean mask to exclude missing values from the loss calculation.
|
|
240
|
+
class_weights (torch.Tensor | None): Weights to apply to each class in the loss to handle imbalance.
|
|
576
241
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
"""Batch (=step) size per epoch.
|
|
580
|
-
:noindex:
|
|
242
|
+
Returns:
|
|
243
|
+
torch.Tensor: The computed scalar loss value.
|
|
581
244
|
"""
|
|
582
|
-
|
|
245
|
+
if class_weights is None:
|
|
246
|
+
class_weights = torch.ones(self.num_classes, device=y.device)
|
|
583
247
|
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
"""Current batch (=step) index.
|
|
587
|
-
:noindex:
|
|
588
|
-
"""
|
|
589
|
-
return self._batch_idx
|
|
248
|
+
logits_flat = reconstruction.view(-1, self.num_classes)
|
|
249
|
+
targets_flat = torch.argmax(y, dim=-1).view(-1)
|
|
590
250
|
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
"""
|
|
596
|
-
return self._y
|
|
251
|
+
if mask is None:
|
|
252
|
+
mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
|
|
253
|
+
else:
|
|
254
|
+
mask_flat = mask.view(-1)
|
|
597
255
|
|
|
598
|
-
|
|
599
|
-
def missing_mask(self):
|
|
600
|
-
"""Missing mask of shape (y.shape[0], y.shape[1])
|
|
601
|
-
:noindex:
|
|
602
|
-
"""
|
|
603
|
-
return self._missing_mask
|
|
256
|
+
criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
|
|
604
257
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
return self._sample_weight
|
|
258
|
+
reconstruction_loss = criterion(
|
|
259
|
+
logits_flat.to(self.device),
|
|
260
|
+
targets_flat.to(self.device),
|
|
261
|
+
valid_mask=mask_flat.to(self.device),
|
|
262
|
+
)
|
|
611
263
|
|
|
612
|
-
|
|
613
|
-
def batch_size(self, value):
|
|
614
|
-
"""Set batch_size parameter.
|
|
615
|
-
:noindex:
|
|
616
|
-
"""
|
|
617
|
-
self._batch_size = int(value)
|
|
264
|
+
return reconstruction_loss
|
|
618
265
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
"""
|
|
624
|
-
self._batch_idx = int(value)
|
|
266
|
+
def _resolve_activation(
|
|
267
|
+
self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
|
|
268
|
+
) -> torch.nn.Module:
|
|
269
|
+
"""Resolves an activation function module from a string name.
|
|
625
270
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
"""Set y after each epoch.
|
|
629
|
-
:noindex:
|
|
630
|
-
"""
|
|
631
|
-
self._y = value
|
|
271
|
+
Args:
|
|
272
|
+
activation (Literal["relu", "elu", "leaky_relu", "selu"]): The name of the activation function.
|
|
632
273
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
"""Set missing_mask after each epoch.
|
|
636
|
-
:noindex:
|
|
637
|
-
"""
|
|
638
|
-
self._missing_mask = value
|
|
274
|
+
Returns:
|
|
275
|
+
torch.nn.Module: The corresponding instantiated PyTorch activation function module.
|
|
639
276
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
"""Set sample_weight after each epoch.
|
|
643
|
-
:noindex:
|
|
277
|
+
Raises:
|
|
278
|
+
ValueError: If the provided activation name is not supported.
|
|
644
279
|
"""
|
|
645
|
-
|
|
280
|
+
act: str = activation.lower()
|
|
281
|
+
|
|
282
|
+
if act == "relu":
|
|
283
|
+
return nn.ReLU()
|
|
284
|
+
elif act == "elu":
|
|
285
|
+
return nn.ELU()
|
|
286
|
+
elif act in ("leaky_relu", "leakyrelu"):
|
|
287
|
+
return nn.LeakyReLU()
|
|
288
|
+
elif act == "selu":
|
|
289
|
+
return nn.SELU()
|
|
290
|
+
else:
|
|
291
|
+
msg = f"Activation {activation} not supported."
|
|
292
|
+
self.logger.error(msg)
|
|
293
|
+
raise ValueError(msg)
|