SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +130 -65
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbE.py +1300 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +511 -0
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +17 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/METADATA +1 -1
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/RECORD +17 -9
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class VirtualCellDecoder:
|
|
13
|
+
"""
|
|
14
|
+
Advanced transcriptome decoder based on Virtual Cell Challenge research
|
|
15
|
+
Optimized for latent-to-expression mapping with biological constraints
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
gene_dim: int = 60000,
|
|
21
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
|
+
biological_prior_dim: int = 256,
|
|
23
|
+
dropout_rate: float = 0.1,
|
|
24
|
+
device: str = None):
|
|
25
|
+
"""
|
|
26
|
+
State-of-the-art decoder based on Virtual Cell Challenge insights
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
latent_dim: Latent variable dimension (typically 50-100)
|
|
30
|
+
gene_dim: Number of genes (full transcriptome ~60,000)
|
|
31
|
+
hidden_dims: Hidden layer dimensions for progressive expansion
|
|
32
|
+
biological_prior_dim: Dimension for biological prior knowledge
|
|
33
|
+
dropout_rate: Dropout rate for regularization
|
|
34
|
+
device: Computation device
|
|
35
|
+
"""
|
|
36
|
+
self.latent_dim = latent_dim
|
|
37
|
+
self.gene_dim = gene_dim
|
|
38
|
+
self.hidden_dims = hidden_dims
|
|
39
|
+
self.biological_prior_dim = biological_prior_dim
|
|
40
|
+
self.dropout_rate = dropout_rate
|
|
41
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
42
|
+
|
|
43
|
+
# Initialize model with biological constraints
|
|
44
|
+
self.model = self._build_biological_model()
|
|
45
|
+
self.model.to(self.device)
|
|
46
|
+
|
|
47
|
+
# Training state
|
|
48
|
+
self.is_trained = False
|
|
49
|
+
self.training_history = None
|
|
50
|
+
self.best_val_loss = float('inf')
|
|
51
|
+
|
|
52
|
+
print(f"🧬 VirtualCellDecoder Initialized:")
|
|
53
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
54
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
55
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
56
|
+
print(f" - Biological Prior Dimension: {biological_prior_dim}")
|
|
57
|
+
print(f" - Device: {self.device}")
|
|
58
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
59
|
+
|
|
60
|
+
class BiologicalPriorNetwork(nn.Module):
|
|
61
|
+
"""Biological prior network based on gene regulatory knowledge"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, latent_dim: int, prior_dim: int, gene_dim: int):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.gene_dim = gene_dim
|
|
66
|
+
self.prior_dim = prior_dim
|
|
67
|
+
|
|
68
|
+
# Learnable gene regulatory matrix (sparse initialization)
|
|
69
|
+
self.regulatory_matrix = nn.Parameter(
|
|
70
|
+
torch.randn(gene_dim, prior_dim) * 0.01
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Latent to regulatory space projection
|
|
74
|
+
self.latent_to_regulatory = nn.Sequential(
|
|
75
|
+
nn.Linear(latent_dim, prior_dim * 2),
|
|
76
|
+
nn.ReLU(),
|
|
77
|
+
nn.Dropout(0.1),
|
|
78
|
+
nn.Linear(prior_dim * 2, prior_dim)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Regulatory to expression projection
|
|
82
|
+
self.regulatory_to_expression = nn.Sequential(
|
|
83
|
+
nn.Linear(prior_dim, prior_dim),
|
|
84
|
+
nn.ReLU(),
|
|
85
|
+
nn.Dropout(0.1),
|
|
86
|
+
nn.Linear(prior_dim, gene_dim)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
self._init_weights()
|
|
90
|
+
|
|
91
|
+
def _init_weights(self):
|
|
92
|
+
"""Initialize with biological constraints"""
|
|
93
|
+
# Sparse initialization for regulatory matrix
|
|
94
|
+
nn.init.sparse_(self.regulatory_matrix, sparsity=0.8)
|
|
95
|
+
for module in self.modules():
|
|
96
|
+
if isinstance(module, nn.Linear):
|
|
97
|
+
nn.init.xavier_uniform_(module.weight)
|
|
98
|
+
if module.bias is not None:
|
|
99
|
+
nn.init.zeros_(module.bias)
|
|
100
|
+
|
|
101
|
+
def forward(self, latent):
|
|
102
|
+
batch_size = latent.shape[0]
|
|
103
|
+
|
|
104
|
+
# Project latent to regulatory space
|
|
105
|
+
regulatory_factors = self.latent_to_regulatory(latent) # [batch, prior_dim]
|
|
106
|
+
|
|
107
|
+
# Apply regulatory matrix (gene-specific modulation)
|
|
108
|
+
regulatory_effect = torch.matmul(
|
|
109
|
+
regulatory_factors, self.regulatory_matrix.T # [batch, gene_dim]
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Final expression projection
|
|
113
|
+
expression_base = self.regulatory_to_expression(regulatory_factors)
|
|
114
|
+
|
|
115
|
+
# Combine regulatory effect with base expression
|
|
116
|
+
biological_prior = expression_base + regulatory_effect
|
|
117
|
+
|
|
118
|
+
return biological_prior
|
|
119
|
+
|
|
120
|
+
class GeneSpecificAttention(nn.Module):
|
|
121
|
+
"""Gene-specific attention mechanism for capturing co-expression patterns"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, gene_dim: int, attention_dim: int = 128, num_heads: int = 8):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.gene_dim = gene_dim
|
|
126
|
+
self.attention_dim = attention_dim
|
|
127
|
+
self.num_heads = num_heads
|
|
128
|
+
|
|
129
|
+
# Gene embeddings for attention
|
|
130
|
+
self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, attention_dim))
|
|
131
|
+
|
|
132
|
+
# Attention mechanism
|
|
133
|
+
self.query_proj = nn.Linear(attention_dim, attention_dim)
|
|
134
|
+
self.key_proj = nn.Linear(attention_dim, attention_dim)
|
|
135
|
+
self.value_proj = nn.Linear(attention_dim, attention_dim)
|
|
136
|
+
|
|
137
|
+
# Output projection
|
|
138
|
+
self.output_proj = nn.Linear(attention_dim, attention_dim)
|
|
139
|
+
|
|
140
|
+
self._init_weights()
|
|
141
|
+
|
|
142
|
+
def _init_weights(self):
|
|
143
|
+
"""Initialize attention weights"""
|
|
144
|
+
nn.init.xavier_uniform_(self.gene_embeddings)
|
|
145
|
+
for module in [self.query_proj, self.key_proj, self.value_proj, self.output_proj]:
|
|
146
|
+
nn.init.xavier_uniform_(module.weight)
|
|
147
|
+
if module.bias is not None:
|
|
148
|
+
nn.init.zeros_(module.bias)
|
|
149
|
+
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
batch_size = x.shape[0]
|
|
152
|
+
|
|
153
|
+
# Prepare gene embeddings
|
|
154
|
+
gene_embeds = self.gene_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
|
|
155
|
+
|
|
156
|
+
# Compute attention
|
|
157
|
+
Q = self.query_proj(gene_embeds)
|
|
158
|
+
K = self.key_proj(gene_embeds)
|
|
159
|
+
V = self.value_proj(gene_embeds)
|
|
160
|
+
|
|
161
|
+
# Multi-head attention
|
|
162
|
+
head_dim = self.attention_dim // self.num_heads
|
|
163
|
+
Q = Q.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
|
|
164
|
+
K = K.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
|
|
165
|
+
V = V.view(batch_size, self.gene_dim, self.num_heads, head_dim).transpose(1, 2)
|
|
166
|
+
|
|
167
|
+
# Scaled dot-product attention
|
|
168
|
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)
|
|
169
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
170
|
+
|
|
171
|
+
# Apply attention
|
|
172
|
+
attn_output = torch.matmul(attn_weights, V)
|
|
173
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, self.gene_dim, self.attention_dim)
|
|
174
|
+
|
|
175
|
+
# Output projection
|
|
176
|
+
output = self.output_proj(attn_output)
|
|
177
|
+
|
|
178
|
+
return output
|
|
179
|
+
|
|
180
|
+
class SparseActivation(nn.Module):
|
|
181
|
+
"""Sparse activation function for biological data"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, sparsity_target: float = 0.85):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.sparsity_target = sparsity_target
|
|
186
|
+
self.alpha = nn.Parameter(torch.tensor(1.0))
|
|
187
|
+
self.beta = nn.Parameter(torch.tensor(0.0))
|
|
188
|
+
|
|
189
|
+
def forward(self, x):
|
|
190
|
+
# Learnable softplus with sparsity constraint
|
|
191
|
+
activated = F.softplus(x * self.alpha + self.beta)
|
|
192
|
+
|
|
193
|
+
# Sparsity regularization (encourages biological sparsity)
|
|
194
|
+
sparsity_loss = (activated.mean() - self.sparsity_target) ** 2
|
|
195
|
+
self.sparsity_loss = sparsity_loss * 0.01 # Light regularization
|
|
196
|
+
|
|
197
|
+
return activated
|
|
198
|
+
|
|
199
|
+
class VirtualCellModel(nn.Module):
|
|
200
|
+
"""Main Virtual Cell Challenge inspired model"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
203
|
+
biological_prior_dim: int, dropout_rate: float):
|
|
204
|
+
super().__init__()
|
|
205
|
+
|
|
206
|
+
# Phase 1: Latent expansion with biological constraints
|
|
207
|
+
self.latent_expansion = nn.Sequential(
|
|
208
|
+
nn.Linear(latent_dim, hidden_dims[0]),
|
|
209
|
+
nn.BatchNorm1d(hidden_dims[0]),
|
|
210
|
+
nn.ReLU(),
|
|
211
|
+
nn.Dropout(dropout_rate),
|
|
212
|
+
nn.Linear(hidden_dims[0], hidden_dims[1]),
|
|
213
|
+
nn.BatchNorm1d(hidden_dims[1]),
|
|
214
|
+
nn.ReLU(),
|
|
215
|
+
nn.Dropout(dropout_rate)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Phase 2: Biological prior network
|
|
219
|
+
self.biological_prior = VirtualCellDecoder.BiologicalPriorNetwork(
|
|
220
|
+
hidden_dims[1], biological_prior_dim, gene_dim
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Phase 3: Gene-specific processing
|
|
224
|
+
self.gene_attention = VirtualCellDecoder.GeneSpecificAttention(gene_dim)
|
|
225
|
+
|
|
226
|
+
# Phase 4: Final expression refinement
|
|
227
|
+
self.expression_refinement = nn.Sequential(
|
|
228
|
+
nn.Linear(gene_dim, hidden_dims[2]),
|
|
229
|
+
nn.BatchNorm1d(hidden_dims[2]),
|
|
230
|
+
nn.ReLU(),
|
|
231
|
+
nn.Dropout(dropout_rate),
|
|
232
|
+
nn.Linear(hidden_dims[2], gene_dim)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Phase 5: Sparse activation
|
|
236
|
+
self.sparse_activation = VirtualCellDecoder.SparseActivation()
|
|
237
|
+
|
|
238
|
+
self._init_weights()
|
|
239
|
+
|
|
240
|
+
def _init_weights(self):
|
|
241
|
+
"""Biological-inspired weight initialization"""
|
|
242
|
+
for module in self.modules():
|
|
243
|
+
if isinstance(module, nn.Linear):
|
|
244
|
+
# Xavier initialization for stable training
|
|
245
|
+
nn.init.xavier_uniform_(module.weight)
|
|
246
|
+
if module.bias is not None:
|
|
247
|
+
nn.init.zeros_(module.bias)
|
|
248
|
+
elif isinstance(module, nn.BatchNorm1d):
|
|
249
|
+
nn.init.ones_(module.weight)
|
|
250
|
+
nn.init.zeros_(module.bias)
|
|
251
|
+
|
|
252
|
+
def forward(self, latent):
|
|
253
|
+
# Phase 1: Latent expansion
|
|
254
|
+
expanded_latent = self.latent_expansion(latent)
|
|
255
|
+
|
|
256
|
+
# Phase 2: Biological prior
|
|
257
|
+
biological_output = self.biological_prior(expanded_latent)
|
|
258
|
+
|
|
259
|
+
# Phase 3: Gene attention
|
|
260
|
+
attention_output = self.gene_attention(biological_output)
|
|
261
|
+
|
|
262
|
+
# Phase 4: Refinement with residual connection
|
|
263
|
+
refined_output = self.expression_refinement(attention_output) + biological_output
|
|
264
|
+
|
|
265
|
+
# Phase 5: Sparse activation
|
|
266
|
+
final_output = self.sparse_activation(refined_output)
|
|
267
|
+
|
|
268
|
+
return final_output
|
|
269
|
+
|
|
270
|
+
def _build_biological_model(self):
|
|
271
|
+
"""Build the biologically constrained model"""
|
|
272
|
+
return self.VirtualCellModel(
|
|
273
|
+
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
274
|
+
self.biological_prior_dim, self.dropout_rate
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def train(self,
|
|
278
|
+
train_latent: np.ndarray,
|
|
279
|
+
train_expression: np.ndarray,
|
|
280
|
+
val_latent: np.ndarray = None,
|
|
281
|
+
val_expression: np.ndarray = None,
|
|
282
|
+
batch_size: int = 32,
|
|
283
|
+
num_epochs: int = 200,
|
|
284
|
+
learning_rate: float = 1e-4,
|
|
285
|
+
biological_weight: float = 0.1,
|
|
286
|
+
checkpoint_path: str = 'virtual_cell_decoder.pth') -> Dict:
|
|
287
|
+
"""
|
|
288
|
+
Train with biological constraints and Virtual Cell Challenge insights
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
train_latent: Training latent variables
|
|
292
|
+
train_expression: Training expression data
|
|
293
|
+
val_latent: Validation latent variables
|
|
294
|
+
val_expression: Validation expression data
|
|
295
|
+
batch_size: Batch size optimized for biological data
|
|
296
|
+
num_epochs: Number of training epochs
|
|
297
|
+
learning_rate: Learning rate
|
|
298
|
+
biological_weight: Weight for biological constraint loss
|
|
299
|
+
checkpoint_path: Model save path
|
|
300
|
+
"""
|
|
301
|
+
print("🧬 Starting Virtual Cell Challenge Training...")
|
|
302
|
+
print("📚 Incorporating biological constraints and regulatory priors")
|
|
303
|
+
|
|
304
|
+
# Data preparation
|
|
305
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
306
|
+
|
|
307
|
+
if val_latent is not None and val_expression is not None:
|
|
308
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
309
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
310
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
311
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
312
|
+
else:
|
|
313
|
+
# Auto split (90/10)
|
|
314
|
+
train_size = int(0.9 * len(train_dataset))
|
|
315
|
+
val_size = len(train_dataset) - train_size
|
|
316
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
317
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
|
|
318
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
|
|
319
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
320
|
+
|
|
321
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
322
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
323
|
+
print(f"📊 Batch size: {batch_size}")
|
|
324
|
+
|
|
325
|
+
# Optimizer with biological regularization
|
|
326
|
+
optimizer = optim.AdamW(
|
|
327
|
+
self.model.parameters(),
|
|
328
|
+
lr=learning_rate,
|
|
329
|
+
weight_decay=1e-5, # L2 regularization
|
|
330
|
+
betas=(0.9, 0.999)
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Cosine annealing with warmup
|
|
334
|
+
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
335
|
+
optimizer, T_0=50, T_mult=2, eta_min=1e-6
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Biological loss function
|
|
339
|
+
def biological_loss(pred, target):
|
|
340
|
+
# 1. Reconstruction loss
|
|
341
|
+
mse_loss = F.mse_loss(pred, target)
|
|
342
|
+
|
|
343
|
+
# 2. Poisson loss for count data
|
|
344
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
345
|
+
|
|
346
|
+
# 3. Correlation loss for pattern matching
|
|
347
|
+
correlation = self._pearson_correlation(pred, target)
|
|
348
|
+
correlation_loss = 1 - correlation
|
|
349
|
+
|
|
350
|
+
# 4. Sparsity loss (biological constraint)
|
|
351
|
+
sparsity_loss = self.model.sparse_activation.sparsity_loss
|
|
352
|
+
|
|
353
|
+
# 5. Biological consistency loss
|
|
354
|
+
biological_loss = self._biological_consistency_loss(pred)
|
|
355
|
+
|
|
356
|
+
total_loss = (mse_loss + 0.5 * poisson_loss + 0.3 * correlation_loss +
|
|
357
|
+
0.1 * sparsity_loss + biological_weight * biological_loss)
|
|
358
|
+
|
|
359
|
+
return total_loss, {
|
|
360
|
+
'mse': mse_loss.item(),
|
|
361
|
+
'poisson': poisson_loss.item(),
|
|
362
|
+
'correlation': correlation.item(),
|
|
363
|
+
'sparsity': sparsity_loss.item(),
|
|
364
|
+
'biological': biological_loss.item()
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
# Training history
|
|
368
|
+
history = {
|
|
369
|
+
'train_loss': [], 'val_loss': [],
|
|
370
|
+
'train_mse': [], 'val_mse': [],
|
|
371
|
+
'train_correlation': [], 'val_correlation': [],
|
|
372
|
+
'train_sparsity': [], 'val_sparsity': [],
|
|
373
|
+
'learning_rates': [], 'grad_norms': []
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
best_val_loss = float('inf')
|
|
377
|
+
patience = 25
|
|
378
|
+
patience_counter = 0
|
|
379
|
+
|
|
380
|
+
print("\n🔬 Starting training with biological constraints...")
|
|
381
|
+
for epoch in range(1, num_epochs + 1):
|
|
382
|
+
# Training phase
|
|
383
|
+
train_loss, train_components, grad_norm = self._train_epoch(
|
|
384
|
+
train_loader, optimizer, biological_loss
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Validation phase
|
|
388
|
+
val_loss, val_components = self._validate_epoch(val_loader, biological_loss)
|
|
389
|
+
|
|
390
|
+
# Update scheduler
|
|
391
|
+
scheduler.step()
|
|
392
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
393
|
+
|
|
394
|
+
# Record history
|
|
395
|
+
history['train_loss'].append(train_loss)
|
|
396
|
+
history['val_loss'].append(val_loss)
|
|
397
|
+
history['train_mse'].append(train_components['mse'])
|
|
398
|
+
history['val_mse'].append(val_components['mse'])
|
|
399
|
+
history['train_correlation'].append(train_components['correlation'])
|
|
400
|
+
history['val_correlation'].append(val_components['correlation'])
|
|
401
|
+
history['train_sparsity'].append(train_components['sparsity'])
|
|
402
|
+
history['val_sparsity'].append(val_components['sparsity'])
|
|
403
|
+
history['learning_rates'].append(current_lr)
|
|
404
|
+
history['grad_norms'].append(grad_norm)
|
|
405
|
+
|
|
406
|
+
# Print detailed progress
|
|
407
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
408
|
+
print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
|
|
409
|
+
f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
|
|
410
|
+
f"Corr: {val_components['correlation']:.4f} | "
|
|
411
|
+
f"Sparsity: {val_components['sparsity']:.4f} | "
|
|
412
|
+
f"LR: {current_lr:.2e}")
|
|
413
|
+
|
|
414
|
+
# Early stopping and model saving
|
|
415
|
+
if val_loss < best_val_loss:
|
|
416
|
+
best_val_loss = val_loss
|
|
417
|
+
patience_counter = 0
|
|
418
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
419
|
+
if epoch % 20 == 0:
|
|
420
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
421
|
+
else:
|
|
422
|
+
patience_counter += 1
|
|
423
|
+
if patience_counter >= patience:
|
|
424
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
425
|
+
break
|
|
426
|
+
|
|
427
|
+
# Training completed
|
|
428
|
+
self.is_trained = True
|
|
429
|
+
self.training_history = history
|
|
430
|
+
self.best_val_loss = best_val_loss
|
|
431
|
+
|
|
432
|
+
print(f"\n🎉 Training completed!")
|
|
433
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
434
|
+
print(f"📊 Final correlation: {history['val_correlation'][-1]:.4f}")
|
|
435
|
+
print(f"🌿 Final sparsity: {history['val_sparsity'][-1]:.4f}")
|
|
436
|
+
|
|
437
|
+
return history
|
|
438
|
+
|
|
439
|
+
def _biological_consistency_loss(self, pred):
|
|
440
|
+
"""Biological consistency loss based on Virtual Cell Challenge insights"""
|
|
441
|
+
# 1. Gene expression variance consistency
|
|
442
|
+
gene_variance = pred.var(dim=0)
|
|
443
|
+
target_variance = torch.ones_like(gene_variance) * 0.5 # Reasonable biological variance
|
|
444
|
+
variance_loss = F.mse_loss(gene_variance, target_variance)
|
|
445
|
+
|
|
446
|
+
# 2. Co-expression pattern consistency
|
|
447
|
+
correlation_matrix = torch.corrcoef(pred.T)
|
|
448
|
+
correlation_loss = torch.mean(torch.abs(correlation_matrix)) # Encourage moderate correlations
|
|
449
|
+
|
|
450
|
+
return variance_loss + 0.5 * correlation_loss
|
|
451
|
+
|
|
452
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
453
|
+
"""Create dataset with biological data validation"""
|
|
454
|
+
class BiologicalDataset(Dataset):
|
|
455
|
+
def __init__(self, latent, expression):
|
|
456
|
+
# Validate biological data characteristics
|
|
457
|
+
assert np.all(expression >= 0), "Expression data must be non-negative"
|
|
458
|
+
assert np.mean(expression == 0) > 0.7, "Expression data should be sparse (typical scRNA-seq)"
|
|
459
|
+
|
|
460
|
+
self.latent = torch.FloatTensor(latent)
|
|
461
|
+
self.expression = torch.FloatTensor(expression)
|
|
462
|
+
|
|
463
|
+
def __len__(self):
|
|
464
|
+
return len(self.latent)
|
|
465
|
+
|
|
466
|
+
def __getitem__(self, idx):
|
|
467
|
+
return self.latent[idx], self.expression[idx]
|
|
468
|
+
|
|
469
|
+
return BiologicalDataset(latent_data, expression_data)
|
|
470
|
+
|
|
471
|
+
def _pearson_correlation(self, pred, target):
|
|
472
|
+
"""Calculate Pearson correlation coefficient"""
|
|
473
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
474
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
475
|
+
|
|
476
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
477
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
478
|
+
|
|
479
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
480
|
+
|
|
481
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
482
|
+
"""Train one epoch with biological constraints"""
|
|
483
|
+
self.model.train()
|
|
484
|
+
total_loss = 0
|
|
485
|
+
total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
|
|
486
|
+
grad_norms = []
|
|
487
|
+
|
|
488
|
+
for latent, target in train_loader:
|
|
489
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
490
|
+
target = target.to(self.device, non_blocking=True)
|
|
491
|
+
|
|
492
|
+
optimizer.zero_grad()
|
|
493
|
+
pred = self.model(latent)
|
|
494
|
+
|
|
495
|
+
loss, components = loss_fn(pred, target)
|
|
496
|
+
loss.backward()
|
|
497
|
+
|
|
498
|
+
# Gradient clipping for stability
|
|
499
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
500
|
+
optimizer.step()
|
|
501
|
+
|
|
502
|
+
total_loss += loss.item()
|
|
503
|
+
for key in components:
|
|
504
|
+
total_components[key] += components[key]
|
|
505
|
+
grad_norms.append(grad_norm.item())
|
|
506
|
+
|
|
507
|
+
num_batches = len(train_loader)
|
|
508
|
+
avg_loss = total_loss / num_batches
|
|
509
|
+
avg_components = {key: value / num_batches for key, value in total_components.items()}
|
|
510
|
+
avg_grad_norm = np.mean(grad_norms)
|
|
511
|
+
|
|
512
|
+
return avg_loss, avg_components, avg_grad_norm
|
|
513
|
+
|
|
514
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
515
|
+
"""Validate one epoch"""
|
|
516
|
+
self.model.eval()
|
|
517
|
+
total_loss = 0
|
|
518
|
+
total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'biological': 0}
|
|
519
|
+
|
|
520
|
+
with torch.no_grad():
|
|
521
|
+
for latent, target in val_loader:
|
|
522
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
523
|
+
target = target.to(self.device, non_blocking=True)
|
|
524
|
+
|
|
525
|
+
pred = self.model(latent)
|
|
526
|
+
loss, components = loss_fn(pred, target)
|
|
527
|
+
|
|
528
|
+
total_loss += loss.item()
|
|
529
|
+
for key in components:
|
|
530
|
+
total_components[key] += components[key]
|
|
531
|
+
|
|
532
|
+
num_batches = len(val_loader)
|
|
533
|
+
avg_loss = total_loss / num_batches
|
|
534
|
+
avg_components = {key: value / num_batches for key, value in total_components.items()}
|
|
535
|
+
|
|
536
|
+
return avg_loss, avg_components
|
|
537
|
+
|
|
538
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
539
|
+
"""Save model checkpoint"""
|
|
540
|
+
torch.save({
|
|
541
|
+
'epoch': epoch,
|
|
542
|
+
'model_state_dict': self.model.state_dict(),
|
|
543
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
544
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
545
|
+
'best_val_loss': best_loss,
|
|
546
|
+
'training_history': history,
|
|
547
|
+
'model_config': {
|
|
548
|
+
'latent_dim': self.latent_dim,
|
|
549
|
+
'gene_dim': self.gene_dim,
|
|
550
|
+
'hidden_dims': self.hidden_dims,
|
|
551
|
+
'biological_prior_dim': self.biological_prior_dim,
|
|
552
|
+
'dropout_rate': self.dropout_rate
|
|
553
|
+
}
|
|
554
|
+
}, path)
|
|
555
|
+
|
|
556
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
557
|
+
"""
|
|
558
|
+
Predict gene expression with biological constraints
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
latent_data: Latent variables [n_samples, latent_dim]
|
|
562
|
+
batch_size: Prediction batch size
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
expression: Predicted expression [n_samples, gene_dim]
|
|
566
|
+
"""
|
|
567
|
+
if not self.is_trained:
|
|
568
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
569
|
+
|
|
570
|
+
self.model.eval()
|
|
571
|
+
|
|
572
|
+
if isinstance(latent_data, np.ndarray):
|
|
573
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
574
|
+
|
|
575
|
+
# Predict in batches
|
|
576
|
+
predictions = []
|
|
577
|
+
with torch.no_grad():
|
|
578
|
+
for i in range(0, len(latent_data), batch_size):
|
|
579
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
580
|
+
batch_pred = self.model(batch_latent)
|
|
581
|
+
predictions.append(batch_pred.cpu())
|
|
582
|
+
|
|
583
|
+
return torch.cat(predictions).numpy()
|
|
584
|
+
|
|
585
|
+
def load_model(self, model_path: str):
|
|
586
|
+
"""Load pre-trained model"""
|
|
587
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
588
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
589
|
+
self.is_trained = True
|
|
590
|
+
self.training_history = checkpoint.get('training_history')
|
|
591
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
592
|
+
print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
|
|
593
|
+
|
|
594
|
+
def get_model_info(self) -> Dict:
|
|
595
|
+
"""Get model information"""
|
|
596
|
+
return {
|
|
597
|
+
'is_trained': self.is_trained,
|
|
598
|
+
'best_val_loss': self.best_val_loss,
|
|
599
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
600
|
+
'latent_dim': self.latent_dim,
|
|
601
|
+
'gene_dim': self.gene_dim,
|
|
602
|
+
'hidden_dims': self.hidden_dims,
|
|
603
|
+
'biological_prior_dim': self.biological_prior_dim,
|
|
604
|
+
'device': str(self.device)
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
'''
|
|
608
|
+
# Example usage
|
|
609
|
+
def example_usage():
|
|
610
|
+
"""Example demonstration of Virtual Cell Challenge decoder"""
|
|
611
|
+
|
|
612
|
+
# Initialize decoder
|
|
613
|
+
decoder = VirtualCellDecoder(
|
|
614
|
+
latent_dim=100,
|
|
615
|
+
gene_dim=2000, # Reduced for example
|
|
616
|
+
hidden_dims=[256, 512, 1024],
|
|
617
|
+
biological_prior_dim=128,
|
|
618
|
+
dropout_rate=0.1
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Generate example data with biological characteristics
|
|
622
|
+
n_samples = 1000
|
|
623
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
624
|
+
|
|
625
|
+
# Simulate biological expression data (sparse, non-negative)
|
|
626
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
627
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
628
|
+
expression_data = np.maximum(expression_data, 0)
|
|
629
|
+
|
|
630
|
+
# Add biological sparsity (typical scRNA-seq characteristics)
|
|
631
|
+
mask = np.random.random(expression_data.shape) > 0.8 # 80% sparsity
|
|
632
|
+
expression_data[mask] = 0
|
|
633
|
+
|
|
634
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
635
|
+
print(f"🌿 Biological sparsity: {(expression_data == 0).mean():.3f}")
|
|
636
|
+
|
|
637
|
+
# Train with biological constraints
|
|
638
|
+
history = decoder.train(
|
|
639
|
+
train_latent=latent_data,
|
|
640
|
+
train_expression=expression_data,
|
|
641
|
+
batch_size=32,
|
|
642
|
+
num_epochs=50,
|
|
643
|
+
learning_rate=1e-4,
|
|
644
|
+
biological_weight=0.1
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Predict
|
|
648
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
649
|
+
predictions = decoder.predict(test_latent)
|
|
650
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
651
|
+
print(f"🌿 Predicted sparsity: {(predictions < 0.1).mean():.3f}")
|
|
652
|
+
|
|
653
|
+
return decoder
|
|
654
|
+
|
|
655
|
+
if __name__ == "__main__":
|
|
656
|
+
example_usage()
|
|
657
|
+
|
|
658
|
+
'''
|
SURE/__init__.py
CHANGED
|
@@ -1,12 +1,28 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
+
from .DensityFlowLinear import DensityFlowLinear
|
|
4
|
+
from .PerturbE import PerturbE
|
|
5
|
+
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
6
|
+
from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
|
|
7
|
+
from .EfficientTranscriptomeDecoder import EfficientTranscriptomeDecoder
|
|
8
|
+
from .VirtualCellDecoder import VirtualCellDecoder
|
|
9
|
+
from .PerturbationAwareDecoder import PerturbationAwareDecoder
|
|
3
10
|
|
|
4
11
|
from . import utils
|
|
5
12
|
from . import codebook
|
|
6
13
|
from . import SURE
|
|
7
14
|
from . import DensityFlow
|
|
15
|
+
from . import DensityFlowLinear
|
|
8
16
|
from . import atac
|
|
9
17
|
from . import flow
|
|
10
18
|
from . import perturb
|
|
19
|
+
from . import PerturbE
|
|
20
|
+
from . import TranscriptomeDecoder
|
|
21
|
+
from . import SimpleTranscriptomeDecoder
|
|
22
|
+
from . import EfficientTranscriptomeDecoder
|
|
23
|
+
from . import VirtualCellDecoder
|
|
24
|
+
from . import PerturbationAwareDecoder
|
|
11
25
|
|
|
12
|
-
__all__ = ['SURE', 'DensityFlow', '
|
|
26
|
+
__all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
27
|
+
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
|
|
28
|
+
'flow', 'perturb', 'atac', 'utils', 'codebook']
|