webinfer 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/LICENSE +201 -0
  2. package/dist/attention/block-sparse/format.d.ts +52 -0
  3. package/dist/attention/block-sparse/patterns/causal.d.ts +16 -0
  4. package/dist/attention/block-sparse/patterns/sliding.d.ts +22 -0
  5. package/dist/attention/flash-attention.d.ts +30 -0
  6. package/dist/attention/index.d.ts +9 -0
  7. package/dist/attention/paged-kv/block-manager.d.ts +102 -0
  8. package/dist/attention/paged-kv/index.d.ts +5 -0
  9. package/dist/attention/paged-kv/page-table.d.ts +99 -0
  10. package/dist/attention/scheduler.d.ts +40 -0
  11. package/dist/core/buffer-pool.d.ts +18 -0
  12. package/dist/core/device.d.ts +23 -0
  13. package/dist/core/tensor.d.ts +25 -0
  14. package/dist/index.d.ts +22 -0
  15. package/dist/index.js +4228 -0
  16. package/dist/inference/engine.d.ts +69 -0
  17. package/dist/inference/generate.d.ts +30 -0
  18. package/dist/inference/index.d.ts +7 -0
  19. package/dist/inference/types.d.ts +161 -0
  20. package/dist/jit/compiler.d.ts +23 -0
  21. package/dist/jit/kernel-cache.d.ts +21 -0
  22. package/dist/model/gguf.d.ts +90 -0
  23. package/dist/model/index.d.ts +16 -0
  24. package/dist/model/safetensors.d.ts +38 -0
  25. package/dist/model/types.d.ts +182 -0
  26. package/dist/ops/activations.d.ts +43 -0
  27. package/dist/ops/elementwise.d.ts +38 -0
  28. package/dist/ops/embedding.d.ts +30 -0
  29. package/dist/ops/matmul.d.ts +21 -0
  30. package/dist/ops/normalization.d.ts +24 -0
  31. package/dist/ops/reshape.d.ts +39 -0
  32. package/dist/ops/rope.d.ts +32 -0
  33. package/dist/ops/softmax.d.ts +18 -0
  34. package/dist/quantization/index.d.ts +6 -0
  35. package/dist/quantization/qmatmul.d.ts +38 -0
  36. package/dist/quantization/quantize.d.ts +52 -0
  37. package/dist/sampling/index.d.ts +6 -0
  38. package/dist/sampling/sampler.d.ts +39 -0
  39. package/dist/sampling/top-k.d.ts +24 -0
  40. package/dist/sampling/top-p.d.ts +14 -0
  41. package/package.json +54 -0
