SURE-tools 2.2.2__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +103 -74
- SURE/{PerturbFlow.py → PerturbE.py} +51 -110
- SURE/TranscriptomeDecoder.py +527 -0
- SURE/__init__.py +5 -1
- SURE/perturb/perturb.py +27 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/METADATA +1 -1
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/RECORD +12 -11
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch.utils.data import Dataset, DataLoader
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, List, Tuple, Optional
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
import warnings
|
|
11
|
+
warnings.filterwarnings('ignore')
|
|
12
|
+
|
|
13
|
+
class TranscriptomeDecoder:
|
|
14
|
+
def __init__(self,
|
|
15
|
+
latent_dim: int = 100,
|
|
16
|
+
gene_dim: int = 60000,
|
|
17
|
+
hidden_dim: int = 512, # Reduced for memory efficiency
|
|
18
|
+
device: str = None):
|
|
19
|
+
"""
|
|
20
|
+
Whole-transcriptome decoder
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
latent_dim: Latent variable dimension (typically 50-100)
|
|
24
|
+
gene_dim: Number of genes (full transcriptome ~60,000)
|
|
25
|
+
hidden_dim: Hidden dimension (reduced for memory efficiency)
|
|
26
|
+
device: Computation device
|
|
27
|
+
"""
|
|
28
|
+
self.latent_dim = latent_dim
|
|
29
|
+
self.gene_dim = gene_dim
|
|
30
|
+
self.hidden_dim = hidden_dim
|
|
31
|
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
32
|
+
|
|
33
|
+
# Memory optimization settings
|
|
34
|
+
self.gradient_checkpointing = True
|
|
35
|
+
self.mixed_precision = True
|
|
36
|
+
|
|
37
|
+
# Initialize model
|
|
38
|
+
self.model = self._build_model()
|
|
39
|
+
self.model.to(self.device)
|
|
40
|
+
|
|
41
|
+
# Training state
|
|
42
|
+
self.is_trained = False
|
|
43
|
+
self.training_history = None
|
|
44
|
+
self.best_val_loss = float('inf')
|
|
45
|
+
|
|
46
|
+
print(f"🚀 TranscriptomeDecoder Initialized:")
|
|
47
|
+
print(f" - Latent Dimension: {latent_dim}")
|
|
48
|
+
print(f" - Gene Dimension: {gene_dim}")
|
|
49
|
+
print(f" - Hidden Dimension: {hidden_dim}")
|
|
50
|
+
print(f" - Device: {self.device}")
|
|
51
|
+
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
52
|
+
|
|
53
|
+
class MemoryEfficientBlock(nn.Module):
|
|
54
|
+
"""Memory-efficient building block with gradient checkpointing"""
|
|
55
|
+
def __init__(self, input_dim, output_dim, use_checkpointing=True):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.use_checkpointing = use_checkpointing
|
|
58
|
+
self.net = nn.Sequential(
|
|
59
|
+
nn.Linear(input_dim, output_dim),
|
|
60
|
+
nn.BatchNorm1d(output_dim),
|
|
61
|
+
nn.GELU(),
|
|
62
|
+
nn.Dropout(0.1)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
if self.use_checkpointing and self.training:
|
|
67
|
+
return torch.utils.checkpoint.checkpoint(self.net, x)
|
|
68
|
+
return self.net(x)
|
|
69
|
+
|
|
70
|
+
class SparseGeneProjection(nn.Module):
|
|
71
|
+
"""Sparse gene projection to reduce memory usage"""
|
|
72
|
+
def __init__(self, latent_dim, gene_dim, projection_dim=256):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.projection_dim = projection_dim
|
|
75
|
+
self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
|
|
76
|
+
self.latent_projection = nn.Linear(latent_dim, projection_dim)
|
|
77
|
+
self.activation = nn.GELU()
|
|
78
|
+
|
|
79
|
+
def forward(self, latent):
|
|
80
|
+
# Project latent to gene space efficiently
|
|
81
|
+
batch_size = latent.shape[0]
|
|
82
|
+
latent_proj = self.latent_projection(latent) # [batch, projection_dim]
|
|
83
|
+
|
|
84
|
+
# Efficient matrix multiplication
|
|
85
|
+
gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
|
|
86
|
+
output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
|
|
87
|
+
|
|
88
|
+
return self.activation(output)
|
|
89
|
+
|
|
90
|
+
class ChunkedTransformer(nn.Module):
|
|
91
|
+
"""Process genes in chunks to reduce memory usage"""
|
|
92
|
+
def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
|
|
93
|
+
super().__init__()
|
|
94
|
+
self.chunk_size = chunk_size
|
|
95
|
+
self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
|
|
96
|
+
self.layers = nn.ModuleList([
|
|
97
|
+
nn.Sequential(
|
|
98
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
99
|
+
nn.GELU(),
|
|
100
|
+
nn.Dropout(0.1),
|
|
101
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
102
|
+
) for _ in range(num_layers)
|
|
103
|
+
])
|
|
104
|
+
|
|
105
|
+
def forward(self, x):
|
|
106
|
+
# Process in chunks to save memory
|
|
107
|
+
batch_size = x.shape[0]
|
|
108
|
+
output = torch.zeros_like(x)
|
|
109
|
+
|
|
110
|
+
for i in range(self.num_chunks):
|
|
111
|
+
start_idx = i * self.chunk_size
|
|
112
|
+
end_idx = min((i + 1) * self.chunk_size, x.shape[1])
|
|
113
|
+
|
|
114
|
+
chunk = x[:, start_idx:end_idx]
|
|
115
|
+
for layer in self.layers:
|
|
116
|
+
chunk = layer(chunk) + chunk # Residual connection
|
|
117
|
+
|
|
118
|
+
output[:, start_idx:end_idx] = chunk
|
|
119
|
+
|
|
120
|
+
return output
|
|
121
|
+
|
|
122
|
+
class Decoder(nn.Module):
|
|
123
|
+
"""Decoder model"""
|
|
124
|
+
def __init__(self, latent_dim, gene_dim, hidden_dim):
|
|
125
|
+
super().__init__()
|
|
126
|
+
self.latent_dim = latent_dim
|
|
127
|
+
self.gene_dim = gene_dim
|
|
128
|
+
self.hidden_dim = hidden_dim
|
|
129
|
+
|
|
130
|
+
# Stage 1: Latent expansion (memory efficient)
|
|
131
|
+
self.latent_expansion = nn.Sequential(
|
|
132
|
+
nn.Linear(latent_dim, hidden_dim * 2),
|
|
133
|
+
nn.GELU(),
|
|
134
|
+
nn.Dropout(0.1),
|
|
135
|
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Stage 2: Sparse gene projection
|
|
139
|
+
self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
|
|
140
|
+
latent_dim, gene_dim, hidden_dim
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Stage 3: Chunked processing
|
|
144
|
+
self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
|
|
145
|
+
gene_dim, hidden_dim, chunk_size=2000, num_layers=3
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Stage 4: Multi-head output with memory efficiency
|
|
149
|
+
self.output_heads = nn.ModuleList([
|
|
150
|
+
nn.Sequential(
|
|
151
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
152
|
+
nn.GELU(),
|
|
153
|
+
nn.Linear(hidden_dim // 2, 1)
|
|
154
|
+
) for _ in range(2) # Reduced from 3 to 2 heads
|
|
155
|
+
])
|
|
156
|
+
|
|
157
|
+
# Adaptive fusion
|
|
158
|
+
self.fusion_gate = nn.Sequential(
|
|
159
|
+
nn.Linear(hidden_dim, hidden_dim // 4),
|
|
160
|
+
nn.GELU(),
|
|
161
|
+
nn.Linear(hidden_dim // 4, len(self.output_heads)),
|
|
162
|
+
nn.Softmax(dim=-1)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Output scaling
|
|
166
|
+
self.output_scale = nn.Parameter(torch.ones(1))
|
|
167
|
+
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
168
|
+
|
|
169
|
+
self._init_weights()
|
|
170
|
+
|
|
171
|
+
def _init_weights(self):
|
|
172
|
+
for module in self.modules():
|
|
173
|
+
if isinstance(module, nn.Linear):
|
|
174
|
+
nn.init.xavier_uniform_(module.weight)
|
|
175
|
+
if module.bias is not None:
|
|
176
|
+
nn.init.zeros_(module.bias)
|
|
177
|
+
|
|
178
|
+
def forward(self, latent):
|
|
179
|
+
batch_size = latent.shape[0]
|
|
180
|
+
|
|
181
|
+
# 1. Latent expansion
|
|
182
|
+
latent_expanded = self.latent_expansion(latent)
|
|
183
|
+
|
|
184
|
+
# 2. Gene projection (memory efficient)
|
|
185
|
+
gene_features = self.gene_projection(latent)
|
|
186
|
+
|
|
187
|
+
# 3. Add latent information
|
|
188
|
+
print(f'{gene_features.shape}; {latent_expanded.unsqueeze(1).shape}')
|
|
189
|
+
gene_features = gene_features + latent_expanded.unsqueeze(1)
|
|
190
|
+
|
|
191
|
+
# 4. Chunked processing (memory efficient)
|
|
192
|
+
gene_features = self.chunked_processor(gene_features)
|
|
193
|
+
|
|
194
|
+
# 5. Multi-head output with chunking
|
|
195
|
+
final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
|
|
196
|
+
|
|
197
|
+
# Process output in chunks
|
|
198
|
+
chunk_size = 5000
|
|
199
|
+
for i in range(0, self.gene_dim, chunk_size):
|
|
200
|
+
end_idx = min(i + chunk_size, self.gene_dim)
|
|
201
|
+
chunk = gene_features[:, i:end_idx]
|
|
202
|
+
|
|
203
|
+
head_outputs = []
|
|
204
|
+
for head in self.output_heads:
|
|
205
|
+
head_out = head(chunk).squeeze(-1)
|
|
206
|
+
head_outputs.append(head_out)
|
|
207
|
+
|
|
208
|
+
# Adaptive fusion
|
|
209
|
+
gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
|
|
210
|
+
gate_weights = gate_weights.unsqueeze(1)
|
|
211
|
+
|
|
212
|
+
# Weighted fusion
|
|
213
|
+
chunk_output = torch.zeros_like(head_outputs[0])
|
|
214
|
+
for j, head_out in enumerate(head_outputs):
|
|
215
|
+
chunk_output = chunk_output + gate_weights[:, :, j] * head_out
|
|
216
|
+
|
|
217
|
+
final_output[:, i:end_idx] = chunk_output
|
|
218
|
+
|
|
219
|
+
# Final activation
|
|
220
|
+
final_output = F.softplus(final_output * self.output_scale + self.output_bias)
|
|
221
|
+
|
|
222
|
+
return final_output
|
|
223
|
+
|
|
224
|
+
def _build_model(self):
|
|
225
|
+
"""Build model"""
|
|
226
|
+
return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
|
|
227
|
+
|
|
228
|
+
def train(self,
|
|
229
|
+
train_latent: np.ndarray,
|
|
230
|
+
train_expression: np.ndarray,
|
|
231
|
+
val_latent: np.ndarray = None,
|
|
232
|
+
val_expression: np.ndarray = None,
|
|
233
|
+
batch_size: int = 16, # Reduced batch size for memory
|
|
234
|
+
num_epochs: int = 100,
|
|
235
|
+
learning_rate: float = 1e-4,
|
|
236
|
+
checkpoint_path: str = 'transcriptome_decoder.pth'):
|
|
237
|
+
"""
|
|
238
|
+
Memory-efficient training with optimizations
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
train_latent: Training latent variables
|
|
242
|
+
train_expression: Training expression data
|
|
243
|
+
val_latent: Validation latent variables
|
|
244
|
+
val_expression: Validation expression data
|
|
245
|
+
batch_size: Reduced batch size for memory constraints
|
|
246
|
+
num_epochs: Number of training epochs
|
|
247
|
+
learning_rate: Learning rate
|
|
248
|
+
checkpoint_path: Model save path
|
|
249
|
+
"""
|
|
250
|
+
print("🚀 Starting Training...")
|
|
251
|
+
print(f"📊 Batch size: {batch_size}")
|
|
252
|
+
|
|
253
|
+
# Enable memory optimizations
|
|
254
|
+
torch.backends.cudnn.benchmark = True
|
|
255
|
+
if self.mixed_precision:
|
|
256
|
+
scaler = torch.cuda.amp.GradScaler()
|
|
257
|
+
|
|
258
|
+
# Data preparation
|
|
259
|
+
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
260
|
+
|
|
261
|
+
if val_latent is not None and val_expression is not None:
|
|
262
|
+
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
263
|
+
else:
|
|
264
|
+
# Auto split
|
|
265
|
+
train_size = int(0.9 * len(train_dataset))
|
|
266
|
+
val_size = len(train_dataset) - train_size
|
|
267
|
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
|
268
|
+
train_dataset, [train_size, val_size]
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
|
272
|
+
pin_memory=True, num_workers=2)
|
|
273
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
|
274
|
+
pin_memory=True, num_workers=2)
|
|
275
|
+
|
|
276
|
+
# Optimizer with memory-friendly settings
|
|
277
|
+
optimizer = optim.AdamW(
|
|
278
|
+
self.model.parameters(),
|
|
279
|
+
lr=learning_rate,
|
|
280
|
+
weight_decay=0.01,
|
|
281
|
+
betas=(0.9, 0.999)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Learning rate scheduler
|
|
285
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
286
|
+
|
|
287
|
+
# Loss function
|
|
288
|
+
criterion = nn.MSELoss()
|
|
289
|
+
|
|
290
|
+
# Training history
|
|
291
|
+
history = {
|
|
292
|
+
'train_loss': [], 'val_loss': [],
|
|
293
|
+
'learning_rate': [], 'memory_usage': []
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
best_val_loss = float('inf')
|
|
297
|
+
|
|
298
|
+
for epoch in range(1, num_epochs + 1):
|
|
299
|
+
print(f"\n📍 Epoch {epoch}/{num_epochs}")
|
|
300
|
+
|
|
301
|
+
# Training phase with memory monitoring
|
|
302
|
+
train_loss = self._train_epoch(
|
|
303
|
+
train_loader, optimizer, criterion, scaler if self.mixed_precision else None
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Validation phase
|
|
307
|
+
val_loss = self._validate_epoch(val_loader, criterion)
|
|
308
|
+
|
|
309
|
+
# Update scheduler
|
|
310
|
+
scheduler.step()
|
|
311
|
+
|
|
312
|
+
# Record history
|
|
313
|
+
history['train_loss'].append(train_loss)
|
|
314
|
+
history['val_loss'].append(val_loss)
|
|
315
|
+
history['learning_rate'].append(optimizer.param_groups[0]['lr'])
|
|
316
|
+
|
|
317
|
+
# Memory usage tracking
|
|
318
|
+
if torch.cuda.is_available():
|
|
319
|
+
memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
|
|
320
|
+
history['memory_usage'].append(memory_used)
|
|
321
|
+
print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
|
|
322
|
+
|
|
323
|
+
print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
|
324
|
+
print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
|
|
325
|
+
|
|
326
|
+
# Save best model
|
|
327
|
+
if val_loss < best_val_loss:
|
|
328
|
+
best_val_loss = val_loss
|
|
329
|
+
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
330
|
+
print("💾 Best model saved!")
|
|
331
|
+
|
|
332
|
+
self.is_trained = True
|
|
333
|
+
self.training_history = history
|
|
334
|
+
self.best_val_loss = best_val_loss
|
|
335
|
+
|
|
336
|
+
print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
|
|
337
|
+
return history
|
|
338
|
+
|
|
339
|
+
def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
|
|
340
|
+
"""Training epoch"""
|
|
341
|
+
self.model.train()
|
|
342
|
+
total_loss = 0
|
|
343
|
+
|
|
344
|
+
pbar = tqdm(train_loader, desc='Training')
|
|
345
|
+
for latent, target in pbar:
|
|
346
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
347
|
+
target = target.to(self.device, non_blocking=True)
|
|
348
|
+
|
|
349
|
+
optimizer.zero_grad(set_to_none=True) # Memory optimization
|
|
350
|
+
|
|
351
|
+
if scaler: # Mixed precision training
|
|
352
|
+
with torch.cuda.amp.autocast():
|
|
353
|
+
pred = self.model(latent)
|
|
354
|
+
loss = criterion(pred, target)
|
|
355
|
+
|
|
356
|
+
scaler.scale(loss).backward()
|
|
357
|
+
scaler.step(optimizer)
|
|
358
|
+
scaler.update()
|
|
359
|
+
else:
|
|
360
|
+
pred = self.model(latent)
|
|
361
|
+
loss = criterion(pred, target)
|
|
362
|
+
loss.backward()
|
|
363
|
+
optimizer.step()
|
|
364
|
+
|
|
365
|
+
total_loss += loss.item()
|
|
366
|
+
pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
|
|
367
|
+
|
|
368
|
+
# Clear memory
|
|
369
|
+
del pred, loss
|
|
370
|
+
if torch.cuda.is_available():
|
|
371
|
+
torch.cuda.empty_cache()
|
|
372
|
+
|
|
373
|
+
return total_loss / len(train_loader)
|
|
374
|
+
|
|
375
|
+
def _validate_epoch(self, val_loader, criterion):
|
|
376
|
+
"""Validation"""
|
|
377
|
+
self.model.eval()
|
|
378
|
+
total_loss = 0
|
|
379
|
+
|
|
380
|
+
with torch.no_grad():
|
|
381
|
+
for latent, target in val_loader:
|
|
382
|
+
latent = latent.to(self.device, non_blocking=True)
|
|
383
|
+
target = target.to(self.device, non_blocking=True)
|
|
384
|
+
|
|
385
|
+
pred = self.model(latent)
|
|
386
|
+
loss = criterion(pred, target)
|
|
387
|
+
total_loss += loss.item()
|
|
388
|
+
|
|
389
|
+
# Clear memory
|
|
390
|
+
del pred, loss
|
|
391
|
+
|
|
392
|
+
return total_loss / len(val_loader)
|
|
393
|
+
|
|
394
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
395
|
+
"""Create dataset"""
|
|
396
|
+
class EfficientDataset(Dataset):
|
|
397
|
+
def __init__(self, latent, expression):
|
|
398
|
+
self.latent = torch.FloatTensor(latent)
|
|
399
|
+
self.expression = torch.FloatTensor(expression)
|
|
400
|
+
|
|
401
|
+
def __len__(self):
|
|
402
|
+
return len(self.latent)
|
|
403
|
+
|
|
404
|
+
def __getitem__(self, idx):
|
|
405
|
+
return self.latent[idx], self.expression[idx]
|
|
406
|
+
|
|
407
|
+
return EfficientDataset(latent_data, expression_data)
|
|
408
|
+
|
|
409
|
+
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
410
|
+
"""Save checkpoint"""
|
|
411
|
+
torch.save({
|
|
412
|
+
'epoch': epoch,
|
|
413
|
+
'model_state_dict': self.model.state_dict(),
|
|
414
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
415
|
+
'scheduler_state_dict': scheduler.state_dict(),
|
|
416
|
+
'best_val_loss': best_loss,
|
|
417
|
+
'training_history': history,
|
|
418
|
+
'model_config': {
|
|
419
|
+
'latent_dim': self.latent_dim,
|
|
420
|
+
'gene_dim': self.gene_dim,
|
|
421
|
+
'hidden_dim': self.hidden_dim
|
|
422
|
+
}
|
|
423
|
+
}, path)
|
|
424
|
+
|
|
425
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
|
|
426
|
+
"""
|
|
427
|
+
Prediction
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
latent_data: Latent variables [n_samples, latent_dim]
|
|
431
|
+
batch_size: Prediction batch size for memory control
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
expression: Predicted expression [n_samples, gene_dim]
|
|
435
|
+
"""
|
|
436
|
+
if not self.is_trained:
|
|
437
|
+
warnings.warn("Model not trained. Predictions may be inaccurate.")
|
|
438
|
+
|
|
439
|
+
self.model.eval()
|
|
440
|
+
|
|
441
|
+
if isinstance(latent_data, np.ndarray):
|
|
442
|
+
latent_data = torch.FloatTensor(latent_data)
|
|
443
|
+
|
|
444
|
+
# Predict in batches to save memory
|
|
445
|
+
predictions = []
|
|
446
|
+
with torch.no_grad():
|
|
447
|
+
for i in range(0, len(latent_data), batch_size):
|
|
448
|
+
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
449
|
+
batch_pred = self.model(batch_latent)
|
|
450
|
+
predictions.append(batch_pred.cpu())
|
|
451
|
+
|
|
452
|
+
# Clear memory
|
|
453
|
+
del batch_pred
|
|
454
|
+
if torch.cuda.is_available():
|
|
455
|
+
torch.cuda.empty_cache()
|
|
456
|
+
|
|
457
|
+
return torch.cat(predictions).numpy()
|
|
458
|
+
|
|
459
|
+
def load_model(self, model_path: str):
|
|
460
|
+
"""Load pre-trained model"""
|
|
461
|
+
checkpoint = torch.load(model_path, map_location=self.device)
|
|
462
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
463
|
+
self.is_trained = True
|
|
464
|
+
self.training_history = checkpoint.get('training_history')
|
|
465
|
+
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
466
|
+
print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
|
|
467
|
+
|
|
468
|
+
def get_memory_info(self) -> Dict:
|
|
469
|
+
"""Get memory usage information"""
|
|
470
|
+
if torch.cuda.is_available():
|
|
471
|
+
memory_allocated = torch.cuda.memory_allocated() / 1024**3
|
|
472
|
+
memory_reserved = torch.cuda.memory_reserved() / 1024**3
|
|
473
|
+
return {
|
|
474
|
+
'allocated_gb': memory_allocated,
|
|
475
|
+
'reserved_gb': memory_reserved,
|
|
476
|
+
'available_gb': 20 - memory_allocated,
|
|
477
|
+
'utilization_percent': (memory_allocated / 20) * 100
|
|
478
|
+
}
|
|
479
|
+
return {'available_gb': 'N/A (CPU mode)'}
|
|
480
|
+
|
|
481
|
+
'''
|
|
482
|
+
# Example usage with memory monitoring
|
|
483
|
+
def example_usage():
|
|
484
|
+
"""Memory-efficient example"""
|
|
485
|
+
|
|
486
|
+
# 1. Initialize memory-efficient decoder
|
|
487
|
+
decoder = TranscriptomeDecoder(
|
|
488
|
+
latent_dim=100,
|
|
489
|
+
gene_dim=2000, # Reduced for example
|
|
490
|
+
hidden_dim=256 # Reduced for memory
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Check memory info
|
|
494
|
+
memory_info = decoder.get_memory_info()
|
|
495
|
+
print(f"📊 Memory Info: {memory_info}")
|
|
496
|
+
|
|
497
|
+
# 2. Generate example data
|
|
498
|
+
n_samples = 500 # Reduced for memory
|
|
499
|
+
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
500
|
+
expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
|
|
501
|
+
expression_data = np.maximum(expression_data, 0) # Non-negative
|
|
502
|
+
|
|
503
|
+
print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
504
|
+
|
|
505
|
+
# 3. Train with memory monitoring
|
|
506
|
+
history = decoder.train(
|
|
507
|
+
train_latent=latent_data,
|
|
508
|
+
train_expression=expression_data,
|
|
509
|
+
batch_size=8, # Small batch for memory
|
|
510
|
+
num_epochs=20 # Reduced for example
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# 4. Memory-efficient prediction
|
|
514
|
+
test_latent = np.random.randn(5, 100).astype(np.float32)
|
|
515
|
+
predictions = decoder.predict(test_latent, batch_size=2)
|
|
516
|
+
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
517
|
+
|
|
518
|
+
# 5. Final memory check
|
|
519
|
+
final_memory = decoder.get_memory_info()
|
|
520
|
+
print(f"💾 Final memory usage: {final_memory}")
|
|
521
|
+
|
|
522
|
+
return decoder
|
|
523
|
+
|
|
524
|
+
if __name__ == "__main__":
|
|
525
|
+
example_usage()
|
|
526
|
+
|
|
527
|
+
'''
|
SURE/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
+
from .PerturbE import PerturbE
|
|
4
|
+
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
3
5
|
|
|
4
6
|
from . import utils
|
|
5
7
|
from . import codebook
|
|
@@ -8,5 +10,7 @@ from . import DensityFlow
|
|
|
8
10
|
from . import atac
|
|
9
11
|
from . import flow
|
|
10
12
|
from . import perturb
|
|
13
|
+
from . import PerturbE
|
|
14
|
+
from . import TranscriptomeDecoder
|
|
11
15
|
|
|
12
|
-
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
16
|
+
__all__ = ['SURE', 'DensityFlow', 'PerturbE', 'TranscriptomeDecoder', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
SURE/perturb/perturb.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
SURE/utils/custom_mlp.py
CHANGED
|
@@ -240,12 +240,49 @@ class ZeroBiasMLP(nn.Module):
|
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
242
|
if len(y.shape)==2:
|
|
243
|
-
|
|
243
|
+
if type(x)==list:
|
|
244
|
+
mask[x[1][:,0]>0,:] = 1
|
|
245
|
+
else:
|
|
246
|
+
mask[x[:,0]>0,:] = 1
|
|
244
247
|
elif len(y.shape)==3:
|
|
245
|
-
|
|
248
|
+
if type(x)==list:
|
|
249
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
250
|
+
else:
|
|
251
|
+
mask[:,x[:,0]>0,:] = 1
|
|
246
252
|
return y*mask
|
|
247
253
|
|
|
248
254
|
|
|
255
|
+
|
|
256
|
+
class ZeroBiasMLP2(nn.Module):
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
mlp_sizes,
|
|
260
|
+
activation=nn.ReLU,
|
|
261
|
+
output_activation=None,
|
|
262
|
+
post_layer_fct=lambda layer_ix, total_layers, layer: None,
|
|
263
|
+
post_act_fct=lambda layer_ix, total_layers, layer: None,
|
|
264
|
+
allow_broadcast=False,
|
|
265
|
+
use_cuda=False,
|
|
266
|
+
):
|
|
267
|
+
# init the module object
|
|
268
|
+
super().__init__()
|
|
269
|
+
self.mlp = MLP(mlp_sizes=mlp_sizes,
|
|
270
|
+
activation=activation,
|
|
271
|
+
output_activation=output_activation,
|
|
272
|
+
post_layer_fct=post_layer_fct,
|
|
273
|
+
post_act_fct=post_act_fct,
|
|
274
|
+
allow_broadcast=allow_broadcast,
|
|
275
|
+
use_cuda=use_cuda,
|
|
276
|
+
bias=True)
|
|
277
|
+
|
|
278
|
+
# pass through our sequential for the output!
|
|
279
|
+
def forward(self, x):
|
|
280
|
+
y = self.mlp(x)
|
|
281
|
+
mask = torch.zeros_like(y)
|
|
282
|
+
x_sum = torch.sum(x, dim=1)
|
|
283
|
+
mask[x_sum>0,:] = 1
|
|
284
|
+
return y*mask
|
|
285
|
+
|
|
249
286
|
class HDMLP(nn.Module):
|
|
250
287
|
def __init__(
|
|
251
288
|
self,
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
2
|
-
SURE/
|
|
1
|
+
SURE/DensityFlow.py,sha256=YvaE9aPbAC2U7WhTye5i2AMtcw0BI_qS3gv9SP4aE0k,56676
|
|
2
|
+
SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
|
|
3
3
|
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
4
|
-
SURE/
|
|
4
|
+
SURE/TranscriptomeDecoder.py,sha256=4Ai0AeXnEwgUHB-gDsS9v3pHWFkawnTCzdidSRlXlnk,20337
|
|
5
|
+
SURE/__init__.py,sha256=pNSGQ4BMqMXBAPHpFOYNB8_0vFW-RqPy3rr5fvdEEyU,473
|
|
5
6
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
6
7
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
7
8
|
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
@@ -13,14 +14,14 @@ SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
|
13
14
|
SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
|
|
14
15
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
15
16
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
16
|
-
SURE/perturb/perturb.py,sha256=
|
|
17
|
+
SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
|
|
17
18
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
18
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
19
|
+
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
19
20
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
20
21
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
21
|
-
sure_tools-2.
|
|
22
|
-
sure_tools-2.
|
|
23
|
-
sure_tools-2.
|
|
24
|
-
sure_tools-2.
|
|
25
|
-
sure_tools-2.
|
|
26
|
-
sure_tools-2.
|
|
22
|
+
sure_tools-2.4.3.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
23
|
+
sure_tools-2.4.3.dist-info/METADATA,sha256=GX1UIc4xRMWryqUkvvF9xbtTSke6nIpFhUrHT8IMagU,2677
|
|
24
|
+
sure_tools-2.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
sure_tools-2.4.3.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
26
|
+
sure_tools-2.4.3.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
27
|
+
sure_tools-2.4.3.dist-info/RECORD,,
|
|
File without changes
|