SURE-tools 2.3.2__py3-none-any.whl → 2.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbE.py CHANGED
@@ -349,7 +349,8 @@ class PerturbE(nn.Module):
349
349
  else:
350
350
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
351
351
  elif self.loss_func == 'multinomial':
352
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
353
354
  elif self.loss_func == 'bernoulli':
354
355
  if self.use_zeroinflate:
355
356
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -433,7 +434,8 @@ class PerturbE(nn.Module):
433
434
  else:
434
435
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
435
436
  elif self.loss_func == 'multinomial':
436
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
437
439
  elif self.loss_func == 'bernoulli':
438
440
  if self.use_zeroinflate:
439
441
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -529,7 +531,8 @@ class PerturbE(nn.Module):
529
531
  else:
530
532
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
531
533
  elif self.loss_func == 'multinomial':
532
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
533
536
  elif self.loss_func == 'bernoulli':
534
537
  if self.use_zeroinflate:
535
538
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -635,7 +638,8 @@ class PerturbE(nn.Module):
635
638
  else:
636
639
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
637
640
  elif self.loss_func == 'multinomial':
638
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
639
643
  elif self.loss_func == 'bernoulli':
640
644
  if self.use_zeroinflate:
641
645
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -858,6 +862,9 @@ class PerturbE(nn.Module):
858
862
  if self.loss_func == 'bernoulli':
859
863
  #counts = self.sigmoid(concentrate)
860
864
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
861
868
  else:
862
869
  rate = concentrate.exp()
