@simulatte/doppler 0.1.8 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/CHANGELOG.md +14 -1
  2. package/README.md +25 -6
  3. package/package.json +5 -3
  4. package/src/client/doppler-api.browser.js +6 -0
  5. package/src/client/doppler-api.d.ts +3 -0
  6. package/src/client/doppler-api.js +11 -2
  7. package/src/client/doppler-registry.js +3 -5
  8. package/src/client/doppler-registry.json +16 -0
  9. package/src/config/kernels/kernel-ref-digests.js +23 -21
  10. package/src/config/kernels/moe/mixtral.paths.json +46 -0
  11. package/src/config/loader.js +6 -0
  12. package/src/config/platforms/loader.js +3 -1
  13. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
  14. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
  15. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
  16. package/src/config/presets/kernel-paths/registry.json +7 -0
  17. package/src/config/presets/models/gemma3.json +2 -1
  18. package/src/config/presets/models/gemma4.json +61 -0
  19. package/src/config/presets/models/granite-docling.json +70 -0
  20. package/src/config/presets/models/lfm2.json +6 -1
  21. package/src/config/presets/models/qwen3_vl.json +40 -0
  22. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
  23. package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
  24. package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
  25. package/src/config/presets/runtime/modes/trace-layers.json +1 -0
  26. package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
  27. package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
  28. package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
  29. package/src/config/runtime.js +3 -0
  30. package/src/config/schema/debug.schema.d.ts +40 -0
  31. package/src/config/schema/debug.schema.js +28 -0
  32. package/src/config/schema/index.js +2 -0
  33. package/src/config/schema/inference-defaults.schema.js +1 -1
  34. package/src/config/schema/kernel-path.schema.d.ts +1 -0
  35. package/src/config/schema/memory-limits.schema.js +2 -2
  36. package/src/config/schema/storage.schema.js +1 -1
  37. package/src/converter/conversion-plan.js +1 -1
  38. package/src/converter/core.js +17 -8
  39. package/src/converter/quantizer.d.ts +5 -0
  40. package/src/converter/quantizer.js +15 -0
  41. package/src/distribution/shard-delivery.js +34 -0
  42. package/src/formats/rdrr/classification.js +32 -0
  43. package/src/gpu/kernel-runtime.js +4 -2
  44. package/src/gpu/kernels/attention.js +2 -1
  45. package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
  46. package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
  47. package/src/gpu/kernels/dequant_shared.wgsl +4 -2
  48. package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
  49. package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
  50. package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
  51. package/src/gpu/kernels/gated-short-conv.js +284 -0
  52. package/src/gpu/kernels/linear-attention-core.js +37 -17
  53. package/src/gpu/kernels/matmul-selection.js +1 -0
  54. package/src/gpu/kernels/matmul.d.ts +3 -0
  55. package/src/gpu/kernels/matmul.js +70 -1
  56. package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
  57. package/src/gpu/kernels/sample.js +1 -3
  58. package/src/gpu/kernels/sample.wgsl +39 -9
  59. package/src/gpu/kernels/sample_f16.wgsl +38 -8
  60. package/src/gpu/kernels/shader-cache.js +9 -4
  61. package/src/inference/kv-cache/base.js +3 -10
  62. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  63. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +2 -1
  64. package/src/inference/pipelines/text/attention/projections.d.ts +3 -0
  65. package/src/inference/pipelines/text/attention/projections.js +13 -2
  66. package/src/inference/pipelines/text/attention/record.js +1 -0
  67. package/src/inference/pipelines/text/attention/run.js +9 -0
  68. package/src/inference/pipelines/text/config.d.ts +1 -0
  69. package/src/inference/pipelines/text/config.js +32 -4
  70. package/src/inference/pipelines/text/embed.js +26 -7
  71. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
  72. package/src/inference/pipelines/text/execution-v0.js +12 -1
  73. package/src/inference/pipelines/text/generator-helpers.js +1 -0
  74. package/src/inference/pipelines/text/generator-runtime.js +14 -0
  75. package/src/inference/pipelines/text/generator-steps.d.ts +9 -0
  76. package/src/inference/pipelines/text/generator-steps.js +46 -29
  77. package/src/inference/pipelines/text/generator.d.ts +5 -0
  78. package/src/inference/pipelines/text/generator.js +320 -166
  79. package/src/inference/pipelines/text/init.d.ts +2 -0
  80. package/src/inference/pipelines/text/init.js +19 -5
  81. package/src/inference/pipelines/text/layer.js +37 -8
  82. package/src/inference/pipelines/text/moe-gpu.js +21 -3
  83. package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
  84. package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
  85. package/src/inference/pipelines/text/ops.js +123 -53
  86. package/src/inference/pipelines/text/probes.js +1 -0
  87. package/src/inference/pipelines/text/state.js +2 -0
  88. package/src/inference/pipelines/text.d.ts +5 -0
  89. package/src/inference/pipelines/text.js +59 -1
  90. package/src/inference/pipelines/vision/encoder.js +386 -0
  91. package/src/inference/pipelines/vision/image-preprocess.js +151 -0
  92. package/src/inference/pipelines/vision/index.js +173 -0
  93. package/src/inference/pipelines/vision/ops.js +78 -0
  94. package/src/inference/pipelines/vision/patch-embed.js +151 -0
  95. package/src/inference/test-harness.js +9 -7
  96. package/src/loader/doppler-loader.d.ts +3 -0
  97. package/src/loader/doppler-loader.js +20 -3
  98. package/src/loader/experts/expert-cache.js +6 -2
  99. package/src/loader/experts/expert-loader.js +6 -2
  100. package/src/loader/layer-loader.js +42 -3
  101. package/src/loader/manifest-config.js +3 -1
  102. package/src/loader/tensors/tensor-loader.d.ts +3 -0
  103. package/src/loader/tensors/tensor-loader.js +124 -3
  104. package/src/rules/kernels/moe.rules.mixtral.json +75 -0
  105. package/src/rules/kernels/softmax.rules.json +2 -0
  106. package/src/rules/rule-registry.d.ts +1 -0
  107. package/src/rules/rule-registry.js +2 -0
  108. package/src/storage/quickstart-downloader.d.ts +3 -0
  109. package/src/storage/quickstart-downloader.js +27 -30
  110. package/src/tooling/node-converter.js +25 -7
  111. package/src/tooling/node-source-runtime.js +29 -5
  112. package/src/tooling/node-webgpu.js +24 -7
  113. package/src/utils/hf-resolve-url.d.ts +16 -0
  114. package/src/utils/hf-resolve-url.js +17 -0
  115. package/src/version.js +1 -1
  116. package/src/tooling/node-convert.d.ts +0 -54
