@simulatte/doppler 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. package/CHANGELOG.md +32 -0
  2. package/README.md +25 -6
  3. package/package.json +25 -38
  4. package/src/browser/browser-converter.js +5 -0
  5. package/src/client/doppler-api.browser.js +6 -0
  6. package/src/client/doppler-api.d.ts +3 -0
  7. package/src/client/doppler-api.js +11 -2
  8. package/src/client/doppler-registry.js +3 -5
  9. package/src/client/doppler-registry.json +2 -2
  10. package/src/config/kernel-path-loader.d.ts +5 -0
  11. package/src/config/kernel-path-loader.js +13 -0
  12. package/src/config/kernels/kernel-ref-digests.js +23 -21
  13. package/src/config/kernels/moe/mixtral.paths.json +46 -0
  14. package/src/config/kernels/registry.json +74 -0
  15. package/src/config/loader.js +9 -0
  16. package/src/config/merge-contract-check.js +7 -0
  17. package/src/config/platforms/loader.js +3 -1
  18. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
  19. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
  20. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
  21. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  22. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  23. package/src/config/presets/kernel-paths/registry.json +21 -0
  24. package/src/config/presets/models/gemma2.json +2 -1
  25. package/src/config/presets/models/gemma3.json +4 -1
  26. package/src/config/presets/models/gemma4.json +61 -0
  27. package/src/config/presets/models/granite-docling.json +70 -0
  28. package/src/config/presets/models/lfm2.json +6 -1
  29. package/src/config/presets/models/qwen3.json +4 -3
  30. package/src/config/presets/models/qwen3_5.json +16 -0
  31. package/src/config/presets/models/qwen3_vl.json +40 -0
  32. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
  33. package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
  34. package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
  35. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  36. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  37. package/src/config/presets/runtime/modes/trace-layers.json +1 -0
  38. package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
  39. package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
  40. package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
  41. package/src/config/runtime.js +3 -0
  42. package/src/config/schema/conversion.schema.d.ts +1 -0
  43. package/src/config/schema/debug.schema.d.ts +40 -0
  44. package/src/config/schema/debug.schema.js +28 -0
  45. package/src/config/schema/index.js +2 -0
  46. package/src/config/schema/inference-defaults.schema.js +1 -1
  47. package/src/config/schema/kernel-path.schema.d.ts +1 -0
  48. package/src/config/schema/manifest.schema.d.ts +1 -1
  49. package/src/config/schema/manifest.schema.js +1 -1
  50. package/src/config/schema/memory-limits.schema.js +2 -2
  51. package/src/config/schema/storage.schema.js +2 -2
  52. package/src/converter/conversion-plan.js +11 -3
  53. package/src/converter/core.js +19 -8
  54. package/src/converter/manifest-inference.js +12 -22
  55. package/src/converter/parsers/transformer.js +4 -0
  56. package/src/converter/quantization-info.js +5 -1
  57. package/src/converter/quantizer.d.ts +5 -0
  58. package/src/converter/quantizer.js +34 -12
  59. package/src/converter/rope-config.js +8 -6
  60. package/src/converter/tokenizer-utils.d.ts +1 -0
  61. package/src/converter/tokenizer-utils.js +4 -1
  62. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  63. package/src/distribution/shard-delivery.js +40 -1
  64. package/src/formats/rdrr/classification.js +32 -0
  65. package/src/formats/rdrr/parsing.d.ts +4 -0
  66. package/src/formats/rdrr/parsing.js +14 -1
  67. package/src/gpu/kernel-runtime.js +4 -2
  68. package/src/gpu/kernels/attention.js +2 -1
  69. package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
  70. package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
  71. package/src/gpu/kernels/dequant_shared.wgsl +4 -2
  72. package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
  73. package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
  74. package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
  75. package/src/gpu/kernels/gated-short-conv.js +284 -0
  76. package/src/gpu/kernels/index.d.ts +8 -0
  77. package/src/gpu/kernels/index.js +6 -0
  78. package/src/gpu/kernels/linear-attention-core.js +37 -17
  79. package/src/gpu/kernels/matmul-selection.js +48 -4
  80. package/src/gpu/kernels/matmul.d.ts +5 -0
  81. package/src/gpu/kernels/matmul.js +71 -2
  82. package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
  83. package/src/gpu/kernels/rmsnorm.js +9 -2
  84. package/src/gpu/kernels/sample.js +1 -3
  85. package/src/gpu/kernels/sample.wgsl +39 -9
  86. package/src/gpu/kernels/sample_f16.wgsl +38 -8
  87. package/src/gpu/kernels/shader-cache.js +9 -4
  88. package/src/gpu/kernels/split_qg.d.ts +50 -0
  89. package/src/gpu/kernels/split_qg.js +46 -0
  90. package/src/gpu/kernels/split_qg.wgsl +58 -0
  91. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  92. package/src/gpu/weight-buffer.d.ts +1 -1
  93. package/src/gpu/weight-buffer.js +1 -1
  94. package/src/inference/browser-harness.d.ts +2 -0
  95. package/src/inference/browser-harness.js +20 -1
  96. package/src/inference/kv-cache/base.js +3 -10
  97. package/src/inference/pipelines/diffusion/helpers.js +3 -0
  98. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  99. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +10 -3
  100. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  101. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  102. package/src/inference/pipelines/text/attention/projections.d.ts +13 -1
  103. package/src/inference/pipelines/text/attention/projections.js +54 -13
  104. package/src/inference/pipelines/text/attention/record.js +16 -6
  105. package/src/inference/pipelines/text/attention/run.js +59 -6
  106. package/src/inference/pipelines/text/config.d.ts +1 -0
  107. package/src/inference/pipelines/text/config.js +46 -4
  108. package/src/inference/pipelines/text/embed.js +26 -7
  109. package/src/inference/pipelines/text/execution-plan.js +5 -4
  110. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
  111. package/src/inference/pipelines/text/execution-v0.js +12 -1
  112. package/src/inference/pipelines/text/generator-helpers.js +1 -0
  113. package/src/inference/pipelines/text/generator-runtime.js +19 -0
  114. package/src/inference/pipelines/text/generator-steps.d.ts +15 -0
  115. package/src/inference/pipelines/text/generator-steps.js +71 -26
  116. package/src/inference/pipelines/text/generator.d.ts +5 -0
  117. package/src/inference/pipelines/text/generator.js +353 -166
  118. package/src/inference/pipelines/text/init.d.ts +15 -0
  119. package/src/inference/pipelines/text/init.js +35 -10
  120. package/src/inference/pipelines/text/layer.js +38 -8
  121. package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
  122. package/src/inference/pipelines/text/linear-attention.js +33 -3
  123. package/src/inference/pipelines/text/logits/gpu.js +2 -2
  124. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  125. package/src/inference/pipelines/text/logits/index.js +3 -1
  126. package/src/inference/pipelines/text/model-load.js +3 -0
  127. package/src/inference/pipelines/text/moe-gpu.js +21 -3
  128. package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
  129. package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
  130. package/src/inference/pipelines/text/ops.js +123 -53
  131. package/src/inference/pipelines/text/probes.js +1 -0
  132. package/src/inference/pipelines/text/sampling.js +52 -6
  133. package/src/inference/pipelines/text/state.js +2 -0
  134. package/src/inference/pipelines/text.d.ts +5 -0
  135. package/src/inference/pipelines/text.js +59 -1
  136. package/src/inference/pipelines/vision/encoder.js +386 -0
  137. package/src/inference/pipelines/vision/image-preprocess.js +151 -0
  138. package/src/inference/pipelines/vision/index.js +173 -0
  139. package/src/inference/pipelines/vision/ops.js +78 -0
  140. package/src/inference/pipelines/vision/patch-embed.js +151 -0
  141. package/src/inference/test-harness.js +11 -9
  142. package/src/loader/doppler-loader.d.ts +3 -0
  143. package/src/loader/doppler-loader.js +20 -3
  144. package/src/loader/experts/expert-cache.js +6 -2
  145. package/src/loader/experts/expert-loader.js +6 -2
  146. package/src/loader/final-weights-loader.js +2 -0
  147. package/src/loader/layer-loader.js +42 -3
  148. package/src/loader/manifest-config.js +3 -1
  149. package/src/loader/shard-cache.js +3 -2
  150. package/src/loader/tensors/tensor-loader.d.ts +3 -0
  151. package/src/loader/tensors/tensor-loader.js +130 -4
  152. package/src/rules/inference/dtype.rules.json +5 -0
  153. package/src/rules/inference/kernel-path.rules.json +2 -2
  154. package/src/rules/kernels/moe.rules.mixtral.json +75 -0
  155. package/src/rules/kernels/softmax.rules.json +2 -0
  156. package/src/rules/kernels/split-qg.rules.json +6 -0
  157. package/src/rules/rule-registry.d.ts +1 -0
  158. package/src/rules/rule-registry.js +4 -0
  159. package/src/storage/downloader.js +2 -1
  160. package/src/storage/quickstart-downloader.d.ts +3 -0
  161. package/src/storage/quickstart-downloader.js +27 -30
  162. package/src/storage/shard-manager.js +4 -3
  163. package/src/tooling/conversion-config-materializer.js +3 -5
  164. package/src/tooling/node-converter.js +28 -7
  165. package/src/tooling/node-source-runtime.js +65 -5
  166. package/src/tooling/node-webgpu.js +24 -7
  167. package/src/types/model.d.ts +5 -0
  168. package/src/utils/hf-resolve-url.d.ts +16 -0
  169. package/src/utils/hf-resolve-url.js +17 -0
  170. package/src/version.js +1 -1
  171. package/tools/doppler-cli.js +6 -1
  172. package/src/tooling/node-convert.d.ts +0 -54
