pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -1,707 +1,348 @@
1
- import logging
2
- import os
3
- import sys
4
- import warnings
1
+ from typing import List, Literal, Tuple
5
2
 
6
- # Import tensorflow with reduced warnings.
7
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
8
- logging.getLogger("tensorflow").disabled = True
9
- warnings.filterwarnings("ignore", category=UserWarning)
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from snpio.utils.logging import LoggerManager
8
+ from torch.distributions import Normal
10
9
 
11
- import tensorflow as tf
10
+ from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
12
11
 
13
- # Disable can't find cuda .dll errors. Also turns of GPU support.
14
- tf.config.set_visible_devices([], "GPU")
15
12
 
16
- from tensorflow.python.util import deprecation
13
+ class Sampling(nn.Module):
14
+ """A layer that samples from a latent distribution using the reparameterization trick.
17
15
 
18
- # Disable warnings and info logs.
19
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20
- tf.get_logger().setLevel(logging.ERROR)
21
-
22
-
23
- # Monkey patching deprecation utils to supress warnings.
24
- # noinspection PyUnusedLocal
25
- def deprecated(
26
- date, instructions, warn_once=True
27
- ): # pylint: disable=unused-argument
28
- def deprecated_wrapper(func):
29
- return func
30
-
31
- return deprecated_wrapper
32
-
33
-
34
- deprecation.deprecated = deprecated
35
-
36
- from tensorflow.keras.layers import (
37
- Dropout,
38
- Dense,
39
- Reshape,
40
- Activation,
41
- LeakyReLU,
42
- PReLU,
43
- )
44
-
45
- from tensorflow.keras.regularizers import l1_l2
46
- from tensorflow.keras import backend as K
47
-
48
- # Custom Modules
49
- try:
50
- from ..neural_network_methods import NeuralNetworkMethods
51
- except (ModuleNotFoundError, ValueError, ImportError):
52
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
16
+ This layer is a core component of a Variational Autoencoder (VAE). It takes the mean and log-variance of a latent distribution as input and generates a sample from that distribution. By using the reparameterization trick ($z = \mu + \sigma \cdot \epsilon$), it allows gradients to be backpropagated through the random sampling process, making the VAE trainable.
17
+ """
53
18
 
19
+ def forward(self, z_mean: torch.Tensor, z_log_var: torch.Tensor) -> torch.Tensor:
20
+ """Performs the forward pass to generate a latent sample.
54
21
 
55
- class Sampling(tf.keras.layers.Layer):
56
- """Layer to calculate Z to sample from latent dimension."""
22
+ Args:
23
+ z_mean (torch.Tensor): The mean of the latent normal distribution.
24
+ z_log_var (torch.Tensor): The log of the variance of the latent normal distribution.
57
25
 
58
- def __init__(self, *args, **kwargs):
59
- self.is_placeholder = True
60
- super(Sampling, self).__init__(*args, **kwargs)
26
+ Returns:
27
+ torch.Tensor: A sampled vector from the latent space.
28
+ """
29
+ z_sigma = torch.exp(0.5 * z_log_var) # Precompute outside
61
30
 
62
- def call(self, inputs):
63
- """Sampling during forward pass."""
64
- z_mean, z_log_var = inputs
65
- z_sigma = tf.math.exp(0.5 * z_log_var)
66
- batch = tf.shape(z_mean)[0]
67
- dim = tf.shape(z_mean)[1]
68
- epsilon = tf.random.normal(shape=(batch, dim))
31
+ # Ensure on GPU
32
+ # rand_like takes random values from a normal distribution
33
+ # of the same shape as z_mean.
34
+ epsilon = torch.randn_like(z_mean, device=z_mean.device)
69
35
  return z_mean + z_sigma * epsilon
70
36
 
71
37
 
72
- class Encoder(tf.keras.layers.Layer):
73
- """VAE encoder to Encode genotypes to (z_mean, z_log_var, z).
74
-
75
- Args:
76
- n_features (int): Number of featuresi in input dataset.
77
-
78
- num_classes (int): Number of classes in target data.
79
-
80
- latent_dim (int): Number of latent dimensions to use.
81
-
82
- hidden_layer_sizes (list of int): List of hidden layer sizes to use.
83
-
84
- dropout_rate (float): Dropout rate for Dropout layer.
85
-
86
- activation (str): Hidden activation function to use.
87
-
88
- kernel_initializer (str): Initializer to use for weights.
89
-
90
- kernel_regularizer (tf.keras.regularizers): L1 and/ or L2 objects.
91
-
92
- beta (float, optional): KL divergence beta to use. Defualts to 1.0.
93
-
94
- name (str): Name of model. Defaults to Encoder.
38
+ class Encoder(nn.Module):
39
+ """The Encoder module of a Variational Autoencoder (VAE).
95
40
 
