webinfer 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +38 -33
  2. package/dist/activation/index.d.ts +30 -0
  3. package/dist/core/context.d.ts +60 -0
  4. package/dist/core/paged-kv-cache.d.ts +33 -0
  5. package/dist/core/tensor.d.ts +38 -19
  6. package/dist/core/types.d.ts +27 -0
  7. package/dist/decode/index.d.ts +65 -0
  8. package/dist/gemm/index.d.ts +25 -0
  9. package/dist/index.d.ts +26 -19
  10. package/dist/index.js +2508 -3885
  11. package/dist/kernels/activation.wgsl.d.ts +14 -0
  12. package/dist/kernels/batch-decode-paged.wgsl.d.ts +12 -0
  13. package/dist/kernels/batch-prefill-paged.wgsl.d.ts +13 -0
  14. package/dist/kernels/decode-attention.wgsl.d.ts +16 -0
  15. package/dist/kernels/gemm.wgsl.d.ts +17 -0
  16. package/dist/kernels/page.wgsl.d.ts +10 -0
  17. package/dist/kernels/prefill-attention.wgsl.d.ts +17 -0
  18. package/dist/kernels/rmsnorm.wgsl.d.ts +10 -0
  19. package/dist/kernels/rope.wgsl.d.ts +19 -0
  20. package/dist/kernels/sampling.wgsl.d.ts +23 -0
  21. package/dist/norm/index.d.ts +43 -0
  22. package/dist/page/index.d.ts +21 -0
  23. package/dist/prefill/index.d.ts +69 -0
  24. package/dist/rope/index.d.ts +37 -0
  25. package/dist/sampling/index.d.ts +53 -4
  26. package/package.json +1 -1
  27. package/dist/attention/block-sparse/format.d.ts +0 -52
  28. package/dist/attention/block-sparse/patterns/causal.d.ts +0 -16
  29. package/dist/attention/block-sparse/patterns/sliding.d.ts +0 -22
  30. package/dist/attention/flash-attention.d.ts +0 -30
  31. package/dist/attention/index.d.ts +0 -9
  32. package/dist/attention/paged-kv/block-manager.d.ts +0 -102
  33. package/dist/attention/paged-kv/index.d.ts +0 -5
  34. package/dist/attention/paged-kv/page-table.d.ts +0 -99
  35. package/dist/attention/scheduler.d.ts +0 -40
  36. package/dist/core/buffer-pool.d.ts +0 -18
  37. package/dist/core/device.d.ts +0 -23
  38. package/dist/inference/engine.d.ts +0 -69
  39. package/dist/inference/generate.d.ts +0 -30
  40. package/dist/inference/index.d.ts +0 -7
  41. package/dist/inference/types.d.ts +0 -161
  42. package/dist/jit/compiler.d.ts +0 -23
  43. package/dist/jit/kernel-cache.d.ts +0 -21
  44. package/dist/model/gguf.d.ts +0 -90
  45. package/dist/model/index.d.ts +0 -16
  46. package/dist/model/safetensors.d.ts +0 -38
  47. package/dist/model/types.d.ts +0 -182
  48. package/dist/ops/activations.d.ts +0 -43
  49. package/dist/ops/elementwise.d.ts +0 -38
  50. package/dist/ops/embedding.d.ts +0 -30
  51. package/dist/ops/matmul.d.ts +0 -21
  52. package/dist/ops/normalization.d.ts +0 -24
  53. package/dist/ops/reshape.d.ts +0 -39
  54. package/dist/ops/rope.d.ts +0 -32
  55. package/dist/ops/softmax.d.ts +0 -18
  56. package/dist/quantization/index.d.ts +0 -6
  57. package/dist/quantization/qmatmul.d.ts +0 -38
  58. package/dist/quantization/quantize.d.ts +0 -52
  59. package/dist/sampling/sampler.d.ts +0 -39
  60. package/dist/sampling/top-k.d.ts +0 -24
  61. package/dist/sampling/top-p.d.ts +0 -14
