webinfer 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +18 -28
- package/dist/attention/block-sparse/patterns/tree.d.ts +65 -0
- package/dist/attention/cascaded-inference.d.ts +29 -0
- package/dist/attention/index.d.ts +112 -3
- package/dist/attention/paged-attention.d.ts +40 -0
- package/dist/attention/paged-kv/index.d.ts +2 -2
- package/dist/attention/paged-kv/page-table.d.ts +66 -0
- package/dist/core/tdr.d.ts +114 -0
- package/dist/index.d.ts +13 -11
- package/dist/index.js +3638 -2582
- package/dist/inference/engine.d.ts +1 -1
- package/dist/inference/index.d.ts +1 -1
- package/dist/jit/compiler.d.ts +1 -1
- package/dist/model/gguf.d.ts +1 -1
- package/dist/model/index.d.ts +4 -4
- package/dist/model/safetensors.d.ts +1 -1
- package/dist/ops/normalization.d.ts +39 -0
- package/dist/quantization/index.d.ts +2 -2
- package/dist/sampling/beam-search.d.ts +87 -0
- package/dist/sampling/index.d.ts +3 -3
- package/dist/sampling/sampler.d.ts +33 -0
- package/dist/sampling/speculative.d.ts +65 -0
- package/dist/tvm/adapter.d.ts +81 -0
- package/dist/tvm/index.d.ts +8 -0
- package/dist/tvm/ops.d.ts +26 -0
- package/dist/tvm/types.d.ts +35 -0
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -11,44 +11,34 @@ npm install webinfer
|
|
|
11
11
|
## Usage
|
|
12
12
|
|
|
13
13
|
```typescript
|
|
14
|
-
import { WebInferDevice,
|
|
14
|
+
import { WebInferDevice, attention } from 'webinfer';
|
|
15
15
|
|
|
16
|
-
// Initialize WebGPU device
|
|
17
16
|
const device = await WebInferDevice.create();
|
|
18
17
|
|
|
19
|
-
//
|
|
20
|
-
const
|
|
21
|
-
const
|
|
22
|
-
const
|
|
18
|
+
// Single decode attention
|
|
19
|
+
const q = new Float32Array(32 * 128); // [num_qo_heads, head_dim]
|
|
20
|
+
const k = new Float32Array(2048 * 32 * 128); // [kv_len, num_kv_heads, head_dim]
|
|
21
|
+
const v = new Float32Array(2048 * 32 * 128);
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
const result = await c.toArray();
|
|
23
|
+
const output = await attention(device, { q, k, v });
|
|
26
24
|
```
|
|
27
25
|
|
|
28
|
-
##
|
|
26
|
+
## API
|
|
29
27
|
|
|
30
|
-
| Category |
|
|
31
|
-
|
|
32
|
-
| **
|
|
33
|
-
| **
|
|
34
|
-
| **
|
|
35
|
-
| **
|
|
36
|
-
| **
|
|
37
|
-
| **
|
|
38
|
-
| **Model Loading** | SafeTensors, GGUF |
|
|
28
|
+
| Category | Exports |
|
|
29
|
+
|----------|---------|
|
|
30
|
+
| **Attention** | `attention`, `BatchAttention`, `AttentionKernel`, `cascadedAttention` |
|
|
31
|
+
| **KV Cache** | `PagedKVCache`, `BlockManager`, `pagedAttention` |
|
|
32
|
+
| **Patterns** | `buildCausalMask`, `buildSlidingWindowMask`, `buildBlockSparseCSR` |
|
|
33
|
+
| **Sampling** | `topKSamplingFromProbs`, `topPSamplingFromProbs`, `minPSamplingFromProbs`, `topKTopPSamplingFromLogits` |
|
|
34
|
+
| **Normalization** | `rmsNorm`, `layerNorm`, `fusedAddRmsNorm`, `gemmaRmsNorm` |
|
|
35
|
+
| **Core** | `matmul`, `rope`, `gelu`, `silu`, `softmax` |
|
|
39
36
|
|
|
40
|
-
##
|
|
41
|
-
|
|
42
|
-
- Browser with WebGPU support (Chrome 113+, Edge 113+)
|
|
43
|
-
- Or Node.js with `@aspect-build/aspect-cli` for server-side WebGPU
|
|
44
|
-
|
|
45
|
-
## Benchmarks
|
|
37
|
+
## Release
|
|
46
38
|
|
|
47
39
|
```bash
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
bun install
|
|
51
|
-
bun run bench
|
|
40
|
+
bun run build
|
|
41
|
+
npm publish
|
|
52
42
|
```
|
|
53
43
|
|
|
54
44
|
## License
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tree Attention Pattern
|
|
3
|
+
* Used in speculative decoding (Medusa, EAGLE) and tree-based generation
|
|
4
|
+
*/
|
|
5
|
+
import { type BlockSparseCSR } from "../format.ts";
|
|
6
|
+
/**
|
|
7
|
+
* Tree node structure
|
|
8
|
+
*/
|
|
9
|
+
export interface TreeNode {
|
|
10
|
+
/** Token position in sequence */
|
|
11
|
+
position: number;
|
|
12
|
+
/** Parent position (-1 for root) */
|
|
13
|
+
parent: number;
|
|
14
|
+
/** Depth in tree (0 for root) */
|
|
15
|
+
depth: number;
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* Tree attention configuration
|
|
19
|
+
*/
|
|
20
|
+
export interface TreeAttentionConfig {
|
|
21
|
+
/** Total sequence length including prompt */
|
|
22
|
+
seqLen: number;
|
|
23
|
+
/** Prompt length (prefix that all tokens attend to) */
|
|
24
|
+
promptLen: number;
|
|
25
|
+
/** Tree structure for speculative tokens */
|
|
26
|
+
tree: TreeNode[];
|
|
27
|
+
/** Block size for sparse format */
|
|
28
|
+
blockSize?: number;
|
|
29
|
+
}
|
|
30
|
+
/**
|
|
31
|
+
* Build tree attention mask
|
|
32
|
+
*
|
|
33
|
+
* In tree attention:
|
|
34
|
+
* - All tokens attend to the prompt (positions 0 to promptLen-1)
|
|
35
|
+
* - Tree tokens attend to their ancestors in the tree
|
|
36
|
+
* - Maintains causal property within the tree structure
|
|
37
|
+
*/
|
|
38
|
+
export declare function buildTreeMask(config: TreeAttentionConfig): BlockSparseCSR;
|
|
39
|
+
/**
|
|
40
|
+
* Build a simple chain tree (linear speculation)
|
|
41
|
+
* Each token depends on the previous one
|
|
42
|
+
*/
|
|
43
|
+
export declare function buildChainTree(numSpecTokens: number): TreeNode[];
|
|
44
|
+
/**
|
|
45
|
+
* Build a wide tree (parallel speculation)
|
|
46
|
+
* All speculative tokens depend only on the prompt
|
|
47
|
+
*/
|
|
48
|
+
export declare function buildWideTree(numSpecTokens: number): TreeNode[];
|
|
49
|
+
/**
|
|
50
|
+
* Build a binary tree for speculation
|
|
51
|
+
*/
|
|
52
|
+
export declare function buildBinaryTree(depth: number): TreeNode[];
|
|
53
|
+
/**
|
|
54
|
+
* Build Medusa-style tree
|
|
55
|
+
* Multiple heads predict tokens at different positions
|
|
56
|
+
*/
|
|
57
|
+
export declare function buildMedusaTree(numHeads: number, tokensPerHead: number): TreeNode[];
|
|
58
|
+
/**
|
|
59
|
+
* Calculate tree sparsity ratio
|
|
60
|
+
*/
|
|
61
|
+
export declare function getTreeSparsity(config: TreeAttentionConfig): number;
|
|
62
|
+
/**
|
|
63
|
+
* Validate tree structure
|
|
64
|
+
*/
|
|
65
|
+
export declare function validateTree(tree: TreeNode[]): boolean;
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Cascaded Inference - TDR-safe attention for very long sequences
|
|
3
|
+
* Splits attention computation across multiple passes with browser yields
|
|
4
|
+
*/
|
|
5
|
+
import type { WebInferDevice } from "../core/device.ts";
|
|
6
|
+
import { Tensor } from "../core/tensor.ts";
|
|
7
|
+
export interface CascadedAttentionConfig {
|
|
8
|
+
numHeads: number;
|
|
9
|
+
headDim: number;
|
|
10
|
+
seqLen: number;
|
|
11
|
+
scale?: number;
|
|
12
|
+
causal?: boolean;
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* Cascaded Attention - Safe for very long sequences
|
|
16
|
+
* Uses Split-K strategy to prevent TDR (GPU timeout)
|
|
17
|
+
*
|
|
18
|
+
* @param device WebInfer device
|
|
19
|
+
* @param q Query tensor [seqLen, numHeads, headDim]
|
|
20
|
+
* @param k Key tensor [seqLen, numHeads, headDim]
|
|
21
|
+
* @param v Value tensor [seqLen, numHeads, headDim]
|
|
22
|
+
* @param config Attention configuration
|
|
23
|
+
* @param onProgress Optional progress callback
|
|
24
|
+
*/
|
|
25
|
+
export declare function cascadedAttention(device: WebInferDevice, q: Tensor, k: Tensor, v: Tensor, config: CascadedAttentionConfig, onProgress?: (chunk: number, total: number) => void): Promise<Tensor>;
|
|
26
|
+
/**
|
|
27
|
+
* CPU reference implementation for verification
|
|
28
|
+
*/
|
|
29
|
+
export declare function cascadedAttentionCPU(q: Float32Array, k: Float32Array, v: Float32Array, seqLen: number, numHeads: number, headDim: number, causal?: boolean): Float32Array;
|
|
@@ -1,9 +1,118 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* Attention Module Exports
|
|
3
3
|
*/
|
|
4
|
-
|
|
4
|
+
import type { WebInferDevice } from "../core/device.ts";
|
|
5
|
+
import { Tensor } from "../core/tensor.ts";
|
|
6
|
+
import type { BlockSparseCSR } from "./block-sparse/format.ts";
|
|
7
|
+
import { PagedKVCache } from "./paged-kv/page-table.ts";
|
|
8
|
+
/**
|
|
9
|
+
* Tensor-like input types
|
|
10
|
+
*/
|
|
11
|
+
export type TensorLike = Tensor | Float32Array | {
|
|
12
|
+
data: Float32Array;
|
|
13
|
+
shape: number[];
|
|
14
|
+
};
|
|
15
|
+
/**
|
|
16
|
+
* Attention options
|
|
17
|
+
*/
|
|
18
|
+
export interface AttentionOptions {
|
|
19
|
+
q: TensorLike;
|
|
20
|
+
k: TensorLike;
|
|
21
|
+
v: TensorLike;
|
|
22
|
+
causal?: boolean;
|
|
23
|
+
scale?: number;
|
|
24
|
+
window?: number;
|
|
25
|
+
mask?: BlockSparseCSR;
|
|
26
|
+
returnLse?: boolean;
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* Attention result with optional LSE
|
|
30
|
+
*/
|
|
31
|
+
export interface AttentionResultWithLse {
|
|
32
|
+
output: Tensor;
|
|
33
|
+
lse: Tensor;
|
|
34
|
+
}
|
|
35
|
+
/**
|
|
36
|
+
* Simple attention function
|
|
37
|
+
*/
|
|
38
|
+
export declare function attention(device: WebInferDevice, options: AttentionOptions): Promise<Tensor>;
|
|
39
|
+
export declare function attention(device: WebInferDevice, options: AttentionOptions & {
|
|
40
|
+
returnLse: true;
|
|
41
|
+
}): Promise<AttentionResultWithLse>;
|
|
42
|
+
/**
|
|
43
|
+
* BatchAttention configuration
|
|
44
|
+
*/
|
|
45
|
+
export interface BatchAttentionConfig {
|
|
46
|
+
numHeads: number;
|
|
47
|
+
headDim: number;
|
|
48
|
+
numKvHeads?: number;
|
|
49
|
+
maxBatchSize?: number;
|
|
50
|
+
maxSeqLen?: number;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* Prefill input
|
|
54
|
+
*/
|
|
55
|
+
export interface PrefillInput {
|
|
56
|
+
queries: TensorLike[];
|
|
57
|
+
keys: TensorLike[];
|
|
58
|
+
values: TensorLike[];
|
|
59
|
+
causal?: boolean;
|
|
60
|
+
window?: number;
|
|
61
|
+
}
|
|
62
|
+
/**
|
|
63
|
+
* Decode input
|
|
64
|
+
*/
|
|
65
|
+
export interface DecodeInput {
|
|
66
|
+
query: TensorLike;
|
|
67
|
+
kvCache: PagedKVCache;
|
|
68
|
+
seqIds: number[];
|
|
69
|
+
}
|
|
70
|
+
/**
|
|
71
|
+
* BatchAttention - Batched attention for prefill and decode
|
|
72
|
+
*/
|
|
73
|
+
export declare class BatchAttention {
|
|
74
|
+
private device;
|
|
75
|
+
private config;
|
|
76
|
+
constructor(device: WebInferDevice, config: BatchAttentionConfig);
|
|
77
|
+
/**
|
|
78
|
+
* Prefill: Process variable-length sequences
|
|
79
|
+
*/
|
|
80
|
+
prefill(input: PrefillInput): Promise<Tensor[]>;
|
|
81
|
+
/**
|
|
82
|
+
* Decode: Single token per sequence with KV cache
|
|
83
|
+
*/
|
|
84
|
+
decode(input: DecodeInput): Promise<Tensor>;
|
|
85
|
+
getConfig(): Required<BatchAttentionConfig>;
|
|
86
|
+
dispose(): void;
|
|
87
|
+
}
|
|
88
|
+
/**
|
|
89
|
+
* AttentionKernel configuration
|
|
90
|
+
*/
|
|
91
|
+
export interface AttentionKernelConfig {
|
|
92
|
+
numHeads: number;
|
|
93
|
+
headDim: number;
|
|
94
|
+
causal?: boolean;
|
|
95
|
+
blockSize?: number;
|
|
96
|
+
}
|
|
97
|
+
/**
|
|
98
|
+
* AttentionKernel - Low-level compiled kernel
|
|
99
|
+
*/
|
|
100
|
+
export declare class AttentionKernel {
|
|
101
|
+
private device;
|
|
102
|
+
private config;
|
|
103
|
+
private constructor();
|
|
104
|
+
static compile(device: WebInferDevice, config: AttentionKernelConfig): Promise<AttentionKernel>;
|
|
105
|
+
execute(input: {
|
|
106
|
+
q: Tensor;
|
|
107
|
+
k: Tensor;
|
|
108
|
+
v: Tensor;
|
|
109
|
+
}): Promise<Tensor>;
|
|
110
|
+
dispose(): void;
|
|
111
|
+
}
|
|
112
|
+
export { flashAttention, type AttentionConfig } from "./flash-attention.ts";
|
|
5
113
|
export { buildBlockSparseCSR, getSparsityRatio, estimateMemorySavings, type BlockSparseCSR, type AttentionPattern, } from "./block-sparse/format.ts";
|
|
6
114
|
export { buildCausalMask, getCausalSparsity } from "./block-sparse/patterns/causal.ts";
|
|
7
115
|
export { buildSlidingWindowMask, buildCausalSlidingWindowMask, getSlidingWindowSparsity, } from "./block-sparse/patterns/sliding.ts";
|
|
8
|
-
export {
|
|
9
|
-
export {
|
|
116
|
+
export { PagedKVCache, type PagedKVCacheConfig, type SequenceEntry, type DefragmentResult, BlockManager, ContinuousBatchScheduler, type BlockManagerConfig, type AllocationPolicy, type AllocationRequest, } from "./paged-kv/index.ts";
|
|
117
|
+
export { pagedAttention, appendToPagedCache, type PagedAttentionConfig, type PagedAttentionInput, } from "./paged-attention.ts";
|
|
118
|
+
export { cascadedAttention, type CascadedAttentionConfig, } from "./cascaded-inference.ts";
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Paged Attention Implementation for WebGPU
|
|
3
|
+
* Efficient attention computation using paged KV cache (vLLM-style)
|
|
4
|
+
*/
|
|
5
|
+
import type { WebInferDevice } from "../core/device.ts";
|
|
6
|
+
import { Tensor } from "../core/tensor.ts";
|
|
7
|
+
import type { PagedKVCache } from "./paged-kv/page-table.ts";
|
|
8
|
+
export interface PagedAttentionConfig {
|
|
9
|
+
numHeads: number;
|
|
10
|
+
headDim: number;
|
|
11
|
+
scale?: number;
|
|
12
|
+
}
|
|
13
|
+
export interface PagedAttentionInput {
|
|
14
|
+
query: Tensor;
|
|
15
|
+
kvCache: PagedKVCache;
|
|
16
|
+
seqIds: number[];
|
|
17
|
+
positions: number[];
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Paged Attention forward pass for decoding
|
|
21
|
+
* Computes attention against paged KV cache for single-token queries
|
|
22
|
+
*
|
|
23
|
+
* @param device WebInfer device
|
|
24
|
+
* @param input Paged attention input (query, kv cache, sequence info)
|
|
25
|
+
* @param config Attention configuration
|
|
26
|
+
*/
|
|
27
|
+
export declare function pagedAttention(device: WebInferDevice, input: PagedAttentionInput, config: PagedAttentionConfig): Promise<Tensor>;
|
|
28
|
+
/**
|
|
29
|
+
* Append new KV to paged cache
|
|
30
|
+
*/
|
|
31
|
+
export declare function appendToPagedCache(device: WebInferDevice, kvCache: PagedKVCache, seqId: number, key: Tensor, // [numHeads, headDim]
|
|
32
|
+
value: Tensor): Promise<void>;
|
|
33
|
+
/**
|
|
34
|
+
* CPU reference implementation for verification
|
|
35
|
+
*/
|
|
36
|
+
export declare function pagedAttentionCPU(q: Float32Array, // [batchSize, numHeads, headDim]
|
|
37
|
+
keyCache: Float32Array, // [maxPages, pageSize, numHeads, headDim]
|
|
38
|
+
valueCache: Float32Array, pageTable: number[][], // [batchSize][pages]
|
|
39
|
+
seqLens: number[], // [batchSize]
|
|
40
|
+
numHeads: number, headDim: number, pageSize: number, maxPages: number): Float32Array;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* PagedKV Module Exports
|
|
3
3
|
*/
|
|
4
|
-
export {
|
|
5
|
-
export {
|
|
4
|
+
export { type AllocationPolicy, type AllocationRequest, BlockManager, type BlockManagerConfig, ContinuousBatchScheduler, } from "./block-manager.ts";
|
|
5
|
+
export { type DefragmentResult, PagedKVCache, type PagedKVCacheConfig, type SequenceEntry, } from "./page-table.ts";
|
|
@@ -96,4 +96,70 @@ export declare class PagedKVCache {
|
|
|
96
96
|
* Dispose GPU resources
|
|
97
97
|
*/
|
|
98
98
|
dispose(): void;
|
|
99
|
+
/**
|
|
100
|
+
* Allocate a new sequence (v2 alias)
|
|
101
|
+
*/
|
|
102
|
+
alloc(initialLength?: number): number;
|
|
103
|
+
/**
|
|
104
|
+
* Append KV to sequence (v2 API)
|
|
105
|
+
*/
|
|
106
|
+
append(seqId: number, kv: {
|
|
107
|
+
key: Float32Array;
|
|
108
|
+
value: Float32Array;
|
|
109
|
+
layer?: number;
|
|
110
|
+
}): void;
|
|
111
|
+
/**
|
|
112
|
+
* Batch append KV for all layers (v2 API)
|
|
113
|
+
*/
|
|
114
|
+
appendBatch(seqId: number, kv: {
|
|
115
|
+
keys: Float32Array;
|
|
116
|
+
values: Float32Array;
|
|
117
|
+
}): void;
|
|
118
|
+
/**
|
|
119
|
+
* Free sequence (v2 alias)
|
|
120
|
+
*/
|
|
121
|
+
free(seqId: number): void;
|
|
122
|
+
/**
|
|
123
|
+
* Get stats (v2 API)
|
|
124
|
+
*/
|
|
125
|
+
stats(): {
|
|
126
|
+
usedPages: number;
|
|
127
|
+
freePages: number;
|
|
128
|
+
fragmentation: number;
|
|
129
|
+
};
|
|
130
|
+
/**
|
|
131
|
+
* Check if defrag needed (v2 alias)
|
|
132
|
+
*/
|
|
133
|
+
needsDefrag(threshold?: number): boolean;
|
|
134
|
+
/**
|
|
135
|
+
* Check if defragmentation is needed
|
|
136
|
+
* Returns fragmentation ratio (0 = no fragmentation, 1 = fully fragmented)
|
|
137
|
+
*/
|
|
138
|
+
getFragmentationRatio(): number;
|
|
139
|
+
/**
|
|
140
|
+
* Check if defragmentation would be beneficial
|
|
141
|
+
*/
|
|
142
|
+
needsDefragmentation(threshold?: number): boolean;
|
|
143
|
+
/**
|
|
144
|
+
* Defragment the KV cache by compacting pages
|
|
145
|
+
* Returns the number of pages moved
|
|
146
|
+
*/
|
|
147
|
+
defragment(): Promise<DefragmentResult>;
|
|
148
|
+
/**
|
|
149
|
+
* Move a single page from one location to another
|
|
150
|
+
*/
|
|
151
|
+
private movePage;
|
|
152
|
+
/**
|
|
153
|
+
* Rebuild the free page list from scratch
|
|
154
|
+
*/
|
|
155
|
+
private rebuildFreePageList;
|
|
156
|
+
}
|
|
157
|
+
/**
|
|
158
|
+
* Result of defragmentation operation
|
|
159
|
+
*/
|
|
160
|
+
export interface DefragmentResult {
|
|
161
|
+
pagesMoved: number;
|
|
162
|
+
durationMs: number;
|
|
163
|
+
fragmentationBefore: number;
|
|
164
|
+
fragmentationAfter: number;
|
|
99
165
|
}
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* TDR (Timeout Detection and Recovery) Prevention
|
|
3
|
+
* Handles browser-specific GPU timeout limits and graceful degradation
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Browser detection result
|
|
7
|
+
*/
|
|
8
|
+
export interface BrowserInfo {
|
|
9
|
+
name: "chrome" | "safari" | "firefox" | "edge" | "unknown";
|
|
10
|
+
version: number;
|
|
11
|
+
isMobile: boolean;
|
|
12
|
+
hasWebGPU: boolean;
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* TDR configuration for each browser
|
|
16
|
+
*/
|
|
17
|
+
export interface TDRConfig {
|
|
18
|
+
timeoutMs: number;
|
|
19
|
+
safetyMargin: number;
|
|
20
|
+
maxChunkSize: number;
|
|
21
|
+
supportsTimestampQuery: boolean;
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Detect browser information
|
|
25
|
+
*/
|
|
26
|
+
export declare function detectBrowser(): BrowserInfo;
|
|
27
|
+
/**
|
|
28
|
+
* Get TDR configuration for current browser
|
|
29
|
+
*/
|
|
30
|
+
export declare function getTDRConfig(browser?: BrowserInfo): TDRConfig;
|
|
31
|
+
/**
|
|
32
|
+
* Graceful degradation options
|
|
33
|
+
*/
|
|
34
|
+
export interface DegradationOptions {
|
|
35
|
+
/** Use CPU fallback when GPU fails */
|
|
36
|
+
enableCpuFallback: boolean;
|
|
37
|
+
/** Reduce precision to f16 when memory is tight */
|
|
38
|
+
enablePrecisionReduction: boolean;
|
|
39
|
+
/** Auto-chunk large sequences */
|
|
40
|
+
enableAutoChunking: boolean;
|
|
41
|
+
/** Maximum retries before falling back */
|
|
42
|
+
maxRetries: number;
|
|
43
|
+
}
|
|
44
|
+
/**
|
|
45
|
+
* Error types for graceful handling
|
|
46
|
+
*/
|
|
47
|
+
export type WebGPUErrorType = "device_lost" | "out_of_memory" | "validation" | "timeout" | "unknown";
|
|
48
|
+
/**
|
|
49
|
+
* Classify WebGPU error
|
|
50
|
+
*/
|
|
51
|
+
export declare function classifyError(error: unknown): WebGPUErrorType;
|
|
52
|
+
/**
|
|
53
|
+
* TDR-safe execution wrapper
|
|
54
|
+
*/
|
|
55
|
+
export declare class TDRGuard {
|
|
56
|
+
private browser;
|
|
57
|
+
private config;
|
|
58
|
+
private options;
|
|
59
|
+
private lastExecutionTime;
|
|
60
|
+
constructor(options?: Partial<DegradationOptions>);
|
|
61
|
+
/**
|
|
62
|
+
* Get safe execution time limit
|
|
63
|
+
*/
|
|
64
|
+
getSafeTimeLimit(): number;
|
|
65
|
+
/**
|
|
66
|
+
* Check if operation might cause TDR
|
|
67
|
+
*/
|
|
68
|
+
mightTimeout(estimatedMs: number): boolean;
|
|
69
|
+
/**
|
|
70
|
+
* Calculate chunks needed for safe execution
|
|
71
|
+
*/
|
|
72
|
+
calcChunks(estimatedTotalMs: number): number;
|
|
73
|
+
/**
|
|
74
|
+
* Yield to browser main thread
|
|
75
|
+
*/
|
|
76
|
+
yield(): Promise<void>;
|
|
77
|
+
/**
|
|
78
|
+
* Execute with TDR protection
|
|
79
|
+
*/
|
|
80
|
+
execute<T>(fn: () => Promise<T>, options?: {
|
|
81
|
+
estimatedMs?: number;
|
|
82
|
+
onRetry?: (attempt: number, error: unknown) => void;
|
|
83
|
+
fallback?: () => T | Promise<T>;
|
|
84
|
+
}): Promise<T>;
|
|
85
|
+
/**
|
|
86
|
+
* Get browser info
|
|
87
|
+
*/
|
|
88
|
+
getBrowserInfo(): BrowserInfo;
|
|
89
|
+
/**
|
|
90
|
+
* Get TDR config
|
|
91
|
+
*/
|
|
92
|
+
getTDRConfig(): TDRConfig;
|
|
93
|
+
/**
|
|
94
|
+
* Get last execution time
|
|
95
|
+
*/
|
|
96
|
+
getLastExecutionTime(): number;
|
|
97
|
+
}
|
|
98
|
+
/**
|
|
99
|
+
* Check WebGPU support and capabilities
|
|
100
|
+
*/
|
|
101
|
+
export declare function checkWebGPUSupport(): Promise<{
|
|
102
|
+
supported: boolean;
|
|
103
|
+
reason?: string;
|
|
104
|
+
adapter?: GPUAdapter;
|
|
105
|
+
limits?: GPUSupportedLimits;
|
|
106
|
+
}>;
|
|
107
|
+
/**
|
|
108
|
+
* Create device with graceful degradation
|
|
109
|
+
*/
|
|
110
|
+
export declare function createDeviceWithFallback(adapter: GPUAdapter, options?: {
|
|
111
|
+
requiredFeatures?: GPUFeatureName[];
|
|
112
|
+
requiredLimits?: Record<string, number>;
|
|
113
|
+
onFallback?: (reason: string) => void;
|
|
114
|
+
}): Promise<GPUDevice>;
|
package/dist/index.d.ts
CHANGED
|
@@ -5,18 +5,20 @@
|
|
|
5
5
|
export { WebInferDevice, type DeviceInfo } from "./core/device.ts";
|
|
6
6
|
export { Tensor, type DType } from "./core/tensor.ts";
|
|
7
7
|
export { BufferPool } from "./core/buffer-pool.ts";
|
|
8
|
+
export { TDRGuard, detectBrowser, getTDRConfig, classifyError, checkWebGPUSupport, createDeviceWithFallback, type BrowserInfo, type TDRConfig, type DegradationOptions, type WebGPUErrorType, } from "./core/tdr.ts";
|
|
9
|
+
export { attention, BatchAttention, AttentionKernel, type AttentionOptions, type AttentionResultWithLse, type BatchAttentionConfig, type PrefillInput, type DecodeInput, type TensorLike, type AttentionKernelConfig, } from "./attention/index.ts";
|
|
8
10
|
export { KernelCache, type CacheStats } from "./jit/kernel-cache.ts";
|
|
9
11
|
export { WGSLCompiler, type MatMulConfig } from "./jit/compiler.ts";
|
|
10
|
-
export { matmul,
|
|
11
|
-
export { layerNorm,
|
|
12
|
-
export { rope,
|
|
13
|
-
export { gelu,
|
|
14
|
-
export { softmaxGPU
|
|
15
|
-
export { add,
|
|
16
|
-
export { embedding
|
|
17
|
-
export { transpose2D
|
|
18
|
-
export { quantizeToInt8, quantizeToInt4, dequantizeInt8, dequantizeInt4, quantizationError, getMemorySavings,
|
|
19
|
-
export { flashAttention,
|
|
20
|
-
export { topK,
|
|
12
|
+
export { matmul, getMatMulCacheStats } from "./ops/matmul.ts";
|
|
13
|
+
export { layerNorm, rmsNorm, fusedAddRmsNorm, gemmaRmsNorm, gemmaFusedAddRmsNorm, } from "./ops/normalization.ts";
|
|
14
|
+
export { rope, computeRoPEFrequencies, type RoPEConfig } from "./ops/rope.ts";
|
|
15
|
+
export { gelu, silu, relu } from "./ops/activations.ts";
|
|
16
|
+
export { softmaxGPU } from "./ops/softmax.ts";
|
|
17
|
+
export { add, mul, scale } from "./ops/elementwise.ts";
|
|
18
|
+
export { embedding } from "./ops/embedding.ts";
|
|
19
|
+
export { transpose2D } from "./ops/reshape.ts";
|
|
20
|
+
export { quantizeToInt8, quantizeToInt4, dequantizeInt8, dequantizeInt4, quantizationError, getMemorySavings, estimateQMatMulFlops, estimateQMatMulBandwidth, type QuantConfig, type QuantizedTensor, } from "./quantization/index.ts";
|
|
21
|
+
export { flashAttention, type AttentionConfig, buildBlockSparseCSR, getSparsityRatio, estimateMemorySavings, type BlockSparseCSR, type AttentionPattern, buildCausalMask, getCausalSparsity, buildSlidingWindowMask, buildCausalSlidingWindowMask, getSlidingWindowSparsity, PagedKVCache, type PagedKVCacheConfig, type SequenceEntry, type DefragmentResult, BlockManager, ContinuousBatchScheduler, type BlockManagerConfig, type AllocationPolicy, type AllocationRequest, pagedAttention, appendToPagedCache, type PagedAttentionConfig, type PagedAttentionInput, cascadedAttention, type CascadedAttentionConfig, } from "./attention/index.ts";
|
|
22
|
+
export { topK, topKFilter, topPFilter, sample, sampleGreedy, sampleFromProbs, softmax, applyRepetitionPenalty, minPSamplingFromProbs, topKSamplingFromProbs, topPSamplingFromProbs, topKTopPSamplingFromProbs, topKTopPSamplingFromLogits, topPRenormProbs, topKRenormProbs, topKMaskLogits, type SamplingConfig, } from "./sampling/index.ts";
|
|
21
23
|
export { type ModelFormat, type SafetensorsDType, GGUFQuantType, GGUFMetadataValueType, type TensorInfo, type SafetensorsHeader, type ModelMetadata, type LoadedModel, type LoadOptions, parseSafetensorsHeader, loadSafetensors, loadSafetensorsFromUrl, isSafetensors, parseGGUFHeader, loadGGUF, loadGGUFFromUrl, loadGGUFTensor, isGGUF, loadModel, } from "./model/index.ts";
|
|
22
24
|
export { type ModelConfig, type InferenceConfig, type GenerationConfig, type GenerationResult, type StreamToken, type FinishReason, type ForwardResult, DEFAULT_GENERATION_CONFIG, normalizeGenerationConfig, InferenceEngine, generate, generateStream, greedyDecode, sampleNextToken, } from "./inference/index.ts";
|