@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -3,19 +3,32 @@ import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { WORKGROUP_SIZES } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
5
5
 
6
+ function planTransposeDispatch(target, cols) {
7
+ const device = target?.device;
8
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
9
+ ? device.limits.maxComputeWorkgroupsPerDimension
10
+ : 65535;
11
+ const dispatchStride = Math.min(cols, maxPerDim * WORKGROUP_SIZES.DEFAULT);
12
+ return {
13
+ dispatchStride,
14
+ workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
15
+ };
16
+ }
17
+
6
18
  async function _transpose(target, input, rows, cols, options = {}) {
7
19
  const { outputBuffer = null } = options;
8
20
  const bytesPerElement = dtypeBytes(input.dtype);
9
21
  const outputSize = rows * cols * bytesPerElement;
10
22
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'transpose_output');
23
+ const dispatchPlan = planTransposeDispatch(target, cols);
11
24
 
12
25
  await unifiedKernelWrapper(
13
26
  'transpose',
14
27
  target,
15
28
  'default',
16
29
  [input, outputBuf],
17
- { rows, cols },
18
- Math.ceil((rows * cols) / WORKGROUP_SIZES.DEFAULT)
30
+ { rows, cols, _pad0: dispatchPlan.dispatchStride, _pad1: 0 },
31
+ [dispatchPlan.workgroups[0], rows, 1]
19
32
  );
20
33
 
21
34
  return createTensor(outputBuf, input.dtype, [cols, rows], 'transpose_output');
@@ -19,14 +19,13 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
- let total = u.rows * u.cols;
24
- if (idx >= total) {
22
+ let dispatch_stride = max(u._pad0, 1u);
23
+ let row = gid.y;
24
+ let col = gid.x + row * dispatch_stride;
25
+ if (row >= u.rows || col >= u.cols) {
25
26
  return;
26
27
  }
27
-
28
- let row = idx / u.cols;
29
- let col = idx % u.cols;
28
+ let idx = row * u.cols + col;
30
29
  let out_idx = col * u.rows + row;
31
30
  output[out_idx] = input[idx];
32
31
  }
@@ -31,6 +31,7 @@ async function _upsample2d(target, input, options = {}) {
31
31
 
32
32
  const outHeight = resolvedHeight * scale;
33
33
  const outWidth = resolvedWidth * scale;
34
+ const outSpatial = outHeight * outWidth;
34
35
  const bytesPerElement = dtypeBytes(input.dtype);
35
36
  const outputSize = channels * outHeight * outWidth * bytesPerElement;
36
37
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'upsample2d_output');
@@ -43,7 +44,7 @@ async function _upsample2d(target, input, options = {}) {
43
44
  out_height: outHeight, out_width: outWidth, scale,
44
45
  _pad0: 0, _pad1: 0,
45
46
  },
46
- Math.ceil((channels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
47
+ [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
47
48
  );
48
49
 
49
50
  return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'upsample2d_output');
@@ -19,19 +19,16 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let out_spatial = u.out_height * u.out_width;
24
- let total = u.channels * out_spatial;
25
- if (idx >= total) {
23
+ let spatial_idx = gid.x;
24
+ let channel = gid.y;
25
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
26
26
  return;
27
27
  }
28
-
29
- let channel = idx / out_spatial;
30
- let rem = idx - channel * out_spatial;
31
- let out_y = rem / u.out_width;
32
- let out_x = rem - out_y * u.out_width;
28
+ let out_y = spatial_idx / u.out_width;
29
+ let out_x = spatial_idx - out_y * u.out_width;
33
30
  let in_y = out_y / u.scale;
34
31
  let in_x = out_x / u.scale;
35
32
  let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
36
- output[idx] = input[in_idx];
33
+ output[channel * out_spatial + spatial_idx] = input[in_idx];
37
34
  }
