webinfer 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. package/README.md +40 -25
  2. package/dist/activation/index.d.ts +30 -0
  3. package/dist/core/context.d.ts +60 -0
  4. package/dist/core/paged-kv-cache.d.ts +33 -0
  5. package/dist/core/tensor.d.ts +38 -19
  6. package/dist/core/types.d.ts +27 -0
  7. package/dist/decode/index.d.ts +65 -0
  8. package/dist/gemm/index.d.ts +25 -0
  9. package/dist/index.d.ts +26 -21
  10. package/dist/index.js +2439 -4872
  11. package/dist/kernels/activation.wgsl.d.ts +14 -0
  12. package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
  13. package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
  14. package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
  15. package/dist/kernels/gemm.wgsl.d.ts +17 -0
  16. package/dist/kernels/page.wgsl.d.ts +10 -0
  17. package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
  18. package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
  19. package/dist/kernels/rope.wgsl.d.ts +19 -0
  20. package/dist/kernels/sampling.wgsl.d.ts +23 -0
  21. package/dist/norm/index.d.ts +43 -0
  22. package/dist/page/index.d.ts +21 -0
  23. package/dist/prefill/index.d.ts +69 -0
  24. package/dist/rope/index.d.ts +37 -0
  25. package/dist/sampling/index.d.ts +53 -4
  26. package/package.json +1 -1
  27. package/dist/attention/block-sparse/format.d.ts +0 -52
  28. package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
  29. package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
  30. package/dist/attention/block-sparse/patterns/tree.d.ts +0 -65
  31. package/dist/attention/cascaded-inference.d.ts +0 -29
  32. package/dist/attention/flash-attention.d.ts +0 -30
  33. package/dist/attention/index.d.ts +0 -118
  34. package/dist/attention/paged-attention.d.ts +0 -40
  35. package/dist/attention/paged-kv/block-manager.d.ts +0 -102
  36. package/dist/attention/paged-kv/index.d.ts +0 -5
  37. package/dist/attention/paged-kv/page-table.d.ts +0 -165
  38. package/dist/attention/scheduler.d.ts +0 -40
  39. package/dist/core/buffer-pool.d.ts +0 -18
  40. package/dist/core/device.d.ts +0 -23
  41. package/dist/core/tdr.d.ts +0 -114
  42. package/dist/inference/engine.d.ts +0 -69
  43. package/dist/inference/generate.d.ts +0 -30
  44. package/dist/inference/index.d.ts +0 -7
  45. package/dist/inference/types.d.ts +0 -161
  46. package/dist/jit/compiler.d.ts +0 -23
  47. package/dist/jit/kernel-cache.d.ts +0 -21
  48. package/dist/model/gguf.d.ts +0 -90
  49. package/dist/model/index.d.ts +0 -16
  50. package/dist/model/safetensors.d.ts +0 -38
  51. package/dist/model/types.d.ts +0 -182
  52. package/dist/ops/activations.d.ts +0 -43
  53. package/dist/ops/elementwise.d.ts +0 -38
  54. package/dist/ops/embedding.d.ts +0 -30
  55. package/dist/ops/matmul.d.ts +0 -21
  56. package/dist/ops/normalization.d.ts +0 -63
  57. package/dist/ops/reshape.d.ts +0 -39
  58. package/dist/ops/rope.d.ts +0 -32
  59. package/dist/ops/softmax.d.ts +0 -18
  60. package/dist/quantization/index.d.ts +0 -6
  61. package/dist/quantization/qmatmul.d.ts +0 -38
  62. package/dist/quantization/quantize.d.ts +0 -52
  63. package/dist/sampling/beam-search.d.ts +0 -87
  64. package/dist/sampling/sampler.d.ts +0 -72
  65. package/dist/sampling/speculative.d.ts +0 -65
  66. package/dist/sampling/top-k.d.ts +0 -24
  67. package/dist/sampling/top-p.d.ts +0 -14
  68. package/dist/tvm/adapter.d.ts +0 -81
  69. package/dist/tvm/index.d.ts +0 -8
  70. package/dist/tvm/ops.d.ts +0 -26
  71. package/dist/tvm/types.d.ts +0 -35