@@ -1,4 +1,5 @@
1
1
  import { log } from '../debug/index.js';
2
+ import { getExpectedShardHash } from '../formats/rdrr/index.js';
2
3
  import {
3
4
  computeHash,
4
5
  createStreamingHasher,
@@ -1316,6 +1317,25 @@ async function clearPersistedShardState(shardIndex) {
1316
1317
  await writer.abort?.();
1317
1318
  }
1318
1319
 
1320
+ async function recoverHttpRejectedResumeRange(
1321
+ baseUrl,
1322
+ shardInfo,
1323
+ shardIndex,
1324
+ options,
1325
+ transferState,
1326
+ writeToStore
1327
+ ) {
1328
+ await abortHttpTransferState(transferState);
1329
+ if (writeToStore) {
1330
+ await clearPersistedShardState(shardIndex);
1331
+ }
1332
+ return downloadShardFromHttp(baseUrl, shardInfo, shardIndex, {
1333
+ ...options,
1334
+ __disablePersistedResume: true,
1335
+ __resumeRangeRecoveryAttempted: true,
1336
+ });
1337
+ }
1338
+
1319
1339
  async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {}) {
1320
1340
  const {
1321
1341
  signal,
@@ -1528,6 +1548,21 @@ async function downloadShardFromHttp(baseUrl, shardInfo, shardIndex, options = {
1528
1548
  throw error;
1529
1549
  }
1530
1550
 
1551
+ if (
1552
+ error?.status === 416
1553
+ && transferState.receivedBytes > 0
1554
+ && options.__resumeRangeRecoveryAttempted !== true
1555
+ ) {
1556
+ return recoverHttpRejectedResumeRange(
1557
+ baseUrl,
1558
+ shardInfo,
1559
+ shardIndex,
1560
+ options,
1561
+ transferState,
1562
+ writeToStore
1563
+ );
1564
+ }
1565
+
1531
1566
  if (Number.isInteger(error?.status) && error.status >= 400 && error.status < 500 && error.status !== 429) {
1532
1567
  await abortHttpTransferState(transferState);
1533
1568
  throw error;
@@ -2018,7 +2053,11 @@ export async function downloadShard(
2018
2053
  onDeliveryMetrics,
2019
2054
  signal,
2020
2055
  requiredEncoding: requiredEncoding ?? activeConfig.requiredContentEncoding ?? null,
2021
- expectedHash: options.expectedHash ?? shardInfo?.hash ?? activeConfig.expectedHash ?? null,
2056
+ expectedHash:
2057
+ options.expectedHash
2058
+ ?? getExpectedShardHash(shardInfo, algorithm)
2059
+ ?? activeConfig.expectedHash
2060
+ ?? null,
2022
2061
  expectedSize: expectedSize ?? shardInfo?.size ?? null,
2023
2062
  expectedManifestVersionSet: options.expectedManifestVersionSet ?? null,
2024
2063
  writeToStore,
@@ -32,6 +32,12 @@ export function classifyTensor(name, modelType) {
32
32
  return 'head';
33
33
  }
34
34
 
35
+ // Multimodal groups
36
+ const role = classifyTensorRole(name);
37
+ if (role === 'vision') return 'vision';
38
+ if (role === 'projector') return 'projector';
39
+ if (role === 'audio') return 'audio';
40
+
35
41
  // Extract layer index
36
42
  const layerMatch = name.match(/layers?[._](\d+)/i);
37
43
  if (!layerMatch) {
@@ -96,6 +102,29 @@ export function classifyTensorRole(name) {
96
102
  if (lower.includes('lm_head')) return 'lm_head';
97
103
  if (lower.endsWith('output.weight') && !lower.includes('attn_')) return 'lm_head';
98
104
 
105
+ // Multimodal: vision encoder tensors
106
+ if (lower.startsWith('vision_tower.') || lower.startsWith('vision_model.')
107
+ || lower.startsWith('visual.') || lower.startsWith('model.visual.')
108
+ || lower.startsWith('vision.') || lower.startsWith('model.vision.')
109
+ || lower.startsWith('vision_encoder.') || lower.startsWith('image_encoder.')
110
+ || lower.startsWith('image_tower.') || lower.startsWith('image.')
111
+ || lower.startsWith('model.image.')) {
112
+ return 'vision';
113
+ }
114
+
115
+ // Multimodal: audio encoder tensors
116
+ if (lower.startsWith('audio_tower.') || lower.startsWith('audio_model.')
117
+ || lower.startsWith('audio.') || lower.startsWith('model.audio.')
118
+ || lower.startsWith('audio_encoder.')) {
119
+ return 'audio';
120
+ }
121
+
122
+ // Multimodal: projector tensors
123
+ if (lower.startsWith('multi_modal_projector.') || lower.startsWith('model.multi_modal_projector.')
124
+ || lower.startsWith('mm_projector.') || lower.startsWith('model.mm_projector.')) {
125
+ return 'projector';
126
+ }
127
+
99
128
  if (lower.includes('shared_expert') || /experts?[._]/.test(lower)) {
100
129
  return 'expert';
101
130
  }
@@ -207,6 +236,9 @@ export function getGroupType(groupId, modelType) {
207
236
  }
208
237
  if (groupId === 'embed') return 'embed';
209
238
  if (groupId === 'head') return 'head';
239
+ if (groupId === 'vision') return 'vision';
240
+ if (groupId === 'projector') return 'projector';
241
+ if (groupId === 'audio') return 'audio';
210
242
  if (groupId === 'other') return 'layer';
211
243
 
212
244
  if (groupId.includes('.expert.')) return 'expert';
@@ -7,6 +7,10 @@
7
7
  import type { RDRRManifest, ShardInfo, TensorMap } from './types.js';
8
8
 
9
9
  export declare function parseManifest(jsonString: string): RDRRManifest;
10
+ export declare function getExpectedShardHash(
11
+ shard: Partial<ShardInfo> | Record<string, unknown> | null | undefined,
12
+ manifestHashAlgorithm?: string | null
13
+ ): string;
10
14
 
11
15
  export declare function parseTensorMap(jsonString: string): TensorMap;
12
16
 
@@ -4,6 +4,19 @@ import { validateManifest } from './validation.js';
4
4
 
5
5
  let currentManifest = null;
6
6
 
7
+ export function getExpectedShardHash(shard, manifestHashAlgorithm = null) {
8
+ if (!shard || typeof shard !== 'object' || Array.isArray(shard)) {
9
+ return '';
10
+ }
11
+ const algorithm = typeof manifestHashAlgorithm === 'string'
12
+ ? manifestHashAlgorithm.trim().toLowerCase()
13
+ : '';
14
+ if (algorithm === 'blake3') {
15
+ return shard.blake3 || shard.hash || '';
16
+ }
17
+ return shard.hash || shard.blake3 || '';
18
+ }
19
+
7
20
  export function parseManifest(jsonString) {
8
21
  let manifest;
9
22
 
@@ -21,7 +34,7 @@ export function parseManifest(jsonString) {
21
34
  index: shard.index ?? i,
22
35
  filename: shard.filename || shard.fileName || '',
23
36
  size: shard.size,
24
- hash: shard.hash || shard.blake3 || '',
37
+ hash: getExpectedShardHash(shard, manifest.hashAlgorithm),
25
38
  blake3: shard.blake3 || shard.hash,
26
39
  offset: shard.offset ?? offset,
27
40
  hashAlgorithm: shard.hashAlgorithm,
@@ -2,13 +2,15 @@
2
2
 
3
3
  import { autoTuneKernels, prewarmKernels, clearKernelCaches } from './kernels/utils.js';
4
4
  import { getRuntimeConfig } from '../config/runtime.js';
5
- import { DEFAULT_KERNEL_WARMUP_CONFIG } from '../config/schema/kernel-warmup.schema.js';
6
5
 
7
6
 
8
7
  export async function prepareKernelRuntime(
9
8
  options = {}
10
9
  ) {
11
- const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup ?? DEFAULT_KERNEL_WARMUP_CONFIG;
10
+ const kernelWarmup = getRuntimeConfig().shared?.kernelWarmup;
11
+ if (!kernelWarmup) {
12
+ throw new Error('runtime.shared.kernelWarmup is required but missing from resolved config');
13
+ }
12
14
  const {
13
15
  prewarm = kernelWarmup.prewarm,
14
16
  prewarmMode = kernelWarmup.prewarmMode,
@@ -513,9 +513,10 @@ function resolveAttentionPlan(
513
513
  useF16KV,
514
514
  useF16Q,
515
515
  numHeads,
516
+ headDim,
516
517
  kvLen,
518
+ isPaged,
517
519
  caps,
518
- headDim,
519
520
  sharedLimit
520
521
  );
521
522
  const workgroups = calculateAttentionWorkgroups(adaptiveSelection.tier, seqLen, numHeads);
@@ -113,9 +113,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
113
113
  @compute @workgroup_size(WORKGROUP_SIZE_MAIN, 1, 1)
114
114
  fn main(
115
115
  @builtin(local_invocation_id) local_id: vec3<u32>,
116
- @builtin(workgroup_id) workgroup_id: vec3<u32>
116
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
117
+ @builtin(num_workgroups) num_wg: vec3<u32>
117
118
  ) {
118
- let block_idx = workgroup_id.x;
119
+ // Support 2D dispatch for tensors with >65535 blocks.
120
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
119
121
  let elem_idx = local_id.x;
120
122
 
121
123
  if (block_idx >= u.num_blocks) {
@@ -106,9 +106,12 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
106
106
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
107
107
  fn main_vec4(
108
108
  @builtin(local_invocation_id) local_id: vec3<u32>,
109
- @builtin(workgroup_id) workgroup_id: vec3<u32>
109
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
110
+ @builtin(num_workgroups) num_wg: vec3<u32>
110
111
  ) {
111
- let block_idx = workgroup_id.x;
112
+ // Support 2D dispatch for tensors with >65535 blocks (e.g. large FFN weights).
113
+ // block_idx = flat workgroup index across both X and Y dimensions.
114
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
112
115
  let thread_idx = local_id.x;
113
116
 
114
117
  if (block_idx >= u.num_blocks) {
@@ -115,9 +115,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
115
115
  fn main(
116
116
  @builtin(global_invocation_id) global_id: vec3<u32>,
117
117
  @builtin(local_invocation_id) local_id: vec3<u32>,
118
- @builtin(workgroup_id) workgroup_id: vec3<u32>
118
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
119
+ @builtin(num_workgroups) num_wg: vec3<u32>
119
120
  ) {
120
- let block_idx = workgroup_id.x;
121
+ // Support 2D dispatch for tensors with >65535 blocks.
122
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
121
123
  let elem_idx = local_id.x;
122
124
 
123
125
  if (block_idx >= u.num_blocks) {
@@ -108,9 +108,11 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
108
108
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
109
109
  fn main_vec4(
110
110
  @builtin(local_invocation_id) local_id: vec3<u32>,
111
- @builtin(workgroup_id) workgroup_id: vec3<u32>
111
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
112
+ @builtin(num_workgroups) num_wg: vec3<u32>
112
113
  ) {
113
- let block_idx = workgroup_id.x;
114
+ // Support 2D dispatch for tensors with >65535 blocks.
115
+ let block_idx = workgroup_id.x + workgroup_id.y * num_wg.x;
114
116
  let thread_idx = local_id.x;
115
117
 
116
118
  if (block_idx >= u.num_blocks) {
@@ -118,11 +118,15 @@ fn get_q4(qs: array<u32, 32>, idx: u32) -> u32 {
118
118
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
119
119
  fn main(
120
120
  @builtin(global_invocation_id) global_id: vec3<u32>,
121
+ @builtin(num_workgroups) num_wg: vec3<u32>,
121
122
  @builtin(subgroup_invocation_id) sg_id: u32,
122
123
  @builtin(subgroup_size) sg_size: u32
123
124
  ) {
124
- let block_idx = global_id.x / QK_K;
125
- let elem_idx = global_id.x % QK_K;
125
+ // Support 2D dispatch for tensors with >65535 workgroups.
126
+ // Compute flat global thread id across both X and Y dimensions.
127
+ let flat_global_x = global_id.x + global_id.y * num_wg.x * WORKGROUP_SIZE;
128
+ let block_idx = flat_global_x / QK_K;
129
+ let elem_idx = flat_global_x % QK_K;
126
130
 
127
131
  // Use block 0 for out-of-bounds threads to maintain uniform control flow
128
132
  // (required for subgroup operations)
@@ -0,0 +1,63 @@
1
+ /**
2
+ * LFM2 gated short convolution kernel.
3
+ *
4
+ * Fuses B*x pre-gating, depthwise causal conv1d, and C*conv_out post-gating
5
+ * into a single GPU dispatch. Each thread handles one channel across all tokens
6
+ * sequentially, maintaining persistent conv state for autoregressive decode.
7
+ */
8
+
9
+ /** Per-layer state maintained between calls. */
10
+ export interface GatedShortConvLayerState {
11
+ /** Pre-dequantized conv1d weights as GPUBuffer, shape [hiddenSize, kernelSize]. */
12
+ convWeightGPU: GPUBuffer;
13
+
14
+ /** Persistent conv state as GPUBuffer, shape [hiddenSize, kernelSize - 1]. */
15
+ convStateGPU: GPUBuffer;
16
+
17
+ /** Number of channels (hidden dimension). */
18
+ hiddenSize: number;
19
+
20
+ /** Conv1d kernel width (e.g., 4). */
21
+ kernelSize: number;
22
+ }
23
+
24
+ /** Tensor returned by the kernel. */
25
+ export interface Tensor {
26
+ buffer: GPUBuffer;
27
+ dtype: string;
28
+ shape: readonly number[];
29
+ label: string;
30
+ }
31
+
32
+ /** Options for runGatedShortConvGPU. */
33
+ export interface GatedShortConvOptions {
34
+ /** Number of tokens in this batch. Required. */
35
+ numTokens?: number;
36
+
37
+ /** Layer index for labeling/tracing. */
38
+ layerIdx?: number;
39
+
40
+ /** Command recorder for batched submission. */
41
+ recorder?: {
42
+ getEncoder(): GPUCommandEncoder;
43
+ trackTemporaryBuffer(buffer: GPUBuffer): void;
44
+ beginComputePass(label: string): GPUComputePassEncoder;
45
+ createUniformBuffer(data: ArrayBuffer, label: string): GPUBuffer;
46
+ device: GPUDevice;
47
+ } | null;
48
+ }
49
+
50
+ /**
51
+ * Run the LFM2 gated short convolution on GPU.
52
+ *
53
+ * @param inputTensor Tensor with shape [numTokens, 3 * hiddenSize] containing
54
+ * concatenated B, C, x from in_proj matmul output.
55
+ * @param layerState Persistent per-layer state (conv weights + conv state buffer).
56
+ * @param options Dispatch options.
57
+ * @returns Output tensor with shape [numTokens, hiddenSize].
58
+ */
59
+ export function runGatedShortConvGPU(
60
+ inputTensor: Tensor,
61
+ layerState: GatedShortConvLayerState,
62
+ options?: GatedShortConvOptions
63
+ ): Promise<Tensor>;
@@ -0,0 +1,284 @@
1
+ import { getDevice, getDeviceEpoch } from '../device.js';
2
+ import { WORKGROUP_SIZES } from './constants.js';
3
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { createTensor } from '../tensor.js';
5
+ import {
6
+ createUniformBufferFromData,
7
+ getOrCreateBindGroupLayout,
8
+ getOrCreatePipelineLayout,
9
+ } from './utils.js';
10
+ import { recordDispatch } from './dispatch.js';
11
+
12
+ const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
13
+
14
+ const SHADER = /* wgsl */ `
15
+ override WORKGROUP_SIZE: u32 = 256u;
16
+
17
+ struct Params {
18
+ num_tokens: u32,
19
+ hidden_size: u32,
20
+ kernel_size: u32,
21
+ _pad: u32,
22
+ }
23
+
24
+ @group(0) @binding(0) var<uniform> params: Params;
25
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
26
+ @group(0) @binding(2) var<storage, read> conv_weight: array<f32>;
27
+ @group(0) @binding(3) var<storage, read_write> conv_state: array<f32>;
28
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
29
+
30
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
31
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
+ let channel = gid.x;
33
+ if (channel >= params.hidden_size) {
34
+ return;
35
+ }
36
+
37
+ let hidden_size = params.hidden_size;
38
+ let kernel_size = params.kernel_size;
39
+ let state_width = kernel_size - 1u;
40
+ let row_stride = 3u * hidden_size;
41
+ let state_base = channel * state_width;
42
+ let weight_base = channel * kernel_size;
43
+
44
+ for (var t: u32 = 0u; t < params.num_tokens; t = t + 1u) {
45
+ let row_offset = t * row_stride;
46
+
47
+ let b_val = input[row_offset + channel];
48
+ let c_val = input[row_offset + hidden_size + channel];
49
+ let x_val = input[row_offset + 2u * hidden_size + channel];
50
+
51
+ let bx = b_val * x_val;
52
+
53
+ var conv_sum: f32 = 0.0;
54
+ for (var k: u32 = 0u; k < state_width; k = k + 1u) {
55
+ conv_sum = conv_sum + conv_state[state_base + k] * conv_weight[weight_base + k];
56
+ }
57
+ conv_sum = conv_sum + bx * conv_weight[weight_base + state_width];
58
+
59
+ for (var k: u32 = 0u; k + 1u < state_width; k = k + 1u) {
60
+ conv_state[state_base + k] = conv_state[state_base + k + 1u];
61
+ }
62
+ if (state_width > 0u) {
63
+ conv_state[state_base + state_width - 1u] = bx;
64
+ }
65
+
66
+ output[t * hidden_size + channel] = c_val * conv_sum;
67
+ }
68
+ }
69
+ `;
70
+
71
+ // ======================================================================
72
+ // UNIFORM BUFFER
73
+ // ======================================================================
74
+
75
+ const UNIFORM_LAYOUT = {
76
+ numTokens: { offset: 0, size: 4 },
77
+ hiddenSize: { offset: 4, size: 4 },
78
+ kernelSize: { offset: 8, size: 4 },
79
+ _pad: { offset: 12, size: 4 },
80
+ };
81
+
82
+ const UNIFORM_SIZE = 16;
83
+
84
+ function buildParamsData(numTokens, hiddenSize, kernelSize) {
85
+ const data = new ArrayBuffer(UNIFORM_SIZE);
86
+ const view = new DataView(data);
87
+ view.setUint32(UNIFORM_LAYOUT.numTokens.offset, numTokens, true);
88
+ view.setUint32(UNIFORM_LAYOUT.hiddenSize.offset, hiddenSize, true);
89
+ view.setUint32(UNIFORM_LAYOUT.kernelSize.offset, kernelSize, true);
90
+ view.setUint32(UNIFORM_LAYOUT._pad.offset, 0, true);
91
+ return data;
92
+ }
93
+
94
+ // ======================================================================
95
+ // PIPELINE CACHE
96
+ // ======================================================================
97
+
98
+ let cachedEpoch = -1;
99
+ let pipeline = null;
100
+ let bindGroupLayout = null;
101
+
102
+ function createPipeline(device) {
103
+ bindGroupLayout = getOrCreateBindGroupLayout(
104
+ 'gated_short_conv_layout',
105
+ [
106
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
107
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
108
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
109
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
110
+ { binding: 4, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
111
+ ],
112
+ device
113
+ );
114
+
115
+ const module = device.createShaderModule({
116
+ label: 'gated_short_conv',
117
+ code: SHADER,
118
+ });
119
+
120
+ pipeline = device.createComputePipeline({
121
+ label: 'gated_short_conv_pipeline',
122
+ layout: getOrCreatePipelineLayout('gated_short_conv_pipeline_layout', [bindGroupLayout], device),
123
+ compute: {
124
+ module,
125
+ entryPoint: 'main',
126
+ constants: {
127
+ WORKGROUP_SIZE: CONV_WORKGROUP_SIZE,
128
+ },
129
+ },
130
+ });
131
+ }
132
+
133
+ function ensurePipeline(device) {
134
+ const epoch = getDeviceEpoch();
135
+ if (epoch !== cachedEpoch || !pipeline) {
136
+ createPipeline(device);
137
+ cachedEpoch = epoch;
138
+ }
139
+ }
140
+
141
+ // ======================================================================
142
+ // VALIDATION
143
+ // ======================================================================
144
+
145
+ function requireGpuBuffer(buffer, label) {
146
+ if (!(buffer instanceof GPUBuffer)) {
147
+ throw new Error(`gated_short_conv kernel requires GPUBuffer for ${label}.`);
148
+ }
149
+ }
150
+
151
+ // ======================================================================
152
+ // DISPATCH
153
+ // ======================================================================
154
+
155
+ export async function runGatedShortConvGPU(inputTensor, layerState, options = {}) {
156
+ const device = getDevice();
157
+ if (!device) {
158
+ throw new Error('No GPU device available for gated_short_conv.');
159
+ }
160
+
161
+ const recorder = options.recorder ?? null;
162
+ const useRecorder = recorder
163
+ && typeof recorder.getEncoder === 'function'
164
+ && typeof recorder.trackTemporaryBuffer === 'function';
165
+
166
+ requireGpuBuffer(inputTensor?.buffer, 'inputTensor');
167
+ requireGpuBuffer(layerState?.convWeightGPU, 'convWeightGPU');
168
+ requireGpuBuffer(layerState?.convStateGPU, 'convStateGPU');
169
+
170
+ const numTokens = Number(options.numTokens ?? 0);
171
+ if (!Number.isFinite(numTokens) || numTokens <= 0) {
172
+ throw new Error('runGatedShortConvGPU requires numTokens > 0.');
173
+ }
174
+
175
+ const hiddenSize = Number(layerState.hiddenSize ?? 0);
176
+ if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
177
+ throw new Error('runGatedShortConvGPU requires hiddenSize > 0.');
178
+ }
179
+
180
+ const kernelSize = Number(layerState.kernelSize ?? 0);
181
+ if (!Number.isFinite(kernelSize) || kernelSize < 2) {
182
+ throw new Error('runGatedShortConvGPU requires kernelSize >= 2.');
183
+ }
184
+
185
+ ensurePipeline(device);
186
+
187
+ const outputSize = numTokens * hiddenSize * Float32Array.BYTES_PER_ELEMENT;
188
+ const outputBuffer = acquireBuffer(outputSize, undefined, `L${options.layerIdx ?? 0}.gated_short_conv_out`);
189
+
190
+ if (useRecorder) {
191
+ const paramsBuffer = createUniformBufferFromData(
192
+ 'gated_short_conv_params',
193
+ buildParamsData(numTokens, hiddenSize, kernelSize),
194
+ recorder
195
+ );
196
+
197
+ try {
198
+ const bg = device.createBindGroup({
199
+ label: 'gated_short_conv_bind_group',
200
+ layout: bindGroupLayout,
201
+ entries: [
202
+ { binding: 0, resource: { buffer: paramsBuffer } },
203
+ { binding: 1, resource: { buffer: inputTensor.buffer } },
204
+ { binding: 2, resource: { buffer: layerState.convWeightGPU } },
205
+ { binding: 3, resource: { buffer: layerState.convStateGPU } },
206
+ { binding: 4, resource: { buffer: outputBuffer } },
207
+ ],
208
+ });
209
+
210
+ recordDispatch(
211
+ recorder,
212
+ pipeline,
213
+ bg,
214
+ [Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1],
215
+ 'gated_short_conv'
216
+ );
217
+
218
+ return createTensor(
219
+ outputBuffer,
220
+ 'f32',
221
+ [numTokens, hiddenSize],
222
+ `L${options.layerIdx ?? 0}.gated_short_conv`
223
+ );
224
+ } catch (error) {
225
+ releaseBuffer(outputBuffer);
226
+ throw error;
227
+ }
228
+ }
229
+
230
+ // Non-recorder path
231
+ const paramsBuffer = createUniformBufferFromData(
232
+ 'gated_short_conv_params',
233
+ buildParamsData(numTokens, hiddenSize, kernelSize),
234
+ null,
235
+ device,
236
+ { useCache: false }
237
+ );
238
+ let submitted = false;
239
+
240
+ try {
241
+ const bg = device.createBindGroup({
242
+ label: 'gated_short_conv_bind_group',
243
+ layout: bindGroupLayout,
244
+ entries: [
245
+ { binding: 0, resource: { buffer: paramsBuffer } },
246
+ { binding: 1, resource: { buffer: inputTensor.buffer } },
247
+ { binding: 2, resource: { buffer: layerState.convWeightGPU } },
248
+ { binding: 3, resource: { buffer: layerState.convStateGPU } },
249
+ { binding: 4, resource: { buffer: outputBuffer } },
250
+ ],
251
+ });
252
+
253
+ const encoder = device.createCommandEncoder({ label: 'gated_short_conv' });
254
+ const pass = encoder.beginComputePass({ label: 'gated_short_conv_pass' });
255
+ pass.setPipeline(pipeline);
256
+ pass.setBindGroup(0, bg);
257
+ pass.dispatchWorkgroups(Math.ceil(hiddenSize / CONV_WORKGROUP_SIZE), 1, 1);
258
+ pass.end();
259
+ device.queue.submit([encoder.finish()]);
260
+ submitted = true;
261
+
262
+ return createTensor(
263
+ outputBuffer,
264
+ 'f32',
265
+ [numTokens, hiddenSize],
266
+ `L${options.layerIdx ?? 0}.gated_short_conv`
267
+ );
268
+ } catch (error) {
269
+ releaseBuffer(outputBuffer);
270
+ throw error;
271
+ } finally {
272
+ if (submitted) {
273
+ device.queue.onSubmittedWorkDone()
274
+ .then(() => {
275
+ paramsBuffer.destroy();
276
+ })
277
+ .catch(() => {
278
+ paramsBuffer.destroy();
279
+ });
280
+ } else {
281
+ paramsBuffer.destroy();
282
+ }
283
+ }
284
+ }
@@ -326,6 +326,14 @@ export {
326
326
  type SplitQKVResult,
327
327
  } from './split_qkv.js';
328
328
 
329
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
330
+ export {
331
+ runSplitQG,
332
+ recordSplitQG,
333
+ type SplitQGOptions,
334
+ type SplitQGResult,
335
+ } from './split_qg.js';
336
+
329
337
  // Transpose
330
338
  export {
331
339
  runTranspose,
@@ -268,6 +268,12 @@ export {
268
268
  recordSplitQKV,
269
269
  } from './split_qkv.js';
270
270
 
271
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
272
+ export {
273
+ runSplitQG,
274
+ recordSplitQG,
275
+ } from './split_qg.js';
276
+
271
277
  // Transpose
272
278
  export {
273
279
  runTranspose,