webinfer 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +38 -33
- package/dist/activation/index.d.ts +30 -0
- package/dist/core/context.d.ts +60 -0
- package/dist/core/paged-kv-cache.d.ts +33 -0
- package/dist/core/tensor.d.ts +38 -19
- package/dist/core/types.d.ts +27 -0
- package/dist/decode/index.d.ts +65 -0
- package/dist/gemm/index.d.ts +25 -0
- package/dist/index.d.ts +26 -19
- package/dist/index.js +2508 -3885
- package/dist/kernels/activation.wgsl.d.ts +14 -0
- package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
- package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
- package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
- package/dist/kernels/gemm.wgsl.d.ts +17 -0
- package/dist/kernels/page.wgsl.d.ts +10 -0
- package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
- package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
- package/dist/kernels/rope.wgsl.d.ts +19 -0
- package/dist/kernels/sampling.wgsl.d.ts +23 -0
- package/dist/norm/index.d.ts +43 -0
- package/dist/page/index.d.ts +21 -0
- package/dist/prefill/index.d.ts +69 -0
- package/dist/rope/index.d.ts +37 -0
- package/dist/sampling/index.d.ts +53 -4
- package/package.json +1 -1
- package/dist/attention/block-sparse/format.d.ts +0 -52
- package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
- package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
- package/dist/attention/flash-attention.d.ts +0 -30
- package/dist/attention/index.d.ts +0 -9
- package/dist/attention/paged-kv/block-manager.d.ts +0 -102
- package/dist/attention/paged-kv/index.d.ts +0 -5
- package/dist/attention/paged-kv/page-table.d.ts +0 -99
- package/dist/attention/scheduler.d.ts +0 -40
- package/dist/core/buffer-pool.d.ts +0 -18
- package/dist/core/device.d.ts +0 -23
- package/dist/inference/engine.d.ts +0 -69
- package/dist/inference/generate.d.ts +0 -30
- package/dist/inference/index.d.ts +0 -7
- package/dist/inference/types.d.ts +0 -161
- package/dist/jit/compiler.d.ts +0 -23
- package/dist/jit/kernel-cache.d.ts +0 -21
- package/dist/model/gguf.d.ts +0 -90
- package/dist/model/index.d.ts +0 -16
- package/dist/model/safetensors.d.ts +0 -38
- package/dist/model/types.d.ts +0 -182
- package/dist/ops/activations.d.ts +0 -43
- package/dist/ops/elementwise.d.ts +0 -38
- package/dist/ops/embedding.d.ts +0 -30
- package/dist/ops/matmul.d.ts +0 -21
- package/dist/ops/normalization.d.ts +0 -24
- package/dist/ops/reshape.d.ts +0 -39
- package/dist/ops/rope.d.ts +0 -32
- package/dist/ops/softmax.d.ts +0 -18
- package/dist/quantization/index.d.ts +0 -6
- package/dist/quantization/qmatmul.d.ts +0 -38
- package/dist/quantization/quantize.d.ts +0 -52
- package/dist/sampling/sampler.d.ts +0 -39
- package/dist/sampling/top-k.d.ts +0 -24
- package/dist/sampling/top-p.d.ts +0 -14
package/README.md
CHANGED
|
@@ -2,53 +2,58 @@
|
|
|
2
2
|
|
|
3
3
|
High-performance LLM inference kernels for WebGPU.
|
|
4
4
|
|
|
5
|
+
WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
|
|
9
|
+
- **Flash Attention for WebGPU**: Online softmax and paged KV cache for memory-efficient attention
|
|
10
|
+
- **Subgroup Optimization**: Uses WebGPU subgroup operations for 2-4x faster reductions when available
|
|
11
|
+
- **Fused Kernels**: Reduces memory bandwidth with operations like `silu_and_mul`, `fused_add_rmsnorm`
|
|
12
|
+
- **JIT Compilation**: Runtime kernel generation and pipeline caching
|
|
13
|
+
- **f16 Safe Accumulation**: Uses f32 accumulators for f16 computation to prevent precision loss
|
|
14
|
+
- **FlashInfer Compatible API**: Drop-in replacement for FlashInfer with identical naming
|
|
15
|
+
|
|
5
16
|
## Install
|
|
6
17
|
|
|
7
18
|
```bash
|
|
8
19
|
npm install webinfer
|
|
9
20
|
```
|
|
10
21
|
|
|
11
|
-
##
|
|
12
|
-
|
|
13
|
-
```typescript
|
|
14
|
-
import { WebInferDevice, Tensor, matmul, flashAttention } from 'webinfer';
|
|
22
|
+
## API Status
|
|
15
23
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
const b = Tensor.rand(device, [1024, 1024]);
|
|
22
|
-
const c = await matmul(device, a, b);
|
|
23
|
-
|
|
24
|
-
// Read result back to CPU
|
|
25
|
-
const result = await c.toArray();
|
|
26
|
-
```
|
|
24
|
+
**Phase 1: Core + Simple Ops**
|
|
25
|
+
- `webinfer.WebInferContext.create()` - Initialize WebGPU context
|
|
26
|
+
- `webinfer.norm.rmsnorm()` - Root mean square normalization
|
|
27
|
+
- `webinfer.activation.silu_and_mul()` - SiLU activation with gating
|
|
28
|
+
- `webinfer.activation.gelu_and_mul()` - GELU activation with gating
|
|
27
29
|
|
|
28
|
-
|
|
30
|
+
**Phase 2: Single Attention**
|
|
31
|
+
- `webinfer.decode.single_decode_with_kv_cache()` - Flash attention for decode phase
|
|
32
|
+
- `webinfer.prefill.single_prefill_with_kv_cache()` - Flash attention for prefill phase with causal masking
|
|
29
33
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
| **Activations** | gelu, silu, relu, softmax |
|
|
35
|
-
| **Position** | rope (rotary embeddings) |
|
|
36
|
-
| **Sampling** | topK, topP, sample |
|
|
37
|
-
| **Quantization** | INT4/INT8 quantized matmul |
|
|
38
|
-
| **Model Loading** | SafeTensors, GGUF |
|
|
34
|
+
**Phase 3: Paged KV + Batched Operations**
|
|
35
|
+
- `webinfer.page.append_paged_kv_cache()` - Append new KV pairs to paged cache
|
|
36
|
+
- `webinfer.decode.BatchDecodeWithPagedKVCacheWrapper` - Batched decode with paged KV cache
|
|
37
|
+
- `webinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` - Batched prefill with paged KV cache
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
**Phase 4: Additional Operations**
|
|
40
|
+
- `webinfer.rope.apply_rope_inplace()` - Apply RoPE position encoding
|
|
41
|
+
- `webinfer.rope.apply_llama31_rope_inplace()` - Llama 3.1 RoPE variant for long context
|
|
42
|
+
- `webinfer.sampling.sampling_from_probs()` - Categorical sampling
|
|
43
|
+
- `webinfer.sampling.top_k_sampling_from_probs()` - Top-k sampling
|
|
44
|
+
- `webinfer.sampling.top_p_sampling_from_probs()` - Top-p (nucleus) sampling
|
|
45
|
+
- `webinfer.sampling.min_p_sampling_from_probs()` - Min-p sampling
|
|
46
|
+
- `webinfer.sampling.top_k_top_p_sampling_from_probs()` - Combined top-k and top-p
|
|
47
|
+
- `webinfer.gemm.bmm_fp16()` - Batched matrix multiplication (fp16)
|
|
48
|
+
- `webinfer.gemm.bmm_fp32()` - Batched matrix multiplication (fp32)
|
|
41
49
|
|
|
42
|
-
|
|
43
|
-
- Or Node.js with `@aspect-build/aspect-cli` for server-side WebGPU
|
|
50
|
+
See [dev/design.md](dev/design.md) for full design and implementation plan.
|
|
44
51
|
|
|
45
|
-
##
|
|
52
|
+
## Release
|
|
46
53
|
|
|
47
54
|
```bash
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
bun install
|
|
51
|
-
bun run bench
|
|
55
|
+
bun run build
|
|
56
|
+
npm publish
|
|
52
57
|
```
|
|
53
58
|
|
|
54
59
|
## License
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Activation functions
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* SiLU (Swish) activation with gating
|
|
8
|
+
*
|
|
9
|
+
* Input tensor has shape [..., 2 * hidden_size]
|
|
10
|
+
* Splits into gate[..., :hidden_size] and up[..., hidden_size:]
|
|
11
|
+
* Computes: output = silu(gate) * up
|
|
12
|
+
* Where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
|
13
|
+
*
|
|
14
|
+
* @param ctx WebInfer context
|
|
15
|
+
* @param input Input tensor of shape [..., 2 * hidden_size]
|
|
16
|
+
* @returns Activated tensor of shape [..., hidden_size]
|
|
17
|
+
*/
|
|
18
|
+
export declare function silu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
|
|
19
|
+
/**
|
|
20
|
+
* GELU activation with gating
|
|
21
|
+
*
|
|
22
|
+
* Input tensor has shape [..., 2 * hidden_size]
|
|
23
|
+
* Computes: output = gelu(gate) * up
|
|
24
|
+
* Where gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
|
|
25
|
+
*
|
|
26
|
+
* @param ctx WebInfer context
|
|
27
|
+
* @param input Input tensor of shape [..., 2 * hidden_size]
|
|
28
|
+
* @returns Activated tensor of shape [..., hidden_size]
|
|
29
|
+
*/
|
|
30
|
+
export declare function gelu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WebInferContext manages WebGPU device and provides core functionality
|
|
3
|
+
*/
|
|
4
|
+
import type { DType } from './types.ts';
|
|
5
|
+
import { Tensor } from './tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Device vendor detection
|
|
8
|
+
*/
|
|
9
|
+
export type Vendor = 'apple' | 'nvidia' | 'amd' | 'intel' | 'unknown';
|
|
10
|
+
/**
|
|
11
|
+
* Device features detected at runtime
|
|
12
|
+
*/
|
|
13
|
+
export interface DeviceFeatures {
|
|
14
|
+
readonly vendor: Vendor;
|
|
15
|
+
readonly hasSubgroups: boolean;
|
|
16
|
+
readonly hasF16: boolean;
|
|
17
|
+
readonly maxWorkgroupSize: number;
|
|
18
|
+
readonly maxStorageBufferSize: number;
|
|
19
|
+
readonly maxBufferSize: number;
|
|
20
|
+
readonly maxWorkgroupsPerDimension: number;
|
|
21
|
+
}
|
|
22
|
+
/**
|
|
23
|
+
* WebInferContext - main context for all WebInfer operations
|
|
24
|
+
*/
|
|
25
|
+
export declare class WebInferContext {
|
|
26
|
+
readonly device: GPUDevice;
|
|
27
|
+
readonly features: DeviceFeatures;
|
|
28
|
+
private pipelineCache;
|
|
29
|
+
private disposed;
|
|
30
|
+
private constructor();
|
|
31
|
+
/**
|
|
32
|
+
* Create a new WebInferContext
|
|
33
|
+
*/
|
|
34
|
+
static create(): Promise<WebInferContext>;
|
|
35
|
+
/**
|
|
36
|
+
* Create a GPU buffer
|
|
37
|
+
*/
|
|
38
|
+
createBuffer(size: number, usage?: GPUBufferUsageFlags): GPUBuffer;
|
|
39
|
+
/**
|
|
40
|
+
* Create a tensor with uninitialized data
|
|
41
|
+
*/
|
|
42
|
+
createTensor(shape: number[], dtype: DType): Tensor;
|
|
43
|
+
/**
|
|
44
|
+
* Create a tensor from TypedArray data
|
|
45
|
+
*/
|
|
46
|
+
createTensorFromData(data: Float32Array | Uint16Array, shape: number[], dtype: DType): Tensor;
|
|
47
|
+
/**
|
|
48
|
+
* Get or create a cached compute pipeline
|
|
49
|
+
*/
|
|
50
|
+
getOrCreatePipeline(key: string, createFn: () => Promise<GPUComputePipeline> | GPUComputePipeline): Promise<GPUComputePipeline>;
|
|
51
|
+
/**
|
|
52
|
+
* Clear pipeline cache
|
|
53
|
+
*/
|
|
54
|
+
clearPipelineCache(): void;
|
|
55
|
+
/**
|
|
56
|
+
* Dispose the context and cleanup resources
|
|
57
|
+
*/
|
|
58
|
+
dispose(): void;
|
|
59
|
+
private checkDisposed;
|
|
60
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Paged KV Cache implementation
|
|
3
|
+
*/
|
|
4
|
+
import type { Tensor } from './tensor.ts';
|
|
5
|
+
/**
|
|
6
|
+
* Paged KV Cache for efficient memory management
|
|
7
|
+
*
|
|
8
|
+
* Layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
|
|
9
|
+
* where dim 1 has [0] = keys, [1] = values
|
|
10
|
+
*/
|
|
11
|
+
export interface PagedKvCache {
|
|
12
|
+
/**
|
|
13
|
+
* The backing tensor for the paged KV cache
|
|
14
|
+
* Shape: [num_pages, 2, page_size, num_kv_heads, head_dim]
|
|
15
|
+
*/
|
|
16
|
+
data: Tensor;
|
|
17
|
+
/**
|
|
18
|
+
* Number of tokens per page
|
|
19
|
+
*/
|
|
20
|
+
page_size: number;
|
|
21
|
+
/**
|
|
22
|
+
* Number of KV heads (for GQA/MQA)
|
|
23
|
+
*/
|
|
24
|
+
num_kv_heads: number;
|
|
25
|
+
/**
|
|
26
|
+
* Dimension of each head
|
|
27
|
+
*/
|
|
28
|
+
head_dim: number;
|
|
29
|
+
}
|
|
30
|
+
/**
|
|
31
|
+
* Create a paged KV cache
|
|
32
|
+
*/
|
|
33
|
+
export declare function createPagedKvCache(data: Tensor, page_size: number, num_kv_heads: number, head_dim: number): PagedKvCache;
|
package/dist/core/tensor.d.ts
CHANGED
|
@@ -1,25 +1,44 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* Tensor
|
|
2
|
+
* Tensor wrapper for WebGPU buffers
|
|
3
|
+
*/
|
|
4
|
+
import type { DType } from './types.ts';
|
|
5
|
+
/**
|
|
6
|
+
* Compute strides from shape (row-major order)
|
|
7
|
+
*/
|
|
8
|
+
export declare function computeStrides(shape: number[]): number[];
|
|
9
|
+
/**
|
|
10
|
+
* Compute total number of elements from shape
|
|
11
|
+
*/
|
|
12
|
+
export declare function computeSize(shape: number[]): number;
|
|
13
|
+
/**
|
|
14
|
+
* Tensor class wrapping a GPUBuffer with shape and dtype information
|
|
3
15
|
*/
|
|
4
|
-
import type { WebInferDevice } from "./device.ts";
|
|
5
|
-
export type DType = "f32" | "f16" | "i32" | "u32";
|
|
6
16
|
export declare class Tensor {
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
get
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
get
|
|
22
|
-
|
|
17
|
+
readonly buffer: GPUBuffer;
|
|
18
|
+
readonly shape: readonly number[];
|
|
19
|
+
readonly dtype: DType;
|
|
20
|
+
readonly strides: readonly number[];
|
|
21
|
+
readonly size: number;
|
|
22
|
+
readonly byteSize: number;
|
|
23
|
+
constructor(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]);
|
|
24
|
+
/**
|
|
25
|
+
* Get the number of dimensions
|
|
26
|
+
*/
|
|
27
|
+
get ndim(): number;
|
|
28
|
+
/**
|
|
29
|
+
* Check if tensor is contiguous (row-major)
|
|
30
|
+
*/
|
|
31
|
+
get isContiguous(): boolean;
|
|
32
|
+
/**
|
|
33
|
+
* Get a view of the tensor with new shape (must have same size)
|
|
34
|
+
*/
|
|
23
35
|
reshape(newShape: number[]): Tensor;
|
|
36
|
+
/**
|
|
37
|
+
* Destroy the underlying GPU buffer
|
|
38
|
+
*/
|
|
24
39
|
dispose(): void;
|
|
40
|
+
/**
|
|
41
|
+
* Get string representation
|
|
42
|
+
*/
|
|
43
|
+
toString(): string;
|
|
25
44
|
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Core type definitions for WebInfer
|
|
3
|
+
*/
|
|
4
|
+
/**
|
|
5
|
+
* Data type for tensors
|
|
6
|
+
*/
|
|
7
|
+
export type DType = 'float32' | 'float16';
|
|
8
|
+
/**
|
|
9
|
+
* Position encoding modes for attention
|
|
10
|
+
*/
|
|
11
|
+
export declare enum PosEncodingMode {
|
|
12
|
+
NONE = 0,
|
|
13
|
+
ROPE_LLAMA = 1,
|
|
14
|
+
ALIBI = 2
|
|
15
|
+
}
|
|
16
|
+
/**
|
|
17
|
+
* Get the byte size for a given dtype
|
|
18
|
+
*/
|
|
19
|
+
export declare function dtypeByteSize(dtype: DType): number;
|
|
20
|
+
/**
|
|
21
|
+
* Get the GPUTextureFormat for a given dtype
|
|
22
|
+
*/
|
|
23
|
+
export declare function dtypeToGPUFormat(dtype: DType): string;
|
|
24
|
+
/**
|
|
25
|
+
* Get WGSL type string for a given dtype
|
|
26
|
+
*/
|
|
27
|
+
export declare function dtypeToWGSL(dtype: DType): 'f32' | 'f16';
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Decode attention operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
import type { PagedKvCache } from '../core/paged-kv-cache.ts';
|
|
7
|
+
import { PosEncodingMode } from '../core/types.ts';
|
|
8
|
+
/**
|
|
9
|
+
* Single decode with KV cache
|
|
10
|
+
*
|
|
11
|
+
* Performs attention for a single token (decode phase) with KV cache.
|
|
12
|
+
* Uses flash attention with online softmax to avoid materializing full attention matrix.
|
|
13
|
+
*
|
|
14
|
+
* @param ctx WebInfer context
|
|
15
|
+
* @param q Query tensor of shape [num_qo_heads, head_dim]
|
|
16
|
+
* @param k_cache Key cache tensor of shape [seq_len, num_kv_heads, head_dim]
|
|
17
|
+
* @param v_cache Value cache tensor of shape [seq_len, num_kv_heads, head_dim]
|
|
18
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
19
|
+
* @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
|
|
20
|
+
* @param rope_scale RoPE scale (default: 1.0) - not yet implemented
|
|
21
|
+
* @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
|
|
22
|
+
* @returns Output tensor of shape [num_qo_heads, head_dim]
|
|
23
|
+
*/
|
|
24
|
+
export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
|
|
25
|
+
/**
|
|
26
|
+
* Batched decode with paged KV cache wrapper
|
|
27
|
+
*
|
|
28
|
+
* FlashInfer-compatible wrapper for batched decode attention with paged KV cache.
|
|
29
|
+
*/
|
|
30
|
+
export declare class BatchDecodeWithPagedKVCacheWrapper {
|
|
31
|
+
private ctx;
|
|
32
|
+
private planned;
|
|
33
|
+
private num_qo_heads;
|
|
34
|
+
private num_kv_heads;
|
|
35
|
+
private head_dim;
|
|
36
|
+
private page_size;
|
|
37
|
+
private batch_size;
|
|
38
|
+
private sm_scale;
|
|
39
|
+
private pos_encoding_mode;
|
|
40
|
+
constructor(ctx: WebInferContext);
|
|
41
|
+
/**
|
|
42
|
+
* Plan the batched decode operation
|
|
43
|
+
*
|
|
44
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
45
|
+
* @param indices Page indices [nnz_pages]
|
|
46
|
+
* @param last_page_len Last page lengths [batch_size]
|
|
47
|
+
* @param num_qo_heads Number of query heads
|
|
48
|
+
* @param num_kv_heads Number of KV heads
|
|
49
|
+
* @param head_dim Head dimension
|
|
50
|
+
* @param page_size Page size
|
|
51
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
52
|
+
*/
|
|
53
|
+
plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, pos_encoding_mode?: PosEncodingMode): void;
|
|
54
|
+
/**
|
|
55
|
+
* Run the batched decode operation
|
|
56
|
+
*
|
|
57
|
+
* @param q Query tensor [batch_size, num_qo_heads, head_dim]
|
|
58
|
+
* @param paged_kv_cache Paged KV cache
|
|
59
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
60
|
+
* @param indices Page indices [nnz_pages]
|
|
61
|
+
* @param last_page_len Last page lengths [batch_size]
|
|
62
|
+
* @returns Output tensor [batch_size, num_qo_heads, head_dim]
|
|
63
|
+
*/
|
|
64
|
+
run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor): Promise<Tensor>;
|
|
65
|
+
}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GEMM (General Matrix Multiplication) operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Batched matrix multiplication for fp16
|
|
8
|
+
*
|
|
9
|
+
* Computes C = A @ B for batched matrices
|
|
10
|
+
*
|
|
11
|
+
* @param ctx WebInfer context
|
|
12
|
+
* @param a Input tensor A [B, M, K]
|
|
13
|
+
* @param b Input tensor B [B, K, N]
|
|
14
|
+
* @returns Output tensor C [B, M, N]
|
|
15
|
+
*/
|
|
16
|
+
export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
|
|
17
|
+
/**
|
|
18
|
+
* Batched matrix multiplication for fp32
|
|
19
|
+
*
|
|
20
|
+
* @param ctx WebInfer context
|
|
21
|
+
* @param a Input tensor A [B, M, K]
|
|
22
|
+
* @param b Input tensor B [B, K, N]
|
|
23
|
+
* @returns Output tensor C [B, M, N]
|
|
24
|
+
*/
|
|
25
|
+
export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
|
package/dist/index.d.ts
CHANGED
|
@@ -1,22 +1,29 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* WebInfer - High-performance LLM inference kernels for WebGPU
|
|
3
|
-
*
|
|
3
|
+
*
|
|
4
|
+
* WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
|
|
4
5
|
*/
|
|
5
|
-
export {
|
|
6
|
-
export { Tensor
|
|
7
|
-
export {
|
|
8
|
-
export {
|
|
9
|
-
export {
|
|
10
|
-
export {
|
|
11
|
-
export {
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
export {
|
|
21
|
-
export
|
|
22
|
-
export
|
|
6
|
+
export { WebInferContext } from './core/context.ts';
|
|
7
|
+
export { Tensor } from './core/tensor.ts';
|
|
8
|
+
export type { DType } from './core/types.ts';
|
|
9
|
+
export { PosEncodingMode } from './core/types.ts';
|
|
10
|
+
export type { PagedKvCache } from './core/paged-kv-cache.ts';
|
|
11
|
+
export { createPagedKvCache } from './core/paged-kv-cache.ts';
|
|
12
|
+
export { WebInferContext as create_context } from './core/context.ts';
|
|
13
|
+
import * as norm from './norm/index.ts';
|
|
14
|
+
import * as activation from './activation/index.ts';
|
|
15
|
+
import * as decode from './decode/index.ts';
|
|
16
|
+
import * as prefill from './prefill/index.ts';
|
|
17
|
+
import * as page from './page/index.ts';
|
|
18
|
+
import * as rope from './rope/index.ts';
|
|
19
|
+
import * as sampling from './sampling/index.ts';
|
|
20
|
+
import * as gemm from './gemm/index.ts';
|
|
21
|
+
export { norm, activation, decode, prefill, page, rope, sampling, gemm };
|
|
22
|
+
export declare const rmsnorm: typeof norm.rmsnorm, fused_add_rmsnorm: typeof norm.fused_add_rmsnorm, gemma_rmsnorm: typeof norm.gemma_rmsnorm;
|
|
23
|
+
export declare const silu_and_mul: typeof activation.silu_and_mul, gelu_and_mul: typeof activation.gelu_and_mul;
|
|
24
|
+
export declare const single_decode_with_kv_cache: typeof decode.single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper: typeof decode.BatchDecodeWithPagedKVCacheWrapper;
|
|
25
|
+
export declare const single_prefill_with_kv_cache: typeof prefill.single_prefill_with_kv_cache, BatchPrefillWithPagedKVCacheWrapper: typeof prefill.BatchPrefillWithPagedKVCacheWrapper;
|
|
26
|
+
export declare const append_paged_kv_cache: typeof page.append_paged_kv_cache;
|
|
27
|
+
export declare const apply_rope_inplace: typeof rope.apply_rope_inplace, apply_llama31_rope_inplace: typeof rope.apply_llama31_rope_inplace;
|
|
28
|
+
export declare const sampling_from_probs: typeof sampling.sampling_from_probs, top_k_sampling_from_probs: typeof sampling.top_k_sampling_from_probs, top_p_sampling_from_probs: typeof sampling.top_p_sampling_from_probs, min_p_sampling_from_probs: typeof sampling.min_p_sampling_from_probs, top_k_top_p_sampling_from_probs: typeof sampling.top_k_top_p_sampling_from_probs;
|
|
29
|
+
export declare const bmm_fp16: typeof gemm.bmm_fp16, bmm_fp32: typeof gemm.bmm_fp32;
|