mambacode.js 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +196 -0
- package/package.json +54 -0
- package/src/index.js +89 -0
- package/src/kernels/activations.js +88 -0
- package/src/kernels/conv1d.js +153 -0
- package/src/kernels/linear_projection.js +220 -0
- package/src/kernels/selective_scan.js +350 -0
- package/src/kernels/weight_update.js +120 -0
- package/src/model/mamba_block.js +443 -0
- package/src/model/mamba_model.js +335 -0
- package/src/tokenizer/bpe.js +256 -0
- package/src/training/autograd.js +221 -0
- package/src/training/trainer.js +394 -0
- package/src/utils/gpu_utils.js +217 -0
- package/src/utils/quantization.js +215 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba_model.js – Full Mamba language model.
|
|
3
|
+
*
|
|
4
|
+
* Architecture (matches Qwen3.5-Coder-0.8B-style Mamba):
|
|
5
|
+
*
|
|
6
|
+
* Token IDs ──► Embedding ──► [MambaBlock × numLayers] ──► RMSNorm ──► LM Head
|
|
7
|
+
*
|
|
8
|
+
* The LM Head is a linear projection from dModel → vocabSize.
|
|
9
|
+
* All computations run on WebGPU via the kernels in src/kernels/.
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
import { MambaBlock } from './mamba_block.js';
|
|
13
|
+
import {
|
|
14
|
+
createStorageBuffer,
|
|
15
|
+
createEmptyStorageBuffer,
|
|
16
|
+
createUniformBuffer,
|
|
17
|
+
createComputePipeline,
|
|
18
|
+
createBindGroup,
|
|
19
|
+
dispatchKernel,
|
|
20
|
+
readBuffer,
|
|
21
|
+
cdiv,
|
|
22
|
+
} from '../utils/gpu_utils.js';
|
|
23
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
24
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* @typedef {Object} MambaModelConfig
|
|
28
|
+
* @property {number} vocabSize – vocabulary size (Qwen3.5-Coder: 151936)
|
|
29
|
+
* @property {number} dModel – model (embedding) dimension
|
|
30
|
+
* @property {number} numLayers – number of Mamba blocks
|
|
31
|
+
* @property {number} [dState] – SSM state dimension (default 16)
|
|
32
|
+
* @property {number} [dConv] – conv kernel size (default 4)
|
|
33
|
+
* @property {number} [expand] – inner-dim expansion factor (default 2)
|
|
34
|
+
*/
|
|
35
|
+
|
|
36
|
+
export class MambaModel {
|
|
37
|
+
/**
|
|
38
|
+
* @param {GPUDevice} device
|
|
39
|
+
* @param {MambaModelConfig} config
|
|
40
|
+
*/
|
|
41
|
+
constructor(device, config) {
|
|
42
|
+
this.device = device;
|
|
43
|
+
this.config = {
|
|
44
|
+
dState : 16,
|
|
45
|
+
dConv : 4,
|
|
46
|
+
expand : 2,
|
|
47
|
+
...config,
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
const { vocabSize, dModel, numLayers } = this.config;
|
|
51
|
+
|
|
52
|
+
// Token embedding table: (vocabSize, dModel)
|
|
53
|
+
const embedData = new Float32Array(vocabSize * dModel);
|
|
54
|
+
// Xavier-style initialisation
|
|
55
|
+
const std = 1.0 / Math.sqrt(dModel);
|
|
56
|
+
for (let i = 0; i < embedData.length; i++) {
|
|
57
|
+
const u1 = Math.random(), u2 = Math.random();
|
|
58
|
+
embedData[i] = std * Math.sqrt(-2 * Math.log(u1 + 1e-12)) *
|
|
59
|
+
Math.cos(2 * Math.PI * u2);
|
|
60
|
+
}
|
|
61
|
+
this.gpuEmbedding = createStorageBuffer(device, embedData, true);
|
|
62
|
+
|
|
63
|
+
// Stacked Mamba blocks
|
|
64
|
+
this.blocks = Array.from({ length: numLayers }, () =>
|
|
65
|
+
new MambaBlock(device, {
|
|
66
|
+
dModel,
|
|
67
|
+
dState : this.config.dState,
|
|
68
|
+
dConv : this.config.dConv,
|
|
69
|
+
expand : this.config.expand,
|
|
70
|
+
})
|
|
71
|
+
);
|
|
72
|
+
|
|
73
|
+
// Final RMSNorm
|
|
74
|
+
const finalNormW = new Float32Array(dModel).fill(1.0);
|
|
75
|
+
this.gpuFinalNorm = createStorageBuffer(device, finalNormW, true);
|
|
76
|
+
|
|
77
|
+
// LM Head: (vocabSize, dModel) – tied to embedding by default
|
|
78
|
+
// We share the embedding weight (weight tying saves memory).
|
|
79
|
+
this.tiedEmbedding = true;
|
|
80
|
+
|
|
81
|
+
// Compile pipelines
|
|
82
|
+
this._lmHeadPipeline = createComputePipeline(device, LINEAR_FORWARD_WGSL, 'linear_forward');
|
|
83
|
+
this._rmsnormPipeline = createComputePipeline(device, ACTIVATIONS_WGSL, 'rmsnorm_forward');
|
|
84
|
+
|
|
85
|
+
// LM Head bias (zeroed)
|
|
86
|
+
this.gpuLMHeadBias = createStorageBuffer(device, new Float32Array(vocabSize), true);
|
|
87
|
+
|
|
88
|
+
// Embedding lookup pipeline (gather rows)
|
|
89
|
+
this._embedPipeline = createComputePipeline(device, EMBED_LOOKUP_WGSL, 'embed_lookup');
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// ─── Embedding lookup ─────────────────────────────────────────────────────
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Look up token embeddings.
|
|
96
|
+
*
|
|
97
|
+
* @param {Int32Array|Uint32Array} tokenIds – (batch * seqLen,)
|
|
98
|
+
* @param {number} batch
|
|
99
|
+
* @param {number} seqLen
|
|
100
|
+
* @returns {GPUBuffer} – (batch * seqLen, dModel)
|
|
101
|
+
*/
|
|
102
|
+
embedTokens(tokenIds, batch, seqLen) {
|
|
103
|
+
const { dModel } = this.config;
|
|
104
|
+
const M = batch * seqLen;
|
|
105
|
+
|
|
106
|
+
const idsBuf = createStorageBuffer(this.device,
|
|
107
|
+
tokenIds instanceof Uint32Array ? tokenIds : new Uint32Array(tokenIds), false);
|
|
108
|
+
const outBuf = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
|
|
109
|
+
|
|
110
|
+
const params = new Uint32Array([M, dModel]).buffer;
|
|
111
|
+
const pBuf = createUniformBuffer(this.device, params);
|
|
112
|
+
|
|
113
|
+
const bg = createBindGroup(this.device, this._embedPipeline,
|
|
114
|
+
[pBuf, idsBuf, this.gpuEmbedding, outBuf]);
|
|
115
|
+
dispatchKernel(this.device, this._embedPipeline, bg, [cdiv(M, 64), 1, 1]);
|
|
116
|
+
|
|
117
|
+
idsBuf.destroy();
|
|
118
|
+
pBuf.destroy();
|
|
119
|
+
return outBuf;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// ─── Forward pass ─────────────────────────────────────────────────────────
|
|
123
|
+
|
|
124
|
+
/**
|
|
125
|
+
* Full model forward pass.
|
|
126
|
+
*
|
|
127
|
+
* @param {number[]|Uint32Array} tokenIds – (batch * seqLen,) flat
|
|
128
|
+
* @param {number} batch
|
|
129
|
+
* @param {number} seqLen
|
|
130
|
+
* @returns {Promise<{ logits: Float32Array, gpuLogits: GPUBuffer }>}
|
|
131
|
+
* logits – CPU Float32Array of shape (batch * seqLen, vocabSize)
|
|
132
|
+
* gpuLogits – GPU buffer (same data, for chained backward)
|
|
133
|
+
*/
|
|
134
|
+
async forward(tokenIds, batch, seqLen) {
|
|
135
|
+
const { dModel, vocabSize } = this.config;
|
|
136
|
+
const M = batch * seqLen;
|
|
137
|
+
|
|
138
|
+
// 1. Token embedding lookup
|
|
139
|
+
let hidden = this.embedTokens(tokenIds, batch, seqLen);
|
|
140
|
+
|
|
141
|
+
// 2. Mamba blocks
|
|
142
|
+
const caches = [];
|
|
143
|
+
for (const block of this.blocks) {
|
|
144
|
+
const { output, cache } = block.forward(hidden, batch, seqLen);
|
|
145
|
+
caches.push(cache);
|
|
146
|
+
hidden.destroy();
|
|
147
|
+
hidden = output;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// 3. Final RMSNorm
|
|
151
|
+
const normOut = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
|
|
152
|
+
const normInv = createEmptyStorageBuffer(this.device, M * 4, false);
|
|
153
|
+
{
|
|
154
|
+
const params = new ArrayBuffer(16);
|
|
155
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
156
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
157
|
+
const pBuf = createUniformBuffer(this.device, params);
|
|
158
|
+
const bg = createBindGroup(this.device, this._rmsnormPipeline,
|
|
159
|
+
[pBuf, hidden, this.gpuFinalNorm, normOut, normInv]);
|
|
160
|
+
dispatchKernel(this.device, this._rmsnormPipeline, bg, [cdiv(M, 64), 1, 1]);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// 4. LM Head: (M, vocabSize) = normOut @ embedding^T + bias
|
|
164
|
+
const gpuLogits = createEmptyStorageBuffer(this.device, M * vocabSize * 4, true);
|
|
165
|
+
{
|
|
166
|
+
const params = new Uint32Array([M, dModel, vocabSize]).buffer;
|
|
167
|
+
const pBuf = createUniformBuffer(this.device, params);
|
|
168
|
+
const weightBuf = this.tiedEmbedding ? this.gpuEmbedding : this.gpuLMHeadWeight;
|
|
169
|
+
const bg = createBindGroup(this.device, this._lmHeadPipeline,
|
|
170
|
+
[pBuf, normOut, weightBuf, this.gpuLMHeadBias, gpuLogits]);
|
|
171
|
+
dispatchKernel(this.device, this._lmHeadPipeline, bg,
|
|
172
|
+
[cdiv(M, 16), cdiv(vocabSize, 16), 1]);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
normOut.destroy();
|
|
176
|
+
normInv.destroy();
|
|
177
|
+
|
|
178
|
+
// 5. Read back logits to CPU
|
|
179
|
+
const logits = await readBuffer(this.device, gpuLogits, M * vocabSize * 4);
|
|
180
|
+
|
|
181
|
+
return { logits, gpuLogits, caches };
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
/**
|
|
185
|
+
* Greedy / top-k / temperature-sampled autoregressive generation.
|
|
186
|
+
*
|
|
187
|
+
* @param {number[]} promptIds – starting token IDs
|
|
188
|
+
* @param {number} maxNewTokens
|
|
189
|
+
* @param {{ temperature?: number, topK?: number, topP?: number }} [samplingOpts]
|
|
190
|
+
* @returns {Promise<number[]>} – full sequence (prompt + generated)
|
|
191
|
+
*/
|
|
192
|
+
async generate(promptIds, maxNewTokens = 200, samplingOpts = {}) {
|
|
193
|
+
const { temperature = 1.0, topK = 50, topP = 0.9 } = samplingOpts;
|
|
194
|
+
const { vocabSize } = this.config;
|
|
195
|
+
|
|
196
|
+
let ids = [...promptIds];
|
|
197
|
+
|
|
198
|
+
for (let step = 0; step < maxNewTokens; step++) {
|
|
199
|
+
// Use the full context each step (linear cost with Mamba – no kv-cache needed)
|
|
200
|
+
const { logits } = await this.forward(
|
|
201
|
+
new Uint32Array(ids), 1, ids.length
|
|
202
|
+
);
|
|
203
|
+
// Get logits for the last position
|
|
204
|
+
const lastLogits = logits.slice((ids.length - 1) * vocabSize, ids.length * vocabSize);
|
|
205
|
+
|
|
206
|
+
const nextId = sampleToken(lastLogits, { temperature, topK, topP });
|
|
207
|
+
ids.push(nextId);
|
|
208
|
+
|
|
209
|
+
// Stop on EOS
|
|
210
|
+
if (nextId === this.config.eosId) break;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
return ids;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/**
|
|
217
|
+
* Collect all trainable parameters across all blocks.
|
|
218
|
+
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
219
|
+
*/
|
|
220
|
+
parameters() {
|
|
221
|
+
const params = [];
|
|
222
|
+
|
|
223
|
+
// Embedding
|
|
224
|
+
params.push({
|
|
225
|
+
buf : this.gpuEmbedding,
|
|
226
|
+
numel: this.config.vocabSize * this.config.dModel,
|
|
227
|
+
name : 'embedding',
|
|
228
|
+
});
|
|
229
|
+
|
|
230
|
+
// Blocks
|
|
231
|
+
for (let i = 0; i < this.blocks.length; i++) {
|
|
232
|
+
for (const p of this.blocks[i].parameters()) {
|
|
233
|
+
params.push({ ...p, name: `block${i}.${p.name}` });
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Final norm
|
|
238
|
+
params.push({
|
|
239
|
+
buf : this.gpuFinalNorm,
|
|
240
|
+
numel: this.config.dModel,
|
|
241
|
+
name : 'final_norm',
|
|
242
|
+
});
|
|
243
|
+
|
|
244
|
+
return params;
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
/**
|
|
248
|
+
* Enable WSLA (selective fine-tuning of B and C only) across all blocks.
|
|
249
|
+
* @param {boolean} enabled
|
|
250
|
+
*/
|
|
251
|
+
setWSLAMode(enabled) {
|
|
252
|
+
for (const block of this.blocks) block.setWSLAMode(enabled);
|
|
253
|
+
this._wslaMode = enabled;
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// ─── Embedding lookup WGSL kernel ────────────────────────────────────────────
|
|
258
|
+
|
|
259
|
+
const EMBED_LOOKUP_WGSL = /* wgsl */`
|
|
260
|
+
struct EmbedParams {
|
|
261
|
+
num_tokens : u32,
|
|
262
|
+
d_model : u32,
|
|
263
|
+
};
|
|
264
|
+
|
|
265
|
+
@group(0) @binding(0) var<uniform> params : EmbedParams;
|
|
266
|
+
@group(0) @binding(1) var<storage, read> ids : array<u32>;
|
|
267
|
+
@group(0) @binding(2) var<storage, read> table : array<f32>; // (V, D)
|
|
268
|
+
@group(0) @binding(3) var<storage, read_write> out : array<f32>; // (T, D)
|
|
269
|
+
|
|
270
|
+
@compute @workgroup_size(64, 1, 1)
|
|
271
|
+
fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
272
|
+
let token_idx = gid.x;
|
|
273
|
+
if (token_idx >= params.num_tokens) { return; }
|
|
274
|
+
|
|
275
|
+
let D = params.d_model;
|
|
276
|
+
let tok = ids[token_idx];
|
|
277
|
+
let src = tok * D;
|
|
278
|
+
let dst = token_idx * D;
|
|
279
|
+
|
|
280
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
281
|
+
out[dst + i] = table[src + i];
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
`;
|
|
285
|
+
|
|
286
|
+
// ─── Sampling helper ──────────────────────────────────────────────────────────
|
|
287
|
+
|
|
288
|
+
/**
|
|
289
|
+
* Sample a token from logits using temperature + top-k + nucleus (top-p).
|
|
290
|
+
*
|
|
291
|
+
* @param {Float32Array} logits
|
|
292
|
+
* @param {{ temperature?: number, topK?: number, topP?: number }} opts
|
|
293
|
+
* @returns {number}
|
|
294
|
+
*/
|
|
295
|
+
function sampleToken(logits, { temperature = 1.0, topK = 50, topP = 0.9 } = {}) {
|
|
296
|
+
const n = logits.length;
|
|
297
|
+
|
|
298
|
+
// Apply temperature
|
|
299
|
+
const scaled = new Float32Array(n);
|
|
300
|
+
for (let i = 0; i < n; i++) scaled[i] = logits[i] / Math.max(temperature, 1e-7);
|
|
301
|
+
|
|
302
|
+
// Softmax
|
|
303
|
+
let maxL = -Infinity;
|
|
304
|
+
for (let i = 0; i < n; i++) if (scaled[i] > maxL) maxL = scaled[i];
|
|
305
|
+
let sumE = 0;
|
|
306
|
+
const exps = new Float32Array(n);
|
|
307
|
+
for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i] - maxL); sumE += exps[i]; }
|
|
308
|
+
|
|
309
|
+
// Sort indices by probability (descending)
|
|
310
|
+
const indices = Array.from({ length: n }, (_, i) => i)
|
|
311
|
+
.sort((a, b) => exps[b] - exps[a]);
|
|
312
|
+
|
|
313
|
+
// Top-K filter
|
|
314
|
+
const topKIndices = indices.slice(0, topK);
|
|
315
|
+
|
|
316
|
+
// Nucleus (top-p) filter
|
|
317
|
+
let cumSum = 0;
|
|
318
|
+
const nucleus = [];
|
|
319
|
+
for (const idx of topKIndices) {
|
|
320
|
+
cumSum += exps[idx] / sumE;
|
|
321
|
+
nucleus.push(idx);
|
|
322
|
+
if (cumSum >= topP) break;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
// Sample from nucleus
|
|
326
|
+
let nucleusSum = 0;
|
|
327
|
+
for (const idx of nucleus) nucleusSum += exps[idx];
|
|
328
|
+
const threshold = Math.random() * nucleusSum;
|
|
329
|
+
let acc = 0;
|
|
330
|
+
for (const idx of nucleus) {
|
|
331
|
+
acc += exps[idx];
|
|
332
|
+
if (acc >= threshold) return idx;
|
|
333
|
+
}
|
|
334
|
+
return nucleus[nucleus.length - 1];
|
|
335
|
+
}
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* bpe.js – Browser-side Byte Pair Encoding (BPE) tokenizer.
|
|
3
|
+
*
|
|
4
|
+
* Compatible with Qwen3.5-Coder's vocabulary and merge rules.
|
|
5
|
+
* Implements the standard Hugging Face / tiktoken BPE algorithm:
|
|
6
|
+
*
|
|
7
|
+
* 1. Pre-tokenize text into Unicode code point sequences.
|
|
8
|
+
* 2. Apply byte-level encoding for unknown characters.
|
|
9
|
+
* 3. Iteratively merge the highest-priority adjacent byte-pair
|
|
10
|
+
* according to the learnt merge table.
|
|
11
|
+
*
|
|
12
|
+
* Usage
|
|
13
|
+
* -----
|
|
14
|
+
* const tokenizer = new BPETokenizer();
|
|
15
|
+
* await tokenizer.load(vocabUrl, mergesUrl);
|
|
16
|
+
*
|
|
17
|
+
* const ids = tokenizer.encode("function foo() {}");
|
|
18
|
+
* const text = tokenizer.decode(ids);
|
|
19
|
+
*/
|
|
20
|
+
|
|
21
|
+
// Byte-level fallback map (matches GPT-2 / Qwen convention).
|
|
22
|
+
// Maps raw bytes 0-255 to printable Unicode characters so that every
|
|
23
|
+
// byte sequence has a valid string representation.
|
|
24
|
+
function buildByteEncoder() {
|
|
25
|
+
/** @type {Map<number, string>} */
|
|
26
|
+
const enc = new Map();
|
|
27
|
+
const ranges = [
|
|
28
|
+
[0x21, 0x7E], // ! → ~
|
|
29
|
+
[0xA1, 0xAC], // ¡ → ¬
|
|
30
|
+
[0xAE, 0xFF], // ® → ÿ
|
|
31
|
+
];
|
|
32
|
+
let n = 0;
|
|
33
|
+
for (const [lo, hi] of ranges) {
|
|
34
|
+
for (let b = lo; b <= hi; b++) {
|
|
35
|
+
enc.set(b, String.fromCodePoint(b));
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
for (let b = 0; b < 256; b++) {
|
|
39
|
+
if (!enc.has(b)) {
|
|
40
|
+
enc.set(b, String.fromCodePoint(256 + n));
|
|
41
|
+
n++;
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
return enc;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
const BYTE_ENCODER = buildByteEncoder();
|
|
48
|
+
const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
|
|
49
|
+
|
|
50
|
+
// Pre-tokenisation regex: matches words, numbers, punctuation, whitespace
|
|
51
|
+
// (closely mirrors tiktoken's GPT-2/Qwen pre-tokenizer pattern).
|
|
52
|
+
const PRE_TOKENIZE_RE =
|
|
53
|
+
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
|
|
54
|
+
|
|
55
|
+
export class BPETokenizer {
|
|
56
|
+
constructor() {
|
|
57
|
+
/** @type {Map<string, number>} token → id */
|
|
58
|
+
this.vocab = new Map();
|
|
59
|
+
/** @type {Map<number, string>} id → token */
|
|
60
|
+
this.idToToken = new Map();
|
|
61
|
+
/** @type {Map<string, number>} "a b" → rank (lower = higher priority) */
|
|
62
|
+
this.merges = new Map();
|
|
63
|
+
|
|
64
|
+
// Special tokens (added after loading vocab)
|
|
65
|
+
this.bosToken = '<|im_start|>';
|
|
66
|
+
this.eosToken = '<|im_end|>';
|
|
67
|
+
this.padToken = '<|endoftext|>';
|
|
68
|
+
this.unkToken = '<unk>';
|
|
69
|
+
|
|
70
|
+
this.bosId = null;
|
|
71
|
+
this.eosId = null;
|
|
72
|
+
this.padId = null;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* Load vocabulary and merge rules from JSON / text URLs.
|
|
77
|
+
*
|
|
78
|
+
* vocab.json – { "token": id, ... }
|
|
79
|
+
* merges.txt – one merge per line: "a b" (sorted by rank)
|
|
80
|
+
*
|
|
81
|
+
* @param {string|object} vocab – URL string or pre-parsed vocab object
|
|
82
|
+
* @param {string|string[]} merges – URL string or array of merge strings
|
|
83
|
+
* @returns {Promise<void>}
|
|
84
|
+
*/
|
|
85
|
+
async load(vocab, merges) {
|
|
86
|
+
// Load vocab
|
|
87
|
+
let vocabObj;
|
|
88
|
+
if (typeof vocab === 'string') {
|
|
89
|
+
const res = await fetch(vocab);
|
|
90
|
+
vocabObj = await res.json();
|
|
91
|
+
} else {
|
|
92
|
+
vocabObj = vocab;
|
|
93
|
+
}
|
|
94
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
95
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
96
|
+
|
|
97
|
+
// Load merges
|
|
98
|
+
let mergeLines;
|
|
99
|
+
if (typeof merges === 'string') {
|
|
100
|
+
const res = await fetch(merges);
|
|
101
|
+
const txt = await res.text();
|
|
102
|
+
mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
|
|
103
|
+
} else {
|
|
104
|
+
mergeLines = merges;
|
|
105
|
+
}
|
|
106
|
+
this.merges = new Map();
|
|
107
|
+
mergeLines.forEach((line, rank) => {
|
|
108
|
+
this.merges.set(line.trim(), rank);
|
|
109
|
+
});
|
|
110
|
+
|
|
111
|
+
// Resolve special token ids
|
|
112
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
113
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
114
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
/**
|
|
118
|
+
* Load from plain JavaScript objects (no network fetch).
|
|
119
|
+
* Useful for bundling a small vocabulary directly.
|
|
120
|
+
*
|
|
121
|
+
* @param {Object} vocabObj – { token: id }
|
|
122
|
+
* @param {string[]} mergeArr – ["a b", "c d", ...]
|
|
123
|
+
*/
|
|
124
|
+
loadFromObjects(vocabObj, mergeArr) {
|
|
125
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
126
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
127
|
+
this.merges = new Map(mergeArr.map((m, i) => [m, i]));
|
|
128
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
129
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
130
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// ── Encoding ──────────────────────────────────────────────────────────────
|
|
134
|
+
|
|
135
|
+
/**
|
|
136
|
+
* Encode a string to an array of token IDs.
|
|
137
|
+
*
|
|
138
|
+
* @param {string} text
|
|
139
|
+
* @param {{ addBos?: boolean, addEos?: boolean }} [opts]
|
|
140
|
+
* @returns {number[]}
|
|
141
|
+
*/
|
|
142
|
+
encode(text, opts = {}) {
|
|
143
|
+
const words = text.match(PRE_TOKENIZE_RE) ?? [];
|
|
144
|
+
const ids = [];
|
|
145
|
+
|
|
146
|
+
if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
|
|
147
|
+
|
|
148
|
+
for (const word of words) {
|
|
149
|
+
// Convert to byte-level encoding
|
|
150
|
+
const bytes = new TextEncoder().encode(word);
|
|
151
|
+
const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
|
|
152
|
+
const bpeTokens = this._bpe(byteStr);
|
|
153
|
+
|
|
154
|
+
for (const tok of bpeTokens) {
|
|
155
|
+
const id = this.vocab.get(tok);
|
|
156
|
+
if (id !== undefined) {
|
|
157
|
+
ids.push(id);
|
|
158
|
+
} else {
|
|
159
|
+
// Fallback: encode each character individually
|
|
160
|
+
for (const ch of tok) {
|
|
161
|
+
const cid = this.vocab.get(ch);
|
|
162
|
+
if (cid !== undefined) ids.push(cid);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
|
|
169
|
+
return ids;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
/**
|
|
173
|
+
* Decode an array of token IDs back to a string.
|
|
174
|
+
*
|
|
175
|
+
* @param {number[]} ids
|
|
176
|
+
* @returns {string}
|
|
177
|
+
*/
|
|
178
|
+
decode(ids) {
|
|
179
|
+
let byteStr = '';
|
|
180
|
+
for (const id of ids) {
|
|
181
|
+
const tok = this.idToToken.get(id);
|
|
182
|
+
if (tok !== undefined) byteStr += tok;
|
|
183
|
+
}
|
|
184
|
+
// Convert byte-level string back to raw bytes then UTF-8 decode
|
|
185
|
+
const bytes = new Uint8Array(
|
|
186
|
+
[...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0))
|
|
187
|
+
);
|
|
188
|
+
try {
|
|
189
|
+
return new TextDecoder('utf-8').decode(bytes);
|
|
190
|
+
} catch {
|
|
191
|
+
return byteStr;
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
// ── BPE merge algorithm ───────────────────────────────────────────────────
|
|
196
|
+
|
|
197
|
+
/**
|
|
198
|
+
* Apply BPE merges to a byte-encoded word.
|
|
199
|
+
*
|
|
200
|
+
* @param {string} word – Space-free string of byte-level characters
|
|
201
|
+
* @returns {string[]} – Merged token pieces
|
|
202
|
+
*/
|
|
203
|
+
_bpe(word) {
|
|
204
|
+
if (this.vocab.has(word)) return [word];
|
|
205
|
+
|
|
206
|
+
// Start: each character is its own symbol
|
|
207
|
+
let symbols = [...word];
|
|
208
|
+
|
|
209
|
+
// Iteratively merge the highest-priority pair
|
|
210
|
+
while (symbols.length > 1) {
|
|
211
|
+
let bestRank = Infinity;
|
|
212
|
+
let bestIdx = -1;
|
|
213
|
+
|
|
214
|
+
for (let i = 0; i < symbols.length - 1; i++) {
|
|
215
|
+
const pair = symbols[i] + ' ' + symbols[i + 1];
|
|
216
|
+
const rank = this.merges.get(pair);
|
|
217
|
+
if (rank !== undefined && rank < bestRank) {
|
|
218
|
+
bestRank = rank;
|
|
219
|
+
bestIdx = i;
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if (bestIdx === -1) break; // no more merges available
|
|
224
|
+
|
|
225
|
+
// Merge pair at bestIdx
|
|
226
|
+
const merged = symbols[bestIdx] + symbols[bestIdx + 1];
|
|
227
|
+
symbols = [
|
|
228
|
+
...symbols.slice(0, bestIdx),
|
|
229
|
+
merged,
|
|
230
|
+
...symbols.slice(bestIdx + 2),
|
|
231
|
+
];
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
return symbols;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// ── Padding / truncation helpers ─────────────────────────────────────────
|
|
238
|
+
|
|
239
|
+
/**
|
|
240
|
+
* Pad or truncate a sequence to a fixed length.
|
|
241
|
+
*
|
|
242
|
+
* @param {number[]} ids
|
|
243
|
+
* @param {number} maxLen
|
|
244
|
+
* @param {'right'|'left'} [side='right']
|
|
245
|
+
* @returns {number[]}
|
|
246
|
+
*/
|
|
247
|
+
padOrTruncate(ids, maxLen, side = 'right') {
|
|
248
|
+
if (ids.length >= maxLen) return ids.slice(0, maxLen);
|
|
249
|
+
const padId = this.padId ?? 0;
|
|
250
|
+
const pad = new Array(maxLen - ids.length).fill(padId);
|
|
251
|
+
return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
/** @returns {number} */
|
|
255
|
+
get vocabSize() { return this.vocab.size; }
|
|
256
|
+
}
|