pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -1,445 +1,206 @@
1
- import logging
2
- import os
3
- import sys
4
- import warnings
1
+ from typing import List, Literal
5
2
 
6
- # Import tensorflow with reduced warnings.
7
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
8
- logging.getLogger("tensorflow").disabled = True
9
- warnings.filterwarnings("ignore", category=UserWarning)
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from snpio.utils.logging import LoggerManager
10
7
 
11
- # noinspection PyPackageRequirements
12
- import tensorflow as tf
8
+ from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
9
+ from pgsui.utils.logging_utils import configure_logger
13
10
 
14
- # Disable can't find cuda .dll errors. Also turns off GPU support.
15
- tf.config.set_visible_devices([], "GPU")
16
11
 
17
- from tensorflow.python.util import deprecation
12
+ class NLPCAModel(nn.Module):
13
+ r"""A non-linear Principal Component Analysis (NLPCA) decoder for genotypes.
18
14
 
19
- # Disable warnings and info logs.
20
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
21
- tf.get_logger().setLevel(logging.ERROR)
15
+ This module maps a low-dimensional latent vector to logits over genotype states
16
+ (two classes for haploids or three for diploids) at every locus. It is a fully
17
+ connected network with optional batch normalization and dropout layers and is
18
+ used as the decoder inside the NLPCA imputer.
22
19
 
20
+ **Model Architecture**
23
21
 
24
- # Monkey patching deprecation utils to supress warnings.
25
- # noinspection PyUnusedLocal
26
- def deprecated(
27
- date, instructions, warn_once=True
28
- ): # pylint: disable=unused-argument
29
- def deprecated_wrapper(func):
30
- return func
22
+ Let :math:`z \in \mathbb{R}^{d_{\text{latent}}}` be the latent vector. For a
23
+ network with :math:`L` hidden layers, the transformations are
31
24
 
32
- return deprecated_wrapper
25
+ .. math::
33
26
 
27
+ h_1 = f(W_1 z + b_1)
34
28
 
35
- deprecation.deprecated = deprecated
29
+ .. math::
36
30
 
37
- from tensorflow.keras.layers import (
38
- Dropout,
39
- Dense,
40
- Reshape,
41
- LeakyReLU,
42
- PReLU,
43
- )
31
+ h_2 = f(W_2 h_1 + b_2)
44
32
 
45
- from tensorflow.keras.regularizers import l1_l2
33
+ .. math::
46
34
 
47
- # Custom Modules
48
- try:
49
- from ..neural_network_methods import NeuralNetworkMethods
50
- except (ModuleNotFoundError, ValueError, ImportError):
51
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
35
+ \vdots
52
36
 
37
+ .. math::
53
38
 
54
- class NLPCAModel(tf.keras.Model):
55
- """NLPCA model to train and use to predict imputations.
39
+ h_L = f(W_L h_{L-1} + b_L)
56
40
 
57
- NLPCAModel subclasses the tf.keras.Model and overrides the train_step function, which does training and evaluation for each batch in each epoch.
41
+ The final layer produces logits of shape ``(batch_size, n_features, num_classes)``
42
+ by reshaping a linear projection back to the (loci, genotype-state) grid.
58
43
 
