SURE-tools 2.4.13__py3-none-any.whl → 2.4.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +7 -0
- SURE/EfficientTranscriptomeDecoder.py +211 -266
- SURE/PerturbationAwareDecoder.py +723 -0
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +6 -1
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/METADATA +1 -1
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/RECORD +11 -9
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.13.dist-info → sure_tools-2.4.22.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ warnings.filterwarnings('ignore')
|
|
|
12
12
|
class EfficientTranscriptomeDecoder:
|
|
13
13
|
"""
|
|
14
14
|
High-performance, memory-efficient transcriptome decoder
|
|
15
|
-
|
|
15
|
+
Fixed version with corrected RMSNorm implementation
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
def __init__(self,
|
|
@@ -20,7 +20,7 @@ class EfficientTranscriptomeDecoder:
|
|
|
20
20
|
gene_dim: int = 60000,
|
|
21
21
|
hidden_dims: List[int] = [512, 1024, 2048],
|
|
22
22
|
bottleneck_dim: int = 256,
|
|
23
|
-
num_experts: int =
|
|
23
|
+
num_experts: int = 4,
|
|
24
24
|
dropout_rate: float = 0.1,
|
|
25
25
|
device: str = None):
|
|
26
26
|
"""
|
|
@@ -43,8 +43,8 @@ class EfficientTranscriptomeDecoder:
|
|
|
43
43
|
self.dropout_rate = dropout_rate
|
|
44
44
|
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
45
45
|
|
|
46
|
-
# Initialize model with
|
|
47
|
-
self.model = self.
|
|
46
|
+
# Initialize model with corrected architecture
|
|
47
|
+
self.model = self._build_corrected_model()
|
|
48
48
|
self.model.to(self.device)
|
|
49
49
|
|
|
50
50
|
# Training state
|
|
@@ -58,74 +58,50 @@ class EfficientTranscriptomeDecoder:
|
|
|
58
58
|
print(f" - Hidden Dimensions: {hidden_dims}")
|
|
59
59
|
print(f" - Bottleneck Dimension: {bottleneck_dim}")
|
|
60
60
|
print(f" - Number of Experts: {num_experts}")
|
|
61
|
-
print(f" - Estimated GPU Memory: ~6-8GB")
|
|
62
61
|
print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
63
62
|
|
|
64
|
-
class
|
|
65
|
-
"""
|
|
66
|
-
def forward(self, x):
|
|
67
|
-
x, gate = x.chunk(2, dim=-1)
|
|
68
|
-
return x * F.silu(gate)
|
|
69
|
-
|
|
70
|
-
class RMSNorm(nn.Module):
|
|
71
|
-
"""RMS Normalization - more stable than LayerNorm (GPT-3)"""
|
|
63
|
+
class CorrectedRMSNorm(nn.Module):
|
|
64
|
+
"""Corrected RMS Normalization with proper dimension handling"""
|
|
72
65
|
def __init__(self, dim: int, eps: float = 1e-8):
|
|
73
66
|
super().__init__()
|
|
74
67
|
self.eps = eps
|
|
75
|
-
self.
|
|
68
|
+
self.dim = dim
|
|
69
|
+
self.weight = nn.Parameter(torch.ones(dim)) # Correct: weight has same dim as input
|
|
76
70
|
|
|
77
71
|
def forward(self, x):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
72
|
+
# Ensure input has the right dimension
|
|
73
|
+
if x.size(-1) != self.dim:
|
|
74
|
+
raise ValueError(f"Input dimension {x.size(-1)} doesn't match RMSNorm dimension {self.dim}")
|
|
75
|
+
|
|
76
|
+
# Calculate RMS
|
|
77
|
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
|
78
|
+
# Normalize and apply weight
|
|
79
|
+
return x / rms * self.weight
|
|
81
80
|
|
|
82
|
-
class
|
|
83
|
-
"""
|
|
84
|
-
def __init__(self, input_dim: int, expert_dim: int, num_experts: int):
|
|
85
|
-
super().__init__()
|
|
86
|
-
self.num_experts = num_experts
|
|
87
|
-
self.experts = nn.ModuleList([
|
|
88
|
-
nn.Sequential(
|
|
89
|
-
nn.Linear(input_dim, expert_dim),
|
|
90
|
-
nn.Dropout(0.1),
|
|
91
|
-
nn.Linear(expert_dim, input_dim)
|
|
92
|
-
) for _ in range(num_experts)
|
|
93
|
-
])
|
|
94
|
-
self.gate = nn.Linear(input_dim, num_experts)
|
|
95
|
-
self.expert_dim = expert_dim
|
|
96
|
-
|
|
81
|
+
class SimplifiedSwiGLU(nn.Module):
|
|
82
|
+
"""Simplified SwiGLU activation"""
|
|
97
83
|
def forward(self, x):
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
# Expert outputs
|
|
103
|
-
expert_outputs = []
|
|
104
|
-
for i, expert in enumerate(self.experts):
|
|
105
|
-
expert_out = expert(x)
|
|
106
|
-
expert_outputs.append(expert_out.unsqueeze(-1))
|
|
107
|
-
|
|
108
|
-
# Combine expert outputs
|
|
109
|
-
expert_outputs = torch.cat(expert_outputs, dim=-1)
|
|
110
|
-
output = torch.einsum('bd, bde -> be', gate_weights, expert_outputs)
|
|
111
|
-
|
|
112
|
-
return output + x # Residual connection
|
|
84
|
+
# Split into two parts
|
|
85
|
+
x, gate = x.chunk(2, dim=-1)
|
|
86
|
+
return x * F.silu(gate)
|
|
113
87
|
|
|
114
|
-
class
|
|
115
|
-
"""
|
|
88
|
+
class MemoryEfficientBottleneck(nn.Module):
|
|
89
|
+
"""Memory-efficient bottleneck with corrected dimensions"""
|
|
116
90
|
def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
|
|
117
91
|
super().__init__()
|
|
92
|
+
# Ensure proper dimension matching
|
|
118
93
|
self.compress = nn.Linear(input_dim, bottleneck_dim)
|
|
119
|
-
self.norm1 = EfficientTranscriptomeDecoder.
|
|
94
|
+
self.norm1 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(bottleneck_dim)
|
|
120
95
|
self.expand = nn.Linear(bottleneck_dim, output_dim)
|
|
121
|
-
self.norm2 = EfficientTranscriptomeDecoder.
|
|
96
|
+
self.norm2 = EfficientTranscriptomeDecoder.CorrectedRMSNorm(output_dim)
|
|
97
|
+
self.activation = nn.SiLU()
|
|
122
98
|
self.dropout = nn.Dropout(0.1)
|
|
123
99
|
|
|
124
100
|
def forward(self, x):
|
|
125
101
|
# Compress
|
|
126
102
|
compressed = self.compress(x)
|
|
127
103
|
compressed = self.norm1(compressed)
|
|
128
|
-
compressed =
|
|
104
|
+
compressed = self.activation(compressed)
|
|
129
105
|
compressed = self.dropout(compressed)
|
|
130
106
|
|
|
131
107
|
# Expand
|
|
@@ -134,37 +110,57 @@ class EfficientTranscriptomeDecoder:
|
|
|
134
110
|
|
|
135
111
|
return expanded
|
|
136
112
|
|
|
137
|
-
class
|
|
138
|
-
"""
|
|
139
|
-
def __init__(self,
|
|
113
|
+
class StableMixtureOfExperts(nn.Module):
|
|
114
|
+
"""Stable mixture of experts without dimension issues"""
|
|
115
|
+
def __init__(self, input_dim: int, num_experts: int = 4):
|
|
140
116
|
super().__init__()
|
|
141
|
-
self.
|
|
142
|
-
self.
|
|
143
|
-
self.latent_projection = nn.Linear(latent_dim, proj_dim)
|
|
144
|
-
self.output_layer = nn.Linear(proj_dim, 1)
|
|
145
|
-
|
|
146
|
-
def forward(self, latent):
|
|
147
|
-
batch_size = latent.shape[0]
|
|
117
|
+
self.num_experts = num_experts
|
|
118
|
+
self.input_dim = input_dim
|
|
148
119
|
|
|
149
|
-
#
|
|
150
|
-
|
|
120
|
+
# Shared expert with different scaling factors
|
|
121
|
+
self.shared_expert = nn.Sequential(
|
|
122
|
+
nn.Linear(input_dim, input_dim * 2),
|
|
123
|
+
nn.SiLU(),
|
|
124
|
+
nn.Dropout(0.1),
|
|
125
|
+
nn.Linear(input_dim * 2, input_dim)
|
|
126
|
+
)
|
|
151
127
|
|
|
152
|
-
#
|
|
153
|
-
|
|
128
|
+
# Gating network
|
|
129
|
+
self.gate = nn.Sequential(
|
|
130
|
+
nn.Linear(input_dim, num_experts * 4),
|
|
131
|
+
nn.SiLU(),
|
|
132
|
+
nn.Linear(num_experts * 4, num_experts)
|
|
133
|
+
)
|
|
154
134
|
|
|
155
|
-
|
|
135
|
+
def forward(self, x):
|
|
136
|
+
# Get gate weights
|
|
137
|
+
gate_weights = F.softmax(self.gate(x), dim=-1) # [batch, num_experts]
|
|
138
|
+
|
|
139
|
+
# Process through shared expert
|
|
140
|
+
expert_output = self.shared_expert(x) # [batch, input_dim]
|
|
141
|
+
|
|
142
|
+
# Apply expert-specific scaling
|
|
143
|
+
weighted_output = torch.zeros_like(expert_output)
|
|
144
|
+
for i in range(self.num_experts):
|
|
145
|
+
expert_scale = 0.5 + 0.5 * i # Different scaling for each expert
|
|
146
|
+
expert_contribution = expert_output * expert_scale
|
|
147
|
+
expert_weight = gate_weights[:, i].unsqueeze(-1) # [batch, 1]
|
|
148
|
+
weighted_output += expert_weight * expert_contribution
|
|
149
|
+
|
|
150
|
+
# Residual connection
|
|
151
|
+
return x + weighted_output
|
|
156
152
|
|
|
157
|
-
class
|
|
158
|
-
"""
|
|
153
|
+
class CorrectedDecoder(nn.Module):
|
|
154
|
+
"""Corrected decoder with proper dimension handling"""
|
|
159
155
|
|
|
160
156
|
def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
|
|
161
|
-
|
|
157
|
+
bottleneck_dim: int, num_experts: int, dropout_rate: float):
|
|
162
158
|
super().__init__()
|
|
163
159
|
|
|
164
|
-
#
|
|
160
|
+
# Input projection
|
|
165
161
|
self.input_projection = nn.Sequential(
|
|
166
162
|
nn.Linear(latent_dim, hidden_dims[0]),
|
|
167
|
-
EfficientTranscriptomeDecoder.
|
|
163
|
+
EfficientTranscriptomeDecoder.CorrectedRMSNorm(hidden_dims[0]),
|
|
168
164
|
nn.SiLU(),
|
|
169
165
|
nn.Dropout(dropout_rate)
|
|
170
166
|
)
|
|
@@ -173,71 +169,74 @@ class EfficientTranscriptomeDecoder:
|
|
|
173
169
|
self.blocks = nn.ModuleList()
|
|
174
170
|
current_dim = hidden_dims[0]
|
|
175
171
|
|
|
176
|
-
for i,
|
|
177
|
-
block = nn.
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
172
|
+
for i, next_dim in enumerate(hidden_dims[1:], 1):
|
|
173
|
+
block = nn.ModuleDict({
|
|
174
|
+
'swiglu': nn.Sequential(
|
|
175
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
176
|
+
EfficientTranscriptomeDecoder.SimplifiedSwiGLU(),
|
|
177
|
+
nn.Dropout(dropout_rate),
|
|
178
|
+
nn.Linear(current_dim, current_dim) # Project back to same dimension
|
|
179
|
+
),
|
|
180
|
+
'bottleneck': EfficientTranscriptomeDecoder.MemoryEfficientBottleneck(
|
|
181
|
+
current_dim, bottleneck_dim, next_dim
|
|
182
|
+
),
|
|
183
|
+
'experts': EfficientTranscriptomeDecoder.StableMixtureOfExperts(
|
|
184
|
+
next_dim, num_experts
|
|
189
185
|
)
|
|
190
|
-
|
|
186
|
+
})
|
|
191
187
|
self.blocks.append(block)
|
|
192
|
-
current_dim =
|
|
188
|
+
current_dim = next_dim
|
|
193
189
|
|
|
194
|
-
#
|
|
195
|
-
self.
|
|
196
|
-
current_dim,
|
|
190
|
+
# Final projection to gene dimension
|
|
191
|
+
self.final_projection = nn.Sequential(
|
|
192
|
+
nn.Linear(current_dim, current_dim * 2),
|
|
193
|
+
nn.SiLU(),
|
|
194
|
+
nn.Dropout(dropout_rate),
|
|
195
|
+
nn.Linear(current_dim * 2, gene_dim)
|
|
197
196
|
)
|
|
198
197
|
|
|
199
|
-
# Output
|
|
198
|
+
# Output parameters
|
|
200
199
|
self.output_scale = nn.Parameter(torch.ones(1))
|
|
201
200
|
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
202
201
|
|
|
203
202
|
self._init_weights()
|
|
204
203
|
|
|
205
204
|
def _init_weights(self):
|
|
206
|
-
"""
|
|
205
|
+
"""Proper weight initialization"""
|
|
207
206
|
for module in self.modules():
|
|
208
207
|
if isinstance(module, nn.Linear):
|
|
209
|
-
|
|
210
|
-
nn.init.kaiming_normal_(module.weight, nonlinearity='linear')
|
|
208
|
+
nn.init.xavier_uniform_(module.weight)
|
|
211
209
|
if module.bias is not None:
|
|
212
210
|
nn.init.zeros_(module.bias)
|
|
213
211
|
|
|
214
212
|
def forward(self, x):
|
|
215
|
-
#
|
|
213
|
+
# Input projection
|
|
216
214
|
x = self.input_projection(x)
|
|
217
215
|
|
|
218
216
|
# Process through blocks
|
|
219
217
|
for block in self.blocks:
|
|
220
|
-
#
|
|
221
|
-
|
|
218
|
+
# SwiGLU with residual
|
|
219
|
+
residual = x
|
|
220
|
+
x_swiglu = block['swiglu'](x)
|
|
221
|
+
x = x + x_swiglu # Residual connection
|
|
222
222
|
|
|
223
|
-
#
|
|
224
|
-
|
|
223
|
+
# Bottleneck
|
|
224
|
+
x = block['bottleneck'](x)
|
|
225
225
|
|
|
226
|
-
#
|
|
227
|
-
|
|
228
|
-
x = x + swiglu_out # Residual connection
|
|
226
|
+
# Mixture of Experts with residual
|
|
227
|
+
x = block['experts'](x)
|
|
229
228
|
|
|
230
|
-
# Final
|
|
231
|
-
|
|
229
|
+
# Final projection
|
|
230
|
+
x = self.final_projection(x)
|
|
232
231
|
|
|
233
232
|
# Ensure non-negative output
|
|
234
|
-
|
|
233
|
+
x = F.softplus(x * self.output_scale + self.output_bias)
|
|
235
234
|
|
|
236
|
-
return
|
|
235
|
+
return x
|
|
237
236
|
|
|
238
|
-
def
|
|
239
|
-
"""Build the
|
|
240
|
-
return self.
|
|
237
|
+
def _build_corrected_model(self):
|
|
238
|
+
"""Build the corrected model"""
|
|
239
|
+
return self.CorrectedDecoder(
|
|
241
240
|
self.latent_dim, self.gene_dim, self.hidden_dims,
|
|
242
241
|
self.bottleneck_dim, self.num_experts, self.dropout_rate
|
|
243
242
|
)
|
|
@@ -247,24 +246,14 @@ class EfficientTranscriptomeDecoder:
|
|
|
247
246
|
train_expression: np.ndarray,
|
|
248
247
|
val_latent: np.ndarray = None,
|
|
249
248
|
val_expression: np.ndarray = None,
|
|
250
|
-
batch_size: int =
|
|
251
|
-
num_epochs: int =
|
|
249
|
+
batch_size: int = 32,
|
|
250
|
+
num_epochs: int = 100,
|
|
252
251
|
learning_rate: float = 1e-4,
|
|
253
|
-
checkpoint_path: str = '
|
|
252
|
+
checkpoint_path: str = 'transcriptome_decoder.pth') -> Dict:
|
|
254
253
|
"""
|
|
255
|
-
Train
|
|
256
|
-
|
|
257
|
-
Args:
|
|
258
|
-
train_latent: Training latent variables
|
|
259
|
-
train_expression: Training expression data
|
|
260
|
-
val_latent: Validation latent variables
|
|
261
|
-
val_expression: Validation expression data
|
|
262
|
-
batch_size: Batch size (optimized for memory)
|
|
263
|
-
num_epochs: Number of epochs
|
|
264
|
-
learning_rate: Learning rate
|
|
265
|
-
checkpoint_path: Model save path
|
|
254
|
+
Train the corrected decoder
|
|
266
255
|
"""
|
|
267
|
-
print("🚀 Starting
|
|
256
|
+
print("🚀 Starting Training...")
|
|
268
257
|
|
|
269
258
|
# Data preparation
|
|
270
259
|
train_dataset = self._create_dataset(train_latent, train_expression)
|
|
@@ -273,114 +262,83 @@ class EfficientTranscriptomeDecoder:
|
|
|
273
262
|
val_dataset = self._create_dataset(val_latent, val_expression)
|
|
274
263
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
275
264
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
265
|
+
print(f"📈 Using provided validation data: {len(val_dataset)} samples")
|
|
276
266
|
else:
|
|
267
|
+
# Auto split
|
|
277
268
|
train_size = int(0.9 * len(train_dataset))
|
|
278
269
|
val_size = len(train_dataset) - train_size
|
|
279
270
|
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
|
|
280
271
|
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
|
281
272
|
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
273
|
+
print(f"📈 Auto-split validation: {val_size} samples")
|
|
282
274
|
|
|
283
275
|
print(f"📊 Training samples: {len(train_loader.dataset)}")
|
|
284
276
|
print(f"📊 Validation samples: {len(val_loader.dataset)}")
|
|
285
277
|
print(f"📊 Batch size: {batch_size}")
|
|
286
278
|
|
|
287
|
-
#
|
|
279
|
+
# Optimizer
|
|
288
280
|
optimizer = optim.AdamW(
|
|
289
281
|
self.model.parameters(),
|
|
290
282
|
lr=learning_rate,
|
|
291
|
-
weight_decay=0.
|
|
292
|
-
betas=(0.9, 0.
|
|
293
|
-
eps=1e-8
|
|
283
|
+
weight_decay=0.01,
|
|
284
|
+
betas=(0.9, 0.999)
|
|
294
285
|
)
|
|
295
286
|
|
|
296
|
-
#
|
|
297
|
-
scheduler = optim.lr_scheduler.
|
|
298
|
-
optimizer,
|
|
299
|
-
max_lr=learning_rate * 5,
|
|
300
|
-
epochs=num_epochs,
|
|
301
|
-
steps_per_epoch=len(train_loader),
|
|
302
|
-
pct_start=0.1,
|
|
303
|
-
div_factor=10.0,
|
|
304
|
-
final_div_factor=100.0
|
|
305
|
-
)
|
|
287
|
+
# Scheduler
|
|
288
|
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
306
289
|
|
|
307
|
-
#
|
|
308
|
-
def
|
|
309
|
-
# 1. MSE loss for overall accuracy
|
|
290
|
+
# Loss function
|
|
291
|
+
def combined_loss(pred, target):
|
|
310
292
|
mse_loss = F.mse_loss(pred, target)
|
|
311
|
-
|
|
312
|
-
# 2. Poisson loss for count data
|
|
313
293
|
poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
# 4. Sparsity loss for realistic distribution
|
|
319
|
-
sparsity_loss = F.mse_loss(
|
|
320
|
-
(pred < 1e-3).float().mean(),
|
|
321
|
-
torch.tensor(0.85, device=pred.device) # Target sparsity
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
# 5. Spectral loss for smoothness
|
|
325
|
-
spectral_loss = self._spectral_loss(pred, target)
|
|
326
|
-
|
|
327
|
-
# Weighted combination
|
|
328
|
-
total_loss = (mse_loss + 0.3 * poisson_loss + 0.2 * correlation_loss +
|
|
329
|
-
0.1 * sparsity_loss + 0.05 * spectral_loss)
|
|
330
|
-
|
|
331
|
-
return total_loss, {
|
|
332
|
-
'mse': mse_loss.item(),
|
|
333
|
-
'poisson': poisson_loss.item(),
|
|
334
|
-
'correlation': correlation_loss.item(),
|
|
335
|
-
'sparsity': sparsity_loss.item(),
|
|
336
|
-
'spectral': spectral_loss.item()
|
|
337
|
-
}
|
|
294
|
+
correlation = self._pearson_correlation(pred, target)
|
|
295
|
+
correlation_loss = 1 - correlation
|
|
296
|
+
return mse_loss + 0.3 * poisson_loss + 0.1 * correlation_loss
|
|
338
297
|
|
|
339
298
|
# Training history
|
|
340
299
|
history = {
|
|
341
300
|
'train_loss': [], 'val_loss': [],
|
|
342
301
|
'train_mse': [], 'val_mse': [],
|
|
343
302
|
'train_correlation': [], 'val_correlation': [],
|
|
344
|
-
'learning_rates': []
|
|
303
|
+
'learning_rates': []
|
|
345
304
|
}
|
|
346
305
|
|
|
347
306
|
best_val_loss = float('inf')
|
|
348
|
-
patience =
|
|
307
|
+
patience = 20
|
|
349
308
|
patience_counter = 0
|
|
350
309
|
|
|
351
|
-
print("\n📈 Starting training
|
|
310
|
+
print("\n📈 Starting training loop...")
|
|
352
311
|
for epoch in range(1, num_epochs + 1):
|
|
353
|
-
# Training
|
|
354
|
-
|
|
355
|
-
train_loader, optimizer, scheduler, advanced_loss
|
|
356
|
-
)
|
|
312
|
+
# Training
|
|
313
|
+
train_metrics = self._train_epoch(train_loader, optimizer, combined_loss)
|
|
357
314
|
|
|
358
|
-
# Validation
|
|
359
|
-
|
|
315
|
+
# Validation
|
|
316
|
+
val_metrics = self._validate_epoch(val_loader, combined_loss)
|
|
317
|
+
|
|
318
|
+
# Update scheduler
|
|
319
|
+
scheduler.step()
|
|
320
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
360
321
|
|
|
361
322
|
# Record history
|
|
362
|
-
history['train_loss'].append(
|
|
363
|
-
history['val_loss'].append(
|
|
364
|
-
history['train_mse'].append(
|
|
365
|
-
history['val_mse'].append(
|
|
366
|
-
history['train_correlation'].append(
|
|
367
|
-
history['val_correlation'].append(
|
|
368
|
-
history['learning_rates'].append(
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
# Print detailed progress
|
|
323
|
+
history['train_loss'].append(train_metrics['loss'])
|
|
324
|
+
history['val_loss'].append(val_metrics['loss'])
|
|
325
|
+
history['train_mse'].append(train_metrics['mse'])
|
|
326
|
+
history['val_mse'].append(val_metrics['mse'])
|
|
327
|
+
history['train_correlation'].append(train_metrics['correlation'])
|
|
328
|
+
history['val_correlation'].append(val_metrics['correlation'])
|
|
329
|
+
history['learning_rates'].append(current_lr)
|
|
330
|
+
|
|
331
|
+
# Print progress
|
|
372
332
|
if epoch % 10 == 0 or epoch == 1:
|
|
373
|
-
lr = optimizer.param_groups[0]['lr']
|
|
374
333
|
print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
|
|
375
|
-
f"Train: {
|
|
376
|
-
f"Val: {
|
|
377
|
-
f"
|
|
378
|
-
f"LR: {
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
best_val_loss = val_loss
|
|
334
|
+
f"Train Loss: {train_metrics['loss']:.4f} | "
|
|
335
|
+
f"Val Loss: {val_metrics['loss']:.4f} | "
|
|
336
|
+
f"Correlation: {val_metrics['correlation']:.4f} | "
|
|
337
|
+
f"LR: {current_lr:.2e}")
|
|
338
|
+
|
|
339
|
+
# Early stopping
|
|
340
|
+
if val_metrics['loss'] < best_val_loss:
|
|
341
|
+
best_val_loss = val_metrics['loss']
|
|
384
342
|
patience_counter = 0
|
|
385
343
|
self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
|
|
386
344
|
if epoch % 20 == 0:
|
|
@@ -402,8 +360,8 @@ class EfficientTranscriptomeDecoder:
|
|
|
402
360
|
return history
|
|
403
361
|
|
|
404
362
|
def _create_dataset(self, latent_data, expression_data):
|
|
405
|
-
"""Create
|
|
406
|
-
class
|
|
363
|
+
"""Create dataset"""
|
|
364
|
+
class SimpleDataset(Dataset):
|
|
407
365
|
def __init__(self, latent, expression):
|
|
408
366
|
self.latent = torch.FloatTensor(latent)
|
|
409
367
|
self.expression = torch.FloatTensor(expression)
|
|
@@ -414,7 +372,7 @@ class EfficientTranscriptomeDecoder:
|
|
|
414
372
|
def __getitem__(self, idx):
|
|
415
373
|
return self.latent[idx], self.expression[idx]
|
|
416
374
|
|
|
417
|
-
return
|
|
375
|
+
return SimpleDataset(latent_data, expression_data)
|
|
418
376
|
|
|
419
377
|
def _pearson_correlation(self, pred, target):
|
|
420
378
|
"""Calculate Pearson correlation"""
|
|
@@ -426,68 +384,48 @@ class EfficientTranscriptomeDecoder:
|
|
|
426
384
|
|
|
427
385
|
return (numerator / (denominator + 1e-8)).mean()
|
|
428
386
|
|
|
429
|
-
def
|
|
430
|
-
"""
|
|
431
|
-
pred_fft = torch.fft.fft(pred, dim=1)
|
|
432
|
-
target_fft = torch.fft.fft(target, dim=1)
|
|
433
|
-
|
|
434
|
-
magnitude_loss = F.mse_loss(torch.abs(pred_fft), torch.abs(target_fft))
|
|
435
|
-
phase_loss = F.mse_loss(torch.angle(pred_fft), torch.angle(target_fft))
|
|
436
|
-
|
|
437
|
-
return magnitude_loss + 0.5 * phase_loss
|
|
438
|
-
|
|
439
|
-
def _train_epoch_advanced(self, train_loader, optimizer, scheduler, loss_fn):
|
|
440
|
-
"""Advanced training with gradient accumulation"""
|
|
387
|
+
def _train_epoch(self, train_loader, optimizer, loss_fn):
|
|
388
|
+
"""Train one epoch"""
|
|
441
389
|
self.model.train()
|
|
442
390
|
total_loss = 0
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
# Gradient accumulation for effective larger batch size
|
|
447
|
-
accumulation_steps = 4
|
|
448
|
-
optimizer.zero_grad()
|
|
391
|
+
total_mse = 0
|
|
392
|
+
total_correlation = 0
|
|
449
393
|
|
|
450
|
-
for
|
|
394
|
+
for latent, target in train_loader:
|
|
451
395
|
latent = latent.to(self.device, non_blocking=True)
|
|
452
396
|
target = target.to(self.device, non_blocking=True)
|
|
453
397
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
pred = self.model(latent)
|
|
457
|
-
loss, components = loss_fn(pred, target)
|
|
398
|
+
optimizer.zero_grad()
|
|
399
|
+
pred = self.model(latent)
|
|
458
400
|
|
|
459
|
-
|
|
460
|
-
loss = loss / accumulation_steps
|
|
401
|
+
loss = loss_fn(pred, target)
|
|
461
402
|
loss.backward()
|
|
462
403
|
|
|
463
|
-
# Gradient
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
grad_norms.append(grad_norm.item())
|
|
404
|
+
# Gradient clipping
|
|
405
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
406
|
+
optimizer.step()
|
|
407
|
+
|
|
408
|
+
# Calculate metrics
|
|
409
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
410
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
472
411
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
total_components[key] += components[key]
|
|
412
|
+
total_loss += loss.item()
|
|
413
|
+
total_mse += mse_loss
|
|
414
|
+
total_correlation += correlation
|
|
477
415
|
|
|
478
|
-
# Average metrics
|
|
479
416
|
num_batches = len(train_loader)
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
417
|
+
return {
|
|
418
|
+
'loss': total_loss / num_batches,
|
|
419
|
+
'mse': total_mse / num_batches,
|
|
420
|
+
'correlation': total_correlation / num_batches
|
|
421
|
+
}
|
|
485
422
|
|
|
486
|
-
def
|
|
487
|
-
"""
|
|
423
|
+
def _validate_epoch(self, val_loader, loss_fn):
|
|
424
|
+
"""Validate one epoch"""
|
|
488
425
|
self.model.eval()
|
|
489
426
|
total_loss = 0
|
|
490
|
-
|
|
427
|
+
total_mse = 0
|
|
428
|
+
total_correlation = 0
|
|
491
429
|
|
|
492
430
|
with torch.no_grad():
|
|
493
431
|
for latent, target in val_loader:
|
|
@@ -495,17 +433,20 @@ class EfficientTranscriptomeDecoder:
|
|
|
495
433
|
target = target.to(self.device, non_blocking=True)
|
|
496
434
|
|
|
497
435
|
pred = self.model(latent)
|
|
498
|
-
loss
|
|
436
|
+
loss = loss_fn(pred, target)
|
|
437
|
+
mse_loss = F.mse_loss(pred, target).item()
|
|
438
|
+
correlation = self._pearson_correlation(pred, target).item()
|
|
499
439
|
|
|
500
440
|
total_loss += loss.item()
|
|
501
|
-
|
|
502
|
-
|
|
441
|
+
total_mse += mse_loss
|
|
442
|
+
total_correlation += correlation
|
|
503
443
|
|
|
504
444
|
num_batches = len(val_loader)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
445
|
+
return {
|
|
446
|
+
'loss': total_loss / num_batches,
|
|
447
|
+
'mse': total_mse / num_batches,
|
|
448
|
+
'correlation': total_correlation / num_batches
|
|
449
|
+
}
|
|
509
450
|
|
|
510
451
|
def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
|
|
511
452
|
"""Save checkpoint"""
|
|
@@ -525,8 +466,8 @@ class EfficientTranscriptomeDecoder:
|
|
|
525
466
|
}
|
|
526
467
|
}, path)
|
|
527
468
|
|
|
528
|
-
def predict(self, latent_data: np.ndarray, batch_size: int =
|
|
529
|
-
"""
|
|
469
|
+
def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
|
|
470
|
+
"""Predict gene expression"""
|
|
530
471
|
if not self.is_trained:
|
|
531
472
|
warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
|
|
532
473
|
|
|
@@ -539,15 +480,8 @@ class EfficientTranscriptomeDecoder:
|
|
|
539
480
|
with torch.no_grad():
|
|
540
481
|
for i in range(0, len(latent_data), batch_size):
|
|
541
482
|
batch_latent = latent_data[i:i+batch_size].to(self.device)
|
|
542
|
-
|
|
543
|
-
with torch.cuda.amp.autocast(): # Mixed precision for memory
|
|
544
|
-
batch_pred = self.model(batch_latent)
|
|
545
|
-
|
|
483
|
+
batch_pred = self.model(batch_latent)
|
|
546
484
|
predictions.append(batch_pred.cpu())
|
|
547
|
-
|
|
548
|
-
# Clear memory
|
|
549
|
-
if torch.cuda.is_available():
|
|
550
|
-
torch.cuda.empty_cache()
|
|
551
485
|
|
|
552
486
|
return torch.cat(predictions).numpy()
|
|
553
487
|
|
|
@@ -559,11 +493,23 @@ class EfficientTranscriptomeDecoder:
|
|
|
559
493
|
self.training_history = checkpoint.get('training_history')
|
|
560
494
|
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
|
561
495
|
print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
|
|
496
|
+
|
|
497
|
+
def get_model_info(self) -> Dict:
|
|
498
|
+
"""Get model information"""
|
|
499
|
+
return {
|
|
500
|
+
'is_trained': self.is_trained,
|
|
501
|
+
'best_val_loss': self.best_val_loss,
|
|
502
|
+
'parameters': sum(p.numel() for p in self.model.parameters()),
|
|
503
|
+
'latent_dim': self.latent_dim,
|
|
504
|
+
'gene_dim': self.gene_dim,
|
|
505
|
+
'hidden_dims': self.hidden_dims,
|
|
506
|
+
'device': str(self.device)
|
|
507
|
+
}
|
|
562
508
|
|
|
563
509
|
'''
|
|
564
510
|
# Example usage
|
|
565
511
|
def example_usage():
|
|
566
|
-
"""
|
|
512
|
+
"""Example demonstration"""
|
|
567
513
|
|
|
568
514
|
# Initialize decoder
|
|
569
515
|
decoder = EfficientTranscriptomeDecoder(
|
|
@@ -590,7 +536,7 @@ def example_usage():
|
|
|
590
536
|
history = decoder.train(
|
|
591
537
|
train_latent=latent_data,
|
|
592
538
|
train_expression=expression_data,
|
|
593
|
-
batch_size=
|
|
539
|
+
batch_size=32,
|
|
594
540
|
num_epochs=50
|
|
595
541
|
)
|
|
596
542
|
|
|
@@ -603,5 +549,4 @@ def example_usage():
|
|
|
603
549
|
|
|
604
550
|
if __name__ == "__main__":
|
|
605
551
|
example_usage()
|
|
606
|
-
|
|
607
552
|
'''
|