webinfer 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/README.md +65 -21
  2. package/dist/activation/index.d.ts +30 -0
  3. package/dist/core/context.d.ts +70 -0
  4. package/dist/core/paged-kv-cache.d.ts +33 -0
  5. package/dist/core/tensor.d.ts +51 -19
  6. package/dist/core/types.d.ts +27 -0
  7. package/dist/decode/index.d.ts +140 -0
  8. package/dist/gemm/index.d.ts +27 -0
  9. package/dist/index.d.ts +29 -21
  10. package/dist/index.js +3433 -4809
  11. package/dist/jit/index.d.ts +138 -0
  12. package/dist/kernels/activation.wgsl.d.ts +14 -0
  13. package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
  14. package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
  15. package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
  16. package/dist/kernels/gemm.wgsl.d.ts +17 -0
  17. package/dist/kernels/page.wgsl.d.ts +10 -0
  18. package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
  19. package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
  20. package/dist/kernels/rope.wgsl.d.ts +19 -0
  21. package/dist/kernels/sampling.wgsl.d.ts +23 -0
  22. package/dist/norm/index.d.ts +43 -0
  23. package/dist/page/index.d.ts +21 -0
  24. package/dist/prefill/index.d.ts +155 -0
  25. package/dist/rope/index.d.ts +37 -0
  26. package/dist/sampling/index.d.ts +53 -4
  27. package/package.json +1 -1
  28. package/dist/attention/block-sparse/format.d.ts +0 -52
  29. package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
  30. package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
  31. package/dist/attention/block-sparse/patterns/tree.d.ts +0 -65
  32. package/dist/attention/cascaded-inference.d.ts +0 -29
  33. package/dist/attention/flash-attention.d.ts +0 -30
  34. package/dist/attention/index.d.ts +0 -118
  35. package/dist/attention/paged-attention.d.ts +0 -40
  36. package/dist/attention/paged-kv/block-manager.d.ts +0 -102
  37. package/dist/attention/paged-kv/index.d.ts +0 -5
  38. package/dist/attention/paged-kv/page-table.d.ts +0 -165
  39. package/dist/attention/scheduler.d.ts +0 -40
  40. package/dist/core/buffer-pool.d.ts +0 -18
  41. package/dist/core/device.d.ts +0 -23
  42. package/dist/core/tdr.d.ts +0 -114
  43. package/dist/inference/engine.d.ts +0 -69
  44. package/dist/inference/generate.d.ts +0 -30
  45. package/dist/inference/index.d.ts +0 -7
  46. package/dist/inference/types.d.ts +0 -161
  47. package/dist/jit/compiler.d.ts +0 -23
  48. package/dist/jit/kernel-cache.d.ts +0 -21
  49. package/dist/model/gguf.d.ts +0 -90
  50. package/dist/model/index.d.ts +0 -16
  51. package/dist/model/safetensors.d.ts +0 -38
  52. package/dist/model/types.d.ts +0 -182
  53. package/dist/ops/activations.d.ts +0 -43
  54. package/dist/ops/elementwise.d.ts +0 -38
  55. package/dist/ops/embedding.d.ts +0 -30
  56. package/dist/ops/matmul.d.ts +0 -21
  57. package/dist/ops/normalization.d.ts +0 -63
  58. package/dist/ops/reshape.d.ts +0 -39
  59. package/dist/ops/rope.d.ts +0 -32
  60. package/dist/ops/softmax.d.ts +0 -18
  61. package/dist/quantization/index.d.ts +0 -6
  62. package/dist/quantization/qmatmul.d.ts +0 -38
  63. package/dist/quantization/quantize.d.ts +0 -52
  64. package/dist/sampling/beam-search.d.ts +0 -87
  65. package/dist/sampling/sampler.d.ts +0 -72
  66. package/dist/sampling/speculative.d.ts +0 -65
  67. package/dist/sampling/top-k.d.ts +0 -24
  68. package/dist/sampling/top-p.d.ts +0 -14
  69. package/dist/tvm/adapter.d.ts +0 -81
  70. package/dist/tvm/index.d.ts +0 -8
  71. package/dist/tvm/ops.d.ts +0 -26
  72. package/dist/tvm/types.d.ts +0 -35