41
+ This module defines the encoder network, which takes high-dimensional input data and maps it to the parameters of a lower-dimensional latent distribution. The architecture consists of a series of fully-connected hidden layers that process the flattened input. The network culminates in two separate linear layers that output the mean (`z_mean`) and log-variance (`z_log_var`) of the approximate posterior distribution, $q(z|x)$.
96
42
  """
97
43
 
98
44
  def __init__(
99
45
  self,
100
- n_features,
101
- num_classes,
102
- latent_dim,
103
- hidden_layer_sizes,
104
- dropout_rate,
105
- activation,
106
- kernel_initializer,
107
- kernel_regularizer,
108
- beta=1.0,
109
- name="Encoder",
110
- **kwargs,
46
+ n_features: int,
47
+ num_classes: int,
48
+ latent_dim: int,
49
+ hidden_layer_sizes: List[int],
50
+ dropout_rate: float,
51
+ activation: torch.nn.Module,
111
52
  ):
112
- super(Encoder, self).__init__(name=name, **kwargs)
53
+ """Initializes the Encoder module.
54
+
55
+ Args:
56
+ n_features (int): The number of features in the input data (e.g., SNPs).
57
+ num_classes (int): The number of possible classes for each input element (e.g., 4 alleles).
58
+ latent_dim (int): The dimensionality of the latent space.
59
+ hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
60
+ dropout_rate (float): The dropout rate for regularization in the hidden layers.
61
+ activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
62
+ """
63
+ super(Encoder, self).__init__()
64
+ self.flatten = nn.Flatten()
65
+ self.activation = (
66
+ getattr(F, activation) if isinstance(activation, str) else activation
67
+ )
113
68
 
114
- self.beta = beta * latent_dim
69
+ layers = []
70
+ # The input dimension accounts for channels
71
+ input_dim = n_features * num_classes
72
+ for hidden_size in hidden_layer_sizes:
73
+ layers.append(nn.Linear(input_dim, hidden_size))
115
74
 
116
- self.dense2 = None
117
- self.dense3 = None
118
- self.dense4 = None
119
- self.dense5 = None
75
+ # BatchNorm can lead to faster convergence.
76
+ layers.append(nn.BatchNorm1d(hidden_size))
120
77
 
121
- # # n_features * num_classes.
122
- self.flatten = tf.keras.layers.Flatten()
78
+ layers.append(nn.Dropout(dropout_rate))
79
+ layers.append(activation)
80
+ input_dim = hidden_size
123
81
 
124
- self.dense1 = Dense(
125
- hidden_layer_sizes[0],
126
- input_shape=(n_features * num_classes,),
127
- activation=activation,
128
- kernel_initializer=kernel_initializer,
129
- kernel_regularizer=kernel_regularizer,
130
- name="Encoder1",
131
- )
82
+ self.hidden_layers = nn.Sequential(*layers)
83
+ self.dense_z_mean = nn.Linear(input_dim, latent_dim)
84
+ self.dense_z_log_var = nn.Linear(input_dim, latent_dim)
85
+ self.sampling = Sampling()
132
86
 
133
- if len(hidden_layer_sizes) >= 2:
134
- self.dense2 = Dense(
135
- hidden_layer_sizes[1],
136
- activation=activation,
137
- kernel_initializer=kernel_initializer,
138
- kernel_regularizer=kernel_regularizer,
139
- name="Encoder2",
140
- )
141
-
142
- if len(hidden_layer_sizes) >= 3:
143
- self.dense3 = Dense(
144
- hidden_layer_sizes[2],
145
- activation=activation,
146
- kernel_initializer=kernel_initializer,
147
- kernel_regularizer=kernel_regularizer,
148
- name="Encoder3",
149
- )
150
-
151
- if len(hidden_layer_sizes) >= 4:
152
- self.dense4 = Dense(
153
- hidden_layer_sizes[3],
154
- activation=activation,
155
- kernel_initializer=kernel_initializer,
156
- kernel_regularizer=kernel_regularizer,
157
- name="Encoder4",
158
- )
159
-
160
- if len(hidden_layer_sizes) == 5:
161
- self.dense5 = Dense(
162
- hidden_layer_sizes[4],
163
- activation=activation,
164
- kernel_initializer=kernel_initializer,
165
- kernel_regularizer=kernel_regularizer,
166
- name="Encoder5",
167
- )
168
-
169
- self.dense_z_mean = Dense(
170
- latent_dim,
171
- name="z_mean",
172
- )
173
- self.dense_z_log_var = Dense(
174
- latent_dim,
175
- name="z_log_var",
176
- )
177
- # z_mean and z_log_var are inputs.
178
- self.sampling = Sampling(
179
- name="z",
180
- )
87
+ def forward(
88
+ self, x: torch.Tensor
89
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90
+ """Performs the forward pass through the encoder.
181
91
 
