pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,261 +1,215 @@
1
- from typing import List, Literal
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal, cast
2
4
 
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
6
8
 
7
9
 
8
- class SafeFocalCELoss(nn.Module):
9
- """Focal cross-entropy with ignore_index and numeric guards.
10
+ class FocalCELoss(nn.Module):
11
+ """Focal cross-entropy with ignore_index and optional scaling.
12
+
13
+ Supports logits of shape (N, C) or (N, C, d1, d2, ...). Targets must be shape-compatible: (N) or (N, d1, d2, ...).
10
14
 
11
- This class implements the focal loss function, which is designed to address class imbalance by down-weighting easy examples and focusing training on hard negatives. It also includes handling for ignored indices and numeric stability.
15
+ The optional `recon_scale` is useful in reconstruction settings (e.g., VAE) when your base reduction is "mean" over a sparse mask. Multiplying the final reduced loss by `recon_scale` makes the reconstruction term more "sum-like" per batch/sample, preventing KL from dominating.
12
16
  """
13
17
 
14
18
  def __init__(
15
19
  self,
16
- gamma: float,
17
- weight: torch.Tensor | None = None,
20
+ *,
21
+ alpha: torch.Tensor | None = None,
22
+ gamma: float = 2.0,
18
23
  ignore_index: int = -1,
19
- eps: float = 1e-8,
20
- ):
21
- """Initialize the SafeFocalCELoss.
22
-
23
- This class sets up the focal loss with specified focusing parameter, class weights, ignore index, and a small epsilon for numerical stability.
24
+ reduction: Literal["mean", "sum", "none"] = "mean",
25
+ ) -> None:
26
+ """Initialize the focal cross-entropy loss.
24
27
 
25
28
  Args:
26
- gamma (float): Focusing parameter.
27
- weight (torch.Tensor | None): A manual rescaling weight given to each class. If given, has to be a Tensor of size C (number of classes). Defaults to None.
28
- ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. Default is -1.
29
- eps (float): Small value to avoid numerical issues. Default is 1e-8.
29
+ alpha: Optional per-class weights of shape (C,).
30
+ gamma: Focusing parameter.
31
+ ignore_index: Target value to ignore.
32
+ reduction: "mean", "sum", or "none".
30
33
  """
31
34
  super().__init__()
32
- self.gamma = gamma
33
- self.weight = weight
34
- self.ignore_index = ignore_index
35
- self.eps = eps
36
-
37
- def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
38
- """Calculates the focal loss on pre-flattened tensors.
39
-
40
- Args:
41
- logits (torch.Tensor): Logits from the model of shape (N, C) where N is the number of samples and C is the number of classes.
42
- targets (torch.Tensor): Ground truth labels of shape (N,).
43
-
44
- Returns:
45
- torch.Tensor: The computed scalar loss value.
46
- """
47
- # logits: (N, C), targets: (N,)
48
- valid = targets != self.ignore_index
49
-
50
- if not valid.any():
51
- return logits.new_tensor(0.0)
52
-
53
- logits_v = logits[valid]
54
- targets_v = targets[valid]
55
-
56
- logp = F.log_softmax(logits_v, dim=-1) # stable
57
- ce = F.nll_loss(logp, targets_v, weight=self.weight, reduction="none")
58
-
59
- # p_t = exp(logp[range, targets])
60
- p_t = torch.exp(logp.gather(1, targets_v.unsqueeze(1)).squeeze(1))
61
-
62
- # focal factor with clamp to avoid 0**gamma and NaNs
63
- focal = (1.0 - p_t).clamp_min(self.eps).pow(self.gamma)
64
-
65
- loss_vec = focal * ce
66
-
67
- # guard remaining inf/nan if any slipped through
68
- loss_vec = torch.nan_to_num(loss_vec, nan=0.0, posinf=1e6, neginf=0.0)
69
- return loss_vec.mean()
70
-
71
-
72
- class WeightedMaskedCCELoss(nn.Module):
73
- def __init__(
74
- self,
75
- alpha: float | List[float] | torch.Tensor | None = None,
76
- reduction: Literal["mean", "sum"] = "mean",
77
- ):
78
- """A weighted, masked Categorical Cross-Entropy loss function.
79
-
80
- This method computes the categorical cross-entropy loss while allowing for class weights and masking of invalid (missing) entries. It is particularly useful for sequence data where some positions may be missing or should not contribute to the loss calculation.
81
-
82
- Args:
83
- alpha (float | List | Tensor | None): A manual rescaling weight given to each class. If given, has to be a Tensor of size C (number of classes). Defaults to None.
84
- reduction (str, optional): Specifies the reduction to apply to the output: 'mean' or 'sum'. Defaults to "mean".
85
- """
86
- super(WeightedMaskedCCELoss, self).__init__()
35
+ self._gamma = float(gamma)
36
+ self.ignore_index = int(ignore_index)
87
37
  self.reduction = reduction
