SURE-tools 2.4.5__py3-none-any.whl → 2.4.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -0,0 +1,659 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class VirtualCellDecoder:
13
+ """
14
+ Advanced transcriptome decoder based on Virtual Cell Challenge research
15
+ Optimized for latent-to-expression mapping with biological constraints
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ gene_dim: int = 60000,
21
+ hidden_dims: List[int] = [512, 1024, 2048],
22
+ biological_prior_dim: int = 256,
23
+ dropout_rate: float = 0.1,
24
+ device: str = None):
25
+ """
26
+ State-of-the-art decoder based on Virtual Cell Challenge insights
27
+
28
+ Args:
29
+ latent_dim: Latent variable dimension (typically 50-100)
30
+ gene_dim: Number of genes (full transcriptome ~60,000)
31
+ hidden_dims: Hidden layer dimensions for progressive expansion
32
+ biological_prior_dim: Dimension for biological prior knowledge
33
+ dropout_rate: Dropout rate for regularization
34
+ device: Computation device
35
+ """
36
+ self.latent_dim = latent_dim
37
+ self.gene_dim = gene_dim
38
+ self.hidden_dims = hidden_dims
39
+ self.biological_prior_dim = biological_prior_dim
40
+ self.dropout_rate = dropout_rate
41
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
42
+
43
+ # Initialize model with biological constraints
44
+ self.model = self._build_biological_model()
45
+ self.model.to(self.device)
46
+
47
+ # Training state
48
+ self.is_trained = False
49
+ self.training_history = None
50
+ self.best_val_loss = float('inf')
51
+
52
+ print(f"🧬 VirtualCellDecoder Initialized:")
53
+ print(f" - Latent Dimension: {latent_dim}")
54
+ print(f" - Gene Dimension: {gene_dim}")
55
+ print(f" - Hidden Dimensions: {hidden_dims}")
56
+ print(f" - Biological Prior Dimension: {biological_prior_dim}")
57
+ print(f" - Device: {self.device}")
58
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
59
+ print(f" - Estimated GPU Memory: ~8-10GB (optimized for 20GB)")
60
+
61
+ class BiologicalPriorNetwork(nn.Module):
62
+ """Biological prior network based on gene regulatory knowledge"""
63
+
64
+ def __init__(self, latent_dim: int, prior_dim: int, gene_dim: int):
65
+ super().__init__()
66
+ self.gene_dim = gene_dim
67
+ self.prior_dim = prior_dim
68
+
69
+ # Learnable gene regulatory matrix (sparse initialization)
70
+ self.regulatory_matrix = nn.Parameter(
71
+ torch.randn(gene_dim, prior_dim) * 0.01
72
+ )
73
+
74
+ # Latent to regulatory space projection
75
+ self.latent_to_regulatory = nn.Sequential(
76
+ nn.Linear(latent_dim, prior_dim * 2),
77
+ nn.ReLU(),
78
+ nn.Dropout(0.1),
79
+ nn.Linear(prior_dim * 2, prior_dim)
80
+ )
81
+
82
+ # Regulatory to expression projection
83
+ self.regulatory_to_expression = nn.Sequential(
84
+ nn.Linear(prior_dim, prior_dim),
85
+ nn.ReLU(),
86
+ nn.Dropout(0.1),
87
+ nn.Linear(prior_dim, gene_dim)
88
+ )
89
+
90
+ self._init_weights()
91
+
92
+ def _init_weights(self):
93
+ """Initialize with biological constraints"""
94
+ # Sparse initialization for regulatory matrix
95
+ nn.init.sparse_(self.regulatory_matrix, sparsity=0.8)
96
+ for module in self.modules():
97
+ if isinstance(module, nn.Linear):
98
+ nn.init.xavier_uniform_(module.weight)
99
+ if module.bias is not None:
100
+ nn.init.zeros_(module.bias)
101
+
102
+ def forward(self, latent):
103
+ batch_size = latent.shape[0]
104
+
105
+ # Project latent to regulatory space
106
+ regulatory_factors = self.latent_to_regulatory(latent) # [batch, prior_dim]
107
+
108
+ # Apply regulatory matrix (gene-specific modulation)
109
+ regulatory_effect = torch.matmul(
110
+ regulatory_factors, self.regulatory_matrix.T # [batch, gene_dim]
111
+ )
112
+
113
+ # Final expression projection
114
+ expression_base = self.regulatory_to_expression(regulatory_factors)
115
+
116
+ # Combine regulatory effect with base expression
117
+ biological_prior = expression_base + regulatory_effect
118
+
119
+ return biological_prior
120
+
121
+ class GeneSpecificAttention(nn.Module):
122
+ """Gene-specific attention mechanism for capturing co-expression patterns"""
123
+
124
+ def __init__(self, gene_dim: int, attention_dim: int = 128, num_heads: int = 8):
125
+ super().__init__()
126
+ self.gene_dim = gene_dim
127
+ self.attention_dim = attention_dim
128
+ self.num_heads = num_heads
129
+
130
+ # Gene embeddings for attention
131
+ self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, attention_dim))
132
+
133
+ # Attention mechanism
134
+ self.query_proj = nn.Linear(attention_dim, attention_dim)
135
+ self.key_proj = nn.Linear(attention_dim, attention_dim)
136
+ self.value_proj = nn.Linear(attention_dim, attention_dim)
137
+
138
+ # Output projection
139
+ self.output_proj = nn.Linear(attention_dim, attention_dim)
140
+
141
+ self._init_weights()
142
+
143
+ def _init_weights(self):
144
+ """Initialize attention weights"""
145
+ nn.init.xavier_uniform_(self.gene_embeddings)
146
+ for module in [self.query_proj, self.key_proj, self.value_proj, self.output_proj]:
147
+ nn.init.xavier_uniform_(module.weight)
148
+ if module.bias is not None:
149
+ nn.init.zeros_(module.bias)
150
+
151
+ def forward(self, x):
152
+ batch_size = x.shape[0]
153
+
154
+ # Prepare gene embeddings
155
+ gene_embeds = self.gene_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
156
+
157
+ # Compute attention
158
+ Q = self.query_proj(gene_embeds)
159
+ K = self.key_proj(gene_embeds)
160
+ V = self.value_proj(gene_embeds)
161
+
162
+ # Multi-head attention
163
+ head_dim = self.attention_dim // self.num_heads
164
+ Q = Q.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
165
+ K = K.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
166
+ V = V.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
167
+
168
+ # Scaled dot-product attention
169
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)
170
+ attn_weights = F.softmax(attn_scores, dim=-1)
171
+
172
+ # Apply attention
173
+ attn_output = torch.matmul(attn_weights, V)
174
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, self.gene_dim, self.attention_dim)
175
+
176
+ # Output projection
177
+ output = self.output_proj(attn_output)
178
+
179
+ return output
180
+
181
+ class SparseActivation(nn.Module):
182
+ """Sparse activation function for biological data"""
183
+
184
+ def __init__(self, sparsity_target: float = 0.85):
185
+ super().__init__()
186
+ self.sparsity_target = sparsity_target
187
+ self.alpha = nn.Parameter(torch.tensor(1.0))
188
+ self.beta = nn.Parameter(torch.tensor(0.0))
189
+
190
+ def forward(self, x):
191
+ # Learnable softplus with sparsity constraint
192
+ activated = F.softplus(x * self.alpha + self.beta)
193
+
194
+ # Sparsity regularization (encourages biological sparsity)
195
+ sparsity_loss = (activated.mean() - self.sparsity_target) ** 2
196
+ self.sparsity_loss = sparsity_loss * 0.01 # Light regularization
197
+
198
+ return activated
199
+
200
+ class VirtualCellModel(nn.Module):
201
+ """Main Virtual Cell Challenge inspired model"""
202
+
203
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
204
+ biological_prior_dim: int, dropout_rate: float):
205
+ super().__init__()
206
+
207
+ # Phase 1: Latent expansion with biological constraints
208
+ self.latent_expansion = nn.Sequential(
209
+ nn.Linear(latent_dim, hidden_dims[0]),
210
+ nn.BatchNorm1d(hidden_dims[0]),
211
+ nn.ReLU(),
212
+ nn.Dropout(dropout_rate),
213
+ nn.Linear(hidden_dims[0], hidden_dims[1]),
214
+ nn.BatchNorm1d(hidden_dims[1]),
215
+ nn.ReLU(),
216
+ nn.Dropout(dropout_rate)
217
+ )
218
+
219
+ # Phase 2: Biological prior network
220
+ self.biological_prior = VirtualCellDecoder.BiologicalPriorNetwork(
221
+ hidden_dims[1], biological_prior_dim, gene_dim
222
+ )
223
+
224
+ # Phase 3: Gene-specific processing
225
+ self.gene_attention = VirtualCellDecoder.GeneSpecificAttention(gene_dim)
226
+
227
+ # Phase 4: Final expression refinement
228
+ self.expression_refinement = nn.Sequential(
229
+ nn.Linear(gene_dim, hidden_dims[2]),
230
+ nn.BatchNorm1d(hidden_dims[2]),
231
+ nn.ReLU(),
232
+ nn.Dropout(dropout_rate),
233
+ nn.Linear(hidden_dims[2], gene_dim)
234
+ )
235
+
236
+ # Phase 5: Sparse activation
237
+ self.sparse_activation = VirtualCellDecoder.SparseActivation()
238
+
239
+ self._init_weights()
240
+
241
+ def _init_weights(self):
242
+ """Biological-inspired weight initialization"""
243
+ for module in self.modules():
244
+ if isinstance(module, nn.Linear):
245
+ # Xavier initialization for stable training
246
+ nn.init.xavier_uniform_(module.weight)
247
+ if module.bias is not None:
248
+ nn.init.zeros_(module.bias)
249
+ elif isinstance(module, nn.BatchNorm1d):
250
+ nn.init.ones_(module.weight)
251
+ nn.init.zeros_(module.bias)
252
+
253
+ def forward(self, latent):
254
+ # Phase 1: Latent expansion
255
+ expanded_latent = self.latent_expansion(latent)
256
+
257
+ # Phase 2: Biological prior
258
+ biological_output = self.biological_prior(expanded_latent)
259
+
260
+ # Phase 3: Gene attention
261
+ attention_output = self.gene_attention(biological_output)
262
+
263
+ # Phase 4: Refinement with residual connection
264
+ refined_output = self.expression_refinement(attention_output) + biological_output
265
+
266
+ # Phase 5: Sparse activation
267
+ final_output = self.sparse_activation(refined_output)
268
+
269
+ return final_output
270
+
271
+ def _build_biological_model(self):
272
+ """Build the biologically constrained model"""
273
+ return self.VirtualCellModel(
274
+ self.latent_dim, self.gene_dim, self.hidden_dims,
275
+ self.biological_prior_dim, self.dropout_rate
276
+ )
277
+
278
+ def train(self,
279
+ train_latent: np.ndarray,
280
+ train_expression: np.ndarray,
281
+ val_latent: np.ndarray = None,
282
+ val_expression: np.ndarray = None,
283
+ batch_size: int = 32,
284
+ num_epochs: int = 200,
285
+ learning_rate: float = 1e-4,
286
+ biological_weight: float = 0.1,
287
+ checkpoint_path: str = 'virtual_cell_decoder.pth') -> Dict:
288
+ """
289
+ Train with biological constraints and Virtual Cell Challenge insights
290
+
291
+ Args:
292
+ train_latent: Training latent variables
293
+ train_expression: Training expression data
294
+ val_latent: Validation latent variables
295
+ val_expression: Validation expression data
296
+ batch_size: Batch size optimized for biological data
297
+ num_epochs: Number of training epochs
298
+ learning_rate: Learning rate
299
+ biological_weight: Weight for biological constraint loss
300
+ checkpoint_path: Model save path
301
+ """
302
+ print("🧬 Starting Virtual Cell Challenge Training...")
303
+ print("📚 Incorporating biological constraints and regulatory priors")
304
+
305
+ # Data preparation
306
+ train_dataset = self._create_dataset(train_latent, train_expression)
307
+
308
+ if val_latent is not None and val_expression is not None:
309
+ val_dataset = self._create_dataset(val_latent, val_expression)
310
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
311
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
312
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
313
+ else:
314
+ # Auto split (90/10)
315
+ train_size = int(0.9 * len(train_dataset))
316
+ val_size = len(train_dataset) - train_size
317
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
318
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
319
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
320
+ print(f"📈 Auto-split validation: {val_size} samples")
321
+
322
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
323
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
324
+ print(f"📊 Batch size: {batch_size}")
325
+
326
+ # Optimizer with biological regularization
327
+ optimizer = optim.AdamW(
328
+ self.model.parameters(),
329
+ lr=learning_rate,
330
+ weight_decay=1e-5, # L2 regularization
331
+ betas=(0.9, 0.999)
332
+ )
333
+
334
+ # Cosine annealing with warmup
335
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
336
+ optimizer, T_0=50, T_mult=2, eta_min=1e-6
337
+ )
338
+
339
+ # Biological loss function
340
+ def biological_loss(pred, target):
341
+ # 1. Reconstruction loss
342
+ mse_loss = F.mse_loss(pred, target)
343
+
344
+ # 2. Poisson loss for count data
345
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
346
+
347
+ # 3. Correlation loss for pattern matching
348
+ correlation = self._pearson_correlation(pred, target)
349
+ correlation_loss = 1 - correlation
350
+
351
+ # 4. Sparsity loss (biological constraint)
352
+ sparsity_loss = self.model.sparse_activation.sparsity_loss
353
+
354
+ # 5. Biological consistency loss
355
+ biological_loss = self._biological_consistency_loss(pred)
356
+
357
+ total_loss = (mse_loss + 0.5 * poisson_loss + 0.3 * correlation_loss +
358
+ 0.1 * sparsity_loss + biological_weight * biological_loss)
359
+
360
+ return total_loss, {
361
+ 'mse': mse_loss.item(),
362
+ 'poisson': poisson_loss.item(),
363
+ 'correlation': correlation.item(),
364
+ 'sparsity': sparsity_loss.item(),
365
+ 'biological': biological_loss.item()
366
+ }
367
+
368
+ # Training history
369
+ history = {
370
+ 'train_loss': [], 'val_loss': [],
371
+ 'train_mse': [], 'val_mse': [],
372
+ 'train_correlation': [], 'val_correlation': [],
373
+ 'train_sparsity': [], 'val_sparsity': [],
374
+ 'learning_rates': [], 'grad_norms': []
375
+ }
376
+
377
+ best_val_loss = float('inf')
378
+ patience = 25
379
+ patience_counter = 0
380
+
381
+ print("\n🔬 Starting training with biological constraints...")
382
+ for epoch in range(1, num_epochs + 1):
383
+ # Training phase
384
+ train_loss, train_components, grad_norm = self._train_epoch(
385
+ train_loader, optimizer, biological_loss
386
+ )
387
+
388
+ # Validation phase
389
+ val_loss, val_components = self._validate_epoch(val_loader, biological_loss)
390
+
391
+ # Update scheduler
392
+ scheduler.step()
393
+ current_lr = optimizer.param_groups[0]['lr']
394
+
395
+ # Record history
396
+ history['train_loss'].append(train_loss)
397
+ history['val_loss'].append(val_loss)
398
+ history['train_mse'].append(train_components['mse'])
399
+ history['val_mse'].append(val_components['mse'])
400
+ history['train_correlation'].append(train_components['correlation'])
401
+ history['val_correlation'].append(val_components['correlation'])
402
+ history['train_sparsity'].append(train_components['sparsity'])
403
+ history['val_sparsity'].append(val_components['sparsity'])
404
+ history['learning_rates'].append(current_lr)
405
+ history['grad_norms'].append(grad_norm)
406
+
407
+ # Print detailed progress
408
+ if epoch % 10 == 0 or epoch == 1:
409
+ print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
410
+ f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
411
+ f"Corr: {val_components['correlation']:.4f} | "
412
+ f"Sparsity: {val_components['sparsity']:.4f} | "
413
+ f"LR: {current_lr:.2e}")
414
+
415
+ # Early stopping and model saving
416
+ if val_loss < best_val_loss:
417
+ best_val_loss = val_loss
418
+ patience_counter = 0
419
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
420
+ if epoch % 20 == 0:
421
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
422
+ else:
423
+ patience_counter += 1
424
+ if patience_counter >= patience:
425
+ print(f"🛑 Early stopping at epoch {epoch}")
426
+ break
427
+
428
+ # Training completed
429
+ self.is_trained = True
430
+ self.training_history = history
431
+ self.best_val_loss = best_val_loss
432
+
433
+ print(f"\n🎉 Training completed!")
434
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
435
+ print(f"📊 Final correlation: {history['val_correlation'][-1]:.4f}")
436
+ print(f"🌿 Final sparsity: {history['val_sparsity'][-1]:.4f}")
437
+
438
+ return history
439
+
440
+ def _biological_consistency_loss(self, pred):
441
+ """Biological consistency loss based on Virtual Cell Challenge insights"""
442
+ # 1. Gene expression variance consistency
443
+ gene_variance = pred.var(dim=0)
444
+ target_variance = torch.ones_like(gene_variance) * 0.5 # Reasonable biological variance
445
+ variance_loss = F.mse_loss(gene_variance, target_variance)
446
+
447
+ # 2. Co-expression pattern consistency
448
+ correlation_matrix = torch.corrcoef(pred.T)
449
+ correlation_loss = torch.mean(torch.abs(correlation_matrix)) # Encourage moderate correlations
450
+
451
+ return variance_loss + 0.5 * correlation_loss
452
+
453
+ def _create_dataset(self, latent_data, expression_data):
454
+ """Create dataset with biological data validation"""
455
+ class BiologicalDataset(Dataset):
456
+ def __init__(self, latent, expression):
457
+ # Validate biological data characteristics
458
+ assert np.all(expression >= 0), "Expression data must be non-negative"
459
+ assert np.mean(expression == 0) > 0.7, "Expression data should be sparse (typical scRNA-seq)"
460
+
461
+ self.latent = torch.FloatTensor(latent)
462
+ self.expression = torch.FloatTensor(expression)
463
+
464
+ def __len__(self):
465
+ return len(self.latent)
466
+
467
+ def __getitem__(self, idx):
468
+ return self.latent[idx], self.expression[idx]
469
+
470
+ return BiologicalDataset(latent_data, expression_data)
471
+
472
+ def _pearson_correlation(self, pred, target):
473
+ """Calculate Pearson correlation coefficient"""
474
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
475
+ target_centered = target - target.mean(dim=1, keepdim=True)
476
+
477
+ numerator = (pred_centered * target_centered).sum(dim=1)
478
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
479
+
480
+ return (numerator / (denominator + 1e-8)).mean()
481
+
482
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
483
+ """Train one epoch with biological constraints"""
484
+ self.model.train()
485
+ total_loss = 0
486
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
487
+ grad_norms = []
488
+
489
+ for latent, target in train_loader:
490
+ latent = latent.to(self.device, non_blocking=True)
491
+ target = target.to(self.device, non_blocking=True)
492
+
493
+ optimizer.zero_grad()
494
+ pred = self.model(latent)
495
+
496
+ loss, components = loss_fn(pred, target)
497
+ loss.backward()
498
+
499
+ # Gradient clipping for stability
500
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
501
+ optimizer.step()
502
+
503
+ total_loss += loss.item()
504
+ for key in components:
505
+ total_components[key] += components[key]
506
+ grad_norms.append(grad_norm.item())
507
+
508
+ num_batches = len(train_loader)
509
+ avg_loss = total_loss / num_batches
510
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
511
+ avg_grad_norm = np.mean(grad_norms)
512
+
513
+ return avg_loss, avg_components, avg_grad_norm
514
+
515
+ def _validate_epoch(self, val_loader, loss_fn):
516
+ """Validate one epoch"""
517
+ self.model.eval()
518
+ total_loss = 0
519
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
520
+
521
+ with torch.no_grad():
522
+ for latent, target in val_loader:
523
+ latent = latent.to(self.device, non_blocking=True)
524
+ target = target.to(self.device, non_blocking=True)
525
+
526
+ pred = self.model(latent)
527
+ loss, components = loss_fn(pred, target)
528
+
529
+ total_loss += loss.item()
530
+ for key in components:
531
+ total_components[key] += components[key]
532
+
533
+ num_batches = len(val_loader)
534
+ avg_loss = total_loss / num_batches
535
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
536
+
537
+ return avg_loss, avg_components
538
+
539
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
540
+ """Save model checkpoint"""
541
+ torch.save({
542
+ 'epoch': epoch,
543
+ 'model_state_dict': self.model.state_dict(),
544
+ 'optimizer_state_dict': optimizer.state_dict(),
545
+ 'scheduler_state_dict': scheduler.state_dict(),
546
+ 'best_val_loss': best_loss,
547
+ 'training_history': history,
548
+ 'model_config': {
549
+ 'latent_dim': self.latent_dim,
550
+ 'gene_dim': self.gene_dim,
551
+ 'hidden_dims': self.hidden_dims,
552
+ 'biological_prior_dim': self.biological_prior_dim,
553
+ 'dropout_rate': self.dropout_rate
554
+ }
555
+ }, path)
556
+
557
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
558
+ """
559
+ Predict gene expression with biological constraints
560
+
561
+ Args:
562
+ latent_data: Latent variables [n_samples, latent_dim]
563
+ batch_size: Prediction batch size
564
+
565
+ Returns:
566
+ expression: Predicted expression [n_samples, gene_dim]
567
+ """
568
+ if not self.is_trained:
569
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
570
+
571
+ self.model.eval()
572
+
573
+ if isinstance(latent_data, np.ndarray):
574
+ latent_data = torch.FloatTensor(latent_data)
575
+
576
+ # Predict in batches
577
+ predictions = []
578
+ with torch.no_grad():
579
+ for i in range(0, len(latent_data), batch_size):
580
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
581
+ batch_pred = self.model(batch_latent)
582
+ predictions.append(batch_pred.cpu())
583
+
584
+ return torch.cat(predictions).numpy()
585
+
586
+ def load_model(self, model_path: str):
587
+ """Load pre-trained model"""
588
+ checkpoint = torch.load(model_path, map_location=self.device)
589
+ self.model.load_state_dict(checkpoint['model_state_dict'])
590
+ self.is_trained = True
591
+ self.training_history = checkpoint.get('training_history')
592
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
593
+ print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
594
+
595
+ def get_model_info(self) -> Dict:
596
+ """Get model information"""
597
+ return {
598
+ 'is_trained': self.is_trained,
599
+ 'best_val_loss': self.best_val_loss,
600
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
601
+ 'latent_dim': self.latent_dim,
602
+ 'gene_dim': self.gene_dim,
603
+ 'hidden_dims': self.hidden_dims,
604
+ 'biological_prior_dim': self.biological_prior_dim,
605
+ 'device': str(self.device)
606
+ }
607
+
608
+ '''
609
+ # Example usage
610
+ def example_usage():
611
+ """Example demonstration of Virtual Cell Challenge decoder"""
612
+
613
+ # Initialize decoder
614
+ decoder = VirtualCellDecoder(
615
+ latent_dim=100,
616
+ gene_dim=2000, # Reduced for example
617
+ hidden_dims=[256, 512, 1024],
618
+ biological_prior_dim=128,
619
+ dropout_rate=0.1
620
+ )
621
+
622
+ # Generate example data with biological characteristics
623
+ n_samples = 1000
624
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
625
+
626
+ # Simulate biological expression data (sparse, non-negative)
627
+ weights = np.random.randn(100, 2000) * 0.1
628
+ expression_data = np.tanh(latent_data.dot(weights))
629
+ expression_data = np.maximum(expression_data, 0)
630
+
631
+ # Add biological sparsity (typical scRNA-seq characteristics)
632
+ mask = np.random.random(expression_data.shape) > 0.8 # 80% sparsity
633
+ expression_data[mask] = 0
634
+
635
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
636
+ print(f"🌿 Biological sparsity: {(expression_data == 0).mean():.3f}")
637
+
638
+ # Train with biological constraints
639
+ history = decoder.train(
640
+ train_latent=latent_data,
641
+ train_expression=expression_data,
642
+ batch_size=32,
643
+ num_epochs=50,
644
+ learning_rate=1e-4,
645
+ biological_weight=0.1
646
+ )
647
+
648
+ # Predict
649
+ test_latent = np.random.randn(10, 100).astype(np.float32)
650
+ predictions = decoder.predict(test_latent)
651
+ print(f"🔮 Prediction shape: {predictions.shape}")
652
+ print(f"🌿 Predicted sparsity: {(predictions < 0.1).mean():.3f}")
653
+
654
+ return decoder
655
+
656
+ if __name__ == "__main__":
657
+ example_usage()
658
+
659
+ '''
SURE/__init__.py CHANGED
@@ -2,6 +2,9 @@ from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
3
  from .PerturbE import PerturbE
4
4
  from .TranscriptomeDecoder import TranscriptomeDecoder
5
+ from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
6
+ from .EfficientTranscriptomeDecoder import EfficientTranscriptomeDecoder
7
+ from .VirtualCellDecoder import VirtualCellDecoder
5
8
 
6
9
  from . import utils
7
10
  from . import codebook
@@ -12,5 +15,9 @@ from . import flow
12
15
  from . import perturb
13
16
  from . import PerturbE
14
17
  from . import TranscriptomeDecoder
18
+ from . import SimpleTranscriptomeDecoder
19
+ from . import EfficientTranscriptomeDecoder
20
+ from . import VirtualCellDecoder
15
21
 
16
- __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
22
+ __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
23
+ 'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.5
3
+ Version: 2.4.17
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng