bpe-lite 0.3.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/spm.js CHANGED
@@ -6,72 +6,194 @@
6
6
  * Merges are ordered list of "token_a token_b" strings — applied by rank (index = priority).
7
7
  */
8
8
 
9
+ const { MinHeap } = require('./bpe');
10
+
9
11
  const SPACE_CHAR = '\u2581'; // ▁
10
12
 
13
+ function buildPreparedSPM(vocabData) {
14
+ const { vocab, merges } = vocabData;
15
+
16
+ const mergeRank = new Map();
17
+ for (let i = 0; i < merges.length; i++) {
18
+ mergeRank.set(merges[i], i);
19
+ }
20
+
21
+ const idToStr = new Map();
22
+ for (const [str, id] of Object.entries(vocab)) {
23
+ idToStr.set(id, str);
24
+ }
25
+
26
+ return {
27
+ vocab,
28
+ merges,
29
+ mergeRank,
30
+ idToStr,
31
+ // opt A — segment-level cache: each ▁-prefixed word segment → ids[]
32
+ // Generalises across different inputs (same words reused across texts).
33
+ // Note: 1 of 514,906 Gemma merges crosses a ▁ boundary ("> ▁</"),
34
+ // making this negligibly imprecise for that HTML pattern.
35
+ cache: new Map(),
36
+ // opt B — per-instance grow-only scratch
37
+ scratch: { str: null, ids: null, prev: null, next: null, ver: null, alive: null, cap: 0, heap: new MinHeap() },
38
+ };
39
+ }
40
+
41
+ // opt B — grow SPM scratch arrays only when needed
42
+ function ensureScratch(scratch, n) {
43
+ if (n <= scratch.cap) return;
44
+ const cap = n * 2;
45
+ scratch.str = new Array(cap);
46
+ scratch.ids = new Int32Array(cap);
47
+ scratch.prev = new Int32Array(cap);
48
+ scratch.next = new Int32Array(cap);
49
+ scratch.ver = new Int32Array(cap);
50
+ scratch.alive = new Uint8Array(cap);
51
+ scratch.cap = cap;
52
+ }
53
+
11
54
  /**
12
55
  * Encode text using SentencePiece BPE.
13
56
  * @param {string} text
14
- * @param {Object} vocabData — { engine, vocab, merges }
57
+ * @param {Object} vocabData — { engine, vocab, merges }
15
58
  * @returns {number[]}
16
59
  */
17
60
  function encodeSPM(text, vocabData) {
18
- if (!text) return [];
19
-
20
- const { vocab, merges } = vocabData;
61
+ return encodeSPMPrepared(text, buildPreparedSPM(vocabData));
62
+ }
21
63
 
22
- // Build merge rank map: "a b" → rank index (lower = higher priority)
23
- const mergeRank = new Map();
24
- for (let i = 0; i < merges.length; i++) {
25
- mergeRank.set(merges[i], i);
64
+ /**
65
+ * Hot path: scan normalized text purely from cache.
66
+ * Returns true if every segment was a cache hit; false on first miss.
67
+ * Isolated from encodeSegment so V8 can keep this function optimised
68
+ * even when encodeSPMPrepared is called with wildly different text lengths.
69
+ */
70
+ function _scanFromCache(normalized, cache, result) {
71
+ let segStart = 0;
72
+ for (let i = 1; i <= normalized.length; i++) {
73
+ if (i === normalized.length || normalized[i] === SPACE_CHAR) {
74
+ const seg = normalized.slice(segStart, i);
75
+ const segIds = cache.get(seg);
76
+ if (segIds === undefined) return false; // cache miss → caller handles cold path
77
+ for (let j = 0; j < segIds.length; j++) result.push(segIds[j]);
78
+ segStart = i;
79
+ }
26
80
  }
81
+ return true;
82
+ }
27
83
 
28
- // Build reverse vocab: id string (for decode)
29
- // vocab: { tokenString: id }
84
+ // Cold-path helper kept separate so it is never inlined into the hot loop.
85
+ function _encodeAndCache(seg, vocab, mergeRank, scratch, cache) {
86
+ const ids = encodeSegment(seg, vocab, mergeRank, scratch);
87
+ cache.set(seg, ids);
88
+ return ids;
89
+ }
90
+
91
+ function encodeSPMPrepared(text, prepared) {
92
+ if (!text) return [];
93
+
94
+ const { vocab, mergeRank, scratch, cache } = prepared;
30
95
 
31
96
  // Normalize: replace spaces with ▁, prepend ▁
32
97
  const normalized = SPACE_CHAR + text.replace(/ /g, SPACE_CHAR);
33
98
 
34
- // Split into individual Unicode characters
35
- const chars = [...normalized];
99
+ // Fast path: serve every segment from the segment cache.
100
+ // After the first call, this path handles all subsequent calls for common text.
101
+ const result = [];
102
+ if (_scanFromCache(normalized, cache, result)) return result;
103
+
104
+ // Cold path: at least one segment is missing — encode everything from scratch.
105
+ // (Simpler to re-scan than to continue from the miss point.)
106
+ result.length = 0;
107
+ let segStart = 0;
108
+ for (let i = 1; i <= normalized.length; i++) {
109
+ if (i === normalized.length || normalized[i] === SPACE_CHAR) {
110
+ const seg = normalized.slice(segStart, i);
111
+ const segIds = cache.get(seg) ?? _encodeAndCache(seg, vocab, mergeRank, scratch, cache);
112
+ for (let j = 0; j < segIds.length; j++) result.push(segIds[j]);
113
+ segStart = i;
114
+ }
115
+ }
116
+ return result;
117
+ }
118
+
119
+ // Encode a single ▁-prefixed segment using MinHeap BPE.
120
+ function encodeSegment(seg, vocab, mergeRank, scratch) {
121
+ const chars = [...seg];
122
+ const n = chars.length;
123
+
124
+ ensureScratch(scratch, n);
125
+ const { str, ids, prev, next, ver, alive, heap } = scratch;
126
+ heap.reset();
127
+
128
+ for (let i = 0; i < n; i++) {
129
+ const c = chars[i];
130
+ prev[i] = i - 1;
131
+ next[i] = i + 1;
132
+ ver[i] = 0;
133
+ alive[i] = 1;
36
134
 
37
- // Map each character to a token (may be a multi-byte char)
38
- let tokens = chars.map(c => {
39
135
  if (vocab[c] !== undefined) {
40
- return { str: c, id: vocab[c] };
41
- }
42
- // Try byte fallback: <0xNN>
43
- const codePoint = c.codePointAt(0);
44
- const hex = codePoint.toString(16).toUpperCase().padStart(2, '0');
45
- const byteKey = `<0x${hex}>`;
46
- if (vocab[byteKey] !== undefined) {
47
- return { str: byteKey, id: vocab[byteKey] };
48
- }
49
- // Unknown — use UNK (id 3 in Gemma) or skip
50
- return { str: c, id: vocab['<unk>'] ?? 0 };
51
- });
52
-
53
- // Greedy BPE merges: find pair with lowest merge rank, merge, repeat
54
- while (tokens.length >= 2) {
55
- let bestRank = Infinity;
56
- let bestIdx = -1;
57
-
58
- for (let i = 0; i < tokens.length - 1; i++) {
59
- const key = `${tokens[i].str} ${tokens[i + 1].str}`;
60
- const rank = mergeRank.get(key);
61
- if (rank !== undefined && rank < bestRank) {
62
- bestRank = rank;
63
- bestIdx = i;
136
+ str[i] = c;
137
+ ids[i] = vocab[c];
138
+ } else {
139
+ const codePoint = c.codePointAt(0);
140
+ const hex = codePoint.toString(16).toUpperCase().padStart(2, '0');
141
+ const byteKey = `<0x${hex}>`;
142
+ if (vocab[byteKey] !== undefined) {
143
+ str[i] = byteKey;
144
+ ids[i] = vocab[byteKey];
145
+ } else {
146
+ str[i] = c;
147
+ ids[i] = vocab['<unk>'] ?? 0;
64
148
  }
65
149
  }
150
+ }
151
+ next[n - 1] = -1;
152
+
153
+ for (let i = 0; i < n - 1; i++) {
154
+ const rank = mergeRank.get(`${str[i]} ${str[i + 1]}`);
155
+ if (rank !== undefined) heap.push(rank, i, i + 1, ver[i], ver[i + 1]);
156
+ }
157
+
158
+ while (heap.size > 0) {
159
+ const top = heap.pop();
160
+ if (!top) break;
161
+ const { left, right, verL, verR } = top;
66
162
 
67
- if (bestIdx === -1) break;
163
+ if (!alive[left] || !alive[right]) continue;
164
+ if (next[left] !== right) continue;
165
+ if (ver[left] !== verL || ver[right] !== verR) continue;
68
166
 
69
- const merged = tokens[bestIdx].str + tokens[bestIdx + 1].str;
70
- const mergedId = vocab[merged] ?? vocab['<unk>'] ?? 0;
71
- tokens.splice(bestIdx, 2, { str: merged, id: mergedId });
167
+ str[left] = str[left] + str[right];
168
+ ids[left] = vocab[str[left]] ?? vocab['<unk>'] ?? 0;
169
+ ver[left]++;
170
+
171
+ alive[right] = 0;
172
+ ver[right]++;
173
+
174
+ const nr = next[right];
175
+ next[left] = nr;
176
+ if (nr !== -1) prev[nr] = left;
177
+
178
+ const pl = prev[left];
179
+ if (pl !== -1 && alive[pl]) {
180
+ const r = mergeRank.get(`${str[pl]} ${str[left]}`);
181
+ if (r !== undefined) heap.push(r, pl, left, ver[pl], ver[left]);
182
+ }
183
+ const nl = next[left];
184
+ if (nl !== -1 && alive[nl]) {
185
+ const r = mergeRank.get(`${str[left]} ${str[nl]}`);
186
+ if (r !== undefined) heap.push(r, left, nl, ver[left], ver[nl]);
187
+ }
72
188
  }
73
189
 
74
- return tokens.map(t => t.id);
190
+ const result = [];
191
+ let i = 0;
192
+ while (i !== -1) {
193
+ if (alive[i]) result.push(ids[i]);
194
+ i = next[i];
195
+ }
196
+ return result;
75
197
  }
76
198
 
77
199
  /**
@@ -81,20 +203,16 @@ function encodeSPM(text, vocabData) {
81
203
  * @returns {string}
82
204
  */
83
205
  function decodeSPM(ids, vocabData) {
84
- if (!ids || ids.length === 0) return '';
85
-
86
- const { vocab } = vocabData;
206
+ return decodeSPMPrepared(ids, buildPreparedSPM(vocabData));
207
+ }
87
208
 
88
- // Build id → string map
89
- const idToStr = new Map();
90
- for (const [str, id] of Object.entries(vocab)) {
91
- idToStr.set(id, str);
92
- }
209
+ function decodeSPMPrepared(ids, prepared) {
210
+ if (!ids || ids.length === 0) return '';
93
211
 
94
212
  let result = '';
95
- for (const id of ids) {
96
- const str = idToStr.get(id) ?? '';
97
- // Handle byte fallbacks like <0x41> → 'A'
213
+ for (let i = 0; i < ids.length; i++) {
214
+ const id = ids[i];
215
+ const str = prepared.idToStr.get(id) ?? '';
98
216
  const byteMatch = str.match(/^<0x([0-9A-Fa-f]{2})>$/);
99
217
  if (byteMatch) {
100
218
  result += String.fromCharCode(parseInt(byteMatch[1], 16));
@@ -107,4 +225,4 @@ function decodeSPM(ids, vocabData) {
107
225
  return result.replace(new RegExp(SPACE_CHAR, 'g'), ' ').replace(/^ /, '');
108
226
  }
109
227
 
110
- module.exports = { encodeSPM, decodeSPM };
228
+ module.exports = { buildPreparedSPM, encodeSPM, decodeSPM, encodeSPMPrepared, decodeSPMPrepared };
package/src/tokenizer.js CHANGED
@@ -1,29 +1,44 @@
1
1
  'use strict';
2
2
 
3
- const { encodeTiktoken, decodeTiktoken, countTiktokenUpTo } = require('./bpe');
4
- const { encodeSPM, decodeSPM } = require('./spm');
3
+ const {
4
+ buildPreparedTiktoken,
5
+ encodeTiktokenPrepared,
6
+ decodeTiktokenPrepared,
7
+ countTiktokenPrepared,
8
+ countTiktokenUpToPrepared,
9
+ } = require('./bpe');
10
+ const { buildPreparedSPM, encodeSPMPrepared, decodeSPMPrepared } = require('./spm');
5
11
 
6
12
  class Tokenizer {
7
13
  constructor(vocabData) {
8
14
  this._data = vocabData;
9
15
  this._engine = vocabData.engine;
16
+ this._preparedTiktoken = null;
17
+ this._preparedSPM = null;
10
18
 
11
19
  if (this._engine !== 'tiktoken' && this._engine !== 'spm') {
12
20
  throw new Error(`Unknown tokenizer engine: ${this._engine}`);
13
21
  }
22
+
23
+ if (this._engine === 'tiktoken') {
24
+ this._preparedTiktoken = buildPreparedTiktoken(vocabData);
25
+ } else {
26
+ this._preparedSPM = buildPreparedSPM(vocabData);
27
+ }
14
28
  }
15
29
 
16
30
  encode(text) {
17
- if (this._engine === 'tiktoken') return encodeTiktoken(text, this._data);
18
- return encodeSPM(text, this._data);
31
+ if (this._engine === 'tiktoken') return encodeTiktokenPrepared(text, this._preparedTiktoken);
32
+ return encodeSPMPrepared(text, this._preparedSPM);
19
33
  }
20
34
 
21
35
  decode(ids) {
22
- if (this._engine === 'tiktoken') return decodeTiktoken(ids, this._data);
23
- return decodeSPM(ids, this._data);
36
+ if (this._engine === 'tiktoken') return decodeTiktokenPrepared(ids, this._preparedTiktoken);
37
+ return decodeSPMPrepared(ids, this._preparedSPM);
24
38
  }
25
39
 
26
40
  count(text) {
41
+ if (this._engine === 'tiktoken') return countTiktokenPrepared(text, this._preparedTiktoken);
27
42
  return this.encode(text).length;
28
43
  }
29
44
 
@@ -35,7 +50,7 @@ class Tokenizer {
35
50
  * @returns {number}
36
51
  */
37
52
  countUpTo(text, limit) {
38
- if (this._engine === 'tiktoken') return countTiktokenUpTo(text, this._data, limit);
53
+ if (this._engine === 'tiktoken') return countTiktokenUpToPrepared(text, this._preparedTiktoken, limit);
39
54
  // SPM encodes the whole text as one unit — no clean early exit, just encode and count
40
55
  return this.encode(text).length;
41
56
  }
Binary file
Binary file