pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,46 +1,27 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import copy
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
import torch.nn.functional as F
|
|
7
|
-
from
|
|
8
|
-
|
|
7
|
+
from typing import List, Literal, Optional, Tuple, Union
|
|
8
|
+
import numpy as np
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from snpio.utils.logging import LoggerManager
|
|
11
11
|
from pgsui.utils.logging_utils import configure_logger
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class Sampling(nn.Module):
|
|
15
|
-
"""A layer that samples from a latent distribution using the reparameterization trick.
|
|
16
|
-
|
|
17
|
-
This layer is a core component of a Variational Autoencoder (VAE). It takes the mean and log-variance of a latent distribution as input and generates a sample from that distribution. By using the reparameterization trick ($z = \mu + \sigma \cdot \epsilon$), it allows gradients to be backpropagated through the random sampling process, making the VAE trainable.
|
|
18
|
-
"""
|
|
15
|
+
"""A layer that samples from a latent distribution using the reparameterization trick."""
|
|
19
16
|
|
|
20
17
|
def forward(self, z_mean: torch.Tensor, z_log_var: torch.Tensor) -> torch.Tensor:
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
z_mean (torch.Tensor): The mean of the latent normal distribution.
|
|
25
|
-
z_log_var (torch.Tensor): The log of the variance of the latent normal distribution.
|
|
26
|
-
|
|
27
|
-
Returns:
|
|
28
|
-
torch.Tensor: A sampled vector from the latent space.
|
|
29
|
-
"""
|
|
30
|
-
z_sigma = torch.exp(0.5 * z_log_var) # Precompute outside
|
|
31
|
-
|
|
32
|
-
# Ensure on GPU
|
|
33
|
-
# rand_like takes random values from a normal distribution
|
|
34
|
-
# of the same shape as z_mean.
|
|
18
|
+
z_sigma = torch.exp(0.5 * z_log_var)
|
|
35
19
|
epsilon = torch.randn_like(z_mean, device=z_mean.device)
|
|
36
20
|
return z_mean + z_sigma * epsilon
|
|
37
21
|
|
|
38
22
|
|
|
39
23
|
class Encoder(nn.Module):
|
|
40
|
-
"""The Encoder module of a Variational Autoencoder (VAE).
|
|
41
|
-
|
|
42
|
-
This module defines the encoder network, which takes high-dimensional input data and maps it to the parameters of a lower-dimensional latent distribution. The architecture consists of a series of fully-connected hidden layers that process the flattened input. The network culminates in two separate linear layers that output the mean (`z_mean`) and log-variance (`z_log_var`) of the approximate posterior distribution, $q(z|x)$.
|
|
43
|
-
"""
|
|
24
|
+
"""The Encoder module of a Variational Autoencoder (VAE)."""
|
|
44
25
|
|
|
45
26
|
def __init__(
|
|
46
27
|
self,
|
|
@@ -51,33 +32,17 @@ class Encoder(nn.Module):
|
|
|
51
32
|
dropout_rate: float,
|
|
52
33
|
activation: torch.nn.Module,
|
|
53
34
|
):
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
n_features (int): The number of features in the input data (e.g., SNPs).
|
|
58
|
-
num_classes (int): Number of genotype states per locus (2 for haploid, 3 for diploid in practice).
|
|
59
|
-
latent_dim (int): The dimensionality of the latent space.
|
|
60
|
-
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
61
|
-
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
62
|
-
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
63
|
-
"""
|
|
64
|
-
super(Encoder, self).__init__()
|
|
35
|
+
super().__init__()
|
|
65
36
|
self.flatten = nn.Flatten()
|
|
66
|
-
self.activation = (
|
|
67
|
-
getattr(F, activation) if isinstance(activation, str) else activation
|
|
68
|
-
)
|
|
69
37
|
|
|
70
38
|
layers = []
|
|
71
|
-
# The input dimension accounts for channels
|
|
72
39
|
input_dim = n_features * num_classes
|
|
40
|
+
|
|
73
41
|
for hidden_size in hidden_layer_sizes:
|
|
74
42
|
layers.append(nn.Linear(input_dim, hidden_size))
|
|
75
|
-
|
|
76
|
-
# BatchNorm can lead to faster convergence.
|
|
77
43
|
layers.append(nn.BatchNorm1d(hidden_size))
|
|
78
|
-
|
|
44
|
+
layers.append(copy.deepcopy(activation))
|
|
79
45
|
layers.append(nn.Dropout(dropout_rate))
|
|
80
|
-
layers.append(activation)
|
|
81
46
|
input_dim = hidden_size
|
|
82
47
|
|
|
83
48
|
self.hidden_layers = nn.Sequential(*layers)
|
|
@@ -88,14 +53,6 @@ class Encoder(nn.Module):
|
|
|
88
53
|
def forward(
|
|
89
54
|
self, x: torch.Tensor
|
|
90
55
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
91
|
-
"""Performs the forward pass through the encoder.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the latent mean (`z_mean`), latent log-variance (`z_log_var`), and a sample from the latent distribution (`z`).
|
|
98
|
-
"""
|
|
99
56
|
x = self.flatten(x)
|
|
100
57
|
x = self.hidden_layers(x)
|
|
101
58
|
z_mean = self.dense_z_mean(x)
|
|
@@ -105,10 +62,7 @@ class Encoder(nn.Module):
|
|
|
105
62
|
|
|
106
63
|
|
|
107
64
|
class Decoder(nn.Module):
|
|
108
|
-
"""The Decoder module of a Variational Autoencoder (VAE).
|
|
109
|
-
|
|
110
|
-
This module defines the decoder network, which takes a sample from the low-dimensional latent space and maps it back to the high-dimensional data space. It aims to reconstruct the original input data. The architecture consists of a series of fully-connected hidden layers followed by a final linear layer that produces the reconstructed data, which is then reshaped to match the original input's dimensions.
|
|
111
|
-
"""
|
|
65
|
+
"""The Decoder module of a Variational Autoencoder (VAE)."""
|
|
112
66
|
|
|
113
67
|
def __init__(
|
|
114
68
|
self,
|
|
@@ -119,65 +73,30 @@ class Decoder(nn.Module):
|
|
|
119
73
|
dropout_rate: float,
|
|
120
74
|
activation: torch.nn.Module,
|
|
121
75
|
) -> None:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
n_features (int): The number of features in the output data (e.g., SNPs).
|
|
126
|
-
num_classes (int): Number of genotype states per locus (typically 2 or 3).
|
|
127
|
-
latent_dim (int): The dimensionality of the input latent space.
|
|
128
|
-
hidden_layer_sizes (List[int]): A list of integers specifying the size of each hidden layer.
|
|
129
|
-
dropout_rate (float): The dropout rate for regularization in the hidden layers.
|
|
130
|
-
activation (torch.nn.Module): An instantiated activation function module (e.g., `nn.ReLU()`) for the hidden layers.
|
|
131
|
-
"""
|
|
132
|
-
super(Decoder, self).__init__()
|
|
76
|
+
super().__init__()
|
|
133
77
|
|
|
134
78
|
layers = []
|
|
135
79
|
input_dim = latent_dim
|
|
80
|
+
|
|
136
81
|
for hidden_size in hidden_layer_sizes:
|
|
137
82
|
layers.append(nn.Linear(input_dim, hidden_size))
|
|
138
|
-
|
|
139
|
-
# BatchNorm can lead to faster convergence.
|
|
140
83
|
layers.append(nn.BatchNorm1d(hidden_size))
|
|
141
|
-
|
|
84
|
+
layers.append(copy.deepcopy(activation))
|
|
142
85
|
layers.append(nn.Dropout(dropout_rate))
|
|
143
|
-
layers.append(activation)
|
|
144
86
|
input_dim = hidden_size
|
|
145
87
|
|
|
146
88
|
self.hidden_layers = nn.Sequential(*layers)
|
|
147
|
-
# UPDATED: Output dimension must account for channels
|
|
148
89
|
output_dim = n_features * num_classes
|
|
149
90
|
self.dense_output = nn.Linear(input_dim, output_dim)
|
|
150
|
-
# UPDATED: Reshape must account for channels
|
|
151
91
|
self.reshape = (n_features, num_classes)
|
|
152
92
|
|
|
153
93
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
154
|
-
"""Performs the forward pass through the decoder.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
x (torch.Tensor): The input latent tensor of shape `(batch_size, latent_dim)`.
|
|
158
|
-
|
|
159
|
-
Returns:
|
|
160
|
-
torch.Tensor: The reconstructed output data of shape `(batch_size, n_features, num_classes)`.
|
|
161
|
-
"""
|
|
162
94
|
x = self.hidden_layers(x)
|
|
163
95
|
x = self.dense_output(x)
|
|
164
96
|
return x.view(-1, *self.reshape)
|
|
165
97
|
|
|
166
98
|
|
|
167
99
|
class VAEModel(nn.Module):
|
|
168
|
-
"""A Variational Autoencoder (VAE) model for imputation.
|
|
169
|
-
|
|
170
|
-
This class combines an `Encoder` and a `Decoder` to form a VAE, a generative model for learning complex data distributions. It is designed for imputing missing values in categorical data, such as genomic SNPs. The model is trained by maximizing the Evidence Lower Bound (ELBO), which is a lower bound on the log-likelihood of the data.
|
|
171
|
-
|
|
172
|
-
**Objective Function (ELBO):**
|
|
173
|
-
The VAE loss function is derived from the ELBO and consists of two main components: a reconstruction term and a regularization term.
|
|
174
|
-
$$
|
|
175
|
-
\\mathcal{L}(\\theta, \\phi; x) = \\underbrace{\\mathbb{E}_{q_{\\phi}(z|x)}[\\log p_{\\theta}(x|z)]}_{\\text{Reconstruction Loss}} - \\underbrace{D_{KL}(q_{\\phi}(z|x) || p(z))}_{\\text{KL Divergence}}
|
|
176
|
-
$$
|
|
177
|
-
- The **Reconstruction Loss** encourages the decoder to accurately reconstruct the input data from its latent representation. This implementation uses a `MaskedFocalLoss`.
|
|
178
|
-
- The **KL Divergence** acts as a regularizer, forcing the approximate posterior distribution $q_{\\phi}(z|x)$ learned by the encoder to be close to a prior distribution $p(z)$ (typically a standard normal distribution).
|
|
179
|
-
"""
|
|
180
|
-
|
|
181
100
|
def __init__(
|
|
182
101
|
self,
|
|
183
102
|
n_features: int,
|
|
@@ -188,33 +107,18 @@ class VAEModel(nn.Module):
|
|
|
188
107
|
latent_dim: int = 2,
|
|
189
108
|
dropout_rate: float = 0.2,
|
|
190
109
|
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
|
|
191
|
-
|
|
192
|
-
beta: float = 1.0,
|
|
110
|
+
kl_beta: float = 1.0,
|
|
193
111
|
device: Literal["cpu", "gpu", "mps"] = "cpu",
|
|
194
112
|
verbose: bool = False,
|
|
195
113
|
debug: bool = False,
|
|
196
114
|
):
|
|
197
|
-
"""
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
latent_dim (int): The dimensionality of the latent space. Defaults to 2.
|
|
205
|
-
dropout_rate (float): The dropout rate for regularization in the hidden layers. Defaults to 0.2.
|
|
206
|
-
activation (str): The name of the activation function to use in hidden layers. Defaults to "relu".
|
|
207
|
-
gamma (float): The focusing parameter for the focal loss component. Defaults to 2.0.
|
|
208
|
-
beta (float): A weighting factor for the KL divergence term in the total loss ($\beta$-VAE). Defaults to 1.0.
|
|
209
|
-
device (Literal["cpu", "gpu", "mps"]): The device to run the model on.
|
|
210
|
-
verbose (bool): If True, enables detailed logging. Defaults to False.
|
|
211
|
-
debug (bool): If True, enables debug mode. Defaults to False.
|
|
212
|
-
"""
|
|
213
|
-
super(VAEModel, self).__init__()
|
|
214
|
-
self.num_classes = num_classes
|
|
215
|
-
self.gamma = gamma
|
|
216
|
-
self.beta = beta
|
|
217
|
-
self.device = device
|
|
115
|
+
"""Variational Autoencoder (VAE) model for unsupervised imputation."""
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.n_features = int(n_features)
|
|
118
|
+
self.num_classes = int(num_classes)
|
|
119
|
+
self.latent_dim = int(latent_dim)
|
|
120
|
+
self.kl_beta = float(kl_beta)
|
|
121
|
+
self.torch_device = device
|
|
218
122
|
|
|
219
123
|
logman = LoggerManager(
|
|
220
124
|
name=__name__, prefix=prefix, verbose=verbose, debug=debug
|
|
@@ -224,23 +128,20 @@ class VAEModel(nn.Module):
|
|
|
224
128
|
)
|
|
225
129
|
|
|
226
130
|
act = self._resolve_activation(activation)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
131
|
+
hls = (
|
|
132
|
+
hidden_layer_sizes.tolist()
|
|
133
|
+
if isinstance(hidden_layer_sizes, np.ndarray)
|
|
134
|
+
else hidden_layer_sizes
|
|
135
|
+
)
|
|
232
136
|
|
|
233
137
|
self.encoder = Encoder(
|
|
234
|
-
n_features, self.num_classes, latent_dim, hls, dropout_rate, act
|
|
138
|
+
self.n_features, self.num_classes, self.latent_dim, hls, dropout_rate, act
|
|
235
139
|
)
|
|
236
|
-
|
|
237
|
-
decoder_layer_sizes = list(reversed(hls))
|
|
238
|
-
|
|
239
140
|
self.decoder = Decoder(
|
|
240
|
-
n_features,
|
|
141
|
+
self.n_features,
|
|
241
142
|
self.num_classes,
|
|
242
|
-
latent_dim,
|
|
243
|
-
|
|
143
|
+
self.latent_dim,
|
|
144
|
+
list(reversed(hls)),
|
|
244
145
|
dropout_rate,
|
|
245
146
|
act,
|
|
246
147
|
)
|
|
@@ -248,102 +149,20 @@ class VAEModel(nn.Module):
|
|
|
248
149
|
def forward(
|
|
249
150
|
self, x: torch.Tensor
|
|
250
151
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
251
|
-
"""Performs the forward pass through the full VAE model.
|
|
252
|
-
|
|
253
|
-
Args:
|
|
254
|
-
x (torch.Tensor): The input data tensor of shape `(batch_size, n_features, num_classes)`.
|
|
255
|
-
|
|
256
|
-
Returns:
|
|
257
|
-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the reconstructed output, the latent mean (`z_mean`), and the latent log-variance (`z_log_var`).
|
|
258
|
-
"""
|
|
259
152
|
z_mean, z_log_var, z = self.encoder(x)
|
|
260
153
|
reconstruction = self.decoder(z)
|
|
261
154
|
return reconstruction, z_mean, z_log_var
|
|
262
155
|
|
|
263
|
-
def
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
mask: torch.Tensor | None = None,
|
|
268
|
-
class_weights: torch.Tensor | None = None,
|
|
269
|
-
) -> torch.Tensor:
|
|
270
|
-
"""Computes the VAE loss function (negative ELBO).
|
|
271
|
-
|
|
272
|
-
The loss is the sum of a reconstruction term and a regularizing KL divergence term. The reconstruction loss is calculated using a masked focal loss, and the KL divergence measures the difference between the learned latent distribution and a standard normal prior.
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
outputs (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The tuple of (reconstruction, z_mean, z_log_var) from the model's forward pass.
|
|
276
|
-
y (torch.Tensor): The target data tensor, expected to be one-hot encoded. This is converted to class indices internally for the loss function.
|
|
277
|
-
mask (torch.Tensor | None): A boolean mask to exclude missing values from the reconstruction loss.
|
|
278
|
-
class_weights (torch.Tensor | None): Weights to apply to each class in the reconstruction loss to handle imbalance.
|
|
279
|
-
|
|
280
|
-
Returns:
|
|
281
|
-
torch.Tensor: The computed scalar loss value.
|
|
282
|
-
"""
|
|
283
|
-
reconstruction, z_mean, z_log_var = outputs
|
|
284
|
-
|
|
285
|
-
# 1. KL Divergence Calculation
|
|
286
|
-
prior = Normal(torch.zeros_like(z_mean), torch.ones_like(z_log_var))
|
|
287
|
-
posterior = Normal(z_mean, torch.exp(0.5 * z_log_var))
|
|
288
|
-
kl_loss = (
|
|
289
|
-
torch.distributions.kl.kl_divergence(posterior, prior).sum(dim=1).mean()
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
if class_weights is None:
|
|
293
|
-
class_weights = torch.ones(self.num_classes, device=y.device)
|
|
294
|
-
|
|
295
|
-
# 2. Reconstruction Loss Calculation
|
|
296
|
-
# Reverting to the robust method of flattening tensors and using the
|
|
297
|
-
# custom loss function.
|
|
298
|
-
n_classes = reconstruction.shape[-1]
|
|
299
|
-
logits_flat = reconstruction.reshape(-1, n_classes)
|
|
300
|
-
|
|
301
|
-
# Convert one-hot `y` to class indices for the loss function.
|
|
302
|
-
targets_flat = torch.argmax(y, dim=-1).reshape(-1)
|
|
303
|
-
|
|
304
|
-
if mask is None:
|
|
305
|
-
# If no mask is provided, all targets are considered valid.
|
|
306
|
-
mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
|
|
307
|
-
else:
|
|
308
|
-
# The mask needs to be reshaped to match the flattened targets.
|
|
309
|
-
mask_flat = mask.reshape(-1)
|
|
310
|
-
|
|
311
|
-
# Logits, class-index targets, and the valid mask.
|
|
312
|
-
criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
|
|
313
|
-
|
|
314
|
-
reconstruction_loss = criterion(
|
|
315
|
-
logits_flat.to(self.device),
|
|
316
|
-
targets_flat.to(self.device),
|
|
317
|
-
valid_mask=mask_flat.to(self.device),
|
|
318
|
-
)
|
|
319
|
-
|
|
320
|
-
return reconstruction_loss + self.beta * kl_loss
|
|
321
|
-
|
|
322
|
-
def _resolve_activation(
|
|
323
|
-
self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
|
|
324
|
-
) -> torch.nn.Module:
|
|
325
|
-
"""Resolves an activation function module from a string name.
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
activation (Literal["relu", "elu", "leaky_relu", "selu"]): The name of the activation function.
|
|
329
|
-
|
|
330
|
-
Returns:
|
|
331
|
-
torch.nn.Module: The corresponding instantiated PyTorch activation function module.
|
|
332
|
-
|
|
333
|
-
Raises:
|
|
334
|
-
ValueError: If the provided activation name is not supported.
|
|
335
|
-
"""
|
|
336
|
-
if isinstance(activation, str):
|
|
337
|
-
a = activation.lower()
|
|
156
|
+
def _resolve_activation(self, activation: Union[str, nn.Module]) -> nn.Module:
|
|
157
|
+
if isinstance(activation, nn.Module):
|
|
158
|
+
return activation
|
|
159
|
+
a = activation.lower()
|
|
338
160
|
if a == "relu":
|
|
339
161
|
return nn.ReLU()
|
|
340
|
-
|
|
162
|
+
if a == "elu":
|
|
341
163
|
return nn.ELU()
|
|
342
|
-
|
|
164
|
+
if a in {"leaky_relu", "leakyrelu"}:
|
|
343
165
|
return nn.LeakyReLU()
|
|
344
|
-
|
|
166
|
+
if a == "selu":
|
|
345
167
|
return nn.SELU()
|
|
346
|
-
|
|
347
|
-
msg = f"Activation {activation} not supported."
|
|
348
|
-
self.logger.error(msg)
|
|
349
|
-
raise ValueError(msg)
|
|
168
|
+
raise ValueError(f"Activation {activation} not supported.")
|
|
@@ -6,6 +6,8 @@ from sklearn.metrics import (
|
|
|
6
6
|
accuracy_score,
|
|
7
7
|
average_precision_score,
|
|
8
8
|
f1_score,
|
|
9
|
+
jaccard_score,
|
|
10
|
+
matthews_corrcoef,
|
|
9
11
|
precision_score,
|
|
10
12
|
recall_score,
|
|
11
13
|
roc_auc_score,
|
|
@@ -106,18 +108,23 @@ class Scorer:
|
|
|
106
108
|
recall_score(y_true, y_pred, average=self.average, zero_division=0)
|
|
107
109
|
)
|
|
108
110
|
|
|
109
|
-
def roc_auc(self,
|
|
111
|
+
def roc_auc(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
110
112
|
"""Compute the ROC AUC score.
|
|
111
113
|
|
|
112
114
|
Args:
|
|
113
|
-
|
|
115
|
+
y_true_ohe (np.ndarray): One-hot encoded ground truth (correct) target values.
|
|
114
116
|
y_pred_proba (np.ndarray): Predicted probabilities.
|
|
115
117
|
|
|
116
118
|
Returns:
|
|
117
119
|
float: The ROC AUC score.
|
|
118
120
|
"""
|
|
119
|
-
if
|
|
120
|
-
|
|
121
|
+
if np.all(np.count_nonzero(y_true_ohe[..., 1]) == 0) or np.all(
|
|
122
|
+
np.count_nonzero(y_true_ohe[..., 2]) == 0
|
|
123
|
+
):
|
|
124
|
+
# ROC AUC is not defined in that case
|
|
125
|
+
msg = "No positive samples in y_true; ROC AUC score is undefined. Setting to 0.5 (random classification chance)."
|
|
126
|
+
self.logger.warning(msg)
|
|
127
|
+
return 0.5 # Return a neutral score
|
|
121
128
|
|
|
122
129
|
if y_pred_proba.shape[-1] == 2:
|
|
123
130
|
# Binary classification case
|
|
@@ -125,14 +132,11 @@ class Scorer:
|
|
|
125
132
|
# Otherwise it throws an error.
|
|
126
133
|
y_pred_proba = y_pred_proba[:, 1]
|
|
127
134
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
y_true, y_pred_proba, average=self.average, multi_class="ovr"
|
|
132
|
-
)
|
|
135
|
+
return float(
|
|
136
|
+
roc_auc_score(
|
|
137
|
+
y_true_ohe, y_pred_proba, average=self.average, multi_class="ovr"
|
|
133
138
|
)
|
|
134
|
-
|
|
135
|
-
return float(roc_auc_score(y_true, y_pred_proba, average=self.average))
|
|
139
|
+
)
|
|
136
140
|
|
|
137
141
|
# This method now correctly expects one-hot encoded true labels
|
|
138
142
|
def average_precision(
|
|
@@ -160,6 +164,34 @@ class Scorer:
|
|
|
160
164
|
average_precision_score(y_true_ohe, y_pred_proba, average=self.average)
|
|
161
165
|
)
|
|
162
166
|
|
|
167
|
+
def jaccard(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
168
|
+
"""Compute the Jaccard score.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
y_true (np.ndarray): Ground truth (correct) target values.
|
|
172
|
+
y_pred (np.ndarray): Estimated target values.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
float: The Jaccard score.
|
|
176
|
+
"""
|
|
177
|
+
return float(
|
|
178
|
+
jaccard_score(y_true, y_pred, average=self.average, zero_division=0)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def mcc(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
182
|
+
"""Compute the Matthews correlation coefficient (MCC).
|
|
183
|
+
|
|
184
|
+
MCC is a balanced measure that can be used even if the classes are of very different sizes. It returns a value between -1 and +1, where +1 indicates a perfect prediction, 0 indicates no better than random prediction, and -1 indicates total disagreement between prediction and observation.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
y_true (np.ndarray): Ground truth (correct) target values.
|
|
188
|
+
y_pred (np.ndarray): Estimated target values.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
float: The Matthews correlation coefficient.
|
|
192
|
+
"""
|
|
193
|
+
return float(matthews_corrcoef(y_true, y_pred))
|
|
194
|
+
|
|
163
195
|
def pr_macro(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
164
196
|
"""Compute the macro-average precision score.
|
|
165
197
|
|
|
@@ -196,6 +228,8 @@ class Scorer:
|
|
|
196
228
|
"f1",
|
|
197
229
|
"precision",
|
|
198
230
|
"recall",
|
|
231
|
+
"mcc",
|
|
232
|
+
"jaccard",
|
|
199
233
|
] = "pr_macro",
|
|
200
234
|
) -> Dict[str, float]:
|
|
201
235
|
"""Evaluate the model using various metrics.
|
|
@@ -218,7 +252,7 @@ class Scorer:
|
|
|
218
252
|
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
219
253
|
),
|
|
220
254
|
"roc_auc": lambda: self.roc_auc(
|
|
221
|
-
np.asarray(
|
|
255
|
+
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
222
256
|
),
|
|
223
257
|
"average_precision": lambda: self.average_precision(
|
|
224
258
|
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
@@ -231,6 +265,8 @@ class Scorer:
|
|
|
231
265
|
np.asarray(y_true), np.asarray(y_pred)
|
|
232
266
|
),
|
|
233
267
|
"recall": lambda: self.recall(np.asarray(y_true), np.asarray(y_pred)),
|
|
268
|
+
"mcc": lambda: self.mcc(np.asarray(y_true), np.asarray(y_pred)),
|
|
269
|
+
"jaccard": lambda: self.jaccard(np.asarray(y_true), np.asarray(y_pred)),
|
|
234
270
|
}
|
|
235
271
|
if tune_metric not in metric_calculators:
|
|
236
272
|
msg = f"Invalid tune_metric provided: '{tune_metric}'."
|
|
@@ -244,12 +280,16 @@ class Scorer:
|
|
|
244
280
|
"f1": self.f1(np.asarray(y_true), np.asarray(y_pred)),
|
|
245
281
|
"precision": self.precision(np.asarray(y_true), np.asarray(y_pred)),
|
|
246
282
|
"recall": self.recall(np.asarray(y_true), np.asarray(y_pred)),
|
|
247
|
-
"roc_auc": self.roc_auc(
|
|
283
|
+
"roc_auc": self.roc_auc(
|
|
284
|
+
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
285
|
+
),
|
|
248
286
|
"average_precision": self.average_precision(
|
|
249
287
|
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
250
288
|
),
|
|
251
289
|
"pr_macro": self.pr_macro(
|
|
252
290
|
np.asarray(y_true_ohe), np.asarray(y_pred_proba)
|
|
253
291
|
),
|
|
292
|
+
"mcc": self.mcc(np.asarray(y_true), np.asarray(y_pred)),
|
|
293
|
+
"jaccard": self.jaccard(np.asarray(y_true), np.asarray(y_pred)),
|
|
254
294
|
}
|
|
255
295
|
return {k: float(v) for k, v in metrics.items()}
|