88
- self.alpha = alpha
38
+
39
+ if alpha is not None:
40
+ if alpha.dim() != 1:
41
+ raise ValueError("alpha must be a 1D tensor of shape (C,).")
42
+ # Register as buffer so it moves with the module across devices.
43
+ self.register_buffer("alpha", alpha)
44
+ else:
45
+ self.alpha = None
89
46
 
90
47
  def forward(
91
48
  self,
92
49
  logits: torch.Tensor,
93
50
  targets: torch.Tensor,
94
- valid_mask: torch.Tensor | None = None,
51
+ *,
52
+ recon_scale: torch.Tensor | float | None = None,
95
53
  ) -> torch.Tensor:
96
- """Compute the masked categorical cross-entropy loss.
54
+ """Compute focal cross-entropy loss.
97
55
 
98
56
  Args:
99
- logits (torch.Tensor): Logits from the model of shape
100
- (batch_size, seq_len, num_classes).
101
- targets (torch.Tensor): Ground truth labels of shape (batch_size, seq_len).
102
- valid_mask (torch.Tensor, optional): Boolean mask of shape (batch_size, seq_len) where True indicates a valid (observed) value to include in the loss.
103
- Defaults to None, in which case all values are considered valid.
57
+ logits: Tensor of shape (N, C) or (N, C, d1, d2, ...).
58
+ targets: Tensor of shape (N) or (N, d1, d2, ...).
59
+ recon_scale: Optional scalar multiplier applied to the final loss.
60
+ - If reduction is "mean" or "sum", multiplies the scalar loss.
61
+ - If reduction is "none", multiplies elementwise.
104
62
 
105
63
  Returns:
106
- torch.Tensor: The computed scalar loss value.
64
+ Loss tensor:
65
+ - Scalar if reduction in {"mean","sum"}
66
+ - Tensor shaped like `targets` if reduction == "none"
107
67
  """
108
- # Automatically detect the device from the input tensor
109
- device = logits.device
110
- num_classes = logits.shape[-1]
111
-
112
- # Ensure targets are on the correct device and are Long type
113
- targets = targets.to(device).long()
114
-
115
- # Prepare weights and pass them directly to the loss function
116
- class_weights = None
117
- if self.alpha is not None:
118
- if not isinstance(self.alpha, torch.Tensor):
119
- class_weights = torch.tensor(
120
- self.alpha, dtype=torch.float, device=device
121
- )
122
- else:
123
- class_weights = self.alpha.to(device)
124
-
125
- loss = F.cross_entropy(
126
- logits.reshape(-1, num_classes),
127
- targets.reshape(-1),
128
- weight=class_weights,
129
- reduction="none",
130
- ignore_index=-1, # Ignore all targets with the value -1
131
- )
132
-
133
- # If a mask is provided, filter the losses for the training set
134
- if valid_mask is not None:
135
- loss = loss[valid_mask.reshape(-1)]
136
-
137
- # If after masking no valid losses remain, return 0
138
- if loss.numel() == 0:
139
- return torch.tensor(0.0, device=device)
140
-
141
- # Apply the final reduction
142
- if self.reduction == "mean":
143
- return loss.mean()
144
- elif self.reduction == "sum":
145
- return loss.sum()
146
- else:
147
- msg = f"Reduction mode '{self.reduction}' not supported."
148
- raise ValueError(msg)
68
+ # Move C (dim 1) to the last position for flattening:
69
+ # (N, C, d1, ...) -> (N, d1, ..., C)
70
+ if logits.dim() > 2:
71
+ logits = logits.permute(0, *range(2, logits.dim()), 1)
149
72
 
73
+ logits_flat = logits.reshape(-1, logits.size(-1))
74
+ targets_flat = targets.reshape(-1).long()
150
75
 
151
- class MaskedFocalLoss(nn.Module):
152
- """Focal loss (gamma > 0) with optional class weights and a boolean valid mask.
76
+ valid_mask = targets_flat != self.ignore_index
153
77
 
