SURE-tools 2.3.2__tar.gz → 2.4.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (37) hide show
  1. {sure_tools-2.3.2 → sure_tools-2.4.23}/PKG-INFO +1 -1
  2. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/DensityFlow.py +7 -0
  3. sure_tools-2.4.23/SURE/EfficientTranscriptomeDecoder.py +552 -0
  4. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/PerturbE.py +11 -4
  5. sure_tools-2.4.23/SURE/PerturbationAwareDecoder.py +737 -0
  6. sure_tools-2.4.23/SURE/SimpleTranscriptomeDecoder.py +567 -0
  7. sure_tools-2.4.23/SURE/TranscriptomeDecoder.py +511 -0
  8. sure_tools-2.4.23/SURE/VirtualCellDecoder.py +658 -0
  9. sure_tools-2.4.23/SURE/__init__.py +26 -0
  10. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/PKG-INFO +1 -1
  11. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/SOURCES.txt +5 -0
  12. {sure_tools-2.3.2 → sure_tools-2.4.23}/setup.py +1 -1
  13. sure_tools-2.3.2/SURE/__init__.py +0 -13
  14. {sure_tools-2.3.2 → sure_tools-2.4.23}/LICENSE +0 -0
  15. {sure_tools-2.3.2 → sure_tools-2.4.23}/README.md +0 -0
  16. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/SURE.py +0 -0
  17. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/assembly/__init__.py +0 -0
  18. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/assembly/assembly.py +0 -0
  19. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/assembly/atlas.py +0 -0
  20. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/atac/__init__.py +0 -0
  21. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/atac/utils.py +0 -0
  22. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/codebook/__init__.py +0 -0
  23. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/codebook/codebook.py +0 -0
  24. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/flow/__init__.py +0 -0
  25. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/flow/flow_stats.py +0 -0
  26. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/flow/plot_quiver.py +0 -0
  27. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/perturb/__init__.py +0 -0
  28. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/perturb/perturb.py +0 -0
  29. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/utils/__init__.py +0 -0
  30. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/utils/custom_mlp.py +0 -0
  31. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/utils/queue.py +0 -0
  32. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE/utils/utils.py +0 -0
  33. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/dependency_links.txt +0 -0
  34. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/entry_points.txt +0 -0
  35. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/requires.txt +0 -0
  36. {sure_tools-2.3.2 → sure_tools-2.4.23}/SURE_tools.egg-info/top_level.txt +0 -0
  37. {sure_tools-2.3.2 → sure_tools-2.4.23}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.3.2
3
+ Version: 2.4.23
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
109
109
 
110
110
  set_random_seed(seed)
111
111
  self.setup_networks()
112
+
113
+ print(f"🧬 DensityFlow Initialized:")
114
+ print(f" - Latent Dimension: {self.latent_dim}")
115
+ print(f" - Gene Dimension: {self.input_size}")
116
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
117
+ print(f" - Device: {self.get_device()}")
118
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
112
119
 
113
120
  def setup_networks(self):
114
121
  latent_dim = self.latent_dim
