SURE-tools 2.4.5__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -1,28 +1,28 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- import torch.optim as optim
4
3
  import torch.nn.functional as F
4
+ import torch.optim as optim
5
5
  from torch.utils.data import Dataset, DataLoader
6
6
  import numpy as np
7
- from typing import Dict, List, Tuple, Optional
8
- import matplotlib.pyplot as plt
9
- from tqdm import tqdm
7
+ from typing import Dict, Optional
10
8
  import warnings
11
9
  warnings.filterwarnings('ignore')
12
10
 
13
11
  class TranscriptomeDecoder:
12
+ """Transcriptome decoder"""
13
+
14
14
  def __init__(self,
15
15
  latent_dim: int = 100,
16
16
  gene_dim: int = 60000,
17
- hidden_dim: int = 512, # Reduced for memory efficiency
17
+ hidden_dim: int = 512,
18
18
  device: str = None):
19
19
  """
20
- Whole-transcriptome decoder
20
+ Simple but powerful decoder for latent to transcriptome mapping
21
21
 
22
22
  Args:
23
23
  latent_dim: Latent variable dimension (typically 50-100)
24
24
  gene_dim: Number of genes (full transcriptome ~60,000)
25
- hidden_dim: Hidden dimension (reduced for memory efficiency)
25
+ hidden_dim: Hidden dimension optimized
26
26
  device: Computation device
27
27
  """
28
28
  self.latent_dim = latent_dim
@@ -30,10 +30,6 @@ class TranscriptomeDecoder:
30
30
  self.hidden_dim = hidden_dim
31
31
  self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
32
 
33
- # Memory optimization settings
34
- self.gradient_checkpointing = True
35
- self.mixed_precision = True
36
-
37
33
  # Initialize model
38
34
  self.model = self._build_model()
39
35
  self.model.to(self.device)
@@ -43,259 +39,163 @@ class TranscriptomeDecoder:
43
39
  self.training_history = None
44
40
  self.best_val_loss = float('inf')
45
41
 
46
- print(f"🚀 TranscriptomeDecoder Initialized:")
42
+ print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
47
43
  print(f" - Latent Dimension: {latent_dim}")
48
44
  print(f" - Gene Dimension: {gene_dim}")
49
45
  print(f" - Hidden Dimension: {hidden_dim}")
50
46
  print(f" - Device: {self.device}")
51
47
  print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
48
 
53
- class MemoryEfficientBlock(nn.Module):
54
- """Memory-efficient building block with gradient checkpointing"""
55
- def __init__(self, input_dim, output_dim, use_checkpointing=True):
56
- super().__init__()
57
- self.use_checkpointing = use_checkpointing
58
- self.net = nn.Sequential(
59
- nn.Linear(input_dim, output_dim),
60
- nn.BatchNorm1d(output_dim),
61
- nn.GELU(),
62
- nn.Dropout(0.1)
63
- )
64
-
65
- def forward(self, x):
66
- if self.use_checkpointing and self.training:
67
- return torch.utils.checkpoint.checkpoint(self.net, x)
68
- return self.net(x)
69
-
70
- class SparseGeneProjection(nn.Module):
71
- """Sparse gene projection to reduce memory usage"""
72
- def __init__(self, latent_dim, gene_dim, projection_dim=256):
73
- super().__init__()
74
- self.projection_dim = projection_dim
75
- self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
76
- self.latent_projection = nn.Linear(latent_dim, projection_dim)
77
- self.activation = nn.GELU()
78
-
79
- def forward(self, latent):
80
- # Project latent to gene space efficiently
81
- batch_size = latent.shape[0]
82
- latent_proj = self.latent_projection(latent) # [batch, projection_dim]
83
-
84
- # Efficient matrix multiplication
85
- gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
86
- output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
-
88
- return self.activation(output)
89
-
90
- class ChunkedTransformer(nn.Module):
91
- def __init__(self, gene_dim, hidden_dim=512, chunk_size=2000, num_layers=3):
92
- super().__init__()
93
- self.chunk_size = chunk_size
94
- self.hidden_dim = hidden_dim
95
- self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
-
97
- # 共享的Transformer层
98
- self.transformer_layers = nn.ModuleList([
99
- nn.Sequential(
100
- nn.Linear(hidden_dim, hidden_dim),
101
- nn.GELU(),
102
- nn.Dropout(0.1),
103
- nn.Linear(hidden_dim, hidden_dim),
104
- ) for _ in range(num_layers)
105
- ])
106
-
107
- # 每个chunk独立的投影层
108
- self.input_projections = nn.ModuleList([
109
- nn.Linear(min(chunk_size, gene_dim - i * chunk_size), hidden_dim)
110
- for i in range(self.num_chunks)
111
- ])
112
- self.output_projections = nn.ModuleList([
113
- nn.Linear(hidden_dim, min(chunk_size, gene_dim - i * chunk_size))
114
- for i in range(self.num_chunks)
115
- ])
116
-
117
- def forward(self, x):
118
- batch_size, gene_dim = x.shape
119
- output = torch.zeros_like(x)
120
-
121
- for i in range(self.num_chunks):
122
- start_idx = i * self.chunk_size
123
- end_idx = min((i + 1) * self.chunk_size, gene_dim)
124
- current_chunk_size = end_idx - start_idx
125
-
126
- chunk = x[:, start_idx:end_idx] # [batch_size, current_chunk_size]
127
-
128
- # 投影到hidden_dim
129
- chunk_proj = self.input_projections[i](chunk) # [batch_size, hidden_dim]
130
-
131
- # Transformer处理
132
- for layer in self.transformer_layers:
133
- chunk_proj = layer(chunk_proj) + chunk_proj
134
-
135
- # 投影回原始维度
136
- chunk_out = self.output_projections[i](chunk_proj) # [batch_size, current_chunk_size]
137
-
138
- output[:, start_idx:end_idx] = chunk_out
139
-
140
- return output
141
-
142
49
  class Decoder(nn.Module):
143
- """Decoder model"""
144
- def __init__(self, latent_dim, gene_dim, hidden_dim):
50
+ """Memory-efficient decoder architecture with dimension handling"""
51
+
52
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
145
53
  super().__init__()
146
54
  self.latent_dim = latent_dim
147
55
  self.gene_dim = gene_dim
148
56
  self.hidden_dim = hidden_dim
149
57
 
150
- # Stage 1: Latent expansion (memory efficient)
58
+ # Stage 1: Latent variable expansion
151
59
  self.latent_expansion = nn.Sequential(
152
60
  nn.Linear(latent_dim, hidden_dim * 2),
61
+ nn.BatchNorm1d(hidden_dim * 2),
153
62
  nn.GELU(),
154
63
  nn.Dropout(0.1),
155
64
  nn.Linear(hidden_dim * 2, hidden_dim),
65
+ nn.BatchNorm1d(hidden_dim),
66
+ nn.GELU(),
156
67
  )
157
68
 
158
- # Stage 2: Sparse gene projection
159
- self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
160
- latent_dim, gene_dim, hidden_dim
161
- )
162
-
163
- # Stage 3: Chunked processing
164
- self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
165
- gene_dim, hidden_dim, chunk_size=2000, num_layers=3
69
+ # Stage 2: Direct projection to gene dimension (simpler approach)
70
+ self.gene_projector = nn.Sequential(
71
+ nn.Linear(hidden_dim, hidden_dim * 2),
72
+ nn.GELU(),
73
+ nn.Dropout(0.1),
74
+ nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
166
75
  )
167
76
 
168
- # Stage 4: Multi-head output with memory efficiency
169
- self.output_heads = nn.ModuleList([
170
- nn.Sequential(
171
- nn.Linear(hidden_dim, hidden_dim // 2),
172
- nn.GELU(),
173
- nn.Linear(hidden_dim // 2, 1)
174
- ) for _ in range(2) # Reduced from 3 to 2 heads
175
- ])
176
-
177
- # Adaptive fusion
178
- self.fusion_gate = nn.Sequential(
179
- nn.Linear(hidden_dim, hidden_dim // 4),
77
+ # Stage 3: Lightweight gene interaction
78
+ self.gene_interaction = nn.Sequential(
79
+ nn.Conv1d(1, 32, kernel_size=3, padding=1),
180
80
  nn.GELU(),
181
- nn.Linear(hidden_dim // 4, len(self.output_heads)),
182
- nn.Softmax(dim=-1)
81
+ nn.Dropout1d(0.1),
82
+ nn.Conv1d(32, 1, kernel_size=3, padding=1),
183
83
  )
184
84
 
185
85
  # Output scaling
186
86
  self.output_scale = nn.Parameter(torch.ones(1))
187
87
  self.output_bias = nn.Parameter(torch.zeros(1))
188
88
 
189
- self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
190
-
191
89
  self._init_weights()
192
90
 
193
91
  def _init_weights(self):
92
+ """Weight initialization"""
194
93
  for module in self.modules():
195
94
  if isinstance(module, nn.Linear):
196
95
  nn.init.xavier_uniform_(module.weight)
197
96
  if module.bias is not None:
198
97
  nn.init.zeros_(module.bias)
98
+ elif isinstance(module, nn.Conv1d):
99
+ nn.init.kaiming_uniform_(module.weight)
100
+ if module.bias is not None:
101
+ nn.init.zeros_(module.bias)
199
102
 
200
- def forward(self, latent):
103
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
201
104
  batch_size = latent.shape[0]
202
105
 
203
- # 1. Latent expansion
204
- latent_expanded = self.latent_expansion(latent)
205
-
206
- # 2. Gene projection (memory efficient)
207
- gene_features = self.gene_projection(latent)
106
+ # 1. Expand latent variables
107
+ latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
208
108
 
209
- # 3. Add latent information
210
- latent_gene_injection = self.latent_to_gene(latent_expanded)
211
- gene_features = gene_features + latent_gene_injection
109
+ # 2. Direct projection to gene dimension
110
+ gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
212
111
 
213
- # 4. Chunked processing (memory efficient)
214
- gene_features = self.chunked_processor(gene_features)
112
+ # 3. Gene interaction with dimension safety
113
+ if self.gene_dim > 1: # Only apply if gene_dim > 1
114
+ gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
115
+ interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
116
+ gene_output = gene_output + interaction_output # Residual connection
117
+ gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
215
118
 
216
- # 5. Multi-head output with chunking
217
- final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
119
+ # 4. Final activation (ensure non-negative)
120
+ gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
218
121
 
219
- # Process output in chunks
220
- chunk_size = 5000
221
- for i in range(0, self.gene_dim, chunk_size):
222
- end_idx = min(i + chunk_size, self.gene_dim)
223
- chunk = gene_features[:, i:end_idx]
224
-
225
- head_outputs = []
226
- for head in self.output_heads:
227
- head_out = head(chunk).squeeze(-1)
228
- head_outputs.append(head_out)
229
-
230
- # Adaptive fusion
231
- gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
232
- gate_weights = gate_weights.unsqueeze(1)
233
-
234
- # Weighted fusion
235
- chunk_output = torch.zeros_like(head_outputs[0])
236
- for j, head_out in enumerate(head_outputs):
237
- chunk_output = chunk_output + gate_weights[:, :, j] * head_out
238
-
239
- final_output[:, i:end_idx] = chunk_output
240
-
241
- # Final activation
242
- final_output = F.softplus(final_output * self.output_scale + self.output_bias)
243
-
244
- return final_output
122
+ return gene_output
245
123
 
246
124
  def _build_model(self):
247
- """Build model"""
125
+ """Build the decoder model"""
248
126
  return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
249
127
 
128
+ def _create_dataset(self, latent_data, expression_data):
129
+ """Create dataset with dimension validation"""
130
+ class SimpleDataset(Dataset):
131
+ def __init__(self, latent, expression):
132
+ # Ensure dimensions match
133
+ assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
134
+ assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
135
+ assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
136
+
137
+ self.latent = torch.FloatTensor(latent)
138
+ self.expression = torch.FloatTensor(expression)
139
+
140
+ def __len__(self):
141
+ return len(self.latent)
142
+
143
+ def __getitem__(self, idx):
144
+ return self.latent[idx], self.expression[idx]
145
+
146
+ return SimpleDataset(latent_data, expression_data)
147
+
250
148
  def train(self,
251
149
  train_latent: np.ndarray,
252
150
  train_expression: np.ndarray,
253
151
  val_latent: np.ndarray = None,
254
152
  val_expression: np.ndarray = None,
255
- batch_size: int = 16, # Reduced batch size for memory
153
+ batch_size: int = 32,
256
154
  num_epochs: int = 100,
257
155
  learning_rate: float = 1e-4,
258
156
  checkpoint_path: str = 'transcriptome_decoder.pth'):
259
157
  """
260
- Memory-efficient training with optimizations
158
+ Train the decoder model with dimension safety
261
159
 
262
160
  Args:
263
- train_latent: Training latent variables
264
- train_expression: Training expression data
265
- val_latent: Validation latent variables
266
- val_expression: Validation expression data
267
- batch_size: Reduced batch size for memory constraints
161
+ train_latent: Training latent variables [n_samples, latent_dim]
162
+ train_expression: Training expression data [n_samples, gene_dim]
163
+ val_latent: Validation latent variables (optional)
164
+ val_expression: Validation expression data (optional)
165
+ batch_size: Batch size optimized for memory
268
166
  num_epochs: Number of training epochs
269
167
  learning_rate: Learning rate
270
- checkpoint_path: Model save path
168
+ checkpoint_path: Path to save the best model
271
169
  """
272
- print("🚀 Starting Training...")
273
- print(f"📊 Batch size: {batch_size}")
170
+ print("🚀 Starting training...")
274
171
 
275
- # Enable memory optimizations
276
- torch.backends.cudnn.benchmark = True
277
- if self.mixed_precision:
278
- scaler = torch.cuda.amp.GradScaler()
172
+ # Dimension validation
173
+ self._validate_data_dimensions(train_latent, train_expression, "Training")
174
+ if val_latent is not None and val_expression is not None:
175
+ self._validate_data_dimensions(val_latent, val_expression, "Validation")
279
176
 
280
177
  # Data preparation
281
- train_dataset = self._create_dataset(train_latent, train_expression)
178
+ train_dataset = self._create_safe_dataset(train_latent, train_expression)
282
179
 
283
180
  if val_latent is not None and val_expression is not None:
284
- val_dataset = self._create_dataset(val_latent, val_expression)
181
+ val_dataset = self._create_safe_dataset(val_latent, val_expression)
182
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
183
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
184
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
285
185
  else:
286
186
  # Auto split
287
187
  train_size = int(0.9 * len(train_dataset))
288
188
  val_size = len(train_dataset) - train_size
289
- train_dataset, val_dataset = torch.utils.data.random_split(
290
- train_dataset, [train_size, val_size]
291
- )
189
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
190
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
191
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
192
+ print(f"📈 Auto-split validation: {val_size} samples")
292
193
 
293
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
294
- pin_memory=True, num_workers=2)
295
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
296
- pin_memory=True, num_workers=2)
194
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
195
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
196
+ print(f"📊 Batch size: {batch_size}")
297
197
 
298
- # Optimizer with memory-friendly settings
198
+ # Optimizer configuration
299
199
  optimizer = optim.AdamW(
300
200
  self.model.parameters(),
301
201
  lr=learning_rate,
@@ -306,130 +206,181 @@ class TranscriptomeDecoder:
306
206
  # Learning rate scheduler
307
207
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
308
208
 
309
- # Loss function
310
- criterion = nn.MSELoss()
209
+ # Loss function with dimension safety
210
+ def safe_loss(pred, target):
211
+ # Ensure dimensions match
212
+ if pred.shape != target.shape:
213
+ print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
214
+ # Truncate to minimum dimension (safety measure)
215
+ min_dim = min(pred.shape[1], target.shape[1])
216
+ pred = pred[:, :min_dim]
217
+ target = target[:, :min_dim]
218
+
219
+ def correlation_loss(pred, target):
220
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
221
+ target_centered = target - target.mean(dim=1, keepdim=True)
222
+
223
+ correlation = (pred_centered * target_centered).sum(dim=1) / (
224
+ torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
225
+ torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
226
+ )
227
+
228
+ return 1 - correlation.mean()
229
+
230
+ mse_loss = F.mse_loss(pred, target)
231
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
232
+ corr_loss = correlation_loss(pred, target)
233
+ return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
311
234
 
312
235
  # Training history
313
236
  history = {
314
- 'train_loss': [], 'val_loss': [],
315
- 'learning_rate': [], 'memory_usage': []
237
+ 'train_loss': [],
238
+ 'val_loss': [],
239
+ 'learning_rate': []
316
240
  }
317
241
 
318
242
  best_val_loss = float('inf')
243
+ patience = 15
244
+ patience_counter = 0
319
245
 
246
+ print("\n📈 Starting training loop...")
320
247
  for epoch in range(1, num_epochs + 1):
321
- print(f"\n📍 Epoch {epoch}/{num_epochs}")
322
-
323
- # Training phase with memory monitoring
324
- train_loss = self._train_epoch(
325
- train_loader, optimizer, criterion, scaler if self.mixed_precision else None
326
- )
248
+ # Training phase
249
+ train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
327
250
 
328
251
  # Validation phase
329
- val_loss = self._validate_epoch(val_loader, criterion)
252
+ val_loss = self._validate_epoch(val_loader, safe_loss)
330
253
 
331
254
  # Update scheduler
332
255
  scheduler.step()
256
+ current_lr = scheduler.get_last_lr()[0]
333
257
 
334
258
  # Record history
335
259
  history['train_loss'].append(train_loss)
336
260
  history['val_loss'].append(val_loss)
337
- history['learning_rate'].append(optimizer.param_groups[0]['lr'])
338
-
339
- # Memory usage tracking
340
- if torch.cuda.is_available():
341
- memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
342
- history['memory_usage'].append(memory_used)
343
- print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
261
+ history['learning_rate'].append(current_lr)
344
262
 
345
- print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
346
- print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
263
+ # Print progress
264
+ if epoch % 5 == 0 or epoch == 1:
265
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
266
+ f"Train Loss: {train_loss:.4f} | "
267
+ f"Val Loss: {val_loss:.4f} | "
268
+ f"LR: {current_lr:.2e}")
347
269
 
348
- # Save best model
270
+ # Early stopping and model saving
349
271
  if val_loss < best_val_loss:
350
272
  best_val_loss = val_loss
273
+ patience_counter = 0
351
274
  self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
352
- print("💾 Best model saved!")
275
+ if epoch % 10 == 0:
276
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
277
+ else:
278
+ patience_counter += 1
279
+ if patience_counter >= patience:
280
+ print(f"🛑 Early stopping at epoch {epoch}")
281
+ break
353
282
 
283
+ # Training completed
354
284
  self.is_trained = True
355
285
  self.training_history = history
356
286
  self.best_val_loss = best_val_loss
357
287
 
358
- print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
288
+ print(f"\n🎉 Training completed!")
289
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
290
+ print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
291
+
359
292
  return history
360
293
 
361
- def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
362
- """Training epoch"""
294
+ def _validate_data_dimensions(self, latent_data, expression_data, data_type):
295
+ """Validate input data dimensions"""
296
+ assert latent_data.shape[1] == self.latent_dim, (
297
+ f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
298
+ assert expression_data.shape[1] == self.gene_dim, (
299
+ f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
300
+ assert latent_data.shape[0] == expression_data.shape[0], (
301
+ f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
302
+ print(f"✅ {data_type} data dimensions validated")
303
+
304
+ def _create_safe_dataset(self, latent_data, expression_data):
305
+ """Create dataset with safety checks"""
306
+ class SafeDataset(Dataset):
307
+ def __init__(self, latent, expression):
308
+ self.latent = torch.FloatTensor(latent)
309
+ self.expression = torch.FloatTensor(expression)
310
+
311
+ # Safety check
312
+ if self.latent.shape[0] != self.expression.shape[0]:
313
+ raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
314
+
315
+ def __len__(self):
316
+ return len(self.latent)
317
+
318
+ def __getitem__(self, idx):
319
+ return self.latent[idx], self.expression[idx]
320
+
321
+ return SafeDataset(latent_data, expression_data)
322
+
323
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
324
+ """Train for one epoch with dimension safety"""
363
325
  self.model.train()
364
326
  total_loss = 0
365
327
 
366
- pbar = tqdm(train_loader, desc='Training')
367
- for latent, target in pbar:
368
- latent = latent.to(self.device, non_blocking=True)
369
- target = target.to(self.device, non_blocking=True)
370
-
371
- optimizer.zero_grad(set_to_none=True) # Memory optimization
328
+ for batch_idx, (latent, target) in enumerate(train_loader):
329
+ latent = latent.to(self.device)
330
+ target = target.to(self.device)
372
331
 
373
- if scaler: # Mixed precision training
374
- with torch.cuda.amp.autocast():
375
- pred = self.model(latent)
376
- loss = criterion(pred, target)
332
+ # Dimension check
333
+ if latent.shape[1] != self.latent_dim:
334
+ print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
335
+ continue
377
336
 
378
- scaler.scale(loss).backward()
379
- scaler.step(optimizer)
380
- scaler.update()
381
- else:
382
- pred = self.model(latent)
383
- loss = criterion(pred, target)
384
- loss.backward()
385
- optimizer.step()
337
+ optimizer.zero_grad()
338
+ pred = self.model(latent)
386
339
 
387
- total_loss += loss.item()
388
- pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
340
+ # Final dimension check before loss calculation
341
+ if pred.shape[1] != target.shape[1]:
342
+ min_dim = min(pred.shape[1], target.shape[1])
343
+ pred = pred[:, :min_dim]
344
+ target = target[:, :min_dim]
345
+ if batch_idx == 0: # Only warn once
346
+ print(f"⚠️ Truncating to min dimension: {min_dim}")
389
347
 
390
- # Clear memory
391
- del pred, loss
392
- if torch.cuda.is_available():
393
- torch.cuda.empty_cache()
348
+ loss = loss_fn(pred, target)
349
+ loss.backward()
350
+
351
+ # Gradient clipping for stability
352
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
353
+ optimizer.step()
354
+
355
+ total_loss += loss.item()
394
356
 
395
357
  return total_loss / len(train_loader)
396
358
 
397
- def _validate_epoch(self, val_loader, criterion):
398
- """Validation"""
359
+ def _validate_epoch(self, val_loader, loss_fn):
360
+ """Validate for one epoch with dimension safety"""
399
361
  self.model.eval()
400
362
  total_loss = 0
401
363
 
402
364
  with torch.no_grad():
403
- for latent, target in val_loader:
404
- latent = latent.to(self.device, non_blocking=True)
405
- target = target.to(self.device, non_blocking=True)
365
+ for batch_idx, (latent, target) in enumerate(val_loader):
366
+ latent = latent.to(self.device)
367
+ target = target.to(self.device)
406
368
 
407
369
  pred = self.model(latent)
408
- loss = criterion(pred, target)
409
- total_loss += loss.item()
410
370
 
411
- # Clear memory
412
- del pred, loss
371
+ # Dimension safety
372
+ if pred.shape[1] != target.shape[1]:
373
+ min_dim = min(pred.shape[1], target.shape[1])
374
+ pred = pred[:, :min_dim]
375
+ target = target[:, :min_dim]
376
+
377
+ loss = loss_fn(pred, target)
378
+ total_loss += loss.item()
413
379
 
414
380
  return total_loss / len(val_loader)
415
381
 
416
- def _create_dataset(self, latent_data, expression_data):
417
- """Create dataset"""
418
- class EfficientDataset(Dataset):
419
- def __init__(self, latent, expression):
420
- self.latent = torch.FloatTensor(latent)
421
- self.expression = torch.FloatTensor(expression)
422
-
423
- def __len__(self):
424
- return len(self.latent)
425
-
426
- def __getitem__(self, idx):
427
- return self.latent[idx], self.expression[idx]
428
-
429
- return EfficientDataset(latent_data, expression_data)
430
-
431
382
  def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
432
- """Save checkpoint"""
383
+ """Save model checkpoint"""
433
384
  torch.save({
434
385
  'epoch': epoch,
435
386
  'model_state_dict': self.model.state_dict(),
@@ -444,22 +395,26 @@ class TranscriptomeDecoder:
444
395
  }
445
396
  }, path)
446
397
 
447
- def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
398
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
448
399
  """
449
- Prediction
400
+ Predict gene expression from latent variables
450
401
 
451
402
  Args:
452
403
  latent_data: Latent variables [n_samples, latent_dim]
453
- batch_size: Prediction batch size for memory control
404
+ batch_size: Prediction batch size
454
405
 
455
406
  Returns:
456
407
  expression: Predicted expression [n_samples, gene_dim]
457
408
  """
458
409
  if not self.is_trained:
459
- warnings.warn("Model not trained. Predictions may be inaccurate.")
410
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
460
411
 
461
412
  self.model.eval()
462
413
 
414
+ # Input validation
415
+ if latent_data.shape[1] != self.latent_dim:
416
+ raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
417
+
463
418
  if isinstance(latent_data, np.ndarray):
464
419
  latent_data = torch.FloatTensor(latent_data)
465
420
 
@@ -470,76 +425,83 @@ class TranscriptomeDecoder:
470
425
  batch_latent = latent_data[i:i+batch_size].to(self.device)
471
426
  batch_pred = self.model(batch_latent)
472
427
  predictions.append(batch_pred.cpu())
473
-
474
- # Clear memory
475
- del batch_pred
476
- if torch.cuda.is_available():
477
- torch.cuda.empty_cache()
478
428
 
479
429
  return torch.cat(predictions).numpy()
480
430
 
481
431
  def load_model(self, model_path: str):
482
432
  """Load pre-trained model"""
483
433
  checkpoint = torch.load(model_path, map_location=self.device)
434
+
435
+ # Check model configuration
436
+ if 'model_config' in checkpoint:
437
+ config = checkpoint['model_config']
438
+ if (config['latent_dim'] != self.latent_dim or
439
+ config['gene_dim'] != self.gene_dim):
440
+ print("⚠️ Model configuration mismatch. Reinitializing model.")
441
+ self.model = self._build_model()
442
+ self.model.to(self.device)
443
+
484
444
  self.model.load_state_dict(checkpoint['model_state_dict'])
485
445
  self.is_trained = True
486
446
  self.training_history = checkpoint.get('training_history')
487
447
  self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
488
- print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
489
-
490
- def get_memory_info(self) -> Dict:
491
- """Get memory usage information"""
492
- if torch.cuda.is_available():
493
- memory_allocated = torch.cuda.memory_allocated() / 1024**3
494
- memory_reserved = torch.cuda.memory_reserved() / 1024**3
495
- return {
496
- 'allocated_gb': memory_allocated,
497
- 'reserved_gb': memory_reserved,
498
- 'available_gb': 20 - memory_allocated,
499
- 'utilization_percent': (memory_allocated / 20) * 100
500
- }
501
- return {'available_gb': 'N/A (CPU mode)'}
448
+
449
+ print(f"✅ Model loaded successfully!")
450
+ print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
502
451
 
452
+ def get_model_info(self) -> Dict:
453
+ """Get model information"""
454
+ return {
455
+ 'is_trained': self.is_trained,
456
+ 'best_val_loss': self.best_val_loss,
457
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
458
+ 'latent_dim': self.latent_dim,
459
+ 'gene_dim': self.gene_dim,
460
+ 'hidden_dim': self.hidden_dim,
461
+ 'device': str(self.device)
462
+ }
503
463
  '''
504
- # Example usage with memory monitoring
464
+ # Example usage
505
465
  def example_usage():
506
- """Memory-efficient example"""
466
+ """Example demonstration with dimension safety"""
507
467
 
508
- # 1. Initialize memory-efficient decoder
509
- decoder = TranscriptomeDecoder(
468
+ # 1. Initialize decoder
469
+ decoder = SimpleTranscriptomeDecoder(
510
470
  latent_dim=100,
511
471
  gene_dim=2000, # Reduced for example
512
- hidden_dim=256 # Reduced for memory
472
+ hidden_dim=256
513
473
  )
514
474
 
515
- # Check memory info
516
- memory_info = decoder.get_memory_info()
517
- print(f"📊 Memory Info: {memory_info}")
518
-
519
- # 2. Generate example data
520
- n_samples = 500 # Reduced for memory
475
+ # 2. Generate example data with correct dimensions
476
+ n_samples = 1000
521
477
  latent_data = np.random.randn(n_samples, 100).astype(np.float32)
522
- expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
478
+
479
+ # Create simulated expression data
480
+ weights = np.random.randn(100, 2000) * 0.1
481
+ expression_data = np.tanh(latent_data.dot(weights))
523
482
  expression_data = np.maximum(expression_data, 0) # Non-negative
524
483
 
525
- print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
484
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
526
485
 
527
- # 3. Train with memory monitoring
486
+ # 3. Train the model
528
487
  history = decoder.train(
529
488
  train_latent=latent_data,
530
489
  train_expression=expression_data,
531
- batch_size=8, # Small batch for memory
532
- num_epochs=20 # Reduced for example
490
+ batch_size=32,
491
+ num_epochs=50,
492
+ learning_rate=1e-4
533
493
  )
534
494
 
535
- # 4. Memory-efficient prediction
536
- test_latent = np.random.randn(5, 100).astype(np.float32)
537
- predictions = decoder.predict(test_latent, batch_size=2)
495
+ # 4. Make predictions
496
+ test_latent = np.random.randn(10, 100).astype(np.float32)
497
+ predictions = decoder.predict(test_latent)
538
498
  print(f"🔮 Prediction shape: {predictions.shape}")
539
499
 
540
- # 5. Final memory check
541
- final_memory = decoder.get_memory_info()
542
- print(f"💾 Final memory usage: {final_memory}")
500
+ # 5. Get model info
501
+ info = decoder.get_model_info()
502
+ print(f"\n📋 Model Info:")
503
+ for key, value in info.items():
504
+ print(f" {key}: {value}")
543
505
 
544
506
  return decoder
545
507