pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -1,523 +1,204 @@
1
- import logging
2
- import os
3
- import sys
4
- import warnings
1
+ from typing import List, Literal
5
2
 
6
- # Import tensorflow with reduced warnings.
7
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
8
- logging.getLogger("tensorflow").disabled = True
9
- warnings.filterwarnings("ignore", category=UserWarning)
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from snpio.utils.logging import LoggerManager
10
7
 
11
- # noinspection PyPackageRequirements
12
- import tensorflow as tf
8
+ from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
13
9
 
14
- # Disable can't find cuda .dll errors. Also turns off GPU support.
15
- tf.config.set_visible_devices([], "GPU")
16
10
 
17
- from tensorflow.python.util import deprecation
11
+ class NLPCAModel(nn.Module):
12
+ r"""A non-linear Principal Component Analysis (NLPCA) model.
18
13
 
19
- # Disable warnings and info logs.
20
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
21
- tf.get_logger().setLevel(logging.ERROR)
14
+ This model serves as a decoder for an autoencoder-based imputation strategy. It's a deep neural network that takes a low-dimensional latent vector as input and reconstructs the high-dimensional allele data. The architecture is a multi-layered, fully connected network with optional batch normalization and dropout layers. The model is specifically designed for two-channel allele data, predicting allele probabilities for each of the two channels at every SNP.
22
15
 
16
+ **Model Architecture**
23
17
 
24
- # Monkey patching deprecation utils to supress warnings.
25
- # noinspection PyUnusedLocal
26
- def deprecated(
27
- date, instructions, warn_once=True
28
- ): # pylint: disable=unused-argument
29
- def deprecated_wrapper(func):
30
- return func
18
+ The model's forward pass, from a latent representation :math:`z` to the reconstructed input :math:`\hat{x}`, can be described as follows.
31
19
 
32
- return deprecated_wrapper
20
+ Let :math:`z \in \mathbb{R}^{d_{\text{latent}}}` be the latent vector.
33
21
 
22
+ The decoder consists of a series of fully connected layers with activation functions, batch normalization, and dropout. For a network with :math:`L` hidden layers, the transformations are:
34
23
 
35
- deprecation.deprecated = deprecated
24
+ .. math::
36
25
 
37
- from tensorflow.keras.layers import (
38
- Dropout,
39
- Dense,
40
- Reshape,
41
- LeakyReLU,
42
- PReLU,
43
- Activation,
44
- )
26
+ h_1 = f(W_1 z + b_1)
45
27
 
46
- from tensorflow.keras.regularizers import l1_l2
28
+ .. math::
47
29
 
48
- # Custom Modules
49
- try:
50
- from ..neural_network_methods import NeuralNetworkMethods
51
- except (ModuleNotFoundError, ValueError, ImportError):
52
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
30
+ h_2 = f(W_2 h_1 + b_2)
53
31
 
32
+ .. math::
54
33
 
55
- class NLPCAModel(tf.keras.Model):
56
- """NLPCA model to train and use to predict imputations.
34
+ \vdots
57
35
 
58
- NLPCAModel subclasses the tf.keras.Model and overrides the train_step function, which does training and evaluation for each batch in each epoch.
36
+ .. math::
59
37
 
60
- Args:
61
- V (numpy.ndarray(float)): V should have been randomly initialized and will be used as the input data that gets refined during training. Defaults to None.
38
+ h_L = f(W_L h_{L-1} + b_L)
62
39
 
63
- y (numpy.ndarray): Target values to predict. Actual input data. Defaults to None.
40
+ The final output layer produces a tensor of shape ``(batch_size, n_features, n_channels, n_classes)``:
64
41
 
65
- batch_size (int, optional): Batch size per epoch. Defaults to 32.
42
+ .. math::
66
43
 
67
- missing_mask (numpy.ndarray): Missing data mask for y. Defaults to None.
44
+ \hat{x} = W_{L+1} h_L + b_{L+1}
68
45
 
69
- output_shape (int): Output units for n_features dimension. Output will be of shape (batch_size, n_features). Defaults to None.
46
+ where :math:`f(\cdot)` is the activation function (e.g., ReLU, ELU), and :math:`W_i` and :math:`b_i` are the weights and biases of each layer.
70
47
 
71
- n_components (int, optional): Number of features in input V to use. Defaults to 3.
72
-
73
- weights_initializer (str, optional): Kernel initializer to use for initializing model weights. Defaults to "glorot_normal".
74
-
75
- hidden_layer_sizes (List[int], optional): Output units for each hidden layer. List should be of same length as the number of hidden layers. Defaults to "midpoint".
76
-
77
- num_hidden_layers (int, optional): Number of hidden layers to use. Defaults to 1.
78
-
79
- hidden_activation (str, optional): Activation function to use for hidden layers. Defaults to "elu".
80
-
81
- l1_penalty (float, optional): L1 regularization penalty to use to reduce overfitting. Defaults to 0.01.
82
-
83
- l2_penalty (float, optional): L2 regularization penalty to use to reduce overfitting. Defaults to 0.01.
84
-
85
- dropout_rate (float, optional): Dropout rate during training to reduce overfitting. Must be a float between 0 and 1. Defaults to 0.2.
86
-
87
- num_classes (int, optional): Number of classes in output. Corresponds to the 3rd dimension of the output shape (batch_size, n_features, num_classes). Defaults to 1.
88
-
89
- phase (NoneType): Here for compatibility with UBP.
90
-
91
- sample_weight (numpy.ndarray, optional): 2D sample weights of shape (n_samples, n_features). Should have values for each class weighted. Defaults to None.
92
-
93
- Example:
94
- >>>model = NLPCAModel(V=V, y=y, batch_size=32, missing_mask=missing_mask, output_shape, n_components, weights_initializer, hidden_layer_sizes, num_hidden_layers, hidden_activation, l1_penalty, l2_penalty, dropout_rate, num_classes=3)
95
- >>>
96
- >>>model.compile(optimizer=optimizer, loss=loss_func, metrics=[my_metrics], run_eagerly=True)
97
- >>>
98
- >>>history = model.fit(X, y, batch_size=batch_size, epochs=epochs, callbacks=[MyCallback()], validation_split=validation_split, shuffle=False)
99
-
100
- Raises:
101
- TypeError: V, y, missing_mask, output_shape must not be NoneType.
102
- ValueError: Maximum of 5 hidden layers.
48
+ **Loss Function**
103
49
 
