turbobpe 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- turbobpe-0.1.0/LICENSE +21 -0
- turbobpe-0.1.0/MANIFEST.in +18 -0
- turbobpe-0.1.0/PKG-INFO +22 -0
- turbobpe-0.1.0/README.md +0 -0
- turbobpe-0.1.0/pyproject.toml +53 -0
- turbobpe-0.1.0/setup.cfg +4 -0
- turbobpe-0.1.0/setup.py +14 -0
- turbobpe-0.1.0/src/turbobpe/__init__.py +5 -0
- turbobpe-0.1.0/src/turbobpe/base.py +237 -0
- turbobpe-0.1.0/src/turbobpe/regex.py +185 -0
- turbobpe-0.1.0/src/turbobpe/utils.c +9933 -0
- turbobpe-0.1.0/src/turbobpe.egg-info/PKG-INFO +22 -0
- turbobpe-0.1.0/src/turbobpe.egg-info/SOURCES.txt +14 -0
- turbobpe-0.1.0/src/turbobpe.egg-info/dependency_links.txt +1 -0
- turbobpe-0.1.0/src/turbobpe.egg-info/requires.txt +1 -0
- turbobpe-0.1.0/src/turbobpe.egg-info/top_level.txt +1 -0
turbobpe-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Amrendra Gupta
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
include src/turbobpe/utils.c
|
|
2
|
+
|
|
3
|
+
include README.md
|
|
4
|
+
include LICENSE
|
|
5
|
+
include pyproject.toml
|
|
6
|
+
include setup.py
|
|
7
|
+
|
|
8
|
+
exclude src/turbobpe/utils.pyx
|
|
9
|
+
|
|
10
|
+
# Exclude build artefacts and dev tooling
|
|
11
|
+
exclude src/turbobpe/*.so
|
|
12
|
+
exclude src/turbobpe/*.pyd
|
|
13
|
+
prune build
|
|
14
|
+
prune dist
|
|
15
|
+
prune *.egg-info
|
|
16
|
+
global-exclude __pycache__
|
|
17
|
+
global-exclude *.py[cod]
|
|
18
|
+
global-exclude .DS_Store
|
turbobpe-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: turbobpe
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A fast BPE tokenizer with Cython-accelerated core
|
|
5
|
+
Author: Amrendra Gupta
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/Amrendra-gupta/turbobpe
|
|
8
|
+
Project-URL: Repository, https://github.com/Amrendra-gupta/turbobpe
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
16
|
+
Classifier: Operating System :: OS Independent
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.9
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: regex>=2023.0
|
|
22
|
+
Dynamic: license-file
|
turbobpe-0.1.0/README.md
ADDED
|
File without changes
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel", "packaging>=24.2"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "turbobpe"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A fast BPE tokenizer with Cython-accelerated core"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
license-files = ["LICENSE"]
|
|
12
|
+
requires-python = ">=3.9"
|
|
13
|
+
dependencies = ["regex>=2023.0"]
|
|
14
|
+
authors = [{ name = "Amrendra Gupta"}]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Programming Language :: Python :: 3.9",
|
|
18
|
+
"Programming Language :: Python :: 3.10",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
"Programming Language :: Python :: 3.14",
|
|
23
|
+
"Operating System :: OS Independent",
|
|
24
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.urls]
|
|
28
|
+
Homepage = "https://github.com/Amrendra-gupta/turbobpe"
|
|
29
|
+
Repository = "https://github.com/Amrendra-gupta/turbobpe"
|
|
30
|
+
|
|
31
|
+
[tool.setuptools.packages.find]
|
|
32
|
+
where = ["src"]
|
|
33
|
+
|
|
34
|
+
[tool.setuptools.package-data]
|
|
35
|
+
turbobpe = ["*.c"]
|
|
36
|
+
|
|
37
|
+
[tool.cibuildwheel]
|
|
38
|
+
# Build for CPython 3.9–3.14, skip PyPy and 32-bit Windows
|
|
39
|
+
build = "cp39-* cp310-* cp311-* cp312-* cp313-* cp314-*"
|
|
40
|
+
skip = "*-win32 *-musllinux_i686 pp*"
|
|
41
|
+
|
|
42
|
+
test-command = "python -c \"from turbobpe import Tokenizer, RegexTokenizer; print('OK')\""
|
|
43
|
+
|
|
44
|
+
[tool.cibuildwheel.linux]
|
|
45
|
+
manylinux-x86_64-image = "manylinux2014"
|
|
46
|
+
manylinux-aarch64-image = "manylinux2014"
|
|
47
|
+
archs = ["x86_64", "aarch64"]
|
|
48
|
+
|
|
49
|
+
[tool.cibuildwheel.macos]
|
|
50
|
+
archs = ["x86_64", "arm64"]
|
|
51
|
+
|
|
52
|
+
[tool.cibuildwheel.windows]
|
|
53
|
+
archs = ["AMD64"]
|
turbobpe-0.1.0/setup.cfg
ADDED
turbobpe-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Contains the base Tokenizer class and a few common helper functions.
|
|
3
|
+
The base class also contains the (common) save/load functionality.
|
|
4
|
+
It would be possible to be a lot more strict about the interface and
|
|
5
|
+
e.g. isolating all regex/pattern parts to the RegexTokenizer, but
|
|
6
|
+
some concessions are made for simplicity.
|
|
7
|
+
"""
|
|
8
|
+
import unicodedata
|
|
9
|
+
|
|
10
|
+
# -----------------------------------------------------------------------------
|
|
11
|
+
# a few helper functions useful for both BasicTokenizer and RegexTokenizer
|
|
12
|
+
|
|
13
|
+
def get_stats(ids, counts=None):
|
|
14
|
+
"""
|
|
15
|
+
Given a list of integers, return a dictionary of counts of consecutive pairs
|
|
16
|
+
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
|
|
17
|
+
Optionally allows to update an existing dictionary of counts
|
|
18
|
+
"""
|
|
19
|
+
counts = {} if counts is None else counts
|
|
20
|
+
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
|
21
|
+
counts[pair] = counts.get(pair, 0) + 1
|
|
22
|
+
return counts
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def merge(ids, pair, idx):
|
|
26
|
+
"""
|
|
27
|
+
In the list of integers (ids), replace all consecutive occurrences
|
|
28
|
+
of pair with the new integer token idx
|
|
29
|
+
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
|
30
|
+
"""
|
|
31
|
+
newids = []
|
|
32
|
+
i = 0
|
|
33
|
+
while i < len(ids):
|
|
34
|
+
# if not at the very last position AND the pair matches, replace it
|
|
35
|
+
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
|
36
|
+
newids.append(idx)
|
|
37
|
+
i += 2
|
|
38
|
+
else:
|
|
39
|
+
newids.append(ids[i])
|
|
40
|
+
i += 1
|
|
41
|
+
return newids
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def find_overlapping_cases(coordinates):
|
|
45
|
+
"""
|
|
46
|
+
Given a list of (A, B) pairs, return all pairs that form a chain overlap.
|
|
47
|
+
A chain overlap occurs when the right token of one pair is the left token of another: (A, B) and (B, C).
|
|
48
|
+
Merging both in the same batch would be unsafe as the first merge affects the second.
|
|
49
|
+
Example: [(1, 2), (2, 3), (4, 5)] -> [(1, 2), (2, 3)] # (4, 5) is safe, no chain
|
|
50
|
+
"""
|
|
51
|
+
overlapping_cases = []
|
|
52
|
+
coordinate_dict = {}
|
|
53
|
+
|
|
54
|
+
for coord in coordinates:
|
|
55
|
+
if coord[0] not in coordinate_dict:
|
|
56
|
+
coordinate_dict[coord[0]] = []
|
|
57
|
+
coordinate_dict[coord[0]].append(coord[1])
|
|
58
|
+
|
|
59
|
+
# Check for (A, B) and (B, C) pairs
|
|
60
|
+
for a, b_list in coordinate_dict.items():
|
|
61
|
+
for b in b_list:
|
|
62
|
+
if b in coordinate_dict: # Check if B is also a key
|
|
63
|
+
for c in coordinate_dict[b]:
|
|
64
|
+
overlapping_cases.append((a, b))
|
|
65
|
+
overlapping_cases.append((b, c))
|
|
66
|
+
#print(f"(({vocab[a] + vocab[b]}) -- ({vocab[b] + vocab[c]}))")
|
|
67
|
+
|
|
68
|
+
#print(f"Found {len(overlapping_cases)} overlapping cases.")
|
|
69
|
+
return overlapping_cases
|
|
70
|
+
|
|
71
|
+
def filter_top_pairs(top_pairs, overlaps):
|
|
72
|
+
"""
|
|
73
|
+
Given a sorted (descending by frequency) list of top pairs and their overlapping cases,
|
|
74
|
+
return a safe subset for batch merging by dropping chained overlaps.
|
|
75
|
+
The first pair in an overlap chain is kept (highest count, safe to merge), the second is dropped.
|
|
76
|
+
Example: top_pairs=[(1,2),(2,3),(4,5)], overlaps=[(1,2),(2,3)] -> [(1,2),(4,5)]
|
|
77
|
+
"""
|
|
78
|
+
common_count = 0 # Counter for common elements
|
|
79
|
+
result = [] # Store filtered elements
|
|
80
|
+
for pair in top_pairs:
|
|
81
|
+
# Check if the pair is in overlaps
|
|
82
|
+
if pair in overlaps:
|
|
83
|
+
common_count += 1
|
|
84
|
+
# Stop when the second common element is found
|
|
85
|
+
if common_count == 2:
|
|
86
|
+
break
|
|
87
|
+
result.append(pair) # Always add the current element
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
def batch_merge(ids, pairs, indices):
|
|
91
|
+
"""
|
|
92
|
+
Sequentially applies a batch of non-overlapping pair merges
|
|
93
|
+
to a single ID sequence.
|
|
94
|
+
"""
|
|
95
|
+
newids = list(ids)
|
|
96
|
+
n = len(pairs)
|
|
97
|
+
for i in range(n):
|
|
98
|
+
p0, p1 = pairs[i]
|
|
99
|
+
idx = indices[i]
|
|
100
|
+
j = 0
|
|
101
|
+
limit = len(newids) - 1 # cache once per pair
|
|
102
|
+
while j < limit:
|
|
103
|
+
if newids[j] == p0 and newids[j + 1] == p1:
|
|
104
|
+
newids[j] = idx
|
|
105
|
+
del newids[j + 1]
|
|
106
|
+
limit -= 1 # list shrank, update limit
|
|
107
|
+
j += 1
|
|
108
|
+
else:
|
|
109
|
+
j += 1
|
|
110
|
+
if limit == 0: # list is now length 1
|
|
111
|
+
return None
|
|
112
|
+
return newids
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# first two helper functions...
|
|
116
|
+
def replace_control_characters(s: str) -> str:
|
|
117
|
+
# we don't want to print control characters
|
|
118
|
+
# which distort the output (e.g. \n or much worse)
|
|
119
|
+
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
|
|
120
|
+
# http://www.unicode.org/reports/tr44/#GC_Values_Table
|
|
121
|
+
chars = []
|
|
122
|
+
for ch in s:
|
|
123
|
+
if unicodedata.category(ch)[0] != "C":
|
|
124
|
+
chars.append(ch) # this character is ok
|
|
125
|
+
else:
|
|
126
|
+
chars.append(f"\\u{ord(ch):04x}") # escape
|
|
127
|
+
return "".join(chars)
|
|
128
|
+
|
|
129
|
+
def render_token(t: bytes) -> str:
|
|
130
|
+
# pretty print a token, escaping control characters
|
|
131
|
+
s = t.decode('utf-8', errors='replace')
|
|
132
|
+
s = replace_control_characters(s)
|
|
133
|
+
return s
|
|
134
|
+
|
|
135
|
+
# -----------------------------------------------------------------------------
|
|
136
|
+
# the base Tokenizer class
|
|
137
|
+
|
|
138
|
+
class Tokenizer:
|
|
139
|
+
"""Base class for Tokenizers"""
|
|
140
|
+
|
|
141
|
+
def __init__(self):
|
|
142
|
+
# default: vocab size of 256 (all bytes), no merges, no patterns
|
|
143
|
+
self.merges = {} # (int, int) -> int
|
|
144
|
+
self.pattern = "" # str
|
|
145
|
+
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
|
|
146
|
+
self.vocab = self._build_vocab() # int -> bytes
|
|
147
|
+
|
|
148
|
+
def train(self, text, vocab_size, verbose=False):
|
|
149
|
+
# Tokenizer can train a vocabulary of size vocab_size from text
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def encode(self, text):
|
|
153
|
+
# Tokenizer can encode a string into a list of integers
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
def decode(self, ids):
|
|
157
|
+
# Tokenizer can decode a list of integers into a string
|
|
158
|
+
raise NotImplementedError
|
|
159
|
+
|
|
160
|
+
def _build_vocab(self):
|
|
161
|
+
# vocab is simply and deterministically derived from merges
|
|
162
|
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
|
163
|
+
for (p0, p1), idx in self.merges.items():
|
|
164
|
+
vocab[idx] = vocab[p0] + vocab[p1]
|
|
165
|
+
for special, idx in self.special_tokens.items():
|
|
166
|
+
vocab[idx] = special.encode("utf-8")
|
|
167
|
+
return vocab
|
|
168
|
+
|
|
169
|
+
def save(self, file_prefix):
|
|
170
|
+
"""
|
|
171
|
+
Saves two files: file_prefix.vocab and file_prefix.model
|
|
172
|
+
This is inspired (but not equivalent to!) sentencepiece's model saving:
|
|
173
|
+
- model file is the critical one, intended for load()
|
|
174
|
+
- vocab file is just a pretty printed version for human inspection only
|
|
175
|
+
"""
|
|
176
|
+
# write the model: to be used in load() later
|
|
177
|
+
model_file = file_prefix + ".model"
|
|
178
|
+
with open(model_file, 'w') as f:
|
|
179
|
+
# write the version, pattern and merges, that's all that's needed
|
|
180
|
+
f.write("minbpe v1\n")
|
|
181
|
+
f.write(f"{self.pattern}\n")
|
|
182
|
+
# write the special tokens, first the number of them, then each one
|
|
183
|
+
f.write(f"{len(self.special_tokens)}\n")
|
|
184
|
+
for special, idx in self.special_tokens.items():
|
|
185
|
+
f.write(f"{special} {idx}\n")
|
|
186
|
+
# the merges dict
|
|
187
|
+
for idx1, idx2 in self.merges:
|
|
188
|
+
f.write(f"{idx1} {idx2}\n")
|
|
189
|
+
# write the vocab: for the human to look at
|
|
190
|
+
vocab_file = file_prefix + ".vocab"
|
|
191
|
+
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
|
|
192
|
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
|
193
|
+
for idx, token in self.vocab.items():
|
|
194
|
+
# note: many tokens may be partial utf-8 sequences
|
|
195
|
+
# and cannot be decoded into valid strings. Here we're using
|
|
196
|
+
# errors='replace' to replace them with the replacement char �.
|
|
197
|
+
# this also means that we couldn't possibly use .vocab in load()
|
|
198
|
+
# because decoding in this way is a lossy operation!
|
|
199
|
+
s = render_token(token)
|
|
200
|
+
# find the children of this token, if any
|
|
201
|
+
if idx in inverted_merges:
|
|
202
|
+
# if this token has children, render it nicely as a merge
|
|
203
|
+
idx0, idx1 = inverted_merges[idx]
|
|
204
|
+
s0 = render_token(self.vocab[idx0])
|
|
205
|
+
s1 = render_token(self.vocab[idx1])
|
|
206
|
+
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
|
|
207
|
+
else:
|
|
208
|
+
# otherwise this is leaf token, just print it
|
|
209
|
+
# (this should just be the first 256 tokens, the bytes)
|
|
210
|
+
f.write(f"[{s}] {idx}\n")
|
|
211
|
+
|
|
212
|
+
def load(self, model_file):
|
|
213
|
+
"""Inverse of save() but only for the model file"""
|
|
214
|
+
assert model_file.endswith(".model")
|
|
215
|
+
# read the model file
|
|
216
|
+
merges = {}
|
|
217
|
+
special_tokens = {}
|
|
218
|
+
idx = 256
|
|
219
|
+
with open(model_file, 'r', encoding="utf-8") as f:
|
|
220
|
+
# read the version
|
|
221
|
+
version = f.readline().strip()
|
|
222
|
+
assert version == "minbpe v1"
|
|
223
|
+
# read the pattern
|
|
224
|
+
self.pattern = f.readline().strip()
|
|
225
|
+
# read the special tokens
|
|
226
|
+
num_special = int(f.readline().strip())
|
|
227
|
+
for _ in range(num_special):
|
|
228
|
+
special, special_idx = f.readline().strip().split()
|
|
229
|
+
special_tokens[special] = int(special_idx)
|
|
230
|
+
# read the merges
|
|
231
|
+
for line in f:
|
|
232
|
+
idx1, idx2 = map(int, line.split())
|
|
233
|
+
merges[(idx1, idx2)] = idx
|
|
234
|
+
idx += 1
|
|
235
|
+
self.merges = merges
|
|
236
|
+
self.special_tokens = special_tokens
|
|
237
|
+
self.vocab = self._build_vocab()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal (byte-level) Byte Pair Encoding tokenizer.
|
|
3
|
+
|
|
4
|
+
Algorithmically follows along the GPT tokenizer:
|
|
5
|
+
https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
|
6
|
+
|
|
7
|
+
Unlike BasicTokenizer:
|
|
8
|
+
- RegexTokenizer handles an optional regex splitting pattern.
|
|
9
|
+
- RegexTokenizer handles optional special tokens.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import regex as re
|
|
13
|
+
from .base import Tokenizer, find_overlapping_cases, filter_top_pairs
|
|
14
|
+
try:
|
|
15
|
+
from turbobpe.utils import get_stats, merge, batch_merge
|
|
16
|
+
print("Tokenizer initialized (Acceleration: Enabled)")
|
|
17
|
+
except ImportError:
|
|
18
|
+
from .base import get_stats, merge, batch_merge
|
|
19
|
+
print("Tokenizer initialized (Acceleration: Disabled)")
|
|
20
|
+
print("Warning: C-accelerated backend not found. Performance may be degraded.\nEnsure 'turbobpe' is installed for optimal performance.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# the main GPT text split patterns, see
|
|
24
|
+
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
|
|
25
|
+
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RegexTokenizer(Tokenizer):
|
|
29
|
+
|
|
30
|
+
def __init__(self, pattern=None):
|
|
31
|
+
"""
|
|
32
|
+
- pattern: optional string to override the default (GPT-4 split pattern)
|
|
33
|
+
- special_tokens: str -> int dictionary of special tokens
|
|
34
|
+
example: {'<|endoftext|>': 100257}
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.pattern = GPT4_SPLIT_PATTERN
|
|
38
|
+
self.compiled_pattern = re.compile(self.pattern)
|
|
39
|
+
self.special_tokens = {}
|
|
40
|
+
self.inverse_special_tokens = {}
|
|
41
|
+
|
|
42
|
+
def train(self, text, vocab_size, batch_size = 10, verbose=False):
|
|
43
|
+
assert vocab_size >= 256
|
|
44
|
+
num_merges = vocab_size - 256
|
|
45
|
+
|
|
46
|
+
# split the text up into text chunks
|
|
47
|
+
text_chunks = re.findall(self.compiled_pattern, text)
|
|
48
|
+
|
|
49
|
+
# input text preprocessing
|
|
50
|
+
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
|
|
51
|
+
|
|
52
|
+
# iteratively merge the most common pairs to create new tokens
|
|
53
|
+
merges = {} # (int, int) -> int
|
|
54
|
+
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
|
55
|
+
s = 256
|
|
56
|
+
while s - 256 < num_merges:
|
|
57
|
+
# count the number of times every consecutive pair appears
|
|
58
|
+
stats = {}
|
|
59
|
+
for chunk_ids in ids:
|
|
60
|
+
# passing in stats will update it in place, adding up counts
|
|
61
|
+
get_stats(chunk_ids, stats)
|
|
62
|
+
# pairs we attempt to merge at once (e.g., 10) with highest occurance
|
|
63
|
+
top_pairs = sorted(stats.items(), key=lambda x: x[1], reverse=True)[:batch_size]
|
|
64
|
+
top_pairs = [pair[0] for pair in top_pairs]
|
|
65
|
+
# Filter out structural chain overlaps
|
|
66
|
+
overlaps = find_overlapping_cases(top_pairs)
|
|
67
|
+
if len(overlaps)>0:
|
|
68
|
+
top_pairs = filter_top_pairs(top_pairs, overlaps)
|
|
69
|
+
|
|
70
|
+
# Truncate top_pairs if it would overshoot the exact num_merges limit
|
|
71
|
+
remaining_slots = num_merges - (s - 256)
|
|
72
|
+
if len(top_pairs) > remaining_slots:
|
|
73
|
+
top_pairs = top_pairs[:remaining_slots]
|
|
74
|
+
|
|
75
|
+
# mint a new token: assign it the next available id
|
|
76
|
+
idx = [s + j for j in range(len(top_pairs))]
|
|
77
|
+
s += len(idx)
|
|
78
|
+
# replace all occurrences of pair in ids with idx
|
|
79
|
+
ids = [res for chunk_ids in ids if (res := batch_merge(chunk_ids, top_pairs, idx)) is not None]
|
|
80
|
+
|
|
81
|
+
# save the merge and print details
|
|
82
|
+
for pair, new_idx in zip(top_pairs, idx):
|
|
83
|
+
merges[pair] = new_idx
|
|
84
|
+
vocab[new_idx] = vocab[pair[0]] + vocab[pair[1]]
|
|
85
|
+
|
|
86
|
+
if verbose:
|
|
87
|
+
print(f"merge {new_idx-256}: {pair} -> {new_idx} ({vocab[new_idx]}) had {stats[pair]} occurrences")
|
|
88
|
+
|
|
89
|
+
# save class variables
|
|
90
|
+
self.merges = merges # used in encode()
|
|
91
|
+
self.vocab = vocab # used in decode()
|
|
92
|
+
|
|
93
|
+
def register_special_tokens(self, special_tokens):
|
|
94
|
+
# special_tokens is a dictionary of str -> int
|
|
95
|
+
# example: {"<|endoftext|>": 100257}
|
|
96
|
+
self.special_tokens = special_tokens
|
|
97
|
+
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
|
|
98
|
+
|
|
99
|
+
def decode(self, ids):
|
|
100
|
+
# given ids (list of integers), return Python string
|
|
101
|
+
part_bytes = []
|
|
102
|
+
for idx in ids:
|
|
103
|
+
if idx in self.vocab:
|
|
104
|
+
part_bytes.append(self.vocab[idx])
|
|
105
|
+
elif idx in self.inverse_special_tokens:
|
|
106
|
+
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"invalid token id: {idx}")
|
|
109
|
+
text_bytes = b"".join(part_bytes)
|
|
110
|
+
text = text_bytes.decode("utf-8", errors="replace")
|
|
111
|
+
return text
|
|
112
|
+
|
|
113
|
+
def _encode_chunk(self, text_bytes):
|
|
114
|
+
# return the token ids
|
|
115
|
+
# let's begin. first, convert all bytes to integers in range 0..255
|
|
116
|
+
ids = list(text_bytes)
|
|
117
|
+
while len(ids) >= 2:
|
|
118
|
+
# find the pair with the lowest merge index
|
|
119
|
+
stats = get_stats(ids)
|
|
120
|
+
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
|
121
|
+
# subtle: if there are no more merges available, the key will
|
|
122
|
+
# result in an inf for every single pair, and the min will be
|
|
123
|
+
# just the first pair in the list, arbitrarily
|
|
124
|
+
# we can detect this terminating case by a membership check
|
|
125
|
+
if pair not in self.merges:
|
|
126
|
+
break # nothing else can be merged anymore
|
|
127
|
+
# otherwise let's merge the best pair (lowest merge index)
|
|
128
|
+
idx = self.merges[pair]
|
|
129
|
+
ids = merge(ids, pair, idx)
|
|
130
|
+
return ids
|
|
131
|
+
|
|
132
|
+
def encode_ordinary(self, text):
|
|
133
|
+
"""Encoding that ignores any special tokens."""
|
|
134
|
+
# split text into chunks of text by categories defined in regex pattern
|
|
135
|
+
text_chunks = re.findall(self.compiled_pattern, text)
|
|
136
|
+
# all chunks of text are encoded separately, then results are joined
|
|
137
|
+
ids = []
|
|
138
|
+
for chunk in text_chunks:
|
|
139
|
+
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
|
140
|
+
chunk_ids = self._encode_chunk(chunk_bytes)
|
|
141
|
+
ids.extend(chunk_ids)
|
|
142
|
+
return ids
|
|
143
|
+
|
|
144
|
+
def encode(self, text, allowed_special="none_raise"):
|
|
145
|
+
"""
|
|
146
|
+
Unlike encode_ordinary, this function handles special tokens.
|
|
147
|
+
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
|
|
148
|
+
if none_raise, then an error is raised if any special token is encountered in text
|
|
149
|
+
this is the default tiktoken behavior right now as well
|
|
150
|
+
any other behavior is either annoying, or a major footgun
|
|
151
|
+
"""
|
|
152
|
+
# decode the user desire w.r.t. handling of special tokens
|
|
153
|
+
special = None
|
|
154
|
+
if allowed_special == "all":
|
|
155
|
+
special = self.special_tokens
|
|
156
|
+
elif allowed_special == "none":
|
|
157
|
+
special = {}
|
|
158
|
+
elif allowed_special == "none_raise":
|
|
159
|
+
special = {}
|
|
160
|
+
assert all(token not in text for token in self.special_tokens)
|
|
161
|
+
elif isinstance(allowed_special, set):
|
|
162
|
+
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"allowed_special={allowed_special} not understood")
|
|
165
|
+
if not special:
|
|
166
|
+
# shortcut: if no special tokens, just use the ordinary encoding
|
|
167
|
+
return self.encode_ordinary(text)
|
|
168
|
+
# otherwise, we have to be careful with potential special tokens in text
|
|
169
|
+
# we handle special tokens by splitting the text
|
|
170
|
+
# based on the occurrence of any exact match with any of the special tokens
|
|
171
|
+
# we can use re.split for this. note that surrounding the pattern with ()
|
|
172
|
+
# makes it into a capturing group, so the special tokens will be included
|
|
173
|
+
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
|
|
174
|
+
special_chunks = re.split(special_pattern, text)
|
|
175
|
+
# now all the special characters are separated from the rest of the text
|
|
176
|
+
# all chunks of text are encoded separately, then results are joined
|
|
177
|
+
ids = []
|
|
178
|
+
for part in special_chunks:
|
|
179
|
+
if part in special:
|
|
180
|
+
# this is a special token, encode it separately as a special case
|
|
181
|
+
ids.append(special[part])
|
|
182
|
+
else:
|
|
183
|
+
# this is an ordinary sequence, encode it normally
|
|
184
|
+
ids.extend(self.encode_ordinary(part))
|
|
185
|
+
return ids
|