@@ -16,7 +16,7 @@ export const DEFAULT_QUOTA_CONFIG = {
16
16
 
17
17
  export const DEFAULT_VRAM_ESTIMATION_CONFIG = {
18
18
  unifiedMemoryRatio: 0.5, // 50% of system RAM
19
- fallbackVramBytes: 2 * GB,
19
+ fallbackVramBytes: 4 * GB,
20
20
  lowVramHeadroomBytes: 500 * MB,
21
21
  };
22
22
 
@@ -473,7 +473,7 @@ export function resolveConversionPlan(options) {
473
473
  // role dtypes should not change kernel-path selection when explicit compute precision is targeted.
474
474
  const embedDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'embedding'));
475
475
  const lmHeadDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'lm_head'));
476
- const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder']);
476
+ const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder', 'visual.']);
477
477
  const hasAudio = hasAnyTensorPattern(tensors, ['audio_', 'audio_encoder', 'whisper', 'wav2vec']);
478
478
  const hasProjector = hasAnyTensorPattern(tensors, ['multi_modal_projector', 'mm_projector', 'projector']);
479
479
  const quantizationInfo = buildQuantizationInfo(
@@ -114,6 +114,15 @@ export function resolveTensorTargetQuant(tensorName, fallbackQuant, quantization
114
114
  const headQuant = quantizationInfo.lmHead ?? quantizationInfo.embeddings ?? fallback;
115
115
  return normalizeStorageQuant(headQuant) ?? fallback;
116
116
  }
117
+ if (role === 'vision') {
118
+ return normalizeStorageQuant(quantizationInfo.vision ?? fallback) ?? fallback;
119
+ }
120
+ if (role === 'projector') {
121
+ return normalizeStorageQuant(quantizationInfo.projector ?? fallback) ?? fallback;
122
+ }
123
+ if (role === 'audio') {
124
+ return normalizeStorageQuant(quantizationInfo.audio ?? fallback) ?? fallback;
125
+ }
117
126
  return normalizeStorageQuant(quantizationInfo.weights ?? fallback) ?? fallback;
118
127
  }
119
128
 
@@ -819,11 +828,11 @@ export function extractArchitecture(config, ggufConfig) {
819
828
  vocabSize,
820
829
  maxSeqLen,
821
830
  ropeTheta,
822
- linearNumKeyHeads: linearNumKeyHeads ?? undefined,
823
- linearNumValueHeads: linearNumValueHeads ?? undefined,
824
- linearKeyHeadDim: linearKeyHeadDim ?? undefined,
825
- linearValueHeadDim: linearValueHeadDim ?? undefined,
826
- linearConvKernelDim: linearConvKernelDim ?? undefined,
831
+ linearNumKeyHeads,
832
+ linearNumValueHeads,
833
+ linearKeyHeadDim,
834
+ linearValueHeadDim,
835
+ linearConvKernelDim,
827
836
  linearNormMode,
828
837
  };
829
838
  }
@@ -1056,7 +1065,7 @@ export function createManifest(
1056
1065
  modelId,
1057
1066
  modelType: resolvedModelType,
1058
1067
  quantization: resolvedQuantization,
1059
- quantizationInfo: options.quantizationInfo ?? undefined,
1068
+ quantizationInfo: options.quantizationInfo,
1060
1069
  architecture: resolvedArchitecture,
1061
1070
  moeConfig,
1062
1071
  inference,
@@ -1065,8 +1074,8 @@ export function createManifest(
1065
1074
  totalSize: shards.reduce((sum, s) => sum + s.size, 0),
1066
1075
  hashAlgorithm,
1067
1076
  eos_token_id: eosTokenId,
1068
- config: isDiffusion ? rawConfig : undefined,
1069
- conversion: options.conversionInfo ?? undefined,
1077
+ config: isDiffusion ? rawConfig : (rawConfig.vision_config ? { vision_config: rawConfig.vision_config } : undefined),
1078
+ conversion: options.conversionInfo,
1070
1079
  metadata: {
1071
1080
  source,
1072
1081
  convertedAt: resolveConvertedAt(
@@ -73,6 +73,11 @@ export declare function dequantizeQ4KM(
73
73
  shape: number[]
74
74
  ): Float32Array;
75
75
 
76
+ export declare function dequantizeQ4KMRowWise(
77
+ quantized: Uint8Array,
78
+ shape: [number, number]
79
+ ): Float32Array;
80
+
76
81
  export declare function calculateQuantizationError(
77
82
  original: Float32Array,
78
83
  reconstructed: Float32Array
@@ -355,6 +355,21 @@ export function dequantizeQ4KM(quantized, numBlocks, shape) {
355
355
  return result;
356
356
  }
357
357
 
358
+ export function dequantizeQ4KMRowWise(quantized, shape) {
359
+ const [rows, cols] = shape;
360
+ const blocksPerRow = Math.ceil(cols / QK_K);
361
+ const result = new Float32Array(rows * cols);
362
+
363
+ for (let row = 0; row < rows; row++) {
364
+ const rowOffset = row * blocksPerRow * QK4_K_BLOCK_SIZE;
365
+ const rowBytes = quantized.slice(rowOffset, rowOffset + (blocksPerRow * QK4_K_BLOCK_SIZE));
366
+ const rowDequantized = dequantizeQ4KM(rowBytes, blocksPerRow, [1, cols]);
367
+ result.set(rowDequantized, row * cols);
368
+ }
369
+
370
+ return result;
371
+ }
372
+
358
373
  export function calculateQuantizationError(original, reconstructed) {
359
374
  if (original.length !== reconstructed.length) {
360
375
  throw new Error('Length mismatch');
@@ -1317,6 +1317,25 @@ async function clearPersistedShardState(shardIndex) {
1317
1317
  await writer.abort?.();
1318
1318
  }
1319
1319
 
1320
+ async function recoverHttpRejectedResumeRange(
1321
+ baseUrl,
1322
+ shardInfo,
1323
+ shardIndex,
1324
+ options,
1325
+ transferState,
1326
+ writeToStore
1327
+ ) {
1328
+ await abortHttpTransferState(transferState);
1329
+ if (writeToStore) {
1330
+ await clearPersistedShardState(shardIndex);
1331
+ }
1332
+ return downloadShardFromHttp(baseUrl, shardInfo, shardIndex, {
1333
+ ...options,
1334
+ __disablePersistedResume: true,
1335
+ __resumeRangeRecoveryAttempted: true,
1336
+ });
1337
+ }
1338
+
1320
1339
  async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {}) {
1321
1340
  const {
1322
1341
  signal,
@@ -1529,6 +1548,21 @@ async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {
1529
1548
  throw error;
1530
1549
  }
1531
1550
 
1551
+ if (
1552
+ error?.status === 416
1553
+ && transferState.receivedBytes > 0
1554
+ && options.__resumeRangeRecoveryAttempted !== true
1555
+ ) {
1556
+ return recoverHttpRejectedResumeRange(
1557
+ baseUrl,
1558
+ shardInfo,
1559
+ shardIndex,
1560
+ options,
1561
+ transferState,
1562
+ writeToStore
1563
+ );
1564
+ }
1565
+
1532
1566
  if (Number.isInteger(error?.status) && error.status >= 400 && error.status < 500 && error.status !== 429) {
1533
1567
  await abortHttpTransferState(transferState);
1534
1568
  throw error;
@@ -32,6 +32,12 @@ export function classifyTensor(name, modelType) {
32
32
  return 'head';
33
33
  }
34
34
 
35
+ // Multimodal groups
36
+ const role = classifyTensorRole(name);
37
+ if (role === 'vision') return 'vision';
38
+ if (role === 'projector') return 'projector';
39
+ if (role === 'audio') return 'audio';
40
+
35
41
  // Extract layer index
36
42
  const layerMatch = name.match(/layers?[._](\d+)/i);
37
43
  if (!layerMatch) {
@@ -96,6 +102,29 @@ export function classifyTensorRole(name) {
96
102
  if (lower.includes('lm_head')) return 'lm_head';
97
103
  if (lower.endsWith('output.weight') && !lower.includes('attn_')) return 'lm_head';
98
104
 
105
+ // Multimodal: vision encoder tensors
106
+ if (lower.startsWith('vision_tower.') || lower.startsWith('vision_model.')
107
+ || lower.startsWith('visual.') || lower.startsWith('model.visual.')
108
+ || lower.startsWith('vision.') || lower.startsWith('model.vision.')
109
+ || lower.startsWith('vision_encoder.') || lower.startsWith('image_encoder.')
110
+ || lower.startsWith('image_tower.') || lower.startsWith('image.')
111
+ || lower.startsWith('model.image.')) {
112
+ return 'vision';
113
+ }
114
+
115
+ // Multimodal: audio encoder tensors
116
+ if (lower.startsWith('audio_tower.') || lower.startsWith('audio_model.')
117
+ || lower.startsWith('audio.') || lower.startsWith('model.audio.')
118
+ || lower.startsWith('audio_encoder.')) {
119
+ return 'audio';
120
+ }
121
+
122
+ // Multimodal: projector tensors
123
+ if (lower.startsWith('multi_modal_projector.') || lower.startsWith('model.multi_modal_projector.')
124
+ || lower.startsWith('mm_projector.') || lower.startsWith('model.mm_projector.')) {
125
+ return 'projector';
126
+ }
127
+
99
128
  if (lower.includes('shared_expert') || /experts?[._]/.test(lower)) {
100
129
  return 'expert';
101
130
  }
@@ -207,6 +236,9 @@ export function getGroupType(groupId, modelType) {
207
236
  }
208
237
  if (groupId === 'embed') return 'embed';
209
238
  if (groupId === 'head') return 'head';
239
+ if (groupId === 'vision') return 'vision';
240
+ if (groupId === 'projector') return 'projector';
241
+ if (groupId === 'audio') return 'audio';
210
242
  if (groupId === 'other') return 'layer';
211
243
 
212
244
  if (groupId.includes('.expert.')) return 'expert';
@@ -2,13 +2,15 @@
2
2
 
3
3
  import { autoTuneKernels, prewarmKernels, clearKernelCaches } from './kernels/utils.js';
4
4
  import { getRuntimeConfig } from '../config/runtime.js';
5
- import { DEFAULT_KERNEL_WARMUP_CONFIG } from '../config/schema/kernel-warmup.schema.js';
6
5
 
7
6
 
8
7
  export async function prepareKernelRuntime(
9
8
  options = {}
10
9
  ) {
11
- const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup ?? DEFAULT_KERNEL_WARMUP_CONFIG;
10
+ const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup;
11
+ if (!kernelWarmup) {
12
+ throw new Error('runtime.shared.kernelWarmup is required but missing from resolved config');
13
+ }
12
14
  const {
13
15
  prewarm = kernelWarmup.prewarm,
14
16
  prewarmMode = kernelWarmup.prewarmMode,
@@ -513,9 +513,10 @@ function resolveAttentionPlan(
513
513
  useF16KV,
514
514
  useF16Q,
515
515
  numHeads,
516
+ headDim,
516
517
  kvLen,
518
+ isPaged,
517
519
  caps,
518
- headDim,
519
520
  sharedLimit
520
521
  );
521
522
  const workgroups = calculateAttentionWorkgroups(adaptiveSelection.tier, seqLen, numHeads);
@@ -113,9 +113,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
113
113
  @compute @workgroup_size(WORKGROUP_SIZE_MAIN, 1, 1)
114
114
  fn main(
115
115
  @builtin(local_invocation_id) local_id: vec3<u32>,
116
- @builtin(workgroup_id) workgroup_id: vec3<u32>
116
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
117
+ @builtin(num_workgroups) num_wg: vec3<u32>
117
118
  ) {
118
- let block_idx = workgroup_id.x;
119
+ // Support 2D dispatch for tensors with >65535 blocks.
120
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
119
121
  let elem_idx = local_id.x;
120
122
 
121
123
  if (block_idx >= u.num_blocks) {
@@ -106,9 +106,12 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
106
106
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
107
107
  fn main_vec4(
108
108
  @builtin(local_invocation_id) local_id: vec3<u32>,
109
- @builtin(workgroup_id) workgroup_id: vec3<u32>
109
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
110
+ @builtin(num_workgroups) num_wg: vec3<u32>
110
111
  ) {
111
- let block_idx = workgroup_id.x;
112
+ // Support 2D dispatch for tensors with >65535 blocks (e.g. large FFN weights).
113
+ // block_idx = flat workgroup index across both X and Y dimensions.
114
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
112
115
  let thread_idx = local_id.x;
113
116
 
114
117
  if (block_idx >= u.num_blocks) {
@@ -115,9 +115,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
115
115
  fn main(
116
116
  @builtin(global_invocation_id) global_id: vec3<u32>,
117
117
  @builtin(local_invocation_id) local_id: vec3<u32>,
118
- @builtin(workgroup_id) workgroup_id: vec3<u32>
118
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
119
+ @builtin(num_workgroups) num_wg: vec3<u32>
119
120
  ) {
120
- let block_idx = workgroup_id.x;
121
+ // Support 2D dispatch for tensors with >65535 blocks.
122
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
121
123
  let elem_idx = local_id.x;
122
124
 
123
125
  if (block_idx >= u.num_blocks) {
@@ -108,9 +108,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
108
108
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
109
109
  fn main_vec4(
110
110
  @builtin(local_invocation_id) local_id: vec3<u32>,
111
- @builtin(workgroup_id) workgroup_id: vec3<u32>
111
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
112
+ @builtin(num_workgroups) num_wg: vec3<u32>
112
113
  ) {
113
- let block_idx = workgroup_id.x;
114
+ // Support 2D dispatch for tensors with >65535 blocks.
115
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
114
116
  let thread_idx = local_id.x;
115
117
 
116
118
  if (block_idx >= u.num_blocks) {
@@ -118,11 +118,15 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
118
118
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
119
119
  fn main(
120
120
  @builtin(global_invocation_id) global_id: vec3<u32>,
121
+ @builtin(num_workgroups) num_wg: vec3<u32>,
121
122
  @builtin(subgroup_invocation_id) sg_id: u32,
122
123
  @builtin(subgroup_size) sg_size: u32
123
124
  ) {
124
- let block_idx = global_id.x / QK_K;
125
- let elem_idx = global_id.x % QK_K;
125
+ // Support 2D dispatch for tensors with >65535 workgroups.
126
+ // Compute flat global thread id across both X and Y dimensions.
127
+ let flat_global_x = global_id.x + global_id.y * num_wg.x * WORKGROUP_SIZE;
128
+ let block_idx = flat_global_x / QK_K;
129
+ let elem_idx = flat_global_x % QK_K;
126
130
 
127
131
  // Use block 0 for out-of-bounds threads to maintain uniform control flow
128
132
  // (required for subgroup operations)
@@ -0,0 +1,63 @@
1
+ /**
2
+ * LFM2 gated short convolution kernel.
3
+ *
4
+ * Fuses B*x pre-gating, depthwise causal conv1d, and C*conv_out post-gating
5
+ * into a single GPU dispatch. Each thread handles one channel across all tokens
6
+ * sequentially, maintaining persistent conv state for autoregressive decode.
7
+ */
8
+
9
+ /** Per-layer state maintained between calls. */
10
+ export interface GatedShortConvLayerState {
11
+ /** Pre-dequantized conv1d weights as GPUBuffer, shape [hiddenSize, kernelSize]. */
12
+ convWeightGPU: GPUBuffer;
13
+
14
+ /** Persistent conv state as GPUBuffer, shape [hiddenSize, kernelSize - 1]. */
15
+ convStateGPU: GPUBuffer;
16
+
17
+ /** Number of channels (hidden dimension). */
18
+ hiddenSize: number;
19
+
20
+ /** Conv1d kernel width (e.g., 4). */
21
+ kernelSize: number;
22
+ }
23
+
24
+ /** Tensor returned by the kernel. */
25
+ export interface Tensor {
26
+ buffer: GPUBuffer;
27
+ dtype: string;
28
+ shape: readonly number[];
29
+ label: string;
30
+ }
31
+
32
+ /** Options for runGatedShortConvGPU. */
33
+ export interface GatedShortConvOptions {
34
+ /** Number of tokens in this batch. Required. */
35
+ numTokens?: number;
36
+
37
+ /** Layer index for labeling/tracing. */
38
+ layerIdx?: number;
39
+
40
+ /** Command recorder for batched submission. */
41
+ recorder?: {
42
+ getEncoder(): GPUCommandEncoder;
43
+ trackTemporaryBuffer(buffer: GPUBuffer): void;
44
+ beginComputePass(label: string): GPUComputePassEncoder;
45
+ createUniformBuffer(data: ArrayBuffer, label: string): GPUBuffer;
46
+ device: GPUDevice;
47
+ } | null;
48
+ }
49
+
50
+ /**
51
+ * Run the LFM2 gated short convolution on GPU.
52
+ *
53
+ * @param inputTensor Tensor with shape [numTokens, 3 * hiddenSize] containing
54
+ * concatenated B, C, x from in_proj matmul output.
55
+ * @param layerState Persistent per-layer state (conv weights + conv state buffer).
56
+ * @param options Dispatch options.
57
+ * @returns Output tensor with shape [numTokens, hiddenSize].
58
+ */
59
+ export function runGatedShortConvGPU(
60
+ inputTensor: Tensor,
61
+ layerState: GatedShortConvLayerState,
62
+ options?: GatedShortConvOptions
63
+ ): Promise<Tensor>;
@@ -0,0 +1,284 @@
1
+ import { getDevice, getDeviceEpoch } from '../device.js';
2
+ import { WORKGROUP_SIZES } from './constants.js';
3
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { createTensor } from '../tensor.js';
5
+ import {
6
+ createUniformBufferFromData,
7
+ getOrCreateBindGroupLayout,
8
+ getOrCreatePipelineLayout,
9
+ } from './utils.js';
10
+ import { recordDispatch } from './dispatch.js';
11
+
12
+ const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
13
+
14
+ const SHADER = /* wgsl */ `
15
+ override WORKGROUP_SIZE: u32 = 256u;
16
+
17
+ struct Params {
18
+ num_tokens: u32,
19
+ hidden_size: u32,
20
+ kernel_size: u32,
21
+ _pad: u32,
22
+ }
23
+
24
+ @group(0) @binding(0) var<uniform> params: Params;
25
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
26
+ @group(0) @binding(2) var<storage, read> conv_weight: array<f32>;
27
+ @group(0) @binding(3) var<storage, read_write> conv_state: array<f32>;
28
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
29
+
30
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
31
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
+ let channel = gid.x;
33
+ if (channel >= params.hidden_size) {
34
+ return;
35
+ }
36
+
37
+ let hidden_size = params.hidden_size;
38
+ let kernel_size = params.kernel_size;
39
+ let state_width = kernel_size - 1u;
40
+ let row_stride = 3u * hidden_size;
41
+ let state_base = channel * state_width;
42
+ let weight_base = channel * kernel_size;
43
+
44
+ for (var t: u32 = 0u; t < params.num_tokens; t = t + 1u) {
45
+ let row_offset = t * row_stride;
46
+
47
+ let b_val = input[row_offset + channel];
48
+ let c_val = input[row_offset + hidden_size + channel];
49
+ let x_val = input[row_offset + 2u * hidden_size + channel];
50
+
51
+ let bx = b_val * x_val;
52
+
53
+ var conv_sum: f32 = 0.0;
54
+ for (var k: u32 = 0u; k < state_width; k = k + 1u) {
55
+ conv_sum = conv_sum + conv_state[state_base + k] * conv_weight[weight_base + k];
56
+ }
57
+ conv_sum = conv_sum + bx * conv_weight[weight_base + state_width];
58
+
59
+ for (var k: u32 = 0u; k + 1u < state_width; k = k + 1u) {
60
+ conv_state[state_base + k] = conv_state[state_base + k + 1u];
61
+ }
62
+ if (state_width > 0u) {
63
+ conv_state[state_base + state_width - 1u] = bx;
64
+ }
65
+
66
+ output[t * hidden_size + channel] = c_val * conv_sum;
67
+ }
68
+ }
69
+ `;
70
+
71
+ // ======================================================================
72
+ // UNIFORM BUFFER
73
+ // ======================================================================
74
+
75
+ const UNIFORM_LAYOUT = {
76
+ numTokens: { offset: 0, size: 4 },
77
+ hiddenSize: { offset: 4, size: 4 },
78
+ kernelSize: { offset: 8, size: 4 },
79
+ _pad: { offset: 12, size: 4 },
80
+ };
81
+
82
+ const UNIFORM_SIZE = 16;
83
+
84
+ function buildParamsData(numTokens, hiddenSize, kernelSize) {
85
+ const data = new ArrayBuffer(UNIFORM_SIZE);
86
+ const view = new DataView(data);
87
+ view.setUint32(UNIFORM_LAYOUT.numTokens.offset, numTokens, true);
88
+ view.setUint32(UNIFORM_LAYOUT.hiddenSize.offset, hiddenSize, true);
89
+ view.setUint32(UNIFORM_LAYOUT.kernelSize.offset, kernelSize, true);
90
+ view.setUint32(UNIFORM_LAYOUT._pad.offset, 0, true);
91
+ return data;
92
+ }
93
+
94
+ // ======================================================================
95
+ // PIPELINE CACHE
96
+ // ======================================================================
97
+
98
+ let cachedEpoch = -1;
99
+ let pipeline = null;
100
+ let bindGroupLayout = null;
101
+
102
+ function createPipeline(device) {
103
+ bindGroupLayout = getOrCreateBindGroupLayout(
104
+ 'gated_short_conv_layout',
105
+ [
106
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
107
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
108
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
109
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
110
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
111
+ ],
112
+ device
113
+ );
114
+
115
+ const module = device.createShaderModule({
116
+ label: 'gated_short_conv',
117
+ code: SHADER,
118
+ });
119
+
120
+ pipeline = device.createComputePipeline({
121
+ label: 'gated_short_conv_pipeline',
122
+ layout: getOrCreatePipelineLayout('gated_short_conv_pipeline_layout', [bindGroupLayout], device),
123
+ compute: {
124
+ module,
125
+ entryPoint: 'main',
126
+ constants: {
127
+ WORKGROUP_SIZE: CONV_WORKGROUP_SIZE,
128
+ },
129
+ },
130
+ });
131
+ }
132
+
133
+ function ensurePipeline(device) {
134
+ const epoch = getDeviceEpoch();
135
+ if (epoch !== cachedEpoch || !pipeline) {
136
+ createPipeline(device);
137
+ cachedEpoch = epoch;
138
+ }
139
+ }
140
+
141
+ // ======================================================================
142
+ // VALIDATION
143
+ // ======================================================================
144
+
145
+ function requireGpuBuffer(buffer, label) {
146
+ if (!(buffer instanceof GPUBuffer)) {
147
+ throw new Error(`gated_short_conv kernel requires GPUBuffer for ${label}.`);
148
+ }
149
+ }
150
+
151
+ // ======================================================================
152
+ // DISPATCH
153
+ // ======================================================================
154
+
155
+ export async function runGatedShortConvGPU(inputTensor, layerState, options = {}) {
156
+ const device = getDevice();
157
+ if (!device) {
158
+ throw new Error('No GPU device available for gated_short_conv.');
159
+ }
160
+
161
+ const recorder = options.recorder ?? null;
162
+ const useRecorder = recorder
163
+ && typeof recorder.getEncoder === 'function'
164
+ && typeof recorder.trackTemporaryBuffer === 'function';
165
+
166
+ requireGpuBuffer(inputTensor?.buffer, 'inputTensor');
167
+ requireGpuBuffer(layerState?.convWeightGPU, 'convWeightGPU');
168
+ requireGpuBuffer(layerState?.convStateGPU, 'convStateGPU');
169
+
170
+ const numTokens = Number(options.numTokens ?? 0);
171
+ if (!Number.isFinite(numTokens) || numTokens <= 0) {
172
+ throw new Error('runGatedShortConvGPU requires numTokens > 0.');
173
+ }
174
+
175
+ const hiddenSize = Number(layerState.hiddenSize ?? 0);
176
+ if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
177
+ throw new Error('runGatedShortConvGPU requires hiddenSize > 0.');
178
+ }
179
+
180
+ const kernelSize = Number(layerState.kernelSize ?? 0);
181
+ if (!Number.isFinite(kernelSize) || kernelSize < 2) {
182
+ throw new Error('runGatedShortConvGPU requires kernelSize >= 2.');
183
+ }
184
+
185
+ ensurePipeline(device);
186
+
187
+ const outputSize = numTokens * hiddenSize * Float32Array.BYTES_PER_ELEMENT;
188
+ const outputBuffer = acquireBuffer(outputSize, undefined, `L${options.layerIdx ?? 0}.gated_short_conv_out`);
189
+
190
+ if (useRecorder) {
191
+ const paramsBuffer = createUniformBufferFromData(
192
+ 'gated_short_conv_params',
193
+ buildParamsData(numTokens, hiddenSize, kernelSize),
194
+ recorder
195
+ );
196
+
197
+ try {
198
+ const bg = device.createBindGroup({
199
+ label: 'gated_short_conv_bind_group',
200
+ layout: bindGroupLayout,
201
+ entries: [
202
+ { binding: 0, resource: { buffer: paramsBuffer } },
203
+ { binding: 1, resource: { buffer: inputTensor.buffer } },
204
+ { binding: 2, resource: { buffer: layerState.convWeightGPU } },
205
+ { binding: 3, resource: { buffer: layerState.convStateGPU } },
206
+ { binding: 4, resource: { buffer: outputBuffer } },
207
+ ],
208
+ });
209
+
210
+ recordDispatch(
211
+ recorder,
212
+ pipeline,
213
+ bg,
214
+ [Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1],
215
+ 'gated_short_conv'
216
+ );
217
+
218
+ return createTensor(
219
+ outputBuffer,
220
+ 'f32',
221
+ [numTokens, hiddenSize],
222
+ `L${options.layerIdx ?? 0}.gated_short_conv`
223
+ );
224
+ } catch (error) {
225
+ releaseBuffer(outputBuffer);
226
+ throw error;
227
+ }
228
+ }
229
+
230
+ // Non-recorder path
231
+ const paramsBuffer = createUniformBufferFromData(
232
+ 'gated_short_conv_params',
233
+ buildParamsData(numTokens, hiddenSize, kernelSize),
234
+ null,
235
+ device,
236
+ { useCache: false }
237
+ );
238
+ let submitted = false;
239
+
240
+ try {
241
+ const bg = device.createBindGroup({
242
+ label: 'gated_short_conv_bind_group',
243
+ layout: bindGroupLayout,
244
+ entries: [
245
+ { binding: 0, resource: { buffer: paramsBuffer } },
246
+ { binding: 1, resource: { buffer: inputTensor.buffer } },
247
+ { binding: 2, resource: { buffer: layerState.convWeightGPU } },
248
+ { binding: 3, resource: { buffer: layerState.convStateGPU } },
249
+ { binding: 4, resource: { buffer: outputBuffer } },
250
+ ],
251
+ });
252
+
253
+ const encoder = device.createCommandEncoder({ label: 'gated_short_conv' });
254
+ const pass = encoder.beginComputePass({ label: 'gated_short_conv_pass' });
255
+ pass.setPipeline(pipeline);
256
+ pass.setBindGroup(0, bg);
257
+ pass.dispatchWorkgroups(Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1);
258
+ pass.end();
259
+ device.queue.submit([encoder.finish()]);
260
+ submitted = true;
261
+
262
+ return createTensor(
263
+ outputBuffer,
264
+ 'f32',
265
+ [numTokens, hiddenSize],
266
+ `L${options.layerIdx ?? 0}.gated_short_conv`
267
+ );
268
+ } catch (error) {
269
+ releaseBuffer(outputBuffer);
270
+ throw error;
271
+ } finally {
272
+ if (submitted) {
273
+ device.queue.onSubmittedWorkDone()
274
+ .then(() => {
275
+ paramsBuffer.destroy();
276
+ })
277
+ .catch(() => {
278
+ paramsBuffer.destroy();
279
+ });
280
+ } else {
281
+ paramsBuffer.destroy();
282
+ }
283
+ }
284
+ }