transformer-toolkit 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformer_toolkit-0.1.0/PKG-INFO +28 -0
- transformer_toolkit-0.1.0/README.md +0 -0
- transformer_toolkit-0.1.0/pyproject.toml +43 -0
- transformer_toolkit-0.1.0/setup.cfg +4 -0
- transformer_toolkit-0.1.0/transformer_toolkit/__init__.py +16 -0
- transformer_toolkit-0.1.0/transformer_toolkit/attention.py +131 -0
- transformer_toolkit-0.1.0/transformer_toolkit/block.py +30 -0
- transformer_toolkit-0.1.0/transformer_toolkit/c_tokenizers.py +114 -0
- transformer_toolkit-0.1.0/transformer_toolkit/colors.py +23 -0
- transformer_toolkit-0.1.0/transformer_toolkit/dataloader.py +184 -0
- transformer_toolkit-0.1.0/transformer_toolkit/feed_forward.py +51 -0
- transformer_toolkit-0.1.0/transformer_toolkit/hf_hub.py +257 -0
- transformer_toolkit-0.1.0/transformer_toolkit/inference.py +151 -0
- transformer_toolkit-0.1.0/transformer_toolkit/model.py +128 -0
- transformer_toolkit-0.1.0/transformer_toolkit/normalization.py +39 -0
- transformer_toolkit-0.1.0/transformer_toolkit/positional_encodings.py +63 -0
- transformer_toolkit-0.1.0/transformer_toolkit/sentiment.py +299 -0
- transformer_toolkit-0.1.0/transformer_toolkit/trainer.py +368 -0
- transformer_toolkit-0.1.0/transformer_toolkit.egg-info/PKG-INFO +28 -0
- transformer_toolkit-0.1.0/transformer_toolkit.egg-info/SOURCES.txt +21 -0
- transformer_toolkit-0.1.0/transformer_toolkit.egg-info/dependency_links.txt +1 -0
- transformer_toolkit-0.1.0/transformer_toolkit.egg-info/requires.txt +13 -0
- transformer_toolkit-0.1.0/transformer_toolkit.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: transformer-toolkit
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Minimal, modular transformer library for training your own LLM
|
|
5
|
+
Author: Govind Barbade
|
|
6
|
+
Project-URL: Homepage, https://github.com/govindbarbade/transformer-toolkit
|
|
7
|
+
Project-URL: Repository, https://github.com/govindbarbade/transformer-toolkit
|
|
8
|
+
Keywords: transformer,llm,deep learning,nlp,pytorch
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: torch>=2.0.0
|
|
20
|
+
Requires-Dist: pydantic>=2.0.0
|
|
21
|
+
Provides-Extra: tokenizers
|
|
22
|
+
Requires-Dist: tokenizers>=0.15.0; extra == "tokenizers"
|
|
23
|
+
Provides-Extra: hf
|
|
24
|
+
Requires-Dist: transformers>=4.35.0; extra == "hf"
|
|
25
|
+
Requires-Dist: huggingface_hub>=0.20.0; extra == "hf"
|
|
26
|
+
Requires-Dist: datasets>=2.14.0; extra == "hf"
|
|
27
|
+
Provides-Extra: all
|
|
28
|
+
Requires-Dist: transformer-toolkit[hf,tokenizers]; extra == "all"
|
|
File without changes
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "transformer-toolkit"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Minimal, modular transformer library for training your own LLM"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { file = "LICENSE" }
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
authors = [{ name = "Govind Barbade" }]
|
|
13
|
+
keywords = ["transformer", "llm", "deep learning", "nlp", "pytorch"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
dependencies = [
|
|
26
|
+
"torch>=2.0.0",
|
|
27
|
+
"pydantic>=2.0.0",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
tokenizers = ["tokenizers>=0.15.0"]
|
|
32
|
+
hf = ["transformers>=4.35.0",
|
|
33
|
+
"huggingface_hub>=0.20.0",
|
|
34
|
+
"datasets>=2.14.0"]
|
|
35
|
+
all = ["transformer-toolkit[tokenizers,hf]"]
|
|
36
|
+
|
|
37
|
+
[project.urls]
|
|
38
|
+
Homepage = "https://github.com/govindbarbade/transformer-toolkit"
|
|
39
|
+
Repository = "https://github.com/govindbarbade/transformer-toolkit"
|
|
40
|
+
|
|
41
|
+
[tool.setuptools.packages.find]
|
|
42
|
+
where = ["."]
|
|
43
|
+
include = ["transformer_toolkit*"]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .model import Transformer, TransformerConfig
|
|
2
|
+
from .trainer import Trainer, TrainConfig
|
|
3
|
+
from .dataloader import DataConfig, from_files, from_binary, from_hf, from_strings
|
|
4
|
+
from .c_tokenizers import RustBPETokenizer, HFTokenizer, ByteLevelTokenizer, CharLevelTokenizer
|
|
5
|
+
from .hf_hub import login, push_to_hub, pull_from_hub
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
__author__ = "Govind Barbade"
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Transformer", "TransformerConfig",
|
|
12
|
+
"Trainer", "TrainConfig",
|
|
13
|
+
"DataConfig", "from_files", "from_binary", "from_hf", "from_strings",
|
|
14
|
+
"RustBPETokenizer", "HFTokenizer", "ByteLevelTokenizer", "CharLevelTokenizer",
|
|
15
|
+
"login", "push_to_hub", "pull_from_hub",
|
|
16
|
+
]
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
class MultiHeadAttention(nn.Module):
|
|
6
|
+
"""Classic MHA. Used in original Transformer, BERT, GPT-2."""
|
|
7
|
+
def __init__(self, dim: int, n_heads: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.n_heads = n_heads
|
|
10
|
+
self.head_dim = dim // n_heads
|
|
11
|
+
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
|
|
12
|
+
self.out = nn.Linear(dim, dim, bias=False)
|
|
13
|
+
|
|
14
|
+
def forward(self, x, mask=None):
|
|
15
|
+
B, T, C = x.shape
|
|
16
|
+
q, k, v = self.qkv(x).split(C, dim=-1)
|
|
17
|
+
def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
18
|
+
q, k, v = split(q), split(k), split(v)
|
|
19
|
+
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
20
|
+
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
21
|
+
return self.out(F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C) # noqa
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GroupedQueryAttention(nn.Module):
|
|
25
|
+
"""Fewer k/v heads than q heads. Used in LLaMA 3, Mistral."""
|
|
26
|
+
def __init__(self, dim: int, n_heads: int, n_kv_heads: int):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.n_heads = n_heads
|
|
29
|
+
self.n_kv_heads = n_kv_heads
|
|
30
|
+
self.head_dim = dim // n_heads
|
|
31
|
+
self.q = nn.Linear(dim, dim, bias=False)
|
|
32
|
+
self.kv = nn.Linear(dim, 2 * n_kv_heads * self.head_dim, bias=False)
|
|
33
|
+
self.out = nn.Linear(dim, dim, bias=False)
|
|
34
|
+
|
|
35
|
+
def forward(self, x, mask=None):
|
|
36
|
+
B, T, C = x.shape
|
|
37
|
+
q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
38
|
+
k, v = self.kv(x).split(self.n_kv_heads * self.head_dim, dim=-1)
|
|
39
|
+
def split_kv(t): return t.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
|
40
|
+
k, v = split_kv(k), split_kv(v)
|
|
41
|
+
# repeat k/v to match q heads
|
|
42
|
+
r = self.n_heads // self.n_kv_heads
|
|
43
|
+
k = k.repeat_interleave(r, dim=1)
|
|
44
|
+
v = v.repeat_interleave(r, dim=1)
|
|
45
|
+
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
46
|
+
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
47
|
+
return self.out((F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MultiQueryAttention(nn.Module):
|
|
51
|
+
"""Single k/v head shared across all q heads. Used in Falcon, early Gemini."""
|
|
52
|
+
def __init__(self, dim: int, n_heads: int):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.n_heads = n_heads
|
|
55
|
+
self.head_dim = dim // n_heads
|
|
56
|
+
self.q = nn.Linear(dim, dim, bias=False)
|
|
57
|
+
self.k = nn.Linear(dim, self.head_dim, bias=False)
|
|
58
|
+
self.v = nn.Linear(dim, self.head_dim, bias=False)
|
|
59
|
+
self.out = nn.Linear(dim, dim, bias=False)
|
|
60
|
+
|
|
61
|
+
def forward(self, x, mask=None):
|
|
62
|
+
B, T, C = x.shape
|
|
63
|
+
q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
64
|
+
k = self.k(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(B, self.n_heads, T, self.head_dim)
|
|
65
|
+
v = self.v(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(B, self.n_heads, T, self.head_dim)
|
|
66
|
+
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
67
|
+
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
68
|
+
return self.out((F.softmax(scores, -1) @ v).transpose(1, 2).reshape(B, T, C))
|
|
69
|
+
|
|
70
|
+
class FlashAttention(nn.Module):
|
|
71
|
+
"""
|
|
72
|
+
Flash Attention — same result as MHA, far less memory.
|
|
73
|
+
Uses torch's built-in scaled_dot_product_attention (torch >= 2.0)
|
|
74
|
+
which calls the CUDA kernel automatically when available.
|
|
75
|
+
"""
|
|
76
|
+
def __init__(self, dim: int, n_heads: int):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.n_heads = n_heads
|
|
79
|
+
self.head_dim = dim // n_heads
|
|
80
|
+
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
|
|
81
|
+
self.out = nn.Linear(dim, dim, bias=False)
|
|
82
|
+
|
|
83
|
+
def forward(self, x, mask=None, causal=True):
|
|
84
|
+
B, T, C = x.shape
|
|
85
|
+
q, k, v = self.qkv(x).split(C, dim=-1)
|
|
86
|
+
def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
87
|
+
q, k, v = split(q), split(k), split(v)
|
|
88
|
+
|
|
89
|
+
# one line — PyTorch handles the tiling/chunking kernel
|
|
90
|
+
x = nn.functional.scaled_dot_product_attention(
|
|
91
|
+
q, k, v,
|
|
92
|
+
attn_mask=mask,
|
|
93
|
+
is_causal=causal, # builds causal mask internally, no extra memory
|
|
94
|
+
dropout_p=0.0,
|
|
95
|
+
)
|
|
96
|
+
return self.out(x.transpose(1, 2).reshape(B, T, C))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class MLAttention(nn.Module):
|
|
100
|
+
"""
|
|
101
|
+
Multi-Head Latent Attention (DeepSeek-V2/V3).
|
|
102
|
+
Compresses k/v into a small latent vector instead of caching full k/v.
|
|
103
|
+
Huge KV cache reduction at inference time.
|
|
104
|
+
"""
|
|
105
|
+
def __init__(self, dim: int, n_heads: int, latent_dim: int):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.n_heads = n_heads
|
|
108
|
+
self.head_dim = dim // n_heads
|
|
109
|
+
self.kv_down = nn.Linear(dim, latent_dim, bias=False)
|
|
110
|
+
self.k_up = nn.Linear(latent_dim, dim, bias=False)
|
|
111
|
+
self.v_up = nn.Linear(latent_dim, dim, bias=False)
|
|
112
|
+
|
|
113
|
+
self.q = nn.Linear(dim, dim, bias=False)
|
|
114
|
+
self.out = nn.Linear(dim, dim, bias=False)
|
|
115
|
+
|
|
116
|
+
def forward(self, x, mask=None):
|
|
117
|
+
B, T, C = x.shape
|
|
118
|
+
|
|
119
|
+
q = self.q(x)
|
|
120
|
+
|
|
121
|
+
# compress → this is what you cache at inference, not full k/v
|
|
122
|
+
latent = self.kv_down(x) # [B, T, latent_dim]
|
|
123
|
+
|
|
124
|
+
k = self.k_up(latent)
|
|
125
|
+
v = self.v_up(latent)
|
|
126
|
+
|
|
127
|
+
def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
128
|
+
q, k, v = split(q), split(k), split(v)
|
|
129
|
+
|
|
130
|
+
x = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
131
|
+
return self.out(x.transpose(1, 2).reshape(B, T, C))
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from feed_forward import FFN
|
|
2
|
+
from normalization import LayerNorm
|
|
3
|
+
from attention import MultiHeadAttention
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TransformerBlock(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Default: original 'Attention is All You Need' — MHA + FFN + LayerNorm.
|
|
10
|
+
Swap any component via attn=, ffn=, norm= for modern variants.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
dim: int,
|
|
15
|
+
n_heads: int,
|
|
16
|
+
hidden: int,
|
|
17
|
+
norm = None, # default: LayerNorm (original)
|
|
18
|
+
attn = None, # default: MultiHeadAttention (original)
|
|
19
|
+
ffn = None, # default: FFN + ReLU (original)
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.norm1 = norm or LayerNorm(dim)
|
|
23
|
+
self.norm2 = norm or LayerNorm(dim)
|
|
24
|
+
self.attn = attn or MultiHeadAttention(dim, n_heads)
|
|
25
|
+
self.ffn = ffn or FFN(dim, hidden)
|
|
26
|
+
|
|
27
|
+
def forward(self, x, mask=None):
|
|
28
|
+
x = x + self.attn(self.norm1(x), mask)
|
|
29
|
+
x = x + self.ffn(self.norm2(x))
|
|
30
|
+
return x
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BaseTokenizer(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def train(self, texts: list[str], vocab_size: int): ...
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def encode(self, text: str) -> list[int]: ...
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def decode(self, ids: list[int]) -> str: ...
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def save(self, path: str): ...
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def load(self, path: str): ...
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def vocab_size(self) -> int: ...
|
|
23
|
+
|
|
24
|
+
class ByteLevelTokenizer(BaseTokenizer):
|
|
25
|
+
"""
|
|
26
|
+
Zero deps. Every byte is a token (0-255).
|
|
27
|
+
Works on any text/language out of the box.
|
|
28
|
+
"""
|
|
29
|
+
def train(self, texts, vocab_size=256): pass # nothing to train
|
|
30
|
+
|
|
31
|
+
def encode(self, text: str) -> list[int]:
|
|
32
|
+
return list(text.encode("utf-8"))
|
|
33
|
+
|
|
34
|
+
def decode(self, ids: list[int]) -> str:
|
|
35
|
+
return bytes(ids).decode("utf-8", errors="replace")
|
|
36
|
+
|
|
37
|
+
def save(self, path: str): pass
|
|
38
|
+
def load(self, path: str): pass
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def vocab_size(self) -> int: return 256
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class HFTokenizer(BaseTokenizer):
|
|
45
|
+
"""
|
|
46
|
+
Thin wrapper around any HuggingFace tokenizer.
|
|
47
|
+
pip install transformers
|
|
48
|
+
"""
|
|
49
|
+
def __init__(self, model_name: str = "gpt2"):
|
|
50
|
+
from transformers import AutoTokenizer
|
|
51
|
+
self._tok = AutoTokenizer.from_pretrained(model_name)
|
|
52
|
+
|
|
53
|
+
def train(self, texts, vocab_size=None):
|
|
54
|
+
raise NotImplementedError("use HF's train_new_from_iterator for custom training")
|
|
55
|
+
|
|
56
|
+
def encode(self, text: str) -> list[int]:
|
|
57
|
+
return self._tok.encode(text)
|
|
58
|
+
|
|
59
|
+
def decode(self, ids: list[int]) -> str:
|
|
60
|
+
return self._tok.decode(ids)
|
|
61
|
+
|
|
62
|
+
def save(self, path: str):
|
|
63
|
+
self._tok.save_pretrained(path)
|
|
64
|
+
|
|
65
|
+
def load(self, path: str):
|
|
66
|
+
from transformers import AutoTokenizer
|
|
67
|
+
self._tok = AutoTokenizer.from_pretrained(path)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def vocab_size(self) -> int:
|
|
71
|
+
return len(self._tok)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class RustBPETokenizer(BaseTokenizer):
|
|
75
|
+
"""
|
|
76
|
+
BPE tokenizer backed by HuggingFace's `tokenizers` Rust crate.
|
|
77
|
+
Trains ~100x faster than pure Python BPE.
|
|
78
|
+
pip install tokenizers
|
|
79
|
+
"""
|
|
80
|
+
def __init__(self):
|
|
81
|
+
from tokenizers import Tokenizer
|
|
82
|
+
from tokenizers.models import BPE
|
|
83
|
+
self._tok = Tokenizer(BPE(unk_token="[UNK]"))
|
|
84
|
+
self._trained = False
|
|
85
|
+
|
|
86
|
+
def train(self, texts: list[str], vocab_size: int = 8000):
|
|
87
|
+
from tokenizers.trainers import BpeTrainer
|
|
88
|
+
from tokenizers.pre_tokenizers import Whitespace
|
|
89
|
+
|
|
90
|
+
self._tok.pre_tokenizer = Whitespace()
|
|
91
|
+
trainer = BpeTrainer(
|
|
92
|
+
vocab_size=vocab_size,
|
|
93
|
+
special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"]
|
|
94
|
+
)
|
|
95
|
+
self._tok.train_from_iterator(texts, trainer)
|
|
96
|
+
self._trained = True
|
|
97
|
+
|
|
98
|
+
def encode(self, text: str) -> list[int]:
|
|
99
|
+
return self._tok.encode(text).ids
|
|
100
|
+
|
|
101
|
+
def decode(self, ids: list[int]) -> str:
|
|
102
|
+
return self._tok.decode(ids)
|
|
103
|
+
|
|
104
|
+
def save(self, path: str):
|
|
105
|
+
self._tok.save(path)
|
|
106
|
+
|
|
107
|
+
def load(self, path: str):
|
|
108
|
+
from tokenizers import Tokenizer
|
|
109
|
+
self._tok = Tokenizer.from_file(path)
|
|
110
|
+
self._trained = True
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def vocab_size(self) -> int:
|
|
114
|
+
return self._tok.get_vocab_size()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
class C:
|
|
2
|
+
RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m"
|
|
3
|
+
GREEN = "\033[32m"; CYAN = "\033[36m"; YELLOW = "\033[33m"
|
|
4
|
+
BLUE = "\033[34m"; RED = "\033[31m"; WHITE = "\033[37m"
|
|
5
|
+
MAGENTA = "\033[35m"
|
|
6
|
+
|
|
7
|
+
def _bar(current, total, width=28):
|
|
8
|
+
filled = int(width * current / max(total, 1))
|
|
9
|
+
return f"{C.CYAN}{'█' * filled}{'░' * (width - filled)}{C.RESET}"
|
|
10
|
+
|
|
11
|
+
def _section(title):
|
|
12
|
+
print(f"\n{C.BOLD}{C.CYAN}{'─' * 52}{C.RESET}")
|
|
13
|
+
print(f"{C.BOLD}{C.CYAN} {title}{C.RESET}")
|
|
14
|
+
print(f"{C.BOLD}{C.CYAN}{'─' * 52}{C.RESET}")
|
|
15
|
+
|
|
16
|
+
def _info(label, value):
|
|
17
|
+
print(f" {C.DIM}{label:<18}{C.RESET} {C.WHITE}{value}{C.RESET}")
|
|
18
|
+
|
|
19
|
+
def _ok(msg):
|
|
20
|
+
print(f" {C.GREEN}✓{C.RESET} {msg}")
|
|
21
|
+
|
|
22
|
+
def _err(msg):
|
|
23
|
+
print(f" {C.RED}✗{C.RESET} {msg}")
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import struct
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
|
5
|
+
from c_tokenizers import BaseTokenizer
|
|
6
|
+
from colors import C, _section, _info, _ok, _bar
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# ─── config ───────────────────────────────────────────────────────────────────
|
|
10
|
+
|
|
11
|
+
class DataConfig:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
seq_len: int = 512,
|
|
15
|
+
batch_size: int = 16,
|
|
16
|
+
shuffle: bool = True,
|
|
17
|
+
num_workers: int = 0,
|
|
18
|
+
split: float = 0.9,
|
|
19
|
+
streaming: bool = False,
|
|
20
|
+
):
|
|
21
|
+
self.seq_len = seq_len
|
|
22
|
+
self.batch_size = batch_size
|
|
23
|
+
self.shuffle = shuffle
|
|
24
|
+
self.num_workers = num_workers
|
|
25
|
+
self.split = split
|
|
26
|
+
self.streaming = streaming
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ─── dataset ──────────────────────────────────────────────────────────────────
|
|
30
|
+
|
|
31
|
+
class TokenDataset(Dataset):
|
|
32
|
+
def __init__(self, tokens: torch.Tensor, seq_len: int):
|
|
33
|
+
self.tokens = tokens
|
|
34
|
+
self.seq_len = seq_len
|
|
35
|
+
|
|
36
|
+
def __len__(self):
|
|
37
|
+
return max(0, len(self.tokens) - self.seq_len)
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, idx):
|
|
40
|
+
x = self.tokens[idx: idx + self.seq_len]
|
|
41
|
+
y = self.tokens[idx + 1: idx + self.seq_len + 1]
|
|
42
|
+
return x, y
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class StreamingDataset(IterableDataset):
|
|
46
|
+
def __init__(self, paths: list[str], tokenizer: BaseTokenizer, seq_len: int):
|
|
47
|
+
self.paths = paths
|
|
48
|
+
self.tokenizer = tokenizer
|
|
49
|
+
self.seq_len = seq_len
|
|
50
|
+
|
|
51
|
+
def __iter__(self):
|
|
52
|
+
buf = []
|
|
53
|
+
for path in self.paths:
|
|
54
|
+
for line in open(path, encoding="utf-8", errors="replace"):
|
|
55
|
+
buf.extend(self.tokenizer.encode(line))
|
|
56
|
+
while len(buf) >= self.seq_len + 1:
|
|
57
|
+
chunk = buf[:self.seq_len + 1]
|
|
58
|
+
buf = buf[self.seq_len:]
|
|
59
|
+
yield (
|
|
60
|
+
torch.tensor(chunk[:-1], dtype=torch.long),
|
|
61
|
+
torch.tensor(chunk[1:], dtype=torch.long),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ─── tokenize + split ─────────────────────────────────────────────────────────
|
|
66
|
+
|
|
67
|
+
def _tokenize(texts: list[str], tokenizer: BaseTokenizer) -> torch.Tensor:
|
|
68
|
+
ids = []
|
|
69
|
+
for i, text in enumerate(texts):
|
|
70
|
+
print(f"\r {C.DIM}tokenizing{C.RESET} {_bar(i+1, len(texts))} {C.DIM}{i+1}/{len(texts)}{C.RESET}", end="", flush=True)
|
|
71
|
+
ids.extend(tokenizer.encode(text))
|
|
72
|
+
print()
|
|
73
|
+
return torch.tensor(ids, dtype=torch.long)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _split(tokens: torch.Tensor, cfg: DataConfig) -> tuple[TokenDataset, TokenDataset]:
|
|
77
|
+
n = len(tokens)
|
|
78
|
+
split_at = int(n * cfg.split)
|
|
79
|
+
train_ds = TokenDataset(tokens[:split_at], cfg.seq_len)
|
|
80
|
+
val_ds = TokenDataset(tokens[split_at:], cfg.seq_len)
|
|
81
|
+
|
|
82
|
+
if len(val_ds) == 0:
|
|
83
|
+
print(f" {C.YELLOW}⚠ val set empty — using last 10% of train{C.RESET}")
|
|
84
|
+
split_at = int(split_at * 0.9)
|
|
85
|
+
train_ds = TokenDataset(tokens[:split_at], cfg.seq_len)
|
|
86
|
+
val_ds = TokenDataset(tokens[split_at:], cfg.seq_len)
|
|
87
|
+
|
|
88
|
+
return train_ds, val_ds
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _loaders(train_ds, val_ds, cfg, shuffle=None) -> tuple[DataLoader, DataLoader]:
|
|
92
|
+
s = cfg.shuffle if shuffle is None else shuffle
|
|
93
|
+
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=s, num_workers=cfg.num_workers)
|
|
94
|
+
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
|
|
95
|
+
print(f" {C.DIM}train{C.RESET} {C.WHITE}{len(train_ds):>10,}{C.RESET} samples {C.DIM}│{C.RESET} {C.YELLOW}{len(train_dl):,}{C.RESET} batches")
|
|
96
|
+
print(f" {C.DIM}val {C.RESET} {C.WHITE}{len(val_ds):>10,}{C.RESET} samples {C.DIM}│{C.RESET} {C.YELLOW}{len(val_dl):,}{C.RESET} batches\n")
|
|
97
|
+
return train_dl, val_dl
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ─── binary ───────────────────────────────────────────────────────────────────
|
|
101
|
+
|
|
102
|
+
def save_binary(tokens: list[int], path: str):
|
|
103
|
+
Path(path).write_bytes(struct.pack(f"{len(tokens)}H", *tokens))
|
|
104
|
+
print(f" {C.GREEN}✓{C.RESET} saved {C.CYAN}{len(tokens):,}{C.RESET} tokens → {C.DIM}{path}{C.RESET}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_binary(path: str) -> torch.Tensor:
|
|
108
|
+
raw = Path(path).read_bytes()
|
|
109
|
+
ids = struct.unpack(f"{len(raw)//2}H", raw)
|
|
110
|
+
tok = torch.tensor(ids, dtype=torch.long)
|
|
111
|
+
print(f" {C.GREEN}✓{C.RESET} loaded {C.CYAN}{len(tok):,}{C.RESET} tokens ← {C.DIM}{path}{C.RESET}")
|
|
112
|
+
return tok
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# ─── public API ───────────────────────────────────────────────────────────────
|
|
116
|
+
|
|
117
|
+
def from_binary(path: str, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
|
|
118
|
+
_section("💾 Binary dataset")
|
|
119
|
+
_info("path", path)
|
|
120
|
+
tokens = load_binary(path)
|
|
121
|
+
train_ds, val_ds = _split(tokens, cfg)
|
|
122
|
+
return _loaders(train_ds, val_ds, cfg)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def from_files(paths: list[str], tokenizer: BaseTokenizer, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
|
|
126
|
+
_section("📂 File dataset")
|
|
127
|
+
for p in paths: _info("file", p)
|
|
128
|
+
|
|
129
|
+
if cfg.streaming:
|
|
130
|
+
_info("mode", "streaming")
|
|
131
|
+
split_at = max(1, int(len(paths) * cfg.split))
|
|
132
|
+
train_ds = StreamingDataset(paths[:split_at], tokenizer, cfg.seq_len)
|
|
133
|
+
val_ds = StreamingDataset(paths[split_at:] or paths[-1:], tokenizer, cfg.seq_len)
|
|
134
|
+
return _loaders(train_ds, val_ds, cfg, shuffle=False)
|
|
135
|
+
|
|
136
|
+
_info("mode", "in-memory")
|
|
137
|
+
texts = [Path(p).read_text(encoding="utf-8", errors="replace") for p in paths]
|
|
138
|
+
tokens = _tokenize(texts, tokenizer)
|
|
139
|
+
train_ds, val_ds = _split(tokens, cfg)
|
|
140
|
+
return _loaders(train_ds, val_ds, cfg)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def from_hf(dataset_name: str, tokenizer: BaseTokenizer, cfg: DataConfig,
|
|
144
|
+
split: str = "train", text_col: str = "text") -> tuple[DataLoader, DataLoader]:
|
|
145
|
+
_section("🤗 HuggingFace dataset")
|
|
146
|
+
_info("dataset", dataset_name)
|
|
147
|
+
_info("split", split)
|
|
148
|
+
_info("streaming", str(cfg.streaming))
|
|
149
|
+
|
|
150
|
+
from datasets import load_dataset
|
|
151
|
+
ds = load_dataset(dataset_name, split=split, streaming=cfg.streaming)
|
|
152
|
+
|
|
153
|
+
if cfg.streaming:
|
|
154
|
+
def _gen():
|
|
155
|
+
buf = []
|
|
156
|
+
for row in ds:
|
|
157
|
+
buf.extend(tokenizer.encode(row[text_col]))
|
|
158
|
+
while len(buf) >= cfg.seq_len + 1:
|
|
159
|
+
chunk = buf[:cfg.seq_len + 1]
|
|
160
|
+
buf = buf[cfg.seq_len:]
|
|
161
|
+
yield (
|
|
162
|
+
torch.tensor(chunk[:-1], dtype=torch.long),
|
|
163
|
+
torch.tensor(chunk[1:], dtype=torch.long),
|
|
164
|
+
)
|
|
165
|
+
class _HFStream(IterableDataset):
|
|
166
|
+
def __iter__(self): return _gen()
|
|
167
|
+
dl = DataLoader(_HFStream(), batch_size=cfg.batch_size, num_workers=cfg.num_workers)
|
|
168
|
+
_ok("streaming dataloader ready")
|
|
169
|
+
return dl, dl
|
|
170
|
+
|
|
171
|
+
print(f" {C.YELLOW}⏳ downloading...{C.RESET}", flush=True)
|
|
172
|
+
texts = [row[text_col] for row in ds]
|
|
173
|
+
_ok(f"downloaded {len(texts):,} documents")
|
|
174
|
+
tokens = _tokenize(texts, tokenizer)
|
|
175
|
+
train_ds, val_ds = _split(tokens, cfg)
|
|
176
|
+
return _loaders(train_ds, val_ds, cfg)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def from_strings(texts: list[str], tokenizer: BaseTokenizer, cfg: DataConfig) -> tuple[DataLoader, DataLoader]:
|
|
180
|
+
_section("📝 String dataset")
|
|
181
|
+
_info("documents", str(len(texts)))
|
|
182
|
+
tokens = _tokenize(texts, tokenizer)
|
|
183
|
+
train_ds, val_ds = _split(tokens, cfg)
|
|
184
|
+
return _loaders(train_ds, val_ds, cfg)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FFN(nn.Module):
|
|
7
|
+
"""Standard FFN. Used in original Transformer, BERT."""
|
|
8
|
+
def __init__(self, dim: int, hidden: int):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.net = nn.Sequential(
|
|
11
|
+
nn.Linear(dim, hidden),
|
|
12
|
+
nn.GELU(),
|
|
13
|
+
nn.Linear(hidden, dim)
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return self.net(x)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SwiGLU(nn.Module):
|
|
21
|
+
"""Gated FFN with Swish. Used in LLaMA, Mistral, PaLM."""
|
|
22
|
+
def __init__(self, dim: int, hidden: int):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.w1 = nn.Linear(dim, hidden, bias=False)
|
|
25
|
+
self.w2 = nn.Linear(hidden, dim, bias=False)
|
|
26
|
+
self.w3 = nn.Linear(dim, hidden, bias=False)
|
|
27
|
+
|
|
28
|
+
def forward(self, x):
|
|
29
|
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MoE(nn.Module):
|
|
33
|
+
"""Mixture of Experts. Used in Mixtral, GPT-4 (rumoured)."""
|
|
34
|
+
def __init__(self, dim: int, hidden: int, n_experts: int, top_k: int = 2):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.top_k = top_k
|
|
37
|
+
self.gate = nn.Linear(dim, n_experts, bias=False)
|
|
38
|
+
self.experts = nn.ModuleList([SwiGLU(dim, hidden) for _ in range(n_experts)])
|
|
39
|
+
|
|
40
|
+
def forward(self, x):
|
|
41
|
+
B, T, C = x.shape
|
|
42
|
+
x_flat = x.view(-1, C)
|
|
43
|
+
weights, idx = self.gate(x_flat).topk(self.top_k, dim=-1) # [B*T, top_k]
|
|
44
|
+
weights = F.softmax(weights, dim=-1)
|
|
45
|
+
out = torch.zeros_like(x_flat)
|
|
46
|
+
for i, expert in enumerate(self.experts):
|
|
47
|
+
mask = (idx == i).any(dim=-1)
|
|
48
|
+
if mask.any():
|
|
49
|
+
w = weights[mask][idx[mask] == i].unsqueeze(-1)
|
|
50
|
+
out[mask] += w * expert(x_flat[mask])
|
|
51
|
+
return out.view(B, T, C)
|