webinfer 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,59 +1,88 @@
1
1
  # WebInfer
2
2
 
3
- High-performance LLM inference kernels for WebGPU.
3
+ [![npm version](https://img.shields.io/npm/v/webinfer.svg)](https://www.npmjs.com/package/webinfer)
4
+ [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE)
4
5
 
5
- WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
6
+ High-performance LLM inference kernels for WebGPU. A browser-native implementation of [FlashInfer](https://github.com/flashinfer-ai/flashinfer) APIs.
6
7
 
7
- ## Features
8
+ ## Why WebInfer?
8
9
 
9
- - **Flash Attention for WebGPU**: Online softmax and paged KV cache for memory-efficient attention
10
- - **Subgroup Optimization**: Uses WebGPU subgroup operations for 2-4x faster reductions when available
11
- - **Fused Kernels**: Reduces memory bandwidth with operations like `silu_and_mul`, `fused_add_rmsnorm`
12
- - **JIT Compilation**: Runtime kernel generation and pipeline caching
13
- - **f16 Safe Accumulation**: Uses f32 accumulators for f16 computation to prevent precision loss
14
- - **FlashInfer Compatible API**: Drop-in replacement for FlashInfer with identical naming
10
+ Running LLMs in the browser requires efficient GPU kernels. WebInfer brings FlashInfer's battle-tested attention mechanisms to WebGPU:
15
11
 
16
- ## Install
12
+ - **Flash Attention**: Online softmax with O(1) memory for long sequences
13
+ - **Subgroup Optimizations**: 2-4x faster reductions on supported hardware
14
+ - **JIT Compilation**: Runtime kernel generation with pipeline caching
15
+ - **Dynamic Tile Sizes**: Auto-tuned tile dimensions based on workload and hardware
16
+
17
+ ## Installation
17
18
 
18
19
  ```bash
19
20
  npm install webinfer
20
21
  ```
21
22
 
22
- ## API Status
23
+ ## Quick Start
24
+
25
+ ```typescript
26
+ import * as webinfer from 'webinfer';
23
27
 
24
- **Phase 1: Core + Simple Ops**
25
- - `webinfer.WebInferContext.create()` - Initialize WebGPU context
26
- - `webinfer.norm.rmsnorm()` - Root mean square normalization
27
- - `webinfer.activation.silu_and_mul()` - SiLU activation with gating
28
- - `webinfer.activation.gelu_and_mul()` - GELU activation with gating
28
+ // Initialize context
29
+ const ctx = await webinfer.WebInferContext.create();
29
30
 
30
- **Phase 2: Single Attention**
31
- - `webinfer.decode.single_decode_with_kv_cache()` - Flash attention for decode phase
32
- - `webinfer.prefill.single_prefill_with_kv_cache()` - Flash attention for prefill phase with causal masking
31
+ // Run RMSNorm
32
+ const output = ctx.norm.rmsnorm(input, weight, epsilon);
33
+
34
+ // Flash attention decode
35
+ const attnOutput = ctx.decode.single_decode_with_kv_cache(
36
+ query, // [num_heads, head_dim]
37
+ kv_cache, // [seq_len, 2, num_kv_heads, head_dim]
38
+ );
39
+ ```
33
40
 
34
- **Phase 3: Paged KV + Batched Operations**
35
- - `webinfer.page.append_paged_kv_cache()` - Append new KV pairs to paged cache
36
- - `webinfer.decode.BatchDecodeWithPagedKVCacheWrapper` - Batched decode with paged KV cache
37
- - `webinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` - Batched prefill with paged KV cache
41
+ ## Browser Compatibility
38
42
 
39
- **Phase 4: Additional Operations**
40
- - `webinfer.rope.apply_rope_inplace()` - Apply RoPE position encoding
41
- - `webinfer.rope.apply_llama31_rope_inplace()` - Llama 3.1 RoPE variant for long context
42
- - `webinfer.sampling.sampling_from_probs()` - Categorical sampling
43
- - `webinfer.sampling.top_k_sampling_from_probs()` - Top-k sampling
44
- - `webinfer.sampling.top_p_sampling_from_probs()` - Top-p (nucleus) sampling
45
- - `webinfer.sampling.min_p_sampling_from_probs()` - Min-p sampling
46
- - `webinfer.sampling.top_k_top_p_sampling_from_probs()` - Combined top-k and top-p
47
- - `webinfer.gemm.bmm_fp16()` - Batched matrix multiplication (fp16)
48
- - `webinfer.gemm.bmm_fp32()` - Batched matrix multiplication (fp32)
43
+ WebInfer requires WebGPU support:
49
44
 
50
- See [dev/design.md](dev/design.md) for full design and implementation plan.
45
+ | Browser | Status |
46
+ |---------|--------|
47
+ | Chrome 113+ | Supported |
48
+ | Edge 113+ | Supported |
49
+ | Firefox Nightly | Behind flag |
50
+ | Safari 18+ | Supported |
51
+
52
+ Check for WebGPU support:
53
+
54
+ ```typescript
55
+ if (!navigator.gpu) {
56
+ console.error('WebGPU not supported');
57
+ }
58
+ ```
51
59
 
52
- ## Release
60
+ ## Development
53
61
 
54
62
  ```bash
63
+ # Run tests
64
+ bun test
65
+
66
+ # Build
55
67
  bun run build
56
- npm publish
68
+ ```
69
+
70
+ ## Acknowledgments
71
+
72
+ - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) - The original CUDA implementation
73
+ - [WebLLM](https://github.com/mlc-ai/web-llm) - Browser LLM runtime
74
+
75
+ ## Citation
76
+
77
+ If you use WebInfer in your research, please cite:
78
+
79
+ ```bibtex
80
+ @software{webinfer2025,
81
+ author = {Guan-Ming, Chiu},
82
+ title = {WebInfer: High-Performance LLM Inference Kernels for WebGPU},
83
+ year = {2026},
84
+ url = {https://github.com/guan404ming/webinfer}
85
+ }
57
86
  ```
58
87
 
59
88
  ## License
@@ -29,7 +29,17 @@ export declare class WebInferContext {
29
29
  private disposed;
30
30
  private constructor();
31
31
  /**
32
- * Create a new WebInferContext
32
+ * Create a WebInferContext from an existing GPUDevice.
33
+ * Use this when integrating with TVM or other WebGPU frameworks
34
+ * that already have a GPUDevice to enable zero-copy buffer sharing.
35
+ *
36
+ * @param device - The existing GPUDevice to use
37
+ * @param featureOverrides - Optional overrides for detected features
38
+ * @returns A new WebInferContext using the provided device
39
+ */
40
+ static fromDevice(device: GPUDevice, featureOverrides?: Partial<DeviceFeatures>): WebInferContext;
41
+ /**
42
+ * Create a new WebInferContext with its own GPUDevice
33
43
  */
34
44
  static create(): Promise<WebInferContext>;
35
45
  /**
@@ -21,6 +21,19 @@ export declare class Tensor {
21
21
  readonly size: number;
22
22
  readonly byteSize: number;
23
23
  constructor(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]);
24
+ /**
25
+ * Create a Tensor from an existing GPUBuffer (zero-copy).
26
+ *
27
+ * This is useful for TVM integration where the buffer is already allocated
28
+ * by TVM's WebGPU runtime and we want to wrap it without copying data.
29
+ *
30
+ * @param buffer - The existing GPUBuffer to wrap
31
+ * @param shape - The shape of the tensor
32
+ * @param dtype - The data type of the tensor
33
+ * @param strides - Optional custom strides (default: row-major)
34
+ * @returns A new Tensor wrapping the provided buffer
35
+ */
36
+ static fromBuffer(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]): Tensor;
24
37
  /**
25
38
  * Get the number of dimensions
26
39
  */
@@ -5,6 +5,76 @@ import type { WebInferContext } from '../core/context.ts';
5
5
  import type { Tensor } from '../core/tensor.ts';
6
6
  import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
7
  import { PosEncodingMode } from '../core/types.ts';
8
+ /**
9
+ * Options for batch_decode_plan()
10
+ */
11
+ export interface BatchDecodePlanOptions {
12
+ /** Number of sequences in the batch */
13
+ batchSize: number;
14
+ /** Size of each page in the paged KV cache */
15
+ pageSize: number;
16
+ /** Number of query/output heads */
17
+ numQoHeads: number;
18
+ /** Number of key/value heads */
19
+ numKvHeads: number;
20
+ /** Head dimension */
21
+ headDim: number;
22
+ }
23
+ /**
24
+ * Plan information for batch decode with paged KV cache.
25
+ * This is returned by batch_decode_plan() and passed to batch_decode_run().
26
+ */
27
+ export interface BatchDecodePlanInfo {
28
+ /** Unique key for this configuration */
29
+ key: string;
30
+ /** Number of query/output heads */
31
+ num_qo_heads: number;
32
+ /** Number of key/value heads */
33
+ num_kv_heads: number;
34
+ /** Head dimension */
35
+ head_dim: number;
36
+ /** Page size */
37
+ page_size: number;
38
+ /** Softmax scale */
39
+ sm_scale: number;
40
+ /** Batch size */
41
+ batch_size: number;
42
+ /** Required workspace size in bytes (currently 0) */
43
+ workspaceSize: number;
44
+ }
45
+ /**
46
+ * Plan batch decode with paged KV cache.
47
+ *
48
+ * @param options - Configuration options
49
+ * @returns Plan information for execution
50
+ *
51
+ * @example
52
+ * const plan = batch_decode_plan({
53
+ * batchSize: 4,
54
+ * pageSize: 16,
55
+ * numQoHeads: 32,
56
+ * numKvHeads: 8,
57
+ * headDim: 128
58
+ * });
59
+ */
60
+ export declare function batch_decode_plan(options: BatchDecodePlanOptions): BatchDecodePlanInfo;
61
+ /**
62
+ * Execute batch decode with paged KV cache.
63
+ *
64
+ * This function executes the attention computation using the plan
65
+ * prepared by batch_decode_plan().
66
+ *
67
+ * @param ctx - WebInfer context
68
+ * @param planInfo - Plan information from batch_decode_plan()
69
+ * @param q - Query tensor [batch_size, num_qo_heads, head_dim]
70
+ * @param pagedKvCache - Paged KV cache
71
+ * @param pageIndptr - Page indirection pointer [batch_size + 1]
72
+ * @param pageIndices - Page indices [nnz_pages]
73
+ * @param lastPageLen - Last page lengths [batch_size]
74
+ * @param output - Output tensor [batch_size, num_qo_heads, head_dim]
75
+ * @param lse - Log-sum-exp output [batch_size, num_qo_heads] (optional)
76
+ */
77
+ export declare function batch_decode_run(ctx: WebInferContext, planInfo: BatchDecodePlanInfo, q: Tensor, pagedKvCache: PagedKvCache, pageIndptr: Tensor, pageIndices: Tensor, lastPageLen: Tensor, output: Tensor, lse?: Tensor): Promise<void>;
8
78
  /**
9
79
  * Single decode with KV cache
10
80
  *
@@ -15,13 +85,16 @@ import { PosEncodingMode } from '../core/types.ts';
15
85
  * @param q Query tensor of shape [num_qo_heads, head_dim]
16
86
  * @param k_cache Key cache tensor of shape [seq_len, num_kv_heads, head_dim]
17
87
  * @param v_cache Value cache tensor of shape [seq_len, num_kv_heads, head_dim]
88
+ * @param output Optional pre-allocated output tensor [num_qo_heads, head_dim].
89
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
90
+ * If not provided, a new tensor is created and returned.
18
91
  * @param pos_encoding_mode Position encoding mode (default: NONE)
19
92
  * @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
20
93
  * @param rope_scale RoPE scale (default: 1.0) - not yet implemented
21
94
  * @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
22
95
  * @returns Output tensor of shape [num_qo_heads, head_dim]
23
96
  */
24
- export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
97
+ export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, output?: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
25
98
  /**
26
99
  * Batched decode with paged KV cache wrapper
27
100
  *
@@ -59,7 +132,9 @@ export declare class BatchDecodeWithPagedKVCacheWrapper {
59
132
  * @param indptr Indirection pointer [batch_size + 1]
60
133
  * @param indices Page indices [nnz_pages]
61
134
  * @param last_page_len Last page lengths [batch_size]
135
+ * @param output Optional pre-allocated output tensor [batch_size, num_qo_heads, head_dim].
136
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
62
137
  * @returns Output tensor [batch_size, num_qo_heads, head_dim]
63
138
  */
64
- run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor): Promise<Tensor>;
139
+ run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor, output?: Tensor): Promise<Tensor>;
65
140
  }