50
+ The model is trained by minimizing the ``MaskedFocalLoss``, an extension of cross-entropy that emphasizes hard-to-classify examples and handles missing values via a mask. The loss is computed on the reconstructed output and the ground truth, using a mask to include only observed data.
104
51
  """
105
52
 
106
53
  def __init__(
107
54
  self,
108
- V=None,
109
- y=None,
110
- batch_size=32,
111
- missing_mask=None,
112
- output_shape=None,
113
- n_components=3,
114
- weights_initializer="glorot_normal",
115
- hidden_layer_sizes="midpoint",
116
- num_hidden_layers=1,
117
- hidden_activation="elu",
118
- l1_penalty=0.01,
119
- l2_penalty=0.01,
120
- dropout_rate=0.2,
121
- num_classes=3,
122
- phase=None,
123
- sample_weight=None,
55
+ n_features: int,
56
+ prefix: str,
57
+ *,
58
+ num_classes: int = 4,
59
+ hidden_layer_sizes: List[int] | np.ndarray = [128, 64],
60
+ latent_dim: int = 2,
61
+ dropout_rate: float = 0.2,
62
+ activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
63
+ gamma: float = 2.0,
64
+ device: Literal["gpu", "cpu", "mps"] = "cpu",
65
+ verbose: bool = False,
66
+ debug: bool = False,
124
67
  ):
68
+ """Initializes the NLPCAModel.
69
+
70
+ Args:
71
+ n_features (int): The number of features (SNPs) in the input data.
72
+ prefix (str): A prefix used for logging.
73
+ num_classes (int): The number of possible allele classes (e.g., 4 for A, T, C, G). Defaults to 4.
74
+ hidden_layer_sizes (list[int] | np.ndarray): A list of integers specifying the number of units in each hidden layer. Defaults to [128, 64].
75
+ latent_dim (int): The dimensionality of the latent space (the size of the bottleneck layer). Defaults to 2.
76
+ dropout_rate (float): The dropout rate applied to each hidden layer for regularization. Defaults to 0.2.
77
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): The non-linear activation function to use in hidden layers. Defaults to 'relu'.
78
+ gamma (float): The focusing parameter for the focal loss function, which down-weights well-classified examples. Defaults to 2.0.
79
+ device (Literal["gpu", "cpu", "mps"]): The PyTorch device to run the model on. Defaults to 'cpu'.
80
+ verbose (bool): If True, enables detailed logging. Defaults to False.
81
+ debug (bool): If True, enables debug mode. Defaults to False.
82
+ """
125
83
  super(NLPCAModel, self).__init__()
126
84
 
127
- self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
128
- self.binary_accuracy_tracker = tf.keras.metrics.Mean(
129
- name="binary_accuracy"
130
- )
131
-
132
- nn = NeuralNetworkMethods()
133
- self.nn = nn
134
-
135
- if V is None:
136
- self._V = nn.init_weights(y.shape[0], n_components)
137
- elif isinstance(V, dict):
138
- self._V = V[n_components]
139
- else:
140
- self._V = V
141
-
142
- self._y = y
143
-
144
- hidden_layer_sizes = nn.validate_hidden_layers(
145
- hidden_layer_sizes, num_hidden_layers
146
- )
147
-
148
- hidden_layer_sizes = nn.get_hidden_layer_sizes(
149
- y.shape[1], self._V.shape[1], hidden_layer_sizes
150
- )
151
-
152
- nn.validate_model_inputs(y, missing_mask, output_shape)
153
-
154
- self._missing_mask = missing_mask
155
- self.weights_initializer = weights_initializer
156
- self.phase = phase
157
- self.dropout_rate = dropout_rate
158
- self._sample_weight = sample_weight
159
-
160
- ### NOTE: I tried using just _V as the input to be refined, but it
161
- # wasn't getting updated. So I copy it here and it works.
162
- # V_latent is refined during train_step.
163
- self.V_latent_ = self._V.copy()
164
-
165
- # Initialize parameters used during train_step.
166
- self._batch_idx = 0
167
- self._batch_size = batch_size
168
- self.n_components = n_components
169
-
170
- if l1_penalty == 0.0 and l2_penalty == 0.0:
171
- kernel_regularizer = None
172
- else:
173
- kernel_regularizer = l1_l2(l1_penalty, l2_penalty)
174
-
175
- self.kernel_regularizer = kernel_regularizer
176
- kernel_initializer = weights_initializer
177
-
178
- if hidden_activation.lower() == "leaky_relu":
179
- activation = LeakyReLU(alpha=0.01)
180
-
181
- elif hidden_activation.lower() == "prelu":
182
- activation = PReLU()
183
-
184
- elif hidden_activation.lower() == "selu":
185
- activation = "selu"
186
- kernel_initializer = "lecun_normal"
187
-
188
- else:
189
- activation = hidden_activation
190
-
191
- if num_hidden_layers > 5:
192
- raise ValueError(
193
- f"The maximum number of hidden layers is 5, but got "
194
- f"{num_hidden_layers}"
195
- )
196
-
197
- self.dense2 = None
198
- self.dense3 = None
199
- self.dense4 = None
200
- self.dense5 = None
201
-
202
- # Construct multi-layer perceptron.
203
- # Add hidden layers dynamically.
204
- self.dense1 = Dense(
205
- hidden_layer_sizes[0],
206
- input_shape=(n_components,),
207
- activation=activation,
208
- kernel_initializer=kernel_initializer,
209
- kernel_regularizer=kernel_regularizer,
85
+ logman = LoggerManager(
86
+ name=__name__, prefix=prefix, verbose=verbose, debug=debug
210
87
  )
88
+ self.logger = logman.get_logger()
211
89
 
212
- if num_hidden_layers >= 2:
213
- self.dense2 = Dense(
214
- hidden_layer_sizes[1],
215
- activation=activation,
216
- kernel_initializer=kernel_initializer,
217
- kernel_regularizer=kernel_regularizer,
218
- )
219
-
220
- if num_hidden_layers >= 3:
221
- self.dense3 = Dense(
222
- hidden_layer_sizes[2],
223
- activation=activation,
224
- kernel_initializer=kernel_initializer,
225
- kernel_regularizer=kernel_regularizer,
226
- )
227
-
228
- if num_hidden_layers >= 4:
229
- self.dense4 = Dense(
230
- hidden_layer_sizes[3],
231
- activation=activation,
232
- kernel_initializer=kernel_initializer,
233
- kernel_regularizer=kernel_regularizer,
234
- )
235
-
236
- if num_hidden_layers == 5:
237
- self.dense5 = Dense(
238
- hidden_layer_sizes[4],
239
- activation=activation,
240
- kernel_initializer=kernel_initializer,
241
- kernel_regularizer=kernel_regularizer,
242
- )
243
-
244
- self.output1 = Dense(
245
- output_shape * num_classes,
246
- kernel_initializer=kernel_initializer,
247
- kernel_regularizer=kernel_regularizer,
248
- )
90
+ self.n_features = n_features
91
+ self.num_classes = num_classes
92
+ self.latent_dim = latent_dim
93
+ self.gamma = gamma
94
+ self.device = device
249
95
 
250
- self.rshp = Reshape((output_shape, num_classes))
251
-
252
- self.dropout_layer = Dropout(rate=dropout_rate)
253
-
254
- self.activation = Activation("sigmoid")
255
-
256
- def call(self, inputs, training=None):
257
- x = self.dense1(inputs)
258
- x = self.dropout_layer(x, training=training)
259
- if self.dense2 is not None:
260
- x = self.dense2(x)
261
- x = self.dropout_layer(x, training=training)
262
- if self.dense3 is not None:
263
- x = self.dense3(x)
264
- x = self.dropout_layer(x, training=training)
265
- if self.dense4 is not None:
266
- x = self.dense4(x)
267
- x = self.dropout_layer(x, training=training)
268
- if self.dense5 is not None:
269
- x = self.dense5(x)
270
- x = self.dropout_layer(x, training=training)
271
-
272
- x = self.output1(x)
273
- x = self.rshp(x)
274
- return self.activation(x)
275
-
276
- def model(self):
277
- x = tf.keras.Input(shape=(self.n_components,))
278
- x = tf.keras.Model(inputs=[x], outputs=self.call(x))
279
-
280
- def set_model_outputs(self):
281
- x = tf.keras.Input(shape=(self.n_components,))
282
- model = tf.keras.Model(inputs=[x], outputs=self.call(x))
283
- self.outputs = model.outputs
284
-
285
- @property
286
- def metrics(self):
287
- """Set metric trackers."""
288
- return [
289
- self.total_loss_tracker,
290
- self.binary_accuracy_tracker,
291
- ]
292
-
293
- def train_step(self, data):
294
- """Train step function. Parameters are set in UBPCallbacks callback."""
295
- y = self._y
296
-
297
- (
298
- v,
299
- y_true,
300
- sample_weight,
301
- missing_mask,
302
- batch_start,
303
- batch_end,
304
- ) = self.nn.prepare_training_batches(
305
- self.V_latent_,
306
- y,
307
- self._batch_size,
308
- self._batch_idx,
309
- True,
310
- self.n_components,
311
- self._sample_weight,
312
- self._missing_mask,
313
- )
96
+ if isinstance(hidden_layer_sizes, np.ndarray):
97
+ hidden_layer_sizes = hidden_layer_sizes.tolist()
314
98
 
315
- src = [v]
99
+ layers = []
100
+ input_dim = latent_dim
101
+ for size in hidden_layer_sizes:
102
+ layers.append(nn.Linear(input_dim, size))
103
+ layers.append(nn.BatchNorm1d(size))
104
+ layers.append(nn.Dropout(dropout_rate))
105
+ layers.append(self._resolve_activation(activation))
106
+ input_dim = size
316
107
 
317
- if sample_weight is not None:
318
- sample_weight_masked = tf.convert_to_tensor(
319
- sample_weight[~missing_mask], dtype=tf.float32
320
- )
321
- else:
322
- sample_weight_masked = None
108
+ # Final layer output size is now n_features * num_classes
109
+ final_output_size = self.n_features * self.num_classes
110
+ layers.append(nn.Linear(hidden_layer_sizes[-1], final_output_size))
323
111
 
324
- y_true_masked = tf.boolean_mask(
325
- tf.convert_to_tensor(y_true, dtype=tf.float32),
326
- tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
327
- )
112
+ self.phase23_decoder = nn.Sequential(*layers)
328
113
 
329
- # NOTE: Earlier model architectures incorrectly
330
- # applied one gradient to all the variables, including
331
- # the weights and v. Here we apply them separately, per
332
- # the UBP manuscript.
333
- with tf.GradientTape(persistent=True) as tape:
334
- # Forward pass. Watch input tensor v.
335
- tape.watch(v)
336
- y_pred = self(v, training=True)
337
- y_pred_masked = tf.boolean_mask(
338
- y_pred, tf.reduce_any(tf.not_equal(y_true, -1), axis=2)
339
- )
340
- ### NOTE: If you get the error, "'tuple' object has no attribute
341
- ### 'rank'", then convert y_true to a tensor object."
342
- loss = self.compiled_loss(
343
- y_true_masked,
344
- y_pred_masked,
345
- sample_weight=sample_weight_masked,
346
- regularization_losses=self.losses,
347
- )
348
-
349
- # Refine the watched variables with
350
- # gradient descent backpropagation
351
- gradients = tape.gradient(loss, self.trainable_variables)
352
- self.optimizer.apply_gradients(
353
- zip(gradients, self.trainable_variables)
354
- )
355
-
356
- # Apply separate gradients to v.
357
- vgrad = tape.gradient(loss, src)
358
- self.optimizer.apply_gradients(zip(vgrad, src))
114
+ # Reshape tuple reflects the output structure
115
+ self.reshape = (self.n_features, self.num_classes)
359
116
 
360
- del tape
117
+ def _resolve_activation(
118
+ self, activation: Literal["relu", "elu", "selu", "leaky_relu"]
119
+ ) -> nn.Module:
120
+ """Resolves an activation function from a string name.
361
121
 
