SURE-tools 2.4.17__tar.gz → 2.4.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {sure_tools-2.4.17 → sure_tools-2.4.20}/PKG-INFO +1 -1
  2. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/DensityFlow.py +7 -0
  3. sure_tools-2.4.20/SURE/PerturbationAwareDecoder.py +770 -0
  4. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/VirtualCellDecoder.py +0 -1
  5. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/__init__.py +4 -1
  6. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/PKG-INFO +1 -1
  7. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/SOURCES.txt +1 -0
  8. {sure_tools-2.4.17 → sure_tools-2.4.20}/setup.py +1 -1
  9. {sure_tools-2.4.17 → sure_tools-2.4.20}/LICENSE +0 -0
  10. {sure_tools-2.4.17 → sure_tools-2.4.20}/README.md +0 -0
  11. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/EfficientTranscriptomeDecoder.py +0 -0
  12. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/PerturbE.py +0 -0
  13. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/SURE.py +0 -0
  14. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/SimpleTranscriptomeDecoder.py +0 -0
  15. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/TranscriptomeDecoder.py +0 -0
  16. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/__init__.py +0 -0
  17. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/assembly.py +0 -0
  18. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/assembly/atlas.py +0 -0
  19. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/atac/__init__.py +0 -0
  20. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/atac/utils.py +0 -0
  21. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/codebook/__init__.py +0 -0
  22. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/codebook/codebook.py +0 -0
  23. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/__init__.py +0 -0
  24. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/flow_stats.py +0 -0
  25. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/flow/plot_quiver.py +0 -0
  26. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/perturb/__init__.py +0 -0
  27. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/perturb/perturb.py +0 -0
  28. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/__init__.py +0 -0
  29. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/custom_mlp.py +0 -0
  30. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/queue.py +0 -0
  31. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE/utils/utils.py +0 -0
  32. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/dependency_links.txt +0 -0
  33. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/entry_points.txt +0 -0
  34. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/requires.txt +0 -0
  35. {sure_tools-2.4.17 → sure_tools-2.4.20}/SURE_tools.egg-info/top_level.txt +0 -0
  36. {sure_tools-2.4.17 → sure_tools-2.4.20}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.17
3
+ Version: 2.4.20
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
109
109
 
110
110
  set_random_seed(seed)
111
111
  self.setup_networks()
112
+
113
+ print(f"🧬 DensityFlow Initialized:")
114
+ print(f" - Latent Dimension: {self.latent_dim}")
115
+ print(f" - Gene Dimension: {self.input_size}")
116
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
117
+ print(f" - Device: {self.get_device()}")
118
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
112
119
 
113
120
  def setup_networks(self):
114
121
  latent_dim = self.latent_dim
