mambacode.js 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +59 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
- package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +312 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/training/trainer.js +0 -394
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* quantization.ts – FP16 and Int8 quantization utilities.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export interface QuantizeInt8Result {
|
|
6
|
+
data: Int8Array;
|
|
7
|
+
scale: number;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
export interface QuantizeInt8PerChannelResult {
|
|
11
|
+
data: Int8Array;
|
|
12
|
+
scales: Float32Array;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
export interface MemoryEstimate {
|
|
16
|
+
fp32: number;
|
|
17
|
+
fp16: number;
|
|
18
|
+
int8: number;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export function floatToFp16(val: number): number {
|
|
22
|
+
const buf = new ArrayBuffer(4);
|
|
23
|
+
const f32 = new Float32Array(buf);
|
|
24
|
+
const u32 = new Uint32Array(buf);
|
|
25
|
+
f32[0] = val;
|
|
26
|
+
const bits = u32[0]!;
|
|
27
|
+
|
|
28
|
+
const sign = (bits >>> 31) & 0x1;
|
|
29
|
+
const exponent = (bits >>> 23) & 0xFF;
|
|
30
|
+
const mantissa = bits & 0x7FFFFF;
|
|
31
|
+
|
|
32
|
+
if (exponent === 255) {
|
|
33
|
+
return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const expAdj = exponent - 127 + 15;
|
|
37
|
+
|
|
38
|
+
if (expAdj >= 31) {
|
|
39
|
+
return (sign << 15) | 0x7C00;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if (expAdj <= 0) {
|
|
43
|
+
if (expAdj < -10) { return sign << 15; }
|
|
44
|
+
const shift = 14 - expAdj;
|
|
45
|
+
return (sign << 15) | ((mantissa | 0x800000) >> shift);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
export function fp16ToFloat(val: number): number {
|
|
52
|
+
const sign = (val >>> 15) & 0x1;
|
|
53
|
+
const exponent = (val >>> 10) & 0x1F;
|
|
54
|
+
const mantissa = val & 0x3FF;
|
|
55
|
+
|
|
56
|
+
if (exponent === 0) {
|
|
57
|
+
const f = mantissa / 1024.0;
|
|
58
|
+
return sign ? -f : f;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if (exponent === 31) {
|
|
62
|
+
return sign ? -Infinity : (mantissa ? NaN : Infinity);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
const expUnbiased = exponent - 15;
|
|
66
|
+
const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
|
|
67
|
+
return sign ? -f : f;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export function quantizeFp16(f32: Float32Array): Uint16Array {
|
|
71
|
+
const out = new Uint16Array(f32.length);
|
|
72
|
+
for (let i = 0; i < f32.length; i++) {
|
|
73
|
+
out[i] = floatToFp16(f32[i]!);
|
|
74
|
+
}
|
|
75
|
+
return out;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
export function dequantizeFp16(fp16: Uint16Array): Float32Array {
|
|
79
|
+
const out = new Float32Array(fp16.length);
|
|
80
|
+
for (let i = 0; i < fp16.length; i++) {
|
|
81
|
+
out[i] = fp16ToFloat(fp16[i]!);
|
|
82
|
+
}
|
|
83
|
+
return out;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
export function quantizeInt8(f32: Float32Array): QuantizeInt8Result {
|
|
87
|
+
let maxAbs = 0;
|
|
88
|
+
for (let i = 0; i < f32.length; i++) {
|
|
89
|
+
const a = Math.abs(f32[i]!);
|
|
90
|
+
if (a > maxAbs) maxAbs = a;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
const scale = maxAbs / 127.0 || 1.0;
|
|
94
|
+
const data = new Int8Array(f32.length);
|
|
95
|
+
|
|
96
|
+
for (let i = 0; i < f32.length; i++) {
|
|
97
|
+
data[i] = Math.max(-128, Math.min(127, Math.round(f32[i]! / scale)));
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
return { data, scale };
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
export function dequantizeInt8(int8: Int8Array, scale: number): Float32Array {
|
|
104
|
+
const out = new Float32Array(int8.length);
|
|
105
|
+
for (let i = 0; i < int8.length; i++) {
|
|
106
|
+
out[i] = int8[i]! * scale;
|
|
107
|
+
}
|
|
108
|
+
return out;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
export function quantizeInt8PerChannel(f32: Float32Array, numChannels: number): QuantizeInt8PerChannelResult {
|
|
112
|
+
const channelSize = f32.length / numChannels;
|
|
113
|
+
const scales = new Float32Array(numChannels);
|
|
114
|
+
const data = new Int8Array(f32.length);
|
|
115
|
+
|
|
116
|
+
for (let c = 0; c < numChannels; c++) {
|
|
117
|
+
let maxAbs = 0;
|
|
118
|
+
const base = c * channelSize;
|
|
119
|
+
for (let j = 0; j < channelSize; j++) {
|
|
120
|
+
const a = Math.abs(f32[base + j]!);
|
|
121
|
+
if (a > maxAbs) maxAbs = a;
|
|
122
|
+
}
|
|
123
|
+
scales[c] = maxAbs / 127.0 || 1.0;
|
|
124
|
+
for (let j = 0; j < channelSize; j++) {
|
|
125
|
+
data[base + j] = Math.max(-128, Math.min(127,
|
|
126
|
+
Math.round(f32[base + j]! / scales[c]!)
|
|
127
|
+
));
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return { data, scales };
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
export function dequantizeInt8PerChannel(int8: Int8Array, scales: Float32Array, numChannels: number): Float32Array {
|
|
135
|
+
const channelSize = int8.length / numChannels;
|
|
136
|
+
const out = new Float32Array(int8.length);
|
|
137
|
+
|
|
138
|
+
for (let c = 0; c < numChannels; c++) {
|
|
139
|
+
const base = c * channelSize;
|
|
140
|
+
for (let j = 0; j < channelSize; j++) {
|
|
141
|
+
out[base + j] = int8[base + j]! * scales[c]!;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return out;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export function estimateMemory(numElements: number): MemoryEstimate {
|
|
149
|
+
return {
|
|
150
|
+
fp32: numElements * 4,
|
|
151
|
+
fp16: numElements * 2,
|
|
152
|
+
int8: numElements * 1,
|
|
153
|
+
};
|
|
154
|
+
}
|
package/src/index.js
DELETED
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* MambaCode.js – Entry Point
|
|
3
|
-
*
|
|
4
|
-
* High-performance JavaScript/WGSL Mamba SSM library for browser-based
|
|
5
|
-
* code model training and inference.
|
|
6
|
-
*
|
|
7
|
-
* Quick-start example
|
|
8
|
-
* -------------------
|
|
9
|
-
* import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from 'mambacode.js';
|
|
10
|
-
*
|
|
11
|
-
* const { device } = await initWebGPU();
|
|
12
|
-
* const tokenizer = new BPETokenizer();
|
|
13
|
-
* await tokenizer.load('/vocab.json', '/merges.txt');
|
|
14
|
-
*
|
|
15
|
-
* const model = new MambaModel(device, {
|
|
16
|
-
* vocabSize : tokenizer.vocabSize,
|
|
17
|
-
* dModel : 512,
|
|
18
|
-
* numLayers : 8,
|
|
19
|
-
* });
|
|
20
|
-
*
|
|
21
|
-
* const trainer = new MambaTrainer(model, tokenizer);
|
|
22
|
-
* const losses = await trainer.train(myCodeString, { learningRate: 1e-4, epochs: 5 });
|
|
23
|
-
*
|
|
24
|
-
* const generated = await model.generate(tokenizer.encode('function '), 100);
|
|
25
|
-
* console.log(tokenizer.decode(generated));
|
|
26
|
-
*/
|
|
27
|
-
|
|
28
|
-
// ── Core model ────────────────────────────────────────────────────────────────
|
|
29
|
-
export { MambaModel } from './model/mamba_model.js';
|
|
30
|
-
export { MambaBlock } from './model/mamba_block.js';
|
|
31
|
-
|
|
32
|
-
// ── Training ──────────────────────────────────────────────────────────────────
|
|
33
|
-
export { MambaTrainer } from './training/trainer.js';
|
|
34
|
-
export {
|
|
35
|
-
Tensor,
|
|
36
|
-
backward,
|
|
37
|
-
enableGrad,
|
|
38
|
-
noGrad,
|
|
39
|
-
clearTape,
|
|
40
|
-
recordOperation,
|
|
41
|
-
crossEntropyLoss,
|
|
42
|
-
crossEntropyGrad,
|
|
43
|
-
} from './training/autograd.js';
|
|
44
|
-
|
|
45
|
-
// ── Tokenizer ─────────────────────────────────────────────────────────────────
|
|
46
|
-
export { BPETokenizer } from './tokenizer/bpe.js';
|
|
47
|
-
|
|
48
|
-
// ── WebGPU utilities ──────────────────────────────────────────────────────────
|
|
49
|
-
export {
|
|
50
|
-
initWebGPU,
|
|
51
|
-
createStorageBuffer,
|
|
52
|
-
createEmptyStorageBuffer,
|
|
53
|
-
createUniformBuffer,
|
|
54
|
-
createComputePipeline,
|
|
55
|
-
createBindGroup,
|
|
56
|
-
dispatchKernel,
|
|
57
|
-
readBuffer,
|
|
58
|
-
uploadBuffer,
|
|
59
|
-
cdiv,
|
|
60
|
-
} from './utils/gpu_utils.js';
|
|
61
|
-
|
|
62
|
-
// ── Quantization utilities ────────────────────────────────────────────────────
|
|
63
|
-
export {
|
|
64
|
-
quantizeFp16,
|
|
65
|
-
dequantizeFp16,
|
|
66
|
-
floatToFp16,
|
|
67
|
-
fp16ToFloat,
|
|
68
|
-
quantizeInt8,
|
|
69
|
-
dequantizeInt8,
|
|
70
|
-
quantizeInt8PerChannel,
|
|
71
|
-
dequantizeInt8PerChannel,
|
|
72
|
-
estimateMemory,
|
|
73
|
-
} from './utils/quantization.js';
|
|
74
|
-
|
|
75
|
-
// ── Raw WGSL kernel sources (for advanced users / custom pipelines) ───────────
|
|
76
|
-
export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL }
|
|
77
|
-
from './kernels/selective_scan.js';
|
|
78
|
-
export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL }
|
|
79
|
-
from './kernels/conv1d.js';
|
|
80
|
-
export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL }
|
|
81
|
-
from './kernels/linear_projection.js';
|
|
82
|
-
export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL }
|
|
83
|
-
from './kernels/weight_update.js';
|
|
84
|
-
export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL }
|
|
85
|
-
from './kernels/activations.js';
|
|
86
|
-
|
|
87
|
-
// ── Library metadata ──────────────────────────────────────────────────────────
|
|
88
|
-
export const VERSION = '0.1.0';
|
|
89
|
-
export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
|
package/src/tokenizer/bpe.js
DELETED
|
@@ -1,256 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* bpe.js – Browser-side Byte Pair Encoding (BPE) tokenizer.
|
|
3
|
-
*
|
|
4
|
-
* Compatible with Qwen3.5-Coder's vocabulary and merge rules.
|
|
5
|
-
* Implements the standard Hugging Face / tiktoken BPE algorithm:
|
|
6
|
-
*
|
|
7
|
-
* 1. Pre-tokenize text into Unicode code point sequences.
|
|
8
|
-
* 2. Apply byte-level encoding for unknown characters.
|
|
9
|
-
* 3. Iteratively merge the highest-priority adjacent byte-pair
|
|
10
|
-
* according to the learnt merge table.
|
|
11
|
-
*
|
|
12
|
-
* Usage
|
|
13
|
-
* -----
|
|
14
|
-
* const tokenizer = new BPETokenizer();
|
|
15
|
-
* await tokenizer.load(vocabUrl, mergesUrl);
|
|
16
|
-
*
|
|
17
|
-
* const ids = tokenizer.encode("function foo() {}");
|
|
18
|
-
* const text = tokenizer.decode(ids);
|
|
19
|
-
*/
|
|
20
|
-
|
|
21
|
-
// Byte-level fallback map (matches GPT-2 / Qwen convention).
|
|
22
|
-
// Maps raw bytes 0-255 to printable Unicode characters so that every
|
|
23
|
-
// byte sequence has a valid string representation.
|
|
24
|
-
function buildByteEncoder() {
|
|
25
|
-
/** @type {Map<number, string>} */
|
|
26
|
-
const enc = new Map();
|
|
27
|
-
const ranges = [
|
|
28
|
-
[0x21, 0x7E], // ! → ~
|
|
29
|
-
[0xA1, 0xAC], // ¡ → ¬
|
|
30
|
-
[0xAE, 0xFF], // ® → ÿ
|
|
31
|
-
];
|
|
32
|
-
let n = 0;
|
|
33
|
-
for (const [lo, hi] of ranges) {
|
|
34
|
-
for (let b = lo; b <= hi; b++) {
|
|
35
|
-
enc.set(b, String.fromCodePoint(b));
|
|
36
|
-
}
|
|
37
|
-
}
|
|
38
|
-
for (let b = 0; b < 256; b++) {
|
|
39
|
-
if (!enc.has(b)) {
|
|
40
|
-
enc.set(b, String.fromCodePoint(256 + n));
|
|
41
|
-
n++;
|
|
42
|
-
}
|
|
43
|
-
}
|
|
44
|
-
return enc;
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
const BYTE_ENCODER = buildByteEncoder();
|
|
48
|
-
const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
|
|
49
|
-
|
|
50
|
-
// Pre-tokenisation regex: matches words, numbers, punctuation, whitespace
|
|
51
|
-
// (closely mirrors tiktoken's GPT-2/Qwen pre-tokenizer pattern).
|
|
52
|
-
const PRE_TOKENIZE_RE =
|
|
53
|
-
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
|
|
54
|
-
|
|
55
|
-
export class BPETokenizer {
|
|
56
|
-
constructor() {
|
|
57
|
-
/** @type {Map<string, number>} token → id */
|
|
58
|
-
this.vocab = new Map();
|
|
59
|
-
/** @type {Map<number, string>} id → token */
|
|
60
|
-
this.idToToken = new Map();
|
|
61
|
-
/** @type {Map<string, number>} "a b" → rank (lower = higher priority) */
|
|
62
|
-
this.merges = new Map();
|
|
63
|
-
|
|
64
|
-
// Special tokens (added after loading vocab)
|
|
65
|
-
this.bosToken = '<|im_start|>';
|
|
66
|
-
this.eosToken = '<|im_end|>';
|
|
67
|
-
this.padToken = '<|endoftext|>';
|
|
68
|
-
this.unkToken = '<unk>';
|
|
69
|
-
|
|
70
|
-
this.bosId = null;
|
|
71
|
-
this.eosId = null;
|
|
72
|
-
this.padId = null;
|
|
73
|
-
}
|
|
74
|
-
|
|
75
|
-
/**
|
|
76
|
-
* Load vocabulary and merge rules from JSON / text URLs.
|
|
77
|
-
*
|
|
78
|
-
* vocab.json – { "token": id, ... }
|
|
79
|
-
* merges.txt – one merge per line: "a b" (sorted by rank)
|
|
80
|
-
*
|
|
81
|
-
* @param {string|object} vocab – URL string or pre-parsed vocab object
|
|
82
|
-
* @param {string|string[]} merges – URL string or array of merge strings
|
|
83
|
-
* @returns {Promise<void>}
|
|
84
|
-
*/
|
|
85
|
-
async load(vocab, merges) {
|
|
86
|
-
// Load vocab
|
|
87
|
-
let vocabObj;
|
|
88
|
-
if (typeof vocab === 'string') {
|
|
89
|
-
const res = await fetch(vocab);
|
|
90
|
-
vocabObj = await res.json();
|
|
91
|
-
} else {
|
|
92
|
-
vocabObj = vocab;
|
|
93
|
-
}
|
|
94
|
-
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
95
|
-
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
96
|
-
|
|
97
|
-
// Load merges
|
|
98
|
-
let mergeLines;
|
|
99
|
-
if (typeof merges === 'string') {
|
|
100
|
-
const res = await fetch(merges);
|
|
101
|
-
const txt = await res.text();
|
|
102
|
-
mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
|
|
103
|
-
} else {
|
|
104
|
-
mergeLines = merges;
|
|
105
|
-
}
|
|
106
|
-
this.merges = new Map();
|
|
107
|
-
mergeLines.forEach((line, rank) => {
|
|
108
|
-
this.merges.set(line.trim(), rank);
|
|
109
|
-
});
|
|
110
|
-
|
|
111
|
-
// Resolve special token ids
|
|
112
|
-
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
113
|
-
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
114
|
-
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
/**
|
|
118
|
-
* Load from plain JavaScript objects (no network fetch).
|
|
119
|
-
* Useful for bundling a small vocabulary directly.
|
|
120
|
-
*
|
|
121
|
-
* @param {Object} vocabObj – { token: id }
|
|
122
|
-
* @param {string[]} mergeArr – ["a b", "c d", ...]
|
|
123
|
-
*/
|
|
124
|
-
loadFromObjects(vocabObj, mergeArr) {
|
|
125
|
-
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
126
|
-
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
127
|
-
this.merges = new Map(mergeArr.map((m, i) => [m, i]));
|
|
128
|
-
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
129
|
-
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
130
|
-
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
// ── Encoding ──────────────────────────────────────────────────────────────
|
|
134
|
-
|
|
135
|
-
/**
|
|
136
|
-
* Encode a string to an array of token IDs.
|
|
137
|
-
*
|
|
138
|
-
* @param {string} text
|
|
139
|
-
* @param {{ addBos?: boolean, addEos?: boolean }} [opts]
|
|
140
|
-
* @returns {number[]}
|
|
141
|
-
*/
|
|
142
|
-
encode(text, opts = {}) {
|
|
143
|
-
const words = text.match(PRE_TOKENIZE_RE) ?? [];
|
|
144
|
-
const ids = [];
|
|
145
|
-
|
|
146
|
-
if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
|
|
147
|
-
|
|
148
|
-
for (const word of words) {
|
|
149
|
-
// Convert to byte-level encoding
|
|
150
|
-
const bytes = new TextEncoder().encode(word);
|
|
151
|
-
const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
|
|
152
|
-
const bpeTokens = this._bpe(byteStr);
|
|
153
|
-
|
|
154
|
-
for (const tok of bpeTokens) {
|
|
155
|
-
const id = this.vocab.get(tok);
|
|
156
|
-
if (id !== undefined) {
|
|
157
|
-
ids.push(id);
|
|
158
|
-
} else {
|
|
159
|
-
// Fallback: encode each character individually
|
|
160
|
-
for (const ch of tok) {
|
|
161
|
-
const cid = this.vocab.get(ch);
|
|
162
|
-
if (cid !== undefined) ids.push(cid);
|
|
163
|
-
}
|
|
164
|
-
}
|
|
165
|
-
}
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
|
|
169
|
-
return ids;
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
/**
|
|
173
|
-
* Decode an array of token IDs back to a string.
|
|
174
|
-
*
|
|
175
|
-
* @param {number[]} ids
|
|
176
|
-
* @returns {string}
|
|
177
|
-
*/
|
|
178
|
-
decode(ids) {
|
|
179
|
-
let byteStr = '';
|
|
180
|
-
for (const id of ids) {
|
|
181
|
-
const tok = this.idToToken.get(id);
|
|
182
|
-
if (tok !== undefined) byteStr += tok;
|
|
183
|
-
}
|
|
184
|
-
// Convert byte-level string back to raw bytes then UTF-8 decode
|
|
185
|
-
const bytes = new Uint8Array(
|
|
186
|
-
[...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0))
|
|
187
|
-
);
|
|
188
|
-
try {
|
|
189
|
-
return new TextDecoder('utf-8').decode(bytes);
|
|
190
|
-
} catch {
|
|
191
|
-
return byteStr;
|
|
192
|
-
}
|
|
193
|
-
}
|
|
194
|
-
|
|
195
|
-
// ── BPE merge algorithm ───────────────────────────────────────────────────
|
|
196
|
-
|
|
197
|
-
/**
|
|
198
|
-
* Apply BPE merges to a byte-encoded word.
|
|
199
|
-
*
|
|
200
|
-
* @param {string} word – Space-free string of byte-level characters
|
|
201
|
-
* @returns {string[]} – Merged token pieces
|
|
202
|
-
*/
|
|
203
|
-
_bpe(word) {
|
|
204
|
-
if (this.vocab.has(word)) return [word];
|
|
205
|
-
|
|
206
|
-
// Start: each character is its own symbol
|
|
207
|
-
let symbols = [...word];
|
|
208
|
-
|
|
209
|
-
// Iteratively merge the highest-priority pair
|
|
210
|
-
while (symbols.length > 1) {
|
|
211
|
-
let bestRank = Infinity;
|
|
212
|
-
let bestIdx = -1;
|
|
213
|
-
|
|
214
|
-
for (let i = 0; i < symbols.length - 1; i++) {
|
|
215
|
-
const pair = symbols[i] + ' ' + symbols[i + 1];
|
|
216
|
-
const rank = this.merges.get(pair);
|
|
217
|
-
if (rank !== undefined && rank < bestRank) {
|
|
218
|
-
bestRank = rank;
|
|
219
|
-
bestIdx = i;
|
|
220
|
-
}
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
if (bestIdx === -1) break; // no more merges available
|
|
224
|
-
|
|
225
|
-
// Merge pair at bestIdx
|
|
226
|
-
const merged = symbols[bestIdx] + symbols[bestIdx + 1];
|
|
227
|
-
symbols = [
|
|
228
|
-
...symbols.slice(0, bestIdx),
|
|
229
|
-
merged,
|
|
230
|
-
...symbols.slice(bestIdx + 2),
|
|
231
|
-
];
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
return symbols;
|
|
235
|
-
}
|
|
236
|
-
|
|
237
|
-
// ── Padding / truncation helpers ─────────────────────────────────────────
|
|
238
|
-
|
|
239
|
-
/**
|
|
240
|
-
* Pad or truncate a sequence to a fixed length.
|
|
241
|
-
*
|
|
242
|
-
* @param {number[]} ids
|
|
243
|
-
* @param {number} maxLen
|
|
244
|
-
* @param {'right'|'left'} [side='right']
|
|
245
|
-
* @returns {number[]}
|
|
246
|
-
*/
|
|
247
|
-
padOrTruncate(ids, maxLen, side = 'right') {
|
|
248
|
-
if (ids.length >= maxLen) return ids.slice(0, maxLen);
|
|
249
|
-
const padId = this.padId ?? 0;
|
|
250
|
-
const pad = new Array(maxLen - ids.length).fill(padId);
|
|
251
|
-
return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
|
|
252
|
-
}
|
|
253
|
-
|
|
254
|
-
/** @returns {number} */
|
|
255
|
-
get vocabSize() { return this.vocab.size; }
|
|
256
|
-
}
|