182
- self.dense_latent = Dense(
183
- latent_dim,
184
- activation=activation,
185
- kernel_initializer=kernel_initializer,
186
- kernel_regularizer=kernel_regularizer,
187
- name="Encoder5",
188
- )
92
+ Args:
93
+ x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
189
94
 
190
- self.dropout_layer = Dropout(dropout_rate)
191
-
192
- def call(self, inputs, training=None):
193
- """Forward pass for model."""
194
- x = self.flatten(inputs)
195
- x = self.dense1(x)
196
- x = self.dropout_layer(x, training=training)
197
- if self.dense2 is not None:
198
- x = self.dense2(x)
199
- x = self.dropout_layer(x, training=training)
200
- if self.dense3 is not None:
201
- x = self.dense3(x)
202
- x = self.dropout_layer(x, training=training)
203
- if self.dense4 is not None:
204
- x = self.dense4(x)
205
- x = self.dropout_layer(x, training=training)
206
- if self.dense5 is not None:
207
- x = self.dense5(x)
208
- x = self.dropout_layer(x, training=training)
209
-
210
- x = self.dense_latent(x)
95
+ Returns:
96
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the latent mean (`z_mean`), latent log-variance (`z_log_var`), and a sample from the latent distribution (`z`).
97
+ """
98
+ x = self.flatten(x)
99
+ x = self.hidden_layers(x)
211
100
  z_mean = self.dense_z_mean(x)
212
101
  z_log_var = self.dense_z_log_var(x)
213
-
214
- # Compute the KL divergence
215
- kl_loss = -0.5 * tf.reduce_sum(
216
- 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1
217
- )
218
- # Add the KL divergence to the model's total loss
219
- self.add_loss(self.beta * tf.reduce_mean(kl_loss))
220
-
221
- z = self.sampling([z_mean, z_log_var])
222
-
102
+ z = self.sampling(z_mean, z_log_var)
223
103
  return z_mean, z_log_var, z
224
104
 
225
105
 
226
- class Decoder(tf.keras.layers.Layer):
227
- """Converts z, the encoded vector, back into the reconstructed output.
228
-
229
- Args:
230
- n_features (int): Number of features in input dataset.
231
-
232
- num_classes (int): Number of classes in input dataset.
106
+ class Decoder(nn.Module):
107
+ """The Decoder module of a Variational Autoencoder (VAE).
233
108
 
234
- latent_dim (int): Number of latent dimensions to use.
235
-
236
- hidden_layer_sizes (list of int): List of hidden layer sizes to use.
237
-
238
- dropout_rate (float): Dropout rate for Dropout layer.
239
-
240
- activation (str): Hidden activation function to use.
241
-
242
- kernel initializer (str): Function for initilizing weights.
243
-
244
- kernel_regularizer (tf.keras.regularizer): Initialized L1 and/ or L2 regularizer.
245
-
246
- name (str): Name of model. Defaults to "Decoder".
109
+ This module defines the decoder network, which takes a sample from the low-dimensional latent space and maps it back to the high-dimensional data space. It aims to reconstruct the original input data. The architecture consists of a series of fully-connected hidden layers followed by a final linear layer that produces the reconstructed data, which is then reshaped to match the original input's dimensions.
247
110
  """
248
111
 
249
112
  def __init__(
250
113
  self,
251
- n_features,
252
- num_classes,
253
- latent_dim,
254
- hidden_layer_sizes,
255
- dropout_rate,
256
- activation,
257
- kernel_initializer,
258
- kernel_regularizer,
259
- name="Decoder",
260
- **kwargs,
261
- ):
262
- super(Decoder, self).__init__(name=name, **kwargs)
263
-
264
- self.dense2 = None
265
- self.dense3 = None
266
- self.dense4 = None
267
- self.dense5 = None
268
-
269
- self.dense1 = Dense(
270
- hidden_layer_sizes[0],
271
- input_shape=(latent_dim,),
272
- activation=activation,
273
- kernel_initializer=kernel_initializer,
274
- kernel_regularizer=kernel_regularizer,
275
- name="Decoder1",
276
- )
277
-
278
- if len(hidden_layer_sizes) >= 2:
279
- self.dense2 = Dense(
280
- hidden_layer_sizes[1],
281
- activation=activation,
282
- kernel_initializer=kernel_initializer,
283
- kernel_regularizer=kernel_regularizer,
284
- name="Decoder2",
285
- )
286
-
287
- if len(hidden_layer_sizes) >= 3:
288
- self.dense3 = Dense(
289
- hidden_layer_sizes[2],
290
- activation=activation,
291
- kernel_initializer=kernel_initializer,
292
- kernel_regularizer=kernel_regularizer,
293
- name="Decoder3",
294
- )
295
-
296
- if len(hidden_layer_sizes) >= 4:
297
- self.dense4 = Dense(
298
- hidden_layer_sizes[3],
299
- activation=activation,
300
- kernel_initializer=kernel_initializer,
301
- kernel_regularizer=kernel_regularizer,
302
- name="Decoder4",
303
- )
304
-
305
- if len(hidden_layer_sizes) == 5:
306
- self.dense5 = Dense(
307
- hidden_layer_sizes[4],
308
- activation=activation,
309
- kernel_initializer=kernel_initializer,
310
- kernel_regularizer=kernel_regularizer,
311
- name="Decoder5",
312
- )
313
-
314
- # No activation for final layer.
315
- self.dense_output = Dense(
316
- n_features * num_classes,
317
- kernel_initializer=kernel_initializer,
318
- kernel_regularizer=kernel_regularizer,
319
- name="DecoderExpanded",
320
- )
321
-
322
- self.rshp = Reshape((n_features, num_classes))
323
-
324
- self.dropout_layer = Dropout(dropout_rate)
325
-
326
- def call(self, inputs, training=None):
327
- """Forward pass for model."""
328
- x = self.dense1(inputs)
329
- x = self.dropout_layer(x, training=training)
330
- if self.dense2 is not None:
331
- x = self.dense2(x)
332
- x = self.dropout_layer(x, training=training)
333
- if self.dense3 is not None:
334
- x = self.dense3(x)
335
- x = self.dropout_layer(x, training=training)
336
- if self.dense4 is not None:
337
- x = self.dense4(x)
338
- x = self.dropout_layer(x, training=training)
339
- if self.dense5 is not None:
340
- x = self.dense5(x)
341
- x = self.dropout_layer(x, training=training)
342
-
343
- x = self.dense_output(x)
344
- return self.rshp(x)
345
-
346
-
347
- class VAEModel(tf.keras.Model):
348
- """Variational Autoencoder model. Runs the encdoer and decoder and outputs the reconsruction.
349
-
350
- Args:
351
- output_shape (tuple): Shape of output. Defaults to None.
352
-
353
- n_components (int): Number of latent dimensions to use. Defaults to 3.
354
-
355
- weights_initializer (str, optional): kernel initializer to use. Defaults to "glorot_normal".
356
-
357
- hidden_layer_sizes (str or list, optional): List of hidden layer sizes to use, or can use "midpoint", "log2", or "sqrt". Defaults to "midpoint".
358
-
359
- num_hidden_layers (int, optional): Number of hidden layers to use. Defaults to 1.
360
-
361
- hidden_activation (str, optional): Hidden activation function to use. Defaults to "elu".
114
+ n_features: int,
115
+ num_classes: int,
116
+ latent_dim: int,
117
+ hidden_layer_sizes: List[int],
118
+ dropout_rate: float,
119
+ activation: torch.nn.Module,
120
+ ) -> None:
121
+ """Initializes the Decoder module.
122
+
123
+ Args:
124
+ n_features (int): The number of features in the output data (e.g., SNPs).
125
+ num_classes (int): The number of possible classes for each output element (e.g., 4 alleles).
126
+ latent_dim (int): The dimensionality of the input latent space.
127
+ hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
128
+ dropout_rate (float): The dropout rate for regularization in the hidden layers.
129
+ activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
130
+ """
131
+ super(Decoder, self).__init__()
362
132
 
363
- l1_penalty (float, optional): L1 regularization penalty to use. Defaults to 1e-06.
133
+ layers = []
134
+ input_dim = latent_dim
135
+ for hidden_size in hidden_layer_sizes:
136
+ layers.append(nn.Linear(input_dim, hidden_size))
364
137
 
365
- l2_penalty (float, optional): L2 regularization penalty. Defaults to 1e-06.
138
+ # BatchNorm can lead to faster convergence.
139
+ layers.append(nn.BatchNorm1d(hidden_size))
366
140
 
367
- dropout_rate (float, optional): Dropout rate to use for Dropout layers.
141
+ layers.append(nn.Dropout(dropout_rate))
142
+ layers.append(activation)
143
+ input_dim = hidden_size
368
144
 
369
- kl_beta (float, optional): Beta to use for KL divergence loss. The KL divergence gets multiplied by this value. Defaults to 1.0 (no scaling).
145
+ self.hidden_layers = nn.Sequential(*layers)
146
+ # UPDATED: Output dimension must account for channels
147
+ output_dim = n_features * num_classes
148
+ self.dense_output = nn.Linear(input_dim, output_dim)
149
+ # UPDATED: Reshape must account for channels
150
+ self.reshape = (n_features, num_classes)
370
151
 
371
- num_classes (int, optional): Number of classes in input data. Defaults to 10.
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ """Performs the forward pass through the decoder.
372
154
 
