SURE-tools 2.4.7__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,737 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class PerturbationAwareDecoder:
13
+ """
14
+ Advanced transcriptome decoder with perturbation awareness
15
+ Fixed version with proper handling of single hidden layer configurations
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ num_known_perturbations: int = 50,
21
+ gene_dim: int = 60000,
22
+ hidden_dims: List[int] = [512],
23
+ perturbation_embedding_dim: int = 128,
24
+ biological_prior_dim: int = 256,
25
+ dropout_rate: float = 0.1,
26
+ device: str = None):
27
+ """
28
+ Multi-modal decoder with fixed single layer support
29
+ """
30
+ self.latent_dim = latent_dim
31
+ self.num_known_perturbations = num_known_perturbations
32
+ self.gene_dim = gene_dim
33
+ self.hidden_dims = hidden_dims
34
+ self.perturbation_embedding_dim = perturbation_embedding_dim
35
+ self.biological_prior_dim = biological_prior_dim
36
+ self.dropout_rate = dropout_rate
37
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
38
+
39
+ # Validate hidden_dims
40
+ self._validate_hidden_dims()
41
+
42
+ # Initialize multi-modal model
43
+ self.model = self._build_fixed_model()
44
+ self.model.to(self.device)
45
+
46
+ # Training state
47
+ self.is_trained = False
48
+ self.training_history = None
49
+ self.best_val_loss = float('inf')
50
+ self.known_perturbation_names = []
51
+ self.perturbation_prototypes = None
52
+
53
+ print(f"🧬 PerturbationAwareDecoder Initialized:")
54
+ print(f" - Latent Dimension: {latent_dim}")
55
+ print(f" - Known Perturbations: {num_known_perturbations}")
56
+ print(f" - Gene Dimension: {gene_dim}")
57
+ print(f" - Hidden Dimensions: {hidden_dims}")
58
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
59
+
60
+ def _validate_hidden_dims(self):
61
+ """Validate hidden_dims parameter"""
62
+ assert len(self.hidden_dims) >= 1, "hidden_dims must have at least one element"
63
+ assert all(dim > 0 for dim in self.hidden_dims), "All hidden dimensions must be positive"
64
+
65
+ if len(self.hidden_dims) == 1:
66
+ print("🔧 Single hidden layer configuration detected")
67
+ else:
68
+ print(f"🔧 Multi-layer configuration: {len(self.hidden_dims)} hidden layers")
69
+
70
+ class FixedPerturbationEncoder(nn.Module):
71
+ """Fixed perturbation encoder"""
72
+
73
+ def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
74
+ super().__init__()
75
+ self.num_perturbations = num_perturbations
76
+
77
+ # Embedding for perturbation types
78
+ self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
79
+
80
+ # Projection to hidden space
81
+ self.projection = nn.Sequential(
82
+ nn.Linear(embedding_dim, hidden_dim),
83
+ nn.ReLU(),
84
+ nn.Dropout(0.1)
85
+ )
86
+
87
+ def forward(self, one_hot_perturbations):
88
+ # Convert one-hot to indices
89
+ perturbation_indices = torch.argmax(one_hot_perturbations, dim=1)
90
+
91
+ # Get perturbation embeddings
92
+ perturbation_embeds = self.perturbation_embedding(perturbation_indices)
93
+
94
+ # Project to hidden space
95
+ hidden_repr = self.projection(perturbation_embeds)
96
+
97
+ return hidden_repr
98
+
99
+ class FixedCrossModalFusion(nn.Module):
100
+ """Fixed cross-modal fusion"""
101
+
102
+ def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
103
+ super().__init__()
104
+ self.latent_projection = nn.Linear(latent_dim, fusion_dim)
105
+ self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
106
+
107
+ # Fusion gate
108
+ self.fusion_gate = nn.Sequential(
109
+ nn.Linear(fusion_dim * 2, fusion_dim),
110
+ nn.Sigmoid()
111
+ )
112
+
113
+ self.norm = nn.LayerNorm(fusion_dim)
114
+ self.dropout = nn.Dropout(0.1)
115
+
116
+ def forward(self, latent, perturbation_encoded):
117
+ # Project both modalities
118
+ latent_proj = self.latent_projection(latent)
119
+ perturbation_proj = self.perturbation_projection(perturbation_encoded)
120
+
121
+ # Gated fusion
122
+ concatenated = torch.cat([latent_proj, perturbation_proj], dim=-1)
123
+ fusion_gate = self.fusion_gate(concatenated)
124
+
125
+ # Gated fusion
126
+ fused = fusion_gate * latent_proj + (1 - fusion_gate) * perturbation_proj
127
+ fused = self.norm(fused)
128
+ fused = self.dropout(fused)
129
+
130
+ return fused
131
+
132
+ class FixedPerturbationResponseNetwork(nn.Module):
133
+ """Fixed response network with proper single layer handling"""
134
+
135
+ def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
136
+ super().__init__()
137
+
138
+ # Build network layers
139
+ layers = []
140
+ input_dim = fusion_dim
141
+
142
+ # Handle both single and multi-layer cases
143
+ for i, hidden_dim in enumerate(hidden_dims):
144
+ layers.extend([
145
+ nn.Linear(input_dim, hidden_dim),
146
+ nn.BatchNorm1d(hidden_dim),
147
+ nn.ReLU(),
148
+ nn.Dropout(0.1)
149
+ ])
150
+ input_dim = hidden_dim
151
+
152
+ self.base_network = nn.Sequential(*layers)
153
+
154
+ # Final projection - FIXED: Use current input_dim instead of hidden_dims[-1]
155
+ self.final_projection = nn.Linear(input_dim, gene_dim)
156
+
157
+ # Perturbation-aware scaling
158
+ self.scale = nn.Linear(fusion_dim, 1)
159
+ self.bias = nn.Linear(fusion_dim, 1)
160
+
161
+ def forward(self, fused_representation):
162
+ base_output = self.base_network(fused_representation)
163
+ expression = self.final_projection(base_output)
164
+
165
+ # Perturbation-aware scaling
166
+ scale = torch.sigmoid(self.scale(fused_representation)) * 2
167
+ bias = self.bias(fused_representation)
168
+
169
+ return F.softplus(expression * scale + bias)
170
+
171
+ class FixedNovelPerturbationPredictor(nn.Module):
172
+ """Fixed novel perturbation predictor"""
173
+
174
+ def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
175
+ super().__init__()
176
+ self.num_known_perturbations = num_known_perturbations
177
+ self.gene_dim = gene_dim
178
+
179
+ # Learnable perturbation prototypes
180
+ self.perturbation_prototypes = nn.Parameter(
181
+ torch.randn(num_known_perturbations, gene_dim) * 0.1
182
+ )
183
+
184
+ # Response generator - handle case where hidden_dim might be 0
185
+ if hidden_dim > 0:
186
+ self.response_generator = nn.Sequential(
187
+ nn.Linear(num_known_perturbations, hidden_dim),
188
+ nn.ReLU(),
189
+ nn.Linear(hidden_dim, gene_dim)
190
+ )
191
+ else:
192
+ # Direct projection if no hidden layer
193
+ self.response_generator = nn.Linear(num_known_perturbations, gene_dim)
194
+
195
+ # Attention mechanism
196
+ self.similarity_attention = nn.Sequential(
197
+ nn.Linear(num_known_perturbations, num_known_perturbations),
198
+ nn.Softmax(dim=-1)
199
+ )
200
+
201
+ def forward(self, similarity_matrix, latent_features=None):
202
+ batch_size = similarity_matrix.shape[0]
203
+
204
+ # Method 1: Attention-weighted combination of known responses
205
+ attention_weights = self.similarity_attention(similarity_matrix)
206
+ weighted_response = torch.matmul(attention_weights, self.perturbation_prototypes)
207
+
208
+ # Method 2: Direct generation from similarity
209
+ generated_response = self.response_generator(similarity_matrix)
210
+
211
+ # Simple combination
212
+ combination_weights = torch.sigmoid(similarity_matrix.mean(dim=1, keepdim=True))
213
+ final_response = (combination_weights * weighted_response +
214
+ (1 - combination_weights) * generated_response)
215
+
216
+ return final_response
217
+
218
+ class FixedMultimodalDecoder(nn.Module):
219
+ """Main decoder with fixed single layer handling"""
220
+
221
+ def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
222
+ hidden_dims: List[int], perturbation_embedding_dim: int,
223
+ biological_prior_dim: int, dropout_rate: float):
224
+ super().__init__()
225
+
226
+ self.num_known_perturbations = num_known_perturbations
227
+ self.latent_dim = latent_dim
228
+ self.gene_dim = gene_dim
229
+
230
+ # Use first hidden dimension for fusion
231
+ main_hidden_dim = hidden_dims[0]
232
+
233
+ # Perturbation encoder
234
+ self.perturbation_encoder = PerturbationAwareDecoder.FixedPerturbationEncoder(
235
+ num_known_perturbations, perturbation_embedding_dim, main_hidden_dim
236
+ )
237
+
238
+ # Cross-modal fusion
239
+ self.cross_modal_fusion = PerturbationAwareDecoder.FixedCrossModalFusion(
240
+ latent_dim, main_hidden_dim, main_hidden_dim
241
+ )
242
+
243
+ # Response network - FIXED: Use all hidden_dims for response network
244
+ self.response_network = PerturbationAwareDecoder.FixedPerturbationResponseNetwork(
245
+ main_hidden_dim, gene_dim, hidden_dims # Pass all hidden_dims
246
+ )
247
+
248
+ # Novel perturbation predictor
249
+ self.novel_predictor = PerturbationAwareDecoder.FixedNovelPerturbationPredictor(
250
+ num_known_perturbations, gene_dim, main_hidden_dim
251
+ )
252
+
253
+ def forward(self, latent, perturbation_matrix, mode='one_hot'):
254
+ if mode == 'one_hot':
255
+ # Known perturbation pathway
256
+ perturbation_encoded = self.perturbation_encoder(perturbation_matrix)
257
+ fused = self.cross_modal_fusion(latent, perturbation_encoded)
258
+ expression = self.response_network(fused)
259
+
260
+ elif mode == 'similarity':
261
+ # Novel perturbation pathway
262
+ expression = self.novel_predictor(perturbation_matrix, latent)
263
+
264
+ else:
265
+ raise ValueError(f"Unknown mode: {mode}. Use 'one_hot' or 'similarity'")
266
+
267
+ return expression
268
+
269
+ def get_perturbation_prototypes(self):
270
+ """Get learned perturbation response prototypes"""
271
+ return self.novel_predictor.perturbation_prototypes.detach()
272
+
273
+ def _build_fixed_model(self):
274
+ """Build the fixed model"""
275
+ return self.FixedMultimodalDecoder(
276
+ self.latent_dim, self.num_known_perturbations, self.gene_dim,
277
+ self.hidden_dims, self.perturbation_embedding_dim,
278
+ self.biological_prior_dim, self.dropout_rate
279
+ )
280
+
281
+ def train(self,
282
+ train_latent: np.ndarray,
283
+ train_perturbations: np.ndarray,
284
+ train_expression: np.ndarray,
285
+ val_latent: np.ndarray = None,
286
+ val_perturbations: np.ndarray = None,
287
+ val_expression: np.ndarray = None,
288
+ batch_size: int = 32,
289
+ num_epochs: int = 200,
290
+ learning_rate: float = 1e-4,
291
+ checkpoint_path: str = 'fixed_decoder.pth') -> Dict:
292
+ """
293
+ Train the fixed decoder
294
+ """
295
+ print("🧬 Starting Training with Fixed Single Layer Support...")
296
+
297
+ # Validate one-hot encoding
298
+ self._validate_one_hot_perturbations(train_perturbations)
299
+
300
+ # Data preparation
301
+ train_dataset = self._create_dataset(train_latent, train_perturbations, train_expression)
302
+
303
+ if val_latent is not None and val_perturbations is not None and val_expression is not None:
304
+ val_dataset = self._create_dataset(val_latent, val_perturbations, val_expression)
305
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
306
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
307
+ else:
308
+ train_size = int(0.9 * len(train_dataset))
309
+ val_size = len(train_dataset) - train_size
310
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
311
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
312
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
313
+
314
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
315
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
316
+ print(f"🔧 Hidden layers: {len(self.hidden_dims)}")
317
+ print(f"🔧 Hidden dimensions: {self.hidden_dims}")
318
+
319
+ # Optimizer
320
+ optimizer = optim.AdamW(
321
+ self.model.parameters(),
322
+ lr=learning_rate,
323
+ weight_decay=1e-5
324
+ )
325
+
326
+ # Scheduler
327
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
328
+
329
+ # Loss function
330
+ def loss_fn(pred, target):
331
+ mse_loss = F.mse_loss(pred, target)
332
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
333
+ correlation = self._pearson_correlation(pred, target)
334
+ correlation_loss = 1 - correlation
335
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
336
+
337
+ # Training history
338
+ history = {
339
+ 'train_loss': [], 'val_loss': [],
340
+ 'train_mse': [], 'val_mse': [],
341
+ 'train_correlation': [], 'val_correlation': [],
342
+ 'learning_rates': []
343
+ }
344
+
345
+ best_val_loss = float('inf')
346
+ patience = 20
347
+ patience_counter = 0
348
+
349
+ print("\n🔬 Starting training...")
350
+ for epoch in range(1, num_epochs + 1):
351
+ # Training
352
+ train_metrics = self._train_epoch(train_loader, optimizer, loss_fn)
353
+
354
+ # Validation
355
+ val_metrics = self._validate_epoch(val_loader, loss_fn)
356
+
357
+ # Update scheduler
358
+ scheduler.step()
359
+ current_lr = optimizer.param_groups[0]['lr']
360
+
361
+ # Record history
362
+ history['train_loss'].append(train_metrics['loss'])
363
+ history['val_loss'].append(val_metrics['loss'])
364
+ history['train_mse'].append(train_metrics['mse'])
365
+ history['val_mse'].append(val_metrics['mse'])
366
+ history['train_correlation'].append(train_metrics['correlation'])
367
+ history['val_correlation'].append(val_metrics['correlation'])
368
+ history['learning_rates'].append(current_lr)
369
+
370
+ # Print progress
371
+ if epoch % 10 == 0 or epoch == 1:
372
+ print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
373
+ f"Train: {train_metrics['loss']:.4f} | "
374
+ f"Val: {val_metrics['loss']:.4f} | "
375
+ f"Corr: {val_metrics['correlation']:.4f} | "
376
+ f"LR: {current_lr:.2e}")
377
+
378
+ # Early stopping
379
+ if val_metrics['loss'] < best_val_loss:
380
+ best_val_loss = val_metrics['loss']
381
+ patience_counter = 0
382
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
383
+ else:
384
+ patience_counter += 1
385
+ if patience_counter >= patience:
386
+ print(f"🛑 Early stopping at epoch {epoch}")
387
+ break
388
+
389
+ self.is_trained = True
390
+ self.training_history = history
391
+ self.best_val_loss = best_val_loss
392
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
393
+
394
+ print(f"\n🎉 Training completed! Best val loss: {best_val_loss:.4f}")
395
+ return history
396
+
397
+ def _validate_one_hot_perturbations(self, perturbations):
398
+ """Validate that perturbations are proper one-hot encodings"""
399
+ assert perturbations.shape[1] == self.num_known_perturbations, \
400
+ f"Perturbation dimension {perturbations.shape[1]} doesn't match expected {self.num_known_perturbations}"
401
+
402
+ row_sums = perturbations.sum(axis=1)
403
+ valid_rows = np.all((row_sums == 0) | (row_sums == 1))
404
+ assert valid_rows, "Perturbations should be one-hot encoded (sum to 0 or 1 per row)"
405
+
406
+ print("✅ One-hot perturbations validated")
407
+
408
+ def _create_dataset(self, latent_data, perturbations, expression_data):
409
+ """Create dataset with one-hot perturbations"""
410
+ class OneHotDataset(Dataset):
411
+ def __init__(self, latent, perturbations, expression):
412
+ self.latent = torch.FloatTensor(latent)
413
+ self.perturbations = torch.FloatTensor(perturbations)
414
+ self.expression = torch.FloatTensor(expression)
415
+
416
+ def __len__(self):
417
+ return len(self.latent)
418
+
419
+ def __getitem__(self, idx):
420
+ return self.latent[idx], self.perturbations[idx], self.expression[idx]
421
+
422
+ return OneHotDataset(latent_data, perturbations, expression_data)
423
+
424
+ def predict(self,
425
+ latent_data: np.ndarray,
426
+ perturbations: np.ndarray,
427
+ batch_size: int = 32) -> np.ndarray:
428
+ """
429
+ Predict expression for known perturbations
430
+ """
431
+ if not self.is_trained:
432
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
433
+
434
+ self._validate_one_hot_perturbations(perturbations)
435
+
436
+ self.model.eval()
437
+
438
+ if isinstance(latent_data, np.ndarray):
439
+ latent_data = torch.FloatTensor(latent_data)
440
+ if isinstance(perturbations, np.ndarray):
441
+ perturbations = torch.FloatTensor(perturbations)
442
+
443
+ predictions = []
444
+ with torch.no_grad():
445
+ for i in range(0, len(latent_data), batch_size):
446
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
447
+ batch_perturbations = perturbations[i:i+batch_size].to(self.device)
448
+
449
+ batch_pred = self.model(batch_latent, batch_perturbations, mode='one_hot')
450
+ predictions.append(batch_pred.cpu())
451
+
452
+ return torch.cat(predictions).numpy()
453
+
454
+ def predict_novel_perturbation(self,
455
+ latent_data: np.ndarray,
456
+ similarity_matrix: np.ndarray,
457
+ batch_size: int = 32) -> np.ndarray:
458
+ """
459
+ Predict response to novel perturbations
460
+ """
461
+ if not self.is_trained:
462
+ warnings.warn("⚠️ Model not trained. Novel perturbation prediction may be inaccurate.")
463
+
464
+ assert similarity_matrix.shape[1] == self.num_known_perturbations, \
465
+ f"Similarity matrix columns {similarity_matrix.shape[1]} must match known perturbations {self.num_known_perturbations}"
466
+
467
+ self.model.eval()
468
+
469
+ if isinstance(latent_data, np.ndarray):
470
+ latent_data = torch.FloatTensor(latent_data)
471
+ if isinstance(similarity_matrix, np.ndarray):
472
+ similarity_matrix = torch.FloatTensor(similarity_matrix)
473
+
474
+ predictions = []
475
+ with torch.no_grad():
476
+ for i in range(0, len(latent_data), batch_size):
477
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
478
+ batch_similarity = similarity_matrix[i:i+batch_size].to(self.device)
479
+
480
+ batch_pred = self.model(batch_latent, batch_similarity, mode='similarity')
481
+ predictions.append(batch_pred.cpu())
482
+
483
+ return torch.cat(predictions).numpy()
484
+
485
+ def get_known_perturbation_prototypes(self) -> np.ndarray:
486
+ """Get learned perturbation response prototypes"""
487
+ if not self.is_trained:
488
+ warnings.warn("⚠️ Model not trained. Prototypes may be uninformative.")
489
+
490
+ if self.perturbation_prototypes is None:
491
+ self.model.eval()
492
+ with torch.no_grad():
493
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
494
+
495
+ return self.perturbation_prototypes
496
+
497
+ def _pearson_correlation(self, pred, target):
498
+ """Calculate Pearson correlation coefficient"""
499
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
500
+ target_centered = target - target.mean(dim=1, keepdim=True)
501
+
502
+ numerator = (pred_centered * target_centered).sum(dim=1)
503
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
504
+
505
+ return (numerator / (denominator + 1e-8)).mean()
506
+
507
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
508
+ """Train one epoch"""
509
+ self.model.train()
510
+ total_loss = 0
511
+ total_mse = 0
512
+ total_correlation = 0
513
+
514
+ for latent, perturbations, target in train_loader:
515
+ latent = latent.to(self.device)
516
+ perturbations = perturbations.to(self.device)
517
+ target = target.to(self.device)
518
+
519
+ optimizer.zero_grad()
520
+ pred = self.model(latent, perturbations, mode='one_hot')
521
+
522
+ loss = loss_fn(pred, target)
523
+ loss.backward()
524
+
525
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
526
+ optimizer.step()
527
+
528
+ mse_loss = F.mse_loss(pred, target).item()
529
+ correlation = self._pearson_correlation(pred, target).item()
530
+
531
+ total_loss += loss.item()
532
+ total_mse += mse_loss
533
+ total_correlation += correlation
534
+
535
+ num_batches = len(train_loader)
536
+ return {
537
+ 'loss': total_loss / num_batches,
538
+ 'mse': total_mse / num_batches,
539
+ 'correlation': total_correlation / num_batches
540
+ }
541
+
542
+ def _validate_epoch(self, val_loader, loss_fn):
543
+ """Validate one epoch"""
544
+ self.model.eval()
545
+ total_loss = 0
546
+ total_mse = 0
547
+ total_correlation = 0
548
+
549
+ with torch.no_grad():
550
+ for latent, perturbations, target in val_loader:
551
+ latent = latent.to(self.device)
552
+ perturbations = perturbations.to(self.device)
553
+ target = target.to(self.device)
554
+
555
+ pred = self.model(latent, perturbations, mode='one_hot')
556
+ loss = loss_fn(pred, target)
557
+ mse_loss = F.mse_loss(pred, target).item()
558
+ correlation = self._pearson_correlation(pred, target).item()
559
+
560
+ total_loss += loss.item()
561
+ total_mse += mse_loss
562
+ total_correlation += correlation
563
+
564
+ num_batches = len(val_loader)
565
+ return {
566
+ 'loss': total_loss / num_batches,
567
+ 'mse': total_mse / num_batches,
568
+ 'correlation': total_correlation / num_batches
569
+ }
570
+
571
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
572
+ """Save model checkpoint"""
573
+ torch.save({
574
+ 'epoch': epoch,
575
+ 'model_state_dict': self.model.state_dict(),
576
+ 'optimizer_state_dict': optimizer.state_dict(),
577
+ 'scheduler_state_dict': scheduler.state_dict(),
578
+ 'best_val_loss': best_loss,
579
+ 'training_history': history,
580
+ 'perturbation_prototypes': self.perturbation_prototypes,
581
+ 'model_config': {
582
+ 'latent_dim': self.latent_dim,
583
+ 'num_known_perturbations': self.num_known_perturbations,
584
+ 'gene_dim': self.gene_dim,
585
+ 'hidden_dims': self.hidden_dims
586
+ }
587
+ }, path)
588
+
589
+ def load_model(self, model_path: str):
590
+ """Load pre-trained model"""
591
+ checkpoint = torch.load(model_path, map_location=self.device)
592
+ self.model.load_state_dict(checkpoint['model_state_dict'])
593
+ self.perturbation_prototypes = checkpoint.get('perturbation_prototypes')
594
+ self.is_trained = True
595
+ self.training_history = checkpoint.get('training_history')
596
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
597
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
598
+
599
+ '''# Test the fixed implementation
600
+ def test_single_layer_fix():
601
+ """Test the fixed single layer implementation"""
602
+
603
+ print("🧪 Testing single layer configuration...")
604
+
605
+ # Test with single hidden layer
606
+ decoder_single = PerturbationAwareDecoder(
607
+ latent_dim=100,
608
+ num_known_perturbations=10,
609
+ gene_dim=2000,
610
+ hidden_dims=[512], # Single element list
611
+ perturbation_embedding_dim=128
612
+ )
613
+
614
+ # Generate test data
615
+ n_samples = 100
616
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
617
+ perturbations = np.zeros((n_samples, 10))
618
+ for i in range(n_samples):
619
+ if i % 10 != 0:
620
+ perturbations[i, np.random.randint(0, 10)] = 1.0
621
+
622
+ expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
623
+ expression_data = np.maximum(expression_data, 0)
624
+
625
+ # Test forward pass
626
+ decoder_single.model.eval()
627
+ with torch.no_grad():
628
+ latent_tensor = torch.FloatTensor(latent_data[:5]).to(decoder_single.device)
629
+ perturbations_tensor = torch.FloatTensor(perturbations[:5]).to(decoder_single.device)
630
+
631
+ # Test known perturbation prediction
632
+ output = decoder_single.model(latent_tensor, perturbations_tensor, mode='one_hot')
633
+ print(f"✅ Known perturbation prediction shape: {output.shape}")
634
+
635
+ # Test novel perturbation prediction
636
+ similarity_matrix = np.random.rand(5, 10).astype(np.float32)
637
+ similarity_tensor = torch.FloatTensor(similarity_matrix).to(decoder_single.device)
638
+ novel_output = decoder_single.model(latent_tensor, similarity_tensor, mode='similarity')
639
+ print(f"✅ Novel perturbation prediction shape: {novel_output.shape}")
640
+
641
+ print("🎉 Single layer test passed!")
642
+
643
+ def test_multi_layer_fix():
644
+ """Test the multi-layer implementation"""
645
+
646
+ print("\n🧪 Testing multi-layer configuration...")
647
+
648
+ # Test with multiple hidden layers
649
+ decoder_multi = PerturbationAwareDecoder(
650
+ latent_dim=100,
651
+ num_known_perturbations=10,
652
+ gene_dim=2000,
653
+ hidden_dims=[256, 512, 1024], # Multiple layers
654
+ perturbation_embedding_dim=128
655
+ )
656
+
657
+ print("🎉 Multi-layer test passed!")
658
+
659
+ def test_edge_cases():
660
+ """Test edge cases"""
661
+
662
+ print("\n🧪 Testing edge cases...")
663
+
664
+ # Test with different hidden_dims configurations
665
+ configs = [
666
+ [512], # Single layer
667
+ [256, 512], # Two layers
668
+ [128, 256, 512], # Three layers
669
+ [1024], # Wide single layer
670
+ [64, 128, 256, 512, 1024] # Deep network
671
+ ]
672
+
673
+ for i, hidden_dims in enumerate(configs):
674
+ try:
675
+ decoder = PerturbationAwareDecoder(
676
+ latent_dim=50,
677
+ num_known_perturbations=5,
678
+ gene_dim=1000,
679
+ hidden_dims=hidden_dims,
680
+ perturbation_embedding_dim=64
681
+ )
682
+ print(f"✅ Config {i+1}: {hidden_dims} - Success")
683
+ except Exception as e:
684
+ print(f"❌ Config {i+1}: {hidden_dims} - Failed: {e}")
685
+
686
+ print("🎉 Edge case testing completed!")
687
+
688
+ if __name__ == "__main__":
689
+ # Run tests
690
+ test_single_layer_fix()
691
+ test_multi_layer_fix()
692
+ test_edge_cases()
693
+
694
+ # Example usage
695
+ print("\n🎯 Example Usage:")
696
+
697
+ # Single hidden layer example
698
+ decoder = PerturbationAwareDecoder(
699
+ latent_dim=100,
700
+ num_known_perturbations=10,
701
+ gene_dim=2000,
702
+ hidden_dims=[512], # Single hidden layer
703
+ perturbation_embedding_dim=128
704
+ )
705
+
706
+ # Generate example data
707
+ n_samples = 1000
708
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
709
+ perturbations = np.zeros((n_samples, 10))
710
+ for i in range(n_samples):
711
+ if i % 10 != 0:
712
+ perturbations[i, np.random.randint(0, 10)] = 1.0
713
+
714
+ # Simulate expression data
715
+ base_weights = np.random.randn(100, 2000) * 0.1
716
+ perturbation_effects = np.random.randn(10, 2000) * 0.5
717
+
718
+ expression_data = np.tanh(latent_data.dot(base_weights))
719
+ for i in range(n_samples):
720
+ if perturbations[i].sum() > 0:
721
+ perturb_id = np.argmax(perturbations[i])
722
+ expression_data[i] += perturbation_effects[perturb_id]
723
+
724
+ expression_data = np.maximum(expression_data, 0)
725
+
726
+ print(f"📊 Example data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}")
727
+
728
+ # Train (commented out for quick testing)
729
+ # history = decoder.train(
730
+ # train_latent=latent_data,
731
+ # train_perturbations=perturbations,
732
+ # train_expression=expression_data,
733
+ # batch_size=32,
734
+ # num_epochs=10 # Short training for testing
735
+ # )
736
+
737
+ print("🎉 All tests completed successfully!")'''