362
- # NOTE: run_eagerly must be set to True in the compile() method for this
363
- # to work. Otherwise it can't convert a Tensor object to a numpy array.
364
- # There is really no other way to set v back to V_latent_ in graph
365
- # mode as far as I know. eager execution is slower, so it would be nice
366
- # to find a way to do this without converting to numpy.
367
- self.V_latent_[batch_start:batch_end, :] = v.numpy()
122
+ This method acts as a factory, returning the correct PyTorch activation function module based on the provided name.
368
123
 
369
- ### NOTE: If you get the error, "'tuple' object has no attribute
370
- ### 'rank', then convert y_true to a tensor object."
371
- self.total_loss_tracker.update_state(loss)
372
- self.binary_accuracy_tracker.update_state(
373
- tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
374
- )
124
+ Args:
125
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): The name of the activation function.
375
126
 
376
- return {
377
- "loss": self.total_loss_tracker.result(),
378
- "binary_accuracy": self.binary_accuracy_tracker.result(),
379
- }
380
-
381
- def test_step(self, data):
382
- """Test step function. Parameters are set in UBPCallbacks callback."""
383
- y = self._y
384
-
385
- (
386
- v,
387
- y_true,
388
- sample_weight,
389
- missing_mask,
390
- batch_start,
391
- batch_end,
392
- ) = self.nn.prepare_training_batches(
393
- self.V_latent_,
394
- y,
395
- self._batch_size,
396
- self._batch_idx,
397
- True,
398
- self.n_components,
399
- self._sample_weight,
400
- self._missing_mask,
401
- )
402
-
403
- if sample_weight is not None:
404
- sample_weight_masked = tf.convert_to_tensor(
405
- sample_weight[~missing_mask], dtype=tf.float32
406
- )
407
- else:
408
- sample_weight_masked = None
409
-
410
- y_true_masked = tf.boolean_mask(
411
- tf.convert_to_tensor(y_true, dtype=tf.float32),
412
- tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
413
- )
414
-
415
- y_pred = self(v, training=False)
416
- y_pred_masked = tf.boolean_mask(
417
- y_pred, tf.reduce_any(tf.not_equal(y_true, -1), axis=2)
418
- )
419
-
420
- ### NOTE: If you get the error, "'tuple' object has no attribute
421
- ### 'rank'", then convert y_true to a tensor object."
422
- loss = self.compiled_loss(
423
- y_true_masked,
424
- y_pred_masked,
425
- sample_weight=sample_weight_masked,
426
- regularization_losses=self.losses,
427
- )
127
+ Returns:
128
+ nn.Module: The corresponding PyTorch activation function module.
428
129
 
429
- ### NOTE: If you get the error, "'tuple' object has no attribute
430
- ### 'rank', then convert y_true to a tensor object."
431
- self.total_loss_tracker.update_state(loss)
432
- self.binary_accuracy_tracker.update_state(
433
- tf.keras.metrics.binary_accuracy(y_true_masked, y_pred_masked)
434
- )
435
-
436
- return {
437
- "loss": self.total_loss_tracker.result(),
438
- "binary_accuracy": self.binary_accuracy_tracker.result(),
439
- }
440
-
441
- @property
442
- def V_latent(self):
443
- """Randomly initialized input that gets refined during training.
444
- :noindex:
445
- """
446
- return self.V_latent_
447
-
448
- @property
449
- def batch_size(self):
450
- """Batch (=step) size per epoch.
451
- :noindex:
130
+ Raises:
131
+ ValueError: If the provided activation name is not supported.
452
132
  """
453
- return self._batch_size
133
+ activation = activation.lower()
134
+ if activation == "relu":
135
+ return nn.ReLU()
136
+ elif activation == "elu":
137
+ return nn.ELU()
138
+ elif activation == "leaky_relu":
139
+ return nn.LeakyReLU()
140
+ elif activation == "selu":
141
+ return nn.SELU()
142
+ else:
143
+ msg = f"Activation function {activation} not supported."
144
+ self.logger.error(msg)
145
+ raise ValueError(msg)
454
146
 
