SURE-tools 2.4.5__py3-none-any.whl → 2.4.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +273 -311
- SURE/VirtualCellDecoder.py +659 -0
- SURE/__init__.py +8 -1
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/METADATA +1 -1
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/RECORD +11 -8
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class EfficientTranscriptomeDecoder:
|
|
13
|
+
"""
|
|
14
|
+
High-performance, memory-efficient transcriptome decoder
|
|
15
|
+
Fixed version with corrected RMSNorm implementation
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
gene_dim: int = 60000,
|
|
21
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
|
+
bottleneck_dim: int = 256,
|
|
23
|
+
num_experts: int = 4,
|
|
24
|
+
dropout_rate: float = 0.1,
|
|
25
|
+
device: str = None):
|
|
26
|
+
"""
|
|
27
|
+
Advanced decoder combining multiple state-of-the-art techniques
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
latent_dim: Latent variable dimension
|
|
31
|
+
gene_dim: Number of genes (full transcriptome)
|
|
32
|
+
hidden_dims: Hidden layer dimensions
|
|
33
|
+
bottleneck_dim: Bottleneck dimension for memory efficiency
|
|
34
|
+
num_experts: Number of mixture-of-experts
|
|
35
|
+
dropout_rate: Dropout rate
|
|
36
|
+
device: Computation device
|
|
37
|
+
"""
|
|
38
|
+
self.latent_dim = latent_dim
|
|
39
|
+
self.gene_dim = gene_dim
|
|
40
|
+
self.hidden_dims = hidden_dims
|
|
41
|
+
self.bottleneck_dim = bottleneck_dim
|
|
42
|
+
self.num_experts = num_experts
|
|
43
|
+
self.dropout_rate = dropout_rate
|
|
44
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
45
|
+
|
|
46
|
+
# Initialize model with corrected architecture
|
|
47
|
+
self.model = self._build_corrected_model()
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
|
|
50
|
+
# Training state
|
|
51
|
+
self.is_trained = False
|
|
52
|
+
self.training_history = None
|
|
53
|
+
self.best_val_loss = float('inf')
|
|
54
|
+
|
|
55
|
+
print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
|
|
56
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
57
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
58
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
59
|
+
print(f" - Bottleneck Dimension: {bottleneck_dim}")
|
|
60
|
+
print(f" - Number of Experts: {num_experts}")
|
|
61
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
62
|
+
|
|
63
|
+
class CorrectedRMSNorm(nn.Module):
|
|
64
|
+
"""Corrected RMS Normalization with proper dimension handling"""
|
|
65
|
+
def __init__(self, dim: int, eps: float = 1e-8):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.eps = eps
|
|
68
|
+
self.dim = dim
|
|
69
|
+
self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
|
|
70
|
+
|
|
71
|
+
def forward(self, x):
|
|
72
|
+
# Ensure input has the right dimension
|
|
73
|
+
if x.size(-1) != self.dim:
|
|
74
|
+
raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
|
|
75
|
+
|
|
76
|
+
# Calculate RMS
|
|
77
|
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
|
78
|
+
# Normalize and apply weight
|
|
79
|
+
return x / rms * self.weight
|
|
80
|
+
|
|
81
|
+
class SimplifiedSwiGLU(nn.Module):
|
|
82
|
+
"""Simplified SwiGLU activation"""
|
|
83
|
+
def forward(self, x):
|
|
84
|
+
# Split into two parts
|
|
85
|
+
x, gate = x.chunk(2, dim=-1)
|
|
86
|
+
return x * F.silu(gate)
|
|
87
|
+
|
|
88
|
+
class MemoryEfficientBottleneck(nn.Module):
|
|
89
|
+
"""Memory-efficient bottleneck with corrected dimensions"""
|
|
90
|
+
def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
|
|
91
|
+
super().__init__()
|
|
92
|
+
# Ensure proper dimension matching
|
|
93
|
+
self.compress = nn.Linear(input_dim, bottleneck_dim)
|
|
94
|
+
self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
|
|
95
|
+
self.expand = nn.Linear(bottleneck_dim, output_dim)
|
|
96
|
+
self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
|
|
97
|
+
self.activation = nn.SiLU()
|
|
98
|
+
self.dropout = nn.Dropout(0.1)
|
|
99
|
+
|
|
100
|
+
def forward(self, x):
|
|
101
|
+
# Compress
|
|
102
|
+
compressed = self.compress(x)
|
|
103
|
+
compressed = self.norm1(compressed)
|
|
104
|
+
compressed = self.activation(compressed)
|
|
105
|
+
compressed = self.dropout(compressed)
|
|
106
|
+
|
|
107
|
+
# Expand
|
|
108
|
+
expanded = self.expand(compressed)
|
|
109
|
+
expanded = self.norm2(expanded)
|
|
110
|
+
|
|
111
|
+
return expanded
|
|
112
|
+
|
|
113
|
+
class StableMixtureOfExperts(nn.Module):
|
|
114
|
+
"""Stable mixture of experts without dimension issues"""
|
|
115
|
+
def __init__(self, input_dim: int, num_experts: int = 4):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.num_experts = num_experts
|
|
118
|
+
self.input_dim = input_dim
|
|
119
|
+
|
|
120
|
+
# Shared expert with different scaling factors
|
|
121
|
+
self.shared_expert = nn.Sequential(
|
|
122
|
+
nn.Linear(input_dim, input_dim * 2),
|
|
123
|
+
nn.SiLU(),
|
|
124
|
+
nn.Dropout(0.1),
|
|
125
|
+
nn.Linear(input_dim * 2, input_dim)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Gating network
|
|
129
|
+
self.gate = nn.Sequential(
|
|
130
|
+
nn.Linear(input_dim, num_experts * 4),
|
|
131
|
+
nn.SiLU(),
|
|
132
|
+
nn.Linear(num_experts * 4, num_experts)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def forward(self, x):
|
|
136
|
+
# Get gate weights
|
|
137
|
+
gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
|
|
138
|
+
|
|
139
|
+
# Process through shared expert
|
|
140
|
+
expert_output = self.shared_expert(x) # [batch, input_dim]
|
|
141
|
+
|
|
142
|
+
# Apply expert-specific scaling
|
|
143
|
+
weighted_output = torch.zeros_like(expert_output)
|
|
144
|
+
for i in range(self.num_experts):
|
|
145
|
+
expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
|
|
146
|
+
expert_contribution = expert_output * expert_scale
|
|
147
|
+
expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
|
|
148
|
+
weighted_output += expert_weight * expert_contribution
|
|
149
|
+
|
|
150
|
+
# Residual connection
|
|
151
|
+
return x + weighted_output
|
|
152
|
+
|
|
153
|
+
class CorrectedDecoder(nn.Module):
|
|
154
|
+
"""Corrected decoder with proper dimension handling"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
157
|
+
bottleneck_dim: int, num_experts: int, dropout_rate: float):
|
|
158
|
+
super().__init__()
|
|
159
|
+
|
|
160
|
+
# Input projection
|
|
161
|
+
self.input_projection = nn.Sequential(
|
|
162
|
+
nn.Linear(latent_dim, hidden_dims[0]),
|
|
163
|
+
EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
|
|
164
|
+
nn.SiLU(),
|
|
165
|
+
nn.Dropout(dropout_rate)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Main processing blocks
|
|
169
|
+
self.blocks = nn.ModuleList()
|
|
170
|
+
current_dim = hidden_dims[0]
|
|
171
|
+
|
|
172
|
+
for i, next_dim in enumerate(hidden_dims[1:], 1):
|
|
173
|
+
block = nn.ModuleDict({
|
|
174
|
+
'swiglu': nn.Sequential(
|
|
175
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
176
|
+
EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
|
|
177
|
+
nn.Dropout(dropout_rate),
|
|
178
|
+
nn.Linear(current_dim, current_dim) # Project back to same dimension
|
|
179
|
+
),
|
|
180
|
+
'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
|
|
181
|
+
current_dim, bottleneck_dim, next_dim
|
|
182
|
+
),
|
|
183
|
+
'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
|
|
184
|
+
next_dim, num_experts
|
|
185
|
+
)
|
|
186
|
+
})
|
|
187
|
+
self.blocks.append(block)
|
|
188
|
+
current_dim = next_dim
|
|
189
|
+
|
|
190
|
+
# Final projection to gene dimension
|
|
191
|
+
self.final_projection = nn.Sequential(
|
|
192
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
193
|
+
nn.SiLU(),
|
|
194
|
+
nn.Dropout(dropout_rate),
|
|
195
|
+
nn.Linear(current_dim * 2, gene_dim)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Output parameters
|
|
199
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
200
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
201
|
+
|
|
202
|
+
self._init_weights()
|
|
203
|
+
|
|
204
|
+
def _init_weights(self):
|
|
205
|
+
"""Proper weight initialization"""
|
|
206
|
+
for module in self.modules():
|
|
207
|
+
if isinstance(module, nn.Linear):
|
|
208
|
+
nn.init.xavier_uniform_(module.weight)
|
|
209
|
+
if module.bias is not None:
|
|
210
|
+
nn.init.zeros_(module.bias)
|
|
211
|
+
|
|
212
|
+
def forward(self, x):
|
|
213
|
+
# Input projection
|
|
214
|
+
x = self.input_projection(x)
|
|
215
|
+
|
|
216
|
+
# Process through blocks
|
|
217
|
+
for block in self.blocks:
|
|
218
|
+
# SwiGLU with residual
|
|
219
|
+
residual = x
|
|
220
|
+
x_swiglu = block['swiglu'](x)
|
|
221
|
+
x = x + x_swiglu # Residual connection
|
|
222
|
+
|
|
223
|
+
# Bottleneck
|
|
224
|
+
x = block['bottleneck'](x)
|
|
225
|
+
|
|
226
|
+
# Mixture of Experts with residual
|
|
227
|
+
x = block['experts'](x)
|
|
228
|
+
|
|
229
|
+
# Final projection
|
|
230
|
+
x = self.final_projection(x)
|
|
231
|
+
|
|
232
|
+
# Ensure non-negative output
|
|
233
|
+
x = F.softplus(x * self.output_scale + self.output_bias)
|
|
234
|
+
|
|
235
|
+
return x
|
|
236
|
+
|
|
237
|
+
def _build_corrected_model(self):
|
|
238
|
+
"""Build the corrected model"""
|
|
239
|
+
return self.CorrectedDecoder(
|
|
240
|
+
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
241
|
+
self.bottleneck_dim, self.num_experts, self.dropout_rate
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def train(self,
|
|
245
|
+
train_latent: np.ndarray,
|
|
246
|
+
train_expression: np.ndarray,
|
|
247
|
+
val_latent: np.ndarray = None,
|
|
248
|
+
val_expression: np.ndarray = None,
|
|
249
|
+
batch_size: int = 32,
|
|
250
|
+
num_epochs: int = 100,
|
|
251
|
+
learning_rate: float = 1e-4,
|
|
252
|
+
checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
|
|
253
|
+
"""
|
|
254
|
+
Train the corrected decoder
|
|
255
|
+
"""
|
|
256
|
+
print("🚀 Starting Training...")
|
|
257
|
+
|
|
258
|
+
# Data preparation
|
|
259
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
260
|
+
|
|
261
|
+
if val_latent is not None and val_expression is not None:
|
|
262
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
263
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
264
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
265
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
266
|
+
else:
|
|
267
|
+
# Auto split
|
|
268
|
+
train_size = int(0.9 * len(train_dataset))
|
|
269
|
+
val_size = len(train_dataset) - train_size
|
|
270
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
271
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
272
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
273
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
274
|
+
|
|
275
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
276
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
277
|
+
print(f"📊 Batch size: {batch_size}")
|
|
278
|
+
|
|
279
|
+
# Optimizer
|
|
280
|
+
optimizer = optim.AdamW(
|
|
281
|
+
self.model.parameters(),
|
|
282
|
+
lr=learning_rate,
|
|
283
|
+
weight_decay=0.01,
|
|
284
|
+
betas=(0.9, 0.999)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Scheduler
|
|
288
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
289
|
+
|
|
290
|
+
# Loss function
|
|
291
|
+
def combined_loss(pred, target):
|
|
292
|
+
mse_loss = F.mse_loss(pred, target)
|
|
293
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
294
|
+
correlation = self._pearson_correlation(pred, target)
|
|
295
|
+
correlation_loss = 1 - correlation
|
|
296
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
297
|
+
|
|
298
|
+
# Training history
|
|
299
|
+
history = {
|
|
300
|
+
'train_loss': [], 'val_loss': [],
|
|
301
|
+
'train_mse': [], 'val_mse': [],
|
|
302
|
+
'train_correlation': [], 'val_correlation': [],
|
|
303
|
+
'learning_rates': []
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
best_val_loss = float('inf')
|
|
307
|
+
patience = 20
|
|
308
|
+
patience_counter = 0
|
|
309
|
+
|
|
310
|
+
print("\n📈 Starting training loop...")
|
|
311
|
+
for epoch in range(1, num_epochs + 1):
|
|
312
|
+
# Training
|
|
313
|
+
train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
|
|
314
|
+
|
|
315
|
+
# Validation
|
|
316
|
+
val_metrics = self._validate_epoch(val_loader, combined_loss)
|
|
317
|
+
|
|
318
|
+
# Update scheduler
|
|
319
|
+
scheduler.step()
|
|
320
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
321
|
+
|
|
322
|
+
# Record history
|
|
323
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
324
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
325
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
326
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
327
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
328
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
329
|
+
history['learning_rates'].append(current_lr)
|
|
330
|
+
|
|
331
|
+
# Print progress
|
|
332
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
333
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
334
|
+
f"Train Loss: {train_metrics['loss']:.4f} | "
|
|
335
|
+
f"Val Loss: {val_metrics['loss']:.4f} | "
|
|
336
|
+
f"Correlation: {val_metrics['correlation']:.4f} | "
|
|
337
|
+
f"LR: {current_lr:.2e}")
|
|
338
|
+
|
|
339
|
+
# Early stopping
|
|
340
|
+
if val_metrics['loss'] < best_val_loss:
|
|
341
|
+
best_val_loss = val_metrics['loss']
|
|
342
|
+
patience_counter = 0
|
|
343
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
344
|
+
if epoch % 20 == 0:
|
|
345
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
346
|
+
else:
|
|
347
|
+
patience_counter += 1
|
|
348
|
+
if patience_counter >= patience:
|
|
349
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
# Training completed
|
|
353
|
+
self.is_trained = True
|
|
354
|
+
self.training_history = history
|
|
355
|
+
self.best_val_loss = best_val_loss
|
|
356
|
+
|
|
357
|
+
print(f"\n🎉 Training completed!")
|
|
358
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
359
|
+
|
|
360
|
+
return history
|
|
361
|
+
|
|
362
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
363
|
+
"""Create dataset"""
|
|
364
|
+
class SimpleDataset(Dataset):
|
|
365
|
+
def __init__(self, latent, expression):
|
|
366
|
+
self.latent = torch.FloatTensor(latent)
|
|
367
|
+
self.expression = torch.FloatTensor(expression)
|
|
368
|
+
|
|
369
|
+
def __len__(self):
|
|
370
|
+
return len(self.latent)
|
|
371
|
+
|
|
372
|
+
def __getitem__(self, idx):
|
|
373
|
+
return self.latent[idx], self.expression[idx]
|
|
374
|
+
|
|
375
|
+
return SimpleDataset(latent_data, expression_data)
|
|
376
|
+
|
|
377
|
+
def _pearson_correlation(self, pred, target):
|
|
378
|
+
"""Calculate Pearson correlation"""
|
|
379
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
380
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
381
|
+
|
|
382
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
383
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
384
|
+
|
|
385
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
386
|
+
|
|
387
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
388
|
+
"""Train one epoch"""
|
|
389
|
+
self.model.train()
|
|
390
|
+
total_loss = 0
|
|
391
|
+
total_mse = 0
|
|
392
|
+
total_correlation = 0
|
|
393
|
+
|
|
394
|
+
for latent, target in train_loader:
|
|
395
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
396
|
+
target = target.to(self.device, non_blocking=True)
|
|
397
|
+
|
|
398
|
+
optimizer.zero_grad()
|
|
399
|
+
pred = self.model(latent)
|
|
400
|
+
|
|
401
|
+
loss = loss_fn(pred, target)
|
|
402
|
+
loss.backward()
|
|
403
|
+
|
|
404
|
+
# Gradient clipping
|
|
405
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
406
|
+
optimizer.step()
|
|
407
|
+
|
|
408
|
+
# Calculate metrics
|
|
409
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
410
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
411
|
+
|
|
412
|
+
total_loss += loss.item()
|
|
413
|
+
total_mse += mse_loss
|
|
414
|
+
total_correlation += correlation
|
|
415
|
+
|
|
416
|
+
num_batches = len(train_loader)
|
|
417
|
+
return {
|
|
418
|
+
'loss': total_loss / num_batches,
|
|
419
|
+
'mse': total_mse / num_batches,
|
|
420
|
+
'correlation': total_correlation / num_batches
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
424
|
+
"""Validate one epoch"""
|
|
425
|
+
self.model.eval()
|
|
426
|
+
total_loss = 0
|
|
427
|
+
total_mse = 0
|
|
428
|
+
total_correlation = 0
|
|
429
|
+
|
|
430
|
+
with torch.no_grad():
|
|
431
|
+
for latent, target in val_loader:
|
|
432
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
433
|
+
target = target.to(self.device, non_blocking=True)
|
|
434
|
+
|
|
435
|
+
pred = self.model(latent)
|
|
436
|
+
loss = loss_fn(pred, target)
|
|
437
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
438
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
439
|
+
|
|
440
|
+
total_loss += loss.item()
|
|
441
|
+
total_mse += mse_loss
|
|
442
|
+
total_correlation += correlation
|
|
443
|
+
|
|
444
|
+
num_batches = len(val_loader)
|
|
445
|
+
return {
|
|
446
|
+
'loss': total_loss / num_batches,
|
|
447
|
+
'mse': total_mse / num_batches,
|
|
448
|
+
'correlation': total_correlation / num_batches
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
452
|
+
"""Save checkpoint"""
|
|
453
|
+
torch.save({
|
|
454
|
+
'epoch': epoch,
|
|
455
|
+
'model_state_dict': self.model.state_dict(),
|
|
456
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
457
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
458
|
+
'best_val_loss': best_loss,
|
|
459
|
+
'training_history': history,
|
|
460
|
+
'model_config': {
|
|
461
|
+
'latent_dim': self.latent_dim,
|
|
462
|
+
'gene_dim': self.gene_dim,
|
|
463
|
+
'hidden_dims': self.hidden_dims,
|
|
464
|
+
'bottleneck_dim': self.bottleneck_dim,
|
|
465
|
+
'num_experts': self.num_experts
|
|
466
|
+
}
|
|
467
|
+
}, path)
|
|
468
|
+
|
|
469
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
470
|
+
"""Predict gene expression"""
|
|
471
|
+
if not self.is_trained:
|
|
472
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
473
|
+
|
|
474
|
+
self.model.eval()
|
|
475
|
+
|
|
476
|
+
if isinstance(latent_data, np.ndarray):
|
|
477
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
478
|
+
|
|
479
|
+
predictions = []
|
|
480
|
+
with torch.no_grad():
|
|
481
|
+
for i in range(0, len(latent_data), batch_size):
|
|
482
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
483
|
+
batch_pred = self.model(batch_latent)
|
|
484
|
+
predictions.append(batch_pred.cpu())
|
|
485
|
+
|
|
486
|
+
return torch.cat(predictions).numpy()
|
|
487
|
+
|
|
488
|
+
def load_model(self, model_path: str):
|
|
489
|
+
"""Load pre-trained model"""
|
|
490
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
491
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
492
|
+
self.is_trained = True
|
|
493
|
+
self.training_history = checkpoint.get('training_history')
|
|
494
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
495
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
496
|
+
|
|
497
|
+
def get_model_info(self) -> Dict:
|
|
498
|
+
"""Get model information"""
|
|
499
|
+
return {
|
|
500
|
+
'is_trained': self.is_trained,
|
|
501
|
+
'best_val_loss': self.best_val_loss,
|
|
502
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
503
|
+
'latent_dim': self.latent_dim,
|
|
504
|
+
'gene_dim': self.gene_dim,
|
|
505
|
+
'hidden_dims': self.hidden_dims,
|
|
506
|
+
'device': str(self.device)
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
'''
|
|
510
|
+
# Example usage
|
|
511
|
+
def example_usage():
|
|
512
|
+
"""Example demonstration"""
|
|
513
|
+
|
|
514
|
+
# Initialize decoder
|
|
515
|
+
decoder = EfficientTranscriptomeDecoder(
|
|
516
|
+
latent_dim=100,
|
|
517
|
+
gene_dim=2000, # Reduced for example
|
|
518
|
+
hidden_dims=[256, 512, 1024],
|
|
519
|
+
bottleneck_dim=128,
|
|
520
|
+
num_experts=4,
|
|
521
|
+
dropout_rate=0.1
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Generate example data
|
|
525
|
+
n_samples = 1000
|
|
526
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
527
|
+
|
|
528
|
+
# Simulate expression data
|
|
529
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
530
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
531
|
+
expression_data = np.maximum(expression_data, 0)
|
|
532
|
+
|
|
533
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
534
|
+
|
|
535
|
+
# Train
|
|
536
|
+
history = decoder.train(
|
|
537
|
+
train_latent=latent_data,
|
|
538
|
+
train_expression=expression_data,
|
|
539
|
+
batch_size=32,
|
|
540
|
+
num_epochs=50
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Predict
|
|
544
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
545
|
+
predictions = decoder.predict(test_latent)
|
|
546
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
547
|
+
|
|
548
|
+
return decoder
|
|
549
|
+
|
|
550
|
+
if __name__ == "__main__":
|
|
551
|
+
example_usage()
|
|
552
|
+
'''
|