@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -3,19 +3,32 @@ import { createTensor, dtypeBytes } from '../tensor.js';
|
|
|
3
3
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
5
5
|
|
|
6
|
+
function planTransposeDispatch(target, cols) {
|
|
7
|
+
const device = target?.device;
|
|
8
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
9
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
10
|
+
: 65535;
|
|
11
|
+
const dispatchStride = Math.min(cols, maxPerDim * WORKGROUP_SIZES.DEFAULT);
|
|
12
|
+
return {
|
|
13
|
+
dispatchStride,
|
|
14
|
+
workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
|
|
15
|
+
};
|
|
16
|
+
}
|
|
17
|
+
|
|
6
18
|
async function _transpose(target, input, rows, cols, options = {}) {
|
|
7
19
|
const { outputBuffer = null } = options;
|
|
8
20
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
9
21
|
const outputSize = rows * cols * bytesPerElement;
|
|
10
22
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'transpose_output');
|
|
23
|
+
const dispatchPlan = planTransposeDispatch(target, cols);
|
|
11
24
|
|
|
12
25
|
await unifiedKernelWrapper(
|
|
13
26
|
'transpose',
|
|
14
27
|
target,
|
|
15
28
|
'default',
|
|
16
29
|
[input, outputBuf],
|
|
17
|
-
{ rows, cols },
|
|
18
|
-
|
|
30
|
+
{ rows, cols, _pad0: dispatchPlan.dispatchStride, _pad1: 0 },
|
|
31
|
+
[dispatchPlan.workgroups[0], rows, 1]
|
|
19
32
|
);
|
|
20
33
|
|
|
21
34
|
return createTensor(outputBuf, input.dtype, [cols, rows], 'transpose_output');
|
|
@@ -19,14 +19,13 @@ struct Uniforms {
|
|
|
19
19
|
|
|
20
20
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
21
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
-
let
|
|
23
|
-
let
|
|
24
|
-
|
|
22
|
+
let dispatch_stride = max(u._pad0, 1u);
|
|
23
|
+
let row = gid.y;
|
|
24
|
+
let col = gid.x + row * dispatch_stride;
|
|
25
|
+
if (row >= u.rows || col >= u.cols) {
|
|
25
26
|
return;
|
|
26
27
|
}
|
|
27
|
-
|
|
28
|
-
let row = idx / u.cols;
|
|
29
|
-
let col = idx % u.cols;
|
|
28
|
+
let idx = row * u.cols + col;
|
|
30
29
|
let out_idx = col * u.rows + row;
|
|
31
30
|
output[out_idx] = input[idx];
|
|
32
31
|
}
|
|
@@ -31,6 +31,7 @@ async function _upsample2d(target, input, options = {}) {
|
|
|
31
31
|
|
|
32
32
|
const outHeight = resolvedHeight * scale;
|
|
33
33
|
const outWidth = resolvedWidth * scale;
|
|
34
|
+
const outSpatial = outHeight * outWidth;
|
|
34
35
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
35
36
|
const outputSize = channels * outHeight * outWidth * bytesPerElement;
|
|
36
37
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'upsample2d_output');
|
|
@@ -43,7 +44,7 @@ async function _upsample2d(target, input, options = {}) {
|
|
|
43
44
|
out_height: outHeight, out_width: outWidth, scale,
|
|
44
45
|
_pad0: 0, _pad1: 0,
|
|
45
46
|
},
|
|
46
|
-
Math.ceil(
|
|
47
|
+
[Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
|
|
47
48
|
);
|
|
48
49
|
|
|
49
50
|
return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'upsample2d_output');
|
|
@@ -19,19 +19,16 @@ struct Uniforms {
|
|
|
19
19
|
|
|
20
20
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
21
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
-
let idx = gid.x;
|
|
23
22
|
let out_spatial = u.out_height * u.out_width;
|
|
24
|
-
let
|
|
25
|
-
|
|
23
|
+
let spatial_idx = gid.x;
|
|
24
|
+
let channel = gid.y;
|
|
25
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
26
26
|
return;
|
|
27
27
|
}
|
|
28
|
-
|
|
29
|
-
let
|
|
30
|
-
let rem = idx - channel * out_spatial;
|
|
31
|
-
let out_y = rem / u.out_width;
|
|
32
|
-
let out_x = rem - out_y * u.out_width;
|
|
28
|
+
let out_y = spatial_idx / u.out_width;
|
|
29
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
33
30
|
let in_y = out_y / u.scale;
|
|
34
31
|
let in_x = out_x / u.scale;
|
|
35
32
|
let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
|
|
36
|
-
output[
|
|
33
|
+
output[channel * out_spatial + spatial_idx] = input[in_idx];
|
|
37
34
|
}
|
|
@@ -23,19 +23,16 @@ struct Uniforms {
|
|
|
23
23
|
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
25
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let idx = gid.x;
|
|
27
26
|
let out_spatial = u.out_height * u.out_width;
|
|
28
|
-
let
|
|
29
|
-
|
|
27
|
+
let spatial_idx = gid.x;
|
|
28
|
+
let channel = gid.y;
|
|
29
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
30
30
|
return;
|
|
31
31
|
}
|
|
32
|
-
|
|
33
|
-
let
|
|
34
|
-
let rem = idx - channel * out_spatial;
|
|
35
|
-
let out_y = rem / u.out_width;
|
|
36
|
-
let out_x = rem - out_y * u.out_width;
|
|
32
|
+
let out_y = spatial_idx / u.out_width;
|
|
33
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
37
34
|
let in_y = out_y / u.scale;
|
|
38
35
|
let in_x = out_x / u.scale;
|
|
39
36
|
let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
|
|
40
|
-
output[
|
|
37
|
+
output[channel * out_spatial + spatial_idx] = input[in_idx];
|
|
41
38
|
}
|
package/src/gpu/kernels/utils.js
CHANGED
|
@@ -116,9 +116,24 @@ export async function unifiedKernelWrapper(opName, target, variant, bindings, un
|
|
|
116
116
|
index = config.variantMetadata.outputBinding;
|
|
117
117
|
}
|
|
118
118
|
|
|
119
|
+
const buffer = binding?.buffer || binding;
|
|
120
|
+
const isGpuBuffer = buffer && (
|
|
121
|
+
typeof GPUBuffer === 'undefined'
|
|
122
|
+
? true
|
|
123
|
+
: buffer instanceof GPUBuffer
|
|
124
|
+
);
|
|
125
|
+
if (!isGpuBuffer) {
|
|
126
|
+
const bindingLabel = binding?.label ?? buffer?.label ?? 'unknown';
|
|
127
|
+
const bufferType = buffer === null ? 'null' : buffer === undefined ? 'undefined' : buffer.constructor?.name || typeof buffer;
|
|
128
|
+
throw new Error(
|
|
129
|
+
`Kernel "${opName}/${variant}" binding "${bindingConfig.name}" (index ${index}) requires a GPUBuffer ` +
|
|
130
|
+
`(label=${bindingLabel}, type=${bufferType}).`
|
|
131
|
+
);
|
|
132
|
+
}
|
|
133
|
+
|
|
119
134
|
bindGroupEntries.push({
|
|
120
135
|
binding: index,
|
|
121
|
-
resource: { buffer
|
|
136
|
+
resource: { buffer }
|
|
122
137
|
});
|
|
123
138
|
}
|
|
124
139
|
|
|
@@ -929,6 +929,9 @@ async function resolveHarnessOverride(options = {}) {
|
|
|
929
929
|
|
|
930
930
|
async function initializeSuiteModel(options = {}) {
|
|
931
931
|
if (options.harnessOverride) {
|
|
932
|
+
if (options.runtime?.runtimeConfig) {
|
|
933
|
+
setRuntimeConfig(options.runtime.runtimeConfig);
|
|
934
|
+
}
|
|
932
935
|
return resolveHarnessOverride(options);
|
|
933
936
|
}
|
|
934
937
|
const loadStart = performance.now();
|
|
@@ -988,6 +991,14 @@ async function runKernelSuite(options = {}) {
|
|
|
988
991
|
|
|
989
992
|
const DEFAULT_HARNESS_PROMPT = 'Summarize this input in one sentence.';
|
|
990
993
|
const DEFAULT_RUNTIME_PLACEHOLDER_PROMPT = 'Hello from Doppler.';
|
|
994
|
+
const DEFAULT_QWEN_PROMPT = Object.freeze({
|
|
995
|
+
messages: Object.freeze([
|
|
996
|
+
Object.freeze({
|
|
997
|
+
role: 'user',
|
|
998
|
+
content: 'Answer in one short sentence: What color is the sky on a clear day?',
|
|
999
|
+
}),
|
|
1000
|
+
]),
|
|
1001
|
+
});
|
|
991
1002
|
const DEFAULT_TRANSLATEGEMMA_PROMPT = Object.freeze({
|
|
992
1003
|
messages: Object.freeze([
|
|
993
1004
|
Object.freeze({
|
|
@@ -1273,6 +1284,9 @@ function resolvePromptTemplateType(source) {
|
|
|
1273
1284
|
}
|
|
1274
1285
|
|
|
1275
1286
|
function buildDefaultGenerationPrompt(templateType) {
|
|
1287
|
+
if (templateType === 'qwen') {
|
|
1288
|
+
return clonePromptInput(DEFAULT_QWEN_PROMPT);
|
|
1289
|
+
}
|
|
1276
1290
|
if (templateType === 'translategemma') {
|
|
1277
1291
|
return clonePromptInput(DEFAULT_TRANSLATEGEMMA_PROMPT);
|
|
1278
1292
|
}
|
|
@@ -1280,7 +1294,7 @@ function buildDefaultGenerationPrompt(templateType) {
|
|
|
1280
1294
|
}
|
|
1281
1295
|
|
|
1282
1296
|
function shouldPreferModelDefaultPrompt(runtimePrompt, templateType) {
|
|
1283
|
-
if (templateType !== 'translategemma') {
|
|
1297
|
+
if (templateType !== 'translategemma' && templateType !== 'qwen') {
|
|
1284
1298
|
return false;
|
|
1285
1299
|
}
|
|
1286
1300
|
if (typeof runtimePrompt !== 'string') {
|
|
@@ -1289,6 +1303,31 @@ function shouldPreferModelDefaultPrompt(runtimePrompt, templateType) {
|
|
|
1289
1303
|
return runtimePrompt.trim() === DEFAULT_RUNTIME_PLACEHOLDER_PROMPT;
|
|
1290
1304
|
}
|
|
1291
1305
|
|
|
1306
|
+
function assertPromptContract(runtimePrompt, templateType, source = 'runtime.inference.prompt') {
|
|
1307
|
+
if (templateType !== 'translategemma') {
|
|
1308
|
+
return;
|
|
1309
|
+
}
|
|
1310
|
+
if (runtimePrompt === undefined || runtimePrompt === null) {
|
|
1311
|
+
return;
|
|
1312
|
+
}
|
|
1313
|
+
if (typeof runtimePrompt === 'string') {
|
|
1314
|
+
const trimmed = runtimePrompt.trim();
|
|
1315
|
+
if (!trimmed || trimmed === DEFAULT_RUNTIME_PLACEHOLDER_PROMPT) {
|
|
1316
|
+
return;
|
|
1317
|
+
}
|
|
1318
|
+
throw new Error(
|
|
1319
|
+
`TranslateGemma harness prompt contract violation: ${source} must be ` +
|
|
1320
|
+
'{ messages: [...] } with source_lang_code/target_lang_code blocks, not a plain string.'
|
|
1321
|
+
);
|
|
1322
|
+
}
|
|
1323
|
+
if (!isStructuredPromptInput(runtimePrompt)) {
|
|
1324
|
+
throw new Error(
|
|
1325
|
+
`TranslateGemma harness prompt contract violation: ${source} must be ` +
|
|
1326
|
+
'{ messages: [...] } with source_lang_code/target_lang_code blocks.'
|
|
1327
|
+
);
|
|
1328
|
+
}
|
|
1329
|
+
}
|
|
1330
|
+
|
|
1292
1331
|
function describePromptInput(promptInput) {
|
|
1293
1332
|
if (typeof promptInput === 'string') {
|
|
1294
1333
|
return promptInput.trim() || DEFAULT_HARNESS_PROMPT;
|
|
@@ -1305,6 +1344,11 @@ function describePromptInput(promptInput) {
|
|
|
1305
1344
|
if (sourceLang && targetLang) {
|
|
1306
1345
|
return `${sourceLang} -> ${targetLang}: ${text || '[non-text request]'}`;
|
|
1307
1346
|
}
|
|
1347
|
+
const stringContent = asText(firstMessage?.content);
|
|
1348
|
+
if (stringContent) {
|
|
1349
|
+
const role = asText(firstMessage?.role) || 'user';
|
|
1350
|
+
return `${role}: ${stringContent}`;
|
|
1351
|
+
}
|
|
1308
1352
|
try {
|
|
1309
1353
|
return JSON.stringify(promptInput);
|
|
1310
1354
|
} catch {
|
|
@@ -1315,6 +1359,7 @@ function describePromptInput(promptInput) {
|
|
|
1315
1359
|
function resolveGenerationPromptInput(runtimeConfig, runOverrides = null, source = null) {
|
|
1316
1360
|
const templateType = resolvePromptTemplateType(source);
|
|
1317
1361
|
const overridePrompt = runOverrides?.prompt;
|
|
1362
|
+
assertPromptContract(overridePrompt, templateType, 'runOverrides.prompt');
|
|
1318
1363
|
if (typeof overridePrompt === 'string' && overridePrompt.trim()) {
|
|
1319
1364
|
return overridePrompt.trim();
|
|
1320
1365
|
}
|
|
@@ -1323,6 +1368,7 @@ function resolveGenerationPromptInput(runtimeConfig, runOverrides = null, source
|
|
|
1323
1368
|
}
|
|
1324
1369
|
|
|
1325
1370
|
const runtimePrompt = runtimeConfig?.inference?.prompt;
|
|
1371
|
+
assertPromptContract(runtimePrompt, templateType, 'runtimeConfig.inference.prompt');
|
|
1326
1372
|
if (shouldPreferModelDefaultPrompt(runtimePrompt, templateType)) {
|
|
1327
1373
|
return buildDefaultGenerationPrompt(templateType);
|
|
1328
1374
|
}
|
|
@@ -52,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
|
|
|
52
52
|
return { latents, latentWidth, latentHeight };
|
|
53
53
|
}
|
|
54
54
|
|
|
55
|
+
function generateNoiseVector(size, seed) {
|
|
56
|
+
if (!Number.isFinite(size) || size <= 0) {
|
|
57
|
+
throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
|
|
58
|
+
}
|
|
59
|
+
const out = new Float32Array(size);
|
|
60
|
+
const rand = createRng(seed ?? createRandomSeed());
|
|
61
|
+
for (let i = 0; i < size; i++) {
|
|
62
|
+
out[i] = sampleNormal(rand);
|
|
63
|
+
}
|
|
64
|
+
return out;
|
|
65
|
+
}
|
|
66
|
+
|
|
55
67
|
function extractTokenSet(tokensByEncoder, key) {
|
|
56
68
|
const output = {};
|
|
57
69
|
for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
|
|
@@ -195,13 +207,10 @@ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep,
|
|
|
195
207
|
const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
|
|
196
208
|
const noise = isFinalStep
|
|
197
209
|
? null
|
|
198
|
-
:
|
|
199
|
-
|
|
200
|
-
runtime.latent.height,
|
|
201
|
-
runtime.latent.channels,
|
|
202
|
-
runtime.latent.scale,
|
|
210
|
+
: generateNoiseVector(
|
|
211
|
+
sample.length,
|
|
203
212
|
(options.seedBase ?? createRandomSeed()) + stepIndex + 1
|
|
204
|
-
)
|
|
213
|
+
);
|
|
205
214
|
const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
|
|
206
215
|
return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
|
|
207
216
|
}
|
|
@@ -80,3 +80,8 @@ export declare function projectContext(
|
|
|
80
80
|
): Promise<Tensor>;
|
|
81
81
|
|
|
82
82
|
export declare function assertClipHiddenActivationSupported(config: { hidden_act?: string }): void;
|
|
83
|
+
|
|
84
|
+
export declare function resolveGemma2WeightRoot(
|
|
85
|
+
weights: Map<string, any>,
|
|
86
|
+
prefix?: string
|
|
87
|
+
): string;
|
|
@@ -723,8 +723,19 @@ function buildGemma2LayerTypes(layerCount, slidingWindow) {
|
|
|
723
723
|
));
|
|
724
724
|
}
|
|
725
725
|
|
|
726
|
-
function
|
|
727
|
-
const
|
|
726
|
+
export function resolveGemma2WeightRoot(weights, prefix = 'text_encoder') {
|
|
727
|
+
const nestedRoot = `${prefix}.model`;
|
|
728
|
+
if (weights?.has(`${nestedRoot}.embed_tokens.weight`)) {
|
|
729
|
+
return nestedRoot;
|
|
730
|
+
}
|
|
731
|
+
if (weights?.has(`${prefix}.embed_tokens.weight`)) {
|
|
732
|
+
return prefix;
|
|
733
|
+
}
|
|
734
|
+
return nestedRoot;
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
function getGemma2LayerWeight(weights, weightRoot, layerIdx, suffix, required = true) {
|
|
738
|
+
const key = `${weightRoot}.layers.${layerIdx}.${suffix}`;
|
|
728
739
|
const weight = weights.get(key) || null;
|
|
729
740
|
if (!weight && required) {
|
|
730
741
|
throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
|
|
@@ -805,8 +816,9 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
805
816
|
const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
|
|
806
817
|
const numTokens = tokenIds.length;
|
|
807
818
|
const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
|
|
819
|
+
const weightRoot = resolveGemma2WeightRoot(weights, prefix);
|
|
808
820
|
|
|
809
|
-
const embedKey = `${
|
|
821
|
+
const embedKey = `${weightRoot}.embed_tokens.weight`;
|
|
810
822
|
const embedWeight = expectDiffusionWeight(
|
|
811
823
|
weights.get(embedKey),
|
|
812
824
|
embedKey
|
|
@@ -837,16 +849,16 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
837
849
|
const layerWeights = new Map();
|
|
838
850
|
for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
|
|
839
851
|
layerWeights.set(`layer_${layerIdx}`, {
|
|
840
|
-
inputNorm: getGemma2LayerWeight(weights,
|
|
841
|
-
qProj: getGemma2LayerWeight(weights,
|
|
842
|
-
kProj: getGemma2LayerWeight(weights,
|
|
843
|
-
vProj: getGemma2LayerWeight(weights,
|
|
844
|
-
oProj: getGemma2LayerWeight(weights,
|
|
845
|
-
postAttentionNorm: getGemma2LayerWeight(weights,
|
|
846
|
-
preFeedforwardNorm: getGemma2LayerWeight(weights,
|
|
847
|
-
gate: getGemma2LayerWeight(weights,
|
|
848
|
-
up: getGemma2LayerWeight(weights,
|
|
849
|
-
down: getGemma2LayerWeight(weights,
|
|
852
|
+
inputNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'input_layernorm.weight'),
|
|
853
|
+
qProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.q_proj.weight'),
|
|
854
|
+
kProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.k_proj.weight'),
|
|
855
|
+
vProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.v_proj.weight'),
|
|
856
|
+
oProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.o_proj.weight'),
|
|
857
|
+
postAttentionNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'post_attention_layernorm.weight'),
|
|
858
|
+
preFeedforwardNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'pre_feedforward_layernorm.weight'),
|
|
859
|
+
gate: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.gate_proj.weight'),
|
|
860
|
+
up: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.up_proj.weight'),
|
|
861
|
+
down: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.down_proj.weight'),
|
|
850
862
|
});
|
|
851
863
|
}
|
|
852
864
|
|
|
@@ -910,10 +922,10 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
910
922
|
numTokens * resolved.hiddenSize,
|
|
911
923
|
context
|
|
912
924
|
);
|
|
913
|
-
hidden = createTensor(output
|
|
925
|
+
hidden = createTensor(output, activationDtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
|
|
914
926
|
}
|
|
915
927
|
|
|
916
|
-
const finalNormKey = `${
|
|
928
|
+
const finalNormKey = `${weightRoot}.norm.weight`;
|
|
917
929
|
const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
|
|
918
930
|
const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
|
|
919
931
|
batchSize: numTokens,
|
|
@@ -182,10 +182,18 @@ export async function recordLayerAttentionGPU(
|
|
|
182
182
|
// 3. RoPE (modifies tensor in-place)
|
|
183
183
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
184
184
|
await recordRoPE(recorder, qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
185
|
-
numHeads,
|
|
185
|
+
numHeads,
|
|
186
|
+
headDim,
|
|
187
|
+
rotaryDim: config.ropeRotaryDim,
|
|
188
|
+
interleaved: config.ropeInterleaved,
|
|
189
|
+
startPos: currentSeqLen,
|
|
186
190
|
});
|
|
187
191
|
await recordRoPE(recorder, kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
188
|
-
numHeads: numKVHeads,
|
|
192
|
+
numHeads: numKVHeads,
|
|
193
|
+
headDim,
|
|
194
|
+
rotaryDim: config.ropeRotaryDim,
|
|
195
|
+
interleaved: config.ropeInterleaved,
|
|
196
|
+
startPos: currentSeqLen,
|
|
189
197
|
});
|
|
190
198
|
}
|
|
191
199
|
|
|
@@ -502,6 +510,7 @@ export async function recordLayerAttentionGPU(
|
|
|
502
510
|
size: numTokens * numHeads * headDim,
|
|
503
511
|
gate: qGateTensor,
|
|
504
512
|
gateActivation: 'sigmoid',
|
|
513
|
+
inputActivation: 'identity',
|
|
505
514
|
swigluLimit: null,
|
|
506
515
|
});
|
|
507
516
|
recorder.trackTemporaryBuffer(attnOutput.buffer);
|
|
@@ -299,10 +299,18 @@ export async function runLayerAttentionGPU(
|
|
|
299
299
|
|
|
300
300
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
301
301
|
await runRoPE(qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
302
|
-
numHeads,
|
|
302
|
+
numHeads,
|
|
303
|
+
headDim,
|
|
304
|
+
rotaryDim: config.ropeRotaryDim,
|
|
305
|
+
interleaved: config.ropeInterleaved,
|
|
306
|
+
startPos: currentSeqLen,
|
|
303
307
|
});
|
|
304
308
|
await runRoPE(kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
305
|
-
numHeads: numKVHeads,
|
|
309
|
+
numHeads: numKVHeads,
|
|
310
|
+
headDim,
|
|
311
|
+
rotaryDim: config.ropeRotaryDim,
|
|
312
|
+
interleaved: config.ropeInterleaved,
|
|
313
|
+
startPos: currentSeqLen,
|
|
306
314
|
});
|
|
307
315
|
|
|
308
316
|
// Trace RoPE outputs
|
|
@@ -690,6 +698,7 @@ export async function runLayerAttentionGPU(
|
|
|
690
698
|
size: numTokens * numHeads * headDim,
|
|
691
699
|
gate: qGateTensor,
|
|
692
700
|
gateActivation: 'sigmoid',
|
|
701
|
+
inputActivation: 'identity',
|
|
693
702
|
swigluLimit: null,
|
|
694
703
|
});
|
|
695
704
|
releaseBuffer(attnOutput.buffer);
|
|
@@ -224,6 +224,29 @@ function formatChatML(messages) {
|
|
|
224
224
|
return parts.join('');
|
|
225
225
|
}
|
|
226
226
|
|
|
227
|
+
function formatQwen(messages) {
|
|
228
|
+
// Qwen 3.5 chat format is ChatML-like, but the generation prelude includes
|
|
229
|
+
// an explicit empty thinking block before assistant output.
|
|
230
|
+
const parts = [];
|
|
231
|
+
for (const [index, message] of messages.entries()) {
|
|
232
|
+
const role = normalizeChatRole(message?.role);
|
|
233
|
+
assertSupportedChatRole(role, 'Qwen', index);
|
|
234
|
+
if (role === 'system' && index !== 0) {
|
|
235
|
+
throw new Error('Qwen template requires any system message to appear first.');
|
|
236
|
+
}
|
|
237
|
+
const content = normalizeChatMessageContent(message?.content);
|
|
238
|
+
if (role === 'system') {
|
|
239
|
+
parts.push(`<|im_start|>system\n${content}<|im_end|>\n`);
|
|
240
|
+
} else if (role === 'user') {
|
|
241
|
+
parts.push(`<|im_start|>user\n${content}<|im_end|>\n`);
|
|
242
|
+
} else if (role === 'assistant') {
|
|
243
|
+
parts.push(`<|im_start|>assistant\n${content}<|im_end|>\n`);
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
parts.push('<|im_start|>assistant\n<think>\n\n</think>\n\n');
|
|
247
|
+
return parts.join('');
|
|
248
|
+
}
|
|
249
|
+
|
|
227
250
|
function formatTranslateGemmaUserPrompt(content) {
|
|
228
251
|
if (!Array.isArray(content) || content.length !== 1) {
|
|
229
252
|
throw new Error(
|
|
@@ -345,7 +368,7 @@ const CHAT_FORMATTERS = {
|
|
|
345
368
|
'llama3': formatHeaderBased,
|
|
346
369
|
'gpt-oss': formatChannelBased,
|
|
347
370
|
'chatml': formatChatML,
|
|
348
|
-
'qwen':
|
|
371
|
+
'qwen': formatQwen,
|
|
349
372
|
'translategemma': formatTranslateGemma,
|
|
350
373
|
};
|
|
351
374
|
|
|
@@ -363,4 +386,5 @@ export function formatChatMessages(messages, templateType) {
|
|
|
363
386
|
export const formatGemmaChat = formatTurnBased;
|
|
364
387
|
export const formatLlama3Chat = formatHeaderBased;
|
|
365
388
|
export const formatGptOssChat = formatChannelBased;
|
|
389
|
+
export const formatQwenChat = formatQwen;
|
|
366
390
|
export const formatTranslateGemmaChat = formatTranslateGemma;
|
|
@@ -148,6 +148,10 @@ export interface ParsedModelConfig {
|
|
|
148
148
|
slidingWindow: number | null;
|
|
149
149
|
ropeTheta: number;
|
|
150
150
|
ropeLocalTheta: number | null;
|
|
151
|
+
ropeRotaryDim: number;
|
|
152
|
+
ropeInterleaved: boolean;
|
|
153
|
+
mropeSection: number[] | null;
|
|
154
|
+
partialRotaryFactor: number | null;
|
|
151
155
|
ropeScale: number;
|
|
152
156
|
ropeLocalScale: number;
|
|
153
157
|
ropeScalingType: string | null;
|
|
@@ -21,6 +21,28 @@ function assertSupportedRuntimeModelType(manifest) {
|
|
|
21
21
|
);
|
|
22
22
|
}
|
|
23
23
|
|
|
24
|
+
function resolveRotaryDim(headDim, partialRotaryFactor, modelId) {
|
|
25
|
+
if (partialRotaryFactor == null) {
|
|
26
|
+
return headDim;
|
|
27
|
+
}
|
|
28
|
+
if (typeof partialRotaryFactor !== 'number' || Number.isNaN(partialRotaryFactor)) {
|
|
29
|
+
throw new Error(`Manifest "${modelId}" has invalid rope.partialRotaryFactor.`);
|
|
30
|
+
}
|
|
31
|
+
if (partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
|
|
32
|
+
throw new Error(
|
|
33
|
+
`Manifest "${modelId}" requires 0 < rope.partialRotaryFactor <= 1; got ${partialRotaryFactor}.`
|
|
34
|
+
);
|
|
35
|
+
}
|
|
36
|
+
const rotaryDim = Math.trunc(headDim * partialRotaryFactor);
|
|
37
|
+
if (rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
|
|
38
|
+
throw new Error(
|
|
39
|
+
`Manifest "${modelId}" resolves rope rotary dim ${rotaryDim} from headDim=${headDim} ` +
|
|
40
|
+
`and partialRotaryFactor=${partialRotaryFactor}, but rotary dim must be a positive even integer.`
|
|
41
|
+
);
|
|
42
|
+
}
|
|
43
|
+
return rotaryDim;
|
|
44
|
+
}
|
|
45
|
+
|
|
24
46
|
export function getStopTokenIds(manifest) {
|
|
25
47
|
const eosTokenId = manifest?.eos_token_id;
|
|
26
48
|
if (Array.isArray(eosTokenId)) return eosTokenId;
|
|
@@ -130,7 +152,14 @@ export function hasManifestInference(manifest) {
|
|
|
130
152
|
|
|
131
153
|
|
|
132
154
|
export function validateRequiredInferenceFields(inf, modelId) {
|
|
133
|
-
|
|
155
|
+
inf = inf ?? {};
|
|
156
|
+
inf.attention = inf.attention ?? {};
|
|
157
|
+
inf.normalization = inf.normalization ?? {};
|
|
158
|
+
inf.ffn = inf.ffn ?? {};
|
|
159
|
+
inf.rope = inf.rope ?? {};
|
|
160
|
+
inf.output = inf.output ?? {};
|
|
161
|
+
inf.layerPattern = inf.layerPattern ?? {};
|
|
162
|
+
inf.chatTemplate = inf.chatTemplate ?? {};
|
|
134
163
|
const errors = [];
|
|
135
164
|
|
|
136
165
|
// Attention fields - non-nullable required
|
|
@@ -201,6 +230,20 @@ export function validateRequiredInferenceFields(inf, modelId) {
|
|
|
201
230
|
if (inf.rope.ropeLocalTheta === undefined) {
|
|
202
231
|
errors.push('rope.ropeLocalTheta must be explicitly set (null for no local theta, or number)');
|
|
203
232
|
}
|
|
233
|
+
if (inf.rope.mropeInterleaved == null) {
|
|
234
|
+
errors.push('rope.mropeInterleaved is required');
|
|
235
|
+
}
|
|
236
|
+
if (inf.rope.mropeSection === undefined) {
|
|
237
|
+
errors.push('rope.mropeSection must be explicitly set (null when unused, or an array of positive integers)');
|
|
238
|
+
}
|
|
239
|
+
if (inf.rope.partialRotaryFactor === undefined) {
|
|
240
|
+
errors.push('rope.partialRotaryFactor must be explicitly set (null when unused, or a number in (0, 1])');
|
|
241
|
+
} else {
|
|
242
|
+
const factor = inf.rope.partialRotaryFactor;
|
|
243
|
+
if (factor !== null && (typeof factor !== 'number' || Number.isNaN(factor) || factor <= 0 || factor > 1)) {
|
|
244
|
+
errors.push('rope.partialRotaryFactor must be a number in (0, 1] or null');
|
|
245
|
+
}
|
|
246
|
+
}
|
|
204
247
|
|
|
205
248
|
// Output fields - non-nullable required
|
|
206
249
|
if (inf.output.tieWordEmbeddings == null) {
|
|
@@ -458,6 +501,26 @@ export function toParsedConfigFromMerged(merged, manifest) {
|
|
|
458
501
|
const ropeScalingType = inf.rope.ropeScalingType;
|
|
459
502
|
const ropeLocalScale = inf.rope.ropeLocalScalingFactor ?? ropeScale;
|
|
460
503
|
const ropeLocalScalingType = inf.rope.ropeLocalScalingType ?? ropeScalingType;
|
|
504
|
+
const partialRotaryFactor = inf.rope.partialRotaryFactor;
|
|
505
|
+
const ropeInterleaved = inf.rope.mropeInterleaved === true;
|
|
506
|
+
const mropeSection = Array.isArray(inf.rope.mropeSection)
|
|
507
|
+
? inf.rope.mropeSection.map((entry) => Math.trunc(Number(entry)))
|
|
508
|
+
: null;
|
|
509
|
+
const ropeRotaryDim = resolveRotaryDim(arch.headDim, partialRotaryFactor, merged.modelId);
|
|
510
|
+
if (mropeSection && mropeSection.some((entry) => !Number.isFinite(entry) || entry <= 0)) {
|
|
511
|
+
throw new Error(
|
|
512
|
+
`Manifest "${merged.modelId}" has invalid rope.mropeSection; expected positive integers.`
|
|
513
|
+
);
|
|
514
|
+
}
|
|
515
|
+
if (ropeInterleaved && mropeSection) {
|
|
516
|
+
const doubledMropeDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
|
|
517
|
+
if (doubledMropeDim !== ropeRotaryDim) {
|
|
518
|
+
throw new Error(
|
|
519
|
+
`Manifest "${merged.modelId}" declares rope.mropeSection=${JSON.stringify(mropeSection)}, ` +
|
|
520
|
+
`which expands to rotary dim ${doubledMropeDim}, but the resolved rotary dim is ${ropeRotaryDim}.`
|
|
521
|
+
);
|
|
522
|
+
}
|
|
523
|
+
}
|
|
461
524
|
|
|
462
525
|
// Build ropeScaling object from manifest values if scaling is enabled
|
|
463
526
|
// Include YARN params when present
|
|
@@ -532,6 +595,10 @@ export function toParsedConfigFromMerged(merged, manifest) {
|
|
|
532
595
|
slidingWindow: inf.attention.slidingWindow,
|
|
533
596
|
ropeTheta: inf.rope.ropeTheta,
|
|
534
597
|
ropeLocalTheta: inf.rope.ropeLocalTheta,
|
|
598
|
+
ropeRotaryDim,
|
|
599
|
+
ropeInterleaved,
|
|
600
|
+
mropeSection,
|
|
601
|
+
partialRotaryFactor,
|
|
535
602
|
ropeScale,
|
|
536
603
|
ropeLocalScale,
|
|
537
604
|
ropeScalingType,
|