373
- sample_weight (list of float, optional): sample weights to use. Defaults to None.
155
+ Args:
156
+ x (torch.Tensor): The input latent tensor of shape `(batch_size, latent_dim)`.
374
157
 
375
- missing_mask (np.ndarray, optional): Missing mask to use in model for masking missing values. Defaults to None.
158
+ Returns:
159
+ torch.Tensor: The reconstructed output data of shape `(batch_size, n_features, num_classes)`.
160
+ """
161
+ x = self.hidden_layers(x)
162
+ x = self.dense_output(x)
163
+ return x.view(-1, *self.reshape)
376
164
 
377
- batch_size (int, optional): Batch size to use for training. Defaults to 32.
378
165
 
379
- final_activation (str, optional): Final activation function to use. Should be either "sigmoid" (if doing multilabel classification) or "softmax" (if doing multiclass). If left None, then activation is not performed. Defaults to None.
166
+ class VAEModel(nn.Module):
167
+ """A Variational Autoencoder (VAE) model for imputation.
380
168
 
381
- y (np.ndarray, optional): Input dataset y. Should be the full dataset. It will get subset by a callback to get batch_size samples. Default to None.
169
+ This class combines an `Encoder` and a `Decoder` to form a VAE, a generative model for learning complex data distributions. It is designed for imputing missing values in categorical data, such as genomic SNPs. The model is trained by maximizing the Evidence Lower Bound (ELBO), which is a lower bound on the log-likelihood of the data.
382
170
 
