webinfer 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,138 @@
1
+ /**
2
+ * JIT compilation module for WebInfer
3
+ *
4
+ * This module provides functionality to compile kernel specifications into
5
+ * executable WGSL shaders at runtime. It follows a similar pattern to
6
+ * FlashInfer's JIT compilation but targets WebGPU instead of CUDA.
7
+ */
8
+ /**
9
+ * Kernel specification for JIT compilation.
10
+ * This is the format used by TVM to pass kernel configurations to WebInfer.
11
+ */
12
+ export interface KernelSpec {
13
+ /** Type of kernel to compile */
14
+ kernel_type: 'batch_prefill_paged' | 'batch_decode_paged' | 'single_prefill' | 'single_decode' | 'rmsnorm' | 'silu_and_mul' | 'gelu_and_mul' | 'rope' | 'sampling';
15
+ /** Data type for computation */
16
+ dtype: 'float16' | 'float32';
17
+ /** Number of query/output heads */
18
+ num_qo_heads?: number;
19
+ /** Number of key/value heads (for GQA/MQA) */
20
+ num_kv_heads?: number;
21
+ /** Head dimension for query and key */
22
+ qk_head_dim?: number;
23
+ /** Head dimension for value (may differ from qk_head_dim) */
24
+ v_head_dim?: number;
25
+ /** Page size for paged KV cache */
26
+ page_size?: number;
27
+ /** Whether to use causal masking */
28
+ causal?: boolean;
29
+ /** Whether to enable inline RoPE */
30
+ enable_inline_rope?: boolean;
31
+ /** Hidden dimension (for normalization kernels) */
32
+ hidden_dim?: number;
33
+ /** Epsilon for numerical stability (for normalization) */
34
+ eps?: number;
35
+ /** RoPE theta base */
36
+ rope_theta?: number;
37
+ /** RoPE scaling factor */
38
+ rope_scale?: number;
39
+ }
40
+ /**
41
+ * Information about a binding in the compiled shader
42
+ */
43
+ export interface BindingInfo {
44
+ /** Binding index */
45
+ binding: number;
46
+ /** Name of the binding (for documentation) */
47
+ name: string;
48
+ /** Type of binding */
49
+ type: 'storage' | 'storage_read' | 'uniform';
50
+ /** Data type */
51
+ dtype: string;
52
+ }
53
+ /**
54
+ * Result of JIT compilation
55
+ */
56
+ export interface CompiledKernel {
57
+ /** Generated WGSL shader code */
58
+ wgsl: string;
59
+ /** Workgroup size [x, y, z] */
60
+ workgroupSize: [number, number, number];
61
+ /** Binding layout information */
62
+ bindings: BindingInfo[];
63
+ /** Entry point function name */
64
+ entryPoint: string;
65
+ /**
66
+ * Calculate dispatch size based on input parameters
67
+ * @param params - Parameters like batch_size, seq_len, etc.
68
+ * @returns Dispatch size [x, y, z]
69
+ */
70
+ dispatchSize: (params: Record<string, number>) => [number, number, number];
71
+ /** Original kernel specification */
72
+ spec: KernelSpec;
73
+ }
74
+ /**
75
+ * Plan information returned by plan phase
76
+ */
77
+ export interface PlanInfo {
78
+ /** Unique key for this configuration */
79
+ key: string;
80
+ /** Required workspace size in bytes */
81
+ workspaceSize: number;
82
+ /** Compiled kernel reference */
83
+ kernel: CompiledKernel;
84
+ /** Additional configuration for execution */
85
+ config: Record<string, number | boolean>;
86
+ }
87
+ /**
88
+ * Compile a kernel from specification.
89
+ *
90
+ * This is the main entry point for JIT compilation. It takes a kernel
91
+ * specification and returns a compiled kernel that can be executed.
92
+ *
93
+ * @param spec - Kernel specification
94
+ * @returns Compiled kernel
95
+ */
96
+ export declare function compileKernel(spec: KernelSpec): CompiledKernel;
97
+ /**
98
+ * Generate a unique key for a kernel specification.
99
+ * Used for caching compiled pipelines.
100
+ */
101
+ export declare function getSpecKey(spec: KernelSpec): string;
102
+ /**
103
+ * Registry entry for a compiled kernel
104
+ */
105
+ export interface CompiledKernelEntry {
106
+ /** The compiled kernel information */
107
+ kernel: CompiledKernel;
108
+ /** The GPU compute pipeline */
109
+ pipeline: GPUComputePipeline;
110
+ /** The bind group layout for this kernel */
111
+ bindGroupLayout: GPUBindGroupLayout;
112
+ }
113
+ /**
114
+ * Registry of compiled kernels, keyed by spec key
115
+ */
116
+ export interface CompiledKernelRegistry {
117
+ [specKey: string]: CompiledKernelEntry;
118
+ }
119
+ /**
120
+ * Initialize multiple kernels from specifications.
121
+ *
122
+ * This function takes a list of kernel specs and compiles them all,
123
+ * returning a registry that can be used to quickly lookup pipelines
124
+ * at execution time.
125
+ *
126
+ * @param device - The GPU device to create pipelines on
127
+ * @param specs - Array of kernel specifications to compile
128
+ * @returns Promise resolving to a registry of compiled kernels
129
+ */
130
+ export declare function initFromSpecs(device: GPUDevice, specs: KernelSpec[]): Promise<CompiledKernelRegistry>;
131
+ /**
132
+ * Get a compiled kernel entry from the registry.
133
+ *
134
+ * @param registry - The kernel registry
135
+ * @param spec - The kernel specification to look up
136
+ * @returns The compiled kernel entry, or undefined if not found
137
+ */
138
+ export declare function getCompiledKernel(registry: CompiledKernelRegistry, spec: KernelSpec): CompiledKernelEntry | undefined;
@@ -5,6 +5,87 @@ import type { WebInferContext } from '../core/context.ts';
5
5
  import type { Tensor } from '../core/tensor.ts';