package/README.md CHANGED
@@ -1,44 +1,88 @@
1
1
  # WebInfer
2
2
 
3
- High-performance LLM inference kernels for WebGPU.
3
+ [![npm version](https://img.shields.io/npm/v/webinfer.svg)](https://www.npmjs.com/package/webinfer)
4
+ [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE)
4
5
 
5
- ## Install
6
+ High-performance LLM inference kernels for WebGPU. A browser-native implementation of [FlashInfer](https://github.com/flashinfer-ai/flashinfer) APIs.
7
+
8
+ ## Why WebInfer?
9
+
10
+ Running LLMs in the browser requires efficient GPU kernels. WebInfer brings FlashInfer's battle-tested attention mechanisms to WebGPU:
11
+
12
+ - **Flash Attention**: Online softmax with O(1) memory for long sequences
13
+ - **Subgroup Optimizations**: 2-4x faster reductions on supported hardware
14
+ - **JIT Compilation**: Runtime kernel generation with pipeline caching
15
+ - **Dynamic Tile Sizes**: Auto-tuned tile dimensions based on workload and hardware
16
+
17
+ ## Installation
6
18
 
7
19
  ```bash
8
20
  npm install webinfer
9
21
  ```
10
22
 
11
- ## Usage
23
+ ## Quick Start
12
24
 
13
25
  ```typescript
14
- import { WebInferDevice, attention } from 'webinfer';
26
+ import * as webinfer from 'webinfer';
15
27
 
16
- const device = await WebInferDevice.create();
28
+ // Initialize context
29
+ const ctx = await webinfer.WebInferContext.create();
17
30
 
18
- // Single decode attention
19
- const q = new Float32Array(32 * 128); // [num_qo_heads, head_dim]
20
- const k = new Float32Array(2048 * 32 * 128); // [kv_len, num_kv_heads, head_dim]
21
- const v = new Float32Array(2048 * 32 * 128);
31
+ // Run RMSNorm
32
+ const output = ctx.norm.rmsnorm(input, weight, epsilon);
22
33
 
23
- const output = await attention(device, { q, k, v });
34
+ // Flash attention decode
35
+ const attnOutput = ctx.decode.single_decode_with_kv_cache(
36
+ query, // [num_heads, head_dim]
37
+ kv_cache, // [seq_len, 2, num_kv_heads, head_dim]
38
+ );
24
39
  ```
25
40
 
26
- ## API
41
+ ## Browser Compatibility
27
42
 
28
- | Category | Exports |
29
- |----------|---------|
30
- | **Attention** | `attention`, `BatchAttention`, `AttentionKernel`, `cascadedAttention` |
31
- | **KV Cache** | `PagedKVCache`, `BlockManager`, `pagedAttention` |
32
- | **Patterns** | `buildCausalMask`, `buildSlidingWindowMask`, `buildBlockSparseCSR` |
33
- | **Sampling** | `topKSamplingFromProbs`, `topPSamplingFromProbs`, `minPSamplingFromProbs`, `topKTopPSamplingFromLogits` |
34
- | **Normalization** | `rmsNorm`, `layerNorm`, `fusedAddRmsNorm`, `gemmaRmsNorm` |
35
- | **Core** | `matmul`, `rope`, `gelu`, `silu`, `softmax` |
43
+ WebInfer requires WebGPU support:
44
+
45
+ | Browser | Status |
46
+ |---------|--------|
47
+ | Chrome 113+ | Supported |
48
+ | Edge 113+ | Supported |
49
+ | Firefox Nightly | Behind flag |
50
+ | Safari 18+ | Supported |
51
+
52
+ Check for WebGPU support:
53
+
54
+ ```typescript
55
+ if (!navigator.gpu) {
56
+ console.error('WebGPU not supported');
57
+ }
58
+ ```
36
59
 
37
- ## Release
60
+ ## Development
38
61
 
39
62
  ```bash
63
+ # Run tests
64
+ bun test
65
+
66
+ # Build
40
67
  bun run build
41
- npm publish
68
+ ```
69
+
70
+ ## Acknowledgments
71
+
72
+ - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) - The original CUDA implementation
73
+ - [WebLLM](https://github.com/mlc-ai/web-llm) - Browser LLM runtime
74
+
75
+ ## Citation
76
+
77
+ If you use WebInfer in your research, please cite:
78
+
79
+ ```bibtex
80
+ @software{webinfer2025,
81
+ author = {Guan-Ming, Chiu},
82
+ title = {WebInfer: High-Performance LLM Inference Kernels for WebGPU},
83
+ year = {2026},
84
+ url = {https://github.com/guan404ming/webinfer}
85
+ }
42
86
  ```
43
87
 
44
88
  ## License
@@ -0,0 +1,30 @@
1
+ /**
2
+ * Activation functions
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * SiLU (Swish) activation with gating
8
+ *
9
+ * Input tensor has shape [..., 2 * hidden_size]
10
+ * Splits into gate[..., :hidden_size] and up[..., hidden_size:]
11
+ * Computes: output = silu(gate) * up
12
+ * Where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
13
+ *
14
+ * @param ctx WebInfer context
15
+ * @param input Input tensor of shape [..., 2 * hidden_size]
16
+ * @returns Activated tensor of shape [..., hidden_size]
17
+ */
18
+ export declare function silu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
19
+ /**
20
+ * GELU activation with gating
21
+ *
22
+ * Input tensor has shape [..., 2 * hidden_size]
23
+ * Computes: output = gelu(gate) * up
24
+ * Where gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
25
+ *
26
+ * @param ctx WebInfer context
27
+ * @param input Input tensor of shape [..., 2 * hidden_size]
28
+ * @returns Activated tensor of shape [..., hidden_size]
29
+ */
30
+ export declare function gelu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
@@ -0,0 +1,70 @@
1
+ /**
2
+ * WebInferContext manages WebGPU device and provides core functionality
3
+ */
4
+ import type { DType } from './types.ts';
5
+ import { Tensor } from './tensor.ts';
6
+ /**
7
+ * Device vendor detection
8
+ */
9
+ export type Vendor = 'apple' | 'nvidia' | 'amd' | 'intel' | 'unknown';
10
+ /**
11
+ * Device features detected at runtime
12
+ */
13
+ export interface DeviceFeatures {
14
+ readonly vendor: Vendor;
15
+ readonly hasSubgroups: boolean;
16
+ readonly hasF16: boolean;
17
+ readonly maxWorkgroupSize: number;
18
+ readonly maxStorageBufferSize: number;
19
+ readonly maxBufferSize: number;
20
+ readonly maxWorkgroupsPerDimension: number;
21
+ }
22
+ /**
23
+ * WebInferContext - main context for all WebInfer operations
24
+ */
25
+ export declare class WebInferContext {
26
+ readonly device: GPUDevice;
27
+ readonly features: DeviceFeatures;
28
+ private pipelineCache;
29
+ private disposed;
30
+ private constructor();
31
+ /**
32
+ * Create a WebInferContext from an existing GPUDevice.
33
+ * Use this when integrating with TVM or other WebGPU frameworks
34
+ * that already have a GPUDevice to enable zero-copy buffer sharing.
35
+ *
36
+ * @param device - The existing GPUDevice to use
37
+ * @param featureOverrides - Optional overrides for detected features
38
+ * @returns A new WebInferContext using the provided device
39
+ */
40
+ static fromDevice(device: GPUDevice, featureOverrides?: Partial<DeviceFeatures>): WebInferContext;
41
+ /**
42
+ * Create a new WebInferContext with its own GPUDevice
43
+ */
44
+ static create(): Promise<WebInferContext>;
45
+ /**
46
+ * Create a GPU buffer
47
+ */
48
+ createBuffer(size: number, usage?: GPUBufferUsageFlags): GPUBuffer;
49
+ /**
50
+ * Create a tensor with uninitialized data
51
+ */
52
+ createTensor(shape: number[], dtype: DType): Tensor;
53
+ /**
54
+ * Create a tensor from TypedArray data
55
+ */
56
+ createTensorFromData(data: Float32Array | Uint16Array, shape: number[], dtype: DType): Tensor;
57
+ /**
58
+ * Get or create a cached compute pipeline
59
+ */
60
+ getOrCreatePipeline(key: string, createFn: () => Promise<GPUComputePipeline> | GPUComputePipeline): Promise<GPUComputePipeline>;
61
+ /**
62
+ * Clear pipeline cache
63
+ */
64
+ clearPipelineCache(): void;
65
+ /**
66
+ * Dispose the context and cleanup resources
67
+ */
68
+ dispose(): void;
69
+ private checkDisposed;
70
+ }
@@ -0,0 +1,33 @@
1
+ /**
2
+ * Paged KV Cache implementation
3
+ */
4
+ import type { Tensor } from './tensor.ts';
5
+ /**
6
+ * Paged KV Cache for efficient memory management
7
+ *
8
+ * Layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
9
+ * where dim 1 has [0] = keys, [1] = values
10
+ */
11
+ export interface PagedKvCache {
12
+ /**
13
+ * The backing tensor for the paged KV cache
14
+ * Shape: [num_pages, 2, page_size, num_kv_heads, head_dim]
15
+ */
16
+ data: Tensor;
17
+ /**
18
+ * Number of tokens per page
19
+ */
20
+ page_size: number;
21
+ /**
22
+ * Number of KV heads (for GQA/MQA)
23
+ */
24
+ num_kv_heads: number;
25
+ /**
26
+ * Dimension of each head
27
+ */
28
+ head_dim: number;
29
+ }
30
+ /**
31
+ * Create a paged KV cache
32
+ */
33
+ export declare function createPagedKvCache(data: Tensor, page_size: number, num_kv_heads: number, head_dim: number): PagedKvCache;
@@ -1,25 +1,57 @@
1
1
  /**
2
- * Tensor - GPU-backed multidimensional array
2
+ * Tensor wrapper for WebGPU buffers
3
+ */
4
+ import type { DType } from './types.ts';
5
+ /**
6
+ * Compute strides from shape (row-major order)
7
+ */
8
+ export declare function computeStrides(shape: number[]): number[];
9
+ /**
10
+ * Compute total number of elements from shape
11
+ */
12
+ export declare function computeSize(shape: number[]): number;
13
+ /**
14
+ * Tensor class wrapping a GPUBuffer with shape and dtype information
3
15
  */
4
- import type { WebInferDevice } from "./device.ts";
5
- export type DType = "f32" | "f16" | "i32" | "u32";
6
16
  export declare class Tensor {
7
- private _device;
8
- private _shape;
9
- private _dtype;
10
- private _buffer;
11
- private _disposed;
12
- constructor(device: WebInferDevice, shape: number[], dtype?: DType, data?: Float32Array | Uint32Array | Int32Array);
13
- static fromArray(device: WebInferDevice, shape: number[], data: Float32Array, dtype?: DType): Promise<Tensor>;
14
- static zeros(device: WebInferDevice, shape: number[], dtype?: DType): Tensor;
15
- static rand(device: WebInferDevice, shape: number[], dtype?: DType): Tensor;
16
- get shape(): readonly number[];
17
- get dtype(): DType;
18
- get numel(): number;
19
- get byteSize(): number;
20
- get buffer(): GPUBuffer;
21
- get device(): WebInferDevice;
22
- toArray(): Promise<Float32Array>;
17
+ readonly buffer: GPUBuffer;
18
+ readonly shape: readonly number[];
19
+ readonly dtype: DType;
20
+ readonly strides: readonly number[];
21
+ readonly size: number;
22
+ readonly byteSize: number;
23
+ constructor(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]);
24
+ /**
25
+ * Create a Tensor from an existing GPUBuffer (zero-copy).
26
+ *
27
+ * This is useful for TVM integration where the buffer is already allocated
28
+ * by TVM's WebGPU runtime and we want to wrap it without copying data.
29
+ *
30
+ * @param buffer - The existing GPUBuffer to wrap
31
+ * @param shape - The shape of the tensor
32
+ * @param dtype - The data type of the tensor
33
+ * @param strides - Optional custom strides (default: row-major)
34
+ * @returns A new Tensor wrapping the provided buffer
35
+ */
36
+ static fromBuffer(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]): Tensor;
37
+ /**
38
+ * Get the number of dimensions
39
+ */
40
+ get ndim(): number;
41
+ /**
42
+ * Check if tensor is contiguous (row-major)
43
+ */
44
+ get isContiguous(): boolean;
45
+ /**
46
+ * Get a view of the tensor with new shape (must have same size)
47
+ */
23
48
  reshape(newShape: number[]): Tensor;
49
+ /**
50
+ * Destroy the underlying GPU buffer
51
+ */
24
52
  dispose(): void;
53
+ /**
54
+ * Get string representation
55
+ */
56
+ toString(): string;
25
57
  }
@@ -0,0 +1,27 @@
1
+ /**
2
+ * Core type definitions for WebInfer
3
+ */
4
+ /**
5
+ * Data type for tensors
6
+ */
7
+ export type DType = 'float32' | 'float16';
8
+ /**
9
+ * Position encoding modes for attention
10
+ */
11
+ export declare enum PosEncodingMode {
12
+ NONE = 0,
13
+ ROPE_LLAMA = 1,
14
+ ALIBI = 2
15
+ }
16
+ /**
17
+ * Get the byte size for a given dtype
18
+ */
19
+ export declare function dtypeByteSize(dtype: DType): number;
20
+ /**
21
+ * Get the GPUTextureFormat for a given dtype
22
+ */
23
+ export declare function dtypeToGPUFormat(dtype: DType): string;
24
+ /**
25
+ * Get WGSL type string for a given dtype
26
+ */
27
+ export declare function dtypeToWGSL(dtype: DType): 'f32' | 'f16';
@@ -0,0 +1,140 @@
1
+ /**
2
+ * Decode attention operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
+ import { PosEncodingMode } from '../core/types.ts';
8
+ /**
9
+ * Options for batch_decode_plan()
10
+ */
11
+ export interface BatchDecodePlanOptions {
12
+ /** Number of sequences in the batch */
13
+ batchSize: number;
14
+ /** Size of each page in the paged KV cache */
15
+ pageSize: number;
16
+ /** Number of query/output heads */
17
+ numQoHeads: number;
18
+ /** Number of key/value heads */
19
+ numKvHeads: number;
20
+ /** Head dimension */
21
+ headDim: number;
22
+ }
23
+ /**
24
+ * Plan information for batch decode with paged KV cache.
25
+ * This is returned by batch_decode_plan() and passed to batch_decode_run().
26
+ */
27
+ export interface BatchDecodePlanInfo {
28
+ /** Unique key for this configuration */
29
+ key: string;
30
+ /** Number of query/output heads */
31
+ num_qo_heads: number;
32
+ /** Number of key/value heads */
33
+ num_kv_heads: number;
34
+ /** Head dimension */
35
+ head_dim: number;
36
+ /** Page size */
37
+ page_size: number;
38
+ /** Softmax scale */
39
+ sm_scale: number;
40
+ /** Batch size */
41
+ batch_size: number;
42
+ /** Required workspace size in bytes (currently 0) */
43
+ workspaceSize: number;
44
+ }
45
+ /**
46
+ * Plan batch decode with paged KV cache.
47
+ *
48
+ * @param options - Configuration options
49
+ * @returns Plan information for execution
50
+ *
51
+ * @example
52
+ * const plan = batch_decode_plan({
53
+ * batchSize: 4,
54
+ * pageSize: 16,
55
+ * numQoHeads: 32,
56
+ * numKvHeads: 8,
57
+ * headDim: 128
58
+ * });
59
+ */
60
+ export declare function batch_decode_plan(options: BatchDecodePlanOptions): BatchDecodePlanInfo;
61
+ /**
62
+ * Execute batch decode with paged KV cache.
63
+ *
64
+ * This function executes the attention computation using the plan
65
+ * prepared by batch_decode_plan().
66
+ *
67
+ * @param ctx - WebInfer context
68
+ * @param planInfo - Plan information from batch_decode_plan()
69
+ * @param q - Query tensor [batch_size, num_qo_heads, head_dim]
70
+ * @param pagedKvCache - Paged KV cache
71
+ * @param pageIndptr - Page indirection pointer [batch_size + 1]
72
+ * @param pageIndices - Page indices [nnz_pages]
73
+ * @param lastPageLen - Last page lengths [batch_size]
74
+ * @param output - Output tensor [batch_size, num_qo_heads, head_dim]
75
+ * @param lse - Log-sum-exp output [batch_size, num_qo_heads] (optional)
76
+ */
77
+ export declare function batch_decode_run(ctx: WebInferContext, planInfo: BatchDecodePlanInfo, q: Tensor, pagedKvCache: PagedKvCache, pageIndptr: Tensor, pageIndices: Tensor, lastPageLen: Tensor, output: Tensor, lse?: Tensor): Promise<void>;
78
+ /**
79
+ * Single decode with KV cache
80
+ *
81
+ * Performs attention for a single token (decode phase) with KV cache.
82
+ * Uses flash attention with online softmax to avoid materializing full attention matrix.
83
+ *
84
+ * @param ctx WebInfer context
85
+ * @param q Query tensor of shape [num_qo_heads, head_dim]
86
+ * @param k_cache Key cache tensor of shape [seq_len, num_kv_heads, head_dim]
87
+ * @param v_cache Value cache tensor of shape [seq_len, num_kv_heads, head_dim]
88
+ * @param output Optional pre-allocated output tensor [num_qo_heads, head_dim].
89
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
90
+ * If not provided, a new tensor is created and returned.
91
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
92
+ * @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
93
+ * @param rope_scale RoPE scale (default: 1.0) - not yet implemented
94
+ * @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
95
+ * @returns Output tensor of shape [num_qo_heads, head_dim]
96
+ */
97
+ export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, output?: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
98
+ /**
99
+ * Batched decode with paged KV cache wrapper
100
+ *
101
+ * FlashInfer-compatible wrapper for batched decode attention with paged KV cache.
102
+ */
103
+ export declare class BatchDecodeWithPagedKVCacheWrapper {
104
+ private ctx;
105
+ private planned;
106
+ private num_qo_heads;
107
+ private num_kv_heads;
108
+ private head_dim;
109
+ private page_size;
110
+ private batch_size;
111
+ private sm_scale;
112
+ private pos_encoding_mode;
113
+ constructor(ctx: WebInferContext);
114
+ /**
115
+ * Plan the batched decode operation
116
+ *
117
+ * @param indptr Indirection pointer [batch_size + 1]
118
+ * @param indices Page indices [nnz_pages]
119
+ * @param last_page_len Last page lengths [batch_size]
120
+ * @param num_qo_heads Number of query heads
121
+ * @param num_kv_heads Number of KV heads
122
+ * @param head_dim Head dimension
123
+ * @param page_size Page size
124
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
125
+ */
126
+ plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, pos_encoding_mode?: PosEncodingMode): void;
127
+ /**
128
+ * Run the batched decode operation
129
+ *
130
+ * @param q Query tensor [batch_size, num_qo_heads, head_dim]
131
+ * @param paged_kv_cache Paged KV cache
132
+ * @param indptr Indirection pointer [batch_size + 1]
133
+ * @param indices Page indices [nnz_pages]
134
+ * @param last_page_len Last page lengths [batch_size]
135
+ * @param output Optional pre-allocated output tensor [batch_size, num_qo_heads, head_dim].
136
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
137
+ * @returns Output tensor [batch_size, num_qo_heads, head_dim]
138
+ */
139
+ run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor, output?: Tensor): Promise<Tensor>;
140
+ }
@@ -0,0 +1,27 @@
1
+ /**
2
+ * GEMM (General Matrix Multiplication) operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * Batched matrix multiplication for fp16
8
+ *
9
+ * Computes C = A @ B for batched matrices
10
+ *
11
+ * @param ctx WebInfer context
12
+ * @param a Input tensor A [B, M, K]
13
+ * @param b Input tensor B [B, K, N]
14
+ * @param output Optional output tensor C [B, M, N] for in-place operation
15
+ * @returns Output tensor C [B, M, N]
16
+ */
17
+ export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
18
+ /**
19
+ * Batched matrix multiplication for fp32
20
+ *
21
+ * @param ctx WebInfer context
22
+ * @param a Input tensor A [B, M, K]
23
+ * @param b Input tensor B [B, K, N]
24
+ * @param output Optional output tensor C [B, M, N] for in-place operation
25
+ * @returns Output tensor C [B, M, N]
26
+ */
27
+ export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor, output?: Tensor): Promise<Tensor>;
package/dist/index.d.ts CHANGED
@@ -1,24 +1,32 @@
1
1
  /**
2
2
  * WebInfer - High-performance LLM inference kernels for WebGPU
3
- * "The cuDNN/FlashInfer of WebGPU"
3
+ *
4
+ * WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
4
5
  */