59
- Args:
60
- V (numpy.ndarray(float)): V should have been randomly initialized and will be used as the input data that gets refined during training. Defaults to None.
61
-
62
- y (numpy.ndarray): Target values to predict. Actual input data. Defaults to None.
63
-
64
- batch_size (int, optional): Batch size per epoch. Defaults to 32.
65
-
66
- missing_mask (numpy.ndarray): Missing data mask for y. Defaults to None.
67
-
68
- output_shape (int): Output units for n_features dimension. Output will be of shape (batch_size, n_features). Defaults to None.
69
-
70
- n_components (int, optional): Number of features in input V to use. Defaults to 3.
71
-
72
- weights_initializer (str, optional): Kernel initializer to use for initializing model weights. Defaults to "glorot_normal".
73
-
74
- hidden_layer_sizes (List[int], optional): Output units for each hidden layer. List should be of same length as the number of hidden layers. Defaults to "midpoint".
75
-
76
- num_hidden_layers (int, optional): Number of hidden layers to use. Defaults to 1.
77
-
78
- hidden_activation (str, optional): Activation function to use for hidden layers. Defaults to "elu".
79
-
80
- l1_penalty (float, optional): L1 regularization penalty to use to reduce overfitting. Defaults to 0.01.
81
-
82
- l2_penalty (float, optional): L2 regularization penalty to use to reduce overfitting. Defaults to 0.01.
83
-
84
- dropout_rate (float, optional): Dropout rate during training to reduce overfitting. Must be a float between 0 and 1. Defaults to 0.2.
85
-
86
- num_classes (int, optional): Number of classes in output. Corresponds to the 3rd dimension of the output shape (batch_size, n_features, num_classes). Defaults to 1.
87
-
88
- phase (NoneType): Here for compatibility with UBP.
89
-
90
- sample_weight (numpy.ndarray, optional): 2D sample weights of shape (n_samples, n_features). Should have values for each class weighted. Defaults to None.
91
-
92
- Example:
93
- >>>model = NLPCAModel(V=V, y=y, batch_size=32, missing_mask=missing_mask, output_shape, n_components, weights_initializer, hidden_layer_sizes, num_hidden_layers, hidden_activation, l1_penalty, l2_penalty, dropout_rate, num_classes=3)
94
- >>>
95
- >>>model.compile(optimizer=optimizer, loss=loss_func, metrics=[my_metrics], run_eagerly=True)
96
- >>>
97
- >>>history = model.fit(X, y, batch_size=batch_size, epochs=epochs, callbacks=[MyCallback()], validation_split=validation_split, shuffle=False)
98
-
99
- Raises:
100
- TypeError: V, y, missing_mask, output_shape must not be NoneType.
101
- ValueError: Maximum of 5 hidden layers.
44
+ **Loss Function**
102
45
 
46
+ Training minimizes ``MaskedFocalLoss``, which extends cross-entropy with class
47
+ weighting, focal re-weighting, and masking so that only observed genotypes
48
+ contribute to the objective.
103
49
  """
104
50
 
105
51
  def __init__(
106
52
  self,
107
- V=None,
108
- y=None,
109
- batch_size=32,
110
- missing_mask=None,
111
- output_shape=None,
112
- n_components=3,
113
- weights_initializer="glorot_normal",
114
- hidden_layer_sizes="midpoint",
115
- num_hidden_layers=1,
116
- hidden_activation="elu",
117
- l1_penalty=0.01,
118
- l2_penalty=0.01,
119
- dropout_rate=0.2,
120
- num_classes=3,
121
- phase=None,
122
- sample_weight=None,
53
+ n_features: int,
54
+ prefix: str,
55
+ *,
56
+ num_classes: int = 4,
57
+ hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
58
+ latent_dim: int = 2,
59
+ dropout_rate: float = 0.2,
60
+ activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
61
+ gamma: float = 2.0,
62
+ device: Literal["gpu", "cpu", "mps"] = "cpu",
63
+ verbose: bool = False,
64
+ debug: bool = False,
123
65
  ):
66
+ """Initializes the NLPCAModel.
67
+
68
+ Args:
69
+ n_features (int): The number of features (SNPs) in the input data.
70
+ prefix (str): A prefix used for logging.
71
+ num_classes (int): Number of genotype states per locus (2 for haploid, 3 for diploid in practice). Defaults to 4 for backward compatibility.
72
+ hidden_layer_sizes (list[int] | np.ndarray): A list of integers specifying the number of units in each hidden layer. Defaults to [128, 64].
73
+ latent_dim (int): The dimensionality of the latent space (the size of the bottleneck layer). Defaults to 2.
74
+ dropout_rate (float): The dropout rate applied to each hidden layer for regularization. Defaults to 0.2.
75
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): The non-linear activation function to use in hidden layers. Defaults to 'relu'.
76
+ gamma (float): The focusing parameter for the focal loss function, which down-weights well-classified examples. Defaults to 2.0.
77
+ device (Literal["gpu", "cpu", "mps"]): The PyTorch device to run the model on. Defaults to 'cpu'.
78
+ verbose (bool): If True, enables detailed logging. Defaults to False.
79
+ debug (bool): If True, enables debug mode. Defaults to False.
80
+ """
124
81
  super(NLPCAModel, self).__init__()
125
82
 
126
- nn = NeuralNetworkMethods()
127
- self.nn = nn
128
-
129
- if V is None:
130
- self._V = nn.init_weights(y.shape[0], n_components)
131
- elif isinstance(V, dict):
132
- self._V = V[n_components]
133
- else:
134
- self._V = V
135
-
136
- self._y = y
137
-
138
- hidden_layer_sizes = nn.validate_hidden_layers(
139
- hidden_layer_sizes, num_hidden_layers
140
- )
141
-
142
- hidden_layer_sizes = nn.get_hidden_layer_sizes(
143
- y.shape[1], self._V.shape[1], hidden_layer_sizes
144
- )
145
-
146
- nn.validate_model_inputs(y, missing_mask, output_shape)
147
-
148
- self._missing_mask = missing_mask
149
- self.weights_initializer = weights_initializer
150
- self.phase = phase
151
- self.dropout_rate = dropout_rate
152
- self._sample_weight = sample_weight
153
-
154
- ### NOTE: I tried using just _V as the input to be refined, but it
155
- # wasn't getting updated. So I copy it here and it works.
156
- # V_latent is refined during train_step.
157
- self.V_latent_ = self._V.copy()
158
-
159
- # Initialize parameters used during train_step.
160
- self._batch_idx = 0
161
- self._batch_size = batch_size
162
- self.n_components = n_components
163
-
164
- if l1_penalty == 0.0 and l2_penalty == 0.0:
165
- kernel_regularizer = None
166
- else:
167
- kernel_regularizer = l1_l2(l1_penalty, l2_penalty)
168
-
169
- self.kernel_regularizer = kernel_regularizer
170
- kernel_initializer = weights_initializer
171
-
172
- if hidden_activation.lower() == "leaky_relu":
173
- activation = LeakyReLU(alpha=0.01)
174
-
175
- elif hidden_activation.lower() == "prelu":
176
- activation = PReLU()
177
-
178
- elif hidden_activation.lower() == "selu":
179
- activation = "selu"
180
- kernel_initializer = "lecun_normal"
181
-
182
- else:
183
- activation = hidden_activation
184
-
185
- if num_hidden_layers > 5:
186
- raise ValueError(
187
- f"The maximum number of hidden layers is 5, but got "
188
- f"{num_hidden_layers}"
189
- )
190
-
191
- self.dense2 = None
192
- self.dense3 = None
193
- self.dense4 = None
194
- self.dense5 = None
195
-
196
- # Construct multi-layer perceptron.
197
- # Add hidden layers dynamically.
198
- self.dense1 = Dense(
199
- hidden_layer_sizes[0],
200
- input_shape=(n_components,),
201
- activation=activation,
202
- kernel_initializer=kernel_initializer,
203
- kernel_regularizer=kernel_regularizer,
83
+ logman = LoggerManager(
84
+ name=__name__, prefix=prefix, verbose=verbose, debug=debug
204
85
  )
205
-
206
- if num_hidden_layers >= 2:
207
- self.dense2 = Dense(
208
- hidden_layer_sizes[1],
209
- activation=activation,
210
- kernel_initializer=kernel_initializer,
211
- kernel_regularizer=kernel_regularizer,
212
- )
213
-
214
- if num_hidden_layers >= 3:
215
- self.dense3 = Dense(
216
- hidden_layer_sizes[2],
217
- activation=activation,
218
- kernel_initializer=kernel_initializer,
219
- kernel_regularizer=kernel_regularizer,
220
- )
221
-
222
- if num_hidden_layers >= 4:
223
- self.dense4 = Dense(
224
- hidden_layer_sizes[3],
225
- activation=activation,
226
- kernel_initializer=kernel_initializer,
227
- kernel_regularizer=kernel_regularizer,
228
- )
229
-
230
- if num_hidden_layers == 5:
231
- self.dense5 = Dense(
232
- hidden_layer_sizes[4],
233
- activation=activation,
234
- kernel_initializer=kernel_initializer,
235
- kernel_regularizer=kernel_regularizer,
236
- )
237
-
238
- self.output1 = Dense(
239
- output_shape * num_classes,
240
- kernel_initializer=kernel_initializer,
241
- kernel_regularizer=kernel_regularizer,
86
+ self.logger = configure_logger(
87
+ logman.get_logger(), verbose=verbose, debug=debug
242
88
  )
243
89
 
244
- self.rshp = Reshape((output_shape, num_classes))
245
-
246
- self.dropout_layer = Dropout(rate=dropout_rate)
247
-
248
- def call(self, inputs, training=None):
249
- x = self.dense1(inputs)
250
- x = self.dropout_layer(x, training=training)
251
- if self.dense2 is not None:
252
- x = self.dense2(x)
253
- x = self.dropout_layer(x, training=training)
254
- if self.dense3 is not None:
255
- x = self.dense3(x)
256
- x = self.dropout_layer(x, training=training)
257
- if self.dense4 is not None:
258
- x = self.dense4(x)
259
- x = self.dropout_layer(x, training=training)
260
- if self.dense5 is not None:
261
- x = self.dense5(x)
262
- x = self.dropout_layer(x, training=training)
263
-
264
- x = self.output1(x)
265
- return self.rshp(x)
266
-
267
- def model(self):
268
- x = tf.keras.Input(shape=(self.n_components,))
269
- return tf.keras.Model(inputs=[x], outputs=self.call(x))
270
-
271
- def set_model_outputs(self):
272
- x = tf.keras.Input(shape=(self.n_components,))
273
- model = tf.keras.Model(inputs=[x], outputs=self.call(x))
274
- self.outputs = model.outputs
275
-
276
- def train_step(self, data):
277
- """Train step function. Parameters are set in UBPCallbacks callback."""
278
- y = self._y
279
-
280
- (
281
- v,
282
- y_true,
283
- sample_weight,
284
- missing_mask,
285
- batch_start,
286
- batch_end,
287
- ) = self.nn.prepare_training_batches(
288
- self.V_latent_,
289
- y,
290
- self._batch_size,
291
- self._batch_idx,
292
- True,
293
- self.n_components,
294
- self._sample_weight,
295
- self._missing_mask,
296
- )
90
+ self.n_features = n_features
91
+ self.num_classes = num_classes
92
+ self.latent_dim = latent_dim
93
+ self.gamma = gamma
94
+ self.device = device
297
95
 
298
- src = [v]
96
+ if isinstance(hidden_layer_sizes, np.ndarray):
97
+ hidden_layer_sizes = hidden_layer_sizes.tolist()
299
98
 
300
- if sample_weight is not None:
301
- sample_weight_masked = tf.convert_to_tensor(
302
- sample_weight[~missing_mask], dtype=tf.float32
303
- )
304
- else:
305
- sample_weight_masked = None
306
-
307
- y_true_masked = tf.boolean_mask(
308
- tf.convert_to_tensor(y_true, dtype=tf.float32),
309
- tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
310
- )
99
+ layers = []
100
+ input_dim = latent_dim
101
+ for size in hidden_layer_sizes:
102
+ layers.append(nn.Linear(input_dim, size))
103
+ layers.append(nn.BatchNorm1d(size))
104
+ layers.append(nn.Dropout(dropout_rate))
105
+ layers.append(self._resolve_activation(activation))
106
+ input_dim = size
311
107
 
312
- # NOTE: Earlier model architectures incorrectly
313
- # applied one gradient to all the variables, including
314
- # the weights and v. Here we apply them separately, per
315
- # the UBP manuscript.
316
- with tf.GradientTape(persistent=True) as tape:
317
- # Forward pass. Watch input tensor v.
318
- tape.watch(v)
319
- y_pred = self(v, training=True)
320
- y_pred_masked = tf.boolean_mask(
321
- y_pred, tf.reduce_any(tf.not_equal(y_true, -1), axis=2)
322
- )
323
- ### NOTE: If you get the error, "'tuple' object has no attribute
324
- ### 'rank'", then convert y_true to a tensor object."
325
- loss = self.compiled_loss(
326
- y_true_masked,
327
- y_pred_masked,
328
- sample_weight=sample_weight_masked,
329
- regularization_losses=self.losses,
330
- )
331
-
332
- # Refine the watched variables with
333
- # gradient descent backpropagation
334
- gradients = tape.gradient(loss, self.trainable_variables)
335
- self.optimizer.apply_gradients(
336
- zip(gradients, self.trainable_variables)
337
- )
108
+ # Final layer output size is now n_features * num_classes
109
+ final_output_size = self.n_features * self.num_classes
110
+ layers.append(nn.Linear(hidden_layer_sizes[-1], final_output_size))
338
111
 
339
- # Apply separate gradients to v.
340
- vgrad = tape.gradient(loss, src)
341
- self.optimizer.apply_gradients(zip(vgrad, src))
112
+ self.phase23_decoder = nn.Sequential(*layers)
342
113
 
343
- del tape
114
+ # Reshape tuple reflects the output structure
115
+ self.reshape = (self.n_features, self.num_classes)
344
116
 
345
- ### NOTE: If you get the error, "'tuple' object has no attribute
346
- ### 'rank', then convert y_true to a tensor object."
347
- self.compiled_metrics.update_state(
348
- y_true_masked,
349
- y_pred_masked,
350
- sample_weight=sample_weight_masked,
351
- )
117
+ def _resolve_activation(
118
+ self, activation: Literal["relu", "elu", "selu", "leaky_relu"]
119
+ ) -> nn.Module:
120
+ """Resolves an activation function from a string name.
352
121
 