@@ -0,0 +1,14 @@
1
+ /**
2
+ * Activation function WGSL kernels
3
+ */
4
+ export interface ActivationConfig {
5
+ dtype: 'f32' | 'f16';
6
+ }
7
+ /**
8
+ * SiLU (Swish) activation: silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
9
+ */
10
+ export declare function generateSiLUAndMulShader(config: ActivationConfig): string;
11
+ /**
12
+ * GELU activation: gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
13
+ */
14
+ export declare function generateGELUAndMulShader(config: ActivationConfig): string;
@@ -0,0 +1,12 @@
1
+ /**
2
+ * Batched decode attention with paged KV cache
3
+ */
4
+ export interface BatchDecodePagedConfig {
5
+ num_qo_heads: number;
6
+ num_kv_heads: number;
7
+ head_dim: number;
8
+ page_size: number;
9
+ sm_scale: number;
10
+ dtype: 'f32' | 'f16';
11
+ }
12
+ export declare function generateBatchDecodePagedShader(config: BatchDecodePagedConfig): string;
@@ -0,0 +1,13 @@
1
+ /**
2
+ * Batched prefill attention with paged KV cache
3
+ */
4
+ export interface BatchPrefillPagedConfig {
5
+ num_qo_heads: number;
6
+ num_kv_heads: number;
7
+ head_dim: number;
8
+ page_size: number;
9
+ sm_scale: number;
10
+ causal: boolean;
11
+ dtype: 'f32' | 'f16';
12
+ }
13
+ export declare function generateBatchPrefillPagedShader(config: BatchPrefillPagedConfig): string;
@@ -0,0 +1,16 @@
1
+ /**
2
+ * Single decode attention WGSL kernel
3
+ *
4
+ * Flash Attention for decode phase (single token query)
5
+ * Uses online softmax to avoid materializing full attention matrix
6
+ */
7
+ export interface DecodeAttentionConfig {
8
+ num_qo_heads: number;
9
+ num_kv_heads: number;
10
+ head_dim: number;
11
+ seq_len: number;
12
+ sm_scale: number;
13
+ dtype: 'f32' | 'f16';
14
+ useSubgroups: boolean;
15
+ }
16
+ export declare function generateDecodeAttentionShader(config: DecodeAttentionConfig): string;
@@ -0,0 +1,17 @@
1
+ /**
2
+ * GEMM (General Matrix Multiplication) kernels
3
+ */
4
+ export interface GemmConfig {
5
+ M: number;
6
+ N: number;
7
+ K: number;
8
+ tile_m: number;
9
+ tile_n: number;
10
+ tile_k: number;
11
+ dtype: 'f32' | 'f16';
12
+ }
13
+ export declare function generateBmmShader(config: GemmConfig): string;
14
+ /**
15
+ * Simpler non-tiled GEMM for small matrices
16
+ */
17
+ export declare function generateSimpleBmmShader(dtype: 'f32' | 'f16'): string;
@@ -0,0 +1,10 @@
1
+ /**
2
+ * Paged KV cache append kernel
3
+ */
4
+ export interface PageAppendConfig {
5
+ page_size: number;
6
+ num_kv_heads: number;
7
+ head_dim: number;
8
+ dtype: 'f32' | 'f16';
9
+ }
10
+ export declare function generatePageAppendShader(config: PageAppendConfig): string;
@@ -0,0 +1,17 @@
1
+ /**
2
+ * Single prefill attention WGSL kernel
3
+ *
4
+ * Flash Attention for prefill phase (multiple query tokens)
5
+ * Uses online softmax and supports causal masking
6
+ */
7
+ export interface PrefillAttentionConfig {
8
+ num_qo_heads: number;
9
+ num_kv_heads: number;
10
+ head_dim: number;
11
+ qo_len: number;
12
+ kv_len: number;
13
+ sm_scale: number;
14
+ causal: boolean;
15
+ dtype: 'f32' | 'f16';
16
+ }
17
+ export declare function generatePrefillAttentionShader(config: PrefillAttentionConfig): string;
@@ -0,0 +1,10 @@
1
+ /**
2
+ * RMSNorm WGSL kernel template
3
+ */
4
+ export interface RMSNormConfig {
5
+ hiddenSize: number;
6
+ eps: number;
7
+ dtype: 'f32' | 'f16';
8
+ useSubgroups: boolean;
9
+ }
10
+ export declare function generateRMSNormShader(config: RMSNormConfig): string;
@@ -0,0 +1,19 @@
1
+ /**
2
+ * RoPE (Rotary Position Embedding) kernels
3
+ */
4
+ export interface RoPEConfig {
5
+ head_dim: number;
6
+ rope_scale: number;
7
+ rope_theta: number;
8
+ dtype: 'f32' | 'f16';
9
+ }
10
+ export declare function generateRoPEShader(config: RoPEConfig): string;
11
+ export interface Llama31RoPEConfig {
12
+ head_dim: number;
13
+ rope_theta: number;
14
+ low_freq_factor: number;
15
+ high_freq_factor: number;
16
+ old_context_len: number;
17
+ dtype: 'f32' | 'f16';
18
+ }
19
+ export declare function generateLlama31RoPEShader(config: Llama31RoPEConfig): string;
@@ -0,0 +1,23 @@
1
+ /**
2
+ * Sampling kernels for token generation
3
+ */
4
+ export interface SamplingConfig {
5
+ vocab_size: number;
6
+ batch_size: number;
7
+ dtype: 'f32' | 'f16';
8
+ }
9
+ export declare function generateCategoricalSamplingShader(config: SamplingConfig): string;
10
+ export interface TopKSamplingConfig {
11
+ vocab_size: number;
12
+ batch_size: number;
13
+ top_k: number;
14
+ dtype: 'f32' | 'f16';
15
+ }
16
+ export declare function generateTopKSamplingShader(config: TopKSamplingConfig): string;
17
+ export interface TopPSamplingConfig {
18
+ vocab_size: number;
19
+ batch_size: number;
20
+ top_p: number;
21
+ dtype: 'f32' | 'f16';
22
+ }
23
+ export declare function generateTopPSamplingShader(config: TopPSamplingConfig): string;
@@ -0,0 +1,43 @@
1
+ /**
2
+ * Normalization operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * Root Mean Square Normalization
8
+ *
9
+ * Computes: output = (input / sqrt(mean(input^2) + eps)) * weight
10
+ *
11
+ * @param ctx WebInfer context
12
+ * @param input Input tensor of shape [..., hidden_size]
13
+ * @param weight Weight tensor of shape [hidden_size]
14
+ * @param eps Small constant for numerical stability (default: 1e-5)
15
+ * @returns Normalized tensor of same shape as input
16
+ */
17
+ export declare function rmsnorm(ctx: WebInferContext, input: Tensor, weight: Tensor, eps?: number): Promise<Tensor>;
18
+ /**
19
+ * Fused Add + RMSNorm
20
+ *
21
+ * Computes: output = rmsnorm(input + residual)
22
+ * Returns both the normalized output and the new residual (input + residual)
23
+ *
24
+ * @param ctx WebInfer context
25
+ * @param input Input tensor
26
+ * @param residual Residual tensor to add
27
+ * @param weight Weight tensor
28
+ * @param eps Small constant for numerical stability (default: 1e-5)
29
+ * @returns [output, new_residual] tuple
30
+ */
31
+ export declare function fused_add_rmsnorm(ctx: WebInferContext, input: Tensor, residual: Tensor, weight: Tensor, eps?: number): Promise<[Tensor, Tensor]>;
32
+ /**
33
+ * Gemma RMSNorm variant (adds 1.0 to weight)
34
+ *
35
+ * Computes: output = (input / sqrt(mean(input^2) + eps)) * (weight + 1.0)
36
+ *
37
+ * @param ctx WebInfer context
38
+ * @param input Input tensor
39
+ * @param weight Weight tensor
40
+ * @param eps Small constant for numerical stability (default: 1e-5)
41
+ * @returns Normalized tensor
42
+ */
43
+ export declare function gemma_rmsnorm(ctx: WebInferContext, input: Tensor, weight: Tensor, eps?: number): Promise<Tensor>;
@@ -0,0 +1,21 @@
1
+ /**
2
+ * Paged KV cache operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
+ /**
8
+ * Append new KV pairs to paged KV cache
9
+ *
10
+ * Appends new key-value pairs to the paged KV cache structure.
11
+ * Each batch item gets its new KV appended to its last page.
12
+ *
13
+ * @param ctx WebInfer context
14
+ * @param paged_kv_cache The paged KV cache to append to
15
+ * @param indptr Indirection pointer array [batch_size + 1]
16
+ * @param indices Page indices array [nnz_pages]
17
+ * @param last_page_len Number of valid tokens in last page for each batch item [batch_size]
18
+ * @param k New keys to append [total_len, num_kv_heads, head_dim]
19
+ * @param v New values to append [total_len, num_kv_heads, head_dim]
20
+ */
21
+ export declare function append_paged_kv_cache(ctx: WebInferContext, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor, k: Tensor, v: Tensor): Promise<void>;
@@ -0,0 +1,69 @@
1
+ /**
2
+ * Prefill attention operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
+ import { PosEncodingMode } from '../core/types.ts';
8
+ /**
9
+ * Single prefill with KV cache
10
+ *
11
+ * Performs attention for multiple query tokens (prefill phase) with KV cache.
12
+ * Uses flash attention with online softmax and supports causal masking.
13
+ *
14
+ * @param ctx WebInfer context
15
+ * @param q Query tensor of shape [qo_len, num_qo_heads, head_dim]
16
+ * @param k Key tensor of shape [kv_len, num_kv_heads, head_dim]
17
+ * @param v Value tensor of shape [kv_len, num_kv_heads, head_dim]
18
+ * @param causal Whether to apply causal masking (default: true)
19
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
20
+ * @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
21
+ * @param rope_scale RoPE scale (default: 1.0) - not yet implemented
22
+ * @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
23
+ * @returns Output tensor of shape [qo_len, num_qo_heads, head_dim]
24
+ */
25
+ export declare function single_prefill_with_kv_cache(ctx: WebInferContext, q: Tensor, k: Tensor, v: Tensor, causal?: boolean, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
26
+ /**
27
+ * Batched prefill with paged KV cache wrapper
28
+ *
29
+ * FlashInfer-compatible wrapper for batched prefill attention with paged KV cache.
30
+ */
31
+ export declare class BatchPrefillWithPagedKVCacheWrapper {
32
+ private ctx;
33
+ private planned;
34
+ private num_qo_heads;
35
+ private num_kv_heads;
36
+ private head_dim;
37
+ private page_size;
38
+ private sm_scale;
39
+ private causal;
40
+ private pos_encoding_mode;
41
+ constructor(ctx: WebInferContext);
42
+ /**
43
+ * Plan the batched prefill operation
44
+ *
45
+ * @param qo_indptr Query indirection pointer [batch_size + 1]
46
+ * @param paged_kv_indptr Paged KV indirection pointer [batch_size + 1]
47
+ * @param paged_kv_indices Paged KV indices [nnz_pages]
48
+ * @param paged_kv_last_page_len Last page lengths [batch_size]
49
+ * @param num_qo_heads Number of query heads
50
+ * @param num_kv_heads Number of KV heads
51
+ * @param head_dim Head dimension
52
+ * @param page_size Page size
53
+ * @param causal Whether to apply causal masking (default: true)
54
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
55
+ */
56
+ plan(qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, causal?: boolean, pos_encoding_mode?: PosEncodingMode): void;
57
+ /**
58
+ * Run the batched prefill operation
59
+ *
60
+ * @param q Query tensor [total_qo_len, num_qo_heads, head_dim]
61
+ * @param paged_kv_cache Paged KV cache
62
+ * @param qo_indptr Query indirection pointer [batch_size + 1]
63
+ * @param paged_kv_indptr Paged KV indirection pointer [batch_size + 1]
64
+ * @param paged_kv_indices Paged KV indices [nnz_pages]
65
+ * @param paged_kv_last_page_len Last page lengths [batch_size]
66
+ * @returns Output tensor [total_qo_len, num_qo_heads, head_dim]
67
+ */
68
+ run(q: Tensor, paged_kv_cache: PagedKvCache, qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor): Promise<Tensor>;
69
+ }
@@ -0,0 +1,37 @@
1
+ /**
2
+ * RoPE (Rotary Position Embedding) operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * Apply RoPE (Rotary Position Embedding) in-place
8
+ *
9
+ * Applies rotary position embeddings to query and key tensors.
10
+ * Modifies tensors in-place.
11
+ *
12
+ * @param ctx WebInfer context
13
+ * @param q Query tensor [nnz, num_qo_heads, head_dim]
14
+ * @param k Key tensor [nnz, num_kv_heads, head_dim]
15
+ * @param indptr Indirection pointer [batch_size + 1]
16
+ * @param offsets Position offsets for each batch item [batch_size]
17
+ * @param rope_scale RoPE scale factor (default: 1.0)
18
+ * @param rope_theta RoPE theta base (default: 10000.0)
19
+ */
20
+ export declare function apply_rope_inplace(ctx: WebInferContext, q: Tensor, k: Tensor, indptr: Tensor, offsets: Tensor, rope_scale?: number, rope_theta?: number): Promise<void>;
21
+ /**
22
+ * Apply Llama 3.1 RoPE variant in-place
23
+ *
24
+ * Applies Llama 3.1's frequency-scaled RoPE for long context support.
25
+ * Modifies tensors in-place.
26
+ *
27
+ * @param ctx WebInfer context
28
+ * @param q Query tensor [nnz, num_qo_heads, head_dim]
29
+ * @param k Key tensor [nnz, num_kv_heads, head_dim]
30
+ * @param indptr Indirection pointer [batch_size + 1]
31
+ * @param offsets Position offsets [batch_size]
32
+ * @param low_freq_factor Low frequency scaling factor (default: 1.0)
33
+ * @param high_freq_factor High frequency scaling factor (default: 4.0)
34
+ * @param old_context_len Original context length before scaling (default: 8192)
35
+ * @param rope_theta RoPE theta base (default: 500000.0)
36
+ */
37
+ export declare function apply_llama31_rope_inplace(ctx: WebInferContext, q: Tensor, k: Tensor, indptr: Tensor, offsets: Tensor, low_freq_factor?: number, high_freq_factor?: number, old_context_len?: number, rope_theta?: number): Promise<void>;
@@ -1,6 +1,55 @@
1
1
  /**
2
- * Sampling Module Exports
2
+ * Sampling operations for token generation
3
3
  */
4
- export { topK, topKFilter } from "./top-k.ts";
5
- export { topPFilter } from "./top-p.ts";
6
- export { sample, sampleGreedy, sampleFromProbs, softmax, applyRepetitionPenalty, minPSamplingFromProbs, topKSamplingFromProbs, topPSamplingFromProbs, topKTopPSamplingFromProbs, topKTopPSamplingFromLogits, topPRenormProbs, topKRenormProbs, topKMaskLogits, type SamplingConfig, } from "./sampler.ts";
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * Sample from probability distribution (categorical sampling)
8
+ *
9
+ * @param ctx WebInfer context
10
+ * @param probs Probability tensor [batch_size, vocab_size]
11
+ * @param uniform_samples Uniform random samples [batch_size] in range [0, 1)
12
+ * @returns Sampled token IDs [batch_size]
13
+ */
14
+ export declare function sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor): Promise<Tensor>;
15
+ /**
16
+ * Top-k sampling from probability distribution
17
+ *
18
+ * @param ctx WebInfer context
19
+ * @param probs Probability tensor [batch_size, vocab_size]
20
+ * @param uniform_samples Uniform random samples [batch_size]
21
+ * @param top_k Number of top candidates to consider
22
+ * @returns Sampled token IDs [batch_size]
23
+ */
24
+ export declare function top_k_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_k: number): Promise<Tensor>;
25
+ /**
26
+ * Top-p (nucleus) sampling from probability distribution
27
+ *
28
+ * @param ctx WebInfer context
29
+ * @param probs Probability tensor [batch_size, vocab_size]
30
+ * @param uniform_samples Uniform random samples [batch_size]
31
+ * @param top_p Cumulative probability threshold (0, 1]
32
+ * @returns Sampled token IDs [batch_size]
33
+ */
34
+ export declare function top_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_p: number): Promise<Tensor>;
35
+ /**
36
+ * Min-p sampling from probability distribution
37
+ *
38
+ * @param ctx WebInfer context
39
+ * @param probs Probability tensor [batch_size, vocab_size]
40
+ * @param uniform_samples Uniform random samples [batch_size]
41
+ * @param min_p Minimum probability threshold relative to max prob
42
+ * @returns Sampled token IDs [batch_size]
43
+ */
44
+ export declare function min_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, min_p: number): Promise<Tensor>;
45
+ /**
46
+ * Combined top-k and top-p sampling
47
+ *
48
+ * @param ctx WebInfer context
49
+ * @param probs Probability tensor [batch_size, vocab_size]
50
+ * @param uniform_samples Uniform random samples [batch_size]
51
+ * @param top_k Number of top candidates
52
+ * @param top_p Cumulative probability threshold
53
+ * @returns Sampled token IDs [batch_size]
54
+ */
55
+ export declare function top_k_top_p_sampling_from_probs(ctx: WebInferContext, probs: Tensor, uniform_samples: Tensor, top_k: number, top_p: number): Promise<Tensor>;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "webinfer",
3
- "version": "0.0.3",
3
+ "version": "0.0.4",
4
4
  "description": "High-performance LLM inference kernels for WebGPU",
