SURE-tools 2.4.22__py3-none-any.whl → 2.4.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +151 -69
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/PerturbationAwareDecoder.py +162 -148
- SURE/__init__.py +3 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/METADATA +1 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/RECORD +11 -9
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/top_level.txt +0 -0
SURE/PerturbationAwareDecoder.py
CHANGED
|
@@ -12,30 +12,20 @@ warnings.filterwarnings('ignore')
|
|
|
12
12
|
class PerturbationAwareDecoder:
|
|
13
13
|
"""
|
|
14
14
|
Advanced transcriptome decoder with perturbation awareness
|
|
15
|
-
|
|
15
|
+
Fixed version with proper handling of single hidden layer configurations
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
def __init__(self,
|
|
19
19
|
latent_dim: int = 100,
|
|
20
20
|
num_known_perturbations: int = 50,
|
|
21
21
|
gene_dim: int = 60000,
|
|
22
|
-
hidden_dims: List[int] = [512],
|
|
22
|
+
hidden_dims: List[int] = [512],
|
|
23
23
|
perturbation_embedding_dim: int = 128,
|
|
24
24
|
biological_prior_dim: int = 256,
|
|
25
25
|
dropout_rate: float = 0.1,
|
|
26
26
|
device: str = None):
|
|
27
27
|
"""
|
|
28
|
-
Multi-modal decoder
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
latent_dim: Latent variable dimension
|
|
32
|
-
num_known_perturbations: Number of known perturbation types
|
|
33
|
-
gene_dim: Number of genes
|
|
34
|
-
hidden_dims: Hidden layer dimensions (can be single element list)
|
|
35
|
-
perturbation_embedding_dim: Embedding dimension for perturbations
|
|
36
|
-
biological_prior_dim: Dimension for biological prior knowledge
|
|
37
|
-
dropout_rate: Dropout rate
|
|
38
|
-
device: Computation device
|
|
28
|
+
Multi-modal decoder with fixed single layer support
|
|
39
29
|
"""
|
|
40
30
|
self.latent_dim = latent_dim
|
|
41
31
|
self.num_known_perturbations = num_known_perturbations
|
|
@@ -50,7 +40,7 @@ class PerturbationAwareDecoder:
|
|
|
50
40
|
self._validate_hidden_dims()
|
|
51
41
|
|
|
52
42
|
# Initialize multi-modal model
|
|
53
|
-
self.model = self.
|
|
43
|
+
self.model = self._build_fixed_model()
|
|
54
44
|
self.model.to(self.device)
|
|
55
45
|
|
|
56
46
|
# Training state
|
|
@@ -77,8 +67,8 @@ class PerturbationAwareDecoder:
|
|
|
77
67
|
else:
|
|
78
68
|
print(f"🔧 Multi-layer configuration: {len(self.hidden_dims)} hidden layers")
|
|
79
69
|
|
|
80
|
-
class
|
|
81
|
-
"""
|
|
70
|
+
class FixedPerturbationEncoder(nn.Module):
|
|
71
|
+
"""Fixed perturbation encoder"""
|
|
82
72
|
|
|
83
73
|
def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
|
|
84
74
|
super().__init__()
|
|
@@ -87,7 +77,7 @@ class PerturbationAwareDecoder:
|
|
|
87
77
|
# Embedding for perturbation types
|
|
88
78
|
self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
|
|
89
79
|
|
|
90
|
-
#
|
|
80
|
+
# Projection to hidden space
|
|
91
81
|
self.projection = nn.Sequential(
|
|
92
82
|
nn.Linear(embedding_dim, hidden_dim),
|
|
93
83
|
nn.ReLU(),
|
|
@@ -95,28 +85,26 @@ class PerturbationAwareDecoder:
|
|
|
95
85
|
)
|
|
96
86
|
|
|
97
87
|
def forward(self, one_hot_perturbations):
|
|
98
|
-
batch_size = one_hot_perturbations.shape[0]
|
|
99
|
-
|
|
100
88
|
# Convert one-hot to indices
|
|
101
89
|
perturbation_indices = torch.argmax(one_hot_perturbations, dim=1)
|
|
102
90
|
|
|
103
91
|
# Get perturbation embeddings
|
|
104
92
|
perturbation_embeds = self.perturbation_embedding(perturbation_indices)
|
|
105
93
|
|
|
106
|
-
#
|
|
94
|
+
# Project to hidden space
|
|
107
95
|
hidden_repr = self.projection(perturbation_embeds)
|
|
108
96
|
|
|
109
97
|
return hidden_repr
|
|
110
98
|
|
|
111
|
-
class
|
|
112
|
-
"""
|
|
99
|
+
class FixedCrossModalFusion(nn.Module):
|
|
100
|
+
"""Fixed cross-modal fusion"""
|
|
113
101
|
|
|
114
102
|
def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
|
|
115
103
|
super().__init__()
|
|
116
104
|
self.latent_projection = nn.Linear(latent_dim, fusion_dim)
|
|
117
105
|
self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
|
|
118
106
|
|
|
119
|
-
#
|
|
107
|
+
# Fusion gate
|
|
120
108
|
self.fusion_gate = nn.Sequential(
|
|
121
109
|
nn.Linear(fusion_dim * 2, fusion_dim),
|
|
122
110
|
nn.Sigmoid()
|
|
@@ -130,7 +118,7 @@ class PerturbationAwareDecoder:
|
|
|
130
118
|
latent_proj = self.latent_projection(latent)
|
|
131
119
|
perturbation_proj = self.perturbation_projection(perturbation_encoded)
|
|
132
120
|
|
|
133
|
-
#
|
|
121
|
+
# Gated fusion
|
|
134
122
|
concatenated = torch.cat([latent_proj, perturbation_proj], dim=-1)
|
|
135
123
|
fusion_gate = self.fusion_gate(concatenated)
|
|
136
124
|
|
|
@@ -141,43 +129,32 @@ class PerturbationAwareDecoder:
|
|
|
141
129
|
|
|
142
130
|
return fused
|
|
143
131
|
|
|
144
|
-
class
|
|
145
|
-
"""
|
|
132
|
+
class FixedPerturbationResponseNetwork(nn.Module):
|
|
133
|
+
"""Fixed response network with proper single layer handling"""
|
|
146
134
|
|
|
147
135
|
def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
|
|
148
136
|
super().__init__()
|
|
149
137
|
|
|
150
|
-
#
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
138
|
+
# Build network layers
|
|
139
|
+
layers = []
|
|
140
|
+
input_dim = fusion_dim
|
|
141
|
+
|
|
142
|
+
# Handle both single and multi-layer cases
|
|
143
|
+
for i, hidden_dim in enumerate(hidden_dims):
|
|
144
|
+
layers.extend([
|
|
145
|
+
nn.Linear(input_dim, hidden_dim),
|
|
146
|
+
nn.BatchNorm1d(hidden_dim),
|
|
156
147
|
nn.ReLU(),
|
|
157
148
|
nn.Dropout(0.1)
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
layers = []
|
|
163
|
-
input_dim = fusion_dim
|
|
164
|
-
|
|
165
|
-
for hidden_dim in hidden_dims:
|
|
166
|
-
layers.extend([
|
|
167
|
-
nn.Linear(input_dim, hidden_dim),
|
|
168
|
-
nn.BatchNorm1d(hidden_dim),
|
|
169
|
-
nn.ReLU(),
|
|
170
|
-
nn.Dropout(0.1)
|
|
171
|
-
])
|
|
172
|
-
input_dim = hidden_dim
|
|
173
|
-
|
|
174
|
-
self.base_network = nn.Sequential(*layers)
|
|
175
|
-
final_input_dim = hidden_dims[-1]
|
|
149
|
+
])
|
|
150
|
+
input_dim = hidden_dim
|
|
151
|
+
|
|
152
|
+
self.base_network = nn.Sequential(*layers)
|
|
176
153
|
|
|
177
|
-
# Final projection
|
|
178
|
-
self.final_projection = nn.Linear(
|
|
154
|
+
# Final projection - FIXED: Use current input_dim instead of hidden_dims[-1]
|
|
155
|
+
self.final_projection = nn.Linear(input_dim, gene_dim)
|
|
179
156
|
|
|
180
|
-
# Perturbation-aware scaling
|
|
157
|
+
# Perturbation-aware scaling
|
|
181
158
|
self.scale = nn.Linear(fusion_dim, 1)
|
|
182
159
|
self.bias = nn.Linear(fusion_dim, 1)
|
|
183
160
|
|
|
@@ -191,8 +168,8 @@ class PerturbationAwareDecoder:
|
|
|
191
168
|
|
|
192
169
|
return F.softplus(expression * scale + bias)
|
|
193
170
|
|
|
194
|
-
class
|
|
195
|
-
"""
|
|
171
|
+
class FixedNovelPerturbationPredictor(nn.Module):
|
|
172
|
+
"""Fixed novel perturbation predictor"""
|
|
196
173
|
|
|
197
174
|
def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
|
|
198
175
|
super().__init__()
|
|
@@ -204,7 +181,7 @@ class PerturbationAwareDecoder:
|
|
|
204
181
|
torch.randn(num_known_perturbations, gene_dim) * 0.1
|
|
205
182
|
)
|
|
206
183
|
|
|
207
|
-
#
|
|
184
|
+
# Response generator - handle case where hidden_dim might be 0
|
|
208
185
|
if hidden_dim > 0:
|
|
209
186
|
self.response_generator = nn.Sequential(
|
|
210
187
|
nn.Linear(num_known_perturbations, hidden_dim),
|
|
@@ -215,7 +192,7 @@ class PerturbationAwareDecoder:
|
|
|
215
192
|
# Direct projection if no hidden layer
|
|
216
193
|
self.response_generator = nn.Linear(num_known_perturbations, gene_dim)
|
|
217
194
|
|
|
218
|
-
#
|
|
195
|
+
# Attention mechanism
|
|
219
196
|
self.similarity_attention = nn.Sequential(
|
|
220
197
|
nn.Linear(num_known_perturbations, num_known_perturbations),
|
|
221
198
|
nn.Softmax(dim=-1)
|
|
@@ -238,8 +215,8 @@ class PerturbationAwareDecoder:
|
|
|
238
215
|
|
|
239
216
|
return final_response
|
|
240
217
|
|
|
241
|
-
class
|
|
242
|
-
"""Main decoder
|
|
218
|
+
class FixedMultimodalDecoder(nn.Module):
|
|
219
|
+
"""Main decoder with fixed single layer handling"""
|
|
243
220
|
|
|
244
221
|
def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
|
|
245
222
|
hidden_dims: List[int], perturbation_embedding_dim: int,
|
|
@@ -250,33 +227,26 @@ class PerturbationAwareDecoder:
|
|
|
250
227
|
self.latent_dim = latent_dim
|
|
251
228
|
self.gene_dim = gene_dim
|
|
252
229
|
|
|
253
|
-
#
|
|
254
|
-
|
|
255
|
-
# Use the single dimension for all components
|
|
256
|
-
main_hidden_dim = hidden_dims[0]
|
|
257
|
-
response_hidden_dims = [] # No additional hidden layers for response
|
|
258
|
-
else:
|
|
259
|
-
# Multiple layers: first for fusion, rest for response
|
|
260
|
-
main_hidden_dim = hidden_dims[0]
|
|
261
|
-
response_hidden_dims = hidden_dims[1:]
|
|
230
|
+
# Use first hidden dimension for fusion
|
|
231
|
+
main_hidden_dim = hidden_dims[0]
|
|
262
232
|
|
|
263
233
|
# Perturbation encoder
|
|
264
|
-
self.perturbation_encoder = PerturbationAwareDecoder.
|
|
234
|
+
self.perturbation_encoder = PerturbationAwareDecoder.FixedPerturbationEncoder(
|
|
265
235
|
num_known_perturbations, perturbation_embedding_dim, main_hidden_dim
|
|
266
236
|
)
|
|
267
237
|
|
|
268
238
|
# Cross-modal fusion
|
|
269
|
-
self.cross_modal_fusion = PerturbationAwareDecoder.
|
|
239
|
+
self.cross_modal_fusion = PerturbationAwareDecoder.FixedCrossModalFusion(
|
|
270
240
|
latent_dim, main_hidden_dim, main_hidden_dim
|
|
271
241
|
)
|
|
272
242
|
|
|
273
|
-
# Response network
|
|
274
|
-
self.response_network = PerturbationAwareDecoder.
|
|
275
|
-
main_hidden_dim, gene_dim,
|
|
243
|
+
# Response network - FIXED: Use all hidden_dims for response network
|
|
244
|
+
self.response_network = PerturbationAwareDecoder.FixedPerturbationResponseNetwork(
|
|
245
|
+
main_hidden_dim, gene_dim, hidden_dims # Pass all hidden_dims
|
|
276
246
|
)
|
|
277
247
|
|
|
278
248
|
# Novel perturbation predictor
|
|
279
|
-
self.novel_predictor = PerturbationAwareDecoder.
|
|
249
|
+
self.novel_predictor = PerturbationAwareDecoder.FixedNovelPerturbationPredictor(
|
|
280
250
|
num_known_perturbations, gene_dim, main_hidden_dim
|
|
281
251
|
)
|
|
282
252
|
|
|
@@ -300,9 +270,9 @@ class PerturbationAwareDecoder:
|
|
|
300
270
|
"""Get learned perturbation response prototypes"""
|
|
301
271
|
return self.novel_predictor.perturbation_prototypes.detach()
|
|
302
272
|
|
|
303
|
-
def
|
|
304
|
-
"""Build
|
|
305
|
-
return self.
|
|
273
|
+
def _build_fixed_model(self):
|
|
274
|
+
"""Build the fixed model"""
|
|
275
|
+
return self.FixedMultimodalDecoder(
|
|
306
276
|
self.latent_dim, self.num_known_perturbations, self.gene_dim,
|
|
307
277
|
self.hidden_dims, self.perturbation_embedding_dim,
|
|
308
278
|
self.biological_prior_dim, self.dropout_rate
|
|
@@ -318,11 +288,11 @@ class PerturbationAwareDecoder:
|
|
|
318
288
|
batch_size: int = 32,
|
|
319
289
|
num_epochs: int = 200,
|
|
320
290
|
learning_rate: float = 1e-4,
|
|
321
|
-
checkpoint_path: str = '
|
|
291
|
+
checkpoint_path: str = 'fixed_decoder.pth') -> Dict:
|
|
322
292
|
"""
|
|
323
|
-
Train the decoder
|
|
293
|
+
Train the fixed decoder
|
|
324
294
|
"""
|
|
325
|
-
print("🧬 Starting Training with Single Layer Support...")
|
|
295
|
+
print("🧬 Starting Training with Fixed Single Layer Support...")
|
|
326
296
|
|
|
327
297
|
# Validate one-hot encoding
|
|
328
298
|
self._validate_one_hot_perturbations(train_perturbations)
|
|
@@ -626,98 +596,142 @@ class PerturbationAwareDecoder:
|
|
|
626
596
|
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
627
597
|
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
628
598
|
|
|
629
|
-
'''#
|
|
630
|
-
def
|
|
631
|
-
"""
|
|
599
|
+
'''# Test the fixed implementation
|
|
600
|
+
def test_single_layer_fix():
|
|
601
|
+
"""Test the fixed single layer implementation"""
|
|
632
602
|
|
|
633
|
-
|
|
634
|
-
|
|
603
|
+
print("🧪 Testing single layer configuration...")
|
|
604
|
+
|
|
605
|
+
# Test with single hidden layer
|
|
606
|
+
decoder_single = PerturbationAwareDecoder(
|
|
635
607
|
latent_dim=100,
|
|
636
608
|
num_known_perturbations=10,
|
|
637
|
-
gene_dim=2000,
|
|
638
|
-
hidden_dims=[512], # Single
|
|
609
|
+
gene_dim=2000,
|
|
610
|
+
hidden_dims=[512], # Single element list
|
|
639
611
|
perturbation_embedding_dim=128
|
|
640
612
|
)
|
|
641
613
|
|
|
642
|
-
# Generate
|
|
643
|
-
n_samples =
|
|
644
|
-
n_perturbations = 10
|
|
645
|
-
|
|
646
|
-
# Latent variables
|
|
614
|
+
# Generate test data
|
|
615
|
+
n_samples = 100
|
|
647
616
|
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
648
|
-
|
|
649
|
-
# One-hot encoded perturbations
|
|
650
|
-
perturbations = np.zeros((n_samples, n_perturbations))
|
|
651
|
-
for i in range(n_samples):
|
|
652
|
-
if i % 10 != 0: # 90% perturbed, 10% control
|
|
653
|
-
perturb_id = np.random.randint(0, n_perturbations)
|
|
654
|
-
perturbations[i, perturb_id] = 1.0
|
|
655
|
-
|
|
656
|
-
# Expression data
|
|
657
|
-
base_weights = np.random.randn(100, 2000) * 0.1
|
|
658
|
-
perturbation_effects = np.random.randn(n_perturbations, 2000) * 0.5
|
|
659
|
-
|
|
660
|
-
expression_data = np.tanh(latent_data.dot(base_weights))
|
|
617
|
+
perturbations = np.zeros((n_samples, 10))
|
|
661
618
|
for i in range(n_samples):
|
|
662
|
-
if
|
|
663
|
-
|
|
664
|
-
expression_data[i] += perturbation_effects[perturb_id]
|
|
619
|
+
if i % 10 != 0:
|
|
620
|
+
perturbations[i, np.random.randint(0, 10)] = 1.0
|
|
665
621
|
|
|
622
|
+
expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
|
|
666
623
|
expression_data = np.maximum(expression_data, 0)
|
|
667
624
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
625
|
+
# Test forward pass
|
|
626
|
+
decoder_single.model.eval()
|
|
627
|
+
with torch.no_grad():
|
|
628
|
+
latent_tensor = torch.FloatTensor(latent_data[:5]).to(decoder_single.device)
|
|
629
|
+
perturbations_tensor = torch.FloatTensor(perturbations[:5]).to(decoder_single.device)
|
|
630
|
+
|
|
631
|
+
# Test known perturbation prediction
|
|
632
|
+
output = decoder_single.model(latent_tensor, perturbations_tensor, mode='one_hot')
|
|
633
|
+
print(f"✅ Known perturbation prediction shape: {output.shape}")
|
|
634
|
+
|
|
635
|
+
# Test novel perturbation prediction
|
|
636
|
+
similarity_matrix = np.random.rand(5, 10).astype(np.float32)
|
|
637
|
+
similarity_tensor = torch.FloatTensor(similarity_matrix).to(decoder_single.device)
|
|
638
|
+
novel_output = decoder_single.model(latent_tensor, similarity_tensor, mode='similarity')
|
|
639
|
+
print(f"✅ Novel perturbation prediction shape: {novel_output.shape}")
|
|
680
640
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
test_perturbations[i, i % n_perturbations] = 1.0
|
|
641
|
+
print("🎉 Single layer test passed!")
|
|
642
|
+
|
|
643
|
+
def test_multi_layer_fix():
|
|
644
|
+
"""Test the multi-layer implementation"""
|
|
686
645
|
|
|
687
|
-
|
|
688
|
-
print(f"🔮 Known perturbation prediction shape: {predictions.shape}")
|
|
646
|
+
print("\n🧪 Testing multi-layer configuration...")
|
|
689
647
|
|
|
690
|
-
# Test
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
648
|
+
# Test with multiple hidden layers
|
|
649
|
+
decoder_multi = PerturbationAwareDecoder(
|
|
650
|
+
latent_dim=100,
|
|
651
|
+
num_known_perturbations=10,
|
|
652
|
+
gene_dim=2000,
|
|
653
|
+
hidden_dims=[256, 512, 1024], # Multiple layers
|
|
654
|
+
perturbation_embedding_dim=128
|
|
655
|
+
)
|
|
694
656
|
|
|
695
|
-
|
|
696
|
-
|
|
657
|
+
print("🎉 Multi-layer test passed!")
|
|
658
|
+
|
|
659
|
+
def test_edge_cases():
|
|
660
|
+
"""Test edge cases"""
|
|
661
|
+
|
|
662
|
+
print("\n🧪 Testing edge cases...")
|
|
663
|
+
|
|
664
|
+
# Test with different hidden_dims configurations
|
|
665
|
+
configs = [
|
|
666
|
+
[512], # Single layer
|
|
667
|
+
[256, 512], # Two layers
|
|
668
|
+
[128, 256, 512], # Three layers
|
|
669
|
+
[1024], # Wide single layer
|
|
670
|
+
[64, 128, 256, 512, 1024] # Deep network
|
|
671
|
+
]
|
|
672
|
+
|
|
673
|
+
for i, hidden_dims in enumerate(configs):
|
|
674
|
+
try:
|
|
675
|
+
decoder = PerturbationAwareDecoder(
|
|
676
|
+
latent_dim=50,
|
|
677
|
+
num_known_perturbations=5,
|
|
678
|
+
gene_dim=1000,
|
|
679
|
+
hidden_dims=hidden_dims,
|
|
680
|
+
perturbation_embedding_dim=64
|
|
681
|
+
)
|
|
682
|
+
print(f"✅ Config {i+1}: {hidden_dims} - Success")
|
|
683
|
+
except Exception as e:
|
|
684
|
+
print(f"❌ Config {i+1}: {hidden_dims} - Failed: {e}")
|
|
697
685
|
|
|
698
|
-
|
|
686
|
+
print("🎉 Edge case testing completed!")
|
|
699
687
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
688
|
+
if __name__ == "__main__":
|
|
689
|
+
# Run tests
|
|
690
|
+
test_single_layer_fix()
|
|
691
|
+
test_multi_layer_fix()
|
|
692
|
+
test_edge_cases()
|
|
703
693
|
|
|
704
|
-
#
|
|
694
|
+
# Example usage
|
|
695
|
+
print("\n🎯 Example Usage:")
|
|
696
|
+
|
|
697
|
+
# Single hidden layer example
|
|
705
698
|
decoder = PerturbationAwareDecoder(
|
|
706
699
|
latent_dim=100,
|
|
707
700
|
num_known_perturbations=10,
|
|
708
701
|
gene_dim=2000,
|
|
709
|
-
hidden_dims=[
|
|
702
|
+
hidden_dims=[512], # Single hidden layer
|
|
710
703
|
perturbation_embedding_dim=128
|
|
711
704
|
)
|
|
712
705
|
|
|
713
|
-
|
|
714
|
-
|
|
706
|
+
# Generate example data
|
|
707
|
+
n_samples = 1000
|
|
708
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
709
|
+
perturbations = np.zeros((n_samples, 10))
|
|
710
|
+
for i in range(n_samples):
|
|
711
|
+
if i % 10 != 0:
|
|
712
|
+
perturbations[i, np.random.randint(0, 10)] = 1.0
|
|
713
|
+
|
|
714
|
+
# Simulate expression data
|
|
715
|
+
base_weights = np.random.randn(100, 2000) * 0.1
|
|
716
|
+
perturbation_effects = np.random.randn(10, 2000) * 0.5
|
|
715
717
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
if
|
|
719
|
-
|
|
720
|
-
|
|
718
|
+
expression_data = np.tanh(latent_data.dot(base_weights))
|
|
719
|
+
for i in range(n_samples):
|
|
720
|
+
if perturbations[i].sum() > 0:
|
|
721
|
+
perturb_id = np.argmax(perturbations[i])
|
|
722
|
+
expression_data[i] += perturbation_effects[perturb_id]
|
|
723
|
+
|
|
724
|
+
expression_data = np.maximum(expression_data, 0)
|
|
725
|
+
|
|
726
|
+
print(f"📊 Example data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}")
|
|
727
|
+
|
|
728
|
+
# Train (commented out for quick testing)
|
|
729
|
+
# history = decoder.train(
|
|
730
|
+
# train_latent=latent_data,
|
|
731
|
+
# train_perturbations=perturbations,
|
|
732
|
+
# train_expression=expression_data,
|
|
733
|
+
# batch_size=32,
|
|
734
|
+
# num_epochs=10 # Short training for testing
|
|
735
|
+
# )
|
|
721
736
|
|
|
722
|
-
print("
|
|
723
|
-
decoder_multi = example_multi_layer_usage()'''
|
|
737
|
+
print("🎉 All tests completed successfully!")'''
|
SURE/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
+
from .DensityFlowLinear import DensityFlowLinear
|
|
3
4
|
from .PerturbE import PerturbE
|
|
4
5
|
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
5
6
|
from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
|
|
@@ -11,6 +12,7 @@ from . import utils
|
|
|
11
12
|
from . import codebook
|
|
12
13
|
from . import SURE
|
|
13
14
|
from . import DensityFlow
|
|
15
|
+
from . import DensityFlowLinear
|
|
14
16
|
from . import atac
|
|
15
17
|
from . import flow
|
|
16
18
|
from . import perturb
|
|
@@ -21,6 +23,6 @@ from . import EfficientTranscriptomeDecoder
|
|
|
21
23
|
from . import VirtualCellDecoder
|
|
22
24
|
from . import PerturbationAwareDecoder
|
|
23
25
|
|
|
24
|
-
__all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
26
|
+
__all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
25
27
|
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
|
|
26
28
|
'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=fqqI8sHnfXuTK9O1il-dL7F4W7gbUMjGHD8uRwpESlc,60218
|
|
2
|
+
SURE/DensityFlow2.py,sha256=BBRCoA4NpU4EjghToOvowo17UtwYokTN75KxWYHTX1E,58404
|
|
3
|
+
SURE/DensityFlowLinear.py,sha256=bYiPHJ6mza4sOXUjlFq7wButu3rNLYZuqWUTtIO06F4,57540
|
|
2
4
|
SURE/EfficientTranscriptomeDecoder.py,sha256=O_x-4edKBU5OJJbOOS-59u3TQElZqhAtOVJMPlpw8m0,21667
|
|
3
5
|
SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
|
|
4
|
-
SURE/PerturbationAwareDecoder.py,sha256=
|
|
6
|
+
SURE/PerturbationAwareDecoder.py,sha256=duhvBvZjOpAk7c2YTfmA2qKbrgVvwT7IW1pxaukq_iU,30231
|
|
5
7
|
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
6
8
|
SURE/SimpleTranscriptomeDecoder.py,sha256=mLgYYipfRrmuXlpoaxLPfJS009OVwyshdL3nXTJygIE,22285
|
|
7
9
|
SURE/TranscriptomeDecoder.py,sha256=n2tVB8hNVLwSQ1G1Jpd6WzMl2Iw63eK0_Ujk9d48SJY,20982
|
|
8
10
|
SURE/VirtualCellDecoder.py,sha256=z1Z7GRTYmTE3DaSKZueofv138R0J7kGFfnh0a_Lee38,27468
|
|
9
|
-
SURE/__init__.py,sha256=
|
|
11
|
+
SURE/__init__.py,sha256=ayUV9hNysHtZxbRK87nnC6cXs-7wAq6FUkllqotrx8E,1123
|
|
10
12
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
11
13
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
12
14
|
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
@@ -23,9 +25,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
23
25
|
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
24
26
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
25
27
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
26
|
-
sure_tools-2.4.
|
|
27
|
-
sure_tools-2.4.
|
|
28
|
-
sure_tools-2.4.
|
|
29
|
-
sure_tools-2.4.
|
|
30
|
-
sure_tools-2.4.
|
|
31
|
-
sure_tools-2.4.
|
|
28
|
+
sure_tools-2.4.43.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
29
|
+
sure_tools-2.4.43.dist-info/METADATA,sha256=q0DTzGgBqj5Hi8n2YNmJymHD25dZSUdRVlMrfiy-5Hw,2678
|
|
30
|
+
sure_tools-2.4.43.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
sure_tools-2.4.43.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
32
|
+
sure_tools-2.4.43.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
33
|
+
sure_tools-2.4.43.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|