@@ -23,19 +23,16 @@ struct Uniforms {
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
27
26
  let out_spatial = u.out_height * u.out_width;
28
- let total = u.channels * out_spatial;
29
- if (idx >= total) {
27
+ let spatial_idx = gid.x;
28
+ let channel = gid.y;
29
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
30
30
  return;
31
31
  }
32
-
33
- let channel = idx / out_spatial;
34
- let rem = idx - channel * out_spatial;
35
- let out_y = rem / u.out_width;
36
- let out_x = rem - out_y * u.out_width;
32
+ let out_y = spatial_idx / u.out_width;
33
+ let out_x = spatial_idx - out_y * u.out_width;
37
34
  let in_y = out_y / u.scale;
38
35
  let in_x = out_x / u.scale;
39
36
  let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
40
- output[idx] = input[in_idx];
37
+ output[channel * out_spatial + spatial_idx] = input[in_idx];
41
38
  }
@@ -116,9 +116,24 @@ export async function unifiedKernelWrapper(opName, target, variant, bindings, un
116
116
  index = config.variantMetadata.outputBinding;
117
117
  }
118
118
 
119
+ const buffer = binding?.buffer || binding;
120
+ const isGpuBuffer = buffer && (
121
+ typeof GPUBuffer === 'undefined'
122
+ ? true
123
+ : buffer instanceof GPUBuffer
124
+ );
125
+ if (!isGpuBuffer) {
126
+ const bindingLabel = binding?.label ?? buffer?.label ?? 'unknown';
127
+ const bufferType = buffer === null ? 'null' : buffer === undefined ? 'undefined' : buffer.constructor?.name || typeof buffer;
128
+ throw new Error(
129
+ `Kernel "${opName}/${variant}" binding "${bindingConfig.name}" (index ${index}) requires a GPUBuffer ` +
130
+ `(label=${bindingLabel}, type=${bufferType}).`
131
+ );
132
+ }
133
+
119
134
  bindGroupEntries.push({
120
135
  binding: index,
121
- resource: { buffer: binding?.buffer || binding }
136
+ resource: { buffer }
122
137
  });
123
138
  }
124
139
 
