SURE-tools 2.3.2__tar.gz → 2.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (32) hide show
  1. {sure_tools-2.3.2 → sure_tools-2.4.2}/PKG-INFO +1 -1
  2. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/PerturbE.py +11 -4
  3. sure_tools-2.4.2/SURE/TranscriptomeDecoder.py +527 -0
  4. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/__init__.py +4 -1
  5. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/SOURCES.txt +1 -0
  7. {sure_tools-2.3.2 → sure_tools-2.4.2}/setup.py +1 -1
  8. {sure_tools-2.3.2 → sure_tools-2.4.2}/LICENSE +0 -0
  9. {sure_tools-2.3.2 → sure_tools-2.4.2}/README.md +0 -0
  10. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/DensityFlow.py +0 -0
  11. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/SURE.py +0 -0
  12. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/flow/flow_stats.py +0 -0
  21. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/perturb/__init__.py +0 -0
  23. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/perturb/perturb.py +0 -0
  24. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/utils/__init__.py +0 -0
  25. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/utils/custom_mlp.py +0 -0
  26. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/utils/queue.py +0 -0
  27. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE/utils/utils.py +0 -0
  28. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/dependency_links.txt +0 -0
  29. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/entry_points.txt +0 -0
  30. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/requires.txt +0 -0
  31. {sure_tools-2.3.2 → sure_tools-2.4.2}/SURE_tools.egg-info/top_level.txt +0 -0
  32. {sure_tools-2.3.2 → sure_tools-2.4.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.3.2
3
+ Version: 2.4.2
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -349,7 +349,8 @@ class PerturbE(nn.Module):
349
349
  else:
350
350
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
351
351
  elif self.loss_func == 'multinomial':
352
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
353
354
  elif self.loss_func == 'bernoulli':
354
355
  if self.use_zeroinflate:
355
356
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -433,7 +434,8 @@ class PerturbE(nn.Module):
433
434
  else:
434
435
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
435
436
  elif self.loss_func == 'multinomial':
436
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
437
439
  elif self.loss_func == 'bernoulli':
438
440
  if self.use_zeroinflate:
439
441
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -529,7 +531,8 @@ class PerturbE(nn.Module):
529
531
  else:
530
532
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
531
533
  elif self.loss_func == 'multinomial':
532
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
533
536
  elif self.loss_func == 'bernoulli':
534
537
  if self.use_zeroinflate:
535
538
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -635,7 +638,8 @@ class PerturbE(nn.Module):
635
638
  else:
636
639
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
637
640
  elif self.loss_func == 'multinomial':
638
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
639
643
  elif self.loss_func == 'bernoulli':
640
644
  if self.use_zeroinflate:
641
645
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -858,6 +862,9 @@ class PerturbE(nn.Module):
858
862
  if self.loss_func == 'bernoulli':
859
863
  #counts = self.sigmoid(concentrate)
860
864
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
861
868
  else:
862
869
  rate = concentrate.exp()
863
870
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -0,0 +1,527 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Tuple, Optional
8
+ import matplotlib.pyplot as plt
9
+ from tqdm import tqdm
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ class TranscriptomeDecoder:
14
+ def __init__(self,
15
+ latent_dim: int = 100,
16
+ gene_dim: int = 60000,
17
+ hidden_dim: int = 512, # Reduced for memory efficiency
18
+ device: str = None):
19
+ """
20
+ Whole-transcriptome decoder
21
+
22
+ Args:
23
+ latent_dim: Latent variable dimension (typically 50-100)
24
+ gene_dim: Number of genes (full transcriptome ~60,000)
25
+ hidden_dim: Hidden dimension (reduced for memory efficiency)
26
+ device: Computation device
27
+ """
28
+ self.latent_dim = latent_dim
29
+ self.gene_dim = gene_dim
30
+ self.hidden_dim = hidden_dim
31
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Memory optimization settings
34
+ self.gradient_checkpointing = True
35
+ self.mixed_precision = True
36
+
37
+ # Initialize model
38
+ self.model = self._build_model()
39
+ self.model.to(self.device)
40
+
41
+ # Training state
42
+ self.is_trained = False
43
+ self.training_history = None
44
+ self.best_val_loss = float('inf')
45
+
46
+ print(f"🚀 TranscriptomeDecoder Initialized:")
47
+ print(f" - Latent Dimension: {latent_dim}")
48
+ print(f" - Gene Dimension: {gene_dim}")
49
+ print(f" - Hidden Dimension: {hidden_dim}")
50
+ print(f" - Device: {self.device}")
51
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
+
53
+ class MemoryEfficientBlock(nn.Module):
54
+ """Memory-efficient building block with gradient checkpointing"""
55
+ def __init__(self, input_dim, output_dim, use_checkpointing=True):
56
+ super().__init__()
57
+ self.use_checkpointing = use_checkpointing
58
+ self.net = nn.Sequential(
59
+ nn.Linear(input_dim, output_dim),
60
+ nn.BatchNorm1d(output_dim),
61
+ nn.GELU(),
62
+ nn.Dropout(0.1)
63
+ )
64
+
65
+ def forward(self, x):
66
+ if self.use_checkpointing and self.training:
67
+ return torch.utils.checkpoint.checkpoint(self.net, x)
68
+ return self.net(x)
69
+
70
+ class SparseGeneProjection(nn.Module):
71
+ """Sparse gene projection to reduce memory usage"""
72
+ def __init__(self, latent_dim, gene_dim, projection_dim=256):
73
+ super().__init__()
74
+ self.projection_dim = projection_dim
75
+ self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
76
+ self.latent_projection = nn.Linear(latent_dim, projection_dim)
77
+ self.activation = nn.GELU()
78
+
79
+ def forward(self, latent):
80
+ # Project latent to gene space efficiently
81
+ batch_size = latent.shape[0]
82
+ latent_proj = self.latent_projection(latent) # [batch, projection_dim]
83
+
84
+ # Efficient matrix multiplication
85
+ gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
86
+ output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
+
88
+ return self.activation(output)
89
+
90
+ class ChunkedTransformer(nn.Module):
91
+ """Process genes in chunks to reduce memory usage"""
92
+ def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
93
+ super().__init__()
94
+ self.chunk_size = chunk_size
95
+ self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
+ self.layers = nn.ModuleList([
97
+ nn.Sequential(
98
+ nn.Linear(hidden_dim, hidden_dim),
99
+ nn.GELU(),
100
+ nn.Dropout(0.1),
101
+ nn.Linear(hidden_dim, hidden_dim),
102
+ ) for _ in range(num_layers)
103
+ ])
104
+
105
+ def forward(self, x):
106
+ # Process in chunks to save memory
107
+ batch_size = x.shape[0]
108
+ output = torch.zeros_like(x)
109
+
110
+ for i in range(self.num_chunks):
111
+ start_idx = i * self.chunk_size
112
+ end_idx = min((i + 1) * self.chunk_size, x.shape[1])
113
+
114
+ chunk = x[:, start_idx:end_idx]
115
+ for layer in self.layers:
116
+ chunk = layer(chunk) + chunk # Residual connection
117
+
118
+ output[:, start_idx:end_idx] = chunk
119
+
120
+ return output
121
+
122
+ class Decoder(nn.Module):
123
+ """Decoder model"""
124
+ def __init__(self, latent_dim, gene_dim, hidden_dim):
125
+ super().__init__()
126
+ self.latent_dim = latent_dim
127
+ self.gene_dim = gene_dim
128
+ self.hidden_dim = hidden_dim
129
+
130
+ # Stage 1: Latent expansion (memory efficient)
131
+ self.latent_expansion = nn.Sequential(
132
+ nn.Linear(latent_dim, hidden_dim * 2),
133
+ nn.GELU(),
134
+ nn.Dropout(0.1),
135
+ nn.Linear(hidden_dim * 2, hidden_dim),
136
+ )
137
+
138
+ # Stage 2: Sparse gene projection
139
+ self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
140
+ latent_dim, gene_dim, hidden_dim
141
+ )
142
+
143
+ # Stage 3: Chunked processing
144
+ self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
145
+ gene_dim, hidden_dim, chunk_size=2000, num_layers=3
146
+ )
147
+
148
+ # Stage 4: Multi-head output with memory efficiency
149
+ self.output_heads = nn.ModuleList([
150
+ nn.Sequential(
151
+ nn.Linear(hidden_dim, hidden_dim // 2),
152
+ nn.GELU(),
153
+ nn.Linear(hidden_dim // 2, 1)
154
+ ) for _ in range(2) # Reduced from 3 to 2 heads
155
+ ])
156
+
157
+ # Adaptive fusion
158
+ self.fusion_gate = nn.Sequential(
159
+ nn.Linear(hidden_dim, hidden_dim // 4),
160
+ nn.GELU(),
161
+ nn.Linear(hidden_dim // 4, len(self.output_heads)),
162
+ nn.Softmax(dim=-1)
163
+ )
164
+
165
+ # Output scaling
166
+ self.output_scale = nn.Parameter(torch.ones(1))
167
+ self.output_bias = nn.Parameter(torch.zeros(1))
168
+
169
+ self._init_weights()
170
+
171
+ def _init_weights(self):
172
+ for module in self.modules():
173
+ if isinstance(module, nn.Linear):
174
+ nn.init.xavier_uniform_(module.weight)
175
+ if module.bias is not None:
176
+ nn.init.zeros_(module.bias)
177
+
178
+ def forward(self, latent):
179
+ batch_size = latent.shape[0]
180
+
181
+ # 1. Latent expansion
182
+ latent_expanded = self.latent_expansion(latent)
183
+
184
+ # 2. Gene projection (memory efficient)
185
+ gene_features = self.gene_projection(latent)
186
+
187
+ # 3. Add latent information
188
+ print(f'{gene_features.shape}; {latent_expanded.shape}')
189
+ gene_features = gene_features + latent_expanded.unsqueeze(1)
190
+
191
+ # 4. Chunked processing (memory efficient)
192
+ gene_features = self.chunked_processor(gene_features)
193
+
194
+ # 5. Multi-head output with chunking
195
+ final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
196
+
197
+ # Process output in chunks
198
+ chunk_size = 5000
199
+ for i in range(0, self.gene_dim, chunk_size):
200
+ end_idx = min(i + chunk_size, self.gene_dim)
201
+ chunk = gene_features[:, i:end_idx]
202
+
203
+ head_outputs = []
204
+ for head in self.output_heads:
205
+ head_out = head(chunk).squeeze(-1)
206
+ head_outputs.append(head_out)
207
+
208
+ # Adaptive fusion
209
+ gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
210
+ gate_weights = gate_weights.unsqueeze(1)
211
+
212
+ # Weighted fusion
213
+ chunk_output = torch.zeros_like(head_outputs[0])
214
+ for j, head_out in enumerate(head_outputs):
215
+ chunk_output = chunk_output + gate_weights[:, :, j] * head_out
216
+
217
+ final_output[:, i:end_idx] = chunk_output
218
+
219
+ # Final activation
220
+ final_output = F.softplus(final_output * self.output_scale + self.output_bias)
221
+
222
+ return final_output
223
+
224
+ def _build_model(self):
225
+ """Build model"""
226
+ return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
227
+
228
+ def train(self,
229
+ train_latent: np.ndarray,
230
+ train_expression: np.ndarray,
231
+ val_latent: np.ndarray = None,
232
+ val_expression: np.ndarray = None,
233
+ batch_size: int = 16, # Reduced batch size for memory
234
+ num_epochs: int = 100,
235
+ learning_rate: float = 1e-4,
236
+ checkpoint_path: str = 'transcriptome_decoder.pth'):
237
+ """
238
+ Memory-efficient training with optimizations
239
+
240
+ Args:
241
+ train_latent: Training latent variables
242
+ train_expression: Training expression data
243
+ val_latent: Validation latent variables
244
+ val_expression: Validation expression data
245
+ batch_size: Reduced batch size for memory constraints
246
+ num_epochs: Number of training epochs
247
+ learning_rate: Learning rate
248
+ checkpoint_path: Model save path
249
+ """
250
+ print("🚀 Starting Training...")
251
+ print(f"📊 Batch size: {batch_size}")
252
+
253
+ # Enable memory optimizations
254
+ torch.backends.cudnn.benchmark = True
255
+ if self.mixed_precision:
256
+ scaler = torch.cuda.amp.GradScaler()
257
+
258
+ # Data preparation
259
+ train_dataset = self._create_dataset(train_latent, train_expression)
260
+
261
+ if val_latent is not None and val_expression is not None:
262
+ val_dataset = self._create_dataset(val_latent, val_expression)
263
+ else:
264
+ # Auto split
265
+ train_size = int(0.9 * len(train_dataset))
266
+ val_size = len(train_dataset) - train_size
267
+ train_dataset, val_dataset = torch.utils.data.random_split(
268
+ train_dataset, [train_size, val_size]
269
+ )
270
+
271
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
272
+ pin_memory=True, num_workers=2)
273
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
274
+ pin_memory=True, num_workers=2)
275
+
276
+ # Optimizer with memory-friendly settings
277
+ optimizer = optim.AdamW(
278
+ self.model.parameters(),
279
+ lr=learning_rate,
280
+ weight_decay=0.01,
281
+ betas=(0.9, 0.999)
282
+ )
283
+
284
+ # Learning rate scheduler
285
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
286
+
287
+ # Loss function
288
+ criterion = nn.MSELoss()
289
+
290
+ # Training history
291
+ history = {
292
+ 'train_loss': [], 'val_loss': [],
293
+ 'learning_rate': [], 'memory_usage': []
294
+ }
295
+
296
+ best_val_loss = float('inf')
297
+
298
+ for epoch in range(1, num_epochs + 1):
299
+ print(f"\n📍 Epoch {epoch}/{num_epochs}")
300
+
301
+ # Training phase with memory monitoring
302
+ train_loss = self._train_epoch(
303
+ train_loader, optimizer, criterion, scaler if self.mixed_precision else None
304
+ )
305
+
306
+ # Validation phase
307
+ val_loss = self._validate_epoch(val_loader, criterion)
308
+
309
+ # Update scheduler
310
+ scheduler.step()
311
+
312
+ # Record history
313
+ history['train_loss'].append(train_loss)
314
+ history['val_loss'].append(val_loss)
315
+ history['learning_rate'].append(optimizer.param_groups[0]['lr'])
316
+
317
+ # Memory usage tracking
318
+ if torch.cuda.is_available():
319
+ memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
320
+ history['memory_usage'].append(memory_used)
321
+ print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
322
+
323
+ print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
324
+ print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
325
+
326
+ # Save best model
327
+ if val_loss < best_val_loss:
328
+ best_val_loss = val_loss
329
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
330
+ print("💾 Best model saved!")
331
+
332
+ self.is_trained = True
333
+ self.training_history = history
334
+ self.best_val_loss = best_val_loss
335
+
336
+ print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
337
+ return history
338
+
339
+ def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
340
+ """Training epoch"""
341
+ self.model.train()
342
+ total_loss = 0
343
+
344
+ pbar = tqdm(train_loader, desc='Training')
345
+ for latent, target in pbar:
346
+ latent = latent.to(self.device, non_blocking=True)
347
+ target = target.to(self.device, non_blocking=True)
348
+
349
+ optimizer.zero_grad(set_to_none=True) # Memory optimization
350
+
351
+ if scaler: # Mixed precision training
352
+ with torch.cuda.amp.autocast():
353
+ pred = self.model(latent)
354
+ loss = criterion(pred, target)
355
+
356
+ scaler.scale(loss).backward()
357
+ scaler.step(optimizer)
358
+ scaler.update()
359
+ else:
360
+ pred = self.model(latent)
361
+ loss = criterion(pred, target)
362
+ loss.backward()
363
+ optimizer.step()
364
+
365
+ total_loss += loss.item()
366
+ pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
367
+
368
+ # Clear memory
369
+ del pred, loss
370
+ if torch.cuda.is_available():
371
+ torch.cuda.empty_cache()
372
+
373
+ return total_loss / len(train_loader)
374
+
375
+ def _validate_epoch(self, val_loader, criterion):
376
+ """Validation"""
377
+ self.model.eval()
378
+ total_loss = 0
379
+
380
+ with torch.no_grad():
381
+ for latent, target in val_loader:
382
+ latent = latent.to(self.device, non_blocking=True)
383
+ target = target.to(self.device, non_blocking=True)
384
+
385
+ pred = self.model(latent)
386
+ loss = criterion(pred, target)
387
+ total_loss += loss.item()
388
+
389
+ # Clear memory
390
+ del pred, loss
391
+
392
+ return total_loss / len(val_loader)
393
+
394
+ def _create_dataset(self, latent_data, expression_data):
395
+ """Create dataset"""
396
+ class EfficientDataset(Dataset):
397
+ def __init__(self, latent, expression):
398
+ self.latent = torch.FloatTensor(latent)
399
+ self.expression = torch.FloatTensor(expression)
400
+
401
+ def __len__(self):
402
+ return len(self.latent)
403
+
404
+ def __getitem__(self, idx):
405
+ return self.latent[idx], self.expression[idx]
406
+
407
+ return EfficientDataset(latent_data, expression_data)
408
+
409
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
410
+ """Save checkpoint"""
411
+ torch.save({
412
+ 'epoch': epoch,
413
+ 'model_state_dict': self.model.state_dict(),
414
+ 'optimizer_state_dict': optimizer.state_dict(),
415
+ 'scheduler_state_dict': scheduler.state_dict(),
416
+ 'best_val_loss': best_loss,
417
+ 'training_history': history,
418
+ 'model_config': {
419
+ 'latent_dim': self.latent_dim,
420
+ 'gene_dim': self.gene_dim,
421
+ 'hidden_dim': self.hidden_dim
422
+ }
423
+ }, path)
424
+
425
+ def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
426
+ """
427
+ Prediction
428
+
429
+ Args:
430
+ latent_data: Latent variables [n_samples, latent_dim]
431
+ batch_size: Prediction batch size for memory control
432
+
433
+ Returns:
434
+ expression: Predicted expression [n_samples, gene_dim]
435
+ """
436
+ if not self.is_trained:
437
+ warnings.warn("Model not trained. Predictions may be inaccurate.")
438
+
439
+ self.model.eval()
440
+
441
+ if isinstance(latent_data, np.ndarray):
442
+ latent_data = torch.FloatTensor(latent_data)
443
+
444
+ # Predict in batches to save memory
445
+ predictions = []
446
+ with torch.no_grad():
447
+ for i in range(0, len(latent_data), batch_size):
448
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
449
+ batch_pred = self.model(batch_latent)
450
+ predictions.append(batch_pred.cpu())
451
+
452
+ # Clear memory
453
+ del batch_pred
454
+ if torch.cuda.is_available():
455
+ torch.cuda.empty_cache()
456
+
457
+ return torch.cat(predictions).numpy()
458
+
459
+ def load_model(self, model_path: str):
460
+ """Load pre-trained model"""
461
+ checkpoint = torch.load(model_path, map_location=self.device)
462
+ self.model.load_state_dict(checkpoint['model_state_dict'])
463
+ self.is_trained = True
464
+ self.training_history = checkpoint.get('training_history')
465
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
466
+ print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
467
+
468
+ def get_memory_info(self) -> Dict:
469
+ """Get memory usage information"""
470
+ if torch.cuda.is_available():
471
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3
472
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3
473
+ return {
474
+ 'allocated_gb': memory_allocated,
475
+ 'reserved_gb': memory_reserved,
476
+ 'available_gb': 20 - memory_allocated,
477
+ 'utilization_percent': (memory_allocated / 20) * 100
478
+ }
479
+ return {'available_gb': 'N/A (CPU mode)'}
480
+
481
+ '''
482
+ # Example usage with memory monitoring
483
+ def example_usage():
484
+ """Memory-efficient example"""
485
+
486
+ # 1. Initialize memory-efficient decoder
487
+ decoder = TranscriptomeDecoder(
488
+ latent_dim=100,
489
+ gene_dim=2000, # Reduced for example
490
+ hidden_dim=256 # Reduced for memory
491
+ )
492
+
493
+ # Check memory info
494
+ memory_info = decoder.get_memory_info()
495
+ print(f"📊 Memory Info: {memory_info}")
496
+
497
+ # 2. Generate example data
498
+ n_samples = 500 # Reduced for memory
499
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
500
+ expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
501
+ expression_data = np.maximum(expression_data, 0) # Non-negative
502
+
503
+ print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
504
+
505
+ # 3. Train with memory monitoring
506
+ history = decoder.train(
507
+ train_latent=latent_data,
508
+ train_expression=expression_data,
509
+ batch_size=8, # Small batch for memory
510
+ num_epochs=20 # Reduced for example
511
+ )
512
+
513
+ # 4. Memory-efficient prediction
514
+ test_latent = np.random.randn(5, 100).astype(np.float32)
515
+ predictions = decoder.predict(test_latent, batch_size=2)
516
+ print(f"🔮 Prediction shape: {predictions.shape}")
517
+
518
+ # 5. Final memory check
519
+ final_memory = decoder.get_memory_info()
520
+ print(f"💾 Final memory usage: {final_memory}")
521
+
522
+ return decoder
523
+
524
+ if __name__ == "__main__":
525
+ example_usage()
526
+
527
+ '''
@@ -1,6 +1,7 @@
1
1
  from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
3
  from .PerturbE import PerturbE
4
+ from .TranscriptomeDecoder import TranscriptomeDecoder
4
5
 
5
6
  from . import utils
6
7
  from . import codebook
@@ -9,5 +10,7 @@ from . import DensityFlow
9
10
  from . import atac
10
11
  from . import flow
11
12
  from . import perturb
13
+ from . import PerturbE
14
+ from . import TranscriptomeDecoder
12
15
 
13
- __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'flow', 'perturb', 'atac', 'utils', 'codebook']
16
+ __all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.3.2
3
+ Version: 2.4.2
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -4,6 +4,7 @@ setup.py
4
4
  SURE/DensityFlow.py
5
5
  SURE/PerturbE.py
6
6
  SURE/SURE.py
7
+ SURE/TranscriptomeDecoder.py
7
8
  SURE/__init__.py
8
9
  SURE/assembly/__init__.py
9
10
  SURE/assembly/assembly.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.3.2',
8
+ version='2.4.2',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes