pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -1,710 +1,349 @@
|
|
|
1
|
-
import
|
|
2
|
-
import os
|
|
3
|
-
import sys
|
|
4
|
-
import warnings
|
|
1
|
+
from typing import List, Literal, Tuple
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from snpio.utils.logging import LoggerManager
|
|
8
|
+
from torch.distributions import Normal
|
|
10
9
|
|
|
11
|
-
|
|
10
|
+
from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
|
|
11
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
12
12
|
|
|
13
|
-
# Disable can't find cuda .dll errors. Also turns of GPU support.
|
|
14
|
-
tf.config.set_visible_devices([], "GPU")
|
|
15
13
|
|
|
16
|
-
|
|
14
|
+
class Sampling(nn.Module):
|
|
15
|
+
"""A layer that samples from a latent distribution using the reparameterization trick.
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
tf.get_logger().setLevel(logging.ERROR)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Monkey patching deprecation utils to supress warnings.
|
|
24
|
-
# noinspection PyUnusedLocal
|
|
25
|
-
def deprecated(
|
|
26
|
-
date, instructions, warn_once=True
|
|
27
|
-
): # pylint: disable=unused-argument
|
|
28
|
-
def deprecated_wrapper(func):
|
|
29
|
-
return func
|
|
30
|
-
|
|
31
|
-
return deprecated_wrapper
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
deprecation.deprecated = deprecated
|
|
35
|
-
|
|
36
|
-
from tensorflow.keras.layers import (
|
|
37
|
-
Dropout,
|
|
38
|
-
Dense,
|
|
39
|
-
Reshape,
|
|
40
|
-
Activation,
|
|
41
|
-
LeakyReLU,
|
|
42
|
-
PReLU,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
from tensorflow.keras.regularizers import l1_l2
|
|
46
|
-
from tensorflow.keras import backend as K
|
|
47
|
-
|
|
48
|
-
# Custom Modules
|
|
49
|
-
try:
|
|
50
|
-
from ..neural_network_methods import NeuralNetworkMethods
|
|
51
|
-
except (ModuleNotFoundError, ValueError, ImportError):
|
|
52
|
-
from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
|
|
17
|
+
This layer is a core component of a Variational Autoencoder (VAE). It takes the mean and log-variance of a latent distribution as input and generates a sample from that distribution. By using the reparameterization trick ($z = \mu + \sigma \cdot \epsilon$), it allows gradients to be backpropagated through the random sampling process, making the VAE trainable.
|
|
18
|
+
"""
|
|
53
19
|
|
|
20
|
+
def forward(self, z_mean: torch.Tensor, z_log_var: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
"""Performs the forward pass to generate a latent sample.
|
|
54
22
|
|
|
55
|
-
|
|
56
|
-
|
|
23
|
+
Args:
|
|
24
|
+
z_mean (torch.Tensor): The mean of the latent normal distribution.
|
|
25
|
+
z_log_var (torch.Tensor): The log of the variance of the latent normal distribution.
|
|
57
26
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
27
|
+
Returns:
|
|
28
|
+
torch.Tensor: A sampled vector from the latent space.
|
|
29
|
+
"""
|
|
30
|
+
z_sigma = torch.exp(0.5 * z_log_var) # Precompute outside
|
|
61
31
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
batch = tf.shape(z_mean)[0]
|
|
67
|
-
dim = tf.shape(z_mean)[1]
|
|
68
|
-
epsilon = tf.random.normal(shape=(batch, dim))
|
|
32
|
+
# Ensure on GPU
|
|
33
|
+
# rand_like takes random values from a normal distribution
|
|
34
|
+
# of the same shape as z_mean.
|
|
35
|
+
epsilon = torch.randn_like(z_mean, device=z_mean.device)
|
|
69
36
|
return z_mean + z_sigma * epsilon
|
|
70
37
|
|
|
71
38
|
|
|
72
|
-
class Encoder(
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
n_features (int): Number of featuresi in input dataset.
|
|
77
|
-
|
|
78
|
-
num_classes (int): Number of classes in target data.
|
|
79
|
-
|
|
80
|
-
latent_dim (int): Number of latent dimensions to use.
|
|
81
|
-
|
|
82
|
-
hidden_layer_sizes (list of int): List of hidden layer sizes to use.
|
|
83
|
-
|
|
84
|
-
dropout_rate (float): Dropout rate for Dropout layer.
|
|
85
|
-
|
|
86
|
-
activation (str): Hidden activation function to use.
|
|
87
|
-
|
|
88
|
-
kernel_initializer (str): Initializer to use for weights.
|
|
89
|
-
|
|
90
|
-
kernel_regularizer (tf.keras.regularizers): L1 and/ or L2 objects.
|
|
91
|
-
|
|
92
|
-
beta (float, optional): KL divergence beta to use. Defualts to 1.0.
|
|
93
|
-
|
|
94
|
-
name (str): Name of model. Defaults to Encoder.
|
|
39
|
+
class Encoder(nn.Module):
|
|
40
|
+
"""The Encoder module of a Variational Autoencoder (VAE).
|
|
95
41
|
|
|
42
|
+
This module defines the encoder network, which takes high-dimensional input data and maps it to the parameters of a lower-dimensional latent distribution. The architecture consists of a series of fully-connected hidden layers that process the flattened input. The network culminates in two separate linear layers that output the mean (`z_mean`) and log-variance (`z_log_var`) of the approximate posterior distribution, $q(z|x)$.
|
|
96
43
|
"""
|
|
97
44
|
|
|
98
45
|
def __init__(
|
|
99
46
|
self,
|
|
100
|
-
n_features,
|
|
101
|
-
num_classes,
|
|
102
|
-
latent_dim,
|
|
103
|
-
hidden_layer_sizes,
|
|
104
|
-
dropout_rate,
|
|
105
|
-
activation,
|
|
106
|
-
kernel_initializer,
|
|
107
|
-
kernel_regularizer,
|
|
108
|
-
beta=1.0,
|
|
109
|
-
name="Encoder",
|
|
110
|
-
**kwargs,
|
|
47
|
+
n_features: int,
|
|
48
|
+
num_classes: int,
|
|
49
|
+
latent_dim: int,
|
|
50
|
+
hidden_layer_sizes: List[int],
|
|
51
|
+
dropout_rate: float,
|
|
52
|
+
activation: torch.nn.Module,
|
|
111
53
|
):
|
|
112
|
-
|
|
54
|
+
"""Initializes the Encoder module.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
58
|
+
num_classes (int): Number of genotype states per locus (2 for haploid, 3 for diploid in practice).
|
|
59
|
+
latent_dim (int): The dimensionality of the latent space.
|
|
60
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
61
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
62
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
63
|
+
"""
|
|
64
|
+
super(Encoder, self).__init__()
|
|
65
|
+
self.flatten = nn.Flatten()
|
|
66
|
+
self.activation = (
|
|
67
|
+
getattr(F, activation) if isinstance(activation, str) else activation
|
|
68
|
+
)
|
|
113
69
|
|
|
114
|
-
|
|
70
|
+
layers = []
|
|
71
|
+
# The input dimension accounts for channels
|
|
72
|
+
input_dim = n_features * num_classes
|
|
73
|
+
for hidden_size in hidden_layer_sizes:
|
|
74
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
115
75
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
self.dense4 = None
|
|
119
|
-
self.dense5 = None
|
|
76
|
+
# BatchNorm can lead to faster convergence.
|
|
77
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
120
78
|
|
|
121
|
-
|
|
122
|
-
|
|
79
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
80
|
+
layers.append(activation)
|
|
81
|
+
input_dim = hidden_size
|
|
123
82
|
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
kernel_initializer=kernel_initializer,
|
|
129
|
-
kernel_regularizer=kernel_regularizer,
|
|
130
|
-
name="Encoder1",
|
|
131
|
-
)
|
|
83
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
84
|
+
self.dense_z_mean = nn.Linear(input_dim, latent_dim)
|
|
85
|
+
self.dense_z_log_var = nn.Linear(input_dim, latent_dim)
|
|
86
|
+
self.sampling = Sampling()
|
|
132
87
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
kernel_initializer=kernel_initializer,
|
|
138
|
-
kernel_regularizer=kernel_regularizer,
|
|
139
|
-
name="Encoder2",
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
if len(hidden_layer_sizes) >= 3:
|
|
143
|
-
self.dense3 = Dense(
|
|
144
|
-
hidden_layer_sizes[2],
|
|
145
|
-
activation=activation,
|
|
146
|
-
kernel_initializer=kernel_initializer,
|
|
147
|
-
kernel_regularizer=kernel_regularizer,
|
|
148
|
-
name="Encoder3",
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if len(hidden_layer_sizes) >= 4:
|
|
152
|
-
self.dense4 = Dense(
|
|
153
|
-
hidden_layer_sizes[3],
|
|
154
|
-
activation=activation,
|
|
155
|
-
kernel_initializer=kernel_initializer,
|
|
156
|
-
kernel_regularizer=kernel_regularizer,
|
|
157
|
-
name="Encoder4",
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
if len(hidden_layer_sizes) == 5:
|
|
161
|
-
self.dense5 = Dense(
|
|
162
|
-
hidden_layer_sizes[4],
|
|
163
|
-
activation=activation,
|
|
164
|
-
kernel_initializer=kernel_initializer,
|
|
165
|
-
kernel_regularizer=kernel_regularizer,
|
|
166
|
-
name="Encoder5",
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
self.dense_z_mean = Dense(
|
|
170
|
-
latent_dim,
|
|
171
|
-
name="z_mean",
|
|
172
|
-
)
|
|
173
|
-
self.dense_z_log_var = Dense(
|
|
174
|
-
latent_dim,
|
|
175
|
-
name="z_log_var",
|
|
176
|
-
)
|
|
177
|
-
# z_mean and z_log_var are inputs.
|
|
178
|
-
self.sampling = Sampling(
|
|
179
|
-
name="z",
|
|
180
|
-
)
|
|
88
|
+
def forward(
|
|
89
|
+
self, x: torch.Tensor
|
|
90
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
91
|
+
"""Performs the forward pass through the encoder.
|
|
181
92
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
activation=activation,
|
|
185
|
-
kernel_initializer=kernel_initializer,
|
|
186
|
-
kernel_regularizer=kernel_regularizer,
|
|
187
|
-
name="Encoder5",
|
|
188
|
-
)
|
|
93
|
+
Args:
|
|
94
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
189
95
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
x = self.
|
|
195
|
-
x = self.dense1(x)
|
|
196
|
-
x = self.dropout_layer(x, training=training)
|
|
197
|
-
if self.dense2 is not None:
|
|
198
|
-
x = self.dense2(x)
|
|
199
|
-
x = self.dropout_layer(x, training=training)
|
|
200
|
-
if self.dense3 is not None:
|
|
201
|
-
x = self.dense3(x)
|
|
202
|
-
x = self.dropout_layer(x, training=training)
|
|
203
|
-
if self.dense4 is not None:
|
|
204
|
-
x = self.dense4(x)
|
|
205
|
-
x = self.dropout_layer(x, training=training)
|
|
206
|
-
if self.dense5 is not None:
|
|
207
|
-
x = self.dense5(x)
|
|
208
|
-
x = self.dropout_layer(x, training=training)
|
|
209
|
-
|
|
210
|
-
x = self.dense_latent(x)
|
|
96
|
+
Returns:
|
|
97
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the latent mean (`z_mean`), latent log-variance (`z_log_var`), and a sample from the latent distribution (`z`).
|
|
98
|
+
"""
|
|
99
|
+
x = self.flatten(x)
|
|
100
|
+
x = self.hidden_layers(x)
|
|
211
101
|
z_mean = self.dense_z_mean(x)
|
|
212
102
|
z_log_var = self.dense_z_log_var(x)
|
|
213
|
-
|
|
214
|
-
# Compute the KL divergence
|
|
215
|
-
kl_loss = -0.5 * tf.reduce_sum(
|
|
216
|
-
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1
|
|
217
|
-
)
|
|
218
|
-
# Add the KL divergence to the model's total loss
|
|
219
|
-
self.add_loss(self.beta * tf.reduce_mean(kl_loss))
|
|
220
|
-
|
|
221
|
-
z = self.sampling([z_mean, z_log_var])
|
|
222
|
-
|
|
103
|
+
z = self.sampling(z_mean, z_log_var)
|
|
223
104
|
return z_mean, z_log_var, z
|
|
224
105
|
|
|
225
106
|
|
|
226
|
-
class Decoder(
|
|
227
|
-
"""
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
n_features (int): Number of features in input dataset.
|
|
107
|
+
class Decoder(nn.Module):
|
|
108
|
+
"""The Decoder module of a Variational Autoencoder (VAE).
|
|
231
109
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
latent_dim (int): Number of latent dimensions to use.
|
|
235
|
-
|
|
236
|
-
hidden_layer_sizes (list of int): List of hidden layer sizes to use.
|
|
237
|
-
|
|
238
|
-
dropout_rate (float): Dropout rate for Dropout layer.
|
|
239
|
-
|
|
240
|
-
activation (str): Hidden activation function to use.
|
|
241
|
-
|
|
242
|
-
kernel initializer (str): Function for initilizing weights.
|
|
243
|
-
|
|
244
|
-
kernel_regularizer (tf.keras.regularizer): Initialized L1 and/ or L2 regularizer.
|
|
245
|
-
|
|
246
|
-
name (str): Name of model. Defaults to "Decoder".
|
|
110
|
+
This module defines the decoder network, which takes a sample from the low-dimensional latent space and maps it back to the high-dimensional data space. It aims to reconstruct the original input data. The architecture consists of a series of fully-connected hidden layers followed by a final linear layer that produces the reconstructed data, which is then reshaped to match the original input's dimensions.
|
|
247
111
|
"""
|
|
248
112
|
|
|
249
113
|
def __init__(
|
|
250
114
|
self,
|
|
251
|
-
n_features,
|
|
252
|
-
num_classes,
|
|
253
|
-
latent_dim,
|
|
254
|
-
hidden_layer_sizes,
|
|
255
|
-
dropout_rate,
|
|
256
|
-
activation,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.dense1 = Dense(
|
|
270
|
-
hidden_layer_sizes[0],
|
|
271
|
-
input_shape=(latent_dim,),
|
|
272
|
-
activation=activation,
|
|
273
|
-
kernel_initializer=kernel_initializer,
|
|
274
|
-
kernel_regularizer=kernel_regularizer,
|
|
275
|
-
name="Decoder1",
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
if len(hidden_layer_sizes) >= 2:
|
|
279
|
-
self.dense2 = Dense(
|
|
280
|
-
hidden_layer_sizes[1],
|
|
281
|
-
activation=activation,
|
|
282
|
-
kernel_initializer=kernel_initializer,
|
|
283
|
-
kernel_regularizer=kernel_regularizer,
|
|
284
|
-
name="Decoder2",
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
if len(hidden_layer_sizes) >= 3:
|
|
288
|
-
self.dense3 = Dense(
|
|
289
|
-
hidden_layer_sizes[2],
|
|
290
|
-
activation=activation,
|
|
291
|
-
kernel_initializer=kernel_initializer,
|
|
292
|
-
kernel_regularizer=kernel_regularizer,
|
|
293
|
-
name="Decoder3",
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
if len(hidden_layer_sizes) >= 4:
|
|
297
|
-
self.dense4 = Dense(
|
|
298
|
-
hidden_layer_sizes[3],
|
|
299
|
-
activation=activation,
|
|
300
|
-
kernel_initializer=kernel_initializer,
|
|
301
|
-
kernel_regularizer=kernel_regularizer,
|
|
302
|
-
name="Decoder4",
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
if len(hidden_layer_sizes) == 5:
|
|
306
|
-
self.dense5 = Dense(
|
|
307
|
-
hidden_layer_sizes[4],
|
|
308
|
-
activation=activation,
|
|
309
|
-
kernel_initializer=kernel_initializer,
|
|
310
|
-
kernel_regularizer=kernel_regularizer,
|
|
311
|
-
name="Decoder5",
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
# No activation for final layer.
|
|
315
|
-
self.dense_output = Dense(
|
|
316
|
-
n_features * num_classes,
|
|
317
|
-
kernel_initializer=kernel_initializer,
|
|
318
|
-
kernel_regularizer=kernel_regularizer,
|
|
319
|
-
name="DecoderExpanded",
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
self.rshp = Reshape((n_features, num_classes))
|
|
323
|
-
|
|
324
|
-
self.dropout_layer = Dropout(dropout_rate)
|
|
325
|
-
|
|
326
|
-
def call(self, inputs, training=None):
|
|
327
|
-
"""Forward pass for model."""
|
|
328
|
-
x = self.dense1(inputs)
|
|
329
|
-
x = self.dropout_layer(x, training=training)
|
|
330
|
-
if self.dense2 is not None:
|
|
331
|
-
x = self.dense2(x)
|
|
332
|
-
x = self.dropout_layer(x, training=training)
|
|
333
|
-
if self.dense3 is not None:
|
|
334
|
-
x = self.dense3(x)
|
|
335
|
-
x = self.dropout_layer(x, training=training)
|
|
336
|
-
if self.dense4 is not None:
|
|
337
|
-
x = self.dense4(x)
|
|
338
|
-
x = self.dropout_layer(x, training=training)
|
|
339
|
-
if self.dense5 is not None:
|
|
340
|
-
x = self.dense5(x)
|
|
341
|
-
x = self.dropout_layer(x, training=training)
|
|
342
|
-
|
|
343
|
-
x = self.dense_output(x)
|
|
344
|
-
return self.rshp(x)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class VAEModel(tf.keras.Model):
|
|
348
|
-
"""Variational Autoencoder model. Runs the encdoer and decoder and outputs the reconsruction.
|
|
349
|
-
|
|
350
|
-
Args:
|
|
351
|
-
output_shape (tuple): Shape of output. Defaults to None.
|
|
352
|
-
|
|
353
|
-
n_components (int): Number of latent dimensions to use. Defaults to 3.
|
|
354
|
-
|
|
355
|
-
weights_initializer (str, optional): kernel initializer to use. Defaults to "glorot_normal".
|
|
356
|
-
|
|
357
|
-
hidden_layer_sizes (str or list, optional): List of hidden layer sizes to use, or can use "midpoint", "log2", or "sqrt". Defaults to "midpoint".
|
|
358
|
-
|
|
359
|
-
num_hidden_layers (int, optional): Number of hidden layers to use. Defaults to 1.
|
|
360
|
-
|
|
361
|
-
hidden_activation (str, optional): Hidden activation function to use. Defaults to "elu".
|
|
115
|
+
n_features: int,
|
|
116
|
+
num_classes: int,
|
|
117
|
+
latent_dim: int,
|
|
118
|
+
hidden_layer_sizes: List[int],
|
|
119
|
+
dropout_rate: float,
|
|
120
|
+
activation: torch.nn.Module,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Initializes the Decoder module.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
n_features (int): The number of features in the output data (e.g., SNPs).
|
|
126
|
+
num_classes (int): Number of genotype states per locus (typically 2 or 3).
|
|
127
|
+
latent_dim (int): The dimensionality of the input latent space.
|
|
128
|
+
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
129
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
130
|
+
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
131
|
+
"""
|
|
132
|
+
super(Decoder, self).__init__()
|
|
362
133
|
|
|
363
|
-
|
|
134
|
+
layers = []
|
|
135
|
+
input_dim = latent_dim
|
|
136
|
+
for hidden_size in hidden_layer_sizes:
|
|
137
|
+
layers.append(nn.Linear(input_dim, hidden_size))
|
|
364
138
|
|
|
365
|
-
|
|
139
|
+
# BatchNorm can lead to faster convergence.
|
|
140
|
+
layers.append(nn.BatchNorm1d(hidden_size))
|
|
366
141
|
|
|
367
|
-
|
|
142
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
143
|
+
layers.append(activation)
|
|
144
|
+
input_dim = hidden_size
|
|
368
145
|
|
|
369
|
-
|
|
146
|
+
self.hidden_layers = nn.Sequential(*layers)
|
|
147
|
+
# UPDATED: Output dimension must account for channels
|
|
148
|
+
output_dim = n_features * num_classes
|
|
149
|
+
self.dense_output = nn.Linear(input_dim, output_dim)
|
|
150
|
+
# UPDATED: Reshape must account for channels
|
|
151
|
+
self.reshape = (n_features, num_classes)
|
|
370
152
|
|
|
371
|
-
|
|
153
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
154
|
+
"""Performs the forward pass through the decoder.
|
|
372
155
|
|
|
373
|
-
|
|
156
|
+
Args:
|
|
157
|
+
x (torch.Tensor): The input latent tensor of shape `(batch_size, latent_dim)`.
|
|
374
158
|
|
|
375
|
-
|
|
159
|
+
Returns:
|
|
160
|
+
torch.Tensor: The reconstructed output data of shape `(batch_size, n_features, num_classes)`.
|
|
161
|
+
"""
|
|
162
|
+
x = self.hidden_layers(x)
|
|
163
|
+
x = self.dense_output(x)
|
|
164
|
+
return x.view(-1, *self.reshape)
|
|
376
165
|
|
|
377
|
-
batch_size (int, optional): Batch size to use for training. Defaults to 32.
|
|
378
166
|
|
|
379
|
-
|
|
167
|
+
class VAEModel(nn.Module):
|
|
168
|
+
"""A Variational Autoencoder (VAE) model for imputation.
|
|
380
169
|
|
|
381
|
-
|
|
170
|
+
This class combines an `Encoder` and a `Decoder` to form a VAE, a generative model for learning complex data distributions. It is designed for imputing missing values in categorical data, such as genomic SNPs. The model is trained by maximizing the Evidence Lower Bound (ELBO), which is a lower bound on the log-likelihood of the data.
|
|
382
171
|
|
|
172
|
+
**Objective Function (ELBO):**
|
|
173
|
+
The VAE loss function is derived from the ELBO and consists of two main components: a reconstruction term and a regularization term.
|
|
174
|
+
$$
|
|
175
|
+
\\mathcal{L}(\\theta, \\phi; x) = \\underbrace{\\mathbb{E}_{q_{\\phi}(z|x)}[\\log p_{\\theta}(x|z)]}_{\\text{Reconstruction Loss}} - \\underbrace{D_{KL}(q_{\\phi}(z|x) || p(z))}_{\\text{KL Divergence}}
|
|
176
|
+
$$
|
|
177
|
+
- The **Reconstruction Loss** encourages the decoder to accurately reconstruct the input data from its latent representation. This implementation uses a `MaskedFocalLoss`.
|
|
178
|
+
- The **KL Divergence** acts as a regularizer, forcing the approximate posterior distribution $q_{\\phi}(z|x)$ learned by the encoder to be close to a prior distribution $p(z)$ (typically a standard normal distribution).
|
|
383
179
|
"""
|
|
384
180
|
|
|
385
181
|
def __init__(
|
|
386
182
|
self,
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
batch_size=32,
|
|
401
|
-
final_activation=None,
|
|
402
|
-
y=None,
|
|
183
|
+
n_features: int,
|
|
184
|
+
prefix: str,
|
|
185
|
+
*,
|
|
186
|
+
num_classes: int = 4,
|
|
187
|
+
hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
|
|
188
|
+
latent_dim: int = 2,
|
|
189
|
+
dropout_rate: float = 0.2,
|
|
190
|
+
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
|
|
191
|
+
gamma: float = 2.0,
|
|
192
|
+
beta: float = 1.0,
|
|
193
|
+
device: Literal["cpu", "gpu", "mps"] = "cpu",
|
|
194
|
+
verbose: bool = False,
|
|
195
|
+
debug: bool = False,
|
|
403
196
|
):
|
|
197
|
+
"""Initializes the VAEModel.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
201
|
+
prefix (str): A prefix used for logging.
|
|
202
|
+
num_classes (int): Number of genotype states per locus. Defaults to 4 for backward compatibility, though the imputer passes 2 (haploid) or 3 (diploid).
|
|
203
|
+
hidden_layer_sizes (List[int] | np.ndarray): A list of integers specifying the size of each hidden layer in the encoder and decoder. Defaults to [128, 64].
|
|
204
|
+
latent_dim (int): The dimensionality of the latent space. Defaults to 2.
|
|
205
|
+
dropout_rate (float): The dropout rate for regularization in the hidden layers. Defaults to 0.2.
|
|
206
|
+
activation (str): The name of the activation function to use in hidden layers. Defaults to "relu".
|
|
207
|
+
gamma (float): The focusing parameter for the focal loss component. Defaults to 2.0.
|
|
208
|
+
beta (float): A weighting factor for the KL divergence term in the total loss ($\beta$-VAE). Defaults to 1.0.
|
|
209
|
+
device (Literal["cpu", "gpu", "mps"]): The device to run the model on.
|
|
210
|
+
verbose (bool): If True, enables detailed logging. Defaults to False.
|
|
211
|
+
debug (bool): If True, enables debug mode. Defaults to False.
|
|
212
|
+
"""
|
|
404
213
|
super(VAEModel, self).__init__()
|
|
405
|
-
|
|
406
|
-
self.kl_beta = K.variable(0.0)
|
|
407
|
-
self.kl_beta._trainable = False
|
|
408
|
-
|
|
409
|
-
self._sample_weight = sample_weight
|
|
410
|
-
self._missing_mask = missing_mask
|
|
411
|
-
self._batch_idx = 0
|
|
412
|
-
self._batch_size = batch_size
|
|
413
|
-
self._y = y
|
|
414
|
-
|
|
415
|
-
self._final_activation = final_activation
|
|
416
|
-
if num_classes == 10 or num_classes == 3:
|
|
417
|
-
self.acc_func = tf.keras.metrics.categorical_accuracy
|
|
418
|
-
elif num_classes == 4:
|
|
419
|
-
self.acc_func = tf.keras.metrics.binary_accuracy
|
|
420
|
-
|
|
421
|
-
self.nn_ = NeuralNetworkMethods()
|
|
422
|
-
|
|
423
|
-
self.total_loss_tracker = tf.keras.metrics.Mean(name="loss")
|
|
424
|
-
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
|
|
425
|
-
name="reconstruction_loss"
|
|
426
|
-
)
|
|
427
|
-
# self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
|
|
428
|
-
self.accuracy_tracker = tf.keras.metrics.Mean(name="accuracy")
|
|
429
|
-
|
|
430
|
-
# y_train[1] dimension.
|
|
431
|
-
self.n_features = output_shape
|
|
432
|
-
|
|
433
|
-
self.n_components = n_components
|
|
434
|
-
self.weights_initializer = weights_initializer
|
|
435
|
-
self.hidden_layer_sizes = hidden_layer_sizes
|
|
436
|
-
self.num_hidden_layers = num_hidden_layers
|
|
437
|
-
self.hidden_activation = hidden_activation
|
|
438
|
-
self.l1_penalty = l1_penalty
|
|
439
|
-
self.l2_penalty = l2_penalty
|
|
440
|
-
self.dropout_rate = dropout_rate
|
|
441
214
|
self.num_classes = num_classes
|
|
215
|
+
self.gamma = gamma
|
|
216
|
+
self.beta = beta
|
|
217
|
+
self.device = device
|
|
442
218
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
hidden_layer_sizes = nn.validate_hidden_layers(
|
|
446
|
-
self.hidden_layer_sizes, self.num_hidden_layers
|
|
219
|
+
logman = LoggerManager(
|
|
220
|
+
name=__name__, prefix=prefix, verbose=verbose, debug=debug
|
|
447
221
|
)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
self.n_features, self.n_components, hidden_layer_sizes, vae=True
|
|
222
|
+
self.logger = configure_logger(
|
|
223
|
+
logman.get_logger(), verbose=verbose, debug=debug
|
|
451
224
|
)
|
|
452
225
|
|
|
453
|
-
|
|
226
|
+
act = self._resolve_activation(activation)
|
|
454
227
|
|
|
455
|
-
if
|
|
456
|
-
|
|
228
|
+
if isinstance(hidden_layer_sizes, np.ndarray):
|
|
229
|
+
hls = hidden_layer_sizes.tolist()
|
|
457
230
|
else:
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
kernel_initializer = self.weights_initializer
|
|
461
|
-
|
|
462
|
-
if self.hidden_activation.lower() == "leaky_relu":
|
|
463
|
-
activation = LeakyReLU(alpha=0.01)
|
|
464
|
-
|
|
465
|
-
elif self.hidden_activation.lower() == "prelu":
|
|
466
|
-
activation = PReLU()
|
|
467
|
-
|
|
468
|
-
elif self.hidden_activation.lower() == "selu":
|
|
469
|
-
activation = "selu"
|
|
470
|
-
kernel_initializer = "lecun_normal"
|
|
471
|
-
|
|
472
|
-
else:
|
|
473
|
-
activation = self.hidden_activation
|
|
474
|
-
|
|
475
|
-
if num_hidden_layers > 5:
|
|
476
|
-
raise ValueError(
|
|
477
|
-
f"The maximum number of hidden layers is 5, but got "
|
|
478
|
-
f"{num_hidden_layers}"
|
|
479
|
-
)
|
|
231
|
+
hls = hidden_layer_sizes
|
|
480
232
|
|
|
481
233
|
self.encoder = Encoder(
|
|
482
|
-
self.
|
|
483
|
-
self.num_classes,
|
|
484
|
-
self.n_components,
|
|
485
|
-
hidden_layer_sizes,
|
|
486
|
-
self.dropout_rate,
|
|
487
|
-
activation,
|
|
488
|
-
kernel_initializer,
|
|
489
|
-
kernel_regularizer,
|
|
490
|
-
beta=self.kl_beta,
|
|
234
|
+
n_features, self.num_classes, latent_dim, hls, dropout_rate, act
|
|
491
235
|
)
|
|
492
236
|
|
|
493
|
-
|
|
237
|
+
decoder_layer_sizes = list(reversed(hls))
|
|
494
238
|
|
|
495
239
|
self.decoder = Decoder(
|
|
496
|
-
|
|
240
|
+
n_features,
|
|
497
241
|
self.num_classes,
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
kernel_initializer,
|
|
503
|
-
kernel_regularizer,
|
|
242
|
+
latent_dim,
|
|
243
|
+
decoder_layer_sizes,
|
|
244
|
+
dropout_rate,
|
|
245
|
+
act,
|
|
504
246
|
)
|
|
505
247
|
|
|
506
|
-
|
|
507
|
-
|
|
248
|
+
def forward(
|
|
249
|
+
self, x: torch.Tensor
|
|
250
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
251
|
+
"""Performs the forward pass through the full VAE model.
|
|
508
252
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
z_mean, z_log_var, z = self.encoder(inputs)
|
|
512
|
-
reconstruction = self.decoder(z)
|
|
513
|
-
if self._final_activation is not None:
|
|
514
|
-
reconstruction = self.act(reconstruction)
|
|
515
|
-
return reconstruction
|
|
253
|
+
Args:
|
|
254
|
+
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
516
255
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
:noindex:
|
|
256
|
+
Returns:
|
|
257
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the reconstructed output, the latent mean (`z_mean`), and the latent log-variance (`z_log_var`).
|
|
520
258
|
"""
|
|
521
|
-
|
|
522
|
-
|
|
259
|
+
z_mean, z_log_var, z = self.encoder(x)
|
|
260
|
+
reconstruction = self.decoder(z)
|
|
261
|
+
return reconstruction, z_mean, z_log_var
|
|
523
262
|
|
|
524
|
-
def
|
|
525
|
-
|
|
526
|
-
:
|
|
263
|
+
def compute_loss(
|
|
264
|
+
self,
|
|
265
|
+
outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
266
|
+
y: torch.Tensor,
|
|
267
|
+
mask: torch.Tensor | None = None,
|
|
268
|
+
class_weights: torch.Tensor | None = None,
|
|
269
|
+
) -> torch.Tensor:
|
|
270
|
+
"""Computes the VAE loss function (negative ELBO).
|
|
271
|
+
|
|
272
|
+
The loss is the sum of a reconstruction term and a regularizing KL divergence term. The reconstruction loss is calculated using a masked focal loss, and the KL divergence measures the difference between the learned latent distribution and a standard normal prior.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
outputs (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The tuple of (reconstruction, z_mean, z_log_var) from the model's forward pass.
|
|
276
|
+
y (torch.Tensor): The target data tensor, expected to be one-hot encoded. This is converted to class indices internally for the loss function.
|
|
277
|
+
mask (torch.Tensor | None): A boolean mask to exclude missing values from the reconstruction loss.
|
|
278
|
+
class_weights (torch.Tensor | None): Weights to apply to each class in the reconstruction loss to handle imbalance.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
torch.Tensor: The computed scalar loss value.
|
|
527
282
|
"""
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
return [
|
|
536
|
-
self.total_loss_tracker,
|
|
537
|
-
self.reconstruction_loss_tracker,
|
|
538
|
-
# self.kl_loss_tracker,
|
|
539
|
-
self.accuracy_tracker,
|
|
540
|
-
]
|
|
541
|
-
|
|
542
|
-
@tf.function
|
|
543
|
-
def train_step(self, data):
|
|
544
|
-
y = self._y
|
|
545
|
-
|
|
546
|
-
(
|
|
547
|
-
y_true,
|
|
548
|
-
sample_weight,
|
|
549
|
-
missing_mask,
|
|
550
|
-
) = self.nn_.prepare_training_batches(
|
|
551
|
-
y,
|
|
552
|
-
y,
|
|
553
|
-
self._batch_size,
|
|
554
|
-
self._batch_idx,
|
|
555
|
-
True,
|
|
556
|
-
self.n_components,
|
|
557
|
-
self._sample_weight,
|
|
558
|
-
self._missing_mask,
|
|
559
|
-
ubp=False,
|
|
283
|
+
reconstruction, z_mean, z_log_var = outputs
|
|
284
|
+
|
|
285
|
+
# 1. KL Divergence Calculation
|
|
286
|
+
prior = Normal(torch.zeros_like(z_mean), torch.ones_like(z_log_var))
|
|
287
|
+
posterior = Normal(z_mean, torch.exp(0.5 * z_log_var))
|
|
288
|
+
kl_loss = (
|
|
289
|
+
torch.distributions.kl.kl_divergence(posterior, prior).sum(dim=1).mean()
|
|
560
290
|
)
|
|
561
291
|
|
|
562
|
-
if
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
292
|
+
if class_weights is None:
|
|
293
|
+
class_weights = torch.ones(self.num_classes, device=y.device)
|
|
294
|
+
|
|
295
|
+
# 2. Reconstruction Loss Calculation
|
|
296
|
+
# Reverting to the robust method of flattening tensors and using the
|
|
297
|
+
# custom loss function.
|
|
298
|
+
n_classes = reconstruction.shape[-1]
|
|
299
|
+
logits_flat = reconstruction.reshape(-1, n_classes)
|
|
300
|
+
|
|
301
|
+
# Convert one-hot `y` to class indices for the loss function.
|
|
302
|
+
targets_flat = torch.argmax(y, dim=-1).reshape(-1)
|
|
303
|
+
|
|
304
|
+
if mask is None:
|
|
305
|
+
# If no mask is provided, all targets are considered valid.
|
|
306
|
+
mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
|
|
566
307
|
else:
|
|
567
|
-
|
|
308
|
+
# The mask needs to be reshaped to match the flattened targets.
|
|
309
|
+
mask_flat = mask.reshape(-1)
|
|
568
310
|
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
|
|
572
|
-
)
|
|
311
|
+
# Logits, class-index targets, and the valid mask.
|
|
312
|
+
criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
|
|
573
313
|
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
reconstruction,
|
|
579
|
-
tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
# Returns binary crossentropy loss.
|
|
583
|
-
reconstruction_loss = self.compiled_loss(
|
|
584
|
-
y_true_masked,
|
|
585
|
-
y_pred_masked,
|
|
586
|
-
sample_weight=sample_weight_masked,
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
# Doesn't include KL Divergence Loss.
|
|
590
|
-
regularization_loss = sum(self.losses)
|
|
591
|
-
|
|
592
|
-
total_loss = reconstruction_loss + regularization_loss
|
|
593
|
-
|
|
594
|
-
grads = tape.gradient(total_loss, self.trainable_variables)
|
|
595
|
-
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
|
596
|
-
|
|
597
|
-
### NOTE: If you get the error, "'tuple' object has no attribute
|
|
598
|
-
### 'rank', then convert y_true to a tensor object."
|
|
599
|
-
# self.compiled_metrics.update_state(
|
|
600
|
-
self.accuracy_tracker.update_state(
|
|
601
|
-
self.acc_func(
|
|
602
|
-
y_true_masked,
|
|
603
|
-
y_pred_masked,
|
|
604
|
-
)
|
|
314
|
+
reconstruction_loss = criterion(
|
|
315
|
+
logits_flat.to(self.device),
|
|
316
|
+
targets_flat.to(self.device),
|
|
317
|
+
valid_mask=mask_flat.to(self.device),
|
|
605
318
|
)
|
|
606
319
|
|
|
607
|
-
self.
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
# y,
|
|
638
|
-
# reconstruction,
|
|
639
|
-
# sample_weight=sample_weight_masked,
|
|
640
|
-
# )
|
|
641
|
-
|
|
642
|
-
# # Includes KL Divergence Loss.
|
|
643
|
-
# regularization_loss = sum(self.losses)
|
|
644
|
-
|
|
645
|
-
# total_loss = reconstruction_loss + regularization_loss
|
|
646
|
-
|
|
647
|
-
# self.accuracy_tracker.update_state(
|
|
648
|
-
# self.cateogrical_accuracy(
|
|
649
|
-
# y,
|
|
650
|
-
# reconstruction,
|
|
651
|
-
# sample_weight=sample_weight_masked,
|
|
652
|
-
# )
|
|
653
|
-
# )
|
|
654
|
-
|
|
655
|
-
# self.total_loss_tracker.update_state(total_loss)
|
|
656
|
-
# self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
|
657
|
-
# self.kl_loss_tracker.update_state(regularization_loss)
|
|
658
|
-
|
|
659
|
-
# return {
|
|
660
|
-
# "loss": self.total_loss_tracker.result(),
|
|
661
|
-
# "reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
|
662
|
-
# "kl_loss": self.kl_loss_tracker.result(),
|
|
663
|
-
# "accuracy": self.accuracy_tracker.result(),
|
|
664
|
-
# }
|
|
665
|
-
|
|
666
|
-
@property
|
|
667
|
-
def batch_size(self):
|
|
668
|
-
"""Batch (=step) size per epoch."""
|
|
669
|
-
return self._batch_size
|
|
670
|
-
|
|
671
|
-
@property
|
|
672
|
-
def batch_idx(self):
|
|
673
|
-
"""Current batch (=step) index."""
|
|
674
|
-
return self._batch_idx
|
|
675
|
-
|
|
676
|
-
@property
|
|
677
|
-
def y(self):
|
|
678
|
-
return self._y
|
|
679
|
-
|
|
680
|
-
@property
|
|
681
|
-
def missing_mask(self):
|
|
682
|
-
return self._missing_mask
|
|
683
|
-
|
|
684
|
-
@property
|
|
685
|
-
def sample_weight(self):
|
|
686
|
-
return self._sample_weight
|
|
687
|
-
|
|
688
|
-
@batch_size.setter
|
|
689
|
-
def batch_size(self, value):
|
|
690
|
-
"""Set batch_size parameter."""
|
|
691
|
-
self._batch_size = int(value)
|
|
692
|
-
|
|
693
|
-
@batch_idx.setter
|
|
694
|
-
def batch_idx(self, value):
|
|
695
|
-
"""Set current batch (=step) index."""
|
|
696
|
-
self._batch_idx = int(value)
|
|
697
|
-
|
|
698
|
-
@y.setter
|
|
699
|
-
def y(self, value):
|
|
700
|
-
"""Set y after each epoch."""
|
|
701
|
-
self._y = value
|
|
702
|
-
|
|
703
|
-
@missing_mask.setter
|
|
704
|
-
def missing_mask(self, value):
|
|
705
|
-
"""Set y after each epoch."""
|
|
706
|
-
self._missing_mask = value
|
|
707
|
-
|
|
708
|
-
@sample_weight.setter
|
|
709
|
-
def sample_weight(self, value):
|
|
710
|
-
self._sample_weight = value
|
|
320
|
+
return reconstruction_loss + self.beta * kl_loss
|
|
321
|
+
|
|
322
|
+
def _resolve_activation(
|
|
323
|
+
self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
|
|
324
|
+
) -> torch.nn.Module:
|
|
325
|
+
"""Resolves an activation function module from a string name.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
activation (Literal["relu", "elu", "leaky_relu", "selu"]): The name of the activation function.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
torch.nn.Module: The corresponding instantiated PyTorch activation function module.
|
|
332
|
+
|
|
333
|
+
Raises:
|
|
334
|
+
ValueError: If the provided activation name is not supported.
|
|
335
|
+
"""
|
|
336
|
+
if isinstance(activation, str):
|
|
337
|
+
a = activation.lower()
|
|
338
|
+
if a == "relu":
|
|
339
|
+
return nn.ReLU()
|
|
340
|
+
elif a == "elu":
|
|
341
|
+
return nn.ELU()
|
|
342
|
+
elif a in {"leaky_relu", "leakyrelu"}:
|
|
343
|
+
return nn.LeakyReLU()
|
|
344
|
+
elif a == "selu":
|
|
345
|
+
return nn.SELU()
|
|
346
|
+
else:
|
|
347
|
+
msg = f"Activation {activation} not supported."
|
|
348
|
+
self.logger.error(msg)
|
|
349
|
+
raise ValueError(msg)
|