SURE-tools 2.4.2__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -0,0 +1,567 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
10
+
11
+ class SimpleTranscriptomeDecoder:
12
+ """MLP-based transcriptome decoder for latent to expression mapping"""
13
+
14
+ def __init__(self,
15
+ latent_dim: int = 100,
16
+ gene_dim: int = 60000,
17
+ hidden_dims: List[int] = [512, 1024, 2048, 4096],
18
+ dropout_rate: float = 0.1,
19
+ device: str = None):
20
+ """
21
+ Multi-Layer Perceptron based decoder for transcriptome prediction
22
+
23
+ Args:
24
+ latent_dim: Latent variable dimension
25
+ gene_dim: Number of genes (full transcriptome)
26
+ hidden_dims: List of hidden layer dimensions
27
+ dropout_rate: Dropout rate for regularization
28
+ device: Computation device
29
+ """
30
+ self.latent_dim = latent_dim
31
+ self.gene_dim = gene_dim
32
+ self.hidden_dims = hidden_dims
33
+ self.dropout_rate = dropout_rate
34
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
35
+
36
+ # Initialize model
37
+ self.model = self._build_mlp_model()
38
+ self.model.to(self.device)
39
+
40
+ # Training state
41
+ self.is_trained = False
42
+ self.training_history = None
43
+ self.best_val_loss = float('inf')
44
+
45
+ print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
46
+ print(f" - Latent Dimension: {latent_dim}")
47
+ print(f" - Gene Dimension: {gene_dim}")
48
+ print(f" - Hidden Dimensions: {hidden_dims}")
49
+ print(f" - Dropout Rate: {dropout_rate}")
50
+ print(f" - Device: {self.device}")
51
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
+
53
+ class MLPModel(nn.Module):
54
+ """MLP-based decoder architecture"""
55
+
56
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int], dropout_rate: float):
57
+ super().__init__()
58
+ self.latent_dim = latent_dim
59
+ self.gene_dim = gene_dim
60
+
61
+ # Build the MLP layers
62
+ layers = []
63
+ input_dim = latent_dim
64
+
65
+ # Encoder part: expand latent dimension
66
+ for hidden_dim in hidden_dims:
67
+ layers.extend([
68
+ nn.Linear(input_dim, hidden_dim),
69
+ nn.BatchNorm1d(hidden_dim),
70
+ nn.GELU(),
71
+ nn.Dropout(dropout_rate)
72
+ ])
73
+ input_dim = hidden_dim
74
+
75
+ # Decoder part: project to gene dimension
76
+ # Reverse the hidden_dims for decoder
77
+ decoder_dims = hidden_dims[::-1]
78
+ for i, hidden_dim in enumerate(decoder_dims[1:], 1):
79
+ layers.extend([
80
+ nn.Linear(input_dim, hidden_dim),
81
+ nn.BatchNorm1d(hidden_dim),
82
+ nn.GELU(),
83
+ nn.Dropout(dropout_rate)
84
+ ])
85
+ input_dim = hidden_dim
86
+
87
+ # Final projection to gene dimension
88
+ layers.append(nn.Linear(input_dim, gene_dim))
89
+
90
+ self.mlp_layers = nn.Sequential(*layers)
91
+
92
+ # Output scaling parameters
93
+ self.output_scale = nn.Parameter(torch.ones(1))
94
+ self.output_bias = nn.Parameter(torch.zeros(1))
95
+
96
+ self._init_weights()
97
+
98
+ def _init_weights(self):
99
+ """Weight initialization"""
100
+ for module in self.modules():
101
+ if isinstance(module, nn.Linear):
102
+ nn.init.xavier_uniform_(module.weight)
103
+ if module.bias is not None:
104
+ nn.init.zeros_(module.bias)
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ # Pass through MLP layers
108
+ output = self.mlp_layers(x)
109
+
110
+ # Ensure non-negative output with softplus
111
+ output = F.softplus(output * self.output_scale + self.output_bias)
112
+
113
+ return output
114
+
115
+ class ResidualMLPModel(nn.Module):
116
+ """Residual MLP decoder with skip connections"""
117
+
118
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int], dropout_rate: float):
119
+ super().__init__()
120
+ self.latent_dim = latent_dim
121
+ self.gene_dim = gene_dim
122
+
123
+ # Build residual blocks
124
+ self.blocks = nn.ModuleList()
125
+ input_dim = latent_dim
126
+
127
+ for hidden_dim in hidden_dims:
128
+ block = self._build_residual_block(input_dim, hidden_dim, dropout_rate)
129
+ self.blocks.append(block)
130
+ input_dim = hidden_dim
131
+
132
+ # Final projection to gene dimension
133
+ self.final_projection = nn.Sequential(
134
+ nn.Linear(input_dim, input_dim // 2),
135
+ nn.GELU(),
136
+ nn.Dropout(dropout_rate),
137
+ nn.Linear(input_dim // 2, gene_dim)
138
+ )
139
+
140
+ # Output parameters
141
+ self.output_scale = nn.Parameter(torch.ones(1))
142
+ self.output_bias = nn.Parameter(torch.zeros(1))
143
+
144
+ self._init_weights()
145
+
146
+ def _build_residual_block(self, input_dim: int, hidden_dim: int, dropout_rate: float) -> nn.Module:
147
+ """Build a residual block with skip connection"""
148
+ return nn.Sequential(
149
+ nn.Linear(input_dim, hidden_dim),
150
+ nn.BatchNorm1d(hidden_dim),
151
+ nn.GELU(),
152
+ nn.Dropout(dropout_rate),
153
+ nn.Linear(hidden_dim, hidden_dim), # Residual path
154
+ nn.BatchNorm1d(hidden_dim),
155
+ nn.GELU(),
156
+ nn.Dropout(dropout_rate),
157
+ )
158
+
159
+ def _init_weights(self):
160
+ """Weight initialization"""
161
+ for module in self.modules():
162
+ if isinstance(module, nn.Linear):
163
+ nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
164
+ if module.bias is not None:
165
+ nn.init.zeros_(module.bias)
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ # Initial projection
169
+ identity = x
170
+
171
+ for block in self.blocks:
172
+ # Residual connection
173
+ out = block(x)
174
+ # Skip connection if dimensions match, otherwise project
175
+ if out.shape[1] == identity.shape[1]:
176
+ x = out + identity
177
+ else:
178
+ x = out
179
+ identity = x
180
+
181
+ # Final projection
182
+ output = self.final_projection(x)
183
+ output = F.softplus(output * self.output_scale + self.output_bias)
184
+
185
+ return output
186
+
187
+ def _build_mlp_model(self):
188
+ """Build the MLP model - 修正了方法名冲突"""
189
+ # Use simple MLP model for stability
190
+ return self.MLPModel(
191
+ self.latent_dim,
192
+ self.gene_dim,
193
+ self.hidden_dims,
194
+ self.dropout_rate
195
+ )
196
+
197
+ def train(self,
198
+ train_latent: np.ndarray,
199
+ train_expression: np.ndarray,
200
+ val_latent: np.ndarray = None,
201
+ val_expression: np.ndarray = None,
202
+ batch_size: int = 32,
203
+ num_epochs: int = 100,
204
+ learning_rate: float = 1e-4,
205
+ weight_decay: float = 1e-5,
206
+ checkpoint_path: str = 'mlp_decoder.pth') -> Dict:
207
+ """
208
+ Train the MLP decoder model
209
+
210
+ Args:
211
+ train_latent: Training latent variables [n_samples, latent_dim]
212
+ train_expression: Training expression data [n_samples, gene_dim]
213
+ val_latent: Validation latent variables (optional)
214
+ val_expression: Validation expression data (optional)
215
+ batch_size: Batch size for training
216
+ num_epochs: Number of training epochs
217
+ learning_rate: Learning rate
218
+ weight_decay: Weight decay for regularization
219
+ checkpoint_path: Path to save the best model
220
+
221
+ Returns:
222
+ Training history dictionary
223
+ """
224
+ print("🚀 Starting MLP Decoder Training...")
225
+
226
+ # Data validation
227
+ self._validate_input_data(train_latent, train_expression, "Training")
228
+ if val_latent is not None and val_expression is not None:
229
+ self._validate_input_data(val_latent, val_expression, "Validation")
230
+
231
+ # Create datasets and data loaders
232
+ train_dataset = self._create_dataset(train_latent, train_expression)
233
+
234
+ if val_latent is not None and val_expression is not None:
235
+ val_dataset = self._create_dataset(val_latent, val_expression)
236
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
237
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
238
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
239
+ else:
240
+ # Auto split
241
+ train_size = int(0.9 * len(train_dataset))
242
+ val_size = len(train_dataset) - train_size
243
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
244
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
245
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
246
+ print(f"📈 Auto-split validation: {val_size} samples")
247
+
248
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
249
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
250
+ print(f"📊 Batch size: {batch_size}")
251
+
252
+ # Optimizer configuration
253
+ optimizer = optim.AdamW(
254
+ self.model.parameters(),
255
+ lr=learning_rate,
256
+ weight_decay=weight_decay,
257
+ betas=(0.9, 0.999)
258
+ )
259
+
260
+ # Learning rate scheduler
261
+ scheduler = optim.lr_scheduler.OneCycleLR(
262
+ optimizer,
263
+ max_lr=learning_rate * 10,
264
+ epochs=num_epochs,
265
+ steps_per_epoch=len(train_loader)
266
+ )
267
+
268
+ # Loss function combining MSE and Poisson loss
269
+ def combined_loss(pred, target):
270
+ mse_loss = F.mse_loss(pred, target)
271
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
272
+ correlation_loss = 1 - self._pearson_correlation(pred, target)
273
+ return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
274
+
275
+ # Training history
276
+ history = {
277
+ 'train_loss': [], 'val_loss': [],
278
+ 'train_mse': [], 'val_mse': [],
279
+ 'train_correlation': [], 'val_correlation': [],
280
+ 'learning_rates': []
281
+ }
282
+
283
+ best_val_loss = float('inf')
284
+ patience = 20
285
+ patience_counter = 0
286
+
287
+ print("\n📈 Starting training loop...")
288
+ for epoch in range(1, num_epochs + 1):
289
+ # Training phase
290
+ train_metrics = self._train_epoch(train_loader, optimizer, scheduler, combined_loss)
291
+
292
+ # Validation phase
293
+ val_metrics = self._validate_epoch(val_loader, combined_loss)
294
+
295
+ # Record history
296
+ history['train_loss'].append(train_metrics['total_loss'])
297
+ history['train_mse'].append(train_metrics['mse_loss'])
298
+ history['train_correlation'].append(train_metrics['correlation'])
299
+
300
+ history['val_loss'].append(val_metrics['val_loss'])
301
+ history['val_mse'].append(val_metrics['val_mse'])
302
+ history['val_correlation'].append(val_metrics['val_correlation'])
303
+
304
+ history['learning_rates'].append(optimizer.param_groups[0]['lr'])
305
+
306
+ # Print progress
307
+ if epoch % 10 == 0 or epoch == 1:
308
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
309
+ f"Train Loss: {train_metrics['total_loss']:.4f} | "
310
+ f"Val Loss: {val_metrics['val_loss']:.4f} | "
311
+ f"Correlation: {val_metrics['val_correlation']:.4f} | "
312
+ f"LR: {optimizer.param_groups[0]['lr']:.2e}")
313
+
314
+ # Early stopping and model saving
315
+ if val_metrics['val_loss'] < best_val_loss:
316
+ best_val_loss = val_metrics['val_loss']
317
+ patience_counter = 0
318
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
319
+ if epoch % 20 == 0:
320
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
321
+ else:
322
+ patience_counter += 1
323
+ if patience_counter >= patience:
324
+ print(f"🛑 Early stopping at epoch {epoch}")
325
+ break
326
+
327
+ # Training completed
328
+ self.is_trained = True
329
+ self.training_history = history
330
+ self.best_val_loss = best_val_loss
331
+
332
+ print(f"\n🎉 Training completed!")
333
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
334
+
335
+ return history
336
+
337
+ def _validate_input_data(self, latent_data: np.ndarray, expression_data: np.ndarray, data_type: str):
338
+ """Validate input data dimensions and types"""
339
+ assert latent_data.shape[1] == self.latent_dim, \
340
+ f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}"
341
+ assert expression_data.shape[1] == self.gene_dim, \
342
+ f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}"
343
+ assert latent_data.shape[0] == expression_data.shape[0], \
344
+ f"{data_type} sample count mismatch"
345
+ print(f"✅ {data_type} data validated: {latent_data.shape[0]} samples")
346
+
347
+ def _create_dataset(self, latent_data: np.ndarray, expression_data: np.ndarray) -> Dataset:
348
+ """Create PyTorch dataset"""
349
+ class TranscriptomeDataset(Dataset):
350
+ def __init__(self, latent, expression):
351
+ self.latent = torch.FloatTensor(latent)
352
+ self.expression = torch.FloatTensor(expression)
353
+
354
+ def __len__(self):
355
+ return len(self.latent)
356
+
357
+ def __getitem__(self, idx):
358
+ return self.latent[idx], self.expression[idx]
359
+
360
+ return TranscriptomeDataset(latent_data, expression_data)
361
+
362
+ def _pearson_correlation(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
363
+ """Calculate Pearson correlation coefficient"""
364
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
365
+ target_centered = target - target.mean(dim=1, keepdim=True)
366
+
367
+ numerator = (pred_centered * target_centered).sum(dim=1)
368
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
369
+
370
+ return (numerator / (denominator + 1e-8)).mean()
371
+
372
+ def _train_epoch(self, train_loader, optimizer, scheduler, loss_fn):
373
+ """Train for one epoch"""
374
+ self.model.train()
375
+ total_loss = 0
376
+ total_mse = 0
377
+ total_correlation = 0
378
+
379
+ for latent, target in train_loader:
380
+ latent = latent.to(self.device)
381
+ target = target.to(self.device)
382
+
383
+ optimizer.zero_grad()
384
+ pred = self.model(latent)
385
+
386
+ loss = loss_fn(pred, target)
387
+ loss.backward()
388
+
389
+ # Gradient clipping
390
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
391
+ optimizer.step()
392
+ scheduler.step()
393
+
394
+ # Calculate metrics
395
+ mse_loss = F.mse_loss(pred, target).item()
396
+ correlation = self._pearson_correlation(pred, target).item()
397
+
398
+ total_loss += loss.item()
399
+ total_mse += mse_loss
400
+ total_correlation += correlation
401
+
402
+ num_batches = len(train_loader)
403
+ return {
404
+ 'total_loss': total_loss / num_batches,
405
+ 'mse_loss': total_mse / num_batches,
406
+ 'correlation': total_correlation / num_batches
407
+ }
408
+
409
+ def _validate_epoch(self, val_loader, loss_fn):
410
+ """Validate for one epoch"""
411
+ self.model.eval()
412
+ total_loss = 0
413
+ total_mse = 0
414
+ total_correlation = 0
415
+
416
+ with torch.no_grad():
417
+ for latent, target in val_loader:
418
+ latent = latent.to(self.device)
419
+ target = target.to(self.device)
420
+
421
+ pred = self.model(latent)
422
+
423
+ loss = loss_fn(pred, target)
424
+ mse_loss = F.mse_loss(pred, target).item()
425
+ correlation = self._pearson_correlation(pred, target).item()
426
+
427
+ total_loss += loss.item()
428
+ total_mse += mse_loss
429
+ total_correlation += correlation
430
+
431
+ num_batches = len(val_loader)
432
+ return {
433
+ 'val_loss': total_loss / num_batches,
434
+ 'val_mse': total_mse / num_batches,
435
+ 'val_correlation': total_correlation / num_batches
436
+ }
437
+
438
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
439
+ """Save model checkpoint"""
440
+ torch.save({
441
+ 'epoch': epoch,
442
+ 'model_state_dict': self.model.state_dict(),
443
+ 'optimizer_state_dict': optimizer.state_dict(),
444
+ 'scheduler_state_dict': scheduler.state_dict(),
445
+ 'best_val_loss': best_loss,
446
+ 'training_history': history,
447
+ 'model_config': {
448
+ 'latent_dim': self.latent_dim,
449
+ 'gene_dim': self.gene_dim,
450
+ 'hidden_dims': self.hidden_dims,
451
+ 'dropout_rate': self.dropout_rate
452
+ }
453
+ }, path)
454
+
455
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
456
+ """
457
+ Predict gene expression from latent variables
458
+
459
+ Args:
460
+ latent_data: Latent variables [n_samples, latent_dim]
461
+ batch_size: Prediction batch size
462
+
463
+ Returns:
464
+ expression: Predicted expression [n_samples, gene_dim]
465
+ """
466
+ if not self.is_trained:
467
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
468
+
469
+ self.model.eval()
470
+
471
+ if isinstance(latent_data, np.ndarray):
472
+ latent_data = torch.FloatTensor(latent_data)
473
+
474
+ # Predict in batches
475
+ predictions = []
476
+ with torch.no_grad():
477
+ for i in range(0, len(latent_data), batch_size):
478
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
479
+ batch_pred = self.model(batch_latent)
480
+ predictions.append(batch_pred.cpu())
481
+
482
+ return torch.cat(predictions).numpy()
483
+
484
+ def load_model(self, model_path: str):
485
+ """Load pre-trained model"""
486
+ checkpoint = torch.load(model_path, map_location=self.device)
487
+
488
+ # Check model configuration
489
+ if 'model_config' in checkpoint:
490
+ config = checkpoint['model_config']
491
+ if (config['latent_dim'] != self.latent_dim or
492
+ config['gene_dim'] != self.gene_dim):
493
+ print("⚠️ Model configuration mismatch. Reinitializing model.")
494
+ self.model = self._build_mlp_model()
495
+ self.model.to(self.device)
496
+
497
+ self.model.load_state_dict(checkpoint['model_state_dict'])
498
+ self.is_trained = True
499
+ self.training_history = checkpoint.get('training_history')
500
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
501
+
502
+ print(f"✅ Model loaded successfully!")
503
+ print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
504
+
505
+ def get_model_info(self) -> Dict:
506
+ """Get model information"""
507
+ return {
508
+ 'is_trained': self.is_trained,
509
+ 'best_val_loss': self.best_val_loss,
510
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
511
+ 'latent_dim': self.latent_dim,
512
+ 'gene_dim': self.gene_dim,
513
+ 'hidden_dims': self.hidden_dims,
514
+ 'dropout_rate': self.dropout_rate,
515
+ 'device': str(self.device)
516
+ }
517
+
518
+ '''
519
+ # Example usage
520
+ def example_usage():
521
+ """Example demonstration of MLP decoder"""
522
+
523
+ # 1. Initialize decoder
524
+ decoder = SimpleTranscriptomeDecoder(
525
+ latent_dim=100,
526
+ gene_dim=2000, # Reduced for example
527
+ hidden_dims=[256, 512, 1024], # Progressive expansion
528
+ dropout_rate=0.1
529
+ )
530
+
531
+ # 2. Generate example data
532
+ n_samples = 1000
533
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
534
+
535
+ # Create simulated expression data
536
+ weights = np.random.randn(100, 2000) * 0.1
537
+ expression_data = np.tanh(latent_data.dot(weights))
538
+ expression_data = np.maximum(expression_data, 0) # Non-negative
539
+
540
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
541
+
542
+ # 3. Train the model
543
+ history = decoder.train(
544
+ train_latent=latent_data,
545
+ train_expression=expression_data,
546
+ batch_size=32,
547
+ num_epochs=50,
548
+ learning_rate=1e-4
549
+ )
550
+
551
+ # 4. Make predictions
552
+ test_latent = np.random.randn(10, 100).astype(np.float32)
553
+ predictions = decoder.predict(test_latent)
554
+ print(f"🔮 Prediction shape: {predictions.shape}")
555
+
556
+ # 5. Get model info
557
+ info = decoder.get_model_info()
558
+ print(f"\n📋 Model Info:")
559
+ for key, value in info.items():
560
+ print(f" {key}: {value}")
561
+
562
+ return decoder
563
+
564
+ if __name__ == "__main__":
565
+ example_usage()
566
+
567
+ '''