SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,511 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, Optional
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
10
+
11
+ class TranscriptomeDecoder:
12
+ """Transcriptome decoder"""
13
+
14
+ def __init__(self,
15
+ latent_dim: int = 100,
16
+ gene_dim: int = 60000,
17
+ hidden_dim: int = 512,
18
+ device: str = None):
19
+ """
20
+ Simple but powerful decoder for latent to transcriptome mapping
21
+
22
+ Args:
23
+ latent_dim: Latent variable dimension (typically 50-100)
24
+ gene_dim: Number of genes (full transcriptome ~60,000)
25
+ hidden_dim: Hidden dimension optimized
26
+ device: Computation device
27
+ """
28
+ self.latent_dim = latent_dim
29
+ self.gene_dim = gene_dim
30
+ self.hidden_dim = hidden_dim
31
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Initialize model
34
+ self.model = self._build_model()
35
+ self.model.to(self.device)
36
+
37
+ # Training state
38
+ self.is_trained = False
39
+ self.training_history = None
40
+ self.best_val_loss = float('inf')
41
+
42
+ print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
43
+ print(f" - Latent Dimension: {latent_dim}")
44
+ print(f" - Gene Dimension: {gene_dim}")
45
+ print(f" - Hidden Dimension: {hidden_dim}")
46
+ print(f" - Device: {self.device}")
47
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
48
+
49
+ class Decoder(nn.Module):
50
+ """Memory-efficient decoder architecture with dimension handling"""
51
+
52
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
53
+ super().__init__()
54
+ self.latent_dim = latent_dim
55
+ self.gene_dim = gene_dim
56
+ self.hidden_dim = hidden_dim
57
+
58
+ # Stage 1: Latent variable expansion
59
+ self.latent_expansion = nn.Sequential(
60
+ nn.Linear(latent_dim, hidden_dim * 2),
61
+ nn.BatchNorm1d(hidden_dim * 2),
62
+ nn.GELU(),
63
+ nn.Dropout(0.1),
64
+ nn.Linear(hidden_dim * 2, hidden_dim),
65
+ nn.BatchNorm1d(hidden_dim),
66
+ nn.GELU(),
67
+ )
68
+
69
+ # Stage 2: Direct projection to gene dimension (simpler approach)
70
+ self.gene_projector = nn.Sequential(
71
+ nn.Linear(hidden_dim, hidden_dim * 2),
72
+ nn.GELU(),
73
+ nn.Dropout(0.1),
74
+ nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
75
+ )
76
+
77
+ # Stage 3: Lightweight gene interaction
78
+ self.gene_interaction = nn.Sequential(
79
+ nn.Conv1d(1, 32, kernel_size=3, padding=1),
80
+ nn.GELU(),
81
+ nn.Dropout1d(0.1),
82
+ nn.Conv1d(32, 1, kernel_size=3, padding=1),
83
+ )
84
+
85
+ # Output scaling
86
+ self.output_scale = nn.Parameter(torch.ones(1))
87
+ self.output_bias = nn.Parameter(torch.zeros(1))
88
+
89
+ self._init_weights()
90
+
91
+ def _init_weights(self):
92
+ """Weight initialization"""
93
+ for module in self.modules():
94
+ if isinstance(module, nn.Linear):
95
+ nn.init.xavier_uniform_(module.weight)
96
+ if module.bias is not None:
97
+ nn.init.zeros_(module.bias)
98
+ elif isinstance(module, nn.Conv1d):
99
+ nn.init.kaiming_uniform_(module.weight)
100
+ if module.bias is not None:
101
+ nn.init.zeros_(module.bias)
102
+
103
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
104
+ batch_size = latent.shape[0]
105
+
106
+ # 1. Expand latent variables
107
+ latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
108
+
109
+ # 2. Direct projection to gene dimension
110
+ gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
111
+
112
+ # 3. Gene interaction with dimension safety
113
+ if self.gene_dim > 1: # Only apply if gene_dim > 1
114
+ gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
115
+ interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
116
+ gene_output = gene_output + interaction_output # Residual connection
117
+ gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
118
+
119
+ # 4. Final activation (ensure non-negative)
120
+ gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
121
+
122
+ return gene_output
123
+
124
+ def _build_model(self):
125
+ """Build the decoder model"""
126
+ return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
127
+
128
+ def _create_dataset(self, latent_data, expression_data):
129
+ """Create dataset with dimension validation"""
130
+ class SimpleDataset(Dataset):
131
+ def __init__(self, latent, expression):
132
+ # Ensure dimensions match
133
+ assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
134
+ assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
135
+ assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
136
+
137
+ self.latent = torch.FloatTensor(latent)
138
+ self.expression = torch.FloatTensor(expression)
139
+
140
+ def __len__(self):
141
+ return len(self.latent)
142
+
143
+ def __getitem__(self, idx):
144
+ return self.latent[idx], self.expression[idx]
145
+
146
+ return SimpleDataset(latent_data, expression_data)
147
+
148
+ def train(self,
149
+ train_latent: np.ndarray,
150
+ train_expression: np.ndarray,
151
+ val_latent: np.ndarray = None,
152
+ val_expression: np.ndarray = None,
153
+ batch_size: int = 32,
154
+ num_epochs: int = 100,
155
+ learning_rate: float = 1e-4,
156
+ checkpoint_path: str = 'transcriptome_decoder.pth'):
157
+ """
158
+ Train the decoder model with dimension safety
159
+
160
+ Args:
161
+ train_latent: Training latent variables [n_samples, latent_dim]
162
+ train_expression: Training expression data [n_samples, gene_dim]
163
+ val_latent: Validation latent variables (optional)
164
+ val_expression: Validation expression data (optional)
165
+ batch_size: Batch size optimized for memory
166
+ num_epochs: Number of training epochs
167
+ learning_rate: Learning rate
168
+ checkpoint_path: Path to save the best model
169
+ """
170
+ print("🚀 Starting training...")
171
+
172
+ # Dimension validation
173
+ self._validate_data_dimensions(train_latent, train_expression, "Training")
174
+ if val_latent is not None and val_expression is not None:
175
+ self._validate_data_dimensions(val_latent, val_expression, "Validation")
176
+
177
+ # Data preparation
178
+ train_dataset = self._create_safe_dataset(train_latent, train_expression)
179
+
180
+ if val_latent is not None and val_expression is not None:
181
+ val_dataset = self._create_safe_dataset(val_latent, val_expression)
182
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
183
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
184
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
185
+ else:
186
+ # Auto split
187
+ train_size = int(0.9 * len(train_dataset))
188
+ val_size = len(train_dataset) - train_size
189
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
190
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
191
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
192
+ print(f"📈 Auto-split validation: {val_size} samples")
193
+
194
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
195
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
196
+ print(f"📊 Batch size: {batch_size}")
197
+
198
+ # Optimizer configuration
199
+ optimizer = optim.AdamW(
200
+ self.model.parameters(),
201
+ lr=learning_rate,
202
+ weight_decay=0.01,
203
+ betas=(0.9, 0.999)
204
+ )
205
+
206
+ # Learning rate scheduler
207
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
208
+
209
+ # Loss function with dimension safety
210
+ def safe_loss(pred, target):
211
+ # Ensure dimensions match
212
+ if pred.shape != target.shape:
213
+ print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
214
+ # Truncate to minimum dimension (safety measure)
215
+ min_dim = min(pred.shape[1], target.shape[1])
216
+ pred = pred[:, :min_dim]
217
+ target = target[:, :min_dim]
218
+
219
+ def correlation_loss(pred, target):
220
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
221
+ target_centered = target - target.mean(dim=1, keepdim=True)
222
+
223
+ correlation = (pred_centered * target_centered).sum(dim=1) / (
224
+ torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
225
+ torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
226
+ )
227
+
228
+ return 1 - correlation.mean()
229
+
230
+ mse_loss = F.mse_loss(pred, target)
231
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
232
+ corr_loss = correlation_loss(pred, target)
233
+ return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
234
+
235
+ # Training history
236
+ history = {
237
+ 'train_loss': [],
238
+ 'val_loss': [],
239
+ 'learning_rate': []
240
+ }
241
+
242
+ best_val_loss = float('inf')
243
+ patience = 15
244
+ patience_counter = 0
245
+
246
+ print("\n📈 Starting training loop...")
247
+ for epoch in range(1, num_epochs + 1):
248
+ # Training phase
249
+ train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
250
+
251
+ # Validation phase
252
+ val_loss = self._validate_epoch(val_loader, safe_loss)
253
+
254
+ # Update scheduler
255
+ scheduler.step()
256
+ current_lr = scheduler.get_last_lr()[0]
257
+
258
+ # Record history
259
+ history['train_loss'].append(train_loss)
260
+ history['val_loss'].append(val_loss)
261
+ history['learning_rate'].append(current_lr)
262
+
263
+ # Print progress
264
+ if epoch % 5 == 0 or epoch == 1:
265
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
266
+ f"Train Loss: {train_loss:.4f} | "
267
+ f"Val Loss: {val_loss:.4f} | "
268
+ f"LR: {current_lr:.2e}")
269
+
270
+ # Early stopping and model saving
271
+ if val_loss < best_val_loss:
272
+ best_val_loss = val_loss
273
+ patience_counter = 0
274
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
275
+ if epoch % 10 == 0:
276
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
277
+ else:
278
+ patience_counter += 1
279
+ if patience_counter >= patience:
280
+ print(f"🛑 Early stopping at epoch {epoch}")
281
+ break
282
+
283
+ # Training completed
284
+ self.is_trained = True
285
+ self.training_history = history
286
+ self.best_val_loss = best_val_loss
287
+
288
+ print(f"\n🎉 Training completed!")
289
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
290
+ print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
291
+
292
+ return history
293
+
294
+ def _validate_data_dimensions(self, latent_data, expression_data, data_type):
295
+ """Validate input data dimensions"""
296
+ assert latent_data.shape[1] == self.latent_dim, (
297
+ f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
298
+ assert expression_data.shape[1] == self.gene_dim, (
299
+ f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
300
+ assert latent_data.shape[0] == expression_data.shape[0], (
301
+ f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
302
+ print(f"✅ {data_type} data dimensions validated")
303
+
304
+ def _create_safe_dataset(self, latent_data, expression_data):
305
+ """Create dataset with safety checks"""
306
+ class SafeDataset(Dataset):
307
+ def __init__(self, latent, expression):
308
+ self.latent = torch.FloatTensor(latent)
309
+ self.expression = torch.FloatTensor(expression)
310
+
311
+ # Safety check
312
+ if self.latent.shape[0] != self.expression.shape[0]:
313
+ raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
314
+
315
+ def __len__(self):
316
+ return len(self.latent)
317
+
318
+ def __getitem__(self, idx):
319
+ return self.latent[idx], self.expression[idx]
320
+
321
+ return SafeDataset(latent_data, expression_data)
322
+
323
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
324
+ """Train for one epoch with dimension safety"""
325
+ self.model.train()
326
+ total_loss = 0
327
+
328
+ for batch_idx, (latent, target) in enumerate(train_loader):
329
+ latent = latent.to(self.device)
330
+ target = target.to(self.device)
331
+
332
+ # Dimension check
333
+ if latent.shape[1] != self.latent_dim:
334
+ print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
335
+ continue
336
+
337
+ optimizer.zero_grad()
338
+ pred = self.model(latent)
339
+
340
+ # Final dimension check before loss calculation
341
+ if pred.shape[1] != target.shape[1]:
342
+ min_dim = min(pred.shape[1], target.shape[1])
343
+ pred = pred[:, :min_dim]
344
+ target = target[:, :min_dim]
345
+ if batch_idx == 0: # Only warn once
346
+ print(f"⚠️ Truncating to min dimension: {min_dim}")
347
+
348
+ loss = loss_fn(pred, target)
349
+ loss.backward()
350
+
351
+ # Gradient clipping for stability
352
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
353
+ optimizer.step()
354
+
355
+ total_loss += loss.item()
356
+
357
+ return total_loss / len(train_loader)
358
+
359
+ def _validate_epoch(self, val_loader, loss_fn):
360
+ """Validate for one epoch with dimension safety"""
361
+ self.model.eval()
362
+ total_loss = 0
363
+
364
+ with torch.no_grad():
365
+ for batch_idx, (latent, target) in enumerate(val_loader):
366
+ latent = latent.to(self.device)
367
+ target = target.to(self.device)
368
+
369
+ pred = self.model(latent)
370
+
371
+ # Dimension safety
372
+ if pred.shape[1] != target.shape[1]:
373
+ min_dim = min(pred.shape[1], target.shape[1])
374
+ pred = pred[:, :min_dim]
375
+ target = target[:, :min_dim]
376
+
377
+ loss = loss_fn(pred, target)
378
+ total_loss += loss.item()
379
+
380
+ return total_loss / len(val_loader)
381
+
382
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
383
+ """Save model checkpoint"""
384
+ torch.save({
385
+ 'epoch': epoch,
386
+ 'model_state_dict': self.model.state_dict(),
387
+ 'optimizer_state_dict': optimizer.state_dict(),
388
+ 'scheduler_state_dict': scheduler.state_dict(),
389
+ 'best_val_loss': best_loss,
390
+ 'training_history': history,
391
+ 'model_config': {
392
+ 'latent_dim': self.latent_dim,
393
+ 'gene_dim': self.gene_dim,
394
+ 'hidden_dim': self.hidden_dim
395
+ }
396
+ }, path)
397
+
398
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
399
+ """
400
+ Predict gene expression from latent variables
401
+
402
+ Args:
403
+ latent_data: Latent variables [n_samples, latent_dim]
404
+ batch_size: Prediction batch size
405
+
406
+ Returns:
407
+ expression: Predicted expression [n_samples, gene_dim]
408
+ """
409
+ if not self.is_trained:
410
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
411
+
412
+ self.model.eval()
413
+
414
+ # Input validation
415
+ if latent_data.shape[1] != self.latent_dim:
416
+ raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
417
+
418
+ if isinstance(latent_data, np.ndarray):
419
+ latent_data = torch.FloatTensor(latent_data)
420
+
421
+ # Predict in batches to save memory
422
+ predictions = []
423
+ with torch.no_grad():
424
+ for i in range(0, len(latent_data), batch_size):
425
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
426
+ batch_pred = self.model(batch_latent)
427
+ predictions.append(batch_pred.cpu())
428
+
429
+ return torch.cat(predictions).numpy()
430
+
431
+ def load_model(self, model_path: str):
432
+ """Load pre-trained model"""
433
+ checkpoint = torch.load(model_path, map_location=self.device)
434
+
435
+ # Check model configuration
436
+ if 'model_config' in checkpoint:
437
+ config = checkpoint['model_config']
438
+ if (config['latent_dim'] != self.latent_dim or
439
+ config['gene_dim'] != self.gene_dim):
440
+ print("⚠️ Model configuration mismatch. Reinitializing model.")
441
+ self.model = self._build_model()
442
+ self.model.to(self.device)
443
+
444
+ self.model.load_state_dict(checkpoint['model_state_dict'])
445
+ self.is_trained = True
446
+ self.training_history = checkpoint.get('training_history')
447
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
448
+
449
+ print(f"✅ Model loaded successfully!")
450
+ print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
451
+
452
+ def get_model_info(self) -> Dict:
453
+ """Get model information"""
454
+ return {
455
+ 'is_trained': self.is_trained,
456
+ 'best_val_loss': self.best_val_loss,
457
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
458
+ 'latent_dim': self.latent_dim,
459
+ 'gene_dim': self.gene_dim,
460
+ 'hidden_dim': self.hidden_dim,
461
+ 'device': str(self.device)
462
+ }
463
+ '''
464
+ # Example usage
465
+ def example_usage():
466
+ """Example demonstration with dimension safety"""
467
+
468
+ # 1. Initialize decoder
469
+ decoder = SimpleTranscriptomeDecoder(
470
+ latent_dim=100,
471
+ gene_dim=2000, # Reduced for example
472
+ hidden_dim=256
473
+ )
474
+
475
+ # 2. Generate example data with correct dimensions
476
+ n_samples = 1000
477
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
478
+
479
+ # Create simulated expression data
480
+ weights = np.random.randn(100, 2000) * 0.1
481
+ expression_data = np.tanh(latent_data.dot(weights))
482
+ expression_data = np.maximum(expression_data, 0) # Non-negative
483
+
484
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
485
+
486
+ # 3. Train the model
487
+ history = decoder.train(
488
+ train_latent=latent_data,
489
+ train_expression=expression_data,
490
+ batch_size=32,
491
+ num_epochs=50,
492
+ learning_rate=1e-4
493
+ )
494
+
495
+ # 4. Make predictions
496
+ test_latent = np.random.randn(10, 100).astype(np.float32)
497
+ predictions = decoder.predict(test_latent)
498
+ print(f"🔮 Prediction shape: {predictions.shape}")
499
+
500
+ # 5. Get model info
501
+ info = decoder.get_model_info()
502
+ print(f"\n📋 Model Info:")
503
+ for key, value in info.items():
504
+ print(f" {key}: {value}")
505
+
506
+ return decoder
507
+
508
+ if __name__ == "__main__":
509
+ example_usage()
510
+
511
+ '''