@simulatte/doppler 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +32 -0
- package/README.md +25 -6
- package/package.json +25 -38
- package/src/browser/browser-converter.js +5 -0
- package/src/client/doppler-api.browser.js +6 -0
- package/src/client/doppler-api.d.ts +3 -0
- package/src/client/doppler-api.js +11 -2
- package/src/client/doppler-registry.js +3 -5
- package/src/client/doppler-registry.json +2 -2
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +13 -0
- package/src/config/kernels/kernel-ref-digests.js +23 -21
- package/src/config/kernels/moe/mixtral.paths.json +46 -0
- package/src/config/kernels/registry.json +74 -0
- package/src/config/loader.js +9 -0
- package/src/config/merge-contract-check.js +7 -0
- package/src/config/platforms/loader.js +3 -1
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +21 -0
- package/src/config/presets/models/gemma2.json +2 -1
- package/src/config/presets/models/gemma3.json +4 -1
- package/src/config/presets/models/gemma4.json +61 -0
- package/src/config/presets/models/granite-docling.json +70 -0
- package/src/config/presets/models/lfm2.json +6 -1
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/models/qwen3_vl.json +40 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
- package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
- package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/modes/trace-layers.json +1 -0
- package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
- package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
- package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
- package/src/config/runtime.js +3 -0
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +40 -0
- package/src/config/schema/debug.schema.js +28 -0
- package/src/config/schema/index.js +2 -0
- package/src/config/schema/inference-defaults.schema.js +1 -1
- package/src/config/schema/kernel-path.schema.d.ts +1 -0
- package/src/config/schema/manifest.schema.d.ts +1 -1
- package/src/config/schema/manifest.schema.js +1 -1
- package/src/config/schema/memory-limits.schema.js +2 -2
- package/src/config/schema/storage.schema.js +2 -2
- package/src/converter/conversion-plan.js +11 -3
- package/src/converter/core.js +19 -8
- package/src/converter/manifest-inference.js +12 -22
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +5 -1
- package/src/converter/quantizer.d.ts +5 -0
- package/src/converter/quantizer.js +34 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/distribution/shard-delivery.js +40 -1
- package/src/formats/rdrr/classification.js +32 -0
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +14 -1
- package/src/gpu/kernel-runtime.js +4 -2
- package/src/gpu/kernels/attention.js +2 -1
- package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
- package/src/gpu/kernels/dequant_shared.wgsl +4 -2
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
- package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
- package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
- package/src/gpu/kernels/gated-short-conv.js +284 -0
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/linear-attention-core.js +37 -17
- package/src/gpu/kernels/matmul-selection.js +48 -4
- package/src/gpu/kernels/matmul.d.ts +5 -0
- package/src/gpu/kernels/matmul.js +71 -2
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
- package/src/gpu/kernels/rmsnorm.js +9 -2
- package/src/gpu/kernels/sample.js +1 -3
- package/src/gpu/kernels/sample.wgsl +39 -9
- package/src/gpu/kernels/sample_f16.wgsl +38 -8
- package/src/gpu/kernels/shader-cache.js +9 -4
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/inference/browser-harness.d.ts +2 -0
- package/src/inference/browser-harness.js +20 -1
- package/src/inference/kv-cache/base.js +3 -10
- package/src/inference/pipelines/diffusion/helpers.js +3 -0
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +10 -3
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +13 -1
- package/src/inference/pipelines/text/attention/projections.js +54 -13
- package/src/inference/pipelines/text/attention/record.js +16 -6
- package/src/inference/pipelines/text/attention/run.js +59 -6
- package/src/inference/pipelines/text/config.d.ts +1 -0
- package/src/inference/pipelines/text/config.js +46 -4
- package/src/inference/pipelines/text/embed.js +26 -7
- package/src/inference/pipelines/text/execution-plan.js +5 -4
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
- package/src/inference/pipelines/text/execution-v0.js +12 -1
- package/src/inference/pipelines/text/generator-helpers.js +1 -0
- package/src/inference/pipelines/text/generator-runtime.js +19 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +15 -0
- package/src/inference/pipelines/text/generator-steps.js +71 -26
- package/src/inference/pipelines/text/generator.d.ts +5 -0
- package/src/inference/pipelines/text/generator.js +353 -166
- package/src/inference/pipelines/text/init.d.ts +15 -0
- package/src/inference/pipelines/text/init.js +35 -10
- package/src/inference/pipelines/text/layer.js +38 -8
- package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
- package/src/inference/pipelines/text/linear-attention.js +33 -3
- package/src/inference/pipelines/text/logits/gpu.js +2 -2
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +3 -1
- package/src/inference/pipelines/text/model-load.js +3 -0
- package/src/inference/pipelines/text/moe-gpu.js +21 -3
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
- package/src/inference/pipelines/text/ops.js +123 -53
- package/src/inference/pipelines/text/probes.js +1 -0
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/state.js +2 -0
- package/src/inference/pipelines/text.d.ts +5 -0
- package/src/inference/pipelines/text.js +59 -1
- package/src/inference/pipelines/vision/encoder.js +386 -0
- package/src/inference/pipelines/vision/image-preprocess.js +151 -0
- package/src/inference/pipelines/vision/index.js +173 -0
- package/src/inference/pipelines/vision/ops.js +78 -0
- package/src/inference/pipelines/vision/patch-embed.js +151 -0
- package/src/inference/test-harness.js +11 -9
- package/src/loader/doppler-loader.d.ts +3 -0
- package/src/loader/doppler-loader.js +20 -3
- package/src/loader/experts/expert-cache.js +6 -2
- package/src/loader/experts/expert-loader.js +6 -2
- package/src/loader/final-weights-loader.js +2 -0
- package/src/loader/layer-loader.js +42 -3
- package/src/loader/manifest-config.js +3 -1
- package/src/loader/shard-cache.js +3 -2
- package/src/loader/tensors/tensor-loader.d.ts +3 -0
- package/src/loader/tensors/tensor-loader.js +130 -4
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +2 -2
- package/src/rules/kernels/moe.rules.mixtral.json +75 -0
- package/src/rules/kernels/softmax.rules.json +2 -0
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.d.ts +1 -0
- package/src/rules/rule-registry.js +4 -0
- package/src/storage/downloader.js +2 -1
- package/src/storage/quickstart-downloader.d.ts +3 -0
- package/src/storage/quickstart-downloader.js +27 -30
- package/src/storage/shard-manager.js +4 -3
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/node-converter.js +28 -7
- package/src/tooling/node-source-runtime.js +65 -5
- package/src/tooling/node-webgpu.js +24 -7
- package/src/types/model.d.ts +5 -0
- package/src/utils/hf-resolve-url.d.ts +16 -0
- package/src/utils/hf-resolve-url.js +17 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +6 -1
- package/src/tooling/node-convert.d.ts +0 -54
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { log } from '../debug/index.js';
|
|
2
|
+
import { getExpectedShardHash } from '../formats/rdrr/index.js';
|
|
2
3
|
import {
|
|
3
4
|
computeHash,
|
|
4
5
|
createStreamingHasher,
|
|
@@ -1316,6 +1317,25 @@ async function clearPersistedShardState(shardIndex) {
|
|
|
1316
1317
|
await writer.abort?.();
|
|
1317
1318
|
}
|
|
1318
1319
|
|
|
1320
|
+
async function recoverHttpRejectedResumeRange(
|
|
1321
|
+
baseUrl,
|
|
1322
|
+
shardInfo,
|
|
1323
|
+
shardIndex,
|
|
1324
|
+
options,
|
|
1325
|
+
transferState,
|
|
1326
|
+
writeToStore
|
|
1327
|
+
) {
|
|
1328
|
+
await abortHttpTransferState(transferState);
|
|
1329
|
+
if (writeToStore) {
|
|
1330
|
+
await clearPersistedShardState(shardIndex);
|
|
1331
|
+
}
|
|
1332
|
+
return downloadShardFromHttp(baseUrl, shardInfo, shardIndex, {
|
|
1333
|
+
...options,
|
|
1334
|
+
__disablePersistedResume: true,
|
|
1335
|
+
__resumeRangeRecoveryAttempted: true,
|
|
1336
|
+
});
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1319
1339
|
async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {}) {
|
|
1320
1340
|
const {
|
|
1321
1341
|
signal,
|
|
@@ -1528,6 +1548,21 @@ async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {
|
|
|
1528
1548
|
throw error;
|
|
1529
1549
|
}
|
|
1530
1550
|
|
|
1551
|
+
if (
|
|
1552
|
+
error?.status === 416
|
|
1553
|
+
&& transferState.receivedBytes > 0
|
|
1554
|
+
&& options.__resumeRangeRecoveryAttempted !== true
|
|
1555
|
+
) {
|
|
1556
|
+
return recoverHttpRejectedResumeRange(
|
|
1557
|
+
baseUrl,
|
|
1558
|
+
shardInfo,
|
|
1559
|
+
shardIndex,
|
|
1560
|
+
options,
|
|
1561
|
+
transferState,
|
|
1562
|
+
writeToStore
|
|
1563
|
+
);
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1531
1566
|
if (Number.isInteger(error?.status) && error.status >= 400 && error.status < 500 && error.status !== 429) {
|
|
1532
1567
|
await abortHttpTransferState(transferState);
|
|
1533
1568
|
throw error;
|
|
@@ -2018,7 +2053,11 @@ export async function downloadShard(
|
|
|
2018
2053
|
onDeliveryMetrics,
|
|
2019
2054
|
signal,
|
|
2020
2055
|
requiredEncoding: requiredEncoding ?? activeConfig.requiredContentEncoding ?? null,
|
|
2021
|
-
expectedHash:
|
|
2056
|
+
expectedHash:
|
|
2057
|
+
options.expectedHash
|
|
2058
|
+
?? getExpectedShardHash(shardInfo, algorithm)
|
|
2059
|
+
?? activeConfig.expectedHash
|
|
2060
|
+
?? null,
|
|
2022
2061
|
expectedSize: expectedSize ?? shardInfo?.size ?? null,
|
|
2023
2062
|
expectedManifestVersionSet: options.expectedManifestVersionSet ?? null,
|
|
2024
2063
|
writeToStore,
|
|
@@ -32,6 +32,12 @@ export function classifyTensor(name, modelType) {
|
|
|
32
32
|
return 'head';
|
|
33
33
|
}
|
|
34
34
|
|
|
35
|
+
// Multimodal groups
|
|
36
|
+
const role = classifyTensorRole(name);
|
|
37
|
+
if (role === 'vision') return 'vision';
|
|
38
|
+
if (role === 'projector') return 'projector';
|
|
39
|
+
if (role === 'audio') return 'audio';
|
|
40
|
+
|
|
35
41
|
// Extract layer index
|
|
36
42
|
const layerMatch = name.match(/layers?[._](\d+)/i);
|
|
37
43
|
if (!layerMatch) {
|
|
@@ -96,6 +102,29 @@ export function classifyTensorRole(name) {
|
|
|
96
102
|
if (lower.includes('lm_head')) return 'lm_head';
|
|
97
103
|
if (lower.endsWith('output.weight') && !lower.includes('attn_')) return 'lm_head';
|
|
98
104
|
|
|
105
|
+
// Multimodal: vision encoder tensors
|
|
106
|
+
if (lower.startsWith('vision_tower.') || lower.startsWith('vision_model.')
|
|
107
|
+
|| lower.startsWith('visual.') || lower.startsWith('model.visual.')
|
|
108
|
+
|| lower.startsWith('vision.') || lower.startsWith('model.vision.')
|
|
109
|
+
|| lower.startsWith('vision_encoder.') || lower.startsWith('image_encoder.')
|
|
110
|
+
|| lower.startsWith('image_tower.') || lower.startsWith('image.')
|
|
111
|
+
|| lower.startsWith('model.image.')) {
|
|
112
|
+
return 'vision';
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Multimodal: audio encoder tensors
|
|
116
|
+
if (lower.startsWith('audio_tower.') || lower.startsWith('audio_model.')
|
|
117
|
+
|| lower.startsWith('audio.') || lower.startsWith('model.audio.')
|
|
118
|
+
|| lower.startsWith('audio_encoder.')) {
|
|
119
|
+
return 'audio';
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Multimodal: projector tensors
|
|
123
|
+
if (lower.startsWith('multi_modal_projector.') || lower.startsWith('model.multi_modal_projector.')
|
|
124
|
+
|| lower.startsWith('mm_projector.') || lower.startsWith('model.mm_projector.')) {
|
|
125
|
+
return 'projector';
|
|
126
|
+
}
|
|
127
|
+
|
|
99
128
|
if (lower.includes('shared_expert') || /experts?[._]/.test(lower)) {
|
|
100
129
|
return 'expert';
|
|
101
130
|
}
|
|
@@ -207,6 +236,9 @@ export function getGroupType(groupId, modelType) {
|
|
|
207
236
|
}
|
|
208
237
|
if (groupId === 'embed') return 'embed';
|
|
209
238
|
if (groupId === 'head') return 'head';
|
|
239
|
+
if (groupId === 'vision') return 'vision';
|
|
240
|
+
if (groupId === 'projector') return 'projector';
|
|
241
|
+
if (groupId === 'audio') return 'audio';
|
|
210
242
|
if (groupId === 'other') return 'layer';
|
|
211
243
|
|
|
212
244
|
if (groupId.includes('.expert.')) return 'expert';
|
|
@@ -7,6 +7,10 @@
|
|
|
7
7
|
import type { RDRRManifest, ShardInfo, TensorMap } from './types.js';
|
|
8
8
|
|
|
9
9
|
export declare function parseManifest(jsonString: string): RDRRManifest;
|
|
10
|
+
export declare function getExpectedShardHash(
|
|
11
|
+
shard: Partial<ShardInfo> | Record<string, unknown> | null | undefined,
|
|
12
|
+
manifestHashAlgorithm?: string | null
|
|
13
|
+
): string;
|
|
10
14
|
|
|
11
15
|
export declare function parseTensorMap(jsonString: string): TensorMap;
|
|
12
16
|
|
|
@@ -4,6 +4,19 @@ import { validateManifest } from './validation.js';
|
|
|
4
4
|
|
|
5
5
|
let currentManifest = null;
|
|
6
6
|
|
|
7
|
+
export function getExpectedShardHash(shard, manifestHashAlgorithm = null) {
|
|
8
|
+
if (!shard || typeof shard !== 'object' || Array.isArray(shard)) {
|
|
9
|
+
return '';
|
|
10
|
+
}
|
|
11
|
+
const algorithm = typeof manifestHashAlgorithm === 'string'
|
|
12
|
+
? manifestHashAlgorithm.trim().toLowerCase()
|
|
13
|
+
: '';
|
|
14
|
+
if (algorithm === 'blake3') {
|
|
15
|
+
return shard.blake3 || shard.hash || '';
|
|
16
|
+
}
|
|
17
|
+
return shard.hash || shard.blake3 || '';
|
|
18
|
+
}
|
|
19
|
+
|
|
7
20
|
export function parseManifest(jsonString) {
|
|
8
21
|
let manifest;
|
|
9
22
|
|
|
@@ -21,7 +34,7 @@ export function parseManifest(jsonString) {
|
|
|
21
34
|
index: shard.index ?? i,
|
|
22
35
|
filename: shard.filename || shard.fileName || '',
|
|
23
36
|
size: shard.size,
|
|
24
|
-
hash: shard
|
|
37
|
+
hash: getExpectedShardHash(shard, manifest.hashAlgorithm),
|
|
25
38
|
blake3: shard.blake3 || shard.hash,
|
|
26
39
|
offset: shard.offset ?? offset,
|
|
27
40
|
hashAlgorithm: shard.hashAlgorithm,
|
|
@@ -2,13 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import { autoTuneKernels, prewarmKernels, clearKernelCaches } from './kernels/utils.js';
|
|
4
4
|
import { getRuntimeConfig } from '../config/runtime.js';
|
|
5
|
-
import { DEFAULT_KERNEL_WARMUP_CONFIG } from '../config/schema/kernel-warmup.schema.js';
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
export async function prepareKernelRuntime(
|
|
9
8
|
options = {}
|
|
10
9
|
) {
|
|
11
|
-
const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup
|
|
10
|
+
const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup;
|
|
11
|
+
if (!kernelWarmup) {
|
|
12
|
+
throw new Error('runtime.shared.kernelWarmup is required but missing from resolved config');
|
|
13
|
+
}
|
|
12
14
|
const {
|
|
13
15
|
prewarm = kernelWarmup.prewarm,
|
|
14
16
|
prewarmMode = kernelWarmup.prewarmMode,
|
|
@@ -513,9 +513,10 @@ function resolveAttentionPlan(
|
|
|
513
513
|
useF16KV,
|
|
514
514
|
useF16Q,
|
|
515
515
|
numHeads,
|
|
516
|
+
headDim,
|
|
516
517
|
kvLen,
|
|
518
|
+
isPaged,
|
|
517
519
|
caps,
|
|
518
|
-
headDim,
|
|
519
520
|
sharedLimit
|
|
520
521
|
);
|
|
521
522
|
const workgroups = calculateAttentionWorkgroups(adaptiveSelection.tier, seqLen, numHeads);
|
|
@@ -113,9 +113,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
113
113
|
@compute @workgroup_size(WORKGROUP_SIZE_MAIN, 1, 1)
|
|
114
114
|
fn main(
|
|
115
115
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
116
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
116
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
117
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
117
118
|
) {
|
|
118
|
-
|
|
119
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
120
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
119
121
|
let elem_idx = local_id.x;
|
|
120
122
|
|
|
121
123
|
if (block_idx >= u.num_blocks) {
|
|
@@ -106,9 +106,12 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
106
106
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
107
107
|
fn main_vec4(
|
|
108
108
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
109
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
109
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
110
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
110
111
|
) {
|
|
111
|
-
|
|
112
|
+
// Support 2D dispatch for tensors with >65535 blocks (e.g. large FFN weights).
|
|
113
|
+
// block_idx = flat workgroup index across both X and Y dimensions.
|
|
114
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
112
115
|
let thread_idx = local_id.x;
|
|
113
116
|
|
|
114
117
|
if (block_idx >= u.num_blocks) {
|
|
@@ -115,9 +115,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
115
115
|
fn main(
|
|
116
116
|
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
117
117
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
118
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
118
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
119
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
119
120
|
) {
|
|
120
|
-
|
|
121
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
122
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
121
123
|
let elem_idx = local_id.x;
|
|
122
124
|
|
|
123
125
|
if (block_idx >= u.num_blocks) {
|
|
@@ -108,9 +108,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
108
108
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
109
109
|
fn main_vec4(
|
|
110
110
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
111
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
111
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
112
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
112
113
|
) {
|
|
113
|
-
|
|
114
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
115
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
114
116
|
let thread_idx = local_id.x;
|
|
115
117
|
|
|
116
118
|
if (block_idx >= u.num_blocks) {
|
|
@@ -118,11 +118,15 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
118
118
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
119
119
|
fn main(
|
|
120
120
|
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
121
|
+
@builtin(num_workgroups) num_wg: vec3<u32>,
|
|
121
122
|
@builtin(subgroup_invocation_id) sg_id: u32,
|
|
122
123
|
@builtin(subgroup_size) sg_size: u32
|
|
123
124
|
) {
|
|
124
|
-
|
|
125
|
-
|
|
125
|
+
// Support 2D dispatch for tensors with >65535 workgroups.
|
|
126
|
+
// Compute flat global thread id across both X and Y dimensions.
|
|
127
|
+
let flat_global_x = global_id.x + global_id.y * num_wg.x * WORKGROUP_SIZE;
|
|
128
|
+
let block_idx = flat_global_x / QK_K;
|
|
129
|
+
let elem_idx = flat_global_x % QK_K;
|
|
126
130
|
|
|
127
131
|
// Use block 0 for out-of-bounds threads to maintain uniform control flow
|
|
128
132
|
// (required for subgroup operations)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* LFM2 gated short convolution kernel.
|
|
3
|
+
*
|
|
4
|
+
* Fuses B*x pre-gating, depthwise causal conv1d, and C*conv_out post-gating
|
|
5
|
+
* into a single GPU dispatch. Each thread handles one channel across all tokens
|
|
6
|
+
* sequentially, maintaining persistent conv state for autoregressive decode.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
/** Per-layer state maintained between calls. */
|
|
10
|
+
export interface GatedShortConvLayerState {
|
|
11
|
+
/** Pre-dequantized conv1d weights as GPUBuffer, shape [hiddenSize, kernelSize]. */
|
|
12
|
+
convWeightGPU: GPUBuffer;
|
|
13
|
+
|
|
14
|
+
/** Persistent conv state as GPUBuffer, shape [hiddenSize, kernelSize - 1]. */
|
|
15
|
+
convStateGPU: GPUBuffer;
|
|
16
|
+
|
|
17
|
+
/** Number of channels (hidden dimension). */
|
|
18
|
+
hiddenSize: number;
|
|
19
|
+
|
|
20
|
+
/** Conv1d kernel width (e.g., 4). */
|
|
21
|
+
kernelSize: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
/** Tensor returned by the kernel. */
|
|
25
|
+
export interface Tensor {
|
|
26
|
+
buffer: GPUBuffer;
|
|
27
|
+
dtype: string;
|
|
28
|
+
shape: readonly number[];
|
|
29
|
+
label: string;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
/** Options for runGatedShortConvGPU. */
|
|
33
|
+
export interface GatedShortConvOptions {
|
|
34
|
+
/** Number of tokens in this batch. Required. */
|
|
35
|
+
numTokens?: number;
|
|
36
|
+
|
|
37
|
+
/** Layer index for labeling/tracing. */
|
|
38
|
+
layerIdx?: number;
|
|
39
|
+
|
|
40
|
+
/** Command recorder for batched submission. */
|
|
41
|
+
recorder?: {
|
|
42
|
+
getEncoder(): GPUCommandEncoder;
|
|
43
|
+
trackTemporaryBuffer(buffer: GPUBuffer): void;
|
|
44
|
+
beginComputePass(label: string): GPUComputePassEncoder;
|
|
45
|
+
createUniformBuffer(data: ArrayBuffer, label: string): GPUBuffer;
|
|
46
|
+
device: GPUDevice;
|
|
47
|
+
} | null;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Run the LFM2 gated short convolution on GPU.
|
|
52
|
+
*
|
|
53
|
+
* @param inputTensor Tensor with shape [numTokens, 3 * hiddenSize] containing
|
|
54
|
+
* concatenated B, C, x from in_proj matmul output.
|
|
55
|
+
* @param layerState Persistent per-layer state (conv weights + conv state buffer).
|
|
56
|
+
* @param options Dispatch options.
|
|
57
|
+
* @returns Output tensor with shape [numTokens, hiddenSize].
|
|
58
|
+
*/
|
|
59
|
+
export function runGatedShortConvGPU(
|
|
60
|
+
inputTensor: Tensor,
|
|
61
|
+
layerState: GatedShortConvLayerState,
|
|
62
|
+
options?: GatedShortConvOptions
|
|
63
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import { getDevice, getDeviceEpoch } from '../device.js';
|
|
2
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
3
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { createTensor } from '../tensor.js';
|
|
5
|
+
import {
|
|
6
|
+
createUniformBufferFromData,
|
|
7
|
+
getOrCreateBindGroupLayout,
|
|
8
|
+
getOrCreatePipelineLayout,
|
|
9
|
+
} from './utils.js';
|
|
10
|
+
import { recordDispatch } from './dispatch.js';
|
|
11
|
+
|
|
12
|
+
const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
|
|
13
|
+
|
|
14
|
+
const SHADER = /* wgsl */ `
|
|
15
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
16
|
+
|
|
17
|
+
struct Params {
|
|
18
|
+
num_tokens: u32,
|
|
19
|
+
hidden_size: u32,
|
|
20
|
+
kernel_size: u32,
|
|
21
|
+
_pad: u32,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
@group(0) @binding(0) var<uniform> params: Params;
|
|
25
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
26
|
+
@group(0) @binding(2) var<storage, read> conv_weight: array<f32>;
|
|
27
|
+
@group(0) @binding(3) var<storage, read_write> conv_state: array<f32>;
|
|
28
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
29
|
+
|
|
30
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
31
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
32
|
+
let channel = gid.x;
|
|
33
|
+
if (channel >= params.hidden_size) {
|
|
34
|
+
return;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
let hidden_size = params.hidden_size;
|
|
38
|
+
let kernel_size = params.kernel_size;
|
|
39
|
+
let state_width = kernel_size - 1u;
|
|
40
|
+
let row_stride = 3u * hidden_size;
|
|
41
|
+
let state_base = channel * state_width;
|
|
42
|
+
let weight_base = channel * kernel_size;
|
|
43
|
+
|
|
44
|
+
for (var t: u32 = 0u; t < params.num_tokens; t = t + 1u) {
|
|
45
|
+
let row_offset = t * row_stride;
|
|
46
|
+
|
|
47
|
+
let b_val = input[row_offset + channel];
|
|
48
|
+
let c_val = input[row_offset + hidden_size + channel];
|
|
49
|
+
let x_val = input[row_offset + 2u * hidden_size + channel];
|
|
50
|
+
|
|
51
|
+
let bx = b_val * x_val;
|
|
52
|
+
|
|
53
|
+
var conv_sum: f32 = 0.0;
|
|
54
|
+
for (var k: u32 = 0u; k < state_width; k = k + 1u) {
|
|
55
|
+
conv_sum = conv_sum + conv_state[state_base + k] * conv_weight[weight_base + k];
|
|
56
|
+
}
|
|
57
|
+
conv_sum = conv_sum + bx * conv_weight[weight_base + state_width];
|
|
58
|
+
|
|
59
|
+
for (var k: u32 = 0u; k + 1u < state_width; k = k + 1u) {
|
|
60
|
+
conv_state[state_base + k] = conv_state[state_base + k + 1u];
|
|
61
|
+
}
|
|
62
|
+
if (state_width > 0u) {
|
|
63
|
+
conv_state[state_base + state_width - 1u] = bx;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
output[t * hidden_size + channel] = c_val * conv_sum;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
`;
|
|
70
|
+
|
|
71
|
+
// ======================================================================
|
|
72
|
+
// UNIFORM BUFFER
|
|
73
|
+
// ======================================================================
|
|
74
|
+
|
|
75
|
+
const UNIFORM_LAYOUT = {
|
|
76
|
+
numTokens: { offset: 0, size: 4 },
|
|
77
|
+
hiddenSize: { offset: 4, size: 4 },
|
|
78
|
+
kernelSize: { offset: 8, size: 4 },
|
|
79
|
+
_pad: { offset: 12, size: 4 },
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
const UNIFORM_SIZE = 16;
|
|
83
|
+
|
|
84
|
+
function buildParamsData(numTokens, hiddenSize, kernelSize) {
|
|
85
|
+
const data = new ArrayBuffer(UNIFORM_SIZE);
|
|
86
|
+
const view = new DataView(data);
|
|
87
|
+
view.setUint32(UNIFORM_LAYOUT.numTokens.offset, numTokens, true);
|
|
88
|
+
view.setUint32(UNIFORM_LAYOUT.hiddenSize.offset, hiddenSize, true);
|
|
89
|
+
view.setUint32(UNIFORM_LAYOUT.kernelSize.offset, kernelSize, true);
|
|
90
|
+
view.setUint32(UNIFORM_LAYOUT._pad.offset, 0, true);
|
|
91
|
+
return data;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// ======================================================================
|
|
95
|
+
// PIPELINE CACHE
|
|
96
|
+
// ======================================================================
|
|
97
|
+
|
|
98
|
+
let cachedEpoch = -1;
|
|
99
|
+
let pipeline = null;
|
|
100
|
+
let bindGroupLayout = null;
|
|
101
|
+
|
|
102
|
+
function createPipeline(device) {
|
|
103
|
+
bindGroupLayout = getOrCreateBindGroupLayout(
|
|
104
|
+
'gated_short_conv_layout',
|
|
105
|
+
[
|
|
106
|
+
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
|
|
107
|
+
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
|
|
108
|
+
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
|
|
109
|
+
{ binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
|
|
110
|
+
{ binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
|
|
111
|
+
],
|
|
112
|
+
device
|
|
113
|
+
);
|
|
114
|
+
|
|
115
|
+
const module = device.createShaderModule({
|
|
116
|
+
label: 'gated_short_conv',
|
|
117
|
+
code: SHADER,
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
pipeline = device.createComputePipeline({
|
|
121
|
+
label: 'gated_short_conv_pipeline',
|
|
122
|
+
layout: getOrCreatePipelineLayout('gated_short_conv_pipeline_layout', [bindGroupLayout], device),
|
|
123
|
+
compute: {
|
|
124
|
+
module,
|
|
125
|
+
entryPoint: 'main',
|
|
126
|
+
constants: {
|
|
127
|
+
WORKGROUP_SIZE: CONV_WORKGROUP_SIZE,
|
|
128
|
+
},
|
|
129
|
+
},
|
|
130
|
+
});
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
function ensurePipeline(device) {
|
|
134
|
+
const epoch = getDeviceEpoch();
|
|
135
|
+
if (epoch !== cachedEpoch || !pipeline) {
|
|
136
|
+
createPipeline(device);
|
|
137
|
+
cachedEpoch = epoch;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// ======================================================================
|
|
142
|
+
// VALIDATION
|
|
143
|
+
// ======================================================================
|
|
144
|
+
|
|
145
|
+
function requireGpuBuffer(buffer, label) {
|
|
146
|
+
if (!(buffer instanceof GPUBuffer)) {
|
|
147
|
+
throw new Error(`gated_short_conv kernel requires GPUBuffer for ${label}.`);
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// ======================================================================
|
|
152
|
+
// DISPATCH
|
|
153
|
+
// ======================================================================
|
|
154
|
+
|
|
155
|
+
export async function runGatedShortConvGPU(inputTensor, layerState, options = {}) {
|
|
156
|
+
const device = getDevice();
|
|
157
|
+
if (!device) {
|
|
158
|
+
throw new Error('No GPU device available for gated_short_conv.');
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
const recorder = options.recorder ?? null;
|
|
162
|
+
const useRecorder = recorder
|
|
163
|
+
&& typeof recorder.getEncoder === 'function'
|
|
164
|
+
&& typeof recorder.trackTemporaryBuffer === 'function';
|
|
165
|
+
|
|
166
|
+
requireGpuBuffer(inputTensor?.buffer, 'inputTensor');
|
|
167
|
+
requireGpuBuffer(layerState?.convWeightGPU, 'convWeightGPU');
|
|
168
|
+
requireGpuBuffer(layerState?.convStateGPU, 'convStateGPU');
|
|
169
|
+
|
|
170
|
+
const numTokens = Number(options.numTokens ?? 0);
|
|
171
|
+
if (!Number.isFinite(numTokens) || numTokens <= 0) {
|
|
172
|
+
throw new Error('runGatedShortConvGPU requires numTokens > 0.');
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
const hiddenSize = Number(layerState.hiddenSize ?? 0);
|
|
176
|
+
if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
|
|
177
|
+
throw new Error('runGatedShortConvGPU requires hiddenSize > 0.');
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
const kernelSize = Number(layerState.kernelSize ?? 0);
|
|
181
|
+
if (!Number.isFinite(kernelSize) || kernelSize < 2) {
|
|
182
|
+
throw new Error('runGatedShortConvGPU requires kernelSize >= 2.');
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
ensurePipeline(device);
|
|
186
|
+
|
|
187
|
+
const outputSize = numTokens * hiddenSize * Float32Array.BYTES_PER_ELEMENT;
|
|
188
|
+
const outputBuffer = acquireBuffer(outputSize, undefined, `L${options.layerIdx ?? 0}.gated_short_conv_out`);
|
|
189
|
+
|
|
190
|
+
if (useRecorder) {
|
|
191
|
+
const paramsBuffer = createUniformBufferFromData(
|
|
192
|
+
'gated_short_conv_params',
|
|
193
|
+
buildParamsData(numTokens, hiddenSize, kernelSize),
|
|
194
|
+
recorder
|
|
195
|
+
);
|
|
196
|
+
|
|
197
|
+
try {
|
|
198
|
+
const bg = device.createBindGroup({
|
|
199
|
+
label: 'gated_short_conv_bind_group',
|
|
200
|
+
layout: bindGroupLayout,
|
|
201
|
+
entries: [
|
|
202
|
+
{ binding: 0, resource: { buffer: paramsBuffer } },
|
|
203
|
+
{ binding: 1, resource: { buffer: inputTensor.buffer } },
|
|
204
|
+
{ binding: 2, resource: { buffer: layerState.convWeightGPU } },
|
|
205
|
+
{ binding: 3, resource: { buffer: layerState.convStateGPU } },
|
|
206
|
+
{ binding: 4, resource: { buffer: outputBuffer } },
|
|
207
|
+
],
|
|
208
|
+
});
|
|
209
|
+
|
|
210
|
+
recordDispatch(
|
|
211
|
+
recorder,
|
|
212
|
+
pipeline,
|
|
213
|
+
bg,
|
|
214
|
+
[Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1],
|
|
215
|
+
'gated_short_conv'
|
|
216
|
+
);
|
|
217
|
+
|
|
218
|
+
return createTensor(
|
|
219
|
+
outputBuffer,
|
|
220
|
+
'f32',
|
|
221
|
+
[numTokens, hiddenSize],
|
|
222
|
+
`L${options.layerIdx ?? 0}.gated_short_conv`
|
|
223
|
+
);
|
|
224
|
+
} catch (error) {
|
|
225
|
+
releaseBuffer(outputBuffer);
|
|
226
|
+
throw error;
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Non-recorder path
|
|
231
|
+
const paramsBuffer = createUniformBufferFromData(
|
|
232
|
+
'gated_short_conv_params',
|
|
233
|
+
buildParamsData(numTokens, hiddenSize, kernelSize),
|
|
234
|
+
null,
|
|
235
|
+
device,
|
|
236
|
+
{ useCache: false }
|
|
237
|
+
);
|
|
238
|
+
let submitted = false;
|
|
239
|
+
|
|
240
|
+
try {
|
|
241
|
+
const bg = device.createBindGroup({
|
|
242
|
+
label: 'gated_short_conv_bind_group',
|
|
243
|
+
layout: bindGroupLayout,
|
|
244
|
+
entries: [
|
|
245
|
+
{ binding: 0, resource: { buffer: paramsBuffer } },
|
|
246
|
+
{ binding: 1, resource: { buffer: inputTensor.buffer } },
|
|
247
|
+
{ binding: 2, resource: { buffer: layerState.convWeightGPU } },
|
|
248
|
+
{ binding: 3, resource: { buffer: layerState.convStateGPU } },
|
|
249
|
+
{ binding: 4, resource: { buffer: outputBuffer } },
|
|
250
|
+
],
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
const encoder = device.createCommandEncoder({ label: 'gated_short_conv' });
|
|
254
|
+
const pass = encoder.beginComputePass({ label: 'gated_short_conv_pass' });
|
|
255
|
+
pass.setPipeline(pipeline);
|
|
256
|
+
pass.setBindGroup(0, bg);
|
|
257
|
+
pass.dispatchWorkgroups(Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1);
|
|
258
|
+
pass.end();
|
|
259
|
+
device.queue.submit([encoder.finish()]);
|
|
260
|
+
submitted = true;
|
|
261
|
+
|
|
262
|
+
return createTensor(
|
|
263
|
+
outputBuffer,
|
|
264
|
+
'f32',
|
|
265
|
+
[numTokens, hiddenSize],
|
|
266
|
+
`L${options.layerIdx ?? 0}.gated_short_conv`
|
|
267
|
+
);
|
|
268
|
+
} catch (error) {
|
|
269
|
+
releaseBuffer(outputBuffer);
|
|
270
|
+
throw error;
|
|
271
|
+
} finally {
|
|
272
|
+
if (submitted) {
|
|
273
|
+
device.queue.onSubmittedWorkDone()
|
|
274
|
+
.then(() => {
|
|
275
|
+
paramsBuffer.destroy();
|
|
276
|
+
})
|
|
277
|
+
.catch(() => {
|
|
278
|
+
paramsBuffer.destroy();
|
|
279
|
+
});
|
|
280
|
+
} else {
|
|
281
|
+
paramsBuffer.destroy();
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
}
|
|
@@ -326,6 +326,14 @@ export {
|
|
|
326
326
|
type SplitQKVResult,
|
|
327
327
|
} from './split_qkv.js';
|
|
328
328
|
|
|
329
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
330
|
+
export {
|
|
331
|
+
runSplitQG,
|
|
332
|
+
recordSplitQG,
|
|
333
|
+
type SplitQGOptions,
|
|
334
|
+
type SplitQGResult,
|
|
335
|
+
} from './split_qg.js';
|
|
336
|
+
|
|
329
337
|
// Transpose
|
|
330
338
|
export {
|
|
331
339
|
runTranspose,
|
package/src/gpu/kernels/index.js
CHANGED
|
@@ -268,6 +268,12 @@ export {
|
|
|
268
268
|
recordSplitQKV,
|
|
269
269
|
} from './split_qkv.js';
|
|
270
270
|
|
|
271
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
272
|
+
export {
|
|
273
|
+
runSplitQG,
|
|
274
|
+
recordSplitQG,
|
|
275
|
+
} from './split_qg.js';
|
|
276
|
+
|
|
271
277
|
// Transpose
|
|
272
278
|
export {
|
|
273
279
|
runTranspose,
|