5
- export { WebInferDevice, type DeviceInfo } from "./core/device.ts";
6
- export { Tensor, type DType } from "./core/tensor.ts";
7
- export { BufferPool } from "./core/buffer-pool.ts";
8
- export { TDRGuard, detectBrowser, getTDRConfig, classifyError, checkWebGPUSupport, createDeviceWithFallback, type BrowserInfo, type TDRConfig, type DegradationOptions, type WebGPUErrorType, } from "./core/tdr.ts";
9
- export { attention, BatchAttention, AttentionKernel, type AttentionOptions, type AttentionResultWithLse, type BatchAttentionConfig, type PrefillInput, type DecodeInput, type TensorLike, type AttentionKernelConfig, } from "./attention/index.ts";
10
- export { KernelCache, type CacheStats } from "./jit/kernel-cache.ts";
11
- export { WGSLCompiler, type MatMulConfig } from "./jit/compiler.ts";
12
- export { matmul, getMatMulCacheStats } from "./ops/matmul.ts";
13
- export { layerNorm, rmsNorm, fusedAddRmsNorm, gemmaRmsNorm, gemmaFusedAddRmsNorm, } from "./ops/normalization.ts";
14
- export { rope, computeRoPEFrequencies, type RoPEConfig } from "./ops/rope.ts";
15
- export { gelu, silu, relu } from "./ops/activations.ts";
16
- export { softmaxGPU } from "./ops/softmax.ts";
17
- export { add, mul, scale } from "./ops/elementwise.ts";
18
- export { embedding } from "./ops/embedding.ts";
19
- export { transpose2D } from "./ops/reshape.ts";
20
- export { quantizeToInt8, quantizeToInt4, dequantizeInt8, dequantizeInt4, quantizationError, getMemorySavings, estimateQMatMulFlops, estimateQMatMulBandwidth, type QuantConfig, type QuantizedTensor, } from "./quantization/index.ts";
21
- export { flashAttention, type AttentionConfig, buildBlockSparseCSR, getSparsityRatio, estimateMemorySavings, type BlockSparseCSR, type AttentionPattern, buildCausalMask, getCausalSparsity, buildSlidingWindowMask, buildCausalSlidingWindowMask, getSlidingWindowSparsity, PagedKVCache, type PagedKVCacheConfig, type SequenceEntry, type DefragmentResult, BlockManager, ContinuousBatchScheduler, type BlockManagerConfig, type AllocationPolicy, type AllocationRequest, pagedAttention, appendToPagedCache, type PagedAttentionConfig, type PagedAttentionInput, cascadedAttention, type CascadedAttentionConfig, } from "./attention/index.ts";
22
- export { topK, topKFilter, topPFilter, sample, sampleGreedy, sampleFromProbs, softmax, applyRepetitionPenalty, minPSamplingFromProbs, topKSamplingFromProbs, topPSamplingFromProbs, topKTopPSamplingFromProbs, topKTopPSamplingFromLogits, topPRenormProbs, topKRenormProbs, topKMaskLogits, type SamplingConfig, } from "./sampling/index.ts";
23
- export { type ModelFormat, type SafetensorsDType, GGUFQuantType, GGUFMetadataValueType, type TensorInfo, type SafetensorsHeader, type ModelMetadata, type LoadedModel, type LoadOptions, parseSafetensorsHeader, loadSafetensors, loadSafetensorsFromUrl, isSafetensors, parseGGUFHeader, loadGGUF, loadGGUFFromUrl, loadGGUFTensor, isGGUF, loadModel, } from "./model/index.ts";
24
- export { type ModelConfig, type InferenceConfig, type GenerationConfig, type GenerationResult, type StreamToken, type FinishReason, type ForwardResult, DEFAULT_GENERATION_CONFIG, normalizeGenerationConfig, InferenceEngine, generate, generateStream, greedyDecode, sampleNextToken, } from "./inference/index.ts";
6
+ export { WebInferContext } from './core/context.ts';
7
+ export { Tensor } from './core/tensor.ts';
8
+ export type { DType } from './core/types.ts';
9
+ export { PosEncodingMode } from './core/types.ts';
10
+ export type { PagedKvCache } from './core/paged-kv-cache.ts';
11
+ export { createPagedKvCache } from './core/paged-kv-cache.ts';
12
+ export { WebInferContext as create_context } from './core/context.ts';
13
+ import * as norm from './norm/index.ts';
14
+ import * as activation from './activation/index.ts';
15
+ import * as decode from './decode/index.ts';
16
+ import * as prefill from './prefill/index.ts';
17
+ import * as page from './page/index.ts';
18
+ import * as rope from './rope/index.ts';
19
+ import * as sampling from './sampling/index.ts';
20
+ import * as gemm from './gemm/index.ts';
21
+ import * as jit from './jit/index.ts';
22
+ export { norm, activation, decode, prefill, page, rope, sampling, gemm, jit };
23
+ export type { KernelSpec, CompiledKernel, BindingInfo, PlanInfo, CompiledKernelEntry, CompiledKernelRegistry, } from './jit/index.ts';
24
+ export { compileKernel, getSpecKey, initFromSpecs, getCompiledKernel, } from './jit/index.ts';
25
+ export declare const rmsnorm: typeof norm.rmsnorm, fused_add_rmsnorm: typeof norm.fused_add_rmsnorm, gemma_rmsnorm: typeof norm.gemma_rmsnorm;
26
+ export declare const silu_and_mul: typeof activation.silu_and_mul, gelu_and_mul: typeof activation.gelu_and_mul;
27
+ export declare const single_decode_with_kv_cache: typeof decode.single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper: typeof decode.BatchDecodeWithPagedKVCacheWrapper;
28
+ export declare const single_prefill_with_kv_cache: typeof prefill.single_prefill_with_kv_cache, BatchPrefillWithPagedKVCacheWrapper: typeof prefill.BatchPrefillWithPagedKVCacheWrapper;
29
+ export declare const append_paged_kv_cache: typeof page.append_paged_kv_cache;
30
+ export declare const apply_rope_inplace: typeof rope.apply_rope_inplace, apply_llama31_rope_inplace: typeof rope.apply_llama31_rope_inplace;
31
+ export declare const sampling_from_probs: typeof sampling.sampling_from_probs, top_k_sampling_from_probs: typeof sampling.top_k_sampling_from_probs, top_p_sampling_from_probs: typeof sampling.top_p_sampling_from_probs, min_p_sampling_from_probs: typeof sampling.min_p_sampling_from_probs, top_k_top_p_sampling_from_probs: typeof sampling.top_k_top_p_sampling_from_probs;
32
+ export declare const bmm_fp16: typeof gemm.bmm_fp16, bmm_fp32: typeof gemm.bmm_fp32;