171
+ **Objective Function (ELBO):**
172
+ The VAE loss function is derived from the ELBO and consists of two main components: a reconstruction term and a regularization term.
173
+ $$
174
+ \\mathcal{L}(\\theta, \\phi; x) = \\underbrace{\\mathbb{E}_{q_{\\phi}(z|x)}[\\log p_{\\theta}(x|z)]}_{\\text{Reconstruction Loss}} - \\underbrace{D_{KL}(q_{\\phi}(z|x) || p(z))}_{\\text{KL Divergence}}
175
+ $$
176
+ - The **Reconstruction Loss** encourages the decoder to accurately reconstruct the input data from its latent representation. This implementation uses a `MaskedFocalLoss`.
177
+ - The **KL Divergence** acts as a regularizer, forcing the approximate posterior distribution $q_{\\phi}(z|x)$ learned by the encoder to be close to a prior distribution $p(z)$ (typically a standard normal distribution).
383
178
  """
384
179
 
385
180
  def __init__(
386
181
  self,
387
- output_shape=None,
388
- n_components=3,
389
- weights_initializer="glorot_normal",
390
- hidden_layer_sizes="midpoint",
391
- num_hidden_layers=1,
392
- hidden_activation="elu",
393
- l1_penalty=1e-6,
394
- l2_penalty=1e-6,
395
- dropout_rate=0.2,
396
- kl_beta=1.0,
397
- num_classes=10,
398
- sample_weight=None,
399
- missing_mask=None,
400
- batch_size=32,
401
- final_activation=None,
402
- y=None,
182
+ n_features: int,
183
+ prefix: str,
184
+ *,
185
+ num_classes: int = 4,
186
+ hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
187
+ latent_dim: int = 2,
188
+ dropout_rate: float = 0.2,
189
+ activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
190
+ gamma: float = 2.0,
191
+ beta: float = 1.0,
192
+ device: Literal["cpu", "gpu", "mps"] = "cpu",
193
+ verbose: bool = False,
194
+ debug: bool = False,
403
195
  ):
196
+ """Initializes the VAEModel.
197
+
198
+ Args:
199
+ n_features (int): The number of features in the input data (e.g., SNPs).
200
+ prefix (str): A prefix used for logging.
201
+ num_classes (int): The number of possible classes for each input element. Defaults to 4.
202
+ hidden_layer_sizes (List[int] | np.ndarray): A list of integers specifying the size of each hidden layer in the encoder and decoder. Defaults to [128, 64].
203
+ latent_dim (int): The dimensionality of the latent space. Defaults to 2.
204
+ dropout_rate (float): The dropout rate for regularization in the hidden layers. Defaults to 0.2.
205
+ activation (str): The name of the activation function to use in hidden layers. Defaults to "relu".
206
+ gamma (float): The focusing parameter for the focal loss component. Defaults to 2.0.
207
+ beta (float): A weighting factor for the KL divergence term in the total loss ($\beta$-VAE). Defaults to 1.0.
208
+ device (Literal["cpu", "gpu", "mps"]): The device to run the model on.
209
+ verbose (bool): If True, enables detailed logging. Defaults to False.
210
+ debug (bool): If True, enables debug mode. Defaults to False.
211
+ """
404
212
  super(VAEModel, self).__init__()
405
-
406
- self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
407
- self.binary_accuracy_tracker = tf.keras.metrics.Mean(
408
- name="binary_accuracy"
409
- )
410
-
411
- self.kl_beta = K.variable(0.0)
412
- self.kl_beta._trainable = False
413
-
414
- self._sample_weight = sample_weight
415
- self._missing_mask = missing_mask
416
- self._batch_idx = 0
417
- self._batch_size = batch_size
418
- self._y = y
419
-
420
- self._final_activation = final_activation
421
- if num_classes == 10 or num_classes == 3:
422
- self.acc_func = tf.keras.metrics.categorical_accuracy
423
- elif num_classes == 4:
424
- self.acc_func = tf.keras.metrics.binary_accuracy
425
-
426
- self.nn_ = NeuralNetworkMethods()
427
-
428
- # y_train[1] dimension.
429
- self.n_features = output_shape
430
-
431
- self.n_components = n_components
432
- self.weights_initializer = weights_initializer
433
- self.hidden_layer_sizes = hidden_layer_sizes
434
- self.num_hidden_layers = num_hidden_layers
435
- self.hidden_activation = hidden_activation
436
- self.l1_penalty = l1_penalty
437
- self.l2_penalty = l2_penalty
438
- self.dropout_rate = dropout_rate
439
213
  self.num_classes = num_classes
214
+ self.gamma = gamma
215
+ self.beta = beta
216
+ self.device = device
440
217
 
441
- nn = NeuralNetworkMethods()
442
-
443
- hidden_layer_sizes = nn.validate_hidden_layers(
444
- self.hidden_layer_sizes, self.num_hidden_layers
445
- )
446
-
447
- hidden_layer_sizes = nn.get_hidden_layer_sizes(
448
- self.n_features, self.n_components, hidden_layer_sizes, vae=True
218
+ logman = LoggerManager(
219
+ name=__name__, prefix=prefix, verbose=verbose, debug=debug
449
220
  )
221
+ self.logger = logman.get_logger()
450
222
 
451
- hidden_layer_sizes = [h * self.num_classes for h in hidden_layer_sizes]
452
-
453
- if self.l1_penalty == 0.0 and self.l2_penalty == 0.0:
454
- kernel_regularizer = None
455
- else:
456
- kernel_regularizer = l1_l2(self.l1_penalty, self.l2_penalty)
457
-
458
- kernel_initializer = self.weights_initializer
459
-
460
- if self.hidden_activation.lower() == "leaky_relu":
461
- activation = LeakyReLU(alpha=0.01)
462
-
463
- elif self.hidden_activation.lower() == "prelu":
464
- activation = PReLU()
465
-
466
- elif self.hidden_activation.lower() == "selu":
467
- activation = "selu"
468
- kernel_initializer = "lecun_normal"
469
-
470
- else:
471
- activation = self.hidden_activation
472
-
473
- if num_hidden_layers > 5:
474
- raise ValueError(
475
- f"The maximum number of hidden layers is 5, but got "
476
- f"{num_hidden_layers}"
477
- )
223
+ activation = self._resolve_activation(activation)
478
224
 
479
225
  self.encoder = Encoder(
480
- self.n_features,
226
+ n_features,
481
227
  self.num_classes,
482
- self.n_components,
228
+ latent_dim,
483
229
  hidden_layer_sizes,
484
- self.dropout_rate,
230
+ dropout_rate,
485
231
  activation,
486
- kernel_initializer,
487
- kernel_regularizer,
488
- beta=self.kl_beta,
489
232
  )
490
233
 
491
- hidden_layer_sizes.reverse()
234
+ decoder_layer_sizes = list(reversed(hidden_layer_sizes))
492
235
 
493
236
  self.decoder = Decoder(
494
- self.n_features,
237
+ n_features,
495
238
  self.num_classes,
496
- self.n_components,
497
- hidden_layer_sizes,
498
- self.dropout_rate,
239
+ latent_dim,
240
+ decoder_layer_sizes,
241
+ dropout_rate,
499
242
  activation,
500
- kernel_initializer,
501
- kernel_regularizer,
502
243
  )
503
244
 
504
- self.activation = Activation("sigmoid")
245
+ def forward(
246
+ self, x: torch.Tensor
247
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
248
+ """Performs the forward pass through the full VAE model.
505
249
 
506
- def call(self, inputs, training=None):
507
- """Forward pass for model."""
508
- z_mean, z_log_var, z = self.encoder(inputs)
509
- reconstruction = self.decoder(z)
510
- return self.activation(reconstruction)
250
+ Args:
251
+ x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
511
252
 
512
- def model(self):
513
- """Here so that mymodel.model().summary() can be called for debugging.
514
- :noindex:
253
+ Returns:
254
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the reconstructed output, the latent mean (`z_mean`), and the latent log-variance (`z_log_var`).
515
255
  """
