SURE-tools 2.4.13__py3-none-any.whl → 2.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -12,7 +12,7 @@ warnings.filterwarnings('ignore')
12
12
  class EfficientTranscriptomeDecoder:
13
13
  """
14
14
  High-performance, memory-efficient transcriptome decoder
15
- Combines latest research techniques for optimal performance
15
+ Fixed version with corrected RMSNorm implementation
16
16
  """
17
17
 
18
18
  def __init__(self,
@@ -20,7 +20,7 @@ class EfficientTranscriptomeDecoder:
20
20
  gene_dim: int = 60000,
21
21
  hidden_dims: List[int] = [512, 1024, 2048],
22
22
  bottleneck_dim: int = 256,
23
- num_experts: int = 8,
23
+ num_experts: int = 4,
24
24
  dropout_rate: float = 0.1,
25
25
  device: str = None):
26
26
  """
@@ -43,8 +43,8 @@ class EfficientTranscriptomeDecoder:
43
43
  self.dropout_rate = dropout_rate
44
44
  self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
45
45
 
46
- # Initialize model with advanced architecture
47
- self.model = self._build_advanced_model()
46
+ # Initialize model with corrected architecture
47
+ self.model = self._build_corrected_model()
48
48
  self.model.to(self.device)
49
49
 
50
50
  # Training state
@@ -58,74 +58,50 @@ class EfficientTranscriptomeDecoder:
58
58
  print(f" - Hidden Dimensions: {hidden_dims}")
59
59
  print(f" - Bottleneck Dimension: {bottleneck_dim}")
60
60
  print(f" - Number of Experts: {num_experts}")
61
- print(f" - Estimated GPU Memory: ~6-8GB")
62
61
  print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
63
62
 
64
- class SwiGLU(nn.Module):
65
- """SwiGLU activation - better than GELU (PaLM, LLaMA)"""
66
- def forward(self, x):
67
- x, gate = x.chunk(2, dim=-1)
68
- return x * F.silu(gate)
69
-
70
- class RMSNorm(nn.Module):
71
- """RMS Normalization - more stable than LayerNorm (GPT-3)"""
63
+ class CorrectedRMSNorm(nn.Module):
64
+ """Corrected RMS Normalization with proper dimension handling"""
72
65
  def __init__(self, dim: int, eps: float = 1e-8):
73
66
  super().__init__()
74
67
  self.eps = eps
75
- self.weight = nn.Parameter(torch.ones(dim))
68
+ self.dim = dim
69
+ self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
76
70
 
77
71
  def forward(self, x):
78
- norm_x = x.norm(2, dim=-1, keepdim=True)
79
- rms_x = norm_x * (x.shape[-1] ** -0.5)
80
- return x / (rms_x + self.eps) * self.weight
72
+ # Ensure input has the right dimension
73
+ if x.size(-1) != self.dim:
74
+ raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
75
+
76
+ # Calculate RMS
77
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
78
+ # Normalize and apply weight
79
+ return x / rms * self.weight
81
80
 
82
- class MixtureOfExperts(nn.Module):
83
- """Mixture of Experts for conditional computation"""
84
- def __init__(self, input_dim: int, expert_dim: int, num_experts: int):
85
- super().__init__()
86
- self.num_experts = num_experts
87
- self.experts = nn.ModuleList([
88
- nn.Sequential(
89
- nn.Linear(input_dim, expert_dim),
90
- nn.Dropout(0.1),
91
- nn.Linear(expert_dim, input_dim)
92
- ) for _ in range(num_experts)
93
- ])
94
- self.gate = nn.Linear(input_dim, num_experts)
95
- self.expert_dim = expert_dim
96
-
81
+ class SimplifiedSwiGLU(nn.Module):
82
+ """Simplified SwiGLU activation"""
97
83
  def forward(self, x):
98
- # Gate network
99
- gate_logits = self.gate(x)
100
- gate_weights = F.softmax(gate_logits, dim=-1)
101
-
102
- # Expert outputs
103
- expert_outputs = []
104
- for i, expert in enumerate(self.experts):
105
- expert_out = expert(x)
106
- expert_outputs.append(expert_out.unsqueeze(-1))
107
-
108
- # Combine expert outputs
109
- expert_outputs = torch.cat(expert_outputs, dim=-1)
110
- output = torch.einsum('bd, bde -> be', gate_weights, expert_outputs)
111
-
112
- return output + x # Residual connection
84
+ # Split into two parts
85
+ x, gate = x.chunk(2, dim=-1)
86
+ return x * F.silu(gate)
113
87
 
114
- class AdaptiveBottleneck(nn.Module):
115
- """Adaptive bottleneck for memory efficiency"""
88
+ class MemoryEfficientBottleneck(nn.Module):
89
+ """Memory-efficient bottleneck with corrected dimensions"""
116
90
  def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
117
91
  super().__init__()
92
+ # Ensure proper dimension matching
118
93
  self.compress = nn.Linear(input_dim, bottleneck_dim)
119
- self.norm1 = EfficientTranscriptomeDecoder.RMSNorm(bottleneck_dim)
94
+ self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
120
95
  self.expand = nn.Linear(bottleneck_dim, output_dim)
121
- self.norm2 = EfficientTranscriptomeDecoder.RMSNorm(output_dim)
96
+ self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
97
+ self.activation = nn.SiLU()
122
98
  self.dropout = nn.Dropout(0.1)
123
99
 
124
100
  def forward(self, x):
125
101
  # Compress
126
102
  compressed = self.compress(x)
127
103
  compressed = self.norm1(compressed)
128
- compressed = F.silu(compressed)
104
+ compressed = self.activation(compressed)
129
105
  compressed = self.dropout(compressed)
130
106
 
131
107
  # Expand
@@ -134,37 +110,57 @@ class EfficientTranscriptomeDecoder:
134
110
 
135
111
  return expanded
136
112
 
137
- class GeneSpecificProjection(nn.Module):
138
- """Gene-specific projection with weight sharing"""
139
- def __init__(self, latent_dim: int, gene_dim: int, proj_dim: int = 64):
113
+ class StableMixtureOfExperts(nn.Module):
114
+ """Stable mixture of experts without dimension issues"""
115
+ def __init__(self, input_dim: int, num_experts: int = 4):
140
116
  super().__init__()
141
- self.proj_dim = proj_dim
142
- self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, proj_dim) * 0.02)
143
- self.latent_projection = nn.Linear(latent_dim, proj_dim)
144
- self.output_layer = nn.Linear(proj_dim, 1)
145
-
146
- def forward(self, latent):
147
- batch_size = latent.shape[0]
117
+ self.num_experts = num_experts
118
+ self.input_dim = input_dim
148
119
 
149
- # Project latent to gene space
150
- latent_proj = self.latent_projection(latent) # [batch_size, proj_dim]
120
+ # Shared expert with different scaling factors
121
+ self.shared_expert = nn.Sequential(
122
+ nn.Linear(input_dim, input_dim * 2),
123
+ nn.SiLU(),
124
+ nn.Dropout(0.1),
125
+ nn.Linear(input_dim * 2, input_dim)
126
+ )
151
127
 
152
- # Efficient matrix multiplication
153
- gene_output = torch.matmul(latent_proj, self.gene_embeddings.T) # [batch_size, gene_dim]
128
+ # Gating network
129
+ self.gate = nn.Sequential(
130
+ nn.Linear(input_dim, num_experts * 4),
131
+ nn.SiLU(),
132
+ nn.Linear(num_experts * 4, num_experts)
133
+ )
154
134
 
155
- return gene_output
135
+ def forward(self, x):
136
+ # Get gate weights
137
+ gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
138
+
139
+ # Process through shared expert
140
+ expert_output = self.shared_expert(x) # [batch, input_dim]
141
+
142
+ # Apply expert-specific scaling
143
+ weighted_output = torch.zeros_like(expert_output)
144
+ for i in range(self.num_experts):
145
+ expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
146
+ expert_contribution = expert_output * expert_scale
147
+ expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
148
+ weighted_output += expert_weight * expert_contribution
149
+
150
+ # Residual connection
151
+ return x + weighted_output
156
152
 
157
- class AdvancedDecoder(nn.Module):
158
- """Advanced decoder combining multiple techniques"""
153
+ class CorrectedDecoder(nn.Module):
154
+ """Corrected decoder with proper dimension handling"""
159
155
 
160
156
  def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
161
- bottleneck_dim: int, num_experts: int, dropout_rate: float):
157
+ bottleneck_dim: int, num_experts: int, dropout_rate: float):
162
158
  super().__init__()
163
159
 
164
- # Initial projection
160
+ # Input projection
165
161
  self.input_projection = nn.Sequential(
166
162
  nn.Linear(latent_dim, hidden_dims[0]),
167
- EfficientTranscriptomeDecoder.RMSNorm(hidden_dims[0]),
163
+ EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
168
164
  nn.SiLU(),
169
165
  nn.Dropout(dropout_rate)
170
166
  )
@@ -173,71 +169,74 @@ class EfficientTranscriptomeDecoder:
173
169
  self.blocks = nn.ModuleList()
174
170
  current_dim = hidden_dims[0]
175
171
 
176
- for i, hidden_dim in enumerate(hidden_dims[1:], 1):
177
- block = nn.ModuleList([
178
- # Mixture of Experts
179
- EfficientTranscriptomeDecoder.MixtureOfExperts(current_dim, hidden_dim, num_experts),
180
-
181
- # Adaptive Bottleneck
182
- EfficientTranscriptomeDecoder.AdaptiveBottleneck(current_dim, bottleneck_dim, hidden_dim),
183
-
184
- # SwiGLU activation
185
- nn.Sequential(
186
- nn.Linear(hidden_dim, hidden_dim * 2),
187
- EfficientTranscriptomeDecoder.SwiGLU(),
188
- nn.Dropout(dropout_rate)
172
+ for i, next_dim in enumerate(hidden_dims[1:], 1):
173
+ block = nn.ModuleDict({
174
+ 'swiglu': nn.Sequential(
175
+ nn.Linear(current_dim, current_dim * 2),
176
+ EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
177
+ nn.Dropout(dropout_rate),
178
+ nn.Linear(current_dim, current_dim) # Project back to same dimension
179
+ ),
180
+ 'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
181
+ current_dim, bottleneck_dim, next_dim
182
+ ),
183
+ 'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
184
+ next_dim, num_experts
189
185
  )
190
- ])
186
+ })
191
187
  self.blocks.append(block)
192
- current_dim = hidden_dim
188
+ current_dim = next_dim
193
189
 
194
- # Gene-specific projection
195
- self.gene_projection = EfficientTranscriptomeDecoder.GeneSpecificProjection(
196
- current_dim, gene_dim, proj_dim=128
190
+ # Final projection to gene dimension
191
+ self.final_projection = nn.Sequential(
192
+ nn.Linear(current_dim, current_dim * 2),
193
+ nn.SiLU(),
194
+ nn.Dropout(dropout_rate),
195
+ nn.Linear(current_dim * 2, gene_dim)
197
196
  )
198
197
 
199
- # Output scaling
198
+ # Output parameters
200
199
  self.output_scale = nn.Parameter(torch.ones(1))
201
200
  self.output_bias = nn.Parameter(torch.zeros(1))
202
201
 
203
202
  self._init_weights()
204
203
 
205
204
  def _init_weights(self):
206
- """Advanced weight initialization"""
205
+ """Proper weight initialization"""
207
206
  for module in self.modules():
208
207
  if isinstance(module, nn.Linear):
209
- # Kaiming init for SiLU/SwiGLU
210
- nn.init.kaiming_normal_(module.weight, nonlinearity='linear')
208
+ nn.init.xavier_uniform_(module.weight)
211
209
  if module.bias is not None:
212
210
  nn.init.zeros_(module.bias)
213
211
 
214
212
  def forward(self, x):
215
- # Initial projection
213
+ # Input projection
216
214
  x = self.input_projection(x)
217
215
 
218
216
  # Process through blocks
219
217
  for block in self.blocks:
220
- # Mixture of Experts
221
- expert_out = block[0](x)
218
+ # SwiGLU with residual
219
+ residual = x
220
+ x_swiglu = block['swiglu'](x)
221
+ x = x + x_swiglu # Residual connection
222
222
 
223
- # Adaptive Bottleneck
224
- bottleneck_out = block[1](expert_out)
223
+ # Bottleneck
224
+ x = block['bottleneck'](x)
225
225
 
226
- # SwiGLU activation with residual
227
- swiglu_out = block[2](bottleneck_out)
228
- x = x + swiglu_out # Residual connection
226
+ # Mixture of Experts with residual
227
+ x = block['experts'](x)
229
228
 
230
- # Final gene projection
231
- output = self.gene_projection(x)
229
+ # Final projection
230
+ x = self.final_projection(x)
232
231
 
233
232
  # Ensure non-negative output
234
- output = F.softplus(output * self.output_scale + self.output_bias)
233
+ x = F.softplus(x * self.output_scale + self.output_bias)
235
234
 
236
- return output
235
+ return x
237
236
 
238
- def _build_advanced_model(self):
239
- """Build the advanced decoder model"""
240
- return self.AdvancedDecoder(
237
+ def _build_corrected_model(self):
238
+ """Build the corrected model"""
239
+ return self.CorrectedDecoder(
241
240
  self.latent_dim, self.gene_dim, self.hidden_dims,
242
241
  self.bottleneck_dim, self.num_experts, self.dropout_rate
243
242
  )
@@ -247,24 +246,14 @@ class EfficientTranscriptomeDecoder:
247
246
  train_expression: np.ndarray,
248
247
  val_latent: np.ndarray = None,
249
248
  val_expression: np.ndarray = None,
250
- batch_size: int = 16, # Smaller batches for memory efficiency
251
- num_epochs: int = 200,
249
+ batch_size: int = 32,
250
+ num_epochs: int = 100,
252
251
  learning_rate: float = 1e-4,
253
- checkpoint_path: str = 'efficient_decoder.pth') -> Dict:
252
+ checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
254
253
  """
