mambacode.js 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +18 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +59 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
  61. package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/trainer.ts +312 -0
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/training/trainer.js +0 -394
  71. package/src/utils/gpu_utils.js +0 -217
  72. package/src/utils/quantization.js +0 -215
  73. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -1,15 +1,8 @@
1
1
  /**
2
- * mamba_model.js – Full Mamba language model.
3
- *
4
- * Architecture (matches Qwen3.5-Coder-0.8B-style Mamba):
5
- *
6
- * Token IDs ──► Embedding ──► [MambaBlock × numLayers] ──► RMSNorm ──► LM Head
7
- *
8
- * The LM Head is a linear projection from dModel → vocabSize.
9
- * All computations run on WebGPU via the kernels in src/kernels/.
2
+ * mamba_model.ts – Full Mamba language model.
10
3
  */
11
4
 
12
- import { MambaBlock } from './mamba_block.js';
5
+ import { MambaBlock, BlockCache, BlockParam } from './mamba_block';
13
6
  import {
14
7
  createStorageBuffer,
15
8
  createEmptyStorageBuffer,
@@ -18,40 +11,60 @@ import {
18
11
  createBindGroup,
19
12
  dispatchKernel,
20
13
  readBuffer,
14
+ uploadBuffer,
21
15
  cdiv,
22
- } from '../utils/gpu_utils.js';
23
- import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
24
- import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
16
+ } from '../utils/gpu_utils';
17
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection';
18
+ import { ACTIVATIONS_WGSL } from '../kernels/activations';
19
+
20
+ export interface MambaModelConfig {
21
+ vocabSize: number;
22
+ dModel: number;
23
+ numLayers: number;
24
+ dState?: number;
25
+ dConv?: number;
26
+ expand?: number;
27
+ eosId?: number;
28
+ }
25
29
 
26
- /**
27
- * @typedef {Object} MambaModelConfig
28
- * @property {number} vocabSize – vocabulary size (Qwen3.5-Coder: 151936)
29
- * @property {number} dModel – model (embedding) dimension
30
- * @property {number} numLayers – number of Mamba blocks
31
- * @property {number} [dState] – SSM state dimension (default 16)
32
- * @property {number} [dConv] – conv kernel size (default 4)
33
- * @property {number} [expand] – inner-dim expansion factor (default 2)
34
- */
30
+ export interface ModelForwardResult {
31
+ logits: Float32Array;
32
+ gpuLogits: GPUBuffer;
33
+ caches: BlockCache[];
34
+ }
35
+
36
+ export interface SamplingOptions {
37
+ temperature?: number;
38
+ topK?: number;
39
+ topP?: number;
40
+ }
35
41
 
36
42
  export class MambaModel {
37
- /**
38
- * @param {GPUDevice} device
39
- * @param {MambaModelConfig} config
40
- */
41
- constructor(device, config) {
43
+ device: GPUDevice;
44
+ config: Required<MambaModelConfig>;
45
+ gpuEmbedding: GPUBuffer;
46
+ blocks: MambaBlock[];
47
+ gpuFinalNorm: GPUBuffer;
48
+ tiedEmbedding: boolean;
49
+ gpuLMHeadBias: GPUBuffer;
50
+ private _lmHeadPipeline: GPUComputePipeline;
51
+ private _rmsnormPipeline: GPUComputePipeline;
52
+ private _embedPipeline: GPUComputePipeline;
53
+ private _wslaMode = false;
54
+
55
+ constructor(device: GPUDevice, config: MambaModelConfig) {
42
56
  this.device = device;
43
57
  this.config = {
44
58
  dState : 16,
45
59
  dConv : 4,
46
60
  expand : 2,
61
+ eosId : -1,
47
62
  ...config,
48
- };
63
+ } as Required<MambaModelConfig>;
49
64
 
50
65
  const { vocabSize, dModel, numLayers } = this.config;
51
66
 
52
- // Token embedding table: (vocabSize, dModel)
53
67
  const embedData = new Float32Array(vocabSize * dModel);
54
- // Xavier-style initialisation
55
68
  const std = 1.0 / Math.sqrt(dModel);
56
69
  for (let i = 0; i < embedData.length; i++) {
57
70
  const u1 = Math.random(), u2 = Math.random();
@@ -60,7 +73,6 @@ export class MambaModel {
60
73
  }
61
74
  this.gpuEmbedding = createStorageBuffer(device, embedData, true);
62
75
 
63
- // Stacked Mamba blocks
64
76
  this.blocks = Array.from({ length: numLayers }, () =>
65
77
  new MambaBlock(device, {
66
78
  dModel,
@@ -70,36 +82,20 @@ export class MambaModel {
70
82
  })
71
83
  );
72
84
 
73
- // Final RMSNorm
74
85
  const finalNormW = new Float32Array(dModel).fill(1.0);
75
86
  this.gpuFinalNorm = createStorageBuffer(device, finalNormW, true);
76
87
 
77
- // LM Head: (vocabSize, dModel) – tied to embedding by default
78
- // We share the embedding weight (weight tying saves memory).
79
88
  this.tiedEmbedding = true;
80
89
 
81
- // Compile pipelines
82
90
  this._lmHeadPipeline = createComputePipeline(device, LINEAR_FORWARD_WGSL, 'linear_forward');
83
91
  this._rmsnormPipeline = createComputePipeline(device, ACTIVATIONS_WGSL, 'rmsnorm_forward');
84
92
 
85
- // LM Head bias (zeroed)
86
93
  this.gpuLMHeadBias = createStorageBuffer(device, new Float32Array(vocabSize), true);
87
94
 
88
- // Embedding lookup pipeline (gather rows)
89
95
  this._embedPipeline = createComputePipeline(device, EMBED_LOOKUP_WGSL, 'embed_lookup');
90
96
  }
91
97
 
92
- // ─── Embedding lookup ─────────────────────────────────────────────────────
93
-
94
- /**
95
- * Look up token embeddings.
96
- *
97
- * @param {Int32Array|Uint32Array} tokenIds – (batch * seqLen,)
98
- * @param {number} batch
99
- * @param {number} seqLen
100
- * @returns {GPUBuffer} – (batch * seqLen, dModel)
101
- */
102
- embedTokens(tokenIds, batch, seqLen) {
98
+ embedTokens(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): GPUBuffer {
103
99
  const { dModel } = this.config;
104
100
  const M = batch * seqLen;
105
101
 
@@ -119,27 +115,13 @@ export class MambaModel {
119
115
  return outBuf;
120
116
  }
121
117
 
122
- // ─── Forward pass ─────────────────────────────────────────────────────────
123
-
124
- /**
125
- * Full model forward pass.
126
- *
127
- * @param {number[]|Uint32Array} tokenIds – (batch * seqLen,) flat
128
- * @param {number} batch
129
- * @param {number} seqLen
130
- * @returns {Promise<{ logits: Float32Array, gpuLogits: GPUBuffer }>}
131
- * logits – CPU Float32Array of shape (batch * seqLen, vocabSize)
132
- * gpuLogits – GPU buffer (same data, for chained backward)
133
- */
134
- async forward(tokenIds, batch, seqLen) {
118
+ async forward(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): Promise<ModelForwardResult> {
135
119
  const { dModel, vocabSize } = this.config;
136
120
  const M = batch * seqLen;
137
121
 
138
- // 1. Token embedding lookup
139
122
  let hidden = this.embedTokens(tokenIds, batch, seqLen);
140
123
 
141
- // 2. Mamba blocks
142
- const caches = [];
124
+ const caches: BlockCache[] = [];
143
125
  for (const block of this.blocks) {
144
126
  const { output, cache } = block.forward(hidden, batch, seqLen);
145
127
  caches.push(cache);
@@ -147,7 +129,6 @@ export class MambaModel {
147
129
  hidden = output;
148
130
  }
149
131
 
150
- // 3. Final RMSNorm
151
132
  const normOut = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
152
133
  const normInv = createEmptyStorageBuffer(this.device, M * 4, false);
153
134
  {
@@ -160,12 +141,11 @@ export class MambaModel {
160
141
  dispatchKernel(this.device, this._rmsnormPipeline, bg, [cdiv(M, 64), 1, 1]);
161
142
  }
162
143
 
163
- // 4. LM Head: (M, vocabSize) = normOut @ embedding^T + bias
164
144
  const gpuLogits = createEmptyStorageBuffer(this.device, M * vocabSize * 4, true);
165
145
  {
166
146
  const params = new Uint32Array([M, dModel, vocabSize]).buffer;
167
147
  const pBuf = createUniformBuffer(this.device, params);
168
- const weightBuf = this.tiedEmbedding ? this.gpuEmbedding : this.gpuLMHeadWeight;
148
+ const weightBuf = this.tiedEmbedding ? this.gpuEmbedding : this.gpuLMHeadBias;
169
149
  const bg = createBindGroup(this.device, this._lmHeadPipeline,
170
150
  [pBuf, normOut, weightBuf, this.gpuLMHeadBias, gpuLogits]);
171
151
  dispatchKernel(this.device, this._lmHeadPipeline, bg,
@@ -175,66 +155,47 @@ export class MambaModel {
175
155
  normOut.destroy();
176
156
  normInv.destroy();
177
157
 
178
- // 5. Read back logits to CPU
179
158
  const logits = await readBuffer(this.device, gpuLogits, M * vocabSize * 4);
180
159
 
181
160
  return { logits, gpuLogits, caches };
182
161
  }
183
162
 
184
- /**
185
- * Greedy / top-k / temperature-sampled autoregressive generation.
186
- *
187
- * @param {number[]} promptIds – starting token IDs
188
- * @param {number} maxNewTokens
189
- * @param {{ temperature?: number, topK?: number, topP?: number }} [samplingOpts]
190
- * @returns {Promise<number[]>} – full sequence (prompt + generated)
191
- */
192
- async generate(promptIds, maxNewTokens = 200, samplingOpts = {}) {
163
+ async generate(promptIds: number[], maxNewTokens = 200, samplingOpts: SamplingOptions = {}): Promise<number[]> {
193
164
  const { temperature = 1.0, topK = 50, topP = 0.9 } = samplingOpts;
194
165
  const { vocabSize } = this.config;
195
166
 
196
167
  let ids = [...promptIds];
197
168
 
198
169
  for (let step = 0; step < maxNewTokens; step++) {
199
- // Use the full context each step (linear cost with Mamba – no kv-cache needed)
200
170
  const { logits } = await this.forward(
201
171
  new Uint32Array(ids), 1, ids.length
202
172
  );
203
- // Get logits for the last position
204
173
  const lastLogits = logits.slice((ids.length - 1) * vocabSize, ids.length * vocabSize);
205
174
 
206
175
  const nextId = sampleToken(lastLogits, { temperature, topK, topP });
207
176
  ids.push(nextId);
208
177
 
209
- // Stop on EOS
210
178
  if (nextId === this.config.eosId) break;
211
179
  }
212
180
 
213
181
  return ids;
214
182
  }
215
183
 
216
- /**
217
- * Collect all trainable parameters across all blocks.
218
- * @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
219
- */
220
- parameters() {
221
- const params = [];
184
+ parameters(): BlockParam[] {
185
+ const params: BlockParam[] = [];
222
186
 
223
- // Embedding
224
187
  params.push({
225
188
  buf : this.gpuEmbedding,
226
189
  numel: this.config.vocabSize * this.config.dModel,
227
190
  name : 'embedding',
228
191
  });
229
192
 
230
- // Blocks
231
193
  for (let i = 0; i < this.blocks.length; i++) {
232
- for (const p of this.blocks[i].parameters()) {
194
+ for (const p of this.blocks[i]!.parameters()) {
233
195
  params.push({ ...p, name: `block${i}.${p.name}` });
234
196
  }
235
197
  }
236
198
 
237
- // Final norm
238
199
  params.push({
239
200
  buf : this.gpuFinalNorm,
240
201
  numel: this.config.dModel,
@@ -244,19 +205,117 @@ export class MambaModel {
244
205
  return params;
245
206
  }
246
207
 
247
- /**
248
- * Enable WSLA (selective fine-tuning of B and C only) across all blocks.
249
- * @param {boolean} enabled
250
- */
251
- setWSLAMode(enabled) {
208
+ setWSLAMode(enabled: boolean): void {
252
209
  for (const block of this.blocks) block.setWSLAMode(enabled);
253
210
  this._wslaMode = enabled;
254
211
  }
255
- }
256
212
 
257
- // ─── Embedding lookup WGSL kernel ────────────────────────────────────────────
213
+ /**
214
+ * Serialise all model parameters to an ArrayBuffer.
215
+ *
216
+ * Binary format:
217
+ * [0..3] magic : uint32 = 0x4D424A53 ('MBJS')
218
+ * [4..7] version: uint32 = 1
219
+ * [8..11] nParams: uint32
220
+ * [12 .. 12+4*nParams-1] numel[i]: uint32 for each parameter i
221
+ * [12+4*nParams ..] float32 data for each parameter, concatenated
222
+ *
223
+ * Save the returned buffer to a file or IndexedDB and reload it with
224
+ * `loadWeights()` to resume from a checkpoint.
225
+ */
226
+ async exportWeights(): Promise<ArrayBuffer> {
227
+ const params = this.parameters();
228
+ const nParams = params.length;
229
+
230
+ // Read all GPU buffers into CPU Float32Arrays
231
+ const arrays: Float32Array[] = await Promise.all(
232
+ params.map(p => readBuffer(this.device, p.buf, p.numel * 4))
233
+ );
234
+
235
+ // Calculate total byte size: header + numel table + all float data
236
+ const headerBytes = 4 + 4 + 4 + nParams * 4; // magic + version + nParams + numel[]
237
+ const dataBytes = arrays.reduce((acc, a) => acc + a.byteLength, 0);
238
+ const out = new ArrayBuffer(headerBytes + dataBytes);
239
+ const view = new DataView(out);
258
240
 
259
- const EMBED_LOOKUP_WGSL = /* wgsl */`
241
+ let offset = 0;
242
+ view.setUint32(offset, 0x4D424A53, true); offset += 4; // magic 'MBJS'
243
+ view.setUint32(offset, 1, true); offset += 4; // version
244
+ view.setUint32(offset, nParams, true); offset += 4; // nParams
245
+
246
+ for (const p of params) {
247
+ view.setUint32(offset, p.numel, true);
248
+ offset += 4;
249
+ }
250
+
251
+ for (const arr of arrays) {
252
+ new Float32Array(out, offset, arr.length).set(arr);
253
+ offset += arr.byteLength;
254
+ }
255
+
256
+ return out;
257
+ }
258
+
259
+ /**
260
+ * Load model parameters from an ArrayBuffer previously produced by
261
+ * `exportWeights()`. The parameter count and element counts must match
262
+ * the current model configuration exactly.
263
+ *
264
+ * @throws {Error} if the magic number, version, or parameter layout do
265
+ * not match the current model.
266
+ */
267
+ async loadWeights(buffer: ArrayBuffer): Promise<void> {
268
+ const view = new DataView(buffer);
269
+ let offset = 0;
270
+
271
+ const magic = view.getUint32(offset, true); offset += 4;
272
+ if (magic !== 0x4D424A53) {
273
+ throw new Error(
274
+ 'Invalid weight file: bad magic number. ' +
275
+ 'Ensure the file was exported by MambaModel.exportWeights().'
276
+ );
277
+ }
278
+
279
+ const version = view.getUint32(offset, true); offset += 4;
280
+ if (version !== 1) {
281
+ throw new Error(`Unsupported weight file version: ${version}. Expected version 1.`);
282
+ }
283
+
284
+ const nParams = view.getUint32(offset, true); offset += 4;
285
+ const params = this.parameters();
286
+
287
+ if (nParams !== params.length) {
288
+ throw new Error(
289
+ `Weight file has ${nParams} parameters but this model has ${params.length}. ` +
290
+ 'Ensure the model configuration matches the one used when exporting.'
291
+ );
292
+ }
293
+
294
+ const numels: number[] = [];
295
+ for (let i = 0; i < nParams; i++) {
296
+ numels.push(view.getUint32(offset, true));
297
+ offset += 4;
298
+ }
299
+
300
+ for (let i = 0; i < nParams; i++) {
301
+ // i is guaranteed in-bounds: nParams === params.length was verified above
302
+ const p = params[i]!;
303
+ const numel = numels[i]!;
304
+ if (numel !== p.numel) {
305
+ throw new Error(
306
+ `Parameter ${i} ("${p.name}") size mismatch: ` +
307
+ `file has ${numel} elements, model expects ${p.numel}.`
308
+ );
309
+ }
310
+
311
+ const slice = new Float32Array(buffer, offset, p.numel);
312
+ uploadBuffer(this.device, p.buf, slice);
313
+ offset += p.numel * 4;
314
+ }
315
+ }
316
+ }
317
+
318
+ const EMBED_LOOKUP_WGSL: string = /* wgsl */`
260
319
  struct EmbedParams {
261
320
  num_tokens : u32,
262
321
  d_model : u32,
@@ -264,8 +323,8 @@ struct EmbedParams {
264
323
 
265
324
  @group(0) @binding(0) var<uniform> params : EmbedParams;
266
325
  @group(0) @binding(1) var<storage, read> ids : array<u32>;
267
- @group(0) @binding(2) var<storage, read> table : array<f32>; // (V, D)
268
- @group(0) @binding(3) var<storage, read_write> out : array<f32>; // (T, D)
326
+ @group(0) @binding(2) var<storage, read> table : array<f32>;
327
+ @group(0) @binding(3) var<storage, read_write> out : array<f32>;
269
328
 
270
329
  @compute @workgroup_size(64, 1, 1)
271
330
  fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
@@ -283,53 +342,38 @@ fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
283
342
  }
284
343
  `;
285
344
 
286
- // ─── Sampling helper ──────────────────────────────────────────────────────────
287
-
288
- /**
289
- * Sample a token from logits using temperature + top-k + nucleus (top-p).
290
- *
291
- * @param {Float32Array} logits
292
- * @param {{ temperature?: number, topK?: number, topP?: number }} opts
293
- * @returns {number}
294
- */
295
- function sampleToken(logits, { temperature = 1.0, topK = 50, topP = 0.9 } = {}) {
345
+ function sampleToken(logits: Float32Array, { temperature = 1.0, topK = 50, topP = 0.9 } = {}): number {
296
346
  const n = logits.length;
297
347
 
298
- // Apply temperature
299
348
  const scaled = new Float32Array(n);
300
- for (let i = 0; i < n; i++) scaled[i] = logits[i] / Math.max(temperature, 1e-7);
349
+ for (let i = 0; i < n; i++) scaled[i] = logits[i]! / Math.max(temperature, 1e-7);
301
350
 
302
- // Softmax
303
351
  let maxL = -Infinity;
304
- for (let i = 0; i < n; i++) if (scaled[i] > maxL) maxL = scaled[i];
352
+ for (let i = 0; i < n; i++) if (scaled[i]! > maxL) maxL = scaled[i]!;
305
353
  let sumE = 0;
306
354
  const exps = new Float32Array(n);
307
- for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i] - maxL); sumE += exps[i]; }
355
+ for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i]! - maxL); sumE += exps[i]!; }
308
356
 
309
- // Sort indices by probability (descending)
310
357
  const indices = Array.from({ length: n }, (_, i) => i)
311
- .sort((a, b) => exps[b] - exps[a]);
358
+ .sort((a, b) => exps[b]! - exps[a]!);
312
359
 
313
- // Top-K filter
314
360
  const topKIndices = indices.slice(0, topK);
315
361
 
316
- // Nucleus (top-p) filter
317
362
  let cumSum = 0;
318
- const nucleus = [];
363
+ const nucleus: number[] = [];
319
364
  for (const idx of topKIndices) {
320
- cumSum += exps[idx] / sumE;
365
+ cumSum += exps[idx]! / sumE;
321
366
  nucleus.push(idx);
322
367
  if (cumSum >= topP) break;
323
368
  }
324
369
 
325
- // Sample from nucleus
326
370
  let nucleusSum = 0;
327
- for (const idx of nucleus) nucleusSum += exps[idx];
371
+ for (const idx of nucleus) nucleusSum += exps[idx]!;
328
372
  const threshold = Math.random() * nucleusSum;
329
373
  let acc = 0;
330
374
  for (const idx of nucleus) {
331
- acc += exps[idx];
375
+ acc += exps[idx]!;
332
376
  if (acc >= threshold) return idx;
333
377
  }
334
- return nucleus[nucleus.length - 1];
378
+ return nucleus[nucleus.length - 1]!;
335
379
  }
@@ -0,0 +1,186 @@
1
+ /**
2
+ * bpe.ts – Browser-side Byte Pair Encoding (BPE) tokenizer.
3
+ */
4
+
5
+ export interface BPEEncodeOptions {
6
+ addBos?: boolean;
7
+ addEos?: boolean;
8
+ }
9
+
10
+ export type PadSide = 'right' | 'left';
11
+
12
+ function buildByteEncoder(): Map<number, string> {
13
+ const enc = new Map<number, string>();
14
+ const ranges: [number, number][] = [
15
+ [0x21, 0x7E],
16
+ [0xA1, 0xAC],
17
+ [0xAE, 0xFF],
18
+ ];
19
+ let n = 0;
20
+ for (const [lo, hi] of ranges) {
21
+ for (let b = lo; b <= hi; b++) {
22
+ enc.set(b, String.fromCodePoint(b));
23
+ }
24
+ }
25
+ for (let b = 0; b < 256; b++) {
26
+ if (!enc.has(b)) {
27
+ enc.set(b, String.fromCodePoint(256 + n));
28
+ n++;
29
+ }
30
+ }
31
+ return enc;
32
+ }
33
+
34
+ const BYTE_ENCODER = buildByteEncoder();
35
+ const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
36
+
37
+ const PRE_TOKENIZE_RE =
38
+ /(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
39
+
40
+ export class BPETokenizer {
41
+ vocab: Map<string, number>;
42
+ idToToken: Map<number, string>;
43
+ merges: Map<string, number>;
44
+ bosToken: string;
45
+ eosToken: string;
46
+ padToken: string;
47
+ unkToken: string;
48
+ bosId: number | null;
49
+ eosId: number | null;
50
+ padId: number | null;
51
+
52
+ constructor() {
53
+ this.vocab = new Map();
54
+ this.idToToken = new Map();
55
+ this.merges = new Map();
56
+ this.bosToken = '<|im_start|>';
57
+ this.eosToken = '<|im_end|>';
58
+ this.padToken = '<|endoftext|>';
59
+ this.unkToken = '<unk>';
60
+ this.bosId = null;
61
+ this.eosId = null;
62
+ this.padId = null;
63
+ }
64
+
65
+ async load(vocab: string | Record<string, number>, merges: string | string[]): Promise<void> {
66
+ let vocabObj: Record<string, number>;
67
+ if (typeof vocab === 'string') {
68
+ const res = await fetch(vocab);
69
+ vocabObj = await res.json() as Record<string, number>;
70
+ } else {
71
+ vocabObj = vocab;
72
+ }
73
+ this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
74
+ this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
75
+
76
+ let mergeLines: string[];
77
+ if (typeof merges === 'string') {
78
+ const res = await fetch(merges);
79
+ const txt = await res.text();
80
+ mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
81
+ } else {
82
+ mergeLines = merges;
83
+ }
84
+ this.merges = new Map();
85
+ mergeLines.forEach((line, rank) => {
86
+ this.merges.set(line.trim(), rank);
87
+ });
88
+
89
+ this.bosId = this.vocab.get(this.bosToken) ?? null;
90
+ this.eosId = this.vocab.get(this.eosToken) ?? null;
91
+ this.padId = this.vocab.get(this.padToken) ?? null;
92
+ }
93
+
94
+ loadFromObjects(vocabObj: Record<string, number>, mergeArr: string[]): void {
95
+ this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
96
+ this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
97
+ this.merges = new Map(mergeArr.map((m, i) => [m, i]));
98
+ this.bosId = this.vocab.get(this.bosToken) ?? null;
99
+ this.eosId = this.vocab.get(this.eosToken) ?? null;
100
+ this.padId = this.vocab.get(this.padToken) ?? null;
101
+ }
102
+
103
+ encode(text: string, opts: BPEEncodeOptions = {}): number[] {
104
+ const words = text.match(PRE_TOKENIZE_RE) ?? [];
105
+ const ids: number[] = [];
106
+
107
+ if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
108
+
109
+ for (const word of words) {
110
+ const bytes = new TextEncoder().encode(word);
111
+ const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
112
+ const bpeTokens = this._bpe(byteStr);
113
+
114
+ for (const tok of bpeTokens) {
115
+ const id = this.vocab.get(tok);
116
+ if (id !== undefined) {
117
+ ids.push(id);
118
+ } else {
119
+ for (const ch of tok) {
120
+ const cid = this.vocab.get(ch);
121
+ if (cid !== undefined) ids.push(cid);
122
+ }
123
+ }
124
+ }
125
+ }
126
+
127
+ if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
128
+ return ids;
129
+ }
130
+
131
+ decode(ids: number[]): string {
132
+ let byteStr = '';
133
+ for (const id of ids) {
134
+ const tok = this.idToToken.get(id);
135
+ if (tok !== undefined) byteStr += tok;
136
+ }
137
+ const bytes = new Uint8Array(
138
+ [...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0) ?? 0)
139
+ );
140
+ try {
141
+ return new TextDecoder('utf-8').decode(bytes);
142
+ } catch {
143
+ return byteStr;
144
+ }
145
+ }
146
+
147
+ _bpe(word: string): string[] {
148
+ if (this.vocab.has(word)) return [word];
149
+
150
+ let symbols = [...word];
151
+
152
+ while (symbols.length > 1) {
153
+ let bestRank = Infinity;
154
+ let bestIdx = -1;
155
+
156
+ for (let i = 0; i < symbols.length - 1; i++) {
157
+ const pair = symbols[i] + ' ' + symbols[i + 1];
158
+ const rank = this.merges.get(pair);
159
+ if (rank !== undefined && rank < bestRank) {
160
+ bestRank = rank;
161
+ bestIdx = i;
162
+ }
163
+ }
164
+
165
+ if (bestIdx === -1) break;
166
+
167
+ const merged = symbols[bestIdx]! + symbols[bestIdx + 1]!;
168
+ symbols = [
169
+ ...symbols.slice(0, bestIdx),
170
+ merged,
171
+ ...symbols.slice(bestIdx + 2),
172
+ ];
173
+ }
174
+
175
+ return symbols;
176
+ }
177
+
178
+ padOrTruncate(ids: number[], maxLen: number, side: PadSide = 'right'): number[] {
179
+ if (ids.length >= maxLen) return ids.slice(0, maxLen);
180
+ const padId = this.padId ?? 0;
181
+ const pad = new Array<number>(maxLen - ids.length).fill(padId);
182
+ return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
183
+ }
184
+
185
+ get vocabSize(): number { return this.vocab.size; }
186
+ }