353
- # NOTE: run_eagerly must be set to True in the compile() method for this
354
- # to work. Otherwise it can't convert a Tensor object to a numpy array.
355
- # There is really no other way to set v back to V_latent_ in graph
356
- # mode as far as I know. eager execution is slower, so it would be nice
357
- # to find a way to do this without converting to numpy.
358
- self.V_latent_[batch_start:batch_end, :] = v.numpy()
122
+ This method acts as a factory, returning the correct PyTorch activation function module based on the provided name.
359
123
 
360
- # history object that gets returned from fit().
361
- return {m.name: m.result() for m in self.metrics}
124
+ Args:
125
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): The name of the activation function.
362
126
 
363
- @property
364
- def V_latent(self):
365
- """Randomly initialized input that gets refined during training.
366
- :noindex:
367
- """
368
- return self.V_latent_
127
+ Returns:
128
+ nn.Module: The corresponding PyTorch activation function module.
369
129
 
370
- @property
371
- def batch_size(self):
372
- """Batch (=step) size per epoch.
373
- :noindex:
130
+ Raises:
131
+ ValueError: If the provided activation name is not supported.
374
132
  """
375
- return self._batch_size
133
+ act: str = activation.lower()
134
+
135
+ if act == "relu":
136
+ return nn.ReLU()
137
+ elif act == "elu":
138
+ return nn.ELU()
139
+ elif act == "leaky_relu":
140
+ return nn.LeakyReLU()
141
+ elif act == "selu":
142
+ return nn.SELU()
143
+ else:
144
+ msg = f"Activation function {act} not supported."
145
+ self.logger.error(msg)
146
+ raise ValueError(msg)
376
147
 
377
- @property
378
- def batch_idx(self):
379
- """Current batch (=step) index.
380
- :noindex:
381
- """
382
- return self._batch_idx
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ """Performs the forward pass of the model.
383
150
 
384
- @property
385
- def y(self):
386
- """Full dataset y.
387
- :noindex:
388
- """
389
- return self._y
151
+ The input tensor is passed through the decoder network to produce logits,
152
+ which are reshaped to align with the locus-by-class grid used by the loss.
390
153
 
391
- @property
392
- def missing_mask(self):
393
- """Missing mask of shape (y.shape[0], y.shape[1])
394
- :noindex:
395
- """
396
- return self._missing_mask
154
+ Args:
155
+ x (torch.Tensor): The input tensor, which should represent the latent space vector.
397
156
 
398
- @property
399
- def sample_weight(self):
400
- """Sample weights of shape (y.shape[0], y.shape[1])
401
- :noindex:
157
+ Returns:
158
+ torch.Tensor: The reconstructed output tensor of shape `(batch_size, n_features, num_classes)`.
402
159
  """
403
- return self._sample_weight
160
+ x = self.phase23_decoder(x)
404
161
 
405
- @V_latent.setter
406
- def V_latent(self, value):
407
- """Set randomly initialized input. Refined during training.
408
- :noindex:
409
- """
410
- self.V_latent_ = value
162
+ # Reshape to (batch, features, num_classes)
163
+ return x.view(-1, *self.reshape)
411
164
 
412
- @batch_size.setter
413
- def batch_size(self, value):
414
- """Set batch_size parameter.
415
- :noindex:
165
+ def compute_loss(
166
+ self,
167
+ y: torch.Tensor,
168
+ outputs: torch.Tensor,
169
+ mask: torch.Tensor | None = None,
170
+ class_weights: torch.Tensor | None = None,
171
+ gamma: float = 2.0,
172
+ ) -> torch.Tensor:
173
+ """Computes the masked focal loss between model outputs and ground truth.
174
+
175
+ This method calculates the loss value, handling class imbalance with weights and ignoring masked (missing) values.
176
+
177
+ Args:
178
+ y (torch.Tensor): Integer ground-truth genotypes of shape `(batch_size, n_features)`.
179
+ outputs (torch.Tensor): Logits of shape `(batch_size, n_features, num_classes)`.
180
+ mask (torch.Tensor | None): An optional boolean mask indicating which elements should be included in the loss calculation. Defaults to None.
181
+ class_weights (torch.Tensor | None): An optional tensor of weights for each class to address imbalance. Defaults to None.
182
+ gamma (float): The focusing parameter for the focal loss. Defaults to 2.0.
183
+
184
+ Returns:
185
+ torch.Tensor: The computed scalar loss value.
416
186
  """
