SURE-tools 2.4.7__py3-none-any.whl → 2.4.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,658 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class VirtualCellDecoder:
13
+ """
14
+ Advanced transcriptome decoder based on Virtual Cell Challenge research
15
+ Optimized for latent-to-expression mapping with biological constraints
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ gene_dim: int = 60000,
21
+ hidden_dims: List[int] = [512, 1024, 2048],
22
+ biological_prior_dim: int = 256,
23
+ dropout_rate: float = 0.1,
24
+ device: str = None):
25
+ """
26
+ State-of-the-art decoder based on Virtual Cell Challenge insights
27
+
28
+ Args:
29
+ latent_dim: Latent variable dimension (typically 50-100)
30
+ gene_dim: Number of genes (full transcriptome ~60,000)
31
+ hidden_dims: Hidden layer dimensions for progressive expansion
32
+ biological_prior_dim: Dimension for biological prior knowledge
33
+ dropout_rate: Dropout rate for regularization
34
+ device: Computation device
35
+ """
36
+ self.latent_dim = latent_dim
37
+ self.gene_dim = gene_dim
38
+ self.hidden_dims = hidden_dims
39
+ self.biological_prior_dim = biological_prior_dim
40
+ self.dropout_rate = dropout_rate
41
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
42
+
43
+ # Initialize model with biological constraints
44
+ self.model = self._build_biological_model()
45
+ self.model.to(self.device)
46
+
47
+ # Training state
48
+ self.is_trained = False
49
+ self.training_history = None
50
+ self.best_val_loss = float('inf')
51
+
52
+ print(f"🧬 VirtualCellDecoder Initialized:")
53
+ print(f" - Latent Dimension: {latent_dim}")
54
+ print(f" - Gene Dimension: {gene_dim}")
55
+ print(f" - Hidden Dimensions: {hidden_dims}")
56
+ print(f" - Biological Prior Dimension: {biological_prior_dim}")
57
+ print(f" - Device: {self.device}")
58
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
59
+
60
+ class BiologicalPriorNetwork(nn.Module):
61
+ """Biological prior network based on gene regulatory knowledge"""
62
+
63
+ def __init__(self, latent_dim: int, prior_dim: int, gene_dim: int):
64
+ super().__init__()
65
+ self.gene_dim = gene_dim
66
+ self.prior_dim = prior_dim
67
+
68
+ # Learnable gene regulatory matrix (sparse initialization)
69
+ self.regulatory_matrix = nn.Parameter(
70
+ torch.randn(gene_dim, prior_dim) * 0.01
71
+ )
72
+
73
+ # Latent to regulatory space projection
74
+ self.latent_to_regulatory = nn.Sequential(
75
+ nn.Linear(latent_dim, prior_dim * 2),
76
+ nn.ReLU(),
77
+ nn.Dropout(0.1),
78
+ nn.Linear(prior_dim * 2, prior_dim)
79
+ )
80
+
81
+ # Regulatory to expression projection
82
+ self.regulatory_to_expression = nn.Sequential(
83
+ nn.Linear(prior_dim, prior_dim),
84
+ nn.ReLU(),
85
+ nn.Dropout(0.1),
86
+ nn.Linear(prior_dim, gene_dim)
87
+ )
88
+
89
+ self._init_weights()
90
+
91
+ def _init_weights(self):
92
+ """Initialize with biological constraints"""
93
+ # Sparse initialization for regulatory matrix
94
+ nn.init.sparse_(self.regulatory_matrix, sparsity=0.8)
95
+ for module in self.modules():
96
+ if isinstance(module, nn.Linear):
97
+ nn.init.xavier_uniform_(module.weight)
98
+ if module.bias is not None:
99
+ nn.init.zeros_(module.bias)
100
+
101
+ def forward(self, latent):
102
+ batch_size = latent.shape[0]
103
+
104
+ # Project latent to regulatory space
105
+ regulatory_factors = self.latent_to_regulatory(latent) # [batch, prior_dim]
106
+
107
+ # Apply regulatory matrix (gene-specific modulation)
108
+ regulatory_effect = torch.matmul(
109
+ regulatory_factors, self.regulatory_matrix.T # [batch, gene_dim]
110
+ )
111
+
112
+ # Final expression projection
113
+ expression_base = self.regulatory_to_expression(regulatory_factors)
114
+
115
+ # Combine regulatory effect with base expression
116
+ biological_prior = expression_base + regulatory_effect
117
+
118
+ return biological_prior
119
+
120
+ class GeneSpecificAttention(nn.Module):
121
+ """Gene-specific attention mechanism for capturing co-expression patterns"""
122
+
123
+ def __init__(self, gene_dim: int, attention_dim: int = 128, num_heads: int = 8):
124
+ super().__init__()
125
+ self.gene_dim = gene_dim
126
+ self.attention_dim = attention_dim
127
+ self.num_heads = num_heads
128
+
129
+ # Gene embeddings for attention
130
+ self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, attention_dim))
131
+
132
+ # Attention mechanism
133
+ self.query_proj = nn.Linear(attention_dim, attention_dim)
134
+ self.key_proj = nn.Linear(attention_dim, attention_dim)
135
+ self.value_proj = nn.Linear(attention_dim, attention_dim)
136
+
137
+ # Output projection
138
+ self.output_proj = nn.Linear(attention_dim, attention_dim)
139
+
140
+ self._init_weights()
141
+
142
+ def _init_weights(self):
143
+ """Initialize attention weights"""
144
+ nn.init.xavier_uniform_(self.gene_embeddings)
145
+ for module in [self.query_proj, self.key_proj, self.value_proj, self.output_proj]:
146
+ nn.init.xavier_uniform_(module.weight)
147
+ if module.bias is not None:
148
+ nn.init.zeros_(module.bias)
149
+
150
+ def forward(self, x):
151
+ batch_size = x.shape[0]
152
+
153
+ # Prepare gene embeddings
154
+ gene_embeds = self.gene_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
155
+
156
+ # Compute attention
157
+ Q = self.query_proj(gene_embeds)
158
+ K = self.key_proj(gene_embeds)
159
+ V = self.value_proj(gene_embeds)
160
+
161
+ # Multi-head attention
162
+ head_dim = self.attention_dim // self.num_heads
163
+ Q = Q.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
164
+ K = K.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
165
+ V = V.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
166
+
167
+ # Scaled dot-product attention
168
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)
169
+ attn_weights = F.softmax(attn_scores, dim=-1)
170
+
171
+ # Apply attention
172
+ attn_output = torch.matmul(attn_weights, V)
173
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, self.gene_dim, self.attention_dim)
174
+
175
+ # Output projection
176
+ output = self.output_proj(attn_output)
177
+
178
+ return output
179
+
180
+ class SparseActivation(nn.Module):
181
+ """Sparse activation function for biological data"""
182
+
183
+ def __init__(self, sparsity_target: float = 0.85):
184
+ super().__init__()
185
+ self.sparsity_target = sparsity_target
186
+ self.alpha = nn.Parameter(torch.tensor(1.0))
187
+ self.beta = nn.Parameter(torch.tensor(0.0))
188
+
189
+ def forward(self, x):
190
+ # Learnable softplus with sparsity constraint
191
+ activated = F.softplus(x * self.alpha + self.beta)
192
+
193
+ # Sparsity regularization (encourages biological sparsity)
194
+ sparsity_loss = (activated.mean() - self.sparsity_target) ** 2
195
+ self.sparsity_loss = sparsity_loss * 0.01 # Light regularization
196
+
197
+ return activated
198
+
199
+ class VirtualCellModel(nn.Module):
200
+ """Main Virtual Cell Challenge inspired model"""
201
+
202
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
203
+ biological_prior_dim: int, dropout_rate: float):
204
+ super().__init__()
205
+
206
+ # Phase 1: Latent expansion with biological constraints
207
+ self.latent_expansion = nn.Sequential(
208
+ nn.Linear(latent_dim, hidden_dims[0]),
209
+ nn.BatchNorm1d(hidden_dims[0]),
210
+ nn.ReLU(),
211
+ nn.Dropout(dropout_rate),
212
+ nn.Linear(hidden_dims[0], hidden_dims[1]),
213
+ nn.BatchNorm1d(hidden_dims[1]),
214
+ nn.ReLU(),
215
+ nn.Dropout(dropout_rate)
216
+ )
217
+
218
+ # Phase 2: Biological prior network
219
+ self.biological_prior = VirtualCellDecoder.BiologicalPriorNetwork(
220
+ hidden_dims[1], biological_prior_dim, gene_dim
221
+ )
222
+
223
+ # Phase 3: Gene-specific processing
224
+ self.gene_attention = VirtualCellDecoder.GeneSpecificAttention(gene_dim)
225
+
226
+ # Phase 4: Final expression refinement
227
+ self.expression_refinement = nn.Sequential(
228
+ nn.Linear(gene_dim, hidden_dims[2]),
229
+ nn.BatchNorm1d(hidden_dims[2]),
230
+ nn.ReLU(),
231
+ nn.Dropout(dropout_rate),
232
+ nn.Linear(hidden_dims[2], gene_dim)
233
+ )
234
+
235
+ # Phase 5: Sparse activation
236
+ self.sparse_activation = VirtualCellDecoder.SparseActivation()
237
+
238
+ self._init_weights()
239
+
240
+ def _init_weights(self):
241
+ """Biological-inspired weight initialization"""
242
+ for module in self.modules():
243
+ if isinstance(module, nn.Linear):
244
+ # Xavier initialization for stable training
245
+ nn.init.xavier_uniform_(module.weight)
246
+ if module.bias is not None:
247
+ nn.init.zeros_(module.bias)
248
+ elif isinstance(module, nn.BatchNorm1d):
249
+ nn.init.ones_(module.weight)
250
+ nn.init.zeros_(module.bias)
251
+
252
+ def forward(self, latent):
253
+ # Phase 1: Latent expansion
254
+ expanded_latent = self.latent_expansion(latent)
255
+
256
+ # Phase 2: Biological prior
257
+ biological_output = self.biological_prior(expanded_latent)
258
+
259
+ # Phase 3: Gene attention
260
+ attention_output = self.gene_attention(biological_output)
261
+
262
+ # Phase 4: Refinement with residual connection
263
+ refined_output = self.expression_refinement(attention_output) + biological_output
264
+
265
+ # Phase 5: Sparse activation
266
+ final_output = self.sparse_activation(refined_output)
267
+
268
+ return final_output
269
+
270
+ def _build_biological_model(self):
271
+ """Build the biologically constrained model"""
272
+ return self.VirtualCellModel(
273
+ self.latent_dim, self.gene_dim, self.hidden_dims,
274
+ self.biological_prior_dim, self.dropout_rate
275
+ )
276
+
277
+ def train(self,
278
+ train_latent: np.ndarray,
279
+ train_expression: np.ndarray,
280
+ val_latent: np.ndarray = None,
281
+ val_expression: np.ndarray = None,
282
+ batch_size: int = 32,
283
+ num_epochs: int = 200,
284
+ learning_rate: float = 1e-4,
285
+ biological_weight: float = 0.1,
286
+ checkpoint_path: str = 'virtual_cell_decoder.pth') -> Dict:
287
+ """
288
+ Train with biological constraints and Virtual Cell Challenge insights
289
+
290
+ Args:
291
+ train_latent: Training latent variables
292
+ train_expression: Training expression data
293
+ val_latent: Validation latent variables
294
+ val_expression: Validation expression data
295
+ batch_size: Batch size optimized for biological data
296
+ num_epochs: Number of training epochs
297
+ learning_rate: Learning rate
298
+ biological_weight: Weight for biological constraint loss
299
+ checkpoint_path: Model save path
300
+ """
301
+ print("🧬 Starting Virtual Cell Challenge Training...")
302
+ print("📚 Incorporating biological constraints and regulatory priors")
303
+
304
+ # Data preparation
305
+ train_dataset = self._create_dataset(train_latent, train_expression)
306
+
307
+ if val_latent is not None and val_expression is not None:
308
+ val_dataset = self._create_dataset(val_latent, val_expression)
309
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
310
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
311
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
312
+ else:
313
+ # Auto split (90/10)
314
+ train_size = int(0.9 * len(train_dataset))
315
+ val_size = len(train_dataset) - train_size
316
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
317
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
318
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
319
+ print(f"📈 Auto-split validation: {val_size} samples")
320
+
321
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
322
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
323
+ print(f"📊 Batch size: {batch_size}")
324
+
325
+ # Optimizer with biological regularization
326
+ optimizer = optim.AdamW(
327
+ self.model.parameters(),
328
+ lr=learning_rate,
329
+ weight_decay=1e-5, # L2 regularization
330
+ betas=(0.9, 0.999)
331
+ )
332
+
333
+ # Cosine annealing with warmup
334
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
335
+ optimizer, T_0=50, T_mult=2, eta_min=1e-6
336
+ )
337
+
338
+ # Biological loss function
339
+ def biological_loss(pred, target):
340
+ # 1. Reconstruction loss
341
+ mse_loss = F.mse_loss(pred, target)
342
+
343
+ # 2. Poisson loss for count data
344
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
345
+
346
+ # 3. Correlation loss for pattern matching
347
+ correlation = self._pearson_correlation(pred, target)
348
+ correlation_loss = 1 - correlation
349
+
350
+ # 4. Sparsity loss (biological constraint)
351
+ sparsity_loss = self.model.sparse_activation.sparsity_loss
352
+
353
+ # 5. Biological consistency loss
354
+ biological_loss = self._biological_consistency_loss(pred)
355
+
356
+ total_loss = (mse_loss + 0.5 * poisson_loss + 0.3 * correlation_loss +
357
+ 0.1 * sparsity_loss + biological_weight * biological_loss)
358
+
359
+ return total_loss, {
360
+ 'mse': mse_loss.item(),
361
+ 'poisson': poisson_loss.item(),
362
+ 'correlation': correlation.item(),
363
+ 'sparsity': sparsity_loss.item(),
364
+ 'biological': biological_loss.item()
365
+ }
366
+
367
+ # Training history
368
+ history = {
369
+ 'train_loss': [], 'val_loss': [],
370
+ 'train_mse': [], 'val_mse': [],
371
+ 'train_correlation': [], 'val_correlation': [],
372
+ 'train_sparsity': [], 'val_sparsity': [],
373
+ 'learning_rates': [], 'grad_norms': []
374
+ }
375
+
376
+ best_val_loss = float('inf')
377
+ patience = 25
378
+ patience_counter = 0
379
+
380
+ print("\n🔬 Starting training with biological constraints...")
381
+ for epoch in range(1, num_epochs + 1):
382
+ # Training phase
383
+ train_loss, train_components, grad_norm = self._train_epoch(
384
+ train_loader, optimizer, biological_loss
385
+ )
386
+
387
+ # Validation phase
388
+ val_loss, val_components = self._validate_epoch(val_loader, biological_loss)
389
+
390
+ # Update scheduler
391
+ scheduler.step()
392
+ current_lr = optimizer.param_groups[0]['lr']
393
+
394
+ # Record history
395
+ history['train_loss'].append(train_loss)
396
+ history['val_loss'].append(val_loss)
397
+ history['train_mse'].append(train_components['mse'])
398
+ history['val_mse'].append(val_components['mse'])
399
+ history['train_correlation'].append(train_components['correlation'])
400
+ history['val_correlation'].append(val_components['correlation'])
401
+ history['train_sparsity'].append(train_components['sparsity'])
402
+ history['val_sparsity'].append(val_components['sparsity'])
403
+ history['learning_rates'].append(current_lr)
404
+ history['grad_norms'].append(grad_norm)
405
+
406
+ # Print detailed progress
407
+ if epoch % 10 == 0 or epoch == 1:
408
+ print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
409
+ f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
410
+ f"Corr: {val_components['correlation']:.4f} | "
411
+ f"Sparsity: {val_components['sparsity']:.4f} | "
412
+ f"LR: {current_lr:.2e}")
413
+
414
+ # Early stopping and model saving
415
+ if val_loss < best_val_loss:
416
+ best_val_loss = val_loss
417
+ patience_counter = 0
418
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
419
+ if epoch % 20 == 0:
420
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
421
+ else:
422
+ patience_counter += 1
423
+ if patience_counter >= patience:
424
+ print(f"🛑 Early stopping at epoch {epoch}")
425
+ break
426
+
427
+ # Training completed
428
+ self.is_trained = True
429
+ self.training_history = history
430
+ self.best_val_loss = best_val_loss
431
+
432
+ print(f"\n🎉 Training completed!")
433
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
434
+ print(f"📊 Final correlation: {history['val_correlation'][-1]:.4f}")
435
+ print(f"🌿 Final sparsity: {history['val_sparsity'][-1]:.4f}")
436
+
437
+ return history
438
+
439
+ def _biological_consistency_loss(self, pred):
440
+ """Biological consistency loss based on Virtual Cell Challenge insights"""
441
+ # 1. Gene expression variance consistency
442
+ gene_variance = pred.var(dim=0)
443
+ target_variance = torch.ones_like(gene_variance) * 0.5 # Reasonable biological variance
444
+ variance_loss = F.mse_loss(gene_variance, target_variance)
445
+
446
+ # 2. Co-expression pattern consistency
447
+ correlation_matrix = torch.corrcoef(pred.T)
448
+ correlation_loss = torch.mean(torch.abs(correlation_matrix)) # Encourage moderate correlations
449
+
450
+ return variance_loss + 0.5 * correlation_loss
451
+
452
+ def _create_dataset(self, latent_data, expression_data):
453
+ """Create dataset with biological data validation"""
454
+ class BiologicalDataset(Dataset):
455
+ def __init__(self, latent, expression):
456
+ # Validate biological data characteristics
457
+ assert np.all(expression >= 0), "Expression data must be non-negative"
458
+ assert np.mean(expression == 0) > 0.7, "Expression data should be sparse (typical scRNA-seq)"
459
+
460
+ self.latent = torch.FloatTensor(latent)
461
+ self.expression = torch.FloatTensor(expression)
462
+
463
+ def __len__(self):
464
+ return len(self.latent)
465
+
466
+ def __getitem__(self, idx):
467
+ return self.latent[idx], self.expression[idx]
468
+
469
+ return BiologicalDataset(latent_data, expression_data)
470
+
471
+ def _pearson_correlation(self, pred, target):
472
+ """Calculate Pearson correlation coefficient"""
473
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
474
+ target_centered = target - target.mean(dim=1, keepdim=True)
475
+
476
+ numerator = (pred_centered * target_centered).sum(dim=1)
477
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
478
+
479
+ return (numerator / (denominator + 1e-8)).mean()
480
+
481
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
482
+ """Train one epoch with biological constraints"""
483
+ self.model.train()
484
+ total_loss = 0
485
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
486
+ grad_norms = []
487
+
488
+ for latent, target in train_loader:
489
+ latent = latent.to(self.device, non_blocking=True)
490
+ target = target.to(self.device, non_blocking=True)
491
+
492
+ optimizer.zero_grad()
493
+ pred = self.model(latent)
494
+
495
+ loss, components = loss_fn(pred, target)
496
+ loss.backward()
497
+
498
+ # Gradient clipping for stability
499
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
500
+ optimizer.step()
501
+
502
+ total_loss += loss.item()
503
+ for key in components:
504
+ total_components[key] += components[key]
505
+ grad_norms.append(grad_norm.item())
506
+
507
+ num_batches = len(train_loader)
508
+ avg_loss = total_loss / num_batches
509
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
510
+ avg_grad_norm = np.mean(grad_norms)
511
+
512
+ return avg_loss, avg_components, avg_grad_norm
513
+
514
+ def _validate_epoch(self, val_loader, loss_fn):
515
+ """Validate one epoch"""
516
+ self.model.eval()
517
+ total_loss = 0
518
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
519
+
520
+ with torch.no_grad():
521
+ for latent, target in val_loader:
522
+ latent = latent.to(self.device, non_blocking=True)
523
+ target = target.to(self.device, non_blocking=True)
524
+
525
+ pred = self.model(latent)
526
+ loss, components = loss_fn(pred, target)
527
+
528
+ total_loss += loss.item()
529
+ for key in components:
530
+ total_components[key] += components[key]
531
+
532
+ num_batches = len(val_loader)
533
+ avg_loss = total_loss / num_batches
534
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
535
+
536
+ return avg_loss, avg_components
537
+
538
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
539
+ """Save model checkpoint"""
540
+ torch.save({
541
+ 'epoch': epoch,
542
+ 'model_state_dict': self.model.state_dict(),
543
+ 'optimizer_state_dict': optimizer.state_dict(),
544
+ 'scheduler_state_dict': scheduler.state_dict(),
545
+ 'best_val_loss': best_loss,
546
+ 'training_history': history,
547
+ 'model_config': {
548
+ 'latent_dim': self.latent_dim,
549
+ 'gene_dim': self.gene_dim,
550
+ 'hidden_dims': self.hidden_dims,
551
+ 'biological_prior_dim': self.biological_prior_dim,
552
+ 'dropout_rate': self.dropout_rate
553
+ }
554
+ }, path)
555
+
556
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
557
+ """
558
+ Predict gene expression with biological constraints
559
+
560
+ Args:
561
+ latent_data: Latent variables [n_samples, latent_dim]
562
+ batch_size: Prediction batch size
563
+
564
+ Returns:
565
+ expression: Predicted expression [n_samples, gene_dim]
566
+ """
567
+ if not self.is_trained:
568
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
569
+
570
+ self.model.eval()
571
+
572
+ if isinstance(latent_data, np.ndarray):
573
+ latent_data = torch.FloatTensor(latent_data)
574
+
575
+ # Predict in batches
576
+ predictions = []
577
+ with torch.no_grad():
578
+ for i in range(0, len(latent_data), batch_size):
579
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
580
+ batch_pred = self.model(batch_latent)
581
+ predictions.append(batch_pred.cpu())
582
+
583
+ return torch.cat(predictions).numpy()
584
+
585
+ def load_model(self, model_path: str):
586
+ """Load pre-trained model"""
587
+ checkpoint = torch.load(model_path, map_location=self.device)
588
+ self.model.load_state_dict(checkpoint['model_state_dict'])
589
+ self.is_trained = True
590
+ self.training_history = checkpoint.get('training_history')
591
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
592
+ print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
593
+
594
+ def get_model_info(self) -> Dict:
595
+ """Get model information"""
596
+ return {
597
+ 'is_trained': self.is_trained,
598
+ 'best_val_loss': self.best_val_loss,
599
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
600
+ 'latent_dim': self.latent_dim,
601
+ 'gene_dim': self.gene_dim,
602
+ 'hidden_dims': self.hidden_dims,
603
+ 'biological_prior_dim': self.biological_prior_dim,
604
+ 'device': str(self.device)
605
+ }
606
+
607
+ '''
608
+ # Example usage
609
+ def example_usage():
610
+ """Example demonstration of Virtual Cell Challenge decoder"""
611
+
612
+ # Initialize decoder
613
+ decoder = VirtualCellDecoder(
614
+ latent_dim=100,
615
+ gene_dim=2000, # Reduced for example
616
+ hidden_dims=[256, 512, 1024],
617
+ biological_prior_dim=128,
618
+ dropout_rate=0.1
619
+ )
620
+
621
+ # Generate example data with biological characteristics
622
+ n_samples = 1000
623
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
624
+
625
+ # Simulate biological expression data (sparse, non-negative)
626
+ weights = np.random.randn(100, 2000) * 0.1
627
+ expression_data = np.tanh(latent_data.dot(weights))
628
+ expression_data = np.maximum(expression_data, 0)
629
+
630
+ # Add biological sparsity (typical scRNA-seq characteristics)
631
+ mask = np.random.random(expression_data.shape) > 0.8 # 80% sparsity
632
+ expression_data[mask] = 0
633
+
634
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
635
+ print(f"🌿 Biological sparsity: {(expression_data == 0).mean():.3f}")
636
+
637
+ # Train with biological constraints
638
+ history = decoder.train(
639
+ train_latent=latent_data,
640
+ train_expression=expression_data,
641
+ batch_size=32,
642
+ num_epochs=50,
643
+ learning_rate=1e-4,
644
+ biological_weight=0.1
645
+ )
646
+
647
+ # Predict
648
+ test_latent = np.random.randn(10, 100).astype(np.float32)
649
+ predictions = decoder.predict(test_latent)
650
+ print(f"🔮 Prediction shape: {predictions.shape}")
651
+ print(f"🌿 Predicted sparsity: {(predictions < 0.1).mean():.3f}")
652
+
653
+ return decoder
654
+
655
+ if __name__ == "__main__":
656
+ example_usage()
657
+
658
+ '''
SURE/__init__.py CHANGED
@@ -1,16 +1,28 @@
1
1
  from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
+ from .DensityFlowLinear import DensityFlowLinear
3
4
  from .PerturbE import PerturbE
4
5
  from .TranscriptomeDecoder import TranscriptomeDecoder
6
+ from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
7
+ from .EfficientTranscriptomeDecoder import EfficientTranscriptomeDecoder
8
+ from .VirtualCellDecoder import VirtualCellDecoder
9
+ from .PerturbationAwareDecoder import PerturbationAwareDecoder
5
10
 
6
11
  from . import utils
7
12
  from . import codebook
8
13
  from . import SURE
9
14
  from . import DensityFlow
15
+ from . import DensityFlowLinear
10
16
  from . import atac
11
17
  from . import flow
12
18
  from . import perturb
13
19
  from . import PerturbE
14
20
  from . import TranscriptomeDecoder
21
+ from . import SimpleTranscriptomeDecoder
22
+ from . import EfficientTranscriptomeDecoder
23
+ from . import VirtualCellDecoder
24
+ from . import PerturbationAwareDecoder
15
25
 
16
- __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
26
+ __all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
27
+ 'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
28
+ 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.7
3
+ Version: 2.4.42
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng