mambacode.js 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +18 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +59 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
  61. package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/trainer.ts +312 -0
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/training/trainer.js +0 -394
  71. package/src/utils/gpu_utils.js +0 -217
  72. package/src/utils/quantization.js +0 -215
  73. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -0,0 +1,154 @@
1
+ /**
2
+ * quantization.ts – FP16 and Int8 quantization utilities.
3
+ */
4
+
5
+ export interface QuantizeInt8Result {
6
+ data: Int8Array;
7
+ scale: number;
8
+ }
9
+
10
+ export interface QuantizeInt8PerChannelResult {
11
+ data: Int8Array;
12
+ scales: Float32Array;
13
+ }
14
+
15
+ export interface MemoryEstimate {
16
+ fp32: number;
17
+ fp16: number;
18
+ int8: number;
19
+ }
20
+
21
+ export function floatToFp16(val: number): number {
22
+ const buf = new ArrayBuffer(4);
23
+ const f32 = new Float32Array(buf);
24
+ const u32 = new Uint32Array(buf);
25
+ f32[0] = val;
26
+ const bits = u32[0]!;
27
+
28
+ const sign = (bits >>> 31) & 0x1;
29
+ const exponent = (bits >>> 23) & 0xFF;
30
+ const mantissa = bits & 0x7FFFFF;
31
+
32
+ if (exponent === 255) {
33
+ return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
34
+ }
35
+
36
+ const expAdj = exponent - 127 + 15;
37
+
38
+ if (expAdj >= 31) {
39
+ return (sign << 15) | 0x7C00;
40
+ }
41
+
42
+ if (expAdj <= 0) {
43
+ if (expAdj < -10) { return sign << 15; }
44
+ const shift = 14 - expAdj;
45
+ return (sign << 15) | ((mantissa | 0x800000) >> shift);
46
+ }
47
+
48
+ return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
49
+ }
50
+
51
+ export function fp16ToFloat(val: number): number {
52
+ const sign = (val >>> 15) & 0x1;
53
+ const exponent = (val >>> 10) & 0x1F;
54
+ const mantissa = val & 0x3FF;
55
+
56
+ if (exponent === 0) {
57
+ const f = mantissa / 1024.0;
58
+ return sign ? -f : f;
59
+ }
60
+
61
+ if (exponent === 31) {
62
+ return sign ? -Infinity : (mantissa ? NaN : Infinity);
63
+ }
64
+
65
+ const expUnbiased = exponent - 15;
66
+ const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
67
+ return sign ? -f : f;
68
+ }
69
+
70
+ export function quantizeFp16(f32: Float32Array): Uint16Array {
71
+ const out = new Uint16Array(f32.length);
72
+ for (let i = 0; i < f32.length; i++) {
73
+ out[i] = floatToFp16(f32[i]!);
74
+ }
75
+ return out;
76
+ }
77
+
78
+ export function dequantizeFp16(fp16: Uint16Array): Float32Array {
79
+ const out = new Float32Array(fp16.length);
80
+ for (let i = 0; i < fp16.length; i++) {
81
+ out[i] = fp16ToFloat(fp16[i]!);
82
+ }
83
+ return out;
84
+ }
85
+
86
+ export function quantizeInt8(f32: Float32Array): QuantizeInt8Result {
87
+ let maxAbs = 0;
88
+ for (let i = 0; i < f32.length; i++) {
89
+ const a = Math.abs(f32[i]!);
90
+ if (a > maxAbs) maxAbs = a;
91
+ }
92
+
93
+ const scale = maxAbs / 127.0 || 1.0;
94
+ const data = new Int8Array(f32.length);
95
+
96
+ for (let i = 0; i < f32.length; i++) {
97
+ data[i] = Math.max(-128, Math.min(127, Math.round(f32[i]! / scale)));
98
+ }
99
+
100
+ return { data, scale };
101
+ }
102
+
103
+ export function dequantizeInt8(int8: Int8Array, scale: number): Float32Array {
104
+ const out = new Float32Array(int8.length);
105
+ for (let i = 0; i < int8.length; i++) {
106
+ out[i] = int8[i]! * scale;
107
+ }
108
+ return out;
109
+ }
110
+
111
+ export function quantizeInt8PerChannel(f32: Float32Array, numChannels: number): QuantizeInt8PerChannelResult {
112
+ const channelSize = f32.length / numChannels;
113
+ const scales = new Float32Array(numChannels);
114
+ const data = new Int8Array(f32.length);
115
+
116
+ for (let c = 0; c < numChannels; c++) {
117
+ let maxAbs = 0;
118
+ const base = c * channelSize;
119
+ for (let j = 0; j < channelSize; j++) {
120
+ const a = Math.abs(f32[base + j]!);
121
+ if (a > maxAbs) maxAbs = a;
122
+ }
123
+ scales[c] = maxAbs / 127.0 || 1.0;
124
+ for (let j = 0; j < channelSize; j++) {
125
+ data[base + j] = Math.max(-128, Math.min(127,
126
+ Math.round(f32[base + j]! / scales[c]!)
127
+ ));
128
+ }
129
+ }
130
+
131
+ return { data, scales };
132
+ }
133
+
134
+ export function dequantizeInt8PerChannel(int8: Int8Array, scales: Float32Array, numChannels: number): Float32Array {
135
+ const channelSize = int8.length / numChannels;
136
+ const out = new Float32Array(int8.length);
137
+
138
+ for (let c = 0; c < numChannels; c++) {
139
+ const base = c * channelSize;
140
+ for (let j = 0; j < channelSize; j++) {
141
+ out[base + j] = int8[base + j]! * scales[c]!;
142
+ }
143
+ }
144
+
145
+ return out;
146
+ }
147
+
148
+ export function estimateMemory(numElements: number): MemoryEstimate {
149
+ return {
150
+ fp32: numElements * 4,
151
+ fp16: numElements * 2,
152
+ int8: numElements * 1,
153
+ };
154
+ }
package/src/index.js DELETED
@@ -1,89 +0,0 @@
1
- /**
2
- * MambaCode.js – Entry Point
3
- *
4
- * High-performance JavaScript/WGSL Mamba SSM library for browser-based
5
- * code model training and inference.
6
- *
7
- * Quick-start example
8
- * -------------------
9
- * import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from 'mambacode.js';
10
- *
11
- * const { device } = await initWebGPU();
12
- * const tokenizer = new BPETokenizer();
13
- * await tokenizer.load('/vocab.json', '/merges.txt');
14
- *
15
- * const model = new MambaModel(device, {
16
- * vocabSize : tokenizer.vocabSize,
17
- * dModel : 512,
18
- * numLayers : 8,
19
- * });
20
- *
21
- * const trainer = new MambaTrainer(model, tokenizer);
22
- * const losses = await trainer.train(myCodeString, { learningRate: 1e-4, epochs: 5 });
23
- *
24
- * const generated = await model.generate(tokenizer.encode('function '), 100);
25
- * console.log(tokenizer.decode(generated));
26
- */
27
-
28
- // ── Core model ────────────────────────────────────────────────────────────────
29
- export { MambaModel } from './model/mamba_model.js';
30
- export { MambaBlock } from './model/mamba_block.js';
31
-
32
- // ── Training ──────────────────────────────────────────────────────────────────
33
- export { MambaTrainer } from './training/trainer.js';
34
- export {
35
- Tensor,
36
- backward,
37
- enableGrad,
38
- noGrad,
39
- clearTape,
40
- recordOperation,
41
- crossEntropyLoss,
42
- crossEntropyGrad,
43
- } from './training/autograd.js';
44
-
45
- // ── Tokenizer ─────────────────────────────────────────────────────────────────
46
- export { BPETokenizer } from './tokenizer/bpe.js';
47
-
48
- // ── WebGPU utilities ──────────────────────────────────────────────────────────
49
- export {
50
- initWebGPU,
51
- createStorageBuffer,
52
- createEmptyStorageBuffer,
53
- createUniformBuffer,
54
- createComputePipeline,
55
- createBindGroup,
56
- dispatchKernel,
57
- readBuffer,
58
- uploadBuffer,
59
- cdiv,
60
- } from './utils/gpu_utils.js';
61
-
62
- // ── Quantization utilities ────────────────────────────────────────────────────
63
- export {
64
- quantizeFp16,
65
- dequantizeFp16,
66
- floatToFp16,
67
- fp16ToFloat,
68
- quantizeInt8,
69
- dequantizeInt8,
70
- quantizeInt8PerChannel,
71
- dequantizeInt8PerChannel,
72
- estimateMemory,
73
- } from './utils/quantization.js';
74
-
75
- // ── Raw WGSL kernel sources (for advanced users / custom pipelines) ───────────
76
- export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL }
77
- from './kernels/selective_scan.js';
78
- export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL }
79
- from './kernels/conv1d.js';
80
- export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL }
81
- from './kernels/linear_projection.js';
82
- export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL }
83
- from './kernels/weight_update.js';
84
- export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL }
85
- from './kernels/activations.js';
86
-
87
- // ── Library metadata ──────────────────────────────────────────────────────────
88
- export const VERSION = '0.1.0';
89
- export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
@@ -1,256 +0,0 @@
1
- /**
2
- * bpe.js – Browser-side Byte Pair Encoding (BPE) tokenizer.
3
- *
4
- * Compatible with Qwen3.5-Coder's vocabulary and merge rules.
5
- * Implements the standard Hugging Face / tiktoken BPE algorithm:
6
- *
7
- * 1. Pre-tokenize text into Unicode code point sequences.
8
- * 2. Apply byte-level encoding for unknown characters.
9
- * 3. Iteratively merge the highest-priority adjacent byte-pair
10
- * according to the learnt merge table.
11
- *
12
- * Usage
13
- * -----
14
- * const tokenizer = new BPETokenizer();
15
- * await tokenizer.load(vocabUrl, mergesUrl);
16
- *
17
- * const ids = tokenizer.encode("function foo() {}");
18
- * const text = tokenizer.decode(ids);
19
- */
20
-
21
- // Byte-level fallback map (matches GPT-2 / Qwen convention).
22
- // Maps raw bytes 0-255 to printable Unicode characters so that every
23
- // byte sequence has a valid string representation.
24
- function buildByteEncoder() {
25
- /** @type {Map<number, string>} */
26
- const enc = new Map();
27
- const ranges = [
28
- [0x21, 0x7E], // ! → ~
29
- [0xA1, 0xAC], // ¡ → ¬
30
- [0xAE, 0xFF], // ® → ÿ
31
- ];
32
- let n = 0;
33
- for (const [lo, hi] of ranges) {
34
- for (let b = lo; b <= hi; b++) {
35
- enc.set(b, String.fromCodePoint(b));
36
- }
37
- }
38
- for (let b = 0; b < 256; b++) {
39
- if (!enc.has(b)) {
40
- enc.set(b, String.fromCodePoint(256 + n));
41
- n++;
42
- }
43
- }
44
- return enc;
45
- }
46
-
47
- const BYTE_ENCODER = buildByteEncoder();
48
- const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
49
-
50
- // Pre-tokenisation regex: matches words, numbers, punctuation, whitespace
51
- // (closely mirrors tiktoken's GPT-2/Qwen pre-tokenizer pattern).
52
- const PRE_TOKENIZE_RE =
53
- /(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
54
-
55
- export class BPETokenizer {
56
- constructor() {
57
- /** @type {Map<string, number>} token → id */
58
- this.vocab = new Map();
59
- /** @type {Map<number, string>} id → token */
60
- this.idToToken = new Map();
61
- /** @type {Map<string, number>} "a b" → rank (lower = higher priority) */
62
- this.merges = new Map();
63
-
64
- // Special tokens (added after loading vocab)
65
- this.bosToken = '<|im_start|>';
66
- this.eosToken = '<|im_end|>';
67
- this.padToken = '<|endoftext|>';
68
- this.unkToken = '<unk>';
69
-
70
- this.bosId = null;
71
- this.eosId = null;
72
- this.padId = null;
73
- }
74
-
75
- /**
76
- * Load vocabulary and merge rules from JSON / text URLs.
77
- *
78
- * vocab.json – { "token": id, ... }
79
- * merges.txt – one merge per line: "a b" (sorted by rank)
80
- *
81
- * @param {string|object} vocab – URL string or pre-parsed vocab object
82
- * @param {string|string[]} merges – URL string or array of merge strings
83
- * @returns {Promise<void>}
84
- */
85
- async load(vocab, merges) {
86
- // Load vocab
87
- let vocabObj;
88
- if (typeof vocab === 'string') {
89
- const res = await fetch(vocab);
90
- vocabObj = await res.json();
91
- } else {
92
- vocabObj = vocab;
93
- }
94
- this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
95
- this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
96
-
97
- // Load merges
98
- let mergeLines;
99
- if (typeof merges === 'string') {
100
- const res = await fetch(merges);
101
- const txt = await res.text();
102
- mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
103
- } else {
104
- mergeLines = merges;
105
- }
106
- this.merges = new Map();
107
- mergeLines.forEach((line, rank) => {
108
- this.merges.set(line.trim(), rank);
109
- });
110
-
111
- // Resolve special token ids
112
- this.bosId = this.vocab.get(this.bosToken) ?? null;
113
- this.eosId = this.vocab.get(this.eosToken) ?? null;
114
- this.padId = this.vocab.get(this.padToken) ?? null;
115
- }
116
-
117
- /**
118
- * Load from plain JavaScript objects (no network fetch).
119
- * Useful for bundling a small vocabulary directly.
120
- *
121
- * @param {Object} vocabObj – { token: id }
122
- * @param {string[]} mergeArr – ["a b", "c d", ...]
123
- */
124
- loadFromObjects(vocabObj, mergeArr) {
125
- this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
126
- this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
127
- this.merges = new Map(mergeArr.map((m, i) => [m, i]));
128
- this.bosId = this.vocab.get(this.bosToken) ?? null;
129
- this.eosId = this.vocab.get(this.eosToken) ?? null;
130
- this.padId = this.vocab.get(this.padToken) ?? null;
131
- }
132
-
133
- // ── Encoding ──────────────────────────────────────────────────────────────
134
-
135
- /**
136
- * Encode a string to an array of token IDs.
137
- *
138
- * @param {string} text
139
- * @param {{ addBos?: boolean, addEos?: boolean }} [opts]
140
- * @returns {number[]}
141
- */
142
- encode(text, opts = {}) {
143
- const words = text.match(PRE_TOKENIZE_RE) ?? [];
144
- const ids = [];
145
-
146
- if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
147
-
148
- for (const word of words) {
149
- // Convert to byte-level encoding
150
- const bytes = new TextEncoder().encode(word);
151
- const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
152
- const bpeTokens = this._bpe(byteStr);
153
-
154
- for (const tok of bpeTokens) {
155
- const id = this.vocab.get(tok);
156
- if (id !== undefined) {
157
- ids.push(id);
158
- } else {
159
- // Fallback: encode each character individually
160
- for (const ch of tok) {
161
- const cid = this.vocab.get(ch);
162
- if (cid !== undefined) ids.push(cid);
163
- }
164
- }
165
- }
166
- }
167
-
168
- if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
169
- return ids;
170
- }
171
-
172
- /**
173
- * Decode an array of token IDs back to a string.
174
- *
175
- * @param {number[]} ids
176
- * @returns {string}
177
- */
178
- decode(ids) {
179
- let byteStr = '';
180
- for (const id of ids) {
181
- const tok = this.idToToken.get(id);
182
- if (tok !== undefined) byteStr += tok;
183
- }
184
- // Convert byte-level string back to raw bytes then UTF-8 decode
185
- const bytes = new Uint8Array(
186
- [...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0))
187
- );
188
- try {
189
- return new TextDecoder('utf-8').decode(bytes);
190
- } catch {
191
- return byteStr;
192
- }
193
- }
194
-
195
- // ── BPE merge algorithm ───────────────────────────────────────────────────
196
-
197
- /**
198
- * Apply BPE merges to a byte-encoded word.
199
- *
200
- * @param {string} word – Space-free string of byte-level characters
201
- * @returns {string[]} – Merged token pieces
202
- */
203
- _bpe(word) {
204
- if (this.vocab.has(word)) return [word];
205
-
206
- // Start: each character is its own symbol
207
- let symbols = [...word];
208
-
209
- // Iteratively merge the highest-priority pair
210
- while (symbols.length > 1) {
211
- let bestRank = Infinity;
212
- let bestIdx = -1;
213
-
214
- for (let i = 0; i < symbols.length - 1; i++) {
215
- const pair = symbols[i] + ' ' + symbols[i + 1];
216
- const rank = this.merges.get(pair);
217
- if (rank !== undefined && rank < bestRank) {
218
- bestRank = rank;
219
- bestIdx = i;
220
- }
221
- }
222
-
223
- if (bestIdx === -1) break; // no more merges available
224
-
225
- // Merge pair at bestIdx
226
- const merged = symbols[bestIdx] + symbols[bestIdx + 1];
227
- symbols = [
228
- ...symbols.slice(0, bestIdx),
229
- merged,
230
- ...symbols.slice(bestIdx + 2),
231
- ];
232
- }
233
-
234
- return symbols;
235
- }
236
-
237
- // ── Padding / truncation helpers ─────────────────────────────────────────
238
-
239
- /**
240
- * Pad or truncate a sequence to a fixed length.
241
- *
242
- * @param {number[]} ids
243
- * @param {number} maxLen
244
- * @param {'right'|'left'} [side='right']
245
- * @returns {number[]}
246
- */
247
- padOrTruncate(ids, maxLen, side = 'right') {
248
- if (ids.length >= maxLen) return ids.slice(0, maxLen);
249
- const padId = this.padId ?? 0;
250
- const pad = new Array(maxLen - ids.length).fill(padId);
251
- return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
252
- }
253
-
254
- /** @returns {number} */
255
- get vocabSize() { return this.vocab.size; }
256
- }