516
- x = tf.keras.Input(shape=(self.n_features, self.num_classes))
517
- return tf.keras.Model(inputs=[x], outputs=self.call(x))
256
+ z_mean, z_log_var, z = self.encoder(x)
257
+ reconstruction = self.decoder(z)
258
+ return reconstruction, z_mean, z_log_var
518
259
 
519
- def set_model_outputs(self):
520
- """Set model output dimensions for building model.
521
- :noindex:
260
+ def compute_loss(
261
+ self,
262
+ outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
263
+ y: torch.Tensor,
264
+ mask: torch.Tensor | None = None,
265
+ class_weights: torch.Tensor | None = None,
266
+ ) -> torch.Tensor:
267
+ """Computes the VAE loss function (negative ELBO).
268
+
269
+ The loss is the sum of a reconstruction term and a regularizing KL divergence term. The reconstruction loss is calculated using a masked focal loss, and the KL divergence measures the difference between the learned latent distribution and a standard normal prior.
270
+
271
+ Args:
272
+ outputs (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The tuple of (reconstruction, z_mean, z_log_var) from the model's forward pass.
273
+ y (torch.Tensor): The target data tensor, expected to be one-hot encoded. This is converted to class indices internally for the loss function.
274
+ mask (torch.Tensor | None): A boolean mask to exclude missing values from the reconstruction loss.
275
+ class_weights (torch.Tensor | None): Weights to apply to each class in the reconstruction loss to handle imbalance.
276
+
277
+ Returns:
278
+ torch.Tensor: The computed scalar loss value.
522
279
  """
523
- x = tf.keras.Input(shape=(self.n_features, self.num_classes))
524
- model = tf.keras.Model(inputs=[x], outputs=self.call(x))
525
- self.outputs = model.outputs
526
-
527
- @property
528
- def metrics(self):
529
- return [
530
- self.total_loss_tracker,
531
- self.binary_accuracy_tracker,
532
- ]
533
-
534
- @tf.function
535
- def train_step(self, data):
536
- y = self._y
537
-
538
- (
539
- y_true,
540
- sample_weight,
541
- missing_mask,
542
- ) = self.nn_.prepare_training_batches(
543
- y,
544
- y,
545
- self._batch_size,
546
- self._batch_idx,
547
- True,
548
- self.n_components,
549
- self._sample_weight,
550
- self._missing_mask,
551
- ubp=False,
552
- )
553
-
554
- if sample_weight is not None:
555
- sample_weight_masked = tf.convert_to_tensor(
556
- sample_weight[~missing_mask], dtype=tf.float32
557
- )
558
- else:
559
- sample_weight_masked = None
280
+ reconstruction, z_mean, z_log_var = outputs
560
281
 
561
- y_true_masked = tf.boolean_mask(
562
- tf.convert_to_tensor(y_true, dtype=tf.float32),
563
- tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
282
+ # 1. KL Divergence Calculation
283
+ prior = Normal(torch.zeros_like(z_mean), torch.ones_like(z_log_var))
284
+ posterior = Normal(z_mean, torch.exp(0.5 * z_log_var))
285
+ kl_loss = (
286
+ torch.distributions.kl.kl_divergence(posterior, prior).sum(dim=1).mean()
564
287
  )
565
288
 
566
- with tf.GradientTape() as tape:
567
- reconstruction = self(y_true, training=True)
289
+ if class_weights is None:
290
+ class_weights = torch.ones(self.num_classes, device=y.device)
568
291
 
569
- y_pred_masked = tf.boolean_mask(
570
- reconstruction,
571
- tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
572
- )
292
+ # 2. Reconstruction Loss Calculation
293
+ # Reverting to the robust method of flattening tensors and using the
294
+ # custom loss function.
295
+ n_classes = reconstruction.shape[-1]
296
+ logits_flat = reconstruction.reshape(-1, n_classes)
573
297
 
574
- # Returns binary crossentropy loss.
575
- reconstruction_loss = self.compiled_loss(
576
- y_true_masked,
577
- y_pred_masked,
578
- sample_weight=sample_weight_masked,
579
- regularization_losses=self.losses,
580
- )
298
+ # Convert one-hot `y` to class indices for the loss function.
299
+ targets_flat = torch.argmax(y, dim=-1).reshape(-1)
581
300
 
582
- # Doesn't include KL Divergence Loss.
583
- regularization_loss = sum(self.losses)
584
-
585
- total_loss = reconstruction_loss + regularization_loss
586
-
587
- grads = tape.gradient(total_loss, self.trainable_variables)
588
- self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
589
-
590
- self.total_loss_tracker.update_state(total_loss)
591
- self.binary_accuracy_tracker.update_state(
592
- tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
593
- )
594
-
595
- return {
596
- "loss": self.total_loss_tracker.result(),
597
- "binary_accuracy": self.binary_accuracy_tracker.result(),
598
- }
599
-
600
- @tf.function
601
- def test_step(self, data):
602
- y = self._y
603
-
604
- (
605
- y_true,
606
- sample_weight,
607
- missing_mask,
608
- ) = self.nn_.prepare_training_batches(
609
- y,
610
- y,
611
- self._batch_size,
612
- self._batch_idx,
613
- True,
614
- self.n_components,
615
- self._sample_weight,
616
- self._missing_mask,
617
- ubp=False,
618
- )
619
-
620
- if sample_weight is not None:
621
- sample_weight_masked = tf.convert_to_tensor(
622
- sample_weight[~missing_mask], dtype=tf.float32
623
- )
301
+ if mask is None:
302
+ # If no mask is provided, all targets are considered valid.
303
+ mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
624
304
  else:
625
- sample_weight_masked = None
626
-
627
- y_true_masked = tf.boolean_mask(
628
- tf.convert_to_tensor(y_true, dtype=tf.float32),
629
- tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
630
- )
305
+ # The mask needs to be reshaped to match the flattened targets.
306
+ mask_flat = mask.reshape(-1)
631
307
 
632
- reconstruction = self(y_true, training=False)
308
+ # Logits, class-index targets, and the valid mask.
309
+ criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
633
310
 
634
- y_pred_masked = tf.boolean_mask(
635
- reconstruction,
636
- tf.reduce_any(tf.not_equal(y_true, -1), axis=-1),
311
+ reconstruction_loss = criterion(
312
+ logits_flat.to(self.device),
313
+ targets_flat.to(self.device),
314
+ valid_mask=mask_flat.to(self.device),
637
315
  )
638
316
 
639
- reconstruction_loss = self.compiled_loss(
640
- y_true_masked,
641
- y_pred_masked,
642
- sample_weight=sample_weight_masked,
643
- regularization_losses=self.losses,
644
- )
317
+ return reconstruction_loss + self.beta * kl_loss
645
318
 
646
- # Doesn't include KL Divergence Loss.
647
- regularization_loss = sum(self.losses)
319
+ def _resolve_activation(
320
+ self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
321
+ ) -> torch.nn.Module:
322
+ """Resolves an activation function module from a string name.
648
323
 
649
- total_loss = reconstruction_loss + regularization_loss
324
+ Args:
325
+ activation (Literal["relu", "elu", "leaky_relu", "selu"]): The name of the activation function.
650
326
 
651
- ### NOTE: If you get the error, "'tuple' object has no attribute
652
- ### 'rank', then convert y_true to a tensor object."
653
- self.total_loss_tracker.update_state(total_loss)
654
- self.binary_accuracy_tracker.update_state(
655
- tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
656
- )
327
+ Returns:
328
+ torch.nn.Module: The corresponding instantiated PyTorch activation function module.
657
329
 
658
- return {
659
- "loss": self.total_loss_tracker.result(),
660
- "binary_accuracy": self.binary_accuracy_tracker.result(),
661
- }
662
-
663
- @property
664
- def batch_size(self):
665
- """Batch (=step) size per epoch."""
666
- return self._batch_size
667
-
668
- @property
669
- def batch_idx(self):
670
- """Current batch (=step) index."""
671
- return self._batch_idx
672
-
673
- @property
674
- def y(self):
675
- return self._y
676
-
677
- @property
678
- def missing_mask(self):
679
- return self._missing_mask
680
-
681
- @property
682
- def sample_weight(self):
683
- return self._sample_weight
684
-
685
- @batch_size.setter
686
- def batch_size(self, value):
687
- """Set batch_size parameter."""
688
- self._batch_size = int(value)
689
-
690
- @batch_idx.setter
691
- def batch_idx(self, value):
692
- """Set current batch (=step) index."""
693
- self._batch_idx = int(value)
694
-
695
- @y.setter
696
- def y(self, value):
697
- """Set y after each epoch."""
698
- self._y = value
699
-
700
- @missing_mask.setter
701
- def missing_mask(self, value):
702
- """Set y after each epoch."""
703
- self._missing_mask = value
704
-
705
- @sample_weight.setter
706
- def sample_weight(self, value):
707
- self._sample_weight = value
330
+ Raises:
331
+ ValueError: If the provided activation name is not supported.
332
+ """
333
+ if isinstance(activation, str):
334
+ activation = activation.lower()
335
+ if activation == "relu":
336
+ return nn.ReLU()
337
+ elif activation == "elu":
338
+ return nn.ELU()
339
+ elif activation in ["leaky_relu", "leakyrelu"]:
340
+ return nn.LeakyReLU()
341
+ elif activation == "selu":
342
+ return nn.SELU()
343
+ else:
344
+ msg = f"Activation {activation} not supported."
345
+ self.logger.error(msg)
346
+ raise ValueError(msg)
347
+
348
+ return activation