SURE-tools 2.4.13__py3-none-any.whl → 2.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -0,0 +1,723 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class PerturbationAwareDecoder:
13
+ """
14
+ Advanced transcriptome decoder with perturbation awareness
15
+ Supports single hidden layer configurations
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ num_known_perturbations: int = 50,
21
+ gene_dim: int = 60000,
22
+ hidden_dims: List[int] = [512], # Now supports single element list
23
+ perturbation_embedding_dim: int = 128,
24
+ biological_prior_dim: int = 256,
25
+ dropout_rate: float = 0.1,
26
+ device: str = None):
27
+ """
28
+ Multi-modal decoder supporting single hidden layer configurations
29
+
30
+ Args:
31
+ latent_dim: Latent variable dimension
32
+ num_known_perturbations: Number of known perturbation types
33
+ gene_dim: Number of genes
34
+ hidden_dims: Hidden layer dimensions (can be single element list)
35
+ perturbation_embedding_dim: Embedding dimension for perturbations
36
+ biological_prior_dim: Dimension for biological prior knowledge
37
+ dropout_rate: Dropout rate
38
+ device: Computation device
39
+ """
40
+ self.latent_dim = latent_dim
41
+ self.num_known_perturbations = num_known_perturbations
42
+ self.gene_dim = gene_dim
43
+ self.hidden_dims = hidden_dims
44
+ self.perturbation_embedding_dim = perturbation_embedding_dim
45
+ self.biological_prior_dim = biological_prior_dim
46
+ self.dropout_rate = dropout_rate
47
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
48
+
49
+ # Validate hidden_dims
50
+ self._validate_hidden_dims()
51
+
52
+ # Initialize multi-modal model
53
+ self.model = self._build_single_layer_compatible_model()
54
+ self.model.to(self.device)
55
+
56
+ # Training state
57
+ self.is_trained = False
58
+ self.training_history = None
59
+ self.best_val_loss = float('inf')
60
+ self.known_perturbation_names = []
61
+ self.perturbation_prototypes = None
62
+
63
+ print(f"🧬 PerturbationAwareDecoder Initialized:")
64
+ print(f" - Latent Dimension: {latent_dim}")
65
+ print(f" - Known Perturbations: {num_known_perturbations}")
66
+ print(f" - Gene Dimension: {gene_dim}")
67
+ print(f" - Hidden Dimensions: {hidden_dims}")
68
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
69
+
70
+ def _validate_hidden_dims(self):
71
+ """Validate hidden_dims parameter"""
72
+ assert len(self.hidden_dims) >= 1, "hidden_dims must have at least one element"
73
+ assert all(dim > 0 for dim in self.hidden_dims), "All hidden dimensions must be positive"
74
+
75
+ if len(self.hidden_dims) == 1:
76
+ print("🔧 Single hidden layer configuration detected")
77
+ else:
78
+ print(f"🔧 Multi-layer configuration: {len(self.hidden_dims)} hidden layers")
79
+
80
+ class SingleLayerCompatiblePerturbationEncoder(nn.Module):
81
+ """Perturbation encoder compatible with single hidden layer"""
82
+
83
+ def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
84
+ super().__init__()
85
+ self.num_perturbations = num_perturbations
86
+
87
+ # Embedding for perturbation types
88
+ self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
89
+
90
+ # Single projection layer (compatible with single hidden layer)
91
+ self.projection = nn.Sequential(
92
+ nn.Linear(embedding_dim, hidden_dim),
93
+ nn.ReLU(),
94
+ nn.Dropout(0.1)
95
+ )
96
+
97
+ def forward(self, one_hot_perturbations):
98
+ batch_size = one_hot_perturbations.shape[0]
99
+
100
+ # Convert one-hot to indices
101
+ perturbation_indices = torch.argmax(one_hot_perturbations, dim=1)
102
+
103
+ # Get perturbation embeddings
104
+ perturbation_embeds = self.perturbation_embedding(perturbation_indices)
105
+
106
+ # Single projection
107
+ hidden_repr = self.projection(perturbation_embeds)
108
+
109
+ return hidden_repr
110
+
111
+ class SingleLayerCrossModalFusion(nn.Module):
112
+ """Cross-modal fusion compatible with single hidden layer"""
113
+
114
+ def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
115
+ super().__init__()
116
+ self.latent_projection = nn.Linear(latent_dim, fusion_dim)
117
+ self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
118
+
119
+ # Simplified fusion for single layer
120
+ self.fusion_gate = nn.Sequential(
121
+ nn.Linear(fusion_dim * 2, fusion_dim),
122
+ nn.Sigmoid()
123
+ )
124
+
125
+ self.norm = nn.LayerNorm(fusion_dim)
126
+ self.dropout = nn.Dropout(0.1)
127
+
128
+ def forward(self, latent, perturbation_encoded):
129
+ # Project both modalities
130
+ latent_proj = self.latent_projection(latent)
131
+ perturbation_proj = self.perturbation_projection(perturbation_encoded)
132
+
133
+ # Simple gated fusion (no complex attention for single layer)
134
+ concatenated = torch.cat([latent_proj, perturbation_proj], dim=-1)
135
+ fusion_gate = self.fusion_gate(concatenated)
136
+
137
+ # Gated fusion
138
+ fused = fusion_gate * latent_proj + (1 - fusion_gate) * perturbation_proj
139
+ fused = self.norm(fused)
140
+ fused = self.dropout(fused)
141
+
142
+ return fused
143
+
144
+ class SingleLayerPerturbationResponseNetwork(nn.Module):
145
+ """Response network compatible with single hidden layer"""
146
+
147
+ def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
148
+ super().__init__()
149
+
150
+ # Handle single layer case
151
+ if len(hidden_dims) == 1:
152
+ # Single hidden layer
153
+ self.base_network = nn.Sequential(
154
+ nn.Linear(fusion_dim, hidden_dims[0]),
155
+ nn.BatchNorm1d(hidden_dims[0]),
156
+ nn.ReLU(),
157
+ nn.Dropout(0.1)
158
+ )
159
+ final_input_dim = hidden_dims[0]
160
+ else:
161
+ # Multiple hidden layers
162
+ layers = []
163
+ input_dim = fusion_dim
164
+
165
+ for hidden_dim in hidden_dims:
166
+ layers.extend([
167
+ nn.Linear(input_dim, hidden_dim),
168
+ nn.BatchNorm1d(hidden_dim),
169
+ nn.ReLU(),
170
+ nn.Dropout(0.1)
171
+ ])
172
+ input_dim = hidden_dim
173
+
174
+ self.base_network = nn.Sequential(*layers)
175
+ final_input_dim = hidden_dims[-1]
176
+
177
+ # Final projection
178
+ self.final_projection = nn.Linear(final_input_dim, gene_dim)
179
+
180
+ # Perturbation-aware scaling (simplified for single layer)
181
+ self.scale = nn.Linear(fusion_dim, 1)
182
+ self.bias = nn.Linear(fusion_dim, 1)
183
+
184
+ def forward(self, fused_representation):
185
+ base_output = self.base_network(fused_representation)
186
+ expression = self.final_projection(base_output)
187
+
188
+ # Perturbation-aware scaling
189
+ scale = torch.sigmoid(self.scale(fused_representation)) * 2
190
+ bias = self.bias(fused_representation)
191
+
192
+ return F.softplus(expression * scale + bias)
193
+
194
+ class SingleLayerNovelPerturbationPredictor(nn.Module):
195
+ """Novel perturbation predictor compatible with single hidden layer"""
196
+
197
+ def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
198
+ super().__init__()
199
+ self.num_known_perturbations = num_known_perturbations
200
+ self.gene_dim = gene_dim
201
+
202
+ # Learnable perturbation prototypes
203
+ self.perturbation_prototypes = nn.Parameter(
204
+ torch.randn(num_known_perturbations, gene_dim) * 0.1
205
+ )
206
+
207
+ # Simplified response generator for single layer
208
+ if hidden_dim > 0:
209
+ self.response_generator = nn.Sequential(
210
+ nn.Linear(num_known_perturbations, hidden_dim),
211
+ nn.ReLU(),
212
+ nn.Linear(hidden_dim, gene_dim)
213
+ )
214
+ else:
215
+ # Direct projection if no hidden layer
216
+ self.response_generator = nn.Linear(num_known_perturbations, gene_dim)
217
+
218
+ # Simplified attention
219
+ self.similarity_attention = nn.Sequential(
220
+ nn.Linear(num_known_perturbations, num_known_perturbations),
221
+ nn.Softmax(dim=-1)
222
+ )
223
+
224
+ def forward(self, similarity_matrix, latent_features=None):
225
+ batch_size = similarity_matrix.shape[0]
226
+
227
+ # Method 1: Attention-weighted combination of known responses
228
+ attention_weights = self.similarity_attention(similarity_matrix)
229
+ weighted_response = torch.matmul(attention_weights, self.perturbation_prototypes)
230
+
231
+ # Method 2: Direct generation from similarity
232
+ generated_response = self.response_generator(similarity_matrix)
233
+
234
+ # Simple combination
235
+ combination_weights = torch.sigmoid(similarity_matrix.mean(dim=1, keepdim=True))
236
+ final_response = (combination_weights * weighted_response +
237
+ (1 - combination_weights) * generated_response)
238
+
239
+ return final_response
240
+
241
+ class SingleLayerCompatibleDecoder(nn.Module):
242
+ """Main decoder compatible with single hidden layer configurations"""
243
+
244
+ def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
245
+ hidden_dims: List[int], perturbation_embedding_dim: int,
246
+ biological_prior_dim: int, dropout_rate: float):
247
+ super().__init__()
248
+
249
+ self.num_known_perturbations = num_known_perturbations
250
+ self.latent_dim = latent_dim
251
+ self.gene_dim = gene_dim
252
+
253
+ # Handle single layer case for hidden_dims
254
+ if len(hidden_dims) == 1:
255
+ # Use the single dimension for all components
256
+ main_hidden_dim = hidden_dims[0]
257
+ response_hidden_dims = [] # No additional hidden layers for response
258
+ else:
259
+ # Multiple layers: first for fusion, rest for response
260
+ main_hidden_dim = hidden_dims[0]
261
+ response_hidden_dims = hidden_dims[1:]
262
+
263
+ # Perturbation encoder
264
+ self.perturbation_encoder = PerturbationAwareDecoder.SingleLayerCompatiblePerturbationEncoder(
265
+ num_known_perturbations, perturbation_embedding_dim, main_hidden_dim
266
+ )
267
+
268
+ # Cross-modal fusion
269
+ self.cross_modal_fusion = PerturbationAwareDecoder.SingleLayerCrossModalFusion(
270
+ latent_dim, main_hidden_dim, main_hidden_dim
271
+ )
272
+
273
+ # Response network
274
+ self.response_network = PerturbationAwareDecoder.SingleLayerPerturbationResponseNetwork(
275
+ main_hidden_dim, gene_dim, response_hidden_dims
276
+ )
277
+
278
+ # Novel perturbation predictor
279
+ self.novel_predictor = PerturbationAwareDecoder.SingleLayerNovelPerturbationPredictor(
280
+ num_known_perturbations, gene_dim, main_hidden_dim
281
+ )
282
+
283
+ def forward(self, latent, perturbation_matrix, mode='one_hot'):
284
+ if mode == 'one_hot':
285
+ # Known perturbation pathway
286
+ perturbation_encoded = self.perturbation_encoder(perturbation_matrix)
287
+ fused = self.cross_modal_fusion(latent, perturbation_encoded)
288
+ expression = self.response_network(fused)
289
+
290
+ elif mode == 'similarity':
291
+ # Novel perturbation pathway
292
+ expression = self.novel_predictor(perturbation_matrix, latent)
293
+
294
+ else:
295
+ raise ValueError(f"Unknown mode: {mode}. Use 'one_hot' or 'similarity'")
296
+
297
+ return expression
298
+
299
+ def get_perturbation_prototypes(self):
300
+ """Get learned perturbation response prototypes"""
301
+ return self.novel_predictor.perturbation_prototypes.detach()
302
+
303
+ def _build_single_layer_compatible_model(self):
304
+ """Build model compatible with single hidden layer"""
305
+ return self.SingleLayerCompatibleDecoder(
306
+ self.latent_dim, self.num_known_perturbations, self.gene_dim,
307
+ self.hidden_dims, self.perturbation_embedding_dim,
308
+ self.biological_prior_dim, self.dropout_rate
309
+ )
310
+
311
+ def train(self,
312
+ train_latent: np.ndarray,
313
+ train_perturbations: np.ndarray,
314
+ train_expression: np.ndarray,
315
+ val_latent: np.ndarray = None,
316
+ val_perturbations: np.ndarray = None,
317
+ val_expression: np.ndarray = None,
318
+ batch_size: int = 32,
319
+ num_epochs: int = 200,
320
+ learning_rate: float = 1e-4,
321
+ checkpoint_path: str = 'single_layer_decoder.pth') -> Dict:
322
+ """
323
+ Train the decoder with single layer compatibility
324
+ """
325
+ print("🧬 Starting Training with Single Layer Support...")
326
+
327
+ # Validate one-hot encoding
328
+ self._validate_one_hot_perturbations(train_perturbations)
329
+
330
+ # Data preparation
331
+ train_dataset = self._create_dataset(train_latent, train_perturbations, train_expression)
332
+
333
+ if val_latent is not None and val_perturbations is not None and val_expression is not None:
334
+ val_dataset = self._create_dataset(val_latent, val_perturbations, val_expression)
335
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
336
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
337
+ else:
338
+ train_size = int(0.9 * len(train_dataset))
339
+ val_size = len(train_dataset) - train_size
340
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
341
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
342
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
343
+
344
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
345
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
346
+ print(f"🔧 Hidden layers: {len(self.hidden_dims)}")
347
+ print(f"🔧 Hidden dimensions: {self.hidden_dims}")
348
+
349
+ # Optimizer
350
+ optimizer = optim.AdamW(
351
+ self.model.parameters(),
352
+ lr=learning_rate,
353
+ weight_decay=1e-5
354
+ )
355
+
356
+ # Scheduler
357
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
358
+
359
+ # Loss function
360
+ def loss_fn(pred, target):
361
+ mse_loss = F.mse_loss(pred, target)
362
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
363
+ correlation = self._pearson_correlation(pred, target)
364
+ correlation_loss = 1 - correlation
365
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
366
+
367
+ # Training history
368
+ history = {
369
+ 'train_loss': [], 'val_loss': [],
370
+ 'train_mse': [], 'val_mse': [],
371
+ 'train_correlation': [], 'val_correlation': [],
372
+ 'learning_rates': []
373
+ }
374
+
375
+ best_val_loss = float('inf')
376
+ patience = 20
377
+ patience_counter = 0
378
+
379
+ print("\n🔬 Starting training...")
380
+ for epoch in range(1, num_epochs + 1):
381
+ # Training
382
+ train_metrics = self._train_epoch(train_loader, optimizer, loss_fn)
383
+
384
+ # Validation
385
+ val_metrics = self._validate_epoch(val_loader, loss_fn)
386
+
387
+ # Update scheduler
388
+ scheduler.step()
389
+ current_lr = optimizer.param_groups[0]['lr']
390
+
391
+ # Record history
392
+ history['train_loss'].append(train_metrics['loss'])
393
+ history['val_loss'].append(val_metrics['loss'])
394
+ history['train_mse'].append(train_metrics['mse'])
395
+ history['val_mse'].append(val_metrics['mse'])
396
+ history['train_correlation'].append(train_metrics['correlation'])
397
+ history['val_correlation'].append(val_metrics['correlation'])
398
+ history['learning_rates'].append(current_lr)
399
+
400
+ # Print progress
401
+ if epoch % 10 == 0 or epoch == 1:
402
+ print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
403
+ f"Train: {train_metrics['loss']:.4f} | "
404
+ f"Val: {val_metrics['loss']:.4f} | "
405
+ f"Corr: {val_metrics['correlation']:.4f} | "
406
+ f"LR: {current_lr:.2e}")
407
+
408
+ # Early stopping
409
+ if val_metrics['loss'] < best_val_loss:
410
+ best_val_loss = val_metrics['loss']
411
+ patience_counter = 0
412
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
413
+ else:
414
+ patience_counter += 1
415
+ if patience_counter >= patience:
416
+ print(f"🛑 Early stopping at epoch {epoch}")
417
+ break
418
+
419
+ self.is_trained = True
420
+ self.training_history = history
421
+ self.best_val_loss = best_val_loss
422
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
423
+
424
+ print(f"\n🎉 Training completed! Best val loss: {best_val_loss:.4f}")
425
+ return history
426
+
427
+ def _validate_one_hot_perturbations(self, perturbations):
428
+ """Validate that perturbations are proper one-hot encodings"""
429
+ assert perturbations.shape[1] == self.num_known_perturbations, \
430
+ f"Perturbation dimension {perturbations.shape[1]} doesn't match expected {self.num_known_perturbations}"
431
+
432
+ row_sums = perturbations.sum(axis=1)
433
+ valid_rows = np.all((row_sums == 0) | (row_sums == 1))
434
+ assert valid_rows, "Perturbations should be one-hot encoded (sum to 0 or 1 per row)"
435
+
436
+ print("✅ One-hot perturbations validated")
437
+
438
+ def _create_dataset(self, latent_data, perturbations, expression_data):
439
+ """Create dataset with one-hot perturbations"""
440
+ class OneHotDataset(Dataset):
441
+ def __init__(self, latent, perturbations, expression):
442
+ self.latent = torch.FloatTensor(latent)
443
+ self.perturbations = torch.FloatTensor(perturbations)
444
+ self.expression = torch.FloatTensor(expression)
445
+
446
+ def __len__(self):
447
+ return len(self.latent)
448
+
449
+ def __getitem__(self, idx):
450
+ return self.latent[idx], self.perturbations[idx], self.expression[idx]
451
+
452
+ return OneHotDataset(latent_data, perturbations, expression_data)
453
+
454
+ def predict(self,
455
+ latent_data: np.ndarray,
456
+ perturbations: np.ndarray,
457
+ batch_size: int = 32) -> np.ndarray:
458
+ """
459
+ Predict expression for known perturbations
460
+ """
461
+ if not self.is_trained:
462
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
463
+
464
+ self._validate_one_hot_perturbations(perturbations)
465
+
466
+ self.model.eval()
467
+
468
+ if isinstance(latent_data, np.ndarray):
469
+ latent_data = torch.FloatTensor(latent_data)
470
+ if isinstance(perturbations, np.ndarray):
471
+ perturbations = torch.FloatTensor(perturbations)
472
+
473
+ predictions = []
474
+ with torch.no_grad():
475
+ for i in range(0, len(latent_data), batch_size):
476
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
477
+ batch_perturbations = perturbations[i:i+batch_size].to(self.device)
478
+
479
+ batch_pred = self.model(batch_latent, batch_perturbations, mode='one_hot')
480
+ predictions.append(batch_pred.cpu())
481
+
482
+ return torch.cat(predictions).numpy()
483
+
484
+ def predict_novel_perturbation(self,
485
+ latent_data: np.ndarray,
486
+ similarity_matrix: np.ndarray,
487
+ batch_size: int = 32) -> np.ndarray:
488
+ """
489
+ Predict response to novel perturbations
490
+ """
491
+ if not self.is_trained:
492
+ warnings.warn("⚠️ Model not trained. Novel perturbation prediction may be inaccurate.")
493
+
494
+ assert similarity_matrix.shape[1] == self.num_known_perturbations, \
495
+ f"Similarity matrix columns {similarity_matrix.shape[1]} must match known perturbations {self.num_known_perturbations}"
496
+
497
+ self.model.eval()
498
+
499
+ if isinstance(latent_data, np.ndarray):
500
+ latent_data = torch.FloatTensor(latent_data)
501
+ if isinstance(similarity_matrix, np.ndarray):
502
+ similarity_matrix = torch.FloatTensor(similarity_matrix)
503
+
504
+ predictions = []
505
+ with torch.no_grad():
506
+ for i in range(0, len(latent_data), batch_size):
507
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
508
+ batch_similarity = similarity_matrix[i:i+batch_size].to(self.device)
509
+
510
+ batch_pred = self.model(batch_latent, batch_similarity, mode='similarity')
511
+ predictions.append(batch_pred.cpu())
512
+
513
+ return torch.cat(predictions).numpy()
514
+
515
+ def get_known_perturbation_prototypes(self) -> np.ndarray:
516
+ """Get learned perturbation response prototypes"""
517
+ if not self.is_trained:
518
+ warnings.warn("⚠️ Model not trained. Prototypes may be uninformative.")
519
+
520
+ if self.perturbation_prototypes is None:
521
+ self.model.eval()
522
+ with torch.no_grad():
523
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
524
+
525
+ return self.perturbation_prototypes
526
+
527
+ def _pearson_correlation(self, pred, target):
528
+ """Calculate Pearson correlation coefficient"""
529
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
530
+ target_centered = target - target.mean(dim=1, keepdim=True)
531
+
532
+ numerator = (pred_centered * target_centered).sum(dim=1)
533
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
534
+
535
+ return (numerator / (denominator + 1e-8)).mean()
536
+
537
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
538
+ """Train one epoch"""
539
+ self.model.train()
540
+ total_loss = 0
541
+ total_mse = 0
542
+ total_correlation = 0
543
+
544
+ for latent, perturbations, target in train_loader:
545
+ latent = latent.to(self.device)
546
+ perturbations = perturbations.to(self.device)
547
+ target = target.to(self.device)
548
+
549
+ optimizer.zero_grad()
550
+ pred = self.model(latent, perturbations, mode='one_hot')
551
+
552
+ loss = loss_fn(pred, target)
553
+ loss.backward()
554
+
555
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
556
+ optimizer.step()
557
+
558
+ mse_loss = F.mse_loss(pred, target).item()
559
+ correlation = self._pearson_correlation(pred, target).item()
560
+
561
+ total_loss += loss.item()
562
+ total_mse += mse_loss
563
+ total_correlation += correlation
564
+
565
+ num_batches = len(train_loader)
566
+ return {
567
+ 'loss': total_loss / num_batches,
568
+ 'mse': total_mse / num_batches,
569
+ 'correlation': total_correlation / num_batches
570
+ }
571
+
572
+ def _validate_epoch(self, val_loader, loss_fn):
573
+ """Validate one epoch"""
574
+ self.model.eval()
575
+ total_loss = 0
576
+ total_mse = 0
577
+ total_correlation = 0
578
+
579
+ with torch.no_grad():
580
+ for latent, perturbations, target in val_loader:
581
+ latent = latent.to(self.device)
582
+ perturbations = perturbations.to(self.device)
583
+ target = target.to(self.device)
584
+
585
+ pred = self.model(latent, perturbations, mode='one_hot')
586
+ loss = loss_fn(pred, target)
587
+ mse_loss = F.mse_loss(pred, target).item()
588
+ correlation = self._pearson_correlation(pred, target).item()
589
+
590
+ total_loss += loss.item()
591
+ total_mse += mse_loss
592
+ total_correlation += correlation
593
+
594
+ num_batches = len(val_loader)
595
+ return {
596
+ 'loss': total_loss / num_batches,
597
+ 'mse': total_mse / num_batches,
598
+ 'correlation': total_correlation / num_batches
599
+ }
600
+
601
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
602
+ """Save model checkpoint"""
603
+ torch.save({
604
+ 'epoch': epoch,
605
+ 'model_state_dict': self.model.state_dict(),
606
+ 'optimizer_state_dict': optimizer.state_dict(),
607
+ 'scheduler_state_dict': scheduler.state_dict(),
608
+ 'best_val_loss': best_loss,
609
+ 'training_history': history,
610
+ 'perturbation_prototypes': self.perturbation_prototypes,
611
+ 'model_config': {
612
+ 'latent_dim': self.latent_dim,
613
+ 'num_known_perturbations': self.num_known_perturbations,
614
+ 'gene_dim': self.gene_dim,
615
+ 'hidden_dims': self.hidden_dims
616
+ }
617
+ }, path)
618
+
619
+ def load_model(self, model_path: str):
620
+ """Load pre-trained model"""
621
+ checkpoint = torch.load(model_path, map_location=self.device)
622
+ self.model.load_state_dict(checkpoint['model_state_dict'])
623
+ self.perturbation_prototypes = checkpoint.get('perturbation_prototypes')
624
+ self.is_trained = True
625
+ self.training_history = checkpoint.get('training_history')
626
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
627
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
628
+
629
+ '''# Example usage with single hidden layer
630
+ def example_single_layer_usage():
631
+ """Example demonstration with single hidden layer"""
632
+
633
+ # Initialize decoder with single hidden layer
634
+ decoder = PerturbationAwareDecoder(
635
+ latent_dim=100,
636
+ num_known_perturbations=10,
637
+ gene_dim=2000, # Reduced for example
638
+ hidden_dims=[512], # Single hidden layer
639
+ perturbation_embedding_dim=128
640
+ )
641
+
642
+ # Generate example data
643
+ n_samples = 1000
644
+ n_perturbations = 10
645
+
646
+ # Latent variables
647
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
648
+
649
+ # One-hot encoded perturbations
650
+ perturbations = np.zeros((n_samples, n_perturbations))
651
+ for i in range(n_samples):
652
+ if i % 10 != 0: # 90% perturbed, 10% control
653
+ perturb_id = np.random.randint(0, n_perturbations)
654
+ perturbations[i, perturb_id] = 1.0
655
+
656
+ # Expression data
657
+ base_weights = np.random.randn(100, 2000) * 0.1
658
+ perturbation_effects = np.random.randn(n_perturbations, 2000) * 0.5
659
+
660
+ expression_data = np.tanh(latent_data.dot(base_weights))
661
+ for i in range(n_samples):
662
+ if perturbations[i].sum() > 0:
663
+ perturb_id = np.argmax(perturbations[i])
664
+ expression_data[i] += perturbation_effects[perturb_id]
665
+
666
+ expression_data = np.maximum(expression_data, 0)
667
+
668
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}, Expression {expression_data.shape}")
669
+ print(f"🔧 Hidden layers: {len(decoder.hidden_dims)}")
670
+ print(f"🔧 Hidden dimensions: {decoder.hidden_dims}")
671
+
672
+ # Train with single hidden layer
673
+ history = decoder.train(
674
+ train_latent=latent_data,
675
+ train_perturbations=perturbations,
676
+ train_expression=expression_data,
677
+ batch_size=32,
678
+ num_epochs=50
679
+ )
680
+
681
+ # Test predictions
682
+ test_latent = np.random.randn(10, 100).astype(np.float32)
683
+ test_perturbations = np.zeros((10, n_perturbations))
684
+ for i in range(10):
685
+ test_perturbations[i, i % n_perturbations] = 1.0
686
+
687
+ predictions = decoder.predict(test_latent, test_perturbations)
688
+ print(f"🔮 Known perturbation prediction shape: {predictions.shape}")
689
+
690
+ # Test novel perturbation prediction
691
+ test_latent_novel = np.random.randn(5, 100).astype(np.float32)
692
+ similarity_matrix = np.random.rand(5, n_perturbations)
693
+ similarity_matrix = similarity_matrix / similarity_matrix.sum(axis=1, keepdims=True)
694
+
695
+ novel_predictions = decoder.predict_novel_perturbation(test_latent_novel, similarity_matrix)
696
+ print(f"🔮 Novel perturbation prediction shape: {novel_predictions.shape}")
697
+
698
+ return decoder
699
+
700
+ # Example usage with multiple hidden layers
701
+ def example_multi_layer_usage():
702
+ """Example demonstration with multiple hidden layers"""
703
+
704
+ # Initialize decoder with multiple hidden layers
705
+ decoder = PerturbationAwareDecoder(
706
+ latent_dim=100,
707
+ num_known_perturbations=10,
708
+ gene_dim=2000,
709
+ hidden_dims=[256, 512, 1024], # Multiple hidden layers
710
+ perturbation_embedding_dim=128
711
+ )
712
+
713
+ print(f"🔧 Hidden layers: {len(decoder.hidden_dims)}")
714
+ print(f"🔧 Hidden dimensions: {decoder.hidden_dims}")
715
+
716
+ return decoder
717
+
718
+ if __name__ == "__main__":
719
+ print("=== Single Hidden Layer Example ===")
720
+ decoder_single = example_single_layer_usage()
721
+
722
+ print("\n=== Multiple Hidden Layers Example ===")
723
+ decoder_multi = example_multi_layer_usage()'''