SURE-tools 2.2.23__tar.gz → 2.4.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.2.23 → sure_tools-2.4.17}/PKG-INFO +1 -1
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/DensityFlow.py +17 -10
- sure_tools-2.4.17/SURE/EfficientTranscriptomeDecoder.py +552 -0
- sure_tools-2.4.17/SURE/PerturbE.py +1300 -0
- sure_tools-2.4.17/SURE/SimpleTranscriptomeDecoder.py +567 -0
- sure_tools-2.4.17/SURE/TranscriptomeDecoder.py +511 -0
- sure_tools-2.4.17/SURE/VirtualCellDecoder.py +659 -0
- sure_tools-2.4.17/SURE/__init__.py +23 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/SOURCES.txt +5 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/setup.py +1 -1
- sure_tools-2.2.23/SURE/__init__.py +0 -12
- {sure_tools-2.2.23 → sure_tools-2.4.17}/LICENSE +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/README.md +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/SURE.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.23 → sure_tools-2.4.17}/setup.cfg +0 -0
|
@@ -396,7 +396,8 @@ class DensityFlow(nn.Module):
|
|
|
396
396
|
else:
|
|
397
397
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
398
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
399
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
400
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
400
401
|
elif self.loss_func == 'bernoulli':
|
|
401
402
|
if self.use_zeroinflate:
|
|
402
403
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -480,7 +481,8 @@ class DensityFlow(nn.Module):
|
|
|
480
481
|
else:
|
|
481
482
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
482
483
|
elif self.loss_func == 'multinomial':
|
|
483
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
484
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
485
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
484
486
|
elif self.loss_func == 'bernoulli':
|
|
485
487
|
if self.use_zeroinflate:
|
|
486
488
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -576,7 +578,8 @@ class DensityFlow(nn.Module):
|
|
|
576
578
|
else:
|
|
577
579
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
578
580
|
elif self.loss_func == 'multinomial':
|
|
579
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
581
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
582
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
580
583
|
elif self.loss_func == 'bernoulli':
|
|
581
584
|
if self.use_zeroinflate:
|
|
582
585
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -682,7 +685,8 @@ class DensityFlow(nn.Module):
|
|
|
682
685
|
else:
|
|
683
686
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
684
687
|
elif self.loss_func == 'multinomial':
|
|
685
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
688
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
689
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
686
690
|
elif self.loss_func == 'bernoulli':
|
|
687
691
|
if self.use_zeroinflate:
|
|
688
692
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -854,12 +858,12 @@ class DensityFlow(nn.Module):
|
|
|
854
858
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
855
859
|
|
|
856
860
|
# factor effect of xs
|
|
857
|
-
dzs0 = self.get_cell_response(zs,
|
|
861
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
858
862
|
|
|
859
863
|
# perturbation effect
|
|
860
864
|
ps = np.ones_like(us_i)
|
|
861
865
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
-
dzs = self.get_cell_response(zs,
|
|
866
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
863
867
|
zs = zs + dzs0 + dzs
|
|
864
868
|
else:
|
|
865
869
|
zs = zs + dzs0
|
|
@@ -946,6 +950,9 @@ class DensityFlow(nn.Module):
|
|
|
946
950
|
if self.loss_func == 'bernoulli':
|
|
947
951
|
#counts = self.sigmoid(concentrate)
|
|
948
952
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
953
|
+
elif self.loss_func == 'multinomial':
|
|
954
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
955
|
+
counts = theta * library_size
|
|
949
956
|
else:
|
|
950
957
|
rate = concentrate.exp()
|
|
951
958
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
@@ -1340,7 +1347,7 @@ def main():
|
|
|
1340
1347
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1341
1348
|
|
|
1342
1349
|
###########################################
|
|
1343
|
-
|
|
1350
|
+
df = DensityFlow(
|
|
1344
1351
|
input_size=input_size,
|
|
1345
1352
|
cell_factor_size=cell_factor_size,
|
|
1346
1353
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1359,7 +1366,7 @@ def main():
|
|
|
1359
1366
|
dtype=dtype,
|
|
1360
1367
|
)
|
|
1361
1368
|
|
|
1362
|
-
|
|
1369
|
+
df.fit(xs, us=us,
|
|
1363
1370
|
num_epochs=args.num_epochs,
|
|
1364
1371
|
learning_rate=args.learning_rate,
|
|
1365
1372
|
batch_size=args.batch_size,
|
|
@@ -1371,9 +1378,9 @@ def main():
|
|
|
1371
1378
|
|
|
1372
1379
|
if args.save_model is not None:
|
|
1373
1380
|
if args.save_model.endswith('gz'):
|
|
1374
|
-
DensityFlow.save_model(
|
|
1381
|
+
DensityFlow.save_model(df, args.save_model, compression=True)
|
|
1375
1382
|
else:
|
|
1376
|
-
DensityFlow.save_model(
|
|
1383
|
+
DensityFlow.save_model(df, args.save_model)
|
|
1377
1384
|
|
|
1378
1385
|
|
|
1379
1386
|
|
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class EfficientTranscriptomeDecoder:
|
|
13
|
+
"""
|
|
14
|
+
High-performance, memory-efficient transcriptome decoder
|
|
15
|
+
Fixed version with corrected RMSNorm implementation
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
gene_dim: int = 60000,
|
|
21
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
|
+
bottleneck_dim: int = 256,
|
|
23
|
+
num_experts: int = 4,
|
|
24
|
+
dropout_rate: float = 0.1,
|
|
25
|
+
device: str = None):
|
|
26
|
+
"""
|
|
27
|
+
Advanced decoder combining multiple state-of-the-art techniques
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
latent_dim: Latent variable dimension
|
|
31
|
+
gene_dim: Number of genes (full transcriptome)
|
|
32
|
+
hidden_dims: Hidden layer dimensions
|
|
33
|
+
bottleneck_dim: Bottleneck dimension for memory efficiency
|
|
34
|
+
num_experts: Number of mixture-of-experts
|
|
35
|
+
dropout_rate: Dropout rate
|
|
36
|
+
device: Computation device
|
|
37
|
+
"""
|
|
38
|
+
self.latent_dim = latent_dim
|
|
39
|
+
self.gene_dim = gene_dim
|
|
40
|
+
self.hidden_dims = hidden_dims
|
|
41
|
+
self.bottleneck_dim = bottleneck_dim
|
|
42
|
+
self.num_experts = num_experts
|
|
43
|
+
self.dropout_rate = dropout_rate
|
|
44
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
45
|
+
|
|
46
|
+
# Initialize model with corrected architecture
|
|
47
|
+
self.model = self._build_corrected_model()
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
|
|
50
|
+
# Training state
|
|
51
|
+
self.is_trained = False
|
|
52
|
+
self.training_history = None
|
|
53
|
+
self.best_val_loss = float('inf')
|
|
54
|
+
|
|
55
|
+
print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
|
|
56
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
57
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
58
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
59
|
+
print(f" - Bottleneck Dimension: {bottleneck_dim}")
|
|
60
|
+
print(f" - Number of Experts: {num_experts}")
|
|
61
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
62
|
+
|
|
63
|
+
class CorrectedRMSNorm(nn.Module):
|
|
64
|
+
"""Corrected RMS Normalization with proper dimension handling"""
|
|
65
|
+
def __init__(self, dim: int, eps: float = 1e-8):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.eps = eps
|
|
68
|
+
self.dim = dim
|
|
69
|
+
self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
|
|
70
|
+
|
|
71
|
+
def forward(self, x):
|
|
72
|
+
# Ensure input has the right dimension
|
|
73
|
+
if x.size(-1) != self.dim:
|
|
74
|
+
raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
|
|
75
|
+
|
|
76
|
+
# Calculate RMS
|
|
77
|
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
|
78
|
+
# Normalize and apply weight
|
|
79
|
+
return x / rms * self.weight
|
|
80
|
+
|
|
81
|
+
class SimplifiedSwiGLU(nn.Module):
|
|
82
|
+
"""Simplified SwiGLU activation"""
|
|
83
|
+
def forward(self, x):
|
|
84
|
+
# Split into two parts
|
|
85
|
+
x, gate = x.chunk(2, dim=-1)
|
|
86
|
+
return x * F.silu(gate)
|
|
87
|
+
|
|
88
|
+
class MemoryEfficientBottleneck(nn.Module):
|
|
89
|
+
"""Memory-efficient bottleneck with corrected dimensions"""
|
|
90
|
+
def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
|
|
91
|
+
super().__init__()
|
|
92
|
+
# Ensure proper dimension matching
|
|
93
|
+
self.compress = nn.Linear(input_dim, bottleneck_dim)
|
|
94
|
+
self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
|
|
95
|
+
self.expand = nn.Linear(bottleneck_dim, output_dim)
|
|
96
|
+
self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
|
|
97
|
+
self.activation = nn.SiLU()
|
|
98
|
+
self.dropout = nn.Dropout(0.1)
|
|
99
|
+
|
|
100
|
+
def forward(self, x):
|
|
101
|
+
# Compress
|
|
102
|
+
compressed = self.compress(x)
|
|
103
|
+
compressed = self.norm1(compressed)
|
|
104
|
+
compressed = self.activation(compressed)
|
|
105
|
+
compressed = self.dropout(compressed)
|
|
106
|
+
|
|
107
|
+
# Expand
|
|
108
|
+
expanded = self.expand(compressed)
|
|
109
|
+
expanded = self.norm2(expanded)
|
|
110
|
+
|
|
111
|
+
return expanded
|
|
112
|
+
|
|
113
|
+
class StableMixtureOfExperts(nn.Module):
|
|
114
|
+
"""Stable mixture of experts without dimension issues"""
|
|
115
|
+
def __init__(self, input_dim: int, num_experts: int = 4):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.num_experts = num_experts
|
|
118
|
+
self.input_dim = input_dim
|
|
119
|
+
|
|
120
|
+
# Shared expert with different scaling factors
|
|
121
|
+
self.shared_expert = nn.Sequential(
|
|
122
|
+
nn.Linear(input_dim, input_dim * 2),
|
|
123
|
+
nn.SiLU(),
|
|
124
|
+
nn.Dropout(0.1),
|
|
125
|
+
nn.Linear(input_dim * 2, input_dim)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Gating network
|
|
129
|
+
self.gate = nn.Sequential(
|
|
130
|
+
nn.Linear(input_dim, num_experts * 4),
|
|
131
|
+
nn.SiLU(),
|
|
132
|
+
nn.Linear(num_experts * 4, num_experts)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def forward(self, x):
|
|
136
|
+
# Get gate weights
|
|
137
|
+
gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
|
|
138
|
+
|
|
139
|
+
# Process through shared expert
|
|
140
|
+
expert_output = self.shared_expert(x) # [batch, input_dim]
|
|
141
|
+
|
|
142
|
+
# Apply expert-specific scaling
|
|
143
|
+
weighted_output = torch.zeros_like(expert_output)
|
|
144
|
+
for i in range(self.num_experts):
|
|
145
|
+
expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
|
|
146
|
+
expert_contribution = expert_output * expert_scale
|
|
147
|
+
expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
|
|
148
|
+
weighted_output += expert_weight * expert_contribution
|
|
149
|
+
|
|
150
|
+
# Residual connection
|
|
151
|
+
return x + weighted_output
|
|
152
|
+
|
|
153
|
+
class CorrectedDecoder(nn.Module):
|
|
154
|
+
"""Corrected decoder with proper dimension handling"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
157
|
+
bottleneck_dim: int, num_experts: int, dropout_rate: float):
|
|
158
|
+
super().__init__()
|
|
159
|
+
|
|
160
|
+
# Input projection
|
|
161
|
+
self.input_projection = nn.Sequential(
|
|
162
|
+
nn.Linear(latent_dim, hidden_dims[0]),
|
|
163
|
+
EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
|
|
164
|
+
nn.SiLU(),
|
|
165
|
+
nn.Dropout(dropout_rate)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Main processing blocks
|
|
169
|
+
self.blocks = nn.ModuleList()
|
|
170
|
+
current_dim = hidden_dims[0]
|
|
171
|
+
|
|
172
|
+
for i, next_dim in enumerate(hidden_dims[1:], 1):
|
|
173
|
+
block = nn.ModuleDict({
|
|
174
|
+
'swiglu': nn.Sequential(
|
|
175
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
176
|
+
EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
|
|
177
|
+
nn.Dropout(dropout_rate),
|
|
178
|
+
nn.Linear(current_dim, current_dim) # Project back to same dimension
|
|
179
|
+
),
|
|
180
|
+
'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
|
|
181
|
+
current_dim, bottleneck_dim, next_dim
|
|
182
|
+
),
|
|
183
|
+
'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
|
|
184
|
+
next_dim, num_experts
|
|
185
|
+
)
|
|
186
|
+
})
|
|
187
|
+
self.blocks.append(block)
|
|
188
|
+
current_dim = next_dim
|
|
189
|
+
|
|
190
|
+
# Final projection to gene dimension
|
|
191
|
+
self.final_projection = nn.Sequential(
|
|
192
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
193
|
+
nn.SiLU(),
|
|
194
|
+
nn.Dropout(dropout_rate),
|
|
195
|
+
nn.Linear(current_dim * 2, gene_dim)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Output parameters
|
|
199
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
200
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
201
|
+
|
|
202
|
+
self._init_weights()
|
|
203
|
+
|
|
204
|
+
def _init_weights(self):
|
|
205
|
+
"""Proper weight initialization"""
|
|
206
|
+
for module in self.modules():
|
|
207
|
+
if isinstance(module, nn.Linear):
|
|
208
|
+
nn.init.xavier_uniform_(module.weight)
|
|
209
|
+
if module.bias is not None:
|
|
210
|
+
nn.init.zeros_(module.bias)
|
|
211
|
+
|
|
212
|
+
def forward(self, x):
|
|
213
|
+
# Input projection
|
|
214
|
+
x = self.input_projection(x)
|
|
215
|
+
|
|
216
|
+
# Process through blocks
|
|
217
|
+
for block in self.blocks:
|
|
218
|
+
# SwiGLU with residual
|
|
219
|
+
residual = x
|
|
220
|
+
x_swiglu = block['swiglu'](x)
|
|
221
|
+
x = x + x_swiglu # Residual connection
|
|
222
|
+
|
|
223
|
+
# Bottleneck
|
|
224
|
+
x = block['bottleneck'](x)
|
|
225
|
+
|
|
226
|
+
# Mixture of Experts with residual
|
|
227
|
+
x = block['experts'](x)
|
|
228
|
+
|
|
229
|
+
# Final projection
|
|
230
|
+
x = self.final_projection(x)
|
|
231
|
+
|
|
232
|
+
# Ensure non-negative output
|
|
233
|
+
x = F.softplus(x * self.output_scale + self.output_bias)
|
|
234
|
+
|
|
235
|
+
return x
|
|
236
|
+
|
|
237
|
+
def _build_corrected_model(self):
|
|
238
|
+
"""Build the corrected model"""
|
|
239
|
+
return self.CorrectedDecoder(
|
|
240
|
+
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
241
|
+
self.bottleneck_dim, self.num_experts, self.dropout_rate
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def train(self,
|
|
245
|
+
train_latent: np.ndarray,
|
|
246
|
+
train_expression: np.ndarray,
|
|
247
|
+
val_latent: np.ndarray = None,
|
|
248
|
+
val_expression: np.ndarray = None,
|
|
249
|
+
batch_size: int = 32,
|
|
250
|
+
num_epochs: int = 100,
|
|
251
|
+
learning_rate: float = 1e-4,
|
|
252
|
+
checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
|
|
253
|
+
"""
|
|
254
|
+
Train the corrected decoder
|
|
255
|
+
"""
|
|
256
|
+
print("🚀 Starting Training...")
|
|
257
|
+
|
|
258
|
+
# Data preparation
|
|
259
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
260
|
+
|
|
261
|
+
if val_latent is not None and val_expression is not None:
|
|
262
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
263
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
264
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
265
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
266
|
+
else:
|
|
267
|
+
# Auto split
|
|
268
|
+
train_size = int(0.9 * len(train_dataset))
|
|
269
|
+
val_size = len(train_dataset) - train_size
|
|
270
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
271
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
272
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
273
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
274
|
+
|
|
275
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
276
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
277
|
+
print(f"📊 Batch size: {batch_size}")
|
|
278
|
+
|
|
279
|
+
# Optimizer
|
|
280
|
+
optimizer = optim.AdamW(
|
|
281
|
+
self.model.parameters(),
|
|
282
|
+
lr=learning_rate,
|
|
283
|
+
weight_decay=0.01,
|
|
284
|
+
betas=(0.9, 0.999)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Scheduler
|
|
288
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
289
|
+
|
|
290
|
+
# Loss function
|
|
291
|
+
def combined_loss(pred, target):
|
|
292
|
+
mse_loss = F.mse_loss(pred, target)
|
|
293
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
294
|
+
correlation = self._pearson_correlation(pred, target)
|
|
295
|
+
correlation_loss = 1 - correlation
|
|
296
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
297
|
+
|
|
298
|
+
# Training history
|
|
299
|
+
history = {
|
|
300
|
+
'train_loss': [], 'val_loss': [],
|
|
301
|
+
'train_mse': [], 'val_mse': [],
|
|
302
|
+
'train_correlation': [], 'val_correlation': [],
|
|
303
|
+
'learning_rates': []
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
best_val_loss = float('inf')
|
|
307
|
+
patience = 20
|
|
308
|
+
patience_counter = 0
|
|
309
|
+
|
|
310
|
+
print("\n📈 Starting training loop...")
|
|
311
|
+
for epoch in range(1, num_epochs + 1):
|
|
312
|
+
# Training
|
|
313
|
+
train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
|
|
314
|
+
|
|
315
|
+
# Validation
|
|
316
|
+
val_metrics = self._validate_epoch(val_loader, combined_loss)
|
|
317
|
+
|
|
318
|
+
# Update scheduler
|
|
319
|
+
scheduler.step()
|
|
320
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
321
|
+
|
|
322
|
+
# Record history
|
|
323
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
324
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
325
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
326
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
327
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
328
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
329
|
+
history['learning_rates'].append(current_lr)
|
|
330
|
+
|
|
331
|
+
# Print progress
|
|
332
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
333
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
334
|
+
f"Train Loss: {train_metrics['loss']:.4f} | "
|
|
335
|
+
f"Val Loss: {val_metrics['loss']:.4f} | "
|
|
336
|
+
f"Correlation: {val_metrics['correlation']:.4f} | "
|
|
337
|
+
f"LR: {current_lr:.2e}")
|
|
338
|
+
|
|
339
|
+
# Early stopping
|
|
340
|
+
if val_metrics['loss'] < best_val_loss:
|
|
341
|
+
best_val_loss = val_metrics['loss']
|
|
342
|
+
patience_counter = 0
|
|
343
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
344
|
+
if epoch % 20 == 0:
|
|
345
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
346
|
+
else:
|
|
347
|
+
patience_counter += 1
|
|
348
|
+
if patience_counter >= patience:
|
|
349
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
# Training completed
|
|
353
|
+
self.is_trained = True
|
|
354
|
+
self.training_history = history
|
|
355
|
+
self.best_val_loss = best_val_loss
|
|
356
|
+
|
|
357
|
+
print(f"\n🎉 Training completed!")
|
|
358
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
359
|
+
|
|
360
|
+
return history
|
|
361
|
+
|
|
362
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
363
|
+
"""Create dataset"""
|
|
364
|
+
class SimpleDataset(Dataset):
|
|
365
|
+
def __init__(self, latent, expression):
|
|
366
|
+
self.latent = torch.FloatTensor(latent)
|
|
367
|
+
self.expression = torch.FloatTensor(expression)
|
|
368
|
+
|
|
369
|
+
def __len__(self):
|
|
370
|
+
return len(self.latent)
|
|
371
|
+
|
|
372
|
+
def __getitem__(self, idx):
|
|
373
|
+
return self.latent[idx], self.expression[idx]
|
|
374
|
+
|
|
375
|
+
return SimpleDataset(latent_data, expression_data)
|
|
376
|
+
|
|
377
|
+
def _pearson_correlation(self, pred, target):
|
|
378
|
+
"""Calculate Pearson correlation"""
|
|
379
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
380
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
381
|
+
|
|
382
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
383
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
384
|
+
|
|
385
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
386
|
+
|
|
387
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
388
|
+
"""Train one epoch"""
|
|
389
|
+
self.model.train()
|
|
390
|
+
total_loss = 0
|
|
391
|
+
total_mse = 0
|
|
392
|
+
total_correlation = 0
|
|
393
|
+
|
|
394
|
+
for latent, target in train_loader:
|
|
395
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
396
|
+
target = target.to(self.device, non_blocking=True)
|
|
397
|
+
|
|
398
|
+
optimizer.zero_grad()
|
|
399
|
+
pred = self.model(latent)
|
|
400
|
+
|
|
401
|
+
loss = loss_fn(pred, target)
|
|
402
|
+
loss.backward()
|
|
403
|
+
|
|
404
|
+
# Gradient clipping
|
|
405
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
406
|
+
optimizer.step()
|
|
407
|
+
|
|
408
|
+
# Calculate metrics
|
|
409
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
410
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
411
|
+
|
|
412
|
+
total_loss += loss.item()
|
|
413
|
+
total_mse += mse_loss
|
|
414
|
+
total_correlation += correlation
|
|
415
|
+
|
|
416
|
+
num_batches = len(train_loader)
|
|
417
|
+
return {
|
|
418
|
+
'loss': total_loss / num_batches,
|
|
419
|
+
'mse': total_mse / num_batches,
|
|
420
|
+
'correlation': total_correlation / num_batches
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
424
|
+
"""Validate one epoch"""
|
|
425
|
+
self.model.eval()
|
|
426
|
+
total_loss = 0
|
|
427
|
+
total_mse = 0
|
|
428
|
+
total_correlation = 0
|
|
429
|
+
|
|
430
|
+
with torch.no_grad():
|
|
431
|
+
for latent, target in val_loader:
|
|
432
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
433
|
+
target = target.to(self.device, non_blocking=True)
|
|
434
|
+
|
|
435
|
+
pred = self.model(latent)
|
|
436
|
+
loss = loss_fn(pred, target)
|
|
437
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
438
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
439
|
+
|
|
440
|
+
total_loss += loss.item()
|
|
441
|
+
total_mse += mse_loss
|
|
442
|
+
total_correlation += correlation
|
|
443
|
+
|
|
444
|
+
num_batches = len(val_loader)
|
|
445
|
+
return {
|
|
446
|
+
'loss': total_loss / num_batches,
|
|
447
|
+
'mse': total_mse / num_batches,
|
|
448
|
+
'correlation': total_correlation / num_batches
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
452
|
+
"""Save checkpoint"""
|
|
453
|
+
torch.save({
|
|
454
|
+
'epoch': epoch,
|
|
455
|
+
'model_state_dict': self.model.state_dict(),
|
|
456
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
457
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
458
|
+
'best_val_loss': best_loss,
|
|
459
|
+
'training_history': history,
|
|
460
|
+
'model_config': {
|
|
461
|
+
'latent_dim': self.latent_dim,
|
|
462
|
+
'gene_dim': self.gene_dim,
|
|
463
|
+
'hidden_dims': self.hidden_dims,
|
|
464
|
+
'bottleneck_dim': self.bottleneck_dim,
|
|
465
|
+
'num_experts': self.num_experts
|
|
466
|
+
}
|
|
467
|
+
}, path)
|
|
468
|
+
|
|
469
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
470
|
+
"""Predict gene expression"""
|
|
471
|
+
if not self.is_trained:
|
|
472
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
473
|
+
|
|
474
|
+
self.model.eval()
|
|
475
|
+
|
|
476
|
+
if isinstance(latent_data, np.ndarray):
|
|
477
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
478
|
+
|
|
479
|
+
predictions = []
|
|
480
|
+
with torch.no_grad():
|
|
481
|
+
for i in range(0, len(latent_data), batch_size):
|
|
482
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
483
|
+
batch_pred = self.model(batch_latent)
|
|
484
|
+
predictions.append(batch_pred.cpu())
|
|
485
|
+
|
|
486
|
+
return torch.cat(predictions).numpy()
|
|
487
|
+
|
|
488
|
+
def load_model(self, model_path: str):
|
|
489
|
+
"""Load pre-trained model"""
|
|
490
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
491
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
492
|
+
self.is_trained = True
|
|
493
|
+
self.training_history = checkpoint.get('training_history')
|
|
494
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
495
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
496
|
+
|
|
497
|
+
def get_model_info(self) -> Dict:
|
|
498
|
+
"""Get model information"""
|
|
499
|
+
return {
|
|
500
|
+
'is_trained': self.is_trained,
|
|
501
|
+
'best_val_loss': self.best_val_loss,
|
|
502
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
503
|
+
'latent_dim': self.latent_dim,
|
|
504
|
+
'gene_dim': self.gene_dim,
|
|
505
|
+
'hidden_dims': self.hidden_dims,
|
|
506
|
+
'device': str(self.device)
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
'''
|
|
510
|
+
# Example usage
|
|
511
|
+
def example_usage():
|
|
512
|
+
"""Example demonstration"""
|
|
513
|
+
|
|
514
|
+
# Initialize decoder
|
|
515
|
+
decoder = EfficientTranscriptomeDecoder(
|
|
516
|
+
latent_dim=100,
|
|
517
|
+
gene_dim=2000, # Reduced for example
|
|
518
|
+
hidden_dims=[256, 512, 1024],
|
|
519
|
+
bottleneck_dim=128,
|
|
520
|
+
num_experts=4,
|
|
521
|
+
dropout_rate=0.1
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Generate example data
|
|
525
|
+
n_samples = 1000
|
|
526
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
527
|
+
|
|
528
|
+
# Simulate expression data
|
|
529
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
530
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
531
|
+
expression_data = np.maximum(expression_data, 0)
|
|
532
|
+
|
|
533
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
534
|
+
|
|
535
|
+
# Train
|
|
536
|
+
history = decoder.train(
|
|
537
|
+
train_latent=latent_data,
|
|
538
|
+
train_expression=expression_data,
|
|
539
|
+
batch_size=32,
|
|
540
|
+
num_epochs=50
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Predict
|
|
544
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
545
|
+
predictions = decoder.predict(test_latent)
|
|
546
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
547
|
+
|
|
548
|
+
return decoder
|
|
549
|
+
|
|
550
|
+
if __name__ == "__main__":
|
|
551
|
+
example_usage()
|
|
552
|
+
'''
|