SURE-tools 2.2.24__py3-none-any.whl → 2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +15 -8
- SURE/PerturbE.py +1300 -0
- SURE/TranscriptomeDecoder.py +499 -0
- SURE/__init__.py +5 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/METADATA +1 -1
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/RECORD +11 -9
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
import warnings
|
|
9
|
+
warnings.filterwarnings('ignore')
|
|
10
|
+
|
|
11
|
+
class TranscriptomeDecoder:
|
|
12
|
+
"""Transcriptome decoder"""
|
|
13
|
+
|
|
14
|
+
def __init__(self,
|
|
15
|
+
latent_dim: int = 100,
|
|
16
|
+
gene_dim: int = 60000,
|
|
17
|
+
hidden_dim: int = 512,
|
|
18
|
+
device: str = None):
|
|
19
|
+
"""
|
|
20
|
+
Simple but powerful decoder for latent to transcriptome mapping
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
latent_dim: Latent variable dimension (typically 50-100)
|
|
24
|
+
gene_dim: Number of genes (full transcriptome ~60,000)
|
|
25
|
+
hidden_dim: Hidden dimension optimized
|
|
26
|
+
device: Computation device
|
|
27
|
+
"""
|
|
28
|
+
self.latent_dim = latent_dim
|
|
29
|
+
self.gene_dim = gene_dim
|
|
30
|
+
self.hidden_dim = hidden_dim
|
|
31
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
32
|
+
|
|
33
|
+
# Initialize model
|
|
34
|
+
self.model = self._build_model()
|
|
35
|
+
self.model.to(self.device)
|
|
36
|
+
|
|
37
|
+
# Training state
|
|
38
|
+
self.is_trained = False
|
|
39
|
+
self.training_history = None
|
|
40
|
+
self.best_val_loss = float('inf')
|
|
41
|
+
|
|
42
|
+
print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
|
|
43
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
44
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
45
|
+
print(f" - Hidden Dimension: {hidden_dim}")
|
|
46
|
+
print(f" - Device: {self.device}")
|
|
47
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
48
|
+
|
|
49
|
+
class Decoder(nn.Module):
|
|
50
|
+
"""Memory-efficient decoder architecture with dimension handling"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.latent_dim = latent_dim
|
|
55
|
+
self.gene_dim = gene_dim
|
|
56
|
+
self.hidden_dim = hidden_dim
|
|
57
|
+
|
|
58
|
+
# Stage 1: Latent variable expansion
|
|
59
|
+
self.latent_expansion = nn.Sequential(
|
|
60
|
+
nn.Linear(latent_dim, hidden_dim * 2),
|
|
61
|
+
nn.BatchNorm1d(hidden_dim * 2),
|
|
62
|
+
nn.GELU(),
|
|
63
|
+
nn.Dropout(0.1),
|
|
64
|
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
65
|
+
nn.BatchNorm1d(hidden_dim),
|
|
66
|
+
nn.GELU(),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Stage 2: Direct projection to gene dimension (simpler approach)
|
|
70
|
+
self.gene_projector = nn.Sequential(
|
|
71
|
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
72
|
+
nn.GELU(),
|
|
73
|
+
nn.Dropout(0.1),
|
|
74
|
+
nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Stage 3: Lightweight gene interaction
|
|
78
|
+
self.gene_interaction = nn.Sequential(
|
|
79
|
+
nn.Conv1d(1, 32, kernel_size=3, padding=1),
|
|
80
|
+
nn.GELU(),
|
|
81
|
+
nn.Dropout1d(0.1),
|
|
82
|
+
nn.Conv1d(32, 1, kernel_size=3, padding=1),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Output scaling
|
|
86
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
87
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
88
|
+
|
|
89
|
+
self._init_weights()
|
|
90
|
+
|
|
91
|
+
def _init_weights(self):
|
|
92
|
+
"""Weight initialization"""
|
|
93
|
+
for module in self.modules():
|
|
94
|
+
if isinstance(module, nn.Linear):
|
|
95
|
+
nn.init.xavier_uniform_(module.weight)
|
|
96
|
+
if module.bias is not None:
|
|
97
|
+
nn.init.zeros_(module.bias)
|
|
98
|
+
elif isinstance(module, nn.Conv1d):
|
|
99
|
+
nn.init.kaiming_uniform_(module.weight)
|
|
100
|
+
if module.bias is not None:
|
|
101
|
+
nn.init.zeros_(module.bias)
|
|
102
|
+
|
|
103
|
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
batch_size = latent.shape[0]
|
|
105
|
+
|
|
106
|
+
# 1. Expand latent variables
|
|
107
|
+
latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
|
|
108
|
+
|
|
109
|
+
# 2. Direct projection to gene dimension
|
|
110
|
+
gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
|
|
111
|
+
|
|
112
|
+
# 3. Gene interaction with dimension safety
|
|
113
|
+
if self.gene_dim > 1: # Only apply if gene_dim > 1
|
|
114
|
+
gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
|
|
115
|
+
interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
|
|
116
|
+
gene_output = gene_output + interaction_output # Residual connection
|
|
117
|
+
gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
|
|
118
|
+
|
|
119
|
+
# 4. Final activation (ensure non-negative)
|
|
120
|
+
gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
|
|
121
|
+
|
|
122
|
+
return gene_output
|
|
123
|
+
|
|
124
|
+
def _build_model(self):
|
|
125
|
+
"""Build the decoder model"""
|
|
126
|
+
return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
|
|
127
|
+
|
|
128
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
129
|
+
"""Create dataset with dimension validation"""
|
|
130
|
+
class SimpleDataset(Dataset):
|
|
131
|
+
def __init__(self, latent, expression):
|
|
132
|
+
# Ensure dimensions match
|
|
133
|
+
assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
|
|
134
|
+
assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
|
|
135
|
+
assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
|
|
136
|
+
|
|
137
|
+
self.latent = torch.FloatTensor(latent)
|
|
138
|
+
self.expression = torch.FloatTensor(expression)
|
|
139
|
+
|
|
140
|
+
def __len__(self):
|
|
141
|
+
return len(self.latent)
|
|
142
|
+
|
|
143
|
+
def __getitem__(self, idx):
|
|
144
|
+
return self.latent[idx], self.expression[idx]
|
|
145
|
+
|
|
146
|
+
return SimpleDataset(latent_data, expression_data)
|
|
147
|
+
|
|
148
|
+
def train(self,
|
|
149
|
+
train_latent: np.ndarray,
|
|
150
|
+
train_expression: np.ndarray,
|
|
151
|
+
val_latent: np.ndarray = None,
|
|
152
|
+
val_expression: np.ndarray = None,
|
|
153
|
+
batch_size: int = 32,
|
|
154
|
+
num_epochs: int = 100,
|
|
155
|
+
learning_rate: float = 1e-4,
|
|
156
|
+
checkpoint_path: str = 'transcriptome_decoder.pth'):
|
|
157
|
+
"""
|
|
158
|
+
Train the decoder model with dimension safety
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
train_latent: Training latent variables [n_samples, latent_dim]
|
|
162
|
+
train_expression: Training expression data [n_samples, gene_dim]
|
|
163
|
+
val_latent: Validation latent variables (optional)
|
|
164
|
+
val_expression: Validation expression data (optional)
|
|
165
|
+
batch_size: Batch size optimized for memory
|
|
166
|
+
num_epochs: Number of training epochs
|
|
167
|
+
learning_rate: Learning rate
|
|
168
|
+
checkpoint_path: Path to save the best model
|
|
169
|
+
"""
|
|
170
|
+
print("🚀 Starting training...")
|
|
171
|
+
|
|
172
|
+
# Dimension validation
|
|
173
|
+
self._validate_data_dimensions(train_latent, train_expression, "Training")
|
|
174
|
+
if val_latent is not None and val_expression is not None:
|
|
175
|
+
self._validate_data_dimensions(val_latent, val_expression, "Validation")
|
|
176
|
+
|
|
177
|
+
# Data preparation
|
|
178
|
+
train_dataset = self._create_safe_dataset(train_latent, train_expression)
|
|
179
|
+
|
|
180
|
+
if val_latent is not None and val_expression is not None:
|
|
181
|
+
val_dataset = self._create_safe_dataset(val_latent, val_expression)
|
|
182
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
183
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
184
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
185
|
+
else:
|
|
186
|
+
# Auto split
|
|
187
|
+
train_size = int(0.9 * len(train_dataset))
|
|
188
|
+
val_size = len(train_dataset) - train_size
|
|
189
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
190
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
191
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
192
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
193
|
+
|
|
194
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
195
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
196
|
+
print(f"📊 Batch size: {batch_size}")
|
|
197
|
+
|
|
198
|
+
# Optimizer configuration
|
|
199
|
+
optimizer = optim.AdamW(
|
|
200
|
+
self.model.parameters(),
|
|
201
|
+
lr=learning_rate,
|
|
202
|
+
weight_decay=0.01,
|
|
203
|
+
betas=(0.9, 0.999)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Learning rate scheduler
|
|
207
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
208
|
+
|
|
209
|
+
# Loss function with dimension safety
|
|
210
|
+
def safe_loss(pred, target):
|
|
211
|
+
# Ensure dimensions match
|
|
212
|
+
if pred.shape != target.shape:
|
|
213
|
+
print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
|
|
214
|
+
# Truncate to minimum dimension (safety measure)
|
|
215
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
216
|
+
pred = pred[:, :min_dim]
|
|
217
|
+
target = target[:, :min_dim]
|
|
218
|
+
|
|
219
|
+
mse_loss = F.mse_loss(pred, target)
|
|
220
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
221
|
+
return mse_loss + 0.3 * poisson_loss
|
|
222
|
+
|
|
223
|
+
# Training history
|
|
224
|
+
history = {
|
|
225
|
+
'train_loss': [],
|
|
226
|
+
'val_loss': [],
|
|
227
|
+
'learning_rate': []
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
best_val_loss = float('inf')
|
|
231
|
+
patience = 15
|
|
232
|
+
patience_counter = 0
|
|
233
|
+
|
|
234
|
+
print("\n📈 Starting training loop...")
|
|
235
|
+
for epoch in range(1, num_epochs + 1):
|
|
236
|
+
# Training phase
|
|
237
|
+
train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
|
|
238
|
+
|
|
239
|
+
# Validation phase
|
|
240
|
+
val_loss = self._validate_epoch(val_loader, safe_loss)
|
|
241
|
+
|
|
242
|
+
# Update scheduler
|
|
243
|
+
scheduler.step()
|
|
244
|
+
current_lr = scheduler.get_last_lr()[0]
|
|
245
|
+
|
|
246
|
+
# Record history
|
|
247
|
+
history['train_loss'].append(train_loss)
|
|
248
|
+
history['val_loss'].append(val_loss)
|
|
249
|
+
history['learning_rate'].append(current_lr)
|
|
250
|
+
|
|
251
|
+
# Print progress
|
|
252
|
+
if epoch % 5 == 0 or epoch == 1:
|
|
253
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
254
|
+
f"Train Loss: {train_loss:.4f} | "
|
|
255
|
+
f"Val Loss: {val_loss:.4f} | "
|
|
256
|
+
f"LR: {current_lr:.2e}")
|
|
257
|
+
|
|
258
|
+
# Early stopping and model saving
|
|
259
|
+
if val_loss < best_val_loss:
|
|
260
|
+
best_val_loss = val_loss
|
|
261
|
+
patience_counter = 0
|
|
262
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
263
|
+
if epoch % 10 == 0:
|
|
264
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
265
|
+
else:
|
|
266
|
+
patience_counter += 1
|
|
267
|
+
if patience_counter >= patience:
|
|
268
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
269
|
+
break
|
|
270
|
+
|
|
271
|
+
# Training completed
|
|
272
|
+
self.is_trained = True
|
|
273
|
+
self.training_history = history
|
|
274
|
+
self.best_val_loss = best_val_loss
|
|
275
|
+
|
|
276
|
+
print(f"\n🎉 Training completed!")
|
|
277
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
278
|
+
print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
|
|
279
|
+
|
|
280
|
+
return history
|
|
281
|
+
|
|
282
|
+
def _validate_data_dimensions(self, latent_data, expression_data, data_type):
|
|
283
|
+
"""Validate input data dimensions"""
|
|
284
|
+
assert latent_data.shape[1] == self.latent_dim, (
|
|
285
|
+
f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
286
|
+
assert expression_data.shape[1] == self.gene_dim, (
|
|
287
|
+
f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
|
|
288
|
+
assert latent_data.shape[0] == expression_data.shape[0], (
|
|
289
|
+
f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
|
|
290
|
+
print(f"✅ {data_type} data dimensions validated")
|
|
291
|
+
|
|
292
|
+
def _create_safe_dataset(self, latent_data, expression_data):
|
|
293
|
+
"""Create dataset with safety checks"""
|
|
294
|
+
class SafeDataset(Dataset):
|
|
295
|
+
def __init__(self, latent, expression):
|
|
296
|
+
self.latent = torch.FloatTensor(latent)
|
|
297
|
+
self.expression = torch.FloatTensor(expression)
|
|
298
|
+
|
|
299
|
+
# Safety check
|
|
300
|
+
if self.latent.shape[0] != self.expression.shape[0]:
|
|
301
|
+
raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
|
|
302
|
+
|
|
303
|
+
def __len__(self):
|
|
304
|
+
return len(self.latent)
|
|
305
|
+
|
|
306
|
+
def __getitem__(self, idx):
|
|
307
|
+
return self.latent[idx], self.expression[idx]
|
|
308
|
+
|
|
309
|
+
return SafeDataset(latent_data, expression_data)
|
|
310
|
+
|
|
311
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
312
|
+
"""Train for one epoch with dimension safety"""
|
|
313
|
+
self.model.train()
|
|
314
|
+
total_loss = 0
|
|
315
|
+
|
|
316
|
+
for batch_idx, (latent, target) in enumerate(train_loader):
|
|
317
|
+
latent = latent.to(self.device)
|
|
318
|
+
target = target.to(self.device)
|
|
319
|
+
|
|
320
|
+
# Dimension check
|
|
321
|
+
if latent.shape[1] != self.latent_dim:
|
|
322
|
+
print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
optimizer.zero_grad()
|
|
326
|
+
pred = self.model(latent)
|
|
327
|
+
|
|
328
|
+
# Final dimension check before loss calculation
|
|
329
|
+
if pred.shape[1] != target.shape[1]:
|
|
330
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
331
|
+
pred = pred[:, :min_dim]
|
|
332
|
+
target = target[:, :min_dim]
|
|
333
|
+
if batch_idx == 0: # Only warn once
|
|
334
|
+
print(f"⚠️ Truncating to min dimension: {min_dim}")
|
|
335
|
+
|
|
336
|
+
loss = loss_fn(pred, target)
|
|
337
|
+
loss.backward()
|
|
338
|
+
|
|
339
|
+
# Gradient clipping for stability
|
|
340
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
341
|
+
optimizer.step()
|
|
342
|
+
|
|
343
|
+
total_loss += loss.item()
|
|
344
|
+
|
|
345
|
+
return total_loss / len(train_loader)
|
|
346
|
+
|
|
347
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
348
|
+
"""Validate for one epoch with dimension safety"""
|
|
349
|
+
self.model.eval()
|
|
350
|
+
total_loss = 0
|
|
351
|
+
|
|
352
|
+
with torch.no_grad():
|
|
353
|
+
for batch_idx, (latent, target) in enumerate(val_loader):
|
|
354
|
+
latent = latent.to(self.device)
|
|
355
|
+
target = target.to(self.device)
|
|
356
|
+
|
|
357
|
+
pred = self.model(latent)
|
|
358
|
+
|
|
359
|
+
# Dimension safety
|
|
360
|
+
if pred.shape[1] != target.shape[1]:
|
|
361
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
362
|
+
pred = pred[:, :min_dim]
|
|
363
|
+
target = target[:, :min_dim]
|
|
364
|
+
|
|
365
|
+
loss = loss_fn(pred, target)
|
|
366
|
+
total_loss += loss.item()
|
|
367
|
+
|
|
368
|
+
return total_loss / len(val_loader)
|
|
369
|
+
|
|
370
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
371
|
+
"""Save model checkpoint"""
|
|
372
|
+
torch.save({
|
|
373
|
+
'epoch': epoch,
|
|
374
|
+
'model_state_dict': self.model.state_dict(),
|
|
375
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
376
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
377
|
+
'best_val_loss': best_loss,
|
|
378
|
+
'training_history': history,
|
|
379
|
+
'model_config': {
|
|
380
|
+
'latent_dim': self.latent_dim,
|
|
381
|
+
'gene_dim': self.gene_dim,
|
|
382
|
+
'hidden_dim': self.hidden_dim
|
|
383
|
+
}
|
|
384
|
+
}, path)
|
|
385
|
+
|
|
386
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
387
|
+
"""
|
|
388
|
+
Predict gene expression from latent variables
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
latent_data: Latent variables [n_samples, latent_dim]
|
|
392
|
+
batch_size: Prediction batch size
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
expression: Predicted expression [n_samples, gene_dim]
|
|
396
|
+
"""
|
|
397
|
+
if not self.is_trained:
|
|
398
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
399
|
+
|
|
400
|
+
self.model.eval()
|
|
401
|
+
|
|
402
|
+
# Input validation
|
|
403
|
+
if latent_data.shape[1] != self.latent_dim:
|
|
404
|
+
raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
405
|
+
|
|
406
|
+
if isinstance(latent_data, np.ndarray):
|
|
407
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
408
|
+
|
|
409
|
+
# Predict in batches to save memory
|
|
410
|
+
predictions = []
|
|
411
|
+
with torch.no_grad():
|
|
412
|
+
for i in range(0, len(latent_data), batch_size):
|
|
413
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
414
|
+
batch_pred = self.model(batch_latent)
|
|
415
|
+
predictions.append(batch_pred.cpu())
|
|
416
|
+
|
|
417
|
+
return torch.cat(predictions).numpy()
|
|
418
|
+
|
|
419
|
+
def load_model(self, model_path: str):
|
|
420
|
+
"""Load pre-trained model"""
|
|
421
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
422
|
+
|
|
423
|
+
# Check model configuration
|
|
424
|
+
if 'model_config' in checkpoint:
|
|
425
|
+
config = checkpoint['model_config']
|
|
426
|
+
if (config['latent_dim'] != self.latent_dim or
|
|
427
|
+
config['gene_dim'] != self.gene_dim):
|
|
428
|
+
print("⚠️ Model configuration mismatch. Reinitializing model.")
|
|
429
|
+
self.model = self._build_model()
|
|
430
|
+
self.model.to(self.device)
|
|
431
|
+
|
|
432
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
433
|
+
self.is_trained = True
|
|
434
|
+
self.training_history = checkpoint.get('training_history')
|
|
435
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
436
|
+
|
|
437
|
+
print(f"✅ Model loaded successfully!")
|
|
438
|
+
print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
|
|
439
|
+
|
|
440
|
+
def get_model_info(self) -> Dict:
|
|
441
|
+
"""Get model information"""
|
|
442
|
+
return {
|
|
443
|
+
'is_trained': self.is_trained,
|
|
444
|
+
'best_val_loss': self.best_val_loss,
|
|
445
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
446
|
+
'latent_dim': self.latent_dim,
|
|
447
|
+
'gene_dim': self.gene_dim,
|
|
448
|
+
'hidden_dim': self.hidden_dim,
|
|
449
|
+
'device': str(self.device)
|
|
450
|
+
}
|
|
451
|
+
'''
|
|
452
|
+
# Example usage
|
|
453
|
+
def example_usage():
|
|
454
|
+
"""Example demonstration with dimension safety"""
|
|
455
|
+
|
|
456
|
+
# 1. Initialize decoder
|
|
457
|
+
decoder = SimpleTranscriptomeDecoder(
|
|
458
|
+
latent_dim=100,
|
|
459
|
+
gene_dim=2000, # Reduced for example
|
|
460
|
+
hidden_dim=256
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# 2. Generate example data with correct dimensions
|
|
464
|
+
n_samples = 1000
|
|
465
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
466
|
+
|
|
467
|
+
# Create simulated expression data
|
|
468
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
469
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
470
|
+
expression_data = np.maximum(expression_data, 0) # Non-negative
|
|
471
|
+
|
|
472
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
473
|
+
|
|
474
|
+
# 3. Train the model
|
|
475
|
+
history = decoder.train(
|
|
476
|
+
train_latent=latent_data,
|
|
477
|
+
train_expression=expression_data,
|
|
478
|
+
batch_size=32,
|
|
479
|
+
num_epochs=50,
|
|
480
|
+
learning_rate=1e-4
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# 4. Make predictions
|
|
484
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
485
|
+
predictions = decoder.predict(test_latent)
|
|
486
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
487
|
+
|
|
488
|
+
# 5. Get model info
|
|
489
|
+
info = decoder.get_model_info()
|
|
490
|
+
print(f"\n📋 Model Info:")
|
|
491
|
+
for key, value in info.items():
|
|
492
|
+
print(f" {key}: {value}")
|
|
493
|
+
|
|
494
|
+
return decoder
|
|
495
|
+
|
|
496
|
+
if __name__ == "__main__":
|
|
497
|
+
example_usage()
|
|
498
|
+
|
|
499
|
+
'''
|
SURE/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
+
from .PerturbE import PerturbE
|
|
4
|
+
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
3
5
|
|
|
4
6
|
from . import utils
|
|
5
7
|
from . import codebook
|
|
@@ -8,5 +10,7 @@ from . import DensityFlow
|
|
|
8
10
|
from . import atac
|
|
9
11
|
from . import flow
|
|
10
12
|
from . import perturb
|
|
13
|
+
from . import PerturbE
|
|
14
|
+
from . import TranscriptomeDecoder
|
|
11
15
|
|
|
12
|
-
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
16
|
+
__all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
SURE/utils/custom_mlp.py
CHANGED
|
@@ -240,12 +240,49 @@ class ZeroBiasMLP(nn.Module):
|
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
242
|
if len(y.shape)==2:
|
|
243
|
-
|
|
243
|
+
if type(x)==list:
|
|
244
|
+
mask[x[1][:,0]>0,:] = 1
|
|
245
|
+
else:
|
|
246
|
+
mask[x[:,0]>0,:] = 1
|
|
244
247
|
elif len(y.shape)==3:
|
|
245
|
-
|
|
248
|
+
if type(x)==list:
|
|
249
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
250
|
+
else:
|
|
251
|
+
mask[:,x[:,0]>0,:] = 1
|
|
246
252
|
return y*mask
|
|
247
253
|
|
|
248
254
|
|
|
255
|
+
|
|
256
|
+
class ZeroBiasMLP2(nn.Module):
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
mlp_sizes,
|
|
260
|
+
activation=nn.ReLU,
|
|
261
|
+
output_activation=None,
|
|
262
|
+
post_layer_fct=lambda layer_ix, total_layers, layer: None,
|
|
263
|
+
post_act_fct=lambda layer_ix, total_layers, layer: None,
|
|
264
|
+
allow_broadcast=False,
|
|
265
|
+
use_cuda=False,
|
|
266
|
+
):
|
|
267
|
+
# init the module object
|
|
268
|
+
super().__init__()
|
|
269
|
+
self.mlp = MLP(mlp_sizes=mlp_sizes,
|
|
270
|
+
activation=activation,
|
|
271
|
+
output_activation=output_activation,
|
|
272
|
+
post_layer_fct=post_layer_fct,
|
|
273
|
+
post_act_fct=post_act_fct,
|
|
274
|
+
allow_broadcast=allow_broadcast,
|
|
275
|
+
use_cuda=use_cuda,
|
|
276
|
+
bias=True)
|
|
277
|
+
|
|
278
|
+
# pass through our sequential for the output!
|
|
279
|
+
def forward(self, x):
|
|
280
|
+
y = self.mlp(x)
|
|
281
|
+
mask = torch.zeros_like(y)
|
|
282
|
+
x_sum = torch.sum(x, dim=1)
|
|
283
|
+
mask[x_sum>0,:] = 1
|
|
284
|
+
return y*mask
|
|
285
|
+
|
|
249
286
|
class HDMLP(nn.Module):
|
|
250
287
|
def __init__(
|
|
251
288
|
self,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=YvaE9aPbAC2U7WhTye5i2AMtcw0BI_qS3gv9SP4aE0k,56676
|
|
2
|
+
SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
|
|
2
3
|
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
3
|
-
SURE/
|
|
4
|
+
SURE/TranscriptomeDecoder.py,sha256=e1AOt5fVTSfHTcNK1pyUfny3hFqdMoIRJ3NVh8r7wuY,20387
|
|
5
|
+
SURE/__init__.py,sha256=pNSGQ4BMqMXBAPHpFOYNB8_0vFW-RqPy3rr5fvdEEyU,473
|
|
4
6
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
7
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
6
8
|
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
@@ -14,12 +16,12 @@ SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
|
14
16
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
17
|
SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
|
|
16
18
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
19
|
+
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
18
20
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
21
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.
|
|
21
|
-
sure_tools-2.
|
|
22
|
-
sure_tools-2.
|
|
23
|
-
sure_tools-2.
|
|
24
|
-
sure_tools-2.
|
|
25
|
-
sure_tools-2.
|
|
22
|
+
sure_tools-2.4.7.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
23
|
+
sure_tools-2.4.7.dist-info/METADATA,sha256=4TYiOBuq9ddmR7U9GaO1YQvyRRIZr37cwM50bHJ0O2E,2677
|
|
24
|
+
sure_tools-2.4.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
sure_tools-2.4.7.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
26
|
+
sure_tools-2.4.7.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
27
|
+
sure_tools-2.4.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|