5
5
  "license": "Apache-2.0",
6
6
  "repository": {
@@ -1,52 +0,0 @@
1
- /**
2
- * Block-Sparse CSR Format
3
- * Enables a single kernel to support all attention variants
4
- */
5
- /**
6
- * Block-Sparse CSR representation of attention mask
7
- */
8
- export interface BlockSparseCSR {
9
- blockSize: number;
10
- rowPtr: Uint32Array;
11
- colIdx: Uint32Array;
12
- blockMask?: Uint8Array;
13
- numRows: number;
14
- numCols: number;
15
- numBlockRows: number;
16
- numBlockCols: number;
17
- nnzBlocks: number;
18
- }
19
- /**
20
- * Attention pattern types
21
- */
22
- export type AttentionPattern = {
23
- type: "dense";
24
- } | {
25
- type: "causal";
26
- } | {
27
- type: "sliding";
28
- windowSize: number;
29
- } | {
30
- type: "global-local";
31
- globalTokens: number[];
32
- localWindow: number;
33
- } | {
34
- type: "custom";
35
- mask: boolean[][];
36
- };
37
- /**
38
- * Build BS-CSR from attention pattern
39
- */
40
- export declare function buildBlockSparseCSR(seqLen: number, pattern: AttentionPattern, blockSize?: number): BlockSparseCSR;
41
- /**
42
- * Calculate sparsity ratio of the mask
43
- */
44
- export declare function getSparsityRatio(csr: BlockSparseCSR): number;
45
- /**
46
- * Estimate memory savings from sparsity
47
- */
48
- export declare function estimateMemorySavings(csr: BlockSparseCSR): {
49
- denseBytes: number;
50
- sparseBytes: number;
51
- savingsRatio: number;
52
- };
@@ -1,16 +0,0 @@
1
- /**
2
- * Causal (autoregressive) attention pattern
3
- * Used in GPT, Llama, and most decoder-only models
4
- */
5
- import { type BlockSparseCSR } from "../format.ts";
6
- /**
7
- * Build causal attention mask in BS-CSR format
8
- * Each query position can only attend to positions <= its own position
9
- */
10
- export declare function buildCausalMask(seqLen: number, blockSize?: number): BlockSparseCSR;
11
- /**
12
- * Get the theoretical sparsity of causal attention
13
- * For a sequence of length N, causal attention has N*(N+1)/2 non-zero elements
14
- * out of N*N total, giving ~50% sparsity for large N
15
- */
16
- export declare function getCausalSparsity(seqLen: number): number;
@@ -1,22 +0,0 @@
1
- /**
2
- * Sliding window attention pattern
3
- * Used in Mistral and other efficient attention models
4
- */
5
- import { type BlockSparseCSR } from "../format.ts";
6
- /**
7
- * Build sliding window attention mask in BS-CSR format
8
- * Each query can only attend to the previous `windowSize` positions
9
- */
10
- export declare function buildSlidingWindowMask(seqLen: number, windowSize: number, blockSize?: number): BlockSparseCSR;
11
- /**
12
- * Get the theoretical sparsity of sliding window attention
13
- * For window size W and sequence length N:
14
- * - First W positions have triangular attention (like causal)
15
- * - Remaining N-W positions have W+1 attention each
16
- */
17
- export declare function getSlidingWindowSparsity(seqLen: number, windowSize: number): number;
18
- /**
19
- * Sliding window with causal constraint
20
- * This is what Mistral uses - combines sliding window with causal masking
21
- */
22
- export declare function buildCausalSlidingWindowMask(seqLen: number, windowSize: number, blockSize?: number): BlockSparseCSR;
@@ -1,65 +0,0 @@
1
- /**
2
- * Tree Attention Pattern
3
- * Used in speculative decoding (Medusa, EAGLE) and tree-based generation
4
- */
5
- import { type BlockSparseCSR } from "../format.ts";
6
- /**
7
- * Tree node structure
8
- */
9
- export interface TreeNode {
10
- /** Token position in sequence */
11
- position: number;
12
- /** Parent position (-1 for root) */
13
- parent: number;
14
- /** Depth in tree (0 for root) */
15
- depth: number;
16
- }
17
- /**
18
- * Tree attention configuration
19
- */
20
- export interface TreeAttentionConfig {
21
- /** Total sequence length including prompt */
22
- seqLen: number;
23
- /** Prompt length (prefix that all tokens attend to) */
24
- promptLen: number;
25
- /** Tree structure for speculative tokens */
26
- tree: TreeNode[];
27
- /** Block size for sparse format */
28
- blockSize?: number;
29
- }
30
- /**
31
- * Build tree attention mask
32
- *
33
- * In tree attention:
34
- * - All tokens attend to the prompt (positions 0 to promptLen-1)
35
- * - Tree tokens attend to their ancestors in the tree
36
- * - Maintains causal property within the tree structure
37
- */
38
- export declare function buildTreeMask(config: TreeAttentionConfig): BlockSparseCSR;
39
- /**
40
- * Build a simple chain tree (linear speculation)
41
- * Each token depends on the previous one
42
- */
43
- export declare function buildChainTree(numSpecTokens: number): TreeNode[];
44
- /**
45
- * Build a wide tree (parallel speculation)
46
- * All speculative tokens depend only on the prompt
47
- */
48
- export declare function buildWideTree(numSpecTokens: number): TreeNode[];
49
- /**
50
- * Build a binary tree for speculation
51
- */
52
- export declare function buildBinaryTree(depth: number): TreeNode[];
53
- /**
54
- * Build Medusa-style tree
55
- * Multiple heads predict tokens at different positions
56
- */
57
- export declare function buildMedusaTree(numHeads: number, tokensPerHead: number): TreeNode[];
58
- /**
59
- * Calculate tree sparsity ratio
60
- */
61
- export declare function getTreeSparsity(config: TreeAttentionConfig): number;
62
- /**
63
- * Validate tree structure
64
- */
65
- export declare function validateTree(tree: TreeNode[]): boolean;
@@ -1,29 +0,0 @@
1
- /**
2
- * Cascaded Inference - TDR-safe attention for very long sequences
3
- * Splits attention computation across multiple passes with browser yields
4
- */
5
- import type { WebInferDevice } from "../core/device.ts";
6
- import { Tensor } from "../core/tensor.ts";
7
- export interface CascadedAttentionConfig {
8
- numHeads: number;
9
- headDim: number;
10
- seqLen: number;
11
- scale?: number;
12
- causal?: boolean;
13
- }
14
- /**
15
- * Cascaded Attention - Safe for very long sequences
16
- * Uses Split-K strategy to prevent TDR (GPU timeout)
17
- *
18
- * @param device WebInfer device
19
- * @param q Query tensor [seqLen, numHeads, headDim]
20
- * @param k Key tensor [seqLen, numHeads, headDim]
21
- * @param v Value tensor [seqLen, numHeads, headDim]
22
- * @param config Attention configuration
23
- * @param onProgress Optional progress callback
24
- */
25
- export declare function cascadedAttention(device: WebInferDevice, q: Tensor, k: Tensor, v: Tensor, config: CascadedAttentionConfig, onProgress?: (chunk: number, total: number) => void): Promise<Tensor>;
26
- /**
27
- * CPU reference implementation for verification
28
- */
29
- export declare function cascadedAttentionCPU(q: Float32Array, k: Float32Array, v: Float32Array, seqLen: number, numHeads: number, headDim: number, causal?: boolean): Float32Array;
@@ -1,30 +0,0 @@
1
- /**
2
- * FlashAttention Implementation for WebGPU
3
- * Memory-efficient attention using online softmax and tiling
4
- */
5
- import type { WebInferDevice } from "../core/device.ts";
6
- import { Tensor } from "../core/tensor.ts";
7
- import type { AttentionPattern } from "./block-sparse/format.ts";
8
- export interface AttentionConfig {
9
- numHeads: number;
10
- headDim: number;
11
- seqLen: number;
12
- scale?: number;
13
- pattern?: AttentionPattern;
14
- blockSize?: number;
15
- }
16
- /**
17
- * FlashAttention forward pass
18
- * Computes: softmax(Q @ K^T / sqrt(d)) @ V
19
- *
20
- * @param device WebInfer device
21
- * @param q Query tensor [batch, seqLen, numHeads, headDim]
22
- * @param k Key tensor [batch, seqLen, numHeads, headDim]
23
- * @param v Value tensor [batch, seqLen, numHeads, headDim]
24
- * @param config Attention configuration
25
- */
26
- export declare function flashAttention(device: WebInferDevice, q: Tensor, k: Tensor, v: Tensor, config: AttentionConfig): Promise<Tensor>;
27
- /**
28
- * CPU reference implementation for verification
29
- */
30
- export declare function attentionCPU(q: Float32Array, k: Float32Array, v: Float32Array, seqLen: number, numHeads: number, headDim: number, causal?: boolean): Float32Array;