6
6
  import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
7
  import { PosEncodingMode } from '../core/types.ts';
8
+ /**
9
+ * Options for batch_prefill_plan()
10
+ */
11
+ export interface BatchPrefillPlanOptions {
12
+ /** Number of sequences in the batch */
13
+ batchSize: number;
14
+ /** Total number of query tokens across all sequences */
15
+ totalQoLen: number;
16
+ /** Size of each page in the paged KV cache */
17
+ pageSize: number;
18
+ /** Number of query/output heads */
19
+ numQoHeads: number;
20
+ /** Number of key/value heads */
21
+ numKvHeads: number;
22
+ /** Head dimension for query and key */
23
+ headDim: number;
24
+ /** Whether to apply causal masking (default: true) */
25
+ causal?: boolean;
26
+ }
27
+ /**
28
+ * Plan information for batch prefill with paged KV cache.
29
+ * This is returned by batch_prefill_plan() and passed to batch_prefill_run().
30
+ */
31
+ export interface BatchPrefillPlanInfo {
32
+ /** Unique key for this configuration */
33
+ key: string;
34
+ /** Number of query/output heads */
35
+ num_qo_heads: number;
36
+ /** Number of key/value heads */
37
+ num_kv_heads: number;
38
+ /** Head dimension */
39
+ head_dim: number;
40
+ /** Page size */
41
+ page_size: number;
42
+ /** Softmax scale */
43
+ sm_scale: number;
44
+ /** Whether causal masking is enabled */
45
+ causal: boolean;
46
+ /** Batch size */
47
+ batch_size: number;
48
+ /** Total query/output length */
49
+ total_qo_len: number;
50
+ /** Required workspace size in bytes (currently 0) */
51
+ workspaceSize: number;
52
+ }
53
+ /**
54
+ * Plan batch prefill with paged KV cache.
55
+ *
56
+ * @param options - Configuration options
57
+ * @returns Plan information for execution
58
+ *
59
+ * @example
60
+ * const plan = batch_prefill_plan({
61
+ * batchSize: 4,
62
+ * totalQoLen: 2048,
63
+ * pageSize: 16,
64
+ * numQoHeads: 32,
65
+ * numKvHeads: 8,
66
+ * headDim: 128,
67
+ * causal: true
68
+ * });
69
+ */
70
+ export declare function batch_prefill_plan(options: BatchPrefillPlanOptions): BatchPrefillPlanInfo;
71
+ /**
72
+ * Execute batch prefill with paged KV cache.
73
+ *
74
+ * This function executes the attention computation using the plan
75
+ * prepared by batch_prefill_plan().
76
+ *
77
+ * @param ctx - WebInfer context
78
+ * @param planInfo - Plan information from batch_prefill_plan()
79
+ * @param q - Query tensor [total_qo_len, num_qo_heads, head_dim]
80
+ * @param pagedKvCache - Paged KV cache
81
+ * @param qoIndptr - Query indirection pointer [batch_size + 1]
82
+ * @param pageIndptr - Page indirection pointer [batch_size + 1]
83
+ * @param pageIndices - Page indices [nnz_pages]
84
+ * @param lastPageLen - Last page lengths [batch_size]
85
+ * @param output - Output tensor [total_qo_len, num_qo_heads, head_dim]
86
+ * @param lse - Log-sum-exp output [total_qo_len, num_qo_heads] (optional)
87
+ */
88
+ export declare function batch_prefill_run(ctx: WebInferContext, planInfo: BatchPrefillPlanInfo, q: Tensor, pagedKvCache: PagedKvCache, qoIndptr: Tensor, pageIndptr: Tensor, pageIndices: Tensor, lastPageLen: Tensor, output: Tensor, lse?: Tensor): Promise<void>;
8
89
  /**
9
90
  * Single prefill with KV cache
10
91
  *
@@ -15,6 +96,9 @@ import { PosEncodingMode } from '../core/types.ts';
15
96
  * @param q Query tensor of shape [qo_len, num_qo_heads, head_dim]
16
97
  * @param k Key tensor of shape [kv_len, num_kv_heads, head_dim]
17
98
  * @param v Value tensor of shape [kv_len, num_kv_heads, head_dim]
99
+ * @param output Optional pre-allocated output tensor [qo_len, num_qo_heads, head_dim].
100
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
101
+ * If not provided, a new tensor is created and returned.
18
102
  * @param causal Whether to apply causal masking (default: true)
19
103
  * @param pos_encoding_mode Position encoding mode (default: NONE)
20
104
  * @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
@@ -22,7 +106,7 @@ import { PosEncodingMode } from '../core/types.ts';
22
106
  * @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
23
107
  * @returns Output tensor of shape [qo_len, num_qo_heads, head_dim]
24
108
  */
