pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,261 +1,215 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal, cast
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
import torch.nn as nn
|
|
5
7
|
import torch.nn.functional as F
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
class
|
|
9
|
-
"""Focal cross-entropy with ignore_index and
|
|
10
|
+
class FocalCELoss(nn.Module):
|
|
11
|
+
"""Focal cross-entropy with ignore_index and optional scaling.
|
|
12
|
+
|
|
13
|
+
Supports logits of shape (N, C) or (N, C, d1, d2, ...). Targets must be shape-compatible: (N) or (N, d1, d2, ...).
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
The optional `recon_scale` is useful in reconstruction settings (e.g., VAE) when your base reduction is "mean" over a sparse mask. Multiplying the final reduced loss by `recon_scale` makes the reconstruction term more "sum-like" per batch/sample, preventing KL from dominating.
|
|
12
16
|
"""
|
|
13
17
|
|
|
14
18
|
def __init__(
|
|
15
19
|
self,
|
|
16
|
-
|
|
17
|
-
|
|
20
|
+
*,
|
|
21
|
+
alpha: torch.Tensor | None = None,
|
|
22
|
+
gamma: float = 2.0,
|
|
18
23
|
ignore_index: int = -1,
|
|
19
|
-
|
|
20
|
-
):
|
|
21
|
-
"""Initialize the
|
|
22
|
-
|
|
23
|
-
This class sets up the focal loss with specified focusing parameter, class weights, ignore index, and a small epsilon for numerical stability.
|
|
24
|
+
reduction: Literal["mean", "sum", "none"] = "mean",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the focal cross-entropy loss.
|
|
24
27
|
|
|
25
28
|
Args:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
ignore_index
|
|
29
|
-
|
|
29
|
+
alpha: Optional per-class weights of shape (C,).
|
|
30
|
+
gamma: Focusing parameter.
|
|
31
|
+
ignore_index: Target value to ignore.
|
|
32
|
+
reduction: "mean", "sum", or "none".
|
|
30
33
|
"""
|
|
31
34
|
super().__init__()
|
|
32
|
-
self.
|
|
33
|
-
self.
|
|
34
|
-
self.ignore_index = ignore_index
|
|
35
|
-
self.eps = eps
|
|
36
|
-
|
|
37
|
-
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
38
|
-
"""Calculates the focal loss on pre-flattened tensors.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
logits (torch.Tensor): Logits from the model of shape (N, C) where N is the number of samples and C is the number of classes.
|
|
42
|
-
targets (torch.Tensor): Ground truth labels of shape (N,).
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
torch.Tensor: The computed scalar loss value.
|
|
46
|
-
"""
|
|
47
|
-
# logits: (N, C), targets: (N,)
|
|
48
|
-
valid = targets != self.ignore_index
|
|
49
|
-
|
|
50
|
-
if not valid.any():
|
|
51
|
-
return logits.new_tensor(0.0)
|
|
52
|
-
|
|
53
|
-
logits_v = logits[valid]
|
|
54
|
-
targets_v = targets[valid]
|
|
55
|
-
|
|
56
|
-
logp = F.log_softmax(logits_v, dim=-1) # stable
|
|
57
|
-
ce = F.nll_loss(logp, targets_v, weight=self.weight, reduction="none")
|
|
58
|
-
|
|
59
|
-
# p_t = exp(logp[range, targets])
|
|
60
|
-
p_t = torch.exp(logp.gather(1, targets_v.unsqueeze(1)).squeeze(1))
|
|
61
|
-
|
|
62
|
-
# focal factor with clamp to avoid 0**gamma and NaNs
|
|
63
|
-
focal = (1.0 - p_t).clamp_min(self.eps).pow(self.gamma)
|
|
64
|
-
|
|
65
|
-
loss_vec = focal * ce
|
|
66
|
-
|
|
67
|
-
# guard remaining inf/nan if any slipped through
|
|
68
|
-
loss_vec = torch.nan_to_num(loss_vec, nan=0.0, posinf=1e6, neginf=0.0)
|
|
69
|
-
return loss_vec.mean()
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class WeightedMaskedCCELoss(nn.Module):
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
alpha: float | List[float] | torch.Tensor | None = None,
|
|
76
|
-
reduction: Literal["mean", "sum"] = "mean",
|
|
77
|
-
):
|
|
78
|
-
"""A weighted, masked Categorical Cross-Entropy loss function.
|
|
79
|
-
|
|
80
|
-
This method computes the categorical cross-entropy loss while allowing for class weights and masking of invalid (missing) entries. It is particularly useful for sequence data where some positions may be missing or should not contribute to the loss calculation.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
alpha (float | List | Tensor | None): A manual rescaling weight given to each class. If given, has to be a Tensor of size C (number of classes). Defaults to None.
|
|
84
|
-
reduction (str, optional): Specifies the reduction to apply to the output: 'mean' or 'sum'. Defaults to "mean".
|
|
85
|
-
"""
|
|
86
|
-
super(WeightedMaskedCCELoss, self).__init__()
|
|
35
|
+
self._gamma = float(gamma)
|
|
36
|
+
self.ignore_index = int(ignore_index)
|
|
87
37
|
self.reduction = reduction
|
|
88
|
-
|
|
38
|
+
|
|
39
|
+
if alpha is not None:
|
|
40
|
+
if alpha.dim() != 1:
|
|
41
|
+
raise ValueError("alpha must be a 1D tensor of shape (C,).")
|
|
42
|
+
# Register as buffer so it moves with the module across devices.
|
|
43
|
+
self.register_buffer("alpha", alpha)
|
|
44
|
+
else:
|
|
45
|
+
self.alpha = None
|
|
89
46
|
|
|
90
47
|
def forward(
|
|
91
48
|
self,
|
|
92
49
|
logits: torch.Tensor,
|
|
93
50
|
targets: torch.Tensor,
|
|
94
|
-
|
|
51
|
+
*,
|
|
52
|
+
recon_scale: torch.Tensor | float | None = None,
|
|
95
53
|
) -> torch.Tensor:
|
|
96
|
-
"""Compute
|
|
54
|
+
"""Compute focal cross-entropy loss.
|
|
97
55
|
|
|
98
56
|
Args:
|
|
99
|
-
logits (
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
57
|
+
logits: Tensor of shape (N, C) or (N, C, d1, d2, ...).
|
|
58
|
+
targets: Tensor of shape (N) or (N, d1, d2, ...).
|
|
59
|
+
recon_scale: Optional scalar multiplier applied to the final loss.
|
|
60
|
+
- If reduction is "mean" or "sum", multiplies the scalar loss.
|
|
61
|
+
- If reduction is "none", multiplies elementwise.
|
|
104
62
|
|
|
105
63
|
Returns:
|
|
106
|
-
|
|
64
|
+
Loss tensor:
|
|
65
|
+
- Scalar if reduction in {"mean","sum"}
|
|
66
|
+
- Tensor shaped like `targets` if reduction == "none"
|
|
107
67
|
"""
|
|
108
|
-
#
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
# Ensure targets are on the correct device and are Long type
|
|
113
|
-
targets = targets.to(device).long()
|
|
114
|
-
|
|
115
|
-
# Prepare weights and pass them directly to the loss function
|
|
116
|
-
class_weights = None
|
|
117
|
-
if self.alpha is not None:
|
|
118
|
-
if not isinstance(self.alpha, torch.Tensor):
|
|
119
|
-
class_weights = torch.tensor(
|
|
120
|
-
self.alpha, dtype=torch.float, device=device
|
|
121
|
-
)
|
|
122
|
-
else:
|
|
123
|
-
class_weights = self.alpha.to(device)
|
|
124
|
-
|
|
125
|
-
loss = F.cross_entropy(
|
|
126
|
-
logits.reshape(-1, num_classes),
|
|
127
|
-
targets.reshape(-1),
|
|
128
|
-
weight=class_weights,
|
|
129
|
-
reduction="none",
|
|
130
|
-
ignore_index=-1, # Ignore all targets with the value -1
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
# If a mask is provided, filter the losses for the training set
|
|
134
|
-
if valid_mask is not None:
|
|
135
|
-
loss = loss[valid_mask.reshape(-1)]
|
|
136
|
-
|
|
137
|
-
# If after masking no valid losses remain, return 0
|
|
138
|
-
if loss.numel() == 0:
|
|
139
|
-
return torch.tensor(0.0, device=device)
|
|
140
|
-
|
|
141
|
-
# Apply the final reduction
|
|
142
|
-
if self.reduction == "mean":
|
|
143
|
-
return loss.mean()
|
|
144
|
-
elif self.reduction == "sum":
|
|
145
|
-
return loss.sum()
|
|
146
|
-
else:
|
|
147
|
-
msg = f"Reduction mode '{self.reduction}' not supported."
|
|
148
|
-
raise ValueError(msg)
|
|
68
|
+
# Move C (dim 1) to the last position for flattening:
|
|
69
|
+
# (N, C, d1, ...) -> (N, d1, ..., C)
|
|
70
|
+
if logits.dim() > 2:
|
|
71
|
+
logits = logits.permute(0, *range(2, logits.dim()), 1)
|
|
149
72
|
|
|
73
|
+
logits_flat = logits.reshape(-1, logits.size(-1))
|
|
74
|
+
targets_flat = targets.reshape(-1).long()
|
|
150
75
|
|
|
151
|
-
|
|
152
|
-
"""Focal loss (gamma > 0) with optional class weights and a boolean valid mask.
|
|
76
|
+
valid_mask = targets_flat != self.ignore_index
|
|
153
77
|
|
|
154
|
-
|
|
155
|
-
|
|
78
|
+
# Early exit if everything is ignored
|
|
79
|
+
if not bool(valid_mask.any()):
|
|
80
|
+
out = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
|
|
81
|
+
# preserve grad path behavior if caller expects it
|
|
82
|
+
out = out.requires_grad_(True)
|
|
83
|
+
return out
|
|
156
84
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
gamma: float = 2.0,
|
|
160
|
-
alpha: torch.Tensor | None = None,
|
|
161
|
-
reduction: Literal["mean", "sum"] = "mean",
|
|
162
|
-
):
|
|
163
|
-
"""Initialize the MaskedFocalLoss.
|
|
85
|
+
logits_v = logits_flat[valid_mask]
|
|
86
|
+
targets_v = targets_flat[valid_mask]
|
|
164
87
|
|
|
165
|
-
|
|
88
|
+
# Numerically stable log-softmax
|
|
89
|
+
log_probs = F.log_softmax(logits_v, dim=-1)
|
|
90
|
+
log_pt = log_probs.gather(1, targets_v.unsqueeze(1)).squeeze(1)
|
|
91
|
+
pt = log_pt.exp()
|
|
166
92
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
alpha (torch.Tensor | None): Class weights.
|
|
170
|
-
reduction (Literal["mean", "sum"]): Reduction mode ('mean' or 'sum').
|
|
171
|
-
"""
|
|
172
|
-
super().__init__()
|
|
173
|
-
self.gamma = gamma
|
|
174
|
-
self.alpha = alpha
|
|
175
|
-
self.reduction = reduction
|
|
176
|
-
|
|
177
|
-
def forward(
|
|
178
|
-
self,
|
|
179
|
-
logits: torch.Tensor, # Expects (N, C) where N = batch*features
|
|
180
|
-
targets: torch.Tensor, # Expects (N,)
|
|
181
|
-
valid_mask: torch.Tensor, # Expects (N,)
|
|
182
|
-
) -> torch.Tensor:
|
|
183
|
-
"""Calculates the focal loss on pre-flattened tensors.
|
|
93
|
+
focal_term = (1.0 - pt).pow(self.gamma)
|
|
94
|
+
loss_vec = -focal_term * log_pt
|
|
184
95
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
targets (torch.Tensor): Ground truth labels of shape (N,).
|
|
188
|
-
valid_mask (torch.Tensor): Boolean mask of shape (N,) where True indicates a valid (observed) value to include in the loss.
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
torch.Tensor: The computed scalar loss value.
|
|
192
|
-
"""
|
|
193
|
-
device = logits.device
|
|
96
|
+
if self.alpha is not None:
|
|
97
|
+
loss_vec = loss_vec * self.alpha[targets_v]
|
|
194
98
|
|
|
195
|
-
#
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
99
|
+
# Apply reduction
|
|
100
|
+
if self.reduction == "mean":
|
|
101
|
+
out = loss_vec.mean()
|
|
102
|
+
elif self.reduction == "sum":
|
|
103
|
+
out = loss_vec.sum()
|
|
104
|
+
else: # "none"
|
|
105
|
+
out_flat = torch.zeros_like(targets_flat, dtype=loss_vec.dtype)
|
|
106
|
+
out_flat[valid_mask] = loss_vec
|
|
107
|
+
out = out_flat.view(targets.shape)
|
|
108
|
+
|
|
109
|
+
# Optional scaling (useful for VAE recon term)
|
|
110
|
+
if recon_scale is not None:
|
|
111
|
+
if not isinstance(recon_scale, torch.Tensor):
|
|
112
|
+
recon_scale = torch.tensor(
|
|
113
|
+
float(recon_scale), device=out.device, dtype=out.dtype
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
recon_scale = recon_scale.to(device=out.device, dtype=out.dtype)
|
|
203
117
|
|
|
204
|
-
|
|
205
|
-
pt = torch.exp(-ce)
|
|
206
|
-
focal = ((1 - pt) ** self.gamma) * ce
|
|
118
|
+
out = out * recon_scale
|
|
207
119
|
|
|
208
|
-
|
|
209
|
-
focal = focal[valid_mask]
|
|
120
|
+
return out
|
|
210
121
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
122
|
+
@property
|
|
123
|
+
def gamma(self) -> float:
|
|
124
|
+
return self._gamma
|
|
214
125
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
else:
|
|
221
|
-
msg = f"Reduction mode '{self.reduction}' not supported."
|
|
222
|
-
raise ValueError(msg)
|
|
126
|
+
@gamma.setter
|
|
127
|
+
def gamma(self, value: torch.Tensor | float) -> None:
|
|
128
|
+
if isinstance(value, torch.Tensor):
|
|
129
|
+
value = float(value.item())
|
|
130
|
+
self._gamma = float(value)
|
|
223
131
|
|
|
224
132
|
|
|
225
133
|
def safe_kl_gauss_unit(
|
|
226
134
|
mu: torch.Tensor, logvar: torch.Tensor, reduction: str = "mean"
|
|
227
135
|
) -> torch.Tensor:
|
|
228
|
-
"""KL divergence between N(mu,
|
|
229
|
-
|
|
230
|
-
|
|
136
|
+
"""Compute KL divergence between N(mu, var) and N(0, I) with numeric guards.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
mu (torch.Tensor): Latent mean (shape: [B, D]).
|
|
140
|
+
logvar (torch.Tensor): Latent log-variance (shape: [B, D]).
|
|
141
|
+
reduction (str): Reduction method ('mean' or 'sum').
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
torch.Tensor: KL divergence (scalar).
|
|
145
|
+
"""
|
|
146
|
+
kl = -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp()) # (B, D)
|
|
147
|
+
kl = kl.sum(dim=-1) # (B,)
|
|
148
|
+
|
|
231
149
|
if reduction == "sum":
|
|
232
150
|
kl = kl.sum()
|
|
233
151
|
elif reduction == "mean":
|
|
234
152
|
kl = kl.mean()
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Invalid reduction: {reduction}")
|
|
155
|
+
|
|
235
156
|
return torch.nan_to_num(kl, nan=0.0, posinf=1e6, neginf=0.0)
|
|
236
157
|
|
|
237
158
|
|
|
238
159
|
def compute_vae_loss(
|
|
160
|
+
criterion: nn.Module,
|
|
239
161
|
recon_logits: torch.Tensor,
|
|
240
162
|
targets: torch.Tensor,
|
|
241
163
|
*,
|
|
242
164
|
mu: torch.Tensor,
|
|
243
165
|
logvar: torch.Tensor,
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
ignore_index: int = -1,
|
|
166
|
+
kl_beta: torch.Tensor | float,
|
|
167
|
+
reduction: str = "mean",
|
|
168
|
+
recon_scale: torch.Tensor | float | None = None,
|
|
248
169
|
) -> torch.Tensor:
|
|
249
|
-
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
170
|
+
"""Compute VAE loss: reconstruction + KL divergence, with optional recon scaling.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
criterion: Reconstruction loss module (e.g., FocalCELoss / CrossEntropyLoss).
|
|
174
|
+
Must accept (logits_2d, targets_1d). If it supports `recon_scale`, it will
|
|
175
|
+
be passed through; otherwise it will be called without it.
|
|
176
|
+
recon_logits: Reconstruction logits from decoder. Shape: (N, L, C) or (N_eval, C).
|
|
177
|
+
targets: Ground truth targets. Shape: (N, L) or (N_eval,).
|
|
178
|
+
mu: Latent mean. Shape: (B, D) (or compatible with safe_kl_gauss_unit).
|
|
179
|
+
logvar: Latent log-variance. Shape: (B, D).
|
|
180
|
+
kl_beta: Scalar KL weight.
|
|
181
|
+
reduction: KL reduction: "mean" or "sum".
|
|
182
|
+
recon_scale: Optional scalar multiplier applied to reconstruction term.
|
|
183
|
+
Use this to make reconstruction more "sum-like" for high-dimensional data.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Scalar loss tensor.
|
|
187
|
+
"""
|
|
188
|
+
# Flatten logits/targets to (N_total, C) and (N_total,)
|
|
189
|
+
if recon_logits.dim() == 3:
|
|
190
|
+
logits_2d = recon_logits.reshape(-1, recon_logits.size(-1))
|
|
191
|
+
elif recon_logits.dim() == 2:
|
|
192
|
+
logits_2d = recon_logits
|
|
193
|
+
else:
|
|
194
|
+
msg = f"recon_logits must be 2D or 3D; got shape {tuple(recon_logits.shape)}"
|
|
195
|
+
raise ValueError(msg)
|
|
196
|
+
|
|
197
|
+
tgt_1d = targets.reshape(-1) if targets.dim() > 1 else targets
|
|
198
|
+
|
|
199
|
+
# Reconstruction loss (criterion may ignore_index internally)
|
|
200
|
+
try:
|
|
201
|
+
rec = criterion(logits_2d, tgt_1d, recon_scale=recon_scale)
|
|
202
|
+
except TypeError:
|
|
203
|
+
# Criterion doesn't accept recon_scale (e.g., torch.nn.CrossEntropyLoss)
|
|
204
|
+
rec = criterion(logits_2d, tgt_1d)
|
|
205
|
+
if recon_scale is not None:
|
|
206
|
+
if isinstance(recon_scale, torch.Tensor):
|
|
207
|
+
rec = rec * recon_scale.to(device=rec.device, dtype=rec.dtype)
|
|
208
|
+
else:
|
|
209
|
+
rec = rec * float(recon_scale)
|
|
210
|
+
|
|
211
|
+
# KL term
|
|
212
|
+
kl = safe_kl_gauss_unit(mu, logvar, reduction=reduction)
|
|
213
|
+
loss = rec + (kl_beta * kl)
|
|
214
|
+
|
|
215
|
+
return torch.nan_to_num(loss, nan=1e6, posinf=1e6, neginf=1e6)
|
|
@@ -5,7 +5,6 @@ import torch
|
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
from snpio.utils.logging import LoggerManager
|
|
7
7
|
|
|
8
|
-
from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
|
|
9
8
|
from pgsui.utils.logging_utils import configure_logger
|
|
10
9
|
|
|
11
10
|
|
|
@@ -44,8 +43,8 @@ class Encoder(nn.Module):
|
|
|
44
43
|
for hidden_size in hidden_layer_sizes:
|
|
45
44
|
layers.append(nn.Linear(input_dim, hidden_size))
|
|
46
45
|
layers.append(nn.BatchNorm1d(hidden_size))
|
|
47
|
-
layers.append(nn.Dropout(dropout_rate))
|
|
48
46
|
layers.append(activation)
|
|
47
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
49
48
|
input_dim = hidden_size
|
|
50
49
|
|
|
51
50
|
self.hidden_layers = nn.Sequential(*layers)
|
|
@@ -98,8 +97,8 @@ class Decoder(nn.Module):
|
|
|
98
97
|
for hidden_size in hidden_layer_sizes:
|
|
99
98
|
layers.append(nn.Linear(input_dim, hidden_size))
|
|
100
99
|
layers.append(nn.BatchNorm1d(hidden_size))
|
|
101
|
-
layers.append(nn.Dropout(dropout_rate))
|
|
102
100
|
layers.append(activation)
|
|
101
|
+
layers.append(nn.Dropout(dropout_rate))
|
|
103
102
|
input_dim = hidden_size
|
|
104
103
|
|
|
105
104
|
self.hidden_layers = nn.Sequential(*layers)
|
|
@@ -128,17 +127,17 @@ class AutoencoderModel(nn.Module):
|
|
|
128
127
|
|
|
129
128
|
**Model Architecture and Objective:**
|
|
130
129
|
|
|
131
|
-
The autoencoder consists of two parts: an encoder, $f_{
|
|
130
|
+
The autoencoder consists of two parts: an encoder, $f_{\\theta}$, and a decoder, $g_{\\phi}$.
|
|
132
131
|
1. The **encoder** maps the input data $x$ to a latent representation $z$:
|
|
133
132
|
$$
|
|
134
133
|
z = f_{\theta}(x)
|
|
135
134
|
$$
|
|
136
|
-
2. The **decoder** reconstructs the data
|
|
135
|
+
2. The **decoder** reconstructs the data $\\hat{x}$ from the latent representation:
|
|
137
136
|
$$
|
|
138
|
-
|
|
137
|
+
\\hat{x} = g_{\\phi}(z)
|
|
139
138
|
$$
|
|
140
139
|
|
|
141
|
-
The model is trained by minimizing a reconstruction loss, $L(x,
|
|
140
|
+
The model is trained by minimizing a reconstruction loss, $L(x, \\hat{x})$, which measures the dissimilarity between the original input and the reconstructed output. This implementation uses a ``FocalCELoss`` to handle missing values and class imbalance effectively.
|
|
142
141
|
"""
|
|
143
142
|
|
|
144
143
|
def __init__(
|
|
@@ -151,7 +150,7 @@ class AutoencoderModel(nn.Module):
|
|
|
151
150
|
latent_dim: int = 2,
|
|
152
151
|
dropout_rate: float = 0.2,
|
|
153
152
|
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
|
|
154
|
-
gamma:
|
|
153
|
+
gamma: torch.Tensor = torch.tensor(2.0),
|
|
155
154
|
device: Literal["cpu", "gpu", "mps"] = "cpu",
|
|
156
155
|
verbose: bool = False,
|
|
157
156
|
debug: bool = False,
|
|
@@ -222,47 +221,6 @@ class AutoencoderModel(nn.Module):
|
|
|
222
221
|
reconstruction = self.decoder(z)
|
|
223
222
|
return reconstruction
|
|
224
223
|
|
|
225
|
-
def compute_loss(
|
|
226
|
-
self,
|
|
227
|
-
reconstruction: torch.Tensor,
|
|
228
|
-
y: torch.Tensor,
|
|
229
|
-
mask: torch.Tensor | None = None,
|
|
230
|
-
class_weights: torch.Tensor | None = None,
|
|
231
|
-
) -> torch.Tensor:
|
|
232
|
-
"""Computes the reconstruction loss for the Autoencoder model.
|
|
233
|
-
|
|
234
|
-
This method calculates the reconstruction loss using a masked focal loss, which is suitable for categorical data with missing values and class imbalance.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
reconstruction (torch.Tensor): The reconstructed output (logits) from the model's forward pass.
|
|
238
|
-
y (torch.Tensor): The target data tensor, expected to be one-hot encoded. It is converted to class indices internally for the loss calculation.
|
|
239
|
-
mask (torch.Tensor | None): A boolean mask to exclude missing values from the loss calculation.
|
|
240
|
-
class_weights (torch.Tensor | None): Weights to apply to each class in the loss to handle imbalance.
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
torch.Tensor: The computed scalar loss value.
|
|
244
|
-
"""
|
|
245
|
-
if class_weights is None:
|
|
246
|
-
class_weights = torch.ones(self.num_classes, device=y.device)
|
|
247
|
-
|
|
248
|
-
logits_flat = reconstruction.view(-1, self.num_classes)
|
|
249
|
-
targets_flat = torch.argmax(y, dim=-1).view(-1)
|
|
250
|
-
|
|
251
|
-
if mask is None:
|
|
252
|
-
mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
|
|
253
|
-
else:
|
|
254
|
-
mask_flat = mask.view(-1)
|
|
255
|
-
|
|
256
|
-
criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
|
|
257
|
-
|
|
258
|
-
reconstruction_loss = criterion(
|
|
259
|
-
logits_flat.to(self.device),
|
|
260
|
-
targets_flat.to(self.device),
|
|
261
|
-
valid_mask=mask_flat.to(self.device),
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
return reconstruction_loss
|
|
265
|
-
|
|
266
224
|
def _resolve_activation(
|
|
267
225
|
self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
|
|
268
226
|
) -> torch.nn.Module:
|