@@ -929,6 +929,9 @@ async function resolveHarnessOverride(options = {}) {
929
929
 
930
930
  async function initializeSuiteModel(options = {}) {
931
931
  if (options.harnessOverride) {
932
+ if (options.runtime?.runtimeConfig) {
933
+ setRuntimeConfig(options.runtime.runtimeConfig);
934
+ }
932
935
  return resolveHarnessOverride(options);
933
936
  }
934
937
  const loadStart = performance.now();
@@ -988,6 +991,14 @@ async function runKernelSuite(options = {}) {
988
991
 
989
992
  const DEFAULT_HARNESS_PROMPT = 'Summarize this input in one sentence.';
990
993
  const DEFAULT_RUNTIME_PLACEHOLDER_PROMPT = 'Hello from Doppler.';
994
+ const DEFAULT_QWEN_PROMPT = Object.freeze({
995
+ messages: Object.freeze([
996
+ Object.freeze({
997
+ role: 'user',
998
+ content: 'Answer in one short sentence: What color is the sky on a clear day?',
999
+ }),
1000
+ ]),
1001
+ });
991
1002
  const DEFAULT_TRANSLATEGEMMA_PROMPT = Object.freeze({
992
1003
  messages: Object.freeze([
993
1004
  Object.freeze({
@@ -1273,6 +1284,9 @@ function resolvePromptTemplateType(source) {
1273
1284
  }
1274
1285
 
1275
1286
  function buildDefaultGenerationPrompt(templateType) {
1287
+ if (templateType === 'qwen') {
1288
+ return clonePromptInput(DEFAULT_QWEN_PROMPT);
1289
+ }
1276
1290
  if (templateType === 'translategemma') {
1277
1291
  return clonePromptInput(DEFAULT_TRANSLATEGEMMA_PROMPT);
1278
1292
  }
@@ -1280,7 +1294,7 @@ function buildDefaultGenerationPrompt(templateType) {
1280
1294
  }
1281
1295
 
1282
1296
  function shouldPreferModelDefaultPrompt(runtimePrompt, templateType) {
1283
- if (templateType !== 'translategemma') {
1297
+ if (templateType !== 'translategemma' && templateType !== 'qwen') {
1284
1298
  return false;
1285
1299
  }
1286
1300
  if (typeof runtimePrompt !== 'string') {
@@ -1289,6 +1303,31 @@ function shouldPreferModelDefaultPrompt(runtimePrompt, templateType) {
1289
1303
  return runtimePrompt.trim() === DEFAULT_RUNTIME_PLACEHOLDER_PROMPT;
1290
1304
  }
1291
1305
 
1306
+ function assertPromptContract(runtimePrompt, templateType, source = 'runtime.inference.prompt') {
1307
+ if (templateType !== 'translategemma') {
1308
+ return;
1309
+ }
1310
+ if (runtimePrompt === undefined || runtimePrompt === null) {
1311
+ return;
1312
+ }
1313
+ if (typeof runtimePrompt === 'string') {
1314
+ const trimmed = runtimePrompt.trim();
1315
+ if (!trimmed || trimmed === DEFAULT_RUNTIME_PLACEHOLDER_PROMPT) {
1316
+ return;
1317
+ }
1318
+ throw new Error(
1319
+ `TranslateGemma harness prompt contract violation: ${source} must be ` +
1320
+ '{ messages: [...] } with source_lang_code/target_lang_code blocks, not a plain string.'
1321
+ );
1322
+ }
1323
+ if (!isStructuredPromptInput(runtimePrompt)) {
1324
+ throw new Error(
1325
+ `TranslateGemma harness prompt contract violation: ${source} must be ` +
1326
+ '{ messages: [...] } with source_lang_code/target_lang_code blocks.'
1327
+ );
1328
+ }
1329
+ }
1330
+
1292
1331
  function describePromptInput(promptInput) {
1293
1332
  if (typeof promptInput === 'string') {
1294
1333
  return promptInput.trim() || DEFAULT_HARNESS_PROMPT;
@@ -1305,6 +1344,11 @@ function describePromptInput(promptInput) {
1305
1344
  if (sourceLang && targetLang) {
1306
1345
  return `${sourceLang} -> ${targetLang}: ${text || '[non-text request]'}`;
1307
1346
  }
1347
+ const stringContent = asText(firstMessage?.content);
1348
+ if (stringContent) {
1349
+ const role = asText(firstMessage?.role) || 'user';
1350
+ return `${role}: ${stringContent}`;
1351
+ }
1308
1352
  try {
1309
1353
  return JSON.stringify(promptInput);
1310
1354
  } catch {
@@ -1315,6 +1359,7 @@ function describePromptInput(promptInput) {
1315
1359
  function resolveGenerationPromptInput(runtimeConfig, runOverrides = null, source = null) {
1316
1360
  const templateType = resolvePromptTemplateType(source);
1317
1361
  const overridePrompt = runOverrides?.prompt;
1362
+ assertPromptContract(overridePrompt, templateType, 'runOverrides.prompt');
1318
1363
  if (typeof overridePrompt === 'string' && overridePrompt.trim()) {
1319
1364
  return overridePrompt.trim();
1320
1365
  }
@@ -1323,6 +1368,7 @@ function resolveGenerationPromptInput(runtimeConfig, runOverrides = null, source
1323
1368
  }
1324
1369
 
1325
1370
  const runtimePrompt = runtimeConfig?.inference?.prompt;
1371
+ assertPromptContract(runtimePrompt, templateType, 'runtimeConfig.inference.prompt');
1326
1372
  if (shouldPreferModelDefaultPrompt(runtimePrompt, templateType)) {
1327
1373
  return buildDefaultGenerationPrompt(templateType);
1328
1374
  }
@@ -52,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
52
52
  return { latents, latentWidth, latentHeight };
53
53
  }
54
54
 
55
+ function generateNoiseVector(size, seed) {
56
+ if (!Number.isFinite(size) || size <= 0) {
57
+ throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
58
+ }
59
+ const out = new Float32Array(size);
60
+ const rand = createRng(seed ?? createRandomSeed());
61
+ for (let i = 0; i < size; i++) {
62
+ out[i] = sampleNormal(rand);
63
+ }
64
+ return out;
65
+ }
66
+
55
67
  function extractTokenSet(tokensByEncoder, key) {
56
68
  const output = {};
57
69
  for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
@@ -195,13 +207,10 @@ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep,
195
207
  const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
196
208
  const noise = isFinalStep
197
209
  ? null
198
- : generateLatents(
199
- runtime.latent.width,
200
- runtime.latent.height,
201
- runtime.latent.channels,
202
- runtime.latent.scale,
210
+ : generateNoiseVector(
211
+ sample.length,
203
212
  (options.seedBase ?? createRandomSeed()) + stepIndex + 1
204
- ).latents;
213
+ );
205
214
  const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
206
215
  return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
207
216
  }
@@ -80,3 +80,8 @@ export declare function projectContext(
80
80
  ): Promise<Tensor>;
81
81
 
82
82
  export declare function assertClipHiddenActivationSupported(config: { hidden_act?: string }): void;
83
+
84
+ export declare function resolveGemma2WeightRoot(
85
+ weights: Map<string, any>,
86
+ prefix?: string
87
+ ): string;
@@ -723,8 +723,19 @@ function buildGemma2LayerTypes(layerCount, slidingWindow) {
723
723
  ));
724
724
  }
725
725
 
726
- function getGemma2LayerWeight(weights, prefix, layerIdx, suffix, required = true) {
727
- const key = `${prefix}.model.layers.${layerIdx}.${suffix}`;
726
+ export function resolveGemma2WeightRoot(weights, prefix = 'text_encoder') {
727
+ const nestedRoot = `${prefix}.model`;
728
+ if (weights?.has(`${nestedRoot}.embed_tokens.weight`)) {
729
+ return nestedRoot;
730
+ }
731
+ if (weights?.has(`${prefix}.embed_tokens.weight`)) {
732
+ return prefix;
733
+ }
734
+ return nestedRoot;
735
+ }
736
+
737
+ function getGemma2LayerWeight(weights, weightRoot, layerIdx, suffix, required = true) {
738
+ const key = `${weightRoot}.layers.${layerIdx}.${suffix}`;
728
739
  const weight = weights.get(key) || null;
729
740
  if (!weight && required) {
730
741
  throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
@@ -805,8 +816,9 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
805
816
  const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
806
817
  const numTokens = tokenIds.length;
807
818
  const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
819
+ const weightRoot = resolveGemma2WeightRoot(weights, prefix);
808
820
 
809
- const embedKey = `${prefix}.model.embed_tokens.weight`;
821
+ const embedKey = `${weightRoot}.embed_tokens.weight`;
810
822
  const embedWeight = expectDiffusionWeight(
811
823
  weights.get(embedKey),
812
824
  embedKey
@@ -837,16 +849,16 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
837
849
  const layerWeights = new Map();
838
850
  for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
839
851
  layerWeights.set(`layer_${layerIdx}`, {
840
- inputNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'input_layernorm.weight'),
841
- qProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.q_proj.weight'),
842
- kProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.k_proj.weight'),
843
- vProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.v_proj.weight'),
844
- oProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.o_proj.weight'),
845
- postAttentionNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'post_attention_layernorm.weight'),
846
- preFeedforwardNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'pre_feedforward_layernorm.weight'),
847
- gate: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.gate_proj.weight'),
848
- up: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.up_proj.weight'),
849
- down: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.down_proj.weight'),
852
+ inputNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'input_layernorm.weight'),
853
+ qProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.q_proj.weight'),
854
+ kProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.k_proj.weight'),
855
+ vProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.v_proj.weight'),
856
+ oProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.o_proj.weight'),
857
+ postAttentionNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'post_attention_layernorm.weight'),
858
+ preFeedforwardNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'pre_feedforward_layernorm.weight'),
859
+ gate: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.gate_proj.weight'),
860
+ up: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.up_proj.weight'),
861
+ down: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.down_proj.weight'),
850
862
  });
851
863
  }
852
864
 
@@ -910,10 +922,10 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
910
922
  numTokens * resolved.hiddenSize,
911
923
  context
912
924
  );
913
- hidden = createTensor(output.buffer, output.dtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
925
+ hidden = createTensor(output, activationDtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
914
926
  }
915
927
 
916
- const finalNormKey = `${prefix}.model.norm.weight`;
928
+ const finalNormKey = `${weightRoot}.norm.weight`;
917
929
  const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
918
930
  const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
919
931
  batchSize: numTokens,
@@ -182,10 +182,18 @@ export async function recordLayerAttentionGPU(
182
182
  // 3. RoPE (modifies tensor in-place)
183
183
  if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
184
184
  await recordRoPE(recorder, qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
185
- numHeads, headDim, startPos: currentSeqLen,
185
+ numHeads,
186
+ headDim,
187
+ rotaryDim: config.ropeRotaryDim,
188
+ interleaved: config.ropeInterleaved,
189
+ startPos: currentSeqLen,
186
190
  });
187
191
  await recordRoPE(recorder, kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
188
- numHeads: numKVHeads, headDim, startPos: currentSeqLen,
192
+ numHeads: numKVHeads,
193
+ headDim,
194
+ rotaryDim: config.ropeRotaryDim,
195
+ interleaved: config.ropeInterleaved,
196
+ startPos: currentSeqLen,
189
197
  });
190
198
  }
191
199
 
@@ -502,6 +510,7 @@ export async function recordLayerAttentionGPU(
502
510
  size: numTokens * numHeads * headDim,
503
511
  gate: qGateTensor,
504
512
  gateActivation: 'sigmoid',
513
+ inputActivation: 'identity',
505
514
  swigluLimit: null,
506
515
  });
507
516
  recorder.trackTemporaryBuffer(attnOutput.buffer);
@@ -299,10 +299,18 @@ export async function runLayerAttentionGPU(
299
299
 
300
300
  if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
301
301
  await runRoPE(qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
302
- numHeads, headDim, startPos: currentSeqLen,
302
+ numHeads,
303
+ headDim,
304
+ rotaryDim: config.ropeRotaryDim,
305
+ interleaved: config.ropeInterleaved,
306
+ startPos: currentSeqLen,
303
307
  });
304
308
  await runRoPE(kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
305
- numHeads: numKVHeads, headDim, startPos: currentSeqLen,
309
+ numHeads: numKVHeads,
310
+ headDim,
311
+ rotaryDim: config.ropeRotaryDim,
312
+ interleaved: config.ropeInterleaved,
313
+ startPos: currentSeqLen,
306
314
  });
307
315
 
308
316
  // Trace RoPE outputs
@@ -690,6 +698,7 @@ export async function runLayerAttentionGPU(
690
698
  size: numTokens * numHeads * headDim,
691
699
  gate: qGateTensor,
692
700
  gateActivation: 'sigmoid',
701
+ inputActivation: 'identity',
693
702
  swigluLimit: null,
694
703
  });
695
704
  releaseBuffer(attnOutput.buffer);
@@ -224,6 +224,29 @@ function formatChatML(messages) {
224
224
  return parts.join('');
225
225
  }
226
226
 
227
+ function formatQwen(messages) {
228
+ // Qwen 3.5 chat format is ChatML-like, but the generation prelude includes
229
+ // an explicit empty thinking block before assistant output.
230
+ const parts = [];
231
+ for (const [index, message] of messages.entries()) {
232
+ const role = normalizeChatRole(message?.role);
233
+ assertSupportedChatRole(role, 'Qwen', index);
234
+ if (role === 'system' && index !== 0) {
235
+ throw new Error('Qwen template requires any system message to appear first.');
236
+ }
237
+ const content = normalizeChatMessageContent(message?.content);
238
+ if (role === 'system') {
239
+ parts.push(`<|im_start|>system\n${content}<|im_end|>\n`);
240
+ } else if (role === 'user') {
241
+ parts.push(`<|im_start|>user\n${content}<|im_end|>\n`);
242
+ } else if (role === 'assistant') {
243
+ parts.push(`<|im_start|>assistant\n${content}<|im_end|>\n`);
244
+ }
245
+ }
246
+ parts.push('<|im_start|>assistant\n<think>\n\n</think>\n\n');
247
+ return parts.join('');
248
+ }
249
+
227
250
  function formatTranslateGemmaUserPrompt(content) {
228
251
  if (!Array.isArray(content) || content.length !== 1) {
229
252
  throw new Error(
@@ -345,7 +368,7 @@ const CHAT_FORMATTERS = {
345
368
  'llama3': formatHeaderBased,
346
369
  'gpt-oss': formatChannelBased,
347
370
  'chatml': formatChatML,
348
- 'qwen': formatChatML,
371
+ 'qwen': formatQwen,
349
372
  'translategemma': formatTranslateGemma,
350
373
  };
351
374
 
@@ -363,4 +386,5 @@ export function formatChatMessages(messages, templateType) {
363
386
  export const formatGemmaChat = formatTurnBased;
364
387
  export const formatLlama3Chat = formatHeaderBased;
365
388
  export const formatGptOssChat = formatChannelBased;
389
+ export const formatQwenChat = formatQwen;
366
390
  export const formatTranslateGemmaChat = formatTranslateGemma;
@@ -148,6 +148,10 @@ export interface ParsedModelConfig {
148
148
  slidingWindow: number | null;
149
149
  ropeTheta: number;
150
150
  ropeLocalTheta: number | null;
151
+ ropeRotaryDim: number;
152
+ ropeInterleaved: boolean;
153
+ mropeSection: number[] | null;
154
+ partialRotaryFactor: number | null;
151
155
  ropeScale: number;
152
156
  ropeLocalScale: number;
153
157
  ropeScalingType: string | null;
@@ -21,6 +21,28 @@ function assertSupportedRuntimeModelType(manifest) {
21
21
  );
22
22
  }
23
23
 
24
+ function resolveRotaryDim(headDim, partialRotaryFactor, modelId) {
25
+ if (partialRotaryFactor == null) {
26
+ return headDim;
27
+ }
28
+ if (typeof partialRotaryFactor !== 'number' || Number.isNaN(partialRotaryFactor)) {
29
+ throw new Error(`Manifest "${modelId}" has invalid rope.partialRotaryFactor.`);
30
+ }
31
+ if (partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
32
+ throw new Error(
33
+ `Manifest "${modelId}" requires 0 < rope.partialRotaryFactor <= 1; got ${partialRotaryFactor}.`
34
+ );
35
+ }
36
+ const rotaryDim = Math.trunc(headDim * partialRotaryFactor);
37
+ if (rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
38
+ throw new Error(
39
+ `Manifest "${modelId}" resolves rope rotary dim ${rotaryDim} from headDim=${headDim} ` +
40
+ `and partialRotaryFactor=${partialRotaryFactor}, but rotary dim must be a positive even integer.`
41
+ );
42
+ }
43
+ return rotaryDim;
44
+ }
45
+
24
46
  export function getStopTokenIds(manifest) {
25
47
  const eosTokenId = manifest?.eos_token_id;
26
48
  if (Array.isArray(eosTokenId)) return eosTokenId;
@@ -130,7 +152,14 @@ export function hasManifestInference(manifest) {
130
152
 
131
153
 
132
154
  export function validateRequiredInferenceFields(inf, modelId) {
133
-
155
+ inf = inf ?? {};
156
+ inf.attention = inf.attention ?? {};
157
+ inf.normalization = inf.normalization ?? {};
158
+ inf.ffn = inf.ffn ?? {};
159
+ inf.rope = inf.rope ?? {};
160
+ inf.output = inf.output ?? {};
161
+ inf.layerPattern = inf.layerPattern ?? {};
162
+ inf.chatTemplate = inf.chatTemplate ?? {};
134
163
  const errors = [];
135
164
 
136
165
  // Attention fields - non-nullable required
@@ -201,6 +230,20 @@ export function validateRequiredInferenceFields(inf, modelId) {
201
230
  if (inf.rope.ropeLocalTheta === undefined) {
202
231
  errors.push('rope.ropeLocalTheta must be explicitly set (null for no local theta, or number)');
203
232
  }
233
+ if (inf.rope.mropeInterleaved == null) {
234
+ errors.push('rope.mropeInterleaved is required');
235
+ }
236
+ if (inf.rope.mropeSection === undefined) {
237
+ errors.push('rope.mropeSection must be explicitly set (null when unused, or an array of positive integers)');
238
+ }
239
+ if (inf.rope.partialRotaryFactor === undefined) {
240
+ errors.push('rope.partialRotaryFactor must be explicitly set (null when unused, or a number in (0, 1])');
241
+ } else {
242
+ const factor = inf.rope.partialRotaryFactor;
243
+ if (factor !== null && (typeof factor !== 'number' || Number.isNaN(factor) || factor <= 0 || factor > 1)) {
244
+ errors.push('rope.partialRotaryFactor must be a number in (0, 1] or null');
245
+ }
246
+ }
204
247
 
205
248
  // Output fields - non-nullable required
206
249
  if (inf.output.tieWordEmbeddings == null) {
@@ -458,6 +501,26 @@ export function toParsedConfigFromMerged(merged, manifest) {
458
501
  const ropeScalingType = inf.rope.ropeScalingType;
459
502
  const ropeLocalScale = inf.rope.ropeLocalScalingFactor ?? ropeScale;
460
503
  const ropeLocalScalingType = inf.rope.ropeLocalScalingType ?? ropeScalingType;
504
+ const partialRotaryFactor = inf.rope.partialRotaryFactor;
505
+ const ropeInterleaved = inf.rope.mropeInterleaved === true;
506
+ const mropeSection = Array.isArray(inf.rope.mropeSection)
507
+ ? inf.rope.mropeSection.map((entry) => Math.trunc(Number(entry)))
508
+ : null;
509
+ const ropeRotaryDim = resolveRotaryDim(arch.headDim, partialRotaryFactor, merged.modelId);
510
+ if (mropeSection && mropeSection.some((entry) => !Number.isFinite(entry) || entry <= 0)) {
511
+ throw new Error(
512
+ `Manifest "${merged.modelId}" has invalid rope.mropeSection; expected positive integers.`
513
+ );
514
+ }
515
+ if (ropeInterleaved && mropeSection) {
516
+ const doubledMropeDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
517
+ if (doubledMropeDim !== ropeRotaryDim) {
518
+ throw new Error(
519
+ `Manifest "${merged.modelId}" declares rope.mropeSection=${JSON.stringify(mropeSection)}, ` +
520
+ `which expands to rotary dim ${doubledMropeDim}, but the resolved rotary dim is ${ropeRotaryDim}.`
521
+ );
522
+ }
523
+ }
461
524
 
462
525
  // Build ropeScaling object from manifest values if scaling is enabled
463
526
  // Include YARN params when present
@@ -532,6 +595,10 @@ export function toParsedConfigFromMerged(merged, manifest) {
532
595
  slidingWindow: inf.attention.slidingWindow,
533
596
  ropeTheta: inf.rope.ropeTheta,
534
597
  ropeLocalTheta: inf.rope.ropeLocalTheta,
598
+ ropeRotaryDim,
599
+ ropeInterleaved,
600
+ mropeSection,
601
+ partialRotaryFactor,
535
602
  ropeScale,
536
603
  ropeLocalScale,
537
604
  ropeScalingType,