SURE-tools 2.4.17__tar.gz → 2.4.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.4.17 → sure_tools-2.4.20}/PKG-INFO +1 -1
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/DensityFlow.py +7 -0
- sure_tools-2.4.20/SURE/PerturbationAwareDecoder.py +770 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/VirtualCellDecoder.py +0 -1
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/__init__.py +4 -1
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/SOURCES.txt +1 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/setup.py +1 -1
- {sure_tools-2.4.17 → sure_tools-2.4.20}/LICENSE +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/README.md +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/EfficientTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/PerturbE.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/SURE.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/SimpleTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/TranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/atac/utils.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/queue.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/utils.py +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.4.17 → sure_tools-2.4.20}/setup.cfg +0 -0
|
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
|
|
|
109
109
|
|
|
110
110
|
set_random_seed(seed)
|
|
111
111
|
self.setup_networks()
|
|
112
|
+
|
|
113
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
114
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
116
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
117
|
+
print(f" - Device: {self.get_device()}")
|
|
118
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
112
119
|
|
|
113
120
|
def setup_networks(self):
|
|
114
121
|
latent_dim = self.latent_dim
|
|
@@ -0,0 +1,770 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings('ignore')
|
|
11
|
+
|
|
12
|
+
class PerturbationAwareDecoder:
|
|
13
|
+
"""
|
|
14
|
+
Advanced transcriptome decoder with perturbation awareness
|
|
15
|
+
Similarity matrix columns correspond to known perturbations for novel perturbation prediction
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
latent_dim: int = 100,
|
|
20
|
+
num_known_perturbations: int = 50, # Number of known perturbation types
|
|
21
|
+
gene_dim: int = 60000,
|
|
22
|
+
hidden_dims: List[int] = [512, 1024, 2048],
|
|
23
|
+
perturbation_embedding_dim: int = 128,
|
|
24
|
+
biological_prior_dim: int = 256,
|
|
25
|
+
dropout_rate: float = 0.1,
|
|
26
|
+
device: str = None):
|
|
27
|
+
"""
|
|
28
|
+
Multi-modal decoder with correct similarity matrix definition
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
latent_dim: Latent variable dimension
|
|
32
|
+
num_known_perturbations: Number of known perturbation types for one-hot encoding
|
|
33
|
+
gene_dim: Number of genes
|
|
34
|
+
hidden_dims: Hidden layer dimensions
|
|
35
|
+
perturbation_embedding_dim: Embedding dimension for perturbations
|
|
36
|
+
biological_prior_dim: Dimension for biological prior knowledge
|
|
37
|
+
dropout_rate: Dropout rate
|
|
38
|
+
device: Computation device
|
|
39
|
+
"""
|
|
40
|
+
self.latent_dim = latent_dim
|
|
41
|
+
self.num_known_perturbations = num_known_perturbations
|
|
42
|
+
self.gene_dim = gene_dim
|
|
43
|
+
self.hidden_dims = hidden_dims
|
|
44
|
+
self.perturbation_embedding_dim = perturbation_embedding_dim
|
|
45
|
+
self.biological_prior_dim = biological_prior_dim
|
|
46
|
+
self.dropout_rate = dropout_rate
|
|
47
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
48
|
+
|
|
49
|
+
# Initialize multi-modal model
|
|
50
|
+
self.model = self._build_corrected_model()
|
|
51
|
+
self.model.to(self.device)
|
|
52
|
+
|
|
53
|
+
# Training state
|
|
54
|
+
self.is_trained = False
|
|
55
|
+
self.training_history = None
|
|
56
|
+
self.best_val_loss = float('inf')
|
|
57
|
+
self.known_perturbation_names = [] # For mapping indices to perturbation names
|
|
58
|
+
self.perturbation_prototypes = None # Learned perturbation representations
|
|
59
|
+
|
|
60
|
+
print(f"🧬 PerturbationAwareDecoder Initialized:")
|
|
61
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
62
|
+
print(f" - Known Perturbations: {num_known_perturbations}")
|
|
63
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
64
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
65
|
+
|
|
66
|
+
class CorrectedPerturbationEncoder(nn.Module):
|
|
67
|
+
"""Encoder for one-hot encoded perturbations"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.num_perturbations = num_perturbations
|
|
72
|
+
|
|
73
|
+
# Embedding for perturbation types
|
|
74
|
+
self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
|
|
75
|
+
|
|
76
|
+
# Projection to hidden space
|
|
77
|
+
self.projection = nn.Sequential(
|
|
78
|
+
nn.Linear(embedding_dim, hidden_dim),
|
|
79
|
+
nn.ReLU(),
|
|
80
|
+
nn.Dropout(0.1),
|
|
81
|
+
nn.Linear(hidden_dim, hidden_dim)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Attention mechanism for perturbation context
|
|
85
|
+
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
|
|
86
|
+
self.norm = nn.LayerNorm(hidden_dim)
|
|
87
|
+
|
|
88
|
+
def forward(self, one_hot_perturbations):
|
|
89
|
+
"""
|
|
90
|
+
Args:
|
|
91
|
+
one_hot_perturbations: [batch_size, num_perturbations] one-hot encoded
|
|
92
|
+
"""
|
|
93
|
+
batch_size = one_hot_perturbations.shape[0]
|
|
94
|
+
|
|
95
|
+
# Convert one-hot to indices
|
|
96
|
+
perturbation_indices = torch.argmax(one_hot_perturbations, dim=1) # [batch_size]
|
|
97
|
+
|
|
98
|
+
# Get perturbation embeddings
|
|
99
|
+
perturbation_embeds = self.perturbation_embedding(perturbation_indices) # [batch_size, embedding_dim]
|
|
100
|
+
|
|
101
|
+
# Project to hidden space
|
|
102
|
+
hidden_repr = self.projection(perturbation_embeds) # [batch_size, hidden_dim]
|
|
103
|
+
|
|
104
|
+
# Add sequence dimension for attention
|
|
105
|
+
hidden_repr = hidden_repr.unsqueeze(1) # [batch_size, 1, hidden_dim]
|
|
106
|
+
|
|
107
|
+
# Self-attention for perturbation context
|
|
108
|
+
attended, _ = self.attention(hidden_repr, hidden_repr, hidden_repr)
|
|
109
|
+
attended = self.norm(hidden_repr + attended)
|
|
110
|
+
|
|
111
|
+
return attended.squeeze(1) # [batch_size, hidden_dim]
|
|
112
|
+
|
|
113
|
+
class CrossModalFusion(nn.Module):
|
|
114
|
+
"""Cross-modal fusion of latent variables and perturbation information"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.latent_projection = nn.Linear(latent_dim, fusion_dim)
|
|
119
|
+
self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
|
|
120
|
+
|
|
121
|
+
# Cross-attention
|
|
122
|
+
self.cross_attention = nn.MultiheadAttention(
|
|
123
|
+
fusion_dim, num_heads=8, batch_first=True
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Fusion gate
|
|
127
|
+
self.fusion_gate = nn.Sequential(
|
|
128
|
+
nn.Linear(fusion_dim * 2, fusion_dim),
|
|
129
|
+
nn.Sigmoid()
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self.norm = nn.LayerNorm(fusion_dim)
|
|
133
|
+
self.dropout = nn.Dropout(0.1)
|
|
134
|
+
|
|
135
|
+
def forward(self, latent, perturbation_encoded):
|
|
136
|
+
# Project both modalities
|
|
137
|
+
latent_proj = self.latent_projection(latent).unsqueeze(1) # [batch_size, 1, fusion_dim]
|
|
138
|
+
perturbation_proj = self.perturbation_projection(perturbation_encoded).unsqueeze(1)
|
|
139
|
+
|
|
140
|
+
# Cross-attention: latent attends to perturbation
|
|
141
|
+
attended, _ = self.cross_attention(latent_proj, perturbation_proj, perturbation_proj)
|
|
142
|
+
|
|
143
|
+
# Gated fusion
|
|
144
|
+
concatenated = torch.cat([attended, latent_proj], dim=-1)
|
|
145
|
+
fusion_gate = self.fusion_gate(concatenated)
|
|
146
|
+
fused = fusion_gate * attended + (1 - fusion_gate) * latent_proj
|
|
147
|
+
|
|
148
|
+
fused = self.norm(fused)
|
|
149
|
+
fused = self.dropout(fused)
|
|
150
|
+
|
|
151
|
+
return fused.squeeze(1)
|
|
152
|
+
|
|
153
|
+
class PerturbationResponseNetwork(nn.Module):
|
|
154
|
+
"""Network for predicting perturbation-specific responses"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
|
|
157
|
+
super().__init__()
|
|
158
|
+
|
|
159
|
+
# Base network
|
|
160
|
+
layers = []
|
|
161
|
+
input_dim = fusion_dim
|
|
162
|
+
|
|
163
|
+
for hidden_dim in hidden_dims:
|
|
164
|
+
layers.extend([
|
|
165
|
+
nn.Linear(input_dim, hidden_dim),
|
|
166
|
+
nn.BatchNorm1d(hidden_dim),
|
|
167
|
+
nn.ReLU(),
|
|
168
|
+
nn.Dropout(0.1)
|
|
169
|
+
])
|
|
170
|
+
input_dim = hidden_dim
|
|
171
|
+
|
|
172
|
+
self.base_network = nn.Sequential(*layers)
|
|
173
|
+
self.final_projection = nn.Linear(hidden_dims[-1], gene_dim)
|
|
174
|
+
|
|
175
|
+
# Perturbation-aware scaling
|
|
176
|
+
self.scale = nn.Linear(fusion_dim, 1)
|
|
177
|
+
self.bias = nn.Linear(fusion_dim, 1)
|
|
178
|
+
|
|
179
|
+
def forward(self, fused_representation):
|
|
180
|
+
base_output = self.base_network(fused_representation)
|
|
181
|
+
expression = self.final_projection(base_output)
|
|
182
|
+
|
|
183
|
+
# Perturbation-aware scaling
|
|
184
|
+
scale = torch.sigmoid(self.scale(fused_representation)) * 2
|
|
185
|
+
bias = self.bias(fused_representation)
|
|
186
|
+
|
|
187
|
+
return F.softplus(expression * scale + bias)
|
|
188
|
+
|
|
189
|
+
class CorrectedNovelPerturbationPredictor(nn.Module):
|
|
190
|
+
"""Predictor for novel perturbations using similarity to known perturbations"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
|
|
193
|
+
super().__init__()
|
|
194
|
+
self.num_known_perturbations = num_known_perturbations
|
|
195
|
+
self.gene_dim = gene_dim
|
|
196
|
+
|
|
197
|
+
# Learnable perturbation prototypes (response patterns for known perturbations)
|
|
198
|
+
self.perturbation_prototypes = nn.Parameter(
|
|
199
|
+
torch.randn(num_known_perturbations, gene_dim) * 0.1
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Similarity-based response generator
|
|
203
|
+
self.response_generator = nn.Sequential(
|
|
204
|
+
nn.Linear(num_known_perturbations, hidden_dim), # Input: similarity to known perturbations
|
|
205
|
+
nn.ReLU(),
|
|
206
|
+
nn.Linear(hidden_dim, gene_dim)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Attention mechanism for combining prototypes
|
|
210
|
+
self.attention_weights = nn.Parameter(torch.randn(num_known_perturbations, 1))
|
|
211
|
+
|
|
212
|
+
def forward(self, similarity_matrix, latent_features=None):
|
|
213
|
+
"""
|
|
214
|
+
Predict response to novel perturbation using similarity to known perturbations
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
similarity_matrix: [batch_size, num_known_perturbations]
|
|
218
|
+
Each row: similarity scores between novel perturbation and known perturbations
|
|
219
|
+
latent_features: [batch_size, latent_dim] (optional) cell state information
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
expression: [batch_size, gene_dim] predicted expression
|
|
223
|
+
"""
|
|
224
|
+
batch_size = similarity_matrix.shape[0]
|
|
225
|
+
|
|
226
|
+
# Method 1: Weighted combination of known perturbation prototypes
|
|
227
|
+
# similarity_matrix: [batch_size, num_known_perturbations]
|
|
228
|
+
# perturbation_prototypes: [num_known_perturbations, gene_dim]
|
|
229
|
+
weighted_prototypes = torch.matmul(similarity_matrix, self.perturbation_prototypes) # [batch_size, gene_dim]
|
|
230
|
+
|
|
231
|
+
# Method 2: Direct generation from similarity profile
|
|
232
|
+
generated_response = self.response_generator(similarity_matrix) # [batch_size, gene_dim]
|
|
233
|
+
|
|
234
|
+
# Combine both methods with learned weights
|
|
235
|
+
combination_weights = torch.sigmoid(
|
|
236
|
+
similarity_matrix.mean(dim=1, keepdim=True) # [batch_size, 1]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
combined_response = (combination_weights * weighted_prototypes +
|
|
240
|
+
(1 - combination_weights) * generated_response)
|
|
241
|
+
|
|
242
|
+
# If latent features provided, modulate response by cell state
|
|
243
|
+
if latent_features is not None:
|
|
244
|
+
# Simple modulation based on latent state
|
|
245
|
+
modulation = torch.sigmoid(latent_features.mean(dim=1, keepdim=True)) # [batch_size, 1]
|
|
246
|
+
combined_response = combined_response * (1 + 0.5 * modulation)
|
|
247
|
+
|
|
248
|
+
return F.softplus(combined_response)
|
|
249
|
+
|
|
250
|
+
class CorrectedMultimodalDecoder(nn.Module):
|
|
251
|
+
"""Main decoder with corrected similarity matrix handling"""
|
|
252
|
+
|
|
253
|
+
def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
|
|
254
|
+
hidden_dims: List[int], perturbation_embedding_dim: int,
|
|
255
|
+
biological_prior_dim: int, dropout_rate: float):
|
|
256
|
+
super().__init__()
|
|
257
|
+
|
|
258
|
+
self.num_known_perturbations = num_known_perturbations
|
|
259
|
+
self.latent_dim = latent_dim
|
|
260
|
+
self.gene_dim = gene_dim
|
|
261
|
+
|
|
262
|
+
# Perturbation encoder for one-hot inputs
|
|
263
|
+
self.perturbation_encoder = PerturbationAwareDecoder.CorrectedPerturbationEncoder(
|
|
264
|
+
num_known_perturbations, perturbation_embedding_dim, hidden_dims[0]
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Cross-modal fusion
|
|
268
|
+
self.cross_modal_fusion = PerturbationAwareDecoder.CrossModalFusion(
|
|
269
|
+
latent_dim, hidden_dims[0], hidden_dims[0]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Response network for known perturbations
|
|
273
|
+
self.response_network = PerturbationAwareDecoder.PerturbationResponseNetwork(
|
|
274
|
+
hidden_dims[0], gene_dim, hidden_dims[1:]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Novel perturbation predictor
|
|
278
|
+
self.novel_predictor = PerturbationAwareDecoder.CorrectedNovelPerturbationPredictor(
|
|
279
|
+
num_known_perturbations, gene_dim, hidden_dims[0]
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def forward(self, latent, perturbation_matrix, mode='one_hot'):
|
|
283
|
+
"""
|
|
284
|
+
Forward pass with corrected similarity matrix definition
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
latent: [batch_size, latent_dim] latent variables
|
|
288
|
+
perturbation_matrix:
|
|
289
|
+
- one_hot mode: [batch_size, num_known_perturbations] one-hot encoded
|
|
290
|
+
- similarity mode: [batch_size, num_known_perturbations] similarity scores
|
|
291
|
+
mode: 'one_hot' for known perturbations, 'similarity' for novel perturbations
|
|
292
|
+
"""
|
|
293
|
+
if mode == 'one_hot':
|
|
294
|
+
# Known perturbation pathway
|
|
295
|
+
perturbation_encoded = self.perturbation_encoder(perturbation_matrix)
|
|
296
|
+
fused = self.cross_modal_fusion(latent, perturbation_encoded)
|
|
297
|
+
expression = self.response_network(fused)
|
|
298
|
+
|
|
299
|
+
elif mode == 'similarity':
|
|
300
|
+
# Novel perturbation pathway
|
|
301
|
+
# perturbation_matrix: similarity to known perturbations [batch_size, num_known_perturbations]
|
|
302
|
+
expression = self.novel_predictor(perturbation_matrix, latent)
|
|
303
|
+
|
|
304
|
+
else:
|
|
305
|
+
raise ValueError(f"Unknown mode: {mode}. Use 'one_hot' or 'similarity'")
|
|
306
|
+
|
|
307
|
+
return expression
|
|
308
|
+
|
|
309
|
+
def get_perturbation_prototypes(self):
|
|
310
|
+
"""Get learned perturbation response prototypes"""
|
|
311
|
+
return self.novel_predictor.perturbation_prototypes.detach()
|
|
312
|
+
|
|
313
|
+
def _build_corrected_model(self):
|
|
314
|
+
"""Build the corrected model"""
|
|
315
|
+
return self.CorrectedMultimodalDecoder(
|
|
316
|
+
self.latent_dim, self.num_known_perturbations, self.gene_dim,
|
|
317
|
+
self.hidden_dims, self.perturbation_embedding_dim,
|
|
318
|
+
self.biological_prior_dim, self.dropout_rate
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def train(self,
|
|
322
|
+
train_latent: np.ndarray,
|
|
323
|
+
train_perturbations: np.ndarray, # One-hot encoded [n_samples, num_known_perturbations]
|
|
324
|
+
train_expression: np.ndarray,
|
|
325
|
+
val_latent: np.ndarray = None,
|
|
326
|
+
val_perturbations: np.ndarray = None,
|
|
327
|
+
val_expression: np.ndarray = None,
|
|
328
|
+
batch_size: int = 32,
|
|
329
|
+
num_epochs: int = 200,
|
|
330
|
+
learning_rate: float = 1e-4,
|
|
331
|
+
checkpoint_path: str = 'corrected_decoder.pth') -> Dict:
|
|
332
|
+
"""
|
|
333
|
+
Train the decoder with one-hot encoded perturbations
|
|
334
|
+
"""
|
|
335
|
+
print("🧬 Starting Training with Corrected Similarity Definition...")
|
|
336
|
+
|
|
337
|
+
# Validate one-hot encoding
|
|
338
|
+
self._validate_one_hot_perturbations(train_perturbations)
|
|
339
|
+
|
|
340
|
+
# Data preparation
|
|
341
|
+
train_dataset = self._create_dataset(train_latent, train_perturbations, train_expression)
|
|
342
|
+
|
|
343
|
+
if val_latent is not None and val_perturbations is not None and val_expression is not None:
|
|
344
|
+
val_dataset = self._create_dataset(val_latent, val_perturbations, val_expression)
|
|
345
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
346
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
347
|
+
else:
|
|
348
|
+
train_size = int(0.9 * len(train_dataset))
|
|
349
|
+
val_size = len(train_dataset) - train_size
|
|
350
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
351
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
|
|
352
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
|
|
353
|
+
|
|
354
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
355
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
356
|
+
print(f"🧪 Known perturbations: {self.num_known_perturbations}")
|
|
357
|
+
|
|
358
|
+
# Optimizer
|
|
359
|
+
optimizer = optim.AdamW(
|
|
360
|
+
self.model.parameters(),
|
|
361
|
+
lr=learning_rate,
|
|
362
|
+
weight_decay=1e-5
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Scheduler
|
|
366
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
367
|
+
|
|
368
|
+
# Loss function
|
|
369
|
+
def loss_fn(pred, target):
|
|
370
|
+
mse_loss = F.mse_loss(pred, target)
|
|
371
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
372
|
+
correlation = self._pearson_correlation(pred, target)
|
|
373
|
+
correlation_loss = 1 - correlation
|
|
374
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
375
|
+
|
|
376
|
+
# Training history
|
|
377
|
+
history = {
|
|
378
|
+
'train_loss': [], 'val_loss': [],
|
|
379
|
+
'train_mse': [], 'val_mse': [],
|
|
380
|
+
'train_correlation': [], 'val_correlation': [],
|
|
381
|
+
'learning_rates': []
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
best_val_loss = float('inf')
|
|
385
|
+
patience = 20
|
|
386
|
+
patience_counter = 0
|
|
387
|
+
|
|
388
|
+
print("\n🔬 Starting training...")
|
|
389
|
+
for epoch in range(1, num_epochs + 1):
|
|
390
|
+
# Training
|
|
391
|
+
train_metrics = self._train_epoch(train_loader, optimizer, loss_fn)
|
|
392
|
+
|
|
393
|
+
# Validation
|
|
394
|
+
val_metrics = self._validate_epoch(val_loader, loss_fn)
|
|
395
|
+
|
|
396
|
+
# Update scheduler
|
|
397
|
+
scheduler.step()
|
|
398
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
399
|
+
|
|
400
|
+
# Record history
|
|
401
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
402
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
403
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
404
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
405
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
406
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
407
|
+
history['learning_rates'].append(current_lr)
|
|
408
|
+
|
|
409
|
+
# Print progress
|
|
410
|
+
if epoch % 10 == 0 or epoch == 1:
|
|
411
|
+
print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
|
|
412
|
+
f"Train: {train_metrics['loss']:.4f} | "
|
|
413
|
+
f"Val: {val_metrics['loss']:.4f} | "
|
|
414
|
+
f"Corr: {val_metrics['correlation']:.4f} | "
|
|
415
|
+
f"LR: {current_lr:.2e}")
|
|
416
|
+
|
|
417
|
+
# Early stopping
|
|
418
|
+
if val_metrics['loss'] < best_val_loss:
|
|
419
|
+
best_val_loss = val_metrics['loss']
|
|
420
|
+
patience_counter = 0
|
|
421
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
422
|
+
else:
|
|
423
|
+
patience_counter += 1
|
|
424
|
+
if patience_counter >= patience:
|
|
425
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
426
|
+
break
|
|
427
|
+
|
|
428
|
+
self.is_trained = True
|
|
429
|
+
self.training_history = history
|
|
430
|
+
self.best_val_loss = best_val_loss
|
|
431
|
+
self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
|
|
432
|
+
|
|
433
|
+
print(f"\n🎉 Training completed! Best val loss: {best_val_loss:.4f}")
|
|
434
|
+
print(f"📊 Learned perturbation prototypes: {self.perturbation_prototypes.shape}")
|
|
435
|
+
return history
|
|
436
|
+
|
|
437
|
+
def _validate_one_hot_perturbations(self, perturbations):
|
|
438
|
+
"""Validate that perturbations are proper one-hot encodings"""
|
|
439
|
+
assert perturbations.shape[1] == self.num_known_perturbations, \
|
|
440
|
+
f"Perturbation dimension {perturbations.shape[1]} doesn't match expected {self.num_known_perturbations}"
|
|
441
|
+
|
|
442
|
+
# Check that each row sums to 1 (perturbation) or 0 (control)
|
|
443
|
+
row_sums = perturbations.sum(axis=1)
|
|
444
|
+
valid_rows = np.all((row_sums == 0) | (row_sums == 1))
|
|
445
|
+
assert valid_rows, "Perturbations should be one-hot encoded (sum to 0 or 1 per row)"
|
|
446
|
+
|
|
447
|
+
print("✅ One-hot perturbations validated")
|
|
448
|
+
|
|
449
|
+
def _create_dataset(self, latent_data, perturbations, expression_data):
|
|
450
|
+
"""Create dataset with one-hot perturbations"""
|
|
451
|
+
class OneHotDataset(Dataset):
|
|
452
|
+
def __init__(self, latent, perturbations, expression):
|
|
453
|
+
self.latent = torch.FloatTensor(latent)
|
|
454
|
+
self.perturbations = torch.FloatTensor(perturbations)
|
|
455
|
+
self.expression = torch.FloatTensor(expression)
|
|
456
|
+
|
|
457
|
+
def __len__(self):
|
|
458
|
+
return len(self.latent)
|
|
459
|
+
|
|
460
|
+
def __getitem__(self, idx):
|
|
461
|
+
return self.latent[idx], self.perturbations[idx], self.expression[idx]
|
|
462
|
+
|
|
463
|
+
return OneHotDataset(latent_data, perturbations, expression_data)
|
|
464
|
+
|
|
465
|
+
def predict(self,
|
|
466
|
+
latent_data: np.ndarray,
|
|
467
|
+
perturbations: np.ndarray,
|
|
468
|
+
batch_size: int = 32) -> np.ndarray:
|
|
469
|
+
"""
|
|
470
|
+
Predict expression for known perturbations using one-hot encoding
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
latent_data: [n_samples, latent_dim] latent variables
|
|
474
|
+
perturbations: [n_samples, num_known_perturbations] one-hot encoded perturbations
|
|
475
|
+
batch_size: Batch size
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
expression: [n_samples, gene_dim] predicted expression
|
|
479
|
+
"""
|
|
480
|
+
if not self.is_trained:
|
|
481
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
482
|
+
|
|
483
|
+
# Validate one-hot encoding
|
|
484
|
+
self._validate_one_hot_perturbations(perturbations)
|
|
485
|
+
|
|
486
|
+
self.model.eval()
|
|
487
|
+
|
|
488
|
+
if isinstance(latent_data, np.ndarray):
|
|
489
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
490
|
+
if isinstance(perturbations, np.ndarray):
|
|
491
|
+
perturbations = torch.FloatTensor(perturbations)
|
|
492
|
+
|
|
493
|
+
predictions = []
|
|
494
|
+
with torch.no_grad():
|
|
495
|
+
for i in range(0, len(latent_data), batch_size):
|
|
496
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
497
|
+
batch_perturbations = perturbations[i:i+batch_size].to(self.device)
|
|
498
|
+
|
|
499
|
+
# Use one-hot mode for known perturbations
|
|
500
|
+
batch_pred = self.model(batch_latent, batch_perturbations, mode='one_hot')
|
|
501
|
+
predictions.append(batch_pred.cpu())
|
|
502
|
+
|
|
503
|
+
return torch.cat(predictions).numpy()
|
|
504
|
+
|
|
505
|
+
def predict_novel_perturbation(self,
|
|
506
|
+
latent_data: np.ndarray,
|
|
507
|
+
similarity_matrix: np.ndarray,
|
|
508
|
+
batch_size: int = 32) -> np.ndarray:
|
|
509
|
+
"""
|
|
510
|
+
Predict response to novel perturbations using similarity to known perturbations
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
latent_data: [n_samples, latent_dim] latent variables
|
|
514
|
+
similarity_matrix: [n_samples, num_known_perturbations]
|
|
515
|
+
Each row: similarity scores between novel perturbation and known perturbations
|
|
516
|
+
Columns correspond to model's known perturbation types
|
|
517
|
+
batch_size: Batch size
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
expression: [n_samples, gene_dim] predicted expression
|
|
521
|
+
"""
|
|
522
|
+
if not self.is_trained:
|
|
523
|
+
warnings.warn("⚠️ Model not trained. Novel perturbation prediction may be inaccurate.")
|
|
524
|
+
|
|
525
|
+
# Validate similarity matrix dimensions
|
|
526
|
+
assert similarity_matrix.shape[1] == self.num_known_perturbations, \
|
|
527
|
+
f"Similarity matrix columns {similarity_matrix.shape[1]} must match known perturbations {self.num_known_perturbations}"
|
|
528
|
+
|
|
529
|
+
# Validate similarity scores are reasonable (0-1 range recommended)
|
|
530
|
+
if np.any(similarity_matrix < 0) or np.any(similarity_matrix > 2):
|
|
531
|
+
warnings.warn("⚠️ Similarity scores outside typical range [0, 1]. Consider normalizing.")
|
|
532
|
+
|
|
533
|
+
self.model.eval()
|
|
534
|
+
|
|
535
|
+
if isinstance(latent_data, np.ndarray):
|
|
536
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
537
|
+
if isinstance(similarity_matrix, np.ndarray):
|
|
538
|
+
similarity_matrix = torch.FloatTensor(similarity_matrix)
|
|
539
|
+
|
|
540
|
+
predictions = []
|
|
541
|
+
with torch.no_grad():
|
|
542
|
+
for i in range(0, len(latent_data), batch_size):
|
|
543
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
544
|
+
batch_similarity = similarity_matrix[i:i+batch_size].to(self.device)
|
|
545
|
+
|
|
546
|
+
# Use similarity mode for novel perturbations
|
|
547
|
+
batch_pred = self.model(batch_latent, batch_similarity, mode='similarity')
|
|
548
|
+
predictions.append(batch_pred.cpu())
|
|
549
|
+
|
|
550
|
+
return torch.cat(predictions).numpy()
|
|
551
|
+
|
|
552
|
+
def get_known_perturbation_prototypes(self) -> np.ndarray:
|
|
553
|
+
"""Get learned response prototypes for known perturbations"""
|
|
554
|
+
if not self.is_trained:
|
|
555
|
+
warnings.warn("⚠️ Model not trained. Prototypes may be uninformative.")
|
|
556
|
+
|
|
557
|
+
if self.perturbation_prototypes is None:
|
|
558
|
+
self.model.eval()
|
|
559
|
+
with torch.no_grad():
|
|
560
|
+
self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
|
|
561
|
+
|
|
562
|
+
return self.perturbation_prototypes
|
|
563
|
+
|
|
564
|
+
def compute_similarity(self, novel_perturbation_features: np.ndarray,
|
|
565
|
+
known_perturbation_features: np.ndarray = None) -> np.ndarray:
|
|
566
|
+
"""
|
|
567
|
+
Compute similarity matrix between novel perturbations and known perturbations
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
novel_perturbation_features: [n_novel, feature_dim] features of novel perturbations
|
|
571
|
+
known_perturbation_features: [n_known, feature_dim] features of known perturbations
|
|
572
|
+
If None, uses learned perturbation prototypes
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
similarity_matrix: [n_novel, num_known_perturbations] similarity scores
|
|
576
|
+
"""
|
|
577
|
+
if known_perturbation_features is None:
|
|
578
|
+
# Use learned prototypes
|
|
579
|
+
known_perturbation_features = self.get_known_perturbation_prototypes()
|
|
580
|
+
|
|
581
|
+
# Normalize features for cosine similarity
|
|
582
|
+
novel_norm = novel_perturbation_features / (np.linalg.norm(novel_perturbation_features, axis=1, keepdims=True) + 1e-8)
|
|
583
|
+
known_norm = known_perturbation_features / (np.linalg.norm(known_perturbation_features, axis=1, keepdims=True) + 1e-8)
|
|
584
|
+
|
|
585
|
+
# Compute cosine similarity
|
|
586
|
+
similarity_matrix = np.dot(novel_norm, known_norm.T)
|
|
587
|
+
|
|
588
|
+
return similarity_matrix
|
|
589
|
+
|
|
590
|
+
def _pearson_correlation(self, pred, target):
|
|
591
|
+
"""Calculate Pearson correlation coefficient"""
|
|
592
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
593
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
594
|
+
|
|
595
|
+
numerator = (pred_centered * target_centered).sum(dim=1)
|
|
596
|
+
denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
|
|
597
|
+
|
|
598
|
+
return (numerator / (denominator + 1e-8)).mean()
|
|
599
|
+
|
|
600
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
601
|
+
"""Train one epoch"""
|
|
602
|
+
self.model.train()
|
|
603
|
+
total_loss = 0
|
|
604
|
+
total_mse = 0
|
|
605
|
+
total_correlation = 0
|
|
606
|
+
|
|
607
|
+
for latent, perturbations, target in train_loader:
|
|
608
|
+
latent = latent.to(self.device)
|
|
609
|
+
perturbations = perturbations.to(self.device)
|
|
610
|
+
target = target.to(self.device)
|
|
611
|
+
|
|
612
|
+
optimizer.zero_grad()
|
|
613
|
+
pred = self.model(latent, perturbations, mode='one_hot')
|
|
614
|
+
|
|
615
|
+
loss = loss_fn(pred, target)
|
|
616
|
+
loss.backward()
|
|
617
|
+
|
|
618
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
619
|
+
optimizer.step()
|
|
620
|
+
|
|
621
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
622
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
623
|
+
|
|
624
|
+
total_loss += loss.item()
|
|
625
|
+
total_mse += mse_loss
|
|
626
|
+
total_correlation += correlation
|
|
627
|
+
|
|
628
|
+
num_batches = len(train_loader)
|
|
629
|
+
return {
|
|
630
|
+
'loss': total_loss / num_batches,
|
|
631
|
+
'mse': total_mse / num_batches,
|
|
632
|
+
'correlation': total_correlation / num_batches
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
636
|
+
"""Validate one epoch"""
|
|
637
|
+
self.model.eval()
|
|
638
|
+
total_loss = 0
|
|
639
|
+
total_mse = 0
|
|
640
|
+
total_correlation = 0
|
|
641
|
+
|
|
642
|
+
with torch.no_grad():
|
|
643
|
+
for latent, perturbations, target in val_loader:
|
|
644
|
+
latent = latent.to(self.device)
|
|
645
|
+
perturbations = perturbations.to(self.device)
|
|
646
|
+
target = target.to(self.device)
|
|
647
|
+
|
|
648
|
+
pred = self.model(latent, perturbations, mode='one_hot')
|
|
649
|
+
loss = loss_fn(pred, target)
|
|
650
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
651
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
652
|
+
|
|
653
|
+
total_loss += loss.item()
|
|
654
|
+
total_mse += mse_loss
|
|
655
|
+
total_correlation += correlation
|
|
656
|
+
|
|
657
|
+
num_batches = len(val_loader)
|
|
658
|
+
return {
|
|
659
|
+
'loss': total_loss / num_batches,
|
|
660
|
+
'mse': total_mse / num_batches,
|
|
661
|
+
'correlation': total_correlation / num_batches
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
665
|
+
"""Save model checkpoint"""
|
|
666
|
+
torch.save({
|
|
667
|
+
'epoch': epoch,
|
|
668
|
+
'model_state_dict': self.model.state_dict(),
|
|
669
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
670
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
671
|
+
'best_val_loss': best_loss,
|
|
672
|
+
'training_history': history,
|
|
673
|
+
'perturbation_prototypes': self.perturbation_prototypes,
|
|
674
|
+
'model_config': {
|
|
675
|
+
'latent_dim': self.latent_dim,
|
|
676
|
+
'num_known_perturbations': self.num_known_perturbations,
|
|
677
|
+
'gene_dim': self.gene_dim,
|
|
678
|
+
'hidden_dims': self.hidden_dims
|
|
679
|
+
}
|
|
680
|
+
}, path)
|
|
681
|
+
|
|
682
|
+
def load_model(self, model_path: str):
|
|
683
|
+
"""Load pre-trained model"""
|
|
684
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
685
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
686
|
+
self.perturbation_prototypes = checkpoint.get('perturbation_prototypes')
|
|
687
|
+
self.is_trained = True
|
|
688
|
+
self.training_history = checkpoint.get('training_history')
|
|
689
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
690
|
+
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
691
|
+
|
|
692
|
+
'''# Example usage
|
|
693
|
+
def example_usage():
|
|
694
|
+
"""Example demonstration of the corrected perturbation decoder"""
|
|
695
|
+
|
|
696
|
+
# Initialize decoder
|
|
697
|
+
decoder = PerturbationAwareDecoder(
|
|
698
|
+
latent_dim=100,
|
|
699
|
+
num_known_perturbations=10, # 10 known perturbation types
|
|
700
|
+
gene_dim=2000, # Reduced for example
|
|
701
|
+
hidden_dims=[256, 512, 1024],
|
|
702
|
+
perturbation_embedding_dim=128
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# Generate example data
|
|
706
|
+
n_samples = 1000
|
|
707
|
+
n_perturbations = 10
|
|
708
|
+
|
|
709
|
+
# Latent variables
|
|
710
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
711
|
+
|
|
712
|
+
# One-hot encoded perturbations
|
|
713
|
+
perturbations = np.zeros((n_samples, n_perturbations))
|
|
714
|
+
for i in range(n_samples):
|
|
715
|
+
if i % 10 != 0: # 90% perturbed, 10% control
|
|
716
|
+
perturb_id = np.random.randint(0, n_perturbations)
|
|
717
|
+
perturbations[i, perturb_id] = 1.0
|
|
718
|
+
|
|
719
|
+
# Expression data with perturbation effects
|
|
720
|
+
base_weights = np.random.randn(100, 2000) * 0.1
|
|
721
|
+
perturbation_effects = np.random.randn(n_perturbations, 2000) * 0.5
|
|
722
|
+
|
|
723
|
+
expression_data = np.tanh(latent_data.dot(base_weights))
|
|
724
|
+
for i in range(n_samples):
|
|
725
|
+
if perturbations[i].sum() > 0: # Perturbed sample
|
|
726
|
+
perturb_id = np.argmax(perturbations[i])
|
|
727
|
+
expression_data[i] += perturbation_effects[perturb_id]
|
|
728
|
+
|
|
729
|
+
expression_data = np.maximum(expression_data, 0)
|
|
730
|
+
|
|
731
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}, Expression {expression_data.shape}")
|
|
732
|
+
print(f"🧪 Control samples: {(perturbations.sum(axis=1) == 0).sum()}")
|
|
733
|
+
print(f"🧪 Perturbed samples: {(perturbations.sum(axis=1) > 0).sum()}")
|
|
734
|
+
|
|
735
|
+
# Train with one-hot perturbations
|
|
736
|
+
history = decoder.train(
|
|
737
|
+
train_latent=latent_data,
|
|
738
|
+
train_perturbations=perturbations,
|
|
739
|
+
train_expression=expression_data,
|
|
740
|
+
batch_size=32,
|
|
741
|
+
num_epochs=50
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Test predictions with one-hot perturbations
|
|
745
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
746
|
+
test_perturbations = np.zeros((10, n_perturbations))
|
|
747
|
+
for i in range(10):
|
|
748
|
+
test_perturbations[i, i % n_perturbations] = 1.0 # One-hot encoding
|
|
749
|
+
|
|
750
|
+
predictions = decoder.predict(test_latent, test_perturbations)
|
|
751
|
+
print(f"🔮 Known perturbation prediction shape: {predictions.shape}")
|
|
752
|
+
|
|
753
|
+
# Test novel perturbation prediction with similarity matrix
|
|
754
|
+
test_latent_novel = np.random.randn(5, 100).astype(np.float32)
|
|
755
|
+
|
|
756
|
+
# Create similarity matrix: [5, 10] - 5 novel perturbations, similarity to 10 known perturbations
|
|
757
|
+
similarity_matrix = np.random.rand(5, n_perturbations)
|
|
758
|
+
similarity_matrix = similarity_matrix / similarity_matrix.sum(axis=1, keepdims=True) # Normalize rows
|
|
759
|
+
|
|
760
|
+
novel_predictions = decoder.predict_novel_perturbation(test_latent_novel, similarity_matrix)
|
|
761
|
+
print(f"🔮 Novel perturbation prediction shape: {novel_predictions.shape}")
|
|
762
|
+
|
|
763
|
+
# Get learned perturbation prototypes
|
|
764
|
+
prototypes = decoder.get_known_perturbation_prototypes()
|
|
765
|
+
print(f"📊 Perturbation prototypes shape: {prototypes.shape}")
|
|
766
|
+
|
|
767
|
+
return decoder
|
|
768
|
+
|
|
769
|
+
if __name__ == "__main__":
|
|
770
|
+
example_usage()'''
|
|
@@ -56,7 +56,6 @@ class VirtualCellDecoder:
|
|
|
56
56
|
print(f" - Biological Prior Dimension: {biological_prior_dim}")
|
|
57
57
|
print(f" - Device: {self.device}")
|
|
58
58
|
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
59
|
-
print(f" - Estimated GPU Memory: ~8-10GB (optimized for 20GB)")
|
|
60
59
|
|
|
61
60
|
class BiologicalPriorNetwork(nn.Module):
|
|
62
61
|
"""Biological prior network based on gene regulatory knowledge"""
|
|
@@ -5,6 +5,7 @@ from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
|
5
5
|
from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
|
|
6
6
|
from .EfficientTranscriptomeDecoder import EfficientTranscriptomeDecoder
|
|
7
7
|
from .VirtualCellDecoder import VirtualCellDecoder
|
|
8
|
+
from .PerturbationAwareDecoder import PerturbationAwareDecoder
|
|
8
9
|
|
|
9
10
|
from . import utils
|
|
10
11
|
from . import codebook
|
|
@@ -18,6 +19,8 @@ from . import TranscriptomeDecoder
|
|
|
18
19
|
from . import SimpleTranscriptomeDecoder
|
|
19
20
|
from . import EfficientTranscriptomeDecoder
|
|
20
21
|
from . import VirtualCellDecoder
|
|
22
|
+
from . import PerturbationAwareDecoder
|
|
21
23
|
|
|
22
24
|
__all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
23
|
-
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', '
|
|
25
|
+
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
|
|
26
|
+
'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|