SURE-tools 2.4.5__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -0,0 +1,607 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import numpy as np
7
+ from typing import Dict, List, Optional, Tuple
8
+ import math
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class EfficientTranscriptomeDecoder:
13
+ """
14
+ High-performance, memory-efficient transcriptome decoder
15
+ Combines latest research techniques for optimal performance
16
+ """
17
+
18
+ def __init__(self,
19
+ latent_dim: int = 100,
20
+ gene_dim: int = 60000,
21
+ hidden_dims: List[int] = [512, 1024, 2048],
22
+ bottleneck_dim: int = 256,
23
+ num_experts: int = 8,
24
+ dropout_rate: float = 0.1,
25
+ device: str = None):
26
+ """
27
+ Advanced decoder combining multiple state-of-the-art techniques
28
+
29
+ Args:
30
+ latent_dim: Latent variable dimension
31
+ gene_dim: Number of genes (full transcriptome)
32
+ hidden_dims: Hidden layer dimensions
33
+ bottleneck_dim: Bottleneck dimension for memory efficiency
34
+ num_experts: Number of mixture-of-experts
35
+ dropout_rate: Dropout rate
36
+ device: Computation device
37
+ """
38
+ self.latent_dim = latent_dim
39
+ self.gene_dim = gene_dim
40
+ self.hidden_dims = hidden_dims
41
+ self.bottleneck_dim = bottleneck_dim
42
+ self.num_experts = num_experts
43
+ self.dropout_rate = dropout_rate
44
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Initialize model with advanced architecture
47
+ self.model = self._build_advanced_model()
48
+ self.model.to(self.device)
49
+
50
+ # Training state
51
+ self.is_trained = False
52
+ self.training_history = None
53
+ self.best_val_loss = float('inf')
54
+
55
+ print(f"🚀 EfficientTranscriptomeDecoder Initialized:")
56
+ print(f" - Latent Dimension: {latent_dim}")
57
+ print(f" - Gene Dimension: {gene_dim}")
58
+ print(f" - Hidden Dimensions: {hidden_dims}")
59
+ print(f" - Bottleneck Dimension: {bottleneck_dim}")
60
+ print(f" - Number of Experts: {num_experts}")
61
+ print(f" - Estimated GPU Memory: ~6-8GB")
62
+ print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
63
+
64
+ class SwiGLU(nn.Module):
65
+ """SwiGLU activation - better than GELU (PaLM, LLaMA)"""
66
+ def forward(self, x):
67
+ x, gate = x.chunk(2, dim=-1)
68
+ return x * F.silu(gate)
69
+
70
+ class RMSNorm(nn.Module):
71
+ """RMS Normalization - more stable than LayerNorm (GPT-3)"""
72
+ def __init__(self, dim: int, eps: float = 1e-8):
73
+ super().__init__()
74
+ self.eps = eps
75
+ self.weight = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x):
78
+ norm_x = x.norm(2, dim=-1, keepdim=True)
79
+ rms_x = norm_x * (x.shape[-1] ** -0.5)
80
+ return x / (rms_x + self.eps) * self.weight
81
+
82
+ class MixtureOfExperts(nn.Module):
83
+ """Mixture of Experts for conditional computation"""
84
+ def __init__(self, input_dim: int, expert_dim: int, num_experts: int):
85
+ super().__init__()
86
+ self.num_experts = num_experts
87
+ self.experts = nn.ModuleList([
88
+ nn.Sequential(
89
+ nn.Linear(input_dim, expert_dim),
90
+ nn.Dropout(0.1),
91
+ nn.Linear(expert_dim, input_dim)
92
+ ) for _ in range(num_experts)
93
+ ])
94
+ self.gate = nn.Linear(input_dim, num_experts)
95
+ self.expert_dim = expert_dim
96
+
97
+ def forward(self, x):
98
+ # Gate network
99
+ gate_logits = self.gate(x)
100
+ gate_weights = F.softmax(gate_logits, dim=-1)
101
+
102
+ # Expert outputs
103
+ expert_outputs = []
104
+ for i, expert in enumerate(self.experts):
105
+ expert_out = expert(x)
106
+ expert_outputs.append(expert_out.unsqueeze(-1))
107
+
108
+ # Combine expert outputs
109
+ expert_outputs = torch.cat(expert_outputs, dim=-1)
110
+ output = torch.einsum('bd, bde -> be', gate_weights, expert_outputs)
111
+
112
+ return output + x # Residual connection
113
+
114
+ class AdaptiveBottleneck(nn.Module):
115
+ """Adaptive bottleneck for memory efficiency"""
116
+ def __init__(self, input_dim: int, bottleneck_dim: int, output_dim: int):
117
+ super().__init__()
118
+ self.compress = nn.Linear(input_dim, bottleneck_dim)
119
+ self.norm1 = EfficientTranscriptomeDecoder.RMSNorm(bottleneck_dim)
120
+ self.expand = nn.Linear(bottleneck_dim, output_dim)
121
+ self.norm2 = EfficientTranscriptomeDecoder.RMSNorm(output_dim)
122
+ self.dropout = nn.Dropout(0.1)
123
+
124
+ def forward(self, x):
125
+ # Compress
126
+ compressed = self.compress(x)
127
+ compressed = self.norm1(compressed)
128
+ compressed = F.silu(compressed)
129
+ compressed = self.dropout(compressed)
130
+
131
+ # Expand
132
+ expanded = self.expand(compressed)
133
+ expanded = self.norm2(expanded)
134
+
135
+ return expanded
136
+
137
+ class GeneSpecificProjection(nn.Module):
138
+ """Gene-specific projection with weight sharing"""
139
+ def __init__(self, latent_dim: int, gene_dim: int, proj_dim: int = 64):
140
+ super().__init__()
141
+ self.proj_dim = proj_dim
142
+ self.gene_embeddings = nn.Parameter(torch.randn(gene_dim, proj_dim) * 0.02)
143
+ self.latent_projection = nn.Linear(latent_dim, proj_dim)
144
+ self.output_layer = nn.Linear(proj_dim, 1)
145
+
146
+ def forward(self, latent):
147
+ batch_size = latent.shape[0]
148
+
149
+ # Project latent to gene space
150
+ latent_proj = self.latent_projection(latent) # [batch_size, proj_dim]
151
+
152
+ # Efficient matrix multiplication
153
+ gene_output = torch.matmul(latent_proj, self.gene_embeddings.T) # [batch_size, gene_dim]
154
+
155
+ return gene_output
156
+
157
+ class AdvancedDecoder(nn.Module):
158
+ """Advanced decoder combining multiple techniques"""
159
+
160
+ def __init__(self, latent_dim: int, gene_dim: int, hidden_dims: List[int],
161
+ bottleneck_dim: int, num_experts: int, dropout_rate: float):
162
+ super().__init__()
163
+
164
+ # Initial projection
165
+ self.input_projection = nn.Sequential(
166
+ nn.Linear(latent_dim, hidden_dims[0]),
167
+ EfficientTranscriptomeDecoder.RMSNorm(hidden_dims[0]),
168
+ nn.SiLU(),
169
+ nn.Dropout(dropout_rate)
170
+ )
171
+
172
+ # Main processing blocks
173
+ self.blocks = nn.ModuleList()
174
+ current_dim = hidden_dims[0]
175
+
176
+ for i, hidden_dim in enumerate(hidden_dims[1:], 1):
177
+ block = nn.ModuleList([
178
+ # Mixture of Experts
179
+ EfficientTranscriptomeDecoder.MixtureOfExperts(current_dim, hidden_dim, num_experts),
180
+
181
+ # Adaptive Bottleneck
182
+ EfficientTranscriptomeDecoder.AdaptiveBottleneck(current_dim, bottleneck_dim, hidden_dim),
183
+
184
+ # SwiGLU activation
185
+ nn.Sequential(
186
+ nn.Linear(hidden_dim, hidden_dim * 2),
187
+ EfficientTranscriptomeDecoder.SwiGLU(),
188
+ nn.Dropout(dropout_rate)
189
+ )
190
+ ])
191
+ self.blocks.append(block)
192
+ current_dim = hidden_dim
193
+
194
+ # Gene-specific projection
195
+ self.gene_projection = EfficientTranscriptomeDecoder.GeneSpecificProjection(
196
+ current_dim, gene_dim, proj_dim=128
197
+ )
198
+
199
+ # Output scaling
200
+ self.output_scale = nn.Parameter(torch.ones(1))
201
+ self.output_bias = nn.Parameter(torch.zeros(1))
202
+
203
+ self._init_weights()
204
+
205
+ def _init_weights(self):
206
+ """Advanced weight initialization"""
207
+ for module in self.modules():
208
+ if isinstance(module, nn.Linear):
209
+ # Kaiming init for SiLU/SwiGLU
210
+ nn.init.kaiming_normal_(module.weight, nonlinearity='linear')
211
+ if module.bias is not None:
212
+ nn.init.zeros_(module.bias)
213
+
214
+ def forward(self, x):
215
+ # Initial projection
216
+ x = self.input_projection(x)
217
+
218
+ # Process through blocks
219
+ for block in self.blocks:
220
+ # Mixture of Experts
221
+ expert_out = block[0](x)
222
+
223
+ # Adaptive Bottleneck
224
+ bottleneck_out = block[1](expert_out)
225
+
226
+ # SwiGLU activation with residual
227
+ swiglu_out = block[2](bottleneck_out)
228
+ x = x + swiglu_out # Residual connection
229
+
230
+ # Final gene projection
231
+ output = self.gene_projection(x)
232
+
233
+ # Ensure non-negative output
234
+ output = F.softplus(output * self.output_scale + self.output_bias)
235
+
236
+ return output
237
+
238
+ def _build_advanced_model(self):
239
+ """Build the advanced decoder model"""
240
+ return self.AdvancedDecoder(
241
+ self.latent_dim, self.gene_dim, self.hidden_dims,
242
+ self.bottleneck_dim, self.num_experts, self.dropout_rate
243
+ )
244
+
245
+ def train(self,
246
+ train_latent: np.ndarray,
247
+ train_expression: np.ndarray,
248
+ val_latent: np.ndarray = None,
249
+ val_expression: np.ndarray = None,
250
+ batch_size: int = 16, # Smaller batches for memory efficiency
251
+ num_epochs: int = 200,
252
+ learning_rate: float = 1e-4,
253
+ checkpoint_path: str = 'efficient_decoder.pth') -> Dict:
254
+ """
255
+ Train with advanced optimization techniques
256
+
257
+ Args:
258
+ train_latent: Training latent variables
259
+ train_expression: Training expression data
260
+ val_latent: Validation latent variables
261
+ val_expression: Validation expression data
262
+ batch_size: Batch size (optimized for memory)
263
+ num_epochs: Number of epochs
264
+ learning_rate: Learning rate
265
+ checkpoint_path: Model save path
266
+ """
267
+ print("🚀 Starting Advanced Training...")
268
+
269
+ # Data preparation
270
+ train_dataset = self._create_dataset(train_latent, train_expression)
271
+
272
+ if val_latent is not None and val_expression is not None:
273
+ val_dataset = self._create_dataset(val_latent, val_expression)
274
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
275
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
276
+ else:
277
+ train_size = int(0.9 * len(train_dataset))
278
+ val_size = len(train_dataset) - train_size
279
+ train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
280
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
281
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)
282
+
283
+ print(f"📊 Training samples: {len(train_loader.dataset)}")
284
+ print(f"📊 Validation samples: {len(val_loader.dataset)}")
285
+ print(f"📊 Batch size: {batch_size}")
286
+
287
+ # Advanced optimizer configuration
288
+ optimizer = optim.AdamW(
289
+ self.model.parameters(),
290
+ lr=learning_rate,
291
+ weight_decay=0.1, # Stronger regularization
292
+ betas=(0.9, 0.95), # Tuned betas
293
+ eps=1e-8
294
+ )
295
+
296
+ # Cosine annealing with warmup
297
+ scheduler = optim.lr_scheduler.OneCycleLR(
298
+ optimizer,
299
+ max_lr=learning_rate * 5,
300
+ epochs=num_epochs,
301
+ steps_per_epoch=len(train_loader),
302
+ pct_start=0.1,
303
+ div_factor=10.0,
304
+ final_div_factor=100.0
305
+ )
306
+
307
+ # Advanced loss function
308
+ def advanced_loss(pred, target):
309
+ # 1. MSE loss for overall accuracy
310
+ mse_loss = F.mse_loss(pred, target)
311
+
312
+ # 2. Poisson loss for count data
313
+ poisson_loss = (pred - target * torch.log(pred + 1e-8)).mean()
314
+
315
+ # 3. Correlation loss for pattern matching
316
+ correlation_loss = 1 - self._pearson_correlation(pred, target)
317
+
318
+ # 4. Sparsity loss for realistic distribution
319
+ sparsity_loss = F.mse_loss(
320
+ (pred < 1e-3).float().mean(),
321
+ torch.tensor(0.85, device=pred.device) # Target sparsity
322
+ )
323
+
324
+ # 5. Spectral loss for smoothness
325
+ spectral_loss = self._spectral_loss(pred, target)
326
+
327
+ # Weighted combination
328
+ total_loss = (mse_loss + 0.3 * poisson_loss + 0.2 * correlation_loss +
329
+ 0.1 * sparsity_loss + 0.05 * spectral_loss)
330
+
331
+ return total_loss, {
332
+ 'mse': mse_loss.item(),
333
+ 'poisson': poisson_loss.item(),
334
+ 'correlation': correlation_loss.item(),
335
+ 'sparsity': sparsity_loss.item(),
336
+ 'spectral': spectral_loss.item()
337
+ }
338
+
339
+ # Training history
340
+ history = {
341
+ 'train_loss': [], 'val_loss': [],
342
+ 'train_mse': [], 'val_mse': [],
343
+ 'train_correlation': [], 'val_correlation': [],
344
+ 'learning_rates': [], 'grad_norms': []
345
+ }
346
+
347
+ best_val_loss = float('inf')
348
+ patience = 25
349
+ patience_counter = 0
350
+
351
+ print("\n📈 Starting training with advanced techniques...")
352
+ for epoch in range(1, num_epochs + 1):
353
+ # Training phase
354
+ train_loss, train_components, grad_norm = self._train_epoch_advanced(
355
+ train_loader, optimizer, scheduler, advanced_loss
356
+ )
357
+
358
+ # Validation phase
359
+ val_loss, val_components = self._validate_epoch_advanced(val_loader, advanced_loss)
360
+
361
+ # Record history
362
+ history['train_loss'].append(train_loss)
363
+ history['val_loss'].append(val_loss)
364
+ history['train_mse'].append(train_components['mse'])
365
+ history['val_mse'].append(val_components['mse'])
366
+ history['train_correlation'].append(train_components['correlation'])
367
+ history['val_correlation'].append(val_components['correlation'])
368
+ history['learning_rates'].append(optimizer.param_groups[0]['lr'])
369
+ history['grad_norms'].append(grad_norm)
370
+
371
+ # Print detailed progress
372
+ if epoch % 10 == 0 or epoch == 1:
373
+ lr = optimizer.param_groups[0]['lr']
374
+ print(f"📍 Epoch {epoch:3d}/{num_epochs} | "
375
+ f"Train: {train_loss:.4f} | "
376
+ f"Val: {val_loss:.4f} | "
377
+ f"Corr: {val_components['correlation']:.4f} | "
378
+ f"LR: {lr:.2e} | "
379
+ f"Grad: {grad_norm:.4f}")
380
+
381
+ # Early stopping with patience
382
+ if val_loss < best_val_loss:
383
+ best_val_loss = val_loss
384
+ patience_counter = 0
385
+ self._save_checkpoint(epoch, optimizer, scheduler, best_val_loss, history, checkpoint_path)
386
+ if epoch % 20 == 0:
387
+ print(f"💾 Best model saved (Val Loss: {best_val_loss:.4f})")
388
+ else:
389
+ patience_counter += 1
390
+ if patience_counter >= patience:
391
+ print(f"🛑 Early stopping at epoch {epoch}")
392
+ break
393
+
394
+ # Training completed
395
+ self.is_trained = True
396
+ self.training_history = history
397
+ self.best_val_loss = best_val_loss
398
+
399
+ print(f"\n🎉 Training completed!")
400
+ print(f"🏆 Best validation loss: {best_val_loss:.4f}")
401
+
402
+ return history
403
+
404
+ def _create_dataset(self, latent_data, expression_data):
405
+ """Create memory-efficient dataset"""
406
+ class EfficientDataset(Dataset):
407
+ def __init__(self, latent, expression):
408
+ self.latent = torch.FloatTensor(latent)
409
+ self.expression = torch.FloatTensor(expression)
410
+
411
+ def __len__(self):
412
+ return len(self.latent)
413
+
414
+ def __getitem__(self, idx):
415
+ return self.latent[idx], self.expression[idx]
416
+
417
+ return EfficientDataset(latent_data, expression_data)
418
+
419
+ def _pearson_correlation(self, pred, target):
420
+ """Calculate Pearson correlation"""
421
+ pred_centered = pred - pred.mean(dim=1, keepdim=True)
422
+ target_centered = target - target.mean(dim=1, keepdim=True)
423
+
424
+ numerator = (pred_centered * target_centered).sum(dim=1)
425
+ denominator = torch.sqrt(torch.sum(pred_centered ** 2, dim=1)) * torch.sqrt(torch.sum(target_centered ** 2, dim=1))
426
+
427
+ return (numerator / (denominator + 1e-8)).mean()
428
+
429
+ def _spectral_loss(self, pred, target):
430
+ """Spectral loss for frequency domain matching"""
431
+ pred_fft = torch.fft.fft(pred, dim=1)
432
+ target_fft = torch.fft.fft(target, dim=1)
433
+
434
+ magnitude_loss = F.mse_loss(torch.abs(pred_fft), torch.abs(target_fft))
435
+ phase_loss = F.mse_loss(torch.angle(pred_fft), torch.angle(target_fft))
436
+
437
+ return magnitude_loss + 0.5 * phase_loss
438
+
439
+ def _train_epoch_advanced(self, train_loader, optimizer, scheduler, loss_fn):
440
+ """Advanced training with gradient accumulation"""
441
+ self.model.train()
442
+ total_loss = 0
443
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
444
+ grad_norms = []
445
+
446
+ # Gradient accumulation for effective larger batch size
447
+ accumulation_steps = 4
448
+ optimizer.zero_grad()
449
+
450
+ for i, (latent, target) in enumerate(train_loader):
451
+ latent = latent.to(self.device, non_blocking=True)
452
+ target = target.to(self.device, non_blocking=True)
453
+
454
+ # Forward pass with mixed precision
455
+ with torch.cuda.amp.autocast(): # Mixed precision for memory efficiency
456
+ pred = self.model(latent)
457
+ loss, components = loss_fn(pred, target)
458
+
459
+ # Scale loss for gradient accumulation
460
+ loss = loss / accumulation_steps
461
+ loss.backward()
462
+
463
+ # Gradient accumulation
464
+ if (i + 1) % accumulation_steps == 0:
465
+ # Gradient clipping
466
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
467
+ optimizer.step()
468
+ scheduler.step()
469
+ optimizer.zero_grad()
470
+
471
+ grad_norms.append(grad_norm.item())
472
+
473
+ # Accumulate losses
474
+ total_loss += loss.item() * accumulation_steps
475
+ for key in total_components:
476
+ total_components[key] += components[key]
477
+
478
+ # Average metrics
479
+ num_batches = len(train_loader)
480
+ avg_loss = total_loss / num_batches
481
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
482
+ avg_grad_norm = np.mean(grad_norms) if grad_norms else 0.0
483
+
484
+ return avg_loss, avg_components, avg_grad_norm
485
+
486
+ def _validate_epoch_advanced(self, val_loader, loss_fn):
487
+ """Advanced validation"""
488
+ self.model.eval()
489
+ total_loss = 0
490
+ total_components = {'mse': 0, 'poisson': 0, 'correlation': 0, 'sparsity': 0, 'spectral': 0}
491
+
492
+ with torch.no_grad():
493
+ for latent, target in val_loader:
494
+ latent = latent.to(self.device, non_blocking=True)
495
+ target = target.to(self.device, non_blocking=True)
496
+
497
+ pred = self.model(latent)
498
+ loss, components = loss_fn(pred, target)
499
+
500
+ total_loss += loss.item()
501
+ for key in total_components:
502
+ total_components[key] += components[key]
503
+
504
+ num_batches = len(val_loader)
505
+ avg_loss = total_loss / num_batches
506
+ avg_components = {key: value / num_batches for key, value in total_components.items()}
507
+
508
+ return avg_loss, avg_components
509
+
510
+ def _save_checkpoint(self, epoch, optimizer, scheduler, best_loss, history, path):
511
+ """Save checkpoint"""
512
+ torch.save({
513
+ 'epoch': epoch,
514
+ 'model_state_dict': self.model.state_dict(),
515
+ 'optimizer_state_dict': optimizer.state_dict(),
516
+ 'scheduler_state_dict': scheduler.state_dict(),
517
+ 'best_val_loss': best_loss,
518
+ 'training_history': history,
519
+ 'model_config': {
520
+ 'latent_dim': self.latent_dim,
521
+ 'gene_dim': self.gene_dim,
522
+ 'hidden_dims': self.hidden_dims,
523
+ 'bottleneck_dim': self.bottleneck_dim,
524
+ 'num_experts': self.num_experts
525
+ }
526
+ }, path)
527
+
528
+ def predict(self, latent_data: np.ndarray, batch_size: int = 16) -> np.ndarray:
529
+ """Memory-efficient prediction"""
530
+ if not self.is_trained:
531
+ warnings.warn("⚠️ Model not trained. Predictions may be inaccurate.")
532
+
533
+ self.model.eval()
534
+
535
+ if isinstance(latent_data, np.ndarray):
536
+ latent_data = torch.FloatTensor(latent_data)
537
+
538
+ predictions = []
539
+ with torch.no_grad():
540
+ for i in range(0, len(latent_data), batch_size):
541
+ batch_latent = latent_data[i:i+batch_size].to(self.device)
542
+
543
+ with torch.cuda.amp.autocast(): # Mixed precision for memory
544
+ batch_pred = self.model(batch_latent)
545
+
546
+ predictions.append(batch_pred.cpu())
547
+
548
+ # Clear memory
549
+ if torch.cuda.is_available():
550
+ torch.cuda.empty_cache()
551
+
552
+ return torch.cat(predictions).numpy()
553
+
554
+ def load_model(self, model_path: str):
555
+ """Load pre-trained model"""
556
+ checkpoint = torch.load(model_path, map_location=self.device)
557
+ self.model.load_state_dict(checkpoint['model_state_dict'])
558
+ self.is_trained = True
559
+ self.training_history = checkpoint.get('training_history')
560
+ self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
561
+ print(f"✅ Model loaded! Best val loss: {self.best_val_loss:.4f}")
562
+
563
+ '''
564
+ # Example usage
565
+ def example_usage():
566
+ """Demonstrate the advanced decoder"""
567
+
568
+ # Initialize decoder
569
+ decoder = EfficientTranscriptomeDecoder(
570
+ latent_dim=100,
571
+ gene_dim=2000, # Reduced for example
572
+ hidden_dims=[256, 512, 1024],
573
+ bottleneck_dim=128,
574
+ num_experts=4,
575
+ dropout_rate=0.1
576
+ )
577
+
578
+ # Generate example data
579
+ n_samples = 1000
580
+ latent_data = np.random.randn(n_samples, 100).astype(np.float32)
581
+
582
+ # Simulate expression data
583
+ weights = np.random.randn(100, 2000) * 0.1
584
+ expression_data = np.tanh(latent_data.dot(weights))
585
+ expression_data = np.maximum(expression_data, 0)
586
+
587
+ print(f"📊 Data shapes: Latent {latent_data.shape}, Expression {expression_data.shape}")
588
+
589
+ # Train
590
+ history = decoder.train(
591
+ train_latent=latent_data,
592
+ train_expression=expression_data,
593
+ batch_size=16,
594
+ num_epochs=50
595
+ )
596
+
597
+ # Predict
598
+ test_latent = np.random.randn(10, 100).astype(np.float32)
599
+ predictions = decoder.predict(test_latent)
600
+ print(f"🔮 Prediction shape: {predictions.shape}")
601
+
602
+ return decoder
603
+
604
+ if __name__ == "__main__":
605
+ example_usage()
606
+
607
+ '''