SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +130 -65
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbE.py +1300 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +511 -0
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +17 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/METADATA +1 -1
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/RECORD +17 -9
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,737 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class PerturbationAwareDecoder:
|
|
13
|
+
"""
|
|
14
|
+
Advanced transcriptome decoder with perturbation awareness
|
|
15
|
+
Fixed version with proper handling of single hidden layer configurations
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
num_known_perturbations: int = 50,
|
|
21
|
+
gene_dim: int = 60000,
|
|
22
|
+
hidden_dims: List[int] = [512],
|
|
23
|
+
perturbation_embedding_dim: int = 128,
|
|
24
|
+
biological_prior_dim: int = 256,
|
|
25
|
+
dropout_rate: float = 0.1,
|
|
26
|
+
device: str = None):
|
|
27
|
+
"""
|
|
28
|
+
Multi-modal decoder with fixed single layer support
|
|
29
|
+
"""
|
|
30
|
+
self.latent_dim = latent_dim
|
|
31
|
+
self.num_known_perturbations = num_known_perturbations
|
|
32
|
+
self.gene_dim = gene_dim
|
|
33
|
+
self.hidden_dims = hidden_dims
|
|
34
|
+
self.perturbation_embedding_dim = perturbation_embedding_dim
|
|
35
|
+
self.biological_prior_dim = biological_prior_dim
|
|
36
|
+
self.dropout_rate = dropout_rate
|
|
37
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
38
|
+
|
|
39
|
+
# Validate hidden_dims
|
|
40
|
+
self._validate_hidden_dims()
|
|
41
|
+
|
|
42
|
+
# Initialize multi-modal model
|
|
43
|
+
self.model = self._build_fixed_model()
|
|
44
|
+
self.model.to(self.device)
|
|
45
|
+
|
|
46
|
+
# Training state
|
|
47
|
+
self.is_trained = False
|
|
48
|
+
self.training_history = None
|
|
49
|
+
self.best_val_loss = float('inf')
|
|
50
|
+
self.known_perturbation_names = []
|
|
51
|
+
self.perturbation_prototypes = None
|
|
52
|
+
|
|
53
|
+
print(f"🧬 PerturbationAwareDecoder Initialized:")
|
|
54
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
55
|
+
print(f" - Known Perturbations: {num_known_perturbations}")
|
|
56
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
57
|
+
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
58
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
59
|
+
|
|
60
|
+
def _validate_hidden_dims(self):
|
|
61
|
+
"""Validate hidden_dims parameter"""
|
|
62
|
+
assert len(self.hidden_dims) >= 1, "hidden_dims must have at least one element"
|
|
63
|
+
assert all(dim > 0 for dim in self.hidden_dims), "All hidden dimensions must be positive"
|
|
64
|
+
|
|
65
|
+
if len(self.hidden_dims) == 1:
|
|
66
|
+
print("🔧 Single hidden layer configuration detected")
|
|
67
|
+
else:
|
|
68
|
+
print(f"🔧 Multi-layer configuration: {len(self.hidden_dims)} hidden layers")
|
|
69
|
+
|
|
70
|
+
class FixedPerturbationEncoder(nn.Module):
|
|
71
|
+
"""Fixed perturbation encoder"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.num_perturbations = num_perturbations
|
|
76
|
+
|
|
77
|
+
# Embedding for perturbation types
|
|
78
|
+
self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
|
|
79
|
+
|
|
80
|
+
# Projection to hidden space
|
|
81
|
+
self.projection = nn.Sequential(
|
|
82
|
+
nn.Linear(embedding_dim, hidden_dim),
|
|
83
|
+
nn.ReLU(),
|
|
84
|
+
nn.Dropout(0.1)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def forward(self, one_hot_perturbations):
|
|
88
|
+
# Convert one-hot to indices
|
|
89
|
+
perturbation_indices = torch.argmax(one_hot_perturbations, dim=1)
|
|
90
|
+
|
|
91
|
+
# Get perturbation embeddings
|
|
92
|
+
perturbation_embeds = self.perturbation_embedding(perturbation_indices)
|
|
93
|
+
|
|
94
|
+
# Project to hidden space
|
|
95
|
+
hidden_repr = self.projection(perturbation_embeds)
|
|
96
|
+
|
|
97
|
+
return hidden_repr
|
|
98
|
+
|
|
99
|
+
class FixedCrossModalFusion(nn.Module):
|
|
100
|
+
"""Fixed cross-modal fusion"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.latent_projection = nn.Linear(latent_dim, fusion_dim)
|
|
105
|
+
self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
|
|
106
|
+
|
|
107
|
+
# Fusion gate
|
|
108
|
+
self.fusion_gate = nn.Sequential(
|
|
109
|
+
nn.Linear(fusion_dim * 2, fusion_dim),
|
|
110
|
+
nn.Sigmoid()
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.norm = nn.LayerNorm(fusion_dim)
|
|
114
|
+
self.dropout = nn.Dropout(0.1)
|
|
115
|
+
|
|
116
|
+
def forward(self, latent, perturbation_encoded):
|
|
117
|
+
# Project both modalities
|
|
118
|
+
latent_proj = self.latent_projection(latent)
|
|
119
|
+
perturbation_proj = self.perturbation_projection(perturbation_encoded)
|
|
120
|
+
|
|
121
|
+
# Gated fusion
|
|
122
|
+
concatenated = torch.cat([latent_proj, perturbation_proj], dim=-1)
|
|
123
|
+
fusion_gate = self.fusion_gate(concatenated)
|
|
124
|
+
|
|
125
|
+
# Gated fusion
|
|
126
|
+
fused = fusion_gate * latent_proj + (1 - fusion_gate) * perturbation_proj
|
|
127
|
+
fused = self.norm(fused)
|
|
128
|
+
fused = self.dropout(fused)
|
|
129
|
+
|
|
130
|
+
return fused
|
|
131
|
+
|
|
132
|
+
class FixedPerturbationResponseNetwork(nn.Module):
|
|
133
|
+
"""Fixed response network with proper single layer handling"""
|
|
134
|
+
|
|
135
|
+
def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
|
|
136
|
+
super().__init__()
|
|
137
|
+
|
|
138
|
+
# Build network layers
|
|
139
|
+
layers = []
|
|
140
|
+
input_dim = fusion_dim
|
|
141
|
+
|
|
142
|
+
# Handle both single and multi-layer cases
|
|
143
|
+
for i, hidden_dim in enumerate(hidden_dims):
|
|
144
|
+
layers.extend([
|
|
145
|
+
nn.Linear(input_dim, hidden_dim),
|
|
146
|
+
nn.BatchNorm1d(hidden_dim),
|
|
147
|
+
nn.ReLU(),
|
|
148
|
+
nn.Dropout(0.1)
|
|
149
|
+
])
|
|
150
|
+
input_dim = hidden_dim
|
|
151
|
+
|
|
152
|
+
self.base_network = nn.Sequential(*layers)
|
|
153
|
+
|
|
154
|
+
# Final projection - FIXED: Use current input_dim instead of hidden_dims[-1]
|
|
155
|
+
self.final_projection = nn.Linear(input_dim, gene_dim)
|
|
156
|
+
|
|
157
|
+
# Perturbation-aware scaling
|
|
158
|
+
self.scale = nn.Linear(fusion_dim, 1)
|
|
159
|
+
self.bias = nn.Linear(fusion_dim, 1)
|
|
160
|
+
|
|
161
|
+
def forward(self, fused_representation):
|
|
162
|
+
base_output = self.base_network(fused_representation)
|
|
163
|
+
expression = self.final_projection(base_output)
|
|
164
|
+
|
|
165
|
+
# Perturbation-aware scaling
|
|
166
|
+
scale = torch.sigmoid(self.scale(fused_representation)) * 2
|
|
167
|
+
bias = self.bias(fused_representation)
|
|
168
|
+
|
|
169
|
+
return F.softplus(expression * scale + bias)
|
|
170
|
+
|
|
171
|
+
class FixedNovelPerturbationPredictor(nn.Module):
|
|
172
|
+
"""Fixed novel perturbation predictor"""
|
|
173
|
+
|
|
174
|
+
def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
|
|
175
|
+
super().__init__()
|
|
176
|
+
self.num_known_perturbations = num_known_perturbations
|
|
177
|
+
self.gene_dim = gene_dim
|
|
178
|
+
|
|
179
|
+
# Learnable perturbation prototypes
|
|
180
|
+
self.perturbation_prototypes = nn.Parameter(
|
|
181
|
+
torch.randn(num_known_perturbations, gene_dim) * 0.1
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Response generator - handle case where hidden_dim might be 0
|
|
185
|
+
if hidden_dim > 0:
|
|
186
|
+
self.response_generator = nn.Sequential(
|
|
187
|
+
nn.Linear(num_known_perturbations, hidden_dim),
|
|
188
|
+
nn.ReLU(),
|
|
189
|
+
nn.Linear(hidden_dim, gene_dim)
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
# Direct projection if no hidden layer
|
|
193
|
+
self.response_generator = nn.Linear(num_known_perturbations, gene_dim)
|
|
194
|
+
|
|
195
|
+
# Attention mechanism
|
|
196
|
+
self.similarity_attention = nn.Sequential(
|
|
197
|
+
nn.Linear(num_known_perturbations, num_known_perturbations),
|
|
198
|
+
nn.Softmax(dim=-1)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def forward(self, similarity_matrix, latent_features=None):
|
|
202
|
+
batch_size = similarity_matrix.shape[0]
|
|
203
|
+
|
|
204
|
+
# Method 1: Attention-weighted combination of known responses
|
|
205
|
+
attention_weights = self.similarity_attention(similarity_matrix)
|
|
206
|
+
weighted_response = torch.matmul(attention_weights, self.perturbation_prototypes)
|
|
207
|
+
|
|
208
|
+
# Method 2: Direct generation from similarity
|
|
209
|
+
generated_response = self.response_generator(similarity_matrix)
|
|
210
|
+
|
|
211
|
+
# Simple combination
|
|
212
|
+
combination_weights = torch.sigmoid(similarity_matrix.mean(dim=1, keepdim=True))
|
|
213
|
+
final_response = (combination_weights * weighted_response +
|
|
214
|
+
(1 - combination_weights) * generated_response)
|
|
215
|
+
|
|
216
|
+
return final_response
|
|
217
|
+
|
|
218
|
+
class FixedMultimodalDecoder(nn.Module):
|
|
219
|
+
"""Main decoder with fixed single layer handling"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
|
|
222
|
+
hidden_dims: List[int], perturbation_embedding_dim: int,
|
|
223
|
+
biological_prior_dim: int, dropout_rate: float):
|
|
224
|
+
super().__init__()
|
|
225
|
+
|
|
226
|
+
self.num_known_perturbations = num_known_perturbations
|
|
227
|
+
self.latent_dim = latent_dim
|
|
228
|
+
self.gene_dim = gene_dim
|
|
229
|
+
|
|
230
|
+
# Use first hidden dimension for fusion
|
|
231
|
+
main_hidden_dim = hidden_dims[0]
|
|
232
|
+
|
|
233
|
+
# Perturbation encoder
|
|
234
|
+
self.perturbation_encoder = PerturbationAwareDecoder.FixedPerturbationEncoder(
|
|
235
|
+
num_known_perturbations, perturbation_embedding_dim, main_hidden_dim
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Cross-modal fusion
|
|
239
|
+
self.cross_modal_fusion = PerturbationAwareDecoder.FixedCrossModalFusion(
|
|
240
|
+
latent_dim, main_hidden_dim, main_hidden_dim
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Response network - FIXED: Use all hidden_dims for response network
|
|
244
|
+
self.response_network = PerturbationAwareDecoder.FixedPerturbationResponseNetwork(
|
|
245
|
+
main_hidden_dim, gene_dim, hidden_dims # Pass all hidden_dims
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Novel perturbation predictor
|
|
249
|
+
self.novel_predictor = PerturbationAwareDecoder.FixedNovelPerturbationPredictor(
|
|
250
|
+
num_known_perturbations, gene_dim, main_hidden_dim
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def forward(self, latent, perturbation_matrix, mode='one_hot'):
|
|
254
|
+
if mode == 'one_hot':
|
|
255
|
+
# Known perturbation pathway
|
|
256
|
+
perturbation_encoded = self.perturbation_encoder(perturbation_matrix)
|
|
257
|
+
fused = self.cross_modal_fusion(latent, perturbation_encoded)
|
|
258
|
+
expression = self.response_network(fused)
|
|
259
|
+
|
|
260
|
+
elif mode == 'similarity':
|
|
261
|
+
# Novel perturbation pathway
|
|
262
|
+
expression = self.novel_predictor(perturbation_matrix, latent)
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
raise ValueError(f"Unknown mode: {mode}. Use 'one_hot' or 'similarity'")
|
|
266
|
+
|
|
267
|
+
return expression
|
|
268
|
+
|
|
269
|
+
def get_perturbation_prototypes(self):
|
|
270
|
+
"""Get learned perturbation response prototypes"""
|
|
271
|
+
return self.novel_predictor.perturbation_prototypes.detach()
|
|
272
|
+
|
|
273
|
+
def _build_fixed_model(self):
|
|
274
|
+
"""Build the fixed model"""
|
|
275
|
+
return self.FixedMultimodalDecoder(
|
|
276
|
+
self.latent_dim, self.num_known_perturbations, self.gene_dim,
|
|
277
|
+
self.hidden_dims, self.perturbation_embedding_dim,
|
|
278
|
+
self.biological_prior_dim, self.dropout_rate
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def train(self,
|
|
282
|
+
train_latent: np.ndarray,
|
|
283
|
+
train_perturbations: np.ndarray,
|
|
284
|
+
train_expression: np.ndarray,
|
|
285
|
+
val_latent: np.ndarray = None,
|
|
286
|
+
val_perturbations: np.ndarray = None,
|
|
287
|
+
val_expression: np.ndarray = None,
|
|
288
|
+
batch_size: int = 32,
|
|
289
|
+
num_epochs: int = 200,
|
|
290
|
+
learning_rate: float = 1e-4,
|
|
291
|
+
checkpoint_path: str = 'fixed_decoder.pth') -> Dict:
|
|
292
|
+
"""
|
|
293
|
+
Train the fixed decoder
|
|
294
|
+
"""
|
|
295
|
+
print("🧬 Starting Training with Fixed Single Layer Support...")
|
|
296
|
+
|
|
297
|
+
# Validate one-hot encoding
|
|
298
|
+
self._validate_one_hot_perturbations(train_perturbations)
|
|
299
|
+
|
|
300
|
+
# Data preparation
|
|
301
|
+
train_dataset = self._create_dataset(train_latent, train_perturbations, train_expression)
|
|
302
|
+
|
|
303
|
+
if val_latent is not None and val_perturbations is not None and val_expression is not None:
|
|
304
|
+
val_dataset = self._create_dataset(val_latent, val_perturbations, val_expression)
|
|
305
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
306
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
307
|
+
else:
|
|
308
|
+
train_size = int(0.9 * len(train_dataset))
|
|
309
|
+
val_size = len(train_dataset) - train_size
|
|
310
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
311
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
|
|
312
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
|
|
313
|
+
|
|
314
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
315
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
316
|
+
print(f"🔧 Hidden layers: {len(self.hidden_dims)}")
|
|
317
|
+
print(f"🔧 Hidden dimensions: {self.hidden_dims}")
|
|
318
|
+
|
|
319
|
+
# Optimizer
|
|
320
|
+
optimizer = optim.AdamW(
|
|
321
|
+
self.model.parameters(),
|
|
322
|
+
lr=learning_rate,
|
|
323
|
+
weight_decay=1e-5
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Scheduler
|
|
327
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
328
|
+
|
|
329
|
+
# Loss function
|
|
330
|
+
def loss_fn(pred, target):
|
|
331
|
+
mse_loss = F.mse_loss(pred, target)
|
|
332
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
333
|
+
correlation = self._pearson_correlation(pred, target)
|
|
334
|
+
correlation_loss = 1 - correlation
|
|
335
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
336
|
+
|
|
337
|
+
# Training history
|
|
338
|
+
history = {
|
|
339
|
+
'train_loss': [], 'val_loss': [],
|
|
340
|
+
'train_mse': [], 'val_mse': [],
|
|
341
|
+
'train_correlation': [], 'val_correlation': [],
|
|
342
|
+
'learning_rates': []
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
best_val_loss = float('inf')
|
|
346
|
+
patience = 20
|
|
347
|
+
patience_counter = 0
|
|
348
|
+
|
|
349
|
+
print("\n🔬 Starting training...")
|
|
350
|
+
for epoch in range(1, num_epochs + 1):
|
|
351
|
+
# Training
|
|
352
|
+
train_metrics = self._train_epoch(train_loader, optimizer, loss_fn)
|
|
353
|
+
|
|
354
|
+
# Validation
|
|
355
|
+
val_metrics = self._validate_epoch(val_loader, loss_fn)
|
|
356
|
+
|
|
357
|
+
# Update scheduler
|
|
358
|
+
scheduler.step()
|
|
359
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
360
|
+
|
|
361
|
+
# Record history
|
|
362
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
363
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
364
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
365
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
366
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
367
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
368
|
+
history['learning_rates'].append(current_lr)
|
|
369
|
+
|
|
370
|
+
# Print progress
|
|
371
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
372
|
+
print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
|
|
373
|
+
f"Train: {train_metrics['loss']:.4f} | "
|
|
374
|
+
f"Val: {val_metrics['loss']:.4f} | "
|
|
375
|
+
f"Corr: {val_metrics['correlation']:.4f} | "
|
|
376
|
+
f"LR: {current_lr:.2e}")
|
|
377
|
+
|
|
378
|
+
# Early stopping
|
|
379
|
+
if val_metrics['loss'] < best_val_loss:
|
|
380
|
+
best_val_loss = val_metrics['loss']
|
|
381
|
+
patience_counter = 0
|
|
382
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
383
|
+
else:
|
|
384
|
+
patience_counter += 1
|
|
385
|
+
if patience_counter >= patience:
|
|
386
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
387
|
+
break
|
|
388
|
+
|
|
389
|
+
self.is_trained = True
|
|
390
|
+
self.training_history = history
|
|
391
|
+
self.best_val_loss = best_val_loss
|
|
392
|
+
self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
|
|
393
|
+
|
|
394
|
+
print(f"\n🎉 Training completed! Best val loss: {best_val_loss:.4f}")
|
|
395
|
+
return history
|
|
396
|
+
|
|
397
|
+
def _validate_one_hot_perturbations(self, perturbations):
|
|
398
|
+
"""Validate that perturbations are proper one-hot encodings"""
|
|
399
|
+
assert perturbations.shape[1] == self.num_known_perturbations, \
|
|
400
|
+
f"Perturbation dimension {perturbations.shape[1]} doesn't match expected {self.num_known_perturbations}"
|
|
401
|
+
|
|
402
|
+
row_sums = perturbations.sum(axis=1)
|
|
403
|
+
valid_rows = np.all((row_sums == 0) | (row_sums == 1))
|
|
404
|
+
assert valid_rows, "Perturbations should be one-hot encoded (sum to 0 or 1 per row)"
|
|
405
|
+
|
|
406
|
+
print("✅ One-hot perturbations validated")
|
|
407
|
+
|
|
408
|
+
def _create_dataset(self, latent_data, perturbations, expression_data):
|
|
409
|
+
"""Create dataset with one-hot perturbations"""
|
|
410
|
+
class OneHotDataset(Dataset):
|
|
411
|
+
def __init__(self, latent, perturbations, expression):
|
|
412
|
+
self.latent = torch.FloatTensor(latent)
|
|
413
|
+
self.perturbations = torch.FloatTensor(perturbations)
|
|
414
|
+
self.expression = torch.FloatTensor(expression)
|
|
415
|
+
|
|
416
|
+
def __len__(self):
|
|
417
|
+
return len(self.latent)
|
|
418
|
+
|
|
419
|
+
def __getitem__(self, idx):
|
|
420
|
+
return self.latent[idx], self.perturbations[idx], self.expression[idx]
|
|
421
|
+
|
|
422
|
+
return OneHotDataset(latent_data, perturbations, expression_data)
|
|
423
|
+
|
|
424
|
+
def predict(self,
|
|
425
|
+
latent_data: np.ndarray,
|
|
426
|
+
perturbations: np.ndarray,
|
|
427
|
+
batch_size: int = 32) -> np.ndarray:
|
|
428
|
+
"""
|
|
429
|
+
Predict expression for known perturbations
|
|
430
|
+
"""
|
|
431
|
+
if not self.is_trained:
|
|
432
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
433
|
+
|
|
434
|
+
self._validate_one_hot_perturbations(perturbations)
|
|
435
|
+
|
|
436
|
+
self.model.eval()
|
|
437
|
+
|
|
438
|
+
if isinstance(latent_data, np.ndarray):
|
|
439
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
440
|
+
if isinstance(perturbations, np.ndarray):
|
|
441
|
+
perturbations = torch.FloatTensor(perturbations)
|
|
442
|
+
|
|
443
|
+
predictions = []
|
|
444
|
+
with torch.no_grad():
|
|
445
|
+
for i in range(0, len(latent_data), batch_size):
|
|
446
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
447
|
+
batch_perturbations = perturbations[i:i+batch_size].to(self.device)
|
|
448
|
+
|
|
449
|
+
batch_pred = self.model(batch_latent, batch_perturbations, mode='one_hot')
|
|
450
|
+
predictions.append(batch_pred.cpu())
|
|
451
|
+
|
|
452
|
+
return torch.cat(predictions).numpy()
|
|
453
|
+
|
|
454
|
+
def predict_novel_perturbation(self,
|
|
455
|
+
latent_data: np.ndarray,
|
|
456
|
+
similarity_matrix: np.ndarray,
|
|
457
|
+
batch_size: int = 32) -> np.ndarray:
|
|
458
|
+
"""
|
|
459
|
+
Predict response to novel perturbations
|
|
460
|
+
"""
|
|
461
|
+
if not self.is_trained:
|
|
462
|
+
warnings.warn("⚠️ Model not trained. Novel perturbation prediction may be inaccurate.")
|
|
463
|
+
|
|
464
|
+
assert similarity_matrix.shape[1] == self.num_known_perturbations, \
|
|
465
|
+
f"Similarity matrix columns {similarity_matrix.shape[1]} must match known perturbations {self.num_known_perturbations}"
|
|
466
|
+
|
|
467
|
+
self.model.eval()
|
|
468
|
+
|
|
469
|
+
if isinstance(latent_data, np.ndarray):
|
|
470
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
471
|
+
if isinstance(similarity_matrix, np.ndarray):
|
|
472
|
+
similarity_matrix = torch.FloatTensor(similarity_matrix)
|
|
473
|
+
|
|
474
|
+
predictions = []
|
|
475
|
+
with torch.no_grad():
|
|
476
|
+
for i in range(0, len(latent_data), batch_size):
|
|
477
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
478
|
+
batch_similarity = similarity_matrix[i:i+batch_size].to(self.device)
|
|
479
|
+
|
|
480
|
+
batch_pred = self.model(batch_latent, batch_similarity, mode='similarity')
|
|
481
|
+
predictions.append(batch_pred.cpu())
|
|
482
|
+
|
|
483
|
+
return torch.cat(predictions).numpy()
|
|
484
|
+
|
|
485
|
+
def get_known_perturbation_prototypes(self) -> np.ndarray:
|
|
486
|
+
"""Get learned perturbation response prototypes"""
|
|
487
|
+
if not self.is_trained:
|
|
488
|
+
warnings.warn("⚠️ Model not trained. Prototypes may be uninformative.")
|
|
489
|
+
|
|
490
|
+
if self.perturbation_prototypes is None:
|
|
491
|
+
self.model.eval()
|
|
492
|
+
with torch.no_grad():
|
|
493
|
+
self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
|
|
494
|
+
|
|
495
|
+
return self.perturbation_prototypes
|
|
496
|
+
|
|
497
|
+
def _pearson_correlation(self, pred, target):
|
|
498
|
+
"""Calculate Pearson correlation coefficient"""
|
|
499
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
500
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
501
|
+
|
|
502
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
503
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
504
|
+
|
|
505
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
506
|
+
|
|
507
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
508
|
+
"""Train one epoch"""
|
|
509
|
+
self.model.train()
|
|
510
|
+
total_loss = 0
|
|
511
|
+
total_mse = 0
|
|
512
|
+
total_correlation = 0
|
|
513
|
+
|
|
514
|
+
for latent, perturbations, target in train_loader:
|
|
515
|
+
latent = latent.to(self.device)
|
|
516
|
+
perturbations = perturbations.to(self.device)
|
|
517
|
+
target = target.to(self.device)
|
|
518
|
+
|
|
519
|
+
optimizer.zero_grad()
|
|
520
|
+
pred = self.model(latent, perturbations, mode='one_hot')
|
|
521
|
+
|
|
522
|
+
loss = loss_fn(pred, target)
|
|
523
|
+
loss.backward()
|
|
524
|
+
|
|
525
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
526
|
+
optimizer.step()
|
|
527
|
+
|
|
528
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
529
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
530
|
+
|
|
531
|
+
total_loss += loss.item()
|
|
532
|
+
total_mse += mse_loss
|
|
533
|
+
total_correlation += correlation
|
|
534
|
+
|
|
535
|
+
num_batches = len(train_loader)
|
|
536
|
+
return {
|
|
537
|
+
'loss': total_loss / num_batches,
|
|
538
|
+
'mse': total_mse / num_batches,
|
|
539
|
+
'correlation': total_correlation / num_batches
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
543
|
+
"""Validate one epoch"""
|
|
544
|
+
self.model.eval()
|
|
545
|
+
total_loss = 0
|
|
546
|
+
total_mse = 0
|
|
547
|
+
total_correlation = 0
|
|
548
|
+
|
|
549
|
+
with torch.no_grad():
|
|
550
|
+
for latent, perturbations, target in val_loader:
|
|
551
|
+
latent = latent.to(self.device)
|
|
552
|
+
perturbations = perturbations.to(self.device)
|
|
553
|
+
target = target.to(self.device)
|
|
554
|
+
|
|
555
|
+
pred = self.model(latent, perturbations, mode='one_hot')
|
|
556
|
+
loss = loss_fn(pred, target)
|
|
557
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
558
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
559
|
+
|
|
560
|
+
total_loss += loss.item()
|
|
561
|
+
total_mse += mse_loss
|
|
562
|
+
total_correlation += correlation
|
|
563
|
+
|
|
564
|
+
num_batches = len(val_loader)
|
|
565
|
+
return {
|
|
566
|
+
'loss': total_loss / num_batches,
|
|
567
|
+
'mse': total_mse / num_batches,
|
|
568
|
+
'correlation': total_correlation / num_batches
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
572
|
+
"""Save model checkpoint"""
|
|
573
|
+
torch.save({
|
|
574
|
+
'epoch': epoch,
|
|
575
|
+
'model_state_dict': self.model.state_dict(),
|
|
576
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
577
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
578
|
+
'best_val_loss': best_loss,
|
|
579
|
+
'training_history': history,
|
|
580
|
+
'perturbation_prototypes': self.perturbation_prototypes,
|
|
581
|
+
'model_config': {
|
|
582
|
+
'latent_dim': self.latent_dim,
|
|
583
|
+
'num_known_perturbations': self.num_known_perturbations,
|
|
584
|
+
'gene_dim': self.gene_dim,
|
|
585
|
+
'hidden_dims': self.hidden_dims
|
|
586
|
+
}
|
|
587
|
+
}, path)
|
|
588
|
+
|
|
589
|
+
def load_model(self, model_path: str):
|
|
590
|
+
"""Load pre-trained model"""
|
|
591
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
592
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
593
|
+
self.perturbation_prototypes = checkpoint.get('perturbation_prototypes')
|
|
594
|
+
self.is_trained = True
|
|
595
|
+
self.training_history = checkpoint.get('training_history')
|
|
596
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
597
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
598
|
+
|
|
599
|
+
'''# Test the fixed implementation
|
|
600
|
+
def test_single_layer_fix():
|
|
601
|
+
"""Test the fixed single layer implementation"""
|
|
602
|
+
|
|
603
|
+
print("🧪 Testing single layer configuration...")
|
|
604
|
+
|
|
605
|
+
# Test with single hidden layer
|
|
606
|
+
decoder_single = PerturbationAwareDecoder(
|
|
607
|
+
latent_dim=100,
|
|
608
|
+
num_known_perturbations=10,
|
|
609
|
+
gene_dim=2000,
|
|
610
|
+
hidden_dims=[512], # Single element list
|
|
611
|
+
perturbation_embedding_dim=128
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Generate test data
|
|
615
|
+
n_samples = 100
|
|
616
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
617
|
+
perturbations = np.zeros((n_samples, 10))
|
|
618
|
+
for i in range(n_samples):
|
|
619
|
+
if i % 10 != 0:
|
|
620
|
+
perturbations[i, np.random.randint(0, 10)] = 1.0
|
|
621
|
+
|
|
622
|
+
expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
|
|
623
|
+
expression_data = np.maximum(expression_data, 0)
|
|
624
|
+
|
|
625
|
+
# Test forward pass
|
|
626
|
+
decoder_single.model.eval()
|
|
627
|
+
with torch.no_grad():
|
|
628
|
+
latent_tensor = torch.FloatTensor(latent_data[:5]).to(decoder_single.device)
|
|
629
|
+
perturbations_tensor = torch.FloatTensor(perturbations[:5]).to(decoder_single.device)
|
|
630
|
+
|
|
631
|
+
# Test known perturbation prediction
|
|
632
|
+
output = decoder_single.model(latent_tensor, perturbations_tensor, mode='one_hot')
|
|
633
|
+
print(f"✅ Known perturbation prediction shape: {output.shape}")
|
|
634
|
+
|
|
635
|
+
# Test novel perturbation prediction
|
|
636
|
+
similarity_matrix = np.random.rand(5, 10).astype(np.float32)
|
|
637
|
+
similarity_tensor = torch.FloatTensor(similarity_matrix).to(decoder_single.device)
|
|
638
|
+
novel_output = decoder_single.model(latent_tensor, similarity_tensor, mode='similarity')
|
|
639
|
+
print(f"✅ Novel perturbation prediction shape: {novel_output.shape}")
|
|
640
|
+
|
|
641
|
+
print("🎉 Single layer test passed!")
|
|
642
|
+
|
|
643
|
+
def test_multi_layer_fix():
|
|
644
|
+
"""Test the multi-layer implementation"""
|
|
645
|
+
|
|
646
|
+
print("\n🧪 Testing multi-layer configuration...")
|
|
647
|
+
|
|
648
|
+
# Test with multiple hidden layers
|
|
649
|
+
decoder_multi = PerturbationAwareDecoder(
|
|
650
|
+
latent_dim=100,
|
|
651
|
+
num_known_perturbations=10,
|
|
652
|
+
gene_dim=2000,
|
|
653
|
+
hidden_dims=[256, 512, 1024], # Multiple layers
|
|
654
|
+
perturbation_embedding_dim=128
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
print("🎉 Multi-layer test passed!")
|
|
658
|
+
|
|
659
|
+
def test_edge_cases():
|
|
660
|
+
"""Test edge cases"""
|
|
661
|
+
|
|
662
|
+
print("\n🧪 Testing edge cases...")
|
|
663
|
+
|
|
664
|
+
# Test with different hidden_dims configurations
|
|
665
|
+
configs = [
|
|
666
|
+
[512], # Single layer
|
|
667
|
+
[256, 512], # Two layers
|
|
668
|
+
[128, 256, 512], # Three layers
|
|
669
|
+
[1024], # Wide single layer
|
|
670
|
+
[64, 128, 256, 512, 1024] # Deep network
|
|
671
|
+
]
|
|
672
|
+
|
|
673
|
+
for i, hidden_dims in enumerate(configs):
|
|
674
|
+
try:
|
|
675
|
+
decoder = PerturbationAwareDecoder(
|
|
676
|
+
latent_dim=50,
|
|
677
|
+
num_known_perturbations=5,
|
|
678
|
+
gene_dim=1000,
|
|
679
|
+
hidden_dims=hidden_dims,
|
|
680
|
+
perturbation_embedding_dim=64
|
|
681
|
+
)
|
|
682
|
+
print(f"✅ Config {i+1}: {hidden_dims} - Success")
|
|
683
|
+
except Exception as e:
|
|
684
|
+
print(f"❌ Config {i+1}: {hidden_dims} - Failed: {e}")
|
|
685
|
+
|
|
686
|
+
print("🎉 Edge case testing completed!")
|
|
687
|
+
|
|
688
|
+
if __name__ == "__main__":
|
|
689
|
+
# Run tests
|
|
690
|
+
test_single_layer_fix()
|
|
691
|
+
test_multi_layer_fix()
|
|
692
|
+
test_edge_cases()
|
|
693
|
+
|
|
694
|
+
# Example usage
|
|
695
|
+
print("\n🎯 Example Usage:")
|
|
696
|
+
|
|
697
|
+
# Single hidden layer example
|
|
698
|
+
decoder = PerturbationAwareDecoder(
|
|
699
|
+
latent_dim=100,
|
|
700
|
+
num_known_perturbations=10,
|
|
701
|
+
gene_dim=2000,
|
|
702
|
+
hidden_dims=[512], # Single hidden layer
|
|
703
|
+
perturbation_embedding_dim=128
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Generate example data
|
|
707
|
+
n_samples = 1000
|
|
708
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
709
|
+
perturbations = np.zeros((n_samples, 10))
|
|
710
|
+
for i in range(n_samples):
|
|
711
|
+
if i % 10 != 0:
|
|
712
|
+
perturbations[i, np.random.randint(0, 10)] = 1.0
|
|
713
|
+
|
|
714
|
+
# Simulate expression data
|
|
715
|
+
base_weights = np.random.randn(100, 2000) * 0.1
|
|
716
|
+
perturbation_effects = np.random.randn(10, 2000) * 0.5
|
|
717
|
+
|
|
718
|
+
expression_data = np.tanh(latent_data.dot(base_weights))
|
|
719
|
+
for i in range(n_samples):
|
|
720
|
+
if perturbations[i].sum() > 0:
|
|
721
|
+
perturb_id = np.argmax(perturbations[i])
|
|
722
|
+
expression_data[i] += perturbation_effects[perturb_id]
|
|
723
|
+
|
|
724
|
+
expression_data = np.maximum(expression_data, 0)
|
|
725
|
+
|
|
726
|
+
print(f"📊 Example data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}")
|
|
727
|
+
|
|
728
|
+
# Train (commented out for quick testing)
|
|
729
|
+
# history = decoder.train(
|
|
730
|
+
# train_latent=latent_data,
|
|
731
|
+
# train_perturbations=perturbations,
|
|
732
|
+
# train_expression=expression_data,
|
|
733
|
+
# batch_size=32,
|
|
734
|
+
# num_epochs=10 # Short training for testing
|
|
735
|
+
# )
|
|
736
|
+
|
|
737
|
+
print("🎉 All tests completed successfully!")'''
|