mambacode.js 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +196 -0
- package/package.json +54 -0
- package/src/index.js +89 -0
- package/src/kernels/activations.js +88 -0
- package/src/kernels/conv1d.js +153 -0
- package/src/kernels/linear_projection.js +220 -0
- package/src/kernels/selective_scan.js +350 -0
- package/src/kernels/weight_update.js +120 -0
- package/src/model/mamba_block.js +443 -0
- package/src/model/mamba_model.js +335 -0
- package/src/tokenizer/bpe.js +256 -0
- package/src/training/autograd.js +221 -0
- package/src/training/trainer.js +394 -0
- package/src/utils/gpu_utils.js +217 -0
- package/src/utils/quantization.js +215 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.js – WebGPU device management and buffer helpers.
|
|
3
|
+
*
|
|
4
|
+
* Provides thin, consistent wrappers around the WebGPU API so that
|
|
5
|
+
* the rest of MambaCode.js never calls navigator.gpu directly.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* Initialise WebGPU and return the { device, adapter } pair.
|
|
10
|
+
*
|
|
11
|
+
* @param {{ powerPreference?: 'high-performance'|'low-power' }} [opts]
|
|
12
|
+
* @returns {Promise<{ device: GPUDevice, adapter: GPUAdapter }>}
|
|
13
|
+
*/
|
|
14
|
+
export async function initWebGPU(opts = {}) {
|
|
15
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
16
|
+
throw new Error(
|
|
17
|
+
'WebGPU is not available in this environment. ' +
|
|
18
|
+
'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.'
|
|
19
|
+
);
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
const adapter = await navigator.gpu.requestAdapter({
|
|
23
|
+
powerPreference: opts.powerPreference ?? 'high-performance',
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
if (!adapter) {
|
|
27
|
+
throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
// Request a device, capping requested limits to what the adapter supports.
|
|
31
|
+
const adapterLimits = adapter.limits;
|
|
32
|
+
const requested3GB = 3 * 1024 * 1024 * 1024;
|
|
33
|
+
const device = await adapter.requestDevice({
|
|
34
|
+
requiredLimits: {
|
|
35
|
+
maxBufferSize: Math.min(
|
|
36
|
+
requested3GB,
|
|
37
|
+
adapterLimits.maxBufferSize
|
|
38
|
+
),
|
|
39
|
+
maxStorageBufferBindingSize: Math.min(
|
|
40
|
+
requested3GB,
|
|
41
|
+
adapterLimits.maxStorageBufferBindingSize
|
|
42
|
+
),
|
|
43
|
+
maxComputeInvocationsPerWorkgroup: Math.min(
|
|
44
|
+
256,
|
|
45
|
+
adapterLimits.maxComputeInvocationsPerWorkgroup
|
|
46
|
+
),
|
|
47
|
+
},
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
device.lost.then((info) => {
|
|
51
|
+
console.error('WebGPU device lost:', info.message);
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
return { device, adapter };
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// ─── Buffer factory helpers ───────────────────────────────────────────────────
|
|
58
|
+
|
|
59
|
+
const UNIFORM = GPUBufferUsage?.UNIFORM ?? 0x40;
|
|
60
|
+
const STORAGE = GPUBufferUsage?.STORAGE ?? 0x80;
|
|
61
|
+
const COPY_SRC = GPUBufferUsage?.COPY_SRC ?? 0x04;
|
|
62
|
+
const COPY_DST = GPUBufferUsage?.COPY_DST ?? 0x08;
|
|
63
|
+
const MAP_READ = GPUBufferUsage?.MAP_READ ?? 0x01;
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
* Create a GPU storage buffer pre-filled with Float32 data.
|
|
67
|
+
*
|
|
68
|
+
* @param {GPUDevice} device
|
|
69
|
+
* @param {Float32Array|number[]} data
|
|
70
|
+
* @param {boolean} [readable=false] Also attach COPY_SRC so it can be read back.
|
|
71
|
+
* @returns {GPUBuffer}
|
|
72
|
+
*/
|
|
73
|
+
export function createStorageBuffer(device, data, readable = false) {
|
|
74
|
+
const arr = data instanceof Float32Array ? data : new Float32Array(data);
|
|
75
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
76
|
+
const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
|
|
77
|
+
new Float32Array(buffer.getMappedRange()).set(arr);
|
|
78
|
+
buffer.unmap();
|
|
79
|
+
return buffer;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Create a GPU storage buffer of `size` bytes, zeroed.
|
|
84
|
+
*
|
|
85
|
+
* @param {GPUDevice} device
|
|
86
|
+
* @param {number} byteSize
|
|
87
|
+
* @param {boolean} [readable=false]
|
|
88
|
+
* @returns {GPUBuffer}
|
|
89
|
+
*/
|
|
90
|
+
export function createEmptyStorageBuffer(device, byteSize, readable = false) {
|
|
91
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
92
|
+
return device.createBuffer({ size: byteSize, usage });
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* Create a uniform buffer for a plain-old-data struct.
|
|
97
|
+
* The caller must supply a correctly-packed ArrayBuffer / TypedArray.
|
|
98
|
+
*
|
|
99
|
+
* @param {GPUDevice} device
|
|
100
|
+
* @param {ArrayBuffer|TypedArray} data
|
|
101
|
+
* @returns {GPUBuffer}
|
|
102
|
+
*/
|
|
103
|
+
export function createUniformBuffer(device, data) {
|
|
104
|
+
const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
|
|
105
|
+
const buffer = device.createBuffer({
|
|
106
|
+
size : bytes.byteLength,
|
|
107
|
+
usage : UNIFORM | COPY_DST,
|
|
108
|
+
mappedAtCreation: true,
|
|
109
|
+
});
|
|
110
|
+
new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
|
|
111
|
+
buffer.unmap();
|
|
112
|
+
return buffer;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
/**
|
|
116
|
+
* Read back a GPU storage buffer to a Float32Array (async, for debugging/eval).
|
|
117
|
+
*
|
|
118
|
+
* @param {GPUDevice} device
|
|
119
|
+
* @param {GPUBuffer} srcBuffer Must have COPY_SRC usage.
|
|
120
|
+
* @param {number} byteSize
|
|
121
|
+
* @returns {Promise<Float32Array>}
|
|
122
|
+
*/
|
|
123
|
+
export async function readBuffer(device, srcBuffer, byteSize) {
|
|
124
|
+
const stagingBuffer = device.createBuffer({
|
|
125
|
+
size : byteSize,
|
|
126
|
+
usage : MAP_READ | COPY_DST,
|
|
127
|
+
});
|
|
128
|
+
|
|
129
|
+
const encoder = device.createCommandEncoder();
|
|
130
|
+
encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
|
|
131
|
+
device.queue.submit([encoder.finish()]);
|
|
132
|
+
|
|
133
|
+
await stagingBuffer.mapAsync(GPUMapMode?.READ ?? 0x01);
|
|
134
|
+
const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
135
|
+
stagingBuffer.unmap();
|
|
136
|
+
stagingBuffer.destroy();
|
|
137
|
+
return result;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Upload a Float32Array to an existing GPU buffer.
|
|
142
|
+
*
|
|
143
|
+
* @param {GPUDevice} device
|
|
144
|
+
* @param {GPUBuffer} buffer Must have COPY_DST usage.
|
|
145
|
+
* @param {Float32Array} data
|
|
146
|
+
* @param {number} [byteOffset=0]
|
|
147
|
+
*/
|
|
148
|
+
export function uploadBuffer(device, buffer, data, byteOffset = 0) {
|
|
149
|
+
device.queue.writeBuffer(buffer, byteOffset, data);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// ─── Pipeline / Shader helpers ────────────────────────────────────────────────
|
|
153
|
+
|
|
154
|
+
/**
|
|
155
|
+
* Compile a WGSL compute shader and return a GPUComputePipeline.
|
|
156
|
+
*
|
|
157
|
+
* @param {GPUDevice} device
|
|
158
|
+
* @param {string} wgslSource
|
|
159
|
+
* @param {string} entryPoint
|
|
160
|
+
* @returns {GPUComputePipeline}
|
|
161
|
+
*/
|
|
162
|
+
export function createComputePipeline(device, wgslSource, entryPoint) {
|
|
163
|
+
const shaderModule = device.createShaderModule({ code: wgslSource });
|
|
164
|
+
return device.createComputePipeline({
|
|
165
|
+
layout : 'auto',
|
|
166
|
+
compute: { module: shaderModule, entryPoint },
|
|
167
|
+
});
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
/**
|
|
171
|
+
* Build a GPUBindGroup from an array of GPUBuffer bindings.
|
|
172
|
+
*
|
|
173
|
+
* @param {GPUDevice} device
|
|
174
|
+
* @param {GPUComputePipeline} pipeline
|
|
175
|
+
* @param {GPUBuffer[]} buffers Ordered list matching @binding(i).
|
|
176
|
+
* @param {number} [groupIndex=0]
|
|
177
|
+
* @returns {GPUBindGroup}
|
|
178
|
+
*/
|
|
179
|
+
export function createBindGroup(device, pipeline, buffers, groupIndex = 0) {
|
|
180
|
+
const entries = buffers.map((buf, i) => ({
|
|
181
|
+
binding : i,
|
|
182
|
+
resource: { buffer: buf },
|
|
183
|
+
}));
|
|
184
|
+
return device.createBindGroup({
|
|
185
|
+
layout : pipeline.getBindGroupLayout(groupIndex),
|
|
186
|
+
entries,
|
|
187
|
+
});
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
/**
|
|
191
|
+
* Dispatch a compute pipeline synchronously (encodes + submits in one call).
|
|
192
|
+
*
|
|
193
|
+
* @param {GPUDevice} device
|
|
194
|
+
* @param {GPUComputePipeline} pipeline
|
|
195
|
+
* @param {GPUBindGroup} bindGroup
|
|
196
|
+
* @param {[number, number, number]} workgroups [x, y, z]
|
|
197
|
+
*/
|
|
198
|
+
export function dispatchKernel(device, pipeline, bindGroup, workgroups) {
|
|
199
|
+
const encoder = device.createCommandEncoder();
|
|
200
|
+
const pass = encoder.beginComputePass();
|
|
201
|
+
pass.setPipeline(pipeline);
|
|
202
|
+
pass.setBindGroup(0, bindGroup);
|
|
203
|
+
pass.dispatchWorkgroups(...workgroups);
|
|
204
|
+
pass.end();
|
|
205
|
+
device.queue.submit([encoder.finish()]);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/**
|
|
209
|
+
* Ceil-divide helper: Math.ceil(a / b) in integer arithmetic.
|
|
210
|
+
*
|
|
211
|
+
* @param {number} a
|
|
212
|
+
* @param {number} b
|
|
213
|
+
* @returns {number}
|
|
214
|
+
*/
|
|
215
|
+
export function cdiv(a, b) {
|
|
216
|
+
return Math.ceil(a / b);
|
|
217
|
+
}
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* quantization.js – FP16 and Int8 quantization utilities.
|
|
3
|
+
*
|
|
4
|
+
* MambaCode.js supports two quantization modes to reduce VRAM usage:
|
|
5
|
+
* • FP16 – weights stored as 16-bit floats (halves memory vs FP32)
|
|
6
|
+
* • Int8 – non-critical activations quantized to signed 8-bit integers
|
|
7
|
+
*
|
|
8
|
+
* All quantization/dequantization happens in JavaScript; the GPU kernels
|
|
9
|
+
* always operate on FP32 tensors internally (dequantized on upload).
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
// ─── FP16 Utilities ──────────────────────────────────────────────────────────
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Convert a 32-bit float to a 16-bit IEEE 754 float (represented as Uint16).
|
|
16
|
+
* Uses bit manipulation to avoid the need for a Float16Array (not in spec yet).
|
|
17
|
+
*
|
|
18
|
+
* @param {number} val – 32-bit float
|
|
19
|
+
* @returns {number} – 16-bit float packed as an integer (0–65535)
|
|
20
|
+
*/
|
|
21
|
+
export function floatToFp16(val) {
|
|
22
|
+
const buf = new ArrayBuffer(4);
|
|
23
|
+
const f32 = new Float32Array(buf);
|
|
24
|
+
const u32 = new Uint32Array(buf);
|
|
25
|
+
f32[0] = val;
|
|
26
|
+
const bits = u32[0];
|
|
27
|
+
|
|
28
|
+
const sign = (bits >>> 31) & 0x1;
|
|
29
|
+
const exponent = (bits >>> 23) & 0xFF;
|
|
30
|
+
const mantissa = bits & 0x7FFFFF;
|
|
31
|
+
|
|
32
|
+
if (exponent === 255) {
|
|
33
|
+
// Inf / NaN
|
|
34
|
+
return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
const expAdj = exponent - 127 + 15; // re-bias from 127 to 15
|
|
38
|
+
|
|
39
|
+
if (expAdj >= 31) {
|
|
40
|
+
// Overflow → Inf
|
|
41
|
+
return (sign << 15) | 0x7C00;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
if (expAdj <= 0) {
|
|
45
|
+
// Underflow or denormal
|
|
46
|
+
if (expAdj < -10) { return sign << 15; } // flush to zero
|
|
47
|
+
const shift = 14 - expAdj;
|
|
48
|
+
return (sign << 15) | ((mantissa | 0x800000) >> shift);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Convert a 16-bit FP16 integer to a 32-bit float.
|
|
56
|
+
*
|
|
57
|
+
* @param {number} val – Uint16 representation of an FP16 value
|
|
58
|
+
* @returns {number} – JavaScript number (float64, but semantically float32)
|
|
59
|
+
*/
|
|
60
|
+
export function fp16ToFloat(val) {
|
|
61
|
+
const sign = (val >>> 15) & 0x1;
|
|
62
|
+
const exponent = (val >>> 10) & 0x1F;
|
|
63
|
+
const mantissa = val & 0x3FF;
|
|
64
|
+
|
|
65
|
+
if (exponent === 0) {
|
|
66
|
+
// Denormal or zero
|
|
67
|
+
const f = mantissa / 1024.0;
|
|
68
|
+
return sign ? -f : f;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if (exponent === 31) {
|
|
72
|
+
// Inf / NaN
|
|
73
|
+
return sign ? -Infinity : (mantissa ? NaN : Infinity);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
const expUnbiased = exponent - 15;
|
|
77
|
+
const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
|
|
78
|
+
return sign ? -f : f;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Quantize a Float32Array to FP16 (stored as Uint16Array).
|
|
83
|
+
*
|
|
84
|
+
* @param {Float32Array} f32
|
|
85
|
+
* @returns {Uint16Array}
|
|
86
|
+
*/
|
|
87
|
+
export function quantizeFp16(f32) {
|
|
88
|
+
const out = new Uint16Array(f32.length);
|
|
89
|
+
for (let i = 0; i < f32.length; i++) {
|
|
90
|
+
out[i] = floatToFp16(f32[i]);
|
|
91
|
+
}
|
|
92
|
+
return out;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* Dequantize a Uint16Array (FP16) back to Float32Array.
|
|
97
|
+
*
|
|
98
|
+
* @param {Uint16Array} fp16
|
|
99
|
+
* @returns {Float32Array}
|
|
100
|
+
*/
|
|
101
|
+
export function dequantizeFp16(fp16) {
|
|
102
|
+
const out = new Float32Array(fp16.length);
|
|
103
|
+
for (let i = 0; i < fp16.length; i++) {
|
|
104
|
+
out[i] = fp16ToFloat(fp16[i]);
|
|
105
|
+
}
|
|
106
|
+
return out;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// ─── Int8 Quantization ───────────────────────────────────────────────────────
|
|
110
|
+
|
|
111
|
+
/**
|
|
112
|
+
* Symmetric per-tensor Int8 quantization.
|
|
113
|
+
* Quantization: q = round(x / scale), scale = max(|x|) / 127
|
|
114
|
+
*
|
|
115
|
+
* @param {Float32Array} f32
|
|
116
|
+
* @returns {{ data: Int8Array, scale: number }}
|
|
117
|
+
*/
|
|
118
|
+
export function quantizeInt8(f32) {
|
|
119
|
+
let maxAbs = 0;
|
|
120
|
+
for (let i = 0; i < f32.length; i++) {
|
|
121
|
+
const a = Math.abs(f32[i]);
|
|
122
|
+
if (a > maxAbs) maxAbs = a;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
const scale = maxAbs / 127.0 || 1.0; // avoid division by zero
|
|
126
|
+
const data = new Int8Array(f32.length);
|
|
127
|
+
|
|
128
|
+
for (let i = 0; i < f32.length; i++) {
|
|
129
|
+
data[i] = Math.max(-128, Math.min(127, Math.round(f32[i] / scale)));
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
return { data, scale };
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/**
|
|
136
|
+
* Dequantize an Int8Array back to Float32Array.
|
|
137
|
+
*
|
|
138
|
+
* @param {Int8Array} int8
|
|
139
|
+
* @param {number} scale
|
|
140
|
+
* @returns {Float32Array}
|
|
141
|
+
*/
|
|
142
|
+
export function dequantizeInt8(int8, scale) {
|
|
143
|
+
const out = new Float32Array(int8.length);
|
|
144
|
+
for (let i = 0; i < int8.length; i++) {
|
|
145
|
+
out[i] = int8[i] * scale;
|
|
146
|
+
}
|
|
147
|
+
return out;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
/**
|
|
151
|
+
* Per-channel Int8 quantization (useful for weight matrices).
|
|
152
|
+
* Each output channel gets its own scale factor for better accuracy.
|
|
153
|
+
*
|
|
154
|
+
* @param {Float32Array} f32 – Flat weight tensor, row-major
|
|
155
|
+
* @param {number} numChannels – Number of output channels (rows)
|
|
156
|
+
* @returns {{ data: Int8Array, scales: Float32Array }}
|
|
157
|
+
*/
|
|
158
|
+
export function quantizeInt8PerChannel(f32, numChannels) {
|
|
159
|
+
const channelSize = f32.length / numChannels;
|
|
160
|
+
const scales = new Float32Array(numChannels);
|
|
161
|
+
const data = new Int8Array(f32.length);
|
|
162
|
+
|
|
163
|
+
for (let c = 0; c < numChannels; c++) {
|
|
164
|
+
let maxAbs = 0;
|
|
165
|
+
const base = c * channelSize;
|
|
166
|
+
for (let j = 0; j < channelSize; j++) {
|
|
167
|
+
const a = Math.abs(f32[base + j]);
|
|
168
|
+
if (a > maxAbs) maxAbs = a;
|
|
169
|
+
}
|
|
170
|
+
scales[c] = maxAbs / 127.0 || 1.0;
|
|
171
|
+
for (let j = 0; j < channelSize; j++) {
|
|
172
|
+
data[base + j] = Math.max(-128, Math.min(127,
|
|
173
|
+
Math.round(f32[base + j] / scales[c])
|
|
174
|
+
));
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return { data, scales };
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/**
|
|
182
|
+
* Dequantize per-channel Int8 data.
|
|
183
|
+
*
|
|
184
|
+
* @param {Int8Array} int8
|
|
185
|
+
* @param {Float32Array} scales
|
|
186
|
+
* @param {number} numChannels
|
|
187
|
+
* @returns {Float32Array}
|
|
188
|
+
*/
|
|
189
|
+
export function dequantizeInt8PerChannel(int8, scales, numChannels) {
|
|
190
|
+
const channelSize = int8.length / numChannels;
|
|
191
|
+
const out = new Float32Array(int8.length);
|
|
192
|
+
|
|
193
|
+
for (let c = 0; c < numChannels; c++) {
|
|
194
|
+
const base = c * channelSize;
|
|
195
|
+
for (let j = 0; j < channelSize; j++) {
|
|
196
|
+
out[base + j] = int8[base + j] * scales[c];
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
return out;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
/**
|
|
204
|
+
* Estimate memory usage for a weight tensor under different precisions.
|
|
205
|
+
*
|
|
206
|
+
* @param {number} numElements
|
|
207
|
+
* @returns {{ fp32: number, fp16: number, int8: number }} – bytes
|
|
208
|
+
*/
|
|
209
|
+
export function estimateMemory(numElements) {
|
|
210
|
+
return {
|
|
211
|
+
fp32: numElements * 4,
|
|
212
|
+
fp16: numElements * 2,
|
|
213
|
+
int8: numElements * 1,
|
|
214
|
+
};
|
|
215
|
+
}
|