package/README.md CHANGED
@@ -2,53 +2,58 @@
2
2
 
3
3
  High-performance LLM inference kernels for WebGPU.
4
4
 
5
+ WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
6
+
7
+ ## Features
8
+
9
+ - **Flash Attention for WebGPU**: Online softmax and paged KV cache for memory-efficient attention
10
+ - **Subgroup Optimization**: Uses WebGPU subgroup operations for 2-4x faster reductions when available
11
+ - **Fused Kernels**: Reduces memory bandwidth with operations like `silu_and_mul`, `fused_add_rmsnorm`
12
+ - **JIT Compilation**: Runtime kernel generation and pipeline caching
13
+ - **f16 Safe Accumulation**: Uses f32 accumulators for f16 computation to prevent precision loss
14
+ - **FlashInfer Compatible API**: Drop-in replacement for FlashInfer with identical naming
15
+
5
16
  ## Install
6
17
 
7
18
  ```bash
8
19
  npm install webinfer
9
20
  ```
10
21
 
11
- ## Usage
12
-
13
- ```typescript
14
- import { WebInferDevice, Tensor, matmul, flashAttention } from 'webinfer';
22
+ ## API Status
15
23
 
16
- // Initialize WebGPU device
17
- const device = await WebInferDevice.create();
18
-
19
- // Matrix multiplication
20
- const a = Tensor.rand(device, [1024, 1024]);
21
- const b = Tensor.rand(device, [1024, 1024]);
22
- const c = await matmul(device, a, b);
23
-
24
- // Read result back to CPU
25
- const result = await c.toArray();
26
- ```
24
+ **Phase 1: Core + Simple Ops**
25
+ - `webinfer.WebInferContext.create()` - Initialize WebGPU context
26
+ - `webinfer.norm.rmsnorm()` - Root mean square normalization
27
+ - `webinfer.activation.silu_and_mul()` - SiLU activation with gating
28
+ - `webinfer.activation.gelu_and_mul()` - GELU activation with gating
27
29
 
28
- ## Operations
30
+ **Phase 2: Single Attention**
31
+ - `webinfer.decode.single_decode_with_kv_cache()` - Flash attention for decode phase
32
+ - `webinfer.prefill.single_prefill_with_kv_cache()` - Flash attention for prefill phase with causal masking
29
33
 
30
- | Category | Operations |
31
- |----------|------------|
32
- | **Core** | matmul, flashAttention |
33
- | **Normalization** | rmsNorm, layerNorm |
34
- | **Activations** | gelu, silu, relu, softmax |
35
- | **Position** | rope (rotary embeddings) |
36
- | **Sampling** | topK, topP, sample |
37
- | **Quantization** | INT4/INT8 quantized matmul |
38
- | **Model Loading** | SafeTensors, GGUF |
34
+ **Phase 3: Paged KV + Batched Operations**
35
+ - `webinfer.page.append_paged_kv_cache()` - Append new KV pairs to paged cache
36
+ - `webinfer.decode.BatchDecodeWithPagedKVCacheWrapper` - Batched decode with paged KV cache
37
+ - `webinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` - Batched prefill with paged KV cache
39
38
 
40
- ## Requirements
39
+ **Phase 4: Additional Operations**
40
+ - `webinfer.rope.apply_rope_inplace()` - Apply RoPE position encoding
41
+ - `webinfer.rope.apply_llama31_rope_inplace()` - Llama 3.1 RoPE variant for long context
42
+ - `webinfer.sampling.sampling_from_probs()` - Categorical sampling
43
+ - `webinfer.sampling.top_k_sampling_from_probs()` - Top-k sampling
44
+ - `webinfer.sampling.top_p_sampling_from_probs()` - Top-p (nucleus) sampling
45
+ - `webinfer.sampling.min_p_sampling_from_probs()` - Min-p sampling
46
+ - `webinfer.sampling.top_k_top_p_sampling_from_probs()` - Combined top-k and top-p
47
+ - `webinfer.gemm.bmm_fp16()` - Batched matrix multiplication (fp16)
48
+ - `webinfer.gemm.bmm_fp32()` - Batched matrix multiplication (fp32)
41
49
 
42
- - Browser with WebGPU support (Chrome 113+, Edge 113+)
43
- - Or Node.js with `@aspect-build/aspect-cli` for server-side WebGPU
50
+ See [dev/design.md](dev/design.md) for full design and implementation plan.
44
51
 
45
- ## Benchmarks
52
+ ## Release
46
53
 
47
54
  ```bash
48
- git clone https://github.com/guan404ming/webinfer
49
- cd webinfer
50
- bun install
51
- bun run bench
55
+ bun run build
56
+ npm publish
52
57
  ```
53
58
 
54
59
  ## License
@@ -0,0 +1,30 @@
1
+ /**
2
+ * Activation functions
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * SiLU (Swish) activation with gating
8
+ *
9
+ * Input tensor has shape [..., 2 * hidden_size]
10
+ * Splits into gate[..., :hidden_size] and up[..., hidden_size:]
11
+ * Computes: output = silu(gate) * up
12
+ * Where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
13
+ *
14
+ * @param ctx WebInfer context
15
+ * @param input Input tensor of shape [..., 2 * hidden_size]
16
+ * @returns Activated tensor of shape [..., hidden_size]
17
+ */
18
+ export declare function silu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
19
+ /**
20
+ * GELU activation with gating
21
+ *
22
+ * Input tensor has shape [..., 2 * hidden_size]
23
+ * Computes: output = gelu(gate) * up
24
+ * Where gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
25
+ *
26
+ * @param ctx WebInfer context
27
+ * @param input Input tensor of shape [..., 2 * hidden_size]
28
+ * @returns Activated tensor of shape [..., hidden_size]
29
+ */
30
+ export declare function gelu_and_mul(ctx: WebInferContext, input: Tensor): Promise<Tensor>;
@@ -0,0 +1,60 @@
1
+ /**
2
+ * WebInferContext manages WebGPU device and provides core functionality
3
+ */
4
+ import type { DType } from './types.ts';
5
+ import { Tensor } from './tensor.ts';
6
+ /**
7
+ * Device vendor detection
8
+ */
9
+ export type Vendor = 'apple' | 'nvidia' | 'amd' | 'intel' | 'unknown';
10
+ /**
11
+ * Device features detected at runtime
12
+ */
13
+ export interface DeviceFeatures {
14
+ readonly vendor: Vendor;
15
+ readonly hasSubgroups: boolean;
16
+ readonly hasF16: boolean;
17
+ readonly maxWorkgroupSize: number;
18
+ readonly maxStorageBufferSize: number;
19
+ readonly maxBufferSize: number;
20
+ readonly maxWorkgroupsPerDimension: number;
21
+ }
22
+ /**
23
+ * WebInferContext - main context for all WebInfer operations
24
+ */
25
+ export declare class WebInferContext {
26
+ readonly device: GPUDevice;
27
+ readonly features: DeviceFeatures;
28
+ private pipelineCache;
29
+ private disposed;
30
+ private constructor();
31
+ /**
32
+ * Create a new WebInferContext
33
+ */
34
+ static create(): Promise<WebInferContext>;
35
+ /**
36
+ * Create a GPU buffer
37
+ */
38
+ createBuffer(size: number, usage?: GPUBufferUsageFlags): GPUBuffer;
39
+ /**
40
+ * Create a tensor with uninitialized data
41
+ */
42
+ createTensor(shape: number[], dtype: DType): Tensor;
43
+ /**
44
+ * Create a tensor from TypedArray data
45
+ */
46
+ createTensorFromData(data: Float32Array | Uint16Array, shape: number[], dtype: DType): Tensor;
47
+ /**
48
+ * Get or create a cached compute pipeline
49
+ */
50
+ getOrCreatePipeline(key: string, createFn: () => Promise<GPUComputePipeline> | GPUComputePipeline): Promise<GPUComputePipeline>;
51
+ /**
52
+ * Clear pipeline cache
53
+ */
54
+ clearPipelineCache(): void;
55
+ /**
56
+ * Dispose the context and cleanup resources
57
+ */
58
+ dispose(): void;
59
+ private checkDisposed;
60
+ }
@@ -0,0 +1,33 @@
1
+ /**
2
+ * Paged KV Cache implementation
3
+ */
4
+ import type { Tensor } from './tensor.ts';
5
+ /**
6
+ * Paged KV Cache for efficient memory management
7
+ *
8
+ * Layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
9
+ * where dim 1 has [0] = keys, [1] = values
10
+ */
11
+ export interface PagedKvCache {
12
+ /**
13
+ * The backing tensor for the paged KV cache
14
+ * Shape: [num_pages, 2, page_size, num_kv_heads, head_dim]
15
+ */
16
+ data: Tensor;
17
+ /**
18
+ * Number of tokens per page
19
+ */
20
+ page_size: number;
21
+ /**
22
+ * Number of KV heads (for GQA/MQA)
23
+ */
24
+ num_kv_heads: number;
25
+ /**
26
+ * Dimension of each head
27
+ */
28
+ head_dim: number;
29
+ }
30
+ /**
31
+ * Create a paged KV cache
32
+ */
33
+ export declare function createPagedKvCache(data: Tensor, page_size: number, num_kv_heads: number, head_dim: number): PagedKvCache;
@@ -1,25 +1,44 @@
1
1
  /**
2
- * Tensor - GPU-backed multidimensional array
2
+ * Tensor wrapper for WebGPU buffers
3
+ */
4
+ import type { DType } from './types.ts';
5
+ /**
6
+ * Compute strides from shape (row-major order)
7
+ */
8
+ export declare function computeStrides(shape: number[]): number[];
9
+ /**
10
+ * Compute total number of elements from shape
11
+ */
12
+ export declare function computeSize(shape: number[]): number;
13
+ /**
14
+ * Tensor class wrapping a GPUBuffer with shape and dtype information
3
15
  */
4
- import type { WebInferDevice } from "./device.ts";
5
- export type DType = "f32" | "f16" | "i32" | "u32";
6
16
  export declare class Tensor {
7
- private _device;
8
- private _shape;
9
- private _dtype;
10
- private _buffer;
11
- private _disposed;
12
- constructor(device: WebInferDevice, shape: number[], dtype?: DType, data?: Float32Array | Uint32Array | Int32Array);
13
- static fromArray(device: WebInferDevice, shape: number[], data: Float32Array, dtype?: DType): Promise<Tensor>;
14
- static zeros(device: WebInferDevice, shape: number[], dtype?: DType): Tensor;
15
- static rand(device: WebInferDevice, shape: number[], dtype?: DType): Tensor;
16
- get shape(): readonly number[];
17
- get dtype(): DType;
18
- get numel(): number;
19
- get byteSize(): number;
20
- get buffer(): GPUBuffer;
21
- get device(): WebInferDevice;
22
- toArray(): Promise<Float32Array>;
17
+ readonly buffer: GPUBuffer;
18
+ readonly shape: readonly number[];
19
+ readonly dtype: DType;
20
+ readonly strides: readonly number[];
21
+ readonly size: number;
22
+ readonly byteSize: number;
23
+ constructor(buffer: GPUBuffer, shape: number[], dtype: DType, strides?: number[]);
24
+ /**
25
+ * Get the number of dimensions
26
+ */
27
+ get ndim(): number;
28
+ /**
29
+ * Check if tensor is contiguous (row-major)
30
+ */
31
+ get isContiguous(): boolean;
32
+ /**
33
+ * Get a view of the tensor with new shape (must have same size)
34
+ */
23
35
  reshape(newShape: number[]): Tensor;
36
+ /**
37
+ * Destroy the underlying GPU buffer
38
+ */
24
39
  dispose(): void;
40
+ /**
41
+ * Get string representation
42
+ */
43
+ toString(): string;
25
44
  }
@@ -0,0 +1,27 @@
1
+ /**
2
+ * Core type definitions for WebInfer
3
+ */
4
+ /**
5
+ * Data type for tensors
6
+ */
7
+ export type DType = 'float32' | 'float16';
8
+ /**
9
+ * Position encoding modes for attention
10
+ */
11
+ export declare enum PosEncodingMode {
12
+ NONE = 0,
13
+ ROPE_LLAMA = 1,
14
+ ALIBI = 2
15
+ }
16
+ /**
17
+ * Get the byte size for a given dtype
18
+ */
19
+ export declare function dtypeByteSize(dtype: DType): number;
20
+ /**
21
+ * Get the GPUTextureFormat for a given dtype
22
+ */
23
+ export declare function dtypeToGPUFormat(dtype: DType): string;
24
+ /**
25
+ * Get WGSL type string for a given dtype
26
+ */
27
+ export declare function dtypeToWGSL(dtype: DType): 'f32' | 'f16';
@@ -0,0 +1,65 @@
1
+ /**
2
+ * Decode attention operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ import type { PagedKvCache } from '../core/paged-kv-cache.ts';
7
+ import { PosEncodingMode } from '../core/types.ts';
8
+ /**
9
+ * Single decode with KV cache
10
+ *
11
+ * Performs attention for a single token (decode phase) with KV cache.
12
+ * Uses flash attention with online softmax to avoid materializing full attention matrix.
13
+ *
14
+ * @param ctx WebInfer context
15
+ * @param q Query tensor of shape [num_qo_heads, head_dim]
16
+ * @param k_cache Key cache tensor of shape [seq_len, num_kv_heads, head_dim]
17
+ * @param v_cache Value cache tensor of shape [seq_len, num_kv_heads, head_dim]
18
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
19
+ * @param sm_scale Softmax scale (default: 1/sqrt(head_dim))
20
+ * @param rope_scale RoPE scale (default: 1.0) - not yet implemented
21
+ * @param rope_theta RoPE theta (default: 10000.0) - not yet implemented
22
+ * @returns Output tensor of shape [num_qo_heads, head_dim]
23
+ */
24
+ export declare function single_decode_with_kv_cache(ctx: WebInferContext, q: Tensor, k_cache: Tensor, v_cache: Tensor, pos_encoding_mode?: PosEncodingMode, sm_scale?: number, rope_scale?: number, rope_theta?: number): Promise<Tensor>;
25
+ /**
26
+ * Batched decode with paged KV cache wrapper
27
+ *
28
+ * FlashInfer-compatible wrapper for batched decode attention with paged KV cache.
29
+ */
30
+ export declare class BatchDecodeWithPagedKVCacheWrapper {
31
+ private ctx;
32
+ private planned;
33
+ private num_qo_heads;
34
+ private num_kv_heads;
35
+ private head_dim;
36
+ private page_size;
37
+ private batch_size;
38
+ private sm_scale;
39
+ private pos_encoding_mode;
40
+ constructor(ctx: WebInferContext);
41
+ /**
42
+ * Plan the batched decode operation
43
+ *
44
+ * @param indptr Indirection pointer [batch_size + 1]
45
+ * @param indices Page indices [nnz_pages]
46
+ * @param last_page_len Last page lengths [batch_size]
47
+ * @param num_qo_heads Number of query heads
48
+ * @param num_kv_heads Number of KV heads
49
+ * @param head_dim Head dimension
50
+ * @param page_size Page size
51
+ * @param pos_encoding_mode Position encoding mode (default: NONE)
52
+ */
53
+ plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: number, num_kv_heads: number, head_dim: number, page_size: number, pos_encoding_mode?: PosEncodingMode): void;
54
+ /**
55
+ * Run the batched decode operation
56
+ *
57
+ * @param q Query tensor [batch_size, num_qo_heads, head_dim]
58
+ * @param paged_kv_cache Paged KV cache
59
+ * @param indptr Indirection pointer [batch_size + 1]
60
+ * @param indices Page indices [nnz_pages]
61
+ * @param last_page_len Last page lengths [batch_size]
62
+ * @returns Output tensor [batch_size, num_qo_heads, head_dim]
63
+ */
64
+ run(q: Tensor, paged_kv_cache: PagedKvCache, indptr: Tensor, indices: Tensor, last_page_len: Tensor): Promise<Tensor>;
65
+ }
@@ -0,0 +1,25 @@
1
+ /**
2
+ * GEMM (General Matrix Multiplication) operations
3
+ */
4
+ import type { WebInferContext } from '../core/context.ts';
5
+ import type { Tensor } from '../core/tensor.ts';
6
+ /**
7
+ * Batched matrix multiplication for fp16
8
+ *
9
+ * Computes C = A @ B for batched matrices
10
+ *
11
+ * @param ctx WebInfer context
12
+ * @param a Input tensor A [B, M, K]
13
+ * @param b Input tensor B [B, K, N]
14
+ * @returns Output tensor C [B, M, N]
15
+ */
16
+ export declare function bmm_fp16(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
17
+ /**
18
+ * Batched matrix multiplication for fp32
19
+ *
20
+ * @param ctx WebInfer context
21
+ * @param a Input tensor A [B, M, K]
22
+ * @param b Input tensor B [B, K, N]
23
+ * @returns Output tensor C [B, M, N]
24
+ */
25
+ export declare function bmm_fp32(ctx: WebInferContext, a: Tensor, b: Tensor): Promise<Tensor>;
package/dist/index.d.ts CHANGED
@@ -1,22 +1,29 @@
1
1
  /**
2
2
  * WebInfer - High-performance LLM inference kernels for WebGPU
3
- * "The cuDNN/FlashInfer of WebGPU"
3
+ *
4
+ * WebGPU implementation of FlashInfer APIs for browser-based LLM inference.
4
5
  */
5
- export { WebInferDevice, type DeviceInfo } from "./core/device.ts";
6
- export { Tensor, type DType } from "./core/tensor.ts";
7
- export { BufferPool } from "./core/buffer-pool.ts";
8
- export { KernelCache, type CacheStats } from "./jit/kernel-cache.ts";
9
- export { WGSLCompiler, type MatMulConfig } from "./jit/compiler.ts";
10
- export { matmul, matmulCPU, getMatMulCacheStats } from "./ops/matmul.ts";
11
- export { layerNorm, layerNormCPU, rmsNorm, rmsNormCPU, } from "./ops/normalization.ts";
12
- export { rope, ropeCPU, computeRoPEFrequencies, type RoPEConfig, } from "./ops/rope.ts";
13
- export { gelu, geluCPU, geluExactCPU, silu, siluCPU, relu, reluCPU, sigmoidCPU, } from "./ops/activations.ts";
14
- export { softmaxGPU, softmaxCPU, logSoftmaxCPU } from "./ops/softmax.ts";
15
- export { add, addCPU, mul, mulCPU, scale, scaleCPU, addScalarCPU, fmaCPU, } from "./ops/elementwise.ts";
16
- export { embedding, embeddingCPU, batchedEmbeddingCPU, } from "./ops/embedding.ts";
17
- export { transpose2D, transpose2DCPU, transposeCPU, reshapeCPU, permuteCPU, } from "./ops/reshape.ts";
18
- export { quantizeToInt8, quantizeToInt4, dequantizeInt8, dequantizeInt4, quantizationError, getMemorySavings, qmatmulInt8CPU, qmatmulInt4CPU, qmatmulInt8BlockCPU, estimateQMatMulFlops, estimateQMatMulBandwidth, type QuantConfig, type QuantizedTensor, } from "./quantization/index.ts";
19
- export { flashAttention, attentionCPU, type AttentionConfig, buildBlockSparseCSR, getSparsityRatio, estimateMemorySavings, type BlockSparseCSR, type AttentionPattern, buildCausalMask, getCausalSparsity, buildSlidingWindowMask, buildCausalSlidingWindowMask, getSlidingWindowSparsity, AttentionScheduler, type ChunkPlan, PagedKVCache, type PagedKVCacheConfig, type SequenceEntry, BlockManager, ContinuousBatchScheduler, type BlockManagerConfig, type AllocationPolicy, type AllocationRequest, } from "./attention/index.ts";
20
- export { topK, topKCPU, topKFilter, topPFilter, topPFilterCPU, sample, sampleCPU, sampleGreedy, sampleFromProbs, softmax, applyRepetitionPenalty, type SamplingConfig, } from "./sampling/index.ts";
21
- export { type ModelFormat, type SafetensorsDType, GGUFQuantType, GGUFMetadataValueType, type TensorInfo, type SafetensorsHeader, type ModelMetadata, type LoadedModel, type LoadOptions, parseSafetensorsHeader, loadSafetensors, loadSafetensorsFromUrl, isSafetensors, parseGGUFHeader, loadGGUF, loadGGUFFromUrl, loadGGUFTensor, isGGUF, loadModel, } from "./model/index.ts";
22
- export { type ModelConfig, type InferenceConfig, type GenerationConfig, type GenerationResult, type StreamToken, type FinishReason, type ForwardResult, DEFAULT_GENERATION_CONFIG, normalizeGenerationConfig, InferenceEngine, generate, generateStream, greedyDecode, sampleNextToken, } from "./inference/index.ts";
6
+ export { WebInferContext } from './core/context.ts';
7
+ export { Tensor } from './core/tensor.ts';
8
+ export type { DType } from './core/types.ts';
9
+ export { PosEncodingMode } from './core/types.ts';
10
+ export type { PagedKvCache } from './core/paged-kv-cache.ts';
11
+ export { createPagedKvCache } from './core/paged-kv-cache.ts';
12
+ export { WebInferContext as create_context } from './core/context.ts';
13
+ import * as norm from './norm/index.ts';
14
+ import * as activation from './activation/index.ts';
15
+ import * as decode from './decode/index.ts';
16
+ import * as prefill from './prefill/index.ts';
17
+ import * as page from './page/index.ts';
18
+ import * as rope from './rope/index.ts';
19
+ import * as sampling from './sampling/index.ts';
20
+ import * as gemm from './gemm/index.ts';
21
+ export { norm, activation, decode, prefill, page, rope, sampling, gemm };
22
+ export declare const rmsnorm: typeof norm.rmsnorm, fused_add_rmsnorm: typeof norm.fused_add_rmsnorm, gemma_rmsnorm: typeof norm.gemma_rmsnorm;
23
+ export declare const silu_and_mul: typeof activation.silu_and_mul, gelu_and_mul: typeof activation.gelu_and_mul;
24
+ export declare const single_decode_with_kv_cache: typeof decode.single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper: typeof decode.BatchDecodeWithPagedKVCacheWrapper;
25
+ export declare const single_prefill_with_kv_cache: typeof prefill.single_prefill_with_kv_cache, BatchPrefillWithPagedKVCacheWrapper: typeof prefill.BatchPrefillWithPagedKVCacheWrapper;
26
+ export declare const append_paged_kv_cache: typeof page.append_paged_kv_cache;
27
+ export declare const apply_rope_inplace: typeof rope.apply_rope_inplace, apply_llama31_rope_inplace: typeof rope.apply_llama31_rope_inplace;
28
+ export declare const sampling_from_probs: typeof sampling.sampling_from_probs, top_k_sampling_from_probs: typeof sampling.top_k_sampling_from_probs, top_p_sampling_from_probs: typeof sampling.top_p_sampling_from_probs, min_p_sampling_from_probs: typeof sampling.min_p_sampling_from_probs, top_k_top_p_sampling_from_probs: typeof sampling.top_k_top_p_sampling_from_probs;
29
+ export declare const bmm_fp16: typeof gemm.bmm_fp16, bmm_fp32: typeof gemm.bmm_fp32;