SURE-tools 2.4.2__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/EfficientTranscriptomeDecoder.py +607 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +273 -289
- SURE/__init__.py +6 -1
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/METADATA +1 -1
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/RECORD +10 -8
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.13.dist-info}/top_level.txt +0 -0
SURE/TranscriptomeDecoder.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
import torch.optim as optim
|
|
4
3
|
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
5
|
from torch.utils.data import Dataset, DataLoader
|
|
6
6
|
import numpy as np
|
|
7
|
-
from typing import Dict,
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
|
-
from tqdm import tqdm
|
|
7
|
+
from typing import Dict, Optional
|
|
10
8
|
import warnings
|
|
11
9
|
warnings.filterwarnings('ignore')
|
|
12
10
|
|
|
13
11
|
class TranscriptomeDecoder:
|
|
12
|
+
"""Transcriptome decoder"""
|
|
13
|
+
|
|
14
14
|
def __init__(self,
|
|
15
15
|
latent_dim: int = 100,
|
|
16
16
|
gene_dim: int = 60000,
|
|
17
|
-
hidden_dim: int = 512,
|
|
17
|
+
hidden_dim: int = 512,
|
|
18
18
|
device: str = None):
|
|
19
19
|
"""
|
|
20
|
-
|
|
20
|
+
Simple but powerful decoder for latent to transcriptome mapping
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
latent_dim: Latent variable dimension (typically 50-100)
|
|
24
24
|
gene_dim: Number of genes (full transcriptome ~60,000)
|
|
25
|
-
hidden_dim: Hidden dimension
|
|
25
|
+
hidden_dim: Hidden dimension optimized
|
|
26
26
|
device: Computation device
|
|
27
27
|
"""
|
|
28
28
|
self.latent_dim = latent_dim
|
|
@@ -30,10 +30,6 @@ class TranscriptomeDecoder:
|
|
|
30
30
|
self.hidden_dim = hidden_dim
|
|
31
31
|
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
32
32
|
|
|
33
|
-
# Memory optimization settings
|
|
34
|
-
self.gradient_checkpointing = True
|
|
35
|
-
self.mixed_precision = True
|
|
36
|
-
|
|
37
33
|
# Initialize model
|
|
38
34
|
self.model = self._build_model()
|
|
39
35
|
self.model.to(self.device)
|
|
@@ -43,123 +39,47 @@ class TranscriptomeDecoder:
|
|
|
43
39
|
self.training_history = None
|
|
44
40
|
self.best_val_loss = float('inf')
|
|
45
41
|
|
|
46
|
-
print(f"🚀
|
|
42
|
+
print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
|
|
47
43
|
print(f" - Latent Dimension: {latent_dim}")
|
|
48
44
|
print(f" - Gene Dimension: {gene_dim}")
|
|
49
45
|
print(f" - Hidden Dimension: {hidden_dim}")
|
|
50
46
|
print(f" - Device: {self.device}")
|
|
51
47
|
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
52
48
|
|
|
53
|
-
class MemoryEfficientBlock(nn.Module):
|
|
54
|
-
"""Memory-efficient building block with gradient checkpointing"""
|
|
55
|
-
def __init__(self, input_dim, output_dim, use_checkpointing=True):
|
|
56
|
-
super().__init__()
|
|
57
|
-
self.use_checkpointing = use_checkpointing
|
|
58
|
-
self.net = nn.Sequential(
|
|
59
|
-
nn.Linear(input_dim, output_dim),
|
|
60
|
-
nn.BatchNorm1d(output_dim),
|
|
61
|
-
nn.GELU(),
|
|
62
|
-
nn.Dropout(0.1)
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
def forward(self, x):
|
|
66
|
-
if self.use_checkpointing and self.training:
|
|
67
|
-
return torch.utils.checkpoint.checkpoint(self.net, x)
|
|
68
|
-
return self.net(x)
|
|
69
|
-
|
|
70
|
-
class SparseGeneProjection(nn.Module):
|
|
71
|
-
"""Sparse gene projection to reduce memory usage"""
|
|
72
|
-
def __init__(self, latent_dim, gene_dim, projection_dim=256):
|
|
73
|
-
super().__init__()
|
|
74
|
-
self.projection_dim = projection_dim
|
|
75
|
-
self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
|
|
76
|
-
self.latent_projection = nn.Linear(latent_dim, projection_dim)
|
|
77
|
-
self.activation = nn.GELU()
|
|
78
|
-
|
|
79
|
-
def forward(self, latent):
|
|
80
|
-
# Project latent to gene space efficiently
|
|
81
|
-
batch_size = latent.shape[0]
|
|
82
|
-
latent_proj = self.latent_projection(latent) # [batch, projection_dim]
|
|
83
|
-
|
|
84
|
-
# Efficient matrix multiplication
|
|
85
|
-
gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
|
|
86
|
-
output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
|
|
87
|
-
|
|
88
|
-
return self.activation(output)
|
|
89
|
-
|
|
90
|
-
class ChunkedTransformer(nn.Module):
|
|
91
|
-
"""Process genes in chunks to reduce memory usage"""
|
|
92
|
-
def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
|
|
93
|
-
super().__init__()
|
|
94
|
-
self.chunk_size = chunk_size
|
|
95
|
-
self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
|
|
96
|
-
self.layers = nn.ModuleList([
|
|
97
|
-
nn.Sequential(
|
|
98
|
-
nn.Linear(hidden_dim, hidden_dim),
|
|
99
|
-
nn.GELU(),
|
|
100
|
-
nn.Dropout(0.1),
|
|
101
|
-
nn.Linear(hidden_dim, hidden_dim),
|
|
102
|
-
) for _ in range(num_layers)
|
|
103
|
-
])
|
|
104
|
-
|
|
105
|
-
def forward(self, x):
|
|
106
|
-
# Process in chunks to save memory
|
|
107
|
-
batch_size = x.shape[0]
|
|
108
|
-
output = torch.zeros_like(x)
|
|
109
|
-
|
|
110
|
-
for i in range(self.num_chunks):
|
|
111
|
-
start_idx = i * self.chunk_size
|
|
112
|
-
end_idx = min((i + 1) * self.chunk_size, x.shape[1])
|
|
113
|
-
|
|
114
|
-
chunk = x[:, start_idx:end_idx]
|
|
115
|
-
for layer in self.layers:
|
|
116
|
-
chunk = layer(chunk) + chunk # Residual connection
|
|
117
|
-
|
|
118
|
-
output[:, start_idx:end_idx] = chunk
|
|
119
|
-
|
|
120
|
-
return output
|
|
121
|
-
|
|
122
49
|
class Decoder(nn.Module):
|
|
123
|
-
"""
|
|
124
|
-
|
|
50
|
+
"""Memory-efficient decoder architecture with dimension handling"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
|
|
125
53
|
super().__init__()
|
|
126
54
|
self.latent_dim = latent_dim
|
|
127
55
|
self.gene_dim = gene_dim
|
|
128
56
|
self.hidden_dim = hidden_dim
|
|
129
57
|
|
|
130
|
-
# Stage 1: Latent expansion
|
|
58
|
+
# Stage 1: Latent variable expansion
|
|
131
59
|
self.latent_expansion = nn.Sequential(
|
|
132
60
|
nn.Linear(latent_dim, hidden_dim * 2),
|
|
61
|
+
nn.BatchNorm1d(hidden_dim * 2),
|
|
133
62
|
nn.GELU(),
|
|
134
63
|
nn.Dropout(0.1),
|
|
135
64
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
65
|
+
nn.BatchNorm1d(hidden_dim),
|
|
66
|
+
nn.GELU(),
|
|
136
67
|
)
|
|
137
68
|
|
|
138
|
-
# Stage 2:
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
|
|
145
|
-
gene_dim, hidden_dim, chunk_size=2000, num_layers=3
|
|
69
|
+
# Stage 2: Direct projection to gene dimension (simpler approach)
|
|
70
|
+
self.gene_projector = nn.Sequential(
|
|
71
|
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
72
|
+
nn.GELU(),
|
|
73
|
+
nn.Dropout(0.1),
|
|
74
|
+
nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
|
|
146
75
|
)
|
|
147
76
|
|
|
148
|
-
# Stage
|
|
149
|
-
self.
|
|
150
|
-
nn.
|
|
151
|
-
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
152
|
-
nn.GELU(),
|
|
153
|
-
nn.Linear(hidden_dim // 2, 1)
|
|
154
|
-
) for _ in range(2) # Reduced from 3 to 2 heads
|
|
155
|
-
])
|
|
156
|
-
|
|
157
|
-
# Adaptive fusion
|
|
158
|
-
self.fusion_gate = nn.Sequential(
|
|
159
|
-
nn.Linear(hidden_dim, hidden_dim // 4),
|
|
77
|
+
# Stage 3: Lightweight gene interaction
|
|
78
|
+
self.gene_interaction = nn.Sequential(
|
|
79
|
+
nn.Conv1d(1, 32, kernel_size=3, padding=1),
|
|
160
80
|
nn.GELU(),
|
|
161
|
-
nn.
|
|
162
|
-
nn.
|
|
81
|
+
nn.Dropout1d(0.1),
|
|
82
|
+
nn.Conv1d(32, 1, kernel_size=3, padding=1),
|
|
163
83
|
)
|
|
164
84
|
|
|
165
85
|
# Output scaling
|
|
@@ -169,111 +89,113 @@ class TranscriptomeDecoder:
|
|
|
169
89
|
self._init_weights()
|
|
170
90
|
|
|
171
91
|
def _init_weights(self):
|
|
92
|
+
"""Weight initialization"""
|
|
172
93
|
for module in self.modules():
|
|
173
94
|
if isinstance(module, nn.Linear):
|
|
174
95
|
nn.init.xavier_uniform_(module.weight)
|
|
175
96
|
if module.bias is not None:
|
|
176
97
|
nn.init.zeros_(module.bias)
|
|
98
|
+
elif isinstance(module, nn.Conv1d):
|
|
99
|
+
nn.init.kaiming_uniform_(module.weight)
|
|
100
|
+
if module.bias is not None:
|
|
101
|
+
nn.init.zeros_(module.bias)
|
|
177
102
|
|
|
178
|
-
def forward(self, latent):
|
|
103
|
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
|
179
104
|
batch_size = latent.shape[0]
|
|
180
105
|
|
|
181
|
-
# 1.
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
# 2. Gene projection (memory efficient)
|
|
185
|
-
gene_features = self.gene_projection(latent)
|
|
106
|
+
# 1. Expand latent variables
|
|
107
|
+
latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
|
|
186
108
|
|
|
187
|
-
#
|
|
188
|
-
|
|
189
|
-
gene_features = gene_features + latent_expanded.unsqueeze(1)
|
|
109
|
+
# 2. Direct projection to gene dimension
|
|
110
|
+
gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
|
|
190
111
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
# Process output in chunks
|
|
198
|
-
chunk_size = 5000
|
|
199
|
-
for i in range(0, self.gene_dim, chunk_size):
|
|
200
|
-
end_idx = min(i + chunk_size, self.gene_dim)
|
|
201
|
-
chunk = gene_features[:, i:end_idx]
|
|
202
|
-
|
|
203
|
-
head_outputs = []
|
|
204
|
-
for head in self.output_heads:
|
|
205
|
-
head_out = head(chunk).squeeze(-1)
|
|
206
|
-
head_outputs.append(head_out)
|
|
207
|
-
|
|
208
|
-
# Adaptive fusion
|
|
209
|
-
gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
|
|
210
|
-
gate_weights = gate_weights.unsqueeze(1)
|
|
211
|
-
|
|
212
|
-
# Weighted fusion
|
|
213
|
-
chunk_output = torch.zeros_like(head_outputs[0])
|
|
214
|
-
for j, head_out in enumerate(head_outputs):
|
|
215
|
-
chunk_output = chunk_output + gate_weights[:, :, j] * head_out
|
|
216
|
-
|
|
217
|
-
final_output[:, i:end_idx] = chunk_output
|
|
112
|
+
# 3. Gene interaction with dimension safety
|
|
113
|
+
if self.gene_dim > 1: # Only apply if gene_dim > 1
|
|
114
|
+
gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
|
|
115
|
+
interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
|
|
116
|
+
gene_output = gene_output + interaction_output # Residual connection
|
|
117
|
+
gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
|
|
218
118
|
|
|
219
|
-
# Final activation
|
|
220
|
-
|
|
119
|
+
# 4. Final activation (ensure non-negative)
|
|
120
|
+
gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
|
|
221
121
|
|
|
222
|
-
return
|
|
122
|
+
return gene_output
|
|
223
123
|
|
|
224
124
|
def _build_model(self):
|
|
225
|
-
"""Build model"""
|
|
125
|
+
"""Build the decoder model"""
|
|
226
126
|
return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
|
|
227
127
|
|
|
128
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
129
|
+
"""Create dataset with dimension validation"""
|
|
130
|
+
class SimpleDataset(Dataset):
|
|
131
|
+
def __init__(self, latent, expression):
|
|
132
|
+
# Ensure dimensions match
|
|
133
|
+
assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
|
|
134
|
+
assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
|
|
135
|
+
assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
|
|
136
|
+
|
|
137
|
+
self.latent = torch.FloatTensor(latent)
|
|
138
|
+
self.expression = torch.FloatTensor(expression)
|
|
139
|
+
|
|
140
|
+
def __len__(self):
|
|
141
|
+
return len(self.latent)
|
|
142
|
+
|
|
143
|
+
def __getitem__(self, idx):
|
|
144
|
+
return self.latent[idx], self.expression[idx]
|
|
145
|
+
|
|
146
|
+
return SimpleDataset(latent_data, expression_data)
|
|
147
|
+
|
|
228
148
|
def train(self,
|
|
229
149
|
train_latent: np.ndarray,
|
|
230
150
|
train_expression: np.ndarray,
|
|
231
151
|
val_latent: np.ndarray = None,
|
|
232
152
|
val_expression: np.ndarray = None,
|
|
233
|
-
batch_size: int =
|
|
153
|
+
batch_size: int = 32,
|
|
234
154
|
num_epochs: int = 100,
|
|
235
155
|
learning_rate: float = 1e-4,
|
|
236
156
|
checkpoint_path: str = 'transcriptome_decoder.pth'):
|
|
237
157
|
"""
|
|
238
|
-
|
|
158
|
+
Train the decoder model with dimension safety
|
|
239
159
|
|
|
240
160
|
Args:
|
|
241
|
-
train_latent: Training latent variables
|
|
242
|
-
train_expression: Training expression data
|
|
243
|
-
val_latent: Validation latent variables
|
|
244
|
-
val_expression: Validation expression data
|
|
245
|
-
batch_size:
|
|
161
|
+
train_latent: Training latent variables [n_samples, latent_dim]
|
|
162
|
+
train_expression: Training expression data [n_samples, gene_dim]
|
|
163
|
+
val_latent: Validation latent variables (optional)
|
|
164
|
+
val_expression: Validation expression data (optional)
|
|
165
|
+
batch_size: Batch size optimized for memory
|
|
246
166
|
num_epochs: Number of training epochs
|
|
247
167
|
learning_rate: Learning rate
|
|
248
|
-
checkpoint_path:
|
|
168
|
+
checkpoint_path: Path to save the best model
|
|
249
169
|
"""
|
|
250
|
-
print("🚀 Starting
|
|
251
|
-
print(f"📊 Batch size: {batch_size}")
|
|
170
|
+
print("🚀 Starting training...")
|
|
252
171
|
|
|
253
|
-
#
|
|
254
|
-
|
|
255
|
-
if
|
|
256
|
-
|
|
172
|
+
# Dimension validation
|
|
173
|
+
self._validate_data_dimensions(train_latent, train_expression, "Training")
|
|
174
|
+
if val_latent is not None and val_expression is not None:
|
|
175
|
+
self._validate_data_dimensions(val_latent, val_expression, "Validation")
|
|
257
176
|
|
|
258
177
|
# Data preparation
|
|
259
|
-
train_dataset = self.
|
|
178
|
+
train_dataset = self._create_safe_dataset(train_latent, train_expression)
|
|
260
179
|
|
|
261
180
|
if val_latent is not None and val_expression is not None:
|
|
262
|
-
val_dataset = self.
|
|
181
|
+
val_dataset = self._create_safe_dataset(val_latent, val_expression)
|
|
182
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
183
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
184
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
263
185
|
else:
|
|
264
186
|
# Auto split
|
|
265
187
|
train_size = int(0.9 * len(train_dataset))
|
|
266
188
|
val_size = len(train_dataset) - train_size
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
)
|
|
189
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
190
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
191
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
192
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
270
193
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
pin_memory=True, num_workers=2)
|
|
194
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
195
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
196
|
+
print(f"📊 Batch size: {batch_size}")
|
|
275
197
|
|
|
276
|
-
# Optimizer
|
|
198
|
+
# Optimizer configuration
|
|
277
199
|
optimizer = optim.AdamW(
|
|
278
200
|
self.model.parameters(),
|
|
279
201
|
lr=learning_rate,
|
|
@@ -284,130 +206,181 @@ class TranscriptomeDecoder:
|
|
|
284
206
|
# Learning rate scheduler
|
|
285
207
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
286
208
|
|
|
287
|
-
# Loss function
|
|
288
|
-
|
|
209
|
+
# Loss function with dimension safety
|
|
210
|
+
def safe_loss(pred, target):
|
|
211
|
+
# Ensure dimensions match
|
|
212
|
+
if pred.shape != target.shape:
|
|
213
|
+
print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
|
|
214
|
+
# Truncate to minimum dimension (safety measure)
|
|
215
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
216
|
+
pred = pred[:, :min_dim]
|
|
217
|
+
target = target[:, :min_dim]
|
|
218
|
+
|
|
219
|
+
def correlation_loss(pred, target):
|
|
220
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
221
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
222
|
+
|
|
223
|
+
correlation = (pred_centered * target_centered).sum(dim=1) / (
|
|
224
|
+
torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
|
|
225
|
+
torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return 1 - correlation.mean()
|
|
229
|
+
|
|
230
|
+
mse_loss = F.mse_loss(pred, target)
|
|
231
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
232
|
+
corr_loss = correlation_loss(pred, target)
|
|
233
|
+
return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
|
|
289
234
|
|
|
290
235
|
# Training history
|
|
291
236
|
history = {
|
|
292
|
-
'train_loss': [],
|
|
293
|
-
'
|
|
237
|
+
'train_loss': [],
|
|
238
|
+
'val_loss': [],
|
|
239
|
+
'learning_rate': []
|
|
294
240
|
}
|
|
295
241
|
|
|
296
242
|
best_val_loss = float('inf')
|
|
243
|
+
patience = 15
|
|
244
|
+
patience_counter = 0
|
|
297
245
|
|
|
246
|
+
print("\n📈 Starting training loop...")
|
|
298
247
|
for epoch in range(1, num_epochs + 1):
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
# Training phase with memory monitoring
|
|
302
|
-
train_loss = self._train_epoch(
|
|
303
|
-
train_loader, optimizer, criterion, scaler if self.mixed_precision else None
|
|
304
|
-
)
|
|
248
|
+
# Training phase
|
|
249
|
+
train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
|
|
305
250
|
|
|
306
251
|
# Validation phase
|
|
307
|
-
val_loss = self._validate_epoch(val_loader,
|
|
252
|
+
val_loss = self._validate_epoch(val_loader, safe_loss)
|
|
308
253
|
|
|
309
254
|
# Update scheduler
|
|
310
255
|
scheduler.step()
|
|
256
|
+
current_lr = scheduler.get_last_lr()[0]
|
|
311
257
|
|
|
312
258
|
# Record history
|
|
313
259
|
history['train_loss'].append(train_loss)
|
|
314
260
|
history['val_loss'].append(val_loss)
|
|
315
|
-
history['learning_rate'].append(
|
|
316
|
-
|
|
317
|
-
# Memory usage tracking
|
|
318
|
-
if torch.cuda.is_available():
|
|
319
|
-
memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
|
|
320
|
-
history['memory_usage'].append(memory_used)
|
|
321
|
-
print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
|
|
261
|
+
history['learning_rate'].append(current_lr)
|
|
322
262
|
|
|
323
|
-
|
|
324
|
-
|
|
263
|
+
# Print progress
|
|
264
|
+
if epoch % 5 == 0 or epoch == 1:
|
|
265
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
266
|
+
f"Train Loss: {train_loss:.4f} | "
|
|
267
|
+
f"Val Loss: {val_loss:.4f} | "
|
|
268
|
+
f"LR: {current_lr:.2e}")
|
|
325
269
|
|
|
326
|
-
#
|
|
270
|
+
# Early stopping and model saving
|
|
327
271
|
if val_loss < best_val_loss:
|
|
328
272
|
best_val_loss = val_loss
|
|
273
|
+
patience_counter = 0
|
|
329
274
|
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
330
|
-
|
|
275
|
+
if epoch % 10 == 0:
|
|
276
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
277
|
+
else:
|
|
278
|
+
patience_counter += 1
|
|
279
|
+
if patience_counter >= patience:
|
|
280
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
281
|
+
break
|
|
331
282
|
|
|
283
|
+
# Training completed
|
|
332
284
|
self.is_trained = True
|
|
333
285
|
self.training_history = history
|
|
334
286
|
self.best_val_loss = best_val_loss
|
|
335
287
|
|
|
336
|
-
print(f"\n🎉 Training completed!
|
|
288
|
+
print(f"\n🎉 Training completed!")
|
|
289
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
290
|
+
print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
|
|
291
|
+
|
|
337
292
|
return history
|
|
338
293
|
|
|
339
|
-
def
|
|
340
|
-
"""
|
|
294
|
+
def _validate_data_dimensions(self, latent_data, expression_data, data_type):
|
|
295
|
+
"""Validate input data dimensions"""
|
|
296
|
+
assert latent_data.shape[1] == self.latent_dim, (
|
|
297
|
+
f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
298
|
+
assert expression_data.shape[1] == self.gene_dim, (
|
|
299
|
+
f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
|
|
300
|
+
assert latent_data.shape[0] == expression_data.shape[0], (
|
|
301
|
+
f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
|
|
302
|
+
print(f"✅ {data_type} data dimensions validated")
|
|
303
|
+
|
|
304
|
+
def _create_safe_dataset(self, latent_data, expression_data):
|
|
305
|
+
"""Create dataset with safety checks"""
|
|
306
|
+
class SafeDataset(Dataset):
|
|
307
|
+
def __init__(self, latent, expression):
|
|
308
|
+
self.latent = torch.FloatTensor(latent)
|
|
309
|
+
self.expression = torch.FloatTensor(expression)
|
|
310
|
+
|
|
311
|
+
# Safety check
|
|
312
|
+
if self.latent.shape[0] != self.expression.shape[0]:
|
|
313
|
+
raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
|
|
314
|
+
|
|
315
|
+
def __len__(self):
|
|
316
|
+
return len(self.latent)
|
|
317
|
+
|
|
318
|
+
def __getitem__(self, idx):
|
|
319
|
+
return self.latent[idx], self.expression[idx]
|
|
320
|
+
|
|
321
|
+
return SafeDataset(latent_data, expression_data)
|
|
322
|
+
|
|
323
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
324
|
+
"""Train for one epoch with dimension safety"""
|
|
341
325
|
self.model.train()
|
|
342
326
|
total_loss = 0
|
|
343
327
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
target = target.to(self.device, non_blocking=True)
|
|
348
|
-
|
|
349
|
-
optimizer.zero_grad(set_to_none=True) # Memory optimization
|
|
328
|
+
for batch_idx, (latent, target) in enumerate(train_loader):
|
|
329
|
+
latent = latent.to(self.device)
|
|
330
|
+
target = target.to(self.device)
|
|
350
331
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
332
|
+
# Dimension check
|
|
333
|
+
if latent.shape[1] != self.latent_dim:
|
|
334
|
+
print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
|
|
335
|
+
continue
|
|
355
336
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
scaler.update()
|
|
359
|
-
else:
|
|
360
|
-
pred = self.model(latent)
|
|
361
|
-
loss = criterion(pred, target)
|
|
362
|
-
loss.backward()
|
|
363
|
-
optimizer.step()
|
|
337
|
+
optimizer.zero_grad()
|
|
338
|
+
pred = self.model(latent)
|
|
364
339
|
|
|
365
|
-
|
|
366
|
-
|
|
340
|
+
# Final dimension check before loss calculation
|
|
341
|
+
if pred.shape[1] != target.shape[1]:
|
|
342
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
343
|
+
pred = pred[:, :min_dim]
|
|
344
|
+
target = target[:, :min_dim]
|
|
345
|
+
if batch_idx == 0: # Only warn once
|
|
346
|
+
print(f"⚠️ Truncating to min dimension: {min_dim}")
|
|
347
|
+
|
|
348
|
+
loss = loss_fn(pred, target)
|
|
349
|
+
loss.backward()
|
|
367
350
|
|
|
368
|
-
#
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
351
|
+
# Gradient clipping for stability
|
|
352
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
353
|
+
optimizer.step()
|
|
354
|
+
|
|
355
|
+
total_loss += loss.item()
|
|
372
356
|
|
|
373
357
|
return total_loss / len(train_loader)
|
|
374
358
|
|
|
375
|
-
def _validate_epoch(self, val_loader,
|
|
376
|
-
"""
|
|
359
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
360
|
+
"""Validate for one epoch with dimension safety"""
|
|
377
361
|
self.model.eval()
|
|
378
362
|
total_loss = 0
|
|
379
363
|
|
|
380
364
|
with torch.no_grad():
|
|
381
|
-
for latent, target in val_loader:
|
|
382
|
-
latent = latent.to(self.device
|
|
383
|
-
target = target.to(self.device
|
|
365
|
+
for batch_idx, (latent, target) in enumerate(val_loader):
|
|
366
|
+
latent = latent.to(self.device)
|
|
367
|
+
target = target.to(self.device)
|
|
384
368
|
|
|
385
369
|
pred = self.model(latent)
|
|
386
|
-
loss = criterion(pred, target)
|
|
387
|
-
total_loss += loss.item()
|
|
388
370
|
|
|
389
|
-
#
|
|
390
|
-
|
|
371
|
+
# Dimension safety
|
|
372
|
+
if pred.shape[1] != target.shape[1]:
|
|
373
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
374
|
+
pred = pred[:, :min_dim]
|
|
375
|
+
target = target[:, :min_dim]
|
|
376
|
+
|
|
377
|
+
loss = loss_fn(pred, target)
|
|
378
|
+
total_loss += loss.item()
|
|
391
379
|
|
|
392
380
|
return total_loss / len(val_loader)
|
|
393
381
|
|
|
394
|
-
def _create_dataset(self, latent_data, expression_data):
|
|
395
|
-
"""Create dataset"""
|
|
396
|
-
class EfficientDataset(Dataset):
|
|
397
|
-
def __init__(self, latent, expression):
|
|
398
|
-
self.latent = torch.FloatTensor(latent)
|
|
399
|
-
self.expression = torch.FloatTensor(expression)
|
|
400
|
-
|
|
401
|
-
def __len__(self):
|
|
402
|
-
return len(self.latent)
|
|
403
|
-
|
|
404
|
-
def __getitem__(self, idx):
|
|
405
|
-
return self.latent[idx], self.expression[idx]
|
|
406
|
-
|
|
407
|
-
return EfficientDataset(latent_data, expression_data)
|
|
408
|
-
|
|
409
382
|
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
410
|
-
"""Save checkpoint"""
|
|
383
|
+
"""Save model checkpoint"""
|
|
411
384
|
torch.save({
|
|
412
385
|
'epoch': epoch,
|
|
413
386
|
'model_state_dict': self.model.state_dict(),
|
|
@@ -422,22 +395,26 @@ class TranscriptomeDecoder:
|
|
|
422
395
|
}
|
|
423
396
|
}, path)
|
|
424
397
|
|
|
425
|
-
def predict(self, latent_data: np.ndarray, batch_size: int =
|
|
398
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
426
399
|
"""
|
|
427
|
-
|
|
400
|
+
Predict gene expression from latent variables
|
|
428
401
|
|
|
429
402
|
Args:
|
|
430
403
|
latent_data: Latent variables [n_samples, latent_dim]
|
|
431
|
-
batch_size: Prediction batch size
|
|
404
|
+
batch_size: Prediction batch size
|
|
432
405
|
|
|
433
406
|
Returns:
|
|
434
407
|
expression: Predicted expression [n_samples, gene_dim]
|
|
435
408
|
"""
|
|
436
409
|
if not self.is_trained:
|
|
437
|
-
warnings.warn("Model not trained. Predictions may be inaccurate.")
|
|
410
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
438
411
|
|
|
439
412
|
self.model.eval()
|
|
440
413
|
|
|
414
|
+
# Input validation
|
|
415
|
+
if latent_data.shape[1] != self.latent_dim:
|
|
416
|
+
raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
417
|
+
|
|
441
418
|
if isinstance(latent_data, np.ndarray):
|
|
442
419
|
latent_data = torch.FloatTensor(latent_data)
|
|
443
420
|
|
|
@@ -448,76 +425,83 @@ class TranscriptomeDecoder:
|
|
|
448
425
|
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
449
426
|
batch_pred = self.model(batch_latent)
|
|
450
427
|
predictions.append(batch_pred.cpu())
|
|
451
|
-
|
|
452
|
-
# Clear memory
|
|
453
|
-
del batch_pred
|
|
454
|
-
if torch.cuda.is_available():
|
|
455
|
-
torch.cuda.empty_cache()
|
|
456
428
|
|
|
457
429
|
return torch.cat(predictions).numpy()
|
|
458
430
|
|
|
459
431
|
def load_model(self, model_path: str):
|
|
460
432
|
"""Load pre-trained model"""
|
|
461
433
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
434
|
+
|
|
435
|
+
# Check model configuration
|
|
436
|
+
if 'model_config' in checkpoint:
|
|
437
|
+
config = checkpoint['model_config']
|
|
438
|
+
if (config['latent_dim'] != self.latent_dim or
|
|
439
|
+
config['gene_dim'] != self.gene_dim):
|
|
440
|
+
print("⚠️ Model configuration mismatch. Reinitializing model.")
|
|
441
|
+
self.model = self._build_model()
|
|
442
|
+
self.model.to(self.device)
|
|
443
|
+
|
|
462
444
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
463
445
|
self.is_trained = True
|
|
464
446
|
self.training_history = checkpoint.get('training_history')
|
|
465
447
|
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
"""Get memory usage information"""
|
|
470
|
-
if torch.cuda.is_available():
|
|
471
|
-
memory_allocated = torch.cuda.memory_allocated() / 1024**3
|
|
472
|
-
memory_reserved = torch.cuda.memory_reserved() / 1024**3
|
|
473
|
-
return {
|
|
474
|
-
'allocated_gb': memory_allocated,
|
|
475
|
-
'reserved_gb': memory_reserved,
|
|
476
|
-
'available_gb': 20 - memory_allocated,
|
|
477
|
-
'utilization_percent': (memory_allocated / 20) * 100
|
|
478
|
-
}
|
|
479
|
-
return {'available_gb': 'N/A (CPU mode)'}
|
|
448
|
+
|
|
449
|
+
print(f"✅ Model loaded successfully!")
|
|
450
|
+
print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
|
|
480
451
|
|
|
452
|
+
def get_model_info(self) -> Dict:
|
|
453
|
+
"""Get model information"""
|
|
454
|
+
return {
|
|
455
|
+
'is_trained': self.is_trained,
|
|
456
|
+
'best_val_loss': self.best_val_loss,
|
|
457
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
458
|
+
'latent_dim': self.latent_dim,
|
|
459
|
+
'gene_dim': self.gene_dim,
|
|
460
|
+
'hidden_dim': self.hidden_dim,
|
|
461
|
+
'device': str(self.device)
|
|
462
|
+
}
|
|
481
463
|
'''
|
|
482
|
-
# Example usage
|
|
464
|
+
# Example usage
|
|
483
465
|
def example_usage():
|
|
484
|
-
"""
|
|
466
|
+
"""Example demonstration with dimension safety"""
|
|
485
467
|
|
|
486
|
-
# 1. Initialize
|
|
487
|
-
decoder =
|
|
468
|
+
# 1. Initialize decoder
|
|
469
|
+
decoder = SimpleTranscriptomeDecoder(
|
|
488
470
|
latent_dim=100,
|
|
489
471
|
gene_dim=2000, # Reduced for example
|
|
490
|
-
hidden_dim=256
|
|
472
|
+
hidden_dim=256
|
|
491
473
|
)
|
|
492
474
|
|
|
493
|
-
#
|
|
494
|
-
|
|
495
|
-
print(f"📊 Memory Info: {memory_info}")
|
|
496
|
-
|
|
497
|
-
# 2. Generate example data
|
|
498
|
-
n_samples = 500 # Reduced for memory
|
|
475
|
+
# 2. Generate example data with correct dimensions
|
|
476
|
+
n_samples = 1000
|
|
499
477
|
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
500
|
-
|
|
478
|
+
|
|
479
|
+
# Create simulated expression data
|
|
480
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
481
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
501
482
|
expression_data = np.maximum(expression_data, 0) # Non-negative
|
|
502
483
|
|
|
503
|
-
print(f"
|
|
484
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
504
485
|
|
|
505
|
-
# 3. Train
|
|
486
|
+
# 3. Train the model
|
|
506
487
|
history = decoder.train(
|
|
507
488
|
train_latent=latent_data,
|
|
508
489
|
train_expression=expression_data,
|
|
509
|
-
batch_size=
|
|
510
|
-
num_epochs=
|
|
490
|
+
batch_size=32,
|
|
491
|
+
num_epochs=50,
|
|
492
|
+
learning_rate=1e-4
|
|
511
493
|
)
|
|
512
494
|
|
|
513
|
-
# 4.
|
|
514
|
-
test_latent = np.random.randn(
|
|
515
|
-
predictions = decoder.predict(test_latent
|
|
495
|
+
# 4. Make predictions
|
|
496
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
497
|
+
predictions = decoder.predict(test_latent)
|
|
516
498
|
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
517
499
|
|
|
518
|
-
# 5.
|
|
519
|
-
|
|
520
|
-
print(f"
|
|
500
|
+
# 5. Get model info
|
|
501
|
+
info = decoder.get_model_info()
|
|
502
|
+
print(f"\n📋 Model Info:")
|
|
503
|
+
for key, value in info.items():
|
|
504
|
+
print(f" {key}: {value}")
|
|
521
505
|
|
|
522
506
|
return decoder
|
|
523
507
|
|