863
870
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -0,0 +1,549 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Tuple, Optional
8
+ import matplotlib.pyplot as plt
9
+ from tqdm import tqdm
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ class TranscriptomeDecoder:
14
+ def __init__(self,
15
+ latent_dim: int = 100,
16
+ gene_dim: int = 60000,
17
+ hidden_dim: int = 512, # Reduced for memory efficiency
18
+ device: str = None):
19
+ """
20
+ Whole-transcriptome decoder
21
+
22
+ Args:
23
+ latent_dim: Latent variable dimension (typically 50-100)
24
+ gene_dim: Number of genes (full transcriptome ~60,000)
25
+ hidden_dim: Hidden dimension (reduced for memory efficiency)
26
+ device: Computation device
27
+ """
28
+ self.latent_dim = latent_dim
29
+ self.gene_dim = gene_dim
30
+ self.hidden_dim = hidden_dim
31
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Memory optimization settings
34
+ self.gradient_checkpointing = True
35
+ self.mixed_precision = True
36
+
37
+ # Initialize model
38
+ self.model = self._build_model()
39
+ self.model.to(self.device)
40
+
41
+ # Training state
42
+ self.is_trained = False
43
+ self.training_history = None
44
+ self.best_val_loss = float('inf')
45
+
46
+ print(f"🚀 TranscriptomeDecoder Initialized:")
47
+ print(f" - Latent Dimension: {latent_dim}")
48
+ print(f" - Gene Dimension: {gene_dim}")
49
+ print(f" - Hidden Dimension: {hidden_dim}")
50
+ print(f" - Device: {self.device}")
51
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
+
53
+ class MemoryEfficientBlock(nn.Module):
54
+ """Memory-efficient building block with gradient checkpointing"""
55
+ def __init__(self, input_dim, output_dim, use_checkpointing=True):
56
+ super().__init__()
57
+ self.use_checkpointing = use_checkpointing
58
+ self.net = nn.Sequential(
59
+ nn.Linear(input_dim, output_dim),
60
+ nn.BatchNorm1d(output_dim),
61
+ nn.GELU(),
62
+ nn.Dropout(0.1)
63
+ )
64
+
65
+ def forward(self, x):
66
+ if self.use_checkpointing and self.training:
67
+ return torch.utils.checkpoint.checkpoint(self.net, x)
68
+ return self.net(x)
69
+
70
+ class SparseGeneProjection(nn.Module):
71
+ """Sparse gene projection to reduce memory usage"""
72
+ def __init__(self, latent_dim, gene_dim, projection_dim=256):
73
+ super().__init__()
74
+ self.projection_dim = projection_dim
75
+ self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
76
+ self.latent_projection = nn.Linear(latent_dim, projection_dim)
77
+ self.activation = nn.GELU()
78
+
79
+ def forward(self, latent):
80
+ # Project latent to gene space efficiently
81
+ batch_size = latent.shape[0]
82
+ latent_proj = self.latent_projection(latent) # [batch, projection_dim]
83
+
84
+ # Efficient matrix multiplication
85
+ gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
86
+ output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
+
88
+ return self.activation(output)
89
+
90
+ class ChunkedTransformer(nn.Module):
91
+ def __init__(self, gene_dim, hidden_dim=512, chunk_size=2000, num_layers=3):
92
+ super().__init__()
93
+ self.chunk_size = chunk_size
94
+ self.hidden_dim = hidden_dim
95
+ self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
+
97
+ # 共享的Transformer层
98
+ self.transformer_layers = nn.ModuleList([
99
+ nn.Sequential(
100
+ nn.Linear(hidden_dim, hidden_dim),
101
+ nn.GELU(),
102
+ nn.Dropout(0.1),
103
+ nn.Linear(hidden_dim, hidden_dim),
104
+ ) for _ in range(num_layers)
105
+ ])
106
+
107
+ # 每个chunk独立的投影层
108
+ self.input_projections = nn.ModuleList([
109
+ nn.Linear(min(chunk_size, gene_dim - i * chunk_size), hidden_dim)
110
+ for i in range(self.num_chunks)
111
+ ])
112
+ self.output_projections = nn.ModuleList([
113
+ nn.Linear(hidden_dim, min(chunk_size, gene_dim - i * chunk_size))
114
+ for i in range(self.num_chunks)
115
+ ])
116
+
117
+ def forward(self, x):
118
+ batch_size, gene_dim = x.shape
119
+ output = torch.zeros_like(x)
120
+
121
+ for i in range(self.num_chunks):
122
+ start_idx = i * self.chunk_size
123
+ end_idx = min((i + 1) * self.chunk_size, gene_dim)
124
+ current_chunk_size = end_idx - start_idx
125
+
126
+ chunk = x[:, start_idx:end_idx] # [batch_size, current_chunk_size]
127
+
128
+ # 投影到hidden_dim
129
+ chunk_proj = self.input_projections[i](chunk) # [batch_size, hidden_dim]
130
+
131
+ # Transformer处理
132
+ for layer in self.transformer_layers:
133
+ chunk_proj = layer(chunk_proj) + chunk_proj
134
+
135
+ # 投影回原始维度
136
+ chunk_out = self.output_projections[i](chunk_proj) # [batch_size, current_chunk_size]
137
+
138
+ output[:, start_idx:end_idx] = chunk_out
139
+
140
+ return output
141
+
142
+ class Decoder(nn.Module):
143
+ """Decoder model"""
144
+ def __init__(self, latent_dim, gene_dim, hidden_dim):
145
+ super().__init__()
146
+ self.latent_dim = latent_dim
147
+ self.gene_dim = gene_dim
148
+ self.hidden_dim = hidden_dim
149
+
150
+ # Stage 1: Latent expansion (memory efficient)
151
+ self.latent_expansion = nn.Sequential(
152
+ nn.Linear(latent_dim, hidden_dim * 2),
153
+ nn.GELU(),
154
+ nn.Dropout(0.1),
155
+ nn.Linear(hidden_dim * 2, hidden_dim),
156
+ )
157
+
158
+ # Stage 2: Sparse gene projection
159
+ self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
160
+ latent_dim, gene_dim, hidden_dim
161
+ )
162
+
163
+ # Stage 3: Chunked processing
164
+ self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
165
+ gene_dim, hidden_dim, chunk_size=2000, num_layers=3
166
+ )
167
+
168
+ # Stage 4: Multi-head output with memory efficiency
169
+ self.output_heads = nn.ModuleList([
170
+ nn.Sequential(
171
+ nn.Linear(hidden_dim, hidden_dim // 2),
172
+ nn.GELU(),
173
+ nn.Linear(hidden_dim // 2, 1)
174
+ ) for _ in range(2) # Reduced from 3 to 2 heads
175
+ ])
176
+
177
+ # Adaptive fusion
178
+ self.fusion_gate = nn.Sequential(
179
+ nn.Linear(hidden_dim, hidden_dim // 4),
180
+ nn.GELU(),
181
+ nn.Linear(hidden_dim // 4, len(self.output_heads)),
182
+ nn.Softmax(dim=-1)
183
+ )
184
+
185
+ # Output scaling
186
+ self.output_scale = nn.Parameter(torch.ones(1))
187
+ self.output_bias = nn.Parameter(torch.zeros(1))
188
+
189
+ self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
190
+
191
+ self._init_weights()
192
+
193
+ def _init_weights(self):
194
+ for module in self.modules():
195
+ if isinstance(module, nn.Linear):
196
+ nn.init.xavier_uniform_(module.weight)
197
+ if module.bias is not None:
198
+ nn.init.zeros_(module.bias)
199
+
200
+ def forward(self, latent):
201
+ batch_size = latent.shape[0]
202
+
203
+ # 1. Latent expansion
204
+ latent_expanded = self.latent_expansion(latent)
205
+
206
+ # 2. Gene projection (memory efficient)
207
+ gene_features = self.gene_projection(latent)
208
+
209
+ # 3. Add latent information
210
+ latent_gene_injection = self.latent_to_gene(latent_expanded)
211
+ gene_features = gene_features + latent_gene_injection
212
+
213
+ # 4. Chunked processing (memory efficient)
214
+ gene_features = self.chunked_processor(gene_features)
215
+
216
+ # 5. Multi-head output with chunking
217
+ final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
218
+
219
+ # Process output in chunks
220
+ chunk_size = 5000
221
+ for i in range(0, self.gene_dim, chunk_size):
222
+ end_idx = min(i + chunk_size, self.gene_dim)
223
+ chunk = gene_features[:, i:end_idx]
224
+
225
+ head_outputs = []
226
+ for head in self.output_heads:
227
+ head_out = head(chunk).squeeze(-1)
228
+ head_outputs.append(head_out)
229
+
230
+ # Adaptive fusion
231
+ gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
232
+ gate_weights = gate_weights.unsqueeze(1)
233
+
234
+ # Weighted fusion
235
+ chunk_output = torch.zeros_like(head_outputs[0])
236
+ for j, head_out in enumerate(head_outputs):
237
+ chunk_output = chunk_output + gate_weights[:, :, j] * head_out
238
+
239
+ final_output[:, i:end_idx] = chunk_output
240
+
241
+ # Final activation
242
+ final_output = F.softplus(final_output * self.output_scale + self.output_bias)
243
+
244
+ return final_output
245
+
246
+ def _build_model(self):
247
+ """Build model"""
248
+ return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
249
+
250
+ def train(self,
251
+ train_latent: np.ndarray,
252
+ train_expression: np.ndarray,
253
+ val_latent: np.ndarray = None,
254
+ val_expression: np.ndarray = None,
255
+ batch_size: int = 16, # Reduced batch size for memory
256
+ num_epochs: int = 100,
257
+ learning_rate: float = 1e-4,
258
+ checkpoint_path: str = 'transcriptome_decoder.pth'):
259
+ """
260
+ Memory-efficient training with optimizations
261
+
262
+ Args:
263
+ train_latent: Training latent variables
264
+ train_expression: Training expression data
265
+ val_latent: Validation latent variables
266
+ val_expression: Validation expression data
267
+ batch_size: Reduced batch size for memory constraints
268
+ num_epochs: Number of training epochs
269
+ learning_rate: Learning rate
270
+ checkpoint_path: Model save path
271
+ """
272
+ print("🚀 Starting Training...")
273
+ print(f"📊 Batch size: {batch_size}")
274
+
275
+ # Enable memory optimizations
276
+ torch.backends.cudnn.benchmark = True
277
+ if self.mixed_precision:
278
+ scaler = torch.cuda.amp.GradScaler()
279
+
280
+ # Data preparation
281
+ train_dataset = self._create_dataset(train_latent, train_expression)
282
+
283
+ if val_latent is not None and val_expression is not None:
284
+ val_dataset = self._create_dataset(val_latent, val_expression)
285
+ else:
286
+ # Auto split
287
+ train_size = int(0.9 * len(train_dataset))
288
+ val_size = len(train_dataset) - train_size
289
+ train_dataset, val_dataset = torch.utils.data.random_split(
290
+ train_dataset, [train_size, val_size]
291
+ )
292
+
293
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
294
+ pin_memory=True, num_workers=2)
295
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
296
+ pin_memory=True, num_workers=2)
297
+
298
+ # Optimizer with memory-friendly settings
299
+ optimizer = optim.AdamW(
300
+ self.model.parameters(),
301
+ lr=learning_rate,
302
+ weight_decay=0.01,
303
+ betas=(0.9, 0.999)
304
+ )
305
+
306
+ # Learning rate scheduler
307
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
308
+
309
+ # Loss function
310
+ criterion = nn.MSELoss()
311
+
312
+ # Training history
313
+ history = {
314
+ 'train_loss': [], 'val_loss': [],
315
+ 'learning_rate': [], 'memory_usage': []
316
+ }
317
+
318
+ best_val_loss = float('inf')
319
+
320
+ for epoch in range(1, num_epochs + 1):
321
+ print(f"\n📍 Epoch {epoch}/{num_epochs}")
322
+
323
+ # Training phase with memory monitoring
324
+ train_loss = self._train_epoch(
325
+ train_loader, optimizer, criterion, scaler if self.mixed_precision else None
326
+ )
327
+
328
+ # Validation phase
329
+ val_loss = self._validate_epoch(val_loader, criterion)
330
+
331
+ # Update scheduler
332
+ scheduler.step()
333
+
334
+ # Record history
335
+ history['train_loss'].append(train_loss)
336
+ history['val_loss'].append(val_loss)
337
+ history['learning_rate'].append(optimizer.param_groups[0]['lr'])
338
+
339
+ # Memory usage tracking
340
+ if torch.cuda.is_available():
341
+ memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
342
+ history['memory_usage'].append(memory_used)
343
+ print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
344
+
345
+ print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
346
+ print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
347
+
348
+ # Save best model
349
+ if val_loss < best_val_loss:
350
+ best_val_loss = val_loss
351
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
352
+ print("💾 Best model saved!")
353
+
354
+ self.is_trained = True
355
+ self.training_history = history
356
+ self.best_val_loss = best_val_loss
357
+
358
+ print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
359
+ return history
360
+
361
+ def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
362
+ """Training epoch"""
363
+ self.model.train()
364
+ total_loss = 0
365
+
366
+ pbar = tqdm(train_loader, desc='Training')
367
+ for latent, target in pbar:
368
+ latent = latent.to(self.device, non_blocking=True)
369
+ target = target.to(self.device, non_blocking=True)
370
+
371
+ optimizer.zero_grad(set_to_none=True) # Memory optimization
372
+
373
+ if scaler: # Mixed precision training
374
+ with torch.cuda.amp.autocast():
375
+ pred = self.model(latent)
376
+ loss = criterion(pred, target)
377
+
378
+ scaler.scale(loss).backward()
379
+ scaler.step(optimizer)
380
+ scaler.update()
381
+ else:
382
+ pred = self.model(latent)
383
+ loss = criterion(pred, target)
384
+ loss.backward()
385
+ optimizer.step()
386
+
387
+ total_loss += loss.item()
388
+ pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
389
+
390
+ # Clear memory
391
+ del pred, loss
392
+ if torch.cuda.is_available():
393
+ torch.cuda.empty_cache()
394
+
395
+ return total_loss / len(train_loader)
396
+
397
+ def _validate_epoch(self, val_loader, criterion):
398
+ """Validation"""
399
+ self.model.eval()
400
+ total_loss = 0
401
+
402
+ with torch.no_grad():
403
+ for latent, target in val_loader:
404
+ latent = latent.to(self.device, non_blocking=True)
405
+ target = target.to(self.device, non_blocking=True)
406
+
407
+ pred = self.model(latent)
408
+ loss = criterion(pred, target)
409
+ total_loss += loss.item()
410
+
411
+ # Clear memory
412
+ del pred, loss
413
+
414
+ return total_loss / len(val_loader)
415
+
416
+ def _create_dataset(self, latent_data, expression_data):
417
+ """Create dataset"""
418
+ class EfficientDataset(Dataset):
419
+ def __init__(self, latent, expression):
420
+ self.latent = torch.FloatTensor(latent)
421
+ self.expression = torch.FloatTensor(expression)
422
+
423
+ def __len__(self):
424
+ return len(self.latent)
425
+
426
+ def __getitem__(self, idx):
427
+ return self.latent[idx], self.expression[idx]
428
+
429
+ return EfficientDataset(latent_data, expression_data)
430
+
431
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
432
+ """Save checkpoint"""
433
+ torch.save({
434
+ 'epoch': epoch,
435
+ 'model_state_dict': self.model.state_dict(),
436
+ 'optimizer_state_dict': optimizer.state_dict(),
437
+ 'scheduler_state_dict': scheduler.state_dict(),
438
+ 'best_val_loss': best_loss,
439
+ 'training_history': history,
440
+ 'model_config': {
441
+ 'latent_dim': self.latent_dim,
442
+ 'gene_dim': self.gene_dim,
443
+ 'hidden_dim': self.hidden_dim
444
+ }
445
+ }, path)
446
+
447
+ def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
448
+ """
449
+ Prediction
450
+
451
+ Args:
452
+ latent_data: Latent variables [n_samples, latent_dim]
453
+ batch_size: Prediction batch size for memory control
454
+
455
+ Returns:
456
+ expression: Predicted expression [n_samples, gene_dim]
457
+ """
458
+ if not self.is_trained:
459
+ warnings.warn("Model not trained. Predictions may be inaccurate.")
460
+
461
+ self.model.eval()
462
+
463
+ if isinstance(latent_data, np.ndarray):
464
+ latent_data = torch.FloatTensor(latent_data)
465
+
466
+ # Predict in batches to save memory
467
+ predictions = []
468
+ with torch.no_grad():
469
+ for i in range(0, len(latent_data), batch_size):
470
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
471
+ batch_pred = self.model(batch_latent)
472
+ predictions.append(batch_pred.cpu())
473
+
474
+ # Clear memory
475
+ del batch_pred
476
+ if torch.cuda.is_available():
477
+ torch.cuda.empty_cache()
478
+
479
+ return torch.cat(predictions).numpy()
480
+
481
+ def load_model(self, model_path: str):
482
+ """Load pre-trained model"""
483
+ checkpoint = torch.load(model_path, map_location=self.device)
484
+ self.model.load_state_dict(checkpoint['model_state_dict'])
485
+ self.is_trained = True
486
+ self.training_history = checkpoint.get('training_history')
487
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
488
+ print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
489
+
490
+ def get_memory_info(self) -> Dict:
491
+ """Get memory usage information"""
492
+ if torch.cuda.is_available():
493
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3
494
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3
495
+ return {
496
+ 'allocated_gb': memory_allocated,
497
+ 'reserved_gb': memory_reserved,
498
+ 'available_gb': 20 - memory_allocated,
499
+ 'utilization_percent': (memory_allocated / 20) * 100
500
+ }
501
+ return {'available_gb': 'N/A (CPU mode)'}
502
+
503
+ '''
504
+ # Example usage with memory monitoring
505
+ def example_usage():
506
+ """Memory-efficient example"""
507
+
508
+ # 1. Initialize memory-efficient decoder
509
+ decoder = TranscriptomeDecoder(
510
+ latent_dim=100,
511
+ gene_dim=2000, # Reduced for example
512
+ hidden_dim=256 # Reduced for memory
513
+ )
514
+
515
+ # Check memory info
516
+ memory_info = decoder.get_memory_info()
517
+ print(f"📊 Memory Info: {memory_info}")
518
+
519
+ # 2. Generate example data
520
+ n_samples = 500 # Reduced for memory
521
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
522
+ expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
523
+ expression_data = np.maximum(expression_data, 0) # Non-negative
524
+
525
+ print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
526
+
527
+ # 3. Train with memory monitoring
528
+ history = decoder.train(
529
+ train_latent=latent_data,
530
+ train_expression=expression_data,
531
+ batch_size=8, # Small batch for memory
532
+ num_epochs=20 # Reduced for example
533
+ )
534
+
535
+ # 4. Memory-efficient prediction
536
+ test_latent = np.random.randn(5, 100).astype(np.float32)
537
+ predictions = decoder.predict(test_latent, batch_size=2)
538
+ print(f"🔮 Prediction shape: {predictions.shape}")
539
+
540
+ # 5. Final memory check
541
+ final_memory = decoder.get_memory_info()
542
+ print(f"💾 Final memory usage: {final_memory}")
543
+
544
+ return decoder
545
+
546
+ if __name__ == "__main__":
547
+ example_usage()
548
+
549
+ '''
SURE/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
3
  from .PerturbE import PerturbE
4
+ from .TranscriptomeDecoder import TranscriptomeDecoder
4
5
 
5
6
  from . import utils
6
7
  from . import codebook
@@ -9,5 +10,7 @@ from . import DensityFlow
9
10
  from . import atac
10
11
  from . import flow
11
12
  from . import perturb
13
+ from . import PerturbE
14
+ from . import TranscriptomeDecoder
12
15
 
13
- __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'flow', 'perturb', 'atac', 'utils', 'codebook']
16
+ __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.3.2
3
+ Version: 2.4.5
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,8 @@
1
1
  SURE/DensityFlow.py,sha256=YvaE9aPbAC2U7WhTye5i2AMtcw0BI_qS3gv9SP4aE0k,56676
2
- SURE/PerturbE.py,sha256=nomWc8nl4WkihhaITsAVXc-wTp1OTXfypQYP4kTz7JQ,51685
2
+ SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
3
3
  SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
4
- SURE/__init__.py,sha256=Q6J9HwAMxI42Boj8QCVbIDTAr0YyLY9Qj2ewTkM0uUw,336
4
+ SURE/TranscriptomeDecoder.py,sha256=fjTl2wC-nGTdbQGgFDbTmWYI8RoEg6J4cHPmoUoJJfI,21286
5
+ SURE/__init__.py,sha256=pNSGQ4BMqMXBAPHpFOYNB8_0vFW-RqPy3rr5fvdEEyU,473
5
6
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
6
7
  SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
7
8
  SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
@@ -18,9 +19,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
18
19
  SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
19
20
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
20
21
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
21
- sure_tools-2.3.2.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
22
- sure_tools-2.3.2.dist-info/METADATA,sha256=NDI2_X4N1FIZkNSBdksNPATk-XtSAxUJJs7Hlx3OTSI,2677
23
- sure_tools-2.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
- sure_tools-2.3.2.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
25
- sure_tools-2.3.2.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
26
- sure_tools-2.3.2.dist-info/RECORD,,
22
+ sure_tools-2.4.5.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
23
+ sure_tools-2.4.5.dist-info/METADATA,sha256=2GjCK_HUQ_Vs6b8AT2PIelOadhiVeOakI8B_OqbRyi0,2677
24
+ sure_tools-2.4.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ sure_tools-2.4.5.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
26
+ sure_tools-2.4.5.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
27
+ sure_tools-2.4.5.dist-info/RECORD,,