255
- Train with advanced optimization techniques
256
-
257
- Args:
258
- train_latent: Training latent variables
259
- train_expression: Training expression data
260
- val_latent: Validation latent variables
261
- val_expression: Validation expression data
262
- batch_size: Batch size (optimized for memory)
263
- num_epochs: Number of epochs
264
- learning_rate: Learning rate
265
- checkpoint_path: Model save path
254
+ Train the corrected decoder
266
255
  """
267
- print("🚀 Starting Advanced Training...")
256
+ print("🚀 Starting Training...")
268
257
 
269
258
  # Data preparation
270
259
  train_dataset = self._create_dataset(train_latent, train_expression)
@@ -273,114 +262,83 @@ class EfficientTranscriptomeDecoder:
273
262
  val_dataset = self._create_dataset(val_latent, val_expression)
274
263
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
275
264
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
265
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
276
266
  else:
267
+ # Auto split
277
268
  train_size = int(0.9 * len(train_dataset))
278
269
  val_size = len(train_dataset) - train_size
279
270
  train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
280
271
  train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
281
272
  val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
273
+ print(f"📈 Auto-split validation: {val_size} samples")
282
274
 
283
275
  print(f"📊 Training samples: {len(train_loader.dataset)}")
284
276
  print(f"📊 Validation samples: {len(val_loader.dataset)}")
285
277
  print(f"📊 Batch size: {batch_size}")
286
278
 
287
- # Advanced optimizer configuration
279
+ # Optimizer
288
280
  optimizer = optim.AdamW(
289
281
  self.model.parameters(),
290
282
  lr=learning_rate,
291
- weight_decay=0.1, # Stronger regularization
292
- betas=(0.9, 0.95), # Tuned betas
293
- eps=1e-8
283
+ weight_decay=0.01,
284
+ betas=(0.9, 0.999)
294
285
  )
295
286
 
296
- # Cosine annealing with warmup
297
- scheduler = optim.lr_scheduler.OneCycleLR(
298
- optimizer,
299
- max_lr=learning_rate * 5,
300
- epochs=num_epochs,
301
- steps_per_epoch=len(train_loader),
302
- pct_start=0.1,
303
- div_factor=10.0,
304
- final_div_factor=100.0
305
- )
287
+ # Scheduler
288
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
306
289
 
307
- # Advanced loss function
308
- def advanced_loss(pred, target):
309
- # 1. MSE loss for overall accuracy
290
+ # Loss function
291
+ def combined_loss(pred, target):
310
292
  mse_loss = F.mse_loss(pred, target)
311
-
312
- # 2. Poisson loss for count data
313
293
  poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
314
-
315
- # 3. Correlation loss for pattern matching
316
- correlation_loss = 1 - self._pearson_correlation(pred, target)
317
-
318
- # 4. Sparsity loss for realistic distribution
319
- sparsity_loss = F.mse_loss(
320
- (pred < 1e-3).float().mean(),
321
- torch.tensor(0.85, device=pred.device) # Target sparsity
322
- )
323
-
324
- # 5. Spectral loss for smoothness
325
- spectral_loss = self._spectral_loss(pred, target)
326
-
327
- # Weighted combination
328
- total_loss = (mse_loss + 0.3 * poisson_loss + 0.2 * correlation_loss +
329
- 0.1 * sparsity_loss + 0.05 * spectral_loss)
330
-
331
- return total_loss, {
332
- 'mse': mse_loss.item(),
333
- 'poisson': poisson_loss.item(),
334
- 'correlation': correlation_loss.item(),
335
- 'sparsity': sparsity_loss.item(),
336
- 'spectral': spectral_loss.item()
337
- }
294
+ correlation = self._pearson_correlation(pred, target)
295
+ correlation_loss = 1 - correlation
296
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
338
297
 
339
298
  # Training history
340
299
  history = {
341
300
  'train_loss': [], 'val_loss': [],
342
301
  'train_mse': [], 'val_mse': [],
343
302
  'train_correlation': [], 'val_correlation': [],
344
- 'learning_rates': [], 'grad_norms': []
303
+ 'learning_rates': []
345
304
  }
346
305
 
347
306
  best_val_loss = float('inf')
348
- patience = 25
307
+ patience = 20
349
308
  patience_counter = 0
350
309
 
351
- print("\n📈 Starting training with advanced techniques...")
310
+ print("\n📈 Starting training loop...")
352
311
  for epoch in range(1, num_epochs + 1):
353
- # Training phase
354
- train_loss, train_components, grad_norm = self._train_epoch_advanced(
355
- train_loader, optimizer, scheduler, advanced_loss
356
- )
312
+ # Training
313
+ train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
357
314
 
358
- # Validation phase
359
- val_loss, val_components = self._validate_epoch_advanced(val_loader, advanced_loss)
315
+ # Validation
316
+ val_metrics = self._validate_epoch(val_loader, combined_loss)
317
+
318
+ # Update scheduler
319
+ scheduler.step()
320
+ current_lr = optimizer.param_groups[0]['lr']
360
321
 
361
322
  # Record history
362
- history['train_loss'].append(train_loss)
363
- history['val_loss'].append(val_loss)
364
- history['train_mse'].append(train_components['mse'])
365
- history['val_mse'].append(val_components['mse'])
366
- history['train_correlation'].append(train_components['correlation'])
367
- history['val_correlation'].append(val_components['correlation'])
368
- history['learning_rates'].append(optimizer.param_groups[0]['lr'])
369
- history['grad_norms'].append(grad_norm)
370
-
371
- # Print detailed progress
323
+ history['train_loss'].append(train_metrics['loss'])
324
+ history['val_loss'].append(val_metrics['loss'])
325
+ history['train_mse'].append(train_metrics['mse'])
326
+ history['val_mse'].append(val_metrics['mse'])
327
+ history['train_correlation'].append(train_metrics['correlation'])
328
+ history['val_correlation'].append(val_metrics['correlation'])
329
+ history['learning_rates'].append(current_lr)
330
+
331
+ # Print progress
372
332
  if epoch % 10 == 0 or epoch == 1:
373
- lr = optimizer.param_groups[0]['lr']
374
333
  print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
375
- f"Train: {train_loss:.4f} | "
376
- f"Val: {val_loss:.4f} | "
377
- f"Corr: {val_components['correlation']:.4f} | "
378
- f"LR: {lr:.2e} | "
379
- f"Grad: {grad_norm:.4f}")
380
-
381
- # Early stopping with patience
382
- if val_loss < best_val_loss:
383
- best_val_loss = val_loss
334
+ f"Train Loss: {train_metrics['loss']:.4f} | "
335
+ f"Val Loss: {val_metrics['loss']:.4f} | "
336
+ f"Correlation: {val_metrics['correlation']:.4f} | "
337
+ f"LR: {current_lr:.2e}")
338
+
339
+ # Early stopping
340
+ if val_metrics['loss'] < best_val_loss:
341
+ best_val_loss = val_metrics['loss']
384
342
  patience_counter = 0
385
343
  self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
386
344
  if epoch % 20 == 0:
@@ -402,8 +360,8 @@ class EfficientTranscriptomeDecoder:
402
360
  return history
403
361
 
404
362
  def _create_dataset(self, latent_data, expression_data):
405
- """Create memory-efficient dataset"""
406
- class EfficientDataset(Dataset):
363
+ """Create dataset"""
364
+ class SimpleDataset(Dataset):
407
365
  def __init__(self, latent, expression):
408
366
  self.latent = torch.FloatTensor(latent)
409
367
  self.expression = torch.FloatTensor(expression)
@@ -414,7 +372,7 @@ class EfficientTranscriptomeDecoder:
414
372
  def __getitem__(self, idx):
415
373
  return self.latent[idx], self.expression[idx]
416
374
 
417
- return EfficientDataset(latent_data, expression_data)
375
+ return SimpleDataset(latent_data, expression_data)
418
376
 
419
377
  def _pearson_correlation(self, pred, target):
420
378
  """Calculate Pearson correlation"""
@@ -426,68 +384,48 @@ class EfficientTranscriptomeDecoder:
426
384
 
427
385
  return (numerator / (denominator + 1e-8)).mean()
428
386
 
429
- def _spectral_loss(self, pred, target):
430
- """Spectral loss for frequency domain matching"""
431
- pred_fft = torch.fft.fft(pred, dim=1)
432
- target_fft = torch.fft.fft(target, dim=1)
433
-
434
- magnitude_loss = F.mse_loss(torch.abs(pred_fft), torch.abs(target_fft))
435
- phase_loss = F.mse_loss(torch.angle(pred_fft), torch.angle(target_fft))
436
-
437
- return magnitude_loss + 0.5 * phase_loss
438
-
439
- def _train_epoch_advanced(self, train_loader, optimizer, scheduler, loss_fn):
440
- """Advanced training with gradient accumulation"""
387
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
388
+ """Train one epoch"""
441
389
  self.model.train()
442
390
  total_loss = 0
443
- total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
444
- grad_norms = []
445
-
446
- # Gradient accumulation for effective larger batch size
447
- accumulation_steps = 4
448
- optimizer.zero_grad()
391
+ total_mse = 0
392
+ total_correlation = 0
449
393
 
450
- for i, (latent, target) in enumerate(train_loader):
394
+ for latent, target in train_loader:
451
395
  latent = latent.to(self.device, non_blocking=True)
452
396
  target = target.to(self.device, non_blocking=True)
453
397
 
454
- # Forward pass with mixed precision
455
- with torch.cuda.amp.autocast(): # Mixed precision for memory efficiency
456
- pred = self.model(latent)
457
- loss, components = loss_fn(pred, target)
398
+ optimizer.zero_grad()
399
+ pred = self.model(latent)
458
400
 
459
- # Scale loss for gradient accumulation
460
- loss = loss / accumulation_steps
401
+ loss = loss_fn(pred, target)
461
402
  loss.backward()
462
403
 
463
- # Gradient accumulation
464
- if (i + 1) % accumulation_steps == 0:
465
- # Gradient clipping
466
- grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
467
- optimizer.step()
468
- scheduler.step()
469
- optimizer.zero_grad()
470
-
471
- grad_norms.append(grad_norm.item())
404
+ # Gradient clipping
405
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
406
+ optimizer.step()
407
+
408
+ # Calculate metrics
409
+ mse_loss = F.mse_loss(pred, target).item()
410
+ correlation = self._pearson_correlation(pred, target).item()
472
411
 
473
- # Accumulate losses
474
- total_loss += loss.item() * accumulation_steps
475
- for key in total_components:
476
- total_components[key] += components[key]
412
+ total_loss += loss.item()
413
+ total_mse += mse_loss
414
+ total_correlation += correlation
477
415
 
478
- # Average metrics
479
416
  num_batches = len(train_loader)
480
- avg_loss = total_loss / num_batches
481
- avg_components = {key: value / num_batches for key, value in total_components.items()}
482
- avg_grad_norm = np.mean(grad_norms) if grad_norms else 0.0
483
-
484
- return avg_loss, avg_components, avg_grad_norm
417
+ return {
418
+ 'loss': total_loss / num_batches,
419
+ 'mse': total_mse / num_batches,
420
+ 'correlation': total_correlation / num_batches
421
+ }
485
422
 
486
- def _validate_epoch_advanced(self, val_loader, loss_fn):
487
- """Advanced validation"""
423
+ def _validate_epoch(self, val_loader, loss_fn):
424
+ """Validate one epoch"""
488
425
  self.model.eval()
489
426
  total_loss = 0
490
- total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
427
+ total_mse = 0
428
+ total_correlation = 0
491
429
 
492
430
  with torch.no_grad():
493
431
  for latent, target in val_loader:
@@ -495,17 +433,20 @@ class EfficientTranscriptomeDecoder:
495
433
  target = target.to(self.device, non_blocking=True)
496
434
 
497
435
  pred = self.model(latent)
498
- loss, components = loss_fn(pred, target)
436
+ loss = loss_fn(pred, target)
437
+ mse_loss = F.mse_loss(pred, target).item()
438
+ correlation = self._pearson_correlation(pred, target).item()
499
439
 
500
440
  total_loss += loss.item()
501
- for key in total_components:
502
- total_components[key] += components[key]
441
+ total_mse += mse_loss
442
+ total_correlation += correlation
503
443
 
504
444
  num_batches = len(val_loader)
505
- avg_loss = total_loss / num_batches
506
- avg_components = {key: value / num_batches for key, value in total_components.items()}
507
-
508
- return avg_loss, avg_components
445
+ return {
446
+ 'loss': total_loss / num_batches,
447
+ 'mse': total_mse / num_batches,
448
+ 'correlation': total_correlation / num_batches
449
+ }
509
450
 
510
451
  def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
511
452
  """Save checkpoint"""
@@ -525,8 +466,8 @@ class EfficientTranscriptomeDecoder:
525
466
  }
526
467
  }, path)
527
468
 
528
- def predict(self, latent_data: np.ndarray, batch_size: int = 16) -> np.ndarray:
529
- """Memory-efficient prediction"""
469
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
470
+ """Predict gene expression"""
530
471
  if not self.is_trained:
531
472
  warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
532
473
 
@@ -539,15 +480,8 @@ class EfficientTranscriptomeDecoder:
539
480
  with torch.no_grad():
540
481
  for i in range(0, len(latent_data), batch_size):
541
482
  batch_latent = latent_data[i:i+batch_size].to(self.device)
542
-
543
- with torch.cuda.amp.autocast(): # Mixed precision for memory
544
- batch_pred = self.model(batch_latent)
545
-
483
+ batch_pred = self.model(batch_latent)
546
484
  predictions.append(batch_pred.cpu())
547
-
548
- # Clear memory
549
- if torch.cuda.is_available():
550
- torch.cuda.empty_cache()
551
485
 
552
486
  return torch.cat(predictions).numpy()
553
487
 
@@ -559,11 +493,23 @@ class EfficientTranscriptomeDecoder:
559
493
  self.training_history = checkpoint.get('training_history')
560
494
  self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
561
495
  print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
496
+
497
+ def get_model_info(self) -> Dict:
498
+ """Get model information"""
499
+ return {
500
+ 'is_trained': self.is_trained,
501
+ 'best_val_loss': self.best_val_loss,
502
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
503
+ 'latent_dim': self.latent_dim,
504
+ 'gene_dim': self.gene_dim,
505
+ 'hidden_dims': self.hidden_dims,
506
+ 'device': str(self.device)
507
+ }
562
508
 
563
509
  '''
564
510
  # Example usage
565
511
  def example_usage():
566
- """Demonstrate the advanced decoder"""
512
+ """Example demonstration"""
567
513
 
568
514
  # Initialize decoder
569
515
  decoder = EfficientTranscriptomeDecoder(
@@ -590,7 +536,7 @@ def example_usage():
590
536
  history = decoder.train(
591
537
  train_latent=latent_data,
592
538
  train_expression=expression_data,
593
- batch_size=16,
539
+ batch_size=32,
594
540
  num_epochs=50
595
541
  )
596
542
 
@@ -603,5 +549,4 @@ def example_usage():
603
549
 
604
550
  if __name__ == "__main__":
605
551
  example_usage()
606
-
607
552
  '''