@@ -0,0 +1,552 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class EfficientTranscriptomeDecoder:
13
+ """
14
+ High-performance, memory-efficient transcriptome decoder
15
+ Fixed version with corrected RMSNorm implementation
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ gene_dim: int = 60000,
21
+ hidden_dims: List[int] = [512, 1024, 2048],
22
+ bottleneck_dim: int = 256,
23
+ num_experts: int = 4,
24
+ dropout_rate: float = 0.1,
25
+ device: str = None):
26
+ """
27
+ Advanced decoder combining multiple state-of-the-art techniques
28
+
29
+ Args:
30
+ latent_dim: Latent variable dimension
31
+ gene_dim: Number of genes (full transcriptome)
32
+ hidden_dims: Hidden layer dimensions
33
+ bottleneck_dim: Bottleneck dimension for memory efficiency
34
+ num_experts: Number of mixture-of-experts
35
+ dropout_rate: Dropout rate
36
+ device: Computation device
37
+ """
38
+ self.latent_dim = latent_dim
39
+ self.gene_dim = gene_dim
40
+ self.hidden_dims = hidden_dims
41
+ self.bottleneck_dim = bottleneck_dim
42
+ self.num_experts = num_experts
43
+ self.dropout_rate = dropout_rate
44
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Initialize model with corrected architecture
47
+ self.model = self._build_corrected_model()
48
+ self.model.to(self.device)
49
+
50
+ # Training state
51
+ self.is_trained = False
52
+ self.training_history = None
53
+ self.best_val_loss = float('inf')
54
+
55
+ print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
56
+ print(f" - Latent Dimension: {latent_dim}")
57
+ print(f" - Gene Dimension: {gene_dim}")
58
+ print(f" - Hidden Dimensions: {hidden_dims}")
59
+ print(f" - Bottleneck Dimension: {bottleneck_dim}")
60
+ print(f" - Number of Experts: {num_experts}")
61
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
62
+
63
+ class CorrectedRMSNorm(nn.Module):
64
+ """Corrected RMS Normalization with proper dimension handling"""
65
+ def __init__(self, dim: int, eps: float = 1e-8):
66
+ super().__init__()
67
+ self.eps = eps
68
+ self.dim = dim
69
+ self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
70
+
71
+ def forward(self, x):
72
+ # Ensure input has the right dimension
73
+ if x.size(-1) != self.dim:
74
+ raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
75
+
76
+ # Calculate RMS
77
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
78
+ # Normalize and apply weight
79
+ return x / rms * self.weight
80
+
81
+ class SimplifiedSwiGLU(nn.Module):
82
+ """Simplified SwiGLU activation"""
83
+ def forward(self, x):
84
+ # Split into two parts
85
+ x, gate = x.chunk(2, dim=-1)
86
+ return x * F.silu(gate)
87
+
88
+ class MemoryEfficientBottleneck(nn.Module):
89
+ """Memory-efficient bottleneck with corrected dimensions"""
90
+ def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
91
+ super().__init__()
92
+ # Ensure proper dimension matching
93
+ self.compress = nn.Linear(input_dim, bottleneck_dim)
94
+ self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
95
+ self.expand = nn.Linear(bottleneck_dim, output_dim)
96
+ self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
97
+ self.activation = nn.SiLU()
98
+ self.dropout = nn.Dropout(0.1)
99
+
100
+ def forward(self, x):
101
+ # Compress
102
+ compressed = self.compress(x)
103
+ compressed = self.norm1(compressed)
104
+ compressed = self.activation(compressed)
105
+ compressed = self.dropout(compressed)
106
+
107
+ # Expand
108
+ expanded = self.expand(compressed)
109
+ expanded = self.norm2(expanded)
110
+
111
+ return expanded
112
+
113
+ class StableMixtureOfExperts(nn.Module):
114
+ """Stable mixture of experts without dimension issues"""
115
+ def __init__(self, input_dim: int, num_experts: int = 4):
116
+ super().__init__()
117
+ self.num_experts = num_experts
118
+ self.input_dim = input_dim
119
+
120
+ # Shared expert with different scaling factors
121
+ self.shared_expert = nn.Sequential(
122
+ nn.Linear(input_dim, input_dim * 2),
123
+ nn.SiLU(),
124
+ nn.Dropout(0.1),
125
+ nn.Linear(input_dim * 2, input_dim)
126
+ )
127
+
128
+ # Gating network
129
+ self.gate = nn.Sequential(
130
+ nn.Linear(input_dim, num_experts * 4),
131
+ nn.SiLU(),
132
+ nn.Linear(num_experts * 4, num_experts)
133
+ )
134
+
135
+ def forward(self, x):
136
+ # Get gate weights
137
+ gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
138
+
139
+ # Process through shared expert
140
+ expert_output = self.shared_expert(x) # [batch, input_dim]
141
+
142
+ # Apply expert-specific scaling
143
+ weighted_output = torch.zeros_like(expert_output)
144
+ for i in range(self.num_experts):
145
+ expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
146
+ expert_contribution = expert_output * expert_scale
147
+ expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
148
+ weighted_output += expert_weight * expert_contribution
149
+
150
+ # Residual connection
151
+ return x + weighted_output
152
+
153
+ class CorrectedDecoder(nn.Module):
154
+ """Corrected decoder with proper dimension handling"""
155
+
156
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
157
+ bottleneck_dim: int, num_experts: int, dropout_rate: float):
158
+ super().__init__()
159
+
160
+ # Input projection
161
+ self.input_projection = nn.Sequential(
162
+ nn.Linear(latent_dim, hidden_dims[0]),
163
+ EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
164
+ nn.SiLU(),
165
+ nn.Dropout(dropout_rate)
166
+ )
167
+
168
+ # Main processing blocks
169
+ self.blocks = nn.ModuleList()
170
+ current_dim = hidden_dims[0]
171
+
172
+ for i, next_dim in enumerate(hidden_dims[1:], 1):
173
+ block = nn.ModuleDict({
174
+ 'swiglu': nn.Sequential(
175
+ nn.Linear(current_dim, current_dim * 2),
176
+ EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
177
+ nn.Dropout(dropout_rate),
178
+ nn.Linear(current_dim, current_dim) # Project back to same dimension
179
+ ),
180
+ 'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
181
+ current_dim, bottleneck_dim, next_dim
182
+ ),
183
+ 'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
184
+ next_dim, num_experts
185
+ )
186
+ })
187
+ self.blocks.append(block)
188
+ current_dim = next_dim
189
+
190
+ # Final projection to gene dimension
191
+ self.final_projection = nn.Sequential(
192
+ nn.Linear(current_dim, current_dim * 2),
193
+ nn.SiLU(),
194
+ nn.Dropout(dropout_rate),
195
+ nn.Linear(current_dim * 2, gene_dim)
196
+ )
197
+
198
+ # Output parameters
199
+ self.output_scale = nn.Parameter(torch.ones(1))
200
+ self.output_bias = nn.Parameter(torch.zeros(1))
201
+
202
+ self._init_weights()
203
+
204
+ def _init_weights(self):
205
+ """Proper weight initialization"""
206
+ for module in self.modules():
207
+ if isinstance(module, nn.Linear):
208
+ nn.init.xavier_uniform_(module.weight)
209
+ if module.bias is not None:
210
+ nn.init.zeros_(module.bias)
211
+
212
+ def forward(self, x):
213
+ # Input projection
214
+ x = self.input_projection(x)
215
+
216
+ # Process through blocks
217
+ for block in self.blocks:
218
+ # SwiGLU with residual
219
+ residual = x
220
+ x_swiglu = block['swiglu'](x)
221
+ x = x + x_swiglu # Residual connection
222
+
223
+ # Bottleneck
224
+ x = block['bottleneck'](x)
225
+
226
+ # Mixture of Experts with residual
227
+ x = block['experts'](x)
228
+
229
+ # Final projection
230
+ x = self.final_projection(x)
231
+
232
+ # Ensure non-negative output
233
+ x = F.softplus(x * self.output_scale + self.output_bias)
234
+
235
+ return x
236
+
237
+ def _build_corrected_model(self):
238
+ """Build the corrected model"""
239
+ return self.CorrectedDecoder(
240
+ self.latent_dim, self.gene_dim, self.hidden_dims,
241
+ self.bottleneck_dim, self.num_experts, self.dropout_rate
242
+ )
243
+
244
+ def train(self,
245
+ train_latent: np.ndarray,
246
+ train_expression: np.ndarray,
247
+ val_latent: np.ndarray = None,
248
+ val_expression: np.ndarray = None,
249
+ batch_size: int = 32,
250
+ num_epochs: int = 100,
251
+ learning_rate: float = 1e-4,
252
+ checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
253
+ """
254
+ Train the corrected decoder
255
+ """
256
+ print("🚀 Starting Training...")
257
+
258
+ # Data preparation
259
+ train_dataset = self._create_dataset(train_latent, train_expression)
260
+
261
+ if val_latent is not None and val_expression is not None:
262
+ val_dataset = self._create_dataset(val_latent, val_expression)
263
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
264
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
265
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
266
+ else:
267
+ # Auto split
268
+ train_size = int(0.9 * len(train_dataset))
269
+ val_size = len(train_dataset) - train_size
270
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
271
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
272
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
273
+ print(f"📈 Auto-split validation: {val_size} samples")
274
+
275
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
276
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
277
+ print(f"📊 Batch size: {batch_size}")
278
+
279
+ # Optimizer
280
+ optimizer = optim.AdamW(
281
+ self.model.parameters(),
282
+ lr=learning_rate,
283
+ weight_decay=0.01,
284
+ betas=(0.9, 0.999)
285
+ )
286
+
287
+ # Scheduler
288
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
289
+
290
+ # Loss function
291
+ def combined_loss(pred, target):
292
+ mse_loss = F.mse_loss(pred, target)
293
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
294
+ correlation = self._pearson_correlation(pred, target)
295
+ correlation_loss = 1 - correlation
296
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
297
+
298
+ # Training history
299
+ history = {
300
+ 'train_loss': [], 'val_loss': [],
301
+ 'train_mse': [], 'val_mse': [],
302
+ 'train_correlation': [], 'val_correlation': [],
303
+ 'learning_rates': []
304
+ }
305
+
306
+ best_val_loss = float('inf')
307
+ patience = 20
308
+ patience_counter = 0
309
+
310
+ print("\n📈 Starting training loop...")
311
+ for epoch in range(1, num_epochs + 1):
312
+ # Training
313
+ train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
314
+
315
+ # Validation
316
+ val_metrics = self._validate_epoch(val_loader, combined_loss)
317
+
318
+ # Update scheduler
319
+ scheduler.step()
320
+ current_lr = optimizer.param_groups[0]['lr']
321
+
322
+ # Record history
323
+ history['train_loss'].append(train_metrics['loss'])
324
+ history['val_loss'].append(val_metrics['loss'])
325
+ history['train_mse'].append(train_metrics['mse'])
326
+ history['val_mse'].append(val_metrics['mse'])
327
+ history['train_correlation'].append(train_metrics['correlation'])
328
+ history['val_correlation'].append(val_metrics['correlation'])
329
+ history['learning_rates'].append(current_lr)
330
+
331
+ # Print progress
332
+ if epoch % 10 == 0 or epoch == 1:
333
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
334
+ f"Train Loss: {train_metrics['loss']:.4f} | "
335
+ f"Val Loss: {val_metrics['loss']:.4f} | "
336
+ f"Correlation: {val_metrics['correlation']:.4f} | "
337
+ f"LR: {current_lr:.2e}")
338
+
339
+ # Early stopping
340
+ if val_metrics['loss'] < best_val_loss:
341
+ best_val_loss = val_metrics['loss']
342
+ patience_counter = 0
343
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
344
+ if epoch % 20 == 0:
345
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
346
+ else:
347
+ patience_counter += 1
348
+ if patience_counter >= patience:
349
+ print(f"🛑 Early stopping at epoch {epoch}")
350
+ break
351
+
352
+ # Training completed
353
+ self.is_trained = True
354
+ self.training_history = history
355
+ self.best_val_loss = best_val_loss
356
+
357
+ print(f"\n🎉 Training completed!")
358
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
359
+
360
+ return history
361
+
362
+ def _create_dataset(self, latent_data, expression_data):
363
+ """Create dataset"""
364
+ class SimpleDataset(Dataset):
365
+ def __init__(self, latent, expression):
366
+ self.latent = torch.FloatTensor(latent)
367
+ self.expression = torch.FloatTensor(expression)
368
+
369
+ def __len__(self):
370
+ return len(self.latent)
371
+
372
+ def __getitem__(self, idx):
373
+ return self.latent[idx], self.expression[idx]
374
+
375
+ return SimpleDataset(latent_data, expression_data)
376
+
377
+ def _pearson_correlation(self, pred, target):
378
+ """Calculate Pearson correlation"""
379
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
380
+ target_centered = target - target.mean(dim=1, keepdim=True)
381
+
382
+ numerator = (pred_centered * target_centered).sum(dim=1)
383
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
384
+
385
+ return (numerator / (denominator + 1e-8)).mean()
386
+
387
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
388
+ """Train one epoch"""
389
+ self.model.train()
390
+ total_loss = 0
391
+ total_mse = 0
392
+ total_correlation = 0
393
+
394
+ for latent, target in train_loader:
395
+ latent = latent.to(self.device, non_blocking=True)
396
+ target = target.to(self.device, non_blocking=True)
397
+
398
+ optimizer.zero_grad()
399
+ pred = self.model(latent)
400
+
401
+ loss = loss_fn(pred, target)
402
+ loss.backward()
403
+
404
+ # Gradient clipping
405
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
406
+ optimizer.step()
407
+
408
+ # Calculate metrics
409
+ mse_loss = F.mse_loss(pred, target).item()
410
+ correlation = self._pearson_correlation(pred, target).item()
411
+
412
+ total_loss += loss.item()
413
+ total_mse += mse_loss
414
+ total_correlation += correlation
415
+
416
+ num_batches = len(train_loader)
417
+ return {
418
+ 'loss': total_loss / num_batches,
419
+ 'mse': total_mse / num_batches,
420
+ 'correlation': total_correlation / num_batches
421
+ }
422
+
423
+ def _validate_epoch(self, val_loader, loss_fn):
424
+ """Validate one epoch"""
425
+ self.model.eval()
426
+ total_loss = 0
427
+ total_mse = 0
428
+ total_correlation = 0
429
+
430
+ with torch.no_grad():
431
+ for latent, target in val_loader:
432
+ latent = latent.to(self.device, non_blocking=True)
433
+ target = target.to(self.device, non_blocking=True)
434
+
435
+ pred = self.model(latent)
436
+ loss = loss_fn(pred, target)
437
+ mse_loss = F.mse_loss(pred, target).item()
438
+ correlation = self._pearson_correlation(pred, target).item()
439
+
440
+ total_loss += loss.item()
441
+ total_mse += mse_loss
442
+ total_correlation += correlation
443
+
444
+ num_batches = len(val_loader)
445
+ return {
446
+ 'loss': total_loss / num_batches,
447
+ 'mse': total_mse / num_batches,
448
+ 'correlation': total_correlation / num_batches
449
+ }
450
+
451
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
452
+ """Save checkpoint"""
453
+ torch.save({
454
+ 'epoch': epoch,
455
+ 'model_state_dict': self.model.state_dict(),
456
+ 'optimizer_state_dict': optimizer.state_dict(),
457
+ 'scheduler_state_dict': scheduler.state_dict(),
458
+ 'best_val_loss': best_loss,
459
+ 'training_history': history,
460
+ 'model_config': {
461
+ 'latent_dim': self.latent_dim,
462
+ 'gene_dim': self.gene_dim,
463
+ 'hidden_dims': self.hidden_dims,
464
+ 'bottleneck_dim': self.bottleneck_dim,
465
+ 'num_experts': self.num_experts
466
+ }
467
+ }, path)
468
+
469
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
470
+ """Predict gene expression"""
471
+ if not self.is_trained:
472
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
473
+
474
+ self.model.eval()
475
+
476
+ if isinstance(latent_data, np.ndarray):
477
+ latent_data = torch.FloatTensor(latent_data)
478
+
479
+ predictions = []
480
+ with torch.no_grad():
481
+ for i in range(0, len(latent_data), batch_size):
482
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
483
+ batch_pred = self.model(batch_latent)
484
+ predictions.append(batch_pred.cpu())
485
+
486
+ return torch.cat(predictions).numpy()
487
+
488
+ def load_model(self, model_path: str):
489
+ """Load pre-trained model"""
490
+ checkpoint = torch.load(model_path, map_location=self.device)
491
+ self.model.load_state_dict(checkpoint['model_state_dict'])
492
+ self.is_trained = True
493
+ self.training_history = checkpoint.get('training_history')
494
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
495
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
496
+
497
+ def get_model_info(self) -> Dict:
498
+ """Get model information"""
499
+ return {
500
+ 'is_trained': self.is_trained,
501
+ 'best_val_loss': self.best_val_loss,
502
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
503
+ 'latent_dim': self.latent_dim,
504
+ 'gene_dim': self.gene_dim,
505
+ 'hidden_dims': self.hidden_dims,
506
+ 'device': str(self.device)
507
+ }
508
+
509
+ '''
510
+ # Example usage
511
+ def example_usage():
512
+ """Example demonstration"""
513
+
514
+ # Initialize decoder
515
+ decoder = EfficientTranscriptomeDecoder(
516
+ latent_dim=100,
517
+ gene_dim=2000, # Reduced for example
518
+ hidden_dims=[256, 512, 1024],
519
+ bottleneck_dim=128,
520
+ num_experts=4,
521
+ dropout_rate=0.1
522
+ )
523
+
524
+ # Generate example data
525
+ n_samples = 1000
526
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
527
+
528
+ # Simulate expression data
529
+ weights = np.random.randn(100, 2000) * 0.1
530
+ expression_data = np.tanh(latent_data.dot(weights))
531
+ expression_data = np.maximum(expression_data, 0)
532
+
533
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
534
+
535
+ # Train
536
+ history = decoder.train(
537
+ train_latent=latent_data,
538
+ train_expression=expression_data,
539
+ batch_size=32,
540
+ num_epochs=50
541
+ )
542
+
543
+ # Predict
544
+ test_latent = np.random.randn(10, 100).astype(np.float32)
545
+ predictions = decoder.predict(test_latent)
546
+ print(f"🔮 Prediction shape: {predictions.shape}")
547
+
548
+ return decoder
549
+
550
+ if __name__ == "__main__":
551
+ example_usage()
552
+ '''
@@ -349,7 +349,8 @@ class PerturbE(nn.Module):
349
349
  else:
350
350
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
351
351
  elif self.loss_func == 'multinomial':
352
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
353
354
  elif self.loss_func == 'bernoulli':
354
355
  if self.use_zeroinflate:
355
356
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -433,7 +434,8 @@ class PerturbE(nn.Module):
433
434
  else:
434
435
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
435
436
  elif self.loss_func == 'multinomial':
436
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
437
439
  elif self.loss_func == 'bernoulli':
438
440
  if self.use_zeroinflate:
439
441
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -529,7 +531,8 @@ class PerturbE(nn.Module):
529
531
  else:
530
532
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
531
533
  elif self.loss_func == 'multinomial':
532
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
533
536
  elif self.loss_func == 'bernoulli':
534
537
  if self.use_zeroinflate:
535
538
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -635,7 +638,8 @@ class PerturbE(nn.Module):
635
638
  else:
636
639
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
637
640
  elif self.loss_func == 'multinomial':
638
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
639
643
  elif self.loss_func == 'bernoulli':
640
644
  if self.use_zeroinflate:
641
645
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -858,6 +862,9 @@ class PerturbE(nn.Module):
858
862
  if self.loss_func == 'bernoulli':
859
863
  #counts = self.sigmoid(concentrate)
860
864
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
861
868
  else:
862
869
  rate = concentrate.exp()
863
870
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean