ebm-splats 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ebm/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ EBM-Splats: Energy-Based Model with Gaussian Splats on a Hypersphere.
3
+
4
+ Archived research project (2026). See README.md for details.
5
+ """
6
+
7
+ __version__ = "2.0.0"
8
+ __status__ = "archived"
ebm/config.py ADDED
@@ -0,0 +1,109 @@
1
+ """
2
+ EBM Configuration — V2 Specification Compliant.
3
+
4
+ Fixes applied:
5
+ - Removed duplicate init_alpha field
6
+ - Added V2 spec parameters (lambda context weights, adaptive noise, etc.)
7
+ - Explicit sign convention documentation
8
+ - Proper default values from V2 Especificacion Tecnica
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import Optional, Tuple
13
+
14
+
15
+ @dataclass
16
+ class EBMConfig:
17
+ """
18
+ Configuration for the EBM continuous language model.
19
+
20
+ Sign Convention (V2 Spec §2.1):
21
+ p(x) = exp(-E(x)) / Z
22
+ score(x) = ∇_x log p(x) = -∇_x E(x)
23
+ Low energy = high probability (near splat centers)
24
+ """
25
+
26
+ # ── Environment ──
27
+ device: str = "cpu"
28
+
29
+ # ── Latent Space ──
30
+ latent_dim: int = 640
31
+ n_splats_init: int = 10000
32
+ max_splats: int = 100000
33
+ knn_k: int = 64
34
+
35
+ # ── Splat Initialization (Phase 1 improvements) ──
36
+ vocab_embedding_path: Optional[str] = None
37
+ init_from_vocab_embeddings: bool = False
38
+ init_alpha: float = 1.0
39
+ init_kappa: float = 10.0
40
+
41
+ # ── Splat Regularization ──
42
+ splat_temperature: float = 0.1
43
+ splat_weight_decay: float = 0.0
44
+ splat_weight_decay_start: float = 1.0
45
+ min_kappa: float = 1.0
46
+ max_kappa: float = 50.0
47
+
48
+ # ── Training ──
49
+ learning_rate: float = 1e-3
50
+ reg_weight: float = 0.01
51
+ grad_clip: float = 1.0
52
+ temperature: float = 0.1
53
+
54
+ # ── Curriculum Learning ──
55
+ enable_curriculum_learning: bool = True
56
+ curriculum_epochs: int = 5
57
+ curriculum_target_splats: int = 50000
58
+ splat_convergence_threshold: float = 0.95
59
+
60
+ # ── Noise Schedule (V2 §3.3 Adaptive Noise) ──
61
+ noise_levels: Tuple[float, ...] = (0.01, 0.05, 0.1, 0.2, 0.5)
62
+ sigma_base: float = 0.1
63
+ sigma_scale: float = 1.0
64
+ sigma_rho_ref: float = 1.0
65
+
66
+ # ── Langevin Dynamics ──
67
+ langevin_steps: int = 200
68
+ langevin_dt: float = 0.001
69
+ langevin_gamma: float = 0.1
70
+ langevin_T: float = 1.0
71
+
72
+ # ── SOC (Self-Organized Criticality) ──
73
+ soc_threshold: float = 0.8
74
+ soc_check_interval: int = 100
75
+ min_splat_distance: float = 0.1
76
+
77
+ # ── Hierarchical Context (V2 §4) ──
78
+ context_local_window: int = 16 # 8-16 tokens
79
+ context_medium_window: int = 128 # 64-128 tokens
80
+ context_global_window: int = 512 # 512+ tokens
81
+ beta_local: float = 0.5 # Fast adaptation
82
+ beta_medium: float = 0.8 # Moderate
83
+ beta_global: float = 0.95 # Slow, stable
84
+ lambda_context_local: float = 1.0
85
+ lambda_context_medium: float = 0.5
86
+ lambda_context_global: float = 0.2
87
+
88
+ # ── Decoder (V2 §5) ──
89
+ vocab_size: int = 50257 # GPT-2 vocab
90
+ moe_experts: int = 4
91
+ moe_active: int = 2
92
+ hidden_dim: int = 1024
93
+
94
+ # ── Collapse Regularization (V2 §2.4) ──
95
+ lambda_reg: float = 0.01
96
+ theta_threshold: float = 0.9 # Min angular separation between splat centers
97
+
98
+ # ── Energy Weights ──
99
+ lambda_geom: float = 0.01
100
+ lambda_comp: float = 0.05
101
+
102
+ # ── EMA ──
103
+ ema_decay: float = 0.999
104
+
105
+ # ── Error Recovery (V2 §7) ──
106
+ energy_stagnation_window: int = 5
107
+ energy_stagnation_epsilon: float = 1e-4
108
+ gradient_magnitude_threshold: float = 1e-5
109
+ max_recovery_attempts: int = 3
ebm/context.py ADDED
@@ -0,0 +1,117 @@
1
+ """
2
+ Hierarchical Context System — V2 §4.
3
+
4
+ Three-level context for long-range dependencies:
5
+ - Local: last 8-16 tokens (β≈0.5, fast adaptation)
6
+ - Medium: last 64-128 tokens (β≈0.8, moderate)
7
+ - Global: full sequence (β≈0.95, slow, stable)
8
+
9
+ Integration:
10
+ E_trans(x_t) = -Σ_{l∈{local,medium,global}} λ_l · (x_t · c_l)
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from typing import Dict
16
+
17
+ from .config import EBMConfig
18
+ from .geometry import normalize_sphere
19
+
20
+
21
+ class HierarchicalContext(nn.Module):
22
+ """Three-level context system for EBM language model."""
23
+
24
+ def __init__(self, config: EBMConfig):
25
+ super().__init__()
26
+ self.config = config
27
+ self.latent_dim = config.latent_dim
28
+
29
+ # Persistent context vectors (on hypersphere)
30
+ self.register_buffer('c_local', torch.zeros(1, config.latent_dim))
31
+ self.register_buffer('c_medium', torch.zeros(1, config.latent_dim))
32
+ self.register_buffer('c_global', torch.zeros(1, config.latent_dim))
33
+
34
+ self.beta_local = config.beta_local
35
+ self.beta_medium = config.beta_medium
36
+ self.beta_global = config.beta_global
37
+
38
+ # Update counters
39
+ self.register_buffer('_step', torch.tensor(0, dtype=torch.long))
40
+ self.register_buffer('_local_count', torch.tensor(0, dtype=torch.long))
41
+ self.register_buffer('_medium_count', torch.tensor(0, dtype=torch.long))
42
+
43
+ # Buffers stored as lists to avoid device/shape issues
44
+ self._local_buffer = []
45
+ self._medium_buffer = []
46
+
47
+ def reset(self, batch_size: int, device: torch.device):
48
+ """Reset context for a new sequence."""
49
+ self.c_local = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
50
+ self.c_medium = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
51
+ self.c_global = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
52
+ self._step.zero_()
53
+ self._local_count.zero_()
54
+ self._medium_count.zero_()
55
+ self._local_buffer = []
56
+ self._medium_buffer = []
57
+
58
+ @torch.no_grad()
59
+ def update(self, x: torch.Tensor):
60
+ """Update context with new token embedding x: [B, D]."""
61
+ B = x.size(0)
62
+ device = x.device
63
+ step = self._step.item()
64
+
65
+ # Ensure batch/device consistency
66
+ if self.c_local.size(0) != B or self.c_local.device != device:
67
+ self.c_local = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
68
+ self.c_medium = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
69
+ self.c_global = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
70
+
71
+ # ── Local: every step ──
72
+ self.c_local = normalize_sphere(
73
+ self.beta_local * self.c_local.to(device) + (1 - self.beta_local) * x
74
+ )
75
+
76
+ # Store batch-averaged embedding for medium context
77
+ self._local_buffer.append(x.mean(dim=0).detach().cpu())
78
+ if len(self._local_buffer) > self.config.context_local_window:
79
+ self._local_buffer.pop(0)
80
+ self._local_count.add_(1)
81
+
82
+ # ── Medium: every 4 steps ──
83
+ if step > 0 and step % 4 == 0 and len(self._local_buffer) > 0:
84
+ local_avg = torch.stack(self._local_buffer).mean(dim=0).to(device)
85
+ self.c_medium = normalize_sphere(
86
+ self.beta_medium * self.c_medium.to(device) + (1 - self.beta_medium) * local_avg.unsqueeze(0).expand(B, -1)
87
+ )
88
+ self._medium_buffer.append(self.c_medium.mean(dim=0).detach().cpu())
89
+ if len(self._medium_buffer) > self.config.context_medium_window:
90
+ self._medium_buffer.pop(0)
91
+ self._medium_count.add_(1)
92
+
93
+ # ── Global: every 16 steps ──
94
+ if step > 0 and step % 16 == 0 and len(self._medium_buffer) > 0:
95
+ medium_avg = torch.stack(self._medium_buffer).mean(dim=0).to(device)
96
+ self.c_global = normalize_sphere(
97
+ self.beta_global * self.c_global.to(device) + (1 - self.beta_global) * medium_avg.unsqueeze(0).expand(B, -1)
98
+ )
99
+
100
+ self._step.add_(1)
101
+
102
+ def get_context(self) -> Dict[str, torch.Tensor]:
103
+ """Return dict of context vectors."""
104
+ return {
105
+ 'local': self.c_local,
106
+ 'medium': self.c_medium,
107
+ 'global': self.c_global,
108
+ }
109
+
110
+ def get_combined_context(self) -> torch.Tensor:
111
+ """Weighted average context."""
112
+ c = (
113
+ self.config.lambda_context_local * self.c_local +
114
+ self.config.lambda_context_medium * self.c_medium +
115
+ self.config.lambda_context_global * self.c_global
116
+ )
117
+ return normalize_sphere(c)
ebm/cuda/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """CUDA-accelerated energy computation (requires PyTorch with CUDA)."""
ebm/cuda/energy.py ADDED
@@ -0,0 +1,102 @@
1
+ """
2
+ CUDA-native Energy Function — V2 Compliant.
3
+
4
+ Same as energy.py but uses CUDA AMP-safe autograd instead of manual gradients.
5
+ The V2 manual gradient approach in energy.py is preferred for training stability.
6
+ This module serves as CUDA fallback when V2 manual gradients aren't needed.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ..config import EBMConfig
14
+ from ..geometry import project_to_tangent
15
+
16
+
17
+ class EnergyFunctionCUDA(nn.Module):
18
+ """CUDA-native energy with V2 sign conventions."""
19
+
20
+ def __init__(self, config: EBMConfig, splat_store):
21
+ super().__init__()
22
+ self.config = config
23
+ self.splats = splat_store
24
+ self.W_comp = nn.Linear(3, 1)
25
+
26
+ def compute_splat_energy(self, x: torch.Tensor) -> torch.Tensor:
27
+ original_shape = x.shape[:-1]
28
+ if x.dim() > 2:
29
+ x = x.reshape(-1, x.size(-1))
30
+
31
+ neighbors_mu, neighbors_alpha, neighbors_kappa = self.splats.find_neighbors(x, self.config.knn_k)
32
+
33
+ dot_products = torch.bmm(neighbors_mu, x.unsqueeze(-1)).squeeze(-1)
34
+
35
+ # Importance weighting (V2 §2.2)
36
+ importance = neighbors_kappa.clamp(min=1e-4)
37
+ weights = importance / importance.sum(dim=-1, keepdim=True)
38
+
39
+ exponent = neighbors_alpha * (dot_products - 1.0) / self.config.temperature
40
+ weighted_exponent = exponent + torch.log(weights.clamp(min=1e-8))
41
+
42
+ energy = -torch.logsumexp(weighted_exponent, dim=-1)
43
+ return energy.view(original_shape)
44
+
45
+ def compute_geom_energy(self, x: torch.Tensor) -> torch.Tensor:
46
+ original_shape = x.shape[:-1]
47
+ if x.dim() > 2:
48
+ x = x.reshape(-1, x.size(-1))
49
+
50
+ if x.size(0) < 2:
51
+ return torch.tensor(0.0, device=x.device)
52
+
53
+ batch_sims = torch.mm(x, x.T)
54
+ mask = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
55
+ off_diag = batch_sims[mask]
56
+ spread_energy = -torch.log(1.0 - off_diag.clamp(max=1.0 - 1e-4) + 1e-4).mean()
57
+
58
+ return spread_energy
59
+
60
+ def compute_comp_energy(self, x: torch.Tensor) -> torch.Tensor:
61
+ original_shape = x.shape[:-1]
62
+ if x.dim() > 2:
63
+ x = x.reshape(-1, x.size(-1))
64
+
65
+ neighbors_mu, _, _ = self.splats.find_neighbors(x, 2)
66
+ u = torch.bmm(neighbors_mu[:, 0:1, :], x.unsqueeze(-1)).squeeze(-1)
67
+ v = torch.bmm(neighbors_mu[:, 1:2, :], x.unsqueeze(-1)).squeeze(-1)
68
+ uv_concat = torch.cat([u, v, u * v], dim=-1)
69
+ comp_energy = torch.sigmoid(self.W_comp(uv_concat)).squeeze(-1)
70
+ return comp_energy.view(original_shape)
71
+
72
+ def compute_context_energy(self, x: torch.Tensor, context_vecs: dict) -> torch.Tensor:
73
+ energy = torch.tensor(0.0, device=x.device)
74
+ for level, vec in context_vecs.items():
75
+ lam = getattr(self.config, f'lambda_context_{level}', 0.0)
76
+ energy = energy - lam * (x * vec).sum(dim=-1)
77
+ return energy
78
+
79
+ def forward(self, x: torch.Tensor, context_vecs: dict = None) -> torch.Tensor:
80
+ e_splat = self.compute_splat_energy(x)
81
+ e_geom = self.compute_geom_energy(x) * self.config.lambda_geom
82
+ e_comp = self.compute_comp_energy(x) * self.config.lambda_comp
83
+ e_total = e_splat + e_geom + e_comp
84
+
85
+ if context_vecs is not None:
86
+ e_trans = self.compute_context_energy(x, context_vecs)
87
+ e_total = e_total + e_trans
88
+
89
+ return e_total
90
+
91
+ def compute_score(self, x: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Riemannian score via AMP-safe autograd.
94
+ For training stability, prefer energy.py manual gradient approach.
95
+ """
96
+ with torch.amp.autocast('cuda', enabled=False):
97
+ x_f32 = x.float().detach().requires_grad_(True)
98
+ energy = self.forward(x_f32.float())
99
+ grad_e = torch.autograd.grad(energy.sum(), x_f32, create_graph=False)[0]
100
+ grad_riemann = project_to_tangent(x_f32, grad_e)
101
+
102
+ return -grad_riemann
ebm/data.py ADDED
@@ -0,0 +1,67 @@
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from datasets import load_dataset
4
+ from transformers import AutoTokenizer
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class TextDataset(Dataset):
11
+ def __init__(self, tokenized_data, seq_len: int = 16):
12
+ """
13
+ Dynamically slices a continuous stream of tokens into chunks of `seq_len`.
14
+ """
15
+ self.seq_len = seq_len
16
+ self.tokens = tokenized_data
17
+ self.total_chunks = len(self.tokens) // seq_len
18
+
19
+ def __len__(self):
20
+ return self.total_chunks
21
+
22
+ def __getitem__(self, idx):
23
+ start = idx * self.seq_len
24
+ end = start + self.seq_len
25
+ return torch.tensor(self.tokens[start:end], dtype=torch.long)
26
+
27
+ def get_dataloader(tokenizer_name: str = "gpt2", dataset_name: str = "wikitext",
28
+ config_name: str = "wikitext-2-raw-v1", split: str = "train",
29
+ batch_size: int = 8, seq_len: int = 16, max_samples: int = 5000):
30
+ """
31
+ Downloads, tokenizes, and maps text blocks to PyTorch DataLoader structs.
32
+ """
33
+ logger.info(f"Loading '{dataset_name}' dataset ({split})...")
34
+
35
+ try:
36
+ raw_datasets = load_dataset(dataset_name, config_name)
37
+ data_split = raw_datasets[split]
38
+ except Exception as e:
39
+ logger.error(f"Failed to load dataset: {e}")
40
+ raise
41
+
42
+ logger.info(f"Loading tokenizer '{tokenizer_name}'...")
43
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
44
+
45
+ # Optional pad token setting if required by specific batching, but we utilize
46
+ # exact chunks here.
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ logger.info("Tokenizing active text samples...")
51
+ # Concatenate texts dropping empty lines
52
+ text_corpus = " ".join([row["text"] for row in data_split.select(range(min(len(data_split), max_samples)))])
53
+
54
+ # Process without truncation targeting a continuous token stream
55
+ tokenized = tokenizer.encode(text_corpus, add_special_tokens=False)
56
+
57
+ logger.info(f"Generated {len(tokenized)} tokens.")
58
+
59
+ # Construct sequence slicer dataset
60
+ dataset = TextDataset(tokenized, seq_len=seq_len)
61
+
62
+ # Drop_last=True ensures all batches are uniformly sized (B, seq_len)
63
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
64
+
65
+ logger.info(f"Dataloader yielding {len(dataloader)} batches of dimension [{batch_size}, {seq_len}]")
66
+
67
+ return dataloader, tokenizer
ebm/data_loader.py ADDED
@@ -0,0 +1,192 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Optimized dataset loader for EBM training. Uses local text files
4
+ instead of downloading from HuggingFace every time.
5
+
6
+ Supported datasets:
7
+ - TinyStories: D:/datasets/ebm/tinystories_train.txt (1.9GB, 2.1M stories)
8
+ - TinyStories val: D:/datasets/ebm/tinystories_val.txt (19MB, 22K stories)
9
+ - WikiText-103: cached via HuggingFace (fallback)
10
+ """
11
+
12
+ import os
13
+ import logging
14
+ import warnings
15
+
16
+ # Suppress GPT-2 tokenizer warnings about sequence length
17
+ warnings.filterwarnings("ignore", message="Token indices sequence length")
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Lazy imports for heavy dependencies
22
+ def _get_torch_utils():
23
+ """Lazy import torch utilities."""
24
+ import torch
25
+ from torch.utils.data import Dataset, DataLoader
26
+ return torch, Dataset, DataLoader
27
+
28
+ def _get_transformers():
29
+ """Lazy import transformers."""
30
+ from transformers import AutoTokenizer
31
+ return AutoTokenizer
32
+
33
+ # Default paths
34
+ DATASETS_DIR = "D:/datasets/ebm"
35
+ TINYSTORIES_TRAIN = os.path.join(DATASETS_DIR, "tinystories_train.txt")
36
+ TINYSTORIES_VAL = os.path.join(DATASETS_DIR, "tinystories_val.txt")
37
+
38
+
39
+ class TextFileDataset:
40
+ """Memory-mapped text file sliced into token chunks."""
41
+
42
+ def __init__(self, token_ids: list, seq_len: int):
43
+ self.seq_len = seq_len
44
+ self.tokens = token_ids
45
+ self.total_chunks = len(self.tokens) // seq_len
46
+ # Register with torch Dataset class
47
+ torch, Dataset, _ = _get_torch_utils()
48
+ self._Dataset = Dataset
49
+
50
+ def __len__(self):
51
+ return self.total_chunks
52
+
53
+ def __getitem__(self, idx):
54
+ torch, _, _ = _get_torch_utils()
55
+ start = idx * self.seq_len
56
+ end = start + self.seq_len
57
+ return torch.tensor(self.tokens[start:end], dtype=torch.long)
58
+
59
+
60
+ def tokenize_file(filepath: str, tokenizer, max_chars: int = None, chunk_size: int = 50_000_000) -> list:
61
+ """Tokenize a text file into a flat list of token IDs, reading in chunks to avoid OOM."""
62
+ logger.info(f"Reading {filepath}...")
63
+
64
+ all_token_ids = []
65
+ chars_read = 0
66
+
67
+ with open(filepath, 'r', encoding='utf-8') as f:
68
+ while True:
69
+ if max_chars and chars_read >= max_chars:
70
+ break
71
+ read_size = min(chunk_size, max_chars - chars_read) if max_chars else chunk_size
72
+ text = f.read(read_size)
73
+ if not text:
74
+ break
75
+ chars_read += len(text)
76
+
77
+ # Tokenize this chunk
78
+ chunk_ids = tokenizer.encode(text, add_special_tokens=False)
79
+ all_token_ids.extend(chunk_ids)
80
+ logger.info(f" {chars_read:,} chars → {len(all_token_ids):,} tokens")
81
+
82
+ logger.info(f" Total: {chars_read:,} characters, {len(all_token_ids):,} tokens")
83
+ return all_token_ids
84
+
85
+
86
+ def get_tinystories_dataloader(
87
+ seq_len: int = 64,
88
+ batch_size: int = 64,
89
+ split: str = "train",
90
+ max_chars: int = None,
91
+ tokenizer_name: str = "gpt2",
92
+ shuffle: bool = True,
93
+ ):
94
+ """Load TinyStories dataset from local D: drive."""
95
+
96
+ if split == "train":
97
+ filepath = TINYSTORIES_TRAIN
98
+ elif split == "val":
99
+ filepath = TINYSTORIES_VAL
100
+ else:
101
+ raise ValueError(f"Unknown split: {split}")
102
+
103
+ if not os.path.exists(filepath):
104
+ raise FileNotFoundError(f"Dataset not found: {filepath}. Run download script first.")
105
+
106
+ # Lazy load transformers and torch
107
+ AutoTokenizer = _get_transformers()
108
+ torch, _, DataLoader = _get_torch_utils()
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
111
+ if tokenizer.pad_token is None:
112
+ tokenizer.pad_token = tokenizer.eos_token
113
+
114
+ token_ids = tokenize_file(filepath, tokenizer, max_chars)
115
+ dataset = TextFileDataset(token_ids, seq_len)
116
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True,
117
+ num_workers=0, pin_memory=True)
118
+
119
+ logger.info(f"TinyStories {split}: {len(dataset):,} chunks of [{batch_size} x {seq_len}]")
120
+ return dataloader, tokenizer
121
+
122
+
123
+ def get_wikitext_dataloader(
124
+ seq_len: int = 32,
125
+ batch_size: int = 64,
126
+ split: str = "train",
127
+ max_samples: int = 5000,
128
+ tokenizer_name: str = "gpt2",
129
+ ):
130
+ """Fallback: load WikiText-103 from HuggingFace cache."""
131
+ from datasets import load_dataset
132
+
133
+ AutoTokenizer = _get_transformers()
134
+ torch, _, DataLoader = _get_torch_utils()
135
+
136
+ raw = load_dataset("wikitext", "wikitext-103-raw-v1")
137
+ data = raw[split]
138
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
139
+ if tokenizer.pad_token is None:
140
+ tokenizer.pad_token = tokenizer.eos_token
141
+
142
+ text_corpus = " ".join(
143
+ row["text"] for row in data.select(range(min(len(data), max_samples)))
144
+ )
145
+ token_ids = tokenizer.encode(text_corpus, add_special_tokens=False)
146
+
147
+ dataset = TextFileDataset(token_ids, seq_len)
148
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True,
149
+ num_workers=0, pin_memory=True)
150
+ return dataloader, tokenizer
151
+
152
+
153
+ def get_dataloader(
154
+ dataset_name: str = "tinystories",
155
+ seq_len: int = 64,
156
+ batch_size: int = 64,
157
+ split: str = "train",
158
+ max_chars: int = None,
159
+ tokenizer_name: str = "gpt2",
160
+ **kwargs,
161
+ ):
162
+ """Unified dataloader factory."""
163
+ if dataset_name == "tinystories":
164
+ return get_tinystories_dataloader(
165
+ seq_len=seq_len, batch_size=batch_size, split=split,
166
+ max_chars=max_chars, tokenizer_name=tokenizer_name,
167
+ )
168
+ elif dataset_name == "wikitext":
169
+ return get_wikitext_dataloader(
170
+ seq_len=seq_len, batch_size=batch_size, split=split,
171
+ max_samples=kwargs.get("max_samples", 5000),
172
+ tokenizer_name=tokenizer_name,
173
+ )
174
+ else:
175
+ raise ValueError(f"Unknown dataset: {dataset_name}")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ # Quick test
180
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
181
+
182
+ print("=== TinyStories (100K chars test) ===")
183
+ dl, tok = get_dataloader("tinystories", seq_len=64, batch_size=8, max_chars=100_000)
184
+ batch = next(iter(dl))
185
+ print(f"Batch shape: {batch.shape}")
186
+ print(f"Sample tokens (first 10): {batch[0][:10].tolist()}")
187
+ print(f"Decoded: {tok.decode(batch[0][:10])}")
188
+
189
+ print("\n=== WikiText fallback ===")
190
+ dl2, tok2 = get_dataloader("wikitext", seq_len=32, batch_size=8, max_samples=100)
191
+ batch2 = next(iter(dl2))
192
+ print(f"Batch shape: {batch2.shape}")
ebm/decoder.py ADDED
@@ -0,0 +1,71 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .config import EBMConfig
4
+
5
+ class MoELayer(nn.Module):
6
+ def __init__(self, config: EBMConfig, in_features: int):
7
+ super().__init__()
8
+ self.num_experts = config.moe_experts
9
+ self.k = config.moe_active
10
+ self.hidden_dim = config.hidden_dim
11
+
12
+ # Router
13
+ self.router = nn.Linear(in_features, self.num_experts)
14
+
15
+ # Experts (simplified as parallel linear layers for CPU efficiency)
16
+ self.experts_w1 = nn.Parameter(torch.randn(self.num_experts, in_features, self.hidden_dim) / in_features**0.5)
17
+ self.experts_b1 = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
18
+
19
+ self.experts_w2 = nn.Parameter(torch.randn(self.num_experts, self.hidden_dim, in_features) / self.hidden_dim**0.5)
20
+ self.experts_b2 = nn.Parameter(torch.zeros(self.num_experts, in_features))
21
+
22
+ self.activation = nn.GELU()
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ # x is [B, D]
26
+ B, D = x.shape
27
+ router_logits = self.router(x)
28
+ routing_weights = torch.softmax(router_logits, dim=-1)
29
+
30
+ # Top-K routing
31
+ top_k_weights, top_k_indices = torch.topk(routing_weights, self.k, dim=-1)
32
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
33
+
34
+ # Vectorized MoE: compute all experts, then gather top-K
35
+ # x: [B, D] → expand to [B, 1, D] → [B, num_experts, D]
36
+ x_expanded = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [B, E, D]
37
+
38
+ # Expert FFN (batched over all experts)
39
+ # experts_w1: [E, D, H], x_expanded: [B, E, D] → h: [B, E, H]
40
+ h = torch.einsum('bed,edh->beh', x_expanded, self.experts_w1) + self.experts_b1.unsqueeze(0)
41
+ h = self.activation(h)
42
+ # experts_w2: [E, H, D], h: [B, E, H] → o: [B, E, D]
43
+ o = torch.einsum('beh,ehd->bed', h, self.experts_w2) + self.experts_b2.unsqueeze(0)
44
+
45
+ # Gather top-K expert outputs: [B, E, D] → [B, K, D]
46
+ top_k_idx = top_k_indices.unsqueeze(-1).expand(-1, -1, D) # [B, K, D]
47
+ selected = torch.gather(o, 1, top_k_idx) # [B, K, D]
48
+
49
+ # Weighted sum: [B, K, 1] * [B, K, D] → [B, D]
50
+ out = (top_k_weights.unsqueeze(-1) * selected).sum(dim=1)
51
+ return out
52
+
53
+ class EBMDecoder(nn.Module):
54
+ def __init__(self, config: EBMConfig):
55
+ super().__init__()
56
+
57
+ # Context concat -> [X; c_total] means dim is 2 * latent_dim
58
+ in_dim = config.latent_dim * 2
59
+
60
+ self.moe = MoELayer(config, in_dim)
61
+ self.output_layer = nn.Linear(in_dim, config.vocab_size)
62
+
63
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ x: Current latent state [B, D]
66
+ context: Context hierarchical vector [B, D]
67
+ """
68
+ combined = torch.cat([x, context], dim=-1)
69
+ moe_out = self.moe(combined)
70
+ logits = self.output_layer(moe_out)
71
+ return logits