SURE-tools 2.4.2__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/EfficientTranscriptomeDecoder.py +607 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +273 -289
- SURE/__init__.py +6 -1
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/METADATA +1 -1
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/RECORD +10 -8
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class EfficientTranscriptomeDecoder:
|
|
13
|
+
"""
|
|
14
|
+
High-performance, memory-efficient transcriptome decoder
|
|
15
|
+
Combines latest research techniques for optimal performance
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
gene_dim: int = 60000,
|
|
21
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
|
+
bottleneck_dim: int = 256,
|
|
23
|
+
num_experts: int = 8,
|
|
24
|
+
dropout_rate: float = 0.1,
|
|
25
|
+
device: str = None):
|
|
26
|
+
"""
|
|
27
|
+
Advanced decoder combining multiple state-of-the-art techniques
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
latent_dim: Latent variable dimension
|
|
31
|
+
gene_dim: Number of genes (full transcriptome)
|
|
32
|
+
hidden_dims: Hidden layer dimensions
|
|
33
|
+
bottleneck_dim: Bottleneck dimension for memory efficiency
|
|
34
|
+
num_experts: Number of mixture-of-experts
|
|
35
|
+
dropout_rate: Dropout rate
|
|
36
|
+
device: Computation device
|
|
37
|
+
"""
|
|
38
|
+
self.latent_dim = latent_dim
|
|
39
|
+
self.gene_dim = gene_dim
|
|
40
|
+
self.hidden_dims = hidden_dims
|
|
41
|
+
self.bottleneck_dim = bottleneck_dim
|
|
42
|
+
self.num_experts = num_experts
|
|
43
|
+
self.dropout_rate = dropout_rate
|
|
44
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
45
|
+
|
|
46
|
+
# Initialize model with advanced architecture
|
|
47
|
+
self.model = self._build_advanced_model()
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
|
|
50
|
+
# Training state
|
|
51
|
+
self.is_trained = False
|
|
52
|
+
self.training_history = None
|
|
53
|
+
self.best_val_loss = float('inf')
|
|
54
|
+
|
|
55
|
+
print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
|
|
56
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
57
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
58
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
59
|
+
print(f" - Bottleneck Dimension: {bottleneck_dim}")
|
|
60
|
+
print(f" - Number of Experts: {num_experts}")
|
|
61
|
+
print(f" - Estimated GPU Memory: ~6-8GB")
|
|
62
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
63
|
+
|
|
64
|
+
class SwiGLU(nn.Module):
|
|
65
|
+
"""SwiGLU activation - better than GELU (PaLM, LLaMA)"""
|
|
66
|
+
def forward(self, x):
|
|
67
|
+
x, gate = x.chunk(2, dim=-1)
|
|
68
|
+
return x * F.silu(gate)
|
|
69
|
+
|
|
70
|
+
class RMSNorm(nn.Module):
|
|
71
|
+
"""RMS Normalization - more stable than LayerNorm (GPT-3)"""
|
|
72
|
+
def __init__(self, dim: int, eps: float = 1e-8):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.eps = eps
|
|
75
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
76
|
+
|
|
77
|
+
def forward(self, x):
|
|
78
|
+
norm_x = x.norm(2, dim=-1, keepdim=True)
|
|
79
|
+
rms_x = norm_x * (x.shape[-1] ** -0.5)
|
|
80
|
+
return x / (rms_x + self.eps) * self.weight
|
|
81
|
+
|
|
82
|
+
class MixtureOfExperts(nn.Module):
|
|
83
|
+
"""Mixture of Experts for conditional computation"""
|
|
84
|
+
def __init__(self, input_dim: int, expert_dim: int, num_experts: int):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.num_experts = num_experts
|
|
87
|
+
self.experts = nn.ModuleList([
|
|
88
|
+
nn.Sequential(
|
|
89
|
+
nn.Linear(input_dim, expert_dim),
|
|
90
|
+
nn.Dropout(0.1),
|
|
91
|
+
nn.Linear(expert_dim, input_dim)
|
|
92
|
+
) for _ in range(num_experts)
|
|
93
|
+
])
|
|
94
|
+
self.gate = nn.Linear(input_dim, num_experts)
|
|
95
|
+
self.expert_dim = expert_dim
|
|
96
|
+
|
|
97
|
+
def forward(self, x):
|
|
98
|
+
# Gate network
|
|
99
|
+
gate_logits = self.gate(x)
|
|
100
|
+
gate_weights = F.softmax(gate_logits, dim=-1)
|
|
101
|
+
|
|
102
|
+
# Expert outputs
|
|
103
|
+
expert_outputs = []
|
|
104
|
+
for i, expert in enumerate(self.experts):
|
|
105
|
+
expert_out = expert(x)
|
|
106
|
+
expert_outputs.append(expert_out.unsqueeze(-1))
|
|
107
|
+
|
|
108
|
+
# Combine expert outputs
|
|
109
|
+
expert_outputs = torch.cat(expert_outputs, dim=-1)
|
|
110
|
+
output = torch.einsum('bd, bde -> be', gate_weights, expert_outputs)
|
|
111
|
+
|
|
112
|
+
return output + x # Residual connection
|
|
113
|
+
|
|
114
|
+
class AdaptiveBottleneck(nn.Module):
|
|
115
|
+
"""Adaptive bottleneck for memory efficiency"""
|
|
116
|
+
def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.compress = nn.Linear(input_dim, bottleneck_dim)
|
|
119
|
+
self.norm1 = EfficientTranscriptomeDecoder.RMSNorm(bottleneck_dim)
|
|
120
|
+
self.expand = nn.Linear(bottleneck_dim, output_dim)
|
|
121
|
+
self.norm2 = EfficientTranscriptomeDecoder.RMSNorm(output_dim)
|
|
122
|
+
self.dropout = nn.Dropout(0.1)
|
|
123
|
+
|
|
124
|
+
def forward(self, x):
|
|
125
|
+
# Compress
|
|
126
|
+
compressed = self.compress(x)
|
|
127
|
+
compressed = self.norm1(compressed)
|
|
128
|
+
compressed = F.silu(compressed)
|
|
129
|
+
compressed = self.dropout(compressed)
|
|
130
|
+
|
|
131
|
+
# Expand
|
|
132
|
+
expanded = self.expand(compressed)
|
|
133
|
+
expanded = self.norm2(expanded)
|
|
134
|
+
|
|
135
|
+
return expanded
|
|
136
|
+
|
|
137
|
+
class GeneSpecificProjection(nn.Module):
|
|
138
|
+
"""Gene-specific projection with weight sharing"""
|
|
139
|
+
def __init__(self, latent_dim: int, gene_dim: int, proj_dim: int = 64):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.proj_dim = proj_dim
|
|
142
|
+
self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, proj_dim) * 0.02)
|
|
143
|
+
self.latent_projection = nn.Linear(latent_dim, proj_dim)
|
|
144
|
+
self.output_layer = nn.Linear(proj_dim, 1)
|
|
145
|
+
|
|
146
|
+
def forward(self, latent):
|
|
147
|
+
batch_size = latent.shape[0]
|
|
148
|
+
|
|
149
|
+
# Project latent to gene space
|
|
150
|
+
latent_proj = self.latent_projection(latent) # [batch_size, proj_dim]
|
|
151
|
+
|
|
152
|
+
# Efficient matrix multiplication
|
|
153
|
+
gene_output = torch.matmul(latent_proj, self.gene_embeddings.T) # [batch_size, gene_dim]
|
|
154
|
+
|
|
155
|
+
return gene_output
|
|
156
|
+
|
|
157
|
+
class AdvancedDecoder(nn.Module):
|
|
158
|
+
"""Advanced decoder combining multiple techniques"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
161
|
+
bottleneck_dim: int, num_experts: int, dropout_rate: float):
|
|
162
|
+
super().__init__()
|
|
163
|
+
|
|
164
|
+
# Initial projection
|
|
165
|
+
self.input_projection = nn.Sequential(
|
|
166
|
+
nn.Linear(latent_dim, hidden_dims[0]),
|
|
167
|
+
EfficientTranscriptomeDecoder.RMSNorm(hidden_dims[0]),
|
|
168
|
+
nn.SiLU(),
|
|
169
|
+
nn.Dropout(dropout_rate)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Main processing blocks
|
|
173
|
+
self.blocks = nn.ModuleList()
|
|
174
|
+
current_dim = hidden_dims[0]
|
|
175
|
+
|
|
176
|
+
for i, hidden_dim in enumerate(hidden_dims[1:], 1):
|
|
177
|
+
block = nn.ModuleList([
|
|
178
|
+
# Mixture of Experts
|
|
179
|
+
EfficientTranscriptomeDecoder.MixtureOfExperts(current_dim, hidden_dim, num_experts),
|
|
180
|
+
|
|
181
|
+
# Adaptive Bottleneck
|
|
182
|
+
EfficientTranscriptomeDecoder.AdaptiveBottleneck(current_dim, bottleneck_dim, hidden_dim),
|
|
183
|
+
|
|
184
|
+
# SwiGLU activation
|
|
185
|
+
nn.Sequential(
|
|
186
|
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
187
|
+
EfficientTranscriptomeDecoder.SwiGLU(),
|
|
188
|
+
nn.Dropout(dropout_rate)
|
|
189
|
+
)
|
|
190
|
+
])
|
|
191
|
+
self.blocks.append(block)
|
|
192
|
+
current_dim = hidden_dim
|
|
193
|
+
|
|
194
|
+
# Gene-specific projection
|
|
195
|
+
self.gene_projection = EfficientTranscriptomeDecoder.GeneSpecificProjection(
|
|
196
|
+
current_dim, gene_dim, proj_dim=128
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Output scaling
|
|
200
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
201
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
202
|
+
|
|
203
|
+
self._init_weights()
|
|
204
|
+
|
|
205
|
+
def _init_weights(self):
|
|
206
|
+
"""Advanced weight initialization"""
|
|
207
|
+
for module in self.modules():
|
|
208
|
+
if isinstance(module, nn.Linear):
|
|
209
|
+
# Kaiming init for SiLU/SwiGLU
|
|
210
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity='linear')
|
|
211
|
+
if module.bias is not None:
|
|
212
|
+
nn.init.zeros_(module.bias)
|
|
213
|
+
|
|
214
|
+
def forward(self, x):
|
|
215
|
+
# Initial projection
|
|
216
|
+
x = self.input_projection(x)
|
|
217
|
+
|
|
218
|
+
# Process through blocks
|
|
219
|
+
for block in self.blocks:
|
|
220
|
+
# Mixture of Experts
|
|
221
|
+
expert_out = block[0](x)
|
|
222
|
+
|
|
223
|
+
# Adaptive Bottleneck
|
|
224
|
+
bottleneck_out = block[1](expert_out)
|
|
225
|
+
|
|
226
|
+
# SwiGLU activation with residual
|
|
227
|
+
swiglu_out = block[2](bottleneck_out)
|
|
228
|
+
x = x + swiglu_out # Residual connection
|
|
229
|
+
|
|
230
|
+
# Final gene projection
|
|
231
|
+
output = self.gene_projection(x)
|
|
232
|
+
|
|
233
|
+
# Ensure non-negative output
|
|
234
|
+
output = F.softplus(output * self.output_scale + self.output_bias)
|
|
235
|
+
|
|
236
|
+
return output
|
|
237
|
+
|
|
238
|
+
def _build_advanced_model(self):
|
|
239
|
+
"""Build the advanced decoder model"""
|
|
240
|
+
return self.AdvancedDecoder(
|
|
241
|
+
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
242
|
+
self.bottleneck_dim, self.num_experts, self.dropout_rate
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def train(self,
|
|
246
|
+
train_latent: np.ndarray,
|
|
247
|
+
train_expression: np.ndarray,
|
|
248
|
+
val_latent: np.ndarray = None,
|
|
249
|
+
val_expression: np.ndarray = None,
|
|
250
|
+
batch_size: int = 16, # Smaller batches for memory efficiency
|
|
251
|
+
num_epochs: int = 200,
|
|
252
|
+
learning_rate: float = 1e-4,
|
|
253
|
+
checkpoint_path: str = 'efficient_decoder.pth') -> Dict:
|
|
254
|
+
"""
|
|
255
|
+
Train with advanced optimization techniques
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
train_latent: Training latent variables
|
|
259
|
+
train_expression: Training expression data
|
|
260
|
+
val_latent: Validation latent variables
|
|
261
|
+
val_expression: Validation expression data
|
|
262
|
+
batch_size: Batch size (optimized for memory)
|
|
263
|
+
num_epochs: Number of epochs
|
|
264
|
+
learning_rate: Learning rate
|
|
265
|
+
checkpoint_path: Model save path
|
|
266
|
+
"""
|
|
267
|
+
print("🚀 Starting Advanced Training...")
|
|
268
|
+
|
|
269
|
+
# Data preparation
|
|
270
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
271
|
+
|
|
272
|
+
if val_latent is not None and val_expression is not None:
|
|
273
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
274
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
275
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
276
|
+
else:
|
|
277
|
+
train_size = int(0.9 * len(train_dataset))
|
|
278
|
+
val_size = len(train_dataset) - train_size
|
|
279
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
280
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
281
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
282
|
+
|
|
283
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
284
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
285
|
+
print(f"📊 Batch size: {batch_size}")
|
|
286
|
+
|
|
287
|
+
# Advanced optimizer configuration
|
|
288
|
+
optimizer = optim.AdamW(
|
|
289
|
+
self.model.parameters(),
|
|
290
|
+
lr=learning_rate,
|
|
291
|
+
weight_decay=0.1, # Stronger regularization
|
|
292
|
+
betas=(0.9, 0.95), # Tuned betas
|
|
293
|
+
eps=1e-8
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Cosine annealing with warmup
|
|
297
|
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
|
298
|
+
optimizer,
|
|
299
|
+
max_lr=learning_rate * 5,
|
|
300
|
+
epochs=num_epochs,
|
|
301
|
+
steps_per_epoch=len(train_loader),
|
|
302
|
+
pct_start=0.1,
|
|
303
|
+
div_factor=10.0,
|
|
304
|
+
final_div_factor=100.0
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Advanced loss function
|
|
308
|
+
def advanced_loss(pred, target):
|
|
309
|
+
# 1. MSE loss for overall accuracy
|
|
310
|
+
mse_loss = F.mse_loss(pred, target)
|
|
311
|
+
|
|
312
|
+
# 2. Poisson loss for count data
|
|
313
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
314
|
+
|
|
315
|
+
# 3. Correlation loss for pattern matching
|
|
316
|
+
correlation_loss = 1 - self._pearson_correlation(pred, target)
|
|
317
|
+
|
|
318
|
+
# 4. Sparsity loss for realistic distribution
|
|
319
|
+
sparsity_loss = F.mse_loss(
|
|
320
|
+
(pred < 1e-3).float().mean(),
|
|
321
|
+
torch.tensor(0.85, device=pred.device) # Target sparsity
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# 5. Spectral loss for smoothness
|
|
325
|
+
spectral_loss = self._spectral_loss(pred, target)
|
|
326
|
+
|
|
327
|
+
# Weighted combination
|
|
328
|
+
total_loss = (mse_loss + 0.3 * poisson_loss + 0.2 * correlation_loss +
|
|
329
|
+
0.1 * sparsity_loss + 0.05 * spectral_loss)
|
|
330
|
+
|
|
331
|
+
return total_loss, {
|
|
332
|
+
'mse': mse_loss.item(),
|
|
333
|
+
'poisson': poisson_loss.item(),
|
|
334
|
+
'correlation': correlation_loss.item(),
|
|
335
|
+
'sparsity': sparsity_loss.item(),
|
|
336
|
+
'spectral': spectral_loss.item()
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
# Training history
|
|
340
|
+
history = {
|
|
341
|
+
'train_loss': [], 'val_loss': [],
|
|
342
|
+
'train_mse': [], 'val_mse': [],
|
|
343
|
+
'train_correlation': [], 'val_correlation': [],
|
|
344
|
+
'learning_rates': [], 'grad_norms': []
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
best_val_loss = float('inf')
|
|
348
|
+
patience = 25
|
|
349
|
+
patience_counter = 0
|
|
350
|
+
|
|
351
|
+
print("\n📈 Starting training with advanced techniques...")
|
|
352
|
+
for epoch in range(1, num_epochs + 1):
|
|
353
|
+
# Training phase
|
|
354
|
+
train_loss, train_components, grad_norm = self._train_epoch_advanced(
|
|
355
|
+
train_loader, optimizer, scheduler, advanced_loss
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Validation phase
|
|
359
|
+
val_loss, val_components = self._validate_epoch_advanced(val_loader, advanced_loss)
|
|
360
|
+
|
|
361
|
+
# Record history
|
|
362
|
+
history['train_loss'].append(train_loss)
|
|
363
|
+
history['val_loss'].append(val_loss)
|
|
364
|
+
history['train_mse'].append(train_components['mse'])
|
|
365
|
+
history['val_mse'].append(val_components['mse'])
|
|
366
|
+
history['train_correlation'].append(train_components['correlation'])
|
|
367
|
+
history['val_correlation'].append(val_components['correlation'])
|
|
368
|
+
history['learning_rates'].append(optimizer.param_groups[0]['lr'])
|
|
369
|
+
history['grad_norms'].append(grad_norm)
|
|
370
|
+
|
|
371
|
+
# Print detailed progress
|
|
372
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
373
|
+
lr = optimizer.param_groups[0]['lr']
|
|
374
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
375
|
+
f"Train: {train_loss:.4f} | "
|
|
376
|
+
f"Val: {val_loss:.4f} | "
|
|
377
|
+
f"Corr: {val_components['correlation']:.4f} | "
|
|
378
|
+
f"LR: {lr:.2e} | "
|
|
379
|
+
f"Grad: {grad_norm:.4f}")
|
|
380
|
+
|
|
381
|
+
# Early stopping with patience
|
|
382
|
+
if val_loss < best_val_loss:
|
|
383
|
+
best_val_loss = val_loss
|
|
384
|
+
patience_counter = 0
|
|
385
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
386
|
+
if epoch % 20 == 0:
|
|
387
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
388
|
+
else:
|
|
389
|
+
patience_counter += 1
|
|
390
|
+
if patience_counter >= patience:
|
|
391
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
392
|
+
break
|
|
393
|
+
|
|
394
|
+
# Training completed
|
|
395
|
+
self.is_trained = True
|
|
396
|
+
self.training_history = history
|
|
397
|
+
self.best_val_loss = best_val_loss
|
|
398
|
+
|
|
399
|
+
print(f"\n🎉 Training completed!")
|
|
400
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
401
|
+
|
|
402
|
+
return history
|
|
403
|
+
|
|
404
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
405
|
+
"""Create memory-efficient dataset"""
|
|
406
|
+
class EfficientDataset(Dataset):
|
|
407
|
+
def __init__(self, latent, expression):
|
|
408
|
+
self.latent = torch.FloatTensor(latent)
|
|
409
|
+
self.expression = torch.FloatTensor(expression)
|
|
410
|
+
|
|
411
|
+
def __len__(self):
|
|
412
|
+
return len(self.latent)
|
|
413
|
+
|
|
414
|
+
def __getitem__(self, idx):
|
|
415
|
+
return self.latent[idx], self.expression[idx]
|
|
416
|
+
|
|
417
|
+
return EfficientDataset(latent_data, expression_data)
|
|
418
|
+
|
|
419
|
+
def _pearson_correlation(self, pred, target):
|
|
420
|
+
"""Calculate Pearson correlation"""
|
|
421
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
422
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
423
|
+
|
|
424
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
425
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
426
|
+
|
|
427
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
428
|
+
|
|
429
|
+
def _spectral_loss(self, pred, target):
|
|
430
|
+
"""Spectral loss for frequency domain matching"""
|
|
431
|
+
pred_fft = torch.fft.fft(pred, dim=1)
|
|
432
|
+
target_fft = torch.fft.fft(target, dim=1)
|
|
433
|
+
|
|
434
|
+
magnitude_loss = F.mse_loss(torch.abs(pred_fft), torch.abs(target_fft))
|
|
435
|
+
phase_loss = F.mse_loss(torch.angle(pred_fft), torch.angle(target_fft))
|
|
436
|
+
|
|
437
|
+
return magnitude_loss + 0.5 * phase_loss
|
|
438
|
+
|
|
439
|
+
def _train_epoch_advanced(self, train_loader, optimizer, scheduler, loss_fn):
|
|
440
|
+
"""Advanced training with gradient accumulation"""
|
|
441
|
+
self.model.train()
|
|
442
|
+
total_loss = 0
|
|
443
|
+
total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
|
|
444
|
+
grad_norms = []
|
|
445
|
+
|
|
446
|
+
# Gradient accumulation for effective larger batch size
|
|
447
|
+
accumulation_steps = 4
|
|
448
|
+
optimizer.zero_grad()
|
|
449
|
+
|
|
450
|
+
for i, (latent, target) in enumerate(train_loader):
|
|
451
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
452
|
+
target = target.to(self.device, non_blocking=True)
|
|
453
|
+
|
|
454
|
+
# Forward pass with mixed precision
|
|
455
|
+
with torch.cuda.amp.autocast(): # Mixed precision for memory efficiency
|
|
456
|
+
pred = self.model(latent)
|
|
457
|
+
loss, components = loss_fn(pred, target)
|
|
458
|
+
|
|
459
|
+
# Scale loss for gradient accumulation
|
|
460
|
+
loss = loss / accumulation_steps
|
|
461
|
+
loss.backward()
|
|
462
|
+
|
|
463
|
+
# Gradient accumulation
|
|
464
|
+
if (i + 1) % accumulation_steps == 0:
|
|
465
|
+
# Gradient clipping
|
|
466
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
467
|
+
optimizer.step()
|
|
468
|
+
scheduler.step()
|
|
469
|
+
optimizer.zero_grad()
|
|
470
|
+
|
|
471
|
+
grad_norms.append(grad_norm.item())
|
|
472
|
+
|
|
473
|
+
# Accumulate losses
|
|
474
|
+
total_loss += loss.item() * accumulation_steps
|
|
475
|
+
for key in total_components:
|
|
476
|
+
total_components[key] += components[key]
|
|
477
|
+
|
|
478
|
+
# Average metrics
|
|
479
|
+
num_batches = len(train_loader)
|
|
480
|
+
avg_loss = total_loss / num_batches
|
|
481
|
+
avg_components = {key: value / num_batches for key, value in total_components.items()}
|
|
482
|
+
avg_grad_norm = np.mean(grad_norms) if grad_norms else 0.0
|
|
483
|
+
|
|
484
|
+
return avg_loss, avg_components, avg_grad_norm
|
|
485
|
+
|
|
486
|
+
def _validate_epoch_advanced(self, val_loader, loss_fn):
|
|
487
|
+
"""Advanced validation"""
|
|
488
|
+
self.model.eval()
|
|
489
|
+
total_loss = 0
|
|
490
|
+
total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
|
|
491
|
+
|
|
492
|
+
with torch.no_grad():
|
|
493
|
+
for latent, target in val_loader:
|
|
494
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
495
|
+
target = target.to(self.device, non_blocking=True)
|
|
496
|
+
|
|
497
|
+
pred = self.model(latent)
|
|
498
|
+
loss, components = loss_fn(pred, target)
|
|
499
|
+
|
|
500
|
+
total_loss += loss.item()
|
|
501
|
+
for key in total_components:
|
|
502
|
+
total_components[key] += components[key]
|
|
503
|
+
|
|
504
|
+
num_batches = len(val_loader)
|
|
505
|
+
avg_loss = total_loss / num_batches
|
|
506
|
+
avg_components = {key: value / num_batches for key, value in total_components.items()}
|
|
507
|
+
|
|
508
|
+
return avg_loss, avg_components
|
|
509
|
+
|
|
510
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
511
|
+
"""Save checkpoint"""
|
|
512
|
+
torch.save({
|
|
513
|
+
'epoch': epoch,
|
|
514
|
+
'model_state_dict': self.model.state_dict(),
|
|
515
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
516
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
517
|
+
'best_val_loss': best_loss,
|
|
518
|
+
'training_history': history,
|
|
519
|
+
'model_config': {
|
|
520
|
+
'latent_dim': self.latent_dim,
|
|
521
|
+
'gene_dim': self.gene_dim,
|
|
522
|
+
'hidden_dims': self.hidden_dims,
|
|
523
|
+
'bottleneck_dim': self.bottleneck_dim,
|
|
524
|
+
'num_experts': self.num_experts
|
|
525
|
+
}
|
|
526
|
+
}, path)
|
|
527
|
+
|
|
528
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 16) -> np.ndarray:
|
|
529
|
+
"""Memory-efficient prediction"""
|
|
530
|
+
if not self.is_trained:
|
|
531
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
532
|
+
|
|
533
|
+
self.model.eval()
|
|
534
|
+
|
|
535
|
+
if isinstance(latent_data, np.ndarray):
|
|
536
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
537
|
+
|
|
538
|
+
predictions = []
|
|
539
|
+
with torch.no_grad():
|
|
540
|
+
for i in range(0, len(latent_data), batch_size):
|
|
541
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
542
|
+
|
|
543
|
+
with torch.cuda.amp.autocast(): # Mixed precision for memory
|
|
544
|
+
batch_pred = self.model(batch_latent)
|
|
545
|
+
|
|
546
|
+
predictions.append(batch_pred.cpu())
|
|
547
|
+
|
|
548
|
+
# Clear memory
|
|
549
|
+
if torch.cuda.is_available():
|
|
550
|
+
torch.cuda.empty_cache()
|
|
551
|
+
|
|
552
|
+
return torch.cat(predictions).numpy()
|
|
553
|
+
|
|
554
|
+
def load_model(self, model_path: str):
|
|
555
|
+
"""Load pre-trained model"""
|
|
556
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
557
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
558
|
+
self.is_trained = True
|
|
559
|
+
self.training_history = checkpoint.get('training_history')
|
|
560
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
561
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
562
|
+
|
|
563
|
+
'''
|
|
564
|
+
# Example usage
|
|
565
|
+
def example_usage():
|
|
566
|
+
"""Demonstrate the advanced decoder"""
|
|
567
|
+
|
|
568
|
+
# Initialize decoder
|
|
569
|
+
decoder = EfficientTranscriptomeDecoder(
|
|
570
|
+
latent_dim=100,
|
|
571
|
+
gene_dim=2000, # Reduced for example
|
|
572
|
+
hidden_dims=[256, 512, 1024],
|
|
573
|
+
bottleneck_dim=128,
|
|
574
|
+
num_experts=4,
|
|
575
|
+
dropout_rate=0.1
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
# Generate example data
|
|
579
|
+
n_samples = 1000
|
|
580
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
581
|
+
|
|
582
|
+
# Simulate expression data
|
|
583
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
584
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
585
|
+
expression_data = np.maximum(expression_data, 0)
|
|
586
|
+
|
|
587
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
588
|
+
|
|
589
|
+
# Train
|
|
590
|
+
history = decoder.train(
|
|
591
|
+
train_latent=latent_data,
|
|
592
|
+
train_expression=expression_data,
|
|
593
|
+
batch_size=16,
|
|
594
|
+
num_epochs=50
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Predict
|
|
598
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
599
|
+
predictions = decoder.predict(test_latent)
|
|
600
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
601
|
+
|
|
602
|
+
return decoder
|
|
603
|
+
|
|
604
|
+
if __name__ == "__main__":
|
|
605
|
+
example_usage()
|
|
606
|
+
|
|
607
|
+
'''
|