webinfer 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/dist/attention/block-sparse/format.d.ts +52 -0
- package/dist/attention/block-sparse/patterns/causal.d.ts +16 -0
- package/dist/attention/block-sparse/patterns/sliding.d.ts +22 -0
- package/dist/attention/flash-attention.d.ts +30 -0
- package/dist/attention/index.d.ts +9 -0
- package/dist/attention/paged-kv/block-manager.d.ts +102 -0
- package/dist/attention/paged-kv/index.d.ts +5 -0
- package/dist/attention/paged-kv/page-table.d.ts +99 -0
- package/dist/attention/scheduler.d.ts +40 -0
- package/dist/core/buffer-pool.d.ts +18 -0
- package/dist/core/device.d.ts +23 -0
- package/dist/core/tensor.d.ts +25 -0
- package/dist/index.d.ts +22 -0
- package/dist/index.js +4228 -0
- package/dist/inference/engine.d.ts +69 -0
- package/dist/inference/generate.d.ts +30 -0
- package/dist/inference/index.d.ts +7 -0
- package/dist/inference/types.d.ts +161 -0
- package/dist/jit/compiler.d.ts +23 -0
- package/dist/jit/kernel-cache.d.ts +21 -0
- package/dist/model/gguf.d.ts +90 -0
- package/dist/model/index.d.ts +16 -0
- package/dist/model/safetensors.d.ts +38 -0
- package/dist/model/types.d.ts +182 -0
- package/dist/ops/activations.d.ts +43 -0
- package/dist/ops/elementwise.d.ts +38 -0
- package/dist/ops/embedding.d.ts +30 -0
- package/dist/ops/matmul.d.ts +21 -0
- package/dist/ops/normalization.d.ts +24 -0
- package/dist/ops/reshape.d.ts +39 -0
- package/dist/ops/rope.d.ts +32 -0
- package/dist/ops/softmax.d.ts +18 -0
- package/dist/quantization/index.d.ts +6 -0
- package/dist/quantization/qmatmul.d.ts +38 -0
- package/dist/quantization/quantize.d.ts +52 -0
- package/dist/sampling/index.d.ts +6 -0
- package/dist/sampling/sampler.d.ts +39 -0
- package/dist/sampling/top-k.d.ts +24 -0
- package/dist/sampling/top-p.d.ts +14 -0
- package/package.json +54 -0
package/dist/index.js
ADDED
|
@@ -0,0 +1,4228 @@
|
|
|
1
|
+
// src/core/device.ts
|
|
2
|
+
class WebInferDevice {
|
|
3
|
+
_device;
|
|
4
|
+
_info;
|
|
5
|
+
constructor(device, info) {
|
|
6
|
+
this._device = device;
|
|
7
|
+
this._info = info;
|
|
8
|
+
}
|
|
9
|
+
static async create() {
|
|
10
|
+
if (!navigator.gpu) {
|
|
11
|
+
throw new Error("WebGPU not supported in this browser");
|
|
12
|
+
}
|
|
13
|
+
const adapter = await navigator.gpu.requestAdapter({
|
|
14
|
+
powerPreference: "high-performance"
|
|
15
|
+
});
|
|
16
|
+
if (!adapter) {
|
|
17
|
+
throw new Error("No WebGPU adapter found");
|
|
18
|
+
}
|
|
19
|
+
const device = await adapter.requestDevice({
|
|
20
|
+
requiredLimits: {
|
|
21
|
+
maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
|
|
22
|
+
maxBufferSize: adapter.limits.maxBufferSize,
|
|
23
|
+
maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
|
|
24
|
+
maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup
|
|
25
|
+
}
|
|
26
|
+
});
|
|
27
|
+
device.lost.then((info2) => {
|
|
28
|
+
console.error("WebGPU device lost:", info2.message);
|
|
29
|
+
});
|
|
30
|
+
const info = WebInferDevice.detectDeviceInfo(adapter, device);
|
|
31
|
+
return new WebInferDevice(device, info);
|
|
32
|
+
}
|
|
33
|
+
static detectDeviceInfo(adapter, device) {
|
|
34
|
+
const adapterInfo = adapter.info;
|
|
35
|
+
const vendorLower = (adapterInfo.vendor || "").toLowerCase();
|
|
36
|
+
const architectureLower = (adapterInfo.architecture || "").toLowerCase();
|
|
37
|
+
let vendor = "unknown";
|
|
38
|
+
if (vendorLower.includes("apple") || architectureLower.includes("apple")) {
|
|
39
|
+
vendor = "apple";
|
|
40
|
+
} else if (vendorLower.includes("nvidia") || architectureLower.includes("nvidia")) {
|
|
41
|
+
vendor = "nvidia";
|
|
42
|
+
} else if (vendorLower.includes("intel") || architectureLower.includes("intel")) {
|
|
43
|
+
vendor = "intel";
|
|
44
|
+
} else if (vendorLower.includes("amd") || vendorLower.includes("advanced micro")) {
|
|
45
|
+
vendor = "amd";
|
|
46
|
+
}
|
|
47
|
+
return {
|
|
48
|
+
vendor,
|
|
49
|
+
architecture: adapterInfo.architecture || "unknown",
|
|
50
|
+
maxWorkgroupSize: device.limits.maxComputeWorkgroupSizeX,
|
|
51
|
+
maxComputeInvocationsPerWorkgroup: device.limits.maxComputeInvocationsPerWorkgroup,
|
|
52
|
+
maxStorageBufferBindingSize: device.limits.maxStorageBufferBindingSize
|
|
53
|
+
};
|
|
54
|
+
}
|
|
55
|
+
get device() {
|
|
56
|
+
return this._device;
|
|
57
|
+
}
|
|
58
|
+
get info() {
|
|
59
|
+
return this._info;
|
|
60
|
+
}
|
|
61
|
+
get limits() {
|
|
62
|
+
return this._device.limits;
|
|
63
|
+
}
|
|
64
|
+
createCommandEncoder() {
|
|
65
|
+
return this._device.createCommandEncoder();
|
|
66
|
+
}
|
|
67
|
+
submit(commandBuffers) {
|
|
68
|
+
this._device.queue.submit(commandBuffers);
|
|
69
|
+
}
|
|
70
|
+
dispose() {
|
|
71
|
+
this._device.destroy();
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
// src/core/tensor.ts
|
|
75
|
+
var DTYPE_BYTES = {
|
|
76
|
+
f32: 4,
|
|
77
|
+
f16: 2,
|
|
78
|
+
i32: 4,
|
|
79
|
+
u32: 4
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
class Tensor {
|
|
83
|
+
_device;
|
|
84
|
+
_shape;
|
|
85
|
+
_dtype;
|
|
86
|
+
_buffer;
|
|
87
|
+
_disposed = false;
|
|
88
|
+
constructor(device, shape, dtype = "f32", data) {
|
|
89
|
+
this._device = device;
|
|
90
|
+
this._shape = Object.freeze([...shape]);
|
|
91
|
+
this._dtype = dtype;
|
|
92
|
+
const byteSize = this.numel * DTYPE_BYTES[dtype];
|
|
93
|
+
const alignedSize = Math.ceil(byteSize / 16) * 16;
|
|
94
|
+
this._buffer = device.device.createBuffer({
|
|
95
|
+
size: alignedSize,
|
|
96
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
97
|
+
mappedAtCreation: !!data
|
|
98
|
+
});
|
|
99
|
+
if (data) {
|
|
100
|
+
const mapped = new Float32Array(this._buffer.getMappedRange());
|
|
101
|
+
mapped.set(data);
|
|
102
|
+
this._buffer.unmap();
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
static async fromArray(device, shape, data, dtype = "f32") {
|
|
106
|
+
return new Tensor(device, shape, dtype, data);
|
|
107
|
+
}
|
|
108
|
+
static zeros(device, shape, dtype = "f32") {
|
|
109
|
+
const numel = shape.reduce((a, b) => a * b, 1);
|
|
110
|
+
const data = new Float32Array(numel);
|
|
111
|
+
return new Tensor(device, shape, dtype, data);
|
|
112
|
+
}
|
|
113
|
+
static rand(device, shape, dtype = "f32") {
|
|
114
|
+
const numel = shape.reduce((a, b) => a * b, 1);
|
|
115
|
+
const data = new Float32Array(numel);
|
|
116
|
+
for (let i = 0;i < numel; i++) {
|
|
117
|
+
data[i] = Math.random();
|
|
118
|
+
}
|
|
119
|
+
return new Tensor(device, shape, dtype, data);
|
|
120
|
+
}
|
|
121
|
+
get shape() {
|
|
122
|
+
return this._shape;
|
|
123
|
+
}
|
|
124
|
+
get dtype() {
|
|
125
|
+
return this._dtype;
|
|
126
|
+
}
|
|
127
|
+
get numel() {
|
|
128
|
+
return this._shape.reduce((a, b) => a * b, 1);
|
|
129
|
+
}
|
|
130
|
+
get byteSize() {
|
|
131
|
+
return this.numel * DTYPE_BYTES[this._dtype];
|
|
132
|
+
}
|
|
133
|
+
get buffer() {
|
|
134
|
+
if (this._disposed) {
|
|
135
|
+
throw new Error("Tensor has been disposed");
|
|
136
|
+
}
|
|
137
|
+
return this._buffer;
|
|
138
|
+
}
|
|
139
|
+
get device() {
|
|
140
|
+
return this._device;
|
|
141
|
+
}
|
|
142
|
+
async toArray() {
|
|
143
|
+
if (this._disposed) {
|
|
144
|
+
throw new Error("Tensor has been disposed");
|
|
145
|
+
}
|
|
146
|
+
const byteSize = this.byteSize;
|
|
147
|
+
const alignedSize = Math.ceil(byteSize / 16) * 16;
|
|
148
|
+
const stagingBuffer = this._device.device.createBuffer({
|
|
149
|
+
size: alignedSize,
|
|
150
|
+
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
|
|
151
|
+
});
|
|
152
|
+
const encoder = this._device.createCommandEncoder();
|
|
153
|
+
encoder.copyBufferToBuffer(this._buffer, 0, stagingBuffer, 0, alignedSize);
|
|
154
|
+
this._device.submit([encoder.finish()]);
|
|
155
|
+
await stagingBuffer.mapAsync(GPUMapMode.READ);
|
|
156
|
+
const data = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
157
|
+
stagingBuffer.unmap();
|
|
158
|
+
stagingBuffer.destroy();
|
|
159
|
+
return data.slice(0, this.numel);
|
|
160
|
+
}
|
|
161
|
+
reshape(newShape) {
|
|
162
|
+
const newNumel = newShape.reduce((a, b) => a * b, 1);
|
|
163
|
+
if (newNumel !== this.numel) {
|
|
164
|
+
throw new Error(`Cannot reshape tensor of size ${this.numel} to shape [${newShape}]`);
|
|
165
|
+
}
|
|
166
|
+
const view = Object.create(Tensor.prototype);
|
|
167
|
+
view._device = this._device;
|
|
168
|
+
view._shape = Object.freeze([...newShape]);
|
|
169
|
+
view._dtype = this._dtype;
|
|
170
|
+
view._buffer = this._buffer;
|
|
171
|
+
view._disposed = false;
|
|
172
|
+
return view;
|
|
173
|
+
}
|
|
174
|
+
dispose() {
|
|
175
|
+
if (!this._disposed) {
|
|
176
|
+
this._buffer.destroy();
|
|
177
|
+
this._disposed = true;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
// src/core/buffer-pool.ts
|
|
182
|
+
class BufferPool {
|
|
183
|
+
device;
|
|
184
|
+
pools = new Map;
|
|
185
|
+
sizeClasses;
|
|
186
|
+
constructor(device) {
|
|
187
|
+
this.device = device;
|
|
188
|
+
this.sizeClasses = [];
|
|
189
|
+
for (let size = 256;size <= 1024 * 1024 * 1024; size *= 2) {
|
|
190
|
+
this.sizeClasses.push(size);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
getSizeClass(size) {
|
|
194
|
+
for (const sizeClass of this.sizeClasses) {
|
|
195
|
+
if (sizeClass >= size) {
|
|
196
|
+
return sizeClass;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
return Math.pow(2, Math.ceil(Math.log2(size)));
|
|
200
|
+
}
|
|
201
|
+
acquire(size, usage) {
|
|
202
|
+
const sizeClass = this.getSizeClass(size);
|
|
203
|
+
const pool = this.pools.get(sizeClass);
|
|
204
|
+
if (pool) {
|
|
205
|
+
for (const pooled2 of pool) {
|
|
206
|
+
if (!pooled2.inUse && (pooled2.buffer.usage & usage) === usage) {
|
|
207
|
+
pooled2.inUse = true;
|
|
208
|
+
return pooled2.buffer;
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
const buffer = this.device.createBuffer({
|
|
213
|
+
size: sizeClass,
|
|
214
|
+
usage: usage | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
|
|
215
|
+
});
|
|
216
|
+
const pooled = {
|
|
217
|
+
buffer,
|
|
218
|
+
size: sizeClass,
|
|
219
|
+
inUse: true
|
|
220
|
+
};
|
|
221
|
+
if (!this.pools.has(sizeClass)) {
|
|
222
|
+
this.pools.set(sizeClass, []);
|
|
223
|
+
}
|
|
224
|
+
this.pools.get(sizeClass).push(pooled);
|
|
225
|
+
return buffer;
|
|
226
|
+
}
|
|
227
|
+
release(buffer) {
|
|
228
|
+
for (const pool of this.pools.values()) {
|
|
229
|
+
for (const pooled of pool) {
|
|
230
|
+
if (pooled.buffer === buffer) {
|
|
231
|
+
pooled.inUse = false;
|
|
232
|
+
return;
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
getStats() {
|
|
238
|
+
let totalBuffers = 0;
|
|
239
|
+
let inUse = 0;
|
|
240
|
+
let totalBytes = 0;
|
|
241
|
+
for (const pool of this.pools.values()) {
|
|
242
|
+
for (const pooled of pool) {
|
|
243
|
+
totalBuffers++;
|
|
244
|
+
totalBytes += pooled.size;
|
|
245
|
+
if (pooled.inUse)
|
|
246
|
+
inUse++;
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
return { totalBuffers, inUse, totalBytes };
|
|
250
|
+
}
|
|
251
|
+
dispose() {
|
|
252
|
+
for (const pool of this.pools.values()) {
|
|
253
|
+
for (const pooled of pool) {
|
|
254
|
+
pooled.buffer.destroy();
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
this.pools.clear();
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
// src/jit/kernel-cache.ts
|
|
261
|
+
class KernelCache {
|
|
262
|
+
device;
|
|
263
|
+
cache = new Map;
|
|
264
|
+
hits = 0;
|
|
265
|
+
misses = 0;
|
|
266
|
+
constructor(device) {
|
|
267
|
+
this.device = device;
|
|
268
|
+
}
|
|
269
|
+
getOrCreate(key, createFn) {
|
|
270
|
+
const existing = this.cache.get(key);
|
|
271
|
+
if (existing) {
|
|
272
|
+
this.hits++;
|
|
273
|
+
return existing;
|
|
274
|
+
}
|
|
275
|
+
this.misses++;
|
|
276
|
+
const pipeline = createFn();
|
|
277
|
+
this.cache.set(key, pipeline);
|
|
278
|
+
return pipeline;
|
|
279
|
+
}
|
|
280
|
+
has(key) {
|
|
281
|
+
return this.cache.has(key);
|
|
282
|
+
}
|
|
283
|
+
get(key) {
|
|
284
|
+
const pipeline = this.cache.get(key);
|
|
285
|
+
if (pipeline)
|
|
286
|
+
this.hits++;
|
|
287
|
+
return pipeline;
|
|
288
|
+
}
|
|
289
|
+
set(key, pipeline) {
|
|
290
|
+
this.cache.set(key, pipeline);
|
|
291
|
+
}
|
|
292
|
+
getStats() {
|
|
293
|
+
return {
|
|
294
|
+
hits: this.hits,
|
|
295
|
+
misses: this.misses,
|
|
296
|
+
size: this.cache.size
|
|
297
|
+
};
|
|
298
|
+
}
|
|
299
|
+
clear() {
|
|
300
|
+
this.cache.clear();
|
|
301
|
+
this.hits = 0;
|
|
302
|
+
this.misses = 0;
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
// src/jit/compiler.ts
|
|
306
|
+
class WGSLCompiler {
|
|
307
|
+
device;
|
|
308
|
+
cache;
|
|
309
|
+
deviceInfo;
|
|
310
|
+
constructor(device, cache, deviceInfo) {
|
|
311
|
+
this.device = device;
|
|
312
|
+
this.cache = cache;
|
|
313
|
+
this.deviceInfo = deviceInfo;
|
|
314
|
+
}
|
|
315
|
+
selectTileSize(config) {
|
|
316
|
+
if (this.deviceInfo.vendor === "apple") {
|
|
317
|
+
return { tileM: 16, tileN: 16, tileK: 16 };
|
|
318
|
+
} else if (this.deviceInfo.vendor === "nvidia") {
|
|
319
|
+
return { tileM: 32, tileN: 32, tileK: 16 };
|
|
320
|
+
}
|
|
321
|
+
return { tileM: 16, tileN: 16, tileK: 16 };
|
|
322
|
+
}
|
|
323
|
+
compileMatMul(config) {
|
|
324
|
+
const tiles = this.selectTileSize(config);
|
|
325
|
+
const tileM = config.tileM ?? tiles.tileM;
|
|
326
|
+
const tileN = config.tileN ?? tiles.tileN;
|
|
327
|
+
const tileK = config.tileK ?? tiles.tileK;
|
|
328
|
+
const key = `matmul_${config.M}_${config.N}_${config.K}_${tileM}_${tileN}_${tileK}`;
|
|
329
|
+
return this.cache.getOrCreate(key, () => {
|
|
330
|
+
const wgsl = this.generateMatMulWGSL(config.M, config.N, config.K, tileM, tileN, tileK);
|
|
331
|
+
const shaderModule = this.device.createShaderModule({
|
|
332
|
+
code: wgsl
|
|
333
|
+
});
|
|
334
|
+
return this.device.createComputePipeline({
|
|
335
|
+
layout: "auto",
|
|
336
|
+
compute: {
|
|
337
|
+
module: shaderModule,
|
|
338
|
+
entryPoint: "main"
|
|
339
|
+
}
|
|
340
|
+
});
|
|
341
|
+
});
|
|
342
|
+
}
|
|
343
|
+
generateMatMulWGSL(M, N, K, tileM, tileN, tileK) {
|
|
344
|
+
const workgroupSizeX = tileN;
|
|
345
|
+
const workgroupSizeY = tileM;
|
|
346
|
+
return `
|
|
347
|
+
// WebInfer MatMul Kernel
|
|
348
|
+
// C[M,N] = A[M,K] @ B[K,N]
|
|
349
|
+
// Tile size: ${tileM}x${tileN}x${tileK}
|
|
350
|
+
|
|
351
|
+
struct Params {
|
|
352
|
+
M: u32,
|
|
353
|
+
N: u32,
|
|
354
|
+
K: u32,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
@group(0) @binding(0) var<storage, read> A: array<f32>;
|
|
358
|
+
@group(0) @binding(1) var<storage, read> B: array<f32>;
|
|
359
|
+
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
|
|
360
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
361
|
+
|
|
362
|
+
var<workgroup> tileA: array<f32, ${tileM * tileK}>;
|
|
363
|
+
var<workgroup> tileB: array<f32, ${tileK * tileN}>;
|
|
364
|
+
|
|
365
|
+
@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY})
|
|
366
|
+
fn main(
|
|
367
|
+
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
368
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
369
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
|
370
|
+
) {
|
|
371
|
+
let row = workgroup_id.y * ${tileM}u + local_id.y;
|
|
372
|
+
let col = workgroup_id.x * ${tileN}u + local_id.x;
|
|
373
|
+
|
|
374
|
+
let localRow = local_id.y;
|
|
375
|
+
let localCol = local_id.x;
|
|
376
|
+
|
|
377
|
+
var sum: f32 = 0.0;
|
|
378
|
+
|
|
379
|
+
let numTiles = (params.K + ${tileK}u - 1u) / ${tileK}u;
|
|
380
|
+
|
|
381
|
+
for (var t: u32 = 0u; t < numTiles; t = t + 1u) {
|
|
382
|
+
// Load tile of A into shared memory
|
|
383
|
+
let aRow = row;
|
|
384
|
+
let aCol = t * ${tileK}u + localCol;
|
|
385
|
+
if (aRow < params.M && aCol < params.K) {
|
|
386
|
+
tileA[localRow * ${tileK}u + localCol] = A[aRow * params.K + aCol];
|
|
387
|
+
} else {
|
|
388
|
+
tileA[localRow * ${tileK}u + localCol] = 0.0;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
// Load tile of B into shared memory
|
|
392
|
+
let bRow = t * ${tileK}u + localRow;
|
|
393
|
+
let bCol = col;
|
|
394
|
+
if (bRow < params.K && bCol < params.N) {
|
|
395
|
+
tileB[localRow * ${tileN}u + localCol] = B[bRow * params.N + bCol];
|
|
396
|
+
} else {
|
|
397
|
+
tileB[localRow * ${tileN}u + localCol] = 0.0;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
workgroupBarrier();
|
|
401
|
+
|
|
402
|
+
// Compute partial dot product
|
|
403
|
+
for (var k: u32 = 0u; k < ${tileK}u; k = k + 1u) {
|
|
404
|
+
sum = sum + tileA[localRow * ${tileK}u + k] * tileB[k * ${tileN}u + localCol];
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
workgroupBarrier();
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
// Write result
|
|
411
|
+
if (row < params.M && col < params.N) {
|
|
412
|
+
C[row * params.N + col] = sum;
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
`;
|
|
416
|
+
}
|
|
417
|
+
getCacheStats() {
|
|
418
|
+
return this.cache.getStats();
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
// src/ops/matmul.ts
|
|
422
|
+
var compilerInstance = null;
|
|
423
|
+
var cacheInstance = null;
|
|
424
|
+
function getCompiler(device) {
|
|
425
|
+
if (!compilerInstance || !cacheInstance) {
|
|
426
|
+
cacheInstance = new KernelCache(device.device);
|
|
427
|
+
compilerInstance = new WGSLCompiler(device.device, cacheInstance, device.info);
|
|
428
|
+
}
|
|
429
|
+
return compilerInstance;
|
|
430
|
+
}
|
|
431
|
+
async function matmul(device, a, b) {
|
|
432
|
+
if (a.shape.length !== 2 || b.shape.length !== 2) {
|
|
433
|
+
throw new Error("matmul requires 2D tensors");
|
|
434
|
+
}
|
|
435
|
+
const [M, K1] = a.shape;
|
|
436
|
+
const [K2, N] = b.shape;
|
|
437
|
+
if (K1 !== K2) {
|
|
438
|
+
throw new Error(`matmul shape mismatch: [${M},${K1}] @ [${K2},${N}] - inner dimensions must match`);
|
|
439
|
+
}
|
|
440
|
+
const K = K1;
|
|
441
|
+
const c = Tensor.zeros(device, [M, N]);
|
|
442
|
+
const compiler = getCompiler(device);
|
|
443
|
+
const pipeline = compiler.compileMatMul({ M, N, K });
|
|
444
|
+
const paramsBuffer = device.device.createBuffer({
|
|
445
|
+
size: 16,
|
|
446
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
447
|
+
});
|
|
448
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, new Uint32Array([M, N, K]));
|
|
449
|
+
const bindGroup = device.device.createBindGroup({
|
|
450
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
451
|
+
entries: [
|
|
452
|
+
{ binding: 0, resource: { buffer: a.buffer } },
|
|
453
|
+
{ binding: 1, resource: { buffer: b.buffer } },
|
|
454
|
+
{ binding: 2, resource: { buffer: c.buffer } },
|
|
455
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
456
|
+
]
|
|
457
|
+
});
|
|
458
|
+
const encoder = device.createCommandEncoder();
|
|
459
|
+
const pass = encoder.beginComputePass();
|
|
460
|
+
pass.setPipeline(pipeline);
|
|
461
|
+
pass.setBindGroup(0, bindGroup);
|
|
462
|
+
const tileSize = 16;
|
|
463
|
+
const workgroupsX = Math.ceil(N / tileSize);
|
|
464
|
+
const workgroupsY = Math.ceil(M / tileSize);
|
|
465
|
+
pass.dispatchWorkgroups(workgroupsX, workgroupsY);
|
|
466
|
+
pass.end();
|
|
467
|
+
device.submit([encoder.finish()]);
|
|
468
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
469
|
+
paramsBuffer.destroy();
|
|
470
|
+
return c;
|
|
471
|
+
}
|
|
472
|
+
function matmulCPU(a, b, M, N, K) {
|
|
473
|
+
const c = new Float32Array(M * N);
|
|
474
|
+
for (let i = 0;i < M; i++) {
|
|
475
|
+
for (let j = 0;j < N; j++) {
|
|
476
|
+
let sum = 0;
|
|
477
|
+
for (let k = 0;k < K; k++) {
|
|
478
|
+
sum += a[i * K + k] * b[k * N + j];
|
|
479
|
+
}
|
|
480
|
+
c[i * N + j] = sum;
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
return c;
|
|
484
|
+
}
|
|
485
|
+
function getMatMulCacheStats(device) {
|
|
486
|
+
const compiler = getCompiler(device);
|
|
487
|
+
return compiler.getCacheStats();
|
|
488
|
+
}
|
|
489
|
+
// src/ops/normalization.ts
|
|
490
|
+
var kernelCache = null;
|
|
491
|
+
function getCache(device) {
|
|
492
|
+
if (!kernelCache) {
|
|
493
|
+
kernelCache = new KernelCache(device);
|
|
494
|
+
}
|
|
495
|
+
return kernelCache;
|
|
496
|
+
}
|
|
497
|
+
function layerNormCPU(x, weight, bias, shape, eps = 0.00001) {
|
|
498
|
+
const lastDim = shape[shape.length - 1];
|
|
499
|
+
const outerSize = x.length / lastDim;
|
|
500
|
+
const output = new Float32Array(x.length);
|
|
501
|
+
for (let i = 0;i < outerSize; i++) {
|
|
502
|
+
const offset = i * lastDim;
|
|
503
|
+
let mean = 0;
|
|
504
|
+
for (let j = 0;j < lastDim; j++) {
|
|
505
|
+
mean += x[offset + j];
|
|
506
|
+
}
|
|
507
|
+
mean /= lastDim;
|
|
508
|
+
let variance = 0;
|
|
509
|
+
for (let j = 0;j < lastDim; j++) {
|
|
510
|
+
const diff = x[offset + j] - mean;
|
|
511
|
+
variance += diff * diff;
|
|
512
|
+
}
|
|
513
|
+
variance /= lastDim;
|
|
514
|
+
const invStd = 1 / Math.sqrt(variance + eps);
|
|
515
|
+
for (let j = 0;j < lastDim; j++) {
|
|
516
|
+
const normalized = (x[offset + j] - mean) * invStd;
|
|
517
|
+
output[offset + j] = normalized * weight[j] + (bias ? bias[j] : 0);
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
return output;
|
|
521
|
+
}
|
|
522
|
+
function rmsNormCPU(x, weight, shape, eps = 0.00001) {
|
|
523
|
+
const lastDim = shape[shape.length - 1];
|
|
524
|
+
const outerSize = x.length / lastDim;
|
|
525
|
+
const output = new Float32Array(x.length);
|
|
526
|
+
for (let i = 0;i < outerSize; i++) {
|
|
527
|
+
const offset = i * lastDim;
|
|
528
|
+
let sumSq = 0;
|
|
529
|
+
for (let j = 0;j < lastDim; j++) {
|
|
530
|
+
sumSq += x[offset + j] * x[offset + j];
|
|
531
|
+
}
|
|
532
|
+
const rms = Math.sqrt(sumSq / lastDim + eps);
|
|
533
|
+
const invRms = 1 / rms;
|
|
534
|
+
for (let j = 0;j < lastDim; j++) {
|
|
535
|
+
output[offset + j] = x[offset + j] * invRms * weight[j];
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
return output;
|
|
539
|
+
}
|
|
540
|
+
async function layerNorm(device, x, weight, bias, eps = 0.00001) {
|
|
541
|
+
const lastDim = x.shape[x.shape.length - 1];
|
|
542
|
+
const outerSize = x.numel / lastDim;
|
|
543
|
+
const cache = getCache(device.device);
|
|
544
|
+
const pipeline = cache.getOrCreate(`layernorm_${lastDim}_${bias !== null}`, () => compileLayerNormKernel(device.device, lastDim, bias !== null));
|
|
545
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
546
|
+
const params = new Float32Array([outerSize, lastDim, eps, 0]);
|
|
547
|
+
const paramsBuffer = device.device.createBuffer({
|
|
548
|
+
size: params.byteLength,
|
|
549
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
550
|
+
});
|
|
551
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
552
|
+
const entries = [
|
|
553
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
554
|
+
{ binding: 1, resource: { buffer: weight.buffer } },
|
|
555
|
+
{ binding: 2, resource: { buffer: output.buffer } },
|
|
556
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
557
|
+
];
|
|
558
|
+
if (bias) {
|
|
559
|
+
entries.push({ binding: 4, resource: { buffer: bias.buffer } });
|
|
560
|
+
}
|
|
561
|
+
const bindGroup = device.device.createBindGroup({
|
|
562
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
563
|
+
entries
|
|
564
|
+
});
|
|
565
|
+
const encoder = device.createCommandEncoder();
|
|
566
|
+
const pass = encoder.beginComputePass();
|
|
567
|
+
pass.setPipeline(pipeline);
|
|
568
|
+
pass.setBindGroup(0, bindGroup);
|
|
569
|
+
pass.dispatchWorkgroups(outerSize);
|
|
570
|
+
pass.end();
|
|
571
|
+
device.submit([encoder.finish()]);
|
|
572
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
573
|
+
paramsBuffer.destroy();
|
|
574
|
+
return output;
|
|
575
|
+
}
|
|
576
|
+
async function rmsNorm(device, x, weight, eps = 0.00001) {
|
|
577
|
+
const lastDim = x.shape[x.shape.length - 1];
|
|
578
|
+
const outerSize = x.numel / lastDim;
|
|
579
|
+
const cache = getCache(device.device);
|
|
580
|
+
const pipeline = cache.getOrCreate(`rmsnorm_${lastDim}`, () => compileRMSNormKernel(device.device, lastDim));
|
|
581
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
582
|
+
const params = new Float32Array([outerSize, lastDim, eps, 0]);
|
|
583
|
+
const paramsBuffer = device.device.createBuffer({
|
|
584
|
+
size: params.byteLength,
|
|
585
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
586
|
+
});
|
|
587
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
588
|
+
const bindGroup = device.device.createBindGroup({
|
|
589
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
590
|
+
entries: [
|
|
591
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
592
|
+
{ binding: 1, resource: { buffer: weight.buffer } },
|
|
593
|
+
{ binding: 2, resource: { buffer: output.buffer } },
|
|
594
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
595
|
+
]
|
|
596
|
+
});
|
|
597
|
+
const encoder = device.createCommandEncoder();
|
|
598
|
+
const pass = encoder.beginComputePass();
|
|
599
|
+
pass.setPipeline(pipeline);
|
|
600
|
+
pass.setBindGroup(0, bindGroup);
|
|
601
|
+
pass.dispatchWorkgroups(outerSize);
|
|
602
|
+
pass.end();
|
|
603
|
+
device.submit([encoder.finish()]);
|
|
604
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
605
|
+
paramsBuffer.destroy();
|
|
606
|
+
return output;
|
|
607
|
+
}
|
|
608
|
+
function compileLayerNormKernel(device, dim, hasBias) {
|
|
609
|
+
const WORKGROUP_SIZE = 256;
|
|
610
|
+
const wgsl = `
|
|
611
|
+
struct Params {
|
|
612
|
+
outerSize: f32,
|
|
613
|
+
dim: f32,
|
|
614
|
+
eps: f32,
|
|
615
|
+
_pad: f32,
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
619
|
+
@group(0) @binding(1) var<storage, read> weight: array<f32>;
|
|
620
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
621
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
622
|
+
${hasBias ? "@group(0) @binding(4) var<storage, read> bias: array<f32>;" : ""}
|
|
623
|
+
|
|
624
|
+
var<workgroup> shared_sum: array<f32, ${WORKGROUP_SIZE}>;
|
|
625
|
+
var<workgroup> shared_mean: f32;
|
|
626
|
+
|
|
627
|
+
@compute @workgroup_size(${WORKGROUP_SIZE})
|
|
628
|
+
fn main(
|
|
629
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
630
|
+
@builtin(workgroup_id) wgid: vec3<u32>
|
|
631
|
+
) {
|
|
632
|
+
let row = wgid.x;
|
|
633
|
+
let tid = lid.x;
|
|
634
|
+
let dim = u32(params.dim);
|
|
635
|
+
let offset = row * dim;
|
|
636
|
+
|
|
637
|
+
// === Pass 1: Compute mean ===
|
|
638
|
+
var partial_sum: f32 = 0.0;
|
|
639
|
+
for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
|
|
640
|
+
partial_sum += x[offset + i];
|
|
641
|
+
}
|
|
642
|
+
shared_sum[tid] = partial_sum;
|
|
643
|
+
workgroupBarrier();
|
|
644
|
+
|
|
645
|
+
// Parallel reduction for sum
|
|
646
|
+
for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
|
|
647
|
+
if (tid < stride) {
|
|
648
|
+
shared_sum[tid] += shared_sum[tid + stride];
|
|
649
|
+
}
|
|
650
|
+
workgroupBarrier();
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
// Store mean for all threads to use
|
|
654
|
+
if (tid == 0u) {
|
|
655
|
+
shared_mean = shared_sum[0] / params.dim;
|
|
656
|
+
}
|
|
657
|
+
workgroupBarrier();
|
|
658
|
+
let mean = shared_mean;
|
|
659
|
+
|
|
660
|
+
// === Pass 2: Compute variance ===
|
|
661
|
+
var partial_var: f32 = 0.0;
|
|
662
|
+
for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
|
|
663
|
+
let diff = x[offset + i] - mean;
|
|
664
|
+
partial_var += diff * diff;
|
|
665
|
+
}
|
|
666
|
+
shared_sum[tid] = partial_var;
|
|
667
|
+
workgroupBarrier();
|
|
668
|
+
|
|
669
|
+
// Parallel reduction for variance
|
|
670
|
+
for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
|
|
671
|
+
if (tid < stride) {
|
|
672
|
+
shared_sum[tid] += shared_sum[tid + stride];
|
|
673
|
+
}
|
|
674
|
+
workgroupBarrier();
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
// Compute inverse standard deviation
|
|
678
|
+
let inv_std = 1.0 / sqrt(shared_sum[0] / params.dim + params.eps);
|
|
679
|
+
|
|
680
|
+
// === Pass 3: Normalize and apply affine transform ===
|
|
681
|
+
for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
|
|
682
|
+
let normalized = (x[offset + i] - mean) * inv_std;
|
|
683
|
+
${hasBias ? "output[offset + i] = normalized * weight[i] + bias[i];" : "output[offset + i] = normalized * weight[i];"}
|
|
684
|
+
}
|
|
685
|
+
}
|
|
686
|
+
`;
|
|
687
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
688
|
+
return device.createComputePipeline({
|
|
689
|
+
layout: "auto",
|
|
690
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
691
|
+
});
|
|
692
|
+
}
|
|
693
|
+
function compileRMSNormKernel(device, dim) {
|
|
694
|
+
const WORKGROUP_SIZE = 256;
|
|
695
|
+
const wgsl = `
|
|
696
|
+
struct Params {
|
|
697
|
+
outerSize: f32,
|
|
698
|
+
dim: f32,
|
|
699
|
+
eps: f32,
|
|
700
|
+
_pad: f32,
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
704
|
+
@group(0) @binding(1) var<storage, read> weight: array<f32>;
|
|
705
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
706
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
707
|
+
|
|
708
|
+
var<workgroup> shared_sum: array<f32, ${WORKGROUP_SIZE}>;
|
|
709
|
+
|
|
710
|
+
@compute @workgroup_size(${WORKGROUP_SIZE})
|
|
711
|
+
fn main(
|
|
712
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
713
|
+
@builtin(workgroup_id) wgid: vec3<u32>
|
|
714
|
+
) {
|
|
715
|
+
let row = wgid.x;
|
|
716
|
+
let tid = lid.x;
|
|
717
|
+
let dim = u32(params.dim);
|
|
718
|
+
let offset = row * dim;
|
|
719
|
+
|
|
720
|
+
// Each thread computes partial sum of squares
|
|
721
|
+
var partial_sum: f32 = 0.0;
|
|
722
|
+
for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
|
|
723
|
+
let val = x[offset + i];
|
|
724
|
+
partial_sum += val * val;
|
|
725
|
+
}
|
|
726
|
+
shared_sum[tid] = partial_sum;
|
|
727
|
+
workgroupBarrier();
|
|
728
|
+
|
|
729
|
+
// Parallel reduction in shared memory
|
|
730
|
+
for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
|
|
731
|
+
if (tid < stride) {
|
|
732
|
+
shared_sum[tid] += shared_sum[tid + stride];
|
|
733
|
+
}
|
|
734
|
+
workgroupBarrier();
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
// Compute inverse RMS (thread 0 has the final sum)
|
|
738
|
+
let inv_rms = 1.0 / sqrt(shared_sum[0] / params.dim + params.eps);
|
|
739
|
+
|
|
740
|
+
// All threads normalize their portion
|
|
741
|
+
for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
|
|
742
|
+
output[offset + i] = x[offset + i] * inv_rms * weight[i];
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
`;
|
|
746
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
747
|
+
return device.createComputePipeline({
|
|
748
|
+
layout: "auto",
|
|
749
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
750
|
+
});
|
|
751
|
+
}
|
|
752
|
+
// src/ops/rope.ts
|
|
753
|
+
var kernelCache2 = null;
|
|
754
|
+
function getCache2(device) {
|
|
755
|
+
if (!kernelCache2) {
|
|
756
|
+
kernelCache2 = new KernelCache(device);
|
|
757
|
+
}
|
|
758
|
+
return kernelCache2;
|
|
759
|
+
}
|
|
760
|
+
function computeRoPEFrequencies(config) {
|
|
761
|
+
const { dim, maxSeqLen, base = 1e4, scaling = 1 } = config;
|
|
762
|
+
const halfDim = dim / 2;
|
|
763
|
+
const invFreq = new Float32Array(halfDim);
|
|
764
|
+
for (let i = 0;i < halfDim; i++) {
|
|
765
|
+
invFreq[i] = 1 / Math.pow(base, 2 * i / dim);
|
|
766
|
+
}
|
|
767
|
+
const cos = new Float32Array(maxSeqLen * halfDim);
|
|
768
|
+
const sin = new Float32Array(maxSeqLen * halfDim);
|
|
769
|
+
for (let pos = 0;pos < maxSeqLen; pos++) {
|
|
770
|
+
const scaledPos = pos / scaling;
|
|
771
|
+
for (let i = 0;i < halfDim; i++) {
|
|
772
|
+
const angle = scaledPos * invFreq[i];
|
|
773
|
+
cos[pos * halfDim + i] = Math.cos(angle);
|
|
774
|
+
sin[pos * halfDim + i] = Math.sin(angle);
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
return { cos, sin };
|
|
778
|
+
}
|
|
779
|
+
function ropeCPU(x, positions, cos, sin, seqLen, numHeads, headDim) {
|
|
780
|
+
const halfDim = headDim / 2;
|
|
781
|
+
const output = new Float32Array(x.length);
|
|
782
|
+
for (let s = 0;s < seqLen; s++) {
|
|
783
|
+
const pos = positions[s];
|
|
784
|
+
const cosOffset = pos * halfDim;
|
|
785
|
+
const sinOffset = pos * halfDim;
|
|
786
|
+
for (let h = 0;h < numHeads; h++) {
|
|
787
|
+
const baseIdx = s * numHeads * headDim + h * headDim;
|
|
788
|
+
for (let d = 0;d < halfDim; d++) {
|
|
789
|
+
const x0 = x[baseIdx + d];
|
|
790
|
+
const x1 = x[baseIdx + halfDim + d];
|
|
791
|
+
const c = cos[cosOffset + d];
|
|
792
|
+
const si = sin[sinOffset + d];
|
|
793
|
+
output[baseIdx + d] = x0 * c - x1 * si;
|
|
794
|
+
output[baseIdx + halfDim + d] = x0 * si + x1 * c;
|
|
795
|
+
}
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
return output;
|
|
799
|
+
}
|
|
800
|
+
async function rope(device, x, positions, config) {
|
|
801
|
+
if (x.shape.length !== 3) {
|
|
802
|
+
throw new Error("RoPE input must be 3D [seqLen, numHeads, headDim]");
|
|
803
|
+
}
|
|
804
|
+
const [seqLen, numHeads, headDim] = x.shape;
|
|
805
|
+
const { cos, sin } = computeRoPEFrequencies(config);
|
|
806
|
+
const cache = getCache2(device.device);
|
|
807
|
+
const pipeline = cache.getOrCreate(`rope_${headDim}_${numHeads}`, () => compileRoPEKernel(device.device, headDim, numHeads));
|
|
808
|
+
const output = Tensor.zeros(device, [seqLen, numHeads, headDim]);
|
|
809
|
+
const cosBuffer = device.device.createBuffer({
|
|
810
|
+
size: cos.byteLength,
|
|
811
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
|
|
812
|
+
});
|
|
813
|
+
device.device.queue.writeBuffer(cosBuffer, 0, new Float32Array(cos));
|
|
814
|
+
const sinBuffer = device.device.createBuffer({
|
|
815
|
+
size: sin.byteLength,
|
|
816
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
|
|
817
|
+
});
|
|
818
|
+
device.device.queue.writeBuffer(sinBuffer, 0, new Float32Array(sin));
|
|
819
|
+
const params = new Uint32Array([seqLen, numHeads, headDim, headDim / 2]);
|
|
820
|
+
const paramsBuffer = device.device.createBuffer({
|
|
821
|
+
size: params.byteLength,
|
|
822
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
823
|
+
});
|
|
824
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
825
|
+
const bindGroup = device.device.createBindGroup({
|
|
826
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
827
|
+
entries: [
|
|
828
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
829
|
+
{ binding: 1, resource: { buffer: positions.buffer } },
|
|
830
|
+
{ binding: 2, resource: { buffer: cosBuffer } },
|
|
831
|
+
{ binding: 3, resource: { buffer: sinBuffer } },
|
|
832
|
+
{ binding: 4, resource: { buffer: output.buffer } },
|
|
833
|
+
{ binding: 5, resource: { buffer: paramsBuffer } }
|
|
834
|
+
]
|
|
835
|
+
});
|
|
836
|
+
const encoder = device.createCommandEncoder();
|
|
837
|
+
const pass = encoder.beginComputePass();
|
|
838
|
+
pass.setPipeline(pipeline);
|
|
839
|
+
pass.setBindGroup(0, bindGroup);
|
|
840
|
+
pass.dispatchWorkgroups(Math.ceil(seqLen / 64), numHeads);
|
|
841
|
+
pass.end();
|
|
842
|
+
device.submit([encoder.finish()]);
|
|
843
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
844
|
+
cosBuffer.destroy();
|
|
845
|
+
sinBuffer.destroy();
|
|
846
|
+
paramsBuffer.destroy();
|
|
847
|
+
return output;
|
|
848
|
+
}
|
|
849
|
+
function compileRoPEKernel(device, headDim, numHeads) {
|
|
850
|
+
const halfDim = headDim / 2;
|
|
851
|
+
const wgsl = `
|
|
852
|
+
struct Params {
|
|
853
|
+
seqLen: u32,
|
|
854
|
+
numHeads: u32,
|
|
855
|
+
headDim: u32,
|
|
856
|
+
halfDim: u32,
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
860
|
+
@group(0) @binding(1) var<storage, read> positions: array<u32>;
|
|
861
|
+
@group(0) @binding(2) var<storage, read> cos: array<f32>;
|
|
862
|
+
@group(0) @binding(3) var<storage, read> sin: array<f32>;
|
|
863
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
864
|
+
@group(0) @binding(5) var<uniform> params: Params;
|
|
865
|
+
|
|
866
|
+
@compute @workgroup_size(64)
|
|
867
|
+
fn main(
|
|
868
|
+
@builtin(global_invocation_id) gid: vec3<u32>,
|
|
869
|
+
@builtin(workgroup_id) wgid: vec3<u32>
|
|
870
|
+
) {
|
|
871
|
+
let seqIdx = gid.x;
|
|
872
|
+
let headIdx = wgid.y;
|
|
873
|
+
|
|
874
|
+
if (seqIdx >= params.seqLen) {
|
|
875
|
+
return;
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
let pos = positions[seqIdx];
|
|
879
|
+
let halfDim = params.halfDim;
|
|
880
|
+
let headDim = params.headDim;
|
|
881
|
+
let numHeads = params.numHeads;
|
|
882
|
+
|
|
883
|
+
let baseIdx = seqIdx * numHeads * headDim + headIdx * headDim;
|
|
884
|
+
let freqOffset = pos * halfDim;
|
|
885
|
+
|
|
886
|
+
// Apply rotation to pairs
|
|
887
|
+
for (var d = 0u; d < halfDim; d = d + 1u) {
|
|
888
|
+
let x0 = x[baseIdx + d];
|
|
889
|
+
let x1 = x[baseIdx + halfDim + d];
|
|
890
|
+
let c = cos[freqOffset + d];
|
|
891
|
+
let s = sin[freqOffset + d];
|
|
892
|
+
|
|
893
|
+
output[baseIdx + d] = x0 * c - x1 * s;
|
|
894
|
+
output[baseIdx + halfDim + d] = x0 * s + x1 * c;
|
|
895
|
+
}
|
|
896
|
+
}
|
|
897
|
+
`;
|
|
898
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
899
|
+
return device.createComputePipeline({
|
|
900
|
+
layout: "auto",
|
|
901
|
+
compute: {
|
|
902
|
+
module: shaderModule,
|
|
903
|
+
entryPoint: "main"
|
|
904
|
+
}
|
|
905
|
+
});
|
|
906
|
+
}
|
|
907
|
+
// src/ops/activations.ts
|
|
908
|
+
var kernelCache3 = null;
|
|
909
|
+
function getCache3(device) {
|
|
910
|
+
if (!kernelCache3) {
|
|
911
|
+
kernelCache3 = new KernelCache(device);
|
|
912
|
+
}
|
|
913
|
+
return kernelCache3;
|
|
914
|
+
}
|
|
915
|
+
function geluCPU(x) {
|
|
916
|
+
const output = new Float32Array(x.length);
|
|
917
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
918
|
+
for (let i = 0;i < x.length; i++) {
|
|
919
|
+
const xi = x[i];
|
|
920
|
+
const inner = sqrt2OverPi * (xi + 0.044715 * xi * xi * xi);
|
|
921
|
+
output[i] = xi * 0.5 * (1 + Math.tanh(inner));
|
|
922
|
+
}
|
|
923
|
+
return output;
|
|
924
|
+
}
|
|
925
|
+
function geluExactCPU(x) {
|
|
926
|
+
const output = new Float32Array(x.length);
|
|
927
|
+
const sqrt2 = Math.sqrt(2);
|
|
928
|
+
for (let i = 0;i < x.length; i++) {
|
|
929
|
+
const xi = x[i];
|
|
930
|
+
output[i] = xi * 0.5 * (1 + erf(xi / sqrt2));
|
|
931
|
+
}
|
|
932
|
+
return output;
|
|
933
|
+
}
|
|
934
|
+
function erf(x) {
|
|
935
|
+
const a1 = 0.254829592;
|
|
936
|
+
const a2 = -0.284496736;
|
|
937
|
+
const a3 = 1.421413741;
|
|
938
|
+
const a4 = -1.453152027;
|
|
939
|
+
const a5 = 1.061405429;
|
|
940
|
+
const p = 0.3275911;
|
|
941
|
+
const sign = x < 0 ? -1 : 1;
|
|
942
|
+
x = Math.abs(x);
|
|
943
|
+
const t = 1 / (1 + p * x);
|
|
944
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
945
|
+
return sign * y;
|
|
946
|
+
}
|
|
947
|
+
function siluCPU(x) {
|
|
948
|
+
const output = new Float32Array(x.length);
|
|
949
|
+
for (let i = 0;i < x.length; i++) {
|
|
950
|
+
const xi = x[i];
|
|
951
|
+
output[i] = xi / (1 + Math.exp(-xi));
|
|
952
|
+
}
|
|
953
|
+
return output;
|
|
954
|
+
}
|
|
955
|
+
function reluCPU(x) {
|
|
956
|
+
const output = new Float32Array(x.length);
|
|
957
|
+
for (let i = 0;i < x.length; i++) {
|
|
958
|
+
output[i] = Math.max(0, x[i]);
|
|
959
|
+
}
|
|
960
|
+
return output;
|
|
961
|
+
}
|
|
962
|
+
function sigmoidCPU(x) {
|
|
963
|
+
const output = new Float32Array(x.length);
|
|
964
|
+
for (let i = 0;i < x.length; i++) {
|
|
965
|
+
output[i] = 1 / (1 + Math.exp(-x[i]));
|
|
966
|
+
}
|
|
967
|
+
return output;
|
|
968
|
+
}
|
|
969
|
+
async function gelu(device, x) {
|
|
970
|
+
const cache = getCache3(device.device);
|
|
971
|
+
const pipeline = cache.getOrCreate("gelu", () => compileGeluKernel(device.device));
|
|
972
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
973
|
+
const params = new Uint32Array([x.numel]);
|
|
974
|
+
const paramsBuffer = device.device.createBuffer({
|
|
975
|
+
size: params.byteLength,
|
|
976
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
977
|
+
});
|
|
978
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
979
|
+
const bindGroup = device.device.createBindGroup({
|
|
980
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
981
|
+
entries: [
|
|
982
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
983
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
984
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
985
|
+
]
|
|
986
|
+
});
|
|
987
|
+
const encoder = device.createCommandEncoder();
|
|
988
|
+
const pass = encoder.beginComputePass();
|
|
989
|
+
pass.setPipeline(pipeline);
|
|
990
|
+
pass.setBindGroup(0, bindGroup);
|
|
991
|
+
pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
|
|
992
|
+
pass.end();
|
|
993
|
+
device.submit([encoder.finish()]);
|
|
994
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
995
|
+
paramsBuffer.destroy();
|
|
996
|
+
return output;
|
|
997
|
+
}
|
|
998
|
+
async function silu(device, x) {
|
|
999
|
+
const cache = getCache3(device.device);
|
|
1000
|
+
const pipeline = cache.getOrCreate("silu", () => compileSiluKernel(device.device));
|
|
1001
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
1002
|
+
const params = new Uint32Array([x.numel]);
|
|
1003
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1004
|
+
size: params.byteLength,
|
|
1005
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1006
|
+
});
|
|
1007
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1008
|
+
const bindGroup = device.device.createBindGroup({
|
|
1009
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1010
|
+
entries: [
|
|
1011
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
1012
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
1013
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
1014
|
+
]
|
|
1015
|
+
});
|
|
1016
|
+
const encoder = device.createCommandEncoder();
|
|
1017
|
+
const pass = encoder.beginComputePass();
|
|
1018
|
+
pass.setPipeline(pipeline);
|
|
1019
|
+
pass.setBindGroup(0, bindGroup);
|
|
1020
|
+
pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
|
|
1021
|
+
pass.end();
|
|
1022
|
+
device.submit([encoder.finish()]);
|
|
1023
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1024
|
+
paramsBuffer.destroy();
|
|
1025
|
+
return output;
|
|
1026
|
+
}
|
|
1027
|
+
async function relu(device, x) {
|
|
1028
|
+
const cache = getCache3(device.device);
|
|
1029
|
+
const pipeline = cache.getOrCreate("relu", () => compileReluKernel(device.device));
|
|
1030
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
1031
|
+
const params = new Uint32Array([x.numel]);
|
|
1032
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1033
|
+
size: params.byteLength,
|
|
1034
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1035
|
+
});
|
|
1036
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1037
|
+
const bindGroup = device.device.createBindGroup({
|
|
1038
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1039
|
+
entries: [
|
|
1040
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
1041
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
1042
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
1043
|
+
]
|
|
1044
|
+
});
|
|
1045
|
+
const encoder = device.createCommandEncoder();
|
|
1046
|
+
const pass = encoder.beginComputePass();
|
|
1047
|
+
pass.setPipeline(pipeline);
|
|
1048
|
+
pass.setBindGroup(0, bindGroup);
|
|
1049
|
+
pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
|
|
1050
|
+
pass.end();
|
|
1051
|
+
device.submit([encoder.finish()]);
|
|
1052
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1053
|
+
paramsBuffer.destroy();
|
|
1054
|
+
return output;
|
|
1055
|
+
}
|
|
1056
|
+
function compileGeluKernel(device) {
|
|
1057
|
+
const wgsl = `
|
|
1058
|
+
struct Params {
|
|
1059
|
+
size: u32,
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
1063
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1064
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1065
|
+
|
|
1066
|
+
const SQRT_2_OVER_PI: f32 = 0.7978845608;
|
|
1067
|
+
const COEFF: f32 = 0.044715;
|
|
1068
|
+
|
|
1069
|
+
@compute @workgroup_size(256)
|
|
1070
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1071
|
+
let idx = gid.x;
|
|
1072
|
+
if (idx >= params.size) {
|
|
1073
|
+
return;
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
let xi = x[idx];
|
|
1077
|
+
let inner = SQRT_2_OVER_PI * (xi + COEFF * xi * xi * xi);
|
|
1078
|
+
output[idx] = xi * 0.5 * (1.0 + tanh(inner));
|
|
1079
|
+
}
|
|
1080
|
+
`;
|
|
1081
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1082
|
+
return device.createComputePipeline({
|
|
1083
|
+
layout: "auto",
|
|
1084
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1085
|
+
});
|
|
1086
|
+
}
|
|
1087
|
+
function compileSiluKernel(device) {
|
|
1088
|
+
const wgsl = `
|
|
1089
|
+
struct Params {
|
|
1090
|
+
size: u32,
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
1094
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1095
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1096
|
+
|
|
1097
|
+
@compute @workgroup_size(256)
|
|
1098
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1099
|
+
let idx = gid.x;
|
|
1100
|
+
if (idx >= params.size) {
|
|
1101
|
+
return;
|
|
1102
|
+
}
|
|
1103
|
+
|
|
1104
|
+
let xi = x[idx];
|
|
1105
|
+
// SiLU: x * sigmoid(x) = x / (1 + exp(-x))
|
|
1106
|
+
output[idx] = xi / (1.0 + exp(-xi));
|
|
1107
|
+
}
|
|
1108
|
+
`;
|
|
1109
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1110
|
+
return device.createComputePipeline({
|
|
1111
|
+
layout: "auto",
|
|
1112
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1113
|
+
});
|
|
1114
|
+
}
|
|
1115
|
+
function compileReluKernel(device) {
|
|
1116
|
+
const wgsl = `
|
|
1117
|
+
struct Params {
|
|
1118
|
+
size: u32,
|
|
1119
|
+
}
|
|
1120
|
+
|
|
1121
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
1122
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1123
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1124
|
+
|
|
1125
|
+
@compute @workgroup_size(256)
|
|
1126
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1127
|
+
let idx = gid.x;
|
|
1128
|
+
if (idx >= params.size) {
|
|
1129
|
+
return;
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
output[idx] = max(0.0, x[idx]);
|
|
1133
|
+
}
|
|
1134
|
+
`;
|
|
1135
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1136
|
+
return device.createComputePipeline({
|
|
1137
|
+
layout: "auto",
|
|
1138
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1139
|
+
});
|
|
1140
|
+
}
|
|
1141
|
+
// src/ops/softmax.ts
|
|
1142
|
+
var kernelCache4 = null;
|
|
1143
|
+
function getCache4(device) {
|
|
1144
|
+
if (!kernelCache4) {
|
|
1145
|
+
kernelCache4 = new KernelCache(device);
|
|
1146
|
+
}
|
|
1147
|
+
return kernelCache4;
|
|
1148
|
+
}
|
|
1149
|
+
function softmaxCPU(x, shape) {
|
|
1150
|
+
const lastDim = shape[shape.length - 1];
|
|
1151
|
+
const outerSize = x.length / lastDim;
|
|
1152
|
+
const output = new Float32Array(x.length);
|
|
1153
|
+
for (let i = 0;i < outerSize; i++) {
|
|
1154
|
+
const offset = i * lastDim;
|
|
1155
|
+
let maxVal = -Infinity;
|
|
1156
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1157
|
+
maxVal = Math.max(maxVal, x[offset + j]);
|
|
1158
|
+
}
|
|
1159
|
+
let sumExp = 0;
|
|
1160
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1161
|
+
const expVal = Math.exp(x[offset + j] - maxVal);
|
|
1162
|
+
output[offset + j] = expVal;
|
|
1163
|
+
sumExp += expVal;
|
|
1164
|
+
}
|
|
1165
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1166
|
+
output[offset + j] = output[offset + j] / sumExp;
|
|
1167
|
+
}
|
|
1168
|
+
}
|
|
1169
|
+
return output;
|
|
1170
|
+
}
|
|
1171
|
+
function logSoftmaxCPU(x, shape) {
|
|
1172
|
+
const lastDim = shape[shape.length - 1];
|
|
1173
|
+
const outerSize = x.length / lastDim;
|
|
1174
|
+
const output = new Float32Array(x.length);
|
|
1175
|
+
for (let i = 0;i < outerSize; i++) {
|
|
1176
|
+
const offset = i * lastDim;
|
|
1177
|
+
let maxVal = -Infinity;
|
|
1178
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1179
|
+
maxVal = Math.max(maxVal, x[offset + j]);
|
|
1180
|
+
}
|
|
1181
|
+
let sumExp = 0;
|
|
1182
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1183
|
+
sumExp += Math.exp(x[offset + j] - maxVal);
|
|
1184
|
+
}
|
|
1185
|
+
const logSumExp = maxVal + Math.log(sumExp);
|
|
1186
|
+
for (let j = 0;j < lastDim; j++) {
|
|
1187
|
+
output[offset + j] = x[offset + j] - logSumExp;
|
|
1188
|
+
}
|
|
1189
|
+
}
|
|
1190
|
+
return output;
|
|
1191
|
+
}
|
|
1192
|
+
async function softmaxGPU(device, x) {
|
|
1193
|
+
const lastDim = x.shape[x.shape.length - 1];
|
|
1194
|
+
const outerSize = x.numel / lastDim;
|
|
1195
|
+
const cache = getCache4(device.device);
|
|
1196
|
+
const pipeline = cache.getOrCreate(`softmax_${lastDim}`, () => compileSoftmaxKernel(device.device, lastDim));
|
|
1197
|
+
const output = Tensor.zeros(device, [...x.shape]);
|
|
1198
|
+
const params = new Uint32Array([outerSize, lastDim]);
|
|
1199
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1200
|
+
size: params.byteLength,
|
|
1201
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1202
|
+
});
|
|
1203
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1204
|
+
const bindGroup = device.device.createBindGroup({
|
|
1205
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1206
|
+
entries: [
|
|
1207
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
1208
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
1209
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
1210
|
+
]
|
|
1211
|
+
});
|
|
1212
|
+
const encoder = device.createCommandEncoder();
|
|
1213
|
+
const pass = encoder.beginComputePass();
|
|
1214
|
+
pass.setPipeline(pipeline);
|
|
1215
|
+
pass.setBindGroup(0, bindGroup);
|
|
1216
|
+
pass.dispatchWorkgroups(outerSize);
|
|
1217
|
+
pass.end();
|
|
1218
|
+
device.submit([encoder.finish()]);
|
|
1219
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1220
|
+
paramsBuffer.destroy();
|
|
1221
|
+
return output;
|
|
1222
|
+
}
|
|
1223
|
+
function compileSoftmaxKernel(device, dim) {
|
|
1224
|
+
const wgsl = `
|
|
1225
|
+
struct Params {
|
|
1226
|
+
outerSize: u32,
|
|
1227
|
+
dim: u32,
|
|
1228
|
+
}
|
|
1229
|
+
|
|
1230
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
1231
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1232
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1233
|
+
|
|
1234
|
+
@compute @workgroup_size(1)
|
|
1235
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1236
|
+
let idx = gid.x;
|
|
1237
|
+
if (idx >= params.outerSize) {
|
|
1238
|
+
return;
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
let dim = params.dim;
|
|
1242
|
+
let offset = idx * dim;
|
|
1243
|
+
|
|
1244
|
+
// Find max
|
|
1245
|
+
var maxVal: f32 = x[offset];
|
|
1246
|
+
for (var j = 1u; j < dim; j = j + 1u) {
|
|
1247
|
+
maxVal = max(maxVal, x[offset + j]);
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
// Compute exp and sum
|
|
1251
|
+
var sumExp: f32 = 0.0;
|
|
1252
|
+
for (var j = 0u; j < dim; j = j + 1u) {
|
|
1253
|
+
let expVal = exp(x[offset + j] - maxVal);
|
|
1254
|
+
output[offset + j] = expVal;
|
|
1255
|
+
sumExp = sumExp + expVal;
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
// Normalize
|
|
1259
|
+
let invSum = 1.0 / sumExp;
|
|
1260
|
+
for (var j = 0u; j < dim; j = j + 1u) {
|
|
1261
|
+
output[offset + j] = output[offset + j] * invSum;
|
|
1262
|
+
}
|
|
1263
|
+
}
|
|
1264
|
+
`;
|
|
1265
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1266
|
+
return device.createComputePipeline({
|
|
1267
|
+
layout: "auto",
|
|
1268
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1269
|
+
});
|
|
1270
|
+
}
|
|
1271
|
+
// src/ops/elementwise.ts
|
|
1272
|
+
var kernelCache5 = null;
|
|
1273
|
+
function getCache5(device) {
|
|
1274
|
+
if (!kernelCache5) {
|
|
1275
|
+
kernelCache5 = new KernelCache(device);
|
|
1276
|
+
}
|
|
1277
|
+
return kernelCache5;
|
|
1278
|
+
}
|
|
1279
|
+
function addCPU(a, b) {
|
|
1280
|
+
if (a.length !== b.length) {
|
|
1281
|
+
throw new Error(`Shape mismatch: ${a.length} vs ${b.length}`);
|
|
1282
|
+
}
|
|
1283
|
+
const output = new Float32Array(a.length);
|
|
1284
|
+
for (let i = 0;i < a.length; i++) {
|
|
1285
|
+
output[i] = a[i] + b[i];
|
|
1286
|
+
}
|
|
1287
|
+
return output;
|
|
1288
|
+
}
|
|
1289
|
+
function mulCPU(a, b) {
|
|
1290
|
+
if (a.length !== b.length) {
|
|
1291
|
+
throw new Error(`Shape mismatch: ${a.length} vs ${b.length}`);
|
|
1292
|
+
}
|
|
1293
|
+
const output = new Float32Array(a.length);
|
|
1294
|
+
for (let i = 0;i < a.length; i++) {
|
|
1295
|
+
output[i] = a[i] * b[i];
|
|
1296
|
+
}
|
|
1297
|
+
return output;
|
|
1298
|
+
}
|
|
1299
|
+
function scaleCPU(a, scalar) {
|
|
1300
|
+
const output = new Float32Array(a.length);
|
|
1301
|
+
for (let i = 0;i < a.length; i++) {
|
|
1302
|
+
output[i] = a[i] * scalar;
|
|
1303
|
+
}
|
|
1304
|
+
return output;
|
|
1305
|
+
}
|
|
1306
|
+
function addScalarCPU(a, scalar) {
|
|
1307
|
+
const output = new Float32Array(a.length);
|
|
1308
|
+
for (let i = 0;i < a.length; i++) {
|
|
1309
|
+
output[i] = a[i] + scalar;
|
|
1310
|
+
}
|
|
1311
|
+
return output;
|
|
1312
|
+
}
|
|
1313
|
+
function fmaCPU(a, b, c) {
|
|
1314
|
+
if (a.length !== b.length || a.length !== c.length) {
|
|
1315
|
+
throw new Error("Shape mismatch");
|
|
1316
|
+
}
|
|
1317
|
+
const output = new Float32Array(a.length);
|
|
1318
|
+
for (let i = 0;i < a.length; i++) {
|
|
1319
|
+
output[i] = a[i] * b[i] + c[i];
|
|
1320
|
+
}
|
|
1321
|
+
return output;
|
|
1322
|
+
}
|
|
1323
|
+
async function add(device, a, b) {
|
|
1324
|
+
if (a.numel !== b.numel) {
|
|
1325
|
+
throw new Error(`Shape mismatch: ${a.shape} vs ${b.shape}`);
|
|
1326
|
+
}
|
|
1327
|
+
const cache = getCache5(device.device);
|
|
1328
|
+
const pipeline = cache.getOrCreate("add", () => compileAddKernel(device.device));
|
|
1329
|
+
const output = Tensor.zeros(device, [...a.shape]);
|
|
1330
|
+
const params = new Uint32Array([a.numel]);
|
|
1331
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1332
|
+
size: params.byteLength,
|
|
1333
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1334
|
+
});
|
|
1335
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1336
|
+
const bindGroup = device.device.createBindGroup({
|
|
1337
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1338
|
+
entries: [
|
|
1339
|
+
{ binding: 0, resource: { buffer: a.buffer } },
|
|
1340
|
+
{ binding: 1, resource: { buffer: b.buffer } },
|
|
1341
|
+
{ binding: 2, resource: { buffer: output.buffer } },
|
|
1342
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
1343
|
+
]
|
|
1344
|
+
});
|
|
1345
|
+
const encoder = device.createCommandEncoder();
|
|
1346
|
+
const pass = encoder.beginComputePass();
|
|
1347
|
+
pass.setPipeline(pipeline);
|
|
1348
|
+
pass.setBindGroup(0, bindGroup);
|
|
1349
|
+
pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
|
|
1350
|
+
pass.end();
|
|
1351
|
+
device.submit([encoder.finish()]);
|
|
1352
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1353
|
+
paramsBuffer.destroy();
|
|
1354
|
+
return output;
|
|
1355
|
+
}
|
|
1356
|
+
async function mul(device, a, b) {
|
|
1357
|
+
if (a.numel !== b.numel) {
|
|
1358
|
+
throw new Error(`Shape mismatch: ${a.shape} vs ${b.shape}`);
|
|
1359
|
+
}
|
|
1360
|
+
const cache = getCache5(device.device);
|
|
1361
|
+
const pipeline = cache.getOrCreate("mul", () => compileMulKernel(device.device));
|
|
1362
|
+
const output = Tensor.zeros(device, [...a.shape]);
|
|
1363
|
+
const params = new Uint32Array([a.numel]);
|
|
1364
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1365
|
+
size: params.byteLength,
|
|
1366
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1367
|
+
});
|
|
1368
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1369
|
+
const bindGroup = device.device.createBindGroup({
|
|
1370
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1371
|
+
entries: [
|
|
1372
|
+
{ binding: 0, resource: { buffer: a.buffer } },
|
|
1373
|
+
{ binding: 1, resource: { buffer: b.buffer } },
|
|
1374
|
+
{ binding: 2, resource: { buffer: output.buffer } },
|
|
1375
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
1376
|
+
]
|
|
1377
|
+
});
|
|
1378
|
+
const encoder = device.createCommandEncoder();
|
|
1379
|
+
const pass = encoder.beginComputePass();
|
|
1380
|
+
pass.setPipeline(pipeline);
|
|
1381
|
+
pass.setBindGroup(0, bindGroup);
|
|
1382
|
+
pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
|
|
1383
|
+
pass.end();
|
|
1384
|
+
device.submit([encoder.finish()]);
|
|
1385
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1386
|
+
paramsBuffer.destroy();
|
|
1387
|
+
return output;
|
|
1388
|
+
}
|
|
1389
|
+
async function scale(device, a, scalar) {
|
|
1390
|
+
const cache = getCache5(device.device);
|
|
1391
|
+
const pipeline = cache.getOrCreate("scale", () => compileScaleKernel(device.device));
|
|
1392
|
+
const output = Tensor.zeros(device, [...a.shape]);
|
|
1393
|
+
const params = new Float32Array([a.numel, scalar]);
|
|
1394
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1395
|
+
size: 8,
|
|
1396
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1397
|
+
});
|
|
1398
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, new Uint32Array([a.numel]));
|
|
1399
|
+
device.device.queue.writeBuffer(paramsBuffer, 4, new Float32Array([scalar]));
|
|
1400
|
+
const bindGroup = device.device.createBindGroup({
|
|
1401
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1402
|
+
entries: [
|
|
1403
|
+
{ binding: 0, resource: { buffer: a.buffer } },
|
|
1404
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
1405
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
1406
|
+
]
|
|
1407
|
+
});
|
|
1408
|
+
const encoder = device.createCommandEncoder();
|
|
1409
|
+
const pass = encoder.beginComputePass();
|
|
1410
|
+
pass.setPipeline(pipeline);
|
|
1411
|
+
pass.setBindGroup(0, bindGroup);
|
|
1412
|
+
pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
|
|
1413
|
+
pass.end();
|
|
1414
|
+
device.submit([encoder.finish()]);
|
|
1415
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1416
|
+
paramsBuffer.destroy();
|
|
1417
|
+
return output;
|
|
1418
|
+
}
|
|
1419
|
+
function compileAddKernel(device) {
|
|
1420
|
+
const wgsl = `
|
|
1421
|
+
struct Params {
|
|
1422
|
+
size: u32,
|
|
1423
|
+
}
|
|
1424
|
+
|
|
1425
|
+
@group(0) @binding(0) var<storage, read> a: array<f32>;
|
|
1426
|
+
@group(0) @binding(1) var<storage, read> b: array<f32>;
|
|
1427
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
1428
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
1429
|
+
|
|
1430
|
+
@compute @workgroup_size(256)
|
|
1431
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1432
|
+
let idx = gid.x;
|
|
1433
|
+
if (idx >= params.size) {
|
|
1434
|
+
return;
|
|
1435
|
+
}
|
|
1436
|
+
output[idx] = a[idx] + b[idx];
|
|
1437
|
+
}
|
|
1438
|
+
`;
|
|
1439
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1440
|
+
return device.createComputePipeline({
|
|
1441
|
+
layout: "auto",
|
|
1442
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1443
|
+
});
|
|
1444
|
+
}
|
|
1445
|
+
function compileMulKernel(device) {
|
|
1446
|
+
const wgsl = `
|
|
1447
|
+
struct Params {
|
|
1448
|
+
size: u32,
|
|
1449
|
+
}
|
|
1450
|
+
|
|
1451
|
+
@group(0) @binding(0) var<storage, read> a: array<f32>;
|
|
1452
|
+
@group(0) @binding(1) var<storage, read> b: array<f32>;
|
|
1453
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
1454
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
1455
|
+
|
|
1456
|
+
@compute @workgroup_size(256)
|
|
1457
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1458
|
+
let idx = gid.x;
|
|
1459
|
+
if (idx >= params.size) {
|
|
1460
|
+
return;
|
|
1461
|
+
}
|
|
1462
|
+
output[idx] = a[idx] * b[idx];
|
|
1463
|
+
}
|
|
1464
|
+
`;
|
|
1465
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1466
|
+
return device.createComputePipeline({
|
|
1467
|
+
layout: "auto",
|
|
1468
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1469
|
+
});
|
|
1470
|
+
}
|
|
1471
|
+
function compileScaleKernel(device) {
|
|
1472
|
+
const wgsl = `
|
|
1473
|
+
struct Params {
|
|
1474
|
+
size: u32,
|
|
1475
|
+
scalar: f32,
|
|
1476
|
+
}
|
|
1477
|
+
|
|
1478
|
+
@group(0) @binding(0) var<storage, read> a: array<f32>;
|
|
1479
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1480
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1481
|
+
|
|
1482
|
+
@compute @workgroup_size(256)
|
|
1483
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1484
|
+
let idx = gid.x;
|
|
1485
|
+
if (idx >= params.size) {
|
|
1486
|
+
return;
|
|
1487
|
+
}
|
|
1488
|
+
output[idx] = a[idx] * params.scalar;
|
|
1489
|
+
}
|
|
1490
|
+
`;
|
|
1491
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1492
|
+
return device.createComputePipeline({
|
|
1493
|
+
layout: "auto",
|
|
1494
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1495
|
+
});
|
|
1496
|
+
}
|
|
1497
|
+
// src/ops/embedding.ts
|
|
1498
|
+
var kernelCache6 = null;
|
|
1499
|
+
function getCache6(device) {
|
|
1500
|
+
if (!kernelCache6) {
|
|
1501
|
+
kernelCache6 = new KernelCache(device);
|
|
1502
|
+
}
|
|
1503
|
+
return kernelCache6;
|
|
1504
|
+
}
|
|
1505
|
+
function embeddingCPU(embeddings, tokens, embeddingDim) {
|
|
1506
|
+
const seqLen = tokens.length;
|
|
1507
|
+
const output = new Float32Array(seqLen * embeddingDim);
|
|
1508
|
+
for (let i = 0;i < seqLen; i++) {
|
|
1509
|
+
const tokenId = tokens[i];
|
|
1510
|
+
const srcOffset = tokenId * embeddingDim;
|
|
1511
|
+
const dstOffset = i * embeddingDim;
|
|
1512
|
+
for (let j = 0;j < embeddingDim; j++) {
|
|
1513
|
+
output[dstOffset + j] = embeddings[srcOffset + j];
|
|
1514
|
+
}
|
|
1515
|
+
}
|
|
1516
|
+
return output;
|
|
1517
|
+
}
|
|
1518
|
+
async function embedding(device, embeddings, tokens) {
|
|
1519
|
+
if (embeddings.shape.length !== 2) {
|
|
1520
|
+
throw new Error("Embedding table must be 2D [vocabSize, embeddingDim]");
|
|
1521
|
+
}
|
|
1522
|
+
if (tokens.shape.length !== 1) {
|
|
1523
|
+
throw new Error("Tokens must be 1D [seqLen]");
|
|
1524
|
+
}
|
|
1525
|
+
const [, embeddingDim] = embeddings.shape;
|
|
1526
|
+
const seqLen = tokens.shape[0];
|
|
1527
|
+
const cache = getCache6(device.device);
|
|
1528
|
+
const pipeline = cache.getOrCreate(`embedding_${embeddingDim}`, () => compileEmbeddingKernel(device.device, embeddingDim));
|
|
1529
|
+
const output = Tensor.zeros(device, [seqLen, embeddingDim]);
|
|
1530
|
+
const params = new Uint32Array([seqLen, embeddingDim]);
|
|
1531
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1532
|
+
size: params.byteLength,
|
|
1533
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1534
|
+
});
|
|
1535
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1536
|
+
const bindGroup = device.device.createBindGroup({
|
|
1537
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1538
|
+
entries: [
|
|
1539
|
+
{ binding: 0, resource: { buffer: embeddings.buffer } },
|
|
1540
|
+
{ binding: 1, resource: { buffer: tokens.buffer } },
|
|
1541
|
+
{ binding: 2, resource: { buffer: output.buffer } },
|
|
1542
|
+
{ binding: 3, resource: { buffer: paramsBuffer } }
|
|
1543
|
+
]
|
|
1544
|
+
});
|
|
1545
|
+
const encoder = device.createCommandEncoder();
|
|
1546
|
+
const pass = encoder.beginComputePass();
|
|
1547
|
+
pass.setPipeline(pipeline);
|
|
1548
|
+
pass.setBindGroup(0, bindGroup);
|
|
1549
|
+
pass.dispatchWorkgroups(seqLen);
|
|
1550
|
+
pass.end();
|
|
1551
|
+
device.submit([encoder.finish()]);
|
|
1552
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1553
|
+
paramsBuffer.destroy();
|
|
1554
|
+
return output;
|
|
1555
|
+
}
|
|
1556
|
+
function compileEmbeddingKernel(device, embeddingDim) {
|
|
1557
|
+
const wgsl = `
|
|
1558
|
+
struct Params {
|
|
1559
|
+
seqLen: u32,
|
|
1560
|
+
embeddingDim: u32,
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
@group(0) @binding(0) var<storage, read> embeddings: array<f32>;
|
|
1564
|
+
@group(0) @binding(1) var<storage, read> tokens: array<u32>;
|
|
1565
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
1566
|
+
@group(0) @binding(3) var<uniform> params: Params;
|
|
1567
|
+
|
|
1568
|
+
@compute @workgroup_size(1)
|
|
1569
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1570
|
+
let seqIdx = gid.x;
|
|
1571
|
+
if (seqIdx >= params.seqLen) {
|
|
1572
|
+
return;
|
|
1573
|
+
}
|
|
1574
|
+
|
|
1575
|
+
let tokenId = tokens[seqIdx];
|
|
1576
|
+
let srcOffset = tokenId * params.embeddingDim;
|
|
1577
|
+
let dstOffset = seqIdx * params.embeddingDim;
|
|
1578
|
+
|
|
1579
|
+
for (var j = 0u; j < params.embeddingDim; j = j + 1u) {
|
|
1580
|
+
output[dstOffset + j] = embeddings[srcOffset + j];
|
|
1581
|
+
}
|
|
1582
|
+
}
|
|
1583
|
+
`;
|
|
1584
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1585
|
+
return device.createComputePipeline({
|
|
1586
|
+
layout: "auto",
|
|
1587
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1588
|
+
});
|
|
1589
|
+
}
|
|
1590
|
+
function batchedEmbeddingCPU(embeddings, tokens, embeddingDim) {
|
|
1591
|
+
const batchSize = tokens.length;
|
|
1592
|
+
const seqLen = tokens[0]?.length ?? 0;
|
|
1593
|
+
const output = new Float32Array(batchSize * seqLen * embeddingDim);
|
|
1594
|
+
for (let b = 0;b < batchSize; b++) {
|
|
1595
|
+
for (let i = 0;i < seqLen; i++) {
|
|
1596
|
+
const tokenId = tokens[b][i];
|
|
1597
|
+
const srcOffset = tokenId * embeddingDim;
|
|
1598
|
+
const dstOffset = (b * seqLen + i) * embeddingDim;
|
|
1599
|
+
for (let j = 0;j < embeddingDim; j++) {
|
|
1600
|
+
output[dstOffset + j] = embeddings[srcOffset + j];
|
|
1601
|
+
}
|
|
1602
|
+
}
|
|
1603
|
+
}
|
|
1604
|
+
return output;
|
|
1605
|
+
}
|
|
1606
|
+
// src/ops/reshape.ts
|
|
1607
|
+
var kernelCache7 = null;
|
|
1608
|
+
function getCache7(device) {
|
|
1609
|
+
if (!kernelCache7) {
|
|
1610
|
+
kernelCache7 = new KernelCache(device);
|
|
1611
|
+
}
|
|
1612
|
+
return kernelCache7;
|
|
1613
|
+
}
|
|
1614
|
+
function transpose2DCPU(x, rows, cols) {
|
|
1615
|
+
const output = new Float32Array(x.length);
|
|
1616
|
+
for (let i = 0;i < rows; i++) {
|
|
1617
|
+
for (let j = 0;j < cols; j++) {
|
|
1618
|
+
output[j * rows + i] = x[i * cols + j];
|
|
1619
|
+
}
|
|
1620
|
+
}
|
|
1621
|
+
return output;
|
|
1622
|
+
}
|
|
1623
|
+
function transposeCPU(x, shape) {
|
|
1624
|
+
if (shape.length < 2) {
|
|
1625
|
+
throw new Error("Transpose requires at least 2D tensor");
|
|
1626
|
+
}
|
|
1627
|
+
const M = shape[shape.length - 2];
|
|
1628
|
+
const N = shape[shape.length - 1];
|
|
1629
|
+
const batchSize = shape.slice(0, -2).reduce((a, b) => a * b, 1);
|
|
1630
|
+
const output = new Float32Array(x.length);
|
|
1631
|
+
const matrixSize = M * N;
|
|
1632
|
+
for (let b = 0;b < batchSize; b++) {
|
|
1633
|
+
const batchOffset = b * matrixSize;
|
|
1634
|
+
for (let i = 0;i < M; i++) {
|
|
1635
|
+
for (let j = 0;j < N; j++) {
|
|
1636
|
+
output[batchOffset + j * M + i] = x[batchOffset + i * N + j];
|
|
1637
|
+
}
|
|
1638
|
+
}
|
|
1639
|
+
}
|
|
1640
|
+
const newShape = [...shape.slice(0, -2), N, M];
|
|
1641
|
+
return { data: output, shape: newShape };
|
|
1642
|
+
}
|
|
1643
|
+
function reshapeCPU(x, oldShape, newShape) {
|
|
1644
|
+
const oldSize = oldShape.reduce((a, b) => a * b, 1);
|
|
1645
|
+
let inferIdx = -1;
|
|
1646
|
+
let knownSize = 1;
|
|
1647
|
+
for (let i = 0;i < newShape.length; i++) {
|
|
1648
|
+
if (newShape[i] === -1) {
|
|
1649
|
+
if (inferIdx !== -1) {
|
|
1650
|
+
throw new Error("Can only have one -1 in reshape");
|
|
1651
|
+
}
|
|
1652
|
+
inferIdx = i;
|
|
1653
|
+
} else {
|
|
1654
|
+
knownSize *= newShape[i];
|
|
1655
|
+
}
|
|
1656
|
+
}
|
|
1657
|
+
const finalShape = [...newShape];
|
|
1658
|
+
if (inferIdx !== -1) {
|
|
1659
|
+
if (oldSize % knownSize !== 0) {
|
|
1660
|
+
throw new Error(`Cannot reshape ${oldShape} to ${newShape}`);
|
|
1661
|
+
}
|
|
1662
|
+
finalShape[inferIdx] = oldSize / knownSize;
|
|
1663
|
+
}
|
|
1664
|
+
const newSize = finalShape.reduce((a, b) => a * b, 1);
|
|
1665
|
+
if (oldSize !== newSize) {
|
|
1666
|
+
throw new Error(`Shape mismatch: ${oldSize} vs ${newSize}`);
|
|
1667
|
+
}
|
|
1668
|
+
return { data: x, shape: finalShape };
|
|
1669
|
+
}
|
|
1670
|
+
async function transpose2D(device, x) {
|
|
1671
|
+
if (x.shape.length !== 2) {
|
|
1672
|
+
throw new Error("transpose2D requires 2D tensor");
|
|
1673
|
+
}
|
|
1674
|
+
const [rows, cols] = x.shape;
|
|
1675
|
+
const cache = getCache7(device.device);
|
|
1676
|
+
const pipeline = cache.getOrCreate(`transpose2d_${rows}_${cols}`, () => compileTranspose2DKernel(device.device));
|
|
1677
|
+
const output = Tensor.zeros(device, [cols, rows]);
|
|
1678
|
+
const params = new Uint32Array([rows, cols]);
|
|
1679
|
+
const paramsBuffer = device.device.createBuffer({
|
|
1680
|
+
size: params.byteLength,
|
|
1681
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
1682
|
+
});
|
|
1683
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, params);
|
|
1684
|
+
const bindGroup = device.device.createBindGroup({
|
|
1685
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1686
|
+
entries: [
|
|
1687
|
+
{ binding: 0, resource: { buffer: x.buffer } },
|
|
1688
|
+
{ binding: 1, resource: { buffer: output.buffer } },
|
|
1689
|
+
{ binding: 2, resource: { buffer: paramsBuffer } }
|
|
1690
|
+
]
|
|
1691
|
+
});
|
|
1692
|
+
const encoder = device.createCommandEncoder();
|
|
1693
|
+
const pass = encoder.beginComputePass();
|
|
1694
|
+
pass.setPipeline(pipeline);
|
|
1695
|
+
pass.setBindGroup(0, bindGroup);
|
|
1696
|
+
pass.dispatchWorkgroups(Math.ceil(cols / 16), Math.ceil(rows / 16));
|
|
1697
|
+
pass.end();
|
|
1698
|
+
device.submit([encoder.finish()]);
|
|
1699
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
1700
|
+
paramsBuffer.destroy();
|
|
1701
|
+
return output;
|
|
1702
|
+
}
|
|
1703
|
+
function permuteCPU(x, shape, dims) {
|
|
1704
|
+
if (dims.length !== shape.length) {
|
|
1705
|
+
throw new Error("Permutation must have same length as shape");
|
|
1706
|
+
}
|
|
1707
|
+
const sorted = [...dims].sort((a, b) => a - b);
|
|
1708
|
+
for (let i = 0;i < sorted.length; i++) {
|
|
1709
|
+
if (sorted[i] !== i) {
|
|
1710
|
+
throw new Error("Invalid permutation");
|
|
1711
|
+
}
|
|
1712
|
+
}
|
|
1713
|
+
const newShape = dims.map((d) => shape[d]);
|
|
1714
|
+
const output = new Float32Array(x.length);
|
|
1715
|
+
const oldStrides = computeStrides(shape);
|
|
1716
|
+
const newStrides = computeStrides(newShape);
|
|
1717
|
+
const ndim = shape.length;
|
|
1718
|
+
const indices = new Array(ndim).fill(0);
|
|
1719
|
+
for (let i = 0;i < x.length; i++) {
|
|
1720
|
+
let remaining = i;
|
|
1721
|
+
for (let d = 0;d < ndim; d++) {
|
|
1722
|
+
indices[d] = Math.floor(remaining / newStrides[d]);
|
|
1723
|
+
remaining = remaining % newStrides[d];
|
|
1724
|
+
}
|
|
1725
|
+
let oldIdx = 0;
|
|
1726
|
+
for (let d = 0;d < ndim; d++) {
|
|
1727
|
+
oldIdx += indices[d] * oldStrides[dims[d]];
|
|
1728
|
+
}
|
|
1729
|
+
output[i] = x[oldIdx];
|
|
1730
|
+
}
|
|
1731
|
+
return { data: output, shape: newShape };
|
|
1732
|
+
}
|
|
1733
|
+
function computeStrides(shape) {
|
|
1734
|
+
const strides = new Array(shape.length);
|
|
1735
|
+
strides[shape.length - 1] = 1;
|
|
1736
|
+
for (let i = shape.length - 2;i >= 0; i--) {
|
|
1737
|
+
strides[i] = strides[i + 1] * shape[i + 1];
|
|
1738
|
+
}
|
|
1739
|
+
return strides;
|
|
1740
|
+
}
|
|
1741
|
+
function compileTranspose2DKernel(device) {
|
|
1742
|
+
const wgsl = `
|
|
1743
|
+
struct Params {
|
|
1744
|
+
rows: u32,
|
|
1745
|
+
cols: u32,
|
|
1746
|
+
}
|
|
1747
|
+
|
|
1748
|
+
@group(0) @binding(0) var<storage, read> x: array<f32>;
|
|
1749
|
+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
|
1750
|
+
@group(0) @binding(2) var<uniform> params: Params;
|
|
1751
|
+
|
|
1752
|
+
@compute @workgroup_size(16, 16)
|
|
1753
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
1754
|
+
let col = gid.x;
|
|
1755
|
+
let row = gid.y;
|
|
1756
|
+
|
|
1757
|
+
if (col >= params.cols || row >= params.rows) {
|
|
1758
|
+
return;
|
|
1759
|
+
}
|
|
1760
|
+
|
|
1761
|
+
// Input: [row, col] at row * cols + col
|
|
1762
|
+
// Output: [col, row] at col * rows + row
|
|
1763
|
+
output[col * params.rows + row] = x[row * params.cols + col];
|
|
1764
|
+
}
|
|
1765
|
+
`;
|
|
1766
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
1767
|
+
return device.createComputePipeline({
|
|
1768
|
+
layout: "auto",
|
|
1769
|
+
compute: { module: shaderModule, entryPoint: "main" }
|
|
1770
|
+
});
|
|
1771
|
+
}
|
|
1772
|
+
// src/quantization/quantize.ts
|
|
1773
|
+
function quantizeToInt8(x, groupSize = 128, symmetric = true) {
|
|
1774
|
+
const numGroups = Math.ceil(x.length / groupSize);
|
|
1775
|
+
const scales = new Float32Array(numGroups);
|
|
1776
|
+
const zeros = symmetric ? null : new Float32Array(numGroups);
|
|
1777
|
+
const quantized = new Uint8Array(x.length);
|
|
1778
|
+
for (let g = 0;g < numGroups; g++) {
|
|
1779
|
+
const start = g * groupSize;
|
|
1780
|
+
const end = Math.min(start + groupSize, x.length);
|
|
1781
|
+
let minVal = x[start];
|
|
1782
|
+
let maxVal = x[start];
|
|
1783
|
+
for (let i = start;i < end; i++) {
|
|
1784
|
+
minVal = Math.min(minVal, x[i]);
|
|
1785
|
+
maxVal = Math.max(maxVal, x[i]);
|
|
1786
|
+
}
|
|
1787
|
+
if (symmetric) {
|
|
1788
|
+
const absMax = Math.max(Math.abs(minVal), Math.abs(maxVal));
|
|
1789
|
+
scales[g] = absMax / 127;
|
|
1790
|
+
const scale2 = scales[g];
|
|
1791
|
+
const invScale = scale2 > 0 ? 1 / scale2 : 0;
|
|
1792
|
+
for (let i = start;i < end; i++) {
|
|
1793
|
+
const q = Math.round(x[i] * invScale);
|
|
1794
|
+
quantized[i] = Math.max(0, Math.min(255, q + 128));
|
|
1795
|
+
}
|
|
1796
|
+
} else {
|
|
1797
|
+
scales[g] = (maxVal - minVal) / 255;
|
|
1798
|
+
zeros[g] = minVal;
|
|
1799
|
+
const scale2 = scales[g];
|
|
1800
|
+
const invScale = scale2 > 0 ? 1 / scale2 : 0;
|
|
1801
|
+
for (let i = start;i < end; i++) {
|
|
1802
|
+
const q = Math.round((x[i] - zeros[g]) * invScale);
|
|
1803
|
+
quantized[i] = Math.max(0, Math.min(255, q));
|
|
1804
|
+
}
|
|
1805
|
+
}
|
|
1806
|
+
}
|
|
1807
|
+
return {
|
|
1808
|
+
data: quantized,
|
|
1809
|
+
scales,
|
|
1810
|
+
zeros,
|
|
1811
|
+
shape: [x.length],
|
|
1812
|
+
config: { bits: 8, groupSize, symmetric }
|
|
1813
|
+
};
|
|
1814
|
+
}
|
|
1815
|
+
function quantizeToInt4(x, groupSize = 128, symmetric = true) {
|
|
1816
|
+
const numGroups = Math.ceil(x.length / groupSize);
|
|
1817
|
+
const scales = new Float32Array(numGroups);
|
|
1818
|
+
const zeros = symmetric ? null : new Float32Array(numGroups);
|
|
1819
|
+
const packedSize = Math.ceil(x.length / 2);
|
|
1820
|
+
const quantized = new Uint8Array(packedSize);
|
|
1821
|
+
for (let g = 0;g < numGroups; g++) {
|
|
1822
|
+
const start = g * groupSize;
|
|
1823
|
+
const end = Math.min(start + groupSize, x.length);
|
|
1824
|
+
let minVal = x[start];
|
|
1825
|
+
let maxVal = x[start];
|
|
1826
|
+
for (let i = start;i < end; i++) {
|
|
1827
|
+
minVal = Math.min(minVal, x[i]);
|
|
1828
|
+
maxVal = Math.max(maxVal, x[i]);
|
|
1829
|
+
}
|
|
1830
|
+
if (symmetric) {
|
|
1831
|
+
const absMax = Math.max(Math.abs(minVal), Math.abs(maxVal));
|
|
1832
|
+
scales[g] = absMax / 7;
|
|
1833
|
+
const scale2 = scales[g];
|
|
1834
|
+
const invScale = scale2 > 0 ? 1 / scale2 : 0;
|
|
1835
|
+
for (let i = start;i < end; i++) {
|
|
1836
|
+
const q = Math.round(x[i] * invScale);
|
|
1837
|
+
const uq = Math.max(0, Math.min(15, q + 8));
|
|
1838
|
+
const byteIdx = Math.floor(i / 2);
|
|
1839
|
+
if (i % 2 === 0) {
|
|
1840
|
+
quantized[byteIdx] = uq;
|
|
1841
|
+
} else {
|
|
1842
|
+
quantized[byteIdx] = quantized[byteIdx] | uq << 4;
|
|
1843
|
+
}
|
|
1844
|
+
}
|
|
1845
|
+
} else {
|
|
1846
|
+
scales[g] = (maxVal - minVal) / 15;
|
|
1847
|
+
zeros[g] = minVal;
|
|
1848
|
+
const scale2 = scales[g];
|
|
1849
|
+
const invScale = scale2 > 0 ? 1 / scale2 : 0;
|
|
1850
|
+
for (let i = start;i < end; i++) {
|
|
1851
|
+
const q = Math.round((x[i] - zeros[g]) * invScale);
|
|
1852
|
+
const uq = Math.max(0, Math.min(15, q));
|
|
1853
|
+
const byteIdx = Math.floor(i / 2);
|
|
1854
|
+
if (i % 2 === 0) {
|
|
1855
|
+
quantized[byteIdx] = uq;
|
|
1856
|
+
} else {
|
|
1857
|
+
quantized[byteIdx] = quantized[byteIdx] | uq << 4;
|
|
1858
|
+
}
|
|
1859
|
+
}
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
return {
|
|
1863
|
+
data: quantized,
|
|
1864
|
+
scales,
|
|
1865
|
+
zeros,
|
|
1866
|
+
shape: [x.length],
|
|
1867
|
+
config: { bits: 4, groupSize, symmetric }
|
|
1868
|
+
};
|
|
1869
|
+
}
|
|
1870
|
+
function dequantizeInt8(qt) {
|
|
1871
|
+
if (qt.config.bits !== 8) {
|
|
1872
|
+
throw new Error("Expected INT8 quantized tensor");
|
|
1873
|
+
}
|
|
1874
|
+
const { data, scales, zeros, config } = qt;
|
|
1875
|
+
const { groupSize, symmetric } = config;
|
|
1876
|
+
const output = new Float32Array(data.length);
|
|
1877
|
+
for (let i = 0;i < data.length; i++) {
|
|
1878
|
+
const g = Math.floor(i / groupSize);
|
|
1879
|
+
const scale2 = scales[g];
|
|
1880
|
+
if (symmetric) {
|
|
1881
|
+
output[i] = (data[i] - 128) * scale2;
|
|
1882
|
+
} else {
|
|
1883
|
+
output[i] = data[i] * scale2 + zeros[g];
|
|
1884
|
+
}
|
|
1885
|
+
}
|
|
1886
|
+
return output;
|
|
1887
|
+
}
|
|
1888
|
+
function dequantizeInt4(qt) {
|
|
1889
|
+
if (qt.config.bits !== 4) {
|
|
1890
|
+
throw new Error("Expected INT4 quantized tensor");
|
|
1891
|
+
}
|
|
1892
|
+
const { data, scales, zeros, shape, config } = qt;
|
|
1893
|
+
const { groupSize, symmetric } = config;
|
|
1894
|
+
const numElements = shape.reduce((a, b) => a * b, 1);
|
|
1895
|
+
const output = new Float32Array(numElements);
|
|
1896
|
+
for (let i = 0;i < numElements; i++) {
|
|
1897
|
+
const byteIdx = Math.floor(i / 2);
|
|
1898
|
+
const isHigh = i % 2 === 1;
|
|
1899
|
+
let q;
|
|
1900
|
+
if (isHigh) {
|
|
1901
|
+
q = data[byteIdx] >> 4 & 15;
|
|
1902
|
+
} else {
|
|
1903
|
+
q = data[byteIdx] & 15;
|
|
1904
|
+
}
|
|
1905
|
+
const g = Math.floor(i / groupSize);
|
|
1906
|
+
const scale2 = scales[g];
|
|
1907
|
+
if (symmetric) {
|
|
1908
|
+
output[i] = (q - 8) * scale2;
|
|
1909
|
+
} else {
|
|
1910
|
+
output[i] = q * scale2 + zeros[g];
|
|
1911
|
+
}
|
|
1912
|
+
}
|
|
1913
|
+
return output;
|
|
1914
|
+
}
|
|
1915
|
+
function quantizationError(original, reconstructed) {
|
|
1916
|
+
if (original.length !== reconstructed.length) {
|
|
1917
|
+
throw new Error("Length mismatch");
|
|
1918
|
+
}
|
|
1919
|
+
let sumSqError = 0;
|
|
1920
|
+
for (let i = 0;i < original.length; i++) {
|
|
1921
|
+
const diff = original[i] - reconstructed[i];
|
|
1922
|
+
sumSqError += diff * diff;
|
|
1923
|
+
}
|
|
1924
|
+
return sumSqError / original.length;
|
|
1925
|
+
}
|
|
1926
|
+
function getMemorySavings(originalBytes, qt) {
|
|
1927
|
+
const dataBytes = qt.data.byteLength;
|
|
1928
|
+
const scaleBytes = qt.scales.byteLength;
|
|
1929
|
+
const zeroBytes = qt.zeros?.byteLength ?? 0;
|
|
1930
|
+
const quantizedBytes = dataBytes + scaleBytes + zeroBytes;
|
|
1931
|
+
return {
|
|
1932
|
+
originalBytes,
|
|
1933
|
+
quantizedBytes,
|
|
1934
|
+
savings: originalBytes - quantizedBytes,
|
|
1935
|
+
ratio: originalBytes / quantizedBytes
|
|
1936
|
+
};
|
|
1937
|
+
}
|
|
1938
|
+
// src/quantization/qmatmul.ts
|
|
1939
|
+
function qmatmulInt8CPU(A, B, M, K, N) {
|
|
1940
|
+
if (B.config.bits !== 8) {
|
|
1941
|
+
throw new Error("Expected INT8 weights");
|
|
1942
|
+
}
|
|
1943
|
+
const { data: Bq, scales, zeros, config } = B;
|
|
1944
|
+
const { groupSize, symmetric } = config;
|
|
1945
|
+
const output = new Float32Array(M * N);
|
|
1946
|
+
for (let m = 0;m < M; m++) {
|
|
1947
|
+
for (let n = 0;n < N; n++) {
|
|
1948
|
+
let sum = 0;
|
|
1949
|
+
for (let k = 0;k < K; k++) {
|
|
1950
|
+
const a = A[m * K + k];
|
|
1951
|
+
const bIdx = k * N + n;
|
|
1952
|
+
const g = Math.floor(bIdx / groupSize);
|
|
1953
|
+
const scale2 = scales[g];
|
|
1954
|
+
let b;
|
|
1955
|
+
if (symmetric) {
|
|
1956
|
+
b = (Bq[bIdx] - 128) * scale2;
|
|
1957
|
+
} else {
|
|
1958
|
+
b = Bq[bIdx] * scale2 + zeros[g];
|
|
1959
|
+
}
|
|
1960
|
+
sum += a * b;
|
|
1961
|
+
}
|
|
1962
|
+
output[m * N + n] = sum;
|
|
1963
|
+
}
|
|
1964
|
+
}
|
|
1965
|
+
return output;
|
|
1966
|
+
}
|
|
1967
|
+
function qmatmulInt4CPU(A, B, M, K, N) {
|
|
1968
|
+
if (B.config.bits !== 4) {
|
|
1969
|
+
throw new Error("Expected INT4 weights");
|
|
1970
|
+
}
|
|
1971
|
+
const { data: Bq, scales, zeros, config } = B;
|
|
1972
|
+
const { groupSize, symmetric } = config;
|
|
1973
|
+
const output = new Float32Array(M * N);
|
|
1974
|
+
for (let m = 0;m < M; m++) {
|
|
1975
|
+
for (let n = 0;n < N; n++) {
|
|
1976
|
+
let sum = 0;
|
|
1977
|
+
for (let k = 0;k < K; k++) {
|
|
1978
|
+
const a = A[m * K + k];
|
|
1979
|
+
const bIdx = k * N + n;
|
|
1980
|
+
const byteIdx = Math.floor(bIdx / 2);
|
|
1981
|
+
const isHigh = bIdx % 2 === 1;
|
|
1982
|
+
let q;
|
|
1983
|
+
if (isHigh) {
|
|
1984
|
+
q = Bq[byteIdx] >> 4 & 15;
|
|
1985
|
+
} else {
|
|
1986
|
+
q = Bq[byteIdx] & 15;
|
|
1987
|
+
}
|
|
1988
|
+
const g = Math.floor(bIdx / groupSize);
|
|
1989
|
+
const scale2 = scales[g];
|
|
1990
|
+
let b;
|
|
1991
|
+
if (symmetric) {
|
|
1992
|
+
b = (q - 8) * scale2;
|
|
1993
|
+
} else {
|
|
1994
|
+
b = q * scale2 + zeros[g];
|
|
1995
|
+
}
|
|
1996
|
+
sum += a * b;
|
|
1997
|
+
}
|
|
1998
|
+
output[m * N + n] = sum;
|
|
1999
|
+
}
|
|
2000
|
+
}
|
|
2001
|
+
return output;
|
|
2002
|
+
}
|
|
2003
|
+
function qmatmulInt8BlockCPU(A, B, M, K, N, blockSize = 32) {
|
|
2004
|
+
if (B.config.bits !== 8) {
|
|
2005
|
+
throw new Error("Expected INT8 weights");
|
|
2006
|
+
}
|
|
2007
|
+
const { data: Bq, scales, zeros, config } = B;
|
|
2008
|
+
const { groupSize, symmetric } = config;
|
|
2009
|
+
const output = new Float32Array(M * N);
|
|
2010
|
+
for (let mb = 0;mb < M; mb += blockSize) {
|
|
2011
|
+
const mEnd = Math.min(mb + blockSize, M);
|
|
2012
|
+
for (let nb = 0;nb < N; nb += blockSize) {
|
|
2013
|
+
const nEnd = Math.min(nb + blockSize, N);
|
|
2014
|
+
for (let kb = 0;kb < K; kb += blockSize) {
|
|
2015
|
+
const kEnd = Math.min(kb + blockSize, K);
|
|
2016
|
+
for (let m = mb;m < mEnd; m++) {
|
|
2017
|
+
for (let n = nb;n < nEnd; n++) {
|
|
2018
|
+
let sum = output[m * N + n];
|
|
2019
|
+
for (let k = kb;k < kEnd; k++) {
|
|
2020
|
+
const a = A[m * K + k];
|
|
2021
|
+
const bIdx = k * N + n;
|
|
2022
|
+
const g = Math.floor(bIdx / groupSize);
|
|
2023
|
+
const scale2 = scales[g];
|
|
2024
|
+
let b;
|
|
2025
|
+
if (symmetric) {
|
|
2026
|
+
b = (Bq[bIdx] - 128) * scale2;
|
|
2027
|
+
} else {
|
|
2028
|
+
b = Bq[bIdx] * scale2 + zeros[g];
|
|
2029
|
+
}
|
|
2030
|
+
sum += a * b;
|
|
2031
|
+
}
|
|
2032
|
+
output[m * N + n] = sum;
|
|
2033
|
+
}
|
|
2034
|
+
}
|
|
2035
|
+
}
|
|
2036
|
+
}
|
|
2037
|
+
}
|
|
2038
|
+
return output;
|
|
2039
|
+
}
|
|
2040
|
+
function estimateQMatMulFlops(M, K, N) {
|
|
2041
|
+
return 2 * M * K * N;
|
|
2042
|
+
}
|
|
2043
|
+
function estimateQMatMulBandwidth(M, K, N, bits, groupSize) {
|
|
2044
|
+
const activationBytes = M * K * 4;
|
|
2045
|
+
const weightElements = K * N;
|
|
2046
|
+
const weightBytes = bits === 8 ? weightElements : Math.ceil(weightElements / 2);
|
|
2047
|
+
const numGroups = Math.ceil(weightElements / groupSize);
|
|
2048
|
+
const scaleBytes = numGroups * 4;
|
|
2049
|
+
const outputBytes = M * N * 4;
|
|
2050
|
+
return {
|
|
2051
|
+
activationBytes,
|
|
2052
|
+
weightBytes,
|
|
2053
|
+
scaleBytes,
|
|
2054
|
+
outputBytes,
|
|
2055
|
+
totalBytes: activationBytes + weightBytes + scaleBytes + outputBytes
|
|
2056
|
+
};
|
|
2057
|
+
}
|
|
2058
|
+
// src/attention/block-sparse/format.ts
|
|
2059
|
+
function buildBlockSparseCSR(seqLen, pattern, blockSize = 64) {
|
|
2060
|
+
const numBlockRows = Math.ceil(seqLen / blockSize);
|
|
2061
|
+
const numBlockCols = Math.ceil(seqLen / blockSize);
|
|
2062
|
+
const nonZeroBlocks = [];
|
|
2063
|
+
for (let br = 0;br < numBlockRows; br++) {
|
|
2064
|
+
for (let bc = 0;bc < numBlockCols; bc++) {
|
|
2065
|
+
if (isBlockNonZero(br, bc, blockSize, seqLen, pattern)) {
|
|
2066
|
+
nonZeroBlocks.push({ row: br, col: bc });
|
|
2067
|
+
}
|
|
2068
|
+
}
|
|
2069
|
+
}
|
|
2070
|
+
const rowPtr = new Uint32Array(numBlockRows + 1);
|
|
2071
|
+
const colIdx = new Uint32Array(nonZeroBlocks.length);
|
|
2072
|
+
let idx = 0;
|
|
2073
|
+
for (let br = 0;br < numBlockRows; br++) {
|
|
2074
|
+
rowPtr[br] = idx;
|
|
2075
|
+
for (const block of nonZeroBlocks) {
|
|
2076
|
+
if (block.row === br) {
|
|
2077
|
+
colIdx[idx++] = block.col;
|
|
2078
|
+
}
|
|
2079
|
+
}
|
|
2080
|
+
}
|
|
2081
|
+
rowPtr[numBlockRows] = idx;
|
|
2082
|
+
return {
|
|
2083
|
+
blockSize,
|
|
2084
|
+
rowPtr,
|
|
2085
|
+
colIdx,
|
|
2086
|
+
numRows: seqLen,
|
|
2087
|
+
numCols: seqLen,
|
|
2088
|
+
numBlockRows,
|
|
2089
|
+
numBlockCols,
|
|
2090
|
+
nnzBlocks: nonZeroBlocks.length
|
|
2091
|
+
};
|
|
2092
|
+
}
|
|
2093
|
+
function isBlockNonZero(blockRow, blockCol, blockSize, seqLen, pattern) {
|
|
2094
|
+
const rowStart = blockRow * blockSize;
|
|
2095
|
+
const rowEnd = Math.min(rowStart + blockSize, seqLen);
|
|
2096
|
+
const colStart = blockCol * blockSize;
|
|
2097
|
+
const colEnd = Math.min(colStart + blockSize, seqLen);
|
|
2098
|
+
switch (pattern.type) {
|
|
2099
|
+
case "dense":
|
|
2100
|
+
return true;
|
|
2101
|
+
case "causal":
|
|
2102
|
+
return rowEnd > colStart;
|
|
2103
|
+
case "sliding": {
|
|
2104
|
+
const windowSize = pattern.windowSize;
|
|
2105
|
+
return colStart < rowEnd && colEnd > Math.max(0, rowStart - windowSize);
|
|
2106
|
+
}
|
|
2107
|
+
case "global-local": {
|
|
2108
|
+
const { globalTokens, localWindow } = pattern;
|
|
2109
|
+
for (const gt of globalTokens) {
|
|
2110
|
+
if (gt >= colStart && gt < colEnd)
|
|
2111
|
+
return true;
|
|
2112
|
+
}
|
|
2113
|
+
return colStart < rowEnd && colEnd > Math.max(0, rowStart - localWindow);
|
|
2114
|
+
}
|
|
2115
|
+
case "custom":
|
|
2116
|
+
for (let i = rowStart;i < rowEnd; i++) {
|
|
2117
|
+
for (let j = colStart;j < colEnd; j++) {
|
|
2118
|
+
if (pattern.mask[i]?.[j])
|
|
2119
|
+
return true;
|
|
2120
|
+
}
|
|
2121
|
+
}
|
|
2122
|
+
return false;
|
|
2123
|
+
default:
|
|
2124
|
+
return true;
|
|
2125
|
+
}
|
|
2126
|
+
}
|
|
2127
|
+
function getSparsityRatio(csr) {
|
|
2128
|
+
const totalBlocks = csr.numBlockRows * csr.numBlockCols;
|
|
2129
|
+
return 1 - csr.nnzBlocks / totalBlocks;
|
|
2130
|
+
}
|
|
2131
|
+
function estimateMemorySavings(csr) {
|
|
2132
|
+
const denseBytes = csr.numRows * csr.numCols * 4;
|
|
2133
|
+
const sparseBytes = csr.nnzBlocks * csr.blockSize * csr.blockSize * 4 + (csr.numBlockRows + 1) * 4 + csr.nnzBlocks * 4;
|
|
2134
|
+
return {
|
|
2135
|
+
denseBytes,
|
|
2136
|
+
sparseBytes,
|
|
2137
|
+
savingsRatio: 1 - sparseBytes / denseBytes
|
|
2138
|
+
};
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
// src/attention/flash-attention.ts
|
|
2142
|
+
var kernelCache8 = null;
|
|
2143
|
+
function getCache8(device) {
|
|
2144
|
+
if (!kernelCache8) {
|
|
2145
|
+
kernelCache8 = new KernelCache(device);
|
|
2146
|
+
}
|
|
2147
|
+
return kernelCache8;
|
|
2148
|
+
}
|
|
2149
|
+
async function flashAttention(device, q, k, v, config) {
|
|
2150
|
+
const { numHeads, headDim, seqLen } = config;
|
|
2151
|
+
const scale2 = config.scale ?? 1 / Math.sqrt(headDim);
|
|
2152
|
+
const blockSize = config.blockSize ?? 64;
|
|
2153
|
+
const pattern = config.pattern ?? { type: "causal" };
|
|
2154
|
+
const sparseMask = buildBlockSparseCSR(seqLen, pattern, blockSize);
|
|
2155
|
+
const cache = getCache8(device.device);
|
|
2156
|
+
const pipeline = cache.getOrCreate(`flash_attn_${numHeads}_${headDim}_${seqLen}_${blockSize}`, () => compileFlashAttentionKernel(device.device, config, blockSize));
|
|
2157
|
+
const output = Tensor.zeros(device, [seqLen, numHeads, headDim]);
|
|
2158
|
+
const paramsData = new Float32Array([
|
|
2159
|
+
seqLen,
|
|
2160
|
+
numHeads,
|
|
2161
|
+
headDim,
|
|
2162
|
+
scale2,
|
|
2163
|
+
blockSize,
|
|
2164
|
+
sparseMask.numBlockRows,
|
|
2165
|
+
0,
|
|
2166
|
+
0
|
|
2167
|
+
]);
|
|
2168
|
+
const paramsBuffer = device.device.createBuffer({
|
|
2169
|
+
size: paramsData.byteLength,
|
|
2170
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
2171
|
+
});
|
|
2172
|
+
device.device.queue.writeBuffer(paramsBuffer, 0, paramsData);
|
|
2173
|
+
const rowPtrBuffer = device.device.createBuffer({
|
|
2174
|
+
size: sparseMask.rowPtr.byteLength,
|
|
2175
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
|
|
2176
|
+
});
|
|
2177
|
+
device.device.queue.writeBuffer(rowPtrBuffer, 0, new Uint32Array(sparseMask.rowPtr));
|
|
2178
|
+
const colIdxBuffer = device.device.createBuffer({
|
|
2179
|
+
size: Math.max(sparseMask.colIdx.byteLength, 4),
|
|
2180
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
|
|
2181
|
+
});
|
|
2182
|
+
if (sparseMask.colIdx.length > 0) {
|
|
2183
|
+
device.device.queue.writeBuffer(colIdxBuffer, 0, new Uint32Array(sparseMask.colIdx));
|
|
2184
|
+
}
|
|
2185
|
+
const bindGroup = device.device.createBindGroup({
|
|
2186
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
2187
|
+
entries: [
|
|
2188
|
+
{ binding: 0, resource: { buffer: q.buffer } },
|
|
2189
|
+
{ binding: 1, resource: { buffer: k.buffer } },
|
|
2190
|
+
{ binding: 2, resource: { buffer: v.buffer } },
|
|
2191
|
+
{ binding: 3, resource: { buffer: output.buffer } },
|
|
2192
|
+
{ binding: 4, resource: { buffer: paramsBuffer } },
|
|
2193
|
+
{ binding: 5, resource: { buffer: rowPtrBuffer } },
|
|
2194
|
+
{ binding: 6, resource: { buffer: colIdxBuffer } }
|
|
2195
|
+
]
|
|
2196
|
+
});
|
|
2197
|
+
const encoder = device.createCommandEncoder();
|
|
2198
|
+
const pass = encoder.beginComputePass();
|
|
2199
|
+
pass.setPipeline(pipeline);
|
|
2200
|
+
pass.setBindGroup(0, bindGroup);
|
|
2201
|
+
const workgroupsX = sparseMask.numBlockRows;
|
|
2202
|
+
const workgroupsY = numHeads;
|
|
2203
|
+
pass.dispatchWorkgroups(workgroupsX, workgroupsY);
|
|
2204
|
+
pass.end();
|
|
2205
|
+
device.submit([encoder.finish()]);
|
|
2206
|
+
await device.device.queue.onSubmittedWorkDone();
|
|
2207
|
+
paramsBuffer.destroy();
|
|
2208
|
+
rowPtrBuffer.destroy();
|
|
2209
|
+
colIdxBuffer.destroy();
|
|
2210
|
+
return output;
|
|
2211
|
+
}
|
|
2212
|
+
function compileFlashAttentionKernel(device, config, blockSize) {
|
|
2213
|
+
const { headDim } = config;
|
|
2214
|
+
const wgsl = `
|
|
2215
|
+
// WebInfer FlashAttention Kernel
|
|
2216
|
+
// Implements online softmax with tiling for memory efficiency
|
|
2217
|
+
|
|
2218
|
+
struct Params {
|
|
2219
|
+
seqLen: u32,
|
|
2220
|
+
numHeads: u32,
|
|
2221
|
+
headDim: u32,
|
|
2222
|
+
scale: f32,
|
|
2223
|
+
blockSize: u32,
|
|
2224
|
+
numBlockRows: u32,
|
|
2225
|
+
_pad0: u32,
|
|
2226
|
+
_pad1: u32,
|
|
2227
|
+
}
|
|
2228
|
+
|
|
2229
|
+
@group(0) @binding(0) var<storage, read> Q: array<f32>;
|
|
2230
|
+
@group(0) @binding(1) var<storage, read> K: array<f32>;
|
|
2231
|
+
@group(0) @binding(2) var<storage, read> V: array<f32>;
|
|
2232
|
+
@group(0) @binding(3) var<storage, read_write> O: array<f32>;
|
|
2233
|
+
@group(0) @binding(4) var<uniform> params: Params;
|
|
2234
|
+
@group(0) @binding(5) var<storage, read> blockRowPtr: array<u32>;
|
|
2235
|
+
@group(0) @binding(6) var<storage, read> blockColIdx: array<u32>;
|
|
2236
|
+
|
|
2237
|
+
// Shared memory for tiles
|
|
2238
|
+
var<workgroup> tileQ: array<f32, ${blockSize * headDim}>;
|
|
2239
|
+
var<workgroup> tileK: array<f32, ${blockSize * headDim}>;
|
|
2240
|
+
var<workgroup> tileV: array<f32, ${blockSize * headDim}>;
|
|
2241
|
+
var<workgroup> tileS: array<f32, ${blockSize * blockSize}>;
|
|
2242
|
+
|
|
2243
|
+
// Online softmax state per row
|
|
2244
|
+
var<workgroup> rowMax: array<f32, ${blockSize}>;
|
|
2245
|
+
var<workgroup> rowSum: array<f32, ${blockSize}>;
|
|
2246
|
+
var<workgroup> rowOut: array<f32, ${blockSize * headDim}>;
|
|
2247
|
+
|
|
2248
|
+
@compute @workgroup_size(${blockSize})
|
|
2249
|
+
fn main(
|
|
2250
|
+
@builtin(workgroup_id) wgId: vec3<u32>,
|
|
2251
|
+
@builtin(local_invocation_id) localId: vec3<u32>
|
|
2252
|
+
) {
|
|
2253
|
+
let blockRowIdx = wgId.x;
|
|
2254
|
+
let headIdx = wgId.y;
|
|
2255
|
+
let tid = localId.x;
|
|
2256
|
+
|
|
2257
|
+
let blockSize = params.blockSize;
|
|
2258
|
+
let headDim = params.headDim;
|
|
2259
|
+
let seqLen = params.seqLen;
|
|
2260
|
+
let scale = params.scale;
|
|
2261
|
+
|
|
2262
|
+
// Global row index
|
|
2263
|
+
let globalRow = blockRowIdx * blockSize + tid;
|
|
2264
|
+
let validRow = globalRow < seqLen;
|
|
2265
|
+
|
|
2266
|
+
// Initialize online softmax state
|
|
2267
|
+
rowMax[tid] = -3.402823e+38f; // -inf
|
|
2268
|
+
rowSum[tid] = 0.0f;
|
|
2269
|
+
|
|
2270
|
+
// Initialize output accumulator
|
|
2271
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2272
|
+
rowOut[tid * headDim + d] = 0.0f;
|
|
2273
|
+
}
|
|
2274
|
+
|
|
2275
|
+
workgroupBarrier();
|
|
2276
|
+
|
|
2277
|
+
// Load Q tile for this block row
|
|
2278
|
+
if (validRow) {
|
|
2279
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2280
|
+
let qIdx = globalRow * params.numHeads * headDim + headIdx * headDim + d;
|
|
2281
|
+
tileQ[tid * headDim + d] = Q[qIdx];
|
|
2282
|
+
}
|
|
2283
|
+
}
|
|
2284
|
+
|
|
2285
|
+
workgroupBarrier();
|
|
2286
|
+
|
|
2287
|
+
// Iterate over non-zero blocks in this row (block-sparse)
|
|
2288
|
+
let blockStart = blockRowPtr[blockRowIdx];
|
|
2289
|
+
let blockEnd = blockRowPtr[blockRowIdx + 1u];
|
|
2290
|
+
|
|
2291
|
+
for (var b = blockStart; b < blockEnd; b = b + 1u) {
|
|
2292
|
+
let blockColIdx_b = blockColIdx[b];
|
|
2293
|
+
let globalCol = blockColIdx_b * blockSize + tid;
|
|
2294
|
+
let validCol = globalCol < seqLen;
|
|
2295
|
+
|
|
2296
|
+
// Load K tile
|
|
2297
|
+
if (validCol) {
|
|
2298
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2299
|
+
let kIdx = globalCol * params.numHeads * headDim + headIdx * headDim + d;
|
|
2300
|
+
tileK[tid * headDim + d] = K[kIdx];
|
|
2301
|
+
}
|
|
2302
|
+
} else {
|
|
2303
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2304
|
+
tileK[tid * headDim + d] = 0.0f;
|
|
2305
|
+
}
|
|
2306
|
+
}
|
|
2307
|
+
|
|
2308
|
+
// Load V tile
|
|
2309
|
+
if (validCol) {
|
|
2310
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2311
|
+
let vIdx = globalCol * params.numHeads * headDim + headIdx * headDim + d;
|
|
2312
|
+
tileV[tid * headDim + d] = V[vIdx];
|
|
2313
|
+
}
|
|
2314
|
+
} else {
|
|
2315
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2316
|
+
tileV[tid * headDim + d] = 0.0f;
|
|
2317
|
+
}
|
|
2318
|
+
}
|
|
2319
|
+
|
|
2320
|
+
workgroupBarrier();
|
|
2321
|
+
|
|
2322
|
+
// Compute attention scores S = Q @ K^T * scale
|
|
2323
|
+
// Each thread computes one row of scores
|
|
2324
|
+
if (validRow) {
|
|
2325
|
+
for (var j = 0u; j < blockSize; j = j + 1u) {
|
|
2326
|
+
var score = 0.0f;
|
|
2327
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2328
|
+
score = score + tileQ[tid * headDim + d] * tileK[j * headDim + d];
|
|
2329
|
+
}
|
|
2330
|
+
score = score * scale;
|
|
2331
|
+
|
|
2332
|
+
// Apply causal mask
|
|
2333
|
+
let colPos = blockColIdx_b * blockSize + j;
|
|
2334
|
+
if (colPos > globalRow) {
|
|
2335
|
+
score = -3.402823e+38f; // -inf for masked positions
|
|
2336
|
+
}
|
|
2337
|
+
|
|
2338
|
+
tileS[tid * blockSize + j] = score;
|
|
2339
|
+
}
|
|
2340
|
+
}
|
|
2341
|
+
|
|
2342
|
+
workgroupBarrier();
|
|
2343
|
+
|
|
2344
|
+
// Online softmax update
|
|
2345
|
+
if (validRow) {
|
|
2346
|
+
// Find max in this tile
|
|
2347
|
+
var tileMax = -3.402823e+38f;
|
|
2348
|
+
for (var j = 0u; j < blockSize; j = j + 1u) {
|
|
2349
|
+
tileMax = max(tileMax, tileS[tid * blockSize + j]);
|
|
2350
|
+
}
|
|
2351
|
+
|
|
2352
|
+
// Update running max
|
|
2353
|
+
let prevMax = rowMax[tid];
|
|
2354
|
+
let newMax = max(prevMax, tileMax);
|
|
2355
|
+
rowMax[tid] = newMax;
|
|
2356
|
+
|
|
2357
|
+
// Rescale previous sum and output
|
|
2358
|
+
let rescale = exp(prevMax - newMax);
|
|
2359
|
+
rowSum[tid] = rowSum[tid] * rescale;
|
|
2360
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2361
|
+
rowOut[tid * headDim + d] = rowOut[tid * headDim + d] * rescale;
|
|
2362
|
+
}
|
|
2363
|
+
|
|
2364
|
+
// Compute softmax for this tile and accumulate
|
|
2365
|
+
for (var j = 0u; j < blockSize; j = j + 1u) {
|
|
2366
|
+
let p = exp(tileS[tid * blockSize + j] - newMax);
|
|
2367
|
+
rowSum[tid] = rowSum[tid] + p;
|
|
2368
|
+
|
|
2369
|
+
// Accumulate output: O += p * V
|
|
2370
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2371
|
+
rowOut[tid * headDim + d] = rowOut[tid * headDim + d] + p * tileV[j * headDim + d];
|
|
2372
|
+
}
|
|
2373
|
+
}
|
|
2374
|
+
}
|
|
2375
|
+
|
|
2376
|
+
workgroupBarrier();
|
|
2377
|
+
}
|
|
2378
|
+
|
|
2379
|
+
// Final normalization and write output
|
|
2380
|
+
if (validRow) {
|
|
2381
|
+
let sumInv = 1.0f / rowSum[tid];
|
|
2382
|
+
for (var d = 0u; d < headDim; d = d + 1u) {
|
|
2383
|
+
let oIdx = globalRow * params.numHeads * headDim + headIdx * headDim + d;
|
|
2384
|
+
O[oIdx] = rowOut[tid * headDim + d] * sumInv;
|
|
2385
|
+
}
|
|
2386
|
+
}
|
|
2387
|
+
}
|
|
2388
|
+
`;
|
|
2389
|
+
const shaderModule = device.createShaderModule({ code: wgsl });
|
|
2390
|
+
return device.createComputePipeline({
|
|
2391
|
+
layout: "auto",
|
|
2392
|
+
compute: {
|
|
2393
|
+
module: shaderModule,
|
|
2394
|
+
entryPoint: "main"
|
|
2395
|
+
}
|
|
2396
|
+
});
|
|
2397
|
+
}
|
|
2398
|
+
function attentionCPU(q, k, v, seqLen, numHeads, headDim, causal = true) {
|
|
2399
|
+
const output = new Float32Array(seqLen * numHeads * headDim);
|
|
2400
|
+
const scale2 = 1 / Math.sqrt(headDim);
|
|
2401
|
+
for (let h = 0;h < numHeads; h++) {
|
|
2402
|
+
for (let i = 0;i < seqLen; i++) {
|
|
2403
|
+
const scores = new Float32Array(seqLen);
|
|
2404
|
+
let maxScore = -Infinity;
|
|
2405
|
+
for (let j = 0;j < seqLen; j++) {
|
|
2406
|
+
if (causal && j > i) {
|
|
2407
|
+
scores[j] = -Infinity;
|
|
2408
|
+
} else {
|
|
2409
|
+
let dot = 0;
|
|
2410
|
+
for (let d = 0;d < headDim; d++) {
|
|
2411
|
+
const qIdx = i * numHeads * headDim + h * headDim + d;
|
|
2412
|
+
const kIdx = j * numHeads * headDim + h * headDim + d;
|
|
2413
|
+
dot += q[qIdx] * k[kIdx];
|
|
2414
|
+
}
|
|
2415
|
+
scores[j] = dot * scale2;
|
|
2416
|
+
}
|
|
2417
|
+
maxScore = Math.max(maxScore, scores[j]);
|
|
2418
|
+
}
|
|
2419
|
+
let sumExp = 0;
|
|
2420
|
+
for (let j = 0;j < seqLen; j++) {
|
|
2421
|
+
scores[j] = Math.exp(scores[j] - maxScore);
|
|
2422
|
+
sumExp += scores[j];
|
|
2423
|
+
}
|
|
2424
|
+
for (let j = 0;j < seqLen; j++) {
|
|
2425
|
+
scores[j] = scores[j] / sumExp;
|
|
2426
|
+
}
|
|
2427
|
+
for (let d = 0;d < headDim; d++) {
|
|
2428
|
+
let sum = 0;
|
|
2429
|
+
for (let j = 0;j < seqLen; j++) {
|
|
2430
|
+
const vIdx = j * numHeads * headDim + h * headDim + d;
|
|
2431
|
+
sum += scores[j] * v[vIdx];
|
|
2432
|
+
}
|
|
2433
|
+
const oIdx = i * numHeads * headDim + h * headDim + d;
|
|
2434
|
+
output[oIdx] = sum;
|
|
2435
|
+
}
|
|
2436
|
+
}
|
|
2437
|
+
}
|
|
2438
|
+
return output;
|
|
2439
|
+
}
|
|
2440
|
+
// src/attention/block-sparse/patterns/causal.ts
|
|
2441
|
+
function buildCausalMask(seqLen, blockSize = 64) {
|
|
2442
|
+
const pattern = { type: "causal" };
|
|
2443
|
+
return buildBlockSparseCSR(seqLen, pattern, blockSize);
|
|
2444
|
+
}
|
|
2445
|
+
function getCausalSparsity(seqLen) {
|
|
2446
|
+
const total = seqLen * seqLen;
|
|
2447
|
+
const nonZero = seqLen * (seqLen + 1) / 2;
|
|
2448
|
+
return 1 - nonZero / total;
|
|
2449
|
+
}
|
|
2450
|
+
// src/attention/block-sparse/patterns/sliding.ts
|
|
2451
|
+
function buildSlidingWindowMask(seqLen, windowSize, blockSize = 64) {
|
|
2452
|
+
const pattern = { type: "sliding", windowSize };
|
|
2453
|
+
return buildBlockSparseCSR(seqLen, pattern, blockSize);
|
|
2454
|
+
}
|
|
2455
|
+
function getSlidingWindowSparsity(seqLen, windowSize) {
|
|
2456
|
+
const total = seqLen * seqLen;
|
|
2457
|
+
const triangularPart = windowSize * (windowSize + 1) / 2;
|
|
2458
|
+
const remainingPositions = Math.max(0, seqLen - windowSize);
|
|
2459
|
+
const windowPart = remainingPositions * (windowSize + 1);
|
|
2460
|
+
const nonZero = triangularPart + windowPart;
|
|
2461
|
+
return 1 - nonZero / total;
|
|
2462
|
+
}
|
|
2463
|
+
function buildCausalSlidingWindowMask(seqLen, windowSize, blockSize = 64) {
|
|
2464
|
+
return buildSlidingWindowMask(seqLen, windowSize, blockSize);
|
|
2465
|
+
}
|
|
2466
|
+
// src/attention/scheduler.ts
|
|
2467
|
+
var TDR_LIMITS = {
|
|
2468
|
+
chrome: 5000,
|
|
2469
|
+
safari: 3000,
|
|
2470
|
+
firefox: 8000,
|
|
2471
|
+
default: 3000
|
|
2472
|
+
};
|
|
2473
|
+
|
|
2474
|
+
class AttentionScheduler {
|
|
2475
|
+
device;
|
|
2476
|
+
tdrLimit;
|
|
2477
|
+
constructor(device) {
|
|
2478
|
+
this.device = device;
|
|
2479
|
+
this.tdrLimit = this.detectTDRLimit();
|
|
2480
|
+
}
|
|
2481
|
+
detectTDRLimit() {
|
|
2482
|
+
if (typeof navigator !== "undefined") {
|
|
2483
|
+
const ua = navigator.userAgent.toLowerCase();
|
|
2484
|
+
if (ua.includes("safari") && !ua.includes("chrome")) {
|
|
2485
|
+
return TDR_LIMITS.safari;
|
|
2486
|
+
} else if (ua.includes("firefox")) {
|
|
2487
|
+
return TDR_LIMITS.firefox;
|
|
2488
|
+
} else if (ua.includes("chrome") || ua.includes("edge")) {
|
|
2489
|
+
return TDR_LIMITS.chrome;
|
|
2490
|
+
}
|
|
2491
|
+
}
|
|
2492
|
+
return TDR_LIMITS.default;
|
|
2493
|
+
}
|
|
2494
|
+
estimateExecutionTime(seqLen, numHeads, headDim) {
|
|
2495
|
+
const flops = 4 * seqLen * seqLen * numHeads * headDim;
|
|
2496
|
+
let tflopsEstimate;
|
|
2497
|
+
switch (this.device.info.vendor) {
|
|
2498
|
+
case "apple":
|
|
2499
|
+
tflopsEstimate = 10;
|
|
2500
|
+
break;
|
|
2501
|
+
case "nvidia":
|
|
2502
|
+
tflopsEstimate = 20;
|
|
2503
|
+
break;
|
|
2504
|
+
case "amd":
|
|
2505
|
+
tflopsEstimate = 15;
|
|
2506
|
+
break;
|
|
2507
|
+
case "intel":
|
|
2508
|
+
tflopsEstimate = 8;
|
|
2509
|
+
break;
|
|
2510
|
+
default:
|
|
2511
|
+
tflopsEstimate = 5;
|
|
2512
|
+
}
|
|
2513
|
+
return flops / (tflopsEstimate * 1000000000000) * 1000 * 2;
|
|
2514
|
+
}
|
|
2515
|
+
computeChunkPlan(seqLen, numHeads, headDim) {
|
|
2516
|
+
const estimatedTime = this.estimateExecutionTime(seqLen, numHeads, headDim);
|
|
2517
|
+
if (estimatedTime < this.tdrLimit * 0.7) {
|
|
2518
|
+
return {
|
|
2519
|
+
numChunks: 1,
|
|
2520
|
+
chunkSize: seqLen,
|
|
2521
|
+
estimatedTimeMs: estimatedTime
|
|
2522
|
+
};
|
|
2523
|
+
}
|
|
2524
|
+
const targetTimePerChunk = this.tdrLimit * 0.5;
|
|
2525
|
+
const numChunks = Math.ceil(estimatedTime / targetTimePerChunk);
|
|
2526
|
+
const chunkSize = Math.ceil(seqLen / numChunks);
|
|
2527
|
+
return {
|
|
2528
|
+
numChunks,
|
|
2529
|
+
chunkSize,
|
|
2530
|
+
estimatedTimeMs: estimatedTime / numChunks
|
|
2531
|
+
};
|
|
2532
|
+
}
|
|
2533
|
+
async yieldToMain() {
|
|
2534
|
+
return new Promise((resolve) => setTimeout(resolve, 0));
|
|
2535
|
+
}
|
|
2536
|
+
mightCauseTDR(seqLen, numHeads, headDim) {
|
|
2537
|
+
const estimatedTime = this.estimateExecutionTime(seqLen, numHeads, headDim);
|
|
2538
|
+
return estimatedTime > this.tdrLimit * 0.7;
|
|
2539
|
+
}
|
|
2540
|
+
getMaxSinglePassSeqLen(numHeads, headDim) {
|
|
2541
|
+
let low = 1;
|
|
2542
|
+
let high = 65536;
|
|
2543
|
+
while (low < high) {
|
|
2544
|
+
const mid = Math.floor((low + high + 1) / 2);
|
|
2545
|
+
const time = this.estimateExecutionTime(mid, numHeads, headDim);
|
|
2546
|
+
if (time <= this.tdrLimit * 0.7) {
|
|
2547
|
+
low = mid;
|
|
2548
|
+
} else {
|
|
2549
|
+
high = mid - 1;
|
|
2550
|
+
}
|
|
2551
|
+
}
|
|
2552
|
+
return low;
|
|
2553
|
+
}
|
|
2554
|
+
}
|
|
2555
|
+
// src/attention/paged-kv/page-table.ts
|
|
2556
|
+
class PagedKVCache {
|
|
2557
|
+
device;
|
|
2558
|
+
config;
|
|
2559
|
+
keyCache;
|
|
2560
|
+
valueCache;
|
|
2561
|
+
pageTable = new Map;
|
|
2562
|
+
freePages = [];
|
|
2563
|
+
nextSeqId = 0;
|
|
2564
|
+
constructor(device, config) {
|
|
2565
|
+
this.device = device;
|
|
2566
|
+
this.config = config;
|
|
2567
|
+
const bytesPerElement = config.dtype === "f16" ? 2 : 4;
|
|
2568
|
+
const pageBytes = config.pageSize * config.numHeads * config.headDim * bytesPerElement;
|
|
2569
|
+
const totalBytes = config.maxPages * pageBytes * config.numLayers;
|
|
2570
|
+
this.keyCache = device.device.createBuffer({
|
|
2571
|
+
size: totalBytes,
|
|
2572
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
|
|
2573
|
+
});
|
|
2574
|
+
this.valueCache = device.device.createBuffer({
|
|
2575
|
+
size: totalBytes,
|
|
2576
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
|
|
2577
|
+
});
|
|
2578
|
+
for (let i = 0;i < config.maxPages; i++) {
|
|
2579
|
+
this.freePages.push(i);
|
|
2580
|
+
}
|
|
2581
|
+
}
|
|
2582
|
+
allocateSequence(initialLength = 0) {
|
|
2583
|
+
const seqId = this.nextSeqId++;
|
|
2584
|
+
const numPagesNeeded = Math.ceil(initialLength / this.config.pageSize);
|
|
2585
|
+
const pages = [];
|
|
2586
|
+
for (let i = 0;i < numPagesNeeded; i++) {
|
|
2587
|
+
const page = this.allocatePage();
|
|
2588
|
+
if (page === null) {
|
|
2589
|
+
for (const p of pages) {
|
|
2590
|
+
this.freePage(p);
|
|
2591
|
+
}
|
|
2592
|
+
throw new Error("Out of KV cache memory");
|
|
2593
|
+
}
|
|
2594
|
+
pages.push(page);
|
|
2595
|
+
}
|
|
2596
|
+
this.pageTable.set(seqId, {
|
|
2597
|
+
seqId,
|
|
2598
|
+
pages,
|
|
2599
|
+
length: initialLength
|
|
2600
|
+
});
|
|
2601
|
+
return seqId;
|
|
2602
|
+
}
|
|
2603
|
+
extendSequence(seqId, numNewTokens) {
|
|
2604
|
+
const entry = this.pageTable.get(seqId);
|
|
2605
|
+
if (!entry) {
|
|
2606
|
+
throw new Error(`Sequence ${seqId} not found`);
|
|
2607
|
+
}
|
|
2608
|
+
const newLength = entry.length + numNewTokens;
|
|
2609
|
+
const currentPages = entry.pages.length;
|
|
2610
|
+
const neededPages = Math.ceil(newLength / this.config.pageSize);
|
|
2611
|
+
while (entry.pages.length < neededPages) {
|
|
2612
|
+
const page = this.allocatePage();
|
|
2613
|
+
if (page === null) {
|
|
2614
|
+
throw new Error("Out of KV cache memory");
|
|
2615
|
+
}
|
|
2616
|
+
entry.pages.push(page);
|
|
2617
|
+
}
|
|
2618
|
+
entry.length = newLength;
|
|
2619
|
+
}
|
|
2620
|
+
freeSequence(seqId) {
|
|
2621
|
+
const entry = this.pageTable.get(seqId);
|
|
2622
|
+
if (!entry)
|
|
2623
|
+
return;
|
|
2624
|
+
for (const page of entry.pages) {
|
|
2625
|
+
this.freePage(page);
|
|
2626
|
+
}
|
|
2627
|
+
this.pageTable.delete(seqId);
|
|
2628
|
+
}
|
|
2629
|
+
getSequencePages(seqId) {
|
|
2630
|
+
return this.pageTable.get(seqId)?.pages ?? null;
|
|
2631
|
+
}
|
|
2632
|
+
getSequenceLength(seqId) {
|
|
2633
|
+
return this.pageTable.get(seqId)?.length ?? 0;
|
|
2634
|
+
}
|
|
2635
|
+
getPageForPosition(seqId, position) {
|
|
2636
|
+
const entry = this.pageTable.get(seqId);
|
|
2637
|
+
if (!entry)
|
|
2638
|
+
return null;
|
|
2639
|
+
const pageIdx = Math.floor(position / this.config.pageSize);
|
|
2640
|
+
return entry.pages[pageIdx] ?? null;
|
|
2641
|
+
}
|
|
2642
|
+
getOffsetInPage(position) {
|
|
2643
|
+
return position % this.config.pageSize;
|
|
2644
|
+
}
|
|
2645
|
+
allocatePage() {
|
|
2646
|
+
if (this.freePages.length === 0)
|
|
2647
|
+
return null;
|
|
2648
|
+
return this.freePages.pop();
|
|
2649
|
+
}
|
|
2650
|
+
freePage(page) {
|
|
2651
|
+
this.freePages.push(page);
|
|
2652
|
+
}
|
|
2653
|
+
getStats() {
|
|
2654
|
+
const bytesPerElement = this.config.dtype === "f16" ? 2 : 4;
|
|
2655
|
+
const pageBytes = this.config.pageSize * this.config.numHeads * this.config.headDim * bytesPerElement;
|
|
2656
|
+
const usedPages = this.config.maxPages - this.freePages.length;
|
|
2657
|
+
return {
|
|
2658
|
+
totalPages: this.config.maxPages,
|
|
2659
|
+
usedPages,
|
|
2660
|
+
freePages: this.freePages.length,
|
|
2661
|
+
numSequences: this.pageTable.size,
|
|
2662
|
+
memoryUsedBytes: usedPages * pageBytes * this.config.numLayers * 2,
|
|
2663
|
+
memoryTotalBytes: this.config.maxPages * pageBytes * this.config.numLayers * 2
|
|
2664
|
+
};
|
|
2665
|
+
}
|
|
2666
|
+
getBuffers() {
|
|
2667
|
+
return {
|
|
2668
|
+
keyCache: this.keyCache,
|
|
2669
|
+
valueCache: this.valueCache
|
|
2670
|
+
};
|
|
2671
|
+
}
|
|
2672
|
+
getConfig() {
|
|
2673
|
+
return { ...this.config };
|
|
2674
|
+
}
|
|
2675
|
+
dispose() {
|
|
2676
|
+
this.keyCache.destroy();
|
|
2677
|
+
this.valueCache.destroy();
|
|
2678
|
+
this.pageTable.clear();
|
|
2679
|
+
this.freePages = [];
|
|
2680
|
+
}
|
|
2681
|
+
}
|
|
2682
|
+
// src/attention/paged-kv/block-manager.ts
|
|
2683
|
+
class BlockManager {
|
|
2684
|
+
cache;
|
|
2685
|
+
config;
|
|
2686
|
+
priorities = new Map;
|
|
2687
|
+
constructor(device, config) {
|
|
2688
|
+
this.config = {
|
|
2689
|
+
policy: "greedy",
|
|
2690
|
+
reservedPages: 0,
|
|
2691
|
+
...config
|
|
2692
|
+
};
|
|
2693
|
+
this.cache = new PagedKVCache(device, config);
|
|
2694
|
+
}
|
|
2695
|
+
canAllocate(request) {
|
|
2696
|
+
const stats = this.cache.getStats();
|
|
2697
|
+
const neededPages = Math.ceil(request.numTokens / this.config.pageSize);
|
|
2698
|
+
const availablePages = stats.freePages - (this.config.reservedPages ?? 0);
|
|
2699
|
+
if (request.seqId !== undefined) {
|
|
2700
|
+
const currentLength = this.cache.getSequenceLength(request.seqId);
|
|
2701
|
+
const currentPages = Math.ceil(currentLength / this.config.pageSize);
|
|
2702
|
+
const newLength = currentLength + request.numTokens;
|
|
2703
|
+
const newPages = Math.ceil(newLength / this.config.pageSize);
|
|
2704
|
+
return newPages - currentPages <= availablePages;
|
|
2705
|
+
}
|
|
2706
|
+
return neededPages <= availablePages;
|
|
2707
|
+
}
|
|
2708
|
+
allocate(request) {
|
|
2709
|
+
if (request.seqId !== undefined) {
|
|
2710
|
+
this.cache.extendSequence(request.seqId, request.numTokens);
|
|
2711
|
+
if (request.priority !== undefined) {
|
|
2712
|
+
this.priorities.set(request.seqId, request.priority);
|
|
2713
|
+
}
|
|
2714
|
+
return request.seqId;
|
|
2715
|
+
}
|
|
2716
|
+
const seqId = this.cache.allocateSequence(request.numTokens);
|
|
2717
|
+
if (request.priority !== undefined) {
|
|
2718
|
+
this.priorities.set(seqId, request.priority);
|
|
2719
|
+
}
|
|
2720
|
+
return seqId;
|
|
2721
|
+
}
|
|
2722
|
+
free(seqId) {
|
|
2723
|
+
this.cache.freeSequence(seqId);
|
|
2724
|
+
this.priorities.delete(seqId);
|
|
2725
|
+
}
|
|
2726
|
+
evict(neededPages) {
|
|
2727
|
+
const evicted = [];
|
|
2728
|
+
const stats = this.cache.getStats();
|
|
2729
|
+
if (stats.freePages >= neededPages) {
|
|
2730
|
+
return evicted;
|
|
2731
|
+
}
|
|
2732
|
+
const sequences = Array.from(this.priorities.entries()).sort((a, b) => a[1] - b[1]).map(([seqId]) => seqId);
|
|
2733
|
+
for (const seqId of sequences) {
|
|
2734
|
+
if (stats.freePages >= neededPages)
|
|
2735
|
+
break;
|
|
2736
|
+
const pages = this.cache.getSequencePages(seqId);
|
|
2737
|
+
if (pages) {
|
|
2738
|
+
this.free(seqId);
|
|
2739
|
+
evicted.push(seqId);
|
|
2740
|
+
}
|
|
2741
|
+
}
|
|
2742
|
+
return evicted;
|
|
2743
|
+
}
|
|
2744
|
+
getUtilization() {
|
|
2745
|
+
const stats = this.cache.getStats();
|
|
2746
|
+
return stats.usedPages / stats.totalPages;
|
|
2747
|
+
}
|
|
2748
|
+
getCache() {
|
|
2749
|
+
return this.cache;
|
|
2750
|
+
}
|
|
2751
|
+
getStats() {
|
|
2752
|
+
return this.cache.getStats();
|
|
2753
|
+
}
|
|
2754
|
+
dispose() {
|
|
2755
|
+
this.cache.dispose();
|
|
2756
|
+
this.priorities.clear();
|
|
2757
|
+
}
|
|
2758
|
+
}
|
|
2759
|
+
|
|
2760
|
+
class ContinuousBatchScheduler {
|
|
2761
|
+
blockManager;
|
|
2762
|
+
runningSequences = new Set;
|
|
2763
|
+
waitingQueue = [];
|
|
2764
|
+
constructor(blockManager) {
|
|
2765
|
+
this.blockManager = blockManager;
|
|
2766
|
+
}
|
|
2767
|
+
addRequest(request) {
|
|
2768
|
+
if (this.blockManager.canAllocate(request)) {
|
|
2769
|
+
const seqId = this.blockManager.allocate(request);
|
|
2770
|
+
this.runningSequences.add(seqId);
|
|
2771
|
+
} else {
|
|
2772
|
+
this.waitingQueue.push(request);
|
|
2773
|
+
}
|
|
2774
|
+
}
|
|
2775
|
+
completeSequence(seqId) {
|
|
2776
|
+
this.runningSequences.delete(seqId);
|
|
2777
|
+
this.blockManager.free(seqId);
|
|
2778
|
+
this.scheduleWaiting();
|
|
2779
|
+
}
|
|
2780
|
+
extendSequence(seqId, numNewTokens) {
|
|
2781
|
+
if (!this.runningSequences.has(seqId)) {
|
|
2782
|
+
return false;
|
|
2783
|
+
}
|
|
2784
|
+
const request = {
|
|
2785
|
+
seqId,
|
|
2786
|
+
numTokens: numNewTokens
|
|
2787
|
+
};
|
|
2788
|
+
if (this.blockManager.canAllocate(request)) {
|
|
2789
|
+
this.blockManager.allocate(request);
|
|
2790
|
+
return true;
|
|
2791
|
+
}
|
|
2792
|
+
return false;
|
|
2793
|
+
}
|
|
2794
|
+
scheduleWaiting() {
|
|
2795
|
+
const stillWaiting = [];
|
|
2796
|
+
for (const request of this.waitingQueue) {
|
|
2797
|
+
if (this.blockManager.canAllocate(request)) {
|
|
2798
|
+
const seqId = this.blockManager.allocate(request);
|
|
2799
|
+
this.runningSequences.add(seqId);
|
|
2800
|
+
} else {
|
|
2801
|
+
stillWaiting.push(request);
|
|
2802
|
+
}
|
|
2803
|
+
}
|
|
2804
|
+
this.waitingQueue = stillWaiting;
|
|
2805
|
+
}
|
|
2806
|
+
getRunningCount() {
|
|
2807
|
+
return this.runningSequences.size;
|
|
2808
|
+
}
|
|
2809
|
+
getWaitingCount() {
|
|
2810
|
+
return this.waitingQueue.length;
|
|
2811
|
+
}
|
|
2812
|
+
}
|
|
2813
|
+
// src/sampling/top-k.ts
|
|
2814
|
+
async function topK(device, logits, k) {
|
|
2815
|
+
if (logits.shape.length !== 1 && logits.shape.length !== 2) {
|
|
2816
|
+
throw new Error("topK expects 1D or 2D tensor");
|
|
2817
|
+
}
|
|
2818
|
+
const is2D = logits.shape.length === 2;
|
|
2819
|
+
const batchSize = is2D ? logits.shape[0] : 1;
|
|
2820
|
+
const vocabSize = is2D ? logits.shape[1] : logits.shape[0];
|
|
2821
|
+
if (k > vocabSize) {
|
|
2822
|
+
throw new Error(`k (${k}) cannot be greater than vocab size (${vocabSize})`);
|
|
2823
|
+
}
|
|
2824
|
+
const logitsData = await logits.toArray();
|
|
2825
|
+
const valuesData = new Float32Array(batchSize * k);
|
|
2826
|
+
const indicesData = new Uint32Array(batchSize * k);
|
|
2827
|
+
for (let b = 0;b < batchSize; b++) {
|
|
2828
|
+
const offset = b * vocabSize;
|
|
2829
|
+
const indices2 = new Array(vocabSize).fill(0).map((_, i) => i);
|
|
2830
|
+
indices2.sort((a, b2) => logitsData[offset + b2] - logitsData[offset + a]);
|
|
2831
|
+
for (let i = 0;i < k; i++) {
|
|
2832
|
+
const idx = indices2[i];
|
|
2833
|
+
valuesData[b * k + i] = logitsData[offset + idx];
|
|
2834
|
+
indicesData[b * k + i] = idx;
|
|
2835
|
+
}
|
|
2836
|
+
}
|
|
2837
|
+
const valuesShape = is2D ? [batchSize, k] : [k];
|
|
2838
|
+
const indicesShape = is2D ? [batchSize, k] : [k];
|
|
2839
|
+
const values = await Tensor.fromArray(device, valuesShape, valuesData);
|
|
2840
|
+
const indices = new Tensor(device, indicesShape, "u32", indicesData);
|
|
2841
|
+
return { values, indices };
|
|
2842
|
+
}
|
|
2843
|
+
function topKCPU(logits, k, vocabSize) {
|
|
2844
|
+
const indices = new Array(vocabSize).fill(0).map((_, i) => i);
|
|
2845
|
+
indices.sort((a, b) => logits[b] - logits[a]);
|
|
2846
|
+
const values = new Float32Array(k);
|
|
2847
|
+
const topIndices = new Uint32Array(k);
|
|
2848
|
+
for (let i = 0;i < k; i++) {
|
|
2849
|
+
const idx = indices[i];
|
|
2850
|
+
values[i] = logits[idx];
|
|
2851
|
+
topIndices[i] = idx;
|
|
2852
|
+
}
|
|
2853
|
+
return { values, indices: topIndices };
|
|
2854
|
+
}
|
|
2855
|
+
async function topKFilter(device, logits, k) {
|
|
2856
|
+
const logitsData = await logits.toArray();
|
|
2857
|
+
const vocabSize = logits.shape[logits.shape.length - 1];
|
|
2858
|
+
const batchSize = logits.numel / vocabSize;
|
|
2859
|
+
const filtered = new Float32Array(logits.numel);
|
|
2860
|
+
for (let b = 0;b < batchSize; b++) {
|
|
2861
|
+
const offset = b * vocabSize;
|
|
2862
|
+
const values = logitsData.slice(offset, offset + vocabSize);
|
|
2863
|
+
const sorted = Float32Array.from(values).sort((a, b2) => b2 - a);
|
|
2864
|
+
const threshold = sorted[k - 1];
|
|
2865
|
+
for (let i = 0;i < vocabSize; i++) {
|
|
2866
|
+
if (logitsData[offset + i] >= threshold) {
|
|
2867
|
+
filtered[offset + i] = logitsData[offset + i];
|
|
2868
|
+
} else {
|
|
2869
|
+
filtered[offset + i] = -Infinity;
|
|
2870
|
+
}
|
|
2871
|
+
}
|
|
2872
|
+
}
|
|
2873
|
+
return Tensor.fromArray(device, [...logits.shape], filtered);
|
|
2874
|
+
}
|
|
2875
|
+
// src/sampling/top-p.ts
|
|
2876
|
+
async function topPFilter(device, logits, p, temperature = 1) {
|
|
2877
|
+
if (p <= 0 || p > 1) {
|
|
2878
|
+
throw new Error("p must be in (0, 1]");
|
|
2879
|
+
}
|
|
2880
|
+
const logitsData = await logits.toArray();
|
|
2881
|
+
const vocabSize = logits.shape[logits.shape.length - 1];
|
|
2882
|
+
const batchSize = logits.numel / vocabSize;
|
|
2883
|
+
const filtered = new Float32Array(logits.numel);
|
|
2884
|
+
for (let b = 0;b < batchSize; b++) {
|
|
2885
|
+
const offset = b * vocabSize;
|
|
2886
|
+
const scaledLogits = new Float32Array(vocabSize);
|
|
2887
|
+
for (let i = 0;i < vocabSize; i++) {
|
|
2888
|
+
scaledLogits[i] = logitsData[offset + i] / temperature;
|
|
2889
|
+
}
|
|
2890
|
+
const maxLogit = Math.max(...scaledLogits);
|
|
2891
|
+
const expLogits = scaledLogits.map((l) => Math.exp(l - maxLogit));
|
|
2892
|
+
const sumExp = expLogits.reduce((a, b2) => a + b2, 0);
|
|
2893
|
+
const probs = expLogits.map((e) => e / sumExp);
|
|
2894
|
+
const indices = new Array(vocabSize).fill(0).map((_, i) => i);
|
|
2895
|
+
indices.sort((a, b2) => probs[b2] - probs[a]);
|
|
2896
|
+
let cumProb = 0;
|
|
2897
|
+
const keepIndices = new Set;
|
|
2898
|
+
for (const idx of indices) {
|
|
2899
|
+
cumProb += probs[idx];
|
|
2900
|
+
keepIndices.add(idx);
|
|
2901
|
+
if (cumProb >= p)
|
|
2902
|
+
break;
|
|
2903
|
+
}
|
|
2904
|
+
for (let i = 0;i < vocabSize; i++) {
|
|
2905
|
+
if (keepIndices.has(i)) {
|
|
2906
|
+
filtered[offset + i] = logitsData[offset + i];
|
|
2907
|
+
} else {
|
|
2908
|
+
filtered[offset + i] = -Infinity;
|
|
2909
|
+
}
|
|
2910
|
+
}
|
|
2911
|
+
}
|
|
2912
|
+
return Tensor.fromArray(device, [...logits.shape], filtered);
|
|
2913
|
+
}
|
|
2914
|
+
function topPFilterCPU(logits, p, temperature = 1) {
|
|
2915
|
+
const vocabSize = logits.length;
|
|
2916
|
+
const scaledLogits = new Float32Array(vocabSize);
|
|
2917
|
+
for (let i = 0;i < vocabSize; i++) {
|
|
2918
|
+
scaledLogits[i] = logits[i] / temperature;
|
|
2919
|
+
}
|
|
2920
|
+
const maxLogit = Math.max(...scaledLogits);
|
|
2921
|
+
const expLogits = scaledLogits.map((l) => Math.exp(l - maxLogit));
|
|
2922
|
+
const sumExp = expLogits.reduce((a, b) => a + b, 0);
|
|
2923
|
+
const probs = expLogits.map((e) => e / sumExp);
|
|
2924
|
+
const indices = new Array(vocabSize).fill(0).map((_, i) => i);
|
|
2925
|
+
indices.sort((a, b) => probs[b] - probs[a]);
|
|
2926
|
+
let cumProb = 0;
|
|
2927
|
+
const keepIndices = new Set;
|
|
2928
|
+
for (const idx of indices) {
|
|
2929
|
+
cumProb += probs[idx];
|
|
2930
|
+
keepIndices.add(idx);
|
|
2931
|
+
if (cumProb >= p)
|
|
2932
|
+
break;
|
|
2933
|
+
}
|
|
2934
|
+
const filtered = new Float32Array(vocabSize);
|
|
2935
|
+
for (let i = 0;i < vocabSize; i++) {
|
|
2936
|
+
filtered[i] = keepIndices.has(i) ? logits[i] : -Infinity;
|
|
2937
|
+
}
|
|
2938
|
+
return filtered;
|
|
2939
|
+
}
|
|
2940
|
+
// src/sampling/sampler.ts
|
|
2941
|
+
function softmax(logits) {
|
|
2942
|
+
const maxLogit = Math.max(...logits);
|
|
2943
|
+
const expLogits = logits.map((l) => Math.exp(l - maxLogit));
|
|
2944
|
+
const sumExp = expLogits.reduce((a, b) => a + b, 0);
|
|
2945
|
+
return expLogits.map((e) => e / sumExp);
|
|
2946
|
+
}
|
|
2947
|
+
function applyRepetitionPenalty(logits, previousTokens, penalty) {
|
|
2948
|
+
if (penalty === 1)
|
|
2949
|
+
return logits;
|
|
2950
|
+
const result = new Float32Array(logits);
|
|
2951
|
+
for (const token of previousTokens) {
|
|
2952
|
+
if (token >= 0 && token < logits.length) {
|
|
2953
|
+
if (result[token] > 0) {
|
|
2954
|
+
result[token] = result[token] / penalty;
|
|
2955
|
+
} else {
|
|
2956
|
+
result[token] = result[token] * penalty;
|
|
2957
|
+
}
|
|
2958
|
+
}
|
|
2959
|
+
}
|
|
2960
|
+
return result;
|
|
2961
|
+
}
|
|
2962
|
+
function sampleFromProbs(probs, random = Math.random) {
|
|
2963
|
+
const r = random();
|
|
2964
|
+
let cumProb = 0;
|
|
2965
|
+
for (let i = 0;i < probs.length; i++) {
|
|
2966
|
+
cumProb += probs[i];
|
|
2967
|
+
if (r < cumProb) {
|
|
2968
|
+
return i;
|
|
2969
|
+
}
|
|
2970
|
+
}
|
|
2971
|
+
return probs.length - 1;
|
|
2972
|
+
}
|
|
2973
|
+
function sampleGreedy(logits) {
|
|
2974
|
+
let maxIdx = 0;
|
|
2975
|
+
let maxVal = logits[0];
|
|
2976
|
+
for (let i = 1;i < logits.length; i++) {
|
|
2977
|
+
if (logits[i] > maxVal) {
|
|
2978
|
+
maxVal = logits[i];
|
|
2979
|
+
maxIdx = i;
|
|
2980
|
+
}
|
|
2981
|
+
}
|
|
2982
|
+
return maxIdx;
|
|
2983
|
+
}
|
|
2984
|
+
async function sample(device, logits, config = {}, previousTokens = []) {
|
|
2985
|
+
const {
|
|
2986
|
+
temperature = 1,
|
|
2987
|
+
topK: topK2 = 0,
|
|
2988
|
+
topP = 1,
|
|
2989
|
+
repetitionPenalty = 1
|
|
2990
|
+
} = config;
|
|
2991
|
+
let logitsData = await logits.toArray();
|
|
2992
|
+
if (repetitionPenalty !== 1 && previousTokens.length > 0) {
|
|
2993
|
+
logitsData = applyRepetitionPenalty(logitsData, previousTokens, repetitionPenalty);
|
|
2994
|
+
}
|
|
2995
|
+
if (temperature === 0 || temperature < 0.000001) {
|
|
2996
|
+
return sampleGreedy(logitsData);
|
|
2997
|
+
}
|
|
2998
|
+
const scaledLogits = new Float32Array(logitsData.length);
|
|
2999
|
+
for (let i = 0;i < logitsData.length; i++) {
|
|
3000
|
+
scaledLogits[i] = logitsData[i] / temperature;
|
|
3001
|
+
}
|
|
3002
|
+
let filteredLogits = scaledLogits;
|
|
3003
|
+
if (topK2 > 0 && topK2 < logitsData.length) {
|
|
3004
|
+
const topKTensor = await Tensor.fromArray(device, [logitsData.length], scaledLogits);
|
|
3005
|
+
const filtered = await topKFilter(device, topKTensor, topK2);
|
|
3006
|
+
filteredLogits = new Float32Array(await filtered.toArray());
|
|
3007
|
+
topKTensor.dispose();
|
|
3008
|
+
filtered.dispose();
|
|
3009
|
+
}
|
|
3010
|
+
if (topP < 1) {
|
|
3011
|
+
const topPTensor = await Tensor.fromArray(device, [filteredLogits.length], filteredLogits);
|
|
3012
|
+
const filtered = await topPFilter(device, topPTensor, topP, 1);
|
|
3013
|
+
filteredLogits = new Float32Array(await filtered.toArray());
|
|
3014
|
+
topPTensor.dispose();
|
|
3015
|
+
filtered.dispose();
|
|
3016
|
+
}
|
|
3017
|
+
const probs = softmax(filteredLogits);
|
|
3018
|
+
return sampleFromProbs(probs);
|
|
3019
|
+
}
|
|
3020
|
+
function sampleCPU(logits, config = {}, previousTokens = []) {
|
|
3021
|
+
const {
|
|
3022
|
+
temperature = 1,
|
|
3023
|
+
topK: topK2 = 0,
|
|
3024
|
+
topP = 1,
|
|
3025
|
+
repetitionPenalty = 1
|
|
3026
|
+
} = config;
|
|
3027
|
+
let processed = new Float32Array(logits);
|
|
3028
|
+
if (repetitionPenalty !== 1 && previousTokens.length > 0) {
|
|
3029
|
+
processed = new Float32Array(applyRepetitionPenalty(processed, previousTokens, repetitionPenalty));
|
|
3030
|
+
}
|
|
3031
|
+
if (temperature === 0 || temperature < 0.000001) {
|
|
3032
|
+
return sampleGreedy(processed);
|
|
3033
|
+
}
|
|
3034
|
+
for (let i = 0;i < processed.length; i++) {
|
|
3035
|
+
processed[i] = processed[i] / temperature;
|
|
3036
|
+
}
|
|
3037
|
+
if (topK2 > 0 && topK2 < processed.length) {
|
|
3038
|
+
const sorted = new Float32Array(processed).sort((a, b) => b - a);
|
|
3039
|
+
const threshold = sorted[topK2 - 1];
|
|
3040
|
+
for (let i = 0;i < processed.length; i++) {
|
|
3041
|
+
if (processed[i] < threshold) {
|
|
3042
|
+
processed[i] = -Infinity;
|
|
3043
|
+
}
|
|
3044
|
+
}
|
|
3045
|
+
}
|
|
3046
|
+
if (topP < 1) {
|
|
3047
|
+
const probs2 = softmax(processed);
|
|
3048
|
+
const indices = new Array(processed.length).fill(0).map((_, i) => i);
|
|
3049
|
+
indices.sort((a, b) => probs2[b] - probs2[a]);
|
|
3050
|
+
let cumProb = 0;
|
|
3051
|
+
const keepSet = new Set;
|
|
3052
|
+
for (const idx of indices) {
|
|
3053
|
+
cumProb += probs2[idx];
|
|
3054
|
+
keepSet.add(idx);
|
|
3055
|
+
if (cumProb >= topP)
|
|
3056
|
+
break;
|
|
3057
|
+
}
|
|
3058
|
+
for (let i = 0;i < processed.length; i++) {
|
|
3059
|
+
if (!keepSet.has(i)) {
|
|
3060
|
+
processed[i] = -Infinity;
|
|
3061
|
+
}
|
|
3062
|
+
}
|
|
3063
|
+
}
|
|
3064
|
+
const probs = softmax(processed);
|
|
3065
|
+
return sampleFromProbs(probs);
|
|
3066
|
+
}
|
|
3067
|
+
// src/model/types.ts
|
|
3068
|
+
var GGUFQuantType;
|
|
3069
|
+
((GGUFQuantType2) => {
|
|
3070
|
+
GGUFQuantType2[GGUFQuantType2["F32"] = 0] = "F32";
|
|
3071
|
+
GGUFQuantType2[GGUFQuantType2["F16"] = 1] = "F16";
|
|
3072
|
+
GGUFQuantType2[GGUFQuantType2["Q4_0"] = 2] = "Q4_0";
|
|
3073
|
+
GGUFQuantType2[GGUFQuantType2["Q4_1"] = 3] = "Q4_1";
|
|
3074
|
+
GGUFQuantType2[GGUFQuantType2["Q5_0"] = 6] = "Q5_0";
|
|
3075
|
+
GGUFQuantType2[GGUFQuantType2["Q5_1"] = 7] = "Q5_1";
|
|
3076
|
+
GGUFQuantType2[GGUFQuantType2["Q8_0"] = 8] = "Q8_0";
|
|
3077
|
+
GGUFQuantType2[GGUFQuantType2["Q8_1"] = 9] = "Q8_1";
|
|
3078
|
+
GGUFQuantType2[GGUFQuantType2["Q2_K"] = 10] = "Q2_K";
|
|
3079
|
+
GGUFQuantType2[GGUFQuantType2["Q3_K"] = 11] = "Q3_K";
|
|
3080
|
+
GGUFQuantType2[GGUFQuantType2["Q4_K"] = 12] = "Q4_K";
|
|
3081
|
+
GGUFQuantType2[GGUFQuantType2["Q5_K"] = 13] = "Q5_K";
|
|
3082
|
+
GGUFQuantType2[GGUFQuantType2["Q6_K"] = 14] = "Q6_K";
|
|
3083
|
+
GGUFQuantType2[GGUFQuantType2["Q8_K"] = 15] = "Q8_K";
|
|
3084
|
+
GGUFQuantType2[GGUFQuantType2["IQ2_XXS"] = 16] = "IQ2_XXS";
|
|
3085
|
+
GGUFQuantType2[GGUFQuantType2["IQ2_XS"] = 17] = "IQ2_XS";
|
|
3086
|
+
GGUFQuantType2[GGUFQuantType2["IQ3_XXS"] = 18] = "IQ3_XXS";
|
|
3087
|
+
GGUFQuantType2[GGUFQuantType2["IQ1_S"] = 19] = "IQ1_S";
|
|
3088
|
+
GGUFQuantType2[GGUFQuantType2["IQ4_NL"] = 20] = "IQ4_NL";
|
|
3089
|
+
GGUFQuantType2[GGUFQuantType2["IQ3_S"] = 21] = "IQ3_S";
|
|
3090
|
+
GGUFQuantType2[GGUFQuantType2["IQ2_S"] = 22] = "IQ2_S";
|
|
3091
|
+
GGUFQuantType2[GGUFQuantType2["IQ4_XS"] = 23] = "IQ4_XS";
|
|
3092
|
+
GGUFQuantType2[GGUFQuantType2["I8"] = 24] = "I8";
|
|
3093
|
+
GGUFQuantType2[GGUFQuantType2["I16"] = 25] = "I16";
|
|
3094
|
+
GGUFQuantType2[GGUFQuantType2["I32"] = 26] = "I32";
|
|
3095
|
+
GGUFQuantType2[GGUFQuantType2["I64"] = 27] = "I64";
|
|
3096
|
+
GGUFQuantType2[GGUFQuantType2["F64"] = 28] = "F64";
|
|
3097
|
+
GGUFQuantType2[GGUFQuantType2["BF16"] = 29] = "BF16";
|
|
3098
|
+
})(GGUFQuantType ||= {});
|
|
3099
|
+
var GGUFMetadataValueType;
|
|
3100
|
+
((GGUFMetadataValueType2) => {
|
|
3101
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["UINT8"] = 0] = "UINT8";
|
|
3102
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["INT8"] = 1] = "INT8";
|
|
3103
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["UINT16"] = 2] = "UINT16";
|
|
3104
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["INT16"] = 3] = "INT16";
|
|
3105
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["UINT32"] = 4] = "UINT32";
|
|
3106
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["INT32"] = 5] = "INT32";
|
|
3107
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["FLOAT32"] = 6] = "FLOAT32";
|
|
3108
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["BOOL"] = 7] = "BOOL";
|
|
3109
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["STRING"] = 8] = "STRING";
|
|
3110
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["ARRAY"] = 9] = "ARRAY";
|
|
3111
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["UINT64"] = 10] = "UINT64";
|
|
3112
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["INT64"] = 11] = "INT64";
|
|
3113
|
+
GGUFMetadataValueType2[GGUFMetadataValueType2["FLOAT64"] = 12] = "FLOAT64";
|
|
3114
|
+
})(GGUFMetadataValueType ||= {});
|
|
3115
|
+
var GGUF_QUANT_BLOCK_SIZE = {
|
|
3116
|
+
[2 /* Q4_0 */]: 32,
|
|
3117
|
+
[3 /* Q4_1 */]: 32,
|
|
3118
|
+
[6 /* Q5_0 */]: 32,
|
|
3119
|
+
[7 /* Q5_1 */]: 32,
|
|
3120
|
+
[8 /* Q8_0 */]: 32,
|
|
3121
|
+
[9 /* Q8_1 */]: 32,
|
|
3122
|
+
[10 /* Q2_K */]: 256,
|
|
3123
|
+
[11 /* Q3_K */]: 256,
|
|
3124
|
+
[12 /* Q4_K */]: 256,
|
|
3125
|
+
[13 /* Q5_K */]: 256,
|
|
3126
|
+
[14 /* Q6_K */]: 256
|
|
3127
|
+
};
|
|
3128
|
+
var GGUF_QUANT_BYTES_PER_BLOCK = {
|
|
3129
|
+
[0 /* F32 */]: 4,
|
|
3130
|
+
[1 /* F16 */]: 2,
|
|
3131
|
+
[2 /* Q4_0 */]: 18,
|
|
3132
|
+
[3 /* Q4_1 */]: 20,
|
|
3133
|
+
[6 /* Q5_0 */]: 22,
|
|
3134
|
+
[7 /* Q5_1 */]: 24,
|
|
3135
|
+
[8 /* Q8_0 */]: 34,
|
|
3136
|
+
[9 /* Q8_1 */]: 36,
|
|
3137
|
+
[10 /* Q2_K */]: 84,
|
|
3138
|
+
[11 /* Q3_K */]: 110,
|
|
3139
|
+
[12 /* Q4_K */]: 144,
|
|
3140
|
+
[13 /* Q5_K */]: 176,
|
|
3141
|
+
[14 /* Q6_K */]: 210
|
|
3142
|
+
};
|
|
3143
|
+
// src/model/safetensors.ts
|
|
3144
|
+
function parseSafetensorsHeader(buffer) {
|
|
3145
|
+
const view = new DataView(buffer);
|
|
3146
|
+
const headerSizeLow = view.getUint32(0, true);
|
|
3147
|
+
const headerSizeHigh = view.getUint32(4, true);
|
|
3148
|
+
if (headerSizeHigh > 0) {
|
|
3149
|
+
throw new Error("Header size too large (exceeds 32-bit range)");
|
|
3150
|
+
}
|
|
3151
|
+
const headerSize = headerSizeLow;
|
|
3152
|
+
const dataOffset = 8 + headerSize;
|
|
3153
|
+
if (dataOffset > buffer.byteLength) {
|
|
3154
|
+
throw new Error(`Invalid SafeTensors file: header size ${headerSize} exceeds file size`);
|
|
3155
|
+
}
|
|
3156
|
+
const headerBytes = new Uint8Array(buffer, 8, headerSize);
|
|
3157
|
+
const headerJson = new TextDecoder("utf-8").decode(headerBytes);
|
|
3158
|
+
let header;
|
|
3159
|
+
try {
|
|
3160
|
+
const parsed = JSON.parse(headerJson);
|
|
3161
|
+
const { __metadata__, ...tensors } = parsed;
|
|
3162
|
+
header = {
|
|
3163
|
+
tensors,
|
|
3164
|
+
__metadata__
|
|
3165
|
+
};
|
|
3166
|
+
} catch (e) {
|
|
3167
|
+
throw new Error(`Failed to parse SafeTensors header JSON: ${e}`);
|
|
3168
|
+
}
|
|
3169
|
+
return { header, dataOffset };
|
|
3170
|
+
}
|
|
3171
|
+
function getSafetensorsTensorInfos(header, dataOffset) {
|
|
3172
|
+
const tensorInfos = new Map;
|
|
3173
|
+
for (const [name, entry] of Object.entries(header.tensors)) {
|
|
3174
|
+
const [start, end] = entry.data_offsets;
|
|
3175
|
+
const byteSize = end - start;
|
|
3176
|
+
tensorInfos.set(name, {
|
|
3177
|
+
name,
|
|
3178
|
+
shape: entry.shape,
|
|
3179
|
+
dtype: entry.dtype,
|
|
3180
|
+
offset: dataOffset + start,
|
|
3181
|
+
byteSize
|
|
3182
|
+
});
|
|
3183
|
+
}
|
|
3184
|
+
return tensorInfos;
|
|
3185
|
+
}
|
|
3186
|
+
function loadSafetensorsTensor(buffer, info) {
|
|
3187
|
+
const dtype = info.dtype;
|
|
3188
|
+
const tensorData = new Uint8Array(buffer, info.offset, info.byteSize);
|
|
3189
|
+
switch (dtype) {
|
|
3190
|
+
case "F32": {
|
|
3191
|
+
const float32 = new Float32Array(tensorData.buffer, tensorData.byteOffset, tensorData.byteLength / 4);
|
|
3192
|
+
return new Float32Array(float32);
|
|
3193
|
+
}
|
|
3194
|
+
case "F16": {
|
|
3195
|
+
const numel = info.shape.reduce((a, b) => a * b, 1);
|
|
3196
|
+
const result = new Float32Array(numel);
|
|
3197
|
+
const uint16View = new Uint16Array(tensorData.buffer, tensorData.byteOffset, numel);
|
|
3198
|
+
for (let i = 0;i < numel; i++) {
|
|
3199
|
+
result[i] = float16ToFloat32(uint16View[i]);
|
|
3200
|
+
}
|
|
3201
|
+
return result;
|
|
3202
|
+
}
|
|
3203
|
+
case "BF16": {
|
|
3204
|
+
const numel = info.shape.reduce((a, b) => a * b, 1);
|
|
3205
|
+
const result = new Float32Array(numel);
|
|
3206
|
+
const uint16View = new Uint16Array(tensorData.buffer, tensorData.byteOffset, numel);
|
|
3207
|
+
for (let i = 0;i < numel; i++) {
|
|
3208
|
+
result[i] = bfloat16ToFloat32(uint16View[i]);
|
|
3209
|
+
}
|
|
3210
|
+
return result;
|
|
3211
|
+
}
|
|
3212
|
+
case "F64": {
|
|
3213
|
+
const numel = info.shape.reduce((a, b) => a * b, 1);
|
|
3214
|
+
const result = new Float32Array(numel);
|
|
3215
|
+
const float64View = new Float64Array(tensorData.buffer, tensorData.byteOffset, numel);
|
|
3216
|
+
for (let i = 0;i < numel; i++) {
|
|
3217
|
+
result[i] = float64View[i];
|
|
3218
|
+
}
|
|
3219
|
+
return result;
|
|
3220
|
+
}
|
|
3221
|
+
case "I8":
|
|
3222
|
+
case "U8":
|
|
3223
|
+
case "I16":
|
|
3224
|
+
case "I32":
|
|
3225
|
+
case "I64":
|
|
3226
|
+
case "BOOL": {
|
|
3227
|
+
throw new Error(`Integer dtype ${dtype} not yet supported for loading`);
|
|
3228
|
+
}
|
|
3229
|
+
default:
|
|
3230
|
+
throw new Error(`Unknown dtype: ${dtype}`);
|
|
3231
|
+
}
|
|
3232
|
+
}
|
|
3233
|
+
function float16ToFloat32(h) {
|
|
3234
|
+
const sign = (h & 32768) >> 15;
|
|
3235
|
+
const exponent = (h & 31744) >> 10;
|
|
3236
|
+
const fraction = h & 1023;
|
|
3237
|
+
if (exponent === 0) {
|
|
3238
|
+
if (fraction === 0) {
|
|
3239
|
+
return sign === 1 ? -0 : 0;
|
|
3240
|
+
}
|
|
3241
|
+
return (sign === 1 ? -1 : 1) * Math.pow(2, -14) * (fraction / 1024);
|
|
3242
|
+
} else if (exponent === 31) {
|
|
3243
|
+
if (fraction === 0) {
|
|
3244
|
+
return sign === 1 ? -Infinity : Infinity;
|
|
3245
|
+
}
|
|
3246
|
+
return NaN;
|
|
3247
|
+
}
|
|
3248
|
+
return (sign === 1 ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 1024);
|
|
3249
|
+
}
|
|
3250
|
+
function bfloat16ToFloat32(bf16) {
|
|
3251
|
+
const uint32 = bf16 << 16;
|
|
3252
|
+
const buffer = new ArrayBuffer(4);
|
|
3253
|
+
new DataView(buffer).setUint32(0, uint32, false);
|
|
3254
|
+
return new DataView(buffer).getFloat32(0, false);
|
|
3255
|
+
}
|
|
3256
|
+
function extractMetadata(header) {
|
|
3257
|
+
const meta = {
|
|
3258
|
+
format: "safetensors",
|
|
3259
|
+
extra: header.__metadata__
|
|
3260
|
+
};
|
|
3261
|
+
const tensorNames = Object.keys(header.tensors);
|
|
3262
|
+
if (tensorNames.some((n) => n.includes("model.layers."))) {
|
|
3263
|
+
meta.architecture = "llama";
|
|
3264
|
+
} else if (tensorNames.some((n) => n.includes("transformer.h."))) {
|
|
3265
|
+
meta.architecture = "gpt2";
|
|
3266
|
+
}
|
|
3267
|
+
for (const [name, entry] of Object.entries(header.tensors)) {
|
|
3268
|
+
if (name.includes("embed_tokens") || name.includes("wte")) {
|
|
3269
|
+
meta.vocabSize = entry.shape[0];
|
|
3270
|
+
meta.embeddingLength = entry.shape[1];
|
|
3271
|
+
}
|
|
3272
|
+
if (name.includes("layers.0.self_attn.q_proj")) {
|
|
3273
|
+
meta.embeddingLength = entry.shape[1];
|
|
3274
|
+
}
|
|
3275
|
+
}
|
|
3276
|
+
const layerNums = tensorNames.map((n) => {
|
|
3277
|
+
const match = n.match(/layers\.(\d+)\./);
|
|
3278
|
+
return match ? parseInt(match[1], 10) : -1;
|
|
3279
|
+
}).filter((n) => n >= 0);
|
|
3280
|
+
if (layerNums.length > 0) {
|
|
3281
|
+
meta.numLayers = Math.max(...layerNums) + 1;
|
|
3282
|
+
}
|
|
3283
|
+
return meta;
|
|
3284
|
+
}
|
|
3285
|
+
function loadSafetensors(buffer, options) {
|
|
3286
|
+
const { header, dataOffset } = parseSafetensorsHeader(buffer);
|
|
3287
|
+
const tensorInfos = getSafetensorsTensorInfos(header, dataOffset);
|
|
3288
|
+
if (options?.tensorFilter) {
|
|
3289
|
+
for (const name of tensorInfos.keys()) {
|
|
3290
|
+
if (!options.tensorFilter(name)) {
|
|
3291
|
+
tensorInfos.delete(name);
|
|
3292
|
+
}
|
|
3293
|
+
}
|
|
3294
|
+
}
|
|
3295
|
+
let totalBytes = 0;
|
|
3296
|
+
for (const info of tensorInfos.values()) {
|
|
3297
|
+
totalBytes += info.byteSize;
|
|
3298
|
+
}
|
|
3299
|
+
const metadata = extractMetadata(header);
|
|
3300
|
+
return {
|
|
3301
|
+
metadata,
|
|
3302
|
+
tensorInfos,
|
|
3303
|
+
totalBytes,
|
|
3304
|
+
buffer,
|
|
3305
|
+
dataOffset
|
|
3306
|
+
};
|
|
3307
|
+
}
|
|
3308
|
+
async function loadSafetensorsFromUrl(url, options) {
|
|
3309
|
+
const response = await fetch(url);
|
|
3310
|
+
if (!response.ok) {
|
|
3311
|
+
throw new Error(`Failed to fetch ${url}: ${response.statusText}`);
|
|
3312
|
+
}
|
|
3313
|
+
const buffer = await response.arrayBuffer();
|
|
3314
|
+
return loadSafetensors(buffer, options);
|
|
3315
|
+
}
|
|
3316
|
+
function isSafetensors(buffer) {
|
|
3317
|
+
if (buffer.byteLength < 8)
|
|
3318
|
+
return false;
|
|
3319
|
+
try {
|
|
3320
|
+
const view = new DataView(buffer);
|
|
3321
|
+
const headerSize = view.getUint32(0, true);
|
|
3322
|
+
const headerSizeHigh = view.getUint32(4, true);
|
|
3323
|
+
if (headerSizeHigh !== 0 || headerSize > 100 * 1024 * 1024) {
|
|
3324
|
+
return false;
|
|
3325
|
+
}
|
|
3326
|
+
if (8 + headerSize > buffer.byteLength) {
|
|
3327
|
+
return false;
|
|
3328
|
+
}
|
|
3329
|
+
const headerBytes = new Uint8Array(buffer, 8, Math.min(headerSize, 100));
|
|
3330
|
+
const headerStart = new TextDecoder("utf-8").decode(headerBytes);
|
|
3331
|
+
return headerStart.trimStart().startsWith("{");
|
|
3332
|
+
} catch {
|
|
3333
|
+
return false;
|
|
3334
|
+
}
|
|
3335
|
+
}
|
|
3336
|
+
// src/model/gguf.ts
|
|
3337
|
+
var GGUF_MAGIC = 1179993927;
|
|
3338
|
+
var GGUF_VERSION = 3;
|
|
3339
|
+
var DEFAULT_ALIGNMENT = 32;
|
|
3340
|
+
|
|
3341
|
+
class GGUFReader {
|
|
3342
|
+
view;
|
|
3343
|
+
offset = 0;
|
|
3344
|
+
textDecoder = new TextDecoder("utf-8");
|
|
3345
|
+
constructor(buffer) {
|
|
3346
|
+
this.view = new DataView(buffer);
|
|
3347
|
+
}
|
|
3348
|
+
get position() {
|
|
3349
|
+
return this.offset;
|
|
3350
|
+
}
|
|
3351
|
+
set position(pos) {
|
|
3352
|
+
this.offset = pos;
|
|
3353
|
+
}
|
|
3354
|
+
readUint8() {
|
|
3355
|
+
const val = this.view.getUint8(this.offset);
|
|
3356
|
+
this.offset += 1;
|
|
3357
|
+
return val;
|
|
3358
|
+
}
|
|
3359
|
+
readInt8() {
|
|
3360
|
+
const val = this.view.getInt8(this.offset);
|
|
3361
|
+
this.offset += 1;
|
|
3362
|
+
return val;
|
|
3363
|
+
}
|
|
3364
|
+
readUint16() {
|
|
3365
|
+
const val = this.view.getUint16(this.offset, true);
|
|
3366
|
+
this.offset += 2;
|
|
3367
|
+
return val;
|
|
3368
|
+
}
|
|
3369
|
+
readInt16() {
|
|
3370
|
+
const val = this.view.getInt16(this.offset, true);
|
|
3371
|
+
this.offset += 2;
|
|
3372
|
+
return val;
|
|
3373
|
+
}
|
|
3374
|
+
readUint32() {
|
|
3375
|
+
const val = this.view.getUint32(this.offset, true);
|
|
3376
|
+
this.offset += 4;
|
|
3377
|
+
return val;
|
|
3378
|
+
}
|
|
3379
|
+
readInt32() {
|
|
3380
|
+
const val = this.view.getInt32(this.offset, true);
|
|
3381
|
+
this.offset += 4;
|
|
3382
|
+
return val;
|
|
3383
|
+
}
|
|
3384
|
+
readUint64() {
|
|
3385
|
+
const val = this.view.getBigUint64(this.offset, true);
|
|
3386
|
+
this.offset += 8;
|
|
3387
|
+
return val;
|
|
3388
|
+
}
|
|
3389
|
+
readInt64() {
|
|
3390
|
+
const val = this.view.getBigInt64(this.offset, true);
|
|
3391
|
+
this.offset += 8;
|
|
3392
|
+
return val;
|
|
3393
|
+
}
|
|
3394
|
+
readFloat32() {
|
|
3395
|
+
const val = this.view.getFloat32(this.offset, true);
|
|
3396
|
+
this.offset += 4;
|
|
3397
|
+
return val;
|
|
3398
|
+
}
|
|
3399
|
+
readFloat64() {
|
|
3400
|
+
const val = this.view.getFloat64(this.offset, true);
|
|
3401
|
+
this.offset += 8;
|
|
3402
|
+
return val;
|
|
3403
|
+
}
|
|
3404
|
+
readBool() {
|
|
3405
|
+
return this.readUint8() !== 0;
|
|
3406
|
+
}
|
|
3407
|
+
readString() {
|
|
3408
|
+
const length = Number(this.readUint64());
|
|
3409
|
+
const bytes = new Uint8Array(this.view.buffer, this.offset, length);
|
|
3410
|
+
this.offset += length;
|
|
3411
|
+
return this.textDecoder.decode(bytes);
|
|
3412
|
+
}
|
|
3413
|
+
alignTo(alignment) {
|
|
3414
|
+
const remainder = this.offset % alignment;
|
|
3415
|
+
if (remainder !== 0) {
|
|
3416
|
+
this.offset += alignment - remainder;
|
|
3417
|
+
}
|
|
3418
|
+
}
|
|
3419
|
+
}
|
|
3420
|
+
function parseGGUFHeader(reader) {
|
|
3421
|
+
const magic = reader.readUint32();
|
|
3422
|
+
if (magic !== GGUF_MAGIC) {
|
|
3423
|
+
throw new Error(`Invalid GGUF magic: expected 0x${GGUF_MAGIC.toString(16)}, got 0x${magic.toString(16)}`);
|
|
3424
|
+
}
|
|
3425
|
+
const version = reader.readUint32();
|
|
3426
|
+
if (version !== GGUF_VERSION) {
|
|
3427
|
+
throw new Error(`Unsupported GGUF version: ${version} (expected ${GGUF_VERSION})`);
|
|
3428
|
+
}
|
|
3429
|
+
const nTensors = reader.readUint64();
|
|
3430
|
+
const nKV = reader.readUint64();
|
|
3431
|
+
return { magic, version, nTensors, nKV };
|
|
3432
|
+
}
|
|
3433
|
+
function parseMetadataValue(reader, valueType) {
|
|
3434
|
+
switch (valueType) {
|
|
3435
|
+
case 0 /* UINT8 */:
|
|
3436
|
+
return reader.readUint8();
|
|
3437
|
+
case 1 /* INT8 */:
|
|
3438
|
+
return reader.readInt8();
|
|
3439
|
+
case 2 /* UINT16 */:
|
|
3440
|
+
return reader.readUint16();
|
|
3441
|
+
case 3 /* INT16 */:
|
|
3442
|
+
return reader.readInt16();
|
|
3443
|
+
case 4 /* UINT32 */:
|
|
3444
|
+
return reader.readUint32();
|
|
3445
|
+
case 5 /* INT32 */:
|
|
3446
|
+
return reader.readInt32();
|
|
3447
|
+
case 6 /* FLOAT32 */:
|
|
3448
|
+
return reader.readFloat32();
|
|
3449
|
+
case 7 /* BOOL */:
|
|
3450
|
+
return reader.readBool();
|
|
3451
|
+
case 8 /* STRING */:
|
|
3452
|
+
return reader.readString();
|
|
3453
|
+
case 10 /* UINT64 */:
|
|
3454
|
+
return reader.readUint64();
|
|
3455
|
+
case 11 /* INT64 */:
|
|
3456
|
+
return reader.readInt64();
|
|
3457
|
+
case 12 /* FLOAT64 */:
|
|
3458
|
+
return reader.readFloat64();
|
|
3459
|
+
case 9 /* ARRAY */: {
|
|
3460
|
+
const arrayType = reader.readUint32();
|
|
3461
|
+
const arrayLen = Number(reader.readUint64());
|
|
3462
|
+
const result = [];
|
|
3463
|
+
for (let i = 0;i < arrayLen; i++) {
|
|
3464
|
+
result.push(parseMetadataValue(reader, arrayType));
|
|
3465
|
+
}
|
|
3466
|
+
return result;
|
|
3467
|
+
}
|
|
3468
|
+
default:
|
|
3469
|
+
throw new Error(`Unknown metadata value type: ${valueType}`);
|
|
3470
|
+
}
|
|
3471
|
+
}
|
|
3472
|
+
function parseGGUFMetadata(reader, nKV) {
|
|
3473
|
+
const metadata = new Map;
|
|
3474
|
+
for (let i = 0n;i < nKV; i++) {
|
|
3475
|
+
const key = reader.readString();
|
|
3476
|
+
const valueType = reader.readUint32();
|
|
3477
|
+
const value = parseMetadataValue(reader, valueType);
|
|
3478
|
+
metadata.set(key, value);
|
|
3479
|
+
}
|
|
3480
|
+
return metadata;
|
|
3481
|
+
}
|
|
3482
|
+
function parseGGUFTensorInfos(reader, nTensors) {
|
|
3483
|
+
const tensorInfos = [];
|
|
3484
|
+
for (let i = 0n;i < nTensors; i++) {
|
|
3485
|
+
const name = reader.readString();
|
|
3486
|
+
const nDims = reader.readUint32();
|
|
3487
|
+
const dimensions = [];
|
|
3488
|
+
for (let d = 0;d < nDims; d++) {
|
|
3489
|
+
dimensions.push(reader.readUint64());
|
|
3490
|
+
}
|
|
3491
|
+
const type = reader.readUint32();
|
|
3492
|
+
const offset = reader.readUint64();
|
|
3493
|
+
tensorInfos.push({ name, nDims, dimensions, type, offset });
|
|
3494
|
+
}
|
|
3495
|
+
return tensorInfos;
|
|
3496
|
+
}
|
|
3497
|
+
function calculateGGUFTensorBytes(type, shape) {
|
|
3498
|
+
const numel = shape.reduce((a, b) => a * b, 1);
|
|
3499
|
+
if (type === 0 /* F32 */) {
|
|
3500
|
+
return numel * 4;
|
|
3501
|
+
}
|
|
3502
|
+
if (type === 1 /* F16 */ || type === 29 /* BF16 */) {
|
|
3503
|
+
return numel * 2;
|
|
3504
|
+
}
|
|
3505
|
+
if (type === 24 /* I8 */) {
|
|
3506
|
+
return numel;
|
|
3507
|
+
}
|
|
3508
|
+
if (type === 25 /* I16 */) {
|
|
3509
|
+
return numel * 2;
|
|
3510
|
+
}
|
|
3511
|
+
if (type === 26 /* I32 */) {
|
|
3512
|
+
return numel * 4;
|
|
3513
|
+
}
|
|
3514
|
+
if (type === 27 /* I64 */ || type === 28 /* F64 */) {
|
|
3515
|
+
return numel * 8;
|
|
3516
|
+
}
|
|
3517
|
+
const blockSize = GGUF_QUANT_BLOCK_SIZE[type];
|
|
3518
|
+
const bytesPerBlock = GGUF_QUANT_BYTES_PER_BLOCK[type];
|
|
3519
|
+
if (blockSize === undefined || bytesPerBlock === undefined) {
|
|
3520
|
+
throw new Error(`Unknown quantization type: ${type}`);
|
|
3521
|
+
}
|
|
3522
|
+
const numBlocks = Math.ceil(numel / blockSize);
|
|
3523
|
+
return numBlocks * bytesPerBlock;
|
|
3524
|
+
}
|
|
3525
|
+
function convertTensorInfo(info, dataOffset) {
|
|
3526
|
+
const shape = info.dimensions.map((d) => Number(d));
|
|
3527
|
+
const byteSize = calculateGGUFTensorBytes(info.type, shape);
|
|
3528
|
+
return {
|
|
3529
|
+
name: info.name,
|
|
3530
|
+
shape,
|
|
3531
|
+
dtype: info.type,
|
|
3532
|
+
offset: dataOffset + Number(info.offset),
|
|
3533
|
+
byteSize
|
|
3534
|
+
};
|
|
3535
|
+
}
|
|
3536
|
+
function extractGGUFMetadata(metadata) {
|
|
3537
|
+
const meta = {
|
|
3538
|
+
format: "gguf",
|
|
3539
|
+
extra: Object.fromEntries(metadata)
|
|
3540
|
+
};
|
|
3541
|
+
meta.name = metadata.get("general.name");
|
|
3542
|
+
meta.architecture = metadata.get("general.architecture");
|
|
3543
|
+
const arch = meta.architecture || "";
|
|
3544
|
+
meta.contextLength = metadata.get(`${arch}.context_length`);
|
|
3545
|
+
meta.embeddingLength = metadata.get(`${arch}.embedding_length`);
|
|
3546
|
+
meta.numLayers = metadata.get(`${arch}.block_count`);
|
|
3547
|
+
meta.numHeads = metadata.get(`${arch}.attention.head_count`);
|
|
3548
|
+
meta.numKVHeads = metadata.get(`${arch}.attention.head_count_kv`);
|
|
3549
|
+
meta.vocabSize = metadata.get(`${arch}.vocab_size`);
|
|
3550
|
+
meta.ropeFreqBase = metadata.get(`${arch}.rope.freq_base`);
|
|
3551
|
+
if (meta.embeddingLength && meta.numHeads) {
|
|
3552
|
+
meta.headDim = meta.embeddingLength / meta.numHeads;
|
|
3553
|
+
}
|
|
3554
|
+
return meta;
|
|
3555
|
+
}
|
|
3556
|
+
function loadGGUF(buffer, options) {
|
|
3557
|
+
const reader = new GGUFReader(buffer);
|
|
3558
|
+
const header = parseGGUFHeader(reader);
|
|
3559
|
+
const rawMetadata = parseGGUFMetadata(reader, header.nKV);
|
|
3560
|
+
const alignment = rawMetadata.get("general.alignment") || DEFAULT_ALIGNMENT;
|
|
3561
|
+
const ggufTensorInfos = parseGGUFTensorInfos(reader, header.nTensors);
|
|
3562
|
+
reader.alignTo(alignment);
|
|
3563
|
+
const dataOffset = reader.position;
|
|
3564
|
+
const tensorInfos = new Map;
|
|
3565
|
+
let totalBytes = 0;
|
|
3566
|
+
for (const info of ggufTensorInfos) {
|
|
3567
|
+
if (options?.tensorFilter && !options.tensorFilter(info.name)) {
|
|
3568
|
+
continue;
|
|
3569
|
+
}
|
|
3570
|
+
const converted = convertTensorInfo(info, dataOffset);
|
|
3571
|
+
tensorInfos.set(info.name, converted);
|
|
3572
|
+
totalBytes += converted.byteSize;
|
|
3573
|
+
}
|
|
3574
|
+
const metadata = extractGGUFMetadata(rawMetadata);
|
|
3575
|
+
return {
|
|
3576
|
+
metadata,
|
|
3577
|
+
tensorInfos,
|
|
3578
|
+
totalBytes,
|
|
3579
|
+
buffer,
|
|
3580
|
+
dataOffset
|
|
3581
|
+
};
|
|
3582
|
+
}
|
|
3583
|
+
async function loadGGUFFromUrl(url, options) {
|
|
3584
|
+
const response = await fetch(url);
|
|
3585
|
+
if (!response.ok) {
|
|
3586
|
+
throw new Error(`Failed to fetch ${url}: ${response.statusText}`);
|
|
3587
|
+
}
|
|
3588
|
+
const buffer = await response.arrayBuffer();
|
|
3589
|
+
return loadGGUF(buffer, options);
|
|
3590
|
+
}
|
|
3591
|
+
function dequantizeQ4_0Block(data, offset) {
|
|
3592
|
+
const result = new Float32Array(32);
|
|
3593
|
+
const scaleU16 = data[offset + 1] << 8 | data[offset];
|
|
3594
|
+
const scale2 = float16ToFloat322(scaleU16);
|
|
3595
|
+
for (let i = 0;i < 16; i++) {
|
|
3596
|
+
const byte = data[offset + 2 + i];
|
|
3597
|
+
const v0 = (byte & 15) - 8;
|
|
3598
|
+
const v1 = (byte >> 4 & 15) - 8;
|
|
3599
|
+
result[i * 2] = v0 * scale2;
|
|
3600
|
+
result[i * 2 + 1] = v1 * scale2;
|
|
3601
|
+
}
|
|
3602
|
+
return result;
|
|
3603
|
+
}
|
|
3604
|
+
function dequantizeQ8_0Block(data, offset) {
|
|
3605
|
+
const result = new Float32Array(32);
|
|
3606
|
+
const scaleU16 = data[offset + 1] << 8 | data[offset];
|
|
3607
|
+
const scale2 = float16ToFloat322(scaleU16);
|
|
3608
|
+
for (let i = 0;i < 32; i++) {
|
|
3609
|
+
const v = data[offset + 2 + i];
|
|
3610
|
+
const signed = v > 127 ? v - 256 : v;
|
|
3611
|
+
result[i] = signed * scale2;
|
|
3612
|
+
}
|
|
3613
|
+
return result;
|
|
3614
|
+
}
|
|
3615
|
+
function loadGGUFTensor(buffer, info) {
|
|
3616
|
+
const type = info.dtype;
|
|
3617
|
+
const data = new Uint8Array(buffer, info.offset, info.byteSize);
|
|
3618
|
+
const numel = info.shape.reduce((a, b) => a * b, 1);
|
|
3619
|
+
switch (type) {
|
|
3620
|
+
case 0 /* F32 */: {
|
|
3621
|
+
return new Float32Array(buffer, info.offset, numel);
|
|
3622
|
+
}
|
|
3623
|
+
case 1 /* F16 */: {
|
|
3624
|
+
const result = new Float32Array(numel);
|
|
3625
|
+
const u16 = new Uint16Array(buffer, info.offset, numel);
|
|
3626
|
+
for (let i = 0;i < numel; i++) {
|
|
3627
|
+
result[i] = float16ToFloat322(u16[i]);
|
|
3628
|
+
}
|
|
3629
|
+
return result;
|
|
3630
|
+
}
|
|
3631
|
+
case 2 /* Q4_0 */: {
|
|
3632
|
+
const blockSize = 32;
|
|
3633
|
+
const bytesPerBlock = 18;
|
|
3634
|
+
const numBlocks = Math.ceil(numel / blockSize);
|
|
3635
|
+
const result = new Float32Array(numel);
|
|
3636
|
+
for (let b = 0;b < numBlocks; b++) {
|
|
3637
|
+
const blockData = dequantizeQ4_0Block(data, b * bytesPerBlock);
|
|
3638
|
+
const outOffset = b * blockSize;
|
|
3639
|
+
const copyLen = Math.min(blockSize, numel - outOffset);
|
|
3640
|
+
result.set(blockData.subarray(0, copyLen), outOffset);
|
|
3641
|
+
}
|
|
3642
|
+
return result;
|
|
3643
|
+
}
|
|
3644
|
+
case 8 /* Q8_0 */: {
|
|
3645
|
+
const blockSize = 32;
|
|
3646
|
+
const bytesPerBlock = 34;
|
|
3647
|
+
const numBlocks = Math.ceil(numel / blockSize);
|
|
3648
|
+
const result = new Float32Array(numel);
|
|
3649
|
+
for (let b = 0;b < numBlocks; b++) {
|
|
3650
|
+
const blockData = dequantizeQ8_0Block(data, b * bytesPerBlock);
|
|
3651
|
+
const outOffset = b * blockSize;
|
|
3652
|
+
const copyLen = Math.min(blockSize, numel - outOffset);
|
|
3653
|
+
result.set(blockData.subarray(0, copyLen), outOffset);
|
|
3654
|
+
}
|
|
3655
|
+
return result;
|
|
3656
|
+
}
|
|
3657
|
+
default:
|
|
3658
|
+
throw new Error(`Quantization type ${GGUFQuantType[type]} not yet supported for dequantization`);
|
|
3659
|
+
}
|
|
3660
|
+
}
|
|
3661
|
+
function isGGUF(buffer) {
|
|
3662
|
+
if (buffer.byteLength < 24)
|
|
3663
|
+
return false;
|
|
3664
|
+
try {
|
|
3665
|
+
const view = new DataView(buffer);
|
|
3666
|
+
const magic = view.getUint32(0, true);
|
|
3667
|
+
return magic === GGUF_MAGIC;
|
|
3668
|
+
} catch {
|
|
3669
|
+
return false;
|
|
3670
|
+
}
|
|
3671
|
+
}
|
|
3672
|
+
function float16ToFloat322(h) {
|
|
3673
|
+
const sign = (h & 32768) >> 15;
|
|
3674
|
+
const exponent = (h & 31744) >> 10;
|
|
3675
|
+
const fraction = h & 1023;
|
|
3676
|
+
if (exponent === 0) {
|
|
3677
|
+
if (fraction === 0) {
|
|
3678
|
+
return sign === 1 ? -0 : 0;
|
|
3679
|
+
}
|
|
3680
|
+
return (sign === 1 ? -1 : 1) * Math.pow(2, -14) * (fraction / 1024);
|
|
3681
|
+
} else if (exponent === 31) {
|
|
3682
|
+
if (fraction === 0) {
|
|
3683
|
+
return sign === 1 ? -Infinity : Infinity;
|
|
3684
|
+
}
|
|
3685
|
+
return NaN;
|
|
3686
|
+
}
|
|
3687
|
+
return (sign === 1 ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 1024);
|
|
3688
|
+
}
|
|
3689
|
+
// src/model/index.ts
|
|
3690
|
+
async function loadModel(source, options) {
|
|
3691
|
+
let buffer;
|
|
3692
|
+
if (typeof source === "string") {
|
|
3693
|
+
const response = await fetch(source);
|
|
3694
|
+
if (!response.ok) {
|
|
3695
|
+
throw new Error(`Failed to fetch model: ${response.statusText}`);
|
|
3696
|
+
}
|
|
3697
|
+
buffer = await response.arrayBuffer();
|
|
3698
|
+
} else {
|
|
3699
|
+
buffer = source;
|
|
3700
|
+
}
|
|
3701
|
+
if (isGGUF(buffer)) {
|
|
3702
|
+
return loadGGUF(buffer, options);
|
|
3703
|
+
} else if (isSafetensors(buffer)) {
|
|
3704
|
+
return loadSafetensors(buffer, options);
|
|
3705
|
+
} else {
|
|
3706
|
+
throw new Error("Unknown model format. Expected SafeTensors or GGUF.");
|
|
3707
|
+
}
|
|
3708
|
+
}
|
|
3709
|
+
// src/inference/types.ts
|
|
3710
|
+
var DEFAULT_GENERATION_CONFIG = {
|
|
3711
|
+
maxTokens: 256,
|
|
3712
|
+
temperature: 1,
|
|
3713
|
+
topK: 0,
|
|
3714
|
+
topP: 1,
|
|
3715
|
+
repetitionPenalty: 1,
|
|
3716
|
+
eosTokenId: 2,
|
|
3717
|
+
padTokenId: 0,
|
|
3718
|
+
bosTokenId: 1,
|
|
3719
|
+
stream: false
|
|
3720
|
+
};
|
|
3721
|
+
function normalizeGenerationConfig(config) {
|
|
3722
|
+
const normalized = {
|
|
3723
|
+
maxTokens: config.maxTokens ?? DEFAULT_GENERATION_CONFIG.maxTokens,
|
|
3724
|
+
temperature: config.temperature ?? DEFAULT_GENERATION_CONFIG.temperature,
|
|
3725
|
+
topK: config.topK ?? DEFAULT_GENERATION_CONFIG.topK,
|
|
3726
|
+
topP: config.topP ?? DEFAULT_GENERATION_CONFIG.topP,
|
|
3727
|
+
repetitionPenalty: config.repetitionPenalty ?? DEFAULT_GENERATION_CONFIG.repetitionPenalty,
|
|
3728
|
+
eosTokenId: config.eosTokenId ?? DEFAULT_GENERATION_CONFIG.eosTokenId,
|
|
3729
|
+
padTokenId: config.padTokenId ?? DEFAULT_GENERATION_CONFIG.padTokenId,
|
|
3730
|
+
bosTokenId: config.bosTokenId ?? DEFAULT_GENERATION_CONFIG.bosTokenId,
|
|
3731
|
+
stream: config.stream ?? DEFAULT_GENERATION_CONFIG.stream,
|
|
3732
|
+
stopSequences: config.stopSequences,
|
|
3733
|
+
seed: config.seed
|
|
3734
|
+
};
|
|
3735
|
+
if (normalized.maxTokens < 1) {
|
|
3736
|
+
throw new Error("maxTokens must be >= 1");
|
|
3737
|
+
}
|
|
3738
|
+
if (normalized.temperature !== undefined && normalized.temperature < 0) {
|
|
3739
|
+
throw new Error("temperature must be >= 0");
|
|
3740
|
+
}
|
|
3741
|
+
if (normalized.topK !== undefined && normalized.topK < 0) {
|
|
3742
|
+
throw new Error("topK must be >= 0");
|
|
3743
|
+
}
|
|
3744
|
+
if (normalized.topP !== undefined && (normalized.topP < 0 || normalized.topP > 1)) {
|
|
3745
|
+
throw new Error("topP must be between 0 and 1");
|
|
3746
|
+
}
|
|
3747
|
+
if (normalized.repetitionPenalty !== undefined && normalized.repetitionPenalty < 1) {
|
|
3748
|
+
throw new Error("repetitionPenalty must be >= 1");
|
|
3749
|
+
}
|
|
3750
|
+
return normalized;
|
|
3751
|
+
}
|
|
3752
|
+
// src/inference/engine.ts
|
|
3753
|
+
var DEFAULT_INFERENCE_CONFIG = {
|
|
3754
|
+
maxBatchSize: 1,
|
|
3755
|
+
maxSeqLen: 2048,
|
|
3756
|
+
useKVCache: true,
|
|
3757
|
+
memoryLimit: 0,
|
|
3758
|
+
enableProfiling: false
|
|
3759
|
+
};
|
|
3760
|
+
|
|
3761
|
+
class InferenceEngine {
|
|
3762
|
+
device;
|
|
3763
|
+
config;
|
|
3764
|
+
modelConfig = null;
|
|
3765
|
+
weights = null;
|
|
3766
|
+
loadedModel = null;
|
|
3767
|
+
kvCache = null;
|
|
3768
|
+
ropeFreqsCos = null;
|
|
3769
|
+
ropeFreqsSin = null;
|
|
3770
|
+
constructor(device, config) {
|
|
3771
|
+
this.device = device;
|
|
3772
|
+
this.config = { ...DEFAULT_INFERENCE_CONFIG, ...config };
|
|
3773
|
+
}
|
|
3774
|
+
async loadModel(model, modelConfig) {
|
|
3775
|
+
this.loadedModel = model;
|
|
3776
|
+
this.modelConfig = modelConfig;
|
|
3777
|
+
this.weights = await this.extractWeights(model, modelConfig);
|
|
3778
|
+
const headDim = modelConfig.headDim ?? modelConfig.hiddenSize / modelConfig.numHeads;
|
|
3779
|
+
const ropeFreqBase = modelConfig.ropeFreqBase ?? 1e4;
|
|
3780
|
+
const { cos, sin } = computeRoPEFrequencies({
|
|
3781
|
+
dim: headDim,
|
|
3782
|
+
maxSeqLen: this.config.maxSeqLen,
|
|
3783
|
+
base: ropeFreqBase
|
|
3784
|
+
});
|
|
3785
|
+
this.ropeFreqsCos = cos;
|
|
3786
|
+
this.ropeFreqsSin = sin;
|
|
3787
|
+
if (this.config.useKVCache) {
|
|
3788
|
+
this.initKVCache(modelConfig);
|
|
3789
|
+
}
|
|
3790
|
+
}
|
|
3791
|
+
async extractWeights(model, config) {
|
|
3792
|
+
const loadTensor = (name) => {
|
|
3793
|
+
const info = model.tensorInfos.get(name);
|
|
3794
|
+
if (!info) {
|
|
3795
|
+
throw new Error(`Tensor not found: ${name}`);
|
|
3796
|
+
}
|
|
3797
|
+
if (model.metadata.format === "safetensors") {
|
|
3798
|
+
return loadSafetensorsTensor(model.buffer, info);
|
|
3799
|
+
} else {
|
|
3800
|
+
return loadGGUFTensor(model.buffer, info);
|
|
3801
|
+
}
|
|
3802
|
+
};
|
|
3803
|
+
const tryLoad = (names) => {
|
|
3804
|
+
for (const name of names) {
|
|
3805
|
+
if (model.tensorInfos.has(name)) {
|
|
3806
|
+
return loadTensor(name);
|
|
3807
|
+
}
|
|
3808
|
+
}
|
|
3809
|
+
throw new Error(`None of these tensors found: ${names.join(", ")}`);
|
|
3810
|
+
};
|
|
3811
|
+
const embedTokens = tryLoad([
|
|
3812
|
+
"model.embed_tokens.weight",
|
|
3813
|
+
"transformer.wte.weight",
|
|
3814
|
+
"embedding.weight"
|
|
3815
|
+
]);
|
|
3816
|
+
const layers = [];
|
|
3817
|
+
for (let i = 0;i < config.numLayers; i++) {
|
|
3818
|
+
const prefix = `model.layers.${i}`;
|
|
3819
|
+
const gptPrefix = `transformer.h.${i}`;
|
|
3820
|
+
const layerWeights = {
|
|
3821
|
+
attention: {
|
|
3822
|
+
qProj: tryLoad([`${prefix}.self_attn.q_proj.weight`, `${gptPrefix}.attn.q_proj.weight`]),
|
|
3823
|
+
kProj: tryLoad([`${prefix}.self_attn.k_proj.weight`, `${gptPrefix}.attn.k_proj.weight`]),
|
|
3824
|
+
vProj: tryLoad([`${prefix}.self_attn.v_proj.weight`, `${gptPrefix}.attn.v_proj.weight`]),
|
|
3825
|
+
oProj: tryLoad([`${prefix}.self_attn.o_proj.weight`, `${gptPrefix}.attn.o_proj.weight`])
|
|
3826
|
+
},
|
|
3827
|
+
ffn: {
|
|
3828
|
+
gate: model.tensorInfos.has(`${prefix}.mlp.gate_proj.weight`) ? loadTensor(`${prefix}.mlp.gate_proj.weight`) : undefined,
|
|
3829
|
+
up: tryLoad([`${prefix}.mlp.up_proj.weight`, `${gptPrefix}.mlp.up_proj.weight`]),
|
|
3830
|
+
down: tryLoad([`${prefix}.mlp.down_proj.weight`, `${gptPrefix}.mlp.down_proj.weight`])
|
|
3831
|
+
},
|
|
3832
|
+
inputNorm: tryLoad([
|
|
3833
|
+
`${prefix}.input_layernorm.weight`,
|
|
3834
|
+
`${gptPrefix}.ln_1.weight`
|
|
3835
|
+
]),
|
|
3836
|
+
postAttentionNorm: tryLoad([
|
|
3837
|
+
`${prefix}.post_attention_layernorm.weight`,
|
|
3838
|
+
`${gptPrefix}.ln_2.weight`
|
|
3839
|
+
])
|
|
3840
|
+
};
|
|
3841
|
+
layers.push(layerWeights);
|
|
3842
|
+
}
|
|
3843
|
+
const finalNorm = tryLoad([
|
|
3844
|
+
"model.norm.weight",
|
|
3845
|
+
"transformer.ln_f.weight"
|
|
3846
|
+
]);
|
|
3847
|
+
const lmHead = tryLoad([
|
|
3848
|
+
"lm_head.weight",
|
|
3849
|
+
"transformer.lm_head.weight"
|
|
3850
|
+
]);
|
|
3851
|
+
return { embedTokens, layers, finalNorm, lmHead };
|
|
3852
|
+
}
|
|
3853
|
+
initKVCache(config) {
|
|
3854
|
+
const headDim = config.headDim ?? config.hiddenSize / config.numHeads;
|
|
3855
|
+
const numKVHeads = config.numKVHeads ?? config.numHeads;
|
|
3856
|
+
const cacheSize = this.config.maxSeqLen * numKVHeads * headDim;
|
|
3857
|
+
this.kvCache = {
|
|
3858
|
+
keys: [],
|
|
3859
|
+
values: [],
|
|
3860
|
+
seqLen: 0
|
|
3861
|
+
};
|
|
3862
|
+
for (let i = 0;i < config.numLayers; i++) {
|
|
3863
|
+
this.kvCache.keys.push(new Float32Array(cacheSize));
|
|
3864
|
+
this.kvCache.values.push(new Float32Array(cacheSize));
|
|
3865
|
+
}
|
|
3866
|
+
}
|
|
3867
|
+
resetKVCache() {
|
|
3868
|
+
if (this.kvCache) {
|
|
3869
|
+
this.kvCache.seqLen = 0;
|
|
3870
|
+
for (let i = 0;i < this.kvCache.keys.length; i++) {
|
|
3871
|
+
this.kvCache.keys[i].fill(0);
|
|
3872
|
+
this.kvCache.values[i].fill(0);
|
|
3873
|
+
}
|
|
3874
|
+
}
|
|
3875
|
+
}
|
|
3876
|
+
forward(inputIds, startPos = 0) {
|
|
3877
|
+
if (!this.weights || !this.modelConfig) {
|
|
3878
|
+
throw new Error("Model not loaded. Call loadModel() first.");
|
|
3879
|
+
}
|
|
3880
|
+
const config = this.modelConfig;
|
|
3881
|
+
const weights = this.weights;
|
|
3882
|
+
const seqLen = inputIds.length;
|
|
3883
|
+
const headDim = config.headDim ?? config.hiddenSize / config.numHeads;
|
|
3884
|
+
const numKVHeads = config.numKVHeads ?? config.numHeads;
|
|
3885
|
+
const eps = config.rmsNormEps ?? 0.00001;
|
|
3886
|
+
const inputIdsArray = Array.from(inputIds);
|
|
3887
|
+
let hidden = embeddingCPU(weights.embedTokens, inputIdsArray, config.hiddenSize);
|
|
3888
|
+
for (let layer = 0;layer < config.numLayers; layer++) {
|
|
3889
|
+
const lw = weights.layers[layer];
|
|
3890
|
+
const normedHidden = rmsNormCPU(hidden, lw.inputNorm, [seqLen, config.hiddenSize], eps);
|
|
3891
|
+
hidden = this.attentionForward(normedHidden, lw, layer, startPos, seqLen, headDim, numKVHeads, hidden);
|
|
3892
|
+
const normedHidden2 = rmsNormCPU(hidden, lw.postAttentionNorm, [seqLen, config.hiddenSize], eps);
|
|
3893
|
+
hidden = this.ffnForward(normedHidden2, lw, hidden);
|
|
3894
|
+
}
|
|
3895
|
+
hidden = rmsNormCPU(hidden, weights.finalNorm, [seqLen, config.hiddenSize], eps);
|
|
3896
|
+
const lastTokenHidden = hidden.slice((seqLen - 1) * config.hiddenSize, seqLen * config.hiddenSize);
|
|
3897
|
+
const logits = matmulCPU(lastTokenHidden, weights.lmHead, 1, config.vocabSize, config.hiddenSize);
|
|
3898
|
+
return {
|
|
3899
|
+
logits,
|
|
3900
|
+
logitsShape: [1, config.vocabSize]
|
|
3901
|
+
};
|
|
3902
|
+
}
|
|
3903
|
+
attentionForward(x, lw, layerIdx, startPos, seqLen, headDim, numKVHeads, residual) {
|
|
3904
|
+
const config = this.modelConfig;
|
|
3905
|
+
const hiddenSize = config.hiddenSize;
|
|
3906
|
+
const numHeads = config.numHeads;
|
|
3907
|
+
let q = matmulCPU(x, lw.attention.qProj, seqLen, numHeads * headDim, hiddenSize);
|
|
3908
|
+
let k = matmulCPU(x, lw.attention.kProj, seqLen, numKVHeads * headDim, hiddenSize);
|
|
3909
|
+
let v = matmulCPU(x, lw.attention.vProj, seqLen, numKVHeads * headDim, hiddenSize);
|
|
3910
|
+
if (this.ropeFreqsCos && this.ropeFreqsSin) {
|
|
3911
|
+
for (let pos = 0;pos < seqLen; pos++) {
|
|
3912
|
+
const actualPos = startPos + pos;
|
|
3913
|
+
for (let h = 0;h < numHeads; h++) {
|
|
3914
|
+
const qOffset = pos * numHeads * headDim + h * headDim;
|
|
3915
|
+
this.applyRoPE(q, qOffset, actualPos, headDim);
|
|
3916
|
+
}
|
|
3917
|
+
for (let h = 0;h < numKVHeads; h++) {
|
|
3918
|
+
const kOffset = pos * numKVHeads * headDim + h * headDim;
|
|
3919
|
+
this.applyRoPE(k, kOffset, actualPos, headDim);
|
|
3920
|
+
}
|
|
3921
|
+
}
|
|
3922
|
+
}
|
|
3923
|
+
if (this.kvCache) {
|
|
3924
|
+
const kvSize = numKVHeads * headDim;
|
|
3925
|
+
for (let pos = 0;pos < seqLen; pos++) {
|
|
3926
|
+
const cachePos = (startPos + pos) * kvSize;
|
|
3927
|
+
this.kvCache.keys[layerIdx].set(k.subarray(pos * kvSize, (pos + 1) * kvSize), cachePos);
|
|
3928
|
+
this.kvCache.values[layerIdx].set(v.subarray(pos * kvSize, (pos + 1) * kvSize), cachePos);
|
|
3929
|
+
}
|
|
3930
|
+
this.kvCache.seqLen = startPos + seqLen;
|
|
3931
|
+
const totalLen = startPos + seqLen;
|
|
3932
|
+
k = this.kvCache.keys[layerIdx].slice(0, totalLen * kvSize);
|
|
3933
|
+
v = this.kvCache.values[layerIdx].slice(0, totalLen * kvSize);
|
|
3934
|
+
}
|
|
3935
|
+
const scale2 = 1 / Math.sqrt(headDim);
|
|
3936
|
+
const totalKVLen = this.kvCache ? this.kvCache.seqLen : seqLen;
|
|
3937
|
+
const attnOutput = new Float32Array(seqLen * numHeads * headDim);
|
|
3938
|
+
for (let pos = 0;pos < seqLen; pos++) {
|
|
3939
|
+
for (let h = 0;h < numHeads; h++) {
|
|
3940
|
+
const kvHead = Math.floor(h * numKVHeads / numHeads);
|
|
3941
|
+
const scores = new Float32Array(totalKVLen);
|
|
3942
|
+
for (let kPos = 0;kPos < totalKVLen; kPos++) {
|
|
3943
|
+
if (kPos > startPos + pos) {
|
|
3944
|
+
scores[kPos] = -Infinity;
|
|
3945
|
+
continue;
|
|
3946
|
+
}
|
|
3947
|
+
let score = 0;
|
|
3948
|
+
for (let d = 0;d < headDim; d++) {
|
|
3949
|
+
const qIdx = pos * numHeads * headDim + h * headDim + d;
|
|
3950
|
+
const kIdx = kPos * numKVHeads * headDim + kvHead * headDim + d;
|
|
3951
|
+
score += q[qIdx] * k[kIdx];
|
|
3952
|
+
}
|
|
3953
|
+
scores[kPos] = score * scale2;
|
|
3954
|
+
}
|
|
3955
|
+
const probs = softmaxCPU(scores, [totalKVLen]);
|
|
3956
|
+
for (let d = 0;d < headDim; d++) {
|
|
3957
|
+
let val = 0;
|
|
3958
|
+
for (let vPos = 0;vPos < totalKVLen; vPos++) {
|
|
3959
|
+
const vIdx = vPos * numKVHeads * headDim + kvHead * headDim + d;
|
|
3960
|
+
val += probs[vPos] * v[vIdx];
|
|
3961
|
+
}
|
|
3962
|
+
const outIdx = pos * numHeads * headDim + h * headDim + d;
|
|
3963
|
+
attnOutput[outIdx] = val;
|
|
3964
|
+
}
|
|
3965
|
+
}
|
|
3966
|
+
}
|
|
3967
|
+
const projected = matmulCPU(attnOutput, lw.attention.oProj, seqLen, hiddenSize, numHeads * headDim);
|
|
3968
|
+
return addCPU(residual, projected);
|
|
3969
|
+
}
|
|
3970
|
+
applyRoPE(x, offset, position, headDim) {
|
|
3971
|
+
for (let i = 0;i < headDim / 2; i++) {
|
|
3972
|
+
const freqIdx = position * (headDim / 2) + i;
|
|
3973
|
+
const cos = this.ropeFreqsCos[freqIdx];
|
|
3974
|
+
const sin = this.ropeFreqsSin[freqIdx];
|
|
3975
|
+
const x0 = x[offset + i];
|
|
3976
|
+
const x1 = x[offset + headDim / 2 + i];
|
|
3977
|
+
x[offset + i] = x0 * cos - x1 * sin;
|
|
3978
|
+
x[offset + headDim / 2 + i] = x0 * sin + x1 * cos;
|
|
3979
|
+
}
|
|
3980
|
+
}
|
|
3981
|
+
ffnForward(x, lw, residual) {
|
|
3982
|
+
const config = this.modelConfig;
|
|
3983
|
+
const seqLen = x.length / config.hiddenSize;
|
|
3984
|
+
const up = matmulCPU(x, lw.ffn.up, seqLen, config.intermediateSize, config.hiddenSize);
|
|
3985
|
+
let gateOut;
|
|
3986
|
+
if (lw.ffn.gate) {
|
|
3987
|
+
gateOut = matmulCPU(x, lw.ffn.gate, seqLen, config.intermediateSize, config.hiddenSize);
|
|
3988
|
+
const upSilu = siluCPU(up);
|
|
3989
|
+
gateOut = mulCPU(gateOut, upSilu);
|
|
3990
|
+
} else {
|
|
3991
|
+
gateOut = siluCPU(up);
|
|
3992
|
+
}
|
|
3993
|
+
const down = matmulCPU(gateOut, lw.ffn.down, seqLen, config.hiddenSize, config.intermediateSize);
|
|
3994
|
+
return addCPU(residual, down);
|
|
3995
|
+
}
|
|
3996
|
+
getModelConfig() {
|
|
3997
|
+
return this.modelConfig;
|
|
3998
|
+
}
|
|
3999
|
+
isLoaded() {
|
|
4000
|
+
return this.weights !== null;
|
|
4001
|
+
}
|
|
4002
|
+
dispose() {
|
|
4003
|
+
this.weights = null;
|
|
4004
|
+
this.loadedModel = null;
|
|
4005
|
+
this.kvCache = null;
|
|
4006
|
+
this.ropeFreqsCos = null;
|
|
4007
|
+
this.ropeFreqsSin = null;
|
|
4008
|
+
}
|
|
4009
|
+
}
|
|
4010
|
+
// src/inference/generate.ts
|
|
4011
|
+
function sampleNextToken(logits, config, generatedTokens) {
|
|
4012
|
+
return sampleCPU(logits, {
|
|
4013
|
+
temperature: config.temperature,
|
|
4014
|
+
topK: config.topK,
|
|
4015
|
+
topP: config.topP,
|
|
4016
|
+
repetitionPenalty: config.repetitionPenalty
|
|
4017
|
+
}, generatedTokens || []);
|
|
4018
|
+
}
|
|
4019
|
+
function checkStopSequences(generatedTokens, stopSequences) {
|
|
4020
|
+
if (!stopSequences || stopSequences.length === 0) {
|
|
4021
|
+
return false;
|
|
4022
|
+
}
|
|
4023
|
+
for (const stopSeq of stopSequences) {
|
|
4024
|
+
if (generatedTokens.length >= stopSeq.length) {
|
|
4025
|
+
const tail = generatedTokens.slice(-stopSeq.length);
|
|
4026
|
+
if (tail.every((t, i) => t === stopSeq[i])) {
|
|
4027
|
+
return true;
|
|
4028
|
+
}
|
|
4029
|
+
}
|
|
4030
|
+
}
|
|
4031
|
+
return false;
|
|
4032
|
+
}
|
|
4033
|
+
async function generate(engine, promptTokens, config = {}) {
|
|
4034
|
+
const normalizedConfig = normalizeGenerationConfig(config);
|
|
4035
|
+
const startTime = performance.now();
|
|
4036
|
+
engine.resetKVCache();
|
|
4037
|
+
const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
|
|
4038
|
+
let result = engine.forward(prompt, 0);
|
|
4039
|
+
const generatedTokens = [];
|
|
4040
|
+
let finishReason = "length";
|
|
4041
|
+
let currentPos = prompt.length;
|
|
4042
|
+
for (let i = 0;i < normalizedConfig.maxTokens; i++) {
|
|
4043
|
+
const nextToken = sampleNextToken(result.logits, normalizedConfig, generatedTokens);
|
|
4044
|
+
generatedTokens.push(nextToken);
|
|
4045
|
+
if (nextToken === normalizedConfig.eosTokenId) {
|
|
4046
|
+
finishReason = "eos";
|
|
4047
|
+
break;
|
|
4048
|
+
}
|
|
4049
|
+
if (checkStopSequences(generatedTokens, normalizedConfig.stopSequences)) {
|
|
4050
|
+
finishReason = "stop";
|
|
4051
|
+
break;
|
|
4052
|
+
}
|
|
4053
|
+
const inputToken = new Uint32Array([nextToken]);
|
|
4054
|
+
result = engine.forward(inputToken, currentPos);
|
|
4055
|
+
currentPos += 1;
|
|
4056
|
+
}
|
|
4057
|
+
const endTime = performance.now();
|
|
4058
|
+
const totalTimeMs = endTime - startTime;
|
|
4059
|
+
return {
|
|
4060
|
+
tokens: generatedTokens,
|
|
4061
|
+
finishReason,
|
|
4062
|
+
promptTokens: prompt.length,
|
|
4063
|
+
generatedTokens: generatedTokens.length,
|
|
4064
|
+
totalTimeMs,
|
|
4065
|
+
tokensPerSecond: generatedTokens.length / totalTimeMs * 1000
|
|
4066
|
+
};
|
|
4067
|
+
}
|
|
4068
|
+
async function* generateStream(engine, promptTokens, config = {}) {
|
|
4069
|
+
const normalizedConfig = normalizeGenerationConfig(config);
|
|
4070
|
+
engine.resetKVCache();
|
|
4071
|
+
const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
|
|
4072
|
+
let result = engine.forward(prompt, 0);
|
|
4073
|
+
const generatedTokens = [];
|
|
4074
|
+
let currentPos = prompt.length;
|
|
4075
|
+
for (let i = 0;i < normalizedConfig.maxTokens; i++) {
|
|
4076
|
+
const nextToken = sampleNextToken(result.logits, normalizedConfig, generatedTokens);
|
|
4077
|
+
generatedTokens.push(nextToken);
|
|
4078
|
+
let finishReason;
|
|
4079
|
+
let isLast = false;
|
|
4080
|
+
if (nextToken === normalizedConfig.eosTokenId) {
|
|
4081
|
+
finishReason = "eos";
|
|
4082
|
+
isLast = true;
|
|
4083
|
+
} else if (checkStopSequences(generatedTokens, normalizedConfig.stopSequences)) {
|
|
4084
|
+
finishReason = "stop";
|
|
4085
|
+
isLast = true;
|
|
4086
|
+
} else if (i === normalizedConfig.maxTokens - 1) {
|
|
4087
|
+
finishReason = "length";
|
|
4088
|
+
isLast = true;
|
|
4089
|
+
}
|
|
4090
|
+
yield {
|
|
4091
|
+
tokenId: nextToken,
|
|
4092
|
+
index: i,
|
|
4093
|
+
isLast,
|
|
4094
|
+
finishReason
|
|
4095
|
+
};
|
|
4096
|
+
if (isLast) {
|
|
4097
|
+
break;
|
|
4098
|
+
}
|
|
4099
|
+
const inputToken = new Uint32Array([nextToken]);
|
|
4100
|
+
result = engine.forward(inputToken, currentPos);
|
|
4101
|
+
currentPos += 1;
|
|
4102
|
+
await new Promise((resolve) => setTimeout(resolve, 0));
|
|
4103
|
+
}
|
|
4104
|
+
}
|
|
4105
|
+
function greedyDecode(engine, promptTokens, maxTokens, eosTokenId = 2) {
|
|
4106
|
+
engine.resetKVCache();
|
|
4107
|
+
const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
|
|
4108
|
+
let result = engine.forward(prompt, 0);
|
|
4109
|
+
const generatedTokens = [];
|
|
4110
|
+
let currentPos = prompt.length;
|
|
4111
|
+
for (let i = 0;i < maxTokens; i++) {
|
|
4112
|
+
let maxIdx = 0;
|
|
4113
|
+
let maxVal = result.logits[0];
|
|
4114
|
+
for (let j = 1;j < result.logits.length; j++) {
|
|
4115
|
+
if (result.logits[j] > maxVal) {
|
|
4116
|
+
maxVal = result.logits[j];
|
|
4117
|
+
maxIdx = j;
|
|
4118
|
+
}
|
|
4119
|
+
}
|
|
4120
|
+
generatedTokens.push(maxIdx);
|
|
4121
|
+
if (maxIdx === eosTokenId) {
|
|
4122
|
+
break;
|
|
4123
|
+
}
|
|
4124
|
+
const inputToken = new Uint32Array([maxIdx]);
|
|
4125
|
+
result = engine.forward(inputToken, currentPos);
|
|
4126
|
+
currentPos += 1;
|
|
4127
|
+
}
|
|
4128
|
+
return generatedTokens;
|
|
4129
|
+
}
|
|
4130
|
+
export {
|
|
4131
|
+
transposeCPU,
|
|
4132
|
+
transpose2DCPU,
|
|
4133
|
+
transpose2D,
|
|
4134
|
+
topPFilterCPU,
|
|
4135
|
+
topPFilter,
|
|
4136
|
+
topKFilter,
|
|
4137
|
+
topKCPU,
|
|
4138
|
+
topK,
|
|
4139
|
+
softmaxGPU,
|
|
4140
|
+
softmaxCPU,
|
|
4141
|
+
softmax,
|
|
4142
|
+
siluCPU,
|
|
4143
|
+
silu,
|
|
4144
|
+
sigmoidCPU,
|
|
4145
|
+
scaleCPU,
|
|
4146
|
+
scale,
|
|
4147
|
+
sampleNextToken,
|
|
4148
|
+
sampleGreedy,
|
|
4149
|
+
sampleFromProbs,
|
|
4150
|
+
sampleCPU,
|
|
4151
|
+
sample,
|
|
4152
|
+
ropeCPU,
|
|
4153
|
+
rope,
|
|
4154
|
+
rmsNormCPU,
|
|
4155
|
+
rmsNorm,
|
|
4156
|
+
reshapeCPU,
|
|
4157
|
+
reluCPU,
|
|
4158
|
+
relu,
|
|
4159
|
+
quantizeToInt8,
|
|
4160
|
+
quantizeToInt4,
|
|
4161
|
+
quantizationError,
|
|
4162
|
+
qmatmulInt8CPU,
|
|
4163
|
+
qmatmulInt8BlockCPU,
|
|
4164
|
+
qmatmulInt4CPU,
|
|
4165
|
+
permuteCPU,
|
|
4166
|
+
parseSafetensorsHeader,
|
|
4167
|
+
parseGGUFHeader,
|
|
4168
|
+
normalizeGenerationConfig,
|
|
4169
|
+
mulCPU,
|
|
4170
|
+
mul,
|
|
4171
|
+
matmulCPU,
|
|
4172
|
+
matmul,
|
|
4173
|
+
logSoftmaxCPU,
|
|
4174
|
+
loadSafetensorsFromUrl,
|
|
4175
|
+
loadSafetensors,
|
|
4176
|
+
loadModel,
|
|
4177
|
+
loadGGUFTensor,
|
|
4178
|
+
loadGGUFFromUrl,
|
|
4179
|
+
loadGGUF,
|
|
4180
|
+
layerNormCPU,
|
|
4181
|
+
layerNorm,
|
|
4182
|
+
isSafetensors,
|
|
4183
|
+
isGGUF,
|
|
4184
|
+
greedyDecode,
|
|
4185
|
+
getSparsityRatio,
|
|
4186
|
+
getSlidingWindowSparsity,
|
|
4187
|
+
getMemorySavings,
|
|
4188
|
+
getMatMulCacheStats,
|
|
4189
|
+
getCausalSparsity,
|
|
4190
|
+
generateStream,
|
|
4191
|
+
generate,
|
|
4192
|
+
geluExactCPU,
|
|
4193
|
+
geluCPU,
|
|
4194
|
+
gelu,
|
|
4195
|
+
fmaCPU,
|
|
4196
|
+
flashAttention,
|
|
4197
|
+
estimateQMatMulFlops,
|
|
4198
|
+
estimateQMatMulBandwidth,
|
|
4199
|
+
estimateMemorySavings,
|
|
4200
|
+
embeddingCPU,
|
|
4201
|
+
embedding,
|
|
4202
|
+
dequantizeInt8,
|
|
4203
|
+
dequantizeInt4,
|
|
4204
|
+
computeRoPEFrequencies,
|
|
4205
|
+
buildSlidingWindowMask,
|
|
4206
|
+
buildCausalSlidingWindowMask,
|
|
4207
|
+
buildCausalMask,
|
|
4208
|
+
buildBlockSparseCSR,
|
|
4209
|
+
batchedEmbeddingCPU,
|
|
4210
|
+
attentionCPU,
|
|
4211
|
+
applyRepetitionPenalty,
|
|
4212
|
+
addScalarCPU,
|
|
4213
|
+
addCPU,
|
|
4214
|
+
add,
|
|
4215
|
+
WebInferDevice,
|
|
4216
|
+
WGSLCompiler,
|
|
4217
|
+
Tensor,
|
|
4218
|
+
PagedKVCache,
|
|
4219
|
+
KernelCache,
|
|
4220
|
+
InferenceEngine,
|
|
4221
|
+
GGUFQuantType,
|
|
4222
|
+
GGUFMetadataValueType,
|
|
4223
|
+
DEFAULT_GENERATION_CONFIG,
|
|
4224
|
+
ContinuousBatchScheduler,
|
|
4225
|
+
BufferPool,
|
|
4226
|
+
BlockManager,
|
|
4227
|
+
AttentionScheduler
|
|
4228
|
+
};
|