SURE-tools 2.4.4__tar.gz → 2.4.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (33) hide show
  1. {sure_tools-2.4.4 → sure_tools-2.4.7}/PKG-INFO +1 -1
  2. sure_tools-2.4.7/SURE/TranscriptomeDecoder.py +499 -0
  3. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/PKG-INFO +1 -1
  4. {sure_tools-2.4.4 → sure_tools-2.4.7}/setup.py +1 -1
  5. sure_tools-2.4.4/SURE/TranscriptomeDecoder.py +0 -529
  6. {sure_tools-2.4.4 → sure_tools-2.4.7}/LICENSE +0 -0
  7. {sure_tools-2.4.4 → sure_tools-2.4.7}/README.md +0 -0
  8. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/DensityFlow.py +0 -0
  9. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/PerturbE.py +0 -0
  10. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/SURE.py +0 -0
  11. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/__init__.py +0 -0
  12. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/flow/flow_stats.py +0 -0
  21. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/perturb/__init__.py +0 -0
  23. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/perturb/perturb.py +0 -0
  24. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/utils/__init__.py +0 -0
  25. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/utils/custom_mlp.py +0 -0
  26. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/utils/queue.py +0 -0
  27. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE/utils/utils.py +0 -0
  28. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/SOURCES.txt +0 -0
  29. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/dependency_links.txt +0 -0
  30. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/entry_points.txt +0 -0
  31. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/requires.txt +0 -0
  32. {sure_tools-2.4.4 → sure_tools-2.4.7}/SURE_tools.egg-info/top_level.txt +0 -0
  33. {sure_tools-2.4.4 → sure_tools-2.4.7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.4
3
+ Version: 2.4.7
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -0,0 +1,499 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, Optional
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
10
+
11
+ class TranscriptomeDecoder:
12
+ """Transcriptome decoder"""
13
+
14
+ def __init__(self,
15
+ latent_dim: int = 100,
16
+ gene_dim: int = 60000,
17
+ hidden_dim: int = 512,
18
+ device: str = None):
19
+ """
20
+ Simple but powerful decoder for latent to transcriptome mapping
21
+
22
+ Args:
23
+ latent_dim: Latent variable dimension (typically 50-100)
24
+ gene_dim: Number of genes (full transcriptome ~60,000)
25
+ hidden_dim: Hidden dimension optimized
26
+ device: Computation device
27
+ """
28
+ self.latent_dim = latent_dim
29
+ self.gene_dim = gene_dim
30
+ self.hidden_dim = hidden_dim
31
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Initialize model
34
+ self.model = self._build_model()
35
+ self.model.to(self.device)
36
+
37
+ # Training state
38
+ self.is_trained = False
39
+ self.training_history = None
40
+ self.best_val_loss = float('inf')
41
+
42
+ print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
43
+ print(f" - Latent Dimension: {latent_dim}")
44
+ print(f" - Gene Dimension: {gene_dim}")
45
+ print(f" - Hidden Dimension: {hidden_dim}")
46
+ print(f" - Device: {self.device}")
47
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
48
+
49
+ class Decoder(nn.Module):
50
+ """Memory-efficient decoder architecture with dimension handling"""
51
+
52
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
53
+ super().__init__()
54
+ self.latent_dim = latent_dim
55
+ self.gene_dim = gene_dim
56
+ self.hidden_dim = hidden_dim
57
+
58
+ # Stage 1: Latent variable expansion
59
+ self.latent_expansion = nn.Sequential(
60
+ nn.Linear(latent_dim, hidden_dim * 2),
61
+ nn.BatchNorm1d(hidden_dim * 2),
62
+ nn.GELU(),
63
+ nn.Dropout(0.1),
64
+ nn.Linear(hidden_dim * 2, hidden_dim),
65
+ nn.BatchNorm1d(hidden_dim),
66
+ nn.GELU(),
67
+ )
68
+
69
+ # Stage 2: Direct projection to gene dimension (simpler approach)
70
+ self.gene_projector = nn.Sequential(
71
+ nn.Linear(hidden_dim, hidden_dim * 2),
72
+ nn.GELU(),
73
+ nn.Dropout(0.1),
74
+ nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
75
+ )
76
+
77
+ # Stage 3: Lightweight gene interaction
78
+ self.gene_interaction = nn.Sequential(
79
+ nn.Conv1d(1, 32, kernel_size=3, padding=1),
80
+ nn.GELU(),
81
+ nn.Dropout1d(0.1),
82
+ nn.Conv1d(32, 1, kernel_size=3, padding=1),
83
+ )
84
+
85
+ # Output scaling
86
+ self.output_scale = nn.Parameter(torch.ones(1))
87
+ self.output_bias = nn.Parameter(torch.zeros(1))
88
+
89
+ self._init_weights()
90
+
91
+ def _init_weights(self):
92
+ """Weight initialization"""
93
+ for module in self.modules():
94
+ if isinstance(module, nn.Linear):
95
+ nn.init.xavier_uniform_(module.weight)
96
+ if module.bias is not None:
97
+ nn.init.zeros_(module.bias)
98
+ elif isinstance(module, nn.Conv1d):
99
+ nn.init.kaiming_uniform_(module.weight)
100
+ if module.bias is not None:
101
+ nn.init.zeros_(module.bias)
102
+
103
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
104
+ batch_size = latent.shape[0]
105
+
106
+ # 1. Expand latent variables
107
+ latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
108
+
109
+ # 2. Direct projection to gene dimension
110
+ gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
111
+
112
+ # 3. Gene interaction with dimension safety
113
+ if self.gene_dim > 1: # Only apply if gene_dim > 1
114
+ gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
115
+ interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
116
+ gene_output = gene_output + interaction_output # Residual connection
117
+ gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
118
+
119
+ # 4. Final activation (ensure non-negative)
120
+ gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
121
+
122
+ return gene_output
123
+
124
+ def _build_model(self):
125
+ """Build the decoder model"""
126
+ return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
127
+
128
+ def _create_dataset(self, latent_data, expression_data):
129
+ """Create dataset with dimension validation"""
130
+ class SimpleDataset(Dataset):
131
+ def __init__(self, latent, expression):
132
+ # Ensure dimensions match
133
+ assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
134
+ assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
135
+ assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
136
+
137
+ self.latent = torch.FloatTensor(latent)
138
+ self.expression = torch.FloatTensor(expression)
139
+
140
+ def __len__(self):
141
+ return len(self.latent)
142
+
143
+ def __getitem__(self, idx):
144
+ return self.latent[idx], self.expression[idx]
145
+
146
+ return SimpleDataset(latent_data, expression_data)
147
+
148
+ def train(self,
149
+ train_latent: np.ndarray,
150
+ train_expression: np.ndarray,
151
+ val_latent: np.ndarray = None,
152
+ val_expression: np.ndarray = None,
153
+ batch_size: int = 32,
154
+ num_epochs: int = 100,
155
+ learning_rate: float = 1e-4,
156
+ checkpoint_path: str = 'transcriptome_decoder.pth'):
157
+ """
158
+ Train the decoder model with dimension safety
159
+
160
+ Args:
161
+ train_latent: Training latent variables [n_samples, latent_dim]
162
+ train_expression: Training expression data [n_samples, gene_dim]
163
+ val_latent: Validation latent variables (optional)
164
+ val_expression: Validation expression data (optional)
165
+ batch_size: Batch size optimized for memory
166
+ num_epochs: Number of training epochs
167
+ learning_rate: Learning rate
168
+ checkpoint_path: Path to save the best model
169
+ """
170
+ print("🚀 Starting training...")
171
+
172
+ # Dimension validation
173
+ self._validate_data_dimensions(train_latent, train_expression, "Training")
174
+ if val_latent is not None and val_expression is not None:
175
+ self._validate_data_dimensions(val_latent, val_expression, "Validation")
176
+
177
+ # Data preparation
178
+ train_dataset = self._create_safe_dataset(train_latent, train_expression)
179
+
180
+ if val_latent is not None and val_expression is not None:
181
+ val_dataset = self._create_safe_dataset(val_latent, val_expression)
182
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
183
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
184
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
185
+ else:
186
+ # Auto split
187
+ train_size = int(0.9 * len(train_dataset))
188
+ val_size = len(train_dataset) - train_size
189
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
190
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
191
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
192
+ print(f"📈 Auto-split validation: {val_size} samples")
193
+
194
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
195
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
196
+ print(f"📊 Batch size: {batch_size}")
197
+
198
+ # Optimizer configuration
199
+ optimizer = optim.AdamW(
200
+ self.model.parameters(),
201
+ lr=learning_rate,
202
+ weight_decay=0.01,
203
+ betas=(0.9, 0.999)
204
+ )
205
+
206
+ # Learning rate scheduler
207
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
208
+
209
+ # Loss function with dimension safety
210
+ def safe_loss(pred, target):
211
+ # Ensure dimensions match
212
+ if pred.shape != target.shape:
213
+ print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
214
+ # Truncate to minimum dimension (safety measure)
215
+ min_dim = min(pred.shape[1], target.shape[1])
216
+ pred = pred[:, :min_dim]
217
+ target = target[:, :min_dim]
218
+
219
+ mse_loss = F.mse_loss(pred, target)
220
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
221
+ return mse_loss + 0.3 * poisson_loss
222
+
223
+ # Training history
224
+ history = {
225
+ 'train_loss': [],
226
+ 'val_loss': [],
227
+ 'learning_rate': []
228
+ }
229
+
230
+ best_val_loss = float('inf')
231
+ patience = 15
232
+ patience_counter = 0
233
+
234
+ print("\n📈 Starting training loop...")
235
+ for epoch in range(1, num_epochs + 1):
236
+ # Training phase
237
+ train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
238
+
239
+ # Validation phase
240
+ val_loss = self._validate_epoch(val_loader, safe_loss)
241
+
242
+ # Update scheduler
243
+ scheduler.step()
244
+ current_lr = scheduler.get_last_lr()[0]
245
+
246
+ # Record history
247
+ history['train_loss'].append(train_loss)
248
+ history['val_loss'].append(val_loss)
249
+ history['learning_rate'].append(current_lr)
250
+
251
+ # Print progress
252
+ if epoch % 5 == 0 or epoch == 1:
253
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
254
+ f"Train Loss: {train_loss:.4f} | "
255
+ f"Val Loss: {val_loss:.4f} | "
256
+ f"LR: {current_lr:.2e}")
257
+
258
+ # Early stopping and model saving
259
+ if val_loss < best_val_loss:
260
+ best_val_loss = val_loss
261
+ patience_counter = 0
262
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
263
+ if epoch % 10 == 0:
264
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
265
+ else:
266
+ patience_counter += 1
267
+ if patience_counter >= patience:
268
+ print(f"🛑 Early stopping at epoch {epoch}")
269
+ break
270
+
271
+ # Training completed
272
+ self.is_trained = True
273
+ self.training_history = history
274
+ self.best_val_loss = best_val_loss
275
+
276
+ print(f"\n🎉 Training completed!")
277
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
278
+ print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
279
+
280
+ return history
281
+
282
+ def _validate_data_dimensions(self, latent_data, expression_data, data_type):
283
+ """Validate input data dimensions"""
284
+ assert latent_data.shape[1] == self.latent_dim, (
285
+ f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
286
+ assert expression_data.shape[1] == self.gene_dim, (
287
+ f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
288
+ assert latent_data.shape[0] == expression_data.shape[0], (
289
+ f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
290
+ print(f"✅ {data_type} data dimensions validated")
291
+
292
+ def _create_safe_dataset(self, latent_data, expression_data):
293
+ """Create dataset with safety checks"""
294
+ class SafeDataset(Dataset):
295
+ def __init__(self, latent, expression):
296
+ self.latent = torch.FloatTensor(latent)
297
+ self.expression = torch.FloatTensor(expression)
298
+
299
+ # Safety check
300
+ if self.latent.shape[0] != self.expression.shape[0]:
301
+ raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
302
+
303
+ def __len__(self):
304
+ return len(self.latent)
305
+
306
+ def __getitem__(self, idx):
307
+ return self.latent[idx], self.expression[idx]
308
+
309
+ return SafeDataset(latent_data, expression_data)
310
+
311
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
312
+ """Train for one epoch with dimension safety"""
313
+ self.model.train()
314
+ total_loss = 0
315
+
316
+ for batch_idx, (latent, target) in enumerate(train_loader):
317
+ latent = latent.to(self.device)
318
+ target = target.to(self.device)
319
+
320
+ # Dimension check
321
+ if latent.shape[1] != self.latent_dim:
322
+ print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
323
+ continue
324
+
325
+ optimizer.zero_grad()
326
+ pred = self.model(latent)
327
+
328
+ # Final dimension check before loss calculation
329
+ if pred.shape[1] != target.shape[1]:
330
+ min_dim = min(pred.shape[1], target.shape[1])
331
+ pred = pred[:, :min_dim]
332
+ target = target[:, :min_dim]
333
+ if batch_idx == 0: # Only warn once
334
+ print(f"⚠️ Truncating to min dimension: {min_dim}")
335
+
336
+ loss = loss_fn(pred, target)
337
+ loss.backward()
338
+
339
+ # Gradient clipping for stability
340
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
341
+ optimizer.step()
342
+
343
+ total_loss += loss.item()
344
+
345
+ return total_loss / len(train_loader)
346
+
347
+ def _validate_epoch(self, val_loader, loss_fn):
348
+ """Validate for one epoch with dimension safety"""
349
+ self.model.eval()
350
+ total_loss = 0
351
+
352
+ with torch.no_grad():
353
+ for batch_idx, (latent, target) in enumerate(val_loader):
354
+ latent = latent.to(self.device)
355
+ target = target.to(self.device)
356
+
357
+ pred = self.model(latent)
358
+
359
+ # Dimension safety
360
+ if pred.shape[1] != target.shape[1]:
361
+ min_dim = min(pred.shape[1], target.shape[1])
362
+ pred = pred[:, :min_dim]
363
+ target = target[:, :min_dim]
364
+
365
+ loss = loss_fn(pred, target)
366
+ total_loss += loss.item()
367
+
368
+ return total_loss / len(val_loader)
369
+
370
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
371
+ """Save model checkpoint"""
372
+ torch.save({
373
+ 'epoch': epoch,
374
+ 'model_state_dict': self.model.state_dict(),
375
+ 'optimizer_state_dict': optimizer.state_dict(),
376
+ 'scheduler_state_dict': scheduler.state_dict(),
377
+ 'best_val_loss': best_loss,
378
+ 'training_history': history,
379
+ 'model_config': {
380
+ 'latent_dim': self.latent_dim,
381
+ 'gene_dim': self.gene_dim,
382
+ 'hidden_dim': self.hidden_dim
383
+ }
384
+ }, path)
385
+
386
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
387
+ """
388
+ Predict gene expression from latent variables
389
+
390
+ Args:
391
+ latent_data: Latent variables [n_samples, latent_dim]
392
+ batch_size: Prediction batch size
393
+
394
+ Returns:
395
+ expression: Predicted expression [n_samples, gene_dim]
396
+ """
397
+ if not self.is_trained:
398
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
399
+
400
+ self.model.eval()
401
+
402
+ # Input validation
403
+ if latent_data.shape[1] != self.latent_dim:
404
+ raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
405
+
406
+ if isinstance(latent_data, np.ndarray):
407
+ latent_data = torch.FloatTensor(latent_data)
408
+
409
+ # Predict in batches to save memory
410
+ predictions = []
411
+ with torch.no_grad():
412
+ for i in range(0, len(latent_data), batch_size):
413
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
414
+ batch_pred = self.model(batch_latent)
415
+ predictions.append(batch_pred.cpu())
416
+
417
+ return torch.cat(predictions).numpy()
418
+
419
+ def load_model(self, model_path: str):
420
+ """Load pre-trained model"""
421
+ checkpoint = torch.load(model_path, map_location=self.device)
422
+
423
+ # Check model configuration
424
+ if 'model_config' in checkpoint:
425
+ config = checkpoint['model_config']
426
+ if (config['latent_dim'] != self.latent_dim or
427
+ config['gene_dim'] != self.gene_dim):
428
+ print("⚠️ Model configuration mismatch. Reinitializing model.")
429
+ self.model = self._build_model()
430
+ self.model.to(self.device)
431
+
432
+ self.model.load_state_dict(checkpoint['model_state_dict'])
433
+ self.is_trained = True
434
+ self.training_history = checkpoint.get('training_history')
435
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
436
+
437
+ print(f"✅ Model loaded successfully!")
438
+ print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
439
+
440
+ def get_model_info(self) -> Dict:
441
+ """Get model information"""
442
+ return {
443
+ 'is_trained': self.is_trained,
444
+ 'best_val_loss': self.best_val_loss,
445
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
446
+ 'latent_dim': self.latent_dim,
447
+ 'gene_dim': self.gene_dim,
448
+ 'hidden_dim': self.hidden_dim,
449
+ 'device': str(self.device)
450
+ }
451
+ '''
452
+ # Example usage
453
+ def example_usage():
454
+ """Example demonstration with dimension safety"""
455
+
456
+ # 1. Initialize decoder
457
+ decoder = SimpleTranscriptomeDecoder(
458
+ latent_dim=100,
459
+ gene_dim=2000, # Reduced for example
460
+ hidden_dim=256
461
+ )
462
+
463
+ # 2. Generate example data with correct dimensions
464
+ n_samples = 1000
465
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
466
+
467
+ # Create simulated expression data
468
+ weights = np.random.randn(100, 2000) * 0.1
469
+ expression_data = np.tanh(latent_data.dot(weights))
470
+ expression_data = np.maximum(expression_data, 0) # Non-negative
471
+
472
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
473
+
474
+ # 3. Train the model
475
+ history = decoder.train(
476
+ train_latent=latent_data,
477
+ train_expression=expression_data,
478
+ batch_size=32,
479
+ num_epochs=50,
480
+ learning_rate=1e-4
481
+ )
482
+
483
+ # 4. Make predictions
484
+ test_latent = np.random.randn(10, 100).astype(np.float32)
485
+ predictions = decoder.predict(test_latent)
486
+ print(f"🔮 Prediction shape: {predictions.shape}")
487
+
488
+ # 5. Get model info
489
+ info = decoder.get_model_info()
490
+ print(f"\n📋 Model Info:")
491
+ for key, value in info.items():
492
+ print(f" {key}: {value}")
493
+
494
+ return decoder
495
+
496
+ if __name__ == "__main__":
497
+ example_usage()
498
+
499
+ '''
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.4
3
+ Version: 2.4.7
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.4.4',
8
+ version='2.4.7',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
@@ -1,529 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import torch.nn.functional as F
5
- from torch.utils.data import Dataset, DataLoader
6
- import numpy as np
7
- from typing import Dict, List, Tuple, Optional
8
- import matplotlib.pyplot as plt
9
- from tqdm import tqdm
10
- import warnings
11
- warnings.filterwarnings('ignore')
12
-
13
- class TranscriptomeDecoder:
14
- def __init__(self,
15
- latent_dim: int = 100,
16
- gene_dim: int = 60000,
17
- hidden_dim: int = 512, # Reduced for memory efficiency
18
- device: str = None):
19
- """
20
- Whole-transcriptome decoder
21
-
22
- Args:
23
- latent_dim: Latent variable dimension (typically 50-100)
24
- gene_dim: Number of genes (full transcriptome ~60,000)
25
- hidden_dim: Hidden dimension (reduced for memory efficiency)
26
- device: Computation device
27
- """
28
- self.latent_dim = latent_dim
29
- self.gene_dim = gene_dim
30
- self.hidden_dim = hidden_dim
31
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
-
33
- # Memory optimization settings
34
- self.gradient_checkpointing = True
35
- self.mixed_precision = True
36
-
37
- # Initialize model
38
- self.model = self._build_model()
39
- self.model.to(self.device)
40
-
41
- # Training state
42
- self.is_trained = False
43
- self.training_history = None
44
- self.best_val_loss = float('inf')
45
-
46
- print(f"🚀 TranscriptomeDecoder Initialized:")
47
- print(f" - Latent Dimension: {latent_dim}")
48
- print(f" - Gene Dimension: {gene_dim}")
49
- print(f" - Hidden Dimension: {hidden_dim}")
50
- print(f" - Device: {self.device}")
51
- print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
-
53
- class MemoryEfficientBlock(nn.Module):
54
- """Memory-efficient building block with gradient checkpointing"""
55
- def __init__(self, input_dim, output_dim, use_checkpointing=True):
56
- super().__init__()
57
- self.use_checkpointing = use_checkpointing
58
- self.net = nn.Sequential(
59
- nn.Linear(input_dim, output_dim),
60
- nn.BatchNorm1d(output_dim),
61
- nn.GELU(),
62
- nn.Dropout(0.1)
63
- )
64
-
65
- def forward(self, x):
66
- if self.use_checkpointing and self.training:
67
- return torch.utils.checkpoint.checkpoint(self.net, x)
68
- return self.net(x)
69
-
70
- class SparseGeneProjection(nn.Module):
71
- """Sparse gene projection to reduce memory usage"""
72
- def __init__(self, latent_dim, gene_dim, projection_dim=256):
73
- super().__init__()
74
- self.projection_dim = projection_dim
75
- self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
76
- self.latent_projection = nn.Linear(latent_dim, projection_dim)
77
- self.activation = nn.GELU()
78
-
79
- def forward(self, latent):
80
- # Project latent to gene space efficiently
81
- batch_size = latent.shape[0]
82
- latent_proj = self.latent_projection(latent) # [batch, projection_dim]
83
-
84
- # Efficient matrix multiplication
85
- gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
86
- output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
-
88
- return self.activation(output)
89
-
90
- class ChunkedTransformer(nn.Module):
91
- """Process genes in chunks to reduce memory usage"""
92
- def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
93
- super().__init__()
94
- self.chunk_size = chunk_size
95
- self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
- self.layers = nn.ModuleList([
97
- nn.Sequential(
98
- nn.Linear(hidden_dim, hidden_dim),
99
- nn.GELU(),
100
- nn.Dropout(0.1),
101
- nn.Linear(hidden_dim, hidden_dim),
102
- ) for _ in range(num_layers)
103
- ])
104
-
105
- def forward(self, x):
106
- # Process in chunks to save memory
107
- batch_size = x.shape[0]
108
- output = torch.zeros_like(x)
109
-
110
- for i in range(self.num_chunks):
111
- start_idx = i * self.chunk_size
112
- end_idx = min((i + 1) * self.chunk_size, x.shape[1])
113
-
114
- chunk = x[:, start_idx:end_idx]
115
- for layer in self.layers:
116
- chunk = layer(chunk) + chunk # Residual connection
117
-
118
- output[:, start_idx:end_idx] = chunk
119
-
120
- return output
121
-
122
- class Decoder(nn.Module):
123
- """Decoder model"""
124
- def __init__(self, latent_dim, gene_dim, hidden_dim):
125
- super().__init__()
126
- self.latent_dim = latent_dim
127
- self.gene_dim = gene_dim
128
- self.hidden_dim = hidden_dim
129
-
130
- # Stage 1: Latent expansion (memory efficient)
131
- self.latent_expansion = nn.Sequential(
132
- nn.Linear(latent_dim, hidden_dim * 2),
133
- nn.GELU(),
134
- nn.Dropout(0.1),
135
- nn.Linear(hidden_dim * 2, hidden_dim),
136
- )
137
-
138
- # Stage 2: Sparse gene projection
139
- self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
140
- latent_dim, gene_dim, hidden_dim
141
- )
142
-
143
- # Stage 3: Chunked processing
144
- self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
145
- gene_dim, hidden_dim, chunk_size=2000, num_layers=3
146
- )
147
-
148
- # Stage 4: Multi-head output with memory efficiency
149
- self.output_heads = nn.ModuleList([
150
- nn.Sequential(
151
- nn.Linear(hidden_dim, hidden_dim // 2),
152
- nn.GELU(),
153
- nn.Linear(hidden_dim // 2, 1)
154
- ) for _ in range(2) # Reduced from 3 to 2 heads
155
- ])
156
-
157
- # Adaptive fusion
158
- self.fusion_gate = nn.Sequential(
159
- nn.Linear(hidden_dim, hidden_dim // 4),
160
- nn.GELU(),
161
- nn.Linear(hidden_dim // 4, len(self.output_heads)),
162
- nn.Softmax(dim=-1)
163
- )
164
-
165
- # Output scaling
166
- self.output_scale = nn.Parameter(torch.ones(1))
167
- self.output_bias = nn.Parameter(torch.zeros(1))
168
-
169
- self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
170
-
171
- self._init_weights()
172
-
173
- def _init_weights(self):
174
- for module in self.modules():
175
- if isinstance(module, nn.Linear):
176
- nn.init.xavier_uniform_(module.weight)
177
- if module.bias is not None:
178
- nn.init.zeros_(module.bias)
179
-
180
- def forward(self, latent):
181
- batch_size = latent.shape[0]
182
-
183
- # 1. Latent expansion
184
- latent_expanded = self.latent_expansion(latent)
185
-
186
- # 2. Gene projection (memory efficient)
187
- gene_features = self.gene_projection(latent)
188
-
189
- # 3. Add latent information
190
- latent_gene_injection = self.latent_to_gene(latent_expanded)
191
- gene_features = gene_features + latent_gene_injection
192
-
193
- # 4. Chunked processing (memory efficient)
194
- gene_features = self.chunked_processor(gene_features)
195
-
196
- # 5. Multi-head output with chunking
197
- final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
198
-
199
- # Process output in chunks
200
- chunk_size = 5000
201
- for i in range(0, self.gene_dim, chunk_size):
202
- end_idx = min(i + chunk_size, self.gene_dim)
203
- chunk = gene_features[:, i:end_idx]
204
-
205
- head_outputs = []
206
- for head in self.output_heads:
207
- head_out = head(chunk).squeeze(-1)
208
- head_outputs.append(head_out)
209
-
210
- # Adaptive fusion
211
- gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
212
- gate_weights = gate_weights.unsqueeze(1)
213
-
214
- # Weighted fusion
215
- chunk_output = torch.zeros_like(head_outputs[0])
216
- for j, head_out in enumerate(head_outputs):
217
- chunk_output = chunk_output + gate_weights[:, :, j] * head_out
218
-
219
- final_output[:, i:end_idx] = chunk_output
220
-
221
- # Final activation
222
- final_output = F.softplus(final_output * self.output_scale + self.output_bias)
223
-
224
- return final_output
225
-
226
- def _build_model(self):
227
- """Build model"""
228
- return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
229
-
230
- def train(self,
231
- train_latent: np.ndarray,
232
- train_expression: np.ndarray,
233
- val_latent: np.ndarray = None,
234
- val_expression: np.ndarray = None,
235
- batch_size: int = 16, # Reduced batch size for memory
236
- num_epochs: int = 100,
237
- learning_rate: float = 1e-4,
238
- checkpoint_path: str = 'transcriptome_decoder.pth'):
239
- """
240
- Memory-efficient training with optimizations
241
-
242
- Args:
243
- train_latent: Training latent variables
244
- train_expression: Training expression data
245
- val_latent: Validation latent variables
246
- val_expression: Validation expression data
247
- batch_size: Reduced batch size for memory constraints
248
- num_epochs: Number of training epochs
249
- learning_rate: Learning rate
250
- checkpoint_path: Model save path
251
- """
252
- print("🚀 Starting Training...")
253
- print(f"📊 Batch size: {batch_size}")
254
-
255
- # Enable memory optimizations
256
- torch.backends.cudnn.benchmark = True
257
- if self.mixed_precision:
258
- scaler = torch.cuda.amp.GradScaler()
259
-
260
- # Data preparation
261
- train_dataset = self._create_dataset(train_latent, train_expression)
262
-
263
- if val_latent is not None and val_expression is not None:
264
- val_dataset = self._create_dataset(val_latent, val_expression)
265
- else:
266
- # Auto split
267
- train_size = int(0.9 * len(train_dataset))
268
- val_size = len(train_dataset) - train_size
269
- train_dataset, val_dataset = torch.utils.data.random_split(
270
- train_dataset, [train_size, val_size]
271
- )
272
-
273
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
274
- pin_memory=True, num_workers=2)
275
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
276
- pin_memory=True, num_workers=2)
277
-
278
- # Optimizer with memory-friendly settings
279
- optimizer = optim.AdamW(
280
- self.model.parameters(),
281
- lr=learning_rate,
282
- weight_decay=0.01,
283
- betas=(0.9, 0.999)
284
- )
285
-
286
- # Learning rate scheduler
287
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
288
-
289
- # Loss function
290
- criterion = nn.MSELoss()
291
-
292
- # Training history
293
- history = {
294
- 'train_loss': [], 'val_loss': [],
295
- 'learning_rate': [], 'memory_usage': []
296
- }
297
-
298
- best_val_loss = float('inf')
299
-
300
- for epoch in range(1, num_epochs + 1):
301
- print(f"\n📍 Epoch {epoch}/{num_epochs}")
302
-
303
- # Training phase with memory monitoring
304
- train_loss = self._train_epoch(
305
- train_loader, optimizer, criterion, scaler if self.mixed_precision else None
306
- )
307
-
308
- # Validation phase
309
- val_loss = self._validate_epoch(val_loader, criterion)
310
-
311
- # Update scheduler
312
- scheduler.step()
313
-
314
- # Record history
315
- history['train_loss'].append(train_loss)
316
- history['val_loss'].append(val_loss)
317
- history['learning_rate'].append(optimizer.param_groups[0]['lr'])
318
-
319
- # Memory usage tracking
320
- if torch.cuda.is_available():
321
- memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
322
- history['memory_usage'].append(memory_used)
323
- print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
324
-
325
- print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
326
- print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
327
-
328
- # Save best model
329
- if val_loss < best_val_loss:
330
- best_val_loss = val_loss
331
- self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
332
- print("💾 Best model saved!")
333
-
334
- self.is_trained = True
335
- self.training_history = history
336
- self.best_val_loss = best_val_loss
337
-
338
- print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
339
- return history
340
-
341
- def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
342
- """Training epoch"""
343
- self.model.train()
344
- total_loss = 0
345
-
346
- pbar = tqdm(train_loader, desc='Training')
347
- for latent, target in pbar:
348
- latent = latent.to(self.device, non_blocking=True)
349
- target = target.to(self.device, non_blocking=True)
350
-
351
- optimizer.zero_grad(set_to_none=True) # Memory optimization
352
-
353
- if scaler: # Mixed precision training
354
- with torch.cuda.amp.autocast():
355
- pred = self.model(latent)
356
- loss = criterion(pred, target)
357
-
358
- scaler.scale(loss).backward()
359
- scaler.step(optimizer)
360
- scaler.update()
361
- else:
362
- pred = self.model(latent)
363
- loss = criterion(pred, target)
364
- loss.backward()
365
- optimizer.step()
366
-
367
- total_loss += loss.item()
368
- pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
369
-
370
- # Clear memory
371
- del pred, loss
372
- if torch.cuda.is_available():
373
- torch.cuda.empty_cache()
374
-
375
- return total_loss / len(train_loader)
376
-
377
- def _validate_epoch(self, val_loader, criterion):
378
- """Validation"""
379
- self.model.eval()
380
- total_loss = 0
381
-
382
- with torch.no_grad():
383
- for latent, target in val_loader:
384
- latent = latent.to(self.device, non_blocking=True)
385
- target = target.to(self.device, non_blocking=True)
386
-
387
- pred = self.model(latent)
388
- loss = criterion(pred, target)
389
- total_loss += loss.item()
390
-
391
- # Clear memory
392
- del pred, loss
393
-
394
- return total_loss / len(val_loader)
395
-
396
- def _create_dataset(self, latent_data, expression_data):
397
- """Create dataset"""
398
- class EfficientDataset(Dataset):
399
- def __init__(self, latent, expression):
400
- self.latent = torch.FloatTensor(latent)
401
- self.expression = torch.FloatTensor(expression)
402
-
403
- def __len__(self):
404
- return len(self.latent)
405
-
406
- def __getitem__(self, idx):
407
- return self.latent[idx], self.expression[idx]
408
-
409
- return EfficientDataset(latent_data, expression_data)
410
-
411
- def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
412
- """Save checkpoint"""
413
- torch.save({
414
- 'epoch': epoch,
415
- 'model_state_dict': self.model.state_dict(),
416
- 'optimizer_state_dict': optimizer.state_dict(),
417
- 'scheduler_state_dict': scheduler.state_dict(),
418
- 'best_val_loss': best_loss,
419
- 'training_history': history,
420
- 'model_config': {
421
- 'latent_dim': self.latent_dim,
422
- 'gene_dim': self.gene_dim,
423
- 'hidden_dim': self.hidden_dim
424
- }
425
- }, path)
426
-
427
- def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
428
- """
429
- Prediction
430
-
431
- Args:
432
- latent_data: Latent variables [n_samples, latent_dim]
433
- batch_size: Prediction batch size for memory control
434
-
435
- Returns:
436
- expression: Predicted expression [n_samples, gene_dim]
437
- """
438
- if not self.is_trained:
439
- warnings.warn("Model not trained. Predictions may be inaccurate.")
440
-
441
- self.model.eval()
442
-
443
- if isinstance(latent_data, np.ndarray):
444
- latent_data = torch.FloatTensor(latent_data)
445
-
446
- # Predict in batches to save memory
447
- predictions = []
448
- with torch.no_grad():
449
- for i in range(0, len(latent_data), batch_size):
450
- batch_latent = latent_data[i:i+batch_size].to(self.device)
451
- batch_pred = self.model(batch_latent)
452
- predictions.append(batch_pred.cpu())
453
-
454
- # Clear memory
455
- del batch_pred
456
- if torch.cuda.is_available():
457
- torch.cuda.empty_cache()
458
-
459
- return torch.cat(predictions).numpy()
460
-
461
- def load_model(self, model_path: str):
462
- """Load pre-trained model"""
463
- checkpoint = torch.load(model_path, map_location=self.device)
464
- self.model.load_state_dict(checkpoint['model_state_dict'])
465
- self.is_trained = True
466
- self.training_history = checkpoint.get('training_history')
467
- self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
468
- print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
469
-
470
- def get_memory_info(self) -> Dict:
471
- """Get memory usage information"""
472
- if torch.cuda.is_available():
473
- memory_allocated = torch.cuda.memory_allocated() / 1024**3
474
- memory_reserved = torch.cuda.memory_reserved() / 1024**3
475
- return {
476
- 'allocated_gb': memory_allocated,
477
- 'reserved_gb': memory_reserved,
478
- 'available_gb': 20 - memory_allocated,
479
- 'utilization_percent': (memory_allocated / 20) * 100
480
- }
481
- return {'available_gb': 'N/A (CPU mode)'}
482
-
483
- '''
484
- # Example usage with memory monitoring
485
- def example_usage():
486
- """Memory-efficient example"""
487
-
488
- # 1. Initialize memory-efficient decoder
489
- decoder = TranscriptomeDecoder(
490
- latent_dim=100,
491
- gene_dim=2000, # Reduced for example
492
- hidden_dim=256 # Reduced for memory
493
- )
494
-
495
- # Check memory info
496
- memory_info = decoder.get_memory_info()
497
- print(f"📊 Memory Info: {memory_info}")
498
-
499
- # 2. Generate example data
500
- n_samples = 500 # Reduced for memory
501
- latent_data = np.random.randn(n_samples, 100).astype(np.float32)
502
- expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
503
- expression_data = np.maximum(expression_data, 0) # Non-negative
504
-
505
- print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
506
-
507
- # 3. Train with memory monitoring
508
- history = decoder.train(
509
- train_latent=latent_data,
510
- train_expression=expression_data,
511
- batch_size=8, # Small batch for memory
512
- num_epochs=20 # Reduced for example
513
- )
514
-
515
- # 4. Memory-efficient prediction
516
- test_latent = np.random.randn(5, 100).astype(np.float32)
517
- predictions = decoder.predict(test_latent, batch_size=2)
518
- print(f"🔮 Prediction shape: {predictions.shape}")
519
-
520
- # 5. Final memory check
521
- final_memory = decoder.get_memory_info()
522
- print(f"💾 Final memory usage: {final_memory}")
523
-
524
- return decoder
525
-
526
- if __name__ == "__main__":
527
- example_usage()
528
-
529
- '''
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes