SURE-tools 2.4.2__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -1,28 +1,28 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- import torch.optim as optim
4
3
  import torch.nn.functional as F
4
+ import torch.optim as optim
5
5
  from torch.utils.data import Dataset, DataLoader
6
6
  import numpy as np
7
- from typing import Dict, List, Tuple, Optional
8
- import matplotlib.pyplot as plt
9
- from tqdm import tqdm
7
+ from typing import Dict, Optional
10
8
  import warnings
11
9
  warnings.filterwarnings('ignore')
12
10
 
13
11
  class TranscriptomeDecoder:
12
+ """Transcriptome decoder"""
13
+
14
14
  def __init__(self,
15
15
  latent_dim: int = 100,
16
16
  gene_dim: int = 60000,
17
- hidden_dim: int = 512, # Reduced for memory efficiency
17
+ hidden_dim: int = 512,
18
18
  device: str = None):
19
19
  """
20
- Whole-transcriptome decoder
20
+ Simple but powerful decoder for latent to transcriptome mapping
21
21
 
22
22
  Args:
23
23
  latent_dim: Latent variable dimension (typically 50-100)
24
24
  gene_dim: Number of genes (full transcriptome ~60,000)
25
- hidden_dim: Hidden dimension (reduced for memory efficiency)
25
+ hidden_dim: Hidden dimension optimized
26
26
  device: Computation device
27
27
  """
28
28
  self.latent_dim = latent_dim
@@ -30,10 +30,6 @@ class TranscriptomeDecoder:
30
30
  self.hidden_dim = hidden_dim
31
31
  self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
32
32
 
33
- # Memory optimization settings
34
- self.gradient_checkpointing = True
35
- self.mixed_precision = True
36
-
37
33
  # Initialize model
38
34
  self.model = self._build_model()
39
35
  self.model.to(self.device)
@@ -43,123 +39,47 @@ class TranscriptomeDecoder:
43
39
  self.training_history = None
44
40
  self.best_val_loss = float('inf')
45
41
 
46
- print(f"🚀 TranscriptomeDecoder Initialized:")
42
+ print(f"🚀 SimpleTranscriptomeDecoder Initialized:")
47
43
  print(f" - Latent Dimension: {latent_dim}")
48
44
  print(f" - Gene Dimension: {gene_dim}")
49
45
  print(f" - Hidden Dimension: {hidden_dim}")
50
46
  print(f" - Device: {self.device}")
51
47
  print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
52
48
 
53
- class MemoryEfficientBlock(nn.Module):
54
- """Memory-efficient building block with gradient checkpointing"""
55
- def __init__(self, input_dim, output_dim, use_checkpointing=True):
56
- super().__init__()
57
- self.use_checkpointing = use_checkpointing
58
- self.net = nn.Sequential(
59
- nn.Linear(input_dim, output_dim),
60
- nn.BatchNorm1d(output_dim),
61
- nn.GELU(),
62
- nn.Dropout(0.1)
63
- )
64
-
65
- def forward(self, x):
66
- if self.use_checkpointing and self.training:
67
- return torch.utils.checkpoint.checkpoint(self.net, x)
68
- return self.net(x)
69
-
70
- class SparseGeneProjection(nn.Module):
71
- """Sparse gene projection to reduce memory usage"""
72
- def __init__(self, latent_dim, gene_dim, projection_dim=256):
73
- super().__init__()
74
- self.projection_dim = projection_dim
75
- self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, projection_dim) * 0.02)
76
- self.latent_projection = nn.Linear(latent_dim, projection_dim)
77
- self.activation = nn.GELU()
78
-
79
- def forward(self, latent):
80
- # Project latent to gene space efficiently
81
- batch_size = latent.shape[0]
82
- latent_proj = self.latent_projection(latent) # [batch, projection_dim]
83
-
84
- # Efficient matrix multiplication
85
- gene_embeds = self.gene_embeddings.T # [projection_dim, gene_dim]
86
- output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
-
88
- return self.activation(output)
89
-
90
- class ChunkedTransformer(nn.Module):
91
- """Process genes in chunks to reduce memory usage"""
92
- def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
93
- super().__init__()
94
- self.chunk_size = chunk_size
95
- self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
- self.layers = nn.ModuleList([
97
- nn.Sequential(
98
- nn.Linear(hidden_dim, hidden_dim),
99
- nn.GELU(),
100
- nn.Dropout(0.1),
101
- nn.Linear(hidden_dim, hidden_dim),
102
- ) for _ in range(num_layers)
103
- ])
104
-
105
- def forward(self, x):
106
- # Process in chunks to save memory
107
- batch_size = x.shape[0]
108
- output = torch.zeros_like(x)
109
-
110
- for i in range(self.num_chunks):
111
- start_idx = i * self.chunk_size
112
- end_idx = min((i + 1) * self.chunk_size, x.shape[1])
113
-
114
- chunk = x[:, start_idx:end_idx]
115
- for layer in self.layers:
116
- chunk = layer(chunk) + chunk # Residual connection
117
-
118
- output[:, start_idx:end_idx] = chunk
119
-
120
- return output
121
-
122
49
  class Decoder(nn.Module):
123
- """Decoder model"""
124
- def __init__(self, latent_dim, gene_dim, hidden_dim):
50
+ """Memory-efficient decoder architecture with dimension handling"""
51
+
52
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dim: int):
125
53
  super().__init__()
126
54
  self.latent_dim = latent_dim
127
55
  self.gene_dim = gene_dim
128
56
  self.hidden_dim = hidden_dim
129
57
 
130
- # Stage 1: Latent expansion (memory efficient)
58
+ # Stage 1: Latent variable expansion
131
59
  self.latent_expansion = nn.Sequential(
132
60
  nn.Linear(latent_dim, hidden_dim * 2),
61
+ nn.BatchNorm1d(hidden_dim * 2),
133
62
  nn.GELU(),
134
63
  nn.Dropout(0.1),
135
64
  nn.Linear(hidden_dim * 2, hidden_dim),
65
+ nn.BatchNorm1d(hidden_dim),
66
+ nn.GELU(),
136
67
  )
137
68
 
138
- # Stage 2: Sparse gene projection
139
- self.gene_projection = TranscriptomeDecoder.SparseGeneProjection(
140
- latent_dim, gene_dim, hidden_dim
141
- )
142
-
143
- # Stage 3: Chunked processing
144
- self.chunked_processor = TranscriptomeDecoder.ChunkedTransformer(
145
- gene_dim, hidden_dim, chunk_size=2000, num_layers=3
69
+ # Stage 2: Direct projection to gene dimension (simpler approach)
70
+ self.gene_projector = nn.Sequential(
71
+ nn.Linear(hidden_dim, hidden_dim * 2),
72
+ nn.GELU(),
73
+ nn.Dropout(0.1),
74
+ nn.Linear(hidden_dim * 2, gene_dim), # Direct projection to gene_dim
146
75
  )
147
76
 
148
- # Stage 4: Multi-head output with memory efficiency
149
- self.output_heads = nn.ModuleList([
150
- nn.Sequential(
151
- nn.Linear(hidden_dim, hidden_dim // 2),
152
- nn.GELU(),
153
- nn.Linear(hidden_dim // 2, 1)
154
- ) for _ in range(2) # Reduced from 3 to 2 heads
155
- ])
156
-
157
- # Adaptive fusion
158
- self.fusion_gate = nn.Sequential(
159
- nn.Linear(hidden_dim, hidden_dim // 4),
77
+ # Stage 3: Lightweight gene interaction
78
+ self.gene_interaction = nn.Sequential(
79
+ nn.Conv1d(1, 32, kernel_size=3, padding=1),
160
80
  nn.GELU(),
161
- nn.Linear(hidden_dim // 4, len(self.output_heads)),
162
- nn.Softmax(dim=-1)
81
+ nn.Dropout1d(0.1),
82
+ nn.Conv1d(32, 1, kernel_size=3, padding=1),
163
83
  )
164
84
 
165
85
  # Output scaling
@@ -169,111 +89,113 @@ class TranscriptomeDecoder:
169
89
  self._init_weights()
170
90
 
171
91
  def _init_weights(self):
92
+ """Weight initialization"""
172
93
  for module in self.modules():
173
94
  if isinstance(module, nn.Linear):
174
95
  nn.init.xavier_uniform_(module.weight)
175
96
  if module.bias is not None:
176
97
  nn.init.zeros_(module.bias)
98
+ elif isinstance(module, nn.Conv1d):
99
+ nn.init.kaiming_uniform_(module.weight)
100
+ if module.bias is not None:
101
+ nn.init.zeros_(module.bias)
177
102
 
178
- def forward(self, latent):
103
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
179
104
  batch_size = latent.shape[0]
180
105
 
181
- # 1. Latent expansion
182
- latent_expanded = self.latent_expansion(latent)
183
-
184
- # 2. Gene projection (memory efficient)
185
- gene_features = self.gene_projection(latent)
106
+ # 1. Expand latent variables
107
+ latent_features = self.latent_expansion(latent) # [batch_size, hidden_dim]
186
108
 
187
- # 3. Add latent information
188
- print(f'{gene_features.shape}; {latent_expanded.shape}')
189
- gene_features = gene_features + latent_expanded.unsqueeze(1)
109
+ # 2. Direct projection to gene dimension
110
+ gene_output = self.gene_projector(latent_features) # [batch_size, gene_dim]
190
111
 
191
- # 4. Chunked processing (memory efficient)
192
- gene_features = self.chunked_processor(gene_features)
193
-
194
- # 5. Multi-head output with chunking
195
- final_output = torch.zeros(batch_size, self.gene_dim, device=latent.device)
196
-
197
- # Process output in chunks
198
- chunk_size = 5000
199
- for i in range(0, self.gene_dim, chunk_size):
200
- end_idx = min(i + chunk_size, self.gene_dim)
201
- chunk = gene_features[:, i:end_idx]
202
-
203
- head_outputs = []
204
- for head in self.output_heads:
205
- head_out = head(chunk).squeeze(-1)
206
- head_outputs.append(head_out)
207
-
208
- # Adaptive fusion
209
- gate_weights = self.fusion_gate(chunk.mean(dim=1, keepdim=True))
210
- gate_weights = gate_weights.unsqueeze(1)
211
-
212
- # Weighted fusion
213
- chunk_output = torch.zeros_like(head_outputs[0])
214
- for j, head_out in enumerate(head_outputs):
215
- chunk_output = chunk_output + gate_weights[:, :, j] * head_out
216
-
217
- final_output[:, i:end_idx] = chunk_output
112
+ # 3. Gene interaction with dimension safety
113
+ if self.gene_dim > 1: # Only apply if gene_dim > 1
114
+ gene_output = gene_output.unsqueeze(1) # [batch_size, 1, gene_dim]
115
+ interaction_output = self.gene_interaction(gene_output) # [batch_size, 1, gene_dim]
116
+ gene_output = gene_output + interaction_output # Residual connection
117
+ gene_output = gene_output.squeeze(1) # [batch_size, gene_dim]
218
118
 
219
- # Final activation
220
- final_output = F.softplus(final_output * self.output_scale + self.output_bias)
119
+ # 4. Final activation (ensure non-negative)
120
+ gene_output = F.softplus(gene_output * self.output_scale + self.output_bias)
221
121
 
222
- return final_output
122
+ return gene_output
223
123
 
224
124
  def _build_model(self):
225
- """Build model"""
125
+ """Build the decoder model"""
226
126
  return self.Decoder(self.latent_dim, self.gene_dim, self.hidden_dim)
227
127
 
128
+ def _create_dataset(self, latent_data, expression_data):
129
+ """Create dataset with dimension validation"""
130
+ class SimpleDataset(Dataset):
131
+ def __init__(self, latent, expression):
132
+ # Ensure dimensions match
133
+ assert latent.shape[0] == expression.shape[0], "Sample count mismatch"
134
+ assert latent.shape[1] == self.latent_dim, f"Latent dim mismatch: expected {self.latent_dim}, got {latent.shape[1]}"
135
+ assert expression.shape[1] == self.gene_dim, f"Gene dim mismatch: expected {self.gene_dim}, got {expression.shape[1]}"
136
+
137
+ self.latent = torch.FloatTensor(latent)
138
+ self.expression = torch.FloatTensor(expression)
139
+
140
+ def __len__(self):
141
+ return len(self.latent)
142
+
143
+ def __getitem__(self, idx):
144
+ return self.latent[idx], self.expression[idx]
145
+
146
+ return SimpleDataset(latent_data, expression_data)
147
+
228
148
  def train(self,
229
149
  train_latent: np.ndarray,
230
150
  train_expression: np.ndarray,
231
151
  val_latent: np.ndarray = None,
232
152
  val_expression: np.ndarray = None,
233
- batch_size: int = 16, # Reduced batch size for memory
153
+ batch_size: int = 32,
234
154
  num_epochs: int = 100,
235
155
  learning_rate: float = 1e-4,
236
156
  checkpoint_path: str = 'transcriptome_decoder.pth'):
237
157
  """
238
- Memory-efficient training with optimizations
158
+ Train the decoder model with dimension safety
239
159
 
240
160
  Args:
241
- train_latent: Training latent variables
242
- train_expression: Training expression data
243
- val_latent: Validation latent variables
244
- val_expression: Validation expression data
245
- batch_size: Reduced batch size for memory constraints
161
+ train_latent: Training latent variables [n_samples, latent_dim]
162
+ train_expression: Training expression data [n_samples, gene_dim]
163
+ val_latent: Validation latent variables (optional)
164
+ val_expression: Validation expression data (optional)
165
+ batch_size: Batch size optimized for memory
246
166
  num_epochs: Number of training epochs
247
167
  learning_rate: Learning rate
248
- checkpoint_path: Model save path
168
+ checkpoint_path: Path to save the best model
249
169
  """
250
- print("🚀 Starting Training...")
251
- print(f"📊 Batch size: {batch_size}")
170
+ print("🚀 Starting training...")
252
171
 
253
- # Enable memory optimizations
254
- torch.backends.cudnn.benchmark = True
255
- if self.mixed_precision:
256
- scaler = torch.cuda.amp.GradScaler()
172
+ # Dimension validation
173
+ self._validate_data_dimensions(train_latent, train_expression, "Training")
174
+ if val_latent is not None and val_expression is not None:
175
+ self._validate_data_dimensions(val_latent, val_expression, "Validation")
257
176
 
258
177
  # Data preparation
259
- train_dataset = self._create_dataset(train_latent, train_expression)
178
+ train_dataset = self._create_safe_dataset(train_latent, train_expression)
260
179
 
261
180
  if val_latent is not None and val_expression is not None:
262
- val_dataset = self._create_dataset(val_latent, val_expression)
181
+ val_dataset = self._create_safe_dataset(val_latent, val_expression)
182
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
183
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
184
+ print(f"📈 Using provided validation data: {len(val_dataset)} samples")
263
185
  else:
264
186
  # Auto split
265
187
  train_size = int(0.9 * len(train_dataset))
266
188
  val_size = len(train_dataset) - train_size
267
- train_dataset, val_dataset = torch.utils.data.random_split(
268
- train_dataset, [train_size, val_size]
269
- )
189
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
190
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
191
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
192
+ print(f"📈 Auto-split validation: {val_size} samples")
270
193
 
271
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
272
- pin_memory=True, num_workers=2)
273
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
274
- pin_memory=True, num_workers=2)
194
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
195
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
196
+ print(f"📊 Batch size: {batch_size}")
275
197
 
276
- # Optimizer with memory-friendly settings
198
+ # Optimizer configuration
277
199
  optimizer = optim.AdamW(
278
200
  self.model.parameters(),
279
201
  lr=learning_rate,
@@ -284,130 +206,181 @@ class TranscriptomeDecoder:
284
206
  # Learning rate scheduler
285
207
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
286
208
 
287
- # Loss function
288
- criterion = nn.MSELoss()
209
+ # Loss function with dimension safety
210
+ def safe_loss(pred, target):
211
+ # Ensure dimensions match
212
+ if pred.shape != target.shape:
213
+ print(f"⚠️ Dimension mismatch: pred {pred.shape}, target {target.shape}")
214
+ # Truncate to minimum dimension (safety measure)
215
+ min_dim = min(pred.shape[1], target.shape[1])
216
+ pred = pred[:, :min_dim]
217
+ target = target[:, :min_dim]
218
+
219
+ def correlation_loss(pred, target):
220
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
221
+ target_centered = target - target.mean(dim=1, keepdim=True)
222
+
223
+ correlation = (pred_centered * target_centered).sum(dim=1) / (
224
+ torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) *
225
+ torch.sqrt(torch.sum(target_centered ** 2, dim=1)) + 1e-8
226
+ )
227
+
228
+ return 1 - correlation.mean()
229
+
230
+ mse_loss = F.mse_loss(pred, target)
231
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
232
+ corr_loss = correlation_loss(pred, target)
233
+ return mse_loss + 0.5 * poisson_loss + 0.3 * corr_loss
289
234
 
290
235
  # Training history
291
236
  history = {
292
- 'train_loss': [], 'val_loss': [],
293
- 'learning_rate': [], 'memory_usage': []
237
+ 'train_loss': [],
238
+ 'val_loss': [],
239
+ 'learning_rate': []
294
240
  }
295
241
 
296
242
  best_val_loss = float('inf')
243
+ patience = 15
244
+ patience_counter = 0
297
245
 
246
+ print("\n📈 Starting training loop...")
298
247
  for epoch in range(1, num_epochs + 1):
299
- print(f"\n📍 Epoch {epoch}/{num_epochs}")
300
-
301
- # Training phase with memory monitoring
302
- train_loss = self._train_epoch(
303
- train_loader, optimizer, criterion, scaler if self.mixed_precision else None
304
- )
248
+ # Training phase
249
+ train_loss = self._train_epoch(train_loader, optimizer, safe_loss)
305
250
 
306
251
  # Validation phase
307
- val_loss = self._validate_epoch(val_loader, criterion)
252
+ val_loss = self._validate_epoch(val_loader, safe_loss)
308
253
 
309
254
  # Update scheduler
310
255
  scheduler.step()
256
+ current_lr = scheduler.get_last_lr()[0]
311
257
 
312
258
  # Record history
313
259
  history['train_loss'].append(train_loss)
314
260
  history['val_loss'].append(val_loss)
315
- history['learning_rate'].append(optimizer.param_groups[0]['lr'])
316
-
317
- # Memory usage tracking
318
- if torch.cuda.is_available():
319
- memory_used = torch.cuda.memory_allocated() / 1024**3 # GB
320
- history['memory_usage'].append(memory_used)
321
- print(f"💾 GPU Memory: {memory_used:.1f}GB / 20GB")
261
+ history['learning_rate'].append(current_lr)
322
262
 
323
- print(f"📊 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
324
- print(f"⚡ Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
263
+ # Print progress
264
+ if epoch % 5 == 0 or epoch == 1:
265
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
266
+ f"Train Loss: {train_loss:.4f} | "
267
+ f"Val Loss: {val_loss:.4f} | "
268
+ f"LR: {current_lr:.2e}")
325
269
 
326
- # Save best model
270
+ # Early stopping and model saving
327
271
  if val_loss < best_val_loss:
328
272
  best_val_loss = val_loss
273
+ patience_counter = 0
329
274
  self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
330
- print("💾 Best model saved!")
275
+ if epoch % 10 == 0:
276
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
277
+ else:
278
+ patience_counter += 1
279
+ if patience_counter >= patience:
280
+ print(f"🛑 Early stopping at epoch {epoch}")
281
+ break
331
282
 
283
+ # Training completed
332
284
  self.is_trained = True
333
285
  self.training_history = history
334
286
  self.best_val_loss = best_val_loss
335
287
 
336
- print(f"\n🎉 Training completed! Best validation loss: {best_val_loss:.4f}")
288
+ print(f"\n🎉 Training completed!")
289
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
290
+ print(f"📊 Final training loss: {history['train_loss'][-1]:.4f}")
291
+
337
292
  return history
338
293
 
339
- def _train_epoch(self, train_loader, optimizer, criterion, scaler=None):
340
- """Training epoch"""
294
+ def _validate_data_dimensions(self, latent_data, expression_data, data_type):
295
+ """Validate input data dimensions"""
296
+ assert latent_data.shape[1] == self.latent_dim, (
297
+ f"{data_type} latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
298
+ assert expression_data.shape[1] == self.gene_dim, (
299
+ f"{data_type} gene dimension mismatch: expected {self.gene_dim}, got {expression_data.shape[1]}")
300
+ assert latent_data.shape[0] == expression_data.shape[0], (
301
+ f"{data_type} sample count mismatch: latent {latent_data.shape[0]}, expression {expression_data.shape[0]}")
302
+ print(f"✅ {data_type} data dimensions validated")
303
+
304
+ def _create_safe_dataset(self, latent_data, expression_data):
305
+ """Create dataset with safety checks"""
306
+ class SafeDataset(Dataset):
307
+ def __init__(self, latent, expression):
308
+ self.latent = torch.FloatTensor(latent)
309
+ self.expression = torch.FloatTensor(expression)
310
+
311
+ # Safety check
312
+ if self.latent.shape[0] != self.expression.shape[0]:
313
+ raise ValueError(f"Sample count mismatch: latent {self.latent.shape[0]}, expression {self.expression.shape[0]}")
314
+
315
+ def __len__(self):
316
+ return len(self.latent)
317
+
318
+ def __getitem__(self, idx):
319
+ return self.latent[idx], self.expression[idx]
320
+
321
+ return SafeDataset(latent_data, expression_data)
322
+
323
+ def _train_epoch(self, train_loader, optimizer, loss_fn):
324
+ """Train for one epoch with dimension safety"""
341
325
  self.model.train()
342
326
  total_loss = 0
343
327
 
344
- pbar = tqdm(train_loader, desc='Training')
345
- for latent, target in pbar:
346
- latent = latent.to(self.device, non_blocking=True)
347
- target = target.to(self.device, non_blocking=True)
348
-
349
- optimizer.zero_grad(set_to_none=True) # Memory optimization
328
+ for batch_idx, (latent, target) in enumerate(train_loader):
329
+ latent = latent.to(self.device)
330
+ target = target.to(self.device)
350
331
 
351
- if scaler: # Mixed precision training
352
- with torch.cuda.amp.autocast():
353
- pred = self.model(latent)
354
- loss = criterion(pred, target)
332
+ # Dimension check
333
+ if latent.shape[1] != self.latent_dim:
334
+ print(f"⚠️ Batch {batch_idx}: Latent dim mismatch {latent.shape[1]} != {self.latent_dim}")
335
+ continue
355
336
 
356
- scaler.scale(loss).backward()
357
- scaler.step(optimizer)
358
- scaler.update()
359
- else:
360
- pred = self.model(latent)
361
- loss = criterion(pred, target)
362
- loss.backward()
363
- optimizer.step()
337
+ optimizer.zero_grad()
338
+ pred = self.model(latent)
364
339
 
365
- total_loss += loss.item()
366
- pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
340
+ # Final dimension check before loss calculation
341
+ if pred.shape[1] != target.shape[1]:
342
+ min_dim = min(pred.shape[1], target.shape[1])
343
+ pred = pred[:, :min_dim]
344
+ target = target[:, :min_dim]
345
+ if batch_idx == 0: # Only warn once
346
+ print(f"⚠️ Truncating to min dimension: {min_dim}")
347
+
348
+ loss = loss_fn(pred, target)
349
+ loss.backward()
367
350
 
368
- # Clear memory
369
- del pred, loss
370
- if torch.cuda.is_available():
371
- torch.cuda.empty_cache()
351
+ # Gradient clipping for stability
352
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
353
+ optimizer.step()
354
+
355
+ total_loss += loss.item()
372
356
 
373
357
  return total_loss / len(train_loader)
374
358
 
375
- def _validate_epoch(self, val_loader, criterion):
376
- """Validation"""
359
+ def _validate_epoch(self, val_loader, loss_fn):
360
+ """Validate for one epoch with dimension safety"""
377
361
  self.model.eval()
378
362
  total_loss = 0
379
363
 
380
364
  with torch.no_grad():
381
- for latent, target in val_loader:
382
- latent = latent.to(self.device, non_blocking=True)
383
- target = target.to(self.device, non_blocking=True)
365
+ for batch_idx, (latent, target) in enumerate(val_loader):
366
+ latent = latent.to(self.device)
367
+ target = target.to(self.device)
384
368
 
385
369
  pred = self.model(latent)
386
- loss = criterion(pred, target)
387
- total_loss += loss.item()
388
370
 
389
- # Clear memory
390
- del pred, loss
371
+ # Dimension safety
372
+ if pred.shape[1] != target.shape[1]:
373
+ min_dim = min(pred.shape[1], target.shape[1])
374
+ pred = pred[:, :min_dim]
375
+ target = target[:, :min_dim]
376
+
377
+ loss = loss_fn(pred, target)
378
+ total_loss += loss.item()
391
379
 
392
380
  return total_loss / len(val_loader)
393
381
 
394
- def _create_dataset(self, latent_data, expression_data):
395
- """Create dataset"""
396
- class EfficientDataset(Dataset):
397
- def __init__(self, latent, expression):
398
- self.latent = torch.FloatTensor(latent)
399
- self.expression = torch.FloatTensor(expression)
400
-
401
- def __len__(self):
402
- return len(self.latent)
403
-
404
- def __getitem__(self, idx):
405
- return self.latent[idx], self.expression[idx]
406
-
407
- return EfficientDataset(latent_data, expression_data)
408
-
409
382
  def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
410
- """Save checkpoint"""
383
+ """Save model checkpoint"""
411
384
  torch.save({
412
385
  'epoch': epoch,
413
386
  'model_state_dict': self.model.state_dict(),
@@ -422,22 +395,26 @@ class TranscriptomeDecoder:
422
395
  }
423
396
  }, path)
424
397
 
425
- def predict(self, latent_data: np.ndarray, batch_size: int = 8) -> np.ndarray:
398
+ def predict(self, latent_data: np.ndarray, batch_size: int = 32) -> np.ndarray:
426
399
  """
427
- Prediction
400
+ Predict gene expression from latent variables
428
401
 
429
402
  Args:
430
403
  latent_data: Latent variables [n_samples, latent_dim]
431
- batch_size: Prediction batch size for memory control
404
+ batch_size: Prediction batch size
432
405
 
433
406
  Returns:
434
407
  expression: Predicted expression [n_samples, gene_dim]
435
408
  """
436
409
  if not self.is_trained:
437
- warnings.warn("Model not trained. Predictions may be inaccurate.")
410
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
438
411
 
439
412
  self.model.eval()
440
413
 
414
+ # Input validation
415
+ if latent_data.shape[1] != self.latent_dim:
416
+ raise ValueError(f"Latent dimension mismatch: expected {self.latent_dim}, got {latent_data.shape[1]}")
417
+
441
418
  if isinstance(latent_data, np.ndarray):
442
419
  latent_data = torch.FloatTensor(latent_data)
443
420
 
@@ -448,76 +425,83 @@ class TranscriptomeDecoder:
448
425
  batch_latent = latent_data[i:i+batch_size].to(self.device)
449
426
  batch_pred = self.model(batch_latent)
450
427
  predictions.append(batch_pred.cpu())
451
-
452
- # Clear memory
453
- del batch_pred
454
- if torch.cuda.is_available():
455
- torch.cuda.empty_cache()
456
428
 
457
429
  return torch.cat(predictions).numpy()
458
430
 
459
431
  def load_model(self, model_path: str):
460
432
  """Load pre-trained model"""
461
433
  checkpoint = torch.load(model_path, map_location=self.device)
434
+
435
+ # Check model configuration
436
+ if 'model_config' in checkpoint:
437
+ config = checkpoint['model_config']
438
+ if (config['latent_dim'] != self.latent_dim or
439
+ config['gene_dim'] != self.gene_dim):
440
+ print("⚠️ Model configuration mismatch. Reinitializing model.")
441
+ self.model = self._build_model()
442
+ self.model.to(self.device)
443
+
462
444
  self.model.load_state_dict(checkpoint['model_state_dict'])
463
445
  self.is_trained = True
464
446
  self.training_history = checkpoint.get('training_history')
465
447
  self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
466
- print(f"✅ Model loaded! Best validation loss: {self.best_val_loss:.4f}")
467
-
468
- def get_memory_info(self) -> Dict:
469
- """Get memory usage information"""
470
- if torch.cuda.is_available():
471
- memory_allocated = torch.cuda.memory_allocated() / 1024**3
472
- memory_reserved = torch.cuda.memory_reserved() / 1024**3
473
- return {
474
- 'allocated_gb': memory_allocated,
475
- 'reserved_gb': memory_reserved,
476
- 'available_gb': 20 - memory_allocated,
477
- 'utilization_percent': (memory_allocated / 20) * 100
478
- }
479
- return {'available_gb': 'N/A (CPU mode)'}
448
+
449
+ print(f"✅ Model loaded successfully!")
450
+ print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")
480
451
 
452
+ def get_model_info(self) -> Dict:
453
+ """Get model information"""
454
+ return {
455
+ 'is_trained': self.is_trained,
456
+ 'best_val_loss': self.best_val_loss,
457
+ 'parameters': sum(p.numel() for p in self.model.parameters()),
458
+ 'latent_dim': self.latent_dim,
459
+ 'gene_dim': self.gene_dim,
460
+ 'hidden_dim': self.hidden_dim,
461
+ 'device': str(self.device)
462
+ }
481
463
  '''
482
- # Example usage with memory monitoring
464
+ # Example usage
483
465
  def example_usage():
484
- """Memory-efficient example"""
466
+ """Example demonstration with dimension safety"""
485
467
 
486
- # 1. Initialize memory-efficient decoder
487
- decoder = TranscriptomeDecoder(
468
+ # 1. Initialize decoder
469
+ decoder = SimpleTranscriptomeDecoder(
488
470
  latent_dim=100,
489
471
  gene_dim=2000, # Reduced for example
490
- hidden_dim=256 # Reduced for memory
472
+ hidden_dim=256
491
473
  )
492
474
 
493
- # Check memory info
494
- memory_info = decoder.get_memory_info()
495
- print(f"📊 Memory Info: {memory_info}")
496
-
497
- # 2. Generate example data
498
- n_samples = 500 # Reduced for memory
475
+ # 2. Generate example data with correct dimensions
476
+ n_samples = 1000
499
477
  latent_data = np.random.randn(n_samples, 100).astype(np.float32)
500
- expression_data = np.random.randn(n_samples, 2000).astype(np.float32)
478
+
479
+ # Create simulated expression data
480
+ weights = np.random.randn(100, 2000) * 0.1
481
+ expression_data = np.tanh(latent_data.dot(weights))
501
482
  expression_data = np.maximum(expression_data, 0) # Non-negative
502
483
 
503
- print(f"📈 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
484
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
504
485
 
505
- # 3. Train with memory monitoring
486
+ # 3. Train the model
506
487
  history = decoder.train(
507
488
  train_latent=latent_data,
508
489
  train_expression=expression_data,
509
- batch_size=8, # Small batch for memory
510
- num_epochs=20 # Reduced for example
490
+ batch_size=32,
491
+ num_epochs=50,
492
+ learning_rate=1e-4
511
493
  )
512
494
 
513
- # 4. Memory-efficient prediction
514
- test_latent = np.random.randn(5, 100).astype(np.float32)
515
- predictions = decoder.predict(test_latent, batch_size=2)
495
+ # 4. Make predictions
496
+ test_latent = np.random.randn(10, 100).astype(np.float32)
497
+ predictions = decoder.predict(test_latent)
516
498
  print(f"🔮 Prediction shape: {predictions.shape}")
517
499
 
518
- # 5. Final memory check
519
- final_memory = decoder.get_memory_info()
520
- print(f"💾 Final memory usage: {final_memory}")
500
+ # 5. Get model info
501
+ info = decoder.get_model_info()
502
+ print(f"\n📋 Model Info:")
503
+ for key, value in info.items():
504
+ print(f" {key}: {value}")
521
505
 
522
506
  return decoder
523
507