154
- This method implements the focal loss function, which is designed to address class imbalance by down-weighting easy examples and focusing training on hard negatives. It also supports masking of invalid (missing) entries, making it suitable for sequence data with missing values.
155
- """
78
+ # Early exit if everything is ignored
79
+ if not bool(valid_mask.any()):
80
+ out = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
81
+ # preserve grad path behavior if caller expects it
82
+ out = out.requires_grad_(True)
83
+ return out
156
84
 
157
- def __init__(
158
- self,
159
- gamma: float = 2.0,
160
- alpha: torch.Tensor | None = None,
161
- reduction: Literal["mean", "sum"] = "mean",
162
- ):
163
- """Initialize the MaskedFocalLoss.
85
+ logits_v = logits_flat[valid_mask]
86
+ targets_v = targets_flat[valid_mask]
164
87
 
165
- This class sets up the focal loss with specified focusing parameter, class weights, and reduction method. It is designed to handle missing data through a valid mask, ensuring that only relevant entries contribute to the loss calculation.
88
+ # Numerically stable log-softmax
89
+ log_probs = F.log_softmax(logits_v, dim=-1)
90
+ log_pt = log_probs.gather(1, targets_v.unsqueeze(1)).squeeze(1)
91
+ pt = log_pt.exp()
166
92
 
167
- Args:
168
- gamma (float): Focusing parameter.
169
- alpha (torch.Tensor | None): Class weights.
170
- reduction (Literal["mean", "sum"]): Reduction mode ('mean' or 'sum').
171
- """
172
- super().__init__()
173
- self.gamma = gamma
174
- self.alpha = alpha
175
- self.reduction = reduction
176
-
177
- def forward(
178
- self,
179
- logits: torch.Tensor, # Expects (N, C) where N = batch*features
180
- targets: torch.Tensor, # Expects (N,)
181
- valid_mask: torch.Tensor, # Expects (N,)
182
- ) -> torch.Tensor:
183
- """Calculates the focal loss on pre-flattened tensors.
93
+ focal_term = (1.0 - pt).pow(self.gamma)
94
+ loss_vec = -focal_term * log_pt
184
95
 