25
- export declare function single_prefill_with_kv_cache(ctx: WebInferContext, q: Tensor, k: Tensor, v: Tensor, causal?: boolean, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
109
+ export declare function single_prefill_with_kv_cache(ctx: WebInferContext, q: Tensor, k: Tensor, v: Tensor, output?: Tensor, causal?: boolean, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
26
110
  /**
27
111
  * Batched prefill with paged KV cache wrapper
28
112
  *
@@ -63,7 +147,9 @@ export declare class BatchPrefillWithPagedKVCacheWrapper {
63
147
  * @param paged_kv_indptr Paged KV indirection pointer [batch_size + 1]
64
148
  * @param paged_kv_indices Paged KV indices [nnz_pages]
65
149
  * @param paged_kv_last_page_len Last page lengths [batch_size]
150
+ * @param output Optional pre-allocated output tensor [total_qo_len, num_qo_heads, head_dim].
151
+ * If provided, results are written to this tensor (zero-copy for TVM integration).
66
152
  * @returns Output tensor [total_qo_len, num_qo_heads, head_dim]
67
153
  */
68
- run(q: Tensor, paged_kv_cache: PagedKvCache, qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor): Promise<Tensor>;
154
+ run(q: Tensor, paged_kv_cache: PagedKvCache, qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor, output?: Tensor): Promise<Tensor>;
69
155
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "webinfer",
3
- "version": "0.0.4",
3
+ "version": "0.0.5",
4
4
  "description": "High-performance LLM inference kernels for WebGPU",
5
5
  "license": "Apache-2.0",
6
6
  "repository": {