@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -182,10 +182,18 @@ export async function recordLayerAttentionGPU(
|
|
|
182
182
|
// 3. RoPE (modifies tensor in-place)
|
|
183
183
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
184
184
|
await recordRoPE(recorder, qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
185
|
-
numHeads,
|
|
185
|
+
numHeads,
|
|
186
|
+
headDim,
|
|
187
|
+
rotaryDim: config.ropeRotaryDim,
|
|
188
|
+
interleaved: config.ropeInterleaved,
|
|
189
|
+
startPos: currentSeqLen,
|
|
186
190
|
});
|
|
187
191
|
await recordRoPE(recorder, kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
188
|
-
numHeads: numKVHeads,
|
|
192
|
+
numHeads: numKVHeads,
|
|
193
|
+
headDim,
|
|
194
|
+
rotaryDim: config.ropeRotaryDim,
|
|
195
|
+
interleaved: config.ropeInterleaved,
|
|
196
|
+
startPos: currentSeqLen,
|
|
189
197
|
});
|
|
190
198
|
}
|
|
191
199
|
|
|
@@ -502,6 +510,7 @@ export async function recordLayerAttentionGPU(
|
|
|
502
510
|
size: numTokens * numHeads * headDim,
|
|
503
511
|
gate: qGateTensor,
|
|
504
512
|
gateActivation: 'sigmoid',
|
|
513
|
+
inputActivation: 'identity',
|
|
505
514
|
swigluLimit: null,
|
|
506
515
|
});
|
|
507
516
|
recorder.trackTemporaryBuffer(attnOutput.buffer);
|
|
@@ -299,10 +299,18 @@ export async function runLayerAttentionGPU(
|
|
|
299
299
|
|
|
300
300
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
301
301
|
await runRoPE(qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
302
|
-
numHeads,
|
|
302
|
+
numHeads,
|
|
303
|
+
headDim,
|
|
304
|
+
rotaryDim: config.ropeRotaryDim,
|
|
305
|
+
interleaved: config.ropeInterleaved,
|
|
306
|
+
startPos: currentSeqLen,
|
|
303
307
|
});
|
|
304
308
|
await runRoPE(kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
305
|
-
numHeads: numKVHeads,
|
|
309
|
+
numHeads: numKVHeads,
|
|
310
|
+
headDim,
|
|
311
|
+
rotaryDim: config.ropeRotaryDim,
|
|
312
|
+
interleaved: config.ropeInterleaved,
|
|
313
|
+
startPos: currentSeqLen,
|
|
306
314
|
});
|
|
307
315
|
|
|
308
316
|
// Trace RoPE outputs
|
|
@@ -690,6 +698,7 @@ export async function runLayerAttentionGPU(
|
|
|
690
698
|
size: numTokens * numHeads * headDim,
|
|
691
699
|
gate: qGateTensor,
|
|
692
700
|
gateActivation: 'sigmoid',
|
|
701
|
+
inputActivation: 'identity',
|
|
693
702
|
swigluLimit: null,
|
|
694
703
|
});
|
|
695
704
|
releaseBuffer(attnOutput.buffer);
|
|
@@ -224,6 +224,29 @@ function formatChatML(messages) {
|
|
|
224
224
|
return parts.join('');
|
|
225
225
|
}
|
|
226
226
|
|
|
227
|
+
function formatQwen(messages) {
|
|
228
|
+
// Qwen 3.5 chat format is ChatML-like, but the generation prelude includes
|
|
229
|
+
// an explicit empty thinking block before assistant output.
|
|
230
|
+
const parts = [];
|
|
231
|
+
for (const [index, message] of messages.entries()) {
|
|
232
|
+
const role = normalizeChatRole(message?.role);
|
|
233
|
+
assertSupportedChatRole(role, 'Qwen', index);
|
|
234
|
+
if (role === 'system' && index !== 0) {
|
|
235
|
+
throw new Error('Qwen template requires any system message to appear first.');
|
|
236
|
+
}
|
|
237
|
+
const content = normalizeChatMessageContent(message?.content);
|
|
238
|
+
if (role === 'system') {
|
|
239
|
+
parts.push(`<|im_start|>system\n${content}<|im_end|>\n`);
|
|
240
|
+
} else if (role === 'user') {
|
|
241
|
+
parts.push(`<|im_start|>user\n${content}<|im_end|>\n`);
|
|
242
|
+
} else if (role === 'assistant') {
|
|
243
|
+
parts.push(`<|im_start|>assistant\n${content}<|im_end|>\n`);
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
parts.push('<|im_start|>assistant\n<think>\n\n</think>\n\n');
|
|
247
|
+
return parts.join('');
|
|
248
|
+
}
|
|
249
|
+
|
|
227
250
|
function formatTranslateGemmaUserPrompt(content) {
|
|
228
251
|
if (!Array.isArray(content) || content.length !== 1) {
|
|
229
252
|
throw new Error(
|
|
@@ -345,7 +368,7 @@ const CHAT_FORMATTERS = {
|
|
|
345
368
|
'llama3': formatHeaderBased,
|
|
346
369
|
'gpt-oss': formatChannelBased,
|
|
347
370
|
'chatml': formatChatML,
|
|
348
|
-
'qwen':
|
|
371
|
+
'qwen': formatQwen,
|
|
349
372
|
'translategemma': formatTranslateGemma,
|
|
350
373
|
};
|
|
351
374
|
|
|
@@ -363,4 +386,5 @@ export function formatChatMessages(messages, templateType) {
|
|
|
363
386
|
export const formatGemmaChat = formatTurnBased;
|
|
364
387
|
export const formatLlama3Chat = formatHeaderBased;
|
|
365
388
|
export const formatGptOssChat = formatChannelBased;
|
|
389
|
+
export const formatQwenChat = formatQwen;
|
|
366
390
|
export const formatTranslateGemmaChat = formatTranslateGemma;
|
|
@@ -148,6 +148,10 @@ export interface ParsedModelConfig {
|
|
|
148
148
|
slidingWindow: number | null;
|
|
149
149
|
ropeTheta: number;
|
|
150
150
|
ropeLocalTheta: number | null;
|
|
151
|
+
ropeRotaryDim: number;
|
|
152
|
+
ropeInterleaved: boolean;
|
|
153
|
+
mropeSection: number[] | null;
|
|
154
|
+
partialRotaryFactor: number | null;
|
|
151
155
|
ropeScale: number;
|
|
152
156
|
ropeLocalScale: number;
|
|
153
157
|
ropeScalingType: string | null;
|
|
@@ -210,6 +214,11 @@ export interface ManifestWithInference {
|
|
|
210
214
|
*/
|
|
211
215
|
export function hasManifestInference(manifest: Manifest): manifest is Manifest & { inference: ManifestInferenceSchema };
|
|
212
216
|
|
|
217
|
+
export function validateRequiredInferenceFields(
|
|
218
|
+
inf: ManifestInferenceSchema,
|
|
219
|
+
modelId: string
|
|
220
|
+
): void;
|
|
221
|
+
|
|
213
222
|
/**
|
|
214
223
|
* Convert MergedConfig to ParsedModelConfig.
|
|
215
224
|
*/
|
|
@@ -21,6 +21,28 @@ function assertSupportedRuntimeModelType(manifest) {
|
|
|
21
21
|
);
|
|
22
22
|
}
|
|
23
23
|
|
|
24
|
+
function resolveRotaryDim(headDim, partialRotaryFactor, modelId) {
|
|
25
|
+
if (partialRotaryFactor == null) {
|
|
26
|
+
return headDim;
|
|
27
|
+
}
|
|
28
|
+
if (typeof partialRotaryFactor !== 'number' || Number.isNaN(partialRotaryFactor)) {
|
|
29
|
+
throw new Error(`Manifest "${modelId}" has invalid rope.partialRotaryFactor.`);
|
|
30
|
+
}
|
|
31
|
+
if (partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
|
|
32
|
+
throw new Error(
|
|
33
|
+
`Manifest "${modelId}" requires 0 < rope.partialRotaryFactor <= 1; got ${partialRotaryFactor}.`
|
|
34
|
+
);
|
|
35
|
+
}
|
|
36
|
+
const rotaryDim = Math.trunc(headDim * partialRotaryFactor);
|
|
37
|
+
if (rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
|
|
38
|
+
throw new Error(
|
|
39
|
+
`Manifest "${modelId}" resolves rope rotary dim ${rotaryDim} from headDim=${headDim} ` +
|
|
40
|
+
`and partialRotaryFactor=${partialRotaryFactor}, but rotary dim must be a positive even integer.`
|
|
41
|
+
);
|
|
42
|
+
}
|
|
43
|
+
return rotaryDim;
|
|
44
|
+
}
|
|
45
|
+
|
|
24
46
|
export function getStopTokenIds(manifest) {
|
|
25
47
|
const eosTokenId = manifest?.eos_token_id;
|
|
26
48
|
if (Array.isArray(eosTokenId)) return eosTokenId;
|
|
@@ -129,8 +151,15 @@ export function hasManifestInference(manifest) {
|
|
|
129
151
|
}
|
|
130
152
|
|
|
131
153
|
|
|
132
|
-
function validateRequiredInferenceFields(inf, modelId) {
|
|
133
|
-
|
|
154
|
+
export function validateRequiredInferenceFields(inf, modelId) {
|
|
155
|
+
inf = inf ?? {};
|
|
156
|
+
inf.attention = inf.attention ?? {};
|
|
157
|
+
inf.normalization = inf.normalization ?? {};
|
|
158
|
+
inf.ffn = inf.ffn ?? {};
|
|
159
|
+
inf.rope = inf.rope ?? {};
|
|
160
|
+
inf.output = inf.output ?? {};
|
|
161
|
+
inf.layerPattern = inf.layerPattern ?? {};
|
|
162
|
+
inf.chatTemplate = inf.chatTemplate ?? {};
|
|
134
163
|
const errors = [];
|
|
135
164
|
|
|
136
165
|
// Attention fields - non-nullable required
|
|
@@ -201,6 +230,20 @@ function validateRequiredInferenceFields(inf, modelId) {
|
|
|
201
230
|
if (inf.rope.ropeLocalTheta === undefined) {
|
|
202
231
|
errors.push('rope.ropeLocalTheta must be explicitly set (null for no local theta, or number)');
|
|
203
232
|
}
|
|
233
|
+
if (inf.rope.mropeInterleaved == null) {
|
|
234
|
+
errors.push('rope.mropeInterleaved is required');
|
|
235
|
+
}
|
|
236
|
+
if (inf.rope.mropeSection === undefined) {
|
|
237
|
+
errors.push('rope.mropeSection must be explicitly set (null when unused, or an array of positive integers)');
|
|
238
|
+
}
|
|
239
|
+
if (inf.rope.partialRotaryFactor === undefined) {
|
|
240
|
+
errors.push('rope.partialRotaryFactor must be explicitly set (null when unused, or a number in (0, 1])');
|
|
241
|
+
} else {
|
|
242
|
+
const factor = inf.rope.partialRotaryFactor;
|
|
243
|
+
if (factor !== null && (typeof factor !== 'number' || Number.isNaN(factor) || factor <= 0 || factor > 1)) {
|
|
244
|
+
errors.push('rope.partialRotaryFactor must be a number in (0, 1] or null');
|
|
245
|
+
}
|
|
246
|
+
}
|
|
204
247
|
|
|
205
248
|
// Output fields - non-nullable required
|
|
206
249
|
if (inf.output.tieWordEmbeddings == null) {
|
|
@@ -458,6 +501,26 @@ export function toParsedConfigFromMerged(merged, manifest) {
|
|
|
458
501
|
const ropeScalingType = inf.rope.ropeScalingType;
|
|
459
502
|
const ropeLocalScale = inf.rope.ropeLocalScalingFactor ?? ropeScale;
|
|
460
503
|
const ropeLocalScalingType = inf.rope.ropeLocalScalingType ?? ropeScalingType;
|
|
504
|
+
const partialRotaryFactor = inf.rope.partialRotaryFactor;
|
|
505
|
+
const ropeInterleaved = inf.rope.mropeInterleaved === true;
|
|
506
|
+
const mropeSection = Array.isArray(inf.rope.mropeSection)
|
|
507
|
+
? inf.rope.mropeSection.map((entry) => Math.trunc(Number(entry)))
|
|
508
|
+
: null;
|
|
509
|
+
const ropeRotaryDim = resolveRotaryDim(arch.headDim, partialRotaryFactor, merged.modelId);
|
|
510
|
+
if (mropeSection && mropeSection.some((entry) => !Number.isFinite(entry) || entry <= 0)) {
|
|
511
|
+
throw new Error(
|
|
512
|
+
`Manifest "${merged.modelId}" has invalid rope.mropeSection; expected positive integers.`
|
|
513
|
+
);
|
|
514
|
+
}
|
|
515
|
+
if (ropeInterleaved && mropeSection) {
|
|
516
|
+
const doubledMropeDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
|
|
517
|
+
if (doubledMropeDim !== ropeRotaryDim) {
|
|
518
|
+
throw new Error(
|
|
519
|
+
`Manifest "${merged.modelId}" declares rope.mropeSection=${JSON.stringify(mropeSection)}, ` +
|
|
520
|
+
`which expands to rotary dim ${doubledMropeDim}, but the resolved rotary dim is ${ropeRotaryDim}.`
|
|
521
|
+
);
|
|
522
|
+
}
|
|
523
|
+
}
|
|
461
524
|
|
|
462
525
|
// Build ropeScaling object from manifest values if scaling is enabled
|
|
463
526
|
// Include YARN params when present
|
|
@@ -532,6 +595,10 @@ export function toParsedConfigFromMerged(merged, manifest) {
|
|
|
532
595
|
slidingWindow: inf.attention.slidingWindow,
|
|
533
596
|
ropeTheta: inf.rope.ropeTheta,
|
|
534
597
|
ropeLocalTheta: inf.rope.ropeLocalTheta,
|
|
598
|
+
ropeRotaryDim,
|
|
599
|
+
ropeInterleaved,
|
|
600
|
+
mropeSection,
|
|
601
|
+
partialRotaryFactor,
|
|
535
602
|
ropeScale,
|
|
536
603
|
ropeLocalScale,
|
|
537
604
|
ropeScalingType,
|
|
@@ -42,56 +42,48 @@ function resolveFallbackActivationDtype(primaryActivationDtype) {
|
|
|
42
42
|
function resolveFallbackKernelPath(primaryKernelPath) {
|
|
43
43
|
const primaryKernelPathId = primaryKernelPath?.id ?? null;
|
|
44
44
|
if (!primaryKernelPathId) {
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
};
|
|
45
|
+
throw new Error(
|
|
46
|
+
'[ExecutionPlan] F16 finiteness fallback requires a primary kernel path with a stable id. ' +
|
|
47
|
+
'Add a registered kernelPath id and a finiteness fallback rule.'
|
|
48
|
+
);
|
|
50
49
|
}
|
|
51
50
|
|
|
52
|
-
const
|
|
51
|
+
const explicitFallbackKernelPathId = typeof primaryKernelPath?.finitenessFallbackKernelPathId === 'string'
|
|
52
|
+
&& primaryKernelPath.finitenessFallbackKernelPathId.length > 0
|
|
53
|
+
? primaryKernelPath.finitenessFallbackKernelPathId
|
|
54
|
+
: null;
|
|
53
55
|
|
|
54
|
-
const fallbackKernelPathId = selectRuleValue(
|
|
56
|
+
const fallbackKernelPathId = explicitFallbackKernelPathId ?? selectRuleValue(
|
|
55
57
|
'inference',
|
|
56
58
|
'kernelPath',
|
|
57
59
|
'finitenessFallback',
|
|
58
60
|
{ kernelPathId: primaryKernelPathId }
|
|
59
61
|
);
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
if (typeof fallbackKernelPathId !== 'string' || fallbackKernelPathId.length === 0) {
|
|
64
|
+
throw new Error(
|
|
65
|
+
`[ExecutionPlan] Missing finiteness fallback kernel path mapping for "${primaryKernelPathId}". ` +
|
|
66
|
+
'Add an explicit rule in src/rules/inference/kernel-path.rules.json.'
|
|
67
|
+
);
|
|
68
|
+
}
|
|
65
69
|
|
|
66
|
-
if (
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
`
|
|
70
|
+
if (fallbackKernelPathId === primaryKernelPathId) {
|
|
71
|
+
throw new Error(
|
|
72
|
+
`[ExecutionPlan] Invalid finiteness fallback mapping for "${primaryKernelPathId}": ` +
|
|
73
|
+
`fallback kernel path resolves to itself. Add an explicit widening path.`
|
|
70
74
|
);
|
|
71
75
|
}
|
|
72
76
|
|
|
73
77
|
try {
|
|
74
|
-
const kernelPath = resolveKernelPath(
|
|
78
|
+
const kernelPath = resolveKernelPath(fallbackKernelPathId);
|
|
75
79
|
return {
|
|
76
80
|
kernelPath,
|
|
77
|
-
kernelPathId:
|
|
78
|
-
kernelPathSource,
|
|
81
|
+
kernelPathId: fallbackKernelPathId,
|
|
82
|
+
kernelPathSource: 'rule',
|
|
79
83
|
};
|
|
80
84
|
} catch (error) {
|
|
81
|
-
if (primaryKernelPathIsObject) {
|
|
82
|
-
log.warn(
|
|
83
|
-
'Pipeline',
|
|
84
|
-
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${resolvedKernelPathId}" ` +
|
|
85
|
-
`for "${primaryKernelPathId}", using inline kernel path as fallback. ${error?.message || error}`
|
|
86
|
-
);
|
|
87
|
-
return {
|
|
88
|
-
kernelPath: primaryKernelPath,
|
|
89
|
-
kernelPathId: primaryKernelPathId,
|
|
90
|
-
kernelPathSource,
|
|
91
|
-
};
|
|
92
|
-
}
|
|
93
85
|
throw new Error(
|
|
94
|
-
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${
|
|
86
|
+
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${fallbackKernelPathId}" ` +
|
|
95
87
|
`(from "${primaryKernelPathId}"): ${error?.message || error}`
|
|
96
88
|
);
|
|
97
89
|
}
|
|
@@ -1,4 +1,13 @@
|
|
|
1
1
|
import { mergeRuntimeValues } from '../../../config/runtime-merge.js';
|
|
2
|
+
import {
|
|
3
|
+
buildExecutionV0KernelProfileKey,
|
|
4
|
+
indexExecutionV0KernelProfiles,
|
|
5
|
+
normalizeExecutionV0Dtype,
|
|
6
|
+
resolveExecutionV0KernelProfile,
|
|
7
|
+
resolveExecutionV0KVIO,
|
|
8
|
+
resolveExecutionV0Precision,
|
|
9
|
+
} from '../../../config/execution-v0-contract-check.js';
|
|
10
|
+
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
2
11
|
import {
|
|
3
12
|
EXECUTION_V0_SCHEMA_ID,
|
|
4
13
|
DEFAULT_EXECUTION_V0_POLICIES,
|
|
@@ -59,13 +68,9 @@ function cloneJson(value) {
|
|
|
59
68
|
return JSON.parse(JSON.stringify(value));
|
|
60
69
|
}
|
|
61
70
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
throw new Error(`[ExecutionV0] ${label} must be "f16" or "f32"; got "${value}"`);
|
|
66
|
-
}
|
|
67
|
-
return normalized;
|
|
68
|
-
}
|
|
71
|
+
const normalizeDtype = normalizeExecutionV0Dtype;
|
|
72
|
+
const resolvePrecision = resolveExecutionV0Precision;
|
|
73
|
+
const resolveKVIO = resolveExecutionV0KVIO;
|
|
69
74
|
|
|
70
75
|
function normalizePhase(value, label) {
|
|
71
76
|
const normalized = String(value ?? '').trim().toLowerCase();
|
|
@@ -117,10 +122,7 @@ function stepHasLayer(step, layerIdx) {
|
|
|
117
122
|
return step.layers.includes(layerIdx);
|
|
118
123
|
}
|
|
119
124
|
|
|
120
|
-
|
|
121
|
-
if (!kernelRef) return null;
|
|
122
|
-
return `${kernelRef.id}|${kernelRef.version}|${kernelRef.digest}`;
|
|
123
|
-
}
|
|
125
|
+
const buildKernelProfileKey = buildExecutionV0KernelProfileKey;
|
|
124
126
|
|
|
125
127
|
function normalizeSlot(value, label) {
|
|
126
128
|
if (typeof value !== 'string' || value.trim().length === 0) {
|
|
@@ -212,90 +214,10 @@ function hasDefinedPath(root, pathSegments) {
|
|
|
212
214
|
return current !== undefined;
|
|
213
215
|
}
|
|
214
216
|
|
|
215
|
-
|
|
216
|
-
const byKey = new Map();
|
|
217
|
-
const profiles = sessionDefaults?.compute?.kernelProfiles ?? [];
|
|
218
|
-
for (const profile of profiles) {
|
|
219
|
-
assertKernelRef(profile.kernelRef, 'sessionDefaults.compute.kernelProfiles[].kernelRef');
|
|
220
|
-
byKey.set(buildKernelProfileKey(profile.kernelRef), profile);
|
|
221
|
-
}
|
|
222
|
-
return byKey;
|
|
223
|
-
}
|
|
217
|
+
const indexKernelProfiles = indexExecutionV0KernelProfiles;
|
|
224
218
|
|
|
225
219
|
function resolveProfile(profileIndex, step) {
|
|
226
|
-
|
|
227
|
-
if (!key) return null;
|
|
228
|
-
return profileIndex.get(key) ?? null;
|
|
229
|
-
}
|
|
230
|
-
|
|
231
|
-
function resolvePrecision(step, profile, sessionDefaults) {
|
|
232
|
-
const defaults = sessionDefaults.compute.defaults;
|
|
233
|
-
const precision = {
|
|
234
|
-
inputDtype: step.precision?.inputDtype
|
|
235
|
-
?? profile?.precision?.inputDtype
|
|
236
|
-
?? null,
|
|
237
|
-
mathDtype: step.precision?.mathDtype
|
|
238
|
-
?? profile?.precision?.mathDtype
|
|
239
|
-
?? defaults.mathDtype,
|
|
240
|
-
accumDtype: step.precision?.accumDtype
|
|
241
|
-
?? profile?.precision?.accumDtype
|
|
242
|
-
?? defaults.accumDtype,
|
|
243
|
-
outputDtype: step.precision?.outputDtype
|
|
244
|
-
?? profile?.precision?.outputDtype
|
|
245
|
-
?? defaults.outputDtype,
|
|
246
|
-
};
|
|
247
|
-
const sources = {
|
|
248
|
-
inputDtype: step.precision?.inputDtype != null
|
|
249
|
-
? 'manifest'
|
|
250
|
-
: profile?.precision?.inputDtype != null
|
|
251
|
-
? 'kernelProfile'
|
|
252
|
-
: 'derived',
|
|
253
|
-
mathDtype: step.precision?.mathDtype != null
|
|
254
|
-
? 'manifest'
|
|
255
|
-
: profile?.precision?.mathDtype != null
|
|
256
|
-
? 'kernelProfile'
|
|
257
|
-
: 'sessionDefault',
|
|
258
|
-
accumDtype: step.precision?.accumDtype != null
|
|
259
|
-
? 'manifest'
|
|
260
|
-
: profile?.precision?.accumDtype != null
|
|
261
|
-
? 'kernelProfile'
|
|
262
|
-
: 'sessionDefault',
|
|
263
|
-
outputDtype: step.precision?.outputDtype != null
|
|
264
|
-
? 'manifest'
|
|
265
|
-
: profile?.precision?.outputDtype != null
|
|
266
|
-
? 'kernelProfile'
|
|
267
|
-
: 'sessionDefault',
|
|
268
|
-
};
|
|
269
|
-
return { precision, sources };
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
function resolveKVIO(step, profile, sessionDefaults) {
|
|
273
|
-
if (step.kvIO) {
|
|
274
|
-
return {
|
|
275
|
-
value: {
|
|
276
|
-
readDtype: normalizeDtype(step.kvIO.readDtype, `${step.id}.kvIO.readDtype`),
|
|
277
|
-
writeDtype: normalizeDtype(step.kvIO.writeDtype, `${step.id}.kvIO.writeDtype`),
|
|
278
|
-
},
|
|
279
|
-
source: 'manifest',
|
|
280
|
-
};
|
|
281
|
-
}
|
|
282
|
-
if (profile?.kvIO) {
|
|
283
|
-
return {
|
|
284
|
-
value: {
|
|
285
|
-
readDtype: normalizeDtype(profile.kvIO.readDtype, `${step.id}.profile.kvIO.readDtype`),
|
|
286
|
-
writeDtype: normalizeDtype(profile.kvIO.writeDtype, `${step.id}.profile.kvIO.writeDtype`),
|
|
287
|
-
},
|
|
288
|
-
source: 'kernelProfile',
|
|
289
|
-
};
|
|
290
|
-
}
|
|
291
|
-
const kvDtype = normalizeDtype(
|
|
292
|
-
sessionDefaults?.kvcache?.kvDtype ?? sessionDefaults.compute.defaults.activationDtype,
|
|
293
|
-
`${step.id}.sessionDefaults.kvcache.kvDtype`
|
|
294
|
-
);
|
|
295
|
-
return {
|
|
296
|
-
value: { readDtype: kvDtype, writeDtype: kvDtype },
|
|
297
|
-
source: 'sessionDefault',
|
|
298
|
-
};
|
|
220
|
+
return resolveExecutionV0KernelProfile(profileIndex, step);
|
|
299
221
|
}
|
|
300
222
|
|
|
301
223
|
function validateStepShape(step, index) {
|
|
@@ -935,7 +857,7 @@ function assertInlineKernelPathSessionCompatibility(path, sessionDefaults) {
|
|
|
935
857
|
}
|
|
936
858
|
}
|
|
937
859
|
|
|
938
|
-
function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
|
|
860
|
+
function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers, finitenessFallbackKernelPathId = null) {
|
|
939
861
|
const activationDtype = normalizeDtype(
|
|
940
862
|
sessionDefaults?.compute?.defaults?.activationDtype ?? 'f16',
|
|
941
863
|
'sessionDefaults.compute.defaults.activationDtype'
|
|
@@ -956,6 +878,9 @@ function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
|
|
|
956
878
|
description: 'Generated from manifest.inference.execution.steps',
|
|
957
879
|
activationDtype,
|
|
958
880
|
kvDtype,
|
|
881
|
+
...(typeof finitenessFallbackKernelPathId === 'string' && finitenessFallbackKernelPathId.length > 0
|
|
882
|
+
? { finitenessFallbackKernelPathId }
|
|
883
|
+
: {}),
|
|
959
884
|
decode: {
|
|
960
885
|
steps: decodeSteps.length > 0 ? decodeSteps : prefillSteps,
|
|
961
886
|
},
|
|
@@ -1186,7 +1111,26 @@ export function compileExecutionV0(options = {}) {
|
|
|
1186
1111
|
...resolvedDecodeSteps.filter((step) => step.phase === 'decode'),
|
|
1187
1112
|
];
|
|
1188
1113
|
|
|
1189
|
-
const
|
|
1114
|
+
const defaultKernelPathId = typeof manifestInference.defaultKernelPath === 'string'
|
|
1115
|
+
&& manifestInference.defaultKernelPath.trim().length > 0
|
|
1116
|
+
? manifestInference.defaultKernelPath.trim()
|
|
1117
|
+
: null;
|
|
1118
|
+
const finitenessFallbackKernelPathId = defaultKernelPathId
|
|
1119
|
+
? selectRuleValue(
|
|
1120
|
+
'inference',
|
|
1121
|
+
'kernelPath',
|
|
1122
|
+
'finitenessFallback',
|
|
1123
|
+
{ kernelPathId: defaultKernelPathId }
|
|
1124
|
+
)
|
|
1125
|
+
: null;
|
|
1126
|
+
|
|
1127
|
+
const kernelPath = buildInlineKernelPath(
|
|
1128
|
+
patchedSteps,
|
|
1129
|
+
resolvedSession,
|
|
1130
|
+
modelId,
|
|
1131
|
+
numLayers,
|
|
1132
|
+
finitenessFallbackKernelPathId
|
|
1133
|
+
);
|
|
1190
1134
|
const layerPipeline = buildLayerPipelineFromExecution(resolvedSteps);
|
|
1191
1135
|
const sessionPatch = buildSessionRuntimePatch(resolvedSession);
|
|
1192
1136
|
const modelOverrides = buildModelRuntimeOverrides(manifestInference);
|
|
@@ -1241,6 +1185,10 @@ export function applyExecutionV0RuntimeConfig(options = {}) {
|
|
|
1241
1185
|
}
|
|
1242
1186
|
|
|
1243
1187
|
const runtimeInferencePatch = { ...executionV0State.runtimeInferencePatch };
|
|
1188
|
+
if (runtimeInference.kernelPath !== undefined) {
|
|
1189
|
+
delete runtimeInferencePatch.kernelPath;
|
|
1190
|
+
delete runtimeInferencePatch.kernelPathSource;
|
|
1191
|
+
}
|
|
1244
1192
|
if (runtimeInferencePatch.modelOverrides) {
|
|
1245
1193
|
runtimeInferencePatch.modelOverrides = mergeRuntimeValues(
|
|
1246
1194
|
runtimeInferencePatch.modelOverrides,
|
|
@@ -42,6 +42,7 @@ export async function processFFNStandard(
|
|
|
42
42
|
hiddenSize,
|
|
43
43
|
probes: context.debugProbes,
|
|
44
44
|
recorder,
|
|
45
|
+
dtype: normedTensor.dtype,
|
|
45
46
|
});
|
|
46
47
|
|
|
47
48
|
// 2. FFN
|
|
@@ -58,6 +59,7 @@ export async function processFFNStandard(
|
|
|
58
59
|
hiddenSize,
|
|
59
60
|
probes: context.debugProbes,
|
|
60
61
|
recorder,
|
|
62
|
+
dtype: ffnOutput.dtype,
|
|
61
63
|
});
|
|
62
64
|
|
|
63
65
|
// 3. Residual add
|
|
@@ -72,6 +74,7 @@ export async function processFFNStandard(
|
|
|
72
74
|
hiddenSize,
|
|
73
75
|
probes: context.debugProbes,
|
|
74
76
|
recorder,
|
|
77
|
+
dtype: output.dtype,
|
|
75
78
|
});
|
|
76
79
|
|
|
77
80
|
if (normedTensor !== postAttn) {
|
|
@@ -71,9 +71,13 @@ export interface PipelineContexts {
|
|
|
71
71
|
*/
|
|
72
72
|
export interface RoPEConfig {
|
|
73
73
|
headDim: number;
|
|
74
|
+
rotaryDim?: number;
|
|
74
75
|
maxSeqLen: number;
|
|
75
76
|
ropeTheta: number;
|
|
76
77
|
ropeLocalTheta?: number | null;
|
|
78
|
+
mropeInterleaved?: boolean;
|
|
79
|
+
mropeSection?: number[] | null;
|
|
80
|
+
partialRotaryFactor?: number | null;
|
|
77
81
|
ropeScale: number;
|
|
78
82
|
ropeLocalScale?: number;
|
|
79
83
|
ropeScalingType?: string | null;
|
|
@@ -206,13 +206,45 @@ function isSameRoPEScalingConfig(
|
|
|
206
206
|
=== (rightScaling?.original_max_position_embeddings ?? null);
|
|
207
207
|
}
|
|
208
208
|
|
|
209
|
+
function resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor) {
|
|
210
|
+
if (rotaryDim != null) {
|
|
211
|
+
if (!Number.isFinite(rotaryDim) || rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
|
|
212
|
+
throw new Error(`RoPE rotary dim must be a positive even integer; got "${rotaryDim}".`);
|
|
213
|
+
}
|
|
214
|
+
if (rotaryDim > headDim) {
|
|
215
|
+
throw new Error(`RoPE rotary dim ${rotaryDim} cannot exceed headDim ${headDim}.`);
|
|
216
|
+
}
|
|
217
|
+
return rotaryDim;
|
|
218
|
+
}
|
|
219
|
+
if (partialRotaryFactor == null) {
|
|
220
|
+
return headDim;
|
|
221
|
+
}
|
|
222
|
+
if (!Number.isFinite(partialRotaryFactor) || partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
|
|
223
|
+
throw new Error(
|
|
224
|
+
`RoPE partialRotaryFactor must be a number in (0, 1]; got "${partialRotaryFactor}".`
|
|
225
|
+
);
|
|
226
|
+
}
|
|
227
|
+
const resolved = Math.trunc(headDim * partialRotaryFactor);
|
|
228
|
+
if (resolved <= 0 || (resolved % 2) !== 0) {
|
|
229
|
+
throw new Error(
|
|
230
|
+
`RoPE partialRotaryFactor=${partialRotaryFactor} with headDim=${headDim} resolves ` +
|
|
231
|
+
`to rotaryDim=${resolved}, but rotaryDim must be a positive even integer.`
|
|
232
|
+
);
|
|
233
|
+
}
|
|
234
|
+
return resolved;
|
|
235
|
+
}
|
|
236
|
+
|
|
209
237
|
|
|
210
238
|
export async function initRoPEFrequencies(config, useGPU) {
|
|
211
239
|
const {
|
|
212
240
|
headDim,
|
|
241
|
+
rotaryDim,
|
|
213
242
|
maxSeqLen,
|
|
214
243
|
ropeTheta,
|
|
215
244
|
ropeLocalTheta,
|
|
245
|
+
mropeInterleaved,
|
|
246
|
+
mropeSection,
|
|
247
|
+
partialRotaryFactor,
|
|
216
248
|
ropeScale,
|
|
217
249
|
ropeLocalScale,
|
|
218
250
|
ropeScalingType,
|
|
@@ -230,14 +262,23 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
230
262
|
const resolvedLocalTheta = ropeLocalTheta ?? ropeTheta;
|
|
231
263
|
const resolvedLocalScalingType = ropeLocalScalingType ?? ropeScalingType;
|
|
232
264
|
const resolvedLocalScaling = ropeLocalScaling ?? ropeScaling;
|
|
265
|
+
const resolvedRotaryDim = resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor);
|
|
266
|
+
const halfDim = resolvedRotaryDim / 2;
|
|
267
|
+
if (mropeInterleaved === true && Array.isArray(mropeSection)) {
|
|
268
|
+
const expandedDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
|
|
269
|
+
if (expandedDim !== resolvedRotaryDim) {
|
|
270
|
+
throw new Error(
|
|
271
|
+
`RoPE mropeSection expands to ${expandedDim} dims, but rotaryDim is ${resolvedRotaryDim}.`
|
|
272
|
+
);
|
|
273
|
+
}
|
|
274
|
+
}
|
|
233
275
|
|
|
234
|
-
const halfDim = headDim / 2;
|
|
235
276
|
const isYarn = ropeScalingType === 'yarn';
|
|
236
277
|
const isLocalYarn = resolvedLocalScalingType === 'yarn';
|
|
237
278
|
|
|
238
279
|
// Compute global (full_attention) frequencies
|
|
239
280
|
const globalFreqs = computeRoPEFreqsForTheta(
|
|
240
|
-
ropeTheta,
|
|
281
|
+
ropeTheta, resolvedRotaryDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
|
|
241
282
|
);
|
|
242
283
|
|
|
243
284
|
// Compute local (sliding_attention) frequencies if different from global.
|
|
@@ -256,7 +297,7 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
256
297
|
if (hasDistinctLocalTheta || hasDistinctLocalScaling) {
|
|
257
298
|
localFreqs = computeRoPEFreqsForTheta(
|
|
258
299
|
resolvedLocalTheta,
|
|
259
|
-
|
|
300
|
+
resolvedRotaryDim,
|
|
260
301
|
maxSeqLen,
|
|
261
302
|
resolvedLocalScale,
|
|
262
303
|
resolvedLocalScalingType,
|
|
@@ -303,9 +344,10 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
303
344
|
|
|
304
345
|
log.debug(
|
|
305
346
|
'Pipeline',
|
|
306
|
-
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
347
|
+
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
307
348
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
308
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
349
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
350
|
+
`interleaved=${mropeInterleaved === true}`
|
|
309
351
|
);
|
|
310
352
|
|
|
311
353
|
return {
|
|
@@ -318,9 +360,10 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
318
360
|
|
|
319
361
|
log.debug(
|
|
320
362
|
'Pipeline',
|
|
321
|
-
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
363
|
+
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
322
364
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
323
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
365
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
366
|
+
`interleaved=${mropeInterleaved === true}`
|
|
324
367
|
);
|
|
325
368
|
|
|
326
369
|
return {
|
|
@@ -688,6 +731,10 @@ function applyChatMLTemplate(prompt) {
|
|
|
688
731
|
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
|
|
689
732
|
}
|
|
690
733
|
|
|
734
|
+
function applyQwenTemplate(prompt) {
|
|
735
|
+
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`;
|
|
736
|
+
}
|
|
737
|
+
|
|
691
738
|
function applyTranslateGemmaTemplate() {
|
|
692
739
|
throw new Error(
|
|
693
740
|
'TranslateGemma template requires structured messages. ' +
|
|
@@ -702,7 +749,7 @@ const PROMPT_TEMPLATES = {
|
|
|
702
749
|
'llama3': applyHeaderBasedTemplate,
|
|
703
750
|
'gpt-oss': applyChannelBasedTemplate,
|
|
704
751
|
'chatml': applyChatMLTemplate,
|
|
705
|
-
'qwen':
|
|
752
|
+
'qwen': applyQwenTemplate,
|
|
706
753
|
'translategemma': applyTranslateGemmaTemplate,
|
|
707
754
|
};
|
|
708
755
|
|
|
@@ -721,7 +768,7 @@ export function applyChatTemplate(prompt, templateType) {
|
|
|
721
768
|
export const applyGemmaChatTemplate = applyTurnBasedTemplate;
|
|
722
769
|
export const applyLlama3ChatTemplate = applyHeaderBasedTemplate;
|
|
723
770
|
export const applyGptOssChatTemplate = applyChannelBasedTemplate;
|
|
724
|
-
export const applyQwenChatTemplate =
|
|
771
|
+
export const applyQwenChatTemplate = applyQwenTemplate;
|
|
725
772
|
|
|
726
773
|
|
|
727
774
|
export function isStopToken(token, stopTokenIds, eosTokenId) {
|