pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -1,707 +1,348 @@
|
|
|
1
|
-
import
|
|
2
|
-
import os
|
|
3
|
-
import sys
|
|
4
|
-
import warnings
|
|
1
|
+
from typing import List, Literal, Tuple
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from snpio.utils.logging import LoggerManager
|
|
8
|
+
from torch.distributions import Normal
|
|
10
9
|
|
|
11
|
-
|
|
10
|
+
from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
|
|
12
11
|
|
|
13
|
-
# Disable can't find cuda .dll errors. Also turns of GPU support.
|
|
14
|
-
tf.config.set_visible_devices([], "GPU")
|
|
15
12
|
|
|
16
|
-
|
|
13
|
+
class Sampling(nn.Module):
|
|
14
|
+
"""A layer that samples from a latent distribution using the reparameterization trick.
|
|
17
15
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
tf.get_logger().setLevel(logging.ERROR)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Monkey patching deprecation utils to supress warnings.
|
|
24
|
-
# noinspection PyUnusedLocal
|
|
25
|
-
def deprecated(
|
|
26
|
-
date, instructions, warn_once=True
|
|
27
|
-
): # pylint: disable=unused-argument
|
|
28
|
-
def deprecated_wrapper(func):
|
|
29
|
-
return func
|
|
30
|
-
|
|
31
|
-
return deprecated_wrapper
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
deprecation.deprecated = deprecated
|
|
35
|
-
|
|
36
|
-
from tensorflow.keras.layers import (
|
|
37
|
-
Dropout,
|
|
38
|
-
Dense,
|
|
39
|
-
Reshape,
|
|
40
|
-
Activation,
|
|
41
|
-
LeakyReLU,
|
|
42
|
-
PReLU,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
from tensorflow.keras.regularizers import l1_l2
|
|
46
|
-
from tensorflow.keras import backend as K
|
|
47
|
-
|
|
48
|
-
# Custom Modules
|
|
49
|
-
try:
|
|
50
|
-
from ..neural_network_methods import NeuralNetworkMethods
|
|
51
|
-
except (ModuleNotFoundError, ValueError, ImportError):
|
|
52
|
-
from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
|
|
16
|
+
This layer is a core component of a Variational Autoencoder (VAE). It takes the mean and log-variance of a latent distribution as input and generates a sample from that distribution. By using the reparameterization trick ($z = \mu + \sigma \cdot \epsilon$), it allows gradients to be backpropagated through the random sampling process, making the VAE trainable.
|
|
17
|
+
"""
|
|
53
18
|
|
|
19
|
+
def forward(self, z_mean: torch.Tensor, z_log_var: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
"""Performs the forward pass to generate a latent sample.
|
|
54
21
|
|
|
55
|
-
|
|
56
|
-
|
|
22
|
+
Args:
|
|
23
|
+
z_mean (torch.Tensor): The mean of the latent normal distribution.
|
|
24
|
+
z_log_var (torch.Tensor): The log of the variance of the latent normal distribution.
|
|
57
25
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
26
|
+
Returns:
|
|
27
|
+
torch.Tensor: A sampled vector from the latent space.
|
|
28
|
+
"""
|
|
29
|
+
z_sigma = torch.exp(0.5 * z_log_var) # Precompute outside
|
|
61
30
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
batch = tf.shape(z_mean)[0]
|
|
67
|
-
dim = tf.shape(z_mean)[1]
|
|
68
|
-
epsilon = tf.random.normal(shape=(batch, dim))
|
|
31
|
+
# Ensure on GPU
|
|
32
|
+
# rand_like takes random values from a normal distribution
|
|
33
|
+
# of the same shape as z_mean.
|
|
34
|
+
epsilon = torch.randn_like(z_mean, device=z_mean.device)
|
|
69
35
|
return z_mean + z_sigma * epsilon
|
|
70
36
|
|
|
71
37
|
|
|
72
|
-
class Encoder(
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
n_features (int): Number of featuresi in input dataset.
|
|
77
|
-
|
|
78
|
-
num_classes (int): Number of classes in target data.
|
|
79
|
-
|
|
80
|
-
latent_dim (int): Number of latent dimensions to use.
|
|
81
|
-
|
|
82
|
-
hidden_layer_sizes (list of int): List of hidden layer sizes to use.
|
|
83
|
-
|
|
84
|
-
dropout_rate (float): Dropout rate for Dropout layer.
|
|
85
|
-
|
|
86
|
-
activation (str): Hidden activation function to use.
|
|
87
|
-
|
|
88
|
-
kernel_initializer (str): Initializer to use for weights.
|
|
89
|
-
|
|
90
|
-
kernel_regularizer (tf.keras.regularizers): L1 and/ or L2 objects.
|
|
91
|
-
|
|
92
|
-
beta (float, optional): KL divergence beta to use. Defualts to 1.0.
|
|
93
|
-
|
|
94
|
-
name (str): Name of model. Defaults to Encoder.
|
|
38
|
+
class Encoder(nn.Module):
|
|
39
|
+
"""The Encoder module of a Variational Autoencoder (VAE).
|
|
95
40
|
|
|
41
|
+
This module defines the encoder network, which takes high-dimensional input data and maps it to the parameters of a lower-dimensional latent distribution. The architecture consists of a series of fully-connected hidden layers that process the flattened input. The network culminates in two separate linear layers that output the mean (`z_mean`) and log-variance (`z_log_var`) of the approximate posterior distribution, $q(z|x)$.
|
|
96
42
|
"""
|
|
97
43
|
|
|
98
44
|
def __init__(
|
|
99
45
|
self,
|
|
100
|
-
n_features,
|
|
101
|
-
num_classes,
|
|
102
|
-
latent_dim,
|
|
103
|
-
hidden_layer_sizes,
|
|
104
|
-
dropout_rate,
|
|
105
|
-
activation,
|
|
106
|
-
kernel_initializer,
|
|
107
|
-
kernel_regularizer,
|
|
108
|
-
beta=1.0,
|
|
109
|
-
name="Encoder",
|
|
110
|
-
**kwargs,
|
|
46
|
+
n_features: int,
|
|
47
|
+
num_classes: int,
|
|
48
|
+
latent_dim: int,
|
|
49
|
+
hidden_layer_sizes: List[int],
|
|
50
|
+
dropout_rate: float,
|
|
51
|
+
activation: torch.nn.Module,
|
|
111
52
|
):
|
|
112
|
-
|
|
53
|
+
"""Initializes the Encoder module.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
57
|
+
num_classes (int): The number of possible classes for each input element (e.g., 4 alleles).
|
|
58
|
+
latent_dim (int): The dimensionality of the latent space.
|
|
59
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
60
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
61
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
62
|
+
"""
|
|
63
|
+
super(Encoder, self).__init__()
|
|
64
|
+
self.flatten = nn.Flatten()
|
|
65
|
+
self.activation = (
|
|
66
|
+
getattr(F, activation) if isinstance(activation, str) else activation
|
|
67
|
+
)
|
|
113
68
|
|
|
114
|
-
|
|
69
|
+
layers = []
|
|
70
|
+
# The input dimension accounts for channels
|
|
71
|
+
input_dim = n_features * num_classes
|
|
72
|
+
for hidden_size in hidden_layer_sizes:
|
|
73
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
115
74
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
self.dense4 = None
|
|
119
|
-
self.dense5 = None
|
|
75
|
+
# BatchNorm can lead to faster convergence.
|
|
76
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
120
77
|
|
|
121
|
-
|
|
122
|
-
|
|
78
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
79
|
+
layers.append(activation)
|
|
80
|
+
input_dim = hidden_size
|
|
123
81
|
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
kernel_initializer=kernel_initializer,
|
|
129
|
-
kernel_regularizer=kernel_regularizer,
|
|
130
|
-
name="Encoder1",
|
|
131
|
-
)
|
|
82
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
83
|
+
self.dense_z_mean = nn.Linear(input_dim, latent_dim)
|
|
84
|
+
self.dense_z_log_var = nn.Linear(input_dim, latent_dim)
|
|
85
|
+
self.sampling = Sampling()
|
|
132
86
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
kernel_initializer=kernel_initializer,
|
|
138
|
-
kernel_regularizer=kernel_regularizer,
|
|
139
|
-
name="Encoder2",
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
if len(hidden_layer_sizes) >= 3:
|
|
143
|
-
self.dense3 = Dense(
|
|
144
|
-
hidden_layer_sizes[2],
|
|
145
|
-
activation=activation,
|
|
146
|
-
kernel_initializer=kernel_initializer,
|
|
147
|
-
kernel_regularizer=kernel_regularizer,
|
|
148
|
-
name="Encoder3",
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if len(hidden_layer_sizes) >= 4:
|
|
152
|
-
self.dense4 = Dense(
|
|
153
|
-
hidden_layer_sizes[3],
|
|
154
|
-
activation=activation,
|
|
155
|
-
kernel_initializer=kernel_initializer,
|
|
156
|
-
kernel_regularizer=kernel_regularizer,
|
|
157
|
-
name="Encoder4",
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
if len(hidden_layer_sizes) == 5:
|
|
161
|
-
self.dense5 = Dense(
|
|
162
|
-
hidden_layer_sizes[4],
|
|
163
|
-
activation=activation,
|
|
164
|
-
kernel_initializer=kernel_initializer,
|
|
165
|
-
kernel_regularizer=kernel_regularizer,
|
|
166
|
-
name="Encoder5",
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
self.dense_z_mean = Dense(
|
|
170
|
-
latent_dim,
|
|
171
|
-
name="z_mean",
|
|
172
|
-
)
|
|
173
|
-
self.dense_z_log_var = Dense(
|
|
174
|
-
latent_dim,
|
|
175
|
-
name="z_log_var",
|
|
176
|
-
)
|
|
177
|
-
# z_mean and z_log_var are inputs.
|
|
178
|
-
self.sampling = Sampling(
|
|
179
|
-
name="z",
|
|
180
|
-
)
|
|
87
|
+
def forward(
|
|
88
|
+
self, x: torch.Tensor
|
|
89
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
90
|
+
"""Performs the forward pass through the encoder.
|
|
181
91
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
activation=activation,
|
|
185
|
-
kernel_initializer=kernel_initializer,
|
|
186
|
-
kernel_regularizer=kernel_regularizer,
|
|
187
|
-
name="Encoder5",
|
|
188
|
-
)
|
|
92
|
+
Args:
|
|
93
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
189
94
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
x = self.
|
|
195
|
-
x = self.dense1(x)
|
|
196
|
-
x = self.dropout_layer(x, training=training)
|
|
197
|
-
if self.dense2 is not None:
|
|
198
|
-
x = self.dense2(x)
|
|
199
|
-
x = self.dropout_layer(x, training=training)
|
|
200
|
-
if self.dense3 is not None:
|
|
201
|
-
x = self.dense3(x)
|
|
202
|
-
x = self.dropout_layer(x, training=training)
|
|
203
|
-
if self.dense4 is not None:
|
|
204
|
-
x = self.dense4(x)
|
|
205
|
-
x = self.dropout_layer(x, training=training)
|
|
206
|
-
if self.dense5 is not None:
|
|
207
|
-
x = self.dense5(x)
|
|
208
|
-
x = self.dropout_layer(x, training=training)
|
|
209
|
-
|
|
210
|
-
x = self.dense_latent(x)
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the latent mean (`z_mean`), latent log-variance (`z_log_var`), and a sample from the latent distribution (`z`).
|
|
97
|
+
"""
|
|
98
|
+
x = self.flatten(x)
|
|
99
|
+
x = self.hidden_layers(x)
|
|
211
100
|
z_mean = self.dense_z_mean(x)
|
|
212
101
|
z_log_var = self.dense_z_log_var(x)
|
|
213
|
-
|
|
214
|
-
# Compute the KL divergence
|
|
215
|
-
kl_loss = -0.5 * tf.reduce_sum(
|
|
216
|
-
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1
|
|
217
|
-
)
|
|
218
|
-
# Add the KL divergence to the model's total loss
|
|
219
|
-
self.add_loss(self.beta * tf.reduce_mean(kl_loss))
|
|
220
|
-
|
|
221
|
-
z = self.sampling([z_mean, z_log_var])
|
|
222
|
-
|
|
102
|
+
z = self.sampling(z_mean, z_log_var)
|
|
223
103
|
return z_mean, z_log_var, z
|
|
224
104
|
|
|
225
105
|
|
|
226
|
-
class Decoder(
|
|
227
|
-
"""
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
n_features (int): Number of features in input dataset.
|
|
231
|
-
|
|
232
|
-
num_classes (int): Number of classes in input dataset.
|
|
106
|
+
class Decoder(nn.Module):
|
|
107
|
+
"""The Decoder module of a Variational Autoencoder (VAE).
|
|
233
108
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
hidden_layer_sizes (list of int): List of hidden layer sizes to use.
|
|
237
|
-
|
|
238
|
-
dropout_rate (float): Dropout rate for Dropout layer.
|
|
239
|
-
|
|
240
|
-
activation (str): Hidden activation function to use.
|
|
241
|
-
|
|
242
|
-
kernel initializer (str): Function for initilizing weights.
|
|
243
|
-
|
|
244
|
-
kernel_regularizer (tf.keras.regularizer): Initialized L1 and/ or L2 regularizer.
|
|
245
|
-
|
|
246
|
-
name (str): Name of model. Defaults to "Decoder".
|
|
109
|
+
This module defines the decoder network, which takes a sample from the low-dimensional latent space and maps it back to the high-dimensional data space. It aims to reconstruct the original input data. The architecture consists of a series of fully-connected hidden layers followed by a final linear layer that produces the reconstructed data, which is then reshaped to match the original input's dimensions.
|
|
247
110
|
"""
|
|
248
111
|
|
|
249
112
|
def __init__(
|
|
250
113
|
self,
|
|
251
|
-
n_features,
|
|
252
|
-
num_classes,
|
|
253
|
-
latent_dim,
|
|
254
|
-
hidden_layer_sizes,
|
|
255
|
-
dropout_rate,
|
|
256
|
-
activation,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.dense1 = Dense(
|
|
270
|
-
hidden_layer_sizes[0],
|
|
271
|
-
input_shape=(latent_dim,),
|
|
272
|
-
activation=activation,
|
|
273
|
-
kernel_initializer=kernel_initializer,
|
|
274
|
-
kernel_regularizer=kernel_regularizer,
|
|
275
|
-
name="Decoder1",
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
if len(hidden_layer_sizes) >= 2:
|
|
279
|
-
self.dense2 = Dense(
|
|
280
|
-
hidden_layer_sizes[1],
|
|
281
|
-
activation=activation,
|
|
282
|
-
kernel_initializer=kernel_initializer,
|
|
283
|
-
kernel_regularizer=kernel_regularizer,
|
|
284
|
-
name="Decoder2",
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
if len(hidden_layer_sizes) >= 3:
|
|
288
|
-
self.dense3 = Dense(
|
|
289
|
-
hidden_layer_sizes[2],
|
|
290
|
-
activation=activation,
|
|
291
|
-
kernel_initializer=kernel_initializer,
|
|
292
|
-
kernel_regularizer=kernel_regularizer,
|
|
293
|
-
name="Decoder3",
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
if len(hidden_layer_sizes) >= 4:
|
|
297
|
-
self.dense4 = Dense(
|
|
298
|
-
hidden_layer_sizes[3],
|
|
299
|
-
activation=activation,
|
|
300
|
-
kernel_initializer=kernel_initializer,
|
|
301
|
-
kernel_regularizer=kernel_regularizer,
|
|
302
|
-
name="Decoder4",
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
if len(hidden_layer_sizes) == 5:
|
|
306
|
-
self.dense5 = Dense(
|
|
307
|
-
hidden_layer_sizes[4],
|
|
308
|
-
activation=activation,
|
|
309
|
-
kernel_initializer=kernel_initializer,
|
|
310
|
-
kernel_regularizer=kernel_regularizer,
|
|
311
|
-
name="Decoder5",
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
# No activation for final layer.
|
|
315
|
-
self.dense_output = Dense(
|
|
316
|
-
n_features * num_classes,
|
|
317
|
-
kernel_initializer=kernel_initializer,
|
|
318
|
-
kernel_regularizer=kernel_regularizer,
|
|
319
|
-
name="DecoderExpanded",
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
self.rshp = Reshape((n_features, num_classes))
|
|
323
|
-
|
|
324
|
-
self.dropout_layer = Dropout(dropout_rate)
|
|
325
|
-
|
|
326
|
-
def call(self, inputs, training=None):
|
|
327
|
-
"""Forward pass for model."""
|
|
328
|
-
x = self.dense1(inputs)
|
|
329
|
-
x = self.dropout_layer(x, training=training)
|
|
330
|
-
if self.dense2 is not None:
|
|
331
|
-
x = self.dense2(x)
|
|
332
|
-
x = self.dropout_layer(x, training=training)
|
|
333
|
-
if self.dense3 is not None:
|
|
334
|
-
x = self.dense3(x)
|
|
335
|
-
x = self.dropout_layer(x, training=training)
|
|
336
|
-
if self.dense4 is not None:
|
|
337
|
-
x = self.dense4(x)
|
|
338
|
-
x = self.dropout_layer(x, training=training)
|
|
339
|
-
if self.dense5 is not None:
|
|
340
|
-
x = self.dense5(x)
|
|
341
|
-
x = self.dropout_layer(x, training=training)
|
|
342
|
-
|
|
343
|
-
x = self.dense_output(x)
|
|
344
|
-
return self.rshp(x)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class VAEModel(tf.keras.Model):
|
|
348
|
-
"""Variational Autoencoder model. Runs the encdoer and decoder and outputs the reconsruction.
|
|
349
|
-
|
|
350
|
-
Args:
|
|
351
|
-
output_shape (tuple): Shape of output. Defaults to None.
|
|
352
|
-
|
|
353
|
-
n_components (int): Number of latent dimensions to use. Defaults to 3.
|
|
354
|
-
|
|
355
|
-
weights_initializer (str, optional): kernel initializer to use. Defaults to "glorot_normal".
|
|
356
|
-
|
|
357
|
-
hidden_layer_sizes (str or list, optional): List of hidden layer sizes to use, or can use "midpoint", "log2", or "sqrt". Defaults to "midpoint".
|
|
358
|
-
|
|
359
|
-
num_hidden_layers (int, optional): Number of hidden layers to use. Defaults to 1.
|
|
360
|
-
|
|
361
|
-
hidden_activation (str, optional): Hidden activation function to use. Defaults to "elu".
|
|
114
|
+
n_features: int,
|
|
115
|
+
num_classes: int,
|
|
116
|
+
latent_dim: int,
|
|
117
|
+
hidden_layer_sizes: List[int],
|
|
118
|
+
dropout_rate: float,
|
|
119
|
+
activation: torch.nn.Module,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Initializes the Decoder module.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
n_features (int): The number of features in the output data (e.g., SNPs).
|
|
125
|
+
num_classes (int): The number of possible classes for each output element (e.g., 4 alleles).
|
|
126
|
+
latent_dim (int): The dimensionality of the input latent space.
|
|
127
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
128
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
129
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
130
|
+
"""
|
|
131
|
+
super(Decoder, self).__init__()
|
|
362
132
|
|
|
363
|
-
|
|
133
|
+
layers = []
|
|
134
|
+
input_dim = latent_dim
|
|
135
|
+
for hidden_size in hidden_layer_sizes:
|
|
136
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
364
137
|
|
|
365
|
-
|
|
138
|
+
# BatchNorm can lead to faster convergence.
|
|
139
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
366
140
|
|
|
367
|
-
|
|
141
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
142
|
+
layers.append(activation)
|
|
143
|
+
input_dim = hidden_size
|
|
368
144
|
|
|
369
|
-
|
|
145
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
146
|
+
# UPDATED: Output dimension must account for channels
|
|
147
|
+
output_dim = n_features * num_classes
|
|
148
|
+
self.dense_output = nn.Linear(input_dim, output_dim)
|
|
149
|
+
# UPDATED: Reshape must account for channels
|
|
150
|
+
self.reshape = (n_features, num_classes)
|
|
370
151
|
|
|
371
|
-
|
|
152
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
153
|
+
"""Performs the forward pass through the decoder.
|
|
372
154
|
|
|
373
|
-
|
|
155
|
+
Args:
|
|
156
|
+
x (torch.Tensor): The input latent tensor of shape `(batch_size, latent_dim)`.
|
|
374
157
|
|
|
375
|
-
|
|
158
|
+
Returns:
|
|
159
|
+
torch.Tensor: The reconstructed output data of shape `(batch_size, n_features, num_classes)`.
|
|
160
|
+
"""
|
|
161
|
+
x = self.hidden_layers(x)
|
|
162
|
+
x = self.dense_output(x)
|
|
163
|
+
return x.view(-1, *self.reshape)
|
|
376
164
|
|
|
377
|
-
batch_size (int, optional): Batch size to use for training. Defaults to 32.
|
|
378
165
|
|
|
379
|
-
|
|
166
|
+
class VAEModel(nn.Module):
|
|
167
|
+
"""A Variational Autoencoder (VAE) model for imputation.
|
|
380
168
|
|
|
381
|
-
|
|
169
|
+
This class combines an `Encoder` and a `Decoder` to form a VAE, a generative model for learning complex data distributions. It is designed for imputing missing values in categorical data, such as genomic SNPs. The model is trained by maximizing the Evidence Lower Bound (ELBO), which is a lower bound on the log-likelihood of the data.
|
|
382
170
|
|
|
171
|
+
**Objective Function (ELBO):**
|
|
172
|
+
The VAE loss function is derived from the ELBO and consists of two main components: a reconstruction term and a regularization term.
|
|
173
|
+
$$
|
|
174
|
+
\\mathcal{L}(\\theta, \\phi; x) = \\underbrace{\\mathbb{E}_{q_{\\phi}(z|x)}[\\log p_{\\theta}(x|z)]}_{\\text{Reconstruction Loss}} - \\underbrace{D_{KL}(q_{\\phi}(z|x) || p(z))}_{\\text{KL Divergence}}
|
|
175
|
+
$$
|
|
176
|
+
- The **Reconstruction Loss** encourages the decoder to accurately reconstruct the input data from its latent representation. This implementation uses a `MaskedFocalLoss`.
|
|
177
|
+
- The **KL Divergence** acts as a regularizer, forcing the approximate posterior distribution $q_{\\phi}(z|x)$ learned by the encoder to be close to a prior distribution $p(z)$ (typically a standard normal distribution).
|
|
383
178
|
"""
|
|
384
179
|
|
|
385
180
|
def __init__(
|
|
386
181
|
self,
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
batch_size=32,
|
|
401
|
-
final_activation=None,
|
|
402
|
-
y=None,
|
|
182
|
+
n_features: int,
|
|
183
|
+
prefix: str,
|
|
184
|
+
*,
|
|
185
|
+
num_classes: int = 4,
|
|
186
|
+
hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
|
|
187
|
+
latent_dim: int = 2,
|
|
188
|
+
dropout_rate: float = 0.2,
|
|
189
|
+
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
|
|
190
|
+
gamma: float = 2.0,
|
|
191
|
+
beta: float = 1.0,
|
|
192
|
+
device: Literal["cpu", "gpu", "mps"] = "cpu",
|
|
193
|
+
verbose: bool = False,
|
|
194
|
+
debug: bool = False,
|
|
403
195
|
):
|
|
196
|
+
"""Initializes the VAEModel.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
200
|
+
prefix (str): A prefix used for logging.
|
|
201
|
+
num_classes (int): The number of possible classes for each input element. Defaults to 4.
|
|
202
|
+
hidden_layer_sizes (List[int] | np.ndarray): A list of integers specifying the size of each hidden layer in the encoder and decoder. Defaults to [128, 64].
|
|
203
|
+
latent_dim (int): The dimensionality of the latent space. Defaults to 2.
|
|
204
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers. Defaults to 0.2.
|
|
205
|
+
activation (str): The name of the activation function to use in hidden layers. Defaults to "relu".
|
|
206
|
+
gamma (float): The focusing parameter for the focal loss component. Defaults to 2.0.
|
|
207
|
+
beta (float): A weighting factor for the KL divergence term in the total loss ($\beta$-VAE). Defaults to 1.0.
|
|
208
|
+
device (Literal["cpu", "gpu", "mps"]): The device to run the model on.
|
|
209
|
+
verbose (bool): If True, enables detailed logging. Defaults to False.
|
|
210
|
+
debug (bool): If True, enables debug mode. Defaults to False.
|
|
211
|
+
"""
|
|
404
212
|
super(VAEModel, self).__init__()
|
|
405
|
-
|
|
406
|
-
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
|
|
407
|
-
self.binary_accuracy_tracker = tf.keras.metrics.Mean(
|
|
408
|
-
name="binary_accuracy"
|
|
409
|
-
)
|
|
410
|
-
|
|
411
|
-
self.kl_beta = K.variable(0.0)
|
|
412
|
-
self.kl_beta._trainable = False
|
|
413
|
-
|
|
414
|
-
self._sample_weight = sample_weight
|
|
415
|
-
self._missing_mask = missing_mask
|
|
416
|
-
self._batch_idx = 0
|
|
417
|
-
self._batch_size = batch_size
|
|
418
|
-
self._y = y
|
|
419
|
-
|
|
420
|
-
self._final_activation = final_activation
|
|
421
|
-
if num_classes == 10 or num_classes == 3:
|
|
422
|
-
self.acc_func = tf.keras.metrics.categorical_accuracy
|
|
423
|
-
elif num_classes == 4:
|
|
424
|
-
self.acc_func = tf.keras.metrics.binary_accuracy
|
|
425
|
-
|
|
426
|
-
self.nn_ = NeuralNetworkMethods()
|
|
427
|
-
|
|
428
|
-
# y_train[1] dimension.
|
|
429
|
-
self.n_features = output_shape
|
|
430
|
-
|
|
431
|
-
self.n_components = n_components
|
|
432
|
-
self.weights_initializer = weights_initializer
|
|
433
|
-
self.hidden_layer_sizes = hidden_layer_sizes
|
|
434
|
-
self.num_hidden_layers = num_hidden_layers
|
|
435
|
-
self.hidden_activation = hidden_activation
|
|
436
|
-
self.l1_penalty = l1_penalty
|
|
437
|
-
self.l2_penalty = l2_penalty
|
|
438
|
-
self.dropout_rate = dropout_rate
|
|
439
213
|
self.num_classes = num_classes
|
|
214
|
+
self.gamma = gamma
|
|
215
|
+
self.beta = beta
|
|
216
|
+
self.device = device
|
|
440
217
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
hidden_layer_sizes = nn.validate_hidden_layers(
|
|
444
|
-
self.hidden_layer_sizes, self.num_hidden_layers
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
hidden_layer_sizes = nn.get_hidden_layer_sizes(
|
|
448
|
-
self.n_features, self.n_components, hidden_layer_sizes, vae=True
|
|
218
|
+
logman = LoggerManager(
|
|
219
|
+
name=__name__, prefix=prefix, verbose=verbose, debug=debug
|
|
449
220
|
)
|
|
221
|
+
self.logger = logman.get_logger()
|
|
450
222
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
if self.l1_penalty == 0.0 and self.l2_penalty == 0.0:
|
|
454
|
-
kernel_regularizer = None
|
|
455
|
-
else:
|
|
456
|
-
kernel_regularizer = l1_l2(self.l1_penalty, self.l2_penalty)
|
|
457
|
-
|
|
458
|
-
kernel_initializer = self.weights_initializer
|
|
459
|
-
|
|
460
|
-
if self.hidden_activation.lower() == "leaky_relu":
|
|
461
|
-
activation = LeakyReLU(alpha=0.01)
|
|
462
|
-
|
|
463
|
-
elif self.hidden_activation.lower() == "prelu":
|
|
464
|
-
activation = PReLU()
|
|
465
|
-
|
|
466
|
-
elif self.hidden_activation.lower() == "selu":
|
|
467
|
-
activation = "selu"
|
|
468
|
-
kernel_initializer = "lecun_normal"
|
|
469
|
-
|
|
470
|
-
else:
|
|
471
|
-
activation = self.hidden_activation
|
|
472
|
-
|
|
473
|
-
if num_hidden_layers > 5:
|
|
474
|
-
raise ValueError(
|
|
475
|
-
f"The maximum number of hidden layers is 5, but got "
|
|
476
|
-
f"{num_hidden_layers}"
|
|
477
|
-
)
|
|
223
|
+
activation = self._resolve_activation(activation)
|
|
478
224
|
|
|
479
225
|
self.encoder = Encoder(
|
|
480
|
-
|
|
226
|
+
n_features,
|
|
481
227
|
self.num_classes,
|
|
482
|
-
|
|
228
|
+
latent_dim,
|
|
483
229
|
hidden_layer_sizes,
|
|
484
|
-
|
|
230
|
+
dropout_rate,
|
|
485
231
|
activation,
|
|
486
|
-
kernel_initializer,
|
|
487
|
-
kernel_regularizer,
|
|
488
|
-
beta=self.kl_beta,
|
|
489
232
|
)
|
|
490
233
|
|
|
491
|
-
hidden_layer_sizes
|
|
234
|
+
decoder_layer_sizes = list(reversed(hidden_layer_sizes))
|
|
492
235
|
|
|
493
236
|
self.decoder = Decoder(
|
|
494
|
-
|
|
237
|
+
n_features,
|
|
495
238
|
self.num_classes,
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
239
|
+
latent_dim,
|
|
240
|
+
decoder_layer_sizes,
|
|
241
|
+
dropout_rate,
|
|
499
242
|
activation,
|
|
500
|
-
kernel_initializer,
|
|
501
|
-
kernel_regularizer,
|
|
502
243
|
)
|
|
503
244
|
|
|
504
|
-
|
|
245
|
+
def forward(
|
|
246
|
+
self, x: torch.Tensor
|
|
247
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
248
|
+
"""Performs the forward pass through the full VAE model.
|
|
505
249
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
z_mean, z_log_var, z = self.encoder(inputs)
|
|
509
|
-
reconstruction = self.decoder(z)
|
|
510
|
-
return self.activation(reconstruction)
|
|
250
|
+
Args:
|
|
251
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
511
252
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
:noindex:
|
|
253
|
+
Returns:
|
|
254
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the reconstructed output, the latent mean (`z_mean`), and the latent log-variance (`z_log_var`).
|
|
515
255
|
"""
|
|
516
|
-
|
|
517
|
-
|
|
256
|
+
z_mean, z_log_var, z = self.encoder(x)
|
|
257
|
+
reconstruction = self.decoder(z)
|
|
258
|
+
return reconstruction, z_mean, z_log_var
|
|
518
259
|
|
|
519
|
-
def
|
|
520
|
-
|
|
521
|
-
:
|
|
260
|
+
def compute_loss(
|
|
261
|
+
self,
|
|
262
|
+
outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
263
|
+
y: torch.Tensor,
|
|
264
|
+
mask: torch.Tensor | None = None,
|
|
265
|
+
class_weights: torch.Tensor | None = None,
|
|
266
|
+
) -> torch.Tensor:
|
|
267
|
+
"""Computes the VAE loss function (negative ELBO).
|
|
268
|
+
|
|
269
|
+
The loss is the sum of a reconstruction term and a regularizing KL divergence term. The reconstruction loss is calculated using a masked focal loss, and the KL divergence measures the difference between the learned latent distribution and a standard normal prior.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
outputs (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The tuple of (reconstruction, z_mean, z_log_var) from the model's forward pass.
|
|
273
|
+
y (torch.Tensor): The target data tensor, expected to be one-hot encoded. This is converted to class indices internally for the loss function.
|
|
274
|
+
mask (torch.Tensor | None): A boolean mask to exclude missing values from the reconstruction loss.
|
|
275
|
+
class_weights (torch.Tensor | None): Weights to apply to each class in the reconstruction loss to handle imbalance.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
torch.Tensor: The computed scalar loss value.
|
|
522
279
|
"""
|
|
523
|
-
|
|
524
|
-
model = tf.keras.Model(inputs=[x], outputs=self.call(x))
|
|
525
|
-
self.outputs = model.outputs
|
|
526
|
-
|
|
527
|
-
@property
|
|
528
|
-
def metrics(self):
|
|
529
|
-
return [
|
|
530
|
-
self.total_loss_tracker,
|
|
531
|
-
self.binary_accuracy_tracker,
|
|
532
|
-
]
|
|
533
|
-
|
|
534
|
-
@tf.function
|
|
535
|
-
def train_step(self, data):
|
|
536
|
-
y = self._y
|
|
537
|
-
|
|
538
|
-
(
|
|
539
|
-
y_true,
|
|
540
|
-
sample_weight,
|
|
541
|
-
missing_mask,
|
|
542
|
-
) = self.nn_.prepare_training_batches(
|
|
543
|
-
y,
|
|
544
|
-
y,
|
|
545
|
-
self._batch_size,
|
|
546
|
-
self._batch_idx,
|
|
547
|
-
True,
|
|
548
|
-
self.n_components,
|
|
549
|
-
self._sample_weight,
|
|
550
|
-
self._missing_mask,
|
|
551
|
-
ubp=False,
|
|
552
|
-
)
|
|
553
|
-
|
|
554
|
-
if sample_weight is not None:
|
|
555
|
-
sample_weight_masked = tf.convert_to_tensor(
|
|
556
|
-
sample_weight[~missing_mask], dtype=tf.float32
|
|
557
|
-
)
|
|
558
|
-
else:
|
|
559
|
-
sample_weight_masked = None
|
|
280
|
+
reconstruction, z_mean, z_log_var = outputs
|
|
560
281
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
282
|
+
# 1. KL Divergence Calculation
|
|
283
|
+
prior = Normal(torch.zeros_like(z_mean), torch.ones_like(z_log_var))
|
|
284
|
+
posterior = Normal(z_mean, torch.exp(0.5 * z_log_var))
|
|
285
|
+
kl_loss = (
|
|
286
|
+
torch.distributions.kl.kl_divergence(posterior, prior).sum(dim=1).mean()
|
|
564
287
|
)
|
|
565
288
|
|
|
566
|
-
|
|
567
|
-
|
|
289
|
+
if class_weights is None:
|
|
290
|
+
class_weights = torch.ones(self.num_classes, device=y.device)
|
|
568
291
|
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
292
|
+
# 2. Reconstruction Loss Calculation
|
|
293
|
+
# Reverting to the robust method of flattening tensors and using the
|
|
294
|
+
# custom loss function.
|
|
295
|
+
n_classes = reconstruction.shape[-1]
|
|
296
|
+
logits_flat = reconstruction.reshape(-1, n_classes)
|
|
573
297
|
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
y_true_masked,
|
|
577
|
-
y_pred_masked,
|
|
578
|
-
sample_weight=sample_weight_masked,
|
|
579
|
-
regularization_losses=self.losses,
|
|
580
|
-
)
|
|
298
|
+
# Convert one-hot `y` to class indices for the loss function.
|
|
299
|
+
targets_flat = torch.argmax(y, dim=-1).reshape(-1)
|
|
581
300
|
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
total_loss = reconstruction_loss + regularization_loss
|
|
586
|
-
|
|
587
|
-
grads = tape.gradient(total_loss, self.trainable_variables)
|
|
588
|
-
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
|
589
|
-
|
|
590
|
-
self.total_loss_tracker.update_state(total_loss)
|
|
591
|
-
self.binary_accuracy_tracker.update_state(
|
|
592
|
-
tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
|
|
593
|
-
)
|
|
594
|
-
|
|
595
|
-
return {
|
|
596
|
-
"loss": self.total_loss_tracker.result(),
|
|
597
|
-
"binary_accuracy": self.binary_accuracy_tracker.result(),
|
|
598
|
-
}
|
|
599
|
-
|
|
600
|
-
@tf.function
|
|
601
|
-
def test_step(self, data):
|
|
602
|
-
y = self._y
|
|
603
|
-
|
|
604
|
-
(
|
|
605
|
-
y_true,
|
|
606
|
-
sample_weight,
|
|
607
|
-
missing_mask,
|
|
608
|
-
) = self.nn_.prepare_training_batches(
|
|
609
|
-
y,
|
|
610
|
-
y,
|
|
611
|
-
self._batch_size,
|
|
612
|
-
self._batch_idx,
|
|
613
|
-
True,
|
|
614
|
-
self.n_components,
|
|
615
|
-
self._sample_weight,
|
|
616
|
-
self._missing_mask,
|
|
617
|
-
ubp=False,
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
if sample_weight is not None:
|
|
621
|
-
sample_weight_masked = tf.convert_to_tensor(
|
|
622
|
-
sample_weight[~missing_mask], dtype=tf.float32
|
|
623
|
-
)
|
|
301
|
+
if mask is None:
|
|
302
|
+
# If no mask is provided, all targets are considered valid.
|
|
303
|
+
mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
|
|
624
304
|
else:
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
y_true_masked = tf.boolean_mask(
|
|
628
|
-
tf.convert_to_tensor(y_true, dtype=tf.float32),
|
|
629
|
-
tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
|
|
630
|
-
)
|
|
305
|
+
# The mask needs to be reshaped to match the flattened targets.
|
|
306
|
+
mask_flat = mask.reshape(-1)
|
|
631
307
|
|
|
632
|
-
|
|
308
|
+
# Logits, class-index targets, and the valid mask.
|
|
309
|
+
criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
|
|
633
310
|
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
311
|
+
reconstruction_loss = criterion(
|
|
312
|
+
logits_flat.to(self.device),
|
|
313
|
+
targets_flat.to(self.device),
|
|
314
|
+
valid_mask=mask_flat.to(self.device),
|
|
637
315
|
)
|
|
638
316
|
|
|
639
|
-
reconstruction_loss
|
|
640
|
-
y_true_masked,
|
|
641
|
-
y_pred_masked,
|
|
642
|
-
sample_weight=sample_weight_masked,
|
|
643
|
-
regularization_losses=self.losses,
|
|
644
|
-
)
|
|
317
|
+
return reconstruction_loss + self.beta * kl_loss
|
|
645
318
|
|
|
646
|
-
|
|
647
|
-
|
|
319
|
+
def _resolve_activation(
|
|
320
|
+
self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
|
|
321
|
+
) -> torch.nn.Module:
|
|
322
|
+
"""Resolves an activation function module from a string name.
|
|
648
323
|
|
|
649
|
-
|
|
324
|
+
Args:
|
|
325
|
+
activation (Literal["relu", "elu", "leaky_relu", "selu"]): The name of the activation function.
|
|
650
326
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
self.total_loss_tracker.update_state(total_loss)
|
|
654
|
-
self.binary_accuracy_tracker.update_state(
|
|
655
|
-
tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
|
|
656
|
-
)
|
|
327
|
+
Returns:
|
|
328
|
+
torch.nn.Module: The corresponding instantiated PyTorch activation function module.
|
|
657
329
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
@property
|
|
678
|
-
def missing_mask(self):
|
|
679
|
-
return self._missing_mask
|
|
680
|
-
|
|
681
|
-
@property
|
|
682
|
-
def sample_weight(self):
|
|
683
|
-
return self._sample_weight
|
|
684
|
-
|
|
685
|
-
@batch_size.setter
|
|
686
|
-
def batch_size(self, value):
|
|
687
|
-
"""Set batch_size parameter."""
|
|
688
|
-
self._batch_size = int(value)
|
|
689
|
-
|
|
690
|
-
@batch_idx.setter
|
|
691
|
-
def batch_idx(self, value):
|
|
692
|
-
"""Set current batch (=step) index."""
|
|
693
|
-
self._batch_idx = int(value)
|
|
694
|
-
|
|
695
|
-
@y.setter
|
|
696
|
-
def y(self, value):
|
|
697
|
-
"""Set y after each epoch."""
|
|
698
|
-
self._y = value
|
|
699
|
-
|
|
700
|
-
@missing_mask.setter
|
|
701
|
-
def missing_mask(self, value):
|
|
702
|
-
"""Set y after each epoch."""
|
|
703
|
-
self._missing_mask = value
|
|
704
|
-
|
|
705
|
-
@sample_weight.setter
|
|
706
|
-
def sample_weight(self, value):
|
|
707
|
-
self._sample_weight = value
|
|
330
|
+
Raises:
|
|
331
|
+
ValueError: If the provided activation name is not supported.
|
|
332
|
+
"""
|
|
333
|
+
if isinstance(activation, str):
|
|
334
|
+
activation = activation.lower()
|
|
335
|
+
if activation == "relu":
|
|
336
|
+
return nn.ReLU()
|
|
337
|
+
elif activation == "elu":
|
|
338
|
+
return nn.ELU()
|
|
339
|
+
elif activation in ["leaky_relu", "leakyrelu"]:
|
|
340
|
+
return nn.LeakyReLU()
|
|
341
|
+
elif activation == "selu":
|
|
342
|
+
return nn.SELU()
|
|
343
|
+
else:
|
|
344
|
+
msg = f"Activation {activation} not supported."
|
|
345
|
+
self.logger.error(msg)
|
|
346
|
+
raise ValueError(msg)
|
|
347
|
+
|
|
348
|
+
return activation
|