package/dist/index.js ADDED
@@ -0,0 +1,4228 @@
1
+ // src/core/device.ts
2
+ class WebInferDevice {
3
+ _device;
4
+ _info;
5
+ constructor(device, info) {
6
+ this._device = device;
7
+ this._info = info;
8
+ }
9
+ static async create() {
10
+ if (!navigator.gpu) {
11
+ throw new Error("WebGPU not supported in this browser");
12
+ }
13
+ const adapter = await navigator.gpu.requestAdapter({
14
+ powerPreference: "high-performance"
15
+ });
16
+ if (!adapter) {
17
+ throw new Error("No WebGPU adapter found");
18
+ }
19
+ const device = await adapter.requestDevice({
20
+ requiredLimits: {
21
+ maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
22
+ maxBufferSize: adapter.limits.maxBufferSize,
23
+ maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
24
+ maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup
25
+ }
26
+ });
27
+ device.lost.then((info2) => {
28
+ console.error("WebGPU device lost:", info2.message);
29
+ });
30
+ const info = WebInferDevice.detectDeviceInfo(adapter, device);
31
+ return new WebInferDevice(device, info);
32
+ }
33
+ static detectDeviceInfo(adapter, device) {
34
+ const adapterInfo = adapter.info;
35
+ const vendorLower = (adapterInfo.vendor || "").toLowerCase();
36
+ const architectureLower = (adapterInfo.architecture || "").toLowerCase();
37
+ let vendor = "unknown";
38
+ if (vendorLower.includes("apple") || architectureLower.includes("apple")) {
39
+ vendor = "apple";
40
+ } else if (vendorLower.includes("nvidia") || architectureLower.includes("nvidia")) {
41
+ vendor = "nvidia";
42
+ } else if (vendorLower.includes("intel") || architectureLower.includes("intel")) {
43
+ vendor = "intel";
44
+ } else if (vendorLower.includes("amd") || vendorLower.includes("advanced micro")) {
45
+ vendor = "amd";
46
+ }
47
+ return {
48
+ vendor,
49
+ architecture: adapterInfo.architecture || "unknown",
50
+ maxWorkgroupSize: device.limits.maxComputeWorkgroupSizeX,
51
+ maxComputeInvocationsPerWorkgroup: device.limits.maxComputeInvocationsPerWorkgroup,
52
+ maxStorageBufferBindingSize: device.limits.maxStorageBufferBindingSize
53
+ };
54
+ }
55
+ get device() {
56
+ return this._device;
57
+ }
58
+ get info() {
59
+ return this._info;
60
+ }
61
+ get limits() {
62
+ return this._device.limits;
63
+ }
64
+ createCommandEncoder() {
65
+ return this._device.createCommandEncoder();
66
+ }
67
+ submit(commandBuffers) {
68
+ this._device.queue.submit(commandBuffers);
69
+ }
70
+ dispose() {
71
+ this._device.destroy();
72
+ }
73
+ }
74
+ // src/core/tensor.ts
75
+ var DTYPE_BYTES = {
76
+ f32: 4,
77
+ f16: 2,
78
+ i32: 4,
79
+ u32: 4
80
+ };
81
+
82
+ class Tensor {
83
+ _device;
84
+ _shape;
85
+ _dtype;
86
+ _buffer;
87
+ _disposed = false;
88
+ constructor(device, shape, dtype = "f32", data) {
89
+ this._device = device;
90
+ this._shape = Object.freeze([...shape]);
91
+ this._dtype = dtype;
92
+ const byteSize = this.numel * DTYPE_BYTES[dtype];
93
+ const alignedSize = Math.ceil(byteSize / 16) * 16;
94
+ this._buffer = device.device.createBuffer({
95
+ size: alignedSize,
96
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
97
+ mappedAtCreation: !!data
98
+ });
99
+ if (data) {
100
+ const mapped = new Float32Array(this._buffer.getMappedRange());
101
+ mapped.set(data);
102
+ this._buffer.unmap();
103
+ }
104
+ }
105
+ static async fromArray(device, shape, data, dtype = "f32") {
106
+ return new Tensor(device, shape, dtype, data);
107
+ }
108
+ static zeros(device, shape, dtype = "f32") {
109
+ const numel = shape.reduce((a, b) => a * b, 1);
110
+ const data = new Float32Array(numel);
111
+ return new Tensor(device, shape, dtype, data);
112
+ }
113
+ static rand(device, shape, dtype = "f32") {
114
+ const numel = shape.reduce((a, b) => a * b, 1);
115
+ const data = new Float32Array(numel);
116
+ for (let i = 0;i < numel; i++) {
117
+ data[i] = Math.random();
118
+ }
119
+ return new Tensor(device, shape, dtype, data);
120
+ }
121
+ get shape() {
122
+ return this._shape;
123
+ }
124
+ get dtype() {
125
+ return this._dtype;
126
+ }
127
+ get numel() {
128
+ return this._shape.reduce((a, b) => a * b, 1);
129
+ }
130
+ get byteSize() {
131
+ return this.numel * DTYPE_BYTES[this._dtype];
132
+ }
133
+ get buffer() {
134
+ if (this._disposed) {
135
+ throw new Error("Tensor has been disposed");
136
+ }
137
+ return this._buffer;
138
+ }
139
+ get device() {
140
+ return this._device;
141
+ }
142
+ async toArray() {
143
+ if (this._disposed) {
144
+ throw new Error("Tensor has been disposed");
145
+ }
146
+ const byteSize = this.byteSize;
147
+ const alignedSize = Math.ceil(byteSize / 16) * 16;
148
+ const stagingBuffer = this._device.device.createBuffer({
149
+ size: alignedSize,
150
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
151
+ });
152
+ const encoder = this._device.createCommandEncoder();
153
+ encoder.copyBufferToBuffer(this._buffer, 0, stagingBuffer, 0, alignedSize);
154
+ this._device.submit([encoder.finish()]);
155
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
156
+ const data = new Float32Array(stagingBuffer.getMappedRange().slice(0));
157
+ stagingBuffer.unmap();
158
+ stagingBuffer.destroy();
159
+ return data.slice(0, this.numel);
160
+ }
161
+ reshape(newShape) {
162
+ const newNumel = newShape.reduce((a, b) => a * b, 1);
163
+ if (newNumel !== this.numel) {
164
+ throw new Error(`Cannot reshape tensor of size ${this.numel} to shape [${newShape}]`);
165
+ }
166
+ const view = Object.create(Tensor.prototype);
167
+ view._device = this._device;
168
+ view._shape = Object.freeze([...newShape]);
169
+ view._dtype = this._dtype;
170
+ view._buffer = this._buffer;
171
+ view._disposed = false;
172
+ return view;
173
+ }
174
+ dispose() {
175
+ if (!this._disposed) {
176
+ this._buffer.destroy();
177
+ this._disposed = true;
178
+ }
179
+ }
180
+ }
181
+ // src/core/buffer-pool.ts
182
+ class BufferPool {
183
+ device;
184
+ pools = new Map;
185
+ sizeClasses;
186
+ constructor(device) {
187
+ this.device = device;
188
+ this.sizeClasses = [];
189
+ for (let size = 256;size <= 1024 * 1024 * 1024; size *= 2) {
190
+ this.sizeClasses.push(size);
191
+ }
192
+ }
193
+ getSizeClass(size) {
194
+ for (const sizeClass of this.sizeClasses) {
195
+ if (sizeClass >= size) {
196
+ return sizeClass;
197
+ }
198
+ }
199
+ return Math.pow(2, Math.ceil(Math.log2(size)));
200
+ }
201
+ acquire(size, usage) {
202
+ const sizeClass = this.getSizeClass(size);
203
+ const pool = this.pools.get(sizeClass);
204
+ if (pool) {
205
+ for (const pooled2 of pool) {
206
+ if (!pooled2.inUse && (pooled2.buffer.usage & usage) === usage) {
207
+ pooled2.inUse = true;
208
+ return pooled2.buffer;
209
+ }
210
+ }
211
+ }
212
+ const buffer = this.device.createBuffer({
213
+ size: sizeClass,
214
+ usage: usage | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
215
+ });
216
+ const pooled = {
217
+ buffer,
218
+ size: sizeClass,
219
+ inUse: true
220
+ };
221
+ if (!this.pools.has(sizeClass)) {
222
+ this.pools.set(sizeClass, []);
223
+ }
224
+ this.pools.get(sizeClass).push(pooled);
225
+ return buffer;
226
+ }
227
+ release(buffer) {
228
+ for (const pool of this.pools.values()) {
229
+ for (const pooled of pool) {
230
+ if (pooled.buffer === buffer) {
231
+ pooled.inUse = false;
232
+ return;
233
+ }
234
+ }
235
+ }
236
+ }
237
+ getStats() {
238
+ let totalBuffers = 0;
239
+ let inUse = 0;
240
+ let totalBytes = 0;
241
+ for (const pool of this.pools.values()) {
242
+ for (const pooled of pool) {
243
+ totalBuffers++;
244
+ totalBytes += pooled.size;
245
+ if (pooled.inUse)
246
+ inUse++;
247
+ }
248
+ }
249
+ return { totalBuffers, inUse, totalBytes };
250
+ }
251
+ dispose() {
252
+ for (const pool of this.pools.values()) {
253
+ for (const pooled of pool) {
254
+ pooled.buffer.destroy();
255
+ }
256
+ }
257
+ this.pools.clear();
258
+ }
259
+ }
260
+ // src/jit/kernel-cache.ts
261
+ class KernelCache {
262
+ device;
263
+ cache = new Map;
264
+ hits = 0;
265
+ misses = 0;
266
+ constructor(device) {
267
+ this.device = device;
268
+ }
269
+ getOrCreate(key, createFn) {
270
+ const existing = this.cache.get(key);
271
+ if (existing) {
272
+ this.hits++;
273
+ return existing;
274
+ }
275
+ this.misses++;
276
+ const pipeline = createFn();
277
+ this.cache.set(key, pipeline);
278
+ return pipeline;
279
+ }
280
+ has(key) {
281
+ return this.cache.has(key);
282
+ }
283
+ get(key) {
284
+ const pipeline = this.cache.get(key);
285
+ if (pipeline)
286
+ this.hits++;
287
+ return pipeline;
288
+ }
289
+ set(key, pipeline) {
290
+ this.cache.set(key, pipeline);
291
+ }
292
+ getStats() {
293
+ return {
294
+ hits: this.hits,
295
+ misses: this.misses,
296
+ size: this.cache.size
297
+ };
298
+ }
299
+ clear() {
300
+ this.cache.clear();
301
+ this.hits = 0;
302
+ this.misses = 0;
303
+ }
304
+ }
305
+ // src/jit/compiler.ts
306
+ class WGSLCompiler {
307
+ device;
308
+ cache;
309
+ deviceInfo;
310
+ constructor(device, cache, deviceInfo) {
311
+ this.device = device;
312
+ this.cache = cache;
313
+ this.deviceInfo = deviceInfo;
314
+ }
315
+ selectTileSize(config) {
316
+ if (this.deviceInfo.vendor === "apple") {
317
+ return { tileM: 16, tileN: 16, tileK: 16 };
318
+ } else if (this.deviceInfo.vendor === "nvidia") {
319
+ return { tileM: 32, tileN: 32, tileK: 16 };
320
+ }
321
+ return { tileM: 16, tileN: 16, tileK: 16 };
322
+ }
323
+ compileMatMul(config) {
324
+ const tiles = this.selectTileSize(config);
325
+ const tileM = config.tileM ?? tiles.tileM;
326
+ const tileN = config.tileN ?? tiles.tileN;
327
+ const tileK = config.tileK ?? tiles.tileK;
328
+ const key = `matmul_${config.M}_${config.N}_${config.K}_${tileM}_${tileN}_${tileK}`;
329
+ return this.cache.getOrCreate(key, () => {
330
+ const wgsl = this.generateMatMulWGSL(config.M, config.N, config.K, tileM, tileN, tileK);
331
+ const shaderModule = this.device.createShaderModule({
332
+ code: wgsl
333
+ });
334
+ return this.device.createComputePipeline({
335
+ layout: "auto",
336
+ compute: {
337
+ module: shaderModule,
338
+ entryPoint: "main"
339
+ }
340
+ });
341
+ });
342
+ }
343
+ generateMatMulWGSL(M, N, K, tileM, tileN, tileK) {
344
+ const workgroupSizeX = tileN;
345
+ const workgroupSizeY = tileM;
346
+ return `
347
+ // WebInfer MatMul Kernel
348
+ // C[M,N] = A[M,K] @ B[K,N]
349
+ // Tile size: ${tileM}x${tileN}x${tileK}
350
+
351
+ struct Params {
352
+ M: u32,
353
+ N: u32,
354
+ K: u32,
355
+ }
356
+
357
+ @group(0) @binding(0) var<storage, read> A: array<f32>;
358
+ @group(0) @binding(1) var<storage, read> B: array<f32>;
359
+ @group(0) @binding(2) var<storage, read_write> C: array<f32>;
360
+ @group(0) @binding(3) var<uniform> params: Params;
361
+
362
+ var<workgroup> tileA: array<f32, ${tileM * tileK}>;
363
+ var<workgroup> tileB: array<f32, ${tileK * tileN}>;
364
+
365
+ @compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY})
366
+ fn main(
367
+ @builtin(global_invocation_id) global_id: vec3<u32>,
368
+ @builtin(local_invocation_id) local_id: vec3<u32>,
369
+ @builtin(workgroup_id) workgroup_id: vec3<u32>
370
+ ) {
371
+ let row = workgroup_id.y * ${tileM}u + local_id.y;
372
+ let col = workgroup_id.x * ${tileN}u + local_id.x;
373
+
374
+ let localRow = local_id.y;
375
+ let localCol = local_id.x;
376
+
377
+ var sum: f32 = 0.0;
378
+
379
+ let numTiles = (params.K + ${tileK}u - 1u) / ${tileK}u;
380
+
381
+ for (var t: u32 = 0u; t < numTiles; t = t + 1u) {
382
+ // Load tile of A into shared memory
383
+ let aRow = row;
384
+ let aCol = t * ${tileK}u + localCol;
385
+ if (aRow < params.M && aCol < params.K) {
386
+ tileA[localRow * ${tileK}u + localCol] = A[aRow * params.K + aCol];
387
+ } else {
388
+ tileA[localRow * ${tileK}u + localCol] = 0.0;
389
+ }
390
+
391
+ // Load tile of B into shared memory
392
+ let bRow = t * ${tileK}u + localRow;
393
+ let bCol = col;
394
+ if (bRow < params.K && bCol < params.N) {
395
+ tileB[localRow * ${tileN}u + localCol] = B[bRow * params.N + bCol];
396
+ } else {
397
+ tileB[localRow * ${tileN}u + localCol] = 0.0;
398
+ }
399
+
400
+ workgroupBarrier();
401
+
402
+ // Compute partial dot product
403
+ for (var k: u32 = 0u; k < ${tileK}u; k = k + 1u) {
404
+ sum = sum + tileA[localRow * ${tileK}u + k] * tileB[k * ${tileN}u + localCol];
405
+ }
406
+
407
+ workgroupBarrier();
408
+ }
409
+
410
+ // Write result
411
+ if (row < params.M && col < params.N) {
412
+ C[row * params.N + col] = sum;
413
+ }
414
+ }
415
+ `;
416
+ }
417
+ getCacheStats() {
418
+ return this.cache.getStats();
419
+ }
420
+ }
421
+ // src/ops/matmul.ts
422
+ var compilerInstance = null;
423
+ var cacheInstance = null;
424
+ function getCompiler(device) {
425
+ if (!compilerInstance || !cacheInstance) {
426
+ cacheInstance = new KernelCache(device.device);
427
+ compilerInstance = new WGSLCompiler(device.device, cacheInstance, device.info);
428
+ }
429
+ return compilerInstance;
430
+ }
431
+ async function matmul(device, a, b) {
432
+ if (a.shape.length !== 2 || b.shape.length !== 2) {
433
+ throw new Error("matmul requires 2D tensors");
434
+ }
435
+ const [M, K1] = a.shape;
436
+ const [K2, N] = b.shape;
437
+ if (K1 !== K2) {
438
+ throw new Error(`matmul shape mismatch: [${M},${K1}] @ [${K2},${N}] - inner dimensions must match`);
439
+ }
440
+ const K = K1;
441
+ const c = Tensor.zeros(device, [M, N]);
442
+ const compiler = getCompiler(device);
443
+ const pipeline = compiler.compileMatMul({ M, N, K });
444
+ const paramsBuffer = device.device.createBuffer({
445
+ size: 16,
446
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
447
+ });
448
+ device.device.queue.writeBuffer(paramsBuffer, 0, new Uint32Array([M, N, K]));
449
+ const bindGroup = device.device.createBindGroup({
450
+ layout: pipeline.getBindGroupLayout(0),
451
+ entries: [
452
+ { binding: 0, resource: { buffer: a.buffer } },
453
+ { binding: 1, resource: { buffer: b.buffer } },
454
+ { binding: 2, resource: { buffer: c.buffer } },
455
+ { binding: 3, resource: { buffer: paramsBuffer } }
456
+ ]
457
+ });
458
+ const encoder = device.createCommandEncoder();
459
+ const pass = encoder.beginComputePass();
460
+ pass.setPipeline(pipeline);
461
+ pass.setBindGroup(0, bindGroup);
462
+ const tileSize = 16;
463
+ const workgroupsX = Math.ceil(N / tileSize);
464
+ const workgroupsY = Math.ceil(M / tileSize);
465
+ pass.dispatchWorkgroups(workgroupsX, workgroupsY);
466
+ pass.end();
467
+ device.submit([encoder.finish()]);
468
+ await device.device.queue.onSubmittedWorkDone();
469
+ paramsBuffer.destroy();
470
+ return c;
471
+ }
472
+ function matmulCPU(a, b, M, N, K) {
473
+ const c = new Float32Array(M * N);
474
+ for (let i = 0;i < M; i++) {
475
+ for (let j = 0;j < N; j++) {
476
+ let sum = 0;
477
+ for (let k = 0;k < K; k++) {
478
+ sum += a[i * K + k] * b[k * N + j];
479
+ }
480
+ c[i * N + j] = sum;
481
+ }
482
+ }
483
+ return c;
484
+ }
485
+ function getMatMulCacheStats(device) {
486
+ const compiler = getCompiler(device);
487
+ return compiler.getCacheStats();
488
+ }
489
+ // src/ops/normalization.ts
490
+ var kernelCache = null;
491
+ function getCache(device) {
492
+ if (!kernelCache) {
493
+ kernelCache = new KernelCache(device);
494
+ }
495
+ return kernelCache;
496
+ }
497
+ function layerNormCPU(x, weight, bias, shape, eps = 0.00001) {
498
+ const lastDim = shape[shape.length - 1];
499
+ const outerSize = x.length / lastDim;
500
+ const output = new Float32Array(x.length);
501
+ for (let i = 0;i < outerSize; i++) {
502
+ const offset = i * lastDim;
503
+ let mean = 0;
504
+ for (let j = 0;j < lastDim; j++) {
505
+ mean += x[offset + j];
506
+ }
507
+ mean /= lastDim;
508
+ let variance = 0;
509
+ for (let j = 0;j < lastDim; j++) {
510
+ const diff = x[offset + j] - mean;
511
+ variance += diff * diff;
512
+ }
513
+ variance /= lastDim;
514
+ const invStd = 1 / Math.sqrt(variance + eps);
515
+ for (let j = 0;j < lastDim; j++) {
516
+ const normalized = (x[offset + j] - mean) * invStd;
517
+ output[offset + j] = normalized * weight[j] + (bias ? bias[j] : 0);
518
+ }
519
+ }
520
+ return output;
521
+ }
522
+ function rmsNormCPU(x, weight, shape, eps = 0.00001) {
523
+ const lastDim = shape[shape.length - 1];
524
+ const outerSize = x.length / lastDim;
525
+ const output = new Float32Array(x.length);
526
+ for (let i = 0;i < outerSize; i++) {
527
+ const offset = i * lastDim;
528
+ let sumSq = 0;
529
+ for (let j = 0;j < lastDim; j++) {
530
+ sumSq += x[offset + j] * x[offset + j];
531
+ }
532
+ const rms = Math.sqrt(sumSq / lastDim + eps);
533
+ const invRms = 1 / rms;
534
+ for (let j = 0;j < lastDim; j++) {
535
+ output[offset + j] = x[offset + j] * invRms * weight[j];
536
+ }
537
+ }
538
+ return output;
539
+ }
540
+ async function layerNorm(device, x, weight, bias, eps = 0.00001) {
541
+ const lastDim = x.shape[x.shape.length - 1];
542
+ const outerSize = x.numel / lastDim;
543
+ const cache = getCache(device.device);
544
+ const pipeline = cache.getOrCreate(`layernorm_${lastDim}_${bias !== null}`, () => compileLayerNormKernel(device.device, lastDim, bias !== null));
545
+ const output = Tensor.zeros(device, [...x.shape]);
546
+ const params = new Float32Array([outerSize, lastDim, eps, 0]);
547
+ const paramsBuffer = device.device.createBuffer({
548
+ size: params.byteLength,
549
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
550
+ });
551
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
552
+ const entries = [
553
+ { binding: 0, resource: { buffer: x.buffer } },
554
+ { binding: 1, resource: { buffer: weight.buffer } },
555
+ { binding: 2, resource: { buffer: output.buffer } },
556
+ { binding: 3, resource: { buffer: paramsBuffer } }
557
+ ];
558
+ if (bias) {
559
+ entries.push({ binding: 4, resource: { buffer: bias.buffer } });
560
+ }
561
+ const bindGroup = device.device.createBindGroup({
562
+ layout: pipeline.getBindGroupLayout(0),
563
+ entries
564
+ });
565
+ const encoder = device.createCommandEncoder();
566
+ const pass = encoder.beginComputePass();
567
+ pass.setPipeline(pipeline);
568
+ pass.setBindGroup(0, bindGroup);
569
+ pass.dispatchWorkgroups(outerSize);
570
+ pass.end();
571
+ device.submit([encoder.finish()]);
572
+ await device.device.queue.onSubmittedWorkDone();
573
+ paramsBuffer.destroy();
574
+ return output;
575
+ }
576
+ async function rmsNorm(device, x, weight, eps = 0.00001) {
577
+ const lastDim = x.shape[x.shape.length - 1];
578
+ const outerSize = x.numel / lastDim;
579
+ const cache = getCache(device.device);
580
+ const pipeline = cache.getOrCreate(`rmsnorm_${lastDim}`, () => compileRMSNormKernel(device.device, lastDim));
581
+ const output = Tensor.zeros(device, [...x.shape]);
582
+ const params = new Float32Array([outerSize, lastDim, eps, 0]);
583
+ const paramsBuffer = device.device.createBuffer({
584
+ size: params.byteLength,
585
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
586
+ });
587
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
588
+ const bindGroup = device.device.createBindGroup({
589
+ layout: pipeline.getBindGroupLayout(0),
590
+ entries: [
591
+ { binding: 0, resource: { buffer: x.buffer } },
592
+ { binding: 1, resource: { buffer: weight.buffer } },
593
+ { binding: 2, resource: { buffer: output.buffer } },
594
+ { binding: 3, resource: { buffer: paramsBuffer } }
595
+ ]
596
+ });
597
+ const encoder = device.createCommandEncoder();
598
+ const pass = encoder.beginComputePass();
599
+ pass.setPipeline(pipeline);
600
+ pass.setBindGroup(0, bindGroup);
601
+ pass.dispatchWorkgroups(outerSize);
602
+ pass.end();
603
+ device.submit([encoder.finish()]);
604
+ await device.device.queue.onSubmittedWorkDone();
605
+ paramsBuffer.destroy();
606
+ return output;
607
+ }
608
+ function compileLayerNormKernel(device, dim, hasBias) {
609
+ const WORKGROUP_SIZE = 256;
610
+ const wgsl = `
611
+ struct Params {
612
+ outerSize: f32,
613
+ dim: f32,
614
+ eps: f32,
615
+ _pad: f32,
616
+ }
617
+
618
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
619
+ @group(0) @binding(1) var<storage, read> weight: array<f32>;
620
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
621
+ @group(0) @binding(3) var<uniform> params: Params;
622
+ ${hasBias ? "@group(0) @binding(4) var<storage, read> bias: array<f32>;" : ""}
623
+
624
+ var<workgroup> shared_sum: array<f32, ${WORKGROUP_SIZE}>;
625
+ var<workgroup> shared_mean: f32;
626
+
627
+ @compute @workgroup_size(${WORKGROUP_SIZE})
628
+ fn main(
629
+ @builtin(local_invocation_id) lid: vec3<u32>,
630
+ @builtin(workgroup_id) wgid: vec3<u32>
631
+ ) {
632
+ let row = wgid.x;
633
+ let tid = lid.x;
634
+ let dim = u32(params.dim);
635
+ let offset = row * dim;
636
+
637
+ // === Pass 1: Compute mean ===
638
+ var partial_sum: f32 = 0.0;
639
+ for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
640
+ partial_sum += x[offset + i];
641
+ }
642
+ shared_sum[tid] = partial_sum;
643
+ workgroupBarrier();
644
+
645
+ // Parallel reduction for sum
646
+ for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
647
+ if (tid < stride) {
648
+ shared_sum[tid] += shared_sum[tid + stride];
649
+ }
650
+ workgroupBarrier();
651
+ }
652
+
653
+ // Store mean for all threads to use
654
+ if (tid == 0u) {
655
+ shared_mean = shared_sum[0] / params.dim;
656
+ }
657
+ workgroupBarrier();
658
+ let mean = shared_mean;
659
+
660
+ // === Pass 2: Compute variance ===
661
+ var partial_var: f32 = 0.0;
662
+ for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
663
+ let diff = x[offset + i] - mean;
664
+ partial_var += diff * diff;
665
+ }
666
+ shared_sum[tid] = partial_var;
667
+ workgroupBarrier();
668
+
669
+ // Parallel reduction for variance
670
+ for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
671
+ if (tid < stride) {
672
+ shared_sum[tid] += shared_sum[tid + stride];
673
+ }
674
+ workgroupBarrier();
675
+ }
676
+
677
+ // Compute inverse standard deviation
678
+ let inv_std = 1.0 / sqrt(shared_sum[0] / params.dim + params.eps);
679
+
680
+ // === Pass 3: Normalize and apply affine transform ===
681
+ for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
682
+ let normalized = (x[offset + i] - mean) * inv_std;
683
+ ${hasBias ? "output[offset + i] = normalized * weight[i] + bias[i];" : "output[offset + i] = normalized * weight[i];"}
684
+ }
685
+ }
686
+ `;
687
+ const shaderModule = device.createShaderModule({ code: wgsl });
688
+ return device.createComputePipeline({
689
+ layout: "auto",
690
+ compute: { module: shaderModule, entryPoint: "main" }
691
+ });
692
+ }
693
+ function compileRMSNormKernel(device, dim) {
694
+ const WORKGROUP_SIZE = 256;
695
+ const wgsl = `
696
+ struct Params {
697
+ outerSize: f32,
698
+ dim: f32,
699
+ eps: f32,
700
+ _pad: f32,
701
+ }
702
+
703
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
704
+ @group(0) @binding(1) var<storage, read> weight: array<f32>;
705
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
706
+ @group(0) @binding(3) var<uniform> params: Params;
707
+
708
+ var<workgroup> shared_sum: array<f32, ${WORKGROUP_SIZE}>;
709
+
710
+ @compute @workgroup_size(${WORKGROUP_SIZE})
711
+ fn main(
712
+ @builtin(local_invocation_id) lid: vec3<u32>,
713
+ @builtin(workgroup_id) wgid: vec3<u32>
714
+ ) {
715
+ let row = wgid.x;
716
+ let tid = lid.x;
717
+ let dim = u32(params.dim);
718
+ let offset = row * dim;
719
+
720
+ // Each thread computes partial sum of squares
721
+ var partial_sum: f32 = 0.0;
722
+ for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
723
+ let val = x[offset + i];
724
+ partial_sum += val * val;
725
+ }
726
+ shared_sum[tid] = partial_sum;
727
+ workgroupBarrier();
728
+
729
+ // Parallel reduction in shared memory
730
+ for (var stride = ${WORKGROUP_SIZE / 2}u; stride > 0u; stride >>= 1u) {
731
+ if (tid < stride) {
732
+ shared_sum[tid] += shared_sum[tid + stride];
733
+ }
734
+ workgroupBarrier();
735
+ }
736
+
737
+ // Compute inverse RMS (thread 0 has the final sum)
738
+ let inv_rms = 1.0 / sqrt(shared_sum[0] / params.dim + params.eps);
739
+
740
+ // All threads normalize their portion
741
+ for (var i = tid; i < dim; i += ${WORKGROUP_SIZE}u) {
742
+ output[offset + i] = x[offset + i] * inv_rms * weight[i];
743
+ }
744
+ }
745
+ `;
746
+ const shaderModule = device.createShaderModule({ code: wgsl });
747
+ return device.createComputePipeline({
748
+ layout: "auto",
749
+ compute: { module: shaderModule, entryPoint: "main" }
750
+ });
751
+ }
752
+ // src/ops/rope.ts
753
+ var kernelCache2 = null;
754
+ function getCache2(device) {
755
+ if (!kernelCache2) {
756
+ kernelCache2 = new KernelCache(device);
757
+ }
758
+ return kernelCache2;
759
+ }
760
+ function computeRoPEFrequencies(config) {
761
+ const { dim, maxSeqLen, base = 1e4, scaling = 1 } = config;
762
+ const halfDim = dim / 2;
763
+ const invFreq = new Float32Array(halfDim);
764
+ for (let i = 0;i < halfDim; i++) {
765
+ invFreq[i] = 1 / Math.pow(base, 2 * i / dim);
766
+ }
767
+ const cos = new Float32Array(maxSeqLen * halfDim);
768
+ const sin = new Float32Array(maxSeqLen * halfDim);
769
+ for (let pos = 0;pos < maxSeqLen; pos++) {
770
+ const scaledPos = pos / scaling;
771
+ for (let i = 0;i < halfDim; i++) {
772
+ const angle = scaledPos * invFreq[i];
773
+ cos[pos * halfDim + i] = Math.cos(angle);
774
+ sin[pos * halfDim + i] = Math.sin(angle);
775
+ }
776
+ }
777
+ return { cos, sin };
778
+ }
779
+ function ropeCPU(x, positions, cos, sin, seqLen, numHeads, headDim) {
780
+ const halfDim = headDim / 2;
781
+ const output = new Float32Array(x.length);
782
+ for (let s = 0;s < seqLen; s++) {
783
+ const pos = positions[s];
784
+ const cosOffset = pos * halfDim;
785
+ const sinOffset = pos * halfDim;
786
+ for (let h = 0;h < numHeads; h++) {
787
+ const baseIdx = s * numHeads * headDim + h * headDim;
788
+ for (let d = 0;d < halfDim; d++) {
789
+ const x0 = x[baseIdx + d];
790
+ const x1 = x[baseIdx + halfDim + d];
791
+ const c = cos[cosOffset + d];
792
+ const si = sin[sinOffset + d];
793
+ output[baseIdx + d] = x0 * c - x1 * si;
794
+ output[baseIdx + halfDim + d] = x0 * si + x1 * c;
795
+ }
796
+ }
797
+ }
798
+ return output;
799
+ }
800
+ async function rope(device, x, positions, config) {
801
+ if (x.shape.length !== 3) {
802
+ throw new Error("RoPE input must be 3D [seqLen, numHeads, headDim]");
803
+ }
804
+ const [seqLen, numHeads, headDim] = x.shape;
805
+ const { cos, sin } = computeRoPEFrequencies(config);
806
+ const cache = getCache2(device.device);
807
+ const pipeline = cache.getOrCreate(`rope_${headDim}_${numHeads}`, () => compileRoPEKernel(device.device, headDim, numHeads));
808
+ const output = Tensor.zeros(device, [seqLen, numHeads, headDim]);
809
+ const cosBuffer = device.device.createBuffer({
810
+ size: cos.byteLength,
811
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
812
+ });
813
+ device.device.queue.writeBuffer(cosBuffer, 0, new Float32Array(cos));
814
+ const sinBuffer = device.device.createBuffer({
815
+ size: sin.byteLength,
816
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
817
+ });
818
+ device.device.queue.writeBuffer(sinBuffer, 0, new Float32Array(sin));
819
+ const params = new Uint32Array([seqLen, numHeads, headDim, headDim / 2]);
820
+ const paramsBuffer = device.device.createBuffer({
821
+ size: params.byteLength,
822
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
823
+ });
824
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
825
+ const bindGroup = device.device.createBindGroup({
826
+ layout: pipeline.getBindGroupLayout(0),
827
+ entries: [
828
+ { binding: 0, resource: { buffer: x.buffer } },
829
+ { binding: 1, resource: { buffer: positions.buffer } },
830
+ { binding: 2, resource: { buffer: cosBuffer } },
831
+ { binding: 3, resource: { buffer: sinBuffer } },
832
+ { binding: 4, resource: { buffer: output.buffer } },
833
+ { binding: 5, resource: { buffer: paramsBuffer } }
834
+ ]
835
+ });
836
+ const encoder = device.createCommandEncoder();
837
+ const pass = encoder.beginComputePass();
838
+ pass.setPipeline(pipeline);
839
+ pass.setBindGroup(0, bindGroup);
840
+ pass.dispatchWorkgroups(Math.ceil(seqLen / 64), numHeads);
841
+ pass.end();
842
+ device.submit([encoder.finish()]);
843
+ await device.device.queue.onSubmittedWorkDone();
844
+ cosBuffer.destroy();
845
+ sinBuffer.destroy();
846
+ paramsBuffer.destroy();
847
+ return output;
848
+ }
849
+ function compileRoPEKernel(device, headDim, numHeads) {
850
+ const halfDim = headDim / 2;
851
+ const wgsl = `
852
+ struct Params {
853
+ seqLen: u32,
854
+ numHeads: u32,
855
+ headDim: u32,
856
+ halfDim: u32,
857
+ }
858
+
859
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
860
+ @group(0) @binding(1) var<storage, read> positions: array<u32>;
861
+ @group(0) @binding(2) var<storage, read> cos: array<f32>;
862
+ @group(0) @binding(3) var<storage, read> sin: array<f32>;
863
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
864
+ @group(0) @binding(5) var<uniform> params: Params;
865
+
866
+ @compute @workgroup_size(64)
867
+ fn main(
868
+ @builtin(global_invocation_id) gid: vec3<u32>,
869
+ @builtin(workgroup_id) wgid: vec3<u32>
870
+ ) {
871
+ let seqIdx = gid.x;
872
+ let headIdx = wgid.y;
873
+
874
+ if (seqIdx >= params.seqLen) {
875
+ return;
876
+ }
877
+
878
+ let pos = positions[seqIdx];
879
+ let halfDim = params.halfDim;
880
+ let headDim = params.headDim;
881
+ let numHeads = params.numHeads;
882
+
883
+ let baseIdx = seqIdx * numHeads * headDim + headIdx * headDim;
884
+ let freqOffset = pos * halfDim;
885
+
886
+ // Apply rotation to pairs
887
+ for (var d = 0u; d < halfDim; d = d + 1u) {
888
+ let x0 = x[baseIdx + d];
889
+ let x1 = x[baseIdx + halfDim + d];
890
+ let c = cos[freqOffset + d];
891
+ let s = sin[freqOffset + d];
892
+
893
+ output[baseIdx + d] = x0 * c - x1 * s;
894
+ output[baseIdx + halfDim + d] = x0 * s + x1 * c;
895
+ }
896
+ }
897
+ `;
898
+ const shaderModule = device.createShaderModule({ code: wgsl });
899
+ return device.createComputePipeline({
900
+ layout: "auto",
901
+ compute: {
902
+ module: shaderModule,
903
+ entryPoint: "main"
904
+ }
905
+ });
906
+ }
907
+ // src/ops/activations.ts
908
+ var kernelCache3 = null;
909
+ function getCache3(device) {
910
+ if (!kernelCache3) {
911
+ kernelCache3 = new KernelCache(device);
912
+ }
913
+ return kernelCache3;
914
+ }
915
+ function geluCPU(x) {
916
+ const output = new Float32Array(x.length);
917
+ const sqrt2OverPi = Math.sqrt(2 / Math.PI);
918
+ for (let i = 0;i < x.length; i++) {
919
+ const xi = x[i];
920
+ const inner = sqrt2OverPi * (xi + 0.044715 * xi * xi * xi);
921
+ output[i] = xi * 0.5 * (1 + Math.tanh(inner));
922
+ }
923
+ return output;
924
+ }
925
+ function geluExactCPU(x) {
926
+ const output = new Float32Array(x.length);
927
+ const sqrt2 = Math.sqrt(2);
928
+ for (let i = 0;i < x.length; i++) {
929
+ const xi = x[i];
930
+ output[i] = xi * 0.5 * (1 + erf(xi / sqrt2));
931
+ }
932
+ return output;
933
+ }
934
+ function erf(x) {
935
+ const a1 = 0.254829592;
936
+ const a2 = -0.284496736;
937
+ const a3 = 1.421413741;
938
+ const a4 = -1.453152027;
939
+ const a5 = 1.061405429;
940
+ const p = 0.3275911;
941
+ const sign = x < 0 ? -1 : 1;
942
+ x = Math.abs(x);
943
+ const t = 1 / (1 + p * x);
944
+ const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
945
+ return sign * y;
946
+ }
947
+ function siluCPU(x) {
948
+ const output = new Float32Array(x.length);
949
+ for (let i = 0;i < x.length; i++) {
950
+ const xi = x[i];
951
+ output[i] = xi / (1 + Math.exp(-xi));
952
+ }
953
+ return output;
954
+ }
955
+ function reluCPU(x) {
956
+ const output = new Float32Array(x.length);
957
+ for (let i = 0;i < x.length; i++) {
958
+ output[i] = Math.max(0, x[i]);
959
+ }
960
+ return output;
961
+ }
962
+ function sigmoidCPU(x) {
963
+ const output = new Float32Array(x.length);
964
+ for (let i = 0;i < x.length; i++) {
965
+ output[i] = 1 / (1 + Math.exp(-x[i]));
966
+ }
967
+ return output;
968
+ }
969
+ async function gelu(device, x) {
970
+ const cache = getCache3(device.device);
971
+ const pipeline = cache.getOrCreate("gelu", () => compileGeluKernel(device.device));
972
+ const output = Tensor.zeros(device, [...x.shape]);
973
+ const params = new Uint32Array([x.numel]);
974
+ const paramsBuffer = device.device.createBuffer({
975
+ size: params.byteLength,
976
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
977
+ });
978
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
979
+ const bindGroup = device.device.createBindGroup({
980
+ layout: pipeline.getBindGroupLayout(0),
981
+ entries: [
982
+ { binding: 0, resource: { buffer: x.buffer } },
983
+ { binding: 1, resource: { buffer: output.buffer } },
984
+ { binding: 2, resource: { buffer: paramsBuffer } }
985
+ ]
986
+ });
987
+ const encoder = device.createCommandEncoder();
988
+ const pass = encoder.beginComputePass();
989
+ pass.setPipeline(pipeline);
990
+ pass.setBindGroup(0, bindGroup);
991
+ pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
992
+ pass.end();
993
+ device.submit([encoder.finish()]);
994
+ await device.device.queue.onSubmittedWorkDone();
995
+ paramsBuffer.destroy();
996
+ return output;
997
+ }
998
+ async function silu(device, x) {
999
+ const cache = getCache3(device.device);
1000
+ const pipeline = cache.getOrCreate("silu", () => compileSiluKernel(device.device));
1001
+ const output = Tensor.zeros(device, [...x.shape]);
1002
+ const params = new Uint32Array([x.numel]);
1003
+ const paramsBuffer = device.device.createBuffer({
1004
+ size: params.byteLength,
1005
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1006
+ });
1007
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1008
+ const bindGroup = device.device.createBindGroup({
1009
+ layout: pipeline.getBindGroupLayout(0),
1010
+ entries: [
1011
+ { binding: 0, resource: { buffer: x.buffer } },
1012
+ { binding: 1, resource: { buffer: output.buffer } },
1013
+ { binding: 2, resource: { buffer: paramsBuffer } }
1014
+ ]
1015
+ });
1016
+ const encoder = device.createCommandEncoder();
1017
+ const pass = encoder.beginComputePass();
1018
+ pass.setPipeline(pipeline);
1019
+ pass.setBindGroup(0, bindGroup);
1020
+ pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
1021
+ pass.end();
1022
+ device.submit([encoder.finish()]);
1023
+ await device.device.queue.onSubmittedWorkDone();
1024
+ paramsBuffer.destroy();
1025
+ return output;
1026
+ }
1027
+ async function relu(device, x) {
1028
+ const cache = getCache3(device.device);
1029
+ const pipeline = cache.getOrCreate("relu", () => compileReluKernel(device.device));
1030
+ const output = Tensor.zeros(device, [...x.shape]);
1031
+ const params = new Uint32Array([x.numel]);
1032
+ const paramsBuffer = device.device.createBuffer({
1033
+ size: params.byteLength,
1034
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1035
+ });
1036
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1037
+ const bindGroup = device.device.createBindGroup({
1038
+ layout: pipeline.getBindGroupLayout(0),
1039
+ entries: [
1040
+ { binding: 0, resource: { buffer: x.buffer } },
1041
+ { binding: 1, resource: { buffer: output.buffer } },
1042
+ { binding: 2, resource: { buffer: paramsBuffer } }
1043
+ ]
1044
+ });
1045
+ const encoder = device.createCommandEncoder();
1046
+ const pass = encoder.beginComputePass();
1047
+ pass.setPipeline(pipeline);
1048
+ pass.setBindGroup(0, bindGroup);
1049
+ pass.dispatchWorkgroups(Math.ceil(x.numel / 256));
1050
+ pass.end();
1051
+ device.submit([encoder.finish()]);
1052
+ await device.device.queue.onSubmittedWorkDone();
1053
+ paramsBuffer.destroy();
1054
+ return output;
1055
+ }
1056
+ function compileGeluKernel(device) {
1057
+ const wgsl = `
1058
+ struct Params {
1059
+ size: u32,
1060
+ }
1061
+
1062
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
1063
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1064
+ @group(0) @binding(2) var<uniform> params: Params;
1065
+
1066
+ const SQRT_2_OVER_PI: f32 = 0.7978845608;
1067
+ const COEFF: f32 = 0.044715;
1068
+
1069
+ @compute @workgroup_size(256)
1070
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1071
+ let idx = gid.x;
1072
+ if (idx >= params.size) {
1073
+ return;
1074
+ }
1075
+
1076
+ let xi = x[idx];
1077
+ let inner = SQRT_2_OVER_PI * (xi + COEFF * xi * xi * xi);
1078
+ output[idx] = xi * 0.5 * (1.0 + tanh(inner));
1079
+ }
1080
+ `;
1081
+ const shaderModule = device.createShaderModule({ code: wgsl });
1082
+ return device.createComputePipeline({
1083
+ layout: "auto",
1084
+ compute: { module: shaderModule, entryPoint: "main" }
1085
+ });
1086
+ }
1087
+ function compileSiluKernel(device) {
1088
+ const wgsl = `
1089
+ struct Params {
1090
+ size: u32,
1091
+ }
1092
+
1093
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
1094
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1095
+ @group(0) @binding(2) var<uniform> params: Params;
1096
+
1097
+ @compute @workgroup_size(256)
1098
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1099
+ let idx = gid.x;
1100
+ if (idx >= params.size) {
1101
+ return;
1102
+ }
1103
+
1104
+ let xi = x[idx];
1105
+ // SiLU: x * sigmoid(x) = x / (1 + exp(-x))
1106
+ output[idx] = xi / (1.0 + exp(-xi));
1107
+ }
1108
+ `;
1109
+ const shaderModule = device.createShaderModule({ code: wgsl });
1110
+ return device.createComputePipeline({
1111
+ layout: "auto",
1112
+ compute: { module: shaderModule, entryPoint: "main" }
1113
+ });
1114
+ }
1115
+ function compileReluKernel(device) {
1116
+ const wgsl = `
1117
+ struct Params {
1118
+ size: u32,
1119
+ }
1120
+
1121
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
1122
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1123
+ @group(0) @binding(2) var<uniform> params: Params;
1124
+
1125
+ @compute @workgroup_size(256)
1126
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1127
+ let idx = gid.x;
1128
+ if (idx >= params.size) {
1129
+ return;
1130
+ }
1131
+
1132
+ output[idx] = max(0.0, x[idx]);
1133
+ }
1134
+ `;
1135
+ const shaderModule = device.createShaderModule({ code: wgsl });
1136
+ return device.createComputePipeline({
1137
+ layout: "auto",
1138
+ compute: { module: shaderModule, entryPoint: "main" }
1139
+ });
1140
+ }
1141
+ // src/ops/softmax.ts
1142
+ var kernelCache4 = null;
1143
+ function getCache4(device) {
1144
+ if (!kernelCache4) {
1145
+ kernelCache4 = new KernelCache(device);
1146
+ }
1147
+ return kernelCache4;
1148
+ }
1149
+ function softmaxCPU(x, shape) {
1150
+ const lastDim = shape[shape.length - 1];
1151
+ const outerSize = x.length / lastDim;
1152
+ const output = new Float32Array(x.length);
1153
+ for (let i = 0;i < outerSize; i++) {
1154
+ const offset = i * lastDim;
1155
+ let maxVal = -Infinity;
1156
+ for (let j = 0;j < lastDim; j++) {
1157
+ maxVal = Math.max(maxVal, x[offset + j]);
1158
+ }
1159
+ let sumExp = 0;
1160
+ for (let j = 0;j < lastDim; j++) {
1161
+ const expVal = Math.exp(x[offset + j] - maxVal);
1162
+ output[offset + j] = expVal;
1163
+ sumExp += expVal;
1164
+ }
1165
+ for (let j = 0;j < lastDim; j++) {
1166
+ output[offset + j] = output[offset + j] / sumExp;
1167
+ }
1168
+ }
1169
+ return output;
1170
+ }
1171
+ function logSoftmaxCPU(x, shape) {
1172
+ const lastDim = shape[shape.length - 1];
1173
+ const outerSize = x.length / lastDim;
1174
+ const output = new Float32Array(x.length);
1175
+ for (let i = 0;i < outerSize; i++) {
1176
+ const offset = i * lastDim;
1177
+ let maxVal = -Infinity;
1178
+ for (let j = 0;j < lastDim; j++) {
1179
+ maxVal = Math.max(maxVal, x[offset + j]);
1180
+ }
1181
+ let sumExp = 0;
1182
+ for (let j = 0;j < lastDim; j++) {
1183
+ sumExp += Math.exp(x[offset + j] - maxVal);
1184
+ }
1185
+ const logSumExp = maxVal + Math.log(sumExp);
1186
+ for (let j = 0;j < lastDim; j++) {
1187
+ output[offset + j] = x[offset + j] - logSumExp;
1188
+ }
1189
+ }
1190
+ return output;
1191
+ }
1192
+ async function softmaxGPU(device, x) {
1193
+ const lastDim = x.shape[x.shape.length - 1];
1194
+ const outerSize = x.numel / lastDim;
1195
+ const cache = getCache4(device.device);
1196
+ const pipeline = cache.getOrCreate(`softmax_${lastDim}`, () => compileSoftmaxKernel(device.device, lastDim));
1197
+ const output = Tensor.zeros(device, [...x.shape]);
1198
+ const params = new Uint32Array([outerSize, lastDim]);
1199
+ const paramsBuffer = device.device.createBuffer({
1200
+ size: params.byteLength,
1201
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1202
+ });
1203
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1204
+ const bindGroup = device.device.createBindGroup({
1205
+ layout: pipeline.getBindGroupLayout(0),
1206
+ entries: [
1207
+ { binding: 0, resource: { buffer: x.buffer } },
1208
+ { binding: 1, resource: { buffer: output.buffer } },
1209
+ { binding: 2, resource: { buffer: paramsBuffer } }
1210
+ ]
1211
+ });
1212
+ const encoder = device.createCommandEncoder();
1213
+ const pass = encoder.beginComputePass();
1214
+ pass.setPipeline(pipeline);
1215
+ pass.setBindGroup(0, bindGroup);
1216
+ pass.dispatchWorkgroups(outerSize);
1217
+ pass.end();
1218
+ device.submit([encoder.finish()]);
1219
+ await device.device.queue.onSubmittedWorkDone();
1220
+ paramsBuffer.destroy();
1221
+ return output;
1222
+ }
1223
+ function compileSoftmaxKernel(device, dim) {
1224
+ const wgsl = `
1225
+ struct Params {
1226
+ outerSize: u32,
1227
+ dim: u32,
1228
+ }
1229
+
1230
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
1231
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1232
+ @group(0) @binding(2) var<uniform> params: Params;
1233
+
1234
+ @compute @workgroup_size(1)
1235
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1236
+ let idx = gid.x;
1237
+ if (idx >= params.outerSize) {
1238
+ return;
1239
+ }
1240
+
1241
+ let dim = params.dim;
1242
+ let offset = idx * dim;
1243
+
1244
+ // Find max
1245
+ var maxVal: f32 = x[offset];
1246
+ for (var j = 1u; j < dim; j = j + 1u) {
1247
+ maxVal = max(maxVal, x[offset + j]);
1248
+ }
1249
+
1250
+ // Compute exp and sum
1251
+ var sumExp: f32 = 0.0;
1252
+ for (var j = 0u; j < dim; j = j + 1u) {
1253
+ let expVal = exp(x[offset + j] - maxVal);
1254
+ output[offset + j] = expVal;
1255
+ sumExp = sumExp + expVal;
1256
+ }
1257
+
1258
+ // Normalize
1259
+ let invSum = 1.0 / sumExp;
1260
+ for (var j = 0u; j < dim; j = j + 1u) {
1261
+ output[offset + j] = output[offset + j] * invSum;
1262
+ }
1263
+ }
1264
+ `;
1265
+ const shaderModule = device.createShaderModule({ code: wgsl });
1266
+ return device.createComputePipeline({
1267
+ layout: "auto",
1268
+ compute: { module: shaderModule, entryPoint: "main" }
1269
+ });
1270
+ }
1271
+ // src/ops/elementwise.ts
1272
+ var kernelCache5 = null;
1273
+ function getCache5(device) {
1274
+ if (!kernelCache5) {
1275
+ kernelCache5 = new KernelCache(device);
1276
+ }
1277
+ return kernelCache5;
1278
+ }
1279
+ function addCPU(a, b) {
1280
+ if (a.length !== b.length) {
1281
+ throw new Error(`Shape mismatch: ${a.length} vs ${b.length}`);
1282
+ }
1283
+ const output = new Float32Array(a.length);
1284
+ for (let i = 0;i < a.length; i++) {
1285
+ output[i] = a[i] + b[i];
1286
+ }
1287
+ return output;
1288
+ }
1289
+ function mulCPU(a, b) {
1290
+ if (a.length !== b.length) {
1291
+ throw new Error(`Shape mismatch: ${a.length} vs ${b.length}`);
1292
+ }
1293
+ const output = new Float32Array(a.length);
1294
+ for (let i = 0;i < a.length; i++) {
1295
+ output[i] = a[i] * b[i];
1296
+ }
1297
+ return output;
1298
+ }
1299
+ function scaleCPU(a, scalar) {
1300
+ const output = new Float32Array(a.length);
1301
+ for (let i = 0;i < a.length; i++) {
1302
+ output[i] = a[i] * scalar;
1303
+ }
1304
+ return output;
1305
+ }
1306
+ function addScalarCPU(a, scalar) {
1307
+ const output = new Float32Array(a.length);
1308
+ for (let i = 0;i < a.length; i++) {
1309
+ output[i] = a[i] + scalar;
1310
+ }
1311
+ return output;
1312
+ }
1313
+ function fmaCPU(a, b, c) {
1314
+ if (a.length !== b.length || a.length !== c.length) {
1315
+ throw new Error("Shape mismatch");
1316
+ }
1317
+ const output = new Float32Array(a.length);
1318
+ for (let i = 0;i < a.length; i++) {
1319
+ output[i] = a[i] * b[i] + c[i];
1320
+ }
1321
+ return output;
1322
+ }
1323
+ async function add(device, a, b) {
1324
+ if (a.numel !== b.numel) {
1325
+ throw new Error(`Shape mismatch: ${a.shape} vs ${b.shape}`);
1326
+ }
1327
+ const cache = getCache5(device.device);
1328
+ const pipeline = cache.getOrCreate("add", () => compileAddKernel(device.device));
1329
+ const output = Tensor.zeros(device, [...a.shape]);
1330
+ const params = new Uint32Array([a.numel]);
1331
+ const paramsBuffer = device.device.createBuffer({
1332
+ size: params.byteLength,
1333
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1334
+ });
1335
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1336
+ const bindGroup = device.device.createBindGroup({
1337
+ layout: pipeline.getBindGroupLayout(0),
1338
+ entries: [
1339
+ { binding: 0, resource: { buffer: a.buffer } },
1340
+ { binding: 1, resource: { buffer: b.buffer } },
1341
+ { binding: 2, resource: { buffer: output.buffer } },
1342
+ { binding: 3, resource: { buffer: paramsBuffer } }
1343
+ ]
1344
+ });
1345
+ const encoder = device.createCommandEncoder();
1346
+ const pass = encoder.beginComputePass();
1347
+ pass.setPipeline(pipeline);
1348
+ pass.setBindGroup(0, bindGroup);
1349
+ pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
1350
+ pass.end();
1351
+ device.submit([encoder.finish()]);
1352
+ await device.device.queue.onSubmittedWorkDone();
1353
+ paramsBuffer.destroy();
1354
+ return output;
1355
+ }
1356
+ async function mul(device, a, b) {
1357
+ if (a.numel !== b.numel) {
1358
+ throw new Error(`Shape mismatch: ${a.shape} vs ${b.shape}`);
1359
+ }
1360
+ const cache = getCache5(device.device);
1361
+ const pipeline = cache.getOrCreate("mul", () => compileMulKernel(device.device));
1362
+ const output = Tensor.zeros(device, [...a.shape]);
1363
+ const params = new Uint32Array([a.numel]);
1364
+ const paramsBuffer = device.device.createBuffer({
1365
+ size: params.byteLength,
1366
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1367
+ });
1368
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1369
+ const bindGroup = device.device.createBindGroup({
1370
+ layout: pipeline.getBindGroupLayout(0),
1371
+ entries: [
1372
+ { binding: 0, resource: { buffer: a.buffer } },
1373
+ { binding: 1, resource: { buffer: b.buffer } },
1374
+ { binding: 2, resource: { buffer: output.buffer } },
1375
+ { binding: 3, resource: { buffer: paramsBuffer } }
1376
+ ]
1377
+ });
1378
+ const encoder = device.createCommandEncoder();
1379
+ const pass = encoder.beginComputePass();
1380
+ pass.setPipeline(pipeline);
1381
+ pass.setBindGroup(0, bindGroup);
1382
+ pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
1383
+ pass.end();
1384
+ device.submit([encoder.finish()]);
1385
+ await device.device.queue.onSubmittedWorkDone();
1386
+ paramsBuffer.destroy();
1387
+ return output;
1388
+ }
1389
+ async function scale(device, a, scalar) {
1390
+ const cache = getCache5(device.device);
1391
+ const pipeline = cache.getOrCreate("scale", () => compileScaleKernel(device.device));
1392
+ const output = Tensor.zeros(device, [...a.shape]);
1393
+ const params = new Float32Array([a.numel, scalar]);
1394
+ const paramsBuffer = device.device.createBuffer({
1395
+ size: 8,
1396
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1397
+ });
1398
+ device.device.queue.writeBuffer(paramsBuffer, 0, new Uint32Array([a.numel]));
1399
+ device.device.queue.writeBuffer(paramsBuffer, 4, new Float32Array([scalar]));
1400
+ const bindGroup = device.device.createBindGroup({
1401
+ layout: pipeline.getBindGroupLayout(0),
1402
+ entries: [
1403
+ { binding: 0, resource: { buffer: a.buffer } },
1404
+ { binding: 1, resource: { buffer: output.buffer } },
1405
+ { binding: 2, resource: { buffer: paramsBuffer } }
1406
+ ]
1407
+ });
1408
+ const encoder = device.createCommandEncoder();
1409
+ const pass = encoder.beginComputePass();
1410
+ pass.setPipeline(pipeline);
1411
+ pass.setBindGroup(0, bindGroup);
1412
+ pass.dispatchWorkgroups(Math.ceil(a.numel / 256));
1413
+ pass.end();
1414
+ device.submit([encoder.finish()]);
1415
+ await device.device.queue.onSubmittedWorkDone();
1416
+ paramsBuffer.destroy();
1417
+ return output;
1418
+ }
1419
+ function compileAddKernel(device) {
1420
+ const wgsl = `
1421
+ struct Params {
1422
+ size: u32,
1423
+ }
1424
+
1425
+ @group(0) @binding(0) var<storage, read> a: array<f32>;
1426
+ @group(0) @binding(1) var<storage, read> b: array<f32>;
1427
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
1428
+ @group(0) @binding(3) var<uniform> params: Params;
1429
+
1430
+ @compute @workgroup_size(256)
1431
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1432
+ let idx = gid.x;
1433
+ if (idx >= params.size) {
1434
+ return;
1435
+ }
1436
+ output[idx] = a[idx] + b[idx];
1437
+ }
1438
+ `;
1439
+ const shaderModule = device.createShaderModule({ code: wgsl });
1440
+ return device.createComputePipeline({
1441
+ layout: "auto",
1442
+ compute: { module: shaderModule, entryPoint: "main" }
1443
+ });
1444
+ }
1445
+ function compileMulKernel(device) {
1446
+ const wgsl = `
1447
+ struct Params {
1448
+ size: u32,
1449
+ }
1450
+
1451
+ @group(0) @binding(0) var<storage, read> a: array<f32>;
1452
+ @group(0) @binding(1) var<storage, read> b: array<f32>;
1453
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
1454
+ @group(0) @binding(3) var<uniform> params: Params;
1455
+
1456
+ @compute @workgroup_size(256)
1457
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1458
+ let idx = gid.x;
1459
+ if (idx >= params.size) {
1460
+ return;
1461
+ }
1462
+ output[idx] = a[idx] * b[idx];
1463
+ }
1464
+ `;
1465
+ const shaderModule = device.createShaderModule({ code: wgsl });
1466
+ return device.createComputePipeline({
1467
+ layout: "auto",
1468
+ compute: { module: shaderModule, entryPoint: "main" }
1469
+ });
1470
+ }
1471
+ function compileScaleKernel(device) {
1472
+ const wgsl = `
1473
+ struct Params {
1474
+ size: u32,
1475
+ scalar: f32,
1476
+ }
1477
+
1478
+ @group(0) @binding(0) var<storage, read> a: array<f32>;
1479
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1480
+ @group(0) @binding(2) var<uniform> params: Params;
1481
+
1482
+ @compute @workgroup_size(256)
1483
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1484
+ let idx = gid.x;
1485
+ if (idx >= params.size) {
1486
+ return;
1487
+ }
1488
+ output[idx] = a[idx] * params.scalar;
1489
+ }
1490
+ `;
1491
+ const shaderModule = device.createShaderModule({ code: wgsl });
1492
+ return device.createComputePipeline({
1493
+ layout: "auto",
1494
+ compute: { module: shaderModule, entryPoint: "main" }
1495
+ });
1496
+ }
1497
+ // src/ops/embedding.ts
1498
+ var kernelCache6 = null;
1499
+ function getCache6(device) {
1500
+ if (!kernelCache6) {
1501
+ kernelCache6 = new KernelCache(device);
1502
+ }
1503
+ return kernelCache6;
1504
+ }
1505
+ function embeddingCPU(embeddings, tokens, embeddingDim) {
1506
+ const seqLen = tokens.length;
1507
+ const output = new Float32Array(seqLen * embeddingDim);
1508
+ for (let i = 0;i < seqLen; i++) {
1509
+ const tokenId = tokens[i];
1510
+ const srcOffset = tokenId * embeddingDim;
1511
+ const dstOffset = i * embeddingDim;
1512
+ for (let j = 0;j < embeddingDim; j++) {
1513
+ output[dstOffset + j] = embeddings[srcOffset + j];
1514
+ }
1515
+ }
1516
+ return output;
1517
+ }
1518
+ async function embedding(device, embeddings, tokens) {
1519
+ if (embeddings.shape.length !== 2) {
1520
+ throw new Error("Embedding table must be 2D [vocabSize, embeddingDim]");
1521
+ }
1522
+ if (tokens.shape.length !== 1) {
1523
+ throw new Error("Tokens must be 1D [seqLen]");
1524
+ }
1525
+ const [, embeddingDim] = embeddings.shape;
1526
+ const seqLen = tokens.shape[0];
1527
+ const cache = getCache6(device.device);
1528
+ const pipeline = cache.getOrCreate(`embedding_${embeddingDim}`, () => compileEmbeddingKernel(device.device, embeddingDim));
1529
+ const output = Tensor.zeros(device, [seqLen, embeddingDim]);
1530
+ const params = new Uint32Array([seqLen, embeddingDim]);
1531
+ const paramsBuffer = device.device.createBuffer({
1532
+ size: params.byteLength,
1533
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1534
+ });
1535
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1536
+ const bindGroup = device.device.createBindGroup({
1537
+ layout: pipeline.getBindGroupLayout(0),
1538
+ entries: [
1539
+ { binding: 0, resource: { buffer: embeddings.buffer } },
1540
+ { binding: 1, resource: { buffer: tokens.buffer } },
1541
+ { binding: 2, resource: { buffer: output.buffer } },
1542
+ { binding: 3, resource: { buffer: paramsBuffer } }
1543
+ ]
1544
+ });
1545
+ const encoder = device.createCommandEncoder();
1546
+ const pass = encoder.beginComputePass();
1547
+ pass.setPipeline(pipeline);
1548
+ pass.setBindGroup(0, bindGroup);
1549
+ pass.dispatchWorkgroups(seqLen);
1550
+ pass.end();
1551
+ device.submit([encoder.finish()]);
1552
+ await device.device.queue.onSubmittedWorkDone();
1553
+ paramsBuffer.destroy();
1554
+ return output;
1555
+ }
1556
+ function compileEmbeddingKernel(device, embeddingDim) {
1557
+ const wgsl = `
1558
+ struct Params {
1559
+ seqLen: u32,
1560
+ embeddingDim: u32,
1561
+ }
1562
+
1563
+ @group(0) @binding(0) var<storage, read> embeddings: array<f32>;
1564
+ @group(0) @binding(1) var<storage, read> tokens: array<u32>;
1565
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
1566
+ @group(0) @binding(3) var<uniform> params: Params;
1567
+
1568
+ @compute @workgroup_size(1)
1569
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1570
+ let seqIdx = gid.x;
1571
+ if (seqIdx >= params.seqLen) {
1572
+ return;
1573
+ }
1574
+
1575
+ let tokenId = tokens[seqIdx];
1576
+ let srcOffset = tokenId * params.embeddingDim;
1577
+ let dstOffset = seqIdx * params.embeddingDim;
1578
+
1579
+ for (var j = 0u; j < params.embeddingDim; j = j + 1u) {
1580
+ output[dstOffset + j] = embeddings[srcOffset + j];
1581
+ }
1582
+ }
1583
+ `;
1584
+ const shaderModule = device.createShaderModule({ code: wgsl });
1585
+ return device.createComputePipeline({
1586
+ layout: "auto",
1587
+ compute: { module: shaderModule, entryPoint: "main" }
1588
+ });
1589
+ }
1590
+ function batchedEmbeddingCPU(embeddings, tokens, embeddingDim) {
1591
+ const batchSize = tokens.length;
1592
+ const seqLen = tokens[0]?.length ?? 0;
1593
+ const output = new Float32Array(batchSize * seqLen * embeddingDim);
1594
+ for (let b = 0;b < batchSize; b++) {
1595
+ for (let i = 0;i < seqLen; i++) {
1596
+ const tokenId = tokens[b][i];
1597
+ const srcOffset = tokenId * embeddingDim;
1598
+ const dstOffset = (b * seqLen + i) * embeddingDim;
1599
+ for (let j = 0;j < embeddingDim; j++) {
1600
+ output[dstOffset + j] = embeddings[srcOffset + j];
1601
+ }
1602
+ }
1603
+ }
1604
+ return output;
1605
+ }
1606
+ // src/ops/reshape.ts
1607
+ var kernelCache7 = null;
1608
+ function getCache7(device) {
1609
+ if (!kernelCache7) {
1610
+ kernelCache7 = new KernelCache(device);
1611
+ }
1612
+ return kernelCache7;
1613
+ }
1614
+ function transpose2DCPU(x, rows, cols) {
1615
+ const output = new Float32Array(x.length);
1616
+ for (let i = 0;i < rows; i++) {
1617
+ for (let j = 0;j < cols; j++) {
1618
+ output[j * rows + i] = x[i * cols + j];
1619
+ }
1620
+ }
1621
+ return output;
1622
+ }
1623
+ function transposeCPU(x, shape) {
1624
+ if (shape.length < 2) {
1625
+ throw new Error("Transpose requires at least 2D tensor");
1626
+ }
1627
+ const M = shape[shape.length - 2];
1628
+ const N = shape[shape.length - 1];
1629
+ const batchSize = shape.slice(0, -2).reduce((a, b) => a * b, 1);
1630
+ const output = new Float32Array(x.length);
1631
+ const matrixSize = M * N;
1632
+ for (let b = 0;b < batchSize; b++) {
1633
+ const batchOffset = b * matrixSize;
1634
+ for (let i = 0;i < M; i++) {
1635
+ for (let j = 0;j < N; j++) {
1636
+ output[batchOffset + j * M + i] = x[batchOffset + i * N + j];
1637
+ }
1638
+ }
1639
+ }
1640
+ const newShape = [...shape.slice(0, -2), N, M];
1641
+ return { data: output, shape: newShape };
1642
+ }
1643
+ function reshapeCPU(x, oldShape, newShape) {
1644
+ const oldSize = oldShape.reduce((a, b) => a * b, 1);
1645
+ let inferIdx = -1;
1646
+ let knownSize = 1;
1647
+ for (let i = 0;i < newShape.length; i++) {
1648
+ if (newShape[i] === -1) {
1649
+ if (inferIdx !== -1) {
1650
+ throw new Error("Can only have one -1 in reshape");
1651
+ }
1652
+ inferIdx = i;
1653
+ } else {
1654
+ knownSize *= newShape[i];
1655
+ }
1656
+ }
1657
+ const finalShape = [...newShape];
1658
+ if (inferIdx !== -1) {
1659
+ if (oldSize % knownSize !== 0) {
1660
+ throw new Error(`Cannot reshape ${oldShape} to ${newShape}`);
1661
+ }
1662
+ finalShape[inferIdx] = oldSize / knownSize;
1663
+ }
1664
+ const newSize = finalShape.reduce((a, b) => a * b, 1);
1665
+ if (oldSize !== newSize) {
1666
+ throw new Error(`Shape mismatch: ${oldSize} vs ${newSize}`);
1667
+ }
1668
+ return { data: x, shape: finalShape };
1669
+ }
1670
+ async function transpose2D(device, x) {
1671
+ if (x.shape.length !== 2) {
1672
+ throw new Error("transpose2D requires 2D tensor");
1673
+ }
1674
+ const [rows, cols] = x.shape;
1675
+ const cache = getCache7(device.device);
1676
+ const pipeline = cache.getOrCreate(`transpose2d_${rows}_${cols}`, () => compileTranspose2DKernel(device.device));
1677
+ const output = Tensor.zeros(device, [cols, rows]);
1678
+ const params = new Uint32Array([rows, cols]);
1679
+ const paramsBuffer = device.device.createBuffer({
1680
+ size: params.byteLength,
1681
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
1682
+ });
1683
+ device.device.queue.writeBuffer(paramsBuffer, 0, params);
1684
+ const bindGroup = device.device.createBindGroup({
1685
+ layout: pipeline.getBindGroupLayout(0),
1686
+ entries: [
1687
+ { binding: 0, resource: { buffer: x.buffer } },
1688
+ { binding: 1, resource: { buffer: output.buffer } },
1689
+ { binding: 2, resource: { buffer: paramsBuffer } }
1690
+ ]
1691
+ });
1692
+ const encoder = device.createCommandEncoder();
1693
+ const pass = encoder.beginComputePass();
1694
+ pass.setPipeline(pipeline);
1695
+ pass.setBindGroup(0, bindGroup);
1696
+ pass.dispatchWorkgroups(Math.ceil(cols / 16), Math.ceil(rows / 16));
1697
+ pass.end();
1698
+ device.submit([encoder.finish()]);
1699
+ await device.device.queue.onSubmittedWorkDone();
1700
+ paramsBuffer.destroy();
1701
+ return output;
1702
+ }
1703
+ function permuteCPU(x, shape, dims) {
1704
+ if (dims.length !== shape.length) {
1705
+ throw new Error("Permutation must have same length as shape");
1706
+ }
1707
+ const sorted = [...dims].sort((a, b) => a - b);
1708
+ for (let i = 0;i < sorted.length; i++) {
1709
+ if (sorted[i] !== i) {
1710
+ throw new Error("Invalid permutation");
1711
+ }
1712
+ }
1713
+ const newShape = dims.map((d) => shape[d]);
1714
+ const output = new Float32Array(x.length);
1715
+ const oldStrides = computeStrides(shape);
1716
+ const newStrides = computeStrides(newShape);
1717
+ const ndim = shape.length;
1718
+ const indices = new Array(ndim).fill(0);
1719
+ for (let i = 0;i < x.length; i++) {
1720
+ let remaining = i;
1721
+ for (let d = 0;d < ndim; d++) {
1722
+ indices[d] = Math.floor(remaining / newStrides[d]);
1723
+ remaining = remaining % newStrides[d];
1724
+ }
1725
+ let oldIdx = 0;
1726
+ for (let d = 0;d < ndim; d++) {
1727
+ oldIdx += indices[d] * oldStrides[dims[d]];
1728
+ }
1729
+ output[i] = x[oldIdx];
1730
+ }
1731
+ return { data: output, shape: newShape };
1732
+ }
1733
+ function computeStrides(shape) {
1734
+ const strides = new Array(shape.length);
1735
+ strides[shape.length - 1] = 1;
1736
+ for (let i = shape.length - 2;i >= 0; i--) {
1737
+ strides[i] = strides[i + 1] * shape[i + 1];
1738
+ }
1739
+ return strides;
1740
+ }
1741
+ function compileTranspose2DKernel(device) {
1742
+ const wgsl = `
1743
+ struct Params {
1744
+ rows: u32,
1745
+ cols: u32,
1746
+ }
1747
+
1748
+ @group(0) @binding(0) var<storage, read> x: array<f32>;
1749
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
1750
+ @group(0) @binding(2) var<uniform> params: Params;
1751
+
1752
+ @compute @workgroup_size(16, 16)
1753
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1754
+ let col = gid.x;
1755
+ let row = gid.y;
1756
+
1757
+ if (col >= params.cols || row >= params.rows) {
1758
+ return;
1759
+ }
1760
+
1761
+ // Input: [row, col] at row * cols + col
1762
+ // Output: [col, row] at col * rows + row
1763
+ output[col * params.rows + row] = x[row * params.cols + col];
1764
+ }
1765
+ `;
1766
+ const shaderModule = device.createShaderModule({ code: wgsl });
1767
+ return device.createComputePipeline({
1768
+ layout: "auto",
1769
+ compute: { module: shaderModule, entryPoint: "main" }
1770
+ });
1771
+ }
1772
+ // src/quantization/quantize.ts
1773
+ function quantizeToInt8(x, groupSize = 128, symmetric = true) {
1774
+ const numGroups = Math.ceil(x.length / groupSize);
1775
+ const scales = new Float32Array(numGroups);
1776
+ const zeros = symmetric ? null : new Float32Array(numGroups);
1777
+ const quantized = new Uint8Array(x.length);
1778
+ for (let g = 0;g < numGroups; g++) {
1779
+ const start = g * groupSize;
1780
+ const end = Math.min(start + groupSize, x.length);
1781
+ let minVal = x[start];
1782
+ let maxVal = x[start];
1783
+ for (let i = start;i < end; i++) {
1784
+ minVal = Math.min(minVal, x[i]);
1785
+ maxVal = Math.max(maxVal, x[i]);
1786
+ }
1787
+ if (symmetric) {
1788
+ const absMax = Math.max(Math.abs(minVal), Math.abs(maxVal));
1789
+ scales[g] = absMax / 127;
1790
+ const scale2 = scales[g];
1791
+ const invScale = scale2 > 0 ? 1 / scale2 : 0;
1792
+ for (let i = start;i < end; i++) {
1793
+ const q = Math.round(x[i] * invScale);
1794
+ quantized[i] = Math.max(0, Math.min(255, q + 128));
1795
+ }
1796
+ } else {
1797
+ scales[g] = (maxVal - minVal) / 255;
1798
+ zeros[g] = minVal;
1799
+ const scale2 = scales[g];
1800
+ const invScale = scale2 > 0 ? 1 / scale2 : 0;
1801
+ for (let i = start;i < end; i++) {
1802
+ const q = Math.round((x[i] - zeros[g]) * invScale);
1803
+ quantized[i] = Math.max(0, Math.min(255, q));
1804
+ }
1805
+ }
1806
+ }
1807
+ return {
1808
+ data: quantized,
1809
+ scales,
1810
+ zeros,
1811
+ shape: [x.length],
1812
+ config: { bits: 8, groupSize, symmetric }
1813
+ };
1814
+ }
1815
+ function quantizeToInt4(x, groupSize = 128, symmetric = true) {
1816
+ const numGroups = Math.ceil(x.length / groupSize);
1817
+ const scales = new Float32Array(numGroups);
1818
+ const zeros = symmetric ? null : new Float32Array(numGroups);
1819
+ const packedSize = Math.ceil(x.length / 2);
1820
+ const quantized = new Uint8Array(packedSize);
1821
+ for (let g = 0;g < numGroups; g++) {
1822
+ const start = g * groupSize;
1823
+ const end = Math.min(start + groupSize, x.length);
1824
+ let minVal = x[start];
1825
+ let maxVal = x[start];
1826
+ for (let i = start;i < end; i++) {
1827
+ minVal = Math.min(minVal, x[i]);
1828
+ maxVal = Math.max(maxVal, x[i]);
1829
+ }
1830
+ if (symmetric) {
1831
+ const absMax = Math.max(Math.abs(minVal), Math.abs(maxVal));
1832
+ scales[g] = absMax / 7;
1833
+ const scale2 = scales[g];
1834
+ const invScale = scale2 > 0 ? 1 / scale2 : 0;
1835
+ for (let i = start;i < end; i++) {
1836
+ const q = Math.round(x[i] * invScale);
1837
+ const uq = Math.max(0, Math.min(15, q + 8));
1838
+ const byteIdx = Math.floor(i / 2);
1839
+ if (i % 2 === 0) {
1840
+ quantized[byteIdx] = uq;
1841
+ } else {
1842
+ quantized[byteIdx] = quantized[byteIdx] | uq << 4;
1843
+ }
1844
+ }
1845
+ } else {
1846
+ scales[g] = (maxVal - minVal) / 15;
1847
+ zeros[g] = minVal;
1848
+ const scale2 = scales[g];
1849
+ const invScale = scale2 > 0 ? 1 / scale2 : 0;
1850
+ for (let i = start;i < end; i++) {
1851
+ const q = Math.round((x[i] - zeros[g]) * invScale);
1852
+ const uq = Math.max(0, Math.min(15, q));
1853
+ const byteIdx = Math.floor(i / 2);
1854
+ if (i % 2 === 0) {
1855
+ quantized[byteIdx] = uq;
1856
+ } else {
1857
+ quantized[byteIdx] = quantized[byteIdx] | uq << 4;
1858
+ }
1859
+ }
1860
+ }
1861
+ }
1862
+ return {
1863
+ data: quantized,
1864
+ scales,
1865
+ zeros,
1866
+ shape: [x.length],
1867
+ config: { bits: 4, groupSize, symmetric }
1868
+ };
1869
+ }
1870
+ function dequantizeInt8(qt) {
1871
+ if (qt.config.bits !== 8) {
1872
+ throw new Error("Expected INT8 quantized tensor");
1873
+ }
1874
+ const { data, scales, zeros, config } = qt;
1875
+ const { groupSize, symmetric } = config;
1876
+ const output = new Float32Array(data.length);
1877
+ for (let i = 0;i < data.length; i++) {
1878
+ const g = Math.floor(i / groupSize);
1879
+ const scale2 = scales[g];
1880
+ if (symmetric) {
1881
+ output[i] = (data[i] - 128) * scale2;
1882
+ } else {
1883
+ output[i] = data[i] * scale2 + zeros[g];
1884
+ }
1885
+ }
1886
+ return output;
1887
+ }
1888
+ function dequantizeInt4(qt) {
1889
+ if (qt.config.bits !== 4) {
1890
+ throw new Error("Expected INT4 quantized tensor");
1891
+ }
1892
+ const { data, scales, zeros, shape, config } = qt;
1893
+ const { groupSize, symmetric } = config;
1894
+ const numElements = shape.reduce((a, b) => a * b, 1);
1895
+ const output = new Float32Array(numElements);
1896
+ for (let i = 0;i < numElements; i++) {
1897
+ const byteIdx = Math.floor(i / 2);
1898
+ const isHigh = i % 2 === 1;
1899
+ let q;
1900
+ if (isHigh) {
1901
+ q = data[byteIdx] >> 4 & 15;
1902
+ } else {
1903
+ q = data[byteIdx] & 15;
1904
+ }
1905
+ const g = Math.floor(i / groupSize);
1906
+ const scale2 = scales[g];
1907
+ if (symmetric) {
1908
+ output[i] = (q - 8) * scale2;
1909
+ } else {
1910
+ output[i] = q * scale2 + zeros[g];
1911
+ }
1912
+ }
1913
+ return output;
1914
+ }
1915
+ function quantizationError(original, reconstructed) {
1916
+ if (original.length !== reconstructed.length) {
1917
+ throw new Error("Length mismatch");
1918
+ }
1919
+ let sumSqError = 0;
1920
+ for (let i = 0;i < original.length; i++) {
1921
+ const diff = original[i] - reconstructed[i];
1922
+ sumSqError += diff * diff;
1923
+ }
1924
+ return sumSqError / original.length;
1925
+ }
1926
+ function getMemorySavings(originalBytes, qt) {
1927
+ const dataBytes = qt.data.byteLength;
1928
+ const scaleBytes = qt.scales.byteLength;
1929
+ const zeroBytes = qt.zeros?.byteLength ?? 0;
1930
+ const quantizedBytes = dataBytes + scaleBytes + zeroBytes;
1931
+ return {
1932
+ originalBytes,
1933
+ quantizedBytes,
1934
+ savings: originalBytes - quantizedBytes,
1935
+ ratio: originalBytes / quantizedBytes
1936
+ };
1937
+ }
1938
+ // src/quantization/qmatmul.ts
1939
+ function qmatmulInt8CPU(A, B, M, K, N) {
1940
+ if (B.config.bits !== 8) {
1941
+ throw new Error("Expected INT8 weights");
1942
+ }
1943
+ const { data: Bq, scales, zeros, config } = B;
1944
+ const { groupSize, symmetric } = config;
1945
+ const output = new Float32Array(M * N);
1946
+ for (let m = 0;m < M; m++) {
1947
+ for (let n = 0;n < N; n++) {
1948
+ let sum = 0;
1949
+ for (let k = 0;k < K; k++) {
1950
+ const a = A[m * K + k];
1951
+ const bIdx = k * N + n;
1952
+ const g = Math.floor(bIdx / groupSize);
1953
+ const scale2 = scales[g];
1954
+ let b;
1955
+ if (symmetric) {
1956
+ b = (Bq[bIdx] - 128) * scale2;
1957
+ } else {
1958
+ b = Bq[bIdx] * scale2 + zeros[g];
1959
+ }
1960
+ sum += a * b;
1961
+ }
1962
+ output[m * N + n] = sum;
1963
+ }
1964
+ }
1965
+ return output;
1966
+ }
1967
+ function qmatmulInt4CPU(A, B, M, K, N) {
1968
+ if (B.config.bits !== 4) {
1969
+ throw new Error("Expected INT4 weights");
1970
+ }
1971
+ const { data: Bq, scales, zeros, config } = B;
1972
+ const { groupSize, symmetric } = config;
1973
+ const output = new Float32Array(M * N);
1974
+ for (let m = 0;m < M; m++) {
1975
+ for (let n = 0;n < N; n++) {
1976
+ let sum = 0;
1977
+ for (let k = 0;k < K; k++) {
1978
+ const a = A[m * K + k];
1979
+ const bIdx = k * N + n;
1980
+ const byteIdx = Math.floor(bIdx / 2);
1981
+ const isHigh = bIdx % 2 === 1;
1982
+ let q;
1983
+ if (isHigh) {
1984
+ q = Bq[byteIdx] >> 4 & 15;
1985
+ } else {
1986
+ q = Bq[byteIdx] & 15;
1987
+ }
1988
+ const g = Math.floor(bIdx / groupSize);
1989
+ const scale2 = scales[g];
1990
+ let b;
1991
+ if (symmetric) {
1992
+ b = (q - 8) * scale2;
1993
+ } else {
1994
+ b = q * scale2 + zeros[g];
1995
+ }
1996
+ sum += a * b;
1997
+ }
1998
+ output[m * N + n] = sum;
1999
+ }
2000
+ }
2001
+ return output;
2002
+ }
2003
+ function qmatmulInt8BlockCPU(A, B, M, K, N, blockSize = 32) {
2004
+ if (B.config.bits !== 8) {
2005
+ throw new Error("Expected INT8 weights");
2006
+ }
2007
+ const { data: Bq, scales, zeros, config } = B;
2008
+ const { groupSize, symmetric } = config;
2009
+ const output = new Float32Array(M * N);
2010
+ for (let mb = 0;mb < M; mb += blockSize) {
2011
+ const mEnd = Math.min(mb + blockSize, M);
2012
+ for (let nb = 0;nb < N; nb += blockSize) {
2013
+ const nEnd = Math.min(nb + blockSize, N);
2014
+ for (let kb = 0;kb < K; kb += blockSize) {
2015
+ const kEnd = Math.min(kb + blockSize, K);
2016
+ for (let m = mb;m < mEnd; m++) {
2017
+ for (let n = nb;n < nEnd; n++) {
2018
+ let sum = output[m * N + n];
2019
+ for (let k = kb;k < kEnd; k++) {
2020
+ const a = A[m * K + k];
2021
+ const bIdx = k * N + n;
2022
+ const g = Math.floor(bIdx / groupSize);
2023
+ const scale2 = scales[g];
2024
+ let b;
2025
+ if (symmetric) {
2026
+ b = (Bq[bIdx] - 128) * scale2;
2027
+ } else {
2028
+ b = Bq[bIdx] * scale2 + zeros[g];
2029
+ }
2030
+ sum += a * b;
2031
+ }
2032
+ output[m * N + n] = sum;
2033
+ }
2034
+ }
2035
+ }
2036
+ }
2037
+ }
2038
+ return output;
2039
+ }
2040
+ function estimateQMatMulFlops(M, K, N) {
2041
+ return 2 * M * K * N;
2042
+ }
2043
+ function estimateQMatMulBandwidth(M, K, N, bits, groupSize) {
2044
+ const activationBytes = M * K * 4;
2045
+ const weightElements = K * N;
2046
+ const weightBytes = bits === 8 ? weightElements : Math.ceil(weightElements / 2);
2047
+ const numGroups = Math.ceil(weightElements / groupSize);
2048
+ const scaleBytes = numGroups * 4;
2049
+ const outputBytes = M * N * 4;
2050
+ return {
2051
+ activationBytes,
2052
+ weightBytes,
2053
+ scaleBytes,
2054
+ outputBytes,
2055
+ totalBytes: activationBytes + weightBytes + scaleBytes + outputBytes
2056
+ };
2057
+ }
2058
+ // src/attention/block-sparse/format.ts
2059
+ function buildBlockSparseCSR(seqLen, pattern, blockSize = 64) {
2060
+ const numBlockRows = Math.ceil(seqLen / blockSize);
2061
+ const numBlockCols = Math.ceil(seqLen / blockSize);
2062
+ const nonZeroBlocks = [];
2063
+ for (let br = 0;br < numBlockRows; br++) {
2064
+ for (let bc = 0;bc < numBlockCols; bc++) {
2065
+ if (isBlockNonZero(br, bc, blockSize, seqLen, pattern)) {
2066
+ nonZeroBlocks.push({ row: br, col: bc });
2067
+ }
2068
+ }
2069
+ }
2070
+ const rowPtr = new Uint32Array(numBlockRows + 1);
2071
+ const colIdx = new Uint32Array(nonZeroBlocks.length);
2072
+ let idx = 0;
2073
+ for (let br = 0;br < numBlockRows; br++) {
2074
+ rowPtr[br] = idx;
2075
+ for (const block of nonZeroBlocks) {
2076
+ if (block.row === br) {
2077
+ colIdx[idx++] = block.col;
2078
+ }
2079
+ }
2080
+ }
2081
+ rowPtr[numBlockRows] = idx;
2082
+ return {
2083
+ blockSize,
2084
+ rowPtr,
2085
+ colIdx,
2086
+ numRows: seqLen,
2087
+ numCols: seqLen,
2088
+ numBlockRows,
2089
+ numBlockCols,
2090
+ nnzBlocks: nonZeroBlocks.length
2091
+ };
2092
+ }
2093
+ function isBlockNonZero(blockRow, blockCol, blockSize, seqLen, pattern) {
2094
+ const rowStart = blockRow * blockSize;
2095
+ const rowEnd = Math.min(rowStart + blockSize, seqLen);
2096
+ const colStart = blockCol * blockSize;
2097
+ const colEnd = Math.min(colStart + blockSize, seqLen);
2098
+ switch (pattern.type) {
2099
+ case "dense":
2100
+ return true;
2101
+ case "causal":
2102
+ return rowEnd > colStart;
2103
+ case "sliding": {
2104
+ const windowSize = pattern.windowSize;
2105
+ return colStart < rowEnd && colEnd > Math.max(0, rowStart - windowSize);
2106
+ }
2107
+ case "global-local": {
2108
+ const { globalTokens, localWindow } = pattern;
2109
+ for (const gt of globalTokens) {
2110
+ if (gt >= colStart && gt < colEnd)
2111
+ return true;
2112
+ }
2113
+ return colStart < rowEnd && colEnd > Math.max(0, rowStart - localWindow);
2114
+ }
2115
+ case "custom":
2116
+ for (let i = rowStart;i < rowEnd; i++) {
2117
+ for (let j = colStart;j < colEnd; j++) {
2118
+ if (pattern.mask[i]?.[j])
2119
+ return true;
2120
+ }
2121
+ }
2122
+ return false;
2123
+ default:
2124
+ return true;
2125
+ }
2126
+ }
2127
+ function getSparsityRatio(csr) {
2128
+ const totalBlocks = csr.numBlockRows * csr.numBlockCols;
2129
+ return 1 - csr.nnzBlocks / totalBlocks;
2130
+ }
2131
+ function estimateMemorySavings(csr) {
2132
+ const denseBytes = csr.numRows * csr.numCols * 4;
2133
+ const sparseBytes = csr.nnzBlocks * csr.blockSize * csr.blockSize * 4 + (csr.numBlockRows + 1) * 4 + csr.nnzBlocks * 4;
2134
+ return {
2135
+ denseBytes,
2136
+ sparseBytes,
2137
+ savingsRatio: 1 - sparseBytes / denseBytes
2138
+ };
2139
+ }
2140
+
2141
+ // src/attention/flash-attention.ts
2142
+ var kernelCache8 = null;
2143
+ function getCache8(device) {
2144
+ if (!kernelCache8) {
2145
+ kernelCache8 = new KernelCache(device);
2146
+ }
2147
+ return kernelCache8;
2148
+ }
2149
+ async function flashAttention(device, q, k, v, config) {
2150
+ const { numHeads, headDim, seqLen } = config;
2151
+ const scale2 = config.scale ?? 1 / Math.sqrt(headDim);
2152
+ const blockSize = config.blockSize ?? 64;
2153
+ const pattern = config.pattern ?? { type: "causal" };
2154
+ const sparseMask = buildBlockSparseCSR(seqLen, pattern, blockSize);
2155
+ const cache = getCache8(device.device);
2156
+ const pipeline = cache.getOrCreate(`flash_attn_${numHeads}_${headDim}_${seqLen}_${blockSize}`, () => compileFlashAttentionKernel(device.device, config, blockSize));
2157
+ const output = Tensor.zeros(device, [seqLen, numHeads, headDim]);
2158
+ const paramsData = new Float32Array([
2159
+ seqLen,
2160
+ numHeads,
2161
+ headDim,
2162
+ scale2,
2163
+ blockSize,
2164
+ sparseMask.numBlockRows,
2165
+ 0,
2166
+ 0
2167
+ ]);
2168
+ const paramsBuffer = device.device.createBuffer({
2169
+ size: paramsData.byteLength,
2170
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
2171
+ });
2172
+ device.device.queue.writeBuffer(paramsBuffer, 0, paramsData);
2173
+ const rowPtrBuffer = device.device.createBuffer({
2174
+ size: sparseMask.rowPtr.byteLength,
2175
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
2176
+ });
2177
+ device.device.queue.writeBuffer(rowPtrBuffer, 0, new Uint32Array(sparseMask.rowPtr));
2178
+ const colIdxBuffer = device.device.createBuffer({
2179
+ size: Math.max(sparseMask.colIdx.byteLength, 4),
2180
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
2181
+ });
2182
+ if (sparseMask.colIdx.length > 0) {
2183
+ device.device.queue.writeBuffer(colIdxBuffer, 0, new Uint32Array(sparseMask.colIdx));
2184
+ }
2185
+ const bindGroup = device.device.createBindGroup({
2186
+ layout: pipeline.getBindGroupLayout(0),
2187
+ entries: [
2188
+ { binding: 0, resource: { buffer: q.buffer } },
2189
+ { binding: 1, resource: { buffer: k.buffer } },
2190
+ { binding: 2, resource: { buffer: v.buffer } },
2191
+ { binding: 3, resource: { buffer: output.buffer } },
2192
+ { binding: 4, resource: { buffer: paramsBuffer } },
2193
+ { binding: 5, resource: { buffer: rowPtrBuffer } },
2194
+ { binding: 6, resource: { buffer: colIdxBuffer } }
2195
+ ]
2196
+ });
2197
+ const encoder = device.createCommandEncoder();
2198
+ const pass = encoder.beginComputePass();
2199
+ pass.setPipeline(pipeline);
2200
+ pass.setBindGroup(0, bindGroup);
2201
+ const workgroupsX = sparseMask.numBlockRows;
2202
+ const workgroupsY = numHeads;
2203
+ pass.dispatchWorkgroups(workgroupsX, workgroupsY);
2204
+ pass.end();
2205
+ device.submit([encoder.finish()]);
2206
+ await device.device.queue.onSubmittedWorkDone();
2207
+ paramsBuffer.destroy();
2208
+ rowPtrBuffer.destroy();
2209
+ colIdxBuffer.destroy();
2210
+ return output;
2211
+ }
2212
+ function compileFlashAttentionKernel(device, config, blockSize) {
2213
+ const { headDim } = config;
2214
+ const wgsl = `
2215
+ // WebInfer FlashAttention Kernel
2216
+ // Implements online softmax with tiling for memory efficiency
2217
+
2218
+ struct Params {
2219
+ seqLen: u32,
2220
+ numHeads: u32,
2221
+ headDim: u32,
2222
+ scale: f32,
2223
+ blockSize: u32,
2224
+ numBlockRows: u32,
2225
+ _pad0: u32,
2226
+ _pad1: u32,
2227
+ }
2228
+
2229
+ @group(0) @binding(0) var<storage, read> Q: array<f32>;
2230
+ @group(0) @binding(1) var<storage, read> K: array<f32>;
2231
+ @group(0) @binding(2) var<storage, read> V: array<f32>;
2232
+ @group(0) @binding(3) var<storage, read_write> O: array<f32>;
2233
+ @group(0) @binding(4) var<uniform> params: Params;
2234
+ @group(0) @binding(5) var<storage, read> blockRowPtr: array<u32>;
2235
+ @group(0) @binding(6) var<storage, read> blockColIdx: array<u32>;
2236
+
2237
+ // Shared memory for tiles
2238
+ var<workgroup> tileQ: array<f32, ${blockSize * headDim}>;
2239
+ var<workgroup> tileK: array<f32, ${blockSize * headDim}>;
2240
+ var<workgroup> tileV: array<f32, ${blockSize * headDim}>;
2241
+ var<workgroup> tileS: array<f32, ${blockSize * blockSize}>;
2242
+
2243
+ // Online softmax state per row
2244
+ var<workgroup> rowMax: array<f32, ${blockSize}>;
2245
+ var<workgroup> rowSum: array<f32, ${blockSize}>;
2246
+ var<workgroup> rowOut: array<f32, ${blockSize * headDim}>;
2247
+
2248
+ @compute @workgroup_size(${blockSize})
2249
+ fn main(
2250
+ @builtin(workgroup_id) wgId: vec3<u32>,
2251
+ @builtin(local_invocation_id) localId: vec3<u32>
2252
+ ) {
2253
+ let blockRowIdx = wgId.x;
2254
+ let headIdx = wgId.y;
2255
+ let tid = localId.x;
2256
+
2257
+ let blockSize = params.blockSize;
2258
+ let headDim = params.headDim;
2259
+ let seqLen = params.seqLen;
2260
+ let scale = params.scale;
2261
+
2262
+ // Global row index
2263
+ let globalRow = blockRowIdx * blockSize + tid;
2264
+ let validRow = globalRow < seqLen;
2265
+
2266
+ // Initialize online softmax state
2267
+ rowMax[tid] = -3.402823e+38f; // -inf
2268
+ rowSum[tid] = 0.0f;
2269
+
2270
+ // Initialize output accumulator
2271
+ for (var d = 0u; d < headDim; d = d + 1u) {
2272
+ rowOut[tid * headDim + d] = 0.0f;
2273
+ }
2274
+
2275
+ workgroupBarrier();
2276
+
2277
+ // Load Q tile for this block row
2278
+ if (validRow) {
2279
+ for (var d = 0u; d < headDim; d = d + 1u) {
2280
+ let qIdx = globalRow * params.numHeads * headDim + headIdx * headDim + d;
2281
+ tileQ[tid * headDim + d] = Q[qIdx];
2282
+ }
2283
+ }
2284
+
2285
+ workgroupBarrier();
2286
+
2287
+ // Iterate over non-zero blocks in this row (block-sparse)
2288
+ let blockStart = blockRowPtr[blockRowIdx];
2289
+ let blockEnd = blockRowPtr[blockRowIdx + 1u];
2290
+
2291
+ for (var b = blockStart; b < blockEnd; b = b + 1u) {
2292
+ let blockColIdx_b = blockColIdx[b];
2293
+ let globalCol = blockColIdx_b * blockSize + tid;
2294
+ let validCol = globalCol < seqLen;
2295
+
2296
+ // Load K tile
2297
+ if (validCol) {
2298
+ for (var d = 0u; d < headDim; d = d + 1u) {
2299
+ let kIdx = globalCol * params.numHeads * headDim + headIdx * headDim + d;
2300
+ tileK[tid * headDim + d] = K[kIdx];
2301
+ }
2302
+ } else {
2303
+ for (var d = 0u; d < headDim; d = d + 1u) {
2304
+ tileK[tid * headDim + d] = 0.0f;
2305
+ }
2306
+ }
2307
+
2308
+ // Load V tile
2309
+ if (validCol) {
2310
+ for (var d = 0u; d < headDim; d = d + 1u) {
2311
+ let vIdx = globalCol * params.numHeads * headDim + headIdx * headDim + d;
2312
+ tileV[tid * headDim + d] = V[vIdx];
2313
+ }
2314
+ } else {
2315
+ for (var d = 0u; d < headDim; d = d + 1u) {
2316
+ tileV[tid * headDim + d] = 0.0f;
2317
+ }
2318
+ }
2319
+
2320
+ workgroupBarrier();
2321
+
2322
+ // Compute attention scores S = Q @ K^T * scale
2323
+ // Each thread computes one row of scores
2324
+ if (validRow) {
2325
+ for (var j = 0u; j < blockSize; j = j + 1u) {
2326
+ var score = 0.0f;
2327
+ for (var d = 0u; d < headDim; d = d + 1u) {
2328
+ score = score + tileQ[tid * headDim + d] * tileK[j * headDim + d];
2329
+ }
2330
+ score = score * scale;
2331
+
2332
+ // Apply causal mask
2333
+ let colPos = blockColIdx_b * blockSize + j;
2334
+ if (colPos > globalRow) {
2335
+ score = -3.402823e+38f; // -inf for masked positions
2336
+ }
2337
+
2338
+ tileS[tid * blockSize + j] = score;
2339
+ }
2340
+ }
2341
+
2342
+ workgroupBarrier();
2343
+
2344
+ // Online softmax update
2345
+ if (validRow) {
2346
+ // Find max in this tile
2347
+ var tileMax = -3.402823e+38f;
2348
+ for (var j = 0u; j < blockSize; j = j + 1u) {
2349
+ tileMax = max(tileMax, tileS[tid * blockSize + j]);
2350
+ }
2351
+
2352
+ // Update running max
2353
+ let prevMax = rowMax[tid];
2354
+ let newMax = max(prevMax, tileMax);
2355
+ rowMax[tid] = newMax;
2356
+
2357
+ // Rescale previous sum and output
2358
+ let rescale = exp(prevMax - newMax);
2359
+ rowSum[tid] = rowSum[tid] * rescale;
2360
+ for (var d = 0u; d < headDim; d = d + 1u) {
2361
+ rowOut[tid * headDim + d] = rowOut[tid * headDim + d] * rescale;
2362
+ }
2363
+
2364
+ // Compute softmax for this tile and accumulate
2365
+ for (var j = 0u; j < blockSize; j = j + 1u) {
2366
+ let p = exp(tileS[tid * blockSize + j] - newMax);
2367
+ rowSum[tid] = rowSum[tid] + p;
2368
+
2369
+ // Accumulate output: O += p * V
2370
+ for (var d = 0u; d < headDim; d = d + 1u) {
2371
+ rowOut[tid * headDim + d] = rowOut[tid * headDim + d] + p * tileV[j * headDim + d];
2372
+ }
2373
+ }
2374
+ }
2375
+
2376
+ workgroupBarrier();
2377
+ }
2378
+
2379
+ // Final normalization and write output
2380
+ if (validRow) {
2381
+ let sumInv = 1.0f / rowSum[tid];
2382
+ for (var d = 0u; d < headDim; d = d + 1u) {
2383
+ let oIdx = globalRow * params.numHeads * headDim + headIdx * headDim + d;
2384
+ O[oIdx] = rowOut[tid * headDim + d] * sumInv;
2385
+ }
2386
+ }
2387
+ }
2388
+ `;
2389
+ const shaderModule = device.createShaderModule({ code: wgsl });
2390
+ return device.createComputePipeline({
2391
+ layout: "auto",
2392
+ compute: {
2393
+ module: shaderModule,
2394
+ entryPoint: "main"
2395
+ }
2396
+ });
2397
+ }
2398
+ function attentionCPU(q, k, v, seqLen, numHeads, headDim, causal = true) {
2399
+ const output = new Float32Array(seqLen * numHeads * headDim);
2400
+ const scale2 = 1 / Math.sqrt(headDim);
2401
+ for (let h = 0;h < numHeads; h++) {
2402
+ for (let i = 0;i < seqLen; i++) {
2403
+ const scores = new Float32Array(seqLen);
2404
+ let maxScore = -Infinity;
2405
+ for (let j = 0;j < seqLen; j++) {
2406
+ if (causal && j > i) {
2407
+ scores[j] = -Infinity;
2408
+ } else {
2409
+ let dot = 0;
2410
+ for (let d = 0;d < headDim; d++) {
2411
+ const qIdx = i * numHeads * headDim + h * headDim + d;
2412
+ const kIdx = j * numHeads * headDim + h * headDim + d;
2413
+ dot += q[qIdx] * k[kIdx];
2414
+ }
2415
+ scores[j] = dot * scale2;
2416
+ }
2417
+ maxScore = Math.max(maxScore, scores[j]);
2418
+ }
2419
+ let sumExp = 0;
2420
+ for (let j = 0;j < seqLen; j++) {
2421
+ scores[j] = Math.exp(scores[j] - maxScore);
2422
+ sumExp += scores[j];
2423
+ }
2424
+ for (let j = 0;j < seqLen; j++) {
2425
+ scores[j] = scores[j] / sumExp;
2426
+ }
2427
+ for (let d = 0;d < headDim; d++) {
2428
+ let sum = 0;
2429
+ for (let j = 0;j < seqLen; j++) {
2430
+ const vIdx = j * numHeads * headDim + h * headDim + d;
2431
+ sum += scores[j] * v[vIdx];
2432
+ }
2433
+ const oIdx = i * numHeads * headDim + h * headDim + d;
2434
+ output[oIdx] = sum;
2435
+ }
2436
+ }
2437
+ }
2438
+ return output;
2439
+ }
2440
+ // src/attention/block-sparse/patterns/causal.ts
2441
+ function buildCausalMask(seqLen, blockSize = 64) {
2442
+ const pattern = { type: "causal" };
2443
+ return buildBlockSparseCSR(seqLen, pattern, blockSize);
2444
+ }
2445
+ function getCausalSparsity(seqLen) {
2446
+ const total = seqLen * seqLen;
2447
+ const nonZero = seqLen * (seqLen + 1) / 2;
2448
+ return 1 - nonZero / total;
2449
+ }
2450
+ // src/attention/block-sparse/patterns/sliding.ts
2451
+ function buildSlidingWindowMask(seqLen, windowSize, blockSize = 64) {
2452
+ const pattern = { type: "sliding", windowSize };
2453
+ return buildBlockSparseCSR(seqLen, pattern, blockSize);
2454
+ }
2455
+ function getSlidingWindowSparsity(seqLen, windowSize) {
2456
+ const total = seqLen * seqLen;
2457
+ const triangularPart = windowSize * (windowSize + 1) / 2;
2458
+ const remainingPositions = Math.max(0, seqLen - windowSize);
2459
+ const windowPart = remainingPositions * (windowSize + 1);
2460
+ const nonZero = triangularPart + windowPart;
2461
+ return 1 - nonZero / total;
2462
+ }
2463
+ function buildCausalSlidingWindowMask(seqLen, windowSize, blockSize = 64) {
2464
+ return buildSlidingWindowMask(seqLen, windowSize, blockSize);
2465
+ }
2466
+ // src/attention/scheduler.ts
2467
+ var TDR_LIMITS = {
2468
+ chrome: 5000,
2469
+ safari: 3000,
2470
+ firefox: 8000,
2471
+ default: 3000
2472
+ };
2473
+
2474
+ class AttentionScheduler {
2475
+ device;
2476
+ tdrLimit;
2477
+ constructor(device) {
2478
+ this.device = device;
2479
+ this.tdrLimit = this.detectTDRLimit();
2480
+ }
2481
+ detectTDRLimit() {
2482
+ if (typeof navigator !== "undefined") {
2483
+ const ua = navigator.userAgent.toLowerCase();
2484
+ if (ua.includes("safari") && !ua.includes("chrome")) {
2485
+ return TDR_LIMITS.safari;
2486
+ } else if (ua.includes("firefox")) {
2487
+ return TDR_LIMITS.firefox;
2488
+ } else if (ua.includes("chrome") || ua.includes("edge")) {
2489
+ return TDR_LIMITS.chrome;
2490
+ }
2491
+ }
2492
+ return TDR_LIMITS.default;
2493
+ }
2494
+ estimateExecutionTime(seqLen, numHeads, headDim) {
2495
+ const flops = 4 * seqLen * seqLen * numHeads * headDim;
2496
+ let tflopsEstimate;
2497
+ switch (this.device.info.vendor) {
2498
+ case "apple":
2499
+ tflopsEstimate = 10;
2500
+ break;
2501
+ case "nvidia":
2502
+ tflopsEstimate = 20;
2503
+ break;
2504
+ case "amd":
2505
+ tflopsEstimate = 15;
2506
+ break;
2507
+ case "intel":
2508
+ tflopsEstimate = 8;
2509
+ break;
2510
+ default:
2511
+ tflopsEstimate = 5;
2512
+ }
2513
+ return flops / (tflopsEstimate * 1000000000000) * 1000 * 2;
2514
+ }
2515
+ computeChunkPlan(seqLen, numHeads, headDim) {
2516
+ const estimatedTime = this.estimateExecutionTime(seqLen, numHeads, headDim);
2517
+ if (estimatedTime < this.tdrLimit * 0.7) {
2518
+ return {
2519
+ numChunks: 1,
2520
+ chunkSize: seqLen,
2521
+ estimatedTimeMs: estimatedTime
2522
+ };
2523
+ }
2524
+ const targetTimePerChunk = this.tdrLimit * 0.5;
2525
+ const numChunks = Math.ceil(estimatedTime / targetTimePerChunk);
2526
+ const chunkSize = Math.ceil(seqLen / numChunks);
2527
+ return {
2528
+ numChunks,
2529
+ chunkSize,
2530
+ estimatedTimeMs: estimatedTime / numChunks
2531
+ };
2532
+ }
2533
+ async yieldToMain() {
2534
+ return new Promise((resolve) => setTimeout(resolve, 0));
2535
+ }
2536
+ mightCauseTDR(seqLen, numHeads, headDim) {
2537
+ const estimatedTime = this.estimateExecutionTime(seqLen, numHeads, headDim);
2538
+ return estimatedTime > this.tdrLimit * 0.7;
2539
+ }
2540
+ getMaxSinglePassSeqLen(numHeads, headDim) {
2541
+ let low = 1;
2542
+ let high = 65536;
2543
+ while (low < high) {
2544
+ const mid = Math.floor((low + high + 1) / 2);
2545
+ const time = this.estimateExecutionTime(mid, numHeads, headDim);
2546
+ if (time <= this.tdrLimit * 0.7) {
2547
+ low = mid;
2548
+ } else {
2549
+ high = mid - 1;
2550
+ }
2551
+ }
2552
+ return low;
2553
+ }
2554
+ }
2555
+ // src/attention/paged-kv/page-table.ts
2556
+ class PagedKVCache {
2557
+ device;
2558
+ config;
2559
+ keyCache;
2560
+ valueCache;
2561
+ pageTable = new Map;
2562
+ freePages = [];
2563
+ nextSeqId = 0;
2564
+ constructor(device, config) {
2565
+ this.device = device;
2566
+ this.config = config;
2567
+ const bytesPerElement = config.dtype === "f16" ? 2 : 4;
2568
+ const pageBytes = config.pageSize * config.numHeads * config.headDim * bytesPerElement;
2569
+ const totalBytes = config.maxPages * pageBytes * config.numLayers;
2570
+ this.keyCache = device.device.createBuffer({
2571
+ size: totalBytes,
2572
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
2573
+ });
2574
+ this.valueCache = device.device.createBuffer({
2575
+ size: totalBytes,
2576
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
2577
+ });
2578
+ for (let i = 0;i < config.maxPages; i++) {
2579
+ this.freePages.push(i);
2580
+ }
2581
+ }
2582
+ allocateSequence(initialLength = 0) {
2583
+ const seqId = this.nextSeqId++;
2584
+ const numPagesNeeded = Math.ceil(initialLength / this.config.pageSize);
2585
+ const pages = [];
2586
+ for (let i = 0;i < numPagesNeeded; i++) {
2587
+ const page = this.allocatePage();
2588
+ if (page === null) {
2589
+ for (const p of pages) {
2590
+ this.freePage(p);
2591
+ }
2592
+ throw new Error("Out of KV cache memory");
2593
+ }
2594
+ pages.push(page);
2595
+ }
2596
+ this.pageTable.set(seqId, {
2597
+ seqId,
2598
+ pages,
2599
+ length: initialLength
2600
+ });
2601
+ return seqId;
2602
+ }
2603
+ extendSequence(seqId, numNewTokens) {
2604
+ const entry = this.pageTable.get(seqId);
2605
+ if (!entry) {
2606
+ throw new Error(`Sequence ${seqId} not found`);
2607
+ }
2608
+ const newLength = entry.length + numNewTokens;
2609
+ const currentPages = entry.pages.length;
2610
+ const neededPages = Math.ceil(newLength / this.config.pageSize);
2611
+ while (entry.pages.length < neededPages) {
2612
+ const page = this.allocatePage();
2613
+ if (page === null) {
2614
+ throw new Error("Out of KV cache memory");
2615
+ }
2616
+ entry.pages.push(page);
2617
+ }
2618
+ entry.length = newLength;
2619
+ }
2620
+ freeSequence(seqId) {
2621
+ const entry = this.pageTable.get(seqId);
2622
+ if (!entry)
2623
+ return;
2624
+ for (const page of entry.pages) {
2625
+ this.freePage(page);
2626
+ }
2627
+ this.pageTable.delete(seqId);
2628
+ }
2629
+ getSequencePages(seqId) {
2630
+ return this.pageTable.get(seqId)?.pages ?? null;
2631
+ }
2632
+ getSequenceLength(seqId) {
2633
+ return this.pageTable.get(seqId)?.length ?? 0;
2634
+ }
2635
+ getPageForPosition(seqId, position) {
2636
+ const entry = this.pageTable.get(seqId);
2637
+ if (!entry)
2638
+ return null;
2639
+ const pageIdx = Math.floor(position / this.config.pageSize);
2640
+ return entry.pages[pageIdx] ?? null;
2641
+ }
2642
+ getOffsetInPage(position) {
2643
+ return position % this.config.pageSize;
2644
+ }
2645
+ allocatePage() {
2646
+ if (this.freePages.length === 0)
2647
+ return null;
2648
+ return this.freePages.pop();
2649
+ }
2650
+ freePage(page) {
2651
+ this.freePages.push(page);
2652
+ }
2653
+ getStats() {
2654
+ const bytesPerElement = this.config.dtype === "f16" ? 2 : 4;
2655
+ const pageBytes = this.config.pageSize * this.config.numHeads * this.config.headDim * bytesPerElement;
2656
+ const usedPages = this.config.maxPages - this.freePages.length;
2657
+ return {
2658
+ totalPages: this.config.maxPages,
2659
+ usedPages,
2660
+ freePages: this.freePages.length,
2661
+ numSequences: this.pageTable.size,
2662
+ memoryUsedBytes: usedPages * pageBytes * this.config.numLayers * 2,
2663
+ memoryTotalBytes: this.config.maxPages * pageBytes * this.config.numLayers * 2
2664
+ };
2665
+ }
2666
+ getBuffers() {
2667
+ return {
2668
+ keyCache: this.keyCache,
2669
+ valueCache: this.valueCache
2670
+ };
2671
+ }
2672
+ getConfig() {
2673
+ return { ...this.config };
2674
+ }
2675
+ dispose() {
2676
+ this.keyCache.destroy();
2677
+ this.valueCache.destroy();
2678
+ this.pageTable.clear();
2679
+ this.freePages = [];
2680
+ }
2681
+ }
2682
+ // src/attention/paged-kv/block-manager.ts
2683
+ class BlockManager {
2684
+ cache;
2685
+ config;
2686
+ priorities = new Map;
2687
+ constructor(device, config) {
2688
+ this.config = {
2689
+ policy: "greedy",
2690
+ reservedPages: 0,
2691
+ ...config
2692
+ };
2693
+ this.cache = new PagedKVCache(device, config);
2694
+ }
2695
+ canAllocate(request) {
2696
+ const stats = this.cache.getStats();
2697
+ const neededPages = Math.ceil(request.numTokens / this.config.pageSize);
2698
+ const availablePages = stats.freePages - (this.config.reservedPages ?? 0);
2699
+ if (request.seqId !== undefined) {
2700
+ const currentLength = this.cache.getSequenceLength(request.seqId);
2701
+ const currentPages = Math.ceil(currentLength / this.config.pageSize);
2702
+ const newLength = currentLength + request.numTokens;
2703
+ const newPages = Math.ceil(newLength / this.config.pageSize);
2704
+ return newPages - currentPages <= availablePages;
2705
+ }
2706
+ return neededPages <= availablePages;
2707
+ }
2708
+ allocate(request) {
2709
+ if (request.seqId !== undefined) {
2710
+ this.cache.extendSequence(request.seqId, request.numTokens);
2711
+ if (request.priority !== undefined) {
2712
+ this.priorities.set(request.seqId, request.priority);
2713
+ }
2714
+ return request.seqId;
2715
+ }
2716
+ const seqId = this.cache.allocateSequence(request.numTokens);
2717
+ if (request.priority !== undefined) {
2718
+ this.priorities.set(seqId, request.priority);
2719
+ }
2720
+ return seqId;
2721
+ }
2722
+ free(seqId) {
2723
+ this.cache.freeSequence(seqId);
2724
+ this.priorities.delete(seqId);
2725
+ }
2726
+ evict(neededPages) {
2727
+ const evicted = [];
2728
+ const stats = this.cache.getStats();
2729
+ if (stats.freePages >= neededPages) {
2730
+ return evicted;
2731
+ }
2732
+ const sequences = Array.from(this.priorities.entries()).sort((a, b) => a[1] - b[1]).map(([seqId]) => seqId);
2733
+ for (const seqId of sequences) {
2734
+ if (stats.freePages >= neededPages)
2735
+ break;
2736
+ const pages = this.cache.getSequencePages(seqId);
2737
+ if (pages) {
2738
+ this.free(seqId);
2739
+ evicted.push(seqId);
2740
+ }
2741
+ }
2742
+ return evicted;
2743
+ }
2744
+ getUtilization() {
2745
+ const stats = this.cache.getStats();
2746
+ return stats.usedPages / stats.totalPages;
2747
+ }
2748
+ getCache() {
2749
+ return this.cache;
2750
+ }
2751
+ getStats() {
2752
+ return this.cache.getStats();
2753
+ }
2754
+ dispose() {
2755
+ this.cache.dispose();
2756
+ this.priorities.clear();
2757
+ }
2758
+ }
2759
+
2760
+ class ContinuousBatchScheduler {
2761
+ blockManager;
2762
+ runningSequences = new Set;
2763
+ waitingQueue = [];
2764
+ constructor(blockManager) {
2765
+ this.blockManager = blockManager;
2766
+ }
2767
+ addRequest(request) {
2768
+ if (this.blockManager.canAllocate(request)) {
2769
+ const seqId = this.blockManager.allocate(request);
2770
+ this.runningSequences.add(seqId);
2771
+ } else {
2772
+ this.waitingQueue.push(request);
2773
+ }
2774
+ }
2775
+ completeSequence(seqId) {
2776
+ this.runningSequences.delete(seqId);
2777
+ this.blockManager.free(seqId);
2778
+ this.scheduleWaiting();
2779
+ }
2780
+ extendSequence(seqId, numNewTokens) {
2781
+ if (!this.runningSequences.has(seqId)) {
2782
+ return false;
2783
+ }
2784
+ const request = {
2785
+ seqId,
2786
+ numTokens: numNewTokens
2787
+ };
2788
+ if (this.blockManager.canAllocate(request)) {
2789
+ this.blockManager.allocate(request);
2790
+ return true;
2791
+ }
2792
+ return false;
2793
+ }
2794
+ scheduleWaiting() {
2795
+ const stillWaiting = [];
2796
+ for (const request of this.waitingQueue) {
2797
+ if (this.blockManager.canAllocate(request)) {
2798
+ const seqId = this.blockManager.allocate(request);
2799
+ this.runningSequences.add(seqId);
2800
+ } else {
2801
+ stillWaiting.push(request);
2802
+ }
2803
+ }
2804
+ this.waitingQueue = stillWaiting;
2805
+ }
2806
+ getRunningCount() {
2807
+ return this.runningSequences.size;
2808
+ }
2809
+ getWaitingCount() {
2810
+ return this.waitingQueue.length;
2811
+ }
2812
+ }
2813
+ // src/sampling/top-k.ts
2814
+ async function topK(device, logits, k) {
2815
+ if (logits.shape.length !== 1 && logits.shape.length !== 2) {
2816
+ throw new Error("topK expects 1D or 2D tensor");
2817
+ }
2818
+ const is2D = logits.shape.length === 2;
2819
+ const batchSize = is2D ? logits.shape[0] : 1;
2820
+ const vocabSize = is2D ? logits.shape[1] : logits.shape[0];
2821
+ if (k > vocabSize) {
2822
+ throw new Error(`k (${k}) cannot be greater than vocab size (${vocabSize})`);
2823
+ }
2824
+ const logitsData = await logits.toArray();
2825
+ const valuesData = new Float32Array(batchSize * k);
2826
+ const indicesData = new Uint32Array(batchSize * k);
2827
+ for (let b = 0;b < batchSize; b++) {
2828
+ const offset = b * vocabSize;
2829
+ const indices2 = new Array(vocabSize).fill(0).map((_, i) => i);
2830
+ indices2.sort((a, b2) => logitsData[offset + b2] - logitsData[offset + a]);
2831
+ for (let i = 0;i < k; i++) {
2832
+ const idx = indices2[i];
2833
+ valuesData[b * k + i] = logitsData[offset + idx];
2834
+ indicesData[b * k + i] = idx;
2835
+ }
2836
+ }
2837
+ const valuesShape = is2D ? [batchSize, k] : [k];
2838
+ const indicesShape = is2D ? [batchSize, k] : [k];
2839
+ const values = await Tensor.fromArray(device, valuesShape, valuesData);
2840
+ const indices = new Tensor(device, indicesShape, "u32", indicesData);
2841
+ return { values, indices };
2842
+ }
2843
+ function topKCPU(logits, k, vocabSize) {
2844
+ const indices = new Array(vocabSize).fill(0).map((_, i) => i);
2845
+ indices.sort((a, b) => logits[b] - logits[a]);
2846
+ const values = new Float32Array(k);
2847
+ const topIndices = new Uint32Array(k);
2848
+ for (let i = 0;i < k; i++) {
2849
+ const idx = indices[i];
2850
+ values[i] = logits[idx];
2851
+ topIndices[i] = idx;
2852
+ }
2853
+ return { values, indices: topIndices };
2854
+ }
2855
+ async function topKFilter(device, logits, k) {
2856
+ const logitsData = await logits.toArray();
2857
+ const vocabSize = logits.shape[logits.shape.length - 1];
2858
+ const batchSize = logits.numel / vocabSize;
2859
+ const filtered = new Float32Array(logits.numel);
2860
+ for (let b = 0;b < batchSize; b++) {
2861
+ const offset = b * vocabSize;
2862
+ const values = logitsData.slice(offset, offset + vocabSize);
2863
+ const sorted = Float32Array.from(values).sort((a, b2) => b2 - a);
2864
+ const threshold = sorted[k - 1];
2865
+ for (let i = 0;i < vocabSize; i++) {
2866
+ if (logitsData[offset + i] >= threshold) {
2867
+ filtered[offset + i] = logitsData[offset + i];
2868
+ } else {
2869
+ filtered[offset + i] = -Infinity;
2870
+ }
2871
+ }
2872
+ }
2873
+ return Tensor.fromArray(device, [...logits.shape], filtered);
2874
+ }
2875
+ // src/sampling/top-p.ts
2876
+ async function topPFilter(device, logits, p, temperature = 1) {
2877
+ if (p <= 0 || p > 1) {
2878
+ throw new Error("p must be in (0, 1]");
2879
+ }
2880
+ const logitsData = await logits.toArray();
2881
+ const vocabSize = logits.shape[logits.shape.length - 1];
2882
+ const batchSize = logits.numel / vocabSize;
2883
+ const filtered = new Float32Array(logits.numel);
2884
+ for (let b = 0;b < batchSize; b++) {
2885
+ const offset = b * vocabSize;
2886
+ const scaledLogits = new Float32Array(vocabSize);
2887
+ for (let i = 0;i < vocabSize; i++) {
2888
+ scaledLogits[i] = logitsData[offset + i] / temperature;
2889
+ }
2890
+ const maxLogit = Math.max(...scaledLogits);
2891
+ const expLogits = scaledLogits.map((l) => Math.exp(l - maxLogit));
2892
+ const sumExp = expLogits.reduce((a, b2) => a + b2, 0);
2893
+ const probs = expLogits.map((e) => e / sumExp);
2894
+ const indices = new Array(vocabSize).fill(0).map((_, i) => i);
2895
+ indices.sort((a, b2) => probs[b2] - probs[a]);
2896
+ let cumProb = 0;
2897
+ const keepIndices = new Set;
2898
+ for (const idx of indices) {
2899
+ cumProb += probs[idx];
2900
+ keepIndices.add(idx);
2901
+ if (cumProb >= p)
2902
+ break;
2903
+ }
2904
+ for (let i = 0;i < vocabSize; i++) {
2905
+ if (keepIndices.has(i)) {
2906
+ filtered[offset + i] = logitsData[offset + i];
2907
+ } else {
2908
+ filtered[offset + i] = -Infinity;
2909
+ }
2910
+ }
2911
+ }
2912
+ return Tensor.fromArray(device, [...logits.shape], filtered);
2913
+ }
2914
+ function topPFilterCPU(logits, p, temperature = 1) {
2915
+ const vocabSize = logits.length;
2916
+ const scaledLogits = new Float32Array(vocabSize);
2917
+ for (let i = 0;i < vocabSize; i++) {
2918
+ scaledLogits[i] = logits[i] / temperature;
2919
+ }
2920
+ const maxLogit = Math.max(...scaledLogits);
2921
+ const expLogits = scaledLogits.map((l) => Math.exp(l - maxLogit));
2922
+ const sumExp = expLogits.reduce((a, b) => a + b, 0);
2923
+ const probs = expLogits.map((e) => e / sumExp);
2924
+ const indices = new Array(vocabSize).fill(0).map((_, i) => i);
2925
+ indices.sort((a, b) => probs[b] - probs[a]);
2926
+ let cumProb = 0;
2927
+ const keepIndices = new Set;
2928
+ for (const idx of indices) {
2929
+ cumProb += probs[idx];
2930
+ keepIndices.add(idx);
2931
+ if (cumProb >= p)
2932
+ break;
2933
+ }
2934
+ const filtered = new Float32Array(vocabSize);
2935
+ for (let i = 0;i < vocabSize; i++) {
2936
+ filtered[i] = keepIndices.has(i) ? logits[i] : -Infinity;
2937
+ }
2938
+ return filtered;
2939
+ }
2940
+ // src/sampling/sampler.ts
2941
+ function softmax(logits) {
2942
+ const maxLogit = Math.max(...logits);
2943
+ const expLogits = logits.map((l) => Math.exp(l - maxLogit));
2944
+ const sumExp = expLogits.reduce((a, b) => a + b, 0);
2945
+ return expLogits.map((e) => e / sumExp);
2946
+ }
2947
+ function applyRepetitionPenalty(logits, previousTokens, penalty) {
2948
+ if (penalty === 1)
2949
+ return logits;
2950
+ const result = new Float32Array(logits);
2951
+ for (const token of previousTokens) {
2952
+ if (token >= 0 && token < logits.length) {
2953
+ if (result[token] > 0) {
2954
+ result[token] = result[token] / penalty;
2955
+ } else {
2956
+ result[token] = result[token] * penalty;
2957
+ }
2958
+ }
2959
+ }
2960
+ return result;
2961
+ }
2962
+ function sampleFromProbs(probs, random = Math.random) {
2963
+ const r = random();
2964
+ let cumProb = 0;
2965
+ for (let i = 0;i < probs.length; i++) {
2966
+ cumProb += probs[i];
2967
+ if (r < cumProb) {
2968
+ return i;
2969
+ }
2970
+ }
2971
+ return probs.length - 1;
2972
+ }
2973
+ function sampleGreedy(logits) {
2974
+ let maxIdx = 0;
2975
+ let maxVal = logits[0];
2976
+ for (let i = 1;i < logits.length; i++) {
2977
+ if (logits[i] > maxVal) {
2978
+ maxVal = logits[i];
2979
+ maxIdx = i;
2980
+ }
2981
+ }
2982
+ return maxIdx;
2983
+ }
2984
+ async function sample(device, logits, config = {}, previousTokens = []) {
2985
+ const {
2986
+ temperature = 1,
2987
+ topK: topK2 = 0,
2988
+ topP = 1,
2989
+ repetitionPenalty = 1
2990
+ } = config;
2991
+ let logitsData = await logits.toArray();
2992
+ if (repetitionPenalty !== 1 && previousTokens.length > 0) {
2993
+ logitsData = applyRepetitionPenalty(logitsData, previousTokens, repetitionPenalty);
2994
+ }
2995
+ if (temperature === 0 || temperature < 0.000001) {
2996
+ return sampleGreedy(logitsData);
2997
+ }
2998
+ const scaledLogits = new Float32Array(logitsData.length);
2999
+ for (let i = 0;i < logitsData.length; i++) {
3000
+ scaledLogits[i] = logitsData[i] / temperature;
3001
+ }
3002
+ let filteredLogits = scaledLogits;
3003
+ if (topK2 > 0 && topK2 < logitsData.length) {
3004
+ const topKTensor = await Tensor.fromArray(device, [logitsData.length], scaledLogits);
3005
+ const filtered = await topKFilter(device, topKTensor, topK2);
3006
+ filteredLogits = new Float32Array(await filtered.toArray());
3007
+ topKTensor.dispose();
3008
+ filtered.dispose();
3009
+ }
3010
+ if (topP < 1) {
3011
+ const topPTensor = await Tensor.fromArray(device, [filteredLogits.length], filteredLogits);
3012
+ const filtered = await topPFilter(device, topPTensor, topP, 1);
3013
+ filteredLogits = new Float32Array(await filtered.toArray());
3014
+ topPTensor.dispose();
3015
+ filtered.dispose();
3016
+ }
3017
+ const probs = softmax(filteredLogits);
3018
+ return sampleFromProbs(probs);
3019
+ }
3020
+ function sampleCPU(logits, config = {}, previousTokens = []) {
3021
+ const {
3022
+ temperature = 1,
3023
+ topK: topK2 = 0,
3024
+ topP = 1,
3025
+ repetitionPenalty = 1
3026
+ } = config;
3027
+ let processed = new Float32Array(logits);
3028
+ if (repetitionPenalty !== 1 && previousTokens.length > 0) {
3029
+ processed = new Float32Array(applyRepetitionPenalty(processed, previousTokens, repetitionPenalty));
3030
+ }
3031
+ if (temperature === 0 || temperature < 0.000001) {
3032
+ return sampleGreedy(processed);
3033
+ }
3034
+ for (let i = 0;i < processed.length; i++) {
3035
+ processed[i] = processed[i] / temperature;
3036
+ }
3037
+ if (topK2 > 0 && topK2 < processed.length) {
3038
+ const sorted = new Float32Array(processed).sort((a, b) => b - a);
3039
+ const threshold = sorted[topK2 - 1];
3040
+ for (let i = 0;i < processed.length; i++) {
3041
+ if (processed[i] < threshold) {
3042
+ processed[i] = -Infinity;
3043
+ }
3044
+ }
3045
+ }
3046
+ if (topP < 1) {
3047
+ const probs2 = softmax(processed);
3048
+ const indices = new Array(processed.length).fill(0).map((_, i) => i);
3049
+ indices.sort((a, b) => probs2[b] - probs2[a]);
3050
+ let cumProb = 0;
3051
+ const keepSet = new Set;
3052
+ for (const idx of indices) {
3053
+ cumProb += probs2[idx];
3054
+ keepSet.add(idx);
3055
+ if (cumProb >= topP)
3056
+ break;
3057
+ }
3058
+ for (let i = 0;i < processed.length; i++) {
3059
+ if (!keepSet.has(i)) {
3060
+ processed[i] = -Infinity;
3061
+ }
3062
+ }
3063
+ }
3064
+ const probs = softmax(processed);
3065
+ return sampleFromProbs(probs);
3066
+ }
3067
+ // src/model/types.ts
3068
+ var GGUFQuantType;
3069
+ ((GGUFQuantType2) => {
3070
+ GGUFQuantType2[GGUFQuantType2["F32"] = 0] = "F32";
3071
+ GGUFQuantType2[GGUFQuantType2["F16"] = 1] = "F16";
3072
+ GGUFQuantType2[GGUFQuantType2["Q4_0"] = 2] = "Q4_0";
3073
+ GGUFQuantType2[GGUFQuantType2["Q4_1"] = 3] = "Q4_1";
3074
+ GGUFQuantType2[GGUFQuantType2["Q5_0"] = 6] = "Q5_0";
3075
+ GGUFQuantType2[GGUFQuantType2["Q5_1"] = 7] = "Q5_1";
3076
+ GGUFQuantType2[GGUFQuantType2["Q8_0"] = 8] = "Q8_0";
3077
+ GGUFQuantType2[GGUFQuantType2["Q8_1"] = 9] = "Q8_1";
3078
+ GGUFQuantType2[GGUFQuantType2["Q2_K"] = 10] = "Q2_K";
3079
+ GGUFQuantType2[GGUFQuantType2["Q3_K"] = 11] = "Q3_K";
3080
+ GGUFQuantType2[GGUFQuantType2["Q4_K"] = 12] = "Q4_K";
3081
+ GGUFQuantType2[GGUFQuantType2["Q5_K"] = 13] = "Q5_K";
3082
+ GGUFQuantType2[GGUFQuantType2["Q6_K"] = 14] = "Q6_K";
3083
+ GGUFQuantType2[GGUFQuantType2["Q8_K"] = 15] = "Q8_K";
3084
+ GGUFQuantType2[GGUFQuantType2["IQ2_XXS"] = 16] = "IQ2_XXS";
3085
+ GGUFQuantType2[GGUFQuantType2["IQ2_XS"] = 17] = "IQ2_XS";
3086
+ GGUFQuantType2[GGUFQuantType2["IQ3_XXS"] = 18] = "IQ3_XXS";
3087
+ GGUFQuantType2[GGUFQuantType2["IQ1_S"] = 19] = "IQ1_S";
3088
+ GGUFQuantType2[GGUFQuantType2["IQ4_NL"] = 20] = "IQ4_NL";
3089
+ GGUFQuantType2[GGUFQuantType2["IQ3_S"] = 21] = "IQ3_S";
3090
+ GGUFQuantType2[GGUFQuantType2["IQ2_S"] = 22] = "IQ2_S";
3091
+ GGUFQuantType2[GGUFQuantType2["IQ4_XS"] = 23] = "IQ4_XS";
3092
+ GGUFQuantType2[GGUFQuantType2["I8"] = 24] = "I8";
3093
+ GGUFQuantType2[GGUFQuantType2["I16"] = 25] = "I16";
3094
+ GGUFQuantType2[GGUFQuantType2["I32"] = 26] = "I32";
3095
+ GGUFQuantType2[GGUFQuantType2["I64"] = 27] = "I64";
3096
+ GGUFQuantType2[GGUFQuantType2["F64"] = 28] = "F64";
3097
+ GGUFQuantType2[GGUFQuantType2["BF16"] = 29] = "BF16";
3098
+ })(GGUFQuantType ||= {});
3099
+ var GGUFMetadataValueType;
3100
+ ((GGUFMetadataValueType2) => {
3101
+ GGUFMetadataValueType2[GGUFMetadataValueType2["UINT8"] = 0] = "UINT8";
3102
+ GGUFMetadataValueType2[GGUFMetadataValueType2["INT8"] = 1] = "INT8";
3103
+ GGUFMetadataValueType2[GGUFMetadataValueType2["UINT16"] = 2] = "UINT16";
3104
+ GGUFMetadataValueType2[GGUFMetadataValueType2["INT16"] = 3] = "INT16";
3105
+ GGUFMetadataValueType2[GGUFMetadataValueType2["UINT32"] = 4] = "UINT32";
3106
+ GGUFMetadataValueType2[GGUFMetadataValueType2["INT32"] = 5] = "INT32";
3107
+ GGUFMetadataValueType2[GGUFMetadataValueType2["FLOAT32"] = 6] = "FLOAT32";
3108
+ GGUFMetadataValueType2[GGUFMetadataValueType2["BOOL"] = 7] = "BOOL";
3109
+ GGUFMetadataValueType2[GGUFMetadataValueType2["STRING"] = 8] = "STRING";
3110
+ GGUFMetadataValueType2[GGUFMetadataValueType2["ARRAY"] = 9] = "ARRAY";
3111
+ GGUFMetadataValueType2[GGUFMetadataValueType2["UINT64"] = 10] = "UINT64";
3112
+ GGUFMetadataValueType2[GGUFMetadataValueType2["INT64"] = 11] = "INT64";
3113
+ GGUFMetadataValueType2[GGUFMetadataValueType2["FLOAT64"] = 12] = "FLOAT64";
3114
+ })(GGUFMetadataValueType ||= {});
3115
+ var GGUF_QUANT_BLOCK_SIZE = {
3116
+ [2 /* Q4_0 */]: 32,
3117
+ [3 /* Q4_1 */]: 32,
3118
+ [6 /* Q5_0 */]: 32,
3119
+ [7 /* Q5_1 */]: 32,
3120
+ [8 /* Q8_0 */]: 32,
3121
+ [9 /* Q8_1 */]: 32,
3122
+ [10 /* Q2_K */]: 256,
3123
+ [11 /* Q3_K */]: 256,
3124
+ [12 /* Q4_K */]: 256,
3125
+ [13 /* Q5_K */]: 256,
3126
+ [14 /* Q6_K */]: 256
3127
+ };
3128
+ var GGUF_QUANT_BYTES_PER_BLOCK = {
3129
+ [0 /* F32 */]: 4,
3130
+ [1 /* F16 */]: 2,
3131
+ [2 /* Q4_0 */]: 18,
3132
+ [3 /* Q4_1 */]: 20,
3133
+ [6 /* Q5_0 */]: 22,
3134
+ [7 /* Q5_1 */]: 24,
3135
+ [8 /* Q8_0 */]: 34,
3136
+ [9 /* Q8_1 */]: 36,
3137
+ [10 /* Q2_K */]: 84,
3138
+ [11 /* Q3_K */]: 110,
3139
+ [12 /* Q4_K */]: 144,
3140
+ [13 /* Q5_K */]: 176,
3141
+ [14 /* Q6_K */]: 210
3142
+ };
3143
+ // src/model/safetensors.ts
3144
+ function parseSafetensorsHeader(buffer) {
3145
+ const view = new DataView(buffer);
3146
+ const headerSizeLow = view.getUint32(0, true);
3147
+ const headerSizeHigh = view.getUint32(4, true);
3148
+ if (headerSizeHigh > 0) {
3149
+ throw new Error("Header size too large (exceeds 32-bit range)");
3150
+ }
3151
+ const headerSize = headerSizeLow;
3152
+ const dataOffset = 8 + headerSize;
3153
+ if (dataOffset > buffer.byteLength) {
3154
+ throw new Error(`Invalid SafeTensors file: header size ${headerSize} exceeds file size`);
3155
+ }
3156
+ const headerBytes = new Uint8Array(buffer, 8, headerSize);
3157
+ const headerJson = new TextDecoder("utf-8").decode(headerBytes);
3158
+ let header;
3159
+ try {
3160
+ const parsed = JSON.parse(headerJson);
3161
+ const { __metadata__, ...tensors } = parsed;
3162
+ header = {
3163
+ tensors,
3164
+ __metadata__
3165
+ };
3166
+ } catch (e) {
3167
+ throw new Error(`Failed to parse SafeTensors header JSON: ${e}`);
3168
+ }
3169
+ return { header, dataOffset };
3170
+ }
3171
+ function getSafetensorsTensorInfos(header, dataOffset) {
3172
+ const tensorInfos = new Map;
3173
+ for (const [name, entry] of Object.entries(header.tensors)) {
3174
+ const [start, end] = entry.data_offsets;
3175
+ const byteSize = end - start;
3176
+ tensorInfos.set(name, {
3177
+ name,
3178
+ shape: entry.shape,
3179
+ dtype: entry.dtype,
3180
+ offset: dataOffset + start,
3181
+ byteSize
3182
+ });
3183
+ }
3184
+ return tensorInfos;
3185
+ }
3186
+ function loadSafetensorsTensor(buffer, info) {
3187
+ const dtype = info.dtype;
3188
+ const tensorData = new Uint8Array(buffer, info.offset, info.byteSize);
3189
+ switch (dtype) {
3190
+ case "F32": {
3191
+ const float32 = new Float32Array(tensorData.buffer, tensorData.byteOffset, tensorData.byteLength / 4);
3192
+ return new Float32Array(float32);
3193
+ }
3194
+ case "F16": {
3195
+ const numel = info.shape.reduce((a, b) => a * b, 1);
3196
+ const result = new Float32Array(numel);
3197
+ const uint16View = new Uint16Array(tensorData.buffer, tensorData.byteOffset, numel);
3198
+ for (let i = 0;i < numel; i++) {
3199
+ result[i] = float16ToFloat32(uint16View[i]);
3200
+ }
3201
+ return result;
3202
+ }
3203
+ case "BF16": {
3204
+ const numel = info.shape.reduce((a, b) => a * b, 1);
3205
+ const result = new Float32Array(numel);
3206
+ const uint16View = new Uint16Array(tensorData.buffer, tensorData.byteOffset, numel);
3207
+ for (let i = 0;i < numel; i++) {
3208
+ result[i] = bfloat16ToFloat32(uint16View[i]);
3209
+ }
3210
+ return result;
3211
+ }
3212
+ case "F64": {
3213
+ const numel = info.shape.reduce((a, b) => a * b, 1);
3214
+ const result = new Float32Array(numel);
3215
+ const float64View = new Float64Array(tensorData.buffer, tensorData.byteOffset, numel);
3216
+ for (let i = 0;i < numel; i++) {
3217
+ result[i] = float64View[i];
3218
+ }
3219
+ return result;
3220
+ }
3221
+ case "I8":
3222
+ case "U8":
3223
+ case "I16":
3224
+ case "I32":
3225
+ case "I64":
3226
+ case "BOOL": {
3227
+ throw new Error(`Integer dtype ${dtype} not yet supported for loading`);
3228
+ }
3229
+ default:
3230
+ throw new Error(`Unknown dtype: ${dtype}`);
3231
+ }
3232
+ }
3233
+ function float16ToFloat32(h) {
3234
+ const sign = (h & 32768) >> 15;
3235
+ const exponent = (h & 31744) >> 10;
3236
+ const fraction = h & 1023;
3237
+ if (exponent === 0) {
3238
+ if (fraction === 0) {
3239
+ return sign === 1 ? -0 : 0;
3240
+ }
3241
+ return (sign === 1 ? -1 : 1) * Math.pow(2, -14) * (fraction / 1024);
3242
+ } else if (exponent === 31) {
3243
+ if (fraction === 0) {
3244
+ return sign === 1 ? -Infinity : Infinity;
3245
+ }
3246
+ return NaN;
3247
+ }
3248
+ return (sign === 1 ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 1024);
3249
+ }
3250
+ function bfloat16ToFloat32(bf16) {
3251
+ const uint32 = bf16 << 16;
3252
+ const buffer = new ArrayBuffer(4);
3253
+ new DataView(buffer).setUint32(0, uint32, false);
3254
+ return new DataView(buffer).getFloat32(0, false);
3255
+ }
3256
+ function extractMetadata(header) {
3257
+ const meta = {
3258
+ format: "safetensors",
3259
+ extra: header.__metadata__
3260
+ };
3261
+ const tensorNames = Object.keys(header.tensors);
3262
+ if (tensorNames.some((n) => n.includes("model.layers."))) {
3263
+ meta.architecture = "llama";
3264
+ } else if (tensorNames.some((n) => n.includes("transformer.h."))) {
3265
+ meta.architecture = "gpt2";
3266
+ }
3267
+ for (const [name, entry] of Object.entries(header.tensors)) {
3268
+ if (name.includes("embed_tokens") || name.includes("wte")) {
3269
+ meta.vocabSize = entry.shape[0];
3270
+ meta.embeddingLength = entry.shape[1];
3271
+ }
3272
+ if (name.includes("layers.0.self_attn.q_proj")) {
3273
+ meta.embeddingLength = entry.shape[1];
3274
+ }
3275
+ }
3276
+ const layerNums = tensorNames.map((n) => {
3277
+ const match = n.match(/layers\.(\d+)\./);
3278
+ return match ? parseInt(match[1], 10) : -1;
3279
+ }).filter((n) => n >= 0);
3280
+ if (layerNums.length > 0) {
3281
+ meta.numLayers = Math.max(...layerNums) + 1;
3282
+ }
3283
+ return meta;
3284
+ }
3285
+ function loadSafetensors(buffer, options) {
3286
+ const { header, dataOffset } = parseSafetensorsHeader(buffer);
3287
+ const tensorInfos = getSafetensorsTensorInfos(header, dataOffset);
3288
+ if (options?.tensorFilter) {
3289
+ for (const name of tensorInfos.keys()) {
3290
+ if (!options.tensorFilter(name)) {
3291
+ tensorInfos.delete(name);
3292
+ }
3293
+ }
3294
+ }
3295
+ let totalBytes = 0;
3296
+ for (const info of tensorInfos.values()) {
3297
+ totalBytes += info.byteSize;
3298
+ }
3299
+ const metadata = extractMetadata(header);
3300
+ return {
3301
+ metadata,
3302
+ tensorInfos,
3303
+ totalBytes,
3304
+ buffer,
3305
+ dataOffset
3306
+ };
3307
+ }
3308
+ async function loadSafetensorsFromUrl(url, options) {
3309
+ const response = await fetch(url);
3310
+ if (!response.ok) {
3311
+ throw new Error(`Failed to fetch ${url}: ${response.statusText}`);
3312
+ }
3313
+ const buffer = await response.arrayBuffer();
3314
+ return loadSafetensors(buffer, options);
3315
+ }
3316
+ function isSafetensors(buffer) {
3317
+ if (buffer.byteLength < 8)
3318
+ return false;
3319
+ try {
3320
+ const view = new DataView(buffer);
3321
+ const headerSize = view.getUint32(0, true);
3322
+ const headerSizeHigh = view.getUint32(4, true);
3323
+ if (headerSizeHigh !== 0 || headerSize > 100 * 1024 * 1024) {
3324
+ return false;
3325
+ }
3326
+ if (8 + headerSize > buffer.byteLength) {
3327
+ return false;
3328
+ }
3329
+ const headerBytes = new Uint8Array(buffer, 8, Math.min(headerSize, 100));
3330
+ const headerStart = new TextDecoder("utf-8").decode(headerBytes);
3331
+ return headerStart.trimStart().startsWith("{");
3332
+ } catch {
3333
+ return false;
3334
+ }
3335
+ }
3336
+ // src/model/gguf.ts
3337
+ var GGUF_MAGIC = 1179993927;
3338
+ var GGUF_VERSION = 3;
3339
+ var DEFAULT_ALIGNMENT = 32;
3340
+
3341
+ class GGUFReader {
3342
+ view;
3343
+ offset = 0;
3344
+ textDecoder = new TextDecoder("utf-8");
3345
+ constructor(buffer) {
3346
+ this.view = new DataView(buffer);
3347
+ }
3348
+ get position() {
3349
+ return this.offset;
3350
+ }
3351
+ set position(pos) {
3352
+ this.offset = pos;
3353
+ }
3354
+ readUint8() {
3355
+ const val = this.view.getUint8(this.offset);
3356
+ this.offset += 1;
3357
+ return val;
3358
+ }
3359
+ readInt8() {
3360
+ const val = this.view.getInt8(this.offset);
3361
+ this.offset += 1;
3362
+ return val;
3363
+ }
3364
+ readUint16() {
3365
+ const val = this.view.getUint16(this.offset, true);
3366
+ this.offset += 2;
3367
+ return val;
3368
+ }
3369
+ readInt16() {
3370
+ const val = this.view.getInt16(this.offset, true);
3371
+ this.offset += 2;
3372
+ return val;
3373
+ }
3374
+ readUint32() {
3375
+ const val = this.view.getUint32(this.offset, true);
3376
+ this.offset += 4;
3377
+ return val;
3378
+ }
3379
+ readInt32() {
3380
+ const val = this.view.getInt32(this.offset, true);
3381
+ this.offset += 4;
3382
+ return val;
3383
+ }
3384
+ readUint64() {
3385
+ const val = this.view.getBigUint64(this.offset, true);
3386
+ this.offset += 8;
3387
+ return val;
3388
+ }
3389
+ readInt64() {
3390
+ const val = this.view.getBigInt64(this.offset, true);
3391
+ this.offset += 8;
3392
+ return val;
3393
+ }
3394
+ readFloat32() {
3395
+ const val = this.view.getFloat32(this.offset, true);
3396
+ this.offset += 4;
3397
+ return val;
3398
+ }
3399
+ readFloat64() {
3400
+ const val = this.view.getFloat64(this.offset, true);
3401
+ this.offset += 8;
3402
+ return val;
3403
+ }
3404
+ readBool() {
3405
+ return this.readUint8() !== 0;
3406
+ }
3407
+ readString() {
3408
+ const length = Number(this.readUint64());
3409
+ const bytes = new Uint8Array(this.view.buffer, this.offset, length);
3410
+ this.offset += length;
3411
+ return this.textDecoder.decode(bytes);
3412
+ }
3413
+ alignTo(alignment) {
3414
+ const remainder = this.offset % alignment;
3415
+ if (remainder !== 0) {
3416
+ this.offset += alignment - remainder;
3417
+ }
3418
+ }
3419
+ }
3420
+ function parseGGUFHeader(reader) {
3421
+ const magic = reader.readUint32();
3422
+ if (magic !== GGUF_MAGIC) {
3423
+ throw new Error(`Invalid GGUF magic: expected 0x${GGUF_MAGIC.toString(16)}, got 0x${magic.toString(16)}`);
3424
+ }
3425
+ const version = reader.readUint32();
3426
+ if (version !== GGUF_VERSION) {
3427
+ throw new Error(`Unsupported GGUF version: ${version} (expected ${GGUF_VERSION})`);
3428
+ }
3429
+ const nTensors = reader.readUint64();
3430
+ const nKV = reader.readUint64();
3431
+ return { magic, version, nTensors, nKV };
3432
+ }
3433
+ function parseMetadataValue(reader, valueType) {
3434
+ switch (valueType) {
3435
+ case 0 /* UINT8 */:
3436
+ return reader.readUint8();
3437
+ case 1 /* INT8 */:
3438
+ return reader.readInt8();
3439
+ case 2 /* UINT16 */:
3440
+ return reader.readUint16();
3441
+ case 3 /* INT16 */:
3442
+ return reader.readInt16();
3443
+ case 4 /* UINT32 */:
3444
+ return reader.readUint32();
3445
+ case 5 /* INT32 */:
3446
+ return reader.readInt32();
3447
+ case 6 /* FLOAT32 */:
3448
+ return reader.readFloat32();
3449
+ case 7 /* BOOL */:
3450
+ return reader.readBool();
3451
+ case 8 /* STRING */:
3452
+ return reader.readString();
3453
+ case 10 /* UINT64 */:
3454
+ return reader.readUint64();
3455
+ case 11 /* INT64 */:
3456
+ return reader.readInt64();
3457
+ case 12 /* FLOAT64 */:
3458
+ return reader.readFloat64();
3459
+ case 9 /* ARRAY */: {
3460
+ const arrayType = reader.readUint32();
3461
+ const arrayLen = Number(reader.readUint64());
3462
+ const result = [];
3463
+ for (let i = 0;i < arrayLen; i++) {
3464
+ result.push(parseMetadataValue(reader, arrayType));
3465
+ }
3466
+ return result;
3467
+ }
3468
+ default:
3469
+ throw new Error(`Unknown metadata value type: ${valueType}`);
3470
+ }
3471
+ }
3472
+ function parseGGUFMetadata(reader, nKV) {
3473
+ const metadata = new Map;
3474
+ for (let i = 0n;i < nKV; i++) {
3475
+ const key = reader.readString();
3476
+ const valueType = reader.readUint32();
3477
+ const value = parseMetadataValue(reader, valueType);
3478
+ metadata.set(key, value);
3479
+ }
3480
+ return metadata;
3481
+ }
3482
+ function parseGGUFTensorInfos(reader, nTensors) {
3483
+ const tensorInfos = [];
3484
+ for (let i = 0n;i < nTensors; i++) {
3485
+ const name = reader.readString();
3486
+ const nDims = reader.readUint32();
3487
+ const dimensions = [];
3488
+ for (let d = 0;d < nDims; d++) {
3489
+ dimensions.push(reader.readUint64());
3490
+ }
3491
+ const type = reader.readUint32();
3492
+ const offset = reader.readUint64();
3493
+ tensorInfos.push({ name, nDims, dimensions, type, offset });
3494
+ }
3495
+ return tensorInfos;
3496
+ }
3497
+ function calculateGGUFTensorBytes(type, shape) {
3498
+ const numel = shape.reduce((a, b) => a * b, 1);
3499
+ if (type === 0 /* F32 */) {
3500
+ return numel * 4;
3501
+ }
3502
+ if (type === 1 /* F16 */ || type === 29 /* BF16 */) {
3503
+ return numel * 2;
3504
+ }
3505
+ if (type === 24 /* I8 */) {
3506
+ return numel;
3507
+ }
3508
+ if (type === 25 /* I16 */) {
3509
+ return numel * 2;
3510
+ }
3511
+ if (type === 26 /* I32 */) {
3512
+ return numel * 4;
3513
+ }
3514
+ if (type === 27 /* I64 */ || type === 28 /* F64 */) {
3515
+ return numel * 8;
3516
+ }
3517
+ const blockSize = GGUF_QUANT_BLOCK_SIZE[type];
3518
+ const bytesPerBlock = GGUF_QUANT_BYTES_PER_BLOCK[type];
3519
+ if (blockSize === undefined || bytesPerBlock === undefined) {
3520
+ throw new Error(`Unknown quantization type: ${type}`);
3521
+ }
3522
+ const numBlocks = Math.ceil(numel / blockSize);
3523
+ return numBlocks * bytesPerBlock;
3524
+ }
3525
+ function convertTensorInfo(info, dataOffset) {
3526
+ const shape = info.dimensions.map((d) => Number(d));
3527
+ const byteSize = calculateGGUFTensorBytes(info.type, shape);
3528
+ return {
3529
+ name: info.name,
3530
+ shape,
3531
+ dtype: info.type,
3532
+ offset: dataOffset + Number(info.offset),
3533
+ byteSize
3534
+ };
3535
+ }
3536
+ function extractGGUFMetadata(metadata) {
3537
+ const meta = {
3538
+ format: "gguf",
3539
+ extra: Object.fromEntries(metadata)
3540
+ };
3541
+ meta.name = metadata.get("general.name");
3542
+ meta.architecture = metadata.get("general.architecture");
3543
+ const arch = meta.architecture || "";
3544
+ meta.contextLength = metadata.get(`${arch}.context_length`);
3545
+ meta.embeddingLength = metadata.get(`${arch}.embedding_length`);
3546
+ meta.numLayers = metadata.get(`${arch}.block_count`);
3547
+ meta.numHeads = metadata.get(`${arch}.attention.head_count`);
3548
+ meta.numKVHeads = metadata.get(`${arch}.attention.head_count_kv`);
3549
+ meta.vocabSize = metadata.get(`${arch}.vocab_size`);
3550
+ meta.ropeFreqBase = metadata.get(`${arch}.rope.freq_base`);
3551
+ if (meta.embeddingLength && meta.numHeads) {
3552
+ meta.headDim = meta.embeddingLength / meta.numHeads;
3553
+ }
3554
+ return meta;
3555
+ }
3556
+ function loadGGUF(buffer, options) {
3557
+ const reader = new GGUFReader(buffer);
3558
+ const header = parseGGUFHeader(reader);
3559
+ const rawMetadata = parseGGUFMetadata(reader, header.nKV);
3560
+ const alignment = rawMetadata.get("general.alignment") || DEFAULT_ALIGNMENT;
3561
+ const ggufTensorInfos = parseGGUFTensorInfos(reader, header.nTensors);
3562
+ reader.alignTo(alignment);
3563
+ const dataOffset = reader.position;
3564
+ const tensorInfos = new Map;
3565
+ let totalBytes = 0;
3566
+ for (const info of ggufTensorInfos) {
3567
+ if (options?.tensorFilter && !options.tensorFilter(info.name)) {
3568
+ continue;
3569
+ }
3570
+ const converted = convertTensorInfo(info, dataOffset);
3571
+ tensorInfos.set(info.name, converted);
3572
+ totalBytes += converted.byteSize;
3573
+ }
3574
+ const metadata = extractGGUFMetadata(rawMetadata);
3575
+ return {
3576
+ metadata,
3577
+ tensorInfos,
3578
+ totalBytes,
3579
+ buffer,
3580
+ dataOffset
3581
+ };
3582
+ }
3583
+ async function loadGGUFFromUrl(url, options) {
3584
+ const response = await fetch(url);
3585
+ if (!response.ok) {
3586
+ throw new Error(`Failed to fetch ${url}: ${response.statusText}`);
3587
+ }
3588
+ const buffer = await response.arrayBuffer();
3589
+ return loadGGUF(buffer, options);
3590
+ }
3591
+ function dequantizeQ4_0Block(data, offset) {
3592
+ const result = new Float32Array(32);
3593
+ const scaleU16 = data[offset + 1] << 8 | data[offset];
3594
+ const scale2 = float16ToFloat322(scaleU16);
3595
+ for (let i = 0;i < 16; i++) {
3596
+ const byte = data[offset + 2 + i];
3597
+ const v0 = (byte & 15) - 8;
3598
+ const v1 = (byte >> 4 & 15) - 8;
3599
+ result[i * 2] = v0 * scale2;
3600
+ result[i * 2 + 1] = v1 * scale2;
3601
+ }
3602
+ return result;
3603
+ }
3604
+ function dequantizeQ8_0Block(data, offset) {
3605
+ const result = new Float32Array(32);
3606
+ const scaleU16 = data[offset + 1] << 8 | data[offset];
3607
+ const scale2 = float16ToFloat322(scaleU16);
3608
+ for (let i = 0;i < 32; i++) {
3609
+ const v = data[offset + 2 + i];
3610
+ const signed = v > 127 ? v - 256 : v;
3611
+ result[i] = signed * scale2;
3612
+ }
3613
+ return result;
3614
+ }
3615
+ function loadGGUFTensor(buffer, info) {
3616
+ const type = info.dtype;
3617
+ const data = new Uint8Array(buffer, info.offset, info.byteSize);
3618
+ const numel = info.shape.reduce((a, b) => a * b, 1);
3619
+ switch (type) {
3620
+ case 0 /* F32 */: {
3621
+ return new Float32Array(buffer, info.offset, numel);
3622
+ }
3623
+ case 1 /* F16 */: {
3624
+ const result = new Float32Array(numel);
3625
+ const u16 = new Uint16Array(buffer, info.offset, numel);
3626
+ for (let i = 0;i < numel; i++) {
3627
+ result[i] = float16ToFloat322(u16[i]);
3628
+ }
3629
+ return result;
3630
+ }
3631
+ case 2 /* Q4_0 */: {
3632
+ const blockSize = 32;
3633
+ const bytesPerBlock = 18;
3634
+ const numBlocks = Math.ceil(numel / blockSize);
3635
+ const result = new Float32Array(numel);
3636
+ for (let b = 0;b < numBlocks; b++) {
3637
+ const blockData = dequantizeQ4_0Block(data, b * bytesPerBlock);
3638
+ const outOffset = b * blockSize;
3639
+ const copyLen = Math.min(blockSize, numel - outOffset);
3640
+ result.set(blockData.subarray(0, copyLen), outOffset);
3641
+ }
3642
+ return result;
3643
+ }
3644
+ case 8 /* Q8_0 */: {
3645
+ const blockSize = 32;
3646
+ const bytesPerBlock = 34;
3647
+ const numBlocks = Math.ceil(numel / blockSize);
3648
+ const result = new Float32Array(numel);
3649
+ for (let b = 0;b < numBlocks; b++) {
3650
+ const blockData = dequantizeQ8_0Block(data, b * bytesPerBlock);
3651
+ const outOffset = b * blockSize;
3652
+ const copyLen = Math.min(blockSize, numel - outOffset);
3653
+ result.set(blockData.subarray(0, copyLen), outOffset);
3654
+ }
3655
+ return result;
3656
+ }
3657
+ default:
3658
+ throw new Error(`Quantization type ${GGUFQuantType[type]} not yet supported for dequantization`);
3659
+ }
3660
+ }
3661
+ function isGGUF(buffer) {
3662
+ if (buffer.byteLength < 24)
3663
+ return false;
3664
+ try {
3665
+ const view = new DataView(buffer);
3666
+ const magic = view.getUint32(0, true);
3667
+ return magic === GGUF_MAGIC;
3668
+ } catch {
3669
+ return false;
3670
+ }
3671
+ }
3672
+ function float16ToFloat322(h) {
3673
+ const sign = (h & 32768) >> 15;
3674
+ const exponent = (h & 31744) >> 10;
3675
+ const fraction = h & 1023;
3676
+ if (exponent === 0) {
3677
+ if (fraction === 0) {
3678
+ return sign === 1 ? -0 : 0;
3679
+ }
3680
+ return (sign === 1 ? -1 : 1) * Math.pow(2, -14) * (fraction / 1024);
3681
+ } else if (exponent === 31) {
3682
+ if (fraction === 0) {
3683
+ return sign === 1 ? -Infinity : Infinity;
3684
+ }
3685
+ return NaN;
3686
+ }
3687
+ return (sign === 1 ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 1024);
3688
+ }
3689
+ // src/model/index.ts
3690
+ async function loadModel(source, options) {
3691
+ let buffer;
3692
+ if (typeof source === "string") {
3693
+ const response = await fetch(source);
3694
+ if (!response.ok) {
3695
+ throw new Error(`Failed to fetch model: ${response.statusText}`);
3696
+ }
3697
+ buffer = await response.arrayBuffer();
3698
+ } else {
3699
+ buffer = source;
3700
+ }
3701
+ if (isGGUF(buffer)) {
3702
+ return loadGGUF(buffer, options);
3703
+ } else if (isSafetensors(buffer)) {
3704
+ return loadSafetensors(buffer, options);
3705
+ } else {
3706
+ throw new Error("Unknown model format. Expected SafeTensors or GGUF.");
3707
+ }
3708
+ }
3709
+ // src/inference/types.ts
3710
+ var DEFAULT_GENERATION_CONFIG = {
3711
+ maxTokens: 256,
3712
+ temperature: 1,
3713
+ topK: 0,
3714
+ topP: 1,
3715
+ repetitionPenalty: 1,
3716
+ eosTokenId: 2,
3717
+ padTokenId: 0,
3718
+ bosTokenId: 1,
3719
+ stream: false
3720
+ };
3721
+ function normalizeGenerationConfig(config) {
3722
+ const normalized = {
3723
+ maxTokens: config.maxTokens ?? DEFAULT_GENERATION_CONFIG.maxTokens,
3724
+ temperature: config.temperature ?? DEFAULT_GENERATION_CONFIG.temperature,
3725
+ topK: config.topK ?? DEFAULT_GENERATION_CONFIG.topK,
3726
+ topP: config.topP ?? DEFAULT_GENERATION_CONFIG.topP,
3727
+ repetitionPenalty: config.repetitionPenalty ?? DEFAULT_GENERATION_CONFIG.repetitionPenalty,
3728
+ eosTokenId: config.eosTokenId ?? DEFAULT_GENERATION_CONFIG.eosTokenId,
3729
+ padTokenId: config.padTokenId ?? DEFAULT_GENERATION_CONFIG.padTokenId,
3730
+ bosTokenId: config.bosTokenId ?? DEFAULT_GENERATION_CONFIG.bosTokenId,
3731
+ stream: config.stream ?? DEFAULT_GENERATION_CONFIG.stream,
3732
+ stopSequences: config.stopSequences,
3733
+ seed: config.seed
3734
+ };
3735
+ if (normalized.maxTokens < 1) {
3736
+ throw new Error("maxTokens must be >= 1");
3737
+ }
3738
+ if (normalized.temperature !== undefined && normalized.temperature < 0) {
3739
+ throw new Error("temperature must be >= 0");
3740
+ }
3741
+ if (normalized.topK !== undefined && normalized.topK < 0) {
3742
+ throw new Error("topK must be >= 0");
3743
+ }
3744
+ if (normalized.topP !== undefined && (normalized.topP < 0 || normalized.topP > 1)) {
3745
+ throw new Error("topP must be between 0 and 1");
3746
+ }
3747
+ if (normalized.repetitionPenalty !== undefined && normalized.repetitionPenalty < 1) {
3748
+ throw new Error("repetitionPenalty must be >= 1");
3749
+ }
3750
+ return normalized;
3751
+ }
3752
+ // src/inference/engine.ts
3753
+ var DEFAULT_INFERENCE_CONFIG = {
3754
+ maxBatchSize: 1,
3755
+ maxSeqLen: 2048,
3756
+ useKVCache: true,
3757
+ memoryLimit: 0,
3758
+ enableProfiling: false
3759
+ };
3760
+
3761
+ class InferenceEngine {
3762
+ device;
3763
+ config;
3764
+ modelConfig = null;
3765
+ weights = null;
3766
+ loadedModel = null;
3767
+ kvCache = null;
3768
+ ropeFreqsCos = null;
3769
+ ropeFreqsSin = null;
3770
+ constructor(device, config) {
3771
+ this.device = device;
3772
+ this.config = { ...DEFAULT_INFERENCE_CONFIG, ...config };
3773
+ }
3774
+ async loadModel(model, modelConfig) {
3775
+ this.loadedModel = model;
3776
+ this.modelConfig = modelConfig;
3777
+ this.weights = await this.extractWeights(model, modelConfig);
3778
+ const headDim = modelConfig.headDim ?? modelConfig.hiddenSize / modelConfig.numHeads;
3779
+ const ropeFreqBase = modelConfig.ropeFreqBase ?? 1e4;
3780
+ const { cos, sin } = computeRoPEFrequencies({
3781
+ dim: headDim,
3782
+ maxSeqLen: this.config.maxSeqLen,
3783
+ base: ropeFreqBase
3784
+ });
3785
+ this.ropeFreqsCos = cos;
3786
+ this.ropeFreqsSin = sin;
3787
+ if (this.config.useKVCache) {
3788
+ this.initKVCache(modelConfig);
3789
+ }
3790
+ }
3791
+ async extractWeights(model, config) {
3792
+ const loadTensor = (name) => {
3793
+ const info = model.tensorInfos.get(name);
3794
+ if (!info) {
3795
+ throw new Error(`Tensor not found: ${name}`);
3796
+ }
3797
+ if (model.metadata.format === "safetensors") {
3798
+ return loadSafetensorsTensor(model.buffer, info);
3799
+ } else {
3800
+ return loadGGUFTensor(model.buffer, info);
3801
+ }
3802
+ };
3803
+ const tryLoad = (names) => {
3804
+ for (const name of names) {
3805
+ if (model.tensorInfos.has(name)) {
3806
+ return loadTensor(name);
3807
+ }
3808
+ }
3809
+ throw new Error(`None of these tensors found: ${names.join(", ")}`);
3810
+ };
3811
+ const embedTokens = tryLoad([
3812
+ "model.embed_tokens.weight",
3813
+ "transformer.wte.weight",
3814
+ "embedding.weight"
3815
+ ]);
3816
+ const layers = [];
3817
+ for (let i = 0;i < config.numLayers; i++) {
3818
+ const prefix = `model.layers.${i}`;
3819
+ const gptPrefix = `transformer.h.${i}`;
3820
+ const layerWeights = {
3821
+ attention: {
3822
+ qProj: tryLoad([`${prefix}.self_attn.q_proj.weight`, `${gptPrefix}.attn.q_proj.weight`]),
3823
+ kProj: tryLoad([`${prefix}.self_attn.k_proj.weight`, `${gptPrefix}.attn.k_proj.weight`]),
3824
+ vProj: tryLoad([`${prefix}.self_attn.v_proj.weight`, `${gptPrefix}.attn.v_proj.weight`]),
3825
+ oProj: tryLoad([`${prefix}.self_attn.o_proj.weight`, `${gptPrefix}.attn.o_proj.weight`])
3826
+ },
3827
+ ffn: {
3828
+ gate: model.tensorInfos.has(`${prefix}.mlp.gate_proj.weight`) ? loadTensor(`${prefix}.mlp.gate_proj.weight`) : undefined,
3829
+ up: tryLoad([`${prefix}.mlp.up_proj.weight`, `${gptPrefix}.mlp.up_proj.weight`]),
3830
+ down: tryLoad([`${prefix}.mlp.down_proj.weight`, `${gptPrefix}.mlp.down_proj.weight`])
3831
+ },
3832
+ inputNorm: tryLoad([
3833
+ `${prefix}.input_layernorm.weight`,
3834
+ `${gptPrefix}.ln_1.weight`
3835
+ ]),
3836
+ postAttentionNorm: tryLoad([
3837
+ `${prefix}.post_attention_layernorm.weight`,
3838
+ `${gptPrefix}.ln_2.weight`
3839
+ ])
3840
+ };
3841
+ layers.push(layerWeights);
3842
+ }
3843
+ const finalNorm = tryLoad([
3844
+ "model.norm.weight",
3845
+ "transformer.ln_f.weight"
3846
+ ]);
3847
+ const lmHead = tryLoad([
3848
+ "lm_head.weight",
3849
+ "transformer.lm_head.weight"
3850
+ ]);
3851
+ return { embedTokens, layers, finalNorm, lmHead };
3852
+ }
3853
+ initKVCache(config) {
3854
+ const headDim = config.headDim ?? config.hiddenSize / config.numHeads;
3855
+ const numKVHeads = config.numKVHeads ?? config.numHeads;
3856
+ const cacheSize = this.config.maxSeqLen * numKVHeads * headDim;
3857
+ this.kvCache = {
3858
+ keys: [],
3859
+ values: [],
3860
+ seqLen: 0
3861
+ };
3862
+ for (let i = 0;i < config.numLayers; i++) {
3863
+ this.kvCache.keys.push(new Float32Array(cacheSize));
3864
+ this.kvCache.values.push(new Float32Array(cacheSize));
3865
+ }
3866
+ }
3867
+ resetKVCache() {
3868
+ if (this.kvCache) {
3869
+ this.kvCache.seqLen = 0;
3870
+ for (let i = 0;i < this.kvCache.keys.length; i++) {
3871
+ this.kvCache.keys[i].fill(0);
3872
+ this.kvCache.values[i].fill(0);
3873
+ }
3874
+ }
3875
+ }
3876
+ forward(inputIds, startPos = 0) {
3877
+ if (!this.weights || !this.modelConfig) {
3878
+ throw new Error("Model not loaded. Call loadModel() first.");
3879
+ }
3880
+ const config = this.modelConfig;
3881
+ const weights = this.weights;
3882
+ const seqLen = inputIds.length;
3883
+ const headDim = config.headDim ?? config.hiddenSize / config.numHeads;
3884
+ const numKVHeads = config.numKVHeads ?? config.numHeads;
3885
+ const eps = config.rmsNormEps ?? 0.00001;
3886
+ const inputIdsArray = Array.from(inputIds);
3887
+ let hidden = embeddingCPU(weights.embedTokens, inputIdsArray, config.hiddenSize);
3888
+ for (let layer = 0;layer < config.numLayers; layer++) {
3889
+ const lw = weights.layers[layer];
3890
+ const normedHidden = rmsNormCPU(hidden, lw.inputNorm, [seqLen, config.hiddenSize], eps);
3891
+ hidden = this.attentionForward(normedHidden, lw, layer, startPos, seqLen, headDim, numKVHeads, hidden);
3892
+ const normedHidden2 = rmsNormCPU(hidden, lw.postAttentionNorm, [seqLen, config.hiddenSize], eps);
3893
+ hidden = this.ffnForward(normedHidden2, lw, hidden);
3894
+ }
3895
+ hidden = rmsNormCPU(hidden, weights.finalNorm, [seqLen, config.hiddenSize], eps);
3896
+ const lastTokenHidden = hidden.slice((seqLen - 1) * config.hiddenSize, seqLen * config.hiddenSize);
3897
+ const logits = matmulCPU(lastTokenHidden, weights.lmHead, 1, config.vocabSize, config.hiddenSize);
3898
+ return {
3899
+ logits,
3900
+ logitsShape: [1, config.vocabSize]
3901
+ };
3902
+ }
3903
+ attentionForward(x, lw, layerIdx, startPos, seqLen, headDim, numKVHeads, residual) {
3904
+ const config = this.modelConfig;
3905
+ const hiddenSize = config.hiddenSize;
3906
+ const numHeads = config.numHeads;
3907
+ let q = matmulCPU(x, lw.attention.qProj, seqLen, numHeads * headDim, hiddenSize);
3908
+ let k = matmulCPU(x, lw.attention.kProj, seqLen, numKVHeads * headDim, hiddenSize);
3909
+ let v = matmulCPU(x, lw.attention.vProj, seqLen, numKVHeads * headDim, hiddenSize);
3910
+ if (this.ropeFreqsCos && this.ropeFreqsSin) {
3911
+ for (let pos = 0;pos < seqLen; pos++) {
3912
+ const actualPos = startPos + pos;
3913
+ for (let h = 0;h < numHeads; h++) {
3914
+ const qOffset = pos * numHeads * headDim + h * headDim;
3915
+ this.applyRoPE(q, qOffset, actualPos, headDim);
3916
+ }
3917
+ for (let h = 0;h < numKVHeads; h++) {
3918
+ const kOffset = pos * numKVHeads * headDim + h * headDim;
3919
+ this.applyRoPE(k, kOffset, actualPos, headDim);
3920
+ }
3921
+ }
3922
+ }
3923
+ if (this.kvCache) {
3924
+ const kvSize = numKVHeads * headDim;
3925
+ for (let pos = 0;pos < seqLen; pos++) {
3926
+ const cachePos = (startPos + pos) * kvSize;
3927
+ this.kvCache.keys[layerIdx].set(k.subarray(pos * kvSize, (pos + 1) * kvSize), cachePos);
3928
+ this.kvCache.values[layerIdx].set(v.subarray(pos * kvSize, (pos + 1) * kvSize), cachePos);
3929
+ }
3930
+ this.kvCache.seqLen = startPos + seqLen;
3931
+ const totalLen = startPos + seqLen;
3932
+ k = this.kvCache.keys[layerIdx].slice(0, totalLen * kvSize);
3933
+ v = this.kvCache.values[layerIdx].slice(0, totalLen * kvSize);
3934
+ }
3935
+ const scale2 = 1 / Math.sqrt(headDim);
3936
+ const totalKVLen = this.kvCache ? this.kvCache.seqLen : seqLen;
3937
+ const attnOutput = new Float32Array(seqLen * numHeads * headDim);
3938
+ for (let pos = 0;pos < seqLen; pos++) {
3939
+ for (let h = 0;h < numHeads; h++) {
3940
+ const kvHead = Math.floor(h * numKVHeads / numHeads);
3941
+ const scores = new Float32Array(totalKVLen);
3942
+ for (let kPos = 0;kPos < totalKVLen; kPos++) {
3943
+ if (kPos > startPos + pos) {
3944
+ scores[kPos] = -Infinity;
3945
+ continue;
3946
+ }
3947
+ let score = 0;
3948
+ for (let d = 0;d < headDim; d++) {
3949
+ const qIdx = pos * numHeads * headDim + h * headDim + d;
3950
+ const kIdx = kPos * numKVHeads * headDim + kvHead * headDim + d;
3951
+ score += q[qIdx] * k[kIdx];
3952
+ }
3953
+ scores[kPos] = score * scale2;
3954
+ }
3955
+ const probs = softmaxCPU(scores, [totalKVLen]);
3956
+ for (let d = 0;d < headDim; d++) {
3957
+ let val = 0;
3958
+ for (let vPos = 0;vPos < totalKVLen; vPos++) {
3959
+ const vIdx = vPos * numKVHeads * headDim + kvHead * headDim + d;
3960
+ val += probs[vPos] * v[vIdx];
3961
+ }
3962
+ const outIdx = pos * numHeads * headDim + h * headDim + d;
3963
+ attnOutput[outIdx] = val;
3964
+ }
3965
+ }
3966
+ }
3967
+ const projected = matmulCPU(attnOutput, lw.attention.oProj, seqLen, hiddenSize, numHeads * headDim);
3968
+ return addCPU(residual, projected);
3969
+ }
3970
+ applyRoPE(x, offset, position, headDim) {
3971
+ for (let i = 0;i < headDim / 2; i++) {
3972
+ const freqIdx = position * (headDim / 2) + i;
3973
+ const cos = this.ropeFreqsCos[freqIdx];
3974
+ const sin = this.ropeFreqsSin[freqIdx];
3975
+ const x0 = x[offset + i];
3976
+ const x1 = x[offset + headDim / 2 + i];
3977
+ x[offset + i] = x0 * cos - x1 * sin;
3978
+ x[offset + headDim / 2 + i] = x0 * sin + x1 * cos;
3979
+ }
3980
+ }
3981
+ ffnForward(x, lw, residual) {
3982
+ const config = this.modelConfig;
3983
+ const seqLen = x.length / config.hiddenSize;
3984
+ const up = matmulCPU(x, lw.ffn.up, seqLen, config.intermediateSize, config.hiddenSize);
3985
+ let gateOut;
3986
+ if (lw.ffn.gate) {
3987
+ gateOut = matmulCPU(x, lw.ffn.gate, seqLen, config.intermediateSize, config.hiddenSize);
3988
+ const upSilu = siluCPU(up);
3989
+ gateOut = mulCPU(gateOut, upSilu);
3990
+ } else {
3991
+ gateOut = siluCPU(up);
3992
+ }
3993
+ const down = matmulCPU(gateOut, lw.ffn.down, seqLen, config.hiddenSize, config.intermediateSize);
3994
+ return addCPU(residual, down);
3995
+ }
3996
+ getModelConfig() {
3997
+ return this.modelConfig;
3998
+ }
3999
+ isLoaded() {
4000
+ return this.weights !== null;
4001
+ }
4002
+ dispose() {
4003
+ this.weights = null;
4004
+ this.loadedModel = null;
4005
+ this.kvCache = null;
4006
+ this.ropeFreqsCos = null;
4007
+ this.ropeFreqsSin = null;
4008
+ }
4009
+ }
4010
+ // src/inference/generate.ts
4011
+ function sampleNextToken(logits, config, generatedTokens) {
4012
+ return sampleCPU(logits, {
4013
+ temperature: config.temperature,
4014
+ topK: config.topK,
4015
+ topP: config.topP,
4016
+ repetitionPenalty: config.repetitionPenalty
4017
+ }, generatedTokens || []);
4018
+ }
4019
+ function checkStopSequences(generatedTokens, stopSequences) {
4020
+ if (!stopSequences || stopSequences.length === 0) {
4021
+ return false;
4022
+ }
4023
+ for (const stopSeq of stopSequences) {
4024
+ if (generatedTokens.length >= stopSeq.length) {
4025
+ const tail = generatedTokens.slice(-stopSeq.length);
4026
+ if (tail.every((t, i) => t === stopSeq[i])) {
4027
+ return true;
4028
+ }
4029
+ }
4030
+ }
4031
+ return false;
4032
+ }
4033
+ async function generate(engine, promptTokens, config = {}) {
4034
+ const normalizedConfig = normalizeGenerationConfig(config);
4035
+ const startTime = performance.now();
4036
+ engine.resetKVCache();
4037
+ const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
4038
+ let result = engine.forward(prompt, 0);
4039
+ const generatedTokens = [];
4040
+ let finishReason = "length";
4041
+ let currentPos = prompt.length;
4042
+ for (let i = 0;i < normalizedConfig.maxTokens; i++) {
4043
+ const nextToken = sampleNextToken(result.logits, normalizedConfig, generatedTokens);
4044
+ generatedTokens.push(nextToken);
4045
+ if (nextToken === normalizedConfig.eosTokenId) {
4046
+ finishReason = "eos";
4047
+ break;
4048
+ }
4049
+ if (checkStopSequences(generatedTokens, normalizedConfig.stopSequences)) {
4050
+ finishReason = "stop";
4051
+ break;
4052
+ }
4053
+ const inputToken = new Uint32Array([nextToken]);
4054
+ result = engine.forward(inputToken, currentPos);
4055
+ currentPos += 1;
4056
+ }
4057
+ const endTime = performance.now();
4058
+ const totalTimeMs = endTime - startTime;
4059
+ return {
4060
+ tokens: generatedTokens,
4061
+ finishReason,
4062
+ promptTokens: prompt.length,
4063
+ generatedTokens: generatedTokens.length,
4064
+ totalTimeMs,
4065
+ tokensPerSecond: generatedTokens.length / totalTimeMs * 1000
4066
+ };
4067
+ }
4068
+ async function* generateStream(engine, promptTokens, config = {}) {
4069
+ const normalizedConfig = normalizeGenerationConfig(config);
4070
+ engine.resetKVCache();
4071
+ const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
4072
+ let result = engine.forward(prompt, 0);
4073
+ const generatedTokens = [];
4074
+ let currentPos = prompt.length;
4075
+ for (let i = 0;i < normalizedConfig.maxTokens; i++) {
4076
+ const nextToken = sampleNextToken(result.logits, normalizedConfig, generatedTokens);
4077
+ generatedTokens.push(nextToken);
4078
+ let finishReason;
4079
+ let isLast = false;
4080
+ if (nextToken === normalizedConfig.eosTokenId) {
4081
+ finishReason = "eos";
4082
+ isLast = true;
4083
+ } else if (checkStopSequences(generatedTokens, normalizedConfig.stopSequences)) {
4084
+ finishReason = "stop";
4085
+ isLast = true;
4086
+ } else if (i === normalizedConfig.maxTokens - 1) {
4087
+ finishReason = "length";
4088
+ isLast = true;
4089
+ }
4090
+ yield {
4091
+ tokenId: nextToken,
4092
+ index: i,
4093
+ isLast,
4094
+ finishReason
4095
+ };
4096
+ if (isLast) {
4097
+ break;
4098
+ }
4099
+ const inputToken = new Uint32Array([nextToken]);
4100
+ result = engine.forward(inputToken, currentPos);
4101
+ currentPos += 1;
4102
+ await new Promise((resolve) => setTimeout(resolve, 0));
4103
+ }
4104
+ }
4105
+ function greedyDecode(engine, promptTokens, maxTokens, eosTokenId = 2) {
4106
+ engine.resetKVCache();
4107
+ const prompt = promptTokens instanceof Uint32Array ? promptTokens : new Uint32Array(promptTokens);
4108
+ let result = engine.forward(prompt, 0);
4109
+ const generatedTokens = [];
4110
+ let currentPos = prompt.length;
4111
+ for (let i = 0;i < maxTokens; i++) {
4112
+ let maxIdx = 0;
4113
+ let maxVal = result.logits[0];
4114
+ for (let j = 1;j < result.logits.length; j++) {
4115
+ if (result.logits[j] > maxVal) {
4116
+ maxVal = result.logits[j];
4117
+ maxIdx = j;
4118
+ }
4119
+ }
4120
+ generatedTokens.push(maxIdx);
4121
+ if (maxIdx === eosTokenId) {
4122
+ break;
4123
+ }
4124
+ const inputToken = new Uint32Array([maxIdx]);
4125
+ result = engine.forward(inputToken, currentPos);
4126
+ currentPos += 1;
4127
+ }
4128
+ return generatedTokens;
4129
+ }
4130
+ export {
4131
+ transposeCPU,
4132
+ transpose2DCPU,
4133
+ transpose2D,
4134
+ topPFilterCPU,
4135
+ topPFilter,
4136
+ topKFilter,
4137
+ topKCPU,
4138
+ topK,
4139
+ softmaxGPU,
4140
+ softmaxCPU,
4141
+ softmax,
4142
+ siluCPU,
4143
+ silu,
4144
+ sigmoidCPU,
4145
+ scaleCPU,
4146
+ scale,
4147
+ sampleNextToken,
4148
+ sampleGreedy,
4149
+ sampleFromProbs,
4150
+ sampleCPU,
4151
+ sample,
4152
+ ropeCPU,
4153
+ rope,
4154
+ rmsNormCPU,
4155
+ rmsNorm,
4156
+ reshapeCPU,
4157
+ reluCPU,
4158
+ relu,
4159
+ quantizeToInt8,
4160
+ quantizeToInt4,
4161
+ quantizationError,
4162
+ qmatmulInt8CPU,
4163
+ qmatmulInt8BlockCPU,
4164
+ qmatmulInt4CPU,
4165
+ permuteCPU,
4166
+ parseSafetensorsHeader,
4167
+ parseGGUFHeader,
4168
+ normalizeGenerationConfig,
4169
+ mulCPU,
4170
+ mul,
4171
+ matmulCPU,
4172
+ matmul,
4173
+ logSoftmaxCPU,
4174
+ loadSafetensorsFromUrl,
4175
+ loadSafetensors,
4176
+ loadModel,
4177
+ loadGGUFTensor,
4178
+ loadGGUFFromUrl,
4179
+ loadGGUF,
4180
+ layerNormCPU,
4181
+ layerNorm,
4182
+ isSafetensors,
4183
+ isGGUF,
4184
+ greedyDecode,
4185
+ getSparsityRatio,
4186
+ getSlidingWindowSparsity,
4187
+ getMemorySavings,
4188
+ getMatMulCacheStats,
4189
+ getCausalSparsity,
4190
+ generateStream,
4191
+ generate,
4192
+ geluExactCPU,
4193
+ geluCPU,
4194
+ gelu,
4195
+ fmaCPU,
4196
+ flashAttention,
4197
+ estimateQMatMulFlops,
4198
+ estimateQMatMulBandwidth,
4199
+ estimateMemorySavings,
4200
+ embeddingCPU,
4201
+ embedding,
4202
+ dequantizeInt8,
4203
+ dequantizeInt4,
4204
+ computeRoPEFrequencies,
4205
+ buildSlidingWindowMask,
4206
+ buildCausalSlidingWindowMask,
4207
+ buildCausalMask,
4208
+ buildBlockSparseCSR,
4209
+ batchedEmbeddingCPU,
4210
+ attentionCPU,
4211
+ applyRepetitionPenalty,
4212
+ addScalarCPU,
4213
+ addCPU,
4214
+ add,
4215
+ WebInferDevice,
4216
+ WGSLCompiler,
4217
+ Tensor,
4218
+ PagedKVCache,
4219
+ KernelCache,
4220
+ InferenceEngine,
4221
+ GGUFQuantType,
4222
+ GGUFMetadataValueType,
4223
+ DEFAULT_GENERATION_CONFIG,
4224
+ ContinuousBatchScheduler,
4225
+ BufferPool,
4226
+ BlockManager,
4227
+ AttentionScheduler
4228
+ };