SURE-tools 2.2.23__tar.gz → 2.4.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {sure_tools-2.2.23 → sure_tools-2.4.17}/PKG-INFO +1 -1
  2. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/DensityFlow.py +17 -10
  3. sure_tools-2.4.17/SURE/EfficientTranscriptomeDecoder.py +552 -0
  4. sure_tools-2.4.17/SURE/PerturbE.py +1300 -0
  5. sure_tools-2.4.17/SURE/SimpleTranscriptomeDecoder.py +567 -0
  6. sure_tools-2.4.17/SURE/TranscriptomeDecoder.py +511 -0
  7. sure_tools-2.4.17/SURE/VirtualCellDecoder.py +659 -0
  8. sure_tools-2.4.17/SURE/__init__.py +23 -0
  9. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/custom_mlp.py +39 -2
  10. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/PKG-INFO +1 -1
  11. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/SOURCES.txt +5 -0
  12. {sure_tools-2.2.23 → sure_tools-2.4.17}/setup.py +1 -1
  13. sure_tools-2.2.23/SURE/__init__.py +0 -12
  14. {sure_tools-2.2.23 → sure_tools-2.4.17}/LICENSE +0 -0
  15. {sure_tools-2.2.23 → sure_tools-2.4.17}/README.md +0 -0
  16. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/SURE.py +0 -0
  17. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/__init__.py +0 -0
  18. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/assembly.py +0 -0
  19. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/atlas.py +0 -0
  20. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/atac/__init__.py +0 -0
  21. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/atac/utils.py +0 -0
  22. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/codebook/__init__.py +0 -0
  23. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/codebook/codebook.py +0 -0
  24. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/__init__.py +0 -0
  25. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/flow_stats.py +0 -0
  26. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/plot_quiver.py +0 -0
  27. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/perturb/__init__.py +0 -0
  28. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/perturb/perturb.py +0 -0
  29. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/__init__.py +0 -0
  30. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/queue.py +0 -0
  31. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/utils.py +0 -0
  32. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/dependency_links.txt +0 -0
  33. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/entry_points.txt +0 -0
  34. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/requires.txt +0 -0
  35. {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/top_level.txt +0 -0
  36. {sure_tools-2.2.23 → sure_tools-2.4.17}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.23
3
+ Version: 2.4.17
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -396,7 +396,8 @@ class DensityFlow(nn.Module):
396
396
  else:
397
397
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
398
  elif self.loss_func == 'multinomial':
399
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
399
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
400
401
  elif self.loss_func == 'bernoulli':
401
402
  if self.use_zeroinflate:
402
403
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -480,7 +481,8 @@ class DensityFlow(nn.Module):
480
481
  else:
481
482
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
482
483
  elif self.loss_func == 'multinomial':
483
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
484
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
484
486
  elif self.loss_func == 'bernoulli':
485
487
  if self.use_zeroinflate:
486
488
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -576,7 +578,8 @@ class DensityFlow(nn.Module):
576
578
  else:
577
579
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
578
580
  elif self.loss_func == 'multinomial':
579
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
581
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
580
583
  elif self.loss_func == 'bernoulli':
581
584
  if self.use_zeroinflate:
582
585
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -682,7 +685,8 @@ class DensityFlow(nn.Module):
682
685
  else:
683
686
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
684
687
  elif self.loss_func == 'multinomial':
685
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
688
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
686
690
  elif self.loss_func == 'bernoulli':
687
691
  if self.use_zeroinflate:
688
692
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -854,12 +858,12 @@ class DensityFlow(nn.Module):
854
858
  us_i = us[:,pert_idx].reshape(-1,1)
855
859
 
856
860
  # factor effect of xs
857
- dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
861
+ dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
858
862
 
859
863
  # perturbation effect
860
864
  ps = np.ones_like(us_i)
861
865
  if np.sum(np.abs(ps-us_i))>=1:
862
- dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
866
+ dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
863
867
  zs = zs + dzs0 + dzs
864
868
  else:
865
869
  zs = zs + dzs0
@@ -946,6 +950,9 @@ class DensityFlow(nn.Module):
946
950
  if self.loss_func == 'bernoulli':
947
951
  #counts = self.sigmoid(concentrate)
948
952
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
953
+ elif self.loss_func == 'multinomial':
954
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
955
+ counts = theta * library_size
949
956
  else:
950
957
  rate = concentrate.exp()
951
958
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -1340,7 +1347,7 @@ def main():
1340
1347
  cell_factor_size = 0 if us is None else us.shape[1]
1341
1348
 
1342
1349
  ###########################################
1343
- DensityFlow = DensityFlow(
1350
+ df = DensityFlow(
1344
1351
  input_size=input_size,
1345
1352
  cell_factor_size=cell_factor_size,
1346
1353
  inverse_dispersion=args.inverse_dispersion,
@@ -1359,7 +1366,7 @@ def main():
1359
1366
  dtype=dtype,
1360
1367
  )
1361
1368
 
1362
- DensityFlow.fit(xs, us=us,
1369
+ df.fit(xs, us=us,
1363
1370
  num_epochs=args.num_epochs,
1364
1371
  learning_rate=args.learning_rate,
1365
1372
  batch_size=args.batch_size,
@@ -1371,9 +1378,9 @@ def main():
1371
1378
 
1372
1379
  if args.save_model is not None:
1373
1380
  if args.save_model.endswith('gz'):
1374
- DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1381
+ DensityFlow.save_model(df, args.save_model, compression=True)
1375
1382
  else:
1376
- DensityFlow.save_model(DensityFlow, args.save_model)
1383
+ DensityFlow.save_model(df, args.save_model)
1377
1384
 
1378
1385
 
1379
1386
 
@@ -0,0 +1,552 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class EfficientTranscriptomeDecoder:
13
+ """
14
+ High-performance, memory-efficient transcriptome decoder
15
+ Fixed version with corrected RMSNorm implementation
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ gene_dim: int = 60000,
21
+ hidden_dims: List[int] = [512, 1024, 2048],
22
+ bottleneck_dim: int = 256,
23
+ num_experts: int = 4,
24
+ dropout_rate: float = 0.1,
25
+ device: str = None):
26
+ """
27
+ Advanced decoder combining multiple state-of-the-art techniques
28
+
29
+ Args:
30
+ latent_dim: Latent variable dimension
31
+ gene_dim: Number of genes (full transcriptome)
32
+ hidden_dims: Hidden layer dimensions
33
+ bottleneck_dim: Bottleneck dimension for memory efficiency
34
+ num_experts: Number of mixture-of-experts
35
+ dropout_rate: Dropout rate
36
+ device: Computation device
37
+ """
38
+ self.latent_dim = latent_dim
39
+ self.gene_dim = gene_dim
40
+ self.hidden_dims = hidden_dims
41
+ self.bottleneck_dim = bottleneck_dim
42
+ self.num_experts = num_experts
43
+ self.dropout_rate = dropout_rate
44
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Initialize model with corrected architecture
47
+ self.model = self._build_corrected_model()
48
+ self.model.to(self.device)
49
+
50
+ # Training state
51
+ self.is_trained = False
52
+ self.training_history = None
53
+ self.best_val_loss = float('inf')
54
+
55
+ print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
56
+ print(f" - Latent Dimension: {latent_dim}")
57
+ print(f" - Gene Dimension: {gene_dim}")
58
+ print(f" - Hidden Dimensions: {hidden_dims}")
59
+ print(f" - Bottleneck Dimension: {bottleneck_dim}")
60
+ print(f" - Number of Experts: {num_experts}")
61
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
62
+
63
+ class CorrectedRMSNorm(nn.Module):
64
+ """Corrected RMS Normalization with proper dimension handling"""
65
+ def __init__(self, dim: int, eps: float = 1e-8):
66
+ super().__init__()
67
+ self.eps = eps
68
+ self.dim = dim
69
+ self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
70
+
71
+ def forward(self, x):
72
+ # Ensure input has the right dimension
73
+ if x.size(-1) != self.dim:
74
+ raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
75
+
76
+ # Calculate RMS
77
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
78
+ # Normalize and apply weight
79
+ return x / rms * self.weight
80
+
81
+ class SimplifiedSwiGLU(nn.Module):
82
+ """Simplified SwiGLU activation"""
83
+ def forward(self, x):
84
+ # Split into two parts
85
+ x, gate = x.chunk(2, dim=-1)
86
+ return x * F.silu(gate)
87
+
88
+ class MemoryEfficientBottleneck(nn.Module):
89
+ """Memory-efficient bottleneck with corrected dimensions"""
90
+ def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
91
+ super().__init__()
92
+ # Ensure proper dimension matching
93
+ self.compress = nn.Linear(input_dim, bottleneck_dim)
94
+ self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
95
+ self.expand = nn.Linear(bottleneck_dim, output_dim)
96
+ self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
97
+ self.activation = nn.SiLU()
98
+ self.dropout = nn.Dropout(0.1)
99
+
100
+ def forward(self, x):
101
+ # Compress
102
+ compressed = self.compress(x)
103
+ compressed = self.norm1(compressed)
104
+ compressed = self.activation(compressed)
105
+ compressed = self.dropout(compressed)
106
+
107
+ # Expand
108
+ expanded = self.expand(compressed)
109
+ expanded = self.norm2(expanded)
110
+
111
+ return expanded
112
+
113
+ class StableMixtureOfExperts(nn.Module):
114
+ """Stable mixture of experts without dimension issues"""
115
+ def __init__(self, input_dim: int, num_experts: int = 4):
116
+ super().__init__()
117
+ self.num_experts = num_experts
118
+ self.input_dim = input_dim
119
+
120
+ # Shared expert with different scaling factors
121
+ self.shared_expert = nn.Sequential(
122
+ nn.Linear(input_dim, input_dim * 2),
123
+ nn.SiLU(),
124
+ nn.Dropout(0.1),
125
+ nn.Linear(input_dim * 2, input_dim)
126
+ )
127
+
128
+ # Gating network
129
+ self.gate = nn.Sequential(
130
+ nn.Linear(input_dim, num_experts * 4),
131
+ nn.SiLU(),
132
+ nn.Linear(num_experts * 4, num_experts)
133
+ )
134
+
135
+ def forward(self, x):
136
+ # Get gate weights
137
+ gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
138
+
139
+ # Process through shared expert
140
+ expert_output = self.shared_expert(x) # [batch, input_dim]
141
+
142
+ # Apply expert-specific scaling
143
+ weighted_output = torch.zeros_like(expert_output)
144
+ for i in range(self.num_experts):
145
+ expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
146
+ expert_contribution = expert_output * expert_scale
147
+ expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
148
+ weighted_output += expert_weight * expert_contribution
149
+
150
+ # Residual connection
151
+ return x + weighted_output
152
+
153
+ class CorrectedDecoder(nn.Module):
154
+ """Corrected decoder with proper dimension handling"""
155
+
156
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
157
+ bottleneck_dim: int, num_experts: int, dropout_rate: float):
158
+ super().__init__()
159
+
160
+ # Input projection
161
+ self.input_projection = nn.Sequential(
162
+ nn.Linear(latent_dim, hidden_dims[0]),
163
+ EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
164
+ nn.SiLU(),
165
+ nn.Dropout(dropout_rate)
166
+ )
167
+
168
+ # Main processing blocks
169
+ self.blocks = nn.ModuleList()
170
+ current_dim = hidden_dims[0]
171
+
172
+ for i, next_dim in enumerate(hidden_dims[1:], 1):
173
+ block = nn.ModuleDict({
174
+ 'swiglu': nn.Sequential(
175
+ nn.Linear(current_dim, current_dim * 2),
176
+ EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
177
+ nn.Dropout(dropout_rate),
178
+ nn.Linear(current_dim, current_dim) # Project back to same dimension
179
+ ),
180
+ 'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
181
+ current_dim, bottleneck_dim, next_dim
182
+ ),
183
+ 'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
184
+ next_dim, num_experts
185
+ )
186
+ })
187
+ self.blocks.append(block)
188
+ current_dim = next_dim
189
+
190
+ # Final projection to gene dimension
191
+ self.final_projection = nn.Sequential(
192
+ nn.Linear(current_dim, current_dim * 2),
193
+ nn.SiLU(),
194
+ nn.Dropout(dropout_rate),
195
+ nn.Linear(current_dim * 2, gene_dim)
196
+ )
197
+
198
+ # Output parameters
199
+ self.output_scale = nn.Parameter(torch.ones(1))
200
+ self.output_bias = nn.Parameter(torch.zeros(1))
201
+
202
+ self._init_weights()
203
+
204
+ def _init_weights(self):
205
+ """Proper weight initialization"""
206
+ for module in self.modules():
207
+ if isinstance(module, nn.Linear):
208
+ nn.init.xavier_uniform_(module.weight)
209
+ if module.bias is not None:
210
+ nn.init.zeros_(module.bias)
211
+
212
+ def forward(self, x):
213
+ # Input projection
214
+ x = self.input_projection(x)
215
+
216
+ # Process through blocks
217
+ for block in self.blocks:
218
+ # SwiGLU with residual
219
+ residual = x
220
+ x_swiglu = block['swiglu'](x)
221
+ x = x + x_swiglu # Residual connection
222
+
223
+ # Bottleneck
224
+ x = block['bottleneck'](x)
225
+
226
+ # Mixture of Experts with residual
227
+ x = block['experts'](x)
228
+
229
+ # Final projection
230
+ x = self.final_projection(x)
231
+
232
+ # Ensure non-negative output
233
+ x = F.softplus(x * self.output_scale + self.output_bias)
234
+
235
+ return x
236
+
237
+ def _build_corrected_model(self):
238
+ """Build the corrected model"""
239
+ return self.CorrectedDecoder(
240
+ self.latent_dim, self.gene_dim, self.hidden_dims,
241
+ self.bottleneck_dim, self.num_experts, self.dropout_rate
242
+ )
243
+
244
+ def train(self,
245
+ train_latent: np.ndarray,
246
+ train_expression: np.ndarray,
247
+ val_latent: np.ndarray = None,
248
+ val_expression: np.ndarray = None,
249
+ batch_size: int = 32,
250
+ num_epochs: int = 100,
251
+ learning_rate: float = 1e-4,
252
+ checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
253
+ """
254
+ Train the corrected decoder
255
+ """
256
+ print("🚀 Starting Training...")
257
+
258
+ # Data preparation
259
+ train_dataset = self._create_dataset(train_latent, train_expression)
260
+
261
+ if val_latent is not None and val_expression is not None:
262
+ val_dataset = self._create_dataset(val_latent, val_expression)
263
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
264
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
265
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
266
+ else:
267
+ # Auto split
268
+ train_size = int(0.9 * len(train_dataset))
269
+ val_size = len(train_dataset) - train_size
270
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
271
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
272
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
273
+ print(f"📈 Auto-split validation: {val_size} samples")
274
+
275
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
276
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
277
+ print(f"📊 Batch size: {batch_size}")
278
+
279
+ # Optimizer
280
+ optimizer = optim.AdamW(
281
+ self.model.parameters(),
282
+ lr=learning_rate,
283
+ weight_decay=0.01,
284
+ betas=(0.9, 0.999)
285
+ )
286
+
287
+ # Scheduler
288
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
289
+
290
+ # Loss function
291
+ def combined_loss(pred, target):
292
+ mse_loss = F.mse_loss(pred, target)
293
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
294
+ correlation = self._pearson_correlation(pred, target)
295
+ correlation_loss = 1 - correlation
296
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
297
+
298
+ # Training history
299
+ history = {
300
+ 'train_loss': [], 'val_loss': [],
301
+ 'train_mse': [], 'val_mse': [],
302
+ 'train_correlation': [], 'val_correlation': [],
303
+ 'learning_rates': []
304
+ }
305
+
306
+ best_val_loss = float('inf')
307
+ patience = 20
308
+ patience_counter = 0
309
+
310
+ print("\n📈 Starting training loop...")
311
+ for epoch in range(1, num_epochs + 1):
312
+ # Training
313
+ train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
314
+
315
+ # Validation
316
+ val_metrics = self._validate_epoch(val_loader, combined_loss)
317
+
318
+ # Update scheduler
319
+ scheduler.step()
320
+ current_lr = optimizer.param_groups[0]['lr']
321
+
322
+ # Record history
323
+ history['train_loss'].append(train_metrics['loss'])
324
+ history['val_loss'].append(val_metrics['loss'])
325
+ history['train_mse'].append(train_metrics['mse'])
326
+ history['val_mse'].append(val_metrics['mse'])
327
+ history['train_correlation'].append(train_metrics['correlation'])
328
+ history['val_correlation'].append(val_metrics['correlation'])
329
+ history['learning_rates'].append(current_lr)
330
+
331
+ # Print progress
332
+ if epoch % 10 == 0 or epoch == 1:
333
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
334
+ f"Train Loss: {train_metrics['loss']:.4f} | "
335
+ f"Val Loss: {val_metrics['loss']:.4f} | "
336
+ f"Correlation: {val_metrics['correlation']:.4f} | "
337
+ f"LR: {current_lr:.2e}")
338
+
339
+ # Early stopping
340
+ if val_metrics['loss'] < best_val_loss:
341
+ best_val_loss = val_metrics['loss']
342
+ patience_counter = 0
343
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
344
+ if epoch % 20 == 0:
345
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
346
+ else:
347
+ patience_counter += 1
348
+ if patience_counter >= patience:
349
+ print(f"🛑 Early stopping at epoch {epoch}")
350
+ break
351
+
352
+ # Training completed
353
+ self.is_trained = True
354
+ self.training_history = history
355
+ self.best_val_loss = best_val_loss
356
+
357
+ print(f"\n🎉 Training completed!")
358
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
359
+
360
+ return history
361
+
362
+ def _create_dataset(self, latent_data, expression_data):
363
+ """Create dataset"""
364
+ class SimpleDataset(Dataset):
365
+ def __init__(self, latent, expression):
366
+ self.latent = torch.FloatTensor(latent)
367
+ self.expression = torch.FloatTensor(expression)
368
+
369
+ def __len__(self):
370
+ return len(self.latent)
371
+
372
+ def __getitem__(self, idx):
373
+ return self.latent[idx], self.expression[idx]
374
+
375
+ return SimpleDataset(latent_data, expression_data)
376
+
377
+ def _pearson_correlation(self, pred, target):
378
+ """Calculate Pearson correlation"""
379
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
380
+ target_centered = target - target.mean(dim=1, keepdim=True)
381
+
382
+ numerator = (pred_centered * target_centered).sum(dim=1)
383
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
384
+
385
+ return (numerator / (denominator + 1e-8)).mean()
386
+
387
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
388
+ """Train one epoch"""
389
+ self.model.train()
390
+ total_loss = 0
391
+ total_mse = 0
392
+ total_correlation = 0
393
+
394
+ for latent, target in train_loader:
395
+ latent = latent.to(self.device, non_blocking=True)
396
+ target = target.to(self.device, non_blocking=True)
397
+
398
+ optimizer.zero_grad()
399
+ pred = self.model(latent)
400
+
401
+ loss = loss_fn(pred, target)
402
+ loss.backward()
403
+
404
+ # Gradient clipping
405
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
406
+ optimizer.step()
407
+
408
+ # Calculate metrics
409
+ mse_loss = F.mse_loss(pred, target).item()
410
+ correlation = self._pearson_correlation(pred, target).item()
411
+
412
+ total_loss += loss.item()
413
+ total_mse += mse_loss
414
+ total_correlation += correlation
415
+
416
+ num_batches = len(train_loader)
417
+ return {
418
+ 'loss': total_loss / num_batches,
419
+ 'mse': total_mse / num_batches,
420
+ 'correlation': total_correlation / num_batches
421
+ }
422
+
423
+ def _validate_epoch(self, val_loader, loss_fn):
424
+ """Validate one epoch"""
425
+ self.model.eval()
426
+ total_loss = 0
427
+ total_mse = 0
428
+ total_correlation = 0
429
+
430
+ with torch.no_grad():
431
+ for latent, target in val_loader:
432
+ latent = latent.to(self.device, non_blocking=True)
433
+ target = target.to(self.device, non_blocking=True)
434
+
435
+ pred = self.model(latent)
436
+ loss = loss_fn(pred, target)
437
+ mse_loss = F.mse_loss(pred, target).item()
438
+ correlation = self._pearson_correlation(pred, target).item()
439
+
440
+ total_loss += loss.item()
441
+ total_mse += mse_loss
442
+ total_correlation += correlation
443
+
444
+ num_batches = len(val_loader)
445
+ return {
446
+ 'loss': total_loss / num_batches,
447
+ 'mse': total_mse / num_batches,
448
+ 'correlation': total_correlation / num_batches
449
+ }
450
+
451
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
452
+ """Save checkpoint"""
453
+ torch.save({
454
+ 'epoch': epoch,
455
+ 'model_state_dict': self.model.state_dict(),
456
+ 'optimizer_state_dict': optimizer.state_dict(),
457
+ 'scheduler_state_dict': scheduler.state_dict(),
458
+ 'best_val_loss': best_loss,
459
+ 'training_history': history,
460
+ 'model_config': {
461
+ 'latent_dim': self.latent_dim,
462
+ 'gene_dim': self.gene_dim,
463
+ 'hidden_dims': self.hidden_dims,
464
+ 'bottleneck_dim': self.bottleneck_dim,
465
+ 'num_experts': self.num_experts
466
+ }
467
+ }, path)
468
+
469
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
470
+ """Predict gene expression"""
471
+ if not self.is_trained:
472
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
473
+
474
+ self.model.eval()
475
+
476
+ if isinstance(latent_data, np.ndarray):
477
+ latent_data = torch.FloatTensor(latent_data)
478
+
479
+ predictions = []
480
+ with torch.no_grad():
481
+ for i in range(0, len(latent_data), batch_size):
482
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
483
+ batch_pred = self.model(batch_latent)
484
+ predictions.append(batch_pred.cpu())
485
+
486
+ return torch.cat(predictions).numpy()
487
+
488
+ def load_model(self, model_path: str):
489
+ """Load pre-trained model"""
490
+ checkpoint = torch.load(model_path, map_location=self.device)
491
+ self.model.load_state_dict(checkpoint['model_state_dict'])
492
+ self.is_trained = True
493
+ self.training_history = checkpoint.get('training_history')
494
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
495
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
496
+
497
+ def get_model_info(self) -> Dict:
498
+ """Get model information"""
499
+ return {
500
+ 'is_trained': self.is_trained,
501
+ 'best_val_loss': self.best_val_loss,
502
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
503
+ 'latent_dim': self.latent_dim,
504
+ 'gene_dim': self.gene_dim,
505
+ 'hidden_dims': self.hidden_dims,
506
+ 'device': str(self.device)
507
+ }
508
+
509
+ '''
510
+ # Example usage
511
+ def example_usage():
512
+ """Example demonstration"""
513
+
514
+ # Initialize decoder
515
+ decoder = EfficientTranscriptomeDecoder(
516
+ latent_dim=100,
517
+ gene_dim=2000, # Reduced for example
518
+ hidden_dims=[256, 512, 1024],
519
+ bottleneck_dim=128,
520
+ num_experts=4,
521
+ dropout_rate=0.1
522
+ )
523
+
524
+ # Generate example data
525
+ n_samples = 1000
526
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
527
+
528
+ # Simulate expression data
529
+ weights = np.random.randn(100, 2000) * 0.1
530
+ expression_data = np.tanh(latent_data.dot(weights))
531
+ expression_data = np.maximum(expression_data, 0)
532
+
533
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
534
+
535
+ # Train
536
+ history = decoder.train(
537
+ train_latent=latent_data,
538
+ train_expression=expression_data,
539
+ batch_size=32,
540
+ num_epochs=50
541
+ )
542
+
543
+ # Predict
544
+ test_latent = np.random.randn(10, 100).astype(np.float32)
545
+ predictions = decoder.predict(test_latent)
546
+ print(f"🔮 Prediction shape: {predictions.shape}")
547
+
548
+ return decoder
549
+
550
+ if __name__ == "__main__":
551
+ example_usage()
552
+ '''