417
- self._batch_size = int(value)
187
+ if class_weights is None:
188
+ class_weights = torch.ones(self.num_classes, device=outputs.device)
418
189
 
419
- @batch_idx.setter
420
- def batch_idx(self, value):
421
- """Set current batch (=step) index.
422
- :noindex:
423
- """
424
- self._batch_idx = int(value)
190
+ if mask is None:
191
+ mask = torch.ones_like(y, dtype=torch.bool)
425
192
 
426
- @y.setter
427
- def y(self, value):
428
- """Set y after each epoch.
429
- :noindex:
430
- """
431
- self._y = value
193
+ # Explicitly flatten all tensors to the (N, C) and (N,) format.
194
+ # This creates a clear contract with the new MaskedFocalLoss function.
195
+ n_classes = outputs.shape[-1]
196
+ logits_flat = outputs.reshape(-1, n_classes)
197
+ targets_flat = y.reshape(-1)
198
+ mask_flat = mask.reshape(-1)
432
199
 
433
- @missing_mask.setter
434
- def missing_mask(self, value):
435
- """Set missing_mask after each epoch.
436
- :noindex:
437
- """
438
- self._missing_mask = value
200
+ criterion = MaskedFocalLoss(gamma=gamma, alpha=class_weights)
439
201
 
440
- @sample_weight.setter
441
- def sample_weight(self, value):
442
- """Set sample_weight after each epoch.
443
- :noindex:
444
- """
445
- self._sample_weight = value
202
+ return criterion(
203
+ logits_flat.to(self.device),
204
+ targets_flat.to(self.device),
205
+ valid_mask=mask_flat.to(self.device),
206
+ )