@@ -11,15 +11,17 @@ import type { Tensor } from '../core/tensor.ts';
11
11
  * @param ctx WebInfer context
12
12
  * @param a Input tensor A [B, M, K]
13
13
  * @param b Input tensor B [B, K, N]
14
+ * @param output Optional output tensor C [B, M, N] for in-place operation
14
15
  * @returns Output tensor C [B, M, N]
15
16
  */
16
- export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
17
+ export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
17
18
  /**
18
19
  * Batched matrix multiplication for fp32
19
20
  *
20
21
  * @param ctx WebInfer context
21
22
  * @param a Input tensor A [B, M, K]
22
23
  * @param b Input tensor B [B, K, N]
24
+ * @param output Optional output tensor C [B, M, N] for in-place operation
23
25
  * @returns Output tensor C [B, M, N]
24
26
  */
25
- export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
27
+ export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
package/dist/index.d.ts CHANGED
@@ -18,7 +18,10 @@ import * as page from './page/index.ts';
18
18
  import * as rope from './rope/index.ts';
19
19
  import * as sampling from './sampling/index.ts';
20
20
  import * as gemm from './gemm/index.ts';
21
- export { norm, activation, decode, prefill, page, rope, sampling, gemm };
21
+ import * as jit from './jit/index.ts';
22
+ export { norm, activation, decode, prefill, page, rope, sampling, gemm, jit };
23
+ export type { KernelSpec, CompiledKernel, BindingInfo, PlanInfo, CompiledKernelEntry, CompiledKernelRegistry, } from './jit/index.ts';
24
+ export { compileKernel, getSpecKey, initFromSpecs, getCompiledKernel, } from './jit/index.ts';
22
25
  export declare const rmsnorm: typeof norm.rmsnorm, fused_add_rmsnorm: typeof norm.fused_add_rmsnorm, gemma_rmsnorm: typeof norm.gemma_rmsnorm;
23
26
  export declare const silu_and_mul: typeof activation.silu_and_mul, gelu_and_mul: typeof activation.gelu_and_mul;
24
27
  export declare const single_decode_with_kv_cache: typeof decode.single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper: typeof decode.BatchDecodeWithPagedKVCacheWrapper;