webinfer 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +65 -21
- package/dist/activation/index.d.ts +30 -0
- package/dist/core/context.d.ts +70 -0
- package/dist/core/paged-kv-cache.d.ts +33 -0
- package/dist/core/tensor.d.ts +51 -19
- package/dist/core/types.d.ts +27 -0
- package/dist/decode/index.d.ts +140 -0
- package/dist/gemm/index.d.ts +27 -0
- package/dist/index.d.ts +29 -21
- package/dist/index.js +3433 -4809
- package/dist/jit/index.d.ts +138 -0
- package/dist/kernels/activation.wgsl.d.ts +14 -0
- package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
- package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
- package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
- package/dist/kernels/gemm.wgsl.d.ts +17 -0
- package/dist/kernels/page.wgsl.d.ts +10 -0
- package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
- package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
- package/dist/kernels/rope.wgsl.d.ts +19 -0
- package/dist/kernels/sampling.wgsl.d.ts +23 -0
- package/dist/norm/index.d.ts +43 -0
- package/dist/page/index.d.ts +21 -0
- package/dist/prefill/index.d.ts +155 -0
- package/dist/rope/index.d.ts +37 -0
- package/dist/sampling/index.d.ts +53 -4
- package/package.json +1 -1
- package/dist/attention/block-sparse/format.d.ts +0 -52
- package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
- package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
- package/dist/attention/block-sparse/patterns/tree.d.ts +0 -65
- package/dist/attention/cascaded-inference.d.ts +0 -29
- package/dist/attention/flash-attention.d.ts +0 -30
- package/dist/attention/index.d.ts +0 -118
- package/dist/attention/paged-attention.d.ts +0 -40
- package/dist/attention/paged-kv/block-manager.d.ts +0 -102
- package/dist/attention/paged-kv/index.d.ts +0 -5
- package/dist/attention/paged-kv/page-table.d.ts +0 -165
- package/dist/attention/scheduler.d.ts +0 -40
- package/dist/core/buffer-pool.d.ts +0 -18
- package/dist/core/device.d.ts +0 -23
- package/dist/core/tdr.d.ts +0 -114
- package/dist/inference/engine.d.ts +0 -69
- package/dist/inference/generate.d.ts +0 -30
- package/dist/inference/index.d.ts +0 -7
- package/dist/inference/types.d.ts +0 -161
- package/dist/jit/compiler.d.ts +0 -23
- package/dist/jit/kernel-cache.d.ts +0 -21
- package/dist/model/gguf.d.ts +0 -90
- package/dist/model/index.d.ts +0 -16
- package/dist/model/safetensors.d.ts +0 -38
- package/dist/model/types.d.ts +0 -182
- package/dist/ops/activations.d.ts +0 -43
- package/dist/ops/elementwise.d.ts +0 -38
- package/dist/ops/embedding.d.ts +0 -30
- package/dist/ops/matmul.d.ts +0 -21
- package/dist/ops/normalization.d.ts +0 -63
- package/dist/ops/reshape.d.ts +0 -39
- package/dist/ops/rope.d.ts +0 -32
- package/dist/ops/softmax.d.ts +0 -18
- package/dist/quantization/index.d.ts +0 -6
- package/dist/quantization/qmatmul.d.ts +0 -38
- package/dist/quantization/quantize.d.ts +0 -52
- package/dist/sampling/beam-search.d.ts +0 -87
- package/dist/sampling/sampler.d.ts +0 -72
- package/dist/sampling/speculative.d.ts +0 -65
- package/dist/sampling/top-k.d.ts +0 -24
- package/dist/sampling/top-p.d.ts +0 -14
- package/dist/tvm/adapter.d.ts +0 -81
- package/dist/tvm/index.d.ts +0 -8
- package/dist/tvm/ops.d.ts +0 -26
- package/dist/tvm/types.d.ts +0 -35
package/README.md
CHANGED
|
@@ -1,44 +1,88 @@
|
|
|
1
1
|
# WebInfer
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
[](https://www.npmjs.com/package/webinfer)
|
|
4
|
+
[](LICENSE)
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
High-performance LLM inference kernels for WebGPU. A browser-native implementation of [FlashInfer](https://github.com/flashinfer-ai/flashinfer) APIs.
|
|
7
|
+
|
|
8
|
+
## Why WebInfer?
|
|
9
|
+
|
|
10
|
+
Running LLMs in the browser requires efficient GPU kernels. WebInfer brings FlashInfer's battle-tested attention mechanisms to WebGPU:
|
|
11
|
+
|
|
12
|
+
- **Flash Attention**: Online softmax with O(1) memory for long sequences
|
|
13
|
+
- **Subgroup Optimizations**: 2-4x faster reductions on supported hardware
|
|
14
|
+
- **JIT Compilation**: Runtime kernel generation with pipeline caching
|
|
15
|
+
- **Dynamic Tile Sizes**: Auto-tuned tile dimensions based on workload and hardware
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
6
18
|
|
|
7
19
|
```bash
|
|
8
20
|
npm install webinfer
|
|
9
21
|
```
|
|
10
22
|
|
|
11
|
-
##
|
|
23
|
+
## Quick Start
|
|
12
24
|
|
|
13
25
|
```typescript
|
|
14
|
-
import
|
|
26
|
+
import * as webinfer from 'webinfer';
|
|
15
27
|
|
|
16
|
-
|
|
28
|
+
// Initialize context
|
|
29
|
+
const ctx = await webinfer.WebInferContext.create();
|
|
17
30
|
|
|
18
|
-
//
|
|
19
|
-
const
|
|
20
|
-
const k = new Float32Array(2048 * 32 * 128); // [kv_len, num_kv_heads, head_dim]
|
|
21
|
-
const v = new Float32Array(2048 * 32 * 128);
|
|
31
|
+
// Run RMSNorm
|
|
32
|
+
const output = ctx.norm.rmsnorm(input, weight, epsilon);
|
|
22
33
|
|
|
23
|
-
|
|
34
|
+
// Flash attention decode
|
|
35
|
+
const attnOutput = ctx.decode.single_decode_with_kv_cache(
|
|
36
|
+
query, // [num_heads, head_dim]
|
|
37
|
+
kv_cache, // [seq_len, 2, num_kv_heads, head_dim]
|
|
38
|
+
);
|
|
24
39
|
```
|
|
25
40
|
|
|
26
|
-
##
|
|
41
|
+
## Browser Compatibility
|
|
27
42
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
|
31
|
-
|
|
32
|
-
|
|
|
33
|
-
|
|
|
34
|
-
|
|
|
35
|
-
|
|
|
43
|
+
WebInfer requires WebGPU support:
|
|
44
|
+
|
|
45
|
+
| Browser | Status |
|
|
46
|
+
|---------|--------|
|
|
47
|
+
| Chrome 113+ | Supported |
|
|
48
|
+
| Edge 113+ | Supported |
|
|
49
|
+
| Firefox Nightly | Behind flag |
|
|
50
|
+
| Safari 18+ | Supported |
|
|
51
|
+
|
|
52
|
+
Check for WebGPU support:
|
|
53
|
+
|
|
54
|
+
```typescript
|
|
55
|
+
if (!navigator.gpu) {
|
|
56
|
+
console.error('WebGPU not supported');
|
|
57
|
+
}
|
|
58
|
+
```
|
|
36
59
|
|
|
37
|
-
##
|
|
60
|
+
## Development
|
|
38
61
|
|
|
39
62
|
```bash
|
|
63
|
+
# Run tests
|
|
64
|
+
bun test
|
|
65
|
+
|
|
66
|
+
# Build
|
|
40
67
|
bun run build
|
|
41
|
-
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Acknowledgments
|
|
71
|
+
|
|
72
|
+
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) - The original CUDA implementation
|
|
73
|
+
- [WebLLM](https://github.com/mlc-ai/web-llm) - Browser LLM runtime
|
|
74
|
+
|
|
75
|
+
## Citation
|
|
76
|
+
|
|
77
|
+
If you use WebInfer in your research, please cite:
|
|
78
|
+
|
|
79
|
+
```bibtex
|
|
80
|
+
@software{webinfer2025,
|
|
81
|
+
author = {Guan-Ming, Chiu},
|
|
82
|
+
title = {WebInfer: High-Performance LLM Inference Kernels for WebGPU},
|
|
83
|
+
year = {2026},
|
|
84
|
+
url = {https://github.com/guan404ming/webinfer}
|
|
85
|
+
}
|
|
42
86
|
```
|
|
43
87
|
|
|
44
88
|
## License
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Activation functions
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* SiLU (Swish) activation with gating
|
|
8
|
+
*
|
|
9
|
+
* Input tensor has shape [..., 2 * hidden_size]
|
|
10
|
+
* Splits into gate[..., :hidden_size] and up[..., hidden_size:]
|
|
11
|
+
* Computes: output = silu(gate) * up
|
|
12
|
+
* Where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
|
13
|
+
*
|
|
14
|
+
* @param ctx WebInfer context
|
|
15
|
+
* @param input Input tensor of shape [..., 2 * hidden_size]
|
|
16
|
+
* @returns Activated tensor of shape [..., hidden_size]
|
|
17
|
+
*/
|
|
18
|
+
export declare function silu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
|
|
19
|
+
/**
|
|
20
|
+
* GELU activation with gating
|
|
21
|
+
*
|
|
22
|
+
* Input tensor has shape [..., 2 * hidden_size]
|
|
23
|
+
* Computes: output = gelu(gate) * up
|
|
24
|
+
* Where gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
|
|
25
|
+
*
|
|
26
|
+
* @param ctx WebInfer context
|
|
27
|
+
* @param input Input tensor of shape [..., 2 * hidden_size]
|
|
28
|
+
* @returns Activated tensor of shape [..., hidden_size]
|
|
29
|
+
*/
|
|
30
|
+
export declare function gelu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WebInferContext manages WebGPU device and provides core functionality
|
|
3
|
+
*/
|
|
4
|
+
import type { DType } from './types.ts';
|
|
5
|
+
import { Tensor } from './tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Device vendor detection
|
|
8
|
+
*/
|
|
9
|
+
export type Vendor = 'apple' | 'nvidia' | 'amd' | 'intel' | 'unknown';
|
|
10
|
+
/**
|
|
11
|
+
* Device features detected at runtime
|
|
12
|
+
*/
|
|
13
|
+
export interface DeviceFeatures {
|
|
14
|
+
readonly vendor: Vendor;
|
|
15
|
+
readonly hasSubgroups: boolean;
|
|
16
|
+
readonly hasF16: boolean;
|
|
17
|
+
readonly maxWorkgroupSize: number;
|
|
18
|
+
readonly maxStorageBufferSize: number;
|
|
19
|
+
readonly maxBufferSize: number;
|
|
20
|
+
readonly maxWorkgroupsPerDimension: number;
|
|
21
|
+
}
|
|
22
|
+
/**
|
|
23
|
+
* WebInferContext - main context for all WebInfer operations
|
|
24
|
+
*/
|
|
25
|
+
export declare class WebInferContext {
|
|
26
|
+
readonly device: GPUDevice;
|
|
27
|
+
readonly features: DeviceFeatures;
|
|
28
|
+
private pipelineCache;
|
|
29
|
+
private disposed;
|
|
30
|
+
private constructor();
|
|
31
|
+
/**
|
|
32
|
+
* Create a WebInferContext from an existing GPUDevice.
|
|
33
|
+
* Use this when integrating with TVM or other WebGPU frameworks
|
|
34
|
+
* that already have a GPUDevice to enable zero-copy buffer sharing.
|
|
35
|
+
*
|
|
36
|
+
* @param device - The existing GPUDevice to use
|
|
37
|
+
* @param featureOverrides - Optional overrides for detected features
|
|
38
|
+
* @returns A new WebInferContext using the provided device
|
|
39
|
+
*/
|
|
40
|
+
static fromDevice(device: GPUDevice, featureOverrides?: Partial<DeviceFeatures>): WebInferContext;
|
|
41
|
+
/**
|
|
42
|
+
* Create a new WebInferContext with its own GPUDevice
|
|
43
|
+
*/
|
|
44
|
+
static create(): Promise<WebInferContext>;
|
|
45
|
+
/**
|
|
46
|
+
* Create a GPU buffer
|
|
47
|
+
*/
|
|
48
|
+
createBuffer(size: number, usage?: GPUBufferUsageFlags): GPUBuffer;
|
|
49
|
+
/**
|
|
50
|
+
* Create a tensor with uninitialized data
|
|
51
|
+
*/
|
|
52
|
+
createTensor(shape: number[], dtype: DType): Tensor;
|
|
53
|
+
/**
|
|
54
|
+
* Create a tensor from TypedArray data
|
|
55
|
+
*/
|
|
56
|
+
createTensorFromData(data: Float32Array | Uint16Array, shape: number[], dtype: DType): Tensor;
|
|
57
|
+
/**
|
|
58
|
+
* Get or create a cached compute pipeline
|
|
59
|
+
*/
|
|
60
|
+
getOrCreatePipeline(key: string, createFn: () => Promise<GPUComputePipeline> | GPUComputePipeline): Promise<GPUComputePipeline>;
|
|
61
|
+
/**
|
|
62
|
+
* Clear pipeline cache
|
|
63
|
+
*/
|
|
64
|
+
clearPipelineCache(): void;
|
|
65
|
+
/**
|
|
66
|
+
* Dispose the context and cleanup resources
|
|
67
|
+
*/
|
|
68
|
+
dispose(): void;
|
|
69
|
+
private checkDisposed;
|
|
70
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Paged KV Cache implementation
|
|
3
|
+
*/
|
|
4
|
+
import type { Tensor } from './tensor.ts';
|
|
5
|
+
/**
|
|
6
|
+
* Paged KV Cache for efficient memory management
|
|
7
|
+
*
|
|
8
|
+
* Layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
|
|
9
|
+
* where dim 1 has [0] = keys, [1] = values
|
|
10
|
+
*/
|
|
11
|
+
export interface PagedKvCache {
|
|
12
|
+
/**
|
|
13
|
+
* The backing tensor for the paged KV cache
|
|
14
|
+
* Shape: [num_pages, 2, page_size, num_kv_heads, head_dim]
|
|
15
|
+
*/
|
|
16
|
+
data: Tensor;
|
|
17
|
+
/**
|
|
18
|
+
* Number of tokens per page
|
|
19
|
+
*/
|
|
20
|
+
page_size: number;
|
|
21
|
+
/**
|
|
22
|
+
* Number of KV heads (for GQA/MQA)
|
|
23
|
+
*/
|
|
24
|
+
num_kv_heads: number;
|
|
25
|
+
/**
|
|
26
|
+
* Dimension of each head
|
|
27
|
+
*/
|
|
28
|
+
head_dim: number;
|
|
29
|
+
}
|
|
30
|
+
/**
|
|
31
|
+
* Create a paged KV cache
|
|
32
|
+
*/
|
|
33
|
+
export declare function createPagedKvCache(data: Tensor, page_size: number, num_kv_heads: number, head_dim: number): PagedKvCache;
|
package/dist/core/tensor.d.ts
CHANGED
|
@@ -1,25 +1,57 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* Tensor
|
|
2
|
+
* Tensor wrapper for WebGPU buffers
|
|
3
|
+
*/
|
|
4
|
+
import type { DType } from './types.ts';
|
|
5
|
+
/**
|
|
6
|
+
* Compute strides from shape (row-major order)
|
|
7
|
+
*/
|
|
8
|
+
export declare function computeStrides(shape: number[]): number[];
|
|
9
|
+
/**
|
|
10
|
+
* Compute total number of elements from shape
|
|
11
|
+
*/
|
|
12
|
+
export declare function computeSize(shape: number[]): number;
|
|
13
|
+
/**
|
|
14
|
+
* Tensor class wrapping a GPUBuffer with shape and dtype information
|
|
3
15
|
*/
|
|
4
|
-
import type { WebInferDevice } from "./device.ts";
|
|
5
|
-
export type DType = "f32" | "f16" | "i32" | "u32";
|
|
6
16
|
export declare class Tensor {
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
17
|
+
readonly buffer: GPUBuffer;
|
|
18
|
+
readonly shape: readonly number[];
|
|
19
|
+
readonly dtype: DType;
|
|
20
|
+
readonly strides: readonly number[];
|
|
21
|
+
readonly size: number;
|
|
22
|
+
readonly byteSize: number;
|
|
23
|
+
constructor(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]);
|
|
24
|
+
/**
|
|
25
|
+
* Create a Tensor from an existing GPUBuffer (zero-copy).
|
|
26
|
+
*
|
|
27
|
+
* This is useful for TVM integration where the buffer is already allocated
|
|
28
|
+
* by TVM's WebGPU runtime and we want to wrap it without copying data.
|
|
29
|
+
*
|
|
30
|
+
* @param buffer - The existing GPUBuffer to wrap
|
|
31
|
+
* @param shape - The shape of the tensor
|
|
32
|
+
* @param dtype - The data type of the tensor
|
|
33
|
+
* @param strides - Optional custom strides (default: row-major)
|
|
34
|
+
* @returns A new Tensor wrapping the provided buffer
|
|
35
|
+
*/
|
|
36
|
+
static fromBuffer(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]): Tensor;
|
|
37
|
+
/**
|
|
38
|
+
* Get the number of dimensions
|
|
39
|
+
*/
|
|
40
|
+
get ndim(): number;
|
|
41
|
+
/**
|
|
42
|
+
* Check if tensor is contiguous (row-major)
|
|
43
|
+
*/
|
|
44
|
+
get isContiguous(): boolean;
|
|
45
|
+
/**
|
|
46
|
+
* Get a view of the tensor with new shape (must have same size)
|
|
47
|
+
*/
|
|
23
48
|
reshape(newShape: number[]): Tensor;
|
|
49
|
+
/**
|
|
50
|
+
* Destroy the underlying GPU buffer
|
|
51
|
+
*/
|
|
24
52
|
dispose(): void;
|
|
53
|
+
/**
|
|
54
|
+
* Get string representation
|
|
55
|
+
*/
|
|
56
|
+
toString(): string;
|
|
25
57
|
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Core type definitions for WebInfer
|
|
3
|
+
*/
|
|
4
|
+
/**
|
|
5
|
+
* Data type for tensors
|
|
6
|
+
*/
|
|
7
|
+
export type DType = 'float32' | 'float16';
|
|
8
|
+
/**
|
|
9
|
+
* Position encoding modes for attention
|
|
10
|
+
*/
|
|
11
|
+
export declare enum PosEncodingMode {
|
|
12
|
+
NONE = 0,
|
|
13
|
+
ROPE_LLAMA = 1,
|
|
14
|
+
ALIBI = 2
|
|
15
|
+
}
|
|
16
|
+
/**
|
|
17
|
+
* Get the byte size for a given dtype
|
|
18
|
+
*/
|
|
19
|
+
export declare function dtypeByteSize(dtype: DType): number;
|
|
20
|
+
/**
|
|
21
|
+
* Get the GPUTextureFormat for a given dtype
|
|
22
|
+
*/
|
|
23
|
+
export declare function dtypeToGPUFormat(dtype: DType): string;
|
|
24
|
+
/**
|
|
25
|
+
* Get WGSL type string for a given dtype
|
|
26
|
+
*/
|
|
27
|
+
export declare function dtypeToWGSL(dtype: DType): 'f32' | 'f16';
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Decode attention operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
import type { PagedKvCache } from '../core/paged-kv-cache.ts';
|
|
7
|
+
import { PosEncodingMode } from '../core/types.ts';
|
|
8
|
+
/**
|
|
9
|
+
* Options for batch_decode_plan()
|
|
10
|
+
*/
|
|
11
|
+
export interface BatchDecodePlanOptions {
|
|
12
|
+
/** Number of sequences in the batch */
|
|
13
|
+
batchSize: number;
|
|
14
|
+
/** Size of each page in the paged KV cache */
|
|
15
|
+
pageSize: number;
|
|
16
|
+
/** Number of query/output heads */
|
|
17
|
+
numQoHeads: number;
|
|
18
|
+
/** Number of key/value heads */
|
|
19
|
+
numKvHeads: number;
|
|
20
|
+
/** Head dimension */
|
|
21
|
+
headDim: number;
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Plan information for batch decode with paged KV cache.
|
|
25
|
+
* This is returned by batch_decode_plan() and passed to batch_decode_run().
|
|
26
|
+
*/
|
|
27
|
+
export interface BatchDecodePlanInfo {
|
|
28
|
+
/** Unique key for this configuration */
|
|
29
|
+
key: string;
|
|
30
|
+
/** Number of query/output heads */
|
|
31
|
+
num_qo_heads: number;
|
|
32
|
+
/** Number of key/value heads */
|
|
33
|
+
num_kv_heads: number;
|
|
34
|
+
/** Head dimension */
|
|
35
|
+
head_dim: number;
|
|
36
|
+
/** Page size */
|
|
37
|
+
page_size: number;
|
|
38
|
+
/** Softmax scale */
|
|
39
|
+
sm_scale: number;
|
|
40
|
+
/** Batch size */
|
|
41
|
+
batch_size: number;
|
|
42
|
+
/** Required workspace size in bytes (currently 0) */
|
|
43
|
+
workspaceSize: number;
|
|
44
|
+
}
|
|
45
|
+
/**
|
|
46
|
+
* Plan batch decode with paged KV cache.
|
|
47
|
+
*
|
|
48
|
+
* @param options - Configuration options
|
|
49
|
+
* @returns Plan information for execution
|
|
50
|
+
*
|
|
51
|
+
* @example
|
|
52
|
+
* const plan = batch_decode_plan({
|
|
53
|
+
* batchSize: 4,
|
|
54
|
+
* pageSize: 16,
|
|
55
|
+
* numQoHeads: 32,
|
|
56
|
+
* numKvHeads: 8,
|
|
57
|
+
* headDim: 128
|
|
58
|
+
* });
|
|
59
|
+
*/
|
|
60
|
+
export declare function batch_decode_plan(options: BatchDecodePlanOptions): BatchDecodePlanInfo;
|
|
61
|
+
/**
|
|
62
|
+
* Execute batch decode with paged KV cache.
|
|
63
|
+
*
|
|
64
|
+
* This function executes the attention computation using the plan
|
|
65
|
+
* prepared by batch_decode_plan().
|
|
66
|
+
*
|
|
67
|
+
* @param ctx - WebInfer context
|
|
68
|
+
* @param planInfo - Plan information from batch_decode_plan()
|
|
69
|
+
* @param q - Query tensor [batch_size, num_qo_heads, head_dim]
|
|
70
|
+
* @param pagedKvCache - Paged KV cache
|
|
71
|
+
* @param pageIndptr - Page indirection pointer [batch_size + 1]
|
|
72
|
+
* @param pageIndices - Page indices [nnz_pages]
|
|
73
|
+
* @param lastPageLen - Last page lengths [batch_size]
|
|
74
|
+
* @param output - Output tensor [batch_size, num_qo_heads, head_dim]
|
|
75
|
+
* @param lse - Log-sum-exp output [batch_size, num_qo_heads] (optional)
|
|
76
|
+
*/
|
|
77
|
+
export declare function batch_decode_run(ctx: WebInferContext, planInfo: BatchDecodePlanInfo, q: Tensor, pagedKvCache: PagedKvCache, pageIndptr: Tensor, pageIndices: Tensor, lastPageLen: Tensor, output: Tensor, lse?: Tensor): Promise<void>;
|
|
78
|
+
/**
|
|
79
|
+
* Single decode with KV cache
|
|
80
|
+
*
|
|
81
|
+
* Performs attention for a single token (decode phase) with KV cache.
|
|
82
|
+
* Uses flash attention with online softmax to avoid materializing full attention matrix.
|
|
83
|
+
*
|
|
84
|
+
* @param ctx WebInfer context
|
|
85
|
+
* @param q Query tensor of shape [num_qo_heads, head_dim]
|
|
86
|
+
* @param k_cache Key cache tensor of shape [seq_len, num_kv_heads, head_dim]
|
|
87
|
+
* @param v_cache Value cache tensor of shape [seq_len, num_kv_heads, head_dim]
|
|
88
|
+
* @param output Optional pre-allocated output tensor [num_qo_heads, head_dim].
|
|
89
|
+
* If provided, results are written to this tensor (zero-copy for TVM integration).
|
|
90
|
+
* If not provided, a new tensor is created and returned.
|
|
91
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
92
|
+
* @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
|
|
93
|
+
* @param rope_scale RoPE scale (default: 1.0) - not yet implemented
|
|
94
|
+
* @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
|
|
95
|
+
* @returns Output tensor of shape [num_qo_heads, head_dim]
|
|
96
|
+
*/
|
|
97
|
+
export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, output?: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
|
|
98
|
+
/**
|
|
99
|
+
* Batched decode with paged KV cache wrapper
|
|
100
|
+
*
|
|
101
|
+
* FlashInfer-compatible wrapper for batched decode attention with paged KV cache.
|
|
102
|
+
*/
|
|
103
|
+
export declare class BatchDecodeWithPagedKVCacheWrapper {
|
|
104
|
+
private ctx;
|
|
105
|
+
private planned;
|
|
106
|
+
private num_qo_heads;
|
|
107
|
+
private num_kv_heads;
|
|
108
|
+
private head_dim;
|
|
109
|
+
private page_size;
|
|
110
|
+
private batch_size;
|
|
111
|
+
private sm_scale;
|
|
112
|
+
private pos_encoding_mode;
|
|
113
|
+
constructor(ctx: WebInferContext);
|
|
114
|
+
/**
|
|
115
|
+
* Plan the batched decode operation
|
|
116
|
+
*
|
|
117
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
118
|
+
* @param indices Page indices [nnz_pages]
|
|
119
|
+
* @param last_page_len Last page lengths [batch_size]
|
|
120
|
+
* @param num_qo_heads Number of query heads
|
|
121
|
+
* @param num_kv_heads Number of KV heads
|
|
122
|
+
* @param head_dim Head dimension
|
|
123
|
+
* @param page_size Page size
|
|
124
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
125
|
+
*/
|
|
126
|
+
plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, pos_encoding_mode?: PosEncodingMode): void;
|
|
127
|
+
/**
|
|
128
|
+
* Run the batched decode operation
|
|
129
|
+
*
|
|
130
|
+
* @param q Query tensor [batch_size, num_qo_heads, head_dim]
|
|
131
|
+
* @param paged_kv_cache Paged KV cache
|
|
132
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
133
|
+
* @param indices Page indices [nnz_pages]
|
|
134
|
+
* @param last_page_len Last page lengths [batch_size]
|
|
135
|
+
* @param output Optional pre-allocated output tensor [batch_size, num_qo_heads, head_dim].
|
|
136
|
+
* If provided, results are written to this tensor (zero-copy for TVM integration).
|
|
137
|
+
* @returns Output tensor [batch_size, num_qo_heads, head_dim]
|
|
138
|
+
*/
|
|
139
|
+
run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor, output?: Tensor): Promise<Tensor>;
|
|
140
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GEMM (General Matrix Multiplication) operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Batched matrix multiplication for fp16
|
|
8
|
+
*
|
|
9
|
+
* Computes C = A @ B for batched matrices
|
|
10
|
+
*
|
|
11
|
+
* @param ctx WebInfer context
|
|
12
|
+
* @param a Input tensor A [B, M, K]
|
|
13
|
+
* @param b Input tensor B [B, K, N]
|
|
14
|
+
* @param output Optional output tensor C [B, M, N] for in-place operation
|
|
15
|
+
* @returns Output tensor C [B, M, N]
|
|
16
|
+
*/
|
|
17
|
+
export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
|
|
18
|
+
/**
|
|
19
|
+
* Batched matrix multiplication for fp32
|
|
20
|
+
*
|
|
21
|
+
* @param ctx WebInfer context
|
|
22
|
+
* @param a Input tensor A [B, M, K]
|
|
23
|
+
* @param b Input tensor B [B, K, N]
|
|
24
|
+
* @param output Optional output tensor C [B, M, N] for in-place operation
|
|
25
|
+
* @returns Output tensor C [B, M, N]
|
|
26
|
+
*/
|
|
27
|
+
export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
|
package/dist/index.d.ts
CHANGED
|
@@ -1,24 +1,32 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* WebInfer - High-performance LLM inference kernels for WebGPU
|
|
3
|
-
*
|
|
3
|
+
*
|
|
4
|
+
* WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
|
|
4
5
|
*/
|
|
5
|
-
export {
|
|
6
|
-
export { Tensor
|
|
7
|
-
export {
|
|
8
|
-
export {
|
|
9
|
-
export
|
|
10
|
-
export {
|
|
11
|
-
export {
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
export {
|
|
22
|
-
export {
|
|
23
|
-
export {
|
|
24
|
-
export
|
|
6
|
+
export { WebInferContext } from './core/context.ts';
|
|
7
|
+
export { Tensor } from './core/tensor.ts';
|
|
8
|
+
export type { DType } from './core/types.ts';
|
|
9
|
+
export { PosEncodingMode } from './core/types.ts';
|
|
10
|
+
export type { PagedKvCache } from './core/paged-kv-cache.ts';
|
|
11
|
+
export { createPagedKvCache } from './core/paged-kv-cache.ts';
|
|
12
|
+
export { WebInferContext as create_context } from './core/context.ts';
|
|
13
|
+
import * as norm from './norm/index.ts';
|
|
14
|
+
import * as activation from './activation/index.ts';
|
|
15
|
+
import * as decode from './decode/index.ts';
|
|
16
|
+
import * as prefill from './prefill/index.ts';
|
|
17
|
+
import * as page from './page/index.ts';
|
|
18
|
+
import * as rope from './rope/index.ts';
|
|
19
|
+
import * as sampling from './sampling/index.ts';
|
|
20
|
+
import * as gemm from './gemm/index.ts';
|
|
21
|
+
import * as jit from './jit/index.ts';
|
|
22
|
+
export { norm, activation, decode, prefill, page, rope, sampling, gemm, jit };
|
|
23
|
+
export type { KernelSpec, CompiledKernel, BindingInfo, PlanInfo, CompiledKernelEntry, CompiledKernelRegistry, } from './jit/index.ts';
|
|
24
|
+
export { compileKernel, getSpecKey, initFromSpecs, getCompiledKernel, } from './jit/index.ts';
|
|
25
|
+
export declare const rmsnorm: typeof norm.rmsnorm, fused_add_rmsnorm: typeof norm.fused_add_rmsnorm, gemma_rmsnorm: typeof norm.gemma_rmsnorm;
|
|
26
|
+
export declare const silu_and_mul: typeof activation.silu_and_mul, gelu_and_mul: typeof activation.gelu_and_mul;
|
|
27
|
+
export declare const single_decode_with_kv_cache: typeof decode.single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper: typeof decode.BatchDecodeWithPagedKVCacheWrapper;
|
|
28
|
+
export declare const single_prefill_with_kv_cache: typeof prefill.single_prefill_with_kv_cache, BatchPrefillWithPagedKVCacheWrapper: typeof prefill.BatchPrefillWithPagedKVCacheWrapper;
|
|
29
|
+
export declare const append_paged_kv_cache: typeof page.append_paged_kv_cache;
|
|
30
|
+
export declare const apply_rope_inplace: typeof rope.apply_rope_inplace, apply_llama31_rope_inplace: typeof rope.apply_llama31_rope_inplace;
|
|
31
|
+
export declare const sampling_from_probs: typeof sampling.sampling_from_probs, top_k_sampling_from_probs: typeof sampling.top_k_sampling_from_probs, top_p_sampling_from_probs: typeof sampling.top_p_sampling_from_probs, min_p_sampling_from_probs: typeof sampling.min_p_sampling_from_probs, top_k_top_p_sampling_from_probs: typeof sampling.top_k_top_p_sampling_from_probs;
|
|
32
|
+
export declare const bmm_fp16: typeof gemm.bmm_fp16, bmm_fp32: typeof gemm.bmm_fp32;
|