455
- @property
456
- def batch_idx(self):
457
- """Current batch (=step) index.
458
- :noindex:
459
- """
460
- return self._batch_idx
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """Performs the forward pass of the model.
461
149
 
462
- @property
463
- def y(self):
464
- """Full dataset y.
465
- :noindex:
466
- """
467
- return self._y
150
+ The input tensor is passed through the decoder network to produce the reconstructed output. The output is then reshaped to a 4D tensor representing batches, features, channels, and classes.
468
151
 
469
- @property
470
- def missing_mask(self):
471
- """Missing mask of shape (y.shape[0], y.shape[1])
472
- :noindex:
473
- """
474
- return self._missing_mask
152
+ Args:
153
+ x (torch.Tensor): The input tensor, which should represent the latent space vector.
475
154
 
476
- @property
477
- def sample_weight(self):
478
- """Sample weights of shape (y.shape[0], y.shape[1])
479
- :noindex:
155
+ Returns:
156
+ torch.Tensor: The reconstructed output tensor of shape `(batch_size, n_features, n_channels, n_classes)`.
480
157
  """
481
- return self._sample_weight
158
+ x = self.phase23_decoder(x)
482
159
 
483
- @V_latent.setter
484
- def V_latent(self, value):
485
- """Set randomly initialized input. Refined during training.
486
- :noindex:
487
- """
488
- self.V_latent_ = value
160
+ # Reshape to (batch, features, channels, classes)
161
+ return x.view(-1, *self.reshape)
489
162
 
490
- @batch_size.setter
491
- def batch_size(self, value):
492
- """Set batch_size parameter.
493
- :noindex:
163
+ def compute_loss(
164
+ self,
165
+ y: torch.Tensor,
166
+ outputs: torch.Tensor,
167
+ mask: torch.Tensor | None = None,
168
+ class_weights: torch.Tensor | None = None,
169
+ gamma: float = 2.0,
170
+ ) -> torch.Tensor:
171
+ """Computes the masked focal loss between model outputs and ground truth.
172
+
173
+ This method calculates the loss value, handling class imbalance with weights and ignoring masked (missing) values.
174
+
175
+ Args:
176
+ y (torch.Tensor): The ground truth tensor of shape `(batch_size, n_features, n_channels)`.
177
+ outputs (torch.Tensor): The model's raw output (logits) of shape `(batch_size, n_features, n_channels, n_classes)`.
178
+ mask (torch.Tensor | None): An optional boolean mask indicating which elements should be included in the loss calculation. Defaults to None.
179
+ class_weights (torch.Tensor | None): An optional tensor of weights for each class to address imbalance. Defaults to None.
180
+ gamma (float): The focusing parameter for the focal loss. Defaults to 2.0.
181
+
182
+ Returns:
183
+ torch.Tensor: The computed scalar loss value.
494
184
  """
495
- self._batch_size = int(value)
185
+ if class_weights is None:
186
+ class_weights = torch.ones(self.num_classes, device=outputs.device)
496
187
 
497
- @batch_idx.setter
498
- def batch_idx(self, value):
499
- """Set current batch (=step) index.
500
- :noindex:
501
- """
502
- self._batch_idx = int(value)
188
+ if mask is None:
189
+ mask = torch.ones_like(y, dtype=torch.bool)
503
190
 
504
- @y.setter
505
- def y(self, value):
506
- """Set y after each epoch.
507
- :noindex:
508
- """
509
- self._y = value
191
+ # Explicitly flatten all tensors to the (N, C) and (N,) format.
192
+ # This creates a clear contract with the new MaskedFocalLoss function.
193
+ n_classes = outputs.shape[-1]
194
+ logits_flat = outputs.reshape(-1, n_classes)
195
+ targets_flat = y.reshape(-1)
196
+ mask_flat = mask.reshape(-1)
510
197
 
511
- @missing_mask.setter
512
- def missing_mask(self, value):
513
- """Set missing_mask after each epoch.
514
- :noindex:
515
- """
516
- self._missing_mask = value
198
+ criterion = MaskedFocalLoss(gamma=gamma, alpha=class_weights)
517
199
 
518
- @sample_weight.setter
519
- def sample_weight(self, value):
520
- """Set sample_weight after each epoch.
521
- :noindex:
522
- """
523
- self._sample_weight = value
200
+ return criterion(
201
+ logits_flat.to(self.device),
202
+ targets_flat.to(self.device),
203
+ valid_mask=mask_flat.to(self.device),
204
+ )