mambacode.js 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ /**
2
+ * gpu_utils.js – WebGPU device management and buffer helpers.
3
+ *
4
+ * Provides thin, consistent wrappers around the WebGPU API so that
5
+ * the rest of MambaCode.js never calls navigator.gpu directly.
6
+ */
7
+
8
+ /**
9
+ * Initialise WebGPU and return the { device, adapter } pair.
10
+ *
11
+ * @param {{ powerPreference?: 'high-performance'|'low-power' }} [opts]
12
+ * @returns {Promise<{ device: GPUDevice, adapter: GPUAdapter }>}
13
+ */
14
+ export async function initWebGPU(opts = {}) {
15
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
16
+ throw new Error(
17
+ 'WebGPU is not available in this environment. ' +
18
+ 'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.'
19
+ );
20
+ }
21
+
22
+ const adapter = await navigator.gpu.requestAdapter({
23
+ powerPreference: opts.powerPreference ?? 'high-performance',
24
+ });
25
+
26
+ if (!adapter) {
27
+ throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
28
+ }
29
+
30
+ // Request a device, capping requested limits to what the adapter supports.
31
+ const adapterLimits = adapter.limits;
32
+ const requested3GB = 3 * 1024 * 1024 * 1024;
33
+ const device = await adapter.requestDevice({
34
+ requiredLimits: {
35
+ maxBufferSize: Math.min(
36
+ requested3GB,
37
+ adapterLimits.maxBufferSize
38
+ ),
39
+ maxStorageBufferBindingSize: Math.min(
40
+ requested3GB,
41
+ adapterLimits.maxStorageBufferBindingSize
42
+ ),
43
+ maxComputeInvocationsPerWorkgroup: Math.min(
44
+ 256,
45
+ adapterLimits.maxComputeInvocationsPerWorkgroup
46
+ ),
47
+ },
48
+ });
49
+
50
+ device.lost.then((info) => {
51
+ console.error('WebGPU device lost:', info.message);
52
+ });
53
+
54
+ return { device, adapter };
55
+ }
56
+
57
+ // ─── Buffer factory helpers ───────────────────────────────────────────────────
58
+
59
+ const UNIFORM = GPUBufferUsage?.UNIFORM ?? 0x40;
60
+ const STORAGE = GPUBufferUsage?.STORAGE ?? 0x80;
61
+ const COPY_SRC = GPUBufferUsage?.COPY_SRC ?? 0x04;
62
+ const COPY_DST = GPUBufferUsage?.COPY_DST ?? 0x08;
63
+ const MAP_READ = GPUBufferUsage?.MAP_READ ?? 0x01;
64
+
65
+ /**
66
+ * Create a GPU storage buffer pre-filled with Float32 data.
67
+ *
68
+ * @param {GPUDevice} device
69
+ * @param {Float32Array|number[]} data
70
+ * @param {boolean} [readable=false] Also attach COPY_SRC so it can be read back.
71
+ * @returns {GPUBuffer}
72
+ */
73
+ export function createStorageBuffer(device, data, readable = false) {
74
+ const arr = data instanceof Float32Array ? data : new Float32Array(data);
75
+ const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
76
+ const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
77
+ new Float32Array(buffer.getMappedRange()).set(arr);
78
+ buffer.unmap();
79
+ return buffer;
80
+ }
81
+
82
+ /**
83
+ * Create a GPU storage buffer of `size` bytes, zeroed.
84
+ *
85
+ * @param {GPUDevice} device
86
+ * @param {number} byteSize
87
+ * @param {boolean} [readable=false]
88
+ * @returns {GPUBuffer}
89
+ */
90
+ export function createEmptyStorageBuffer(device, byteSize, readable = false) {
91
+ const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
92
+ return device.createBuffer({ size: byteSize, usage });
93
+ }
94
+
95
+ /**
96
+ * Create a uniform buffer for a plain-old-data struct.
97
+ * The caller must supply a correctly-packed ArrayBuffer / TypedArray.
98
+ *
99
+ * @param {GPUDevice} device
100
+ * @param {ArrayBuffer|TypedArray} data
101
+ * @returns {GPUBuffer}
102
+ */
103
+ export function createUniformBuffer(device, data) {
104
+ const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
105
+ const buffer = device.createBuffer({
106
+ size : bytes.byteLength,
107
+ usage : UNIFORM | COPY_DST,
108
+ mappedAtCreation: true,
109
+ });
110
+ new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
111
+ buffer.unmap();
112
+ return buffer;
113
+ }
114
+
115
+ /**
116
+ * Read back a GPU storage buffer to a Float32Array (async, for debugging/eval).
117
+ *
118
+ * @param {GPUDevice} device
119
+ * @param {GPUBuffer} srcBuffer Must have COPY_SRC usage.
120
+ * @param {number} byteSize
121
+ * @returns {Promise<Float32Array>}
122
+ */
123
+ export async function readBuffer(device, srcBuffer, byteSize) {
124
+ const stagingBuffer = device.createBuffer({
125
+ size : byteSize,
126
+ usage : MAP_READ | COPY_DST,
127
+ });
128
+
129
+ const encoder = device.createCommandEncoder();
130
+ encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
131
+ device.queue.submit([encoder.finish()]);
132
+
133
+ await stagingBuffer.mapAsync(GPUMapMode?.READ ?? 0x01);
134
+ const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
135
+ stagingBuffer.unmap();
136
+ stagingBuffer.destroy();
137
+ return result;
138
+ }
139
+
140
+ /**
141
+ * Upload a Float32Array to an existing GPU buffer.
142
+ *
143
+ * @param {GPUDevice} device
144
+ * @param {GPUBuffer} buffer Must have COPY_DST usage.
145
+ * @param {Float32Array} data
146
+ * @param {number} [byteOffset=0]
147
+ */
148
+ export function uploadBuffer(device, buffer, data, byteOffset = 0) {
149
+ device.queue.writeBuffer(buffer, byteOffset, data);
150
+ }
151
+
152
+ // ─── Pipeline / Shader helpers ────────────────────────────────────────────────
153
+
154
+ /**
155
+ * Compile a WGSL compute shader and return a GPUComputePipeline.
156
+ *
157
+ * @param {GPUDevice} device
158
+ * @param {string} wgslSource
159
+ * @param {string} entryPoint
160
+ * @returns {GPUComputePipeline}
161
+ */
162
+ export function createComputePipeline(device, wgslSource, entryPoint) {
163
+ const shaderModule = device.createShaderModule({ code: wgslSource });
164
+ return device.createComputePipeline({
165
+ layout : 'auto',
166
+ compute: { module: shaderModule, entryPoint },
167
+ });
168
+ }
169
+
170
+ /**
171
+ * Build a GPUBindGroup from an array of GPUBuffer bindings.
172
+ *
173
+ * @param {GPUDevice} device
174
+ * @param {GPUComputePipeline} pipeline
175
+ * @param {GPUBuffer[]} buffers Ordered list matching @binding(i).
176
+ * @param {number} [groupIndex=0]
177
+ * @returns {GPUBindGroup}
178
+ */
179
+ export function createBindGroup(device, pipeline, buffers, groupIndex = 0) {
180
+ const entries = buffers.map((buf, i) => ({
181
+ binding : i,
182
+ resource: { buffer: buf },
183
+ }));
184
+ return device.createBindGroup({
185
+ layout : pipeline.getBindGroupLayout(groupIndex),
186
+ entries,
187
+ });
188
+ }
189
+
190
+ /**
191
+ * Dispatch a compute pipeline synchronously (encodes + submits in one call).
192
+ *
193
+ * @param {GPUDevice} device
194
+ * @param {GPUComputePipeline} pipeline
195
+ * @param {GPUBindGroup} bindGroup
196
+ * @param {[number, number, number]} workgroups [x, y, z]
197
+ */
198
+ export function dispatchKernel(device, pipeline, bindGroup, workgroups) {
199
+ const encoder = device.createCommandEncoder();
200
+ const pass = encoder.beginComputePass();
201
+ pass.setPipeline(pipeline);
202
+ pass.setBindGroup(0, bindGroup);
203
+ pass.dispatchWorkgroups(...workgroups);
204
+ pass.end();
205
+ device.queue.submit([encoder.finish()]);
206
+ }
207
+
208
+ /**
209
+ * Ceil-divide helper: Math.ceil(a / b) in integer arithmetic.
210
+ *
211
+ * @param {number} a
212
+ * @param {number} b
213
+ * @returns {number}
214
+ */
215
+ export function cdiv(a, b) {
216
+ return Math.ceil(a / b);
217
+ }
@@ -0,0 +1,215 @@
1
+ /**
2
+ * quantization.js – FP16 and Int8 quantization utilities.
3
+ *
4
+ * MambaCode.js supports two quantization modes to reduce VRAM usage:
5
+ * • FP16 – weights stored as 16-bit floats (halves memory vs FP32)
6
+ * • Int8 – non-critical activations quantized to signed 8-bit integers
7
+ *
8
+ * All quantization/dequantization happens in JavaScript; the GPU kernels
9
+ * always operate on FP32 tensors internally (dequantized on upload).
10
+ */
11
+
12
+ // ─── FP16 Utilities ──────────────────────────────────────────────────────────
13
+
14
+ /**
15
+ * Convert a 32-bit float to a 16-bit IEEE 754 float (represented as Uint16).
16
+ * Uses bit manipulation to avoid the need for a Float16Array (not in spec yet).
17
+ *
18
+ * @param {number} val – 32-bit float
19
+ * @returns {number} – 16-bit float packed as an integer (0–65535)
20
+ */
21
+ export function floatToFp16(val) {
22
+ const buf = new ArrayBuffer(4);
23
+ const f32 = new Float32Array(buf);
24
+ const u32 = new Uint32Array(buf);
25
+ f32[0] = val;
26
+ const bits = u32[0];
27
+
28
+ const sign = (bits >>> 31) & 0x1;
29
+ const exponent = (bits >>> 23) & 0xFF;
30
+ const mantissa = bits & 0x7FFFFF;
31
+
32
+ if (exponent === 255) {
33
+ // Inf / NaN
34
+ return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
35
+ }
36
+
37
+ const expAdj = exponent - 127 + 15; // re-bias from 127 to 15
38
+
39
+ if (expAdj >= 31) {
40
+ // Overflow → Inf
41
+ return (sign << 15) | 0x7C00;
42
+ }
43
+
44
+ if (expAdj <= 0) {
45
+ // Underflow or denormal
46
+ if (expAdj < -10) { return sign << 15; } // flush to zero
47
+ const shift = 14 - expAdj;
48
+ return (sign << 15) | ((mantissa | 0x800000) >> shift);
49
+ }
50
+
51
+ return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
52
+ }
53
+
54
+ /**
55
+ * Convert a 16-bit FP16 integer to a 32-bit float.
56
+ *
57
+ * @param {number} val – Uint16 representation of an FP16 value
58
+ * @returns {number} – JavaScript number (float64, but semantically float32)
59
+ */
60
+ export function fp16ToFloat(val) {
61
+ const sign = (val >>> 15) & 0x1;
62
+ const exponent = (val >>> 10) & 0x1F;
63
+ const mantissa = val & 0x3FF;
64
+
65
+ if (exponent === 0) {
66
+ // Denormal or zero
67
+ const f = mantissa / 1024.0;
68
+ return sign ? -f : f;
69
+ }
70
+
71
+ if (exponent === 31) {
72
+ // Inf / NaN
73
+ return sign ? -Infinity : (mantissa ? NaN : Infinity);
74
+ }
75
+
76
+ const expUnbiased = exponent - 15;
77
+ const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
78
+ return sign ? -f : f;
79
+ }
80
+
81
+ /**
82
+ * Quantize a Float32Array to FP16 (stored as Uint16Array).
83
+ *
84
+ * @param {Float32Array} f32
85
+ * @returns {Uint16Array}
86
+ */
87
+ export function quantizeFp16(f32) {
88
+ const out = new Uint16Array(f32.length);
89
+ for (let i = 0; i < f32.length; i++) {
90
+ out[i] = floatToFp16(f32[i]);
91
+ }
92
+ return out;
93
+ }
94
+
95
+ /**
96
+ * Dequantize a Uint16Array (FP16) back to Float32Array.
97
+ *
98
+ * @param {Uint16Array} fp16
99
+ * @returns {Float32Array}
100
+ */
101
+ export function dequantizeFp16(fp16) {
102
+ const out = new Float32Array(fp16.length);
103
+ for (let i = 0; i < fp16.length; i++) {
104
+ out[i] = fp16ToFloat(fp16[i]);
105
+ }
106
+ return out;
107
+ }
108
+
109
+ // ─── Int8 Quantization ───────────────────────────────────────────────────────
110
+
111
+ /**
112
+ * Symmetric per-tensor Int8 quantization.
113
+ * Quantization: q = round(x / scale), scale = max(|x|) / 127
114
+ *
115
+ * @param {Float32Array} f32
116
+ * @returns {{ data: Int8Array, scale: number }}
117
+ */
118
+ export function quantizeInt8(f32) {
119
+ let maxAbs = 0;
120
+ for (let i = 0; i < f32.length; i++) {
121
+ const a = Math.abs(f32[i]);
122
+ if (a > maxAbs) maxAbs = a;
123
+ }
124
+
125
+ const scale = maxAbs / 127.0 || 1.0; // avoid division by zero
126
+ const data = new Int8Array(f32.length);
127
+
128
+ for (let i = 0; i < f32.length; i++) {
129
+ data[i] = Math.max(-128, Math.min(127, Math.round(f32[i] / scale)));
130
+ }
131
+
132
+ return { data, scale };
133
+ }
134
+
135
+ /**
136
+ * Dequantize an Int8Array back to Float32Array.
137
+ *
138
+ * @param {Int8Array} int8
139
+ * @param {number} scale
140
+ * @returns {Float32Array}
141
+ */
142
+ export function dequantizeInt8(int8, scale) {
143
+ const out = new Float32Array(int8.length);
144
+ for (let i = 0; i < int8.length; i++) {
145
+ out[i] = int8[i] * scale;
146
+ }
147
+ return out;
148
+ }
149
+
150
+ /**
151
+ * Per-channel Int8 quantization (useful for weight matrices).
152
+ * Each output channel gets its own scale factor for better accuracy.
153
+ *
154
+ * @param {Float32Array} f32 – Flat weight tensor, row-major
155
+ * @param {number} numChannels – Number of output channels (rows)
156
+ * @returns {{ data: Int8Array, scales: Float32Array }}
157
+ */
158
+ export function quantizeInt8PerChannel(f32, numChannels) {
159
+ const channelSize = f32.length / numChannels;
160
+ const scales = new Float32Array(numChannels);
161
+ const data = new Int8Array(f32.length);
162
+
163
+ for (let c = 0; c < numChannels; c++) {
164
+ let maxAbs = 0;
165
+ const base = c * channelSize;
166
+ for (let j = 0; j < channelSize; j++) {
167
+ const a = Math.abs(f32[base + j]);
168
+ if (a > maxAbs) maxAbs = a;
169
+ }
170
+ scales[c] = maxAbs / 127.0 || 1.0;
171
+ for (let j = 0; j < channelSize; j++) {
172
+ data[base + j] = Math.max(-128, Math.min(127,
173
+ Math.round(f32[base + j] / scales[c])
174
+ ));
175
+ }
176
+ }
177
+
178
+ return { data, scales };
179
+ }
180
+
181
+ /**
182
+ * Dequantize per-channel Int8 data.
183
+ *
184
+ * @param {Int8Array} int8
185
+ * @param {Float32Array} scales
186
+ * @param {number} numChannels
187
+ * @returns {Float32Array}
188
+ */
189
+ export function dequantizeInt8PerChannel(int8, scales, numChannels) {
190
+ const channelSize = int8.length / numChannels;
191
+ const out = new Float32Array(int8.length);
192
+
193
+ for (let c = 0; c < numChannels; c++) {
194
+ const base = c * channelSize;
195
+ for (let j = 0; j < channelSize; j++) {
196
+ out[base + j] = int8[base + j] * scales[c];
197
+ }
198
+ }
199
+
200
+ return out;
201
+ }
202
+
203
+ /**
204
+ * Estimate memory usage for a weight tensor under different precisions.
205
+ *
206
+ * @param {number} numElements
207
+ * @returns {{ fp32: number, fp16: number, int8: number }} – bytes
208
+ */
209
+ export function estimateMemory(numElements) {
210
+ return {
211
+ fp32: numElements * 4,
212
+ fp16: numElements * 2,
213
+ int8: numElements * 1,
214
+ };
215
+ }