SURE-tools 2.4.7__py3-none-any.whl → 2.4.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +159 -70
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +13 -1
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +13 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/METADATA +1 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/RECORD +15 -9
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,567 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import warnings
|
|
9
|
+
warnings.filterwarnings('ignore')
|
|
10
|
+
|
|
11
|
+
class SimpleTranscriptomeDecoder:
|
|
12
|
+
"""MLP-based transcriptome decoder for latent to expression mapping"""
|
|
13
|
+
|
|
14
|
+
def __init__(self,
|
|
15
|
+
latent_dim: int = 100,
|
|
16
|
+
gene_dim: int = 60000,
|
|
17
|
+
hidden_dims: List[int] = [512, 1024, 2048, 4096],
|
|
18
|
+
dropout_rate: float = 0.1,
|
|
19
|
+
device: str = None):
|
|
20
|
+
"""
|
|
21
|
+
Multi-Layer Perceptron based decoder for transcriptome prediction
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
latent_dim: Latent variable dimension
|
|
25
|
+
gene_dim: Number of genes (full transcriptome)
|
|
26
|
+
hidden_dims: List of hidden layer dimensions
|
|
27
|
+
dropout_rate: Dropout rate for regularization
|
|
28
|
+
device: Computation device
|
|
29
|
+
"""
|
|
30
|
+
self.latent_dim = latent_dim
|
|
31
|
+
self.gene_dim = gene_dim
|
|
32
|
+
self.hidden_dims = hidden_dims
|
|
33
|
+
self.dropout_rate = dropout_rate
|
|
34
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
35
|
+
|
|
36
|
+
# Initialize model
|
|
37
|
+
self.model = self._build_mlp_model()
|
|
38
|
+
self.model.to(self.device)
|
|
39
|
+
|
|
40
|
+
# Training state
|
|
41
|
+
self.is_trained = False
|
|
42
|
+
self.training_history = None
|
|
43
|
+
self.best_val_loss = float('inf')
|
|
44
|
+
|
|
45
|
+
print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
|
|
46
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
47
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
48
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
49
|
+
print(f" - Dropout Rate: {dropout_rate}")
|
|
50
|
+
print(f" - Device: {self.device}")
|
|
51
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
52
|
+
|
|
53
|
+
class MLPModel(nn.Module):
|
|
54
|
+
"""MLP-based decoder architecture"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int], dropout_rate: float):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.latent_dim = latent_dim
|
|
59
|
+
self.gene_dim = gene_dim
|
|
60
|
+
|
|
61
|
+
# Build the MLP layers
|
|
62
|
+
layers = []
|
|
63
|
+
input_dim = latent_dim
|
|
64
|
+
|
|
65
|
+
# Encoder part: expand latent dimension
|
|
66
|
+
for hidden_dim in hidden_dims:
|
|
67
|
+
layers.extend([
|
|
68
|
+
nn.Linear(input_dim, hidden_dim),
|
|
69
|
+
nn.BatchNorm1d(hidden_dim),
|
|
70
|
+
nn.GELU(),
|
|
71
|
+
nn.Dropout(dropout_rate)
|
|
72
|
+
])
|
|
73
|
+
input_dim = hidden_dim
|
|
74
|
+
|
|
75
|
+
# Decoder part: project to gene dimension
|
|
76
|
+
# Reverse the hidden_dims for decoder
|
|
77
|
+
decoder_dims = hidden_dims[::-1]
|
|
78
|
+
for i, hidden_dim in enumerate(decoder_dims[1:], 1):
|
|
79
|
+
layers.extend([
|
|
80
|
+
nn.Linear(input_dim, hidden_dim),
|
|
81
|
+
nn.BatchNorm1d(hidden_dim),
|
|
82
|
+
nn.GELU(),
|
|
83
|
+
nn.Dropout(dropout_rate)
|
|
84
|
+
])
|
|
85
|
+
input_dim = hidden_dim
|
|
86
|
+
|
|
87
|
+
# Final projection to gene dimension
|
|
88
|
+
layers.append(nn.Linear(input_dim, gene_dim))
|
|
89
|
+
|
|
90
|
+
self.mlp_layers = nn.Sequential(*layers)
|
|
91
|
+
|
|
92
|
+
# Output scaling parameters
|
|
93
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
94
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
95
|
+
|
|
96
|
+
self._init_weights()
|
|
97
|
+
|
|
98
|
+
def _init_weights(self):
|
|
99
|
+
"""Weight initialization"""
|
|
100
|
+
for module in self.modules():
|
|
101
|
+
if isinstance(module, nn.Linear):
|
|
102
|
+
nn.init.xavier_uniform_(module.weight)
|
|
103
|
+
if module.bias is not None:
|
|
104
|
+
nn.init.zeros_(module.bias)
|
|
105
|
+
|
|
106
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
# Pass through MLP layers
|
|
108
|
+
output = self.mlp_layers(x)
|
|
109
|
+
|
|
110
|
+
# Ensure non-negative output with softplus
|
|
111
|
+
output = F.softplus(output * self.output_scale + self.output_bias)
|
|
112
|
+
|
|
113
|
+
return output
|
|
114
|
+
|
|
115
|
+
class ResidualMLPModel(nn.Module):
|
|
116
|
+
"""Residual MLP decoder with skip connections"""
|
|
117
|
+
|
|
118
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int], dropout_rate: float):
|
|
119
|
+
super().__init__()
|
|
120
|
+
self.latent_dim = latent_dim
|
|
121
|
+
self.gene_dim = gene_dim
|
|
122
|
+
|
|
123
|
+
# Build residual blocks
|
|
124
|
+
self.blocks = nn.ModuleList()
|
|
125
|
+
input_dim = latent_dim
|
|
126
|
+
|
|
127
|
+
for hidden_dim in hidden_dims:
|
|
128
|
+
block = self._build_residual_block(input_dim, hidden_dim, dropout_rate)
|
|
129
|
+
self.blocks.append(block)
|
|
130
|
+
input_dim = hidden_dim
|
|
131
|
+
|
|
132
|
+
# Final projection to gene dimension
|
|
133
|
+
self.final_projection = nn.Sequential(
|
|
134
|
+
nn.Linear(input_dim, input_dim // 2),
|
|
135
|
+
nn.GELU(),
|
|
136
|
+
nn.Dropout(dropout_rate),
|
|
137
|
+
nn.Linear(input_dim // 2, gene_dim)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Output parameters
|
|
141
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
142
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
143
|
+
|
|
144
|
+
self._init_weights()
|
|
145
|
+
|
|
146
|
+
def _build_residual_block(self, input_dim: int, hidden_dim: int, dropout_rate: float) -> nn.Module:
|
|
147
|
+
"""Build a residual block with skip connection"""
|
|
148
|
+
return nn.Sequential(
|
|
149
|
+
nn.Linear(input_dim, hidden_dim),
|
|
150
|
+
nn.BatchNorm1d(hidden_dim),
|
|
151
|
+
nn.GELU(),
|
|
152
|
+
nn.Dropout(dropout_rate),
|
|
153
|
+
nn.Linear(hidden_dim, hidden_dim), # Residual path
|
|
154
|
+
nn.BatchNorm1d(hidden_dim),
|
|
155
|
+
nn.GELU(),
|
|
156
|
+
nn.Dropout(dropout_rate),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _init_weights(self):
|
|
160
|
+
"""Weight initialization"""
|
|
161
|
+
for module in self.modules():
|
|
162
|
+
if isinstance(module, nn.Linear):
|
|
163
|
+
nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
|
|
164
|
+
if module.bias is not None:
|
|
165
|
+
nn.init.zeros_(module.bias)
|
|
166
|
+
|
|
167
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
# Initial projection
|
|
169
|
+
identity = x
|
|
170
|
+
|
|
171
|
+
for block in self.blocks:
|
|
172
|
+
# Residual connection
|
|
173
|
+
out = block(x)
|
|
174
|
+
# Skip connection if dimensions match, otherwise project
|
|
175
|
+
if out.shape[1] == identity.shape[1]:
|
|
176
|
+
x = out + identity
|
|
177
|
+
else:
|
|
178
|
+
x = out
|
|
179
|
+
identity = x
|
|
180
|
+
|
|
181
|
+
# Final projection
|
|
182
|
+
output = self.final_projection(x)
|
|
183
|
+
output = F.softplus(output * self.output_scale + self.output_bias)
|
|
184
|
+
|
|
185
|
+
return output
|
|
186
|
+
|
|
187
|
+
def _build_mlp_model(self):
|
|
188
|
+
"""Build the MLP model - 修正了方法名冲突"""
|
|
189
|
+
# Use simple MLP model for stability
|
|
190
|
+
return self.MLPModel(
|
|
191
|
+
self.latent_dim,
|
|
192
|
+
self.gene_dim,
|
|
193
|
+
self.hidden_dims,
|
|
194
|
+
self.dropout_rate
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def train(self,
|
|
198
|
+
train_latent: np.ndarray,
|
|
199
|
+
train_expression: np.ndarray,
|
|
200
|
+
val_latent: np.ndarray = None,
|
|
201
|
+
val_expression: np.ndarray = None,
|
|
202
|
+
batch_size: int = 32,
|
|
203
|
+
num_epochs: int = 100,
|
|
204
|
+
learning_rate: float = 1e-4,
|
|
205
|
+
weight_decay: float = 1e-5,
|
|
206
|
+
checkpoint_path: str = 'mlp_decoder.pth') -> Dict:
|
|
207
|
+
"""
|
|
208
|
+
Train the MLP decoder model
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
train_latent: Training latent variables [n_samples, latent_dim]
|
|
212
|
+
train_expression: Training expression data [n_samples, gene_dim]
|
|
213
|
+
val_latent: Validation latent variables (optional)
|
|
214
|
+
val_expression: Validation expression data (optional)
|
|
215
|
+
batch_size: Batch size for training
|
|
216
|
+
num_epochs: Number of training epochs
|
|
217
|
+
learning_rate: Learning rate
|
|
218
|
+
weight_decay: Weight decay for regularization
|
|
219
|
+
checkpoint_path: Path to save the best model
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Training history dictionary
|
|
223
|
+
"""
|
|
224
|
+
print("🚀 Starting MLP Decoder Training...")
|
|
225
|
+
|
|
226
|
+
# Data validation
|
|
227
|
+
self._validate_input_data(train_latent, train_expression, "Training")
|
|
228
|
+
if val_latent is not None and val_expression is not None:
|
|
229
|
+
self._validate_input_data(val_latent, val_expression, "Validation")
|
|
230
|
+
|
|
231
|
+
# Create datasets and data loaders
|
|
232
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
233
|
+
|
|
234
|
+
if val_latent is not None and val_expression is not None:
|
|
235
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
236
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
237
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
238
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
239
|
+
else:
|
|
240
|
+
# Auto split
|
|
241
|
+
train_size = int(0.9 * len(train_dataset))
|
|
242
|
+
val_size = len(train_dataset) - train_size
|
|
243
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
244
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
|
|
245
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
|
|
246
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
247
|
+
|
|
248
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
249
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
250
|
+
print(f"📊 Batch size: {batch_size}")
|
|
251
|
+
|
|
252
|
+
# Optimizer configuration
|
|
253
|
+
optimizer = optim.AdamW(
|
|
254
|
+
self.model.parameters(),
|
|
255
|
+
lr=learning_rate,
|
|
256
|
+
weight_decay=weight_decay,
|
|
257
|
+
betas=(0.9, 0.999)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Learning rate scheduler
|
|
261
|
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
|
262
|
+
optimizer,
|
|
263
|
+
max_lr=learning_rate * 10,
|
|
264
|
+
epochs=num_epochs,
|
|
265
|
+
steps_per_epoch=len(train_loader)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Loss function combining MSE and Poisson loss
|
|
269
|
+
def combined_loss(pred, target):
|
|
270
|
+
mse_loss = F.mse_loss(pred, target)
|
|
271
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
272
|
+
correlation_loss = 1 - self._pearson_correlation(pred, target)
|
|
273
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
274
|
+
|
|
275
|
+
# Training history
|
|
276
|
+
history = {
|
|
277
|
+
'train_loss': [], 'val_loss': [],
|
|
278
|
+
'train_mse': [], 'val_mse': [],
|
|
279
|
+
'train_correlation': [], 'val_correlation': [],
|
|
280
|
+
'learning_rates': []
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
best_val_loss = float('inf')
|
|
284
|
+
patience = 20
|
|
285
|
+
patience_counter = 0
|
|
286
|
+
|
|
287
|
+
print("\n📈 Starting training loop...")
|
|
288
|
+
for epoch in range(1, num_epochs + 1):
|
|
289
|
+
# Training phase
|
|
290
|
+
train_metrics = self._train_epoch(train_loader, optimizer, scheduler, combined_loss)
|
|
291
|
+
|
|
292
|
+
# Validation phase
|
|
293
|
+
val_metrics = self._validate_epoch(val_loader, combined_loss)
|
|
294
|
+
|
|
295
|
+
# Record history
|
|
296
|
+
history['train_loss'].append(train_metrics['total_loss'])
|
|
297
|
+
history['train_mse'].append(train_metrics['mse_loss'])
|
|
298
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
299
|
+
|
|
300
|
+
history['val_loss'].append(val_metrics['val_loss'])
|
|
301
|
+
history['val_mse'].append(val_metrics['val_mse'])
|
|
302
|
+
history['val_correlation'].append(val_metrics['val_correlation'])
|
|
303
|
+
|
|
304
|
+
history['learning_rates'].append(optimizer.param_groups[0]['lr'])
|
|
305
|
+
|
|
306
|
+
# Print progress
|
|
307
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
308
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
309
|
+
f"Train Loss: {train_metrics['total_loss']:.4f} | "
|
|
310
|
+
f"Val Loss: {val_metrics['val_loss']:.4f} | "
|
|
311
|
+
f"Correlation: {val_metrics['val_correlation']:.4f} | "
|
|
312
|
+
f"LR: {optimizer.param_groups[0]['lr']:.2e}")
|
|
313
|
+
|
|
314
|
+
# Early stopping and model saving
|
|
315
|
+
if val_metrics['val_loss'] < best_val_loss:
|
|
316
|
+
best_val_loss = val_metrics['val_loss']
|
|
317
|
+
patience_counter = 0
|
|
318
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
319
|
+
if epoch % 20 == 0:
|
|
320
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
321
|
+
else:
|
|
322
|
+
patience_counter += 1
|
|
323
|
+
if patience_counter >= patience:
|
|
324
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
# Training completed
|
|
328
|
+
self.is_trained = True
|
|
329
|
+
self.training_history = history
|
|
330
|
+
self.best_val_loss = best_val_loss
|
|
331
|
+
|
|
332
|
+
print(f"\n🎉 Training completed!")
|
|
333
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
334
|
+
|
|
335
|
+
return history
|
|
336
|
+
|
|
337
|
+
def _validate_input_data(self, latent_data: np.ndarray, expression_data: np.ndarray, data_type: str):
|
|
338
|
+
"""Validate input data dimensions and types"""
|
|
339
|
+
assert latent_data.shape[1] == self.latent_dim, \
|
|
340
|
+
f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}"
|
|
341
|
+
assert expression_data.shape[1] == self.gene_dim, \
|
|
342
|
+
f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}"
|
|
343
|
+
assert latent_data.shape[0] == expression_data.shape[0], \
|
|
344
|
+
f"{data_type} sample count mismatch"
|
|
345
|
+
print(f"✅ {data_type} data validated: {latent_data.shape[0]} samples")
|
|
346
|
+
|
|
347
|
+
def _create_dataset(self, latent_data: np.ndarray, expression_data: np.ndarray) -> Dataset:
|
|
348
|
+
"""Create PyTorch dataset"""
|
|
349
|
+
class TranscriptomeDataset(Dataset):
|
|
350
|
+
def __init__(self, latent, expression):
|
|
351
|
+
self.latent = torch.FloatTensor(latent)
|
|
352
|
+
self.expression = torch.FloatTensor(expression)
|
|
353
|
+
|
|
354
|
+
def __len__(self):
|
|
355
|
+
return len(self.latent)
|
|
356
|
+
|
|
357
|
+
def __getitem__(self, idx):
|
|
358
|
+
return self.latent[idx], self.expression[idx]
|
|
359
|
+
|
|
360
|
+
return TranscriptomeDataset(latent_data, expression_data)
|
|
361
|
+
|
|
362
|
+
def _pearson_correlation(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
363
|
+
"""Calculate Pearson correlation coefficient"""
|
|
364
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
365
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
366
|
+
|
|
367
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
368
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
369
|
+
|
|
370
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
371
|
+
|
|
372
|
+
def _train_epoch(self, train_loader, optimizer, scheduler, loss_fn):
|
|
373
|
+
"""Train for one epoch"""
|
|
374
|
+
self.model.train()
|
|
375
|
+
total_loss = 0
|
|
376
|
+
total_mse = 0
|
|
377
|
+
total_correlation = 0
|
|
378
|
+
|
|
379
|
+
for latent, target in train_loader:
|
|
380
|
+
latent = latent.to(self.device)
|
|
381
|
+
target = target.to(self.device)
|
|
382
|
+
|
|
383
|
+
optimizer.zero_grad()
|
|
384
|
+
pred = self.model(latent)
|
|
385
|
+
|
|
386
|
+
loss = loss_fn(pred, target)
|
|
387
|
+
loss.backward()
|
|
388
|
+
|
|
389
|
+
# Gradient clipping
|
|
390
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
391
|
+
optimizer.step()
|
|
392
|
+
scheduler.step()
|
|
393
|
+
|
|
394
|
+
# Calculate metrics
|
|
395
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
396
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
397
|
+
|
|
398
|
+
total_loss += loss.item()
|
|
399
|
+
total_mse += mse_loss
|
|
400
|
+
total_correlation += correlation
|
|
401
|
+
|
|
402
|
+
num_batches = len(train_loader)
|
|
403
|
+
return {
|
|
404
|
+
'total_loss': total_loss / num_batches,
|
|
405
|
+
'mse_loss': total_mse / num_batches,
|
|
406
|
+
'correlation': total_correlation / num_batches
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
410
|
+
"""Validate for one epoch"""
|
|
411
|
+
self.model.eval()
|
|
412
|
+
total_loss = 0
|
|
413
|
+
total_mse = 0
|
|
414
|
+
total_correlation = 0
|
|
415
|
+
|
|
416
|
+
with torch.no_grad():
|
|
417
|
+
for latent, target in val_loader:
|
|
418
|
+
latent = latent.to(self.device)
|
|
419
|
+
target = target.to(self.device)
|
|
420
|
+
|
|
421
|
+
pred = self.model(latent)
|
|
422
|
+
|
|
423
|
+
loss = loss_fn(pred, target)
|
|
424
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
425
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
426
|
+
|
|
427
|
+
total_loss += loss.item()
|
|
428
|
+
total_mse += mse_loss
|
|
429
|
+
total_correlation += correlation
|
|
430
|
+
|
|
431
|
+
num_batches = len(val_loader)
|
|
432
|
+
return {
|
|
433
|
+
'val_loss': total_loss / num_batches,
|
|
434
|
+
'val_mse': total_mse / num_batches,
|
|
435
|
+
'val_correlation': total_correlation / num_batches
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
439
|
+
"""Save model checkpoint"""
|
|
440
|
+
torch.save({
|
|
441
|
+
'epoch': epoch,
|
|
442
|
+
'model_state_dict': self.model.state_dict(),
|
|
443
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
444
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
445
|
+
'best_val_loss': best_loss,
|
|
446
|
+
'training_history': history,
|
|
447
|
+
'model_config': {
|
|
448
|
+
'latent_dim': self.latent_dim,
|
|
449
|
+
'gene_dim': self.gene_dim,
|
|
450
|
+
'hidden_dims': self.hidden_dims,
|
|
451
|
+
'dropout_rate': self.dropout_rate
|
|
452
|
+
}
|
|
453
|
+
}, path)
|
|
454
|
+
|
|
455
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
456
|
+
"""
|
|
457
|
+
Predict gene expression from latent variables
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
latent_data: Latent variables [n_samples, latent_dim]
|
|
461
|
+
batch_size: Prediction batch size
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
expression: Predicted expression [n_samples, gene_dim]
|
|
465
|
+
"""
|
|
466
|
+
if not self.is_trained:
|
|
467
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
468
|
+
|
|
469
|
+
self.model.eval()
|
|
470
|
+
|
|
471
|
+
if isinstance(latent_data, np.ndarray):
|
|
472
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
473
|
+
|
|
474
|
+
# Predict in batches
|
|
475
|
+
predictions = []
|
|
476
|
+
with torch.no_grad():
|
|
477
|
+
for i in range(0, len(latent_data), batch_size):
|
|
478
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
479
|
+
batch_pred = self.model(batch_latent)
|
|
480
|
+
predictions.append(batch_pred.cpu())
|
|
481
|
+
|
|
482
|
+
return torch.cat(predictions).numpy()
|
|
483
|
+
|
|
484
|
+
def load_model(self, model_path: str):
|
|
485
|
+
"""Load pre-trained model"""
|
|
486
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
487
|
+
|
|
488
|
+
# Check model configuration
|
|
489
|
+
if 'model_config' in checkpoint:
|
|
490
|
+
config = checkpoint['model_config']
|
|
491
|
+
if (config['latent_dim'] != self.latent_dim or
|
|
492
|
+
config['gene_dim'] != self.gene_dim):
|
|
493
|
+
print("⚠️ Model configuration mismatch. Reinitializing model.")
|
|
494
|
+
self.model = self._build_mlp_model()
|
|
495
|
+
self.model.to(self.device)
|
|
496
|
+
|
|
497
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
498
|
+
self.is_trained = True
|
|
499
|
+
self.training_history = checkpoint.get('training_history')
|
|
500
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
501
|
+
|
|
502
|
+
print(f"✅ Model loaded successfully!")
|
|
503
|
+
print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
|
|
504
|
+
|
|
505
|
+
def get_model_info(self) -> Dict:
|
|
506
|
+
"""Get model information"""
|
|
507
|
+
return {
|
|
508
|
+
'is_trained': self.is_trained,
|
|
509
|
+
'best_val_loss': self.best_val_loss,
|
|
510
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
511
|
+
'latent_dim': self.latent_dim,
|
|
512
|
+
'gene_dim': self.gene_dim,
|
|
513
|
+
'hidden_dims': self.hidden_dims,
|
|
514
|
+
'dropout_rate': self.dropout_rate,
|
|
515
|
+
'device': str(self.device)
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
'''
|
|
519
|
+
# Example usage
|
|
520
|
+
def example_usage():
|
|
521
|
+
"""Example demonstration of MLP decoder"""
|
|
522
|
+
|
|
523
|
+
# 1. Initialize decoder
|
|
524
|
+
decoder = SimpleTranscriptomeDecoder(
|
|
525
|
+
latent_dim=100,
|
|
526
|
+
gene_dim=2000, # Reduced for example
|
|
527
|
+
hidden_dims=[256, 512, 1024], # Progressive expansion
|
|
528
|
+
dropout_rate=0.1
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# 2. Generate example data
|
|
532
|
+
n_samples = 1000
|
|
533
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
534
|
+
|
|
535
|
+
# Create simulated expression data
|
|
536
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
537
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
538
|
+
expression_data = np.maximum(expression_data, 0) # Non-negative
|
|
539
|
+
|
|
540
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
541
|
+
|
|
542
|
+
# 3. Train the model
|
|
543
|
+
history = decoder.train(
|
|
544
|
+
train_latent=latent_data,
|
|
545
|
+
train_expression=expression_data,
|
|
546
|
+
batch_size=32,
|
|
547
|
+
num_epochs=50,
|
|
548
|
+
learning_rate=1e-4
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# 4. Make predictions
|
|
552
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
553
|
+
predictions = decoder.predict(test_latent)
|
|
554
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
555
|
+
|
|
556
|
+
# 5. Get model info
|
|
557
|
+
info = decoder.get_model_info()
|
|
558
|
+
print(f"\n📋 Model Info:")
|
|
559
|
+
for key, value in info.items():
|
|
560
|
+
print(f" {key}: {value}")
|
|
561
|
+
|
|
562
|
+
return decoder
|
|
563
|
+
|
|
564
|
+
if __name__ == "__main__":
|
|
565
|
+
example_usage()
|
|
566
|
+
|
|
567
|
+
'''
|
SURE/TranscriptomeDecoder.py
CHANGED
|
@@ -215,10 +215,22 @@ class TranscriptomeDecoder:
|
|
|
215
215
|
min_dim = min(pred.shape[1], target.shape[1])
|
|
216
216
|
pred = pred[:, :min_dim]
|
|
217
217
|
target = target[:, :min_dim]
|
|
218
|
+
|
|
219
|
+
def correlation_loss(pred, target):
|
|
220
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
221
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
222
|
+
|
|
223
|
+
correlation = (pred_centered * target_centered).sum(dim=1) / (
|
|
224
|
+
torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
|
|
225
|
+
torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return 1 - correlation.mean()
|
|
218
229
|
|
|
219
230
|
mse_loss = F.mse_loss(pred, target)
|
|
220
231
|
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
221
|
-
|
|
232
|
+
corr_loss = correlation_loss(pred, target)
|
|
233
|
+
return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
|
|
222
234
|
|
|
223
235
|
# Training history
|
|
224
236
|
history = {
|