mambacode.js 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +59 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
- package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +312 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/training/trainer.js +0 -394
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -1,15 +1,8 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* mamba_model.
|
|
3
|
-
*
|
|
4
|
-
* Architecture (matches Qwen3.5-Coder-0.8B-style Mamba):
|
|
5
|
-
*
|
|
6
|
-
* Token IDs ──► Embedding ──► [MambaBlock × numLayers] ──► RMSNorm ──► LM Head
|
|
7
|
-
*
|
|
8
|
-
* The LM Head is a linear projection from dModel → vocabSize.
|
|
9
|
-
* All computations run on WebGPU via the kernels in src/kernels/.
|
|
2
|
+
* mamba_model.ts – Full Mamba language model.
|
|
10
3
|
*/
|
|
11
4
|
|
|
12
|
-
import { MambaBlock } from './mamba_block
|
|
5
|
+
import { MambaBlock, BlockCache, BlockParam } from './mamba_block';
|
|
13
6
|
import {
|
|
14
7
|
createStorageBuffer,
|
|
15
8
|
createEmptyStorageBuffer,
|
|
@@ -18,40 +11,60 @@ import {
|
|
|
18
11
|
createBindGroup,
|
|
19
12
|
dispatchKernel,
|
|
20
13
|
readBuffer,
|
|
14
|
+
uploadBuffer,
|
|
21
15
|
cdiv,
|
|
22
|
-
} from '../utils/gpu_utils
|
|
23
|
-
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection
|
|
24
|
-
import { ACTIVATIONS_WGSL } from '../kernels/activations
|
|
16
|
+
} from '../utils/gpu_utils';
|
|
17
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection';
|
|
18
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations';
|
|
19
|
+
|
|
20
|
+
export interface MambaModelConfig {
|
|
21
|
+
vocabSize: number;
|
|
22
|
+
dModel: number;
|
|
23
|
+
numLayers: number;
|
|
24
|
+
dState?: number;
|
|
25
|
+
dConv?: number;
|
|
26
|
+
expand?: number;
|
|
27
|
+
eosId?: number;
|
|
28
|
+
}
|
|
25
29
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
30
|
+
export interface ModelForwardResult {
|
|
31
|
+
logits: Float32Array;
|
|
32
|
+
gpuLogits: GPUBuffer;
|
|
33
|
+
caches: BlockCache[];
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
export interface SamplingOptions {
|
|
37
|
+
temperature?: number;
|
|
38
|
+
topK?: number;
|
|
39
|
+
topP?: number;
|
|
40
|
+
}
|
|
35
41
|
|
|
36
42
|
export class MambaModel {
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
43
|
+
device: GPUDevice;
|
|
44
|
+
config: Required<MambaModelConfig>;
|
|
45
|
+
gpuEmbedding: GPUBuffer;
|
|
46
|
+
blocks: MambaBlock[];
|
|
47
|
+
gpuFinalNorm: GPUBuffer;
|
|
48
|
+
tiedEmbedding: boolean;
|
|
49
|
+
gpuLMHeadBias: GPUBuffer;
|
|
50
|
+
private _lmHeadPipeline: GPUComputePipeline;
|
|
51
|
+
private _rmsnormPipeline: GPUComputePipeline;
|
|
52
|
+
private _embedPipeline: GPUComputePipeline;
|
|
53
|
+
private _wslaMode = false;
|
|
54
|
+
|
|
55
|
+
constructor(device: GPUDevice, config: MambaModelConfig) {
|
|
42
56
|
this.device = device;
|
|
43
57
|
this.config = {
|
|
44
58
|
dState : 16,
|
|
45
59
|
dConv : 4,
|
|
46
60
|
expand : 2,
|
|
61
|
+
eosId : -1,
|
|
47
62
|
...config,
|
|
48
|
-
}
|
|
63
|
+
} as Required<MambaModelConfig>;
|
|
49
64
|
|
|
50
65
|
const { vocabSize, dModel, numLayers } = this.config;
|
|
51
66
|
|
|
52
|
-
// Token embedding table: (vocabSize, dModel)
|
|
53
67
|
const embedData = new Float32Array(vocabSize * dModel);
|
|
54
|
-
// Xavier-style initialisation
|
|
55
68
|
const std = 1.0 / Math.sqrt(dModel);
|
|
56
69
|
for (let i = 0; i < embedData.length; i++) {
|
|
57
70
|
const u1 = Math.random(), u2 = Math.random();
|
|
@@ -60,7 +73,6 @@ export class MambaModel {
|
|
|
60
73
|
}
|
|
61
74
|
this.gpuEmbedding = createStorageBuffer(device, embedData, true);
|
|
62
75
|
|
|
63
|
-
// Stacked Mamba blocks
|
|
64
76
|
this.blocks = Array.from({ length: numLayers }, () =>
|
|
65
77
|
new MambaBlock(device, {
|
|
66
78
|
dModel,
|
|
@@ -70,36 +82,20 @@ export class MambaModel {
|
|
|
70
82
|
})
|
|
71
83
|
);
|
|
72
84
|
|
|
73
|
-
// Final RMSNorm
|
|
74
85
|
const finalNormW = new Float32Array(dModel).fill(1.0);
|
|
75
86
|
this.gpuFinalNorm = createStorageBuffer(device, finalNormW, true);
|
|
76
87
|
|
|
77
|
-
// LM Head: (vocabSize, dModel) – tied to embedding by default
|
|
78
|
-
// We share the embedding weight (weight tying saves memory).
|
|
79
88
|
this.tiedEmbedding = true;
|
|
80
89
|
|
|
81
|
-
// Compile pipelines
|
|
82
90
|
this._lmHeadPipeline = createComputePipeline(device, LINEAR_FORWARD_WGSL, 'linear_forward');
|
|
83
91
|
this._rmsnormPipeline = createComputePipeline(device, ACTIVATIONS_WGSL, 'rmsnorm_forward');
|
|
84
92
|
|
|
85
|
-
// LM Head bias (zeroed)
|
|
86
93
|
this.gpuLMHeadBias = createStorageBuffer(device, new Float32Array(vocabSize), true);
|
|
87
94
|
|
|
88
|
-
// Embedding lookup pipeline (gather rows)
|
|
89
95
|
this._embedPipeline = createComputePipeline(device, EMBED_LOOKUP_WGSL, 'embed_lookup');
|
|
90
96
|
}
|
|
91
97
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
/**
|
|
95
|
-
* Look up token embeddings.
|
|
96
|
-
*
|
|
97
|
-
* @param {Int32Array|Uint32Array} tokenIds – (batch * seqLen,)
|
|
98
|
-
* @param {number} batch
|
|
99
|
-
* @param {number} seqLen
|
|
100
|
-
* @returns {GPUBuffer} – (batch * seqLen, dModel)
|
|
101
|
-
*/
|
|
102
|
-
embedTokens(tokenIds, batch, seqLen) {
|
|
98
|
+
embedTokens(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): GPUBuffer {
|
|
103
99
|
const { dModel } = this.config;
|
|
104
100
|
const M = batch * seqLen;
|
|
105
101
|
|
|
@@ -119,27 +115,13 @@ export class MambaModel {
|
|
|
119
115
|
return outBuf;
|
|
120
116
|
}
|
|
121
117
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
/**
|
|
125
|
-
* Full model forward pass.
|
|
126
|
-
*
|
|
127
|
-
* @param {number[]|Uint32Array} tokenIds – (batch * seqLen,) flat
|
|
128
|
-
* @param {number} batch
|
|
129
|
-
* @param {number} seqLen
|
|
130
|
-
* @returns {Promise<{ logits: Float32Array, gpuLogits: GPUBuffer }>}
|
|
131
|
-
* logits – CPU Float32Array of shape (batch * seqLen, vocabSize)
|
|
132
|
-
* gpuLogits – GPU buffer (same data, for chained backward)
|
|
133
|
-
*/
|
|
134
|
-
async forward(tokenIds, batch, seqLen) {
|
|
118
|
+
async forward(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): Promise<ModelForwardResult> {
|
|
135
119
|
const { dModel, vocabSize } = this.config;
|
|
136
120
|
const M = batch * seqLen;
|
|
137
121
|
|
|
138
|
-
// 1. Token embedding lookup
|
|
139
122
|
let hidden = this.embedTokens(tokenIds, batch, seqLen);
|
|
140
123
|
|
|
141
|
-
|
|
142
|
-
const caches = [];
|
|
124
|
+
const caches: BlockCache[] = [];
|
|
143
125
|
for (const block of this.blocks) {
|
|
144
126
|
const { output, cache } = block.forward(hidden, batch, seqLen);
|
|
145
127
|
caches.push(cache);
|
|
@@ -147,7 +129,6 @@ export class MambaModel {
|
|
|
147
129
|
hidden = output;
|
|
148
130
|
}
|
|
149
131
|
|
|
150
|
-
// 3. Final RMSNorm
|
|
151
132
|
const normOut = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
|
|
152
133
|
const normInv = createEmptyStorageBuffer(this.device, M * 4, false);
|
|
153
134
|
{
|
|
@@ -160,12 +141,11 @@ export class MambaModel {
|
|
|
160
141
|
dispatchKernel(this.device, this._rmsnormPipeline, bg, [cdiv(M, 64), 1, 1]);
|
|
161
142
|
}
|
|
162
143
|
|
|
163
|
-
// 4. LM Head: (M, vocabSize) = normOut @ embedding^T + bias
|
|
164
144
|
const gpuLogits = createEmptyStorageBuffer(this.device, M * vocabSize * 4, true);
|
|
165
145
|
{
|
|
166
146
|
const params = new Uint32Array([M, dModel, vocabSize]).buffer;
|
|
167
147
|
const pBuf = createUniformBuffer(this.device, params);
|
|
168
|
-
const weightBuf = this.tiedEmbedding ? this.gpuEmbedding : this.
|
|
148
|
+
const weightBuf = this.tiedEmbedding ? this.gpuEmbedding : this.gpuLMHeadBias;
|
|
169
149
|
const bg = createBindGroup(this.device, this._lmHeadPipeline,
|
|
170
150
|
[pBuf, normOut, weightBuf, this.gpuLMHeadBias, gpuLogits]);
|
|
171
151
|
dispatchKernel(this.device, this._lmHeadPipeline, bg,
|
|
@@ -175,66 +155,47 @@ export class MambaModel {
|
|
|
175
155
|
normOut.destroy();
|
|
176
156
|
normInv.destroy();
|
|
177
157
|
|
|
178
|
-
// 5. Read back logits to CPU
|
|
179
158
|
const logits = await readBuffer(this.device, gpuLogits, M * vocabSize * 4);
|
|
180
159
|
|
|
181
160
|
return { logits, gpuLogits, caches };
|
|
182
161
|
}
|
|
183
162
|
|
|
184
|
-
|
|
185
|
-
* Greedy / top-k / temperature-sampled autoregressive generation.
|
|
186
|
-
*
|
|
187
|
-
* @param {number[]} promptIds – starting token IDs
|
|
188
|
-
* @param {number} maxNewTokens
|
|
189
|
-
* @param {{ temperature?: number, topK?: number, topP?: number }} [samplingOpts]
|
|
190
|
-
* @returns {Promise<number[]>} – full sequence (prompt + generated)
|
|
191
|
-
*/
|
|
192
|
-
async generate(promptIds, maxNewTokens = 200, samplingOpts = {}) {
|
|
163
|
+
async generate(promptIds: number[], maxNewTokens = 200, samplingOpts: SamplingOptions = {}): Promise<number[]> {
|
|
193
164
|
const { temperature = 1.0, topK = 50, topP = 0.9 } = samplingOpts;
|
|
194
165
|
const { vocabSize } = this.config;
|
|
195
166
|
|
|
196
167
|
let ids = [...promptIds];
|
|
197
168
|
|
|
198
169
|
for (let step = 0; step < maxNewTokens; step++) {
|
|
199
|
-
// Use the full context each step (linear cost with Mamba – no kv-cache needed)
|
|
200
170
|
const { logits } = await this.forward(
|
|
201
171
|
new Uint32Array(ids), 1, ids.length
|
|
202
172
|
);
|
|
203
|
-
// Get logits for the last position
|
|
204
173
|
const lastLogits = logits.slice((ids.length - 1) * vocabSize, ids.length * vocabSize);
|
|
205
174
|
|
|
206
175
|
const nextId = sampleToken(lastLogits, { temperature, topK, topP });
|
|
207
176
|
ids.push(nextId);
|
|
208
177
|
|
|
209
|
-
// Stop on EOS
|
|
210
178
|
if (nextId === this.config.eosId) break;
|
|
211
179
|
}
|
|
212
180
|
|
|
213
181
|
return ids;
|
|
214
182
|
}
|
|
215
183
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
219
|
-
*/
|
|
220
|
-
parameters() {
|
|
221
|
-
const params = [];
|
|
184
|
+
parameters(): BlockParam[] {
|
|
185
|
+
const params: BlockParam[] = [];
|
|
222
186
|
|
|
223
|
-
// Embedding
|
|
224
187
|
params.push({
|
|
225
188
|
buf : this.gpuEmbedding,
|
|
226
189
|
numel: this.config.vocabSize * this.config.dModel,
|
|
227
190
|
name : 'embedding',
|
|
228
191
|
});
|
|
229
192
|
|
|
230
|
-
// Blocks
|
|
231
193
|
for (let i = 0; i < this.blocks.length; i++) {
|
|
232
|
-
for (const p of this.blocks[i]
|
|
194
|
+
for (const p of this.blocks[i]!.parameters()) {
|
|
233
195
|
params.push({ ...p, name: `block${i}.${p.name}` });
|
|
234
196
|
}
|
|
235
197
|
}
|
|
236
198
|
|
|
237
|
-
// Final norm
|
|
238
199
|
params.push({
|
|
239
200
|
buf : this.gpuFinalNorm,
|
|
240
201
|
numel: this.config.dModel,
|
|
@@ -244,19 +205,117 @@ export class MambaModel {
|
|
|
244
205
|
return params;
|
|
245
206
|
}
|
|
246
207
|
|
|
247
|
-
|
|
248
|
-
* Enable WSLA (selective fine-tuning of B and C only) across all blocks.
|
|
249
|
-
* @param {boolean} enabled
|
|
250
|
-
*/
|
|
251
|
-
setWSLAMode(enabled) {
|
|
208
|
+
setWSLAMode(enabled: boolean): void {
|
|
252
209
|
for (const block of this.blocks) block.setWSLAMode(enabled);
|
|
253
210
|
this._wslaMode = enabled;
|
|
254
211
|
}
|
|
255
|
-
}
|
|
256
212
|
|
|
257
|
-
|
|
213
|
+
/**
|
|
214
|
+
* Serialise all model parameters to an ArrayBuffer.
|
|
215
|
+
*
|
|
216
|
+
* Binary format:
|
|
217
|
+
* [0..3] magic : uint32 = 0x4D424A53 ('MBJS')
|
|
218
|
+
* [4..7] version: uint32 = 1
|
|
219
|
+
* [8..11] nParams: uint32
|
|
220
|
+
* [12 .. 12+4*nParams-1] numel[i]: uint32 for each parameter i
|
|
221
|
+
* [12+4*nParams ..] float32 data for each parameter, concatenated
|
|
222
|
+
*
|
|
223
|
+
* Save the returned buffer to a file or IndexedDB and reload it with
|
|
224
|
+
* `loadWeights()` to resume from a checkpoint.
|
|
225
|
+
*/
|
|
226
|
+
async exportWeights(): Promise<ArrayBuffer> {
|
|
227
|
+
const params = this.parameters();
|
|
228
|
+
const nParams = params.length;
|
|
229
|
+
|
|
230
|
+
// Read all GPU buffers into CPU Float32Arrays
|
|
231
|
+
const arrays: Float32Array[] = await Promise.all(
|
|
232
|
+
params.map(p => readBuffer(this.device, p.buf, p.numel * 4))
|
|
233
|
+
);
|
|
234
|
+
|
|
235
|
+
// Calculate total byte size: header + numel table + all float data
|
|
236
|
+
const headerBytes = 4 + 4 + 4 + nParams * 4; // magic + version + nParams + numel[]
|
|
237
|
+
const dataBytes = arrays.reduce((acc, a) => acc + a.byteLength, 0);
|
|
238
|
+
const out = new ArrayBuffer(headerBytes + dataBytes);
|
|
239
|
+
const view = new DataView(out);
|
|
258
240
|
|
|
259
|
-
|
|
241
|
+
let offset = 0;
|
|
242
|
+
view.setUint32(offset, 0x4D424A53, true); offset += 4; // magic 'MBJS'
|
|
243
|
+
view.setUint32(offset, 1, true); offset += 4; // version
|
|
244
|
+
view.setUint32(offset, nParams, true); offset += 4; // nParams
|
|
245
|
+
|
|
246
|
+
for (const p of params) {
|
|
247
|
+
view.setUint32(offset, p.numel, true);
|
|
248
|
+
offset += 4;
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
for (const arr of arrays) {
|
|
252
|
+
new Float32Array(out, offset, arr.length).set(arr);
|
|
253
|
+
offset += arr.byteLength;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
return out;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
/**
|
|
260
|
+
* Load model parameters from an ArrayBuffer previously produced by
|
|
261
|
+
* `exportWeights()`. The parameter count and element counts must match
|
|
262
|
+
* the current model configuration exactly.
|
|
263
|
+
*
|
|
264
|
+
* @throws {Error} if the magic number, version, or parameter layout do
|
|
265
|
+
* not match the current model.
|
|
266
|
+
*/
|
|
267
|
+
async loadWeights(buffer: ArrayBuffer): Promise<void> {
|
|
268
|
+
const view = new DataView(buffer);
|
|
269
|
+
let offset = 0;
|
|
270
|
+
|
|
271
|
+
const magic = view.getUint32(offset, true); offset += 4;
|
|
272
|
+
if (magic !== 0x4D424A53) {
|
|
273
|
+
throw new Error(
|
|
274
|
+
'Invalid weight file: bad magic number. ' +
|
|
275
|
+
'Ensure the file was exported by MambaModel.exportWeights().'
|
|
276
|
+
);
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
const version = view.getUint32(offset, true); offset += 4;
|
|
280
|
+
if (version !== 1) {
|
|
281
|
+
throw new Error(`Unsupported weight file version: ${version}. Expected version 1.`);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
const nParams = view.getUint32(offset, true); offset += 4;
|
|
285
|
+
const params = this.parameters();
|
|
286
|
+
|
|
287
|
+
if (nParams !== params.length) {
|
|
288
|
+
throw new Error(
|
|
289
|
+
`Weight file has ${nParams} parameters but this model has ${params.length}. ` +
|
|
290
|
+
'Ensure the model configuration matches the one used when exporting.'
|
|
291
|
+
);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
const numels: number[] = [];
|
|
295
|
+
for (let i = 0; i < nParams; i++) {
|
|
296
|
+
numels.push(view.getUint32(offset, true));
|
|
297
|
+
offset += 4;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
for (let i = 0; i < nParams; i++) {
|
|
301
|
+
// i is guaranteed in-bounds: nParams === params.length was verified above
|
|
302
|
+
const p = params[i]!;
|
|
303
|
+
const numel = numels[i]!;
|
|
304
|
+
if (numel !== p.numel) {
|
|
305
|
+
throw new Error(
|
|
306
|
+
`Parameter ${i} ("${p.name}") size mismatch: ` +
|
|
307
|
+
`file has ${numel} elements, model expects ${p.numel}.`
|
|
308
|
+
);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
const slice = new Float32Array(buffer, offset, p.numel);
|
|
312
|
+
uploadBuffer(this.device, p.buf, slice);
|
|
313
|
+
offset += p.numel * 4;
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
const EMBED_LOOKUP_WGSL: string = /* wgsl */`
|
|
260
319
|
struct EmbedParams {
|
|
261
320
|
num_tokens : u32,
|
|
262
321
|
d_model : u32,
|
|
@@ -264,8 +323,8 @@ struct EmbedParams {
|
|
|
264
323
|
|
|
265
324
|
@group(0) @binding(0) var<uniform> params : EmbedParams;
|
|
266
325
|
@group(0) @binding(1) var<storage, read> ids : array<u32>;
|
|
267
|
-
@group(0) @binding(2) var<storage, read> table : array<f32>;
|
|
268
|
-
@group(0) @binding(3) var<storage, read_write> out : array<f32>;
|
|
326
|
+
@group(0) @binding(2) var<storage, read> table : array<f32>;
|
|
327
|
+
@group(0) @binding(3) var<storage, read_write> out : array<f32>;
|
|
269
328
|
|
|
270
329
|
@compute @workgroup_size(64, 1, 1)
|
|
271
330
|
fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
@@ -283,53 +342,38 @@ fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
283
342
|
}
|
|
284
343
|
`;
|
|
285
344
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
/**
|
|
289
|
-
* Sample a token from logits using temperature + top-k + nucleus (top-p).
|
|
290
|
-
*
|
|
291
|
-
* @param {Float32Array} logits
|
|
292
|
-
* @param {{ temperature?: number, topK?: number, topP?: number }} opts
|
|
293
|
-
* @returns {number}
|
|
294
|
-
*/
|
|
295
|
-
function sampleToken(logits, { temperature = 1.0, topK = 50, topP = 0.9 } = {}) {
|
|
345
|
+
function sampleToken(logits: Float32Array, { temperature = 1.0, topK = 50, topP = 0.9 } = {}): number {
|
|
296
346
|
const n = logits.length;
|
|
297
347
|
|
|
298
|
-
// Apply temperature
|
|
299
348
|
const scaled = new Float32Array(n);
|
|
300
|
-
for (let i = 0; i < n; i++) scaled[i] = logits[i] / Math.max(temperature, 1e-7);
|
|
349
|
+
for (let i = 0; i < n; i++) scaled[i] = logits[i]! / Math.max(temperature, 1e-7);
|
|
301
350
|
|
|
302
|
-
// Softmax
|
|
303
351
|
let maxL = -Infinity;
|
|
304
|
-
for (let i = 0; i < n; i++) if (scaled[i] > maxL) maxL = scaled[i]
|
|
352
|
+
for (let i = 0; i < n; i++) if (scaled[i]! > maxL) maxL = scaled[i]!;
|
|
305
353
|
let sumE = 0;
|
|
306
354
|
const exps = new Float32Array(n);
|
|
307
|
-
for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i] - maxL); sumE += exps[i]
|
|
355
|
+
for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i]! - maxL); sumE += exps[i]!; }
|
|
308
356
|
|
|
309
|
-
// Sort indices by probability (descending)
|
|
310
357
|
const indices = Array.from({ length: n }, (_, i) => i)
|
|
311
|
-
.sort((a, b) => exps[b] - exps[a]);
|
|
358
|
+
.sort((a, b) => exps[b]! - exps[a]!);
|
|
312
359
|
|
|
313
|
-
// Top-K filter
|
|
314
360
|
const topKIndices = indices.slice(0, topK);
|
|
315
361
|
|
|
316
|
-
// Nucleus (top-p) filter
|
|
317
362
|
let cumSum = 0;
|
|
318
|
-
const nucleus = [];
|
|
363
|
+
const nucleus: number[] = [];
|
|
319
364
|
for (const idx of topKIndices) {
|
|
320
|
-
cumSum += exps[idx] / sumE;
|
|
365
|
+
cumSum += exps[idx]! / sumE;
|
|
321
366
|
nucleus.push(idx);
|
|
322
367
|
if (cumSum >= topP) break;
|
|
323
368
|
}
|
|
324
369
|
|
|
325
|
-
// Sample from nucleus
|
|
326
370
|
let nucleusSum = 0;
|
|
327
|
-
for (const idx of nucleus) nucleusSum += exps[idx]
|
|
371
|
+
for (const idx of nucleus) nucleusSum += exps[idx]!;
|
|
328
372
|
const threshold = Math.random() * nucleusSum;
|
|
329
373
|
let acc = 0;
|
|
330
374
|
for (const idx of nucleus) {
|
|
331
|
-
acc += exps[idx]
|
|
375
|
+
acc += exps[idx]!;
|
|
332
376
|
if (acc >= threshold) return idx;
|
|
333
377
|
}
|
|
334
|
-
return nucleus[nucleus.length - 1]
|
|
378
|
+
return nucleus[nucleus.length - 1]!;
|
|
335
379
|
}
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* bpe.ts – Browser-side Byte Pair Encoding (BPE) tokenizer.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export interface BPEEncodeOptions {
|
|
6
|
+
addBos?: boolean;
|
|
7
|
+
addEos?: boolean;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
export type PadSide = 'right' | 'left';
|
|
11
|
+
|
|
12
|
+
function buildByteEncoder(): Map<number, string> {
|
|
13
|
+
const enc = new Map<number, string>();
|
|
14
|
+
const ranges: [number, number][] = [
|
|
15
|
+
[0x21, 0x7E],
|
|
16
|
+
[0xA1, 0xAC],
|
|
17
|
+
[0xAE, 0xFF],
|
|
18
|
+
];
|
|
19
|
+
let n = 0;
|
|
20
|
+
for (const [lo, hi] of ranges) {
|
|
21
|
+
for (let b = lo; b <= hi; b++) {
|
|
22
|
+
enc.set(b, String.fromCodePoint(b));
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
for (let b = 0; b < 256; b++) {
|
|
26
|
+
if (!enc.has(b)) {
|
|
27
|
+
enc.set(b, String.fromCodePoint(256 + n));
|
|
28
|
+
n++;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
return enc;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
const BYTE_ENCODER = buildByteEncoder();
|
|
35
|
+
const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
|
|
36
|
+
|
|
37
|
+
const PRE_TOKENIZE_RE =
|
|
38
|
+
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
|
|
39
|
+
|
|
40
|
+
export class BPETokenizer {
|
|
41
|
+
vocab: Map<string, number>;
|
|
42
|
+
idToToken: Map<number, string>;
|
|
43
|
+
merges: Map<string, number>;
|
|
44
|
+
bosToken: string;
|
|
45
|
+
eosToken: string;
|
|
46
|
+
padToken: string;
|
|
47
|
+
unkToken: string;
|
|
48
|
+
bosId: number | null;
|
|
49
|
+
eosId: number | null;
|
|
50
|
+
padId: number | null;
|
|
51
|
+
|
|
52
|
+
constructor() {
|
|
53
|
+
this.vocab = new Map();
|
|
54
|
+
this.idToToken = new Map();
|
|
55
|
+
this.merges = new Map();
|
|
56
|
+
this.bosToken = '<|im_start|>';
|
|
57
|
+
this.eosToken = '<|im_end|>';
|
|
58
|
+
this.padToken = '<|endoftext|>';
|
|
59
|
+
this.unkToken = '<unk>';
|
|
60
|
+
this.bosId = null;
|
|
61
|
+
this.eosId = null;
|
|
62
|
+
this.padId = null;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
async load(vocab: string | Record<string, number>, merges: string | string[]): Promise<void> {
|
|
66
|
+
let vocabObj: Record<string, number>;
|
|
67
|
+
if (typeof vocab === 'string') {
|
|
68
|
+
const res = await fetch(vocab);
|
|
69
|
+
vocabObj = await res.json() as Record<string, number>;
|
|
70
|
+
} else {
|
|
71
|
+
vocabObj = vocab;
|
|
72
|
+
}
|
|
73
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
74
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
75
|
+
|
|
76
|
+
let mergeLines: string[];
|
|
77
|
+
if (typeof merges === 'string') {
|
|
78
|
+
const res = await fetch(merges);
|
|
79
|
+
const txt = await res.text();
|
|
80
|
+
mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
|
|
81
|
+
} else {
|
|
82
|
+
mergeLines = merges;
|
|
83
|
+
}
|
|
84
|
+
this.merges = new Map();
|
|
85
|
+
mergeLines.forEach((line, rank) => {
|
|
86
|
+
this.merges.set(line.trim(), rank);
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
90
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
91
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
loadFromObjects(vocabObj: Record<string, number>, mergeArr: string[]): void {
|
|
95
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
96
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
97
|
+
this.merges = new Map(mergeArr.map((m, i) => [m, i]));
|
|
98
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
99
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
100
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
encode(text: string, opts: BPEEncodeOptions = {}): number[] {
|
|
104
|
+
const words = text.match(PRE_TOKENIZE_RE) ?? [];
|
|
105
|
+
const ids: number[] = [];
|
|
106
|
+
|
|
107
|
+
if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
|
|
108
|
+
|
|
109
|
+
for (const word of words) {
|
|
110
|
+
const bytes = new TextEncoder().encode(word);
|
|
111
|
+
const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
|
|
112
|
+
const bpeTokens = this._bpe(byteStr);
|
|
113
|
+
|
|
114
|
+
for (const tok of bpeTokens) {
|
|
115
|
+
const id = this.vocab.get(tok);
|
|
116
|
+
if (id !== undefined) {
|
|
117
|
+
ids.push(id);
|
|
118
|
+
} else {
|
|
119
|
+
for (const ch of tok) {
|
|
120
|
+
const cid = this.vocab.get(ch);
|
|
121
|
+
if (cid !== undefined) ids.push(cid);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
|
|
128
|
+
return ids;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
decode(ids: number[]): string {
|
|
132
|
+
let byteStr = '';
|
|
133
|
+
for (const id of ids) {
|
|
134
|
+
const tok = this.idToToken.get(id);
|
|
135
|
+
if (tok !== undefined) byteStr += tok;
|
|
136
|
+
}
|
|
137
|
+
const bytes = new Uint8Array(
|
|
138
|
+
[...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0) ?? 0)
|
|
139
|
+
);
|
|
140
|
+
try {
|
|
141
|
+
return new TextDecoder('utf-8').decode(bytes);
|
|
142
|
+
} catch {
|
|
143
|
+
return byteStr;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
_bpe(word: string): string[] {
|
|
148
|
+
if (this.vocab.has(word)) return [word];
|
|
149
|
+
|
|
150
|
+
let symbols = [...word];
|
|
151
|
+
|
|
152
|
+
while (symbols.length > 1) {
|
|
153
|
+
let bestRank = Infinity;
|
|
154
|
+
let bestIdx = -1;
|
|
155
|
+
|
|
156
|
+
for (let i = 0; i < symbols.length - 1; i++) {
|
|
157
|
+
const pair = symbols[i] + ' ' + symbols[i + 1];
|
|
158
|
+
const rank = this.merges.get(pair);
|
|
159
|
+
if (rank !== undefined && rank < bestRank) {
|
|
160
|
+
bestRank = rank;
|
|
161
|
+
bestIdx = i;
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (bestIdx === -1) break;
|
|
166
|
+
|
|
167
|
+
const merged = symbols[bestIdx]! + symbols[bestIdx + 1]!;
|
|
168
|
+
symbols = [
|
|
169
|
+
...symbols.slice(0, bestIdx),
|
|
170
|
+
merged,
|
|
171
|
+
...symbols.slice(bestIdx + 2),
|
|
172
|
+
];
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
return symbols;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
padOrTruncate(ids: number[], maxLen: number, side: PadSide = 'right'): number[] {
|
|
179
|
+
if (ids.length >= maxLen) return ids.slice(0, maxLen);
|
|
180
|
+
const padId = this.padId ?? 0;
|
|
181
|
+
const pad = new Array<number>(maxLen - ids.length).fill(padId);
|
|
182
|
+
return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
get vocabSize(): number { return this.vocab.size; }
|
|
186
|
+
}
|