@simulatte/doppler 0.1.8 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +14 -1
- package/README.md +25 -6
- package/package.json +5 -3
- package/src/client/doppler-api.browser.js +6 -0
- package/src/client/doppler-api.d.ts +3 -0
- package/src/client/doppler-api.js +11 -2
- package/src/client/doppler-registry.js +3 -5
- package/src/client/doppler-registry.json +16 -0
- package/src/config/kernels/kernel-ref-digests.js +23 -21
- package/src/config/kernels/moe/mixtral.paths.json +46 -0
- package/src/config/loader.js +6 -0
- package/src/config/platforms/loader.js +3 -1
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +7 -0
- package/src/config/presets/models/gemma3.json +2 -1
- package/src/config/presets/models/gemma4.json +61 -0
- package/src/config/presets/models/granite-docling.json +70 -0
- package/src/config/presets/models/lfm2.json +6 -1
- package/src/config/presets/models/qwen3_vl.json +40 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
- package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
- package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
- package/src/config/presets/runtime/modes/trace-layers.json +1 -0
- package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
- package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
- package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
- package/src/config/runtime.js +3 -0
- package/src/config/schema/debug.schema.d.ts +40 -0
- package/src/config/schema/debug.schema.js +28 -0
- package/src/config/schema/index.js +2 -0
- package/src/config/schema/inference-defaults.schema.js +1 -1
- package/src/config/schema/kernel-path.schema.d.ts +1 -0
- package/src/config/schema/memory-limits.schema.js +2 -2
- package/src/config/schema/storage.schema.js +1 -1
- package/src/converter/conversion-plan.js +1 -1
- package/src/converter/core.js +17 -8
- package/src/converter/quantizer.d.ts +5 -0
- package/src/converter/quantizer.js +15 -0
- package/src/distribution/shard-delivery.js +34 -0
- package/src/formats/rdrr/classification.js +32 -0
- package/src/gpu/kernel-runtime.js +4 -2
- package/src/gpu/kernels/attention.js +2 -1
- package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
- package/src/gpu/kernels/dequant_shared.wgsl +4 -2
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
- package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
- package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
- package/src/gpu/kernels/gated-short-conv.js +284 -0
- package/src/gpu/kernels/linear-attention-core.js +37 -17
- package/src/gpu/kernels/matmul-selection.js +1 -0
- package/src/gpu/kernels/matmul.d.ts +3 -0
- package/src/gpu/kernels/matmul.js +70 -1
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
- package/src/gpu/kernels/sample.js +1 -3
- package/src/gpu/kernels/sample.wgsl +39 -9
- package/src/gpu/kernels/sample_f16.wgsl +38 -8
- package/src/gpu/kernels/shader-cache.js +9 -4
- package/src/inference/kv-cache/base.js +3 -10
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +2 -1
- package/src/inference/pipelines/text/attention/projections.d.ts +3 -0
- package/src/inference/pipelines/text/attention/projections.js +13 -2
- package/src/inference/pipelines/text/attention/record.js +1 -0
- package/src/inference/pipelines/text/attention/run.js +9 -0
- package/src/inference/pipelines/text/config.d.ts +1 -0
- package/src/inference/pipelines/text/config.js +32 -4
- package/src/inference/pipelines/text/embed.js +26 -7
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
- package/src/inference/pipelines/text/execution-v0.js +12 -1
- package/src/inference/pipelines/text/generator-helpers.js +1 -0
- package/src/inference/pipelines/text/generator-runtime.js +14 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +9 -0
- package/src/inference/pipelines/text/generator-steps.js +46 -29
- package/src/inference/pipelines/text/generator.d.ts +5 -0
- package/src/inference/pipelines/text/generator.js +320 -166
- package/src/inference/pipelines/text/init.d.ts +2 -0
- package/src/inference/pipelines/text/init.js +19 -5
- package/src/inference/pipelines/text/layer.js +37 -8
- package/src/inference/pipelines/text/moe-gpu.js +21 -3
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
- package/src/inference/pipelines/text/ops.js +123 -53
- package/src/inference/pipelines/text/probes.js +1 -0
- package/src/inference/pipelines/text/state.js +2 -0
- package/src/inference/pipelines/text.d.ts +5 -0
- package/src/inference/pipelines/text.js +59 -1
- package/src/inference/pipelines/vision/encoder.js +386 -0
- package/src/inference/pipelines/vision/image-preprocess.js +151 -0
- package/src/inference/pipelines/vision/index.js +173 -0
- package/src/inference/pipelines/vision/ops.js +78 -0
- package/src/inference/pipelines/vision/patch-embed.js +151 -0
- package/src/inference/test-harness.js +9 -7
- package/src/loader/doppler-loader.d.ts +3 -0
- package/src/loader/doppler-loader.js +20 -3
- package/src/loader/experts/expert-cache.js +6 -2
- package/src/loader/experts/expert-loader.js +6 -2
- package/src/loader/layer-loader.js +42 -3
- package/src/loader/manifest-config.js +3 -1
- package/src/loader/tensors/tensor-loader.d.ts +3 -0
- package/src/loader/tensors/tensor-loader.js +124 -3
- package/src/rules/kernels/moe.rules.mixtral.json +75 -0
- package/src/rules/kernels/softmax.rules.json +2 -0
- package/src/rules/rule-registry.d.ts +1 -0
- package/src/rules/rule-registry.js +2 -0
- package/src/storage/quickstart-downloader.d.ts +3 -0
- package/src/storage/quickstart-downloader.js +27 -30
- package/src/tooling/node-converter.js +25 -7
- package/src/tooling/node-source-runtime.js +29 -5
- package/src/tooling/node-webgpu.js +24 -7
- package/src/utils/hf-resolve-url.d.ts +16 -0
- package/src/utils/hf-resolve-url.js +17 -0
- package/src/version.js +1 -1
- package/src/tooling/node-convert.d.ts +0 -54
|
@@ -473,7 +473,7 @@ export function resolveConversionPlan(options) {
|
|
|
473
473
|
// role dtypes should not change kernel-path selection when explicit compute precision is targeted.
|
|
474
474
|
const embedDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'embedding'));
|
|
475
475
|
const lmHeadDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'lm_head'));
|
|
476
|
-
const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder']);
|
|
476
|
+
const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder', 'visual.']);
|
|
477
477
|
const hasAudio = hasAnyTensorPattern(tensors, ['audio_', 'audio_encoder', 'whisper', 'wav2vec']);
|
|
478
478
|
const hasProjector = hasAnyTensorPattern(tensors, ['multi_modal_projector', 'mm_projector', 'projector']);
|
|
479
479
|
const quantizationInfo = buildQuantizationInfo(
|
package/src/converter/core.js
CHANGED
|
@@ -114,6 +114,15 @@ export function resolveTensorTargetQuant(tensorName, fallbackQuant, quantization
|
|
|
114
114
|
const headQuant = quantizationInfo.lmHead ?? quantizationInfo.embeddings ?? fallback;
|
|
115
115
|
return normalizeStorageQuant(headQuant) ?? fallback;
|
|
116
116
|
}
|
|
117
|
+
if (role === 'vision') {
|
|
118
|
+
return normalizeStorageQuant(quantizationInfo.vision ?? fallback) ?? fallback;
|
|
119
|
+
}
|
|
120
|
+
if (role === 'projector') {
|
|
121
|
+
return normalizeStorageQuant(quantizationInfo.projector ?? fallback) ?? fallback;
|
|
122
|
+
}
|
|
123
|
+
if (role === 'audio') {
|
|
124
|
+
return normalizeStorageQuant(quantizationInfo.audio ?? fallback) ?? fallback;
|
|
125
|
+
}
|
|
117
126
|
return normalizeStorageQuant(quantizationInfo.weights ?? fallback) ?? fallback;
|
|
118
127
|
}
|
|
119
128
|
|
|
@@ -819,11 +828,11 @@ export function extractArchitecture(config, ggufConfig) {
|
|
|
819
828
|
vocabSize,
|
|
820
829
|
maxSeqLen,
|
|
821
830
|
ropeTheta,
|
|
822
|
-
linearNumKeyHeads
|
|
823
|
-
linearNumValueHeads
|
|
824
|
-
linearKeyHeadDim
|
|
825
|
-
linearValueHeadDim
|
|
826
|
-
linearConvKernelDim
|
|
831
|
+
linearNumKeyHeads,
|
|
832
|
+
linearNumValueHeads,
|
|
833
|
+
linearKeyHeadDim,
|
|
834
|
+
linearValueHeadDim,
|
|
835
|
+
linearConvKernelDim,
|
|
827
836
|
linearNormMode,
|
|
828
837
|
};
|
|
829
838
|
}
|
|
@@ -1056,7 +1065,7 @@ export function createManifest(
|
|
|
1056
1065
|
modelId,
|
|
1057
1066
|
modelType: resolvedModelType,
|
|
1058
1067
|
quantization: resolvedQuantization,
|
|
1059
|
-
quantizationInfo: options.quantizationInfo
|
|
1068
|
+
quantizationInfo: options.quantizationInfo,
|
|
1060
1069
|
architecture: resolvedArchitecture,
|
|
1061
1070
|
moeConfig,
|
|
1062
1071
|
inference,
|
|
@@ -1065,8 +1074,8 @@ export function createManifest(
|
|
|
1065
1074
|
totalSize: shards.reduce((sum, s) => sum + s.size, 0),
|
|
1066
1075
|
hashAlgorithm,
|
|
1067
1076
|
eos_token_id: eosTokenId,
|
|
1068
|
-
config: isDiffusion ? rawConfig : undefined,
|
|
1069
|
-
conversion: options.conversionInfo
|
|
1077
|
+
config: isDiffusion ? rawConfig : (rawConfig.vision_config ? { vision_config: rawConfig.vision_config } : undefined),
|
|
1078
|
+
conversion: options.conversionInfo,
|
|
1070
1079
|
metadata: {
|
|
1071
1080
|
source,
|
|
1072
1081
|
convertedAt: resolveConvertedAt(
|
|
@@ -73,6 +73,11 @@ export declare function dequantizeQ4KM(
|
|
|
73
73
|
shape: number[]
|
|
74
74
|
): Float32Array;
|
|
75
75
|
|
|
76
|
+
export declare function dequantizeQ4KMRowWise(
|
|
77
|
+
quantized: Uint8Array,
|
|
78
|
+
shape: [number, number]
|
|
79
|
+
): Float32Array;
|
|
80
|
+
|
|
76
81
|
export declare function calculateQuantizationError(
|
|
77
82
|
original: Float32Array,
|
|
78
83
|
reconstructed: Float32Array
|
|
@@ -355,6 +355,21 @@ export function dequantizeQ4KM(quantized, numBlocks, shape) {
|
|
|
355
355
|
return result;
|
|
356
356
|
}
|
|
357
357
|
|
|
358
|
+
export function dequantizeQ4KMRowWise(quantized, shape) {
|
|
359
|
+
const [rows, cols] = shape;
|
|
360
|
+
const blocksPerRow = Math.ceil(cols / QK_K);
|
|
361
|
+
const result = new Float32Array(rows * cols);
|
|
362
|
+
|
|
363
|
+
for (let row = 0; row < rows; row++) {
|
|
364
|
+
const rowOffset = row * blocksPerRow * QK4_K_BLOCK_SIZE;
|
|
365
|
+
const rowBytes = quantized.slice(rowOffset, rowOffset + (blocksPerRow * QK4_K_BLOCK_SIZE));
|
|
366
|
+
const rowDequantized = dequantizeQ4KM(rowBytes, blocksPerRow, [1, cols]);
|
|
367
|
+
result.set(rowDequantized, row * cols);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
return result;
|
|
371
|
+
}
|
|
372
|
+
|
|
358
373
|
export function calculateQuantizationError(original, reconstructed) {
|
|
359
374
|
if (original.length !== reconstructed.length) {
|
|
360
375
|
throw new Error('Length mismatch');
|
|
@@ -1317,6 +1317,25 @@ async function clearPersistedShardState(shardIndex) {
|
|
|
1317
1317
|
await writer.abort?.();
|
|
1318
1318
|
}
|
|
1319
1319
|
|
|
1320
|
+
async function recoverHttpRejectedResumeRange(
|
|
1321
|
+
baseUrl,
|
|
1322
|
+
shardInfo,
|
|
1323
|
+
shardIndex,
|
|
1324
|
+
options,
|
|
1325
|
+
transferState,
|
|
1326
|
+
writeToStore
|
|
1327
|
+
) {
|
|
1328
|
+
await abortHttpTransferState(transferState);
|
|
1329
|
+
if (writeToStore) {
|
|
1330
|
+
await clearPersistedShardState(shardIndex);
|
|
1331
|
+
}
|
|
1332
|
+
return downloadShardFromHttp(baseUrl, shardInfo, shardIndex, {
|
|
1333
|
+
...options,
|
|
1334
|
+
__disablePersistedResume: true,
|
|
1335
|
+
__resumeRangeRecoveryAttempted: true,
|
|
1336
|
+
});
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1320
1339
|
async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {}) {
|
|
1321
1340
|
const {
|
|
1322
1341
|
signal,
|
|
@@ -1529,6 +1548,21 @@ async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {
|
|
|
1529
1548
|
throw error;
|
|
1530
1549
|
}
|
|
1531
1550
|
|
|
1551
|
+
if (
|
|
1552
|
+
error?.status === 416
|
|
1553
|
+
&& transferState.receivedBytes > 0
|
|
1554
|
+
&& options.__resumeRangeRecoveryAttempted !== true
|
|
1555
|
+
) {
|
|
1556
|
+
return recoverHttpRejectedResumeRange(
|
|
1557
|
+
baseUrl,
|
|
1558
|
+
shardInfo,
|
|
1559
|
+
shardIndex,
|
|
1560
|
+
options,
|
|
1561
|
+
transferState,
|
|
1562
|
+
writeToStore
|
|
1563
|
+
);
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1532
1566
|
if (Number.isInteger(error?.status) && error.status >= 400 && error.status < 500 && error.status !== 429) {
|
|
1533
1567
|
await abortHttpTransferState(transferState);
|
|
1534
1568
|
throw error;
|
|
@@ -32,6 +32,12 @@ export function classifyTensor(name, modelType) {
|
|
|
32
32
|
return 'head';
|
|
33
33
|
}
|
|
34
34
|
|
|
35
|
+
// Multimodal groups
|
|
36
|
+
const role = classifyTensorRole(name);
|
|
37
|
+
if (role === 'vision') return 'vision';
|
|
38
|
+
if (role === 'projector') return 'projector';
|
|
39
|
+
if (role === 'audio') return 'audio';
|
|
40
|
+
|
|
35
41
|
// Extract layer index
|
|
36
42
|
const layerMatch = name.match(/layers?[._](\d+)/i);
|
|
37
43
|
if (!layerMatch) {
|
|
@@ -96,6 +102,29 @@ export function classifyTensorRole(name) {
|
|
|
96
102
|
if (lower.includes('lm_head')) return 'lm_head';
|
|
97
103
|
if (lower.endsWith('output.weight') && !lower.includes('attn_')) return 'lm_head';
|
|
98
104
|
|
|
105
|
+
// Multimodal: vision encoder tensors
|
|
106
|
+
if (lower.startsWith('vision_tower.') || lower.startsWith('vision_model.')
|
|
107
|
+
|| lower.startsWith('visual.') || lower.startsWith('model.visual.')
|
|
108
|
+
|| lower.startsWith('vision.') || lower.startsWith('model.vision.')
|
|
109
|
+
|| lower.startsWith('vision_encoder.') || lower.startsWith('image_encoder.')
|
|
110
|
+
|| lower.startsWith('image_tower.') || lower.startsWith('image.')
|
|
111
|
+
|| lower.startsWith('model.image.')) {
|
|
112
|
+
return 'vision';
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Multimodal: audio encoder tensors
|
|
116
|
+
if (lower.startsWith('audio_tower.') || lower.startsWith('audio_model.')
|
|
117
|
+
|| lower.startsWith('audio.') || lower.startsWith('model.audio.')
|
|
118
|
+
|| lower.startsWith('audio_encoder.')) {
|
|
119
|
+
return 'audio';
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Multimodal: projector tensors
|
|
123
|
+
if (lower.startsWith('multi_modal_projector.') || lower.startsWith('model.multi_modal_projector.')
|
|
124
|
+
|| lower.startsWith('mm_projector.') || lower.startsWith('model.mm_projector.')) {
|
|
125
|
+
return 'projector';
|
|
126
|
+
}
|
|
127
|
+
|
|
99
128
|
if (lower.includes('shared_expert') || /experts?[._]/.test(lower)) {
|
|
100
129
|
return 'expert';
|
|
101
130
|
}
|
|
@@ -207,6 +236,9 @@ export function getGroupType(groupId, modelType) {
|
|
|
207
236
|
}
|
|
208
237
|
if (groupId === 'embed') return 'embed';
|
|
209
238
|
if (groupId === 'head') return 'head';
|
|
239
|
+
if (groupId === 'vision') return 'vision';
|
|
240
|
+
if (groupId === 'projector') return 'projector';
|
|
241
|
+
if (groupId === 'audio') return 'audio';
|
|
210
242
|
if (groupId === 'other') return 'layer';
|
|
211
243
|
|
|
212
244
|
if (groupId.includes('.expert.')) return 'expert';
|
|
@@ -2,13 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import { autoTuneKernels, prewarmKernels, clearKernelCaches } from './kernels/utils.js';
|
|
4
4
|
import { getRuntimeConfig } from '../config/runtime.js';
|
|
5
|
-
import { DEFAULT_KERNEL_WARMUP_CONFIG } from '../config/schema/kernel-warmup.schema.js';
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
export async function prepareKernelRuntime(
|
|
9
8
|
options = {}
|
|
10
9
|
) {
|
|
11
|
-
const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup
|
|
10
|
+
const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup;
|
|
11
|
+
if (!kernelWarmup) {
|
|
12
|
+
throw new Error('runtime.shared.kernelWarmup is required but missing from resolved config');
|
|
13
|
+
}
|
|
12
14
|
const {
|
|
13
15
|
prewarm = kernelWarmup.prewarm,
|
|
14
16
|
prewarmMode = kernelWarmup.prewarmMode,
|
|
@@ -513,9 +513,10 @@ function resolveAttentionPlan(
|
|
|
513
513
|
useF16KV,
|
|
514
514
|
useF16Q,
|
|
515
515
|
numHeads,
|
|
516
|
+
headDim,
|
|
516
517
|
kvLen,
|
|
518
|
+
isPaged,
|
|
517
519
|
caps,
|
|
518
|
-
headDim,
|
|
519
520
|
sharedLimit
|
|
520
521
|
);
|
|
521
522
|
const workgroups = calculateAttentionWorkgroups(adaptiveSelection.tier, seqLen, numHeads);
|
|
@@ -113,9 +113,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
113
113
|
@compute @workgroup_size(WORKGROUP_SIZE_MAIN, 1, 1)
|
|
114
114
|
fn main(
|
|
115
115
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
116
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
116
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
117
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
117
118
|
) {
|
|
118
|
-
|
|
119
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
120
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
119
121
|
let elem_idx = local_id.x;
|
|
120
122
|
|
|
121
123
|
if (block_idx >= u.num_blocks) {
|
|
@@ -106,9 +106,12 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
106
106
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
107
107
|
fn main_vec4(
|
|
108
108
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
109
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
109
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
110
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
110
111
|
) {
|
|
111
|
-
|
|
112
|
+
// Support 2D dispatch for tensors with >65535 blocks (e.g. large FFN weights).
|
|
113
|
+
// block_idx = flat workgroup index across both X and Y dimensions.
|
|
114
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
112
115
|
let thread_idx = local_id.x;
|
|
113
116
|
|
|
114
117
|
if (block_idx >= u.num_blocks) {
|
|
@@ -115,9 +115,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
115
115
|
fn main(
|
|
116
116
|
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
117
117
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
118
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
118
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
119
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
119
120
|
) {
|
|
120
|
-
|
|
121
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
122
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
121
123
|
let elem_idx = local_id.x;
|
|
122
124
|
|
|
123
125
|
if (block_idx >= u.num_blocks) {
|
|
@@ -108,9 +108,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
108
108
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
109
109
|
fn main_vec4(
|
|
110
110
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
111
|
-
@builtin(workgroup_id) workgroup_id: vec3<u32
|
|
111
|
+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
|
112
|
+
@builtin(num_workgroups) num_wg: vec3<u32>
|
|
112
113
|
) {
|
|
113
|
-
|
|
114
|
+
// Support 2D dispatch for tensors with >65535 blocks.
|
|
115
|
+
let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
|
|
114
116
|
let thread_idx = local_id.x;
|
|
115
117
|
|
|
116
118
|
if (block_idx >= u.num_blocks) {
|
|
@@ -118,11 +118,15 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
|
|
|
118
118
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
119
119
|
fn main(
|
|
120
120
|
@builtin(global_invocation_id) global_id: vec3<u32>,
|
|
121
|
+
@builtin(num_workgroups) num_wg: vec3<u32>,
|
|
121
122
|
@builtin(subgroup_invocation_id) sg_id: u32,
|
|
122
123
|
@builtin(subgroup_size) sg_size: u32
|
|
123
124
|
) {
|
|
124
|
-
|
|
125
|
-
|
|
125
|
+
// Support 2D dispatch for tensors with >65535 workgroups.
|
|
126
|
+
// Compute flat global thread id across both X and Y dimensions.
|
|
127
|
+
let flat_global_x = global_id.x + global_id.y * num_wg.x * WORKGROUP_SIZE;
|
|
128
|
+
let block_idx = flat_global_x / QK_K;
|
|
129
|
+
let elem_idx = flat_global_x % QK_K;
|
|
126
130
|
|
|
127
131
|
// Use block 0 for out-of-bounds threads to maintain uniform control flow
|
|
128
132
|
// (required for subgroup operations)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* LFM2 gated short convolution kernel.
|
|
3
|
+
*
|
|
4
|
+
* Fuses B*x pre-gating, depthwise causal conv1d, and C*conv_out post-gating
|
|
5
|
+
* into a single GPU dispatch. Each thread handles one channel across all tokens
|
|
6
|
+
* sequentially, maintaining persistent conv state for autoregressive decode.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
/** Per-layer state maintained between calls. */
|
|
10
|
+
export interface GatedShortConvLayerState {
|
|
11
|
+
/** Pre-dequantized conv1d weights as GPUBuffer, shape [hiddenSize, kernelSize]. */
|
|
12
|
+
convWeightGPU: GPUBuffer;
|
|
13
|
+
|
|
14
|
+
/** Persistent conv state as GPUBuffer, shape [hiddenSize, kernelSize - 1]. */
|
|
15
|
+
convStateGPU: GPUBuffer;
|
|
16
|
+
|
|
17
|
+
/** Number of channels (hidden dimension). */
|
|
18
|
+
hiddenSize: number;
|
|
19
|
+
|
|
20
|
+
/** Conv1d kernel width (e.g., 4). */
|
|
21
|
+
kernelSize: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
/** Tensor returned by the kernel. */
|
|
25
|
+
export interface Tensor {
|
|
26
|
+
buffer: GPUBuffer;
|
|
27
|
+
dtype: string;
|
|
28
|
+
shape: readonly number[];
|
|
29
|
+
label: string;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
/** Options for runGatedShortConvGPU. */
|
|
33
|
+
export interface GatedShortConvOptions {
|
|
34
|
+
/** Number of tokens in this batch. Required. */
|
|
35
|
+
numTokens?: number;
|
|
36
|
+
|
|
37
|
+
/** Layer index for labeling/tracing. */
|
|
38
|
+
layerIdx?: number;
|
|
39
|
+
|
|
40
|
+
/** Command recorder for batched submission. */
|
|
41
|
+
recorder?: {
|
|
42
|
+
getEncoder(): GPUCommandEncoder;
|
|
43
|
+
trackTemporaryBuffer(buffer: GPUBuffer): void;
|
|
44
|
+
beginComputePass(label: string): GPUComputePassEncoder;
|
|
45
|
+
createUniformBuffer(data: ArrayBuffer, label: string): GPUBuffer;
|
|
46
|
+
device: GPUDevice;
|
|
47
|
+
} | null;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Run the LFM2 gated short convolution on GPU.
|
|
52
|
+
*
|
|
53
|
+
* @param inputTensor Tensor with shape [numTokens, 3 * hiddenSize] containing
|
|
54
|
+
* concatenated B, C, x from in_proj matmul output.
|
|
55
|
+
* @param layerState Persistent per-layer state (conv weights + conv state buffer).
|
|
56
|
+
* @param options Dispatch options.
|
|
57
|
+
* @returns Output tensor with shape [numTokens, hiddenSize].
|
|
58
|
+
*/
|
|
59
|
+
export function runGatedShortConvGPU(
|
|
60
|
+
inputTensor: Tensor,
|
|
61
|
+
layerState: GatedShortConvLayerState,
|
|
62
|
+
options?: GatedShortConvOptions
|
|
63
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import { getDevice, getDeviceEpoch } from '../device.js';
|
|
2
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
3
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { createTensor } from '../tensor.js';
|
|
5
|
+
import {
|
|
6
|
+
createUniformBufferFromData,
|
|
7
|
+
getOrCreateBindGroupLayout,
|
|
8
|
+
getOrCreatePipelineLayout,
|
|
9
|
+
} from './utils.js';
|
|
10
|
+
import { recordDispatch } from './dispatch.js';
|
|
11
|
+
|
|
12
|
+
const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
|
|
13
|
+
|
|
14
|
+
const SHADER = /* wgsl */ `
|
|
15
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
16
|
+
|
|
17
|
+
struct Params {
|
|
18
|
+
num_tokens: u32,
|
|
19
|
+
hidden_size: u32,
|
|
20
|
+
kernel_size: u32,
|
|
21
|
+
_pad: u32,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
@group(0) @binding(0) var<uniform> params: Params;
|
|
25
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
26
|
+
@group(0) @binding(2) var<storage, read> conv_weight: array<f32>;
|
|
27
|
+
@group(0) @binding(3) var<storage, read_write> conv_state: array<f32>;
|
|
28
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
29
|
+
|
|
30
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
31
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
32
|
+
let channel = gid.x;
|
|
33
|
+
if (channel >= params.hidden_size) {
|
|
34
|
+
return;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
let hidden_size = params.hidden_size;
|
|
38
|
+
let kernel_size = params.kernel_size;
|
|
39
|
+
let state_width = kernel_size - 1u;
|
|
40
|
+
let row_stride = 3u * hidden_size;
|
|
41
|
+
let state_base = channel * state_width;
|
|
42
|
+
let weight_base = channel * kernel_size;
|
|
43
|
+
|
|
44
|
+
for (var t: u32 = 0u; t < params.num_tokens; t = t + 1u) {
|
|
45
|
+
let row_offset = t * row_stride;
|
|
46
|
+
|
|
47
|
+
let b_val = input[row_offset + channel];
|
|
48
|
+
let c_val = input[row_offset + hidden_size + channel];
|
|
49
|
+
let x_val = input[row_offset + 2u * hidden_size + channel];
|
|
50
|
+
|
|
51
|
+
let bx = b_val * x_val;
|
|
52
|
+
|
|
53
|
+
var conv_sum: f32 = 0.0;
|
|
54
|
+
for (var k: u32 = 0u; k < state_width; k = k + 1u) {
|
|
55
|
+
conv_sum = conv_sum + conv_state[state_base + k] * conv_weight[weight_base + k];
|
|
56
|
+
}
|
|
57
|
+
conv_sum = conv_sum + bx * conv_weight[weight_base + state_width];
|
|
58
|
+
|
|
59
|
+
for (var k: u32 = 0u; k + 1u < state_width; k = k + 1u) {
|
|
60
|
+
conv_state[state_base + k] = conv_state[state_base + k + 1u];
|
|
61
|
+
}
|
|
62
|
+
if (state_width > 0u) {
|
|
63
|
+
conv_state[state_base + state_width - 1u] = bx;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
output[t * hidden_size + channel] = c_val * conv_sum;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
`;
|
|
70
|
+
|
|
71
|
+
// ======================================================================
|
|
72
|
+
// UNIFORM BUFFER
|
|
73
|
+
// ======================================================================
|
|
74
|
+
|
|
75
|
+
const UNIFORM_LAYOUT = {
|
|
76
|
+
numTokens: { offset: 0, size: 4 },
|
|
77
|
+
hiddenSize: { offset: 4, size: 4 },
|
|
78
|
+
kernelSize: { offset: 8, size: 4 },
|
|
79
|
+
_pad: { offset: 12, size: 4 },
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
const UNIFORM_SIZE = 16;
|
|
83
|
+
|
|
84
|
+
function buildParamsData(numTokens, hiddenSize, kernelSize) {
|
|
85
|
+
const data = new ArrayBuffer(UNIFORM_SIZE);
|
|
86
|
+
const view = new DataView(data);
|
|
87
|
+
view.setUint32(UNIFORM_LAYOUT.numTokens.offset, numTokens, true);
|
|
88
|
+
view.setUint32(UNIFORM_LAYOUT.hiddenSize.offset, hiddenSize, true);
|
|
89
|
+
view.setUint32(UNIFORM_LAYOUT.kernelSize.offset, kernelSize, true);
|
|
90
|
+
view.setUint32(UNIFORM_LAYOUT._pad.offset, 0, true);
|
|
91
|
+
return data;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// ======================================================================
|
|
95
|
+
// PIPELINE CACHE
|
|
96
|
+
// ======================================================================
|
|
97
|
+
|
|
98
|
+
let cachedEpoch = -1;
|
|
99
|
+
let pipeline = null;
|
|
100
|
+
let bindGroupLayout = null;
|
|
101
|
+
|
|
102
|
+
function createPipeline(device) {
|
|
103
|
+
bindGroupLayout = getOrCreateBindGroupLayout(
|
|
104
|
+
'gated_short_conv_layout',
|
|
105
|
+
[
|
|
106
|
+
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
|
|
107
|
+
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
|
|
108
|
+
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
|
|
109
|
+
{ binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
|
|
110
|
+
{ binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
|
|
111
|
+
],
|
|
112
|
+
device
|
|
113
|
+
);
|
|
114
|
+
|
|
115
|
+
const module = device.createShaderModule({
|
|
116
|
+
label: 'gated_short_conv',
|
|
117
|
+
code: SHADER,
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
pipeline = device.createComputePipeline({
|
|
121
|
+
label: 'gated_short_conv_pipeline',
|
|
122
|
+
layout: getOrCreatePipelineLayout('gated_short_conv_pipeline_layout', [bindGroupLayout], device),
|
|
123
|
+
compute: {
|
|
124
|
+
module,
|
|
125
|
+
entryPoint: 'main',
|
|
126
|
+
constants: {
|
|
127
|
+
WORKGROUP_SIZE: CONV_WORKGROUP_SIZE,
|
|
128
|
+
},
|
|
129
|
+
},
|
|
130
|
+
});
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
function ensurePipeline(device) {
|
|
134
|
+
const epoch = getDeviceEpoch();
|
|
135
|
+
if (epoch !== cachedEpoch || !pipeline) {
|
|
136
|
+
createPipeline(device);
|
|
137
|
+
cachedEpoch = epoch;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// ======================================================================
|
|
142
|
+
// VALIDATION
|
|
143
|
+
// ======================================================================
|
|
144
|
+
|
|
145
|
+
function requireGpuBuffer(buffer, label) {
|
|
146
|
+
if (!(buffer instanceof GPUBuffer)) {
|
|
147
|
+
throw new Error(`gated_short_conv kernel requires GPUBuffer for ${label}.`);
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// ======================================================================
|
|
152
|
+
// DISPATCH
|
|
153
|
+
// ======================================================================
|
|
154
|
+
|
|
155
|
+
export async function runGatedShortConvGPU(inputTensor, layerState, options = {}) {
|
|
156
|
+
const device = getDevice();
|
|
157
|
+
if (!device) {
|
|
158
|
+
throw new Error('No GPU device available for gated_short_conv.');
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
const recorder = options.recorder ?? null;
|
|
162
|
+
const useRecorder = recorder
|
|
163
|
+
&& typeof recorder.getEncoder === 'function'
|
|
164
|
+
&& typeof recorder.trackTemporaryBuffer === 'function';
|
|
165
|
+
|
|
166
|
+
requireGpuBuffer(inputTensor?.buffer, 'inputTensor');
|
|
167
|
+
requireGpuBuffer(layerState?.convWeightGPU, 'convWeightGPU');
|
|
168
|
+
requireGpuBuffer(layerState?.convStateGPU, 'convStateGPU');
|
|
169
|
+
|
|
170
|
+
const numTokens = Number(options.numTokens ?? 0);
|
|
171
|
+
if (!Number.isFinite(numTokens) || numTokens <= 0) {
|
|
172
|
+
throw new Error('runGatedShortConvGPU requires numTokens > 0.');
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
const hiddenSize = Number(layerState.hiddenSize ?? 0);
|
|
176
|
+
if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
|
|
177
|
+
throw new Error('runGatedShortConvGPU requires hiddenSize > 0.');
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
const kernelSize = Number(layerState.kernelSize ?? 0);
|
|
181
|
+
if (!Number.isFinite(kernelSize) || kernelSize < 2) {
|
|
182
|
+
throw new Error('runGatedShortConvGPU requires kernelSize >= 2.');
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
ensurePipeline(device);
|
|
186
|
+
|
|
187
|
+
const outputSize = numTokens * hiddenSize * Float32Array.BYTES_PER_ELEMENT;
|
|
188
|
+
const outputBuffer = acquireBuffer(outputSize, undefined, `L${options.layerIdx ?? 0}.gated_short_conv_out`);
|
|
189
|
+
|
|
190
|
+
if (useRecorder) {
|
|
191
|
+
const paramsBuffer = createUniformBufferFromData(
|
|
192
|
+
'gated_short_conv_params',
|
|
193
|
+
buildParamsData(numTokens, hiddenSize, kernelSize),
|
|
194
|
+
recorder
|
|
195
|
+
);
|
|
196
|
+
|
|
197
|
+
try {
|
|
198
|
+
const bg = device.createBindGroup({
|
|
199
|
+
label: 'gated_short_conv_bind_group',
|
|
200
|
+
layout: bindGroupLayout,
|
|
201
|
+
entries: [
|
|
202
|
+
{ binding: 0, resource: { buffer: paramsBuffer } },
|
|
203
|
+
{ binding: 1, resource: { buffer: inputTensor.buffer } },
|
|
204
|
+
{ binding: 2, resource: { buffer: layerState.convWeightGPU } },
|
|
205
|
+
{ binding: 3, resource: { buffer: layerState.convStateGPU } },
|
|
206
|
+
{ binding: 4, resource: { buffer: outputBuffer } },
|
|
207
|
+
],
|
|
208
|
+
});
|
|
209
|
+
|
|
210
|
+
recordDispatch(
|
|
211
|
+
recorder,
|
|
212
|
+
pipeline,
|
|
213
|
+
bg,
|
|
214
|
+
[Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1],
|
|
215
|
+
'gated_short_conv'
|
|
216
|
+
);
|
|
217
|
+
|
|
218
|
+
return createTensor(
|
|
219
|
+
outputBuffer,
|
|
220
|
+
'f32',
|
|
221
|
+
[numTokens, hiddenSize],
|
|
222
|
+
`L${options.layerIdx ?? 0}.gated_short_conv`
|
|
223
|
+
);
|
|
224
|
+
} catch (error) {
|
|
225
|
+
releaseBuffer(outputBuffer);
|
|
226
|
+
throw error;
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Non-recorder path
|
|
231
|
+
const paramsBuffer = createUniformBufferFromData(
|
|
232
|
+
'gated_short_conv_params',
|
|
233
|
+
buildParamsData(numTokens, hiddenSize, kernelSize),
|
|
234
|
+
null,
|
|
235
|
+
device,
|
|
236
|
+
{ useCache: false }
|
|
237
|
+
);
|
|
238
|
+
let submitted = false;
|
|
239
|
+
|
|
240
|
+
try {
|
|
241
|
+
const bg = device.createBindGroup({
|
|
242
|
+
label: 'gated_short_conv_bind_group',
|
|
243
|
+
layout: bindGroupLayout,
|
|
244
|
+
entries: [
|
|
245
|
+
{ binding: 0, resource: { buffer: paramsBuffer } },
|
|
246
|
+
{ binding: 1, resource: { buffer: inputTensor.buffer } },
|
|
247
|
+
{ binding: 2, resource: { buffer: layerState.convWeightGPU } },
|
|
248
|
+
{ binding: 3, resource: { buffer: layerState.convStateGPU } },
|
|
249
|
+
{ binding: 4, resource: { buffer: outputBuffer } },
|
|
250
|
+
],
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
const encoder = device.createCommandEncoder({ label: 'gated_short_conv' });
|
|
254
|
+
const pass = encoder.beginComputePass({ label: 'gated_short_conv_pass' });
|
|
255
|
+
pass.setPipeline(pipeline);
|
|
256
|
+
pass.setBindGroup(0, bg);
|
|
257
|
+
pass.dispatchWorkgroups(Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1);
|
|
258
|
+
pass.end();
|
|
259
|
+
device.queue.submit([encoder.finish()]);
|
|
260
|
+
submitted = true;
|
|
261
|
+
|
|
262
|
+
return createTensor(
|
|
263
|
+
outputBuffer,
|
|
264
|
+
'f32',
|
|
265
|
+
[numTokens, hiddenSize],
|
|
266
|
+
`L${options.layerIdx ?? 0}.gated_short_conv`
|
|
267
|
+
);
|
|
268
|
+
} catch (error) {
|
|
269
|
+
releaseBuffer(outputBuffer);
|
|
270
|
+
throw error;
|
|
271
|
+
} finally {
|
|
272
|
+
if (submitted) {
|
|
273
|
+
device.queue.onSubmittedWorkDone()
|
|
274
|
+
.then(() => {
|
|
275
|
+
paramsBuffer.destroy();
|
|
276
|
+
})
|
|
277
|
+
.catch(() => {
|
|
278
|
+
paramsBuffer.destroy();
|
|
279
|
+
});
|
|
280
|
+
} else {
|
|
281
|
+
paramsBuffer.destroy();
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
}
|