webinfer 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +40 -25
- package/dist/activation/index.d.ts +30 -0
- package/dist/core/context.d.ts +60 -0
- package/dist/core/paged-kv-cache.d.ts +33 -0
- package/dist/core/tensor.d.ts +38 -19
- package/dist/core/types.d.ts +27 -0
- package/dist/decode/index.d.ts +65 -0
- package/dist/gemm/index.d.ts +25 -0
- package/dist/index.d.ts +26 -21
- package/dist/index.js +2439 -4872
- package/dist/kernels/activation.wgsl.d.ts +14 -0
- package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
- package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
- package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
- package/dist/kernels/gemm.wgsl.d.ts +17 -0
- package/dist/kernels/page.wgsl.d.ts +10 -0
- package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
- package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
- package/dist/kernels/rope.wgsl.d.ts +19 -0
- package/dist/kernels/sampling.wgsl.d.ts +23 -0
- package/dist/norm/index.d.ts +43 -0
- package/dist/page/index.d.ts +21 -0
- package/dist/prefill/index.d.ts +69 -0
- package/dist/rope/index.d.ts +37 -0
- package/dist/sampling/index.d.ts +53 -4
- package/package.json +1 -1
- package/dist/attention/block-sparse/format.d.ts +0 -52
- package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
- package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
- package/dist/attention/block-sparse/patterns/tree.d.ts +0 -65
- package/dist/attention/cascaded-inference.d.ts +0 -29
- package/dist/attention/flash-attention.d.ts +0 -30
- package/dist/attention/index.d.ts +0 -118
- package/dist/attention/paged-attention.d.ts +0 -40
- package/dist/attention/paged-kv/block-manager.d.ts +0 -102
- package/dist/attention/paged-kv/index.d.ts +0 -5
- package/dist/attention/paged-kv/page-table.d.ts +0 -165
- package/dist/attention/scheduler.d.ts +0 -40
- package/dist/core/buffer-pool.d.ts +0 -18
- package/dist/core/device.d.ts +0 -23
- package/dist/core/tdr.d.ts +0 -114
- package/dist/inference/engine.d.ts +0 -69
- package/dist/inference/generate.d.ts +0 -30
- package/dist/inference/index.d.ts +0 -7
- package/dist/inference/types.d.ts +0 -161
- package/dist/jit/compiler.d.ts +0 -23
- package/dist/jit/kernel-cache.d.ts +0 -21
- package/dist/model/gguf.d.ts +0 -90
- package/dist/model/index.d.ts +0 -16
- package/dist/model/safetensors.d.ts +0 -38
- package/dist/model/types.d.ts +0 -182
- package/dist/ops/activations.d.ts +0 -43
- package/dist/ops/elementwise.d.ts +0 -38
- package/dist/ops/embedding.d.ts +0 -30
- package/dist/ops/matmul.d.ts +0 -21
- package/dist/ops/normalization.d.ts +0 -63
- package/dist/ops/reshape.d.ts +0 -39
- package/dist/ops/rope.d.ts +0 -32
- package/dist/ops/softmax.d.ts +0 -18
- package/dist/quantization/index.d.ts +0 -6
- package/dist/quantization/qmatmul.d.ts +0 -38
- package/dist/quantization/quantize.d.ts +0 -52
- package/dist/sampling/beam-search.d.ts +0 -87
- package/dist/sampling/sampler.d.ts +0 -72
- package/dist/sampling/speculative.d.ts +0 -65
- package/dist/sampling/top-k.d.ts +0 -24
- package/dist/sampling/top-p.d.ts +0 -14
- package/dist/tvm/adapter.d.ts +0 -81
- package/dist/tvm/index.d.ts +0 -8
- package/dist/tvm/ops.d.ts +0 -26
- package/dist/tvm/types.d.ts +0 -35
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Activation function WGSL kernels
|
|
3
|
+
*/
|
|
4
|
+
export interface ActivationConfig {
|
|
5
|
+
dtype: 'f32' | 'f16';
|
|
6
|
+
}
|
|
7
|
+
/**
|
|
8
|
+
* SiLU (Swish) activation: silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
|
9
|
+
*/
|
|
10
|
+
export declare function generateSiLUAndMulShader(config: ActivationConfig): string;
|
|
11
|
+
/**
|
|
12
|
+
* GELU activation: gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
|
|
13
|
+
*/
|
|
14
|
+
export declare function generateGELUAndMulShader(config: ActivationConfig): string;
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Batched decode attention with paged KV cache
|
|
3
|
+
*/
|
|
4
|
+
export interface BatchDecodePagedConfig {
|
|
5
|
+
num_qo_heads: number;
|
|
6
|
+
num_kv_heads: number;
|
|
7
|
+
head_dim: number;
|
|
8
|
+
page_size: number;
|
|
9
|
+
sm_scale: number;
|
|
10
|
+
dtype: 'f32' | 'f16';
|
|
11
|
+
}
|
|
12
|
+
export declare function generateBatchDecodePagedShader(config: BatchDecodePagedConfig): string;
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Batched prefill attention with paged KV cache
|
|
3
|
+
*/
|
|
4
|
+
export interface BatchPrefillPagedConfig {
|
|
5
|
+
num_qo_heads: number;
|
|
6
|
+
num_kv_heads: number;
|
|
7
|
+
head_dim: number;
|
|
8
|
+
page_size: number;
|
|
9
|
+
sm_scale: number;
|
|
10
|
+
causal: boolean;
|
|
11
|
+
dtype: 'f32' | 'f16';
|
|
12
|
+
}
|
|
13
|
+
export declare function generateBatchPrefillPagedShader(config: BatchPrefillPagedConfig): string;
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Single decode attention WGSL kernel
|
|
3
|
+
*
|
|
4
|
+
* Flash Attention for decode phase (single token query)
|
|
5
|
+
* Uses online softmax to avoid materializing full attention matrix
|
|
6
|
+
*/
|
|
7
|
+
export interface DecodeAttentionConfig {
|
|
8
|
+
num_qo_heads: number;
|
|
9
|
+
num_kv_heads: number;
|
|
10
|
+
head_dim: number;
|
|
11
|
+
seq_len: number;
|
|
12
|
+
sm_scale: number;
|
|
13
|
+
dtype: 'f32' | 'f16';
|
|
14
|
+
useSubgroups: boolean;
|
|
15
|
+
}
|
|
16
|
+
export declare function generateDecodeAttentionShader(config: DecodeAttentionConfig): string;
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GEMM (General Matrix Multiplication) kernels
|
|
3
|
+
*/
|
|
4
|
+
export interface GemmConfig {
|
|
5
|
+
M: number;
|
|
6
|
+
N: number;
|
|
7
|
+
K: number;
|
|
8
|
+
tile_m: number;
|
|
9
|
+
tile_n: number;
|
|
10
|
+
tile_k: number;
|
|
11
|
+
dtype: 'f32' | 'f16';
|
|
12
|
+
}
|
|
13
|
+
export declare function generateBmmShader(config: GemmConfig): string;
|
|
14
|
+
/**
|
|
15
|
+
* Simpler non-tiled GEMM for small matrices
|
|
16
|
+
*/
|
|
17
|
+
export declare function generateSimpleBmmShader(dtype: 'f32' | 'f16'): string;
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Single prefill attention WGSL kernel
|
|
3
|
+
*
|
|
4
|
+
* Flash Attention for prefill phase (multiple query tokens)
|
|
5
|
+
* Uses online softmax and supports causal masking
|
|
6
|
+
*/
|
|
7
|
+
export interface PrefillAttentionConfig {
|
|
8
|
+
num_qo_heads: number;
|
|
9
|
+
num_kv_heads: number;
|
|
10
|
+
head_dim: number;
|
|
11
|
+
qo_len: number;
|
|
12
|
+
kv_len: number;
|
|
13
|
+
sm_scale: number;
|
|
14
|
+
causal: boolean;
|
|
15
|
+
dtype: 'f32' | 'f16';
|
|
16
|
+
}
|
|
17
|
+
export declare function generatePrefillAttentionShader(config: PrefillAttentionConfig): string;
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* RoPE (Rotary Position Embedding) kernels
|
|
3
|
+
*/
|
|
4
|
+
export interface RoPEConfig {
|
|
5
|
+
head_dim: number;
|
|
6
|
+
rope_scale: number;
|
|
7
|
+
rope_theta: number;
|
|
8
|
+
dtype: 'f32' | 'f16';
|
|
9
|
+
}
|
|
10
|
+
export declare function generateRoPEShader(config: RoPEConfig): string;
|
|
11
|
+
export interface Llama31RoPEConfig {
|
|
12
|
+
head_dim: number;
|
|
13
|
+
rope_theta: number;
|
|
14
|
+
low_freq_factor: number;
|
|
15
|
+
high_freq_factor: number;
|
|
16
|
+
old_context_len: number;
|
|
17
|
+
dtype: 'f32' | 'f16';
|
|
18
|
+
}
|
|
19
|
+
export declare function generateLlama31RoPEShader(config: Llama31RoPEConfig): string;
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Sampling kernels for token generation
|
|
3
|
+
*/
|
|
4
|
+
export interface SamplingConfig {
|
|
5
|
+
vocab_size: number;
|
|
6
|
+
batch_size: number;
|
|
7
|
+
dtype: 'f32' | 'f16';
|
|
8
|
+
}
|
|
9
|
+
export declare function generateCategoricalSamplingShader(config: SamplingConfig): string;
|
|
10
|
+
export interface TopKSamplingConfig {
|
|
11
|
+
vocab_size: number;
|
|
12
|
+
batch_size: number;
|
|
13
|
+
top_k: number;
|
|
14
|
+
dtype: 'f32' | 'f16';
|
|
15
|
+
}
|
|
16
|
+
export declare function generateTopKSamplingShader(config: TopKSamplingConfig): string;
|
|
17
|
+
export interface TopPSamplingConfig {
|
|
18
|
+
vocab_size: number;
|
|
19
|
+
batch_size: number;
|
|
20
|
+
top_p: number;
|
|
21
|
+
dtype: 'f32' | 'f16';
|
|
22
|
+
}
|
|
23
|
+
export declare function generateTopPSamplingShader(config: TopPSamplingConfig): string;
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Normalization operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Root Mean Square Normalization
|
|
8
|
+
*
|
|
9
|
+
* Computes: output = (input / sqrt(mean(input^2) + eps)) * weight
|
|
10
|
+
*
|
|
11
|
+
* @param ctx WebInfer context
|
|
12
|
+
* @param input Input tensor of shape [..., hidden_size]
|
|
13
|
+
* @param weight Weight tensor of shape [hidden_size]
|
|
14
|
+
* @param eps Small constant for numerical stability (default: 1e-5)
|
|
15
|
+
* @returns Normalized tensor of same shape as input
|
|
16
|
+
*/
|
|
17
|
+
export declare function rmsnorm(ctx: WebInferContext, input: Tensor, weight: Tensor, eps?: number): Promise<Tensor>;
|
|
18
|
+
/**
|
|
19
|
+
* Fused Add + RMSNorm
|
|
20
|
+
*
|
|
21
|
+
* Computes: output = rmsnorm(input + residual)
|
|
22
|
+
* Returns both the normalized output and the new residual (input + residual)
|
|
23
|
+
*
|
|
24
|
+
* @param ctx WebInfer context
|
|
25
|
+
* @param input Input tensor
|
|
26
|
+
* @param residual Residual tensor to add
|
|
27
|
+
* @param weight Weight tensor
|
|
28
|
+
* @param eps Small constant for numerical stability (default: 1e-5)
|
|
29
|
+
* @returns [output, new_residual] tuple
|
|
30
|
+
*/
|
|
31
|
+
export declare function fused_add_rmsnorm(ctx: WebInferContext, input: Tensor, residual: Tensor, weight: Tensor, eps?: number): Promise<[Tensor, Tensor]>;
|
|
32
|
+
/**
|
|
33
|
+
* Gemma RMSNorm variant (adds 1.0 to weight)
|
|
34
|
+
*
|
|
35
|
+
* Computes: output = (input / sqrt(mean(input^2) + eps)) * (weight + 1.0)
|
|
36
|
+
*
|
|
37
|
+
* @param ctx WebInfer context
|
|
38
|
+
* @param input Input tensor
|
|
39
|
+
* @param weight Weight tensor
|
|
40
|
+
* @param eps Small constant for numerical stability (default: 1e-5)
|
|
41
|
+
* @returns Normalized tensor
|
|
42
|
+
*/
|
|
43
|
+
export declare function gemma_rmsnorm(ctx: WebInferContext, input: Tensor, weight: Tensor, eps?: number): Promise<Tensor>;
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Paged KV cache operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
import type { PagedKvCache } from '../core/paged-kv-cache.ts';
|
|
7
|
+
/**
|
|
8
|
+
* Append new KV pairs to paged KV cache
|
|
9
|
+
*
|
|
10
|
+
* Appends new key-value pairs to the paged KV cache structure.
|
|
11
|
+
* Each batch item gets its new KV appended to its last page.
|
|
12
|
+
*
|
|
13
|
+
* @param ctx WebInfer context
|
|
14
|
+
* @param paged_kv_cache The paged KV cache to append to
|
|
15
|
+
* @param indptr Indirection pointer array [batch_size + 1]
|
|
16
|
+
* @param indices Page indices array [nnz_pages]
|
|
17
|
+
* @param last_page_len Number of valid tokens in last page for each batch item [batch_size]
|
|
18
|
+
* @param k New keys to append [total_len, num_kv_heads, head_dim]
|
|
19
|
+
* @param v New values to append [total_len, num_kv_heads, head_dim]
|
|
20
|
+
*/
|
|
21
|
+
export declare function append_paged_kv_cache(ctx: WebInferContext, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor, k: Tensor, v: Tensor): Promise<void>;
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Prefill attention operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
import type { PagedKvCache } from '../core/paged-kv-cache.ts';
|
|
7
|
+
import { PosEncodingMode } from '../core/types.ts';
|
|
8
|
+
/**
|
|
9
|
+
* Single prefill with KV cache
|
|
10
|
+
*
|
|
11
|
+
* Performs attention for multiple query tokens (prefill phase) with KV cache.
|
|
12
|
+
* Uses flash attention with online softmax and supports causal masking.
|
|
13
|
+
*
|
|
14
|
+
* @param ctx WebInfer context
|
|
15
|
+
* @param q Query tensor of shape [qo_len, num_qo_heads, head_dim]
|
|
16
|
+
* @param k Key tensor of shape [kv_len, num_kv_heads, head_dim]
|
|
17
|
+
* @param v Value tensor of shape [kv_len, num_kv_heads, head_dim]
|
|
18
|
+
* @param causal Whether to apply causal masking (default: true)
|
|
19
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
20
|
+
* @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
|
|
21
|
+
* @param rope_scale RoPE scale (default: 1.0) - not yet implemented
|
|
22
|
+
* @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
|
|
23
|
+
* @returns Output tensor of shape [qo_len, num_qo_heads, head_dim]
|
|
24
|
+
*/
|
|
25
|
+
export declare function single_prefill_with_kv_cache(ctx: WebInferContext, q: Tensor, k: Tensor, v: Tensor, causal?: boolean, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
|
|
26
|
+
/**
|
|
27
|
+
* Batched prefill with paged KV cache wrapper
|
|
28
|
+
*
|
|
29
|
+
* FlashInfer-compatible wrapper for batched prefill attention with paged KV cache.
|
|
30
|
+
*/
|
|
31
|
+
export declare class BatchPrefillWithPagedKVCacheWrapper {
|
|
32
|
+
private ctx;
|
|
33
|
+
private planned;
|
|
34
|
+
private num_qo_heads;
|
|
35
|
+
private num_kv_heads;
|
|
36
|
+
private head_dim;
|
|
37
|
+
private page_size;
|
|
38
|
+
private sm_scale;
|
|
39
|
+
private causal;
|
|
40
|
+
private pos_encoding_mode;
|
|
41
|
+
constructor(ctx: WebInferContext);
|
|
42
|
+
/**
|
|
43
|
+
* Plan the batched prefill operation
|
|
44
|
+
*
|
|
45
|
+
* @param qo_indptr Query indirection pointer [batch_size + 1]
|
|
46
|
+
* @param paged_kv_indptr Paged KV indirection pointer [batch_size + 1]
|
|
47
|
+
* @param paged_kv_indices Paged KV indices [nnz_pages]
|
|
48
|
+
* @param paged_kv_last_page_len Last page lengths [batch_size]
|
|
49
|
+
* @param num_qo_heads Number of query heads
|
|
50
|
+
* @param num_kv_heads Number of KV heads
|
|
51
|
+
* @param head_dim Head dimension
|
|
52
|
+
* @param page_size Page size
|
|
53
|
+
* @param causal Whether to apply causal masking (default: true)
|
|
54
|
+
* @param pos_encoding_mode Position encoding mode (default: NONE)
|
|
55
|
+
*/
|
|
56
|
+
plan(qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, causal?: boolean, pos_encoding_mode?: PosEncodingMode): void;
|
|
57
|
+
/**
|
|
58
|
+
* Run the batched prefill operation
|
|
59
|
+
*
|
|
60
|
+
* @param q Query tensor [total_qo_len, num_qo_heads, head_dim]
|
|
61
|
+
* @param paged_kv_cache Paged KV cache
|
|
62
|
+
* @param qo_indptr Query indirection pointer [batch_size + 1]
|
|
63
|
+
* @param paged_kv_indptr Paged KV indirection pointer [batch_size + 1]
|
|
64
|
+
* @param paged_kv_indices Paged KV indices [nnz_pages]
|
|
65
|
+
* @param paged_kv_last_page_len Last page lengths [batch_size]
|
|
66
|
+
* @returns Output tensor [total_qo_len, num_qo_heads, head_dim]
|
|
67
|
+
*/
|
|
68
|
+
run(q: Tensor, paged_kv_cache: PagedKvCache, qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor): Promise<Tensor>;
|
|
69
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* RoPE (Rotary Position Embedding) operations
|
|
3
|
+
*/
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import type { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Apply RoPE (Rotary Position Embedding) in-place
|
|
8
|
+
*
|
|
9
|
+
* Applies rotary position embeddings to query and key tensors.
|
|
10
|
+
* Modifies tensors in-place.
|
|
11
|
+
*
|
|
12
|
+
* @param ctx WebInfer context
|
|
13
|
+
* @param q Query tensor [nnz, num_qo_heads, head_dim]
|
|
14
|
+
* @param k Key tensor [nnz, num_kv_heads, head_dim]
|
|
15
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
16
|
+
* @param offsets Position offsets for each batch item [batch_size]
|
|
17
|
+
* @param rope_scale RoPE scale factor (default: 1.0)
|
|
18
|
+
* @param rope_theta RoPE theta base (default: 10000.0)
|
|
19
|
+
*/
|
|
20
|
+
export declare function apply_rope_inplace(ctx: WebInferContext, q: Tensor, k: Tensor, indptr: Tensor, offsets: Tensor, rope_scale?: number, rope_theta?: number): Promise<void>;
|
|
21
|
+
/**
|
|
22
|
+
* Apply Llama 3.1 RoPE variant in-place
|
|
23
|
+
*
|
|
24
|
+
* Applies Llama 3.1's frequency-scaled RoPE for long context support.
|
|
25
|
+
* Modifies tensors in-place.
|
|
26
|
+
*
|
|
27
|
+
* @param ctx WebInfer context
|
|
28
|
+
* @param q Query tensor [nnz, num_qo_heads, head_dim]
|
|
29
|
+
* @param k Key tensor [nnz, num_kv_heads, head_dim]
|
|
30
|
+
* @param indptr Indirection pointer [batch_size + 1]
|
|
31
|
+
* @param offsets Position offsets [batch_size]
|
|
32
|
+
* @param low_freq_factor Low frequency scaling factor (default: 1.0)
|
|
33
|
+
* @param high_freq_factor High frequency scaling factor (default: 4.0)
|
|
34
|
+
* @param old_context_len Original context length before scaling (default: 8192)
|
|
35
|
+
* @param rope_theta RoPE theta base (default: 500000.0)
|
|
36
|
+
*/
|
|
37
|
+
export declare function apply_llama31_rope_inplace(ctx: WebInferContext, q: Tensor, k: Tensor, indptr: Tensor, offsets: Tensor, low_freq_factor?: number, high_freq_factor?: number, old_context_len?: number, rope_theta?: number): Promise<void>;
|
package/dist/sampling/index.d.ts
CHANGED
|
@@ -1,6 +1,55 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* Sampling
|
|
2
|
+
* Sampling operations for token generation
|
|
3
3
|
*/
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
4
|
+
import type { WebInferContext } from '../core/context.ts';
|
|
5
|
+
import { Tensor } from '../core/tensor.ts';
|
|
6
|
+
/**
|
|
7
|
+
* Sample from probability distribution (categorical sampling)
|
|
8
|
+
*
|
|
9
|
+
* @param ctx WebInfer context
|
|
10
|
+
* @param probs Probability tensor [batch_size, vocab_size]
|
|
11
|
+
* @param uniform_samples Uniform random samples [batch_size] in range [0, 1)
|
|
12
|
+
* @returns Sampled token IDs [batch_size]
|
|
13
|
+
*/
|
|
14
|
+
export declare function sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor): Promise<Tensor>;
|
|
15
|
+
/**
|
|
16
|
+
* Top-k sampling from probability distribution
|
|
17
|
+
*
|
|
18
|
+
* @param ctx WebInfer context
|
|
19
|
+
* @param probs Probability tensor [batch_size, vocab_size]
|
|
20
|
+
* @param uniform_samples Uniform random samples [batch_size]
|
|
21
|
+
* @param top_k Number of top candidates to consider
|
|
22
|
+
* @returns Sampled token IDs [batch_size]
|
|
23
|
+
*/
|
|
24
|
+
export declare function top_k_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_k: number): Promise<Tensor>;
|
|
25
|
+
/**
|
|
26
|
+
* Top-p (nucleus) sampling from probability distribution
|
|
27
|
+
*
|
|
28
|
+
* @param ctx WebInfer context
|
|
29
|
+
* @param probs Probability tensor [batch_size, vocab_size]
|
|
30
|
+
* @param uniform_samples Uniform random samples [batch_size]
|
|
31
|
+
* @param top_p Cumulative probability threshold (0, 1]
|
|
32
|
+
* @returns Sampled token IDs [batch_size]
|
|
33
|
+
*/
|
|
34
|
+
export declare function top_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_p: number): Promise<Tensor>;
|
|
35
|
+
/**
|
|
36
|
+
* Min-p sampling from probability distribution
|
|
37
|
+
*
|
|
38
|
+
* @param ctx WebInfer context
|
|
39
|
+
* @param probs Probability tensor [batch_size, vocab_size]
|
|
40
|
+
* @param uniform_samples Uniform random samples [batch_size]
|
|
41
|
+
* @param min_p Minimum probability threshold relative to max prob
|
|
42
|
+
* @returns Sampled token IDs [batch_size]
|
|
43
|
+
*/
|
|
44
|
+
export declare function min_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, min_p: number): Promise<Tensor>;
|
|
45
|
+
/**
|
|
46
|
+
* Combined top-k and top-p sampling
|
|
47
|
+
*
|
|
48
|
+
* @param ctx WebInfer context
|
|
49
|
+
* @param probs Probability tensor [batch_size, vocab_size]
|
|
50
|
+
* @param uniform_samples Uniform random samples [batch_size]
|
|
51
|
+
* @param top_k Number of top candidates
|
|
52
|
+
* @param top_p Cumulative probability threshold
|
|
53
|
+
* @returns Sampled token IDs [batch_size]
|
|
54
|
+
*/
|
|
55
|
+
export declare function top_k_top_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_k: number, top_p: number): Promise<Tensor>;
|
package/package.json
CHANGED
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Block-Sparse CSR Format
|
|
3
|
-
* Enables a single kernel to support all attention variants
|
|
4
|
-
*/
|
|
5
|
-
/**
|
|
6
|
-
* Block-Sparse CSR representation of attention mask
|
|
7
|
-
*/
|
|
8
|
-
export interface BlockSparseCSR {
|
|
9
|
-
blockSize: number;
|
|
10
|
-
rowPtr: Uint32Array;
|
|
11
|
-
colIdx: Uint32Array;
|
|
12
|
-
blockMask?: Uint8Array;
|
|
13
|
-
numRows: number;
|
|
14
|
-
numCols: number;
|
|
15
|
-
numBlockRows: number;
|
|
16
|
-
numBlockCols: number;
|
|
17
|
-
nnzBlocks: number;
|
|
18
|
-
}
|
|
19
|
-
/**
|
|
20
|
-
* Attention pattern types
|
|
21
|
-
*/
|
|
22
|
-
export type AttentionPattern = {
|
|
23
|
-
type: "dense";
|
|
24
|
-
} | {
|
|
25
|
-
type: "causal";
|
|
26
|
-
} | {
|
|
27
|
-
type: "sliding";
|
|
28
|
-
windowSize: number;
|
|
29
|
-
} | {
|
|
30
|
-
type: "global-local";
|
|
31
|
-
globalTokens: number[];
|
|
32
|
-
localWindow: number;
|
|
33
|
-
} | {
|
|
34
|
-
type: "custom";
|
|
35
|
-
mask: boolean[][];
|
|
36
|
-
};
|
|
37
|
-
/**
|
|
38
|
-
* Build BS-CSR from attention pattern
|
|
39
|
-
*/
|
|
40
|
-
export declare function buildBlockSparseCSR(seqLen: number, pattern: AttentionPattern, blockSize?: number): BlockSparseCSR;
|
|
41
|
-
/**
|
|
42
|
-
* Calculate sparsity ratio of the mask
|
|
43
|
-
*/
|
|
44
|
-
export declare function getSparsityRatio(csr: BlockSparseCSR): number;
|
|
45
|
-
/**
|
|
46
|
-
* Estimate memory savings from sparsity
|
|
47
|
-
*/
|
|
48
|
-
export declare function estimateMemorySavings(csr: BlockSparseCSR): {
|
|
49
|
-
denseBytes: number;
|
|
50
|
-
sparseBytes: number;
|
|
51
|
-
savingsRatio: number;
|
|
52
|
-
};
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Causal (autoregressive) attention pattern
|
|
3
|
-
* Used in GPT, Llama, and most decoder-only models
|
|
4
|
-
*/
|
|
5
|
-
import { type BlockSparseCSR } from "../format.ts";
|
|
6
|
-
/**
|
|
7
|
-
* Build causal attention mask in BS-CSR format
|
|
8
|
-
* Each query position can only attend to positions <= its own position
|
|
9
|
-
*/
|
|
10
|
-
export declare function buildCausalMask(seqLen: number, blockSize?: number): BlockSparseCSR;
|
|
11
|
-
/**
|
|
12
|
-
* Get the theoretical sparsity of causal attention
|
|
13
|
-
* For a sequence of length N, causal attention has N*(N+1)/2 non-zero elements
|
|
14
|
-
* out of N*N total, giving ~50% sparsity for large N
|
|
15
|
-
*/
|
|
16
|
-
export declare function getCausalSparsity(seqLen: number): number;
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Sliding window attention pattern
|
|
3
|
-
* Used in Mistral and other efficient attention models
|
|
4
|
-
*/
|
|
5
|
-
import { type BlockSparseCSR } from "../format.ts";
|
|
6
|
-
/**
|
|
7
|
-
* Build sliding window attention mask in BS-CSR format
|
|
8
|
-
* Each query can only attend to the previous `windowSize` positions
|
|
9
|
-
*/
|
|
10
|
-
export declare function buildSlidingWindowMask(seqLen: number, windowSize: number, blockSize?: number): BlockSparseCSR;
|
|
11
|
-
/**
|
|
12
|
-
* Get the theoretical sparsity of sliding window attention
|
|
13
|
-
* For window size W and sequence length N:
|
|
14
|
-
* - First W positions have triangular attention (like causal)
|
|
15
|
-
* - Remaining N-W positions have W+1 attention each
|
|
16
|
-
*/
|
|
17
|
-
export declare function getSlidingWindowSparsity(seqLen: number, windowSize: number): number;
|
|
18
|
-
/**
|
|
19
|
-
* Sliding window with causal constraint
|
|
20
|
-
* This is what Mistral uses - combines sliding window with causal masking
|
|
21
|
-
*/
|
|
22
|
-
export declare function buildCausalSlidingWindowMask(seqLen: number, windowSize: number, blockSize?: number): BlockSparseCSR;
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Tree Attention Pattern
|
|
3
|
-
* Used in speculative decoding (Medusa, EAGLE) and tree-based generation
|
|
4
|
-
*/
|
|
5
|
-
import { type BlockSparseCSR } from "../format.ts";
|
|
6
|
-
/**
|
|
7
|
-
* Tree node structure
|
|
8
|
-
*/
|
|
9
|
-
export interface TreeNode {
|
|
10
|
-
/** Token position in sequence */
|
|
11
|
-
position: number;
|
|
12
|
-
/** Parent position (-1 for root) */
|
|
13
|
-
parent: number;
|
|
14
|
-
/** Depth in tree (0 for root) */
|
|
15
|
-
depth: number;
|
|
16
|
-
}
|
|
17
|
-
/**
|
|
18
|
-
* Tree attention configuration
|
|
19
|
-
*/
|
|
20
|
-
export interface TreeAttentionConfig {
|
|
21
|
-
/** Total sequence length including prompt */
|
|
22
|
-
seqLen: number;
|
|
23
|
-
/** Prompt length (prefix that all tokens attend to) */
|
|
24
|
-
promptLen: number;
|
|
25
|
-
/** Tree structure for speculative tokens */
|
|
26
|
-
tree: TreeNode[];
|
|
27
|
-
/** Block size for sparse format */
|
|
28
|
-
blockSize?: number;
|
|
29
|
-
}
|
|
30
|
-
/**
|
|
31
|
-
* Build tree attention mask
|
|
32
|
-
*
|
|
33
|
-
* In tree attention:
|
|
34
|
-
* - All tokens attend to the prompt (positions 0 to promptLen-1)
|
|
35
|
-
* - Tree tokens attend to their ancestors in the tree
|
|
36
|
-
* - Maintains causal property within the tree structure
|
|
37
|
-
*/
|
|
38
|
-
export declare function buildTreeMask(config: TreeAttentionConfig): BlockSparseCSR;
|
|
39
|
-
/**
|
|
40
|
-
* Build a simple chain tree (linear speculation)
|
|
41
|
-
* Each token depends on the previous one
|
|
42
|
-
*/
|
|
43
|
-
export declare function buildChainTree(numSpecTokens: number): TreeNode[];
|
|
44
|
-
/**
|
|
45
|
-
* Build a wide tree (parallel speculation)
|
|
46
|
-
* All speculative tokens depend only on the prompt
|
|
47
|
-
*/
|
|
48
|
-
export declare function buildWideTree(numSpecTokens: number): TreeNode[];
|
|
49
|
-
/**
|
|
50
|
-
* Build a binary tree for speculation
|
|
51
|
-
*/
|
|
52
|
-
export declare function buildBinaryTree(depth: number): TreeNode[];
|
|
53
|
-
/**
|
|
54
|
-
* Build Medusa-style tree
|
|
55
|
-
* Multiple heads predict tokens at different positions
|
|
56
|
-
*/
|
|
57
|
-
export declare function buildMedusaTree(numHeads: number, tokensPerHead: number): TreeNode[];
|
|
58
|
-
/**
|
|
59
|
-
* Calculate tree sparsity ratio
|
|
60
|
-
*/
|
|
61
|
-
export declare function getTreeSparsity(config: TreeAttentionConfig): number;
|
|
62
|
-
/**
|
|
63
|
-
* Validate tree structure
|
|
64
|
-
*/
|
|
65
|
-
export declare function validateTree(tree: TreeNode[]): boolean;
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Cascaded Inference - TDR-safe attention for very long sequences
|
|
3
|
-
* Splits attention computation across multiple passes with browser yields
|
|
4
|
-
*/
|
|
5
|
-
import type { WebInferDevice } from "../core/device.ts";
|
|
6
|
-
import { Tensor } from "../core/tensor.ts";
|
|
7
|
-
export interface CascadedAttentionConfig {
|
|
8
|
-
numHeads: number;
|
|
9
|
-
headDim: number;
|
|
10
|
-
seqLen: number;
|
|
11
|
-
scale?: number;
|
|
12
|
-
causal?: boolean;
|
|
13
|
-
}
|
|
14
|
-
/**
|
|
15
|
-
* Cascaded Attention - Safe for very long sequences
|
|
16
|
-
* Uses Split-K strategy to prevent TDR (GPU timeout)
|
|
17
|
-
*
|
|
18
|
-
* @param device WebInfer device
|
|
19
|
-
* @param q Query tensor [seqLen, numHeads, headDim]
|
|
20
|
-
* @param k Key tensor [seqLen, numHeads, headDim]
|
|
21
|
-
* @param v Value tensor [seqLen, numHeads, headDim]
|
|
22
|
-
* @param config Attention configuration
|
|
23
|
-
* @param onProgress Optional progress callback
|
|
24
|
-
*/
|
|
25
|
-
export declare function cascadedAttention(device: WebInferDevice, q: Tensor, k: Tensor, v: Tensor, config: CascadedAttentionConfig, onProgress?: (chunk: number, total: number) => void): Promise<Tensor>;
|
|
26
|
-
/**
|
|
27
|
-
* CPU reference implementation for verification
|
|
28
|
-
*/
|
|
29
|
-
export declare function cascadedAttentionCPU(q: Float32Array, k: Float32Array, v: Float32Array, seqLen: number, numHeads: number, headDim: number, causal?: boolean): Float32Array;
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* FlashAttention Implementation for WebGPU
|
|
3
|
-
* Memory-efficient attention using online softmax and tiling
|
|
4
|
-
*/
|
|
5
|
-
import type { WebInferDevice } from "../core/device.ts";
|
|
6
|
-
import { Tensor } from "../core/tensor.ts";
|
|
7
|
-
import type { AttentionPattern } from "./block-sparse/format.ts";
|
|
8
|
-
export interface AttentionConfig {
|
|
9
|
-
numHeads: number;
|
|
10
|
-
headDim: number;
|
|
11
|
-
seqLen: number;
|
|
12
|
-
scale?: number;
|
|
13
|
-
pattern?: AttentionPattern;
|
|
14
|
-
blockSize?: number;
|
|
15
|
-
}
|
|
16
|
-
/**
|
|
17
|
-
* FlashAttention forward pass
|
|
18
|
-
* Computes: softmax(Q @ K^T / sqrt(d)) @ V
|
|
19
|
-
*
|
|
20
|
-
* @param device WebInfer device
|
|
21
|
-
* @param q Query tensor [batch, seqLen, numHeads, headDim]
|
|
22
|
-
* @param k Key tensor [batch, seqLen, numHeads, headDim]
|
|
23
|
-
* @param v Value tensor [batch, seqLen, numHeads, headDim]
|
|
24
|
-
* @param config Attention configuration
|
|
25
|
-
*/
|
|
26
|
-
export declare function flashAttention(device: WebInferDevice, q: Tensor, k: Tensor, v: Tensor, config: AttentionConfig): Promise<Tensor>;
|
|
27
|
-
/**
|
|
28
|
-
* CPU reference implementation for verification
|
|
29
|
-
*/
|
|
30
|
-
export declare function attentionCPU(q: Float32Array, k: Float32Array, v: Float32Array, seqLen: number, numHeads: number, headDim: number, causal?: boolean): Float32Array;
|