@@ -0,0 +1,770 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class PerturbationAwareDecoder:
13
+ """
14
+ Advanced transcriptome decoder with perturbation awareness
15
+ Similarity matrix columns correspond to known perturbations for novel perturbation prediction
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ num_known_perturbations: int = 50, # Number of known perturbation types
21
+ gene_dim: int = 60000,
22
+ hidden_dims: List[int] = [512, 1024, 2048],
23
+ perturbation_embedding_dim: int = 128,
24
+ biological_prior_dim: int = 256,
25
+ dropout_rate: float = 0.1,
26
+ device: str = None):
27
+ """
28
+ Multi-modal decoder with correct similarity matrix definition
29
+
30
+ Args:
31
+ latent_dim: Latent variable dimension
32
+ num_known_perturbations: Number of known perturbation types for one-hot encoding
33
+ gene_dim: Number of genes
34
+ hidden_dims: Hidden layer dimensions
35
+ perturbation_embedding_dim: Embedding dimension for perturbations
36
+ biological_prior_dim: Dimension for biological prior knowledge
37
+ dropout_rate: Dropout rate
38
+ device: Computation device
39
+ """
40
+ self.latent_dim = latent_dim
41
+ self.num_known_perturbations = num_known_perturbations
42
+ self.gene_dim = gene_dim
43
+ self.hidden_dims = hidden_dims
44
+ self.perturbation_embedding_dim = perturbation_embedding_dim
45
+ self.biological_prior_dim = biological_prior_dim
46
+ self.dropout_rate = dropout_rate
47
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
48
+
49
+ # Initialize multi-modal model
50
+ self.model = self._build_corrected_model()
51
+ self.model.to(self.device)
52
+
53
+ # Training state
54
+ self.is_trained = False
55
+ self.training_history = None
56
+ self.best_val_loss = float('inf')
57
+ self.known_perturbation_names = [] # For mapping indices to perturbation names
58
+ self.perturbation_prototypes = None # Learned perturbation representations
59
+
60
+ print(f"🧬 PerturbationAwareDecoder Initialized:")
61
+ print(f" - Latent Dimension: {latent_dim}")
62
+ print(f" - Known Perturbations: {num_known_perturbations}")
63
+ print(f" - Gene Dimension: {gene_dim}")
64
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
65
+
66
+ class CorrectedPerturbationEncoder(nn.Module):
67
+ """Encoder for one-hot encoded perturbations"""
68
+
69
+ def __init__(self, num_perturbations: int, embedding_dim: int, hidden_dim: int):
70
+ super().__init__()
71
+ self.num_perturbations = num_perturbations
72
+
73
+ # Embedding for perturbation types
74
+ self.perturbation_embedding = nn.Embedding(num_perturbations, embedding_dim)
75
+
76
+ # Projection to hidden space
77
+ self.projection = nn.Sequential(
78
+ nn.Linear(embedding_dim, hidden_dim),
79
+ nn.ReLU(),
80
+ nn.Dropout(0.1),
81
+ nn.Linear(hidden_dim, hidden_dim)
82
+ )
83
+
84
+ # Attention mechanism for perturbation context
85
+ self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
86
+ self.norm = nn.LayerNorm(hidden_dim)
87
+
88
+ def forward(self, one_hot_perturbations):
89
+ """
90
+ Args:
91
+ one_hot_perturbations: [batch_size, num_perturbations] one-hot encoded
92
+ """
93
+ batch_size = one_hot_perturbations.shape[0]
94
+
95
+ # Convert one-hot to indices
96
+ perturbation_indices = torch.argmax(one_hot_perturbations, dim=1) # [batch_size]
97
+
98
+ # Get perturbation embeddings
99
+ perturbation_embeds = self.perturbation_embedding(perturbation_indices) # [batch_size, embedding_dim]
100
+
101
+ # Project to hidden space
102
+ hidden_repr = self.projection(perturbation_embeds) # [batch_size, hidden_dim]
103
+
104
+ # Add sequence dimension for attention
105
+ hidden_repr = hidden_repr.unsqueeze(1) # [batch_size, 1, hidden_dim]
106
+
107
+ # Self-attention for perturbation context
108
+ attended, _ = self.attention(hidden_repr, hidden_repr, hidden_repr)
109
+ attended = self.norm(hidden_repr + attended)
110
+
111
+ return attended.squeeze(1) # [batch_size, hidden_dim]
112
+
113
+ class CrossModalFusion(nn.Module):
114
+ """Cross-modal fusion of latent variables and perturbation information"""
115
+
116
+ def __init__(self, latent_dim: int, perturbation_dim: int, fusion_dim: int):
117
+ super().__init__()
118
+ self.latent_projection = nn.Linear(latent_dim, fusion_dim)
119
+ self.perturbation_projection = nn.Linear(perturbation_dim, fusion_dim)
120
+
121
+ # Cross-attention
122
+ self.cross_attention = nn.MultiheadAttention(
123
+ fusion_dim, num_heads=8, batch_first=True
124
+ )
125
+
126
+ # Fusion gate
127
+ self.fusion_gate = nn.Sequential(
128
+ nn.Linear(fusion_dim * 2, fusion_dim),
129
+ nn.Sigmoid()
130
+ )
131
+
132
+ self.norm = nn.LayerNorm(fusion_dim)
133
+ self.dropout = nn.Dropout(0.1)
134
+
135
+ def forward(self, latent, perturbation_encoded):
136
+ # Project both modalities
137
+ latent_proj = self.latent_projection(latent).unsqueeze(1) # [batch_size, 1, fusion_dim]
138
+ perturbation_proj = self.perturbation_projection(perturbation_encoded).unsqueeze(1)
139
+
140
+ # Cross-attention: latent attends to perturbation
141
+ attended, _ = self.cross_attention(latent_proj, perturbation_proj, perturbation_proj)
142
+
143
+ # Gated fusion
144
+ concatenated = torch.cat([attended, latent_proj], dim=-1)
145
+ fusion_gate = self.fusion_gate(concatenated)
146
+ fused = fusion_gate * attended + (1 - fusion_gate) * latent_proj
147
+
148
+ fused = self.norm(fused)
149
+ fused = self.dropout(fused)
150
+
151
+ return fused.squeeze(1)
152
+
153
+ class PerturbationResponseNetwork(nn.Module):
154
+ """Network for predicting perturbation-specific responses"""
155
+
156
+ def __init__(self, fusion_dim: int, gene_dim: int, hidden_dims: List[int]):
157
+ super().__init__()
158
+
159
+ # Base network
160
+ layers = []
161
+ input_dim = fusion_dim
162
+
163
+ for hidden_dim in hidden_dims:
164
+ layers.extend([
165
+ nn.Linear(input_dim, hidden_dim),
166
+ nn.BatchNorm1d(hidden_dim),
167
+ nn.ReLU(),
168
+ nn.Dropout(0.1)
169
+ ])
170
+ input_dim = hidden_dim
171
+
172
+ self.base_network = nn.Sequential(*layers)
173
+ self.final_projection = nn.Linear(hidden_dims[-1], gene_dim)
174
+
175
+ # Perturbation-aware scaling
176
+ self.scale = nn.Linear(fusion_dim, 1)
177
+ self.bias = nn.Linear(fusion_dim, 1)
178
+
179
+ def forward(self, fused_representation):
180
+ base_output = self.base_network(fused_representation)
181
+ expression = self.final_projection(base_output)
182
+
183
+ # Perturbation-aware scaling
184
+ scale = torch.sigmoid(self.scale(fused_representation)) * 2
185
+ bias = self.bias(fused_representation)
186
+
187
+ return F.softplus(expression * scale + bias)
188
+
189
+ class CorrectedNovelPerturbationPredictor(nn.Module):
190
+ """Predictor for novel perturbations using similarity to known perturbations"""
191
+
192
+ def __init__(self, num_known_perturbations: int, gene_dim: int, hidden_dim: int):
193
+ super().__init__()
194
+ self.num_known_perturbations = num_known_perturbations
195
+ self.gene_dim = gene_dim
196
+
197
+ # Learnable perturbation prototypes (response patterns for known perturbations)
198
+ self.perturbation_prototypes = nn.Parameter(
199
+ torch.randn(num_known_perturbations, gene_dim) * 0.1
200
+ )
201
+
202
+ # Similarity-based response generator
203
+ self.response_generator = nn.Sequential(
204
+ nn.Linear(num_known_perturbations, hidden_dim), # Input: similarity to known perturbations
205
+ nn.ReLU(),
206
+ nn.Linear(hidden_dim, gene_dim)
207
+ )
208
+
209
+ # Attention mechanism for combining prototypes
210
+ self.attention_weights = nn.Parameter(torch.randn(num_known_perturbations, 1))
211
+
212
+ def forward(self, similarity_matrix, latent_features=None):
213
+ """
214
+ Predict response to novel perturbation using similarity to known perturbations
215
+
216
+ Args:
217
+ similarity_matrix: [batch_size, num_known_perturbations]
218
+ Each row: similarity scores between novel perturbation and known perturbations
219
+ latent_features: [batch_size, latent_dim] (optional) cell state information
220
+
221
+ Returns:
222
+ expression: [batch_size, gene_dim] predicted expression
223
+ """
224
+ batch_size = similarity_matrix.shape[0]
225
+
226
+ # Method 1: Weighted combination of known perturbation prototypes
227
+ # similarity_matrix: [batch_size, num_known_perturbations]
228
+ # perturbation_prototypes: [num_known_perturbations, gene_dim]
229
+ weighted_prototypes = torch.matmul(similarity_matrix, self.perturbation_prototypes) # [batch_size, gene_dim]
230
+
231
+ # Method 2: Direct generation from similarity profile
232
+ generated_response = self.response_generator(similarity_matrix) # [batch_size, gene_dim]
233
+
234
+ # Combine both methods with learned weights
235
+ combination_weights = torch.sigmoid(
236
+ similarity_matrix.mean(dim=1, keepdim=True) # [batch_size, 1]
237
+ )
238
+
239
+ combined_response = (combination_weights * weighted_prototypes +
240
+ (1 - combination_weights) * generated_response)
241
+
242
+ # If latent features provided, modulate response by cell state
243
+ if latent_features is not None:
244
+ # Simple modulation based on latent state
245
+ modulation = torch.sigmoid(latent_features.mean(dim=1, keepdim=True)) # [batch_size, 1]
246
+ combined_response = combined_response * (1 + 0.5 * modulation)
247
+
248
+ return F.softplus(combined_response)
249
+
250
+ class CorrectedMultimodalDecoder(nn.Module):
251
+ """Main decoder with corrected similarity matrix handling"""
252
+
253
+ def __init__(self, latent_dim: int, num_known_perturbations: int, gene_dim: int,
254
+ hidden_dims: List[int], perturbation_embedding_dim: int,
255
+ biological_prior_dim: int, dropout_rate: float):
256
+ super().__init__()
257
+
258
+ self.num_known_perturbations = num_known_perturbations
259
+ self.latent_dim = latent_dim
260
+ self.gene_dim = gene_dim
261
+
262
+ # Perturbation encoder for one-hot inputs
263
+ self.perturbation_encoder = PerturbationAwareDecoder.CorrectedPerturbationEncoder(
264
+ num_known_perturbations, perturbation_embedding_dim, hidden_dims[0]
265
+ )
266
+
267
+ # Cross-modal fusion
268
+ self.cross_modal_fusion = PerturbationAwareDecoder.CrossModalFusion(
269
+ latent_dim, hidden_dims[0], hidden_dims[0]
270
+ )
271
+
272
+ # Response network for known perturbations
273
+ self.response_network = PerturbationAwareDecoder.PerturbationResponseNetwork(
274
+ hidden_dims[0], gene_dim, hidden_dims[1:]
275
+ )
276
+
277
+ # Novel perturbation predictor
278
+ self.novel_predictor = PerturbationAwareDecoder.CorrectedNovelPerturbationPredictor(
279
+ num_known_perturbations, gene_dim, hidden_dims[0]
280
+ )
281
+
282
+ def forward(self, latent, perturbation_matrix, mode='one_hot'):
283
+ """
284
+ Forward pass with corrected similarity matrix definition
285
+
286
+ Args:
287
+ latent: [batch_size, latent_dim] latent variables
288
+ perturbation_matrix:
289
+ - one_hot mode: [batch_size, num_known_perturbations] one-hot encoded
290
+ - similarity mode: [batch_size, num_known_perturbations] similarity scores
291
+ mode: 'one_hot' for known perturbations, 'similarity' for novel perturbations
292
+ """
293
+ if mode == 'one_hot':
294
+ # Known perturbation pathway
295
+ perturbation_encoded = self.perturbation_encoder(perturbation_matrix)
296
+ fused = self.cross_modal_fusion(latent, perturbation_encoded)
297
+ expression = self.response_network(fused)
298
+
299
+ elif mode == 'similarity':
300
+ # Novel perturbation pathway
301
+ # perturbation_matrix: similarity to known perturbations [batch_size, num_known_perturbations]
302
+ expression = self.novel_predictor(perturbation_matrix, latent)
303
+
304
+ else:
305
+ raise ValueError(f"Unknown mode: {mode}. Use 'one_hot' or 'similarity'")
306
+
307
+ return expression
308
+
309
+ def get_perturbation_prototypes(self):
310
+ """Get learned perturbation response prototypes"""
311
+ return self.novel_predictor.perturbation_prototypes.detach()
312
+
313
+ def _build_corrected_model(self):
314
+ """Build the corrected model"""
315
+ return self.CorrectedMultimodalDecoder(
316
+ self.latent_dim, self.num_known_perturbations, self.gene_dim,
317
+ self.hidden_dims, self.perturbation_embedding_dim,
318
+ self.biological_prior_dim, self.dropout_rate
319
+ )
320
+
321
+ def train(self,
322
+ train_latent: np.ndarray,
323
+ train_perturbations: np.ndarray, # One-hot encoded [n_samples, num_known_perturbations]
324
+ train_expression: np.ndarray,
325
+ val_latent: np.ndarray = None,
326
+ val_perturbations: np.ndarray = None,
327
+ val_expression: np.ndarray = None,
328
+ batch_size: int = 32,
329
+ num_epochs: int = 200,
330
+ learning_rate: float = 1e-4,
331
+ checkpoint_path: str = 'corrected_decoder.pth') -> Dict:
332
+ """
333
+ Train the decoder with one-hot encoded perturbations
334
+ """
335
+ print("🧬 Starting Training with Corrected Similarity Definition...")
336
+
337
+ # Validate one-hot encoding
338
+ self._validate_one_hot_perturbations(train_perturbations)
339
+
340
+ # Data preparation
341
+ train_dataset = self._create_dataset(train_latent, train_perturbations, train_expression)
342
+
343
+ if val_latent is not None and val_perturbations is not None and val_expression is not None:
344
+ val_dataset = self._create_dataset(val_latent, val_perturbations, val_expression)
345
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
346
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
347
+ else:
348
+ train_size = int(0.9 * len(train_dataset))
349
+ val_size = len(train_dataset) - train_size
350
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
351
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
352
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
353
+
354
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
355
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
356
+ print(f"🧪 Known perturbations: {self.num_known_perturbations}")
357
+
358
+ # Optimizer
359
+ optimizer = optim.AdamW(
360
+ self.model.parameters(),
361
+ lr=learning_rate,
362
+ weight_decay=1e-5
363
+ )
364
+
365
+ # Scheduler
366
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
367
+
368
+ # Loss function
369
+ def loss_fn(pred, target):
370
+ mse_loss = F.mse_loss(pred, target)
371
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
372
+ correlation = self._pearson_correlation(pred, target)
373
+ correlation_loss = 1 - correlation
374
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
375
+
376
+ # Training history
377
+ history = {
378
+ 'train_loss': [], 'val_loss': [],
379
+ 'train_mse': [], 'val_mse': [],
380
+ 'train_correlation': [], 'val_correlation': [],
381
+ 'learning_rates': []
382
+ }
383
+
384
+ best_val_loss = float('inf')
385
+ patience = 20
386
+ patience_counter = 0
387
+
388
+ print("\n🔬 Starting training...")
389
+ for epoch in range(1, num_epochs + 1):
390
+ # Training
391
+ train_metrics = self._train_epoch(train_loader, optimizer, loss_fn)
392
+
393
+ # Validation
394
+ val_metrics = self._validate_epoch(val_loader, loss_fn)
395
+
396
+ # Update scheduler
397
+ scheduler.step()
398
+ current_lr = optimizer.param_groups[0]['lr']
399
+
400
+ # Record history
401
+ history['train_loss'].append(train_metrics['loss'])
402
+ history['val_loss'].append(val_metrics['loss'])
403
+ history['train_mse'].append(train_metrics['mse'])
404
+ history['val_mse'].append(val_metrics['mse'])
405
+ history['train_correlation'].append(train_metrics['correlation'])
406
+ history['val_correlation'].append(val_metrics['correlation'])
407
+ history['learning_rates'].append(current_lr)
408
+
409
+ # Print progress
410
+ if epoch % 10 == 0 or epoch == 1:
411
+ print(f"🧪 Epoch {epoch:3d}/{num_epochs} | "
412
+ f"Train: {train_metrics['loss']:.4f} | "
413
+ f"Val: {val_metrics['loss']:.4f} | "
414
+ f"Corr: {val_metrics['correlation']:.4f} | "
415
+ f"LR: {current_lr:.2e}")
416
+
417
+ # Early stopping
418
+ if val_metrics['loss'] < best_val_loss:
419
+ best_val_loss = val_metrics['loss']
420
+ patience_counter = 0
421
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
422
+ else:
423
+ patience_counter += 1
424
+ if patience_counter >= patience:
425
+ print(f"🛑 Early stopping at epoch {epoch}")
426
+ break
427
+
428
+ self.is_trained = True
429
+ self.training_history = history
430
+ self.best_val_loss = best_val_loss
431
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
432
+
433
+ print(f"\n🎉 Training completed! Best val loss: {best_val_loss:.4f}")
434
+ print(f"📊 Learned perturbation prototypes: {self.perturbation_prototypes.shape}")
435
+ return history
436
+
437
+ def _validate_one_hot_perturbations(self, perturbations):
438
+ """Validate that perturbations are proper one-hot encodings"""
439
+ assert perturbations.shape[1] == self.num_known_perturbations, \
440
+ f"Perturbation dimension {perturbations.shape[1]} doesn't match expected {self.num_known_perturbations}"
441
+
442
+ # Check that each row sums to 1 (perturbation) or 0 (control)
443
+ row_sums = perturbations.sum(axis=1)
444
+ valid_rows = np.all((row_sums == 0) | (row_sums == 1))
445
+ assert valid_rows, "Perturbations should be one-hot encoded (sum to 0 or 1 per row)"
446
+
447
+ print("✅ One-hot perturbations validated")
448
+
449
+ def _create_dataset(self, latent_data, perturbations, expression_data):
450
+ """Create dataset with one-hot perturbations"""
451
+ class OneHotDataset(Dataset):
452
+ def __init__(self, latent, perturbations, expression):
453
+ self.latent = torch.FloatTensor(latent)
454
+ self.perturbations = torch.FloatTensor(perturbations)
455
+ self.expression = torch.FloatTensor(expression)
456
+
457
+ def __len__(self):
458
+ return len(self.latent)
459
+
460
+ def __getitem__(self, idx):
461
+ return self.latent[idx], self.perturbations[idx], self.expression[idx]
462
+
463
+ return OneHotDataset(latent_data, perturbations, expression_data)
464
+
465
+ def predict(self,
466
+ latent_data: np.ndarray,
467
+ perturbations: np.ndarray,
468
+ batch_size: int = 32) -> np.ndarray:
469
+ """
470
+ Predict expression for known perturbations using one-hot encoding
471
+
472
+ Args:
473
+ latent_data: [n_samples, latent_dim] latent variables
474
+ perturbations: [n_samples, num_known_perturbations] one-hot encoded perturbations
475
+ batch_size: Batch size
476
+
477
+ Returns:
478
+ expression: [n_samples, gene_dim] predicted expression
479
+ """
480
+ if not self.is_trained:
481
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
482
+
483
+ # Validate one-hot encoding
484
+ self._validate_one_hot_perturbations(perturbations)
485
+
486
+ self.model.eval()
487
+
488
+ if isinstance(latent_data, np.ndarray):
489
+ latent_data = torch.FloatTensor(latent_data)
490
+ if isinstance(perturbations, np.ndarray):
491
+ perturbations = torch.FloatTensor(perturbations)
492
+
493
+ predictions = []
494
+ with torch.no_grad():
495
+ for i in range(0, len(latent_data), batch_size):
496
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
497
+ batch_perturbations = perturbations[i:i+batch_size].to(self.device)
498
+
499
+ # Use one-hot mode for known perturbations
500
+ batch_pred = self.model(batch_latent, batch_perturbations, mode='one_hot')
501
+ predictions.append(batch_pred.cpu())
502
+
503
+ return torch.cat(predictions).numpy()
504
+
505
+ def predict_novel_perturbation(self,
506
+ latent_data: np.ndarray,
507
+ similarity_matrix: np.ndarray,
508
+ batch_size: int = 32) -> np.ndarray:
509
+ """
510
+ Predict response to novel perturbations using similarity to known perturbations
511
+
512
+ Args:
513
+ latent_data: [n_samples, latent_dim] latent variables
514
+ similarity_matrix: [n_samples, num_known_perturbations]
515
+ Each row: similarity scores between novel perturbation and known perturbations
516
+ Columns correspond to model's known perturbation types
517
+ batch_size: Batch size
518
+
519
+ Returns:
520
+ expression: [n_samples, gene_dim] predicted expression
521
+ """
522
+ if not self.is_trained:
523
+ warnings.warn("⚠️ Model not trained. Novel perturbation prediction may be inaccurate.")
524
+
525
+ # Validate similarity matrix dimensions
526
+ assert similarity_matrix.shape[1] == self.num_known_perturbations, \
527
+ f"Similarity matrix columns {similarity_matrix.shape[1]} must match known perturbations {self.num_known_perturbations}"
528
+
529
+ # Validate similarity scores are reasonable (0-1 range recommended)
530
+ if np.any(similarity_matrix < 0) or np.any(similarity_matrix > 2):
531
+ warnings.warn("⚠️ Similarity scores outside typical range [0, 1]. Consider normalizing.")
532
+
533
+ self.model.eval()
534
+
535
+ if isinstance(latent_data, np.ndarray):
536
+ latent_data = torch.FloatTensor(latent_data)
537
+ if isinstance(similarity_matrix, np.ndarray):
538
+ similarity_matrix = torch.FloatTensor(similarity_matrix)
539
+
540
+ predictions = []
541
+ with torch.no_grad():
542
+ for i in range(0, len(latent_data), batch_size):
543
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
544
+ batch_similarity = similarity_matrix[i:i+batch_size].to(self.device)
545
+
546
+ # Use similarity mode for novel perturbations
547
+ batch_pred = self.model(batch_latent, batch_similarity, mode='similarity')
548
+ predictions.append(batch_pred.cpu())
549
+
550
+ return torch.cat(predictions).numpy()
551
+
552
+ def get_known_perturbation_prototypes(self) -> np.ndarray:
553
+ """Get learned response prototypes for known perturbations"""
554
+ if not self.is_trained:
555
+ warnings.warn("⚠️ Model not trained. Prototypes may be uninformative.")
556
+
557
+ if self.perturbation_prototypes is None:
558
+ self.model.eval()
559
+ with torch.no_grad():
560
+ self.perturbation_prototypes = self.model.get_perturbation_prototypes().cpu().numpy()
561
+
562
+ return self.perturbation_prototypes
563
+
564
+ def compute_similarity(self, novel_perturbation_features: np.ndarray,
565
+ known_perturbation_features: np.ndarray = None) -> np.ndarray:
566
+ """
567
+ Compute similarity matrix between novel perturbations and known perturbations
568
+
569
+ Args:
570
+ novel_perturbation_features: [n_novel, feature_dim] features of novel perturbations
571
+ known_perturbation_features: [n_known, feature_dim] features of known perturbations
572
+ If None, uses learned perturbation prototypes
573
+
574
+ Returns:
575
+ similarity_matrix: [n_novel, num_known_perturbations] similarity scores
576
+ """
577
+ if known_perturbation_features is None:
578
+ # Use learned prototypes
579
+ known_perturbation_features = self.get_known_perturbation_prototypes()
580
+
581
+ # Normalize features for cosine similarity
582
+ novel_norm = novel_perturbation_features / (np.linalg.norm(novel_perturbation_features, axis=1, keepdims=True) + 1e-8)
583
+ known_norm = known_perturbation_features / (np.linalg.norm(known_perturbation_features, axis=1, keepdims=True) + 1e-8)
584
+
585
+ # Compute cosine similarity
586
+ similarity_matrix = np.dot(novel_norm, known_norm.T)
587
+
588
+ return similarity_matrix
589
+
590
+ def _pearson_correlation(self, pred, target):
591
+ """Calculate Pearson correlation coefficient"""
592
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
593
+ target_centered = target - target.mean(dim=1, keepdim=True)
594
+
595
+ numerator = (pred_centered * target_centered).sum(dim=1)
596
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
597
+
598
+ return (numerator / (denominator + 1e-8)).mean()
599
+
600
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
601
+ """Train one epoch"""
602
+ self.model.train()
603
+ total_loss = 0
604
+ total_mse = 0
605
+ total_correlation = 0
606
+
607
+ for latent, perturbations, target in train_loader:
608
+ latent = latent.to(self.device)
609
+ perturbations = perturbations.to(self.device)
610
+ target = target.to(self.device)
611
+
612
+ optimizer.zero_grad()
613
+ pred = self.model(latent, perturbations, mode='one_hot')
614
+
615
+ loss = loss_fn(pred, target)
616
+ loss.backward()
617
+
618
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
619
+ optimizer.step()
620
+
621
+ mse_loss = F.mse_loss(pred, target).item()
622
+ correlation = self._pearson_correlation(pred, target).item()
623
+
624
+ total_loss += loss.item()
625
+ total_mse += mse_loss
626
+ total_correlation += correlation
627
+
628
+ num_batches = len(train_loader)
629
+ return {
630
+ 'loss': total_loss / num_batches,
631
+ 'mse': total_mse / num_batches,
632
+ 'correlation': total_correlation / num_batches
633
+ }
634
+
635
+ def _validate_epoch(self, val_loader, loss_fn):
636
+ """Validate one epoch"""
637
+ self.model.eval()
638
+ total_loss = 0
639
+ total_mse = 0
640
+ total_correlation = 0
641
+
642
+ with torch.no_grad():
643
+ for latent, perturbations, target in val_loader:
644
+ latent = latent.to(self.device)
645
+ perturbations = perturbations.to(self.device)
646
+ target = target.to(self.device)
647
+
648
+ pred = self.model(latent, perturbations, mode='one_hot')
649
+ loss = loss_fn(pred, target)
650
+ mse_loss = F.mse_loss(pred, target).item()
651
+ correlation = self._pearson_correlation(pred, target).item()
652
+
653
+ total_loss += loss.item()
654
+ total_mse += mse_loss
655
+ total_correlation += correlation
656
+
657
+ num_batches = len(val_loader)
658
+ return {
659
+ 'loss': total_loss / num_batches,
660
+ 'mse': total_mse / num_batches,
661
+ 'correlation': total_correlation / num_batches
662
+ }
663
+
664
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
665
+ """Save model checkpoint"""
666
+ torch.save({
667
+ 'epoch': epoch,
668
+ 'model_state_dict': self.model.state_dict(),
669
+ 'optimizer_state_dict': optimizer.state_dict(),
670
+ 'scheduler_state_dict': scheduler.state_dict(),
671
+ 'best_val_loss': best_loss,
672
+ 'training_history': history,
673
+ 'perturbation_prototypes': self.perturbation_prototypes,
674
+ 'model_config': {
675
+ 'latent_dim': self.latent_dim,
676
+ 'num_known_perturbations': self.num_known_perturbations,
677
+ 'gene_dim': self.gene_dim,
678
+ 'hidden_dims': self.hidden_dims
679
+ }
680
+ }, path)
681
+
682
+ def load_model(self, model_path: str):
683
+ """Load pre-trained model"""
684
+ checkpoint = torch.load(model_path, map_location=self.device)
685
+ self.model.load_state_dict(checkpoint['model_state_dict'])
686
+ self.perturbation_prototypes = checkpoint.get('perturbation_prototypes')
687
+ self.is_trained = True
688
+ self.training_history = checkpoint.get('training_history')
689
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
690
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
691
+
692
+ '''# Example usage
693
+ def example_usage():
694
+ """Example demonstration of the corrected perturbation decoder"""
695
+
696
+ # Initialize decoder
697
+ decoder = PerturbationAwareDecoder(
698
+ latent_dim=100,
699
+ num_known_perturbations=10, # 10 known perturbation types
700
+ gene_dim=2000, # Reduced for example
701
+ hidden_dims=[256, 512, 1024],
702
+ perturbation_embedding_dim=128
703
+ )
704
+
705
+ # Generate example data
706
+ n_samples = 1000
707
+ n_perturbations = 10
708
+
709
+ # Latent variables
710
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
711
+
712
+ # One-hot encoded perturbations
713
+ perturbations = np.zeros((n_samples, n_perturbations))
714
+ for i in range(n_samples):
715
+ if i % 10 != 0: # 90% perturbed, 10% control
716
+ perturb_id = np.random.randint(0, n_perturbations)
717
+ perturbations[i, perturb_id] = 1.0
718
+
719
+ # Expression data with perturbation effects
720
+ base_weights = np.random.randn(100, 2000) * 0.1
721
+ perturbation_effects = np.random.randn(n_perturbations, 2000) * 0.5
722
+
723
+ expression_data = np.tanh(latent_data.dot(base_weights))
724
+ for i in range(n_samples):
725
+ if perturbations[i].sum() > 0: # Perturbed sample
726
+ perturb_id = np.argmax(perturbations[i])
727
+ expression_data[i] += perturbation_effects[perturb_id]
728
+
729
+ expression_data = np.maximum(expression_data, 0)
730
+
731
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Perturbations {perturbations.shape}, Expression {expression_data.shape}")
732
+ print(f"🧪 Control samples: {(perturbations.sum(axis=1) == 0).sum()}")
733
+ print(f"🧪 Perturbed samples: {(perturbations.sum(axis=1) > 0).sum()}")
734
+
735
+ # Train with one-hot perturbations
736
+ history = decoder.train(
737
+ train_latent=latent_data,
738
+ train_perturbations=perturbations,
739
+ train_expression=expression_data,
740
+ batch_size=32,
741
+ num_epochs=50
742
+ )
743
+
744
+ # Test predictions with one-hot perturbations
745
+ test_latent = np.random.randn(10, 100).astype(np.float32)
746
+ test_perturbations = np.zeros((10, n_perturbations))
747
+ for i in range(10):
748
+ test_perturbations[i, i % n_perturbations] = 1.0 # One-hot encoding
749
+
750
+ predictions = decoder.predict(test_latent, test_perturbations)
751
+ print(f"🔮 Known perturbation prediction shape: {predictions.shape}")
752
+
753
+ # Test novel perturbation prediction with similarity matrix
754
+ test_latent_novel = np.random.randn(5, 100).astype(np.float32)
755
+
756
+ # Create similarity matrix: [5, 10] - 5 novel perturbations, similarity to 10 known perturbations
757
+ similarity_matrix = np.random.rand(5, n_perturbations)
758
+ similarity_matrix = similarity_matrix / similarity_matrix.sum(axis=1, keepdims=True) # Normalize rows
759
+
760
+ novel_predictions = decoder.predict_novel_perturbation(test_latent_novel, similarity_matrix)
761
+ print(f"🔮 Novel perturbation prediction shape: {novel_predictions.shape}")
762
+
763
+ # Get learned perturbation prototypes
764
+ prototypes = decoder.get_known_perturbation_prototypes()
765
+ print(f"📊 Perturbation prototypes shape: {prototypes.shape}")
766
+
767
+ return decoder
768
+
769
+ if __name__ == "__main__":
770
+ example_usage()'''
@@ -56,7 +56,6 @@ class VirtualCellDecoder:
56
56
  print(f" - Biological Prior Dimension: {biological_prior_dim}")
57
57
  print(f" - Device: {self.device}")
58
58
  print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
59
- print(f" - Estimated GPU Memory: ~8-10GB (optimized for 20GB)")
60
59
 
61
60
  class BiologicalPriorNetwork(nn.Module):
62
61
  """Biological prior network based on gene regulatory knowledge"""
@@ -5,6 +5,7 @@ from .TranscriptomeDecoder import TranscriptomeDecoder
5
5
  from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
6
6
  from .EfficientTranscriptomeDecoder import EfficientTranscriptomeDecoder
7
7
  from .VirtualCellDecoder import VirtualCellDecoder
8
+ from .PerturbationAwareDecoder import PerturbationAwareDecoder
8
9
 
9
10
  from . import utils
10
11
  from . import codebook
@@ -18,6 +19,8 @@ from . import TranscriptomeDecoder
18
19
  from . import SimpleTranscriptomeDecoder
19
20
  from . import EfficientTranscriptomeDecoder
20
21
  from . import VirtualCellDecoder
22
+ from . import PerturbationAwareDecoder
21
23
 
22
24
  __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
23
- 'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
25
+ 'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
26
+ 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.17
3
+ Version: 2.4.20
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -4,6 +4,7 @@ setup.py
4
4
  SURE/DensityFlow.py
5
5
  SURE/EfficientTranscriptomeDecoder.py
6
6
  SURE/PerturbE.py
7
+ SURE/PerturbationAwareDecoder.py
7
8
  SURE/SURE.py
8
9
  SURE/SimpleTranscriptomeDecoder.py
9
10
  SURE/TranscriptomeDecoder.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.4.17',
8
+ version='2.4.20',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes