@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -42,56 +42,48 @@ function resolveFallbackActivationDtype(primaryActivationDtype) {
42
42
  function resolveFallbackKernelPath(primaryKernelPath) {
43
43
  const primaryKernelPathId = primaryKernelPath?.id ?? null;
44
44
  if (!primaryKernelPathId) {
45
- return {
46
- kernelPath: null,
47
- kernelPathId: null,
48
- kernelPathSource: 'none',
49
- };
45
+ throw new Error(
46
+ '[ExecutionPlan] F16 finiteness fallback requires a primary kernel path with a stable id. ' +
47
+ 'Add a registered kernelPath id and a finiteness fallback rule.'
48
+ );
50
49
  }
51
50
 
52
- const primaryKernelPathIsObject = typeof primaryKernelPath === 'object' && primaryKernelPath !== null;
51
+ const explicitFallbackKernelPathId = typeof primaryKernelPath?.finitenessFallbackKernelPathId === 'string'
52
+ && primaryKernelPath.finitenessFallbackKernelPathId.length > 0
53
+ ? primaryKernelPath.finitenessFallbackKernelPathId
54
+ : null;
53
55
 
54
- const fallbackKernelPathId = selectRuleValue(
56
+ const fallbackKernelPathId = explicitFallbackKernelPathId ?? selectRuleValue(
55
57
  'inference',
56
58
  'kernelPath',
57
59
  'finitenessFallback',
58
60
  { kernelPathId: primaryKernelPathId }
59
61
  );
60
62
 
61
- const resolvedKernelPathId = typeof fallbackKernelPathId === 'string' && fallbackKernelPathId.length > 0
62
- ? fallbackKernelPathId
63
- : primaryKernelPathId;
64
- const kernelPathSource = resolvedKernelPathId === primaryKernelPathId ? 'self' : 'rule';
63
+ if (typeof fallbackKernelPathId !== 'string' || fallbackKernelPathId.length === 0) {
64
+ throw new Error(
65
+ `[ExecutionPlan] Missing finiteness fallback kernel path mapping for "${primaryKernelPathId}". ` +
66
+ 'Add an explicit rule in src/rules/inference/kernel-path.rules.json.'
67
+ );
68
+ }
65
69
 
66
- if (kernelPathSource === 'self') {
67
- log.warn(
68
- 'Pipeline',
69
- `[ExecutionPlan] No finiteness fallback kernel path mapping for "${primaryKernelPathId}"; using primary kernel path.`
70
+ if (fallbackKernelPathId === primaryKernelPathId) {
71
+ throw new Error(
72
+ `[ExecutionPlan] Invalid finiteness fallback mapping for "${primaryKernelPathId}": ` +
73
+ `fallback kernel path resolves to itself. Add an explicit widening path.`
70
74
  );
71
75
  }
72
76
 
73
77
  try {
74
- const kernelPath = resolveKernelPath(resolvedKernelPathId);
78
+ const kernelPath = resolveKernelPath(fallbackKernelPathId);
75
79
  return {
76
80
  kernelPath,
77
- kernelPathId: resolvedKernelPathId,
78
- kernelPathSource,
81
+ kernelPathId: fallbackKernelPathId,
82
+ kernelPathSource: 'rule',
79
83
  };
80
84
  } catch (error) {
81
- if (primaryKernelPathIsObject) {
82
- log.warn(
83
- 'Pipeline',
84
- `[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${resolvedKernelPathId}" ` +
85
- `for "${primaryKernelPathId}", using inline kernel path as fallback. ${error?.message || error}`
86
- );
87
- return {
88
- kernelPath: primaryKernelPath,
89
- kernelPathId: primaryKernelPathId,
90
- kernelPathSource,
91
- };
92
- }
93
85
  throw new Error(
94
- `[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${resolvedKernelPathId}" ` +
86
+ `[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${fallbackKernelPathId}" ` +
95
87
  `(from "${primaryKernelPathId}"): ${error?.message || error}`
96
88
  );
97
89
  }
@@ -7,6 +7,7 @@ import {
7
7
  resolveExecutionV0KVIO,
8
8
  resolveExecutionV0Precision,
9
9
  } from '../../../config/execution-v0-contract-check.js';
10
+ import { selectRuleValue } from '../../../rules/rule-registry.js';
10
11
  import {
11
12
  EXECUTION_V0_SCHEMA_ID,
12
13
  DEFAULT_EXECUTION_V0_POLICIES,
@@ -856,7 +857,7 @@ function assertInlineKernelPathSessionCompatibility(path, sessionDefaults) {
856
857
  }
857
858
  }
858
859
 
859
- function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
860
+ function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers, finitenessFallbackKernelPathId = null) {
860
861
  const activationDtype = normalizeDtype(
861
862
  sessionDefaults?.compute?.defaults?.activationDtype ?? 'f16',
862
863
  'sessionDefaults.compute.defaults.activationDtype'
@@ -877,6 +878,9 @@ function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
877
878
  description: 'Generated from manifest.inference.execution.steps',
878
879
  activationDtype,
879
880
  kvDtype,
881
+ ...(typeof finitenessFallbackKernelPathId === 'string' && finitenessFallbackKernelPathId.length > 0
882
+ ? { finitenessFallbackKernelPathId }
883
+ : {}),
880
884
  decode: {
881
885
  steps: decodeSteps.length > 0 ? decodeSteps : prefillSteps,
882
886
  },
@@ -1107,7 +1111,26 @@ export function compileExecutionV0(options = {}) {
1107
1111
  ...resolvedDecodeSteps.filter((step) => step.phase === 'decode'),
1108
1112
  ];
1109
1113
 
1110
- const kernelPath = buildInlineKernelPath(patchedSteps, resolvedSession, modelId, numLayers);
1114
+ const defaultKernelPathId = typeof manifestInference.defaultKernelPath === 'string'
1115
+ && manifestInference.defaultKernelPath.trim().length > 0
1116
+ ? manifestInference.defaultKernelPath.trim()
1117
+ : null;
1118
+ const finitenessFallbackKernelPathId = defaultKernelPathId
1119
+ ? selectRuleValue(
1120
+ 'inference',
1121
+ 'kernelPath',
1122
+ 'finitenessFallback',
1123
+ { kernelPathId: defaultKernelPathId }
1124
+ )
1125
+ : null;
1126
+
1127
+ const kernelPath = buildInlineKernelPath(
1128
+ patchedSteps,
1129
+ resolvedSession,
1130
+ modelId,
1131
+ numLayers,
1132
+ finitenessFallbackKernelPathId
1133
+ );
1111
1134
  const layerPipeline = buildLayerPipelineFromExecution(resolvedSteps);
1112
1135
  const sessionPatch = buildSessionRuntimePatch(resolvedSession);
1113
1136
  const modelOverrides = buildModelRuntimeOverrides(manifestInference);
@@ -1162,6 +1185,10 @@ export function applyExecutionV0RuntimeConfig(options = {}) {
1162
1185
  }
1163
1186
 
1164
1187
  const runtimeInferencePatch = { ...executionV0State.runtimeInferencePatch };
1188
+ if (runtimeInference.kernelPath !== undefined) {
1189
+ delete runtimeInferencePatch.kernelPath;
1190
+ delete runtimeInferencePatch.kernelPathSource;
1191
+ }
1165
1192
  if (runtimeInferencePatch.modelOverrides) {
1166
1193
  runtimeInferencePatch.modelOverrides = mergeRuntimeValues(
1167
1194
  runtimeInferencePatch.modelOverrides,
@@ -42,6 +42,7 @@ export async function processFFNStandard(
42
42
  hiddenSize,
43
43
  probes: context.debugProbes,
44
44
  recorder,
45
+ dtype: normedTensor.dtype,
45
46
  });
46
47
 
47
48
  // 2. FFN
@@ -58,6 +59,7 @@ export async function processFFNStandard(
58
59
  hiddenSize,
59
60
  probes: context.debugProbes,
60
61
  recorder,
62
+ dtype: ffnOutput.dtype,
61
63
  });
62
64
 
63
65
  // 3. Residual add
@@ -72,6 +74,7 @@ export async function processFFNStandard(
72
74
  hiddenSize,
73
75
  probes: context.debugProbes,
74
76
  recorder,
77
+ dtype: output.dtype,
75
78
  });
76
79
 
77
80
  if (normedTensor !== postAttn) {
@@ -71,9 +71,13 @@ export interface PipelineContexts {
71
71
  */
72
72
  export interface RoPEConfig {
73
73
  headDim: number;
74
+ rotaryDim?: number;
74
75
  maxSeqLen: number;
75
76
  ropeTheta: number;
76
77
  ropeLocalTheta?: number | null;
78
+ mropeInterleaved?: boolean;
79
+ mropeSection?: number[] | null;
80
+ partialRotaryFactor?: number | null;
77
81
  ropeScale: number;
78
82
  ropeLocalScale?: number;
79
83
  ropeScalingType?: string | null;
@@ -206,13 +206,45 @@ function isSameRoPEScalingConfig(
206
206
  === (rightScaling?.original_max_position_embeddings ?? null);
207
207
  }
208
208
 
209
+ function resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor) {
210
+ if (rotaryDim != null) {
211
+ if (!Number.isFinite(rotaryDim) || rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
212
+ throw new Error(`RoPE rotary dim must be a positive even integer; got "${rotaryDim}".`);
213
+ }
214
+ if (rotaryDim > headDim) {
215
+ throw new Error(`RoPE rotary dim ${rotaryDim} cannot exceed headDim ${headDim}.`);
216
+ }
217
+ return rotaryDim;
218
+ }
219
+ if (partialRotaryFactor == null) {
220
+ return headDim;
221
+ }
222
+ if (!Number.isFinite(partialRotaryFactor) || partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
223
+ throw new Error(
224
+ `RoPE partialRotaryFactor must be a number in (0, 1]; got "${partialRotaryFactor}".`
225
+ );
226
+ }
227
+ const resolved = Math.trunc(headDim * partialRotaryFactor);
228
+ if (resolved <= 0 || (resolved % 2) !== 0) {
229
+ throw new Error(
230
+ `RoPE partialRotaryFactor=${partialRotaryFactor} with headDim=${headDim} resolves ` +
231
+ `to rotaryDim=${resolved}, but rotaryDim must be a positive even integer.`
232
+ );
233
+ }
234
+ return resolved;
235
+ }
236
+
209
237
 
210
238
  export async function initRoPEFrequencies(config, useGPU) {
211
239
  const {
212
240
  headDim,
241
+ rotaryDim,
213
242
  maxSeqLen,
214
243
  ropeTheta,
215
244
  ropeLocalTheta,
245
+ mropeInterleaved,
246
+ mropeSection,
247
+ partialRotaryFactor,
216
248
  ropeScale,
217
249
  ropeLocalScale,
218
250
  ropeScalingType,
@@ -230,14 +262,23 @@ export async function initRoPEFrequencies(config, useGPU) {
230
262
  const resolvedLocalTheta = ropeLocalTheta ?? ropeTheta;
231
263
  const resolvedLocalScalingType = ropeLocalScalingType ?? ropeScalingType;
232
264
  const resolvedLocalScaling = ropeLocalScaling ?? ropeScaling;
265
+ const resolvedRotaryDim = resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor);
266
+ const halfDim = resolvedRotaryDim / 2;
267
+ if (mropeInterleaved === true && Array.isArray(mropeSection)) {
268
+ const expandedDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
269
+ if (expandedDim !== resolvedRotaryDim) {
270
+ throw new Error(
271
+ `RoPE mropeSection expands to ${expandedDim} dims, but rotaryDim is ${resolvedRotaryDim}.`
272
+ );
273
+ }
274
+ }
233
275
 
234
- const halfDim = headDim / 2;
235
276
  const isYarn = ropeScalingType === 'yarn';
236
277
  const isLocalYarn = resolvedLocalScalingType === 'yarn';
237
278
 
238
279
  // Compute global (full_attention) frequencies
239
280
  const globalFreqs = computeRoPEFreqsForTheta(
240
- ropeTheta, headDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
281
+ ropeTheta, resolvedRotaryDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
241
282
  );
242
283
 
243
284
  // Compute local (sliding_attention) frequencies if different from global.
@@ -256,7 +297,7 @@ export async function initRoPEFrequencies(config, useGPU) {
256
297
  if (hasDistinctLocalTheta || hasDistinctLocalScaling) {
257
298
  localFreqs = computeRoPEFreqsForTheta(
258
299
  resolvedLocalTheta,
259
- headDim,
300
+ resolvedRotaryDim,
260
301
  maxSeqLen,
261
302
  resolvedLocalScale,
262
303
  resolvedLocalScalingType,
@@ -303,9 +344,10 @@ export async function initRoPEFrequencies(config, useGPU) {
303
344
 
304
345
  log.debug(
305
346
  'Pipeline',
306
- `RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
347
+ `RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
307
348
  `theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
308
- `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
349
+ `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
350
+ `interleaved=${mropeInterleaved === true}`
309
351
  );
310
352
 
311
353
  return {
@@ -318,9 +360,10 @@ export async function initRoPEFrequencies(config, useGPU) {
318
360
 
319
361
  log.debug(
320
362
  'Pipeline',
321
- `RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
363
+ `RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
322
364
  `theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
323
- `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
365
+ `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
366
+ `interleaved=${mropeInterleaved === true}`
324
367
  );
325
368
 
326
369
  return {
@@ -688,6 +731,10 @@ function applyChatMLTemplate(prompt) {
688
731
  return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
689
732
  }
690
733
 
734
+ function applyQwenTemplate(prompt) {
735
+ return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`;
736
+ }
737
+
691
738
  function applyTranslateGemmaTemplate() {
692
739
  throw new Error(
693
740
  'TranslateGemma template requires structured messages. ' +
@@ -702,7 +749,7 @@ const PROMPT_TEMPLATES = {
702
749
  'llama3': applyHeaderBasedTemplate,
703
750
  'gpt-oss': applyChannelBasedTemplate,
704
751
  'chatml': applyChatMLTemplate,
705
- 'qwen': applyChatMLTemplate, // Qwen uses ChatML format
752
+ 'qwen': applyQwenTemplate,
706
753
  'translategemma': applyTranslateGemmaTemplate,
707
754
  };
708
755
 
@@ -721,7 +768,7 @@ export function applyChatTemplate(prompt, templateType) {
721
768
  export const applyGemmaChatTemplate = applyTurnBasedTemplate;
722
769
  export const applyLlama3ChatTemplate = applyHeaderBasedTemplate;
723
770
  export const applyGptOssChatTemplate = applyChannelBasedTemplate;
724
- export const applyQwenChatTemplate = applyChatMLTemplate;
771
+ export const applyQwenChatTemplate = applyQwenTemplate;
725
772
 
726
773
 
727
774
  export function isStopToken(token, stopTokenIds, eosTokenId) {
@@ -259,6 +259,8 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
259
259
  attentionOutputGate: config.attentionOutputGate,
260
260
  causalAttention: config.causalAttention,
261
261
  rmsNormWeightOffset: config.rmsNormWeightOffset,
262
+ ropeRotaryDim: config.ropeRotaryDim,
263
+ ropeInterleaved: config.ropeInterleaved,
262
264
  tokenIds: context.currentTokenIds ?? null,
263
265
  kernelPath: context.kernelPath ?? null,
264
266
  disableRoPE,
@@ -661,6 +663,8 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
661
663
  attentionOutputGate: config.attentionOutputGate,
662
664
  causalAttention: config.causalAttention,
663
665
  rmsNormWeightOffset: config.rmsNormWeightOffset,
666
+ ropeRotaryDim: config.ropeRotaryDim,
667
+ ropeInterleaved: config.ropeInterleaved,
664
668
  tokenIds: context.currentTokenIds ?? null,
665
669
  skipInputNorm: step.skipInputNorm === true,
666
670
  activationDtype,
@@ -690,6 +694,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
690
694
  hiddenSize,
691
695
  probes: context.debugProbes,
692
696
  recorder,
697
+ dtype: outputDtype,
693
698
  });
694
699
  }
695
700
  break;
@@ -733,6 +738,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
733
738
  hiddenSize,
734
739
  probes: context.debugProbes,
735
740
  recorder,
741
+ dtype: outputDtype,
736
742
  });
737
743
  }
738
744
  break;
@@ -767,6 +773,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
767
773
  hiddenSize,
768
774
  probes: context.debugProbes,
769
775
  recorder,
776
+ dtype: outputDtype,
770
777
  });
771
778
  }
772
779
  break;
@@ -801,6 +808,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
801
808
  hiddenSize,
802
809
  probes: context.debugProbes,
803
810
  recorder,
811
+ dtype: outputDtype,
804
812
  });
805
813
  }
806
814
  break;
@@ -825,6 +833,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
825
833
  hiddenSize,
826
834
  probes: context.debugProbes,
827
835
  recorder,
836
+ dtype: outputDtype,
828
837
  });
829
838
  }
830
839
  break;
@@ -851,6 +860,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
851
860
  hiddenSize,
852
861
  probes: context.debugProbes,
853
862
  recorder,
863
+ dtype: toDtype,
854
864
  });
855
865
  }
856
866
  break;
@@ -880,6 +890,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
880
890
  hiddenSize,
881
891
  probes: context.debugProbes,
882
892
  recorder,
893
+ dtype: getSlotDtype('state') ?? activationDtype,
883
894
  });
884
895
 
885
896
  const computeConfig = context.runtimeComputeConfig ?? null;
@@ -299,9 +299,13 @@ export class InferencePipeline extends PipelineState {
299
299
  const maxSeqLen = config.maxSeqLen;
300
300
  const ropeBuffers = await initRoPEFrequencies({
301
301
  headDim: config.headDim,
302
+ rotaryDim: config.ropeRotaryDim,
302
303
  maxSeqLen,
303
304
  ropeTheta: config.ropeTheta,
304
305
  ropeLocalTheta: config.ropeLocalTheta,
306
+ mropeInterleaved: config.ropeInterleaved,
307
+ mropeSection: config.mropeSection,
308
+ partialRotaryFactor: config.partialRotaryFactor,
305
309
  ropeScale: config.ropeScale,
306
310
  ropeLocalScale: config.ropeLocalScale,
307
311
  ropeScalingType: config.ropeScalingType,
@@ -64,6 +64,68 @@ function resolveSpecialTokens(specialTokensRaw, fallbackTokens, vocab) {
64
64
  return resolved;
65
65
  }
66
66
 
67
+ function resolveByteLevelPretokenizerConfig(preTokenizer) {
68
+ if (!preTokenizer || typeof preTokenizer !== 'object') {
69
+ return {
70
+ useByteLevel: false,
71
+ addPrefixSpace: null,
72
+ };
73
+ }
74
+
75
+ if (preTokenizer.type === 'ByteLevel') {
76
+ return {
77
+ useByteLevel: true,
78
+ addPrefixSpace: preTokenizer.add_prefix_space === true,
79
+ };
80
+ }
81
+
82
+ if (preTokenizer.type === 'Sequence' && Array.isArray(preTokenizer.pretokenizers)) {
83
+ for (const entry of preTokenizer.pretokenizers) {
84
+ const resolved = resolveByteLevelPretokenizerConfig(entry);
85
+ if (resolved.useByteLevel) {
86
+ return resolved;
87
+ }
88
+ }
89
+ }
90
+
91
+ return {
92
+ useByteLevel: false,
93
+ addPrefixSpace: null,
94
+ };
95
+ }
96
+
97
+ function registerAddedTokens(addedTokens, vocab, reverseVocab, patterns, specialTokenIds, derivedSpecialTokens = null) {
98
+ let maxId = -1;
99
+ for (const token of addedTokens) {
100
+ const content = token?.content;
101
+ const id = typeof token?.id === 'number' ? token.id : parseInt(token?.id, 10);
102
+ if (!Number.isFinite(id) || !content) continue;
103
+ if (!vocab.has(content)) {
104
+ vocab.set(content, id);
105
+ reverseVocab.set(id, content);
106
+ }
107
+ if (id > maxId) maxId = id;
108
+ if (content.length > 1) {
109
+ patterns.push({ content, id });
110
+ }
111
+ if (token.special) {
112
+ specialTokenIds.add(id);
113
+ if (derivedSpecialTokens) {
114
+ if (derivedSpecialTokens.bos == null && (content === '<bos>' || content === '<s>' || content.includes('bos'))) {
115
+ derivedSpecialTokens.bos = id;
116
+ } else if (derivedSpecialTokens.eos == null && (content === '<eos>' || content === '</s>' || content.includes('eos'))) {
117
+ derivedSpecialTokens.eos = id;
118
+ } else if (derivedSpecialTokens.pad == null && (content === '<pad>' || content.includes('pad'))) {
119
+ derivedSpecialTokens.pad = id;
120
+ } else if (derivedSpecialTokens.unk == null && (content === '<unk>' || content.includes('unk'))) {
121
+ derivedSpecialTokens.unk = id;
122
+ }
123
+ }
124
+ }
125
+ }
126
+ return maxId;
127
+ }
128
+
67
129
 
68
130
  export class TransformersTokenizer extends BaseTokenizer {
69
131
 
@@ -156,6 +218,10 @@ export class BundledTokenizer extends BaseTokenizer {
156
218
 
157
219
  #byteDecoder = null;
158
220
 
221
+ #byteEncoder = null;
222
+
223
+ #useByteLevelEncoding = false;
224
+
159
225
 
160
226
  constructor(config = {}) {
161
227
  // BundledTokenizer gets vocabSize from load(), so defer validation
@@ -199,9 +265,20 @@ export class BundledTokenizer extends BaseTokenizer {
199
265
  }
200
266
 
201
267
  this.#byteDecoder = new Map();
268
+ this.#byteEncoder = new Map();
202
269
  for (let i = 0; i < base.length; i++) {
203
270
  this.#byteDecoder.set(String.fromCodePoint(chars[i]), base[i]);
271
+ this.#byteEncoder.set(base[i], String.fromCodePoint(chars[i]));
272
+ }
273
+ }
274
+
275
+ #encodeByteLevelText(text) {
276
+ const bytes = new TextEncoder().encode(text);
277
+ let out = '';
278
+ for (const byte of bytes) {
279
+ out += this.#byteEncoder?.get(byte) ?? String.fromCharCode(byte);
204
280
  }
281
+ return out;
205
282
  }
206
283
 
207
284
 
@@ -290,30 +367,16 @@ export class BundledTokenizer extends BaseTokenizer {
290
367
  eos: null,
291
368
  unk: null,
292
369
  };
293
- for (const token of addedTokens) {
294
- const content = token.content;
295
- const id = typeof token.id === 'number' ? token.id : parseInt( (token.id), 10);
296
- if (!Number.isFinite(id) || !content) continue;
297
- if (!this.#vocab.has(content)) {
298
- this.#vocab.set(content, id);
299
- this.#reverseVocab.set(id, content);
300
- }
301
- if (id > maxId) maxId = id;
302
- if (token.special) {
303
- specialTokenIds.add(id);
304
- if (content.length > 1) {
305
- specialTokenPatterns.push({ content, id });
306
- }
307
- if (derivedSpecialTokens.bos == null && (content === '<bos>' || content === '<s>' || content.includes('bos'))) {
308
- derivedSpecialTokens.bos = id;
309
- } else if (derivedSpecialTokens.eos == null && (content === '<eos>' || content === '</s>' || content.includes('eos'))) {
310
- derivedSpecialTokens.eos = id;
311
- } else if (derivedSpecialTokens.pad == null && (content === '<pad>' || content.includes('pad'))) {
312
- derivedSpecialTokens.pad = id;
313
- } else if (derivedSpecialTokens.unk == null && (content === '<unk>' || content.includes('unk'))) {
314
- derivedSpecialTokens.unk = id;
315
- }
316
- }
370
+ const addedMaxId = registerAddedTokens(
371
+ addedTokens,
372
+ this.#vocab,
373
+ this.#reverseVocab,
374
+ specialTokenPatterns,
375
+ specialTokenIds,
376
+ derivedSpecialTokens
377
+ );
378
+ if (addedMaxId > maxId) {
379
+ maxId = addedMaxId;
317
380
  }
318
381
 
319
382
  const specialTokensRaw = hf.special_tokens_map || hf.specialTokens || hf.special_tokens || null;
@@ -351,6 +414,7 @@ export class BundledTokenizer extends BaseTokenizer {
351
414
 
352
415
  // Handle behavior flags (use HF config if present, else runtime defaults)
353
416
  const runtimeDefaults = getRuntimeConfig().inference.tokenizer;
417
+ const byteLevelPretokenizer = resolveByteLevelPretokenizerConfig(hf.pre_tokenizer);
354
418
  const configuredAddBosToken = this.addBosToken;
355
419
  const configuredAddEosToken = this.addEosToken;
356
420
  this.addBosToken =
@@ -378,9 +442,16 @@ export class BundledTokenizer extends BaseTokenizer {
378
442
  // - runtime config addSpacePrefix (user override or null for auto-detect)
379
443
  const decoderPrepend = hf.decoder?.prepend_scheme === 'always' || hf.decoder?.add_prefix_space === true;
380
444
  const normalizerPrepend = hf.normalizer?.prepend_scheme === 'always' || hf.normalizer?.add_prefix_space === true;
445
+ this.#useByteLevelEncoding = byteLevelPretokenizer.useByteLevel;
381
446
  const runtimeSpacePrefix = runtimeDefaults.addSpacePrefix;
382
447
  // Use explicit runtime config if set (non-null), otherwise auto-detect from tokenizer.json
383
- this.#addSpacePrefix = runtimeSpacePrefix ?? model.add_prefix_space ?? model.add_dummy_prefix ?? decoderPrepend ?? normalizerPrepend ?? false;
448
+ this.#addSpacePrefix = runtimeSpacePrefix
449
+ ?? byteLevelPretokenizer.addPrefixSpace
450
+ ?? model.add_prefix_space
451
+ ?? model.add_dummy_prefix
452
+ ?? decoderPrepend
453
+ ?? normalizerPrepend
454
+ ?? false;
384
455
  log.debug('Tokenizer', `addSpacePrefix=${this.#addSpacePrefix} (runtime=${runtimeSpacePrefix}, model=${model.add_prefix_space ?? model.add_dummy_prefix}, decoder=${decoderPrepend}, normalizer=${normalizerPrepend})`);
385
456
 
386
457
  // Detect space prefix style by checking which WORD tokens exist in vocab
@@ -469,11 +540,47 @@ export class BundledTokenizer extends BaseTokenizer {
469
540
  this.#tokenTypes = tokenizerJson.tokenTypes;
470
541
  }
471
542
 
543
+ let maxId = -1;
544
+ for (const id of this.#vocab.values()) {
545
+ if (Number.isFinite(id) && id > maxId) {
546
+ maxId = id;
547
+ }
548
+ }
549
+
550
+ const addedTokens = Array.isArray(tokenizerJson.added_tokens) ? tokenizerJson.added_tokens : [];
551
+ const tokenPatterns = [];
552
+ const specialTokenIds = new Set();
553
+ const derivedSpecialTokens = {
554
+ pad: null,
555
+ bos: null,
556
+ eos: null,
557
+ unk: null,
558
+ };
559
+ const addedMaxId = registerAddedTokens(
560
+ addedTokens,
561
+ this.#vocab,
562
+ this.#reverseVocab,
563
+ tokenPatterns,
564
+ specialTokenIds,
565
+ derivedSpecialTokens
566
+ );
567
+ if (addedMaxId > maxId) {
568
+ maxId = addedMaxId;
569
+ }
570
+
472
571
  // Set special tokens - support both camelCase and snake_case formats
473
572
  const specialTokensRaw = (tokenizerJson.specialTokens || (tokenizerJson).special_tokens);
474
- this.specialTokens = resolveSpecialTokens(specialTokensRaw, this.specialTokens, this.#vocab);
573
+ this.specialTokens = resolveSpecialTokens(
574
+ specialTokensRaw,
575
+ {
576
+ ...derivedSpecialTokens,
577
+ ...this.specialTokens,
578
+ },
579
+ this.#vocab
580
+ );
475
581
  log.debug('Tokenizer', `Special tokens: BOS=${this.specialTokens.bos}, EOS=${this.specialTokens.eos}`);
476
- this.#specialTokenIds = new Set();
582
+ this.#specialTokenIds = specialTokenIds;
583
+ this.#specialTokenPatterns = tokenPatterns;
477
584
  const builtinSpecials = [
478
585
  this.specialTokens.pad,
479
586
  this.specialTokens.bos,
@@ -485,8 +592,13 @@ export class BundledTokenizer extends BaseTokenizer {
485
592
  this.#specialTokenIds.add(id);
486
593
  }
487
594
  }
595
+ this.#specialTokenPatterns.sort((a, b) => b.content.length - a.content.length);
596
+ if (maxId >= 0) {
597
+ this.vocabSize = Math.max(this.vocabSize, maxId + 1);
598
+ }
488
599
 
489
600
  const runtimeDefaults = getRuntimeConfig().inference.tokenizer;
601
+ const byteLevelPretokenizer = resolveByteLevelPretokenizerConfig(tokenizerJson.pre_tokenizer);
490
602
  const configuredAddBosToken = this.addBosToken;
491
603
  const configuredAddEosToken = this.addEosToken;
492
604
  this.addBosToken =
@@ -505,9 +617,11 @@ export class BundledTokenizer extends BaseTokenizer {
505
617
  if (this.addEosToken && this.specialTokens.eos == null) {
506
618
  throw new Error('[Tokenizer] addEosToken is enabled but eos token is missing.');
507
619
  }
620
+ this.#useByteLevelEncoding = byteLevelPretokenizer.useByteLevel;
508
621
  // NOTE: Default to FALSE - first word shouldn't get space prefix
509
622
  // Space prefixes are only for words that follow a space in original text
510
- this.#addSpacePrefix = tokenizerJson.addSpacePrefix === true;
623
+ this.#addSpacePrefix = tokenizerJson.addSpacePrefix === true
624
+ || byteLevelPretokenizer.addPrefixSpace === true;
511
625
 
512
626
  // Detect space prefix style based on vocab tokens
513
627
  // GPT-style uses 'Ġ' (U+0120), SentencePiece uses '▁' (U+2581)
@@ -548,7 +662,8 @@ export class BundledTokenizer extends BaseTokenizer {
548
662
  ids.push(this.specialTokens.bos);
549
663
  }
550
664
 
551
- // Split text around special tokens and tokenize each segment
665
+ // Split text around literal added tokens and special tokens, then tokenize
666
+ // the remaining plain-text segments normally.
552
667
  const segments = this.#splitOnSpecialTokens(text);
553
668
  for (const seg of segments) {
554
669
  if (seg.isSpecial && seg.id !== undefined) {
@@ -690,11 +805,19 @@ export class BundledTokenizer extends BaseTokenizer {
690
805
  if (text.length === 0) return [];
691
806
 
692
807
  let normalized = text;
693
- if (this.#addSpacePrefix && !normalized.startsWith(' ')) {
694
- normalized = ` ${normalized}`;
808
+ let prefixed;
809
+ if (this.#useByteLevelEncoding) {
810
+ if (this.#addSpacePrefix && !normalized.startsWith(' ')) {
811
+ normalized = ` ${normalized}`;
812
+ }
813
+ prefixed = this.#encodeByteLevelText(normalized);
814
+ } else {
815
+ if (this.#addSpacePrefix && !normalized.startsWith(' ')) {
816
+ normalized = ` ${normalized}`;
817
+ }
818
+ const sp = this.#spacePrefixChar;
819
+ prefixed = normalized.replace(/ /g, sp);
695
820
  }
696
- const sp = this.#spacePrefixChar;
697
- const prefixed = normalized.replace(/ /g, sp);
698
821
 
699
822
  if (this.#mergeRanks.size === 0) {
700
823
  return this.#encodeBPEGreedy(prefixed);