transformer-toolkit 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. transformer_toolkit-0.1.0/PKG-INFO +28 -0
  2. transformer_toolkit-0.1.0/README.md +0 -0
  3. transformer_toolkit-0.1.0/pyproject.toml +43 -0
  4. transformer_toolkit-0.1.0/setup.cfg +4 -0
  5. transformer_toolkit-0.1.0/transformer_toolkit/__init__.py +16 -0
  6. transformer_toolkit-0.1.0/transformer_toolkit/attention.py +131 -0
  7. transformer_toolkit-0.1.0/transformer_toolkit/block.py +30 -0
  8. transformer_toolkit-0.1.0/transformer_toolkit/c_tokenizers.py +114 -0
  9. transformer_toolkit-0.1.0/transformer_toolkit/colors.py +23 -0
  10. transformer_toolkit-0.1.0/transformer_toolkit/dataloader.py +184 -0
  11. transformer_toolkit-0.1.0/transformer_toolkit/feed_forward.py +51 -0
  12. transformer_toolkit-0.1.0/transformer_toolkit/hf_hub.py +257 -0
  13. transformer_toolkit-0.1.0/transformer_toolkit/inference.py +151 -0
  14. transformer_toolkit-0.1.0/transformer_toolkit/model.py +128 -0
  15. transformer_toolkit-0.1.0/transformer_toolkit/normalization.py +39 -0
  16. transformer_toolkit-0.1.0/transformer_toolkit/positional_encodings.py +63 -0
  17. transformer_toolkit-0.1.0/transformer_toolkit/sentiment.py +299 -0
  18. transformer_toolkit-0.1.0/transformer_toolkit/trainer.py +368 -0
  19. transformer_toolkit-0.1.0/transformer_toolkit.egg-info/PKG-INFO +28 -0
  20. transformer_toolkit-0.1.0/transformer_toolkit.egg-info/SOURCES.txt +21 -0
  21. transformer_toolkit-0.1.0/transformer_toolkit.egg-info/dependency_links.txt +1 -0
  22. transformer_toolkit-0.1.0/transformer_toolkit.egg-info/requires.txt +13 -0
  23. transformer_toolkit-0.1.0/transformer_toolkit.egg-info/top_level.txt +1 -0
@@ -0,0 +1,28 @@
1
+ Metadata-Version: 2.4
2
+ Name: transformer-toolkit
3
+ Version: 0.1.0
4
+ Summary: Minimal, modular transformer library for training your own LLM
5
+ Author: Govind Barbade
6
+ Project-URL: Homepage, https://github.com/govindbarbade/transformer-toolkit
7
+ Project-URL: Repository, https://github.com/govindbarbade/transformer-toolkit
8
+ Keywords: transformer,llm,deep learning,nlp,pytorch
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: torch>=2.0.0
20
+ Requires-Dist: pydantic>=2.0.0
21
+ Provides-Extra: tokenizers
22
+ Requires-Dist: tokenizers>=0.15.0; extra == "tokenizers"
23
+ Provides-Extra: hf
24
+ Requires-Dist: transformers>=4.35.0; extra == "hf"
25
+ Requires-Dist: huggingface_hub>=0.20.0; extra == "hf"
26
+ Requires-Dist: datasets>=2.14.0; extra == "hf"
27
+ Provides-Extra: all
28
+ Requires-Dist: transformer-toolkit[hf,tokenizers]; extra == "all"
File without changes
@@ -0,0 +1,43 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "transformer-toolkit"
7
+ version = "0.1.0"
8
+ description = "Minimal, modular transformer library for training your own LLM"
9
+ readme = "README.md"
10
+ license = { file = "LICENSE" }
11
+ requires-python = ">=3.10"
12
+ authors = [{ name = "Govind Barbade" }]
13
+ keywords = ["transformer", "llm", "deep learning", "nlp", "pytorch"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
23
+ ]
24
+
25
+ dependencies = [
26
+ "torch>=2.0.0",
27
+ "pydantic>=2.0.0",
28
+ ]
29
+
30
+ [project.optional-dependencies]
31
+ tokenizers = ["tokenizers>=0.15.0"]
32
+ hf = ["transformers>=4.35.0",
33
+ "huggingface_hub>=0.20.0",
34
+ "datasets>=2.14.0"]
35
+ all = ["transformer-toolkit[tokenizers,hf]"]
36
+
37
+ [project.urls]
38
+ Homepage = "https://github.com/govindbarbade/transformer-toolkit"
39
+ Repository = "https://github.com/govindbarbade/transformer-toolkit"
40
+
41
+ [tool.setuptools.packages.find]
42
+ where = ["."]
43
+ include = ["transformer_toolkit*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,16 @@
1
+ from .model import Transformer, TransformerConfig
2
+ from .trainer import Trainer, TrainConfig
3
+ from .dataloader import DataConfig, from_files, from_binary, from_hf, from_strings
4
+ from .c_tokenizers import RustBPETokenizer, HFTokenizer, ByteLevelTokenizer, CharLevelTokenizer
5
+ from .hf_hub import login, push_to_hub, pull_from_hub
6
+
7
+ __version__ = "0.1.0"
8
+ __author__ = "Govind Barbade"
9
+
10
+ __all__ = [
11
+ "Transformer", "TransformerConfig",
12
+ "Trainer", "TrainConfig",
13
+ "DataConfig", "from_files", "from_binary", "from_hf", "from_strings",
14
+ "RustBPETokenizer", "HFTokenizer", "ByteLevelTokenizer", "CharLevelTokenizer",
15
+ "login", "push_to_hub", "pull_from_hub",
16
+ ]
@@ -0,0 +1,131 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+ class MultiHeadAttention(nn.Module):
6
+ """Classic MHA. Used in original Transformer, BERT, GPT-2."""
7
+ def __init__(self, dim: int, n_heads: int):
8
+ super().__init__()
9
+ self.n_heads = n_heads
10
+ self.head_dim = dim // n_heads
11
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
12
+ self.out = nn.Linear(dim, dim, bias=False)
13
+
14
+ def forward(self, x, mask=None):
15
+ B, T, C = x.shape
16
+ q, k, v = self.qkv(x).split(C, dim=-1)
17
+ def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
18
+ q, k, v = split(q), split(k), split(v)
19
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
20
+ if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
21
+ return self.out(F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C) # noqa
22
+
23
+
24
+ class GroupedQueryAttention(nn.Module):
25
+ """Fewer k/v heads than q heads. Used in LLaMA 3, Mistral."""
26
+ def __init__(self, dim: int, n_heads: int, n_kv_heads: int):
27
+ super().__init__()
28
+ self.n_heads = n_heads
29
+ self.n_kv_heads = n_kv_heads
30
+ self.head_dim = dim // n_heads
31
+ self.q = nn.Linear(dim, dim, bias=False)
32
+ self.kv = nn.Linear(dim, 2 * n_kv_heads * self.head_dim, bias=False)
33
+ self.out = nn.Linear(dim, dim, bias=False)
34
+
35
+ def forward(self, x, mask=None):
36
+ B, T, C = x.shape
37
+ q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
38
+ k, v = self.kv(x).split(self.n_kv_heads * self.head_dim, dim=-1)
39
+ def split_kv(t): return t.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
40
+ k, v = split_kv(k), split_kv(v)
41
+ # repeat k/v to match q heads
42
+ r = self.n_heads // self.n_kv_heads
43
+ k = k.repeat_interleave(r, dim=1)
44
+ v = v.repeat_interleave(r, dim=1)
45
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
46
+ if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
47
+ return self.out((F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C))
48
+
49
+
50
+ class MultiQueryAttention(nn.Module):
51
+ """Single k/v head shared across all q heads. Used in Falcon, early Gemini."""
52
+ def __init__(self, dim: int, n_heads: int):
53
+ super().__init__()
54
+ self.n_heads = n_heads
55
+ self.head_dim = dim // n_heads
56
+ self.q = nn.Linear(dim, dim, bias=False)
57
+ self.k = nn.Linear(dim, self.head_dim, bias=False)
58
+ self.v = nn.Linear(dim, self.head_dim, bias=False)
59
+ self.out = nn.Linear(dim, dim, bias=False)
60
+
61
+ def forward(self, x, mask=None):
62
+ B, T, C = x.shape
63
+ q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
64
+ k = self.k(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(B, self.n_heads, T, self.head_dim)
65
+ v = self.v(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(B, self.n_heads, T, self.head_dim)
66
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
67
+ if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
68
+ return self.out((F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C))
69
+
70
+ class FlashAttention(nn.Module):
71
+ """
72
+ Flash Attention — same result as MHA, far less memory.
73
+ Uses torch's built-in scaled_dot_product_attention (torch >= 2.0)
74
+ which calls the CUDA kernel automatically when available.
75
+ """
76
+ def __init__(self, dim: int, n_heads: int):
77
+ super().__init__()
78
+ self.n_heads = n_heads
79
+ self.head_dim = dim // n_heads
80
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
81
+ self.out = nn.Linear(dim, dim, bias=False)
82
+
83
+ def forward(self, x, mask=None, causal=True):
84
+ B, T, C = x.shape
85
+ q, k, v = self.qkv(x).split(C, dim=-1)
86
+ def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
87
+ q, k, v = split(q), split(k), split(v)
88
+
89
+ # one line — PyTorch handles the tiling/chunking kernel
90
+ x = nn.functional.scaled_dot_product_attention(
91
+ q, k, v,
92
+ attn_mask=mask,
93
+ is_causal=causal, # builds causal mask internally, no extra memory
94
+ dropout_p=0.0,
95
+ )
96
+ return self.out(x.transpose(1, 2).reshape(B, T, C))
97
+
98
+
99
+ class MLAttention(nn.Module):
100
+ """
101
+ Multi-Head Latent Attention (DeepSeek-V2/V3).
102
+ Compresses k/v into a small latent vector instead of caching full k/v.
103
+ Huge KV cache reduction at inference time.
104
+ """
105
+ def __init__(self, dim: int, n_heads: int, latent_dim: int):
106
+ super().__init__()
107
+ self.n_heads = n_heads
108
+ self.head_dim = dim // n_heads
109
+ self.kv_down = nn.Linear(dim, latent_dim, bias=False)
110
+ self.k_up = nn.Linear(latent_dim, dim, bias=False)
111
+ self.v_up = nn.Linear(latent_dim, dim, bias=False)
112
+
113
+ self.q = nn.Linear(dim, dim, bias=False)
114
+ self.out = nn.Linear(dim, dim, bias=False)
115
+
116
+ def forward(self, x, mask=None):
117
+ B, T, C = x.shape
118
+
119
+ q = self.q(x)
120
+
121
+ # compress → this is what you cache at inference, not full k/v
122
+ latent = self.kv_down(x) # [B, T, latent_dim]
123
+
124
+ k = self.k_up(latent)
125
+ v = self.v_up(latent)
126
+
127
+ def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
128
+ q, k, v = split(q), split(k), split(v)
129
+
130
+ x = F.scaled_dot_product_attention(q, k, v, is_causal=True)
131
+ return self.out(x.transpose(1, 2).reshape(B, T, C))
@@ -0,0 +1,30 @@
1
+ from feed_forward import FFN
2
+ from normalization import LayerNorm
3
+ from attention import MultiHeadAttention
4
+ from torch import nn
5
+
6
+
7
+ class TransformerBlock(nn.Module):
8
+ """
9
+ Default: original 'Attention is All You Need' — MHA + FFN + LayerNorm.
10
+ Swap any component via attn=, ffn=, norm= for modern variants.
11
+ """
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ n_heads: int,
16
+ hidden: int,
17
+ norm = None, # default: LayerNorm (original)
18
+ attn = None, # default: MultiHeadAttention (original)
19
+ ffn = None, # default: FFN + ReLU (original)
20
+ ):
21
+ super().__init__()
22
+ self.norm1 = norm or LayerNorm(dim)
23
+ self.norm2 = norm or LayerNorm(dim)
24
+ self.attn = attn or MultiHeadAttention(dim, n_heads)
25
+ self.ffn = ffn or FFN(dim, hidden)
26
+
27
+ def forward(self, x, mask=None):
28
+ x = x + self.attn(self.norm1(x), mask)
29
+ x = x + self.ffn(self.norm2(x))
30
+ return x
@@ -0,0 +1,114 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class BaseTokenizer(ABC):
5
+ @abstractmethod
6
+ def train(self, texts: list[str], vocab_size: int): ...
7
+
8
+ @abstractmethod
9
+ def encode(self, text: str) -> list[int]: ...
10
+
11
+ @abstractmethod
12
+ def decode(self, ids: list[int]) -> str: ...
13
+
14
+ @abstractmethod
15
+ def save(self, path: str): ...
16
+
17
+ @abstractmethod
18
+ def load(self, path: str): ...
19
+
20
+ @property
21
+ @abstractmethod
22
+ def vocab_size(self) -> int: ...
23
+
24
+ class ByteLevelTokenizer(BaseTokenizer):
25
+ """
26
+ Zero deps. Every byte is a token (0-255).
27
+ Works on any text/language out of the box.
28
+ """
29
+ def train(self, texts, vocab_size=256): pass # nothing to train
30
+
31
+ def encode(self, text: str) -> list[int]:
32
+ return list(text.encode("utf-8"))
33
+
34
+ def decode(self, ids: list[int]) -> str:
35
+ return bytes(ids).decode("utf-8", errors="replace")
36
+
37
+ def save(self, path: str): pass
38
+ def load(self, path: str): pass
39
+
40
+ @property
41
+ def vocab_size(self) -> int: return 256
42
+
43
+
44
+ class HFTokenizer(BaseTokenizer):
45
+ """
46
+ Thin wrapper around any HuggingFace tokenizer.
47
+ pip install transformers
48
+ """
49
+ def __init__(self, model_name: str = "gpt2"):
50
+ from transformers import AutoTokenizer
51
+ self._tok = AutoTokenizer.from_pretrained(model_name)
52
+
53
+ def train(self, texts, vocab_size=None):
54
+ raise NotImplementedError("use HF's train_new_from_iterator for custom training")
55
+
56
+ def encode(self, text: str) -> list[int]:
57
+ return self._tok.encode(text)
58
+
59
+ def decode(self, ids: list[int]) -> str:
60
+ return self._tok.decode(ids)
61
+
62
+ def save(self, path: str):
63
+ self._tok.save_pretrained(path)
64
+
65
+ def load(self, path: str):
66
+ from transformers import AutoTokenizer
67
+ self._tok = AutoTokenizer.from_pretrained(path)
68
+
69
+ @property
70
+ def vocab_size(self) -> int:
71
+ return len(self._tok)
72
+
73
+
74
+ class RustBPETokenizer(BaseTokenizer):
75
+ """
76
+ BPE tokenizer backed by HuggingFace's `tokenizers` Rust crate.
77
+ Trains ~100x faster than pure Python BPE.
78
+ pip install tokenizers
79
+ """
80
+ def __init__(self):
81
+ from tokenizers import Tokenizer
82
+ from tokenizers.models import BPE
83
+ self._tok = Tokenizer(BPE(unk_token="[UNK]"))
84
+ self._trained = False
85
+
86
+ def train(self, texts: list[str], vocab_size: int = 8000):
87
+ from tokenizers.trainers import BpeTrainer
88
+ from tokenizers.pre_tokenizers import Whitespace
89
+
90
+ self._tok.pre_tokenizer = Whitespace()
91
+ trainer = BpeTrainer(
92
+ vocab_size=vocab_size,
93
+ special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"]
94
+ )
95
+ self._tok.train_from_iterator(texts, trainer)
96
+ self._trained = True
97
+
98
+ def encode(self, text: str) -> list[int]:
99
+ return self._tok.encode(text).ids
100
+
101
+ def decode(self, ids: list[int]) -> str:
102
+ return self._tok.decode(ids)
103
+
104
+ def save(self, path: str):
105
+ self._tok.save(path)
106
+
107
+ def load(self, path: str):
108
+ from tokenizers import Tokenizer
109
+ self._tok = Tokenizer.from_file(path)
110
+ self._trained = True
111
+
112
+ @property
113
+ def vocab_size(self) -> int:
114
+ return self._tok.get_vocab_size()
@@ -0,0 +1,23 @@
1
+ class C:
2
+ RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m"
3
+ GREEN = "\033[32m"; CYAN = "\033[36m"; YELLOW = "\033[33m"
4
+ BLUE = "\033[34m"; RED = "\033[31m"; WHITE = "\033[37m"
5
+ MAGENTA = "\033[35m"
6
+
7
+ def _bar(current, total, width=28):
8
+ filled = int(width * current / max(total, 1))
9
+ return f"{C.CYAN}{'█' * filled}{'░' * (width - filled)}{C.RESET}"
10
+
11
+ def _section(title):
12
+ print(f"\n{C.BOLD}{C.CYAN}{'─' * 52}{C.RESET}")
13
+ print(f"{C.BOLD}{C.CYAN} {title}{C.RESET}")
14
+ print(f"{C.BOLD}{C.CYAN}{'─' * 52}{C.RESET}")
15
+
16
+ def _info(label, value):
17
+ print(f" {C.DIM}{label:<18}{C.RESET} {C.WHITE}{value}{C.RESET}")
18
+
19
+ def _ok(msg):
20
+ print(f" {C.GREEN}✓{C.RESET} {msg}")
21
+
22
+ def _err(msg):
23
+ print(f" {C.RED}✗{C.RESET} {msg}")
@@ -0,0 +1,184 @@
1
+ import torch
2
+ import struct
3
+ from pathlib import Path
4
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
5
+ from c_tokenizers import BaseTokenizer
6
+ from colors import C, _section, _info, _ok, _bar
7
+
8
+
9
+ # ─── config ───────────────────────────────────────────────────────────────────
10
+
11
+ class DataConfig:
12
+ def __init__(
13
+ self,
14
+ seq_len: int = 512,
15
+ batch_size: int = 16,
16
+ shuffle: bool = True,
17
+ num_workers: int = 0,
18
+ split: float = 0.9,
19
+ streaming: bool = False,
20
+ ):
21
+ self.seq_len = seq_len
22
+ self.batch_size = batch_size
23
+ self.shuffle = shuffle
24
+ self.num_workers = num_workers
25
+ self.split = split
26
+ self.streaming = streaming
27
+
28
+
29
+ # ─── dataset ──────────────────────────────────────────────────────────────────
30
+
31
+ class TokenDataset(Dataset):
32
+ def __init__(self, tokens: torch.Tensor, seq_len: int):
33
+ self.tokens = tokens
34
+ self.seq_len = seq_len
35
+
36
+ def __len__(self):
37
+ return max(0, len(self.tokens) - self.seq_len)
38
+
39
+ def __getitem__(self, idx):
40
+ x = self.tokens[idx: idx + self.seq_len]
41
+ y = self.tokens[idx + 1: idx + self.seq_len + 1]
42
+ return x, y
43
+
44
+
45
+ class StreamingDataset(IterableDataset):
46
+ def __init__(self, paths: list[str], tokenizer: BaseTokenizer, seq_len: int):
47
+ self.paths = paths
48
+ self.tokenizer = tokenizer
49
+ self.seq_len = seq_len
50
+
51
+ def __iter__(self):
52
+ buf = []
53
+ for path in self.paths:
54
+ for line in open(path, encoding="utf-8", errors="replace"):
55
+ buf.extend(self.tokenizer.encode(line))
56
+ while len(buf) >= self.seq_len + 1:
57
+ chunk = buf[:self.seq_len + 1]
58
+ buf = buf[self.seq_len:]
59
+ yield (
60
+ torch.tensor(chunk[:-1], dtype=torch.long),
61
+ torch.tensor(chunk[1:], dtype=torch.long),
62
+ )
63
+
64
+
65
+ # ─── tokenize + split ─────────────────────────────────────────────────────────
66
+
67
+ def _tokenize(texts: list[str], tokenizer: BaseTokenizer) -> torch.Tensor:
68
+ ids = []
69
+ for i, text in enumerate(texts):
70
+ print(f"\r {C.DIM}tokenizing{C.RESET} {_bar(i+1, len(texts))} {C.DIM}{i+1}/{len(texts)}{C.RESET}", end="", flush=True)
71
+ ids.extend(tokenizer.encode(text))
72
+ print()
73
+ return torch.tensor(ids, dtype=torch.long)
74
+
75
+
76
+ def _split(tokens: torch.Tensor, cfg: DataConfig) -> tuple[TokenDataset, TokenDataset]:
77
+ n = len(tokens)
78
+ split_at = int(n * cfg.split)
79
+ train_ds = TokenDataset(tokens[:split_at], cfg.seq_len)
80
+ val_ds = TokenDataset(tokens[split_at:], cfg.seq_len)
81
+
82
+ if len(val_ds) == 0:
83
+ print(f" {C.YELLOW}⚠ val set empty — using last 10% of train{C.RESET}")
84
+ split_at = int(split_at * 0.9)
85
+ train_ds = TokenDataset(tokens[:split_at], cfg.seq_len)
86
+ val_ds = TokenDataset(tokens[split_at:], cfg.seq_len)
87
+
88
+ return train_ds, val_ds
89
+
90
+
91
+ def _loaders(train_ds, val_ds, cfg, shuffle=None) -> tuple[DataLoader, DataLoader]:
92
+ s = cfg.shuffle if shuffle is None else shuffle
93
+ train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=s, num_workers=cfg.num_workers)
94
+ val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
95
+ print(f" {C.DIM}train{C.RESET} {C.WHITE}{len(train_ds):>10,}{C.RESET} samples {C.DIM}│{C.RESET} {C.YELLOW}{len(train_dl):,}{C.RESET} batches")
96
+ print(f" {C.DIM}val {C.RESET} {C.WHITE}{len(val_ds):>10,}{C.RESET} samples {C.DIM}│{C.RESET} {C.YELLOW}{len(val_dl):,}{C.RESET} batches\n")
97
+ return train_dl, val_dl
98
+
99
+
100
+ # ─── binary ───────────────────────────────────────────────────────────────────
101
+
102
+ def save_binary(tokens: list[int], path: str):
103
+ Path(path).write_bytes(struct.pack(f"{len(tokens)}H", *tokens))
104
+ print(f" {C.GREEN}✓{C.RESET} saved {C.CYAN}{len(tokens):,}{C.RESET} tokens → {C.DIM}{path}{C.RESET}")
105
+
106
+
107
+ def load_binary(path: str) -> torch.Tensor:
108
+ raw = Path(path).read_bytes()
109
+ ids = struct.unpack(f"{len(raw)//2}H", raw)
110
+ tok = torch.tensor(ids, dtype=torch.long)
111
+ print(f" {C.GREEN}✓{C.RESET} loaded {C.CYAN}{len(tok):,}{C.RESET} tokens ← {C.DIM}{path}{C.RESET}")
112
+ return tok
113
+
114
+
115
+ # ─── public API ───────────────────────────────────────────────────────────────
116
+
117
+ def from_binary(path: str, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
118
+ _section("💾 Binary dataset")
119
+ _info("path", path)
120
+ tokens = load_binary(path)
121
+ train_ds, val_ds = _split(tokens, cfg)
122
+ return _loaders(train_ds, val_ds, cfg)
123
+
124
+
125
+ def from_files(paths: list[str], tokenizer: BaseTokenizer, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
126
+ _section("📂 File dataset")
127
+ for p in paths: _info("file", p)
128
+
129
+ if cfg.streaming:
130
+ _info("mode", "streaming")
131
+ split_at = max(1, int(len(paths) * cfg.split))
132
+ train_ds = StreamingDataset(paths[:split_at], tokenizer, cfg.seq_len)
133
+ val_ds = StreamingDataset(paths[split_at:] or paths[-1:], tokenizer, cfg.seq_len)
134
+ return _loaders(train_ds, val_ds, cfg, shuffle=False)
135
+
136
+ _info("mode", "in-memory")
137
+ texts = [Path(p).read_text(encoding="utf-8", errors="replace") for p in paths]
138
+ tokens = _tokenize(texts, tokenizer)
139
+ train_ds, val_ds = _split(tokens, cfg)
140
+ return _loaders(train_ds, val_ds, cfg)
141
+
142
+
143
+ def from_hf(dataset_name: str, tokenizer: BaseTokenizer, cfg: DataConfig,
144
+ split: str = "train", text_col: str = "text") -> tuple[DataLoader, DataLoader]:
145
+ _section("🤗 HuggingFace dataset")
146
+ _info("dataset", dataset_name)
147
+ _info("split", split)
148
+ _info("streaming", str(cfg.streaming))
149
+
150
+ from datasets import load_dataset
151
+ ds = load_dataset(dataset_name, split=split, streaming=cfg.streaming)
152
+
153
+ if cfg.streaming:
154
+ def _gen():
155
+ buf = []
156
+ for row in ds:
157
+ buf.extend(tokenizer.encode(row[text_col]))
158
+ while len(buf) >= cfg.seq_len + 1:
159
+ chunk = buf[:cfg.seq_len + 1]
160
+ buf = buf[cfg.seq_len:]
161
+ yield (
162
+ torch.tensor(chunk[:-1], dtype=torch.long),
163
+ torch.tensor(chunk[1:], dtype=torch.long),
164
+ )
165
+ class _HFStream(IterableDataset):
166
+ def __iter__(self): return _gen()
167
+ dl = DataLoader(_HFStream(), batch_size=cfg.batch_size, num_workers=cfg.num_workers)
168
+ _ok("streaming dataloader ready")
169
+ return dl, dl
170
+
171
+ print(f" {C.YELLOW}⏳ downloading...{C.RESET}", flush=True)
172
+ texts = [row[text_col] for row in ds]
173
+ _ok(f"downloaded {len(texts):,} documents")
174
+ tokens = _tokenize(texts, tokenizer)
175
+ train_ds, val_ds = _split(tokens, cfg)
176
+ return _loaders(train_ds, val_ds, cfg)
177
+
178
+
179
+ def from_strings(texts: list[str], tokenizer: BaseTokenizer, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
180
+ _section("📝 String dataset")
181
+ _info("documents", str(len(texts)))
182
+ tokens = _tokenize(texts, tokenizer)
183
+ train_ds, val_ds = _split(tokens, cfg)
184
+ return _loaders(train_ds, val_ds, cfg)
@@ -0,0 +1,51 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FFN(nn.Module):
7
+ """Standard FFN. Used in original Transformer, BERT."""
8
+ def __init__(self, dim: int, hidden: int):
9
+ super().__init__()
10
+ self.net = nn.Sequential(
11
+ nn.Linear(dim, hidden),
12
+ nn.GELU(),
13
+ nn.Linear(hidden, dim)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.net(x)
18
+
19
+
20
+ class SwiGLU(nn.Module):
21
+ """Gated FFN with Swish. Used in LLaMA, Mistral, PaLM."""
22
+ def __init__(self, dim: int, hidden: int):
23
+ super().__init__()
24
+ self.w1 = nn.Linear(dim, hidden, bias=False)
25
+ self.w2 = nn.Linear(hidden, dim, bias=False)
26
+ self.w3 = nn.Linear(dim, hidden, bias=False)
27
+
28
+ def forward(self, x):
29
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
30
+
31
+
32
+ class MoE(nn.Module):
33
+ """Mixture of Experts. Used in Mixtral, GPT-4 (rumoured)."""
34
+ def __init__(self, dim: int, hidden: int, n_experts: int, top_k: int = 2):
35
+ super().__init__()
36
+ self.top_k = top_k
37
+ self.gate = nn.Linear(dim, n_experts, bias=False)
38
+ self.experts = nn.ModuleList([SwiGLU(dim, hidden) for _ in range(n_experts)])
39
+
40
+ def forward(self, x):
41
+ B, T, C = x.shape
42
+ x_flat = x.view(-1, C)
43
+ weights, idx = self.gate(x_flat).topk(self.top_k, dim=-1) # [B*T, top_k]
44
+ weights = F.softmax(weights, dim=-1)
45
+ out = torch.zeros_like(x_flat)
46
+ for i, expert in enumerate(self.experts):
47
+ mask = (idx == i).any(dim=-1)
48
+ if mask.any():
49
+ w = weights[mask][idx[mask] == i].unsqueeze(-1)
50
+ out[mask] += w * expert(x_flat[mask])
51
+ return out.view(B, T, C)