ebm-splats 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm/__init__.py +8 -0
- ebm/config.py +109 -0
- ebm/context.py +117 -0
- ebm/cuda/__init__.py +1 -0
- ebm/cuda/energy.py +102 -0
- ebm/data.py +67 -0
- ebm/data_loader.py +192 -0
- ebm/decoder.py +71 -0
- ebm/energy.py +188 -0
- ebm/evaluation.py +125 -0
- ebm/geometry.py +27 -0
- ebm/langevin.py +102 -0
- ebm/logger.py +109 -0
- ebm/model.py +126 -0
- ebm/score_network.py +106 -0
- ebm/soc.py +64 -0
- ebm/splats.py +106 -0
- ebm/vulkan.py +115 -0
- ebm_splats-2.0.0.dist-info/METADATA +108 -0
- ebm_splats-2.0.0.dist-info/RECORD +32 -0
- ebm_splats-2.0.0.dist-info/WHEEL +5 -0
- ebm_splats-2.0.0.dist-info/entry_points.txt +2 -0
- ebm_splats-2.0.0.dist-info/licenses/LICENSE +201 -0
- ebm_splats-2.0.0.dist-info/top_level.txt +2 -0
- pglf/__init__.py +33 -0
- pglf/contrastive_head.py +150 -0
- pglf/embedding_service.py +273 -0
- pglf/encoders.py +329 -0
- pglf/flow_matching.py +278 -0
- pglf/pareto_filter.py +188 -0
- pglf/service.py +140 -0
- pglf/trainer.py +519 -0
ebm/__init__.py
ADDED
ebm/config.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EBM Configuration — V2 Specification Compliant.
|
|
3
|
+
|
|
4
|
+
Fixes applied:
|
|
5
|
+
- Removed duplicate init_alpha field
|
|
6
|
+
- Added V2 spec parameters (lambda context weights, adaptive noise, etc.)
|
|
7
|
+
- Explicit sign convention documentation
|
|
8
|
+
- Proper default values from V2 Especificacion Tecnica
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Optional, Tuple
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class EBMConfig:
|
|
17
|
+
"""
|
|
18
|
+
Configuration for the EBM continuous language model.
|
|
19
|
+
|
|
20
|
+
Sign Convention (V2 Spec §2.1):
|
|
21
|
+
p(x) = exp(-E(x)) / Z
|
|
22
|
+
score(x) = ∇_x log p(x) = -∇_x E(x)
|
|
23
|
+
Low energy = high probability (near splat centers)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# ── Environment ──
|
|
27
|
+
device: str = "cpu"
|
|
28
|
+
|
|
29
|
+
# ── Latent Space ──
|
|
30
|
+
latent_dim: int = 640
|
|
31
|
+
n_splats_init: int = 10000
|
|
32
|
+
max_splats: int = 100000
|
|
33
|
+
knn_k: int = 64
|
|
34
|
+
|
|
35
|
+
# ── Splat Initialization (Phase 1 improvements) ──
|
|
36
|
+
vocab_embedding_path: Optional[str] = None
|
|
37
|
+
init_from_vocab_embeddings: bool = False
|
|
38
|
+
init_alpha: float = 1.0
|
|
39
|
+
init_kappa: float = 10.0
|
|
40
|
+
|
|
41
|
+
# ── Splat Regularization ──
|
|
42
|
+
splat_temperature: float = 0.1
|
|
43
|
+
splat_weight_decay: float = 0.0
|
|
44
|
+
splat_weight_decay_start: float = 1.0
|
|
45
|
+
min_kappa: float = 1.0
|
|
46
|
+
max_kappa: float = 50.0
|
|
47
|
+
|
|
48
|
+
# ── Training ──
|
|
49
|
+
learning_rate: float = 1e-3
|
|
50
|
+
reg_weight: float = 0.01
|
|
51
|
+
grad_clip: float = 1.0
|
|
52
|
+
temperature: float = 0.1
|
|
53
|
+
|
|
54
|
+
# ── Curriculum Learning ──
|
|
55
|
+
enable_curriculum_learning: bool = True
|
|
56
|
+
curriculum_epochs: int = 5
|
|
57
|
+
curriculum_target_splats: int = 50000
|
|
58
|
+
splat_convergence_threshold: float = 0.95
|
|
59
|
+
|
|
60
|
+
# ── Noise Schedule (V2 §3.3 Adaptive Noise) ──
|
|
61
|
+
noise_levels: Tuple[float, ...] = (0.01, 0.05, 0.1, 0.2, 0.5)
|
|
62
|
+
sigma_base: float = 0.1
|
|
63
|
+
sigma_scale: float = 1.0
|
|
64
|
+
sigma_rho_ref: float = 1.0
|
|
65
|
+
|
|
66
|
+
# ── Langevin Dynamics ──
|
|
67
|
+
langevin_steps: int = 200
|
|
68
|
+
langevin_dt: float = 0.001
|
|
69
|
+
langevin_gamma: float = 0.1
|
|
70
|
+
langevin_T: float = 1.0
|
|
71
|
+
|
|
72
|
+
# ── SOC (Self-Organized Criticality) ──
|
|
73
|
+
soc_threshold: float = 0.8
|
|
74
|
+
soc_check_interval: int = 100
|
|
75
|
+
min_splat_distance: float = 0.1
|
|
76
|
+
|
|
77
|
+
# ── Hierarchical Context (V2 §4) ──
|
|
78
|
+
context_local_window: int = 16 # 8-16 tokens
|
|
79
|
+
context_medium_window: int = 128 # 64-128 tokens
|
|
80
|
+
context_global_window: int = 512 # 512+ tokens
|
|
81
|
+
beta_local: float = 0.5 # Fast adaptation
|
|
82
|
+
beta_medium: float = 0.8 # Moderate
|
|
83
|
+
beta_global: float = 0.95 # Slow, stable
|
|
84
|
+
lambda_context_local: float = 1.0
|
|
85
|
+
lambda_context_medium: float = 0.5
|
|
86
|
+
lambda_context_global: float = 0.2
|
|
87
|
+
|
|
88
|
+
# ── Decoder (V2 §5) ──
|
|
89
|
+
vocab_size: int = 50257 # GPT-2 vocab
|
|
90
|
+
moe_experts: int = 4
|
|
91
|
+
moe_active: int = 2
|
|
92
|
+
hidden_dim: int = 1024
|
|
93
|
+
|
|
94
|
+
# ── Collapse Regularization (V2 §2.4) ──
|
|
95
|
+
lambda_reg: float = 0.01
|
|
96
|
+
theta_threshold: float = 0.9 # Min angular separation between splat centers
|
|
97
|
+
|
|
98
|
+
# ── Energy Weights ──
|
|
99
|
+
lambda_geom: float = 0.01
|
|
100
|
+
lambda_comp: float = 0.05
|
|
101
|
+
|
|
102
|
+
# ── EMA ──
|
|
103
|
+
ema_decay: float = 0.999
|
|
104
|
+
|
|
105
|
+
# ── Error Recovery (V2 §7) ──
|
|
106
|
+
energy_stagnation_window: int = 5
|
|
107
|
+
energy_stagnation_epsilon: float = 1e-4
|
|
108
|
+
gradient_magnitude_threshold: float = 1e-5
|
|
109
|
+
max_recovery_attempts: int = 3
|
ebm/context.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hierarchical Context System — V2 §4.
|
|
3
|
+
|
|
4
|
+
Three-level context for long-range dependencies:
|
|
5
|
+
- Local: last 8-16 tokens (β≈0.5, fast adaptation)
|
|
6
|
+
- Medium: last 64-128 tokens (β≈0.8, moderate)
|
|
7
|
+
- Global: full sequence (β≈0.95, slow, stable)
|
|
8
|
+
|
|
9
|
+
Integration:
|
|
10
|
+
E_trans(x_t) = -Σ_{l∈{local,medium,global}} λ_l · (x_t · c_l)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from typing import Dict
|
|
16
|
+
|
|
17
|
+
from .config import EBMConfig
|
|
18
|
+
from .geometry import normalize_sphere
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HierarchicalContext(nn.Module):
|
|
22
|
+
"""Three-level context system for EBM language model."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config: EBMConfig):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.config = config
|
|
27
|
+
self.latent_dim = config.latent_dim
|
|
28
|
+
|
|
29
|
+
# Persistent context vectors (on hypersphere)
|
|
30
|
+
self.register_buffer('c_local', torch.zeros(1, config.latent_dim))
|
|
31
|
+
self.register_buffer('c_medium', torch.zeros(1, config.latent_dim))
|
|
32
|
+
self.register_buffer('c_global', torch.zeros(1, config.latent_dim))
|
|
33
|
+
|
|
34
|
+
self.beta_local = config.beta_local
|
|
35
|
+
self.beta_medium = config.beta_medium
|
|
36
|
+
self.beta_global = config.beta_global
|
|
37
|
+
|
|
38
|
+
# Update counters
|
|
39
|
+
self.register_buffer('_step', torch.tensor(0, dtype=torch.long))
|
|
40
|
+
self.register_buffer('_local_count', torch.tensor(0, dtype=torch.long))
|
|
41
|
+
self.register_buffer('_medium_count', torch.tensor(0, dtype=torch.long))
|
|
42
|
+
|
|
43
|
+
# Buffers stored as lists to avoid device/shape issues
|
|
44
|
+
self._local_buffer = []
|
|
45
|
+
self._medium_buffer = []
|
|
46
|
+
|
|
47
|
+
def reset(self, batch_size: int, device: torch.device):
|
|
48
|
+
"""Reset context for a new sequence."""
|
|
49
|
+
self.c_local = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
|
|
50
|
+
self.c_medium = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
|
|
51
|
+
self.c_global = normalize_sphere(torch.randn(batch_size, self.latent_dim, device=device))
|
|
52
|
+
self._step.zero_()
|
|
53
|
+
self._local_count.zero_()
|
|
54
|
+
self._medium_count.zero_()
|
|
55
|
+
self._local_buffer = []
|
|
56
|
+
self._medium_buffer = []
|
|
57
|
+
|
|
58
|
+
@torch.no_grad()
|
|
59
|
+
def update(self, x: torch.Tensor):
|
|
60
|
+
"""Update context with new token embedding x: [B, D]."""
|
|
61
|
+
B = x.size(0)
|
|
62
|
+
device = x.device
|
|
63
|
+
step = self._step.item()
|
|
64
|
+
|
|
65
|
+
# Ensure batch/device consistency
|
|
66
|
+
if self.c_local.size(0) != B or self.c_local.device != device:
|
|
67
|
+
self.c_local = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
|
|
68
|
+
self.c_medium = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
|
|
69
|
+
self.c_global = normalize_sphere(torch.randn(B, self.latent_dim, device=device))
|
|
70
|
+
|
|
71
|
+
# ── Local: every step ──
|
|
72
|
+
self.c_local = normalize_sphere(
|
|
73
|
+
self.beta_local * self.c_local.to(device) + (1 - self.beta_local) * x
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Store batch-averaged embedding for medium context
|
|
77
|
+
self._local_buffer.append(x.mean(dim=0).detach().cpu())
|
|
78
|
+
if len(self._local_buffer) > self.config.context_local_window:
|
|
79
|
+
self._local_buffer.pop(0)
|
|
80
|
+
self._local_count.add_(1)
|
|
81
|
+
|
|
82
|
+
# ── Medium: every 4 steps ──
|
|
83
|
+
if step > 0 and step % 4 == 0 and len(self._local_buffer) > 0:
|
|
84
|
+
local_avg = torch.stack(self._local_buffer).mean(dim=0).to(device)
|
|
85
|
+
self.c_medium = normalize_sphere(
|
|
86
|
+
self.beta_medium * self.c_medium.to(device) + (1 - self.beta_medium) * local_avg.unsqueeze(0).expand(B, -1)
|
|
87
|
+
)
|
|
88
|
+
self._medium_buffer.append(self.c_medium.mean(dim=0).detach().cpu())
|
|
89
|
+
if len(self._medium_buffer) > self.config.context_medium_window:
|
|
90
|
+
self._medium_buffer.pop(0)
|
|
91
|
+
self._medium_count.add_(1)
|
|
92
|
+
|
|
93
|
+
# ── Global: every 16 steps ──
|
|
94
|
+
if step > 0 and step % 16 == 0 and len(self._medium_buffer) > 0:
|
|
95
|
+
medium_avg = torch.stack(self._medium_buffer).mean(dim=0).to(device)
|
|
96
|
+
self.c_global = normalize_sphere(
|
|
97
|
+
self.beta_global * self.c_global.to(device) + (1 - self.beta_global) * medium_avg.unsqueeze(0).expand(B, -1)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._step.add_(1)
|
|
101
|
+
|
|
102
|
+
def get_context(self) -> Dict[str, torch.Tensor]:
|
|
103
|
+
"""Return dict of context vectors."""
|
|
104
|
+
return {
|
|
105
|
+
'local': self.c_local,
|
|
106
|
+
'medium': self.c_medium,
|
|
107
|
+
'global': self.c_global,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def get_combined_context(self) -> torch.Tensor:
|
|
111
|
+
"""Weighted average context."""
|
|
112
|
+
c = (
|
|
113
|
+
self.config.lambda_context_local * self.c_local +
|
|
114
|
+
self.config.lambda_context_medium * self.c_medium +
|
|
115
|
+
self.config.lambda_context_global * self.c_global
|
|
116
|
+
)
|
|
117
|
+
return normalize_sphere(c)
|
ebm/cuda/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""CUDA-accelerated energy computation (requires PyTorch with CUDA)."""
|
ebm/cuda/energy.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CUDA-native Energy Function — V2 Compliant.
|
|
3
|
+
|
|
4
|
+
Same as energy.py but uses CUDA AMP-safe autograd instead of manual gradients.
|
|
5
|
+
The V2 manual gradient approach in energy.py is preferred for training stability.
|
|
6
|
+
This module serves as CUDA fallback when V2 manual gradients aren't needed.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
from ..config import EBMConfig
|
|
14
|
+
from ..geometry import project_to_tangent
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EnergyFunctionCUDA(nn.Module):
|
|
18
|
+
"""CUDA-native energy with V2 sign conventions."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: EBMConfig, splat_store):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.config = config
|
|
23
|
+
self.splats = splat_store
|
|
24
|
+
self.W_comp = nn.Linear(3, 1)
|
|
25
|
+
|
|
26
|
+
def compute_splat_energy(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
original_shape = x.shape[:-1]
|
|
28
|
+
if x.dim() > 2:
|
|
29
|
+
x = x.reshape(-1, x.size(-1))
|
|
30
|
+
|
|
31
|
+
neighbors_mu, neighbors_alpha, neighbors_kappa = self.splats.find_neighbors(x, self.config.knn_k)
|
|
32
|
+
|
|
33
|
+
dot_products = torch.bmm(neighbors_mu, x.unsqueeze(-1)).squeeze(-1)
|
|
34
|
+
|
|
35
|
+
# Importance weighting (V2 §2.2)
|
|
36
|
+
importance = neighbors_kappa.clamp(min=1e-4)
|
|
37
|
+
weights = importance / importance.sum(dim=-1, keepdim=True)
|
|
38
|
+
|
|
39
|
+
exponent = neighbors_alpha * (dot_products - 1.0) / self.config.temperature
|
|
40
|
+
weighted_exponent = exponent + torch.log(weights.clamp(min=1e-8))
|
|
41
|
+
|
|
42
|
+
energy = -torch.logsumexp(weighted_exponent, dim=-1)
|
|
43
|
+
return energy.view(original_shape)
|
|
44
|
+
|
|
45
|
+
def compute_geom_energy(self, x: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
original_shape = x.shape[:-1]
|
|
47
|
+
if x.dim() > 2:
|
|
48
|
+
x = x.reshape(-1, x.size(-1))
|
|
49
|
+
|
|
50
|
+
if x.size(0) < 2:
|
|
51
|
+
return torch.tensor(0.0, device=x.device)
|
|
52
|
+
|
|
53
|
+
batch_sims = torch.mm(x, x.T)
|
|
54
|
+
mask = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
|
|
55
|
+
off_diag = batch_sims[mask]
|
|
56
|
+
spread_energy = -torch.log(1.0 - off_diag.clamp(max=1.0 - 1e-4) + 1e-4).mean()
|
|
57
|
+
|
|
58
|
+
return spread_energy
|
|
59
|
+
|
|
60
|
+
def compute_comp_energy(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
original_shape = x.shape[:-1]
|
|
62
|
+
if x.dim() > 2:
|
|
63
|
+
x = x.reshape(-1, x.size(-1))
|
|
64
|
+
|
|
65
|
+
neighbors_mu, _, _ = self.splats.find_neighbors(x, 2)
|
|
66
|
+
u = torch.bmm(neighbors_mu[:, 0:1, :], x.unsqueeze(-1)).squeeze(-1)
|
|
67
|
+
v = torch.bmm(neighbors_mu[:, 1:2, :], x.unsqueeze(-1)).squeeze(-1)
|
|
68
|
+
uv_concat = torch.cat([u, v, u * v], dim=-1)
|
|
69
|
+
comp_energy = torch.sigmoid(self.W_comp(uv_concat)).squeeze(-1)
|
|
70
|
+
return comp_energy.view(original_shape)
|
|
71
|
+
|
|
72
|
+
def compute_context_energy(self, x: torch.Tensor, context_vecs: dict) -> torch.Tensor:
|
|
73
|
+
energy = torch.tensor(0.0, device=x.device)
|
|
74
|
+
for level, vec in context_vecs.items():
|
|
75
|
+
lam = getattr(self.config, f'lambda_context_{level}', 0.0)
|
|
76
|
+
energy = energy - lam * (x * vec).sum(dim=-1)
|
|
77
|
+
return energy
|
|
78
|
+
|
|
79
|
+
def forward(self, x: torch.Tensor, context_vecs: dict = None) -> torch.Tensor:
|
|
80
|
+
e_splat = self.compute_splat_energy(x)
|
|
81
|
+
e_geom = self.compute_geom_energy(x) * self.config.lambda_geom
|
|
82
|
+
e_comp = self.compute_comp_energy(x) * self.config.lambda_comp
|
|
83
|
+
e_total = e_splat + e_geom + e_comp
|
|
84
|
+
|
|
85
|
+
if context_vecs is not None:
|
|
86
|
+
e_trans = self.compute_context_energy(x, context_vecs)
|
|
87
|
+
e_total = e_total + e_trans
|
|
88
|
+
|
|
89
|
+
return e_total
|
|
90
|
+
|
|
91
|
+
def compute_score(self, x: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
"""
|
|
93
|
+
Riemannian score via AMP-safe autograd.
|
|
94
|
+
For training stability, prefer energy.py manual gradient approach.
|
|
95
|
+
"""
|
|
96
|
+
with torch.amp.autocast('cuda', enabled=False):
|
|
97
|
+
x_f32 = x.float().detach().requires_grad_(True)
|
|
98
|
+
energy = self.forward(x_f32.float())
|
|
99
|
+
grad_e = torch.autograd.grad(energy.sum(), x_f32, create_graph=False)[0]
|
|
100
|
+
grad_riemann = project_to_tangent(x_f32, grad_e)
|
|
101
|
+
|
|
102
|
+
return -grad_riemann
|
ebm/data.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset, DataLoader
|
|
3
|
+
from datasets import load_dataset
|
|
4
|
+
from transformers import AutoTokenizer
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
logging.basicConfig(level=logging.INFO)
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
class TextDataset(Dataset):
|
|
11
|
+
def __init__(self, tokenized_data, seq_len: int = 16):
|
|
12
|
+
"""
|
|
13
|
+
Dynamically slices a continuous stream of tokens into chunks of `seq_len`.
|
|
14
|
+
"""
|
|
15
|
+
self.seq_len = seq_len
|
|
16
|
+
self.tokens = tokenized_data
|
|
17
|
+
self.total_chunks = len(self.tokens) // seq_len
|
|
18
|
+
|
|
19
|
+
def __len__(self):
|
|
20
|
+
return self.total_chunks
|
|
21
|
+
|
|
22
|
+
def __getitem__(self, idx):
|
|
23
|
+
start = idx * self.seq_len
|
|
24
|
+
end = start + self.seq_len
|
|
25
|
+
return torch.tensor(self.tokens[start:end], dtype=torch.long)
|
|
26
|
+
|
|
27
|
+
def get_dataloader(tokenizer_name: str = "gpt2", dataset_name: str = "wikitext",
|
|
28
|
+
config_name: str = "wikitext-2-raw-v1", split: str = "train",
|
|
29
|
+
batch_size: int = 8, seq_len: int = 16, max_samples: int = 5000):
|
|
30
|
+
"""
|
|
31
|
+
Downloads, tokenizes, and maps text blocks to PyTorch DataLoader structs.
|
|
32
|
+
"""
|
|
33
|
+
logger.info(f"Loading '{dataset_name}' dataset ({split})...")
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
raw_datasets = load_dataset(dataset_name, config_name)
|
|
37
|
+
data_split = raw_datasets[split]
|
|
38
|
+
except Exception as e:
|
|
39
|
+
logger.error(f"Failed to load dataset: {e}")
|
|
40
|
+
raise
|
|
41
|
+
|
|
42
|
+
logger.info(f"Loading tokenizer '{tokenizer_name}'...")
|
|
43
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
44
|
+
|
|
45
|
+
# Optional pad token setting if required by specific batching, but we utilize
|
|
46
|
+
# exact chunks here.
|
|
47
|
+
if tokenizer.pad_token is None:
|
|
48
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
49
|
+
|
|
50
|
+
logger.info("Tokenizing active text samples...")
|
|
51
|
+
# Concatenate texts dropping empty lines
|
|
52
|
+
text_corpus = " ".join([row["text"] for row in data_split.select(range(min(len(data_split), max_samples)))])
|
|
53
|
+
|
|
54
|
+
# Process without truncation targeting a continuous token stream
|
|
55
|
+
tokenized = tokenizer.encode(text_corpus, add_special_tokens=False)
|
|
56
|
+
|
|
57
|
+
logger.info(f"Generated {len(tokenized)} tokens.")
|
|
58
|
+
|
|
59
|
+
# Construct sequence slicer dataset
|
|
60
|
+
dataset = TextDataset(tokenized, seq_len=seq_len)
|
|
61
|
+
|
|
62
|
+
# Drop_last=True ensures all batches are uniformly sized (B, seq_len)
|
|
63
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
|
64
|
+
|
|
65
|
+
logger.info(f"Dataloader yielding {len(dataloader)} batches of dimension [{batch_size}, {seq_len}]")
|
|
66
|
+
|
|
67
|
+
return dataloader, tokenizer
|
ebm/data_loader.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Optimized dataset loader for EBM training. Uses local text files
|
|
4
|
+
instead of downloading from HuggingFace every time.
|
|
5
|
+
|
|
6
|
+
Supported datasets:
|
|
7
|
+
- TinyStories: D:/datasets/ebm/tinystories_train.txt (1.9GB, 2.1M stories)
|
|
8
|
+
- TinyStories val: D:/datasets/ebm/tinystories_val.txt (19MB, 22K stories)
|
|
9
|
+
- WikiText-103: cached via HuggingFace (fallback)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import logging
|
|
14
|
+
import warnings
|
|
15
|
+
|
|
16
|
+
# Suppress GPT-2 tokenizer warnings about sequence length
|
|
17
|
+
warnings.filterwarnings("ignore", message="Token indices sequence length")
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Lazy imports for heavy dependencies
|
|
22
|
+
def _get_torch_utils():
|
|
23
|
+
"""Lazy import torch utilities."""
|
|
24
|
+
import torch
|
|
25
|
+
from torch.utils.data import Dataset, DataLoader
|
|
26
|
+
return torch, Dataset, DataLoader
|
|
27
|
+
|
|
28
|
+
def _get_transformers():
|
|
29
|
+
"""Lazy import transformers."""
|
|
30
|
+
from transformers import AutoTokenizer
|
|
31
|
+
return AutoTokenizer
|
|
32
|
+
|
|
33
|
+
# Default paths
|
|
34
|
+
DATASETS_DIR = "D:/datasets/ebm"
|
|
35
|
+
TINYSTORIES_TRAIN = os.path.join(DATASETS_DIR, "tinystories_train.txt")
|
|
36
|
+
TINYSTORIES_VAL = os.path.join(DATASETS_DIR, "tinystories_val.txt")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TextFileDataset:
|
|
40
|
+
"""Memory-mapped text file sliced into token chunks."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, token_ids: list, seq_len: int):
|
|
43
|
+
self.seq_len = seq_len
|
|
44
|
+
self.tokens = token_ids
|
|
45
|
+
self.total_chunks = len(self.tokens) // seq_len
|
|
46
|
+
# Register with torch Dataset class
|
|
47
|
+
torch, Dataset, _ = _get_torch_utils()
|
|
48
|
+
self._Dataset = Dataset
|
|
49
|
+
|
|
50
|
+
def __len__(self):
|
|
51
|
+
return self.total_chunks
|
|
52
|
+
|
|
53
|
+
def __getitem__(self, idx):
|
|
54
|
+
torch, _, _ = _get_torch_utils()
|
|
55
|
+
start = idx * self.seq_len
|
|
56
|
+
end = start + self.seq_len
|
|
57
|
+
return torch.tensor(self.tokens[start:end], dtype=torch.long)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def tokenize_file(filepath: str, tokenizer, max_chars: int = None, chunk_size: int = 50_000_000) -> list:
|
|
61
|
+
"""Tokenize a text file into a flat list of token IDs, reading in chunks to avoid OOM."""
|
|
62
|
+
logger.info(f"Reading {filepath}...")
|
|
63
|
+
|
|
64
|
+
all_token_ids = []
|
|
65
|
+
chars_read = 0
|
|
66
|
+
|
|
67
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
68
|
+
while True:
|
|
69
|
+
if max_chars and chars_read >= max_chars:
|
|
70
|
+
break
|
|
71
|
+
read_size = min(chunk_size, max_chars - chars_read) if max_chars else chunk_size
|
|
72
|
+
text = f.read(read_size)
|
|
73
|
+
if not text:
|
|
74
|
+
break
|
|
75
|
+
chars_read += len(text)
|
|
76
|
+
|
|
77
|
+
# Tokenize this chunk
|
|
78
|
+
chunk_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
79
|
+
all_token_ids.extend(chunk_ids)
|
|
80
|
+
logger.info(f" {chars_read:,} chars → {len(all_token_ids):,} tokens")
|
|
81
|
+
|
|
82
|
+
logger.info(f" Total: {chars_read:,} characters, {len(all_token_ids):,} tokens")
|
|
83
|
+
return all_token_ids
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_tinystories_dataloader(
|
|
87
|
+
seq_len: int = 64,
|
|
88
|
+
batch_size: int = 64,
|
|
89
|
+
split: str = "train",
|
|
90
|
+
max_chars: int = None,
|
|
91
|
+
tokenizer_name: str = "gpt2",
|
|
92
|
+
shuffle: bool = True,
|
|
93
|
+
):
|
|
94
|
+
"""Load TinyStories dataset from local D: drive."""
|
|
95
|
+
|
|
96
|
+
if split == "train":
|
|
97
|
+
filepath = TINYSTORIES_TRAIN
|
|
98
|
+
elif split == "val":
|
|
99
|
+
filepath = TINYSTORIES_VAL
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unknown split: {split}")
|
|
102
|
+
|
|
103
|
+
if not os.path.exists(filepath):
|
|
104
|
+
raise FileNotFoundError(f"Dataset not found: {filepath}. Run download script first.")
|
|
105
|
+
|
|
106
|
+
# Lazy load transformers and torch
|
|
107
|
+
AutoTokenizer = _get_transformers()
|
|
108
|
+
torch, _, DataLoader = _get_torch_utils()
|
|
109
|
+
|
|
110
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
111
|
+
if tokenizer.pad_token is None:
|
|
112
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
113
|
+
|
|
114
|
+
token_ids = tokenize_file(filepath, tokenizer, max_chars)
|
|
115
|
+
dataset = TextFileDataset(token_ids, seq_len)
|
|
116
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True,
|
|
117
|
+
num_workers=0, pin_memory=True)
|
|
118
|
+
|
|
119
|
+
logger.info(f"TinyStories {split}: {len(dataset):,} chunks of [{batch_size} x {seq_len}]")
|
|
120
|
+
return dataloader, tokenizer
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_wikitext_dataloader(
|
|
124
|
+
seq_len: int = 32,
|
|
125
|
+
batch_size: int = 64,
|
|
126
|
+
split: str = "train",
|
|
127
|
+
max_samples: int = 5000,
|
|
128
|
+
tokenizer_name: str = "gpt2",
|
|
129
|
+
):
|
|
130
|
+
"""Fallback: load WikiText-103 from HuggingFace cache."""
|
|
131
|
+
from datasets import load_dataset
|
|
132
|
+
|
|
133
|
+
AutoTokenizer = _get_transformers()
|
|
134
|
+
torch, _, DataLoader = _get_torch_utils()
|
|
135
|
+
|
|
136
|
+
raw = load_dataset("wikitext", "wikitext-103-raw-v1")
|
|
137
|
+
data = raw[split]
|
|
138
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
139
|
+
if tokenizer.pad_token is None:
|
|
140
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
141
|
+
|
|
142
|
+
text_corpus = " ".join(
|
|
143
|
+
row["text"] for row in data.select(range(min(len(data), max_samples)))
|
|
144
|
+
)
|
|
145
|
+
token_ids = tokenizer.encode(text_corpus, add_special_tokens=False)
|
|
146
|
+
|
|
147
|
+
dataset = TextFileDataset(token_ids, seq_len)
|
|
148
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True,
|
|
149
|
+
num_workers=0, pin_memory=True)
|
|
150
|
+
return dataloader, tokenizer
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_dataloader(
|
|
154
|
+
dataset_name: str = "tinystories",
|
|
155
|
+
seq_len: int = 64,
|
|
156
|
+
batch_size: int = 64,
|
|
157
|
+
split: str = "train",
|
|
158
|
+
max_chars: int = None,
|
|
159
|
+
tokenizer_name: str = "gpt2",
|
|
160
|
+
**kwargs,
|
|
161
|
+
):
|
|
162
|
+
"""Unified dataloader factory."""
|
|
163
|
+
if dataset_name == "tinystories":
|
|
164
|
+
return get_tinystories_dataloader(
|
|
165
|
+
seq_len=seq_len, batch_size=batch_size, split=split,
|
|
166
|
+
max_chars=max_chars, tokenizer_name=tokenizer_name,
|
|
167
|
+
)
|
|
168
|
+
elif dataset_name == "wikitext":
|
|
169
|
+
return get_wikitext_dataloader(
|
|
170
|
+
seq_len=seq_len, batch_size=batch_size, split=split,
|
|
171
|
+
max_samples=kwargs.get("max_samples", 5000),
|
|
172
|
+
tokenizer_name=tokenizer_name,
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
if __name__ == "__main__":
|
|
179
|
+
# Quick test
|
|
180
|
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
181
|
+
|
|
182
|
+
print("=== TinyStories (100K chars test) ===")
|
|
183
|
+
dl, tok = get_dataloader("tinystories", seq_len=64, batch_size=8, max_chars=100_000)
|
|
184
|
+
batch = next(iter(dl))
|
|
185
|
+
print(f"Batch shape: {batch.shape}")
|
|
186
|
+
print(f"Sample tokens (first 10): {batch[0][:10].tolist()}")
|
|
187
|
+
print(f"Decoded: {tok.decode(batch[0][:10])}")
|
|
188
|
+
|
|
189
|
+
print("\n=== WikiText fallback ===")
|
|
190
|
+
dl2, tok2 = get_dataloader("wikitext", seq_len=32, batch_size=8, max_samples=100)
|
|
191
|
+
batch2 = next(iter(dl2))
|
|
192
|
+
print(f"Batch shape: {batch2.shape}")
|
ebm/decoder.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .config import EBMConfig
|
|
4
|
+
|
|
5
|
+
class MoELayer(nn.Module):
|
|
6
|
+
def __init__(self, config: EBMConfig, in_features: int):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.num_experts = config.moe_experts
|
|
9
|
+
self.k = config.moe_active
|
|
10
|
+
self.hidden_dim = config.hidden_dim
|
|
11
|
+
|
|
12
|
+
# Router
|
|
13
|
+
self.router = nn.Linear(in_features, self.num_experts)
|
|
14
|
+
|
|
15
|
+
# Experts (simplified as parallel linear layers for CPU efficiency)
|
|
16
|
+
self.experts_w1 = nn.Parameter(torch.randn(self.num_experts, in_features, self.hidden_dim) / in_features**0.5)
|
|
17
|
+
self.experts_b1 = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
|
18
|
+
|
|
19
|
+
self.experts_w2 = nn.Parameter(torch.randn(self.num_experts, self.hidden_dim, in_features) / self.hidden_dim**0.5)
|
|
20
|
+
self.experts_b2 = nn.Parameter(torch.zeros(self.num_experts, in_features))
|
|
21
|
+
|
|
22
|
+
self.activation = nn.GELU()
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
# x is [B, D]
|
|
26
|
+
B, D = x.shape
|
|
27
|
+
router_logits = self.router(x)
|
|
28
|
+
routing_weights = torch.softmax(router_logits, dim=-1)
|
|
29
|
+
|
|
30
|
+
# Top-K routing
|
|
31
|
+
top_k_weights, top_k_indices = torch.topk(routing_weights, self.k, dim=-1)
|
|
32
|
+
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
|
|
33
|
+
|
|
34
|
+
# Vectorized MoE: compute all experts, then gather top-K
|
|
35
|
+
# x: [B, D] → expand to [B, 1, D] → [B, num_experts, D]
|
|
36
|
+
x_expanded = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [B, E, D]
|
|
37
|
+
|
|
38
|
+
# Expert FFN (batched over all experts)
|
|
39
|
+
# experts_w1: [E, D, H], x_expanded: [B, E, D] → h: [B, E, H]
|
|
40
|
+
h = torch.einsum('bed,edh->beh', x_expanded, self.experts_w1) + self.experts_b1.unsqueeze(0)
|
|
41
|
+
h = self.activation(h)
|
|
42
|
+
# experts_w2: [E, H, D], h: [B, E, H] → o: [B, E, D]
|
|
43
|
+
o = torch.einsum('beh,ehd->bed', h, self.experts_w2) + self.experts_b2.unsqueeze(0)
|
|
44
|
+
|
|
45
|
+
# Gather top-K expert outputs: [B, E, D] → [B, K, D]
|
|
46
|
+
top_k_idx = top_k_indices.unsqueeze(-1).expand(-1, -1, D) # [B, K, D]
|
|
47
|
+
selected = torch.gather(o, 1, top_k_idx) # [B, K, D]
|
|
48
|
+
|
|
49
|
+
# Weighted sum: [B, K, 1] * [B, K, D] → [B, D]
|
|
50
|
+
out = (top_k_weights.unsqueeze(-1) * selected).sum(dim=1)
|
|
51
|
+
return out
|
|
52
|
+
|
|
53
|
+
class EBMDecoder(nn.Module):
|
|
54
|
+
def __init__(self, config: EBMConfig):
|
|
55
|
+
super().__init__()
|
|
56
|
+
|
|
57
|
+
# Context concat -> [X; c_total] means dim is 2 * latent_dim
|
|
58
|
+
in_dim = config.latent_dim * 2
|
|
59
|
+
|
|
60
|
+
self.moe = MoELayer(config, in_dim)
|
|
61
|
+
self.output_layer = nn.Linear(in_dim, config.vocab_size)
|
|
62
|
+
|
|
63
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
"""
|
|
65
|
+
x: Current latent state [B, D]
|
|
66
|
+
context: Context hierarchical vector [B, D]
|
|
67
|
+
"""
|
|
68
|
+
combined = torch.cat([x, context], dim=-1)
|
|
69
|
+
moe_out = self.moe(combined)
|
|
70
|
+
logits = self.output_layer(moe_out)
|
|
71
|
+
return logits
|