SURE-tools 2.4.5__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/EfficientTranscriptomeDecoder.py +607 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +273 -311
- SURE/__init__.py +6 -1
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/METADATA +1 -1
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/RECORD +10 -8
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.5.dist-info → sure_tools-2.4.13.dist-info}/top_level.txt +0 -0
SURE/TranscriptomeDecoder.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
import torch.optim as optim
|
|
4
3
|
import torch.nn.functional as F
|
|
4
|
+
import torch.optim as optim
|
|
5
5
|
from torch.utils.data import Dataset, DataLoader
|
|
6
6
|
import numpy as np
|
|
7
|
-
from typing import Dict,
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
|
-
from tqdm import tqdm
|
|
7
|
+
from typing import Dict, Optional
|
|
10
8
|
import warnings
|
|
11
9
|
warnings.filterwarnings('ignore')
|
|
12
10
|
|
|
13
11
|
class TranscriptomeDecoder:
|
|
12
|
+
"""Transcriptome decoder"""
|
|
13
|
+
|
|
14
14
|
def __init__(self,
|
|
15
15
|
latent_dim: int = 100,
|
|
16
16
|
gene_dim: int = 60000,
|
|
17
|
-
hidden_dim: int = 512,
|
|
17
|
+
hidden_dim: int = 512,
|
|
18
18
|
device: str = None):
|
|
19
19
|
"""
|
|
20
|
-
|
|
20
|
+
Simple but powerful decoder for latent to transcriptome mapping
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
latent_dim: Latent variable dimension (typically 50-100)
|
|
24
24
|
gene_dim: Number of genes (full transcriptome ~60,000)
|
|
25
|
-
hidden_dim: Hidden dimension
|
|
25
|
+
hidden_dim: Hidden dimension optimized
|
|
26
26
|
device: Computation device
|
|
27
27
|
"""
|
|
28
28
|
self.latent_dim = latent_dim
|
|
@@ -30,10 +30,6 @@ class TranscriptomeDecoder:
|
|
|
30
30
|
self.hidden_dim = hidden_dim
|
|
31
31
|
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
32
32
|
|
|
33
|
-
# Memory optimization settings
|
|
34
|
-
self.gradient_checkpointing = True
|
|
35
|
-
self.mixed_precision = True
|
|
36
|
-
|
|
37
33
|
# Initialize model
|
|
38
34
|
self.model = self._build_model()
|
|
39
35
|
self.model.to(self.device)
|
|
@@ -43,259 +39,163 @@ class TranscriptomeDecoder:
|
|
|
43
39
|
self.training_history = None
|
|
44
40
|
self.best_val_loss = float('inf')
|
|
45
41
|
|
|
46
|
-
print(f"🚀
|
|
42
|
+
print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
|
|
47
43
|
print(f" - Latent Dimension: {latent_dim}")
|
|
48
44
|
print(f" - Gene Dimension: {gene_dim}")
|
|
49
45
|
print(f" - Hidden Dimension: {hidden_dim}")
|
|
50
46
|
print(f" - Device: {self.device}")
|
|
51
47
|
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
52
48
|
|
|
53
|
-
class MemoryEfficientBlock(nn.Module):
|
|
54
|
-
"""Memory-efficient building block with gradient checkpointing"""
|
|
55
|
-
def __init__(self, input_dim, output_dim, use_checkpointing=True):
|
|
56
|
-
super().__init__()
|
|
57
|
-
self.use_checkpointing = use_checkpointing
|
|
58
|
-
self.net = nn.Sequential(
|
|
59
|
-
nn.Linear(input_dim, output_dim),
|
|
60
|
-
nn.BatchNorm1d(output_dim),
|
|
61
|
-
nn.GELU(),
|
|
62
|
-
nn.Dropout(0.1)
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
def forward(self, x):
|
|
66
|
-
if self.use_checkpointing and self.training:
|
|
67
|
-
return torch.utils.checkpoint.checkpoint(self.net, x)
|
|
68
|
-
return self.net(x)
|
|
69
|
-
|
|
70
|
-
class SparseGeneProjection(nn.Module):
|
|
71
|
-
"""Sparse gene projection to reduce memory usage"""
|
|
72
|
-
def __init__(self, latent_dim, gene_dim, projection_dim=256):
|
|
73
|
-
super().__init__()
|
|
74
|
-
self.projection_dim = projection_dim
|
|
75
|
-
self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
|
|
76
|
-
self.latent_projection = nn.Linear(latent_dim, projection_dim)
|
|
77
|
-
self.activation = nn.GELU()
|
|
78
|
-
|
|
79
|
-
def forward(self, latent):
|
|
80
|
-
# Project latent to gene space efficiently
|
|
81
|
-
batch_size = latent.shape[0]
|
|
82
|
-
latent_proj = self.latent_projection(latent) # [batch, projection_dim]
|
|
83
|
-
|
|
84
|
-
# Efficient matrix multiplication
|
|
85
|
-
gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
|
|
86
|
-
output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
|
|
87
|
-
|
|
88
|
-
return self.activation(output)
|
|
89
|
-
|
|
90
|
-
class ChunkedTransformer(nn.Module):
|
|
91
|
-
def __init__(self, gene_dim, hidden_dim=512, chunk_size=2000, num_layers=3):
|
|
92
|
-
super().__init__()
|
|
93
|
-
self.chunk_size = chunk_size
|
|
94
|
-
self.hidden_dim = hidden_dim
|
|
95
|
-
self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
|
|
96
|
-
|
|
97
|
-
# 共享的Transformer层
|
|
98
|
-
self.transformer_layers = nn.ModuleList([
|
|
99
|
-
nn.Sequential(
|
|
100
|
-
nn.Linear(hidden_dim, hidden_dim),
|
|
101
|
-
nn.GELU(),
|
|
102
|
-
nn.Dropout(0.1),
|
|
103
|
-
nn.Linear(hidden_dim, hidden_dim),
|
|
104
|
-
) for _ in range(num_layers)
|
|
105
|
-
])
|
|
106
|
-
|
|
107
|
-
# 每个chunk独立的投影层
|
|
108
|
-
self.input_projections = nn.ModuleList([
|
|
109
|
-
nn.Linear(min(chunk_size, gene_dim - i * chunk_size), hidden_dim)
|
|
110
|
-
for i in range(self.num_chunks)
|
|
111
|
-
])
|
|
112
|
-
self.output_projections = nn.ModuleList([
|
|
113
|
-
nn.Linear(hidden_dim, min(chunk_size, gene_dim - i * chunk_size))
|
|
114
|
-
for i in range(self.num_chunks)
|
|
115
|
-
])
|
|
116
|
-
|
|
117
|
-
def forward(self, x):
|
|
118
|
-
batch_size, gene_dim = x.shape
|
|
119
|
-
output = torch.zeros_like(x)
|
|
120
|
-
|
|
121
|
-
for i in range(self.num_chunks):
|
|
122
|
-
start_idx = i * self.chunk_size
|
|
123
|
-
end_idx = min((i + 1) * self.chunk_size, gene_dim)
|
|
124
|
-
current_chunk_size = end_idx - start_idx
|
|
125
|
-
|
|
126
|
-
chunk = x[:, start_idx:end_idx] # [batch_size, current_chunk_size]
|
|
127
|
-
|
|
128
|
-
# 投影到hidden_dim
|
|
129
|
-
chunk_proj = self.input_projections[i](chunk) # [batch_size, hidden_dim]
|
|
130
|
-
|
|
131
|
-
# Transformer处理
|
|
132
|
-
for layer in self.transformer_layers:
|
|
133
|
-
chunk_proj = layer(chunk_proj) + chunk_proj
|
|
134
|
-
|
|
135
|
-
# 投影回原始维度
|
|
136
|
-
chunk_out = self.output_projections[i](chunk_proj) # [batch_size, current_chunk_size]
|
|
137
|
-
|
|
138
|
-
output[:, start_idx:end_idx] = chunk_out
|
|
139
|
-
|
|
140
|
-
return output
|
|
141
|
-
|
|
142
49
|
class Decoder(nn.Module):
|
|
143
|
-
"""
|
|
144
|
-
|
|
50
|
+
"""Memory-efficient decoder architecture with dimension handling"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
|
|
145
53
|
super().__init__()
|
|
146
54
|
self.latent_dim = latent_dim
|
|
147
55
|
self.gene_dim = gene_dim
|
|
148
56
|
self.hidden_dim = hidden_dim
|
|
149
57
|
|
|
150
|
-
# Stage 1: Latent expansion
|
|
58
|
+
# Stage 1: Latent variable expansion
|
|
151
59
|
self.latent_expansion = nn.Sequential(
|
|
152
60
|
nn.Linear(latent_dim, hidden_dim * 2),
|
|
61
|
+
nn.BatchNorm1d(hidden_dim * 2),
|
|
153
62
|
nn.GELU(),
|
|
154
63
|
nn.Dropout(0.1),
|
|
155
64
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
65
|
+
nn.BatchNorm1d(hidden_dim),
|
|
66
|
+
nn.GELU(),
|
|
156
67
|
)
|
|
157
68
|
|
|
158
|
-
# Stage 2:
|
|
159
|
-
self.
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
|
|
165
|
-
gene_dim, hidden_dim, chunk_size=2000, num_layers=3
|
|
69
|
+
# Stage 2: Direct projection to gene dimension (simpler approach)
|
|
70
|
+
self.gene_projector = nn.Sequential(
|
|
71
|
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
72
|
+
nn.GELU(),
|
|
73
|
+
nn.Dropout(0.1),
|
|
74
|
+
nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
|
|
166
75
|
)
|
|
167
76
|
|
|
168
|
-
# Stage
|
|
169
|
-
self.
|
|
170
|
-
nn.
|
|
171
|
-
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
172
|
-
nn.GELU(),
|
|
173
|
-
nn.Linear(hidden_dim // 2, 1)
|
|
174
|
-
) for _ in range(2) # Reduced from 3 to 2 heads
|
|
175
|
-
])
|
|
176
|
-
|
|
177
|
-
# Adaptive fusion
|
|
178
|
-
self.fusion_gate = nn.Sequential(
|
|
179
|
-
nn.Linear(hidden_dim, hidden_dim // 4),
|
|
77
|
+
# Stage 3: Lightweight gene interaction
|
|
78
|
+
self.gene_interaction = nn.Sequential(
|
|
79
|
+
nn.Conv1d(1, 32, kernel_size=3, padding=1),
|
|
180
80
|
nn.GELU(),
|
|
181
|
-
nn.
|
|
182
|
-
nn.
|
|
81
|
+
nn.Dropout1d(0.1),
|
|
82
|
+
nn.Conv1d(32, 1, kernel_size=3, padding=1),
|
|
183
83
|
)
|
|
184
84
|
|
|
185
85
|
# Output scaling
|
|
186
86
|
self.output_scale = nn.Parameter(torch.ones(1))
|
|
187
87
|
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
188
88
|
|
|
189
|
-
self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
|
|
190
|
-
|
|
191
89
|
self._init_weights()
|
|
192
90
|
|
|
193
91
|
def _init_weights(self):
|
|
92
|
+
"""Weight initialization"""
|
|
194
93
|
for module in self.modules():
|
|
195
94
|
if isinstance(module, nn.Linear):
|
|
196
95
|
nn.init.xavier_uniform_(module.weight)
|
|
197
96
|
if module.bias is not None:
|
|
198
97
|
nn.init.zeros_(module.bias)
|
|
98
|
+
elif isinstance(module, nn.Conv1d):
|
|
99
|
+
nn.init.kaiming_uniform_(module.weight)
|
|
100
|
+
if module.bias is not None:
|
|
101
|
+
nn.init.zeros_(module.bias)
|
|
199
102
|
|
|
200
|
-
def forward(self, latent):
|
|
103
|
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
|
201
104
|
batch_size = latent.shape[0]
|
|
202
105
|
|
|
203
|
-
# 1.
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
# 2. Gene projection (memory efficient)
|
|
207
|
-
gene_features = self.gene_projection(latent)
|
|
106
|
+
# 1. Expand latent variables
|
|
107
|
+
latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
|
|
208
108
|
|
|
209
|
-
#
|
|
210
|
-
|
|
211
|
-
gene_features = gene_features + latent_gene_injection
|
|
109
|
+
# 2. Direct projection to gene dimension
|
|
110
|
+
gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
|
|
212
111
|
|
|
213
|
-
#
|
|
214
|
-
|
|
112
|
+
# 3. Gene interaction with dimension safety
|
|
113
|
+
if self.gene_dim > 1: # Only apply if gene_dim > 1
|
|
114
|
+
gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
|
|
115
|
+
interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
|
|
116
|
+
gene_output = gene_output + interaction_output # Residual connection
|
|
117
|
+
gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
|
|
215
118
|
|
|
216
|
-
#
|
|
217
|
-
|
|
119
|
+
# 4. Final activation (ensure non-negative)
|
|
120
|
+
gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
|
|
218
121
|
|
|
219
|
-
|
|
220
|
-
chunk_size = 5000
|
|
221
|
-
for i in range(0, self.gene_dim, chunk_size):
|
|
222
|
-
end_idx = min(i + chunk_size, self.gene_dim)
|
|
223
|
-
chunk = gene_features[:, i:end_idx]
|
|
224
|
-
|
|
225
|
-
head_outputs = []
|
|
226
|
-
for head in self.output_heads:
|
|
227
|
-
head_out = head(chunk).squeeze(-1)
|
|
228
|
-
head_outputs.append(head_out)
|
|
229
|
-
|
|
230
|
-
# Adaptive fusion
|
|
231
|
-
gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
|
|
232
|
-
gate_weights = gate_weights.unsqueeze(1)
|
|
233
|
-
|
|
234
|
-
# Weighted fusion
|
|
235
|
-
chunk_output = torch.zeros_like(head_outputs[0])
|
|
236
|
-
for j, head_out in enumerate(head_outputs):
|
|
237
|
-
chunk_output = chunk_output + gate_weights[:, :, j] * head_out
|
|
238
|
-
|
|
239
|
-
final_output[:, i:end_idx] = chunk_output
|
|
240
|
-
|
|
241
|
-
# Final activation
|
|
242
|
-
final_output = F.softplus(final_output * self.output_scale + self.output_bias)
|
|
243
|
-
|
|
244
|
-
return final_output
|
|
122
|
+
return gene_output
|
|
245
123
|
|
|
246
124
|
def _build_model(self):
|
|
247
|
-
"""Build model"""
|
|
125
|
+
"""Build the decoder model"""
|
|
248
126
|
return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
|
|
249
127
|
|
|
128
|
+
def _create_dataset(self, latent_data, expression_data):
|
|
129
|
+
"""Create dataset with dimension validation"""
|
|
130
|
+
class SimpleDataset(Dataset):
|
|
131
|
+
def __init__(self, latent, expression):
|
|
132
|
+
# Ensure dimensions match
|
|
133
|
+
assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
|
|
134
|
+
assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
|
|
135
|
+
assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
|
|
136
|
+
|
|
137
|
+
self.latent = torch.FloatTensor(latent)
|
|
138
|
+
self.expression = torch.FloatTensor(expression)
|
|
139
|
+
|
|
140
|
+
def __len__(self):
|
|
141
|
+
return len(self.latent)
|
|
142
|
+
|
|
143
|
+
def __getitem__(self, idx):
|
|
144
|
+
return self.latent[idx], self.expression[idx]
|
|
145
|
+
|
|
146
|
+
return SimpleDataset(latent_data, expression_data)
|
|
147
|
+
|
|
250
148
|
def train(self,
|
|
251
149
|
train_latent: np.ndarray,
|
|
252
150
|
train_expression: np.ndarray,
|
|
253
151
|
val_latent: np.ndarray = None,
|
|
254
152
|
val_expression: np.ndarray = None,
|
|
255
|
-
batch_size: int =
|
|
153
|
+
batch_size: int = 32,
|
|
256
154
|
num_epochs: int = 100,
|
|
257
155
|
learning_rate: float = 1e-4,
|
|
258
156
|
checkpoint_path: str = 'transcriptome_decoder.pth'):
|
|
259
157
|
"""
|
|
260
|
-
|
|
158
|
+
Train the decoder model with dimension safety
|
|
261
159
|
|
|
262
160
|
Args:
|
|
263
|
-
train_latent: Training latent variables
|
|
264
|
-
train_expression: Training expression data
|
|
265
|
-
val_latent: Validation latent variables
|
|
266
|
-
val_expression: Validation expression data
|
|
267
|
-
batch_size:
|
|
161
|
+
train_latent: Training latent variables [n_samples, latent_dim]
|
|
162
|
+
train_expression: Training expression data [n_samples, gene_dim]
|
|
163
|
+
val_latent: Validation latent variables (optional)
|
|
164
|
+
val_expression: Validation expression data (optional)
|
|
165
|
+
batch_size: Batch size optimized for memory
|
|
268
166
|
num_epochs: Number of training epochs
|
|
269
167
|
learning_rate: Learning rate
|
|
270
|
-
checkpoint_path:
|
|
168
|
+
checkpoint_path: Path to save the best model
|
|
271
169
|
"""
|
|
272
|
-
print("🚀 Starting
|
|
273
|
-
print(f"📊 Batch size: {batch_size}")
|
|
170
|
+
print("🚀 Starting training...")
|
|
274
171
|
|
|
275
|
-
#
|
|
276
|
-
|
|
277
|
-
if
|
|
278
|
-
|
|
172
|
+
# Dimension validation
|
|
173
|
+
self._validate_data_dimensions(train_latent, train_expression, "Training")
|
|
174
|
+
if val_latent is not None and val_expression is not None:
|
|
175
|
+
self._validate_data_dimensions(val_latent, val_expression, "Validation")
|
|
279
176
|
|
|
280
177
|
# Data preparation
|
|
281
|
-
train_dataset = self.
|
|
178
|
+
train_dataset = self._create_safe_dataset(train_latent, train_expression)
|
|
282
179
|
|
|
283
180
|
if val_latent is not None and val_expression is not None:
|
|
284
|
-
val_dataset = self.
|
|
181
|
+
val_dataset = self._create_safe_dataset(val_latent, val_expression)
|
|
182
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
183
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
184
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
285
185
|
else:
|
|
286
186
|
# Auto split
|
|
287
187
|
train_size = int(0.9 * len(train_dataset))
|
|
288
188
|
val_size = len(train_dataset) - train_size
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
)
|
|
189
|
+
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
190
|
+
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
191
|
+
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
|
|
192
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
292
193
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
pin_memory=True, num_workers=2)
|
|
194
|
+
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
195
|
+
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
196
|
+
print(f"📊 Batch size: {batch_size}")
|
|
297
197
|
|
|
298
|
-
# Optimizer
|
|
198
|
+
# Optimizer configuration
|
|
299
199
|
optimizer = optim.AdamW(
|
|
300
200
|
self.model.parameters(),
|
|
301
201
|
lr=learning_rate,
|
|
@@ -306,130 +206,181 @@ class TranscriptomeDecoder:
|
|
|
306
206
|
# Learning rate scheduler
|
|
307
207
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
308
208
|
|
|
309
|
-
# Loss function
|
|
310
|
-
|
|
209
|
+
# Loss function with dimension safety
|
|
210
|
+
def safe_loss(pred, target):
|
|
211
|
+
# Ensure dimensions match
|
|
212
|
+
if pred.shape != target.shape:
|
|
213
|
+
print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
|
|
214
|
+
# Truncate to minimum dimension (safety measure)
|
|
215
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
216
|
+
pred = pred[:, :min_dim]
|
|
217
|
+
target = target[:, :min_dim]
|
|
218
|
+
|
|
219
|
+
def correlation_loss(pred, target):
|
|
220
|
+
pred_centered = pred - pred.mean(dim=1, keepdim=True)
|
|
221
|
+
target_centered = target - target.mean(dim=1, keepdim=True)
|
|
222
|
+
|
|
223
|
+
correlation = (pred_centered * target_centered).sum(dim=1) / (
|
|
224
|
+
torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
|
|
225
|
+
torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return 1 - correlation.mean()
|
|
229
|
+
|
|
230
|
+
mse_loss = F.mse_loss(pred, target)
|
|
231
|
+
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
232
|
+
corr_loss = correlation_loss(pred, target)
|
|
233
|
+
return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
|
|
311
234
|
|
|
312
235
|
# Training history
|
|
313
236
|
history = {
|
|
314
|
-
'train_loss': [],
|
|
315
|
-
'
|
|
237
|
+
'train_loss': [],
|
|
238
|
+
'val_loss': [],
|
|
239
|
+
'learning_rate': []
|
|
316
240
|
}
|
|
317
241
|
|
|
318
242
|
best_val_loss = float('inf')
|
|
243
|
+
patience = 15
|
|
244
|
+
patience_counter = 0
|
|
319
245
|
|
|
246
|
+
print("\n📈 Starting training loop...")
|
|
320
247
|
for epoch in range(1, num_epochs + 1):
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
# Training phase with memory monitoring
|
|
324
|
-
train_loss = self._train_epoch(
|
|
325
|
-
train_loader, optimizer, criterion, scaler if self.mixed_precision else None
|
|
326
|
-
)
|
|
248
|
+
# Training phase
|
|
249
|
+
train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
|
|
327
250
|
|
|
328
251
|
# Validation phase
|
|
329
|
-
val_loss = self._validate_epoch(val_loader,
|
|
252
|
+
val_loss = self._validate_epoch(val_loader, safe_loss)
|
|
330
253
|
|
|
331
254
|
# Update scheduler
|
|
332
255
|
scheduler.step()
|
|
256
|
+
current_lr = scheduler.get_last_lr()[0]
|
|
333
257
|
|
|
334
258
|
# Record history
|
|
335
259
|
history['train_loss'].append(train_loss)
|
|
336
260
|
history['val_loss'].append(val_loss)
|
|
337
|
-
history['learning_rate'].append(
|
|
338
|
-
|
|
339
|
-
# Memory usage tracking
|
|
340
|
-
if torch.cuda.is_available():
|
|
341
|
-
memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
|
|
342
|
-
history['memory_usage'].append(memory_used)
|
|
343
|
-
print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
|
|
261
|
+
history['learning_rate'].append(current_lr)
|
|
344
262
|
|
|
345
|
-
|
|
346
|
-
|
|
263
|
+
# Print progress
|
|
264
|
+
if epoch % 5 == 0 or epoch == 1:
|
|
265
|
+
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
266
|
+
f"Train Loss: {train_loss:.4f} | "
|
|
267
|
+
f"Val Loss: {val_loss:.4f} | "
|
|
268
|
+
f"LR: {current_lr:.2e}")
|
|
347
269
|
|
|
348
|
-
#
|
|
270
|
+
# Early stopping and model saving
|
|
349
271
|
if val_loss < best_val_loss:
|
|
350
272
|
best_val_loss = val_loss
|
|
273
|
+
patience_counter = 0
|
|
351
274
|
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
352
|
-
|
|
275
|
+
if epoch % 10 == 0:
|
|
276
|
+
print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
|
|
277
|
+
else:
|
|
278
|
+
patience_counter += 1
|
|
279
|
+
if patience_counter >= patience:
|
|
280
|
+
print(f"🛑 Early stopping at epoch {epoch}")
|
|
281
|
+
break
|
|
353
282
|
|
|
283
|
+
# Training completed
|
|
354
284
|
self.is_trained = True
|
|
355
285
|
self.training_history = history
|
|
356
286
|
self.best_val_loss = best_val_loss
|
|
357
287
|
|
|
358
|
-
print(f"\n🎉 Training completed!
|
|
288
|
+
print(f"\n🎉 Training completed!")
|
|
289
|
+
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
|
|
290
|
+
print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
|
|
291
|
+
|
|
359
292
|
return history
|
|
360
293
|
|
|
361
|
-
def
|
|
362
|
-
"""
|
|
294
|
+
def _validate_data_dimensions(self, latent_data, expression_data, data_type):
|
|
295
|
+
"""Validate input data dimensions"""
|
|
296
|
+
assert latent_data.shape[1] == self.latent_dim, (
|
|
297
|
+
f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
298
|
+
assert expression_data.shape[1] == self.gene_dim, (
|
|
299
|
+
f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
|
|
300
|
+
assert latent_data.shape[0] == expression_data.shape[0], (
|
|
301
|
+
f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
|
|
302
|
+
print(f"✅ {data_type} data dimensions validated")
|
|
303
|
+
|
|
304
|
+
def _create_safe_dataset(self, latent_data, expression_data):
|
|
305
|
+
"""Create dataset with safety checks"""
|
|
306
|
+
class SafeDataset(Dataset):
|
|
307
|
+
def __init__(self, latent, expression):
|
|
308
|
+
self.latent = torch.FloatTensor(latent)
|
|
309
|
+
self.expression = torch.FloatTensor(expression)
|
|
310
|
+
|
|
311
|
+
# Safety check
|
|
312
|
+
if self.latent.shape[0] != self.expression.shape[0]:
|
|
313
|
+
raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
|
|
314
|
+
|
|
315
|
+
def __len__(self):
|
|
316
|
+
return len(self.latent)
|
|
317
|
+
|
|
318
|
+
def __getitem__(self, idx):
|
|
319
|
+
return self.latent[idx], self.expression[idx]
|
|
320
|
+
|
|
321
|
+
return SafeDataset(latent_data, expression_data)
|
|
322
|
+
|
|
323
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
324
|
+
"""Train for one epoch with dimension safety"""
|
|
363
325
|
self.model.train()
|
|
364
326
|
total_loss = 0
|
|
365
327
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
target = target.to(self.device, non_blocking=True)
|
|
370
|
-
|
|
371
|
-
optimizer.zero_grad(set_to_none=True) # Memory optimization
|
|
328
|
+
for batch_idx, (latent, target) in enumerate(train_loader):
|
|
329
|
+
latent = latent.to(self.device)
|
|
330
|
+
target = target.to(self.device)
|
|
372
331
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
332
|
+
# Dimension check
|
|
333
|
+
if latent.shape[1] != self.latent_dim:
|
|
334
|
+
print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
|
|
335
|
+
continue
|
|
377
336
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
scaler.update()
|
|
381
|
-
else:
|
|
382
|
-
pred = self.model(latent)
|
|
383
|
-
loss = criterion(pred, target)
|
|
384
|
-
loss.backward()
|
|
385
|
-
optimizer.step()
|
|
337
|
+
optimizer.zero_grad()
|
|
338
|
+
pred = self.model(latent)
|
|
386
339
|
|
|
387
|
-
|
|
388
|
-
|
|
340
|
+
# Final dimension check before loss calculation
|
|
341
|
+
if pred.shape[1] != target.shape[1]:
|
|
342
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
343
|
+
pred = pred[:, :min_dim]
|
|
344
|
+
target = target[:, :min_dim]
|
|
345
|
+
if batch_idx == 0: # Only warn once
|
|
346
|
+
print(f"⚠️ Truncating to min dimension: {min_dim}")
|
|
389
347
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
348
|
+
loss = loss_fn(pred, target)
|
|
349
|
+
loss.backward()
|
|
350
|
+
|
|
351
|
+
# Gradient clipping for stability
|
|
352
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
353
|
+
optimizer.step()
|
|
354
|
+
|
|
355
|
+
total_loss += loss.item()
|
|
394
356
|
|
|
395
357
|
return total_loss / len(train_loader)
|
|
396
358
|
|
|
397
|
-
def _validate_epoch(self, val_loader,
|
|
398
|
-
"""
|
|
359
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
360
|
+
"""Validate for one epoch with dimension safety"""
|
|
399
361
|
self.model.eval()
|
|
400
362
|
total_loss = 0
|
|
401
363
|
|
|
402
364
|
with torch.no_grad():
|
|
403
|
-
for latent, target in val_loader:
|
|
404
|
-
latent = latent.to(self.device
|
|
405
|
-
target = target.to(self.device
|
|
365
|
+
for batch_idx, (latent, target) in enumerate(val_loader):
|
|
366
|
+
latent = latent.to(self.device)
|
|
367
|
+
target = target.to(self.device)
|
|
406
368
|
|
|
407
369
|
pred = self.model(latent)
|
|
408
|
-
loss = criterion(pred, target)
|
|
409
|
-
total_loss += loss.item()
|
|
410
370
|
|
|
411
|
-
#
|
|
412
|
-
|
|
371
|
+
# Dimension safety
|
|
372
|
+
if pred.shape[1] != target.shape[1]:
|
|
373
|
+
min_dim = min(pred.shape[1], target.shape[1])
|
|
374
|
+
pred = pred[:, :min_dim]
|
|
375
|
+
target = target[:, :min_dim]
|
|
376
|
+
|
|
377
|
+
loss = loss_fn(pred, target)
|
|
378
|
+
total_loss += loss.item()
|
|
413
379
|
|
|
414
380
|
return total_loss / len(val_loader)
|
|
415
381
|
|
|
416
|
-
def _create_dataset(self, latent_data, expression_data):
|
|
417
|
-
"""Create dataset"""
|
|
418
|
-
class EfficientDataset(Dataset):
|
|
419
|
-
def __init__(self, latent, expression):
|
|
420
|
-
self.latent = torch.FloatTensor(latent)
|
|
421
|
-
self.expression = torch.FloatTensor(expression)
|
|
422
|
-
|
|
423
|
-
def __len__(self):
|
|
424
|
-
return len(self.latent)
|
|
425
|
-
|
|
426
|
-
def __getitem__(self, idx):
|
|
427
|
-
return self.latent[idx], self.expression[idx]
|
|
428
|
-
|
|
429
|
-
return EfficientDataset(latent_data, expression_data)
|
|
430
|
-
|
|
431
382
|
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
432
|
-
"""Save checkpoint"""
|
|
383
|
+
"""Save model checkpoint"""
|
|
433
384
|
torch.save({
|
|
434
385
|
'epoch': epoch,
|
|
435
386
|
'model_state_dict': self.model.state_dict(),
|
|
@@ -444,22 +395,26 @@ class TranscriptomeDecoder:
|
|
|
444
395
|
}
|
|
445
396
|
}, path)
|
|
446
397
|
|
|
447
|
-
def predict(self, latent_data: np.ndarray, batch_size: int =
|
|
398
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
448
399
|
"""
|
|
449
|
-
|
|
400
|
+
Predict gene expression from latent variables
|
|
450
401
|
|
|
451
402
|
Args:
|
|
452
403
|
latent_data: Latent variables [n_samples, latent_dim]
|
|
453
|
-
batch_size: Prediction batch size
|
|
404
|
+
batch_size: Prediction batch size
|
|
454
405
|
|
|
455
406
|
Returns:
|
|
456
407
|
expression: Predicted expression [n_samples, gene_dim]
|
|
457
408
|
"""
|
|
458
409
|
if not self.is_trained:
|
|
459
|
-
warnings.warn("Model not trained. Predictions may be inaccurate.")
|
|
410
|
+
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
460
411
|
|
|
461
412
|
self.model.eval()
|
|
462
413
|
|
|
414
|
+
# Input validation
|
|
415
|
+
if latent_data.shape[1] != self.latent_dim:
|
|
416
|
+
raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
|
|
417
|
+
|
|
463
418
|
if isinstance(latent_data, np.ndarray):
|
|
464
419
|
latent_data = torch.FloatTensor(latent_data)
|
|
465
420
|
|
|
@@ -470,76 +425,83 @@ class TranscriptomeDecoder:
|
|
|
470
425
|
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
471
426
|
batch_pred = self.model(batch_latent)
|
|
472
427
|
predictions.append(batch_pred.cpu())
|
|
473
|
-
|
|
474
|
-
# Clear memory
|
|
475
|
-
del batch_pred
|
|
476
|
-
if torch.cuda.is_available():
|
|
477
|
-
torch.cuda.empty_cache()
|
|
478
428
|
|
|
479
429
|
return torch.cat(predictions).numpy()
|
|
480
430
|
|
|
481
431
|
def load_model(self, model_path: str):
|
|
482
432
|
"""Load pre-trained model"""
|
|
483
433
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
434
|
+
|
|
435
|
+
# Check model configuration
|
|
436
|
+
if 'model_config' in checkpoint:
|
|
437
|
+
config = checkpoint['model_config']
|
|
438
|
+
if (config['latent_dim'] != self.latent_dim or
|
|
439
|
+
config['gene_dim'] != self.gene_dim):
|
|
440
|
+
print("⚠️ Model configuration mismatch. Reinitializing model.")
|
|
441
|
+
self.model = self._build_model()
|
|
442
|
+
self.model.to(self.device)
|
|
443
|
+
|
|
484
444
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
485
445
|
self.is_trained = True
|
|
486
446
|
self.training_history = checkpoint.get('training_history')
|
|
487
447
|
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
"""Get memory usage information"""
|
|
492
|
-
if torch.cuda.is_available():
|
|
493
|
-
memory_allocated = torch.cuda.memory_allocated() / 1024**3
|
|
494
|
-
memory_reserved = torch.cuda.memory_reserved() / 1024**3
|
|
495
|
-
return {
|
|
496
|
-
'allocated_gb': memory_allocated,
|
|
497
|
-
'reserved_gb': memory_reserved,
|
|
498
|
-
'available_gb': 20 - memory_allocated,
|
|
499
|
-
'utilization_percent': (memory_allocated / 20) * 100
|
|
500
|
-
}
|
|
501
|
-
return {'available_gb': 'N/A (CPU mode)'}
|
|
448
|
+
|
|
449
|
+
print(f"✅ Model loaded successfully!")
|
|
450
|
+
print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
|
|
502
451
|
|
|
452
|
+
def get_model_info(self) -> Dict:
|
|
453
|
+
"""Get model information"""
|
|
454
|
+
return {
|
|
455
|
+
'is_trained': self.is_trained,
|
|
456
|
+
'best_val_loss': self.best_val_loss,
|
|
457
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
458
|
+
'latent_dim': self.latent_dim,
|
|
459
|
+
'gene_dim': self.gene_dim,
|
|
460
|
+
'hidden_dim': self.hidden_dim,
|
|
461
|
+
'device': str(self.device)
|
|
462
|
+
}
|
|
503
463
|
'''
|
|
504
|
-
# Example usage
|
|
464
|
+
# Example usage
|
|
505
465
|
def example_usage():
|
|
506
|
-
"""
|
|
466
|
+
"""Example demonstration with dimension safety"""
|
|
507
467
|
|
|
508
|
-
# 1. Initialize
|
|
509
|
-
decoder =
|
|
468
|
+
# 1. Initialize decoder
|
|
469
|
+
decoder = SimpleTranscriptomeDecoder(
|
|
510
470
|
latent_dim=100,
|
|
511
471
|
gene_dim=2000, # Reduced for example
|
|
512
|
-
hidden_dim=256
|
|
472
|
+
hidden_dim=256
|
|
513
473
|
)
|
|
514
474
|
|
|
515
|
-
#
|
|
516
|
-
|
|
517
|
-
print(f"📊 Memory Info: {memory_info}")
|
|
518
|
-
|
|
519
|
-
# 2. Generate example data
|
|
520
|
-
n_samples = 500 # Reduced for memory
|
|
475
|
+
# 2. Generate example data with correct dimensions
|
|
476
|
+
n_samples = 1000
|
|
521
477
|
latent_data = np.random.randn(n_samples, 100).astype(np.float32)
|
|
522
|
-
|
|
478
|
+
|
|
479
|
+
# Create simulated expression data
|
|
480
|
+
weights = np.random.randn(100, 2000) * 0.1
|
|
481
|
+
expression_data = np.tanh(latent_data.dot(weights))
|
|
523
482
|
expression_data = np.maximum(expression_data, 0) # Non-negative
|
|
524
483
|
|
|
525
|
-
print(f"
|
|
484
|
+
print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
|
|
526
485
|
|
|
527
|
-
# 3. Train
|
|
486
|
+
# 3. Train the model
|
|
528
487
|
history = decoder.train(
|
|
529
488
|
train_latent=latent_data,
|
|
530
489
|
train_expression=expression_data,
|
|
531
|
-
batch_size=
|
|
532
|
-
num_epochs=
|
|
490
|
+
batch_size=32,
|
|
491
|
+
num_epochs=50,
|
|
492
|
+
learning_rate=1e-4
|
|
533
493
|
)
|
|
534
494
|
|
|
535
|
-
# 4.
|
|
536
|
-
test_latent = np.random.randn(
|
|
537
|
-
predictions = decoder.predict(test_latent
|
|
495
|
+
# 4. Make predictions
|
|
496
|
+
test_latent = np.random.randn(10, 100).astype(np.float32)
|
|
497
|
+
predictions = decoder.predict(test_latent)
|
|
538
498
|
print(f"🔮 Prediction shape: {predictions.shape}")
|
|
539
499
|
|
|
540
|
-
# 5.
|
|
541
|
-
|
|
542
|
-
print(f"
|
|
500
|
+
# 5. Get model info
|
|
501
|
+
info = decoder.get_model_info()
|
|
502
|
+
print(f"\n📋 Model Info:")
|
|
503
|
+
for key, value in info.items():
|
|
504
|
+
print(f" {key}: {value}")
|
|
543
505
|
|
|
544
506
|
return decoder
|
|
545
507
|
|