185
- Args:
186
- logits (torch.Tensor): Logits from the model of shape (N, C) where N is the number of samples (batch_size * seq_len) and C is the number of classes.
187
- targets (torch.Tensor): Ground truth labels of shape (N,).
188
- valid_mask (torch.Tensor): Boolean mask of shape (N,) where True indicates a valid (observed) value to include in the loss.
189
-
190
- Returns:
191
- torch.Tensor: The computed scalar loss value.
192
- """
193
- device = logits.device
96
+ if self.alpha is not None:
97
+ loss_vec = loss_vec * self.alpha[targets_v]
194
98
 
195
- # Calculate standard cross-entropy loss per-token (no reduction)
196
- ce = F.cross_entropy(
197
- logits,
198
- targets,
199
- weight=(self.alpha.to(device) if self.alpha is not None else None),
200
- reduction="none",
201
- ignore_index=-1,
202
- )
99
+ # Apply reduction
100
+ if self.reduction == "mean":
101
+ out = loss_vec.mean()
102
+ elif self.reduction == "sum":
103
+ out = loss_vec.sum()
104
+ else: # "none"
105
+ out_flat = torch.zeros_like(targets_flat, dtype=loss_vec.dtype)
106
+ out_flat[valid_mask] = loss_vec
107
+ out = out_flat.view(targets.shape)
108
+
109
+ # Optional scaling (useful for VAE recon term)
110
+ if recon_scale is not None:
111
+ if not isinstance(recon_scale, torch.Tensor):
112
+ recon_scale = torch.tensor(
113
+ float(recon_scale), device=out.device, dtype=out.dtype
114
+ )
115
+ else:
116
+ recon_scale = recon_scale.to(device=out.device, dtype=out.dtype)
203
117
 
204
- # Calculate p_t from the cross-entropy loss
205
- pt = torch.exp(-ce)
206
- focal = ((1 - pt) ** self.gamma) * ce
118
+ out = out * recon_scale
207
119
 
208
- # Apply the valid mask. We select only the elements that should contribute to the loss.
209
- focal = focal[valid_mask]
120
+ return out
210
121
 
211
- # Return early if no valid elements exist to avoid NaN results
212
- if focal.numel() == 0:
213
- return torch.tensor(0.0, device=device)
122
+ @property
123
+ def gamma(self) -> float:
124
+ return self._gamma
214
125
 
215
- # Apply reduction
216
- if self.reduction == "mean":
217
- return focal.mean()
218
- elif self.reduction == "sum":
219
- return focal.sum()
220
- else:
221
- msg = f"Reduction mode '{self.reduction}' not supported."
222
- raise ValueError(msg)
126
+ @gamma.setter
127
+ def gamma(self, value: torch.Tensor | float) -> None:
128
+ if isinstance(value, torch.Tensor):
129
+ value = float(value.item())
130
+ self._gamma = float(value)
223
131
 
224
132
 
225
133
  def safe_kl_gauss_unit(
226
134
  mu: torch.Tensor, logvar: torch.Tensor, reduction: str = "mean"
227
135
  ) -> torch.Tensor:
228
- """KL divergence between N(mu, exp(logvar)) and N(0, I) with guards."""
229
- logvar = logvar.clamp(min=-30.0, max=20.0)
230
- kl = -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp())
136
+ """Compute KL divergence between N(mu, var) and N(0, I) with numeric guards.
137
+
138
+ Args:
139
+ mu (torch.Tensor): Latent mean (shape: [B, D]).
140
+ logvar (torch.Tensor): Latent log-variance (shape: [B, D]).
141
+ reduction (str): Reduction method ('mean' or 'sum').
142
+
143
+ Returns:
144
+ torch.Tensor: KL divergence (scalar).
145
+ """
146
+ kl = -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp()) # (B, D)
147
+ kl = kl.sum(dim=-1) # (B,)
148
+
231
149
  if reduction == "sum":
232
150
  kl = kl.sum()
233
151
  elif reduction == "mean":
234
152
  kl = kl.mean()
153
+ else:
154
+ raise ValueError(f"Invalid reduction: {reduction}")
155
+
235
156
  return torch.nan_to_num(kl, nan=0.0, posinf=1e6, neginf=0.0)
236
157
 
237
158
 
238
159
  def compute_vae_loss(
160
+ criterion: nn.Module,
239
161
  recon_logits: torch.Tensor,
240
162
  targets: torch.Tensor,
241
163
  *,
242
164
  mu: torch.Tensor,
243
165
  logvar: torch.Tensor,
244
- class_weights: torch.Tensor | None,
245
- gamma: float,
246
- beta: float,
247
- ignore_index: int = -1,
166
+ kl_beta: torch.Tensor | float,
167
+ reduction: str = "mean",
168
+ recon_scale: torch.Tensor | float | None = None,
248
169
  ) -> torch.Tensor:
249
- """Focal reconstruction + beta * KL with normalized class weights."""
250
- cw = None
251
- if class_weights is not None:
252
- cw = class_weights / class_weights.mean().clamp_min(1e-8)
253
-
254
- criterion = SafeFocalCELoss(
255
- gamma=gamma,
256
- weight=cw,
257
- ignore_index=ignore_index,
258
- )
259
- rec = criterion(recon_logits.view(-1, recon_logits.size(-1)), targets.view(-1))
260
- kl = safe_kl_gauss_unit(mu, logvar, reduction="mean")
261
- return rec + beta * kl
170
+ """Compute VAE loss: reconstruction + KL divergence, with optional recon scaling.
171
+
172
+ Args:
173
+ criterion: Reconstruction loss module (e.g., FocalCELoss / CrossEntropyLoss).
174
+ Must accept (logits_2d, targets_1d). If it supports `recon_scale`, it will
175
+ be passed through; otherwise it will be called without it.
176
+ recon_logits: Reconstruction logits from decoder. Shape: (N, L, C) or (N_eval, C).
177
+ targets: Ground truth targets. Shape: (N, L) or (N_eval,).
178
+ mu: Latent mean. Shape: (B, D) (or compatible with safe_kl_gauss_unit).
179
+ logvar: Latent log-variance. Shape: (B, D).
180
+ kl_beta: Scalar KL weight.
181
+ reduction: KL reduction: "mean" or "sum".
182
+ recon_scale: Optional scalar multiplier applied to reconstruction term.
183
+ Use this to make reconstruction more "sum-like" for high-dimensional data.
184
+
185
+ Returns:
186
+ Scalar loss tensor.
187
+ """
188
+ # Flatten logits/targets to (N_total, C) and (N_total,)
189
+ if recon_logits.dim() == 3:
190
+ logits_2d = recon_logits.reshape(-1, recon_logits.size(-1))
191
+ elif recon_logits.dim() == 2:
192
+ logits_2d = recon_logits
193
+ else:
194
+ msg = f"recon_logits must be 2D or 3D; got shape {tuple(recon_logits.shape)}"
195
+ raise ValueError(msg)
196
+
197
+ tgt_1d = targets.reshape(-1) if targets.dim() > 1 else targets
198
+
199
+ # Reconstruction loss (criterion may ignore_index internally)
200
+ try:
201
+ rec = criterion(logits_2d, tgt_1d, recon_scale=recon_scale)
202
+ except TypeError:
203
+ # Criterion doesn't accept recon_scale (e.g., torch.nn.CrossEntropyLoss)
204
+ rec = criterion(logits_2d, tgt_1d)
205
+ if recon_scale is not None:
206
+ if isinstance(recon_scale, torch.Tensor):
207
+ rec = rec * recon_scale.to(device=rec.device, dtype=rec.dtype)
208
+ else:
209
+ rec = rec * float(recon_scale)
210
+
211
+ # KL term
212
+ kl = safe_kl_gauss_unit(mu, logvar, reduction=reduction)
213
+ loss = rec + (kl_beta * kl)
214
+
215
+ return torch.nan_to_num(loss, nan=1e6, posinf=1e6, neginf=1e6)
@@ -5,7 +5,6 @@ import torch
5
5
  import torch.nn as nn
6
6
  from snpio.utils.logging import LoggerManager
7
7
 
8
- from pgsui.impute.unsupervised.loss_functions import MaskedFocalLoss
9
8
  from pgsui.utils.logging_utils import configure_logger
10
9
 
11
10
 
@@ -44,8 +43,8 @@ class Encoder(nn.Module):
44
43
  for hidden_size in hidden_layer_sizes:
45
44
  layers.append(nn.Linear(input_dim, hidden_size))
46
45
  layers.append(nn.BatchNorm1d(hidden_size))
47
- layers.append(nn.Dropout(dropout_rate))
48
46
  layers.append(activation)
47
+ layers.append(nn.Dropout(dropout_rate))
49
48
  input_dim = hidden_size
50
49
 
51
50
  self.hidden_layers = nn.Sequential(*layers)
@@ -98,8 +97,8 @@ class Decoder(nn.Module):
98
97
  for hidden_size in hidden_layer_sizes:
99
98
  layers.append(nn.Linear(input_dim, hidden_size))
100
99
  layers.append(nn.BatchNorm1d(hidden_size))
101
- layers.append(nn.Dropout(dropout_rate))
102
100
  layers.append(activation)
101
+ layers.append(nn.Dropout(dropout_rate))
103
102
  input_dim = hidden_size
104
103
 
105
104
  self.hidden_layers = nn.Sequential(*layers)
@@ -128,17 +127,17 @@ class AutoencoderModel(nn.Module):
128
127
 
129
128
  **Model Architecture and Objective:**
130
129
 
131
- The autoencoder consists of two parts: an encoder, $f_{\theta}$, and a decoder, $g_{\phi}$.
130
+ The autoencoder consists of two parts: an encoder, $f_{\\theta}$, and a decoder, $g_{\\phi}$.
132
131
  1. The **encoder** maps the input data $x$ to a latent representation $z$:
133
132
  $$
134
133
  z = f_{\theta}(x)
135
134
  $$
136
- 2. The **decoder** reconstructs the data $\hat{x}$ from the latent representation:
135
+ 2. The **decoder** reconstructs the data $\\hat{x}$ from the latent representation:
137
136
  $$
138
- \hat{x} = g_{\phi}(z)
137
+ \\hat{x} = g_{\\phi}(z)
139
138
  $$
140
139
 
141
- The model is trained by minimizing a reconstruction loss, $L(x, \hat{x})$, which measures the dissimilarity between the original input and the reconstructed output. This implementation uses a `MaskedFocalLoss` to handle missing values and class imbalance effectively.
140
+ The model is trained by minimizing a reconstruction loss, $L(x, \\hat{x})$, which measures the dissimilarity between the original input and the reconstructed output. This implementation uses a ``FocalCELoss`` to handle missing values and class imbalance effectively.
142
141
  """
143
142
 
144
143
  def __init__(
@@ -151,7 +150,7 @@ class AutoencoderModel(nn.Module):
151
150
  latent_dim: int = 2,
152
151
  dropout_rate: float = 0.2,
153
152
  activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
154
- gamma: float = 2.0,
153
+ gamma: torch.Tensor = torch.tensor(2.0),
155
154
  device: Literal["cpu", "gpu", "mps"] = "cpu",
156
155
  verbose: bool = False,
157
156
  debug: bool = False,
@@ -222,47 +221,6 @@ class AutoencoderModel(nn.Module):
222
221
  reconstruction = self.decoder(z)
223
222
  return reconstruction
224
223
 
225
- def compute_loss(
226
- self,
227
- reconstruction: torch.Tensor,
228
- y: torch.Tensor,
229
- mask: torch.Tensor | None = None,
230
- class_weights: torch.Tensor | None = None,
231
- ) -> torch.Tensor:
232
- """Computes the reconstruction loss for the Autoencoder model.
233
-
234
- This method calculates the reconstruction loss using a masked focal loss, which is suitable for categorical data with missing values and class imbalance.
235
-
236
- Args:
237
- reconstruction (torch.Tensor): The reconstructed output (logits) from the model's forward pass.
238
- y (torch.Tensor): The target data tensor, expected to be one-hot encoded. It is converted to class indices internally for the loss calculation.
239
- mask (torch.Tensor | None): A boolean mask to exclude missing values from the loss calculation.
240
- class_weights (torch.Tensor | None): Weights to apply to each class in the loss to handle imbalance.
241
-
242
- Returns:
243
- torch.Tensor: The computed scalar loss value.
244
- """
245
- if class_weights is None:
246
- class_weights = torch.ones(self.num_classes, device=y.device)
247
-
248
- logits_flat = reconstruction.view(-1, self.num_classes)
249
- targets_flat = torch.argmax(y, dim=-1).view(-1)
250
-
251
- if mask is None:
252
- mask_flat = torch.ones_like(targets_flat, dtype=torch.bool)
253
- else:
254
- mask_flat = mask.view(-1)
255
-
256
- criterion = MaskedFocalLoss(alpha=class_weights, gamma=self.gamma)
257
-
258
- reconstruction_loss = criterion(
259
- logits_flat.to(self.device),
260
- targets_flat.to(self.device),
261
- valid_mask=mask_flat.to(self.device),
262
- )
263
-
264
- return reconstruction_loss
265
-
266
224
  def _resolve_activation(
267
225
  self, activation: Literal["relu", "elu", "leaky_relu", "selu"]
268
226
  ) -> torch.nn.Module: