SURE-tools 2.2.27__tar.gz → 2.4.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.2.27 → sure_tools-2.4.20}/PKG-INFO +1 -1
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/DensityFlow.py +22 -8
- sure_tools-2.4.20/SURE/EfficientTranscriptomeDecoder.py +552 -0
- sure_tools-2.4.20/SURE/PerturbE.py +1300 -0
- sure_tools-2.4.20/SURE/PerturbationAwareDecoder.py +770 -0
- sure_tools-2.4.20/SURE/SimpleTranscriptomeDecoder.py +567 -0
- sure_tools-2.4.20/SURE/TranscriptomeDecoder.py +511 -0
- sure_tools-2.4.20/SURE/VirtualCellDecoder.py +658 -0
- sure_tools-2.4.20/SURE/__init__.py +26 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/utils/custom_mlp.py +31 -1
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/SOURCES.txt +6 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/setup.py +1 -1
- sure_tools-2.2.27/SURE/__init__.py +0 -12
- {sure_tools-2.2.27 → sure_tools-2.4.20}/LICENSE +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/README.md +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/SURE.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.27 → sure_tools-2.4.20}/setup.cfg +0 -0
|
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
|
|
|
109
109
|
|
|
110
110
|
set_random_seed(seed)
|
|
111
111
|
self.setup_networks()
|
|
112
|
+
|
|
113
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
114
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
116
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
117
|
+
print(f" - Device: {self.get_device()}")
|
|
118
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
112
119
|
|
|
113
120
|
def setup_networks(self):
|
|
114
121
|
latent_dim = self.latent_dim
|
|
@@ -396,7 +403,8 @@ class DensityFlow(nn.Module):
|
|
|
396
403
|
else:
|
|
397
404
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
405
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
406
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
407
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
400
408
|
elif self.loss_func == 'bernoulli':
|
|
401
409
|
if self.use_zeroinflate:
|
|
402
410
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -480,7 +488,8 @@ class DensityFlow(nn.Module):
|
|
|
480
488
|
else:
|
|
481
489
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
482
490
|
elif self.loss_func == 'multinomial':
|
|
483
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
491
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
492
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
484
493
|
elif self.loss_func == 'bernoulli':
|
|
485
494
|
if self.use_zeroinflate:
|
|
486
495
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -576,7 +585,8 @@ class DensityFlow(nn.Module):
|
|
|
576
585
|
else:
|
|
577
586
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
578
587
|
elif self.loss_func == 'multinomial':
|
|
579
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
588
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
589
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
580
590
|
elif self.loss_func == 'bernoulli':
|
|
581
591
|
if self.use_zeroinflate:
|
|
582
592
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -682,7 +692,8 @@ class DensityFlow(nn.Module):
|
|
|
682
692
|
else:
|
|
683
693
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
684
694
|
elif self.loss_func == 'multinomial':
|
|
685
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
695
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
696
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
686
697
|
elif self.loss_func == 'bernoulli':
|
|
687
698
|
if self.use_zeroinflate:
|
|
688
699
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -946,6 +957,9 @@ class DensityFlow(nn.Module):
|
|
|
946
957
|
if self.loss_func == 'bernoulli':
|
|
947
958
|
#counts = self.sigmoid(concentrate)
|
|
948
959
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
960
|
+
elif self.loss_func == 'multinomial':
|
|
961
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
962
|
+
counts = theta * library_size
|
|
949
963
|
else:
|
|
950
964
|
rate = concentrate.exp()
|
|
951
965
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
@@ -1340,7 +1354,7 @@ def main():
|
|
|
1340
1354
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1341
1355
|
|
|
1342
1356
|
###########################################
|
|
1343
|
-
|
|
1357
|
+
df = DensityFlow(
|
|
1344
1358
|
input_size=input_size,
|
|
1345
1359
|
cell_factor_size=cell_factor_size,
|
|
1346
1360
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1359,7 +1373,7 @@ def main():
|
|
|
1359
1373
|
dtype=dtype,
|
|
1360
1374
|
)
|
|
1361
1375
|
|
|
1362
|
-
|
|
1376
|
+
df.fit(xs, us=us,
|
|
1363
1377
|
num_epochs=args.num_epochs,
|
|
1364
1378
|
learning_rate=args.learning_rate,
|
|
1365
1379
|
batch_size=args.batch_size,
|
|
@@ -1371,9 +1385,9 @@ def main():
|
|
|
1371
1385
|
|
|
1372
1386
|
if args.save_model is not None:
|
|
1373
1387
|
if args.save_model.endswith('gz'):
|
|
1374
|
-
DensityFlow.save_model(
|
|
1388
|
+
DensityFlow.save_model(df, args.save_model, compression=True)
|
|
1375
1389
|
else:
|
|
1376
|
-
DensityFlow.save_model(
|
|
1390
|
+
DensityFlow.save_model(df, args.save_model)
|
|
1377
1391
|
|
|
1378
1392
|
|
|
1379
1393
|
|
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class EfficientTranscriptomeDecoder:
|
|
13
|
+
"""
|
|
14
|
+
High-performance, memory-efficient transcriptome decoder
|
|
15
|
+
Fixed version with corrected RMSNorm implementation
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
gene_dim: int = 60000,
|
|
21
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
|
+
bottleneck_dim: int = 256,
|
|
23
|
+
num_experts: int = 4,
|
|
24
|
+
dropout_rate: float = 0.1,
|
|
25
|
+
device: str = None):
|
|
26
|
+
"""
|
|
27
|
+
Advanced decoder combining multiple state-of-the-art techniques
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
latent_dim: Latent variable dimension
|
|
31
|
+
gene_dim: Number of genes (full transcriptome)
|
|
32
|
+
hidden_dims: Hidden layer dimensions
|
|
33
|
+
bottleneck_dim: Bottleneck dimension for memory efficiency
|
|
34
|
+
num_experts: Number of mixture-of-experts
|
|
35
|
+
dropout_rate: Dropout rate
|
|
36
|
+
device: Computation device
|
|
37
|
+
"""
|
|
38
|
+
self.latent_dim = latent_dim
|
|
39
|
+
self.gene_dim = gene_dim
|
|
40
|
+
self.hidden_dims = hidden_dims
|
|
41
|
+
self.bottleneck_dim = bottleneck_dim
|
|
42
|
+
self.num_experts = num_experts
|
|
43
|
+
self.dropout_rate = dropout_rate
|
|
44
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
45
|
+
|
|
46
|
+
# Initialize model with corrected architecture
|
|
47
|
+
self.model = self._build_corrected_model()
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
|
|
50
|
+
# Training state
|
|
51
|
+
self.is_trained = False
|
|
52
|
+
self.training_history = None
|
|
53
|
+
self.best_val_loss = float('inf')
|
|
54
|
+
|
|
55
|
+
print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
|
|
56
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
57
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
58
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
59
|
+
print(f" - Bottleneck Dimension: {bottleneck_dim}")
|
|
60
|
+
print(f" - Number of Experts: {num_experts}")
|
|
61
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
62
|
+
|
|
63
|
+
class CorrectedRMSNorm(nn.Module):
|
|
64
|
+
"""Corrected RMS Normalization with proper dimension handling"""
|
|
65
|
+
def __init__(self, dim: int, eps: float = 1e-8):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.eps = eps
|
|
68
|
+
self.dim = dim
|
|
69
|
+
self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
|
|
70
|
+
|
|
71
|
+
def forward(self, x):
|
|
72
|
+
# Ensure input has the right dimension
|
|
73
|
+
if x.size(-1) != self.dim:
|
|
74
|
+
raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
|
|
75
|
+
|
|
76
|
+
# Calculate RMS
|
|
77
|
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
|
78
|
+
# Normalize and apply weight
|
|
79
|
+
return x / rms * self.weight
|
|
80
|
+
|
|
81
|
+
class SimplifiedSwiGLU(nn.Module):
|
|
82
|
+
"""Simplified SwiGLU activation"""
|
|
83
|
+
def forward(self, x):
|
|
84
|
+
# Split into two parts
|
|
85
|
+
x, gate = x.chunk(2, dim=-1)
|
|
86
|
+
return x * F.silu(gate)
|
|
87
|
+
|
|
88
|
+
class MemoryEfficientBottleneck(nn.Module):
|
|
89
|
+
"""Memory-efficient bottleneck with corrected dimensions"""
|
|
90
|
+
def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
|
|
91
|
+
super().__init__()
|
|
92
|
+
# Ensure proper dimension matching
|
|
93
|
+
self.compress = nn.Linear(input_dim, bottleneck_dim)
|
|
94
|
+
self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
|
|
95
|
+
self.expand = nn.Linear(bottleneck_dim, output_dim)
|
|
96
|
+
self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
|
|
97
|
+
self.activation = nn.SiLU()
|
|
98
|
+
self.dropout = nn.Dropout(0.1)
|
|
99
|
+
|
|
100
|
+
def forward(self, x):
|
|
101
|
+
# Compress
|
|
102
|
+
compressed = self.compress(x)
|
|
103
|
+
compressed = self.norm1(compressed)
|
|
104
|
+
compressed = self.activation(compressed)
|
|
105
|
+
compressed = self.dropout(compressed)
|
|
106
|
+
|
|
107
|
+
# Expand
|
|
108
|
+
expanded = self.expand(compressed)
|
|
109
|
+
expanded = self.norm2(expanded)
|
|
110
|
+
|
|
111
|
+
return expanded
|
|
112
|
+
|
|
113
|
+
class StableMixtureOfExperts(nn.Module):
|
|
114
|
+
"""Stable mixture of experts without dimension issues"""
|
|
115
|
+
def __init__(self, input_dim: int, num_experts: int = 4):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.num_experts = num_experts
|
|
118
|
+
self.input_dim = input_dim
|
|
119
|
+
|
|
120
|
+
# Shared expert with different scaling factors
|
|
121
|
+
self.shared_expert = nn.Sequential(
|
|
122
|
+
nn.Linear(input_dim, input_dim * 2),
|
|
123
|
+
nn.SiLU(),
|
|
124
|
+
nn.Dropout(0.1),
|
|
125
|
+
nn.Linear(input_dim * 2, input_dim)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Gating network
|
|
129
|
+
self.gate = nn.Sequential(
|
|
130
|
+
nn.Linear(input_dim, num_experts * 4),
|
|
131
|
+
nn.SiLU(),
|
|
132
|
+
nn.Linear(num_experts * 4, num_experts)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def forward(self, x):
|
|
136
|
+
# Get gate weights
|
|
137
|
+
gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
|
|
138
|
+
|
|
139
|
+
# Process through shared expert
|
|
140
|
+
expert_output = self.shared_expert(x) # [batch, input_dim]
|
|
141
|
+
|
|
142
|
+
# Apply expert-specific scaling
|
|
143
|
+
weighted_output = torch.zeros_like(expert_output)
|
|
144
|
+
for i in range(self.num_experts):
|
|
145
|
+
expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
|
|
146
|
+
expert_contribution = expert_output * expert_scale
|
|
147
|
+
expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
|
|
148
|
+
weighted_output += expert_weight * expert_contribution
|
|
149
|
+
|
|
150
|
+
# Residual connection
|
|
151
|
+
return x + weighted_output
|
|
152
|
+
|
|
153
|
+
class CorrectedDecoder(nn.Module):
|
|
154
|
+
"""Corrected decoder with proper dimension handling"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
157
|
+
bottleneck_dim: int, num_experts: int, dropout_rate: float):
|
|
158
|
+
super().__init__()
|
|
159
|
+
|
|
160
|
+
# Input projection
|
|
161
|
+
self.input_projection = nn.Sequential(
|
|
162
|
+
nn.Linear(latent_dim, hidden_dims[0]),
|
|
163
|
+
EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
|
|
164
|
+
nn.SiLU(),
|
|
165
|
+
nn.Dropout(dropout_rate)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Main processing blocks
|
|
169
|
+
self.blocks = nn.ModuleList()
|
|
170
|
+
current_dim = hidden_dims[0]
|
|
171
|
+
|
|
172
|
+
for i, next_dim in enumerate(hidden_dims[1:], 1):
|
|
173
|
+
block = nn.ModuleDict({
|
|
174
|
+
'swiglu': nn.Sequential(
|
|
175
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
176
|
+
EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
|
|
177
|
+
nn.Dropout(dropout_rate),
|
|
178
|
+
nn.Linear(current_dim, current_dim) # Project back to same dimension
|
|
179
|
+
),
|
|
180
|
+
'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
|
|
181
|
+
current_dim, bottleneck_dim, next_dim
|
|
182
|
+
),
|
|
183
|
+
'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
|
|
184
|
+
next_dim, num_experts
|
|
185
|
+
)
|
|
186
|
+
})
|
|
187
|
+
self.blocks.append(block)
|
|
188
|
+
current_dim = next_dim
|
|
189
|
+
|
|
190
|
+
# Final projection to gene dimension
|
|
191
|
+
self.final_projection = nn.Sequential(
|
|
192
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
193
|
+
nn.SiLU(),
|
|
194
|
+
nn.Dropout(dropout_rate),
|
|
195
|
+
nn.Linear(current_dim * 2, gene_dim)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Output parameters
|
|
199
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
200
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
201
|
+
|
|
202
|
+
self._init_weights()
|
|
203
|
+
|
|
204
|
+
def _init_weights(self):
|
|
205
|
+
"""Proper weight initialization"""
|
|
206
|
+
for module in self.modules():
|
|
207
|
+
if isinstance(module, nn.Linear):
|
|
208
|
+
nn.init.xavier_uniform_(module.weight)
|
|
209
|
+
if module.bias is not None:
|
|
210
|
+
nn.init.zeros_(module.bias)
|
|
211
|
+
|
|
212
|
+
def forward(self, x):
|
|
213
|
+
# Input projection
|
|
214
|
+
x = self.input_projection(x)
|
|
215
|
+
|
|
216
|
+
# Process through blocks
|
|
217
|
+
for block in self.blocks:
|
|
218
|
+
# SwiGLU with residual
|
|
219
|
+
residual = x
|
|
220
|
+
x_swiglu = block['swiglu'](x)
|
|
221
|
+
x = x + x_swiglu # Residual connection
|
|
222
|
+
|
|
223
|
+
# Bottleneck
|
|
224
|
+
x = block['bottleneck'](x)
|
|
225
|
+
|
|
226
|
+
# Mixture of Experts with residual
|
|
227
|
+
x = block['experts'](x)
|
|
228
|
+
|
|
229
|
+
# Final projection
|
|
230
|
+
x = self.final_projection(x)
|
|
231
|
+
|
|
232
|
+
# Ensure non-negative output
|
|
233
|
+
x = F.softplus(x * self.output_scale + self.output_bias)
|
|
234
|
+
|
|
235
|
+
return x
|
|
236
|
+
|
|
237
|
+
def _build_corrected_model(self):
|
|
238
|
+
"""Build the corrected model"""
|
|
239
|
+
return self.CorrectedDecoder(
|
|
240
|
+
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
241
|
+
self.bottleneck_dim, self.num_experts, self.dropout_rate
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def train(self,
|
|
245
|
+
train_latent: np.ndarray,
|
|
246
|
+
train_expression: np.ndarray,
|
|
247
|
+
val_latent: np.ndarray = None,
|
|
248
|
+
val_expression: np.ndarray = None,
|
|
249
|
+
batch_size: int = 32,
|
|
250
|
+
num_epochs: int = 100,
|
|
251
|
+
learning_rate: float = 1e-4,
|
|
252
|
+
checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
|
|
253
|
+
"""
|
|
254
|
+
Train the corrected decoder
|
|
255
|
+
"""
|
|
256
|
+
print("🚀 Starting Training...")
|
|
257
|
+
|
|
258
|
+
# Data preparation
|
|
259
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
260
|
+
|
|
261
|
+
if val_latent is not None and val_expression is not None:
|
|
262
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
263
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
264
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
265
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
266
|
+
else:
|
|
267
|
+
# Auto split
|
|
268
|
+
train_size = int(0.9 * len(train_dataset))
|
|
269
|
+
val_size = len(train_dataset) - train_size
|
|
270
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
271
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
272
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
273
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
274
|
+
|
|
275
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
276
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
277
|
+
print(f"📊 Batch size: {batch_size}")
|
|
278
|
+
|
|
279
|
+
# Optimizer
|
|
280
|
+
optimizer = optim.AdamW(
|
|
281
|
+
self.model.parameters(),
|
|
282
|
+
lr=learning_rate,
|
|
283
|
+
weight_decay=0.01,
|
|
284
|
+
betas=(0.9, 0.999)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Scheduler
|
|
288
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
289
|
+
|
|
290
|
+
# Loss function
|
|
291
|
+
def combined_loss(pred, target):
|
|
292
|
+
mse_loss = F.mse_loss(pred, target)
|
|
293
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
294
|
+
correlation = self._pearson_correlation(pred, target)
|
|
295
|
+
correlation_loss = 1 - correlation
|
|
296
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
297
|
+
|
|
298
|
+
# Training history
|
|
299
|
+
history = {
|
|
300
|
+
'train_loss': [], 'val_loss': [],
|
|
301
|
+
'train_mse': [], 'val_mse': [],
|
|
302
|
+
'train_correlation': [], 'val_correlation': [],
|
|
303
|
+
'learning_rates': []
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
best_val_loss = float('inf')
|
|
307
|
+
patience = 20
|
|
308
|
+
patience_counter = 0
|
|
309
|
+
|
|
310
|
+
print("\n📈 Starting training loop...")
|
|
311
|
+
for epoch in range(1, num_epochs + 1):
|
|
312
|
+
# Training
|
|
313
|
+
train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
|
|
314
|
+
|
|
315
|
+
# Validation
|
|
316
|
+
val_metrics = self._validate_epoch(val_loader, combined_loss)
|
|
317
|
+
|
|
318
|
+
# Update scheduler
|
|
319
|
+
scheduler.step()
|
|
320
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
321
|
+
|
|
322
|
+
# Record history
|
|
323
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
324
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
325
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
326
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
327
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
328
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
329
|
+
history['learning_rates'].append(current_lr)
|
|
330
|
+
|
|
331
|
+
# Print progress
|
|
332
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
333
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
334
|
+
f"Train Loss: {train_metrics['loss']:.4f} | "
|
|
335
|
+
f"Val Loss: {val_metrics['loss']:.4f} | "
|
|
336
|
+
f"Correlation: {val_metrics['correlation']:.4f} | "
|
|
337
|
+
f"LR: {current_lr:.2e}")
|
|
338
|
+
|
|
339
|
+
# Early stopping
|
|
340
|
+
if val_metrics['loss'] < best_val_loss:
|
|
341
|
+
best_val_loss = val_metrics['loss']
|
|
342
|
+
patience_counter = 0
|
|
343
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
344
|
+
if epoch % 20 == 0:
|
|
345
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
346
|
+
else:
|
|
347
|
+
patience_counter += 1
|
|
348
|
+
if patience_counter >= patience:
|
|
349
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
# Training completed
|
|
353
|
+
self.is_trained = True
|
|
354
|
+
self.training_history = history
|
|
355
|
+
self.best_val_loss = best_val_loss
|
|
356
|
+
|
|
357
|
+
print(f"\n🎉 Training completed!")
|
|
358
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
359
|
+
|
|
360
|
+
return history
|
|
361
|
+
|
|
362
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
363
|
+
"""Create dataset"""
|
|
364
|
+
class SimpleDataset(Dataset):
|
|
365
|
+
def __init__(self, latent, expression):
|
|
366
|
+
self.latent = torch.FloatTensor(latent)
|
|
367
|
+
self.expression = torch.FloatTensor(expression)
|
|
368
|
+
|
|
369
|
+
def __len__(self):
|
|
370
|
+
return len(self.latent)
|
|
371
|
+
|
|
372
|
+
def __getitem__(self, idx):
|
|
373
|
+
return self.latent[idx], self.expression[idx]
|
|
374
|
+
|
|
375
|
+
return SimpleDataset(latent_data, expression_data)
|
|
376
|
+
|
|
377
|
+
def _pearson_correlation(self, pred, target):
|
|
378
|
+
"""Calculate Pearson correlation"""
|
|
379
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
380
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
381
|
+
|
|
382
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
383
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
384
|
+
|
|
385
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
386
|
+
|
|
387
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
388
|
+
"""Train one epoch"""
|
|
389
|
+
self.model.train()
|
|
390
|
+
total_loss = 0
|
|
391
|
+
total_mse = 0
|
|
392
|
+
total_correlation = 0
|
|
393
|
+
|
|
394
|
+
for latent, target in train_loader:
|
|
395
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
396
|
+
target = target.to(self.device, non_blocking=True)
|
|
397
|
+
|
|
398
|
+
optimizer.zero_grad()
|
|
399
|
+
pred = self.model(latent)
|
|
400
|
+
|
|
401
|
+
loss = loss_fn(pred, target)
|
|
402
|
+
loss.backward()
|
|
403
|
+
|
|
404
|
+
# Gradient clipping
|
|
405
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
406
|
+
optimizer.step()
|
|
407
|
+
|
|
408
|
+
# Calculate metrics
|
|
409
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
410
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
411
|
+
|
|
412
|
+
total_loss += loss.item()
|
|
413
|
+
total_mse += mse_loss
|
|
414
|
+
total_correlation += correlation
|
|
415
|
+
|
|
416
|
+
num_batches = len(train_loader)
|
|
417
|
+
return {
|
|
418
|
+
'loss': total_loss / num_batches,
|
|
419
|
+
'mse': total_mse / num_batches,
|
|
420
|
+
'correlation': total_correlation / num_batches
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
424
|
+
"""Validate one epoch"""
|
|
425
|
+
self.model.eval()
|
|
426
|
+
total_loss = 0
|
|
427
|
+
total_mse = 0
|
|
428
|
+
total_correlation = 0
|
|
429
|
+
|
|
430
|
+
with torch.no_grad():
|
|
431
|
+
for latent, target in val_loader:
|
|
432
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
433
|
+
target = target.to(self.device, non_blocking=True)
|
|
434
|
+
|
|
435
|
+
pred = self.model(latent)
|
|
436
|
+
loss = loss_fn(pred, target)
|
|
437
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
438
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
439
|
+
|
|
440
|
+
total_loss += loss.item()
|
|
441
|
+
total_mse += mse_loss
|
|
442
|
+
total_correlation += correlation
|
|
443
|
+
|
|
444
|
+
num_batches = len(val_loader)
|
|
445
|
+
return {
|
|
446
|
+
'loss': total_loss / num_batches,
|
|
447
|
+
'mse': total_mse / num_batches,
|
|
448
|
+
'correlation': total_correlation / num_batches
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
452
|
+
"""Save checkpoint"""
|
|
453
|
+
torch.save({
|
|
454
|
+
'epoch': epoch,
|
|
455
|
+
'model_state_dict': self.model.state_dict(),
|
|
456
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
457
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
458
|
+
'best_val_loss': best_loss,
|
|
459
|
+
'training_history': history,
|
|
460
|
+
'model_config': {
|
|
461
|
+
'latent_dim': self.latent_dim,
|
|
462
|
+
'gene_dim': self.gene_dim,
|
|
463
|
+
'hidden_dims': self.hidden_dims,
|
|
464
|
+
'bottleneck_dim': self.bottleneck_dim,
|
|
465
|
+
'num_experts': self.num_experts
|
|
466
|
+
}
|
|
467
|
+
}, path)
|
|
468
|
+
|
|
469
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
470
|
+
"""Predict gene expression"""
|
|
471
|
+
if not self.is_trained:
|
|
472
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
473
|
+
|
|
474
|
+
self.model.eval()
|
|
475
|
+
|
|
476
|
+
if isinstance(latent_data, np.ndarray):
|
|
477
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
478
|
+
|
|
479
|
+
predictions = []
|
|
480
|
+
with torch.no_grad():
|
|
481
|
+
for i in range(0, len(latent_data), batch_size):
|
|
482
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
483
|
+
batch_pred = self.model(batch_latent)
|
|
484
|
+
predictions.append(batch_pred.cpu())
|
|
485
|
+
|
|
486
|
+
return torch.cat(predictions).numpy()
|
|
487
|
+
|
|
488
|
+
def load_model(self, model_path: str):
|
|
489
|
+
"""Load pre-trained model"""
|
|
490
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
491
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
492
|
+
self.is_trained = True
|
|
493
|
+
self.training_history = checkpoint.get('training_history')
|
|
494
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
495
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
496
|
+
|
|
497
|
+
def get_model_info(self) -> Dict:
|
|
498
|
+
"""Get model information"""
|
|
499
|
+
return {
|
|
500
|
+
'is_trained': self.is_trained,
|
|
501
|
+
'best_val_loss': self.best_val_loss,
|
|
502
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
503
|
+
'latent_dim': self.latent_dim,
|
|
504
|
+
'gene_dim': self.gene_dim,
|
|
505
|
+
'hidden_dims': self.hidden_dims,
|
|
506
|
+
'device': str(self.device)
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
'''
|
|
510
|
+
# Example usage
|
|
511
|
+
def example_usage():
|
|
512
|
+
"""Example demonstration"""
|
|
513
|
+
|
|
514
|
+
# Initialize decoder
|
|
515
|
+
decoder = EfficientTranscriptomeDecoder(
|
|
516
|
+
latent_dim=100,
|
|
517
|
+
gene_dim=2000, # Reduced for example
|
|
518
|
+
hidden_dims=[256, 512, 1024],
|
|
519
|
+
bottleneck_dim=128,
|
|
520
|
+
num_experts=4,
|
|
521
|
+
dropout_rate=0.1
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Generate example data
|
|
525
|
+
n_samples = 1000
|
|
526
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
527
|
+
|
|
528
|
+
# Simulate expression data
|
|
529
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
530
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
531
|
+
expression_data = np.maximum(expression_data, 0)
|
|
532
|
+
|
|
533
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
534
|
+
|
|
535
|
+
# Train
|
|
536
|
+
history = decoder.train(
|
|
537
|
+
train_latent=latent_data,
|
|
538
|
+
train_expression=expression_data,
|
|
539
|
+
batch_size=32,
|
|
540
|
+
num_epochs=50
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Predict
|
|
544
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
545
|
+
predictions = decoder.predict(test_latent)
|
|
546
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
547
|
+
|
|
548
|
+
return decoder
|
|
549
|
+
|
|
550
|
+
if __name__ == "__main__":
|
|
551
|
+
example_usage()
|
|
552
|
+
'''
|