@tryhamster/gerbil 1.0.0-rc.9 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +318 -104
- package/dist/architectures-C1I5V3Dt.mjs +6070 -0
- package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
- package/dist/browser/index.d.ts +276 -590
- package/dist/browser/index.d.ts.map +1 -1
- package/dist/browser/index.js +592 -2334
- package/dist/browser/index.js.map +1 -1
- package/dist/cli.mjs +625 -1098
- package/dist/cli.mjs.map +1 -1
- package/dist/defaults-9komdrbY.mjs +24 -0
- package/dist/defaults-9komdrbY.mjs.map +1 -0
- package/dist/frameworks/express.d.mts +1 -3
- package/dist/frameworks/express.d.mts.map +1 -1
- package/dist/frameworks/express.mjs +7 -7
- package/dist/frameworks/express.mjs.map +1 -1
- package/dist/frameworks/fastify.d.mts +1 -1
- package/dist/frameworks/fastify.d.mts.map +1 -1
- package/dist/frameworks/fastify.mjs +3 -3
- package/dist/frameworks/fastify.mjs.map +1 -1
- package/dist/frameworks/hono.d.mts +1 -1
- package/dist/frameworks/hono.d.mts.map +1 -1
- package/dist/frameworks/hono.mjs +4 -4
- package/dist/frameworks/hono.mjs.map +1 -1
- package/dist/frameworks/next.d.mts +3 -2
- package/dist/frameworks/next.d.mts.map +1 -1
- package/dist/frameworks/next.mjs +4 -4
- package/dist/frameworks/next.mjs.map +1 -1
- package/dist/frameworks/react.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts.map +1 -1
- package/dist/frameworks/trpc.mjs +4 -4
- package/dist/frameworks/trpc.mjs.map +1 -1
- package/dist/gerbil-BetB5xb0.d.mts +488 -0
- package/dist/gerbil-BetB5xb0.d.mts.map +1 -0
- package/dist/gerbil-CTZUa8EZ.mjs +4 -0
- package/dist/gerbil-DNniplr4.mjs +1656 -0
- package/dist/gerbil-DNniplr4.mjs.map +1 -0
- package/dist/gpu/hooks.d.mts +640 -0
- package/dist/gpu/hooks.d.mts.map +1 -0
- package/dist/gpu/hooks.mjs +1369 -0
- package/dist/gpu/hooks.mjs.map +1 -0
- package/dist/gpu/index.d.mts +2 -0
- package/dist/gpu/index.mjs +6 -0
- package/dist/gpu-DFuglcEx.mjs +3790 -0
- package/dist/gpu-DFuglcEx.mjs.map +1 -0
- package/dist/index-Dgmb2kE3.d.mts +245 -0
- package/dist/index-Dgmb2kE3.d.mts.map +1 -0
- package/dist/index-DukkJRMj.d.mts +2114 -0
- package/dist/index-DukkJRMj.d.mts.map +1 -0
- package/dist/index.d.mts +22 -487
- package/dist/index.d.mts.map +1 -1
- package/dist/index.mjs +13 -8
- package/dist/index.mjs.map +1 -1
- package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
- package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
- package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
- package/dist/integrations/ai-sdk.d.mts +75 -6
- package/dist/integrations/ai-sdk.d.mts.map +1 -1
- package/dist/integrations/ai-sdk.mjs +131 -15
- package/dist/integrations/ai-sdk.mjs.map +1 -1
- package/dist/integrations/langchain.d.mts +1 -1
- package/dist/integrations/langchain.d.mts.map +1 -1
- package/dist/integrations/langchain.mjs +5 -5
- package/dist/integrations/langchain.mjs.map +1 -1
- package/dist/integrations/llamaindex.d.mts +1 -1
- package/dist/integrations/llamaindex.d.mts.map +1 -1
- package/dist/integrations/llamaindex.mjs +5 -5
- package/dist/integrations/llamaindex.mjs.map +1 -1
- package/dist/integrations/mcp-client.mjs +3 -3
- package/dist/integrations/mcp-client.mjs.map +1 -1
- package/dist/integrations/mcp.d.mts +3 -2
- package/dist/integrations/mcp.d.mts.map +1 -1
- package/dist/integrations/mcp.mjs +5 -5
- package/dist/{mcp-BvbriaBy.mjs → mcp-D2vvH1Xc.mjs} +4 -4
- package/dist/mcp-D2vvH1Xc.mjs.map +1 -0
- package/dist/memory/index.d.mts +3 -0
- package/dist/memory/index.mjs +6 -0
- package/dist/memory-D1P7Tmda.mjs +4 -0
- package/dist/memory-DVN0MnIG.mjs +132 -0
- package/dist/memory-DVN0MnIG.mjs.map +1 -0
- package/dist/memory-Dj0J1v88.mjs +294 -0
- package/dist/memory-Dj0J1v88.mjs.map +1 -0
- package/dist/moonshine-stt-17dpP1kr.mjs +4 -0
- package/dist/moonshine-stt-4ojLtMq7.mjs +11962 -0
- package/dist/moonshine-stt-4ojLtMq7.mjs.map +1 -0
- package/dist/{one-liner-s-lD8rCC.mjs → one-liner-JhdIPxzF.mjs} +14 -16
- package/dist/one-liner-JhdIPxzF.mjs.map +1 -0
- package/dist/repl-BDRkwPGX.mjs +9 -0
- package/dist/skills/index.d.mts +270 -320
- package/dist/skills/index.d.mts.map +1 -1
- package/dist/skills/index.mjs +5 -5
- package/dist/{skills-CD3Orlex.mjs → skills-CU694Dc8.mjs} +187 -32
- package/dist/skills-CU694Dc8.mjs.map +1 -0
- package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
- package/dist/tools-DQ1mPUw5.mjs.map +1 -0
- package/dist/types-DQBe2lFo.d.mts +165 -0
- package/dist/types-DQBe2lFo.d.mts.map +1 -0
- package/dist/{types-CiTc7ez3.d.mts → types-LlyYILII.d.mts} +112 -14
- package/dist/types-LlyYILII.d.mts.map +1 -0
- package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
- package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
- package/dist/vector-B0panuy6.mjs +95 -0
- package/dist/vector-B0panuy6.mjs.map +1 -0
- package/docs/PROJECT-STATE.md +321 -0
- package/docs/adding-a-model-family.md +280 -0
- package/docs/ai-sdk.md +70 -61
- package/docs/architecture/overview.md +17 -7
- package/docs/browser.md +203 -8
- package/docs/embeddings.md +156 -0
- package/docs/gerbil-site-native-migration.md +217 -0
- package/docs/gpu-engine/architectures.md +398 -0
- package/docs/gpu-engine/ir.md +372 -0
- package/docs/gpu-engine/kernels.md +718 -0
- package/docs/gpu-engine/paper.html +1759 -0
- package/docs/gpu-engine/paper.md +2109 -0
- package/docs/gpu-engine/safetensors.md +312 -0
- package/docs/gpu-engine/tokenizer.md +302 -0
- package/docs/memory-rag.md +91 -0
- package/docs/metal-safari-intel.md +190 -0
- package/docs/mobile-failure-diagnosis.md +124 -0
- package/docs/mobile.md +99 -0
- package/docs/observability.md +230 -0
- package/docs/onnx-removal-plan.md +339 -0
- package/docs/research/autoresearch-portable.md +904 -0
- package/docs/research/dispatch-reduction-hivemind.md +84 -0
- package/docs/research/ios-safari-model-caching.md +117 -0
- package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
- package/docs/research/native-stt-model-selection.md +49 -0
- package/docs/research/native-tts-model-selection.md +90 -0
- package/docs/research/native-vs-chromium-decision.md +152 -0
- package/docs/research/nemotron-mamba2-inference.md +910 -0
- package/docs/research/qwen35-multimodal.md +293 -0
- package/docs/research/qwen36-gemma4-targets.md +337 -0
- package/docs/research/sota-embedding-models.md +179 -0
- package/docs/research/sota-mobile-models-2026.md +263 -0
- package/docs/research/sota-modality-models.md +202 -0
- package/docs/research/tps-baselines.md +71 -0
- package/docs/research/webgpu-m4-reference.md +104 -0
- package/docs/site-update-plan.md +155 -0
- package/docs/structured-output.md +123 -0
- package/docs/stt.md +63 -446
- package/docs/tts.md +77 -499
- package/docs/vision.md +100 -338
- package/package.json +22 -7
- package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
- package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
- package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
- package/dist/gerbil-CJ3ifloF.mjs +0 -4
- package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
- package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
- package/dist/gerbil-qOTe1nl2.d.mts +0 -431
- package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
- package/dist/kokoro-BNTb6egA.mjs +0 -20210
- package/dist/kokoro-BNTb6egA.mjs.map +0 -1
- package/dist/kokoro-CMOGDSgT.js +0 -20212
- package/dist/kokoro-CMOGDSgT.js.map +0 -1
- package/dist/mcp-BvbriaBy.mjs.map +0 -1
- package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
- package/dist/repl-DveXw36T.mjs +0 -9
- package/dist/skills-CD3Orlex.mjs.map +0 -1
- package/dist/stt-Bu-E23Sc.js +0 -433
- package/dist/stt-Bu-E23Sc.js.map +0 -1
- package/dist/stt-CpLYbGFd.mjs +0 -433
- package/dist/stt-CpLYbGFd.mjs.map +0 -1
- package/dist/stt-DRPLEEHB.mjs +0 -3
- package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
- package/dist/transformers.web-DiD1gTwk.js +0 -44695
- package/dist/transformers.web-DiD1gTwk.js.map +0 -1
- package/dist/transformers.web-u34VxRFM.js +0 -3
- package/dist/tts-CqroPaSK.js +0 -724
- package/dist/tts-CqroPaSK.js.map +0 -1
- package/dist/tts-DXgsKGCe.mjs +0 -3
- package/dist/tts-DeGANMNV.mjs +0 -730
- package/dist/tts-DeGANMNV.mjs.map +0 -1
- package/dist/types-CiTc7ez3.d.mts.map +0 -1
- /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
- /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
- /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
|
@@ -0,0 +1,3790 @@
|
|
|
1
|
+
import { n as resolveDefaultRepo } from "./defaults-9komdrbY.mjs";
|
|
2
|
+
import { C as parseKaniConfig, E as GEMMA4_VIS_KEYS, T as DTYPE_BYTES, _ as generateNanoCodecDecoderGraph, b as kaniLayerAlpha, d as KANI_START_OF_HUMAN, f as audioTokensToCodes, g as generateKaniTtsGraph, m as computeKaniPositions, p as buildKaniLayerCosSin, u as KANI_END_OF_HUMAN, v as kaniAttentionLayerIndices, w as CANONICAL_KEYS, x as kaniSinTensor, y as kaniCosTensor } from "./architectures-C1I5V3Dt.mjs";
|
|
3
|
+
import { c as MATMUL_BIAS_F16C_SPEC, d as createStorageBuffer, f as createUniformBuffer, g as verifyGPU, h as initGPU, i as loadModel, l as clearPipelineCache, m as getOrCreatePipeline, o as Executor, p as destroyBuffers, r as loadKaniTTS, s as KERNEL_REGISTRY, u as createBindGroup } from "./moonshine-stt-4ojLtMq7.mjs";
|
|
4
|
+
|
|
5
|
+
//#region src/gpu/architectures/gemma4_vision.ts
|
|
6
|
+
/**
|
|
7
|
+
* Resolve the Gemma 4 vision dims from a raw HF config. Accepts either the
|
|
8
|
+
* top-level config (reads `.vision_config` + `.text_config.hidden_size`) or a
|
|
9
|
+
* bare vision_config (then `textHidden` falls back to the projector row count if
|
|
10
|
+
* present, else hidden). Family-general — no E2B constants.
|
|
11
|
+
*/
|
|
12
|
+
function resolveGemma4VisionInfo(rawConfig) {
|
|
13
|
+
const vcfg = rawConfig.vision_config ?? rawConfig;
|
|
14
|
+
const tcfg = rawConfig.text_config ?? {};
|
|
15
|
+
const hidden_size = vcfg.hidden_size;
|
|
16
|
+
const num_heads = vcfg.num_attention_heads;
|
|
17
|
+
const head_dim = vcfg.head_dim ?? Math.floor(hidden_size / num_heads);
|
|
18
|
+
const depth = vcfg.num_hidden_layers;
|
|
19
|
+
const intermediate_size = vcfg.intermediate_size;
|
|
20
|
+
const patch_size = vcfg.patch_size;
|
|
21
|
+
const patch_dim = (vcfg.num_channels ?? 3) * patch_size * patch_size;
|
|
22
|
+
const pooling_kernel_size = vcfg.pooling_kernel_size ?? 1;
|
|
23
|
+
const rope_theta = (vcfg.rope_parameters ?? {}).rope_theta ?? 100;
|
|
24
|
+
const rms_norm_eps = vcfg.rms_norm_eps ?? 1e-6;
|
|
25
|
+
return {
|
|
26
|
+
hiddenSize: hidden_size,
|
|
27
|
+
numHeads: num_heads,
|
|
28
|
+
headDim: head_dim,
|
|
29
|
+
depth,
|
|
30
|
+
intermediateSize: intermediate_size,
|
|
31
|
+
textHidden: tcfg.hidden_size ?? hidden_size,
|
|
32
|
+
patchSize: patch_size,
|
|
33
|
+
patchDim: patch_dim,
|
|
34
|
+
poolingKernelSize: pooling_kernel_size,
|
|
35
|
+
ropeTheta: rope_theta,
|
|
36
|
+
rmsNormEps: rms_norm_eps
|
|
37
|
+
};
|
|
38
|
+
}
|
|
39
|
+
/**
|
|
40
|
+
* Dequantize an MLX affine-int4 weight to a plain f32 [rows, cols] matrix.
|
|
41
|
+
* MLX packs 8 int4 values per u32 (low-nibble first); each group of `groupSize`
|
|
42
|
+
* columns shares one scale + bias: w[r,c] = scale[r, c/gs] * q + bias[r, c/gs].
|
|
43
|
+
* Used for the Gemma 4 multimodal projector (`embed_vision.embedding_projection`)
|
|
44
|
+
* in MLX-4bit checkpoints, where (unlike the BF16 ViT body) the projector is int4.
|
|
45
|
+
*/
|
|
46
|
+
function dequantizeMLXProjection(packed, scales, biases, rows, cols, groupSize) {
|
|
47
|
+
const out = new Float32Array(rows * cols);
|
|
48
|
+
const numGroups = cols / groupSize;
|
|
49
|
+
const u32PerRow = cols / 8;
|
|
50
|
+
for (let r = 0; r < rows; r++) {
|
|
51
|
+
const rowPacked = r * u32PerRow;
|
|
52
|
+
const rowScale = r * numGroups;
|
|
53
|
+
for (let c = 0; c < cols; c++) {
|
|
54
|
+
const q = packed[rowPacked + (c >> 3)] >>> (c & 7) * 4 & 15;
|
|
55
|
+
const g = Math.floor(c / groupSize);
|
|
56
|
+
out[r * cols + c] = scales[rowScale + g] * q + biases[rowScale + g];
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
return out;
|
|
60
|
+
}
|
|
61
|
+
/**
|
|
62
|
+
* If the Gemma 4 multimodal projector arrived as an MLX affine-int4 triplet
|
|
63
|
+
* (`embed_vision.embedding_projection.{weight(U32), scales, biases}`), dequantize
|
|
64
|
+
* it in-place to a plain f32 `embed_vision.embedding_projection.weight` and drop
|
|
65
|
+
* the scales/biases, so the vision graph's plain MatMul on the projector works for
|
|
66
|
+
* MLX-4bit checkpoints too. No-op for BF16 (HF) checkpoints (weight already f32).
|
|
67
|
+
*/
|
|
68
|
+
function dequantizeGemma4VisionProjection(weights, groupSize, rows, cols) {
|
|
69
|
+
const wKey = GEMMA4_VIS_KEYS.projW;
|
|
70
|
+
const w = weights.get(wKey);
|
|
71
|
+
const sc = weights.get("embed_vision.embedding_projection.scales");
|
|
72
|
+
const bi = weights.get("embed_vision.embedding_projection.biases");
|
|
73
|
+
if (!w || !(w.data instanceof Uint32Array) || !sc || !bi) return;
|
|
74
|
+
const scales = sc.data instanceof Float32Array ? sc.data : new Float32Array(sc.data.buffer, sc.data.byteOffset, sc.data.byteLength / 4);
|
|
75
|
+
const biases = bi.data instanceof Float32Array ? bi.data : new Float32Array(bi.data.buffer, bi.data.byteOffset, bi.data.byteLength / 4);
|
|
76
|
+
const f32 = dequantizeMLXProjection(w.data, scales, biases, rows, cols, groupSize);
|
|
77
|
+
weights.set(wKey, {
|
|
78
|
+
data: f32,
|
|
79
|
+
shape: [rows, cols]
|
|
80
|
+
});
|
|
81
|
+
weights.delete("embed_vision.embedding_projection.scales");
|
|
82
|
+
weights.delete("embed_vision.embedding_projection.biases");
|
|
83
|
+
}
|
|
84
|
+
/**
|
|
85
|
+
* Ensure the Gemma3n multimodal-embedder norm weights exist in the weights map.
|
|
86
|
+
*
|
|
87
|
+
* `embed_vision.embedding_post_projection_norm` is a no-scale RMSNorm in HF
|
|
88
|
+
* (Gemma3nRMSNorm(..., with_scale=False)), so the checkpoint ships NO weight for
|
|
89
|
+
* it — we synthesize a ones gain (matching the kernel's `(x/rms)*weight`).
|
|
90
|
+
*
|
|
91
|
+
* `embed_vision.soft_embedding_norm` IS a learned RMSNorm; the full BF16 repo ships
|
|
92
|
+
* its gain. Lean exports (e.g. some MLX-4bit conversions) omit it — we fall back to
|
|
93
|
+
* a ones gain so the model still loads and produces coherent (if not bit-exact)
|
|
94
|
+
* output. Call BEFORE VisionExecutor.uploadWeights().
|
|
95
|
+
*
|
|
96
|
+
* Gemma 4 vision RMSNorm stores the FULL gain (no +1 bake), so a ones gain is a
|
|
97
|
+
* true identity scale.
|
|
98
|
+
*/
|
|
99
|
+
function ensureGemma4VisionEmbedderNorms(weights, visionHidden, textHidden) {
|
|
100
|
+
const ones = (n) => ({
|
|
101
|
+
data: new Float32Array(n).fill(1),
|
|
102
|
+
shape: [n]
|
|
103
|
+
});
|
|
104
|
+
if (!weights.has(GEMMA4_VIS_KEYS.postProjNormW)) weights.set(GEMMA4_VIS_KEYS.postProjNormW, ones(textHidden));
|
|
105
|
+
if (!weights.has(GEMMA4_VIS_KEYS.softEmbNormW)) weights.set(GEMMA4_VIS_KEYS.softEmbNormW, ones(visionHidden));
|
|
106
|
+
}
|
|
107
|
+
/**
|
|
108
|
+
* Patch the ClippedMatMul nodes of a Gemma 4 vision graph with the calibrated clip
|
|
109
|
+
* scalars from the checkpoint (Gemma4ClippableLinear's per-tensor input/output
|
|
110
|
+
* min/max buffers), then drop those scalar tensors from the weights map so the
|
|
111
|
+
* vision executor doesn't try to upload them as GPU buffers. Call BEFORE
|
|
112
|
+
* VisionExecutor.uploadWeights(). Missing scalars default to ±inf (clip = identity),
|
|
113
|
+
* so a checkpoint without calibration still loads.
|
|
114
|
+
*/
|
|
115
|
+
function patchGemma4VisionClips(graph, weights) {
|
|
116
|
+
const readScalar = (key) => {
|
|
117
|
+
if (!key) return void 0;
|
|
118
|
+
const e = weights.get(key);
|
|
119
|
+
if (!e) return void 0;
|
|
120
|
+
const d = e.data instanceof Float32Array ? e.data : new Float32Array(e.data.buffer, e.data.byteOffset, e.data.byteLength / 4);
|
|
121
|
+
return d.length > 0 ? d[0] : void 0;
|
|
122
|
+
};
|
|
123
|
+
for (const node of graph.nodes) {
|
|
124
|
+
if (node.opType !== "ClippedMatMul") continue;
|
|
125
|
+
const a = node.attributes;
|
|
126
|
+
const imin = readScalar(a.clip_input_min_key);
|
|
127
|
+
const imax = readScalar(a.clip_input_max_key);
|
|
128
|
+
const omin = readScalar(a.clip_output_min_key);
|
|
129
|
+
const omax = readScalar(a.clip_output_max_key);
|
|
130
|
+
if (imin !== void 0) a.imin = imin;
|
|
131
|
+
if (imax !== void 0) a.imax = imax;
|
|
132
|
+
if (omin !== void 0) a.omin = omin;
|
|
133
|
+
if (omax !== void 0) a.omax = omax;
|
|
134
|
+
for (const k of [
|
|
135
|
+
a.clip_input_min_key,
|
|
136
|
+
a.clip_input_max_key,
|
|
137
|
+
a.clip_output_min_key,
|
|
138
|
+
a.clip_output_max_key
|
|
139
|
+
]) if (k) weights.delete(k);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
/**
|
|
143
|
+
* Build the Gemma 4 ViT graph. Shaped by symbolic "N" (number of patches, runtime)
|
|
144
|
+
* and "Np" (number of pooled tokens = ceil(grid_h/k)·ceil(grid_w/k), runtime),
|
|
145
|
+
* resolved from input tensor dims — like the Qwen ViT's "N"/"Nm".
|
|
146
|
+
*/
|
|
147
|
+
function generateGemma4VisionGraph(rawConfig) {
|
|
148
|
+
const info = resolveGemma4VisionInfo(rawConfig);
|
|
149
|
+
const { hiddenSize: hidden_size, numHeads: num_heads, headDim: head_dim, depth, intermediateSize: intermediate_size, textHidden: text_hidden, patchDim: patch_dim, rmsNormEps: eps } = info;
|
|
150
|
+
const qkv_dim = num_heads * head_dim;
|
|
151
|
+
const tensors = {};
|
|
152
|
+
const nodes = [];
|
|
153
|
+
const executionOrder = [];
|
|
154
|
+
function addTensor(desc) {
|
|
155
|
+
tensors[desc.name] = desc;
|
|
156
|
+
}
|
|
157
|
+
function addNode(node) {
|
|
158
|
+
nodes.push(node);
|
|
159
|
+
executionOrder.push(node.id);
|
|
160
|
+
}
|
|
161
|
+
function constant(name, shape, safetensorsKey = name) {
|
|
162
|
+
addTensor({
|
|
163
|
+
name,
|
|
164
|
+
shape,
|
|
165
|
+
dtype: "f32",
|
|
166
|
+
storage: "constant",
|
|
167
|
+
safetensorsKey
|
|
168
|
+
});
|
|
169
|
+
}
|
|
170
|
+
function activation(name, shape) {
|
|
171
|
+
addTensor({
|
|
172
|
+
name,
|
|
173
|
+
shape,
|
|
174
|
+
dtype: "f32",
|
|
175
|
+
storage: "activation"
|
|
176
|
+
});
|
|
177
|
+
}
|
|
178
|
+
/** MatMul (no bias): out[M,N] = in[M,K] @ W[N,K]^T. M from `mTensor`. */
|
|
179
|
+
function matmul(id, input, weight, output, K, N, mTensor) {
|
|
180
|
+
addNode({
|
|
181
|
+
id,
|
|
182
|
+
opType: "MatMul",
|
|
183
|
+
inputs: [input, weight],
|
|
184
|
+
outputs: [output],
|
|
185
|
+
attributes: {
|
|
186
|
+
M_tensor: mTensor,
|
|
187
|
+
K,
|
|
188
|
+
N
|
|
189
|
+
}
|
|
190
|
+
});
|
|
191
|
+
}
|
|
192
|
+
/**
|
|
193
|
+
* Gemma4ClippableLinear: out = clamp(clamp(x,imin,imax) @ W^T, omin, omax).
|
|
194
|
+
* The four clip scalars live in the checkpoint as per-tensor tensors under
|
|
195
|
+
* `<projBase>.{input_min,input_max,output_min,output_max}`; the loader reads
|
|
196
|
+
* them and patches the op's imin/imax/omin/omax attributes (and drops the
|
|
197
|
+
* scalar tensors), like Gemma4's layer_scalar. `projBase` is the canonical
|
|
198
|
+
* key prefix of the projection (e.g. `vision_tower.encoder.layers.0.self_attn.q_proj`).
|
|
199
|
+
*/
|
|
200
|
+
function clippedMatmul(id, input, weight, output, K, N, mTensor, projBase) {
|
|
201
|
+
addNode({
|
|
202
|
+
id,
|
|
203
|
+
opType: "ClippedMatMul",
|
|
204
|
+
inputs: [input, weight],
|
|
205
|
+
outputs: [output],
|
|
206
|
+
attributes: {
|
|
207
|
+
M_tensor: mTensor,
|
|
208
|
+
K,
|
|
209
|
+
N,
|
|
210
|
+
clip_input_min_key: `${projBase}.input_min`,
|
|
211
|
+
clip_input_max_key: `${projBase}.input_max`,
|
|
212
|
+
clip_output_min_key: `${projBase}.output_min`,
|
|
213
|
+
clip_output_max_key: `${projBase}.output_max`
|
|
214
|
+
}
|
|
215
|
+
});
|
|
216
|
+
}
|
|
217
|
+
/** RMSNorm over `hidden` (per row), Gemma full-gain (no +1; baked elsewhere=none). */
|
|
218
|
+
function rmsnorm(id, input, w, output, hidden, seqTensor) {
|
|
219
|
+
addNode({
|
|
220
|
+
id,
|
|
221
|
+
opType: "RMSNorm",
|
|
222
|
+
inputs: [input, w],
|
|
223
|
+
outputs: [output],
|
|
224
|
+
attributes: {
|
|
225
|
+
hidden_size: hidden,
|
|
226
|
+
eps,
|
|
227
|
+
seq_len_tensor: seqTensor
|
|
228
|
+
}
|
|
229
|
+
});
|
|
230
|
+
}
|
|
231
|
+
function residualAdd(id, a, b, output) {
|
|
232
|
+
addNode({
|
|
233
|
+
id,
|
|
234
|
+
opType: "Add",
|
|
235
|
+
inputs: [a, b],
|
|
236
|
+
outputs: [output],
|
|
237
|
+
attributes: {
|
|
238
|
+
count_tensor: a,
|
|
239
|
+
hidden_size
|
|
240
|
+
}
|
|
241
|
+
});
|
|
242
|
+
}
|
|
243
|
+
activation("g4v_patches", ["N", patch_dim]);
|
|
244
|
+
activation("g4v_pos_embeds", ["N", hidden_size]);
|
|
245
|
+
activation("g4v_cos", ["N", head_dim]);
|
|
246
|
+
activation("g4v_sin", ["N", head_dim]);
|
|
247
|
+
activation("g4v_pool_w", ["Np", "N"]);
|
|
248
|
+
constant(GEMMA4_VIS_KEYS.patchEmbedProjW, [hidden_size, patch_dim]);
|
|
249
|
+
activation("g4v_patch_mm", ["N", hidden_size]);
|
|
250
|
+
activation("g4v_h0", ["N", hidden_size]);
|
|
251
|
+
matmul("g4v_patch_proj", "g4v_patches", GEMMA4_VIS_KEYS.patchEmbedProjW, "g4v_patch_mm", patch_dim, hidden_size, "g4v_patches");
|
|
252
|
+
addNode({
|
|
253
|
+
id: "g4v_add_pos",
|
|
254
|
+
opType: "Add",
|
|
255
|
+
inputs: ["g4v_patch_mm", "g4v_pos_embeds"],
|
|
256
|
+
outputs: ["g4v_h0"],
|
|
257
|
+
attributes: {
|
|
258
|
+
count_tensor: "g4v_patch_mm",
|
|
259
|
+
hidden_size
|
|
260
|
+
}
|
|
261
|
+
});
|
|
262
|
+
let prev = "g4v_h0";
|
|
263
|
+
for (let i = 0; i < depth; i++) {
|
|
264
|
+
const p = `g4v_b${i}`;
|
|
265
|
+
constant(GEMMA4_VIS_KEYS.inputNorm(i), [hidden_size]);
|
|
266
|
+
activation(`${p}_n1`, ["N", hidden_size]);
|
|
267
|
+
rmsnorm(`${p}_input_norm`, prev, GEMMA4_VIS_KEYS.inputNorm(i), `${p}_n1`, hidden_size, prev);
|
|
268
|
+
constant(GEMMA4_VIS_KEYS.qProjW(i), [qkv_dim, hidden_size]);
|
|
269
|
+
constant(GEMMA4_VIS_KEYS.kProjW(i), [qkv_dim, hidden_size]);
|
|
270
|
+
constant(GEMMA4_VIS_KEYS.vProjW(i), [qkv_dim, hidden_size]);
|
|
271
|
+
activation(`${p}_q`, ["N", qkv_dim]);
|
|
272
|
+
activation(`${p}_k`, ["N", qkv_dim]);
|
|
273
|
+
activation(`${p}_v`, ["N", qkv_dim]);
|
|
274
|
+
const saBase = `vision_tower.encoder.layers.${i}.self_attn`;
|
|
275
|
+
clippedMatmul(`${p}_q_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.qProjW(i), `${p}_q`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.q_proj`);
|
|
276
|
+
clippedMatmul(`${p}_k_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.kProjW(i), `${p}_k`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.k_proj`);
|
|
277
|
+
clippedMatmul(`${p}_v_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.vProjW(i), `${p}_v`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.v_proj`);
|
|
278
|
+
constant(GEMMA4_VIS_KEYS.qNormW(i), [head_dim]);
|
|
279
|
+
constant(GEMMA4_VIS_KEYS.kNormW(i), [head_dim]);
|
|
280
|
+
activation(`${p}_qn`, ["N", qkv_dim]);
|
|
281
|
+
activation(`${p}_kn`, ["N", qkv_dim]);
|
|
282
|
+
rmsnorm(`${p}_q_norm`, `${p}_q`, GEMMA4_VIS_KEYS.qNormW(i), `${p}_qn`, head_dim, `${p}_q`);
|
|
283
|
+
rmsnorm(`${p}_k_norm`, `${p}_k`, GEMMA4_VIS_KEYS.kNormW(i), `${p}_kn`, head_dim, `${p}_k`);
|
|
284
|
+
activation(`${p}_qr`, ["N", qkv_dim]);
|
|
285
|
+
activation(`${p}_kr`, ["N", qkv_dim]);
|
|
286
|
+
addNode({
|
|
287
|
+
id: `${p}_rope_q`,
|
|
288
|
+
opType: "ApplyRotaryEmb",
|
|
289
|
+
inputs: [
|
|
290
|
+
`${p}_qn`,
|
|
291
|
+
"g4v_cos",
|
|
292
|
+
"g4v_sin"
|
|
293
|
+
],
|
|
294
|
+
outputs: [`${p}_qr`],
|
|
295
|
+
attributes: {
|
|
296
|
+
num_heads,
|
|
297
|
+
head_dim
|
|
298
|
+
}
|
|
299
|
+
});
|
|
300
|
+
addNode({
|
|
301
|
+
id: `${p}_rope_k`,
|
|
302
|
+
opType: "ApplyRotaryEmb",
|
|
303
|
+
inputs: [
|
|
304
|
+
`${p}_kn`,
|
|
305
|
+
"g4v_cos",
|
|
306
|
+
"g4v_sin"
|
|
307
|
+
],
|
|
308
|
+
outputs: [`${p}_kr`],
|
|
309
|
+
attributes: {
|
|
310
|
+
num_heads,
|
|
311
|
+
head_dim
|
|
312
|
+
}
|
|
313
|
+
});
|
|
314
|
+
activation(`${p}_attn`, ["N", qkv_dim]);
|
|
315
|
+
addNode({
|
|
316
|
+
id: `${p}_attention`,
|
|
317
|
+
opType: "Attention",
|
|
318
|
+
inputs: [
|
|
319
|
+
`${p}_qr`,
|
|
320
|
+
`${p}_kr`,
|
|
321
|
+
`${p}_v`
|
|
322
|
+
],
|
|
323
|
+
outputs: [`${p}_attn`],
|
|
324
|
+
attributes: {
|
|
325
|
+
num_q_heads: num_heads,
|
|
326
|
+
num_kv_heads: num_heads,
|
|
327
|
+
head_dim,
|
|
328
|
+
causal: false
|
|
329
|
+
}
|
|
330
|
+
});
|
|
331
|
+
constant(GEMMA4_VIS_KEYS.oProjW(i), [hidden_size, qkv_dim]);
|
|
332
|
+
activation(`${p}_o`, ["N", hidden_size]);
|
|
333
|
+
clippedMatmul(`${p}_o_proj`, `${p}_attn`, GEMMA4_VIS_KEYS.oProjW(i), `${p}_o`, qkv_dim, hidden_size, `${p}_attn`, `${saBase}.o_proj`);
|
|
334
|
+
constant(GEMMA4_VIS_KEYS.postAttnNorm(i), [hidden_size]);
|
|
335
|
+
activation(`${p}_post_attn`, ["N", hidden_size]);
|
|
336
|
+
rmsnorm(`${p}_post_attn_norm`, `${p}_o`, GEMMA4_VIS_KEYS.postAttnNorm(i), `${p}_post_attn`, hidden_size, `${p}_o`);
|
|
337
|
+
activation(`${p}_res1`, ["N", hidden_size]);
|
|
338
|
+
residualAdd(`${p}_residual1`, prev, `${p}_post_attn`, `${p}_res1`);
|
|
339
|
+
constant(GEMMA4_VIS_KEYS.preFfNorm(i), [hidden_size]);
|
|
340
|
+
activation(`${p}_n2`, ["N", hidden_size]);
|
|
341
|
+
rmsnorm(`${p}_pre_ff_norm`, `${p}_res1`, GEMMA4_VIS_KEYS.preFfNorm(i), `${p}_n2`, hidden_size, `${p}_res1`);
|
|
342
|
+
constant(GEMMA4_VIS_KEYS.gateProjW(i), [intermediate_size, hidden_size]);
|
|
343
|
+
constant(GEMMA4_VIS_KEYS.upProjW(i), [intermediate_size, hidden_size]);
|
|
344
|
+
constant(GEMMA4_VIS_KEYS.downProjW(i), [hidden_size, intermediate_size]);
|
|
345
|
+
activation(`${p}_gate`, ["N", intermediate_size]);
|
|
346
|
+
activation(`${p}_up`, ["N", intermediate_size]);
|
|
347
|
+
activation(`${p}_gelu`, ["N", intermediate_size]);
|
|
348
|
+
activation(`${p}_geglu`, ["N", intermediate_size]);
|
|
349
|
+
activation(`${p}_mlp`, ["N", hidden_size]);
|
|
350
|
+
const mlpBase = `vision_tower.encoder.layers.${i}.mlp`;
|
|
351
|
+
clippedMatmul(`${p}_gate_proj`, `${p}_n2`, GEMMA4_VIS_KEYS.gateProjW(i), `${p}_gate`, hidden_size, intermediate_size, `${p}_n2`, `${mlpBase}.gate_proj`);
|
|
352
|
+
clippedMatmul(`${p}_up_proj`, `${p}_n2`, GEMMA4_VIS_KEYS.upProjW(i), `${p}_up`, hidden_size, intermediate_size, `${p}_n2`, `${mlpBase}.up_proj`);
|
|
353
|
+
addNode({
|
|
354
|
+
id: `${p}_gelu_op`,
|
|
355
|
+
opType: "GELU",
|
|
356
|
+
inputs: [`${p}_gate`],
|
|
357
|
+
outputs: [`${p}_gelu`],
|
|
358
|
+
attributes: { count_tensor: `${p}_gate` }
|
|
359
|
+
});
|
|
360
|
+
addNode({
|
|
361
|
+
id: `${p}_geglu_mul`,
|
|
362
|
+
opType: "Mul",
|
|
363
|
+
inputs: [`${p}_gelu`, `${p}_up`],
|
|
364
|
+
outputs: [`${p}_geglu`],
|
|
365
|
+
attributes: {
|
|
366
|
+
count_tensor: `${p}_gelu`,
|
|
367
|
+
hidden_size: intermediate_size
|
|
368
|
+
}
|
|
369
|
+
});
|
|
370
|
+
clippedMatmul(`${p}_down_proj`, `${p}_geglu`, GEMMA4_VIS_KEYS.downProjW(i), `${p}_mlp`, intermediate_size, hidden_size, `${p}_geglu`, `${mlpBase}.down_proj`);
|
|
371
|
+
constant(GEMMA4_VIS_KEYS.postFfNorm(i), [hidden_size]);
|
|
372
|
+
activation(`${p}_post_ff`, ["N", hidden_size]);
|
|
373
|
+
rmsnorm(`${p}_post_ff_norm`, `${p}_mlp`, GEMMA4_VIS_KEYS.postFfNorm(i), `${p}_post_ff`, hidden_size, `${p}_res1`);
|
|
374
|
+
activation(`${p}_res2`, ["N", hidden_size]);
|
|
375
|
+
residualAdd(`${p}_residual2`, `${p}_res1`, `${p}_post_ff`, `${p}_res2`);
|
|
376
|
+
prev = `${p}_res2`;
|
|
377
|
+
}
|
|
378
|
+
activation("g4v_pooled", ["Np", hidden_size]);
|
|
379
|
+
addNode({
|
|
380
|
+
id: "g4v_pool",
|
|
381
|
+
opType: "PoolMatMul",
|
|
382
|
+
inputs: ["g4v_pool_w", prev],
|
|
383
|
+
outputs: ["g4v_pooled"],
|
|
384
|
+
attributes: {
|
|
385
|
+
K_tensor: prev,
|
|
386
|
+
hidden_size,
|
|
387
|
+
np_tensor: "g4v_pool_w"
|
|
388
|
+
}
|
|
389
|
+
});
|
|
390
|
+
constant(GEMMA4_VIS_KEYS.softEmbNormW, [hidden_size]);
|
|
391
|
+
activation("g4v_soft_normed", ["Np", hidden_size]);
|
|
392
|
+
rmsnorm("g4v_soft_emb_norm", "g4v_pooled", GEMMA4_VIS_KEYS.softEmbNormW, "g4v_soft_normed", hidden_size, "g4v_pooled");
|
|
393
|
+
constant(GEMMA4_VIS_KEYS.projW, [text_hidden, hidden_size]);
|
|
394
|
+
activation("g4v_proj_out", ["Np", text_hidden]);
|
|
395
|
+
matmul("g4v_proj", "g4v_soft_normed", GEMMA4_VIS_KEYS.projW, "g4v_proj_out", hidden_size, text_hidden, "g4v_proj_out");
|
|
396
|
+
constant(GEMMA4_VIS_KEYS.postProjNormW, [text_hidden]);
|
|
397
|
+
activation("g4v_image_embeds", ["Np", text_hidden]);
|
|
398
|
+
rmsnorm("g4v_post_proj_norm", "g4v_proj_out", GEMMA4_VIS_KEYS.postProjNormW, "g4v_image_embeds", text_hidden, "g4v_proj_out");
|
|
399
|
+
return {
|
|
400
|
+
architecture: "Gemma4VisionModel",
|
|
401
|
+
config: {
|
|
402
|
+
hidden_size,
|
|
403
|
+
num_layers: depth,
|
|
404
|
+
num_heads,
|
|
405
|
+
num_kv_heads: num_heads,
|
|
406
|
+
head_dim,
|
|
407
|
+
intermediate_size,
|
|
408
|
+
vocab_size: 0,
|
|
409
|
+
context_length: rawConfig.vision_config?.max_position_embeddings ?? 0,
|
|
410
|
+
rms_norm_eps: eps,
|
|
411
|
+
norm_type: "rmsnorm",
|
|
412
|
+
rope_base: info.ropeTheta,
|
|
413
|
+
rope_dim: head_dim,
|
|
414
|
+
kv_layout: "LHSd",
|
|
415
|
+
is_moe: false,
|
|
416
|
+
has_vision_tower: true,
|
|
417
|
+
vision_architecture: "gemma4_vision",
|
|
418
|
+
vision_patch_size: info.patchSize,
|
|
419
|
+
vision_embed_dim: hidden_size
|
|
420
|
+
},
|
|
421
|
+
capabilities: {
|
|
422
|
+
text: true,
|
|
423
|
+
vision: true,
|
|
424
|
+
moe: false
|
|
425
|
+
},
|
|
426
|
+
tensors,
|
|
427
|
+
nodes,
|
|
428
|
+
executionOrder,
|
|
429
|
+
inputs: [
|
|
430
|
+
"g4v_patches",
|
|
431
|
+
"g4v_pos_embeds",
|
|
432
|
+
"g4v_cos",
|
|
433
|
+
"g4v_sin",
|
|
434
|
+
"g4v_pool_w"
|
|
435
|
+
],
|
|
436
|
+
outputs: ["g4v_image_embeds"]
|
|
437
|
+
};
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
//#endregion
|
|
441
|
+
//#region src/gpu/architectures/qwen3_5_vision.ts
|
|
442
|
+
const VIS_NORM_EPS = 1e-6;
|
|
443
|
+
/**
|
|
444
|
+
* Build the ViT graph. The graph is shaped by symbolic "N" (number of patches),
|
|
445
|
+
* resolved at run time from the input tensor's first dim — exactly like the LM's
|
|
446
|
+
* symbolic "T".
|
|
447
|
+
*/
|
|
448
|
+
function generateQwen3_5VisionGraph(rawConfig) {
|
|
449
|
+
const vcfg = rawConfig.vision_config ?? rawConfig;
|
|
450
|
+
const hidden_size = vcfg.hidden_size;
|
|
451
|
+
const num_heads = vcfg.num_heads;
|
|
452
|
+
const head_dim = Math.floor(hidden_size / num_heads);
|
|
453
|
+
const depth = vcfg.depth;
|
|
454
|
+
const intermediate_size = vcfg.intermediate_size;
|
|
455
|
+
const out_hidden_size = vcfg.out_hidden_size;
|
|
456
|
+
const spatial_merge_size = vcfg.spatial_merge_size;
|
|
457
|
+
const in_channels = vcfg.in_channels;
|
|
458
|
+
const temporal_patch_size = vcfg.temporal_patch_size;
|
|
459
|
+
const patch_size = vcfg.patch_size;
|
|
460
|
+
const patch_dim = in_channels * temporal_patch_size * patch_size * patch_size;
|
|
461
|
+
const merged_in = hidden_size * (spatial_merge_size * spatial_merge_size);
|
|
462
|
+
const tensors = {};
|
|
463
|
+
const nodes = [];
|
|
464
|
+
const executionOrder = [];
|
|
465
|
+
function addTensor(desc) {
|
|
466
|
+
tensors[desc.name] = desc;
|
|
467
|
+
}
|
|
468
|
+
function addNode(node) {
|
|
469
|
+
nodes.push(node);
|
|
470
|
+
executionOrder.push(node.id);
|
|
471
|
+
}
|
|
472
|
+
function constant(name, shape) {
|
|
473
|
+
addTensor({
|
|
474
|
+
name,
|
|
475
|
+
shape,
|
|
476
|
+
dtype: "f32",
|
|
477
|
+
storage: "constant",
|
|
478
|
+
safetensorsKey: name
|
|
479
|
+
});
|
|
480
|
+
}
|
|
481
|
+
function activation(name, shape) {
|
|
482
|
+
addTensor({
|
|
483
|
+
name,
|
|
484
|
+
shape,
|
|
485
|
+
dtype: "f32",
|
|
486
|
+
storage: "activation"
|
|
487
|
+
});
|
|
488
|
+
}
|
|
489
|
+
/**
|
|
490
|
+
* Fused MatMul + row-broadcast bias → out = in @ W^T + bias. Replaces a MatMul
|
|
491
|
+
* followed by a separate AddBias, removing a full read+write of the matmul
|
|
492
|
+
* output per ViT linear layer.
|
|
493
|
+
*/
|
|
494
|
+
function matmulBias(id, input, weight, bias, output, K, N, mTensor) {
|
|
495
|
+
addNode({
|
|
496
|
+
id,
|
|
497
|
+
opType: "MatMulBias",
|
|
498
|
+
inputs: [
|
|
499
|
+
input,
|
|
500
|
+
weight,
|
|
501
|
+
bias
|
|
502
|
+
],
|
|
503
|
+
outputs: [output],
|
|
504
|
+
attributes: {
|
|
505
|
+
M_tensor: mTensor,
|
|
506
|
+
K,
|
|
507
|
+
N
|
|
508
|
+
}
|
|
509
|
+
});
|
|
510
|
+
}
|
|
511
|
+
function layernorm(id, input, w, b, output, hidden, seqTensor) {
|
|
512
|
+
addNode({
|
|
513
|
+
id,
|
|
514
|
+
opType: "LayerNorm",
|
|
515
|
+
inputs: [
|
|
516
|
+
input,
|
|
517
|
+
w,
|
|
518
|
+
b
|
|
519
|
+
],
|
|
520
|
+
outputs: [output],
|
|
521
|
+
attributes: {
|
|
522
|
+
hidden_size: hidden,
|
|
523
|
+
eps: VIS_NORM_EPS,
|
|
524
|
+
seq_len_tensor: seqTensor
|
|
525
|
+
}
|
|
526
|
+
});
|
|
527
|
+
}
|
|
528
|
+
activation("vis_patches", ["N", patch_dim]);
|
|
529
|
+
activation("vis_pos_embeds", ["N", hidden_size]);
|
|
530
|
+
activation("vis_cos", ["N", head_dim]);
|
|
531
|
+
activation("vis_sin", ["N", head_dim]);
|
|
532
|
+
constant(CANONICAL_KEYS.visPatchEmbedWeight, [hidden_size, patch_dim]);
|
|
533
|
+
constant(CANONICAL_KEYS.visPatchEmbedBias, [hidden_size]);
|
|
534
|
+
activation("vis_patch_mm", ["N", hidden_size]);
|
|
535
|
+
activation("vis_patch_bias", ["N", hidden_size]);
|
|
536
|
+
activation("vis_h0", ["N", hidden_size]);
|
|
537
|
+
matmulBias("vis_patch_mm", "vis_patches", CANONICAL_KEYS.visPatchEmbedWeight, CANONICAL_KEYS.visPatchEmbedBias, "vis_patch_bias", patch_dim, hidden_size, "vis_patches");
|
|
538
|
+
addNode({
|
|
539
|
+
id: "vis_add_pos",
|
|
540
|
+
opType: "Add",
|
|
541
|
+
inputs: ["vis_patch_bias", "vis_pos_embeds"],
|
|
542
|
+
outputs: ["vis_h0"],
|
|
543
|
+
attributes: {
|
|
544
|
+
count_tensor: "vis_patch_bias",
|
|
545
|
+
hidden_size
|
|
546
|
+
}
|
|
547
|
+
});
|
|
548
|
+
let prev = "vis_h0";
|
|
549
|
+
const qkv_dim = hidden_size * 3;
|
|
550
|
+
for (let i = 0; i < depth; i++) {
|
|
551
|
+
const p = `vis_b${i}`;
|
|
552
|
+
constant(CANONICAL_KEYS.visBlockNorm1W(i), [hidden_size]);
|
|
553
|
+
constant(CANONICAL_KEYS.visBlockNorm1B(i), [hidden_size]);
|
|
554
|
+
activation(`${p}_n1`, ["N", hidden_size]);
|
|
555
|
+
layernorm(`${p}_norm1`, prev, CANONICAL_KEYS.visBlockNorm1W(i), CANONICAL_KEYS.visBlockNorm1B(i), `${p}_n1`, hidden_size, prev);
|
|
556
|
+
constant(CANONICAL_KEYS.visBlockQkvW(i), [qkv_dim, hidden_size]);
|
|
557
|
+
constant(CANONICAL_KEYS.visBlockQkvB(i), [qkv_dim]);
|
|
558
|
+
activation(`${p}_qkv_mm`, ["N", qkv_dim]);
|
|
559
|
+
activation(`${p}_qkv`, ["N", qkv_dim]);
|
|
560
|
+
matmulBias(`${p}_qkv_proj`, `${p}_n1`, CANONICAL_KEYS.visBlockQkvW(i), CANONICAL_KEYS.visBlockQkvB(i), `${p}_qkv`, hidden_size, qkv_dim, `${p}_n1`);
|
|
561
|
+
activation(`${p}_q`, ["N", hidden_size]);
|
|
562
|
+
activation(`${p}_k`, ["N", hidden_size]);
|
|
563
|
+
activation(`${p}_v`, ["N", hidden_size]);
|
|
564
|
+
const sliceCols = (id, output, offset) => addNode({
|
|
565
|
+
id,
|
|
566
|
+
opType: "SliceCols",
|
|
567
|
+
inputs: [`${p}_qkv`],
|
|
568
|
+
outputs: [output],
|
|
569
|
+
attributes: {
|
|
570
|
+
in_width: qkv_dim,
|
|
571
|
+
out_width: hidden_size,
|
|
572
|
+
col_offset: offset,
|
|
573
|
+
seq_len_tensor: `${p}_qkv`
|
|
574
|
+
}
|
|
575
|
+
});
|
|
576
|
+
sliceCols(`${p}_split_q`, `${p}_q`, 0);
|
|
577
|
+
sliceCols(`${p}_split_k`, `${p}_k`, hidden_size);
|
|
578
|
+
sliceCols(`${p}_split_v`, `${p}_v`, 2 * hidden_size);
|
|
579
|
+
activation(`${p}_qr`, ["N", hidden_size]);
|
|
580
|
+
activation(`${p}_kr`, ["N", hidden_size]);
|
|
581
|
+
addNode({
|
|
582
|
+
id: `${p}_rope_q`,
|
|
583
|
+
opType: "ApplyRotaryEmb",
|
|
584
|
+
inputs: [
|
|
585
|
+
`${p}_q`,
|
|
586
|
+
"vis_cos",
|
|
587
|
+
"vis_sin"
|
|
588
|
+
],
|
|
589
|
+
outputs: [`${p}_qr`],
|
|
590
|
+
attributes: {
|
|
591
|
+
num_heads,
|
|
592
|
+
head_dim
|
|
593
|
+
}
|
|
594
|
+
});
|
|
595
|
+
addNode({
|
|
596
|
+
id: `${p}_rope_k`,
|
|
597
|
+
opType: "ApplyRotaryEmb",
|
|
598
|
+
inputs: [
|
|
599
|
+
`${p}_k`,
|
|
600
|
+
"vis_cos",
|
|
601
|
+
"vis_sin"
|
|
602
|
+
],
|
|
603
|
+
outputs: [`${p}_kr`],
|
|
604
|
+
attributes: {
|
|
605
|
+
num_heads,
|
|
606
|
+
head_dim
|
|
607
|
+
}
|
|
608
|
+
});
|
|
609
|
+
activation(`${p}_attn`, ["N", hidden_size]);
|
|
610
|
+
addNode({
|
|
611
|
+
id: `${p}_attention`,
|
|
612
|
+
opType: "Attention",
|
|
613
|
+
inputs: [
|
|
614
|
+
`${p}_qr`,
|
|
615
|
+
`${p}_kr`,
|
|
616
|
+
`${p}_v`
|
|
617
|
+
],
|
|
618
|
+
outputs: [`${p}_attn`],
|
|
619
|
+
attributes: {
|
|
620
|
+
num_q_heads: num_heads,
|
|
621
|
+
num_kv_heads: num_heads,
|
|
622
|
+
head_dim,
|
|
623
|
+
causal: false
|
|
624
|
+
}
|
|
625
|
+
});
|
|
626
|
+
constant(CANONICAL_KEYS.visBlockProjW(i), [hidden_size, hidden_size]);
|
|
627
|
+
constant(CANONICAL_KEYS.visBlockProjB(i), [hidden_size]);
|
|
628
|
+
activation(`${p}_proj_mm`, ["N", hidden_size]);
|
|
629
|
+
activation(`${p}_proj`, ["N", hidden_size]);
|
|
630
|
+
matmulBias(`${p}_proj_op`, `${p}_attn`, CANONICAL_KEYS.visBlockProjW(i), CANONICAL_KEYS.visBlockProjB(i), `${p}_proj`, hidden_size, hidden_size, `${p}_attn`);
|
|
631
|
+
activation(`${p}_res1`, ["N", hidden_size]);
|
|
632
|
+
addNode({
|
|
633
|
+
id: `${p}_residual1`,
|
|
634
|
+
opType: "Add",
|
|
635
|
+
inputs: [prev, `${p}_proj`],
|
|
636
|
+
outputs: [`${p}_res1`],
|
|
637
|
+
attributes: {
|
|
638
|
+
count_tensor: prev,
|
|
639
|
+
hidden_size
|
|
640
|
+
}
|
|
641
|
+
});
|
|
642
|
+
constant(CANONICAL_KEYS.visBlockNorm2W(i), [hidden_size]);
|
|
643
|
+
constant(CANONICAL_KEYS.visBlockNorm2B(i), [hidden_size]);
|
|
644
|
+
constant(CANONICAL_KEYS.visBlockFc1W(i), [intermediate_size, hidden_size]);
|
|
645
|
+
constant(CANONICAL_KEYS.visBlockFc1B(i), [intermediate_size]);
|
|
646
|
+
constant(CANONICAL_KEYS.visBlockFc2W(i), [hidden_size, intermediate_size]);
|
|
647
|
+
constant(CANONICAL_KEYS.visBlockFc2B(i), [hidden_size]);
|
|
648
|
+
activation(`${p}_n2`, ["N", hidden_size]);
|
|
649
|
+
activation(`${p}_fc1_mm`, ["N", intermediate_size]);
|
|
650
|
+
activation(`${p}_fc1`, ["N", intermediate_size]);
|
|
651
|
+
activation(`${p}_gelu`, ["N", intermediate_size]);
|
|
652
|
+
activation(`${p}_fc2_mm`, ["N", hidden_size]);
|
|
653
|
+
activation(`${p}_fc2`, ["N", hidden_size]);
|
|
654
|
+
activation(`${p}_res2`, ["N", hidden_size]);
|
|
655
|
+
layernorm(`${p}_norm2`, `${p}_res1`, CANONICAL_KEYS.visBlockNorm2W(i), CANONICAL_KEYS.visBlockNorm2B(i), `${p}_n2`, hidden_size, `${p}_res1`);
|
|
656
|
+
matmulBias(`${p}_fc1_op`, `${p}_n2`, CANONICAL_KEYS.visBlockFc1W(i), CANONICAL_KEYS.visBlockFc1B(i), `${p}_fc1`, hidden_size, intermediate_size, `${p}_n2`);
|
|
657
|
+
addNode({
|
|
658
|
+
id: `${p}_gelu_op`,
|
|
659
|
+
opType: "GELU",
|
|
660
|
+
inputs: [`${p}_fc1`],
|
|
661
|
+
outputs: [`${p}_gelu`],
|
|
662
|
+
attributes: { count_tensor: `${p}_fc1` }
|
|
663
|
+
});
|
|
664
|
+
matmulBias(`${p}_fc2_op`, `${p}_gelu`, CANONICAL_KEYS.visBlockFc2W(i), CANONICAL_KEYS.visBlockFc2B(i), `${p}_fc2`, intermediate_size, hidden_size, `${p}_gelu`);
|
|
665
|
+
addNode({
|
|
666
|
+
id: `${p}_residual2`,
|
|
667
|
+
opType: "Add",
|
|
668
|
+
inputs: [`${p}_res1`, `${p}_fc2`],
|
|
669
|
+
outputs: [`${p}_res2`],
|
|
670
|
+
attributes: {
|
|
671
|
+
count_tensor: `${p}_res1`,
|
|
672
|
+
hidden_size
|
|
673
|
+
}
|
|
674
|
+
});
|
|
675
|
+
prev = `${p}_res2`;
|
|
676
|
+
}
|
|
677
|
+
constant(CANONICAL_KEYS.visMergerNormW, [hidden_size]);
|
|
678
|
+
constant(CANONICAL_KEYS.visMergerNormB, [hidden_size]);
|
|
679
|
+
constant(CANONICAL_KEYS.visMergerFc1W, [merged_in, merged_in]);
|
|
680
|
+
constant(CANONICAL_KEYS.visMergerFc1B, [merged_in]);
|
|
681
|
+
constant(CANONICAL_KEYS.visMergerFc2W, [out_hidden_size, merged_in]);
|
|
682
|
+
constant(CANONICAL_KEYS.visMergerFc2B, [out_hidden_size]);
|
|
683
|
+
activation("vis_merge_norm", ["N", hidden_size]);
|
|
684
|
+
activation("vis_merge_fc1_mm", ["Nm", merged_in]);
|
|
685
|
+
activation("vis_merge_fc1", ["Nm", merged_in]);
|
|
686
|
+
activation("vis_merge_gelu", ["Nm", merged_in]);
|
|
687
|
+
activation("vis_merge_fc2_mm", ["Nm", out_hidden_size]);
|
|
688
|
+
activation("vis_image_embeds", ["Nm", out_hidden_size]);
|
|
689
|
+
layernorm("vis_merger_norm", prev, CANONICAL_KEYS.visMergerNormW, CANONICAL_KEYS.visMergerNormB, "vis_merge_norm", hidden_size, prev);
|
|
690
|
+
matmulBias("vis_merger_fc1", "vis_merge_norm", CANONICAL_KEYS.visMergerFc1W, CANONICAL_KEYS.visMergerFc1B, "vis_merge_fc1", merged_in, merged_in, "vis_merge_fc1");
|
|
691
|
+
addNode({
|
|
692
|
+
id: "vis_merger_gelu",
|
|
693
|
+
opType: "GeluErf",
|
|
694
|
+
inputs: ["vis_merge_fc1"],
|
|
695
|
+
outputs: ["vis_merge_gelu"],
|
|
696
|
+
attributes: { count_tensor: "vis_merge_fc1" }
|
|
697
|
+
});
|
|
698
|
+
matmulBias("vis_merger_fc2", "vis_merge_gelu", CANONICAL_KEYS.visMergerFc2W, CANONICAL_KEYS.visMergerFc2B, "vis_image_embeds", merged_in, out_hidden_size, "vis_image_embeds");
|
|
699
|
+
return {
|
|
700
|
+
architecture: "Qwen3_5VisionModel",
|
|
701
|
+
config: {
|
|
702
|
+
hidden_size,
|
|
703
|
+
num_layers: depth,
|
|
704
|
+
num_heads,
|
|
705
|
+
num_kv_heads: num_heads,
|
|
706
|
+
head_dim,
|
|
707
|
+
intermediate_size,
|
|
708
|
+
vocab_size: 0,
|
|
709
|
+
context_length: vcfg.num_position_embeddings,
|
|
710
|
+
rms_norm_eps: VIS_NORM_EPS,
|
|
711
|
+
norm_type: "layernorm",
|
|
712
|
+
rope_base: 1e4,
|
|
713
|
+
rope_dim: head_dim,
|
|
714
|
+
kv_layout: "LHSd",
|
|
715
|
+
is_moe: false,
|
|
716
|
+
has_vision_tower: true,
|
|
717
|
+
vision_architecture: "qwen3_5_vit",
|
|
718
|
+
vision_patch_size: patch_size,
|
|
719
|
+
vision_embed_dim: hidden_size
|
|
720
|
+
},
|
|
721
|
+
capabilities: {
|
|
722
|
+
text: true,
|
|
723
|
+
vision: true,
|
|
724
|
+
moe: false
|
|
725
|
+
},
|
|
726
|
+
tensors,
|
|
727
|
+
nodes,
|
|
728
|
+
executionOrder,
|
|
729
|
+
inputs: [
|
|
730
|
+
"vis_patches",
|
|
731
|
+
"vis_pos_embeds",
|
|
732
|
+
"vis_cos",
|
|
733
|
+
"vis_sin"
|
|
734
|
+
],
|
|
735
|
+
outputs: ["vis_image_embeds"]
|
|
736
|
+
};
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
//#endregion
|
|
740
|
+
//#region src/gpu/kani-tts.ts
|
|
741
|
+
/**
|
|
742
|
+
* KaniTTS — native text-to-speech engine for Gerbil's WebGPU backend.
|
|
743
|
+
*
|
|
744
|
+
* Kani-TTS-2 (nineninesix/kani-tts-2-en) is a two-stage TTS model that, like
|
|
745
|
+
* Moonshine, needs more than one graph:
|
|
746
|
+
*
|
|
747
|
+
* 1. CODEC-LM BACKBONE (LFM2-350M body): autoregressively emits NanoCodec audio
|
|
748
|
+
* tokens (4 per frame) into the same vocab as text. Reuses LFM2's block math
|
|
749
|
+
* with two KaniTTS2 deltas — frame-level position IDs (the 4 audio tokens of a
|
|
750
|
+
* frame share a position) and learnable per-layer RoPE (α^(l)-scaled freqs) —
|
|
751
|
+
* both folded host-side into per-layer cos/sin tables fed to the MRoPE op.
|
|
752
|
+
* 2. NANOCODEC DECODER (NVIDIA NeMo 22 kHz): FSQ dequant + causal HiFi-GAN conv
|
|
753
|
+
* decoder → 22 kHz PCM. Validated bit-exact vs MLX (test-nanocodec-decode.mjs).
|
|
754
|
+
*
|
|
755
|
+
* The AR loop runs on the host with full-logit readback so each frame's 4 audio
|
|
756
|
+
* tokens are sampled per-codebook (constrained to the valid codebook window), then
|
|
757
|
+
* the collected codes are decoded once through the NanoCodec graph.
|
|
758
|
+
*
|
|
759
|
+
* Validated on Dawn (desktop) via scripts/engine/test-kani-speak.mjs.
|
|
760
|
+
*/
|
|
761
|
+
const KANI_SAMPLE_RATE = 22050;
|
|
762
|
+
const KANI_HOP = 1764;
|
|
763
|
+
const KANI_END_OF_TEXT = 2;
|
|
764
|
+
/** NanoCodec decode chunking (keeps per-dispatch under WebGPU's 65535 cap). */
|
|
765
|
+
const KANI_CODEC_CHUNK_FRAMES = 32;
|
|
766
|
+
const KANI_CODEC_LOOKBACK_FRAMES = 32;
|
|
767
|
+
/** Hard cap so a stuck decode cannot loop forever (reference uses 3000). */
|
|
768
|
+
const DEFAULT_MAX_NEW_TOKENS = 3e3;
|
|
769
|
+
/** Build a fresh map of just the constant weights a graph references (see moonshine-stt). */
|
|
770
|
+
function selectGraphWeights(graph, weights) {
|
|
771
|
+
const out = /* @__PURE__ */ new Map();
|
|
772
|
+
for (const [name, desc] of Object.entries(graph.tensors)) {
|
|
773
|
+
if (desc.storage !== "constant") continue;
|
|
774
|
+
const w = weights.get(name) ?? (desc.safetensorsKey ? weights.get(desc.safetensorsKey) : void 0);
|
|
775
|
+
if (w) out.set(name, w);
|
|
776
|
+
}
|
|
777
|
+
return out;
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* Top-p nucleus sample from a (already softmaxed) probability array over `ids`.
|
|
781
|
+
* Sorts by prob desc, keeps the smallest prefix whose cumulative mass ≥ topP,
|
|
782
|
+
* renormalizes, and samples one id.
|
|
783
|
+
*/
|
|
784
|
+
function nucleusSample(probs, ids, topP) {
|
|
785
|
+
const n = probs.length;
|
|
786
|
+
const order = Array.from({ length: n }, (_, i) => i).sort((a, b) => probs[b] - probs[a]);
|
|
787
|
+
let cum = 0;
|
|
788
|
+
let cutoff = n;
|
|
789
|
+
for (let r = 0; r < n; r++) {
|
|
790
|
+
cum += probs[order[r]];
|
|
791
|
+
if (cum >= topP) {
|
|
792
|
+
cutoff = r + 1;
|
|
793
|
+
break;
|
|
794
|
+
}
|
|
795
|
+
}
|
|
796
|
+
let keepSum = 0;
|
|
797
|
+
for (let r = 0; r < cutoff; r++) keepSum += probs[order[r]];
|
|
798
|
+
let pick = Math.random() * keepSum;
|
|
799
|
+
for (let r = 0; r < cutoff; r++) {
|
|
800
|
+
pick -= probs[order[r]];
|
|
801
|
+
if (pick <= 0) return ids[order[r]];
|
|
802
|
+
}
|
|
803
|
+
return ids[order[0]];
|
|
804
|
+
}
|
|
805
|
+
/** In-place softmax of `scores`. */
|
|
806
|
+
function softmaxInPlace(scores) {
|
|
807
|
+
let maxV = Number.NEGATIVE_INFINITY;
|
|
808
|
+
for (let i = 0; i < scores.length; i++) if (scores[i] > maxV) maxV = scores[i];
|
|
809
|
+
let sum = 0;
|
|
810
|
+
for (let i = 0; i < scores.length; i++) {
|
|
811
|
+
const e = Math.exp(scores[i] - maxV);
|
|
812
|
+
scores[i] = e;
|
|
813
|
+
sum += e;
|
|
814
|
+
}
|
|
815
|
+
for (let i = 0; i < scores.length; i++) scores[i] /= sum;
|
|
816
|
+
}
|
|
817
|
+
var KaniTTS = class KaniTTS {
|
|
818
|
+
ctx;
|
|
819
|
+
loaded;
|
|
820
|
+
tokenizer;
|
|
821
|
+
cfg;
|
|
822
|
+
rawConfig;
|
|
823
|
+
maxSeqLen;
|
|
824
|
+
/** Backbone executor (built once; reused across speak() calls). */
|
|
825
|
+
backboneExec;
|
|
826
|
+
/** The attention layer indices (carry learnable α) and their α values. */
|
|
827
|
+
attnLayers;
|
|
828
|
+
layerAlpha;
|
|
829
|
+
headDim;
|
|
830
|
+
ropeBase;
|
|
831
|
+
_destroyed = false;
|
|
832
|
+
architecture = "KaniTTS2ForCausalLM";
|
|
833
|
+
constructor(ctx, loaded, maxSeqLen) {
|
|
834
|
+
this.ctx = ctx;
|
|
835
|
+
this.loaded = loaded;
|
|
836
|
+
this.tokenizer = loaded.tokenizer;
|
|
837
|
+
this.rawConfig = loaded.rawConfig;
|
|
838
|
+
this.cfg = parseKaniConfig(loaded.rawConfig);
|
|
839
|
+
this.maxSeqLen = maxSeqLen;
|
|
840
|
+
const hidden = this.rawConfig.hidden_size;
|
|
841
|
+
const heads = this.rawConfig.num_attention_heads;
|
|
842
|
+
this.headDim = this.rawConfig.head_dim ?? Math.floor(hidden / heads);
|
|
843
|
+
this.ropeBase = this.rawConfig.rope_theta ?? 1e6;
|
|
844
|
+
this.attnLayers = kaniAttentionLayerIndices(this.rawConfig);
|
|
845
|
+
this.layerAlpha = /* @__PURE__ */ new Map();
|
|
846
|
+
const useLearnable = this.rawConfig.use_learnable_rope ?? false;
|
|
847
|
+
for (const layer of this.attnLayers) {
|
|
848
|
+
let alpha = 1;
|
|
849
|
+
if (useLearnable) {
|
|
850
|
+
const w = loaded.backboneWeights.get(`learnable_rope_layers.${layer}.alpha_weight`);
|
|
851
|
+
alpha = kaniLayerAlpha(w ? w.data[0] ?? 0 : 0, this.cfg);
|
|
852
|
+
}
|
|
853
|
+
this.layerAlpha.set(layer, alpha);
|
|
854
|
+
}
|
|
855
|
+
const graph = generateKaniTtsGraph(this.rawConfig);
|
|
856
|
+
this.backboneExec = new Executor(ctx, graph, {
|
|
857
|
+
maxSeqLen,
|
|
858
|
+
kvMode: "f32"
|
|
859
|
+
});
|
|
860
|
+
this.backboneExec.uploadWeightsMap(selectGraphWeights(graph, loaded.backboneWeights));
|
|
861
|
+
this.backboneExec.initBindGroups();
|
|
862
|
+
}
|
|
863
|
+
static async create(options = {}) {
|
|
864
|
+
const ctx = await initGPU();
|
|
865
|
+
const loaded = await loadKaniTTS({
|
|
866
|
+
repo: options.repo,
|
|
867
|
+
codecRepo: options.codecRepo,
|
|
868
|
+
revision: options.revision,
|
|
869
|
+
hfToken: options.hfToken,
|
|
870
|
+
cacheDir: options.cacheDir,
|
|
871
|
+
onProgress: options.onProgress
|
|
872
|
+
});
|
|
873
|
+
const ctxLen = loaded.rawConfig.max_position_embeddings ?? 4096;
|
|
874
|
+
return new KaniTTS(ctx, loaded, Math.min(options.maxSeqLen ?? 2048, ctxLen));
|
|
875
|
+
}
|
|
876
|
+
/** Write per-layer cos/sin for token rows [rowStart, rowStart+positions.length). */
|
|
877
|
+
writeCosSin(positions, rowStart) {
|
|
878
|
+
const rowBytes = this.headDim * 4;
|
|
879
|
+
for (const layer of this.attnLayers) {
|
|
880
|
+
const alpha = this.layerAlpha.get(layer) ?? 1;
|
|
881
|
+
const { cos, sin } = buildKaniLayerCosSin(positions, this.headDim, this.ropeBase, alpha);
|
|
882
|
+
if (rowStart === 0) {
|
|
883
|
+
this.backboneExec.writeInput(kaniCosTensor(layer), cos);
|
|
884
|
+
this.backboneExec.writeInput(kaniSinTensor(layer), sin);
|
|
885
|
+
} else {
|
|
886
|
+
this.backboneExec.writeInputAt(kaniCosTensor(layer), cos, rowStart * rowBytes);
|
|
887
|
+
this.backboneExec.writeInputAt(kaniSinTensor(layer), sin, rowStart * rowBytes);
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
/**
|
|
892
|
+
* Synthesize speech for `text`. Returns 22 kHz mono PCM.
|
|
893
|
+
*
|
|
894
|
+
* Pipeline: build the [SOH]+text+[EOT,EOH] prompt → prefill the codec-LM →
|
|
895
|
+
* AR-decode 4-token frames (per-codebook constrained sampling) until end_of_speech
|
|
896
|
+
* → strip markers → codes → NanoCodec decode → PCM.
|
|
897
|
+
*/
|
|
898
|
+
async speak(text, opts = {}) {
|
|
899
|
+
if (this._destroyed) throw new Error("KaniTTS has been destroyed.");
|
|
900
|
+
const temperature = opts.temperature ?? 1;
|
|
901
|
+
const topP = opts.topP ?? .95;
|
|
902
|
+
const repetitionPenalty = opts.repetitionPenalty ?? 1.1;
|
|
903
|
+
const tag = opts.languageTag ?? "en_us";
|
|
904
|
+
const maxNewTokens = Math.min(opts.maxNewTokens ?? DEFAULT_MAX_NEW_TOKENS, this.maxSeqLen);
|
|
905
|
+
const prompt = [
|
|
906
|
+
KANI_START_OF_HUMAN,
|
|
907
|
+
...this.tokenizer.encode(`${tag.trim()}: ${text}`),
|
|
908
|
+
KANI_END_OF_TEXT,
|
|
909
|
+
KANI_END_OF_HUMAN
|
|
910
|
+
];
|
|
911
|
+
if (prompt.length >= this.maxSeqLen) throw new Error(`Kani prompt is ${prompt.length} tokens >= maxSeqLen ${this.maxSeqLen}.`);
|
|
912
|
+
this.backboneExec.reset();
|
|
913
|
+
const promptPos = computeKaniPositions(prompt, this.cfg);
|
|
914
|
+
this.writeCosSin(promptPos, 0);
|
|
915
|
+
const { logits } = await this.backboneExec.forward(new Uint32Array(prompt));
|
|
916
|
+
const startTime = performance.now();
|
|
917
|
+
const { audioTokens, finished } = await this.runDecodeLoop(prompt, logits, {
|
|
918
|
+
temperature,
|
|
919
|
+
topP,
|
|
920
|
+
repetitionPenalty,
|
|
921
|
+
maxNewTokens,
|
|
922
|
+
maxFrames: opts.maxFrames ?? Number.POSITIVE_INFINITY
|
|
923
|
+
});
|
|
924
|
+
const usable = audioTokens.length - audioTokens.length % this.cfg.tokensPerFrame;
|
|
925
|
+
if (usable === 0) throw new Error(`Kani produced no complete audio frames (finished=${finished}, audioTokens=${audioTokens.length}). The codec-LM emitted no speech.`);
|
|
926
|
+
const { codes, numFrames } = audioTokensToCodes(audioTokens.slice(0, usable), this.cfg);
|
|
927
|
+
const pcm = await this.decodeCodes(codes, numFrames);
|
|
928
|
+
const totalTime = (performance.now() - startTime) / 1e3;
|
|
929
|
+
const audioSeconds = numFrames * KANI_HOP / KANI_SAMPLE_RATE;
|
|
930
|
+
console.log(`[kani] frames=${numFrames} audio=${audioSeconds.toFixed(2)}s wall=${totalTime.toFixed(2)}s RTF=${(audioSeconds / totalTime).toFixed(2)}x`);
|
|
931
|
+
return {
|
|
932
|
+
pcm,
|
|
933
|
+
sampleRate: KANI_SAMPLE_RATE,
|
|
934
|
+
frames: numFrames,
|
|
935
|
+
audioSeconds
|
|
936
|
+
};
|
|
937
|
+
}
|
|
938
|
+
/**
|
|
939
|
+
* Autoregressive decode: from the prefill logits, emit one token per step —
|
|
940
|
+
* greedy for the structural markers ([SOA][SOS]) and per-codebook constrained
|
|
941
|
+
* sampling once in speech — collecting the audio tokens between SOS and EOS.
|
|
942
|
+
* Writes each step's per-layer cos/sin row (frame-level logical position) before
|
|
943
|
+
* the forward. Returns the collected audio tokens and whether EOS/cap was hit.
|
|
944
|
+
*/
|
|
945
|
+
async runDecodeLoop(prompt, prefillLogits, p) {
|
|
946
|
+
const seq = [...prompt];
|
|
947
|
+
const audioTokens = [];
|
|
948
|
+
const frameCap = p.maxFrames * this.cfg.tokensPerFrame;
|
|
949
|
+
let logits = prefillLogits;
|
|
950
|
+
let inSpeech = false;
|
|
951
|
+
let finished = false;
|
|
952
|
+
for (let step = 0; step < p.maxNewTokens; step++) {
|
|
953
|
+
const nextToken = inSpeech ? this.sampleAudioToken(logits, audioTokens.length % this.cfg.tokensPerFrame, {
|
|
954
|
+
temperature: p.temperature,
|
|
955
|
+
topP: p.topP,
|
|
956
|
+
repetitionPenalty: p.repetitionPenalty,
|
|
957
|
+
previous: seq
|
|
958
|
+
}) : this.argmax(logits);
|
|
959
|
+
if (nextToken === this.cfg.startOfSpeech) inSpeech = true;
|
|
960
|
+
else if (nextToken === this.cfg.endOfSpeech) {
|
|
961
|
+
seq.push(nextToken);
|
|
962
|
+
finished = true;
|
|
963
|
+
break;
|
|
964
|
+
} else if (inSpeech && nextToken >= this.cfg.audioTokensStart) audioTokens.push(nextToken);
|
|
965
|
+
seq.push(nextToken);
|
|
966
|
+
if (inSpeech && audioTokens.length >= frameCap) {
|
|
967
|
+
finished = true;
|
|
968
|
+
break;
|
|
969
|
+
}
|
|
970
|
+
if (seq.length >= this.maxSeqLen) break;
|
|
971
|
+
this.writeCosSin(Float32Array.of(this.logicalPositionAt(seq)), this.backboneExec.currentSeqPos);
|
|
972
|
+
logits = (await this.backboneExec.forward(new Uint32Array([nextToken]))).logits;
|
|
973
|
+
}
|
|
974
|
+
return {
|
|
975
|
+
audioTokens,
|
|
976
|
+
finished
|
|
977
|
+
};
|
|
978
|
+
}
|
|
979
|
+
/** Logical (frame-level) position of the LAST token in `seq`. */
|
|
980
|
+
logicalPositionAt(seq) {
|
|
981
|
+
const positions = computeKaniPositions(seq, this.cfg);
|
|
982
|
+
return positions[positions.length - 1];
|
|
983
|
+
}
|
|
984
|
+
/** Greedy argmax over a logits row. */
|
|
985
|
+
argmax(logits) {
|
|
986
|
+
let best = 0;
|
|
987
|
+
let bestVal = logits[0];
|
|
988
|
+
for (let i = 1; i < logits.length; i++) if (logits[i] > bestVal) {
|
|
989
|
+
bestVal = logits[i];
|
|
990
|
+
best = i;
|
|
991
|
+
}
|
|
992
|
+
return best;
|
|
993
|
+
}
|
|
994
|
+
/**
|
|
995
|
+
* Sample one audio token for codebook position `codebook`, constrained to that
|
|
996
|
+
* codebook's valid window [audio_tokens_start + 4032*c, +4032). Allows end_of_speech
|
|
997
|
+
* only at codebook 0 (frame boundary). Applies temperature, top-p, rep-penalty.
|
|
998
|
+
*/
|
|
999
|
+
sampleAudioToken(logits, codebook, p) {
|
|
1000
|
+
const winStart = this.cfg.audioTokensStart + this.cfg.codebookSize * codebook;
|
|
1001
|
+
const winLen = this.cfg.codebookSize;
|
|
1002
|
+
const allowEos = codebook === 0;
|
|
1003
|
+
const n = winLen + (allowEos ? 1 : 0);
|
|
1004
|
+
const ids = new Int32Array(n);
|
|
1005
|
+
for (let i = 0; i < winLen; i++) ids[i] = winStart + i;
|
|
1006
|
+
if (allowEos) ids[winLen] = this.cfg.endOfSpeech;
|
|
1007
|
+
const prev = p.repetitionPenalty !== 1 ? new Set(p.previous) : null;
|
|
1008
|
+
const scores = new Float32Array(n);
|
|
1009
|
+
for (let i = 0; i < n; i++) {
|
|
1010
|
+
let s = logits[ids[i]];
|
|
1011
|
+
if (prev?.has(ids[i])) s = s > 0 ? s / p.repetitionPenalty : s * p.repetitionPenalty;
|
|
1012
|
+
scores[i] = s / p.temperature;
|
|
1013
|
+
}
|
|
1014
|
+
softmaxInPlace(scores);
|
|
1015
|
+
return nucleusSample(scores, ids, p.topP);
|
|
1016
|
+
}
|
|
1017
|
+
/**
|
|
1018
|
+
* Decode NanoCodec codes [groups, T] (group-major) → PCM.
|
|
1019
|
+
*
|
|
1020
|
+
* The decoder graph carries concrete lengths, and the upsampled conv activations
|
|
1021
|
+
* for long clips overflow WebGPU's 65535 per-dimension dispatch cap. The decoder
|
|
1022
|
+
* is fully causal with a small (≤ a few frames) receptive field, so we decode in
|
|
1023
|
+
* frame chunks with a left-context lookback and keep only each chunk's own output
|
|
1024
|
+
* samples — numerically identical to a single decode, but bounded per dispatch.
|
|
1025
|
+
*/
|
|
1026
|
+
async decodeCodes(codes, numFrames) {
|
|
1027
|
+
const groups = 4;
|
|
1028
|
+
const lookback = KANI_CODEC_LOOKBACK_FRAMES;
|
|
1029
|
+
const chunk = KANI_CODEC_CHUNK_FRAMES;
|
|
1030
|
+
const pcm = new Float32Array(numFrames * KANI_HOP);
|
|
1031
|
+
for (let start = 0; start < numFrames; start += chunk) {
|
|
1032
|
+
const ctxStart = Math.max(0, start - lookback);
|
|
1033
|
+
const ctxFrames = start - ctxStart;
|
|
1034
|
+
const end = Math.min(numFrames, start + chunk);
|
|
1035
|
+
const winFrames = end - ctxStart;
|
|
1036
|
+
const winCodes = new Uint32Array(groups * winFrames);
|
|
1037
|
+
for (let g = 0; g < groups; g++) for (let t = 0; t < winFrames; t++) winCodes[g * winFrames + t] = codes[g * numFrames + (ctxStart + t)];
|
|
1038
|
+
const win = await this.decodeCodesWindow(winCodes, winFrames);
|
|
1039
|
+
const dropSamples = ctxFrames * KANI_HOP;
|
|
1040
|
+
const keepSamples = (end - start) * KANI_HOP;
|
|
1041
|
+
pcm.set(win.subarray(dropSamples, dropSamples + keepSamples), start * KANI_HOP);
|
|
1042
|
+
}
|
|
1043
|
+
for (let i = 0; i < pcm.length; i++) pcm[i] = Math.min(1, Math.max(-1, pcm[i]));
|
|
1044
|
+
return pcm;
|
|
1045
|
+
}
|
|
1046
|
+
/** Run the NanoCodec decoder graph for a single (bounded) code window → PCM. */
|
|
1047
|
+
async decodeCodesWindow(winCodes, winFrames) {
|
|
1048
|
+
const graph = generateNanoCodecDecoderGraph({ numFrames: winFrames });
|
|
1049
|
+
const exec = new Executor(this.ctx, graph, {
|
|
1050
|
+
maxSeqLen: winFrames,
|
|
1051
|
+
kvMode: "f32"
|
|
1052
|
+
});
|
|
1053
|
+
try {
|
|
1054
|
+
exec.uploadWeightsMap(selectGraphWeights(graph, this.loaded.codecWeights));
|
|
1055
|
+
exec.initBindGroups();
|
|
1056
|
+
exec.reset();
|
|
1057
|
+
exec.writeInput("audio_codes", winCodes);
|
|
1058
|
+
return await exec.runGraphOutput(graph.outputs[0], winFrames * KANI_HOP);
|
|
1059
|
+
} finally {
|
|
1060
|
+
exec.destroy();
|
|
1061
|
+
}
|
|
1062
|
+
}
|
|
1063
|
+
destroy() {
|
|
1064
|
+
if (this._destroyed) return;
|
|
1065
|
+
this._destroyed = true;
|
|
1066
|
+
this.backboneExec.destroy();
|
|
1067
|
+
this.loaded.backboneWeights.clear();
|
|
1068
|
+
this.loaded.codecWeights.clear();
|
|
1069
|
+
}
|
|
1070
|
+
};
|
|
1071
|
+
|
|
1072
|
+
//#endregion
|
|
1073
|
+
//#region src/gpu/sampler.ts
|
|
1074
|
+
let _heapIndices = null;
|
|
1075
|
+
let _heapValues = null;
|
|
1076
|
+
/**
|
|
1077
|
+
* Sample a token ID from logits.
|
|
1078
|
+
*
|
|
1079
|
+
* Pipeline: repetition penalty → temperature → top-k (min-heap) → softmax → top-p → sample.
|
|
1080
|
+
*/
|
|
1081
|
+
function sampleToken(logits, params = {}, previousTokens) {
|
|
1082
|
+
const temperature = params.temperature ?? .7;
|
|
1083
|
+
const topK = params.topK ?? 50;
|
|
1084
|
+
const topP = params.topP ?? .9;
|
|
1085
|
+
const repetitionPenalty = params.repetitionPenalty ?? 1;
|
|
1086
|
+
if (temperature < 1e-6) return argmax(logits);
|
|
1087
|
+
const N = logits.length;
|
|
1088
|
+
const K = Math.min(topK > 0 ? topK : N, N);
|
|
1089
|
+
if (!_heapIndices || _heapIndices.length < K) {
|
|
1090
|
+
_heapIndices = new Uint32Array(K);
|
|
1091
|
+
_heapValues = new Float32Array(K);
|
|
1092
|
+
}
|
|
1093
|
+
const hIdx = _heapIndices;
|
|
1094
|
+
const hVal = _heapValues;
|
|
1095
|
+
let penaltySet = null;
|
|
1096
|
+
if (repetitionPenalty !== 1 && previousTokens?.length) penaltySet = new Set(previousTokens);
|
|
1097
|
+
let heapSize = 0;
|
|
1098
|
+
for (let i = 0; i < N; i++) {
|
|
1099
|
+
let s = logits[i];
|
|
1100
|
+
if (penaltySet?.has(i)) s = s > 0 ? s / repetitionPenalty : s * repetitionPenalty;
|
|
1101
|
+
s /= temperature;
|
|
1102
|
+
if (heapSize < K) {
|
|
1103
|
+
hIdx[heapSize] = i;
|
|
1104
|
+
hVal[heapSize] = s;
|
|
1105
|
+
heapSize++;
|
|
1106
|
+
if (heapSize === K) for (let j = (K >> 1) - 1; j >= 0; j--) siftDown(hIdx, hVal, j, K);
|
|
1107
|
+
} else if (s > hVal[0]) {
|
|
1108
|
+
hIdx[0] = i;
|
|
1109
|
+
hVal[0] = s;
|
|
1110
|
+
siftDown(hIdx, hVal, 0, K);
|
|
1111
|
+
}
|
|
1112
|
+
}
|
|
1113
|
+
for (let i = 1; i < heapSize; i++) {
|
|
1114
|
+
const vi = hVal[i];
|
|
1115
|
+
const ii = hIdx[i];
|
|
1116
|
+
let j = i - 1;
|
|
1117
|
+
while (j >= 0 && hVal[j] < vi) {
|
|
1118
|
+
hVal[j + 1] = hVal[j];
|
|
1119
|
+
hIdx[j + 1] = hIdx[j];
|
|
1120
|
+
j--;
|
|
1121
|
+
}
|
|
1122
|
+
hVal[j + 1] = vi;
|
|
1123
|
+
hIdx[j + 1] = ii;
|
|
1124
|
+
}
|
|
1125
|
+
const maxScore = hVal[0];
|
|
1126
|
+
let sumExp = 0;
|
|
1127
|
+
for (let i = 0; i < heapSize; i++) {
|
|
1128
|
+
const p = Math.exp(hVal[i] - maxScore);
|
|
1129
|
+
hVal[i] = p;
|
|
1130
|
+
sumExp += p;
|
|
1131
|
+
}
|
|
1132
|
+
const invSum = 1 / sumExp;
|
|
1133
|
+
for (let i = 0; i < heapSize; i++) hVal[i] *= invSum;
|
|
1134
|
+
let candidateCount = heapSize;
|
|
1135
|
+
if (topP < 1) {
|
|
1136
|
+
let cumulative$1 = 0;
|
|
1137
|
+
for (let i = 0; i < heapSize; i++) {
|
|
1138
|
+
cumulative$1 += hVal[i];
|
|
1139
|
+
if (cumulative$1 >= topP) {
|
|
1140
|
+
candidateCount = i + 1;
|
|
1141
|
+
break;
|
|
1142
|
+
}
|
|
1143
|
+
}
|
|
1144
|
+
let sum = 0;
|
|
1145
|
+
for (let i = 0; i < candidateCount; i++) sum += hVal[i];
|
|
1146
|
+
const inv = 1 / sum;
|
|
1147
|
+
for (let i = 0; i < candidateCount; i++) hVal[i] *= inv;
|
|
1148
|
+
}
|
|
1149
|
+
const r = Math.random();
|
|
1150
|
+
let cumulative = 0;
|
|
1151
|
+
for (let i = 0; i < candidateCount; i++) {
|
|
1152
|
+
cumulative += hVal[i];
|
|
1153
|
+
if (r <= cumulative) return hIdx[i];
|
|
1154
|
+
}
|
|
1155
|
+
return hIdx[candidateCount - 1];
|
|
1156
|
+
}
|
|
1157
|
+
/**
|
|
1158
|
+
* Return the index of the maximum value (greedy decoding).
|
|
1159
|
+
*/
|
|
1160
|
+
function argmax(arr) {
|
|
1161
|
+
let maxIdx = 0;
|
|
1162
|
+
let maxVal = arr[0];
|
|
1163
|
+
for (let i = 1; i < arr.length; i++) if (arr[i] > maxVal) {
|
|
1164
|
+
maxVal = arr[i];
|
|
1165
|
+
maxIdx = i;
|
|
1166
|
+
}
|
|
1167
|
+
return maxIdx;
|
|
1168
|
+
}
|
|
1169
|
+
/** Min-heap sift down on parallel index/value typed arrays. */
|
|
1170
|
+
function siftDown(indices, values, i, n) {
|
|
1171
|
+
while (true) {
|
|
1172
|
+
let smallest = i;
|
|
1173
|
+
const left = 2 * i + 1;
|
|
1174
|
+
const right = 2 * i + 2;
|
|
1175
|
+
if (left < n && values[left] < values[smallest]) smallest = left;
|
|
1176
|
+
if (right < n && values[right] < values[smallest]) smallest = right;
|
|
1177
|
+
if (smallest === i) break;
|
|
1178
|
+
const ti = indices[i];
|
|
1179
|
+
indices[i] = indices[smallest];
|
|
1180
|
+
indices[smallest] = ti;
|
|
1181
|
+
const tv = values[i];
|
|
1182
|
+
values[i] = values[smallest];
|
|
1183
|
+
values[smallest] = tv;
|
|
1184
|
+
i = smallest;
|
|
1185
|
+
}
|
|
1186
|
+
}
|
|
1187
|
+
|
|
1188
|
+
//#endregion
|
|
1189
|
+
//#region src/gpu/vision-executor.ts
|
|
1190
|
+
const MAP_MODE_READ = 1;
|
|
1191
|
+
/** Pack an f32 array to IEEE half (f16) bits (round-to-nearest-even). */
|
|
1192
|
+
const _f16scratch = /* @__PURE__ */ new ArrayBuffer(4);
|
|
1193
|
+
const _f16f32 = new Float32Array(_f16scratch);
|
|
1194
|
+
const _f16u32 = new Uint32Array(_f16scratch);
|
|
1195
|
+
function packF16(src) {
|
|
1196
|
+
const out = new Uint16Array(src.length);
|
|
1197
|
+
for (let i = 0; i < src.length; i++) {
|
|
1198
|
+
_f16f32[0] = src[i];
|
|
1199
|
+
const f = _f16u32[0];
|
|
1200
|
+
const sign = f >>> 16 & 32768;
|
|
1201
|
+
const exp = (f >>> 23 & 255) - 112;
|
|
1202
|
+
const mant = f & 8388607;
|
|
1203
|
+
if (exp <= 0) out[i] = sign;
|
|
1204
|
+
else if (exp >= 31) out[i] = sign | 31744;
|
|
1205
|
+
else {
|
|
1206
|
+
const round = (mant & 4096) !== 0;
|
|
1207
|
+
let half = sign | exp << 10 | mant >>> 13;
|
|
1208
|
+
if (round) half += 1;
|
|
1209
|
+
out[i] = half;
|
|
1210
|
+
}
|
|
1211
|
+
}
|
|
1212
|
+
return out;
|
|
1213
|
+
}
|
|
1214
|
+
/**
|
|
1215
|
+
* WebKit only: the ViT bidirectional Attention op is a single dispatch whose
|
|
1216
|
+
* grid is [N query rows, heads], each row looping over all N keys — O(N²) work in
|
|
1217
|
+
* one submission. Past a few hundred rows that can exceed Metal's ~few-second GPU
|
|
1218
|
+
* watchdog and kill the web-content process (indistinguishable from OOM). We split
|
|
1219
|
+
* the attention grid into windows of this many query rows, each its own
|
|
1220
|
+
* submit + onSubmittedWorkDone() drain, so no single submission runs long enough
|
|
1221
|
+
* to trip the watchdog. Math is unchanged: every query row computes identically
|
|
1222
|
+
* regardless of how the grid is partitioned.
|
|
1223
|
+
*/
|
|
1224
|
+
const WEBKIT_ATTN_CHUNK_ROWS = 64;
|
|
1225
|
+
var VisionExecutor = class {
|
|
1226
|
+
ctx;
|
|
1227
|
+
graph;
|
|
1228
|
+
mergeUnit;
|
|
1229
|
+
weightBuffers = /* @__PURE__ */ new Map();
|
|
1230
|
+
activationBuffers = /* @__PURE__ */ new Map();
|
|
1231
|
+
dispatches = [];
|
|
1232
|
+
maxPatches;
|
|
1233
|
+
/** Weight (B) names of MatMulBias nodes stored as f16 (empty without shader-f16). */
|
|
1234
|
+
f16WeightNames = /* @__PURE__ */ new Set();
|
|
1235
|
+
/** Runtime pooled-token count for the Gemma 4 ViT ("Np" dim); 0 for Qwen. */
|
|
1236
|
+
gemma4Np = 0;
|
|
1237
|
+
/** True when this graph is the Gemma 4 vision tower (uses "Np" + PoolMatMul). */
|
|
1238
|
+
isGemma4;
|
|
1239
|
+
constructor(ctx, graph, maxPatches) {
|
|
1240
|
+
this.ctx = ctx;
|
|
1241
|
+
this.graph = graph;
|
|
1242
|
+
this.maxPatches = maxPatches;
|
|
1243
|
+
this.isGemma4 = graph.architecture === "Gemma4VisionModel";
|
|
1244
|
+
const mergedIn = graph.tensors[CANONICAL_KEYS.visMergerFc1W]?.shape[1];
|
|
1245
|
+
const hidden = graph.config.hidden_size;
|
|
1246
|
+
this.mergeUnit = mergedIn && hidden ? mergedIn / hidden : 1;
|
|
1247
|
+
if (ctx.hasF16) {
|
|
1248
|
+
for (const node of graph.nodes) if (node.opType === "MatMulBias" && node.inputs[1]) this.f16WeightNames.add(node.inputs[1]);
|
|
1249
|
+
}
|
|
1250
|
+
this.allocateActivationBuffers();
|
|
1251
|
+
}
|
|
1252
|
+
uploadWeights(weights) {
|
|
1253
|
+
for (const [name, desc] of Object.entries(this.graph.tensors)) {
|
|
1254
|
+
if (desc.storage !== "constant") continue;
|
|
1255
|
+
const w = weights.get(name);
|
|
1256
|
+
if (!w) throw new Error(`VisionExecutor: missing weight "${name}"`);
|
|
1257
|
+
let data = w.data;
|
|
1258
|
+
if (this.f16WeightNames.has(name) && w.data instanceof Float32Array) data = packF16(w.data);
|
|
1259
|
+
const buffer = createStorageBuffer(this.ctx, `vweight_${name}`, data.byteLength, data);
|
|
1260
|
+
this.weightBuffers.set(name, buffer);
|
|
1261
|
+
}
|
|
1262
|
+
}
|
|
1263
|
+
initBindGroups() {
|
|
1264
|
+
const dummyShapes = this.resolveShapes(this.maxPatches);
|
|
1265
|
+
for (const nodeId of this.graph.executionOrder) {
|
|
1266
|
+
const node = this.graph.nodes.find((n) => n.id === nodeId);
|
|
1267
|
+
let spec = KERNEL_REGISTRY[node.opType];
|
|
1268
|
+
if (!spec) throw new Error(`VisionExecutor: no kernel for op "${node.opType}"`);
|
|
1269
|
+
if (node.opType === "MatMulBias" && this.f16WeightNames.has(node.inputs[1])) spec = MATMUL_BIAS_F16C_SPEC;
|
|
1270
|
+
const pipeline = getOrCreatePipeline(this.ctx, `vkernel_${nodeId}`, spec.shaderCode, spec.entryPoint);
|
|
1271
|
+
const paramsData = spec.buildParams(node, dummyShapes, { seqPos: 0 });
|
|
1272
|
+
const uniformBuffer = createUniformBuffer(this.ctx, `vuniform_${nodeId}`, paramsData);
|
|
1273
|
+
const bufferEntries = this.gatherBuffers(spec, node, uniformBuffer);
|
|
1274
|
+
const bindGroup = createBindGroup(this.ctx, pipeline, bufferEntries, `vbg_${nodeId}`);
|
|
1275
|
+
this.dispatches.push({
|
|
1276
|
+
node,
|
|
1277
|
+
spec,
|
|
1278
|
+
pipeline,
|
|
1279
|
+
bindGroup,
|
|
1280
|
+
uniformBuffer
|
|
1281
|
+
});
|
|
1282
|
+
}
|
|
1283
|
+
}
|
|
1284
|
+
/**
|
|
1285
|
+
* Encode patches → merged image embeddings [Nm, out_hidden_size].
|
|
1286
|
+
*
|
|
1287
|
+
* `onStage` (optional) fires coarse phase breadcrumbs during the WebKit path so
|
|
1288
|
+
* a host can localize a GPU-process crash to a specific layer.
|
|
1289
|
+
*/
|
|
1290
|
+
async encode(inputs, onStage) {
|
|
1291
|
+
const N = inputs.numPatches;
|
|
1292
|
+
if (N > this.maxPatches) throw new Error(`VisionExecutor: ${N} patches exceeds maxPatches=${this.maxPatches}`);
|
|
1293
|
+
const q = this.ctx.device.queue;
|
|
1294
|
+
q.writeBuffer(this.activationBuffers.get("vis_patches"), 0, inputs.patches);
|
|
1295
|
+
q.writeBuffer(this.activationBuffers.get("vis_pos_embeds"), 0, inputs.posEmbeds);
|
|
1296
|
+
q.writeBuffer(this.activationBuffers.get("vis_cos"), 0, inputs.cos);
|
|
1297
|
+
q.writeBuffer(this.activationBuffers.get("vis_sin"), 0, inputs.sin);
|
|
1298
|
+
const resolvedShapes = this.resolveShapes(N);
|
|
1299
|
+
const ctxRt = { seqPos: 0 };
|
|
1300
|
+
const sizes = [];
|
|
1301
|
+
for (const d of this.dispatches) {
|
|
1302
|
+
const params = d.spec.buildParams(d.node, resolvedShapes, ctxRt);
|
|
1303
|
+
q.writeBuffer(d.uniformBuffer, 0, params);
|
|
1304
|
+
sizes.push(d.spec.getDispatchSize(d.node, resolvedShapes, ctxRt));
|
|
1305
|
+
}
|
|
1306
|
+
if (this.ctx.isWebKitWebGPU) {
|
|
1307
|
+
onStage?.("vit-encode-start", { total: this.dispatches.length });
|
|
1308
|
+
for (let i = 0; i < this.dispatches.length; i++) {
|
|
1309
|
+
const d = this.dispatches[i];
|
|
1310
|
+
if (d.node.opType === "Attention") {
|
|
1311
|
+
const [gridX, gridY, gridZ] = sizes[i];
|
|
1312
|
+
const layer = Number.parseInt(d.node.id.replace(/\D+/g, ""), 10);
|
|
1313
|
+
if (!Number.isNaN(layer)) onStage?.("vit-attn", { layer });
|
|
1314
|
+
for (let start = 0; start < gridX; start += WEBKIT_ATTN_CHUNK_ROWS) {
|
|
1315
|
+
const rows$1 = Math.min(WEBKIT_ATTN_CHUNK_ROWS, gridX - start);
|
|
1316
|
+
const params = d.spec.buildParams(d.node, resolvedShapes, {
|
|
1317
|
+
seqPos: 0,
|
|
1318
|
+
qOffset: start
|
|
1319
|
+
});
|
|
1320
|
+
q.writeBuffer(d.uniformBuffer, 0, params);
|
|
1321
|
+
const enc$2 = this.ctx.device.createCommandEncoder();
|
|
1322
|
+
const p$1 = enc$2.beginComputePass();
|
|
1323
|
+
p$1.setPipeline(d.pipeline);
|
|
1324
|
+
p$1.setBindGroup(0, d.bindGroup);
|
|
1325
|
+
p$1.dispatchWorkgroups(rows$1, gridY, gridZ);
|
|
1326
|
+
p$1.end();
|
|
1327
|
+
q.submit([enc$2.finish()]);
|
|
1328
|
+
await q.onSubmittedWorkDone();
|
|
1329
|
+
}
|
|
1330
|
+
continue;
|
|
1331
|
+
}
|
|
1332
|
+
const enc$1 = this.ctx.device.createCommandEncoder();
|
|
1333
|
+
const p = enc$1.beginComputePass();
|
|
1334
|
+
p.setPipeline(d.pipeline);
|
|
1335
|
+
p.setBindGroup(0, d.bindGroup);
|
|
1336
|
+
p.dispatchWorkgroups(...sizes[i]);
|
|
1337
|
+
p.end();
|
|
1338
|
+
q.submit([enc$1.finish()]);
|
|
1339
|
+
await q.onSubmittedWorkDone();
|
|
1340
|
+
}
|
|
1341
|
+
onStage?.("vit-encode-done");
|
|
1342
|
+
} else {
|
|
1343
|
+
const enc$1 = this.ctx.device.createCommandEncoder({ label: "vision_encode" });
|
|
1344
|
+
const pass = enc$1.beginComputePass({ label: "vision_pass" });
|
|
1345
|
+
for (let i = 0; i < this.dispatches.length; i++) {
|
|
1346
|
+
pass.setPipeline(this.dispatches[i].pipeline);
|
|
1347
|
+
pass.setBindGroup(0, this.dispatches[i].bindGroup);
|
|
1348
|
+
pass.dispatchWorkgroups(...sizes[i]);
|
|
1349
|
+
}
|
|
1350
|
+
pass.end();
|
|
1351
|
+
q.submit([enc$1.finish()]);
|
|
1352
|
+
}
|
|
1353
|
+
const rows = N / this.mergeUnit;
|
|
1354
|
+
const dim = this.graph.config.vision_embed_dim ? this.graph.tensors.vis_image_embeds.shape[1] : this.graph.tensors.vis_image_embeds.shape[1];
|
|
1355
|
+
const byteLen = rows * dim * 4;
|
|
1356
|
+
const out = this.activationBuffers.get("vis_image_embeds");
|
|
1357
|
+
const readback = this.ctx.device.createBuffer({
|
|
1358
|
+
size: byteLen,
|
|
1359
|
+
usage: 9
|
|
1360
|
+
});
|
|
1361
|
+
const enc = this.ctx.device.createCommandEncoder();
|
|
1362
|
+
enc.copyBufferToBuffer(out, 0, readback, 0, byteLen);
|
|
1363
|
+
this.ctx.device.queue.submit([enc.finish()]);
|
|
1364
|
+
await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
|
|
1365
|
+
const embeds = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
|
|
1366
|
+
readback.unmap();
|
|
1367
|
+
readback.destroy();
|
|
1368
|
+
return {
|
|
1369
|
+
embeds,
|
|
1370
|
+
rows,
|
|
1371
|
+
dim
|
|
1372
|
+
};
|
|
1373
|
+
}
|
|
1374
|
+
/**
|
|
1375
|
+
* Encode patches through the Gemma 4 ViT → projected image tokens [Np, text_hidden].
|
|
1376
|
+
*
|
|
1377
|
+
* Distinct from the Qwen `encode()`: the Gemma graph has 5 inputs (patches,
|
|
1378
|
+
* axial pos-embeds, axial rotary cos/sin, and a host-built [Np,N] pooling matrix)
|
|
1379
|
+
* and its output rows (Np) are the pooled soft-token count, resolved from the
|
|
1380
|
+
* pooling matrix rather than an N/mergeUnit ratio. Reuses the same dispatch
|
|
1381
|
+
* machinery + WebKit per-dispatch-drain discipline.
|
|
1382
|
+
*/
|
|
1383
|
+
async encodeGemma4(inputs, onStage) {
|
|
1384
|
+
const N = inputs.numPatches;
|
|
1385
|
+
if (N > this.maxPatches) throw new Error(`VisionExecutor(Gemma4): ${N} patches exceeds maxPatches=${this.maxPatches}`);
|
|
1386
|
+
this.gemma4Np = inputs.numPooled;
|
|
1387
|
+
const q = this.ctx.device.queue;
|
|
1388
|
+
q.writeBuffer(this.activationBuffers.get("g4v_patches"), 0, inputs.patches);
|
|
1389
|
+
q.writeBuffer(this.activationBuffers.get("g4v_pos_embeds"), 0, inputs.posEmbeds);
|
|
1390
|
+
q.writeBuffer(this.activationBuffers.get("g4v_cos"), 0, inputs.cos);
|
|
1391
|
+
q.writeBuffer(this.activationBuffers.get("g4v_sin"), 0, inputs.sin);
|
|
1392
|
+
q.writeBuffer(this.activationBuffers.get("g4v_pool_w"), 0, inputs.poolMatrix);
|
|
1393
|
+
const resolvedShapes = this.resolveShapes(N);
|
|
1394
|
+
const ctxRt = { seqPos: 0 };
|
|
1395
|
+
const sizes = [];
|
|
1396
|
+
for (const d of this.dispatches) {
|
|
1397
|
+
const params = d.spec.buildParams(d.node, resolvedShapes, ctxRt);
|
|
1398
|
+
q.writeBuffer(d.uniformBuffer, 0, params);
|
|
1399
|
+
sizes.push(d.spec.getDispatchSize(d.node, resolvedShapes, ctxRt));
|
|
1400
|
+
}
|
|
1401
|
+
if (this.ctx.isWebKitWebGPU) {
|
|
1402
|
+
onStage?.("vit-encode-start", { total: this.dispatches.length });
|
|
1403
|
+
for (let i = 0; i < this.dispatches.length; i++) {
|
|
1404
|
+
const d = this.dispatches[i];
|
|
1405
|
+
if (d.node.opType === "Attention") {
|
|
1406
|
+
const [gridX, gridY, gridZ] = sizes[i];
|
|
1407
|
+
const layer = Number.parseInt(d.node.id.replace(/\D+/g, ""), 10);
|
|
1408
|
+
if (!Number.isNaN(layer)) onStage?.("vit-attn", { layer });
|
|
1409
|
+
for (let start = 0; start < gridX; start += WEBKIT_ATTN_CHUNK_ROWS) {
|
|
1410
|
+
const rows$1 = Math.min(WEBKIT_ATTN_CHUNK_ROWS, gridX - start);
|
|
1411
|
+
const params = d.spec.buildParams(d.node, resolvedShapes, {
|
|
1412
|
+
seqPos: 0,
|
|
1413
|
+
qOffset: start
|
|
1414
|
+
});
|
|
1415
|
+
q.writeBuffer(d.uniformBuffer, 0, params);
|
|
1416
|
+
const enc$2 = this.ctx.device.createCommandEncoder();
|
|
1417
|
+
const p$1 = enc$2.beginComputePass();
|
|
1418
|
+
p$1.setPipeline(d.pipeline);
|
|
1419
|
+
p$1.setBindGroup(0, d.bindGroup);
|
|
1420
|
+
p$1.dispatchWorkgroups(rows$1, gridY, gridZ);
|
|
1421
|
+
p$1.end();
|
|
1422
|
+
q.submit([enc$2.finish()]);
|
|
1423
|
+
await q.onSubmittedWorkDone();
|
|
1424
|
+
}
|
|
1425
|
+
continue;
|
|
1426
|
+
}
|
|
1427
|
+
const enc$1 = this.ctx.device.createCommandEncoder();
|
|
1428
|
+
const p = enc$1.beginComputePass();
|
|
1429
|
+
p.setPipeline(d.pipeline);
|
|
1430
|
+
p.setBindGroup(0, d.bindGroup);
|
|
1431
|
+
p.dispatchWorkgroups(...sizes[i]);
|
|
1432
|
+
p.end();
|
|
1433
|
+
q.submit([enc$1.finish()]);
|
|
1434
|
+
await q.onSubmittedWorkDone();
|
|
1435
|
+
}
|
|
1436
|
+
onStage?.("vit-encode-done");
|
|
1437
|
+
} else {
|
|
1438
|
+
const enc$1 = this.ctx.device.createCommandEncoder({ label: "gemma4_vision_encode" });
|
|
1439
|
+
const pass = enc$1.beginComputePass({ label: "gemma4_vision_pass" });
|
|
1440
|
+
for (let i = 0; i < this.dispatches.length; i++) {
|
|
1441
|
+
pass.setPipeline(this.dispatches[i].pipeline);
|
|
1442
|
+
pass.setBindGroup(0, this.dispatches[i].bindGroup);
|
|
1443
|
+
pass.dispatchWorkgroups(...sizes[i]);
|
|
1444
|
+
}
|
|
1445
|
+
pass.end();
|
|
1446
|
+
q.submit([enc$1.finish()]);
|
|
1447
|
+
}
|
|
1448
|
+
const rows = inputs.numPooled;
|
|
1449
|
+
const dim = this.graph.tensors.g4v_image_embeds.shape[1];
|
|
1450
|
+
const byteLen = rows * dim * 4;
|
|
1451
|
+
const out = this.activationBuffers.get("g4v_image_embeds");
|
|
1452
|
+
const readback = this.ctx.device.createBuffer({
|
|
1453
|
+
size: byteLen,
|
|
1454
|
+
usage: 9
|
|
1455
|
+
});
|
|
1456
|
+
const enc = this.ctx.device.createCommandEncoder();
|
|
1457
|
+
enc.copyBufferToBuffer(out, 0, readback, 0, byteLen);
|
|
1458
|
+
this.ctx.device.queue.submit([enc.finish()]);
|
|
1459
|
+
await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
|
|
1460
|
+
const embeds = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
|
|
1461
|
+
readback.unmap();
|
|
1462
|
+
readback.destroy();
|
|
1463
|
+
return {
|
|
1464
|
+
embeds,
|
|
1465
|
+
rows,
|
|
1466
|
+
dim
|
|
1467
|
+
};
|
|
1468
|
+
}
|
|
1469
|
+
/** True if this executor is the Gemma 4 vision tower. */
|
|
1470
|
+
get gemma4() {
|
|
1471
|
+
return this.isGemma4;
|
|
1472
|
+
}
|
|
1473
|
+
/** Read back any named activation (debug). Must be called right after encode(). */
|
|
1474
|
+
async debugReadBuffer(name, maxElements) {
|
|
1475
|
+
const buffer = this.activationBuffers.get(name) ?? this.weightBuffers.get(name);
|
|
1476
|
+
if (!buffer) throw new Error(`VisionExecutor: no buffer "${name}"`);
|
|
1477
|
+
const byteLen = maxElements ? Math.min(maxElements * 4, buffer.size) : buffer.size;
|
|
1478
|
+
const readback = this.ctx.device.createBuffer({
|
|
1479
|
+
size: byteLen,
|
|
1480
|
+
usage: 9
|
|
1481
|
+
});
|
|
1482
|
+
const enc = this.ctx.device.createCommandEncoder();
|
|
1483
|
+
enc.copyBufferToBuffer(buffer, 0, readback, 0, byteLen);
|
|
1484
|
+
this.ctx.device.queue.submit([enc.finish()]);
|
|
1485
|
+
await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
|
|
1486
|
+
const data = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
|
|
1487
|
+
readback.unmap();
|
|
1488
|
+
readback.destroy();
|
|
1489
|
+
return data;
|
|
1490
|
+
}
|
|
1491
|
+
destroy() {
|
|
1492
|
+
destroyBuffers([...this.weightBuffers.values()]);
|
|
1493
|
+
destroyBuffers([...this.activationBuffers.values()]);
|
|
1494
|
+
for (const d of this.dispatches) d.uniformBuffer.destroy();
|
|
1495
|
+
this.weightBuffers.clear();
|
|
1496
|
+
this.activationBuffers.clear();
|
|
1497
|
+
this.dispatches = [];
|
|
1498
|
+
}
|
|
1499
|
+
/** Max pooled tokens for buffer sizing: maxPatches with no merge/pool ratio applied. */
|
|
1500
|
+
maxPooled() {
|
|
1501
|
+
return this.maxPatches;
|
|
1502
|
+
}
|
|
1503
|
+
resolveShapes(N) {
|
|
1504
|
+
const Nm = N / this.mergeUnit;
|
|
1505
|
+
const Np = this.gemma4Np || N;
|
|
1506
|
+
const resolved = {};
|
|
1507
|
+
for (const [name, desc] of Object.entries(this.graph.tensors)) resolved[name] = desc.shape.map((d) => {
|
|
1508
|
+
if (d === "N") return N;
|
|
1509
|
+
if (d === "Nm") return Nm;
|
|
1510
|
+
if (d === "Np") return Np;
|
|
1511
|
+
return d;
|
|
1512
|
+
});
|
|
1513
|
+
return resolved;
|
|
1514
|
+
}
|
|
1515
|
+
allocateActivationBuffers() {
|
|
1516
|
+
for (const [name, desc] of Object.entries(this.graph.tensors)) {
|
|
1517
|
+
if (desc.storage !== "activation") continue;
|
|
1518
|
+
const bytes = desc.shape.map((d) => {
|
|
1519
|
+
if (d === "N") return this.maxPatches;
|
|
1520
|
+
if (d === "Nm") return this.maxPatches / this.mergeUnit;
|
|
1521
|
+
if (d === "Np") return this.maxPooled();
|
|
1522
|
+
return d;
|
|
1523
|
+
}).reduce((a, b) => a * b, 1) * DTYPE_BYTES[desc.dtype];
|
|
1524
|
+
this.activationBuffers.set(name, createStorageBuffer(this.ctx, `vact_${name}`, bytes));
|
|
1525
|
+
}
|
|
1526
|
+
}
|
|
1527
|
+
gatherBuffers(spec, node, uniformBuffer) {
|
|
1528
|
+
const entries = [];
|
|
1529
|
+
const tensorNames = [...node.inputs, ...node.outputs];
|
|
1530
|
+
let tensorIdx = 0;
|
|
1531
|
+
for (const binding of spec.bindings) if (binding.type === "uniform") entries.push({ buffer: uniformBuffer });
|
|
1532
|
+
else {
|
|
1533
|
+
const tensorName = tensorNames[tensorIdx++];
|
|
1534
|
+
const buffer = this.weightBuffers.get(tensorName) ?? this.activationBuffers.get(tensorName);
|
|
1535
|
+
if (!buffer) throw new Error(`VisionExecutor: no buffer for "${tensorName}" in op ${node.id}`);
|
|
1536
|
+
entries.push({ buffer });
|
|
1537
|
+
}
|
|
1538
|
+
return entries;
|
|
1539
|
+
}
|
|
1540
|
+
};
|
|
1541
|
+
|
|
1542
|
+
//#endregion
|
|
1543
|
+
//#region src/gpu/vision-preprocess.ts
|
|
1544
|
+
/** linspace(0, n-1, count) matching torch.linspace. */
|
|
1545
|
+
function linspace(stop, count) {
|
|
1546
|
+
const out = new Float64Array(count);
|
|
1547
|
+
if (count === 1) {
|
|
1548
|
+
out[0] = 0;
|
|
1549
|
+
return out;
|
|
1550
|
+
}
|
|
1551
|
+
const step = stop / (count - 1);
|
|
1552
|
+
for (let i = 0; i < count; i++) out[i] = i * step;
|
|
1553
|
+
return out;
|
|
1554
|
+
}
|
|
1555
|
+
/**
|
|
1556
|
+
* Compute the spatial-merge reorder index for a (h, w) grid:
|
|
1557
|
+
* reorder[k] gives the source patch index (row-major h*w) for output slot k,
|
|
1558
|
+
* grouping 2×2 (merge×merge) spatial blocks contiguously. Matches the `reorder`
|
|
1559
|
+
* in get_vision_bilinear_indices_and_weights and the position-id reshape.
|
|
1560
|
+
*/
|
|
1561
|
+
function spatialMergeReorder(h, w, merge) {
|
|
1562
|
+
const hb = h / merge;
|
|
1563
|
+
const wb = w / merge;
|
|
1564
|
+
const out = new Int32Array(h * w);
|
|
1565
|
+
let k = 0;
|
|
1566
|
+
for (let bh = 0; bh < hb; bh++) for (let bw = 0; bw < wb; bw++) for (let ih = 0; ih < merge; ih++) for (let iw = 0; iw < merge; iw++) {
|
|
1567
|
+
const hh = bh * merge + ih;
|
|
1568
|
+
const ww = bw * merge + iw;
|
|
1569
|
+
out[k++] = hh * w + ww;
|
|
1570
|
+
}
|
|
1571
|
+
return out;
|
|
1572
|
+
}
|
|
1573
|
+
/**
|
|
1574
|
+
* Build bilinear-interpolated learned position embeddings [N, hidden].
|
|
1575
|
+
* posEmbedTable is the raw pos_embed.weight [num_position_embeddings, hidden].
|
|
1576
|
+
*/
|
|
1577
|
+
function buildPosEmbeds(gridTHW, posEmbedTable, cfg) {
|
|
1578
|
+
const [t, h, w] = gridTHW;
|
|
1579
|
+
const hidden = cfg.hiddenSize;
|
|
1580
|
+
const side = Math.round(Math.sqrt(cfg.numPositionEmbeddings));
|
|
1581
|
+
const merge = cfg.spatialMergeSize;
|
|
1582
|
+
const hGrid = linspace(side - 1, h);
|
|
1583
|
+
const wGrid = linspace(side - 1, w);
|
|
1584
|
+
const hFloor = new Int32Array(h);
|
|
1585
|
+
const hCeil = new Int32Array(h);
|
|
1586
|
+
const hFrac = new Float64Array(h);
|
|
1587
|
+
for (let i = 0; i < h; i++) {
|
|
1588
|
+
hFloor[i] = Math.trunc(hGrid[i]);
|
|
1589
|
+
hCeil[i] = Math.min(hFloor[i] + 1, side - 1);
|
|
1590
|
+
hFrac[i] = hGrid[i] - hFloor[i];
|
|
1591
|
+
}
|
|
1592
|
+
const wFloor = new Int32Array(w);
|
|
1593
|
+
const wCeil = new Int32Array(w);
|
|
1594
|
+
const wFrac = new Float64Array(w);
|
|
1595
|
+
for (let j = 0; j < w; j++) {
|
|
1596
|
+
wFloor[j] = Math.trunc(wGrid[j]);
|
|
1597
|
+
wCeil[j] = Math.min(wFloor[j] + 1, side - 1);
|
|
1598
|
+
wFrac[j] = wGrid[j] - wFloor[j];
|
|
1599
|
+
}
|
|
1600
|
+
const hw = h * w;
|
|
1601
|
+
const cornerIdx = [
|
|
1602
|
+
new Int32Array(hw),
|
|
1603
|
+
new Int32Array(hw),
|
|
1604
|
+
new Int32Array(hw),
|
|
1605
|
+
new Int32Array(hw)
|
|
1606
|
+
];
|
|
1607
|
+
const cornerW = [
|
|
1608
|
+
new Float64Array(hw),
|
|
1609
|
+
new Float64Array(hw),
|
|
1610
|
+
new Float64Array(hw),
|
|
1611
|
+
new Float64Array(hw)
|
|
1612
|
+
];
|
|
1613
|
+
let p = 0;
|
|
1614
|
+
for (let i = 0; i < h; i++) {
|
|
1615
|
+
const hf = hFloor[i] * side;
|
|
1616
|
+
const hc = hCeil[i] * side;
|
|
1617
|
+
const hfr = hFrac[i];
|
|
1618
|
+
for (let j = 0; j < w; j++) {
|
|
1619
|
+
cornerIdx[0][p] = hf + wFloor[j];
|
|
1620
|
+
cornerIdx[1][p] = hf + wCeil[j];
|
|
1621
|
+
cornerIdx[2][p] = hc + wFloor[j];
|
|
1622
|
+
cornerIdx[3][p] = hc + wCeil[j];
|
|
1623
|
+
cornerW[0][p] = (1 - hfr) * (1 - wFrac[j]);
|
|
1624
|
+
cornerW[1][p] = (1 - hfr) * wFrac[j];
|
|
1625
|
+
cornerW[2][p] = hfr * (1 - wFrac[j]);
|
|
1626
|
+
cornerW[3][p] = hfr * wFrac[j];
|
|
1627
|
+
p++;
|
|
1628
|
+
}
|
|
1629
|
+
}
|
|
1630
|
+
const reorder = spatialMergeReorder(h, w, merge);
|
|
1631
|
+
const N = t * hw;
|
|
1632
|
+
const out = new Float32Array(N * hidden);
|
|
1633
|
+
for (let tt = 0; tt < t; tt++) for (let k = 0; k < hw; k++) {
|
|
1634
|
+
const src = reorder[k];
|
|
1635
|
+
const dstBase = (tt * hw + k) * hidden;
|
|
1636
|
+
for (let corner = 0; corner < 4; corner++) {
|
|
1637
|
+
const tableRow = cornerIdx[corner][src] * hidden;
|
|
1638
|
+
const wgt = cornerW[corner][src];
|
|
1639
|
+
if (wgt === 0) continue;
|
|
1640
|
+
for (let d = 0; d < hidden; d++) out[dstBase + d] += posEmbedTable[tableRow + d] * wgt;
|
|
1641
|
+
}
|
|
1642
|
+
}
|
|
1643
|
+
return out;
|
|
1644
|
+
}
|
|
1645
|
+
/**
|
|
1646
|
+
* Build the reordered (row, col) position ids [N, 2] for rotary, matching
|
|
1647
|
+
* get_vision_position_ids.
|
|
1648
|
+
*/
|
|
1649
|
+
function buildPositionIds(gridTHW, merge) {
|
|
1650
|
+
const [t, h, w] = gridTHW;
|
|
1651
|
+
const hb = h / merge;
|
|
1652
|
+
const wb = w / merge;
|
|
1653
|
+
const hw = h * w;
|
|
1654
|
+
const perFrame = new Int32Array(hw * 2);
|
|
1655
|
+
let k = 0;
|
|
1656
|
+
for (let bh = 0; bh < hb; bh++) for (let bw = 0; bw < wb; bw++) for (let ih = 0; ih < merge; ih++) for (let iw = 0; iw < merge; iw++) {
|
|
1657
|
+
perFrame[k * 2] = bh * merge + ih;
|
|
1658
|
+
perFrame[k * 2 + 1] = bw * merge + iw;
|
|
1659
|
+
k++;
|
|
1660
|
+
}
|
|
1661
|
+
const out = new Int32Array(t * hw * 2);
|
|
1662
|
+
for (let tt = 0; tt < t; tt++) out.set(perFrame, tt * hw * 2);
|
|
1663
|
+
return out;
|
|
1664
|
+
}
|
|
1665
|
+
/**
|
|
1666
|
+
* Build rotary cos/sin tables [N, head_dim] from position ids, matching
|
|
1667
|
+
* Qwen3_5VisionRotaryEmbedding + the cat((rotary, rotary)) in VisionModel.forward.
|
|
1668
|
+
*
|
|
1669
|
+
* rotary_pos_emb(position_ids) = (position_ids[..,None] * inv_freq).flatten(1)
|
|
1670
|
+
* where inv_freq has length (head_dim/2)/2 = head_dim/4, computed over dim=head_dim/2.
|
|
1671
|
+
* For each token the two position components (h, w) each produce head_dim/4 freqs,
|
|
1672
|
+
* concatenated → head_dim/2, then duplicated → head_dim for cos/sin.
|
|
1673
|
+
*/
|
|
1674
|
+
function buildRotaryCosSin(positionIds, headDim, theta = 1e4) {
|
|
1675
|
+
const rotaryDim = headDim / 2;
|
|
1676
|
+
const half = rotaryDim / 2;
|
|
1677
|
+
const invFreq = new Float64Array(half);
|
|
1678
|
+
for (let i = 0; i < half; i++) invFreq[i] = 1 / theta ** (2 * i / rotaryDim);
|
|
1679
|
+
const N = positionIds.length / 2;
|
|
1680
|
+
const cos = new Float32Array(N * headDim);
|
|
1681
|
+
const sin = new Float32Array(N * headDim);
|
|
1682
|
+
for (let n = 0; n < N; n++) {
|
|
1683
|
+
const hp = positionIds[n * 2];
|
|
1684
|
+
const wp = positionIds[n * 2 + 1];
|
|
1685
|
+
const base = n * headDim;
|
|
1686
|
+
for (let i = 0; i < half; i++) {
|
|
1687
|
+
const fh = hp * invFreq[i];
|
|
1688
|
+
const fw = wp * invFreq[i];
|
|
1689
|
+
const ch = Math.cos(fh);
|
|
1690
|
+
const sh = Math.sin(fh);
|
|
1691
|
+
const cw = Math.cos(fw);
|
|
1692
|
+
const sw = Math.sin(fw);
|
|
1693
|
+
cos[base + i] = ch;
|
|
1694
|
+
sin[base + i] = sh;
|
|
1695
|
+
cos[base + half + i] = cw;
|
|
1696
|
+
sin[base + half + i] = sw;
|
|
1697
|
+
cos[base + rotaryDim + i] = ch;
|
|
1698
|
+
sin[base + rotaryDim + i] = sh;
|
|
1699
|
+
cos[base + rotaryDim + half + i] = cw;
|
|
1700
|
+
sin[base + rotaryDim + half + i] = sw;
|
|
1701
|
+
}
|
|
1702
|
+
}
|
|
1703
|
+
return {
|
|
1704
|
+
cos,
|
|
1705
|
+
sin,
|
|
1706
|
+
numPatches: N
|
|
1707
|
+
};
|
|
1708
|
+
}
|
|
1709
|
+
/**
|
|
1710
|
+
* Build all host position tensors for a single image grid in one call.
|
|
1711
|
+
*/
|
|
1712
|
+
function buildVisionPositionTensors(gridTHW, posEmbedTable, cfg) {
|
|
1713
|
+
const headDim = Math.floor(cfg.hiddenSize / cfg.numHeads);
|
|
1714
|
+
const posEmbeds = buildPosEmbeds(gridTHW, posEmbedTable, cfg);
|
|
1715
|
+
const { cos, sin, numPatches } = buildRotaryCosSin(buildPositionIds(gridTHW, cfg.spatialMergeSize), headDim, cfg.ropeTheta ?? 1e4);
|
|
1716
|
+
return {
|
|
1717
|
+
posEmbeds,
|
|
1718
|
+
cos,
|
|
1719
|
+
sin,
|
|
1720
|
+
numPatches
|
|
1721
|
+
};
|
|
1722
|
+
}
|
|
1723
|
+
/**
|
|
1724
|
+
* Per-patch (x, y) grid coordinates in row-major order, matching the Gemma4
|
|
1725
|
+
* image processor's meshgrid(arange(width), arange(height)) reshape: patch p at
|
|
1726
|
+
* row r (0..gridH-1), col c (0..gridW-1) → x=c, y=r, index = r*gridW + c.
|
|
1727
|
+
*/
|
|
1728
|
+
function gemma4PatchXY(gridH, gridW) {
|
|
1729
|
+
const n = gridH * gridW;
|
|
1730
|
+
const x = new Int32Array(n);
|
|
1731
|
+
const y = new Int32Array(n);
|
|
1732
|
+
let p = 0;
|
|
1733
|
+
for (let r = 0; r < gridH; r++) for (let c = 0; c < gridW; c++) {
|
|
1734
|
+
x[p] = c;
|
|
1735
|
+
y[p] = r;
|
|
1736
|
+
p++;
|
|
1737
|
+
}
|
|
1738
|
+
return {
|
|
1739
|
+
x,
|
|
1740
|
+
y
|
|
1741
|
+
};
|
|
1742
|
+
}
|
|
1743
|
+
/**
|
|
1744
|
+
* Build axial learned position embeddings [N, hidden] from the [2, posSize, hidden]
|
|
1745
|
+
* table: pos[p] = table[0][x_p] + table[1][y_p]. Direct lookup, no interpolation
|
|
1746
|
+
* (HF F.embedding on clamped positions).
|
|
1747
|
+
*/
|
|
1748
|
+
function buildGemma4PosEmbeds(gridH, gridW, posEmbedTable, hidden, posSize) {
|
|
1749
|
+
const { x, y } = gemma4PatchXY(gridH, gridW);
|
|
1750
|
+
const n = gridH * gridW;
|
|
1751
|
+
const out = new Float32Array(n * hidden);
|
|
1752
|
+
const yPlane = posSize * hidden;
|
|
1753
|
+
for (let p = 0; p < n; p++) {
|
|
1754
|
+
const xi = Math.max(0, x[p]);
|
|
1755
|
+
const yi = Math.max(0, y[p]);
|
|
1756
|
+
const xBase = xi * hidden;
|
|
1757
|
+
const yBase = yPlane + yi * hidden;
|
|
1758
|
+
const dst = p * hidden;
|
|
1759
|
+
for (let d = 0; d < hidden; d++) out[dst + d] = posEmbedTable[xBase + d] + posEmbedTable[yBase + d];
|
|
1760
|
+
}
|
|
1761
|
+
return out;
|
|
1762
|
+
}
|
|
1763
|
+
/**
|
|
1764
|
+
* Build the 2D axial rotary cos/sin tables [N, headDim].
|
|
1765
|
+
* spatial_dim = headDim / 2; inv_freq[j] = 1/theta^((2j)/spatial_dim), j in [0, spatial_dim/2)
|
|
1766
|
+
* per spatial dim: f = pos * inv_freq (spatial_dim/2 values); emb = cat(f, f) (spatial_dim values)
|
|
1767
|
+
* cos/sin = cat([emb_x, emb_y]) → headDim values, layout [fx,fx,fy,fy].
|
|
1768
|
+
* Applied with the global-half rotate_half kernel (ApplyRotaryEmb), which computes
|
|
1769
|
+
* out = x*cos + rotate_half(x)*sin element-wise — exact for this layout.
|
|
1770
|
+
*/
|
|
1771
|
+
function buildGemma4RotaryCosSin(gridH, gridW, headDim, theta) {
|
|
1772
|
+
const { x, y } = gemma4PatchXY(gridH, gridW);
|
|
1773
|
+
const n = gridH * gridW;
|
|
1774
|
+
const spatialDim = headDim / 2;
|
|
1775
|
+
const half = spatialDim / 2;
|
|
1776
|
+
const invFreq = new Float64Array(half);
|
|
1777
|
+
for (let j = 0; j < half; j++) invFreq[j] = 1 / theta ** (2 * j / spatialDim);
|
|
1778
|
+
const cos = new Float32Array(n * headDim);
|
|
1779
|
+
const sin = new Float32Array(n * headDim);
|
|
1780
|
+
for (let p = 0; p < n; p++) {
|
|
1781
|
+
const xp = Math.max(0, x[p]);
|
|
1782
|
+
const yp = Math.max(0, y[p]);
|
|
1783
|
+
const base = p * headDim;
|
|
1784
|
+
for (let j = 0; j < half; j++) {
|
|
1785
|
+
const fx = xp * invFreq[j];
|
|
1786
|
+
const fy = yp * invFreq[j];
|
|
1787
|
+
const cx = Math.cos(fx);
|
|
1788
|
+
const sx = Math.sin(fx);
|
|
1789
|
+
const cy = Math.cos(fy);
|
|
1790
|
+
const sy = Math.sin(fy);
|
|
1791
|
+
cos[base + j] = cx;
|
|
1792
|
+
sin[base + j] = sx;
|
|
1793
|
+
cos[base + half + j] = cx;
|
|
1794
|
+
sin[base + half + j] = sx;
|
|
1795
|
+
cos[base + spatialDim + j] = cy;
|
|
1796
|
+
sin[base + spatialDim + j] = sy;
|
|
1797
|
+
cos[base + spatialDim + half + j] = cy;
|
|
1798
|
+
sin[base + spatialDim + half + j] = sy;
|
|
1799
|
+
}
|
|
1800
|
+
}
|
|
1801
|
+
return {
|
|
1802
|
+
cos,
|
|
1803
|
+
sin
|
|
1804
|
+
};
|
|
1805
|
+
}
|
|
1806
|
+
/**
|
|
1807
|
+
* Build the [Np, N] average-pooling matrix for k×k spatial pooling over the real
|
|
1808
|
+
* (unpadded) grid, matching modeling_gemma4's kernel_idxs/one_hot pooling:
|
|
1809
|
+
* cell(p) = floor(x_p/k) + ceil(gridW/k) * floor(y_p/k)
|
|
1810
|
+
* poolMatrix[cell, p] = 1/k² (so pooled = poolMatrix @ hidden = mean over the k×k block)
|
|
1811
|
+
* Np = ceil(gridH/k) * ceil(gridW/k). Each pooled cell averages exactly the patches
|
|
1812
|
+
* that fall in it (edge cells with fewer than k² patches still divide by k², matching
|
|
1813
|
+
* HF's fixed 1/k² normalization).
|
|
1814
|
+
*/
|
|
1815
|
+
function buildGemma4PoolMatrix(gridH, gridW, k) {
|
|
1816
|
+
const n = gridH * gridW;
|
|
1817
|
+
const cellsW = Math.ceil(gridW / k);
|
|
1818
|
+
const np = cellsW * Math.ceil(gridH / k);
|
|
1819
|
+
const inv = 1 / (k * k);
|
|
1820
|
+
const poolMatrix = new Float32Array(np * n);
|
|
1821
|
+
const { x, y } = gemma4PatchXY(gridH, gridW);
|
|
1822
|
+
for (let p = 0; p < n; p++) {
|
|
1823
|
+
const cell = Math.floor(x[p] / k) + cellsW * Math.floor(y[p] / k);
|
|
1824
|
+
poolMatrix[cell * n + p] = inv;
|
|
1825
|
+
}
|
|
1826
|
+
return {
|
|
1827
|
+
poolMatrix,
|
|
1828
|
+
numPooled: np
|
|
1829
|
+
};
|
|
1830
|
+
}
|
|
1831
|
+
/**
|
|
1832
|
+
* Build all Gemma 4 vision host tensors for one image grid in one call.
|
|
1833
|
+
* `posEmbedTable` is the raw [2, posSize, hidden] flattened table.
|
|
1834
|
+
*/
|
|
1835
|
+
function buildGemma4VisionPositionTensors(gridH, gridW, posEmbedTable, posSize, cfg) {
|
|
1836
|
+
const posEmbeds = buildGemma4PosEmbeds(gridH, gridW, posEmbedTable, cfg.hiddenSize, posSize);
|
|
1837
|
+
const { cos, sin } = buildGemma4RotaryCosSin(gridH, gridW, cfg.headDim, cfg.ropeTheta);
|
|
1838
|
+
const { poolMatrix, numPooled } = buildGemma4PoolMatrix(gridH, gridW, cfg.poolingKernelSize);
|
|
1839
|
+
return {
|
|
1840
|
+
posEmbeds,
|
|
1841
|
+
cos,
|
|
1842
|
+
sin,
|
|
1843
|
+
poolMatrix,
|
|
1844
|
+
numPatches: gridH * gridW,
|
|
1845
|
+
numPooled
|
|
1846
|
+
};
|
|
1847
|
+
}
|
|
1848
|
+
/** Gemma 4 image processor config (from processor_config.json). */
|
|
1849
|
+
const GEMMA4_IMAGE_PROCESSOR = {
|
|
1850
|
+
patchSize: 16,
|
|
1851
|
+
temporalPatchSize: 1,
|
|
1852
|
+
mergeSize: 1,
|
|
1853
|
+
imageMean: [
|
|
1854
|
+
0,
|
|
1855
|
+
0,
|
|
1856
|
+
0
|
|
1857
|
+
],
|
|
1858
|
+
imageStd: [
|
|
1859
|
+
1,
|
|
1860
|
+
1,
|
|
1861
|
+
1
|
|
1862
|
+
],
|
|
1863
|
+
rescaleFactor: 1 / 255,
|
|
1864
|
+
minPixels: 256,
|
|
1865
|
+
maxPixels: 2520 * 16 * 16
|
|
1866
|
+
};
|
|
1867
|
+
/**
|
|
1868
|
+
* Preprocess a decoded RGB image for the Gemma 4 ViT: aspect-preserving resize so
|
|
1869
|
+
* the patch grid is ≤ max_soft_tokens·k² patches and H,W divisible by k·patch,
|
|
1870
|
+
* rescale ×1/255 (no normalize), patchify row-major into [N, 3·16·16].
|
|
1871
|
+
*
|
|
1872
|
+
* @param pixels row-major HWC RGB (0..255), length width*height*3.
|
|
1873
|
+
*/
|
|
1874
|
+
function preprocessImageGemma4(pixels, width, height, maxSoftTokens = 280, poolingKernelSize = 3, patchSize = 16) {
|
|
1875
|
+
const factor = poolingKernelSize * patchSize;
|
|
1876
|
+
const maxPatches = maxSoftTokens * poolingKernelSize * poolingKernelSize;
|
|
1877
|
+
const floorByFactor = (n) => Math.max(factor, Math.floor(n / factor) * factor);
|
|
1878
|
+
let outH = floorByFactor(height);
|
|
1879
|
+
let outW = floorByFactor(width);
|
|
1880
|
+
while (outH / patchSize * (outW / patchSize) > maxPatches) {
|
|
1881
|
+
const beta = Math.sqrt(outH / patchSize * (outW / patchSize) / maxPatches);
|
|
1882
|
+
outH = floorByFactor(outH / beta);
|
|
1883
|
+
outW = floorByFactor(outW / beta);
|
|
1884
|
+
if (outH <= factor && outW <= factor) break;
|
|
1885
|
+
}
|
|
1886
|
+
const resized = bilinearResize(pixels, height, width, outH, outW);
|
|
1887
|
+
const gridH = outH / patchSize;
|
|
1888
|
+
const gridW = outW / patchSize;
|
|
1889
|
+
const ch = 3;
|
|
1890
|
+
const ps = patchSize;
|
|
1891
|
+
const patchDim = ch * ps * ps;
|
|
1892
|
+
const numP = gridH * gridW;
|
|
1893
|
+
const patches = new Float32Array(numP * patchDim);
|
|
1894
|
+
const rescale = 1 / 255;
|
|
1895
|
+
let pIdx = 0;
|
|
1896
|
+
for (let pr = 0; pr < gridH; pr++) for (let pc = 0; pc < gridW; pc++) {
|
|
1897
|
+
const base = pIdx * patchDim;
|
|
1898
|
+
let kk = 0;
|
|
1899
|
+
for (let c = 0; c < ch; c++) for (let py = 0; py < ps; py++) {
|
|
1900
|
+
const row = (pr * ps + py) * outW;
|
|
1901
|
+
for (let px = 0; px < ps; px++) {
|
|
1902
|
+
patches[base + kk] = resized[(row + pc * ps + px) * 3 + c] * rescale;
|
|
1903
|
+
kk++;
|
|
1904
|
+
}
|
|
1905
|
+
}
|
|
1906
|
+
pIdx++;
|
|
1907
|
+
}
|
|
1908
|
+
return {
|
|
1909
|
+
patches,
|
|
1910
|
+
gridHW: [gridH, gridW]
|
|
1911
|
+
};
|
|
1912
|
+
}
|
|
1913
|
+
const QWEN3_5_IMAGE_PROCESSOR = {
|
|
1914
|
+
patchSize: 16,
|
|
1915
|
+
temporalPatchSize: 2,
|
|
1916
|
+
mergeSize: 2,
|
|
1917
|
+
imageMean: [
|
|
1918
|
+
.5,
|
|
1919
|
+
.5,
|
|
1920
|
+
.5
|
|
1921
|
+
],
|
|
1922
|
+
imageStd: [
|
|
1923
|
+
.5,
|
|
1924
|
+
.5,
|
|
1925
|
+
.5
|
|
1926
|
+
],
|
|
1927
|
+
rescaleFactor: 1 / 255,
|
|
1928
|
+
minPixels: 65536,
|
|
1929
|
+
maxPixels: 16777216
|
|
1930
|
+
};
|
|
1931
|
+
/**
|
|
1932
|
+
* Qwen2-VL smart-resize: round H and W to multiples of factor=patch*merge,
|
|
1933
|
+
* keeping aspect ratio and clamping the total pixel budget to [minPixels, maxPixels].
|
|
1934
|
+
* Matches transformers.models.qwen2_vl.image_processing.smart_resize.
|
|
1935
|
+
*/
|
|
1936
|
+
function smartResize(height, width, factor, minPixels, maxPixels) {
|
|
1937
|
+
if (height < factor || width < factor) {
|
|
1938
|
+
const scale = factor / Math.min(height, width);
|
|
1939
|
+
height = Math.max(factor, Math.round(height * scale));
|
|
1940
|
+
width = Math.max(factor, Math.round(width * scale));
|
|
1941
|
+
}
|
|
1942
|
+
const roundByFactor = (n) => Math.round(n / factor) * factor;
|
|
1943
|
+
const floorByFactor = (n) => Math.floor(n / factor) * factor;
|
|
1944
|
+
const ceilByFactor = (n) => Math.ceil(n / factor) * factor;
|
|
1945
|
+
let hBar = Math.max(factor, roundByFactor(height));
|
|
1946
|
+
let wBar = Math.max(factor, roundByFactor(width));
|
|
1947
|
+
if (hBar * wBar > maxPixels) {
|
|
1948
|
+
const beta = Math.sqrt(height * width / maxPixels);
|
|
1949
|
+
hBar = Math.max(factor, floorByFactor(height / beta));
|
|
1950
|
+
wBar = Math.max(factor, floorByFactor(width / beta));
|
|
1951
|
+
} else if (hBar * wBar < minPixels) {
|
|
1952
|
+
const beta = Math.sqrt(minPixels / (height * width));
|
|
1953
|
+
hBar = ceilByFactor(height * beta);
|
|
1954
|
+
wBar = ceilByFactor(width * beta);
|
|
1955
|
+
}
|
|
1956
|
+
return [hBar, wBar];
|
|
1957
|
+
}
|
|
1958
|
+
/**
|
|
1959
|
+
* Bilinear-resample an RGB image (HWC, 0..255 or 0..1) to (outH, outW).
|
|
1960
|
+
* Half-pixel-centered sampling to match PIL/torchvision BILINEAR closely.
|
|
1961
|
+
* `pixels` is row-major [H, W, 3].
|
|
1962
|
+
*/
|
|
1963
|
+
function bilinearResize(pixels, inH, inW, outH, outW) {
|
|
1964
|
+
const out = new Float32Array(outH * outW * 3);
|
|
1965
|
+
const scaleH = inH / outH;
|
|
1966
|
+
const scaleW = inW / outW;
|
|
1967
|
+
for (let oy = 0; oy < outH; oy++) {
|
|
1968
|
+
let sy = (oy + .5) * scaleH - .5;
|
|
1969
|
+
if (sy < 0) sy = 0;
|
|
1970
|
+
if (sy > inH - 1) sy = inH - 1;
|
|
1971
|
+
const y0 = Math.floor(sy);
|
|
1972
|
+
const y1 = Math.min(y0 + 1, inH - 1);
|
|
1973
|
+
const fy = sy - y0;
|
|
1974
|
+
for (let ox = 0; ox < outW; ox++) {
|
|
1975
|
+
let sx = (ox + .5) * scaleW - .5;
|
|
1976
|
+
if (sx < 0) sx = 0;
|
|
1977
|
+
if (sx > inW - 1) sx = inW - 1;
|
|
1978
|
+
const x0 = Math.floor(sx);
|
|
1979
|
+
const x1 = Math.min(x0 + 1, inW - 1);
|
|
1980
|
+
const fx = sx - x0;
|
|
1981
|
+
const i00 = (y0 * inW + x0) * 3;
|
|
1982
|
+
const i01 = (y0 * inW + x1) * 3;
|
|
1983
|
+
const i10 = (y1 * inW + x0) * 3;
|
|
1984
|
+
const i11 = (y1 * inW + x1) * 3;
|
|
1985
|
+
const od = (oy * outW + ox) * 3;
|
|
1986
|
+
for (let c = 0; c < 3; c++) {
|
|
1987
|
+
const top = pixels[i00 + c] * (1 - fx) + pixels[i01 + c] * fx;
|
|
1988
|
+
const bot = pixels[i10 + c] * (1 - fx) + pixels[i11 + c] * fx;
|
|
1989
|
+
out[od + c] = top * (1 - fy) + bot * fy;
|
|
1990
|
+
}
|
|
1991
|
+
}
|
|
1992
|
+
}
|
|
1993
|
+
return out;
|
|
1994
|
+
}
|
|
1995
|
+
/**
|
|
1996
|
+
* Preprocess a decoded RGB image into the [N, 1536] patch tensor + grid_thw that
|
|
1997
|
+
* `encodeImage()` expects, matching the HF Qwen2-VL image processor:
|
|
1998
|
+
* smart_resize → rescale (×1/255) → normalize → temporal-pair (×temporal_patch_size)
|
|
1999
|
+
* → patchify into spatial_merge×spatial_merge blocks → flatten to [N, C·T·P·P].
|
|
2000
|
+
*
|
|
2001
|
+
* @param pixels row-major HWC RGB, length width*height*3. Values 0..255 (default)
|
|
2002
|
+
* or already 0..1 if `rescaled` is true.
|
|
2003
|
+
* @param width source pixel width
|
|
2004
|
+
* @param height source pixel height
|
|
2005
|
+
*/
|
|
2006
|
+
function preprocessImage(pixels, width, height, cfg = QWEN3_5_IMAGE_PROCESSOR, rescaled = false) {
|
|
2007
|
+
const [outH, outW] = smartResize(height, width, cfg.patchSize * cfg.mergeSize, cfg.minPixels, cfg.maxPixels);
|
|
2008
|
+
const resized = bilinearResize(pixels, height, width, outH, outW);
|
|
2009
|
+
const [mr, mg, mb] = cfg.imageMean;
|
|
2010
|
+
const [sr, sg, sb] = cfg.imageStd;
|
|
2011
|
+
const mean = [
|
|
2012
|
+
mr,
|
|
2013
|
+
mg,
|
|
2014
|
+
mb
|
|
2015
|
+
];
|
|
2016
|
+
const std = [
|
|
2017
|
+
sr,
|
|
2018
|
+
sg,
|
|
2019
|
+
sb
|
|
2020
|
+
];
|
|
2021
|
+
const rescale = rescaled ? 1 : cfg.rescaleFactor;
|
|
2022
|
+
const chw = new Float32Array(3 * outH * outW);
|
|
2023
|
+
const plane = outH * outW;
|
|
2024
|
+
for (let y = 0; y < outH; y++) for (let x = 0; x < outW; x++) {
|
|
2025
|
+
const src = (y * outW + x) * 3;
|
|
2026
|
+
const dst = y * outW + x;
|
|
2027
|
+
for (let c = 0; c < 3; c++) chw[c * plane + dst] = (resized[src + c] * rescale - mean[c]) / std[c];
|
|
2028
|
+
}
|
|
2029
|
+
const t = 1;
|
|
2030
|
+
const gridH = outH / cfg.patchSize;
|
|
2031
|
+
const gridW = outW / cfg.patchSize;
|
|
2032
|
+
const ps = cfg.patchSize;
|
|
2033
|
+
const tps = cfg.temporalPatchSize;
|
|
2034
|
+
const ch = 3;
|
|
2035
|
+
const merge = cfg.mergeSize;
|
|
2036
|
+
const gridHm = gridH / merge;
|
|
2037
|
+
const gridWm = gridW / merge;
|
|
2038
|
+
const patchDim = ch * tps * ps * ps;
|
|
2039
|
+
const numPatches = t * gridH * gridW;
|
|
2040
|
+
const patches = new Float32Array(numPatches * patchDim);
|
|
2041
|
+
let pIdx = 0;
|
|
2042
|
+
for (let gt = 0; gt < t; gt++) for (let bh = 0; bh < gridHm; bh++) for (let bw = 0; bw < gridWm; bw++) for (let mh = 0; mh < merge; mh++) for (let mw = 0; mw < merge; mw++) {
|
|
2043
|
+
const patchRow = (bh * merge + mh) * ps;
|
|
2044
|
+
const patchCol = (bw * merge + mw) * ps;
|
|
2045
|
+
const base = pIdx * patchDim;
|
|
2046
|
+
let k = 0;
|
|
2047
|
+
for (let c = 0; c < ch; c++) for (let tt = 0; tt < tps; tt++) for (let py = 0; py < ps; py++) {
|
|
2048
|
+
const row = (patchRow + py) * outW;
|
|
2049
|
+
for (let px = 0; px < ps; px++) {
|
|
2050
|
+
patches[base + k] = chw[c * plane + row + (patchCol + px)];
|
|
2051
|
+
k++;
|
|
2052
|
+
}
|
|
2053
|
+
}
|
|
2054
|
+
pIdx++;
|
|
2055
|
+
}
|
|
2056
|
+
return {
|
|
2057
|
+
patches,
|
|
2058
|
+
gridTHW: [
|
|
2059
|
+
t,
|
|
2060
|
+
gridH,
|
|
2061
|
+
gridW
|
|
2062
|
+
]
|
|
2063
|
+
};
|
|
2064
|
+
}
|
|
2065
|
+
/**
|
|
2066
|
+
* Build the 3D M-RoPE position ids [3, seq] for a text sequence that contains
|
|
2067
|
+
* `image_token_id` runs, matching Qwen3_5Model.get_rope_index for a single image.
|
|
2068
|
+
*
|
|
2069
|
+
* Text tokens advance all three (t,h,w) components together (standard 1D RoPE).
|
|
2070
|
+
* Each image run uses the vision grid: temporal=start (t=1), height=start+h_block,
|
|
2071
|
+
* width=start+w_block, in row-major (h outer, w inner) order; after the image the
|
|
2072
|
+
* cursor jumps by max(h_merged, w_merged). Returns Int32Array of length 3*seq
|
|
2073
|
+
* laid out as [T-row(seq), H-row(seq), W-row(seq)].
|
|
2074
|
+
*/
|
|
2075
|
+
/**
|
|
2076
|
+
* Fill the 3D position rows for one image run starting at sequence index `at`
|
|
2077
|
+
* and logical `start`. Returns the next sequence index and the post-image cursor.
|
|
2078
|
+
*/
|
|
2079
|
+
function fillImagePositions(rows, at, start, grid, mergeSize) {
|
|
2080
|
+
const [gt, gh, gw] = grid;
|
|
2081
|
+
const hm = gh / mergeSize;
|
|
2082
|
+
const wm = gw / mergeSize;
|
|
2083
|
+
let i = at;
|
|
2084
|
+
for (let tt = 0; tt < gt; tt++) for (let h = 0; h < hm; h++) for (let w = 0; w < wm; w++) {
|
|
2085
|
+
rows.t[i] = start + tt;
|
|
2086
|
+
rows.h[i] = start + h;
|
|
2087
|
+
rows.w[i] = start + w;
|
|
2088
|
+
i++;
|
|
2089
|
+
}
|
|
2090
|
+
return {
|
|
2091
|
+
next: i,
|
|
2092
|
+
cursor: start + Math.max(hm, wm)
|
|
2093
|
+
};
|
|
2094
|
+
}
|
|
2095
|
+
function buildMRoPEPositionIds(inputIds, imageGrids, imageTokenId, mergeSize) {
|
|
2096
|
+
const seq = inputIds.length;
|
|
2097
|
+
const rows = {
|
|
2098
|
+
t: new Int32Array(seq),
|
|
2099
|
+
h: new Int32Array(seq),
|
|
2100
|
+
w: new Int32Array(seq)
|
|
2101
|
+
};
|
|
2102
|
+
let i = 0;
|
|
2103
|
+
let cursor = 0;
|
|
2104
|
+
let imageIdx = 0;
|
|
2105
|
+
while (i < seq) if (inputIds[i] === imageTokenId) {
|
|
2106
|
+
const res = fillImagePositions(rows, i, cursor, imageGrids[imageIdx++], mergeSize);
|
|
2107
|
+
i = res.next;
|
|
2108
|
+
cursor = res.cursor;
|
|
2109
|
+
} else {
|
|
2110
|
+
rows.t[i] = cursor;
|
|
2111
|
+
rows.h[i] = cursor;
|
|
2112
|
+
rows.w[i] = cursor;
|
|
2113
|
+
cursor++;
|
|
2114
|
+
i++;
|
|
2115
|
+
}
|
|
2116
|
+
const out = new Int32Array(3 * seq);
|
|
2117
|
+
out.set(rows.t, 0);
|
|
2118
|
+
out.set(rows.h, seq);
|
|
2119
|
+
out.set(rows.w, 2 * seq);
|
|
2120
|
+
return out;
|
|
2121
|
+
}
|
|
2122
|
+
/**
|
|
2123
|
+
* Per-pair frequency→dimension assignment for interleaved M-RoPE, matching
|
|
2124
|
+
* Qwen3_5TextRotaryEmbedding.apply_interleaved_mrope. For pair index i in
|
|
2125
|
+
* [0, sum(section)) the position component is section-cyclic: T,H,W,T,H,W,...
|
|
2126
|
+
* but each component capped at its section count. Returns an array of length
|
|
2127
|
+
* (rope_dim/2) with values 0=T, 1=H, 2=W.
|
|
2128
|
+
*/
|
|
2129
|
+
function mropeFreqDims(mropeSection) {
|
|
2130
|
+
const total = mropeSection[0] + mropeSection[1] + mropeSection[2];
|
|
2131
|
+
const dims = new Int32Array(total);
|
|
2132
|
+
for (let d = 1; d <= 2; d++) {
|
|
2133
|
+
const length = mropeSection[d] * 3;
|
|
2134
|
+
for (let idx = d; idx < length && idx < total; idx += 3) dims[idx] = d;
|
|
2135
|
+
}
|
|
2136
|
+
return dims;
|
|
2137
|
+
}
|
|
2138
|
+
/**
|
|
2139
|
+
* Build the interleaved-M-RoPE cos/sin tables [seq, rope_dim] from 3D position
|
|
2140
|
+
* ids, matching Qwen3_5TextRotaryEmbedding.forward:
|
|
2141
|
+
* freqs[d][i] = pos[d] * inv_freq[i], inv_freq[i] = 1/theta^(2i/rope_dim)
|
|
2142
|
+
* freq[i] picks component mropeFreqDims[i]; emb = cat(freqs, freqs).
|
|
2143
|
+
* cos/sin have length seq*rope_dim. For text-only (all 3 pos rows equal) this
|
|
2144
|
+
* reduces exactly to standard 1D partial RoPE.
|
|
2145
|
+
*
|
|
2146
|
+
* @param positionIds3 [3, seq] as produced by buildMRoPEPositionIds.
|
|
2147
|
+
* @param ropeDim number of rotated dims per head (head_dim * partial_factor).
|
|
2148
|
+
*/
|
|
2149
|
+
function buildMRoPECosSin(positionIds3, seq, ropeDim, theta, mropeSection) {
|
|
2150
|
+
const half = ropeDim / 2;
|
|
2151
|
+
const freqDims = mropeFreqDims(mropeSection);
|
|
2152
|
+
const invFreq = new Float64Array(half);
|
|
2153
|
+
for (let i = 0; i < half; i++) invFreq[i] = 1 / theta ** (2 * i / ropeDim);
|
|
2154
|
+
const cos = new Float32Array(seq * ropeDim);
|
|
2155
|
+
const sin = new Float32Array(seq * ropeDim);
|
|
2156
|
+
for (let n = 0; n < seq; n++) {
|
|
2157
|
+
const base = n * ropeDim;
|
|
2158
|
+
for (let i = 0; i < half; i++) {
|
|
2159
|
+
const angle = positionIds3[freqDims[i] * seq + n] * invFreq[i];
|
|
2160
|
+
const c = Math.cos(angle);
|
|
2161
|
+
const s = Math.sin(angle);
|
|
2162
|
+
cos[base + i] = c;
|
|
2163
|
+
sin[base + i] = s;
|
|
2164
|
+
cos[base + half + i] = c;
|
|
2165
|
+
sin[base + half + i] = s;
|
|
2166
|
+
}
|
|
2167
|
+
}
|
|
2168
|
+
return {
|
|
2169
|
+
cos,
|
|
2170
|
+
sin
|
|
2171
|
+
};
|
|
2172
|
+
}
|
|
2173
|
+
|
|
2174
|
+
//#endregion
|
|
2175
|
+
//#region src/browser/device-guards.ts
|
|
2176
|
+
const WEBKIT_GROUP_PROBE_KEY = "gerbil-webkit-group-v1";
|
|
2177
|
+
const DEFAULT_PROBE = {
|
|
2178
|
+
knownGood: 1,
|
|
2179
|
+
trying: null,
|
|
2180
|
+
capped: false
|
|
2181
|
+
};
|
|
2182
|
+
/** Read the persisted WebKit group probe record (guarded; safe on node). */
|
|
2183
|
+
function readGroupProbe() {
|
|
2184
|
+
if (typeof localStorage === "undefined") return { ...DEFAULT_PROBE };
|
|
2185
|
+
try {
|
|
2186
|
+
const raw = localStorage.getItem(WEBKIT_GROUP_PROBE_KEY);
|
|
2187
|
+
if (!raw) return { ...DEFAULT_PROBE };
|
|
2188
|
+
const parsed = JSON.parse(raw);
|
|
2189
|
+
return {
|
|
2190
|
+
knownGood: typeof parsed.knownGood === "number" && parsed.knownGood >= 1 ? parsed.knownGood : 1,
|
|
2191
|
+
trying: typeof parsed.trying === "number" ? parsed.trying : null,
|
|
2192
|
+
capped: parsed.capped === true
|
|
2193
|
+
};
|
|
2194
|
+
} catch {
|
|
2195
|
+
return { ...DEFAULT_PROBE };
|
|
2196
|
+
}
|
|
2197
|
+
}
|
|
2198
|
+
/** Persist the WebKit group probe record (guarded; no-op on node). */
|
|
2199
|
+
function writeGroupProbe(rec) {
|
|
2200
|
+
if (typeof localStorage === "undefined") return;
|
|
2201
|
+
try {
|
|
2202
|
+
localStorage.setItem(WEBKIT_GROUP_PROBE_KEY, JSON.stringify(rec));
|
|
2203
|
+
} catch {}
|
|
2204
|
+
}
|
|
2205
|
+
/**
|
|
2206
|
+
* The validated non-phone sweet spot. iPad swept 1→7.9, 8→19, 32→24.8, 64→26.6,
|
|
2207
|
+
* 128→26.9 (peak), 256→26.2 tok/s — a plateau from ~64 up, so 128 is the best
|
|
2208
|
+
* stable target (more batching just costs memory). Non-phone WebKit jumps here
|
|
2209
|
+
* directly; the crash breadcrumb caps it down if a device can't sustain it.
|
|
2210
|
+
*/
|
|
2211
|
+
const NONPHONE_TARGET_GROUP = 128;
|
|
2212
|
+
/**
|
|
2213
|
+
* Resolve the WebKit group size to use this session, recording `trying` as a
|
|
2214
|
+
* side effect so a crash this load is detectable on the next load.
|
|
2215
|
+
*
|
|
2216
|
+
* Algorithm (only meaningful on WebKit; inert otherwise):
|
|
2217
|
+
* 1. Read the record (default {knownGood:1, trying:null, capped:false}).
|
|
2218
|
+
* 2. If `trying !== null` on entry → the previous load set it but never cleared
|
|
2219
|
+
* it → that load CRASHED at `trying`. Cap there, keep `knownGood`, clear
|
|
2220
|
+
* `trying`. Use `knownGood` this session.
|
|
2221
|
+
* 3. Else if !capped and there is a rung above `knownGood` → set `trying = next`,
|
|
2222
|
+
* persist BEFORE any GPU work, and use it (we're escalating).
|
|
2223
|
+
* 4. Else → use `knownGood`.
|
|
2224
|
+
*
|
|
2225
|
+
* @returns the group size to use this session.
|
|
2226
|
+
*/
|
|
2227
|
+
function resolveWebkitGroupSize(args) {
|
|
2228
|
+
if (args.override && args.override > 0) return args.override;
|
|
2229
|
+
if (!args.isWebKit) return 1;
|
|
2230
|
+
const rec = readGroupProbe();
|
|
2231
|
+
if (rec.trying !== null) {
|
|
2232
|
+
const crashedAt = rec.trying;
|
|
2233
|
+
const next = {
|
|
2234
|
+
knownGood: rec.knownGood,
|
|
2235
|
+
trying: null,
|
|
2236
|
+
capped: true
|
|
2237
|
+
};
|
|
2238
|
+
writeGroupProbe(next);
|
|
2239
|
+
console.log(`[engine] webkit group probe: previous load crashed at group=${crashedAt} → capping at knownGood=${next.knownGood}`);
|
|
2240
|
+
return next.knownGood;
|
|
2241
|
+
}
|
|
2242
|
+
if (!rec.capped && !args.conservative && rec.knownGood < NONPHONE_TARGET_GROUP) {
|
|
2243
|
+
writeGroupProbe({
|
|
2244
|
+
knownGood: rec.knownGood,
|
|
2245
|
+
trying: NONPHONE_TARGET_GROUP,
|
|
2246
|
+
capped: false
|
|
2247
|
+
});
|
|
2248
|
+
console.log(`[engine] webkit group probe: trying target group=${NONPHONE_TARGET_GROUP} (knownGood=${rec.knownGood})`);
|
|
2249
|
+
return NONPHONE_TARGET_GROUP;
|
|
2250
|
+
}
|
|
2251
|
+
console.log(`[engine] webkit group probe: knownGood=${rec.knownGood} trying=null capped=${rec.capped} → using ${rec.knownGood}`);
|
|
2252
|
+
return rec.knownGood;
|
|
2253
|
+
}
|
|
2254
|
+
/**
|
|
2255
|
+
* Promote (or cap) the WebKit group probe after the first successful forward.
|
|
2256
|
+
*
|
|
2257
|
+
* Call this once per page-load, after the model has loaded AND a first forward
|
|
2258
|
+
* completed without the page dying. The breadcrumb already handles the crash
|
|
2259
|
+
* class (the page death leaves `trying` set for the next load); this handles the
|
|
2260
|
+
* wrong-output class and records success.
|
|
2261
|
+
*
|
|
2262
|
+
* @param correct true if the first forward produced non-corrupt output.
|
|
2263
|
+
* - correct → promote: knownGood = trying, trying = null.
|
|
2264
|
+
* - incorrect → cap: keep knownGood at the prior rung, trying = null, capped.
|
|
2265
|
+
*/
|
|
2266
|
+
function promoteGroupProbe(correct) {
|
|
2267
|
+
if (typeof localStorage === "undefined") return;
|
|
2268
|
+
const rec = readGroupProbe();
|
|
2269
|
+
if (rec.trying === null) return;
|
|
2270
|
+
if (correct) {
|
|
2271
|
+
writeGroupProbe({
|
|
2272
|
+
knownGood: rec.trying,
|
|
2273
|
+
trying: null,
|
|
2274
|
+
capped: rec.capped
|
|
2275
|
+
});
|
|
2276
|
+
console.log(`[engine] webkit group probe: PROMOTED group=${rec.trying} to known-good`);
|
|
2277
|
+
} else {
|
|
2278
|
+
writeGroupProbe({
|
|
2279
|
+
knownGood: rec.knownGood,
|
|
2280
|
+
trying: null,
|
|
2281
|
+
capped: true
|
|
2282
|
+
});
|
|
2283
|
+
console.log(`[engine] webkit group probe: group=${rec.trying} produced INCORRECT output → capping at knownGood=${rec.knownGood}`);
|
|
2284
|
+
}
|
|
2285
|
+
}
|
|
2286
|
+
|
|
2287
|
+
//#endregion
|
|
2288
|
+
//#region src/gpu/index.ts
|
|
2289
|
+
/**
|
|
2290
|
+
* WebGPUEngine -- gerbil's native WebGPU inference engine.
|
|
2291
|
+
*
|
|
2292
|
+
* Public API:
|
|
2293
|
+
* const engine = await WebGPUEngine.create({ repo: "Qwen/Qwen3.5-0.8B" });
|
|
2294
|
+
* const result = await engine.generate("What is 2+2?");
|
|
2295
|
+
* engine.destroy();
|
|
2296
|
+
*/
|
|
2297
|
+
/** EmbeddingGemma task prefixes (from config_sentence_transformers.json prompts). */
|
|
2298
|
+
const EMBEDDING_GEMMA_PROMPTS = {
|
|
2299
|
+
query: "task: search result | query: ",
|
|
2300
|
+
document: "title: none | text: "
|
|
2301
|
+
};
|
|
2302
|
+
/**
|
|
2303
|
+
* Extract the first complete JSON object or array from a model's text output.
|
|
2304
|
+
*
|
|
2305
|
+
* Tolerant of prose, markdown, and ```json code fences: it scans for the first
|
|
2306
|
+
* `{` or `[`, then walks forward tracking string/escape state and brace/bracket
|
|
2307
|
+
* depth to find the matching close, returning the balanced substring. Returns
|
|
2308
|
+
* `undefined` when no JSON-looking span is present. Used by
|
|
2309
|
+
* {@link WebGPUEngine.generateObject}.
|
|
2310
|
+
*/
|
|
2311
|
+
function extractJson(text) {
|
|
2312
|
+
const cleaned = text.replace(/```(?:json)?/gi, "");
|
|
2313
|
+
const start = cleaned.search(/[{[]/);
|
|
2314
|
+
if (start === -1) return void 0;
|
|
2315
|
+
const open = cleaned[start];
|
|
2316
|
+
const close = open === "{" ? "}" : "]";
|
|
2317
|
+
let depth = 0;
|
|
2318
|
+
let inString = false;
|
|
2319
|
+
let escaped = false;
|
|
2320
|
+
for (let i = start; i < cleaned.length; i++) {
|
|
2321
|
+
const ch = cleaned[i];
|
|
2322
|
+
if (inString) {
|
|
2323
|
+
if (escaped) escaped = false;
|
|
2324
|
+
else if (ch === "\\") escaped = true;
|
|
2325
|
+
else if (ch === "\"") inString = false;
|
|
2326
|
+
continue;
|
|
2327
|
+
}
|
|
2328
|
+
if (ch === "\"") inString = true;
|
|
2329
|
+
else if (ch === open) depth++;
|
|
2330
|
+
else if (ch === close) {
|
|
2331
|
+
depth--;
|
|
2332
|
+
if (depth === 0) return cleaned.slice(start, i + 1);
|
|
2333
|
+
}
|
|
2334
|
+
}
|
|
2335
|
+
}
|
|
2336
|
+
/**
|
|
2337
|
+
* The main WebGPU inference engine.
|
|
2338
|
+
*
|
|
2339
|
+
* Usage:
|
|
2340
|
+
* const engine = await WebGPUEngine.create({ repo: "Qwen/Qwen3.5-0.8B" });
|
|
2341
|
+
* const result = await engine.generate("Hello!");
|
|
2342
|
+
* console.log(result.text);
|
|
2343
|
+
* engine.destroy();
|
|
2344
|
+
*/
|
|
2345
|
+
/** System prompt that locks the model into "continue the text" autocomplete mode. */
|
|
2346
|
+
const AUTOCOMPLETE_SYSTEM = [
|
|
2347
|
+
"You are an inline autocomplete engine.",
|
|
2348
|
+
"Continue the user's text with a brief, natural continuation of the SAME sentence or thought.",
|
|
2349
|
+
"Output ONLY the continuation text — no preamble, no quotes, no explanations, no assistant voice.",
|
|
2350
|
+
"Do not answer questions; just continue the writing.",
|
|
2351
|
+
"Example — input: \"The quick brown fox\" → continuation: \" jumps over the lazy dog.\""
|
|
2352
|
+
].join(" ");
|
|
2353
|
+
/**
|
|
2354
|
+
* Turn raw model output into a clean inline continuation: cut after the first
|
|
2355
|
+
* newline (single-line), strip wrapping quotes, drop an echoed copy of the typed
|
|
2356
|
+
* text, and add a single leading space unless the suggestion hugs punctuation or
|
|
2357
|
+
* the typed text already ends with whitespace.
|
|
2358
|
+
*/
|
|
2359
|
+
function normalizeContinuation(raw, typed, singleLine) {
|
|
2360
|
+
let s = singleLine ? raw.replace(/\n[\s\S]*$/, "") : raw;
|
|
2361
|
+
s = s.replace(/^["'“”']+/, "").replace(/["'“”']+$/, "");
|
|
2362
|
+
if (s.startsWith(typed)) s = s.slice(typed.length);
|
|
2363
|
+
s = s.replace(/^\s+/, "");
|
|
2364
|
+
if (!s) return "";
|
|
2365
|
+
const startsWithPunct = /^[.,;:!?)\]}'"”’%]/.test(s);
|
|
2366
|
+
const typedEndsWithSpace = /\s$/.test(typed) || typed.length === 0;
|
|
2367
|
+
return startsWithPunct || typedEndsWithSpace ? s : ` ${s}`;
|
|
2368
|
+
}
|
|
2369
|
+
function formatAgentToolsPrompt(tools) {
|
|
2370
|
+
return `You are a helpful assistant with access to tools.
|
|
2371
|
+
|
|
2372
|
+
# Tools
|
|
2373
|
+
|
|
2374
|
+
${tools.map((t) => `## ${t.name}\nDescription: ${t.description}\nParameters: ${JSON.stringify(t.parameters ?? {
|
|
2375
|
+
type: "object",
|
|
2376
|
+
properties: {}
|
|
2377
|
+
})}`).join("\n\n")}
|
|
2378
|
+
|
|
2379
|
+
## How to call a tool
|
|
2380
|
+
|
|
2381
|
+
Reply with ONLY:
|
|
2382
|
+
<tool_call>
|
|
2383
|
+
{"name": "tool_name", "arguments": {"param": "value"}}
|
|
2384
|
+
</tool_call>
|
|
2385
|
+
|
|
2386
|
+
When you have the final answer, reply normally with no tool_call.`;
|
|
2387
|
+
}
|
|
2388
|
+
function parseAgentToolCall(text) {
|
|
2389
|
+
const tagged = text.match(/<tool_call>\s*([\s\S]*?)\s*<\/tool_call>/);
|
|
2390
|
+
const json = tagged ? tagged[1] : text.match(/\{\s*"name"\s*:[\s\S]*\}/)?.[0] ?? null;
|
|
2391
|
+
if (!json) return null;
|
|
2392
|
+
try {
|
|
2393
|
+
const parsed = JSON.parse(json);
|
|
2394
|
+
if (typeof parsed.name === "string") return {
|
|
2395
|
+
name: parsed.name,
|
|
2396
|
+
args: parsed.arguments ?? parsed.args ?? {}
|
|
2397
|
+
};
|
|
2398
|
+
} catch {}
|
|
2399
|
+
return null;
|
|
2400
|
+
}
|
|
2401
|
+
var WebGPUEngine = class WebGPUEngine {
|
|
2402
|
+
ctx;
|
|
2403
|
+
executor;
|
|
2404
|
+
tokenizer;
|
|
2405
|
+
_destroyed = false;
|
|
2406
|
+
_isEmbedding;
|
|
2407
|
+
/** HF architecture string (e.g. "Gemma3TextModel", "Qwen3ForCausalLM"). */
|
|
2408
|
+
_architecture;
|
|
2409
|
+
/** Vision encoder (built only when enableVision and the model is vision-capable). */
|
|
2410
|
+
visionExecutor;
|
|
2411
|
+
/** Raw vision_config (for host preprocessing of grids). */
|
|
2412
|
+
visionConfig;
|
|
2413
|
+
/** Raw pos_embed.weight table for bilinear interpolation. */
|
|
2414
|
+
visionPosEmbedTable;
|
|
2415
|
+
/** True when the LM graph was built with the multimodal (M-RoPE + splice) path. */
|
|
2416
|
+
_multimodalGraph;
|
|
2417
|
+
/** Raw config.json (for M-RoPE params: mrope_section, rope_theta, partial factor). */
|
|
2418
|
+
rawConfig;
|
|
2419
|
+
/** Effective max sequence length (cos/sin table coverage). */
|
|
2420
|
+
maxSeqLen;
|
|
2421
|
+
/** Original create() options (used to lazily spin up the Kani-TTS engine for speak()). */
|
|
2422
|
+
_createOptions;
|
|
2423
|
+
/** Lazily-created Kani-TTS engine (codec-LM + NanoCodec) backing speak(). */
|
|
2424
|
+
_kaniTTS = null;
|
|
2425
|
+
/**
|
|
2426
|
+
* WebKit group-size probe state. When true, a candidate group size is being
|
|
2427
|
+
* tried this page-load and must be promoted (or capped) after the FIRST
|
|
2428
|
+
* successful forward produces non-corrupt logits. Goes false once handled so
|
|
2429
|
+
* promotion runs at most once per session. Always false on Dawn/node.
|
|
2430
|
+
*/
|
|
2431
|
+
_groupProbePending = false;
|
|
2432
|
+
/** Model capabilities (text, vision, moe). */
|
|
2433
|
+
capabilities;
|
|
2434
|
+
/** Model architecture config. */
|
|
2435
|
+
config;
|
|
2436
|
+
constructor(ctx, executor, tokenizer, graph, opts, vision) {
|
|
2437
|
+
this.ctx = ctx;
|
|
2438
|
+
this.executor = executor;
|
|
2439
|
+
this.tokenizer = tokenizer;
|
|
2440
|
+
this.capabilities = graph.capabilities;
|
|
2441
|
+
this.config = graph.config;
|
|
2442
|
+
this._isEmbedding = graph.outputs.includes("embedding");
|
|
2443
|
+
this._architecture = graph.architecture;
|
|
2444
|
+
this.visionExecutor = vision?.executor ?? null;
|
|
2445
|
+
this.visionConfig = vision?.config ?? null;
|
|
2446
|
+
this.visionPosEmbedTable = vision?.posEmbedTable ?? null;
|
|
2447
|
+
this._multimodalGraph = opts.multimodalGraph;
|
|
2448
|
+
this.rawConfig = opts.rawConfig;
|
|
2449
|
+
this.maxSeqLen = opts.maxSeqLen;
|
|
2450
|
+
this._createOptions = opts.createOptions;
|
|
2451
|
+
this._groupProbePending = opts.groupProbePending ?? false;
|
|
2452
|
+
}
|
|
2453
|
+
/** True if this engine has a vision encoder built (use encodeImage()). */
|
|
2454
|
+
get hasVision() {
|
|
2455
|
+
return this.visionExecutor !== null;
|
|
2456
|
+
}
|
|
2457
|
+
/** Per-opType decode GPU-time breakdown (only populated under GERBIL_PROFILE). */
|
|
2458
|
+
getDecodeProfile() {
|
|
2459
|
+
return this.executor.getProfile();
|
|
2460
|
+
}
|
|
2461
|
+
/** Clear accumulated decode profiler data (e.g. to drop warm-up tokens). */
|
|
2462
|
+
resetDecodeProfile() {
|
|
2463
|
+
this.executor.resetProfile();
|
|
2464
|
+
}
|
|
2465
|
+
/** Profile ONE real decode step (the pipelined-greedy kernels). Token-independent
|
|
2466
|
+
* timing — pass any valid id. Only meaningful under GERBIL_PROFILE. */
|
|
2467
|
+
async profileDecodeStep(tokenId) {
|
|
2468
|
+
await this.executor.profileDecodeStep(tokenId);
|
|
2469
|
+
}
|
|
2470
|
+
/** Decode dispatch count + the device's storage-buffer limit (which gates the
|
|
2471
|
+
* INT4 projection fusions). Lets the iPad runner report whether fusions applied
|
|
2472
|
+
* on-device or silently fell back (8 < 9 ⇒ more dispatches ⇒ more mobile drains). */
|
|
2473
|
+
getDecodeStats() {
|
|
2474
|
+
return {
|
|
2475
|
+
dispatches: this.executor.decodeDispatchCount,
|
|
2476
|
+
maxStorageBuffers: this.executor.maxStorageBuffers
|
|
2477
|
+
};
|
|
2478
|
+
}
|
|
2479
|
+
/**
|
|
2480
|
+
* Write a coarse crash-phase breadcrumb that survives a GPU-process kill / page
|
|
2481
|
+
* reload. The iPad harness reads `localStorage["gerbil-crash-phase"]` after a
|
|
2482
|
+
* crash; without these, a describe-time crash only shows the last load phase
|
|
2483
|
+
* ("engine:ready"). The describe path tags vit-encode / splice / text-decode so
|
|
2484
|
+
* the next run shows WHERE it died, not just "crashed after load".
|
|
2485
|
+
*/
|
|
2486
|
+
setPhase(phase) {
|
|
2487
|
+
try {
|
|
2488
|
+
if (typeof localStorage !== "undefined") localStorage.setItem("gerbil-crash-phase", phase);
|
|
2489
|
+
} catch {}
|
|
2490
|
+
}
|
|
2491
|
+
/** True if this engine was loaded as an embedding model (use embed(), not generate()). */
|
|
2492
|
+
get isEmbedding() {
|
|
2493
|
+
return this._isEmbedding;
|
|
2494
|
+
}
|
|
2495
|
+
/**
|
|
2496
|
+
* WebKit group-size probe promotion hook. Runs at most once per session, after
|
|
2497
|
+
* the FIRST forward completes without the page dying. If the page had crashed
|
|
2498
|
+
* at this group size, this code never runs and the localStorage breadcrumb
|
|
2499
|
+
* (left by the resolver) caps the device on the next load — that is what makes
|
|
2500
|
+
* the probe survive the crash class. Here we additionally handle the
|
|
2501
|
+
* wrong-output class by inspecting the first forward's logits for corruption
|
|
2502
|
+
* (NaN / Inf / all-zero / all-same), reusing the same signals as integrityCheck().
|
|
2503
|
+
*/
|
|
2504
|
+
maybePromoteGroupProbe(logits) {
|
|
2505
|
+
if (!this._groupProbePending) return;
|
|
2506
|
+
this._groupProbePending = false;
|
|
2507
|
+
const n = Math.min(logits.length, 256);
|
|
2508
|
+
let allZero = true;
|
|
2509
|
+
let allSame = true;
|
|
2510
|
+
let finite = true;
|
|
2511
|
+
const first = logits[0];
|
|
2512
|
+
for (let i = 0; i < n; i++) {
|
|
2513
|
+
const v = logits[i];
|
|
2514
|
+
if (!Number.isFinite(v)) {
|
|
2515
|
+
finite = false;
|
|
2516
|
+
break;
|
|
2517
|
+
}
|
|
2518
|
+
if (v !== 0) allZero = false;
|
|
2519
|
+
if (v !== first) allSame = false;
|
|
2520
|
+
}
|
|
2521
|
+
promoteGroupProbe(finite && !allZero && !allSame);
|
|
2522
|
+
}
|
|
2523
|
+
/**
|
|
2524
|
+
* Create and initialize a WebGPUEngine.
|
|
2525
|
+
*
|
|
2526
|
+
* Downloads the model from HuggingFace, compiles shaders, uploads weights.
|
|
2527
|
+
*/
|
|
2528
|
+
static async create(options = {}) {
|
|
2529
|
+
options = {
|
|
2530
|
+
...options,
|
|
2531
|
+
repo: resolveDefaultRepo(options)
|
|
2532
|
+
};
|
|
2533
|
+
const ctx = await initGPU();
|
|
2534
|
+
const isBrowser = typeof navigator !== "undefined" && typeof location !== "undefined";
|
|
2535
|
+
const isSafari = isBrowser && /Safari/.test(navigator.userAgent) && !/Chrome/.test(navigator.userAgent);
|
|
2536
|
+
const params = isBrowser ? new URLSearchParams(location.search) : null;
|
|
2537
|
+
const forceKvF32 = params?.has("kvf32");
|
|
2538
|
+
const maxSeqOverride = params?.get("maxseq");
|
|
2539
|
+
const groupOverride = params?.get("group");
|
|
2540
|
+
let kvMode;
|
|
2541
|
+
if (options.kvMode) kvMode = options.kvMode;
|
|
2542
|
+
else if (forceKvF32) kvMode = "f32";
|
|
2543
|
+
else if (isSafari && ctx.hasF16) kvMode = "packed-f16";
|
|
2544
|
+
else if (ctx.hasF16) kvMode = "native-f16";
|
|
2545
|
+
else kvMode = "f32";
|
|
2546
|
+
const kvDtype = kvMode === "f32" ? "f32" : "f16";
|
|
2547
|
+
console.log(`[engine] kvMode: ${kvMode}, kvDtype: ${kvDtype}, f16 supported: ${ctx.hasF16}, safari: ${isSafari}`);
|
|
2548
|
+
const setPhase = (phase) => {
|
|
2549
|
+
try {
|
|
2550
|
+
if (typeof localStorage !== "undefined") localStorage.setItem("gerbil-crash-phase", phase);
|
|
2551
|
+
} catch {}
|
|
2552
|
+
};
|
|
2553
|
+
setPhase("engine:loading-model");
|
|
2554
|
+
const maxVisionPatches = options.maxVisionPatches ?? (ctx.isWebKitWebGPU ? 1024 : 4096);
|
|
2555
|
+
const multimodal = options.enableVision ? { maxVisionTokens: Math.ceil(maxVisionPatches / 4) } : void 0;
|
|
2556
|
+
const { graph, tokenizer, weights, rawConfig, pleSource } = await loadModel({
|
|
2557
|
+
...options,
|
|
2558
|
+
repo: resolveDefaultRepo(options),
|
|
2559
|
+
kvDtype,
|
|
2560
|
+
multimodal
|
|
2561
|
+
});
|
|
2562
|
+
setPhase(`engine:model-loaded:${weights.size}-weights`);
|
|
2563
|
+
let visionBundle = null;
|
|
2564
|
+
const hasVisionConfig = rawConfig.vision_config != null;
|
|
2565
|
+
const isGemma4Vision = rawConfig.vision_config?.model_type === "gemma4_vision" || [...weights.keys()].some((k) => k.startsWith("vision_tower."));
|
|
2566
|
+
const visPosKey = isGemma4Vision ? "vision_tower.patch_embedder.position_embedding_table" : "visual.pos_embed.weight";
|
|
2567
|
+
const towerPrefix = isGemma4Vision ? "vision_tower." : "visual.";
|
|
2568
|
+
const hasVisualWeights = weights.keys().some((k) => k.startsWith(towerPrefix));
|
|
2569
|
+
if (Boolean(options.enableVision && hasVisionConfig) && hasVisualWeights) {
|
|
2570
|
+
const visGraph = isGemma4Vision ? generateGemma4VisionGraph(rawConfig) : generateQwen3_5VisionGraph(rawConfig);
|
|
2571
|
+
const maxPatches = maxVisionPatches;
|
|
2572
|
+
const visWeights = /* @__PURE__ */ new Map();
|
|
2573
|
+
for (const k of weights.keys()) if (k.startsWith(towerPrefix) || k.startsWith("embed_vision.") || k === visPosKey) {
|
|
2574
|
+
const w = await weights.get(k);
|
|
2575
|
+
if (w) visWeights.set(k, w);
|
|
2576
|
+
}
|
|
2577
|
+
if (isGemma4Vision) {
|
|
2578
|
+
patchGemma4VisionClips(visGraph, visWeights);
|
|
2579
|
+
const gi = resolveGemma4VisionInfo(rawConfig);
|
|
2580
|
+
dequantizeGemma4VisionProjection(visWeights, (rawConfig.quantization_config ?? rawConfig.quantization)?.group_size ?? 64, gi.textHidden, gi.hiddenSize);
|
|
2581
|
+
ensureGemma4VisionEmbedderNorms(visWeights, gi.hiddenSize, gi.textHidden);
|
|
2582
|
+
}
|
|
2583
|
+
const visExec = new VisionExecutor(ctx, visGraph, maxPatches);
|
|
2584
|
+
const posW = visWeights.get(visPosKey);
|
|
2585
|
+
const posTable = posW && posW.data instanceof Float32Array ? new Float32Array(posW.data) : posW ? new Float32Array(posW.data.buffer, posW.data.byteOffset, posW.data.byteLength / 4).slice() : null;
|
|
2586
|
+
await visExec.uploadWeights(visWeights);
|
|
2587
|
+
visExec.initBindGroups();
|
|
2588
|
+
if (posTable) visionBundle = {
|
|
2589
|
+
executor: visExec,
|
|
2590
|
+
config: rawConfig.vision_config ?? rawConfig,
|
|
2591
|
+
posEmbedTable: posTable
|
|
2592
|
+
};
|
|
2593
|
+
}
|
|
2594
|
+
let maxSeqLen;
|
|
2595
|
+
if (maxSeqOverride) maxSeqLen = Math.min(Number.parseInt(maxSeqOverride, 10), graph.config.context_length);
|
|
2596
|
+
else if (forceKvF32 && isSafari) maxSeqLen = Math.min(options.maxSeqLen ?? 1024, graph.config.context_length, 1024);
|
|
2597
|
+
else if (ctx.isWebKitWebGPU) maxSeqLen = Math.min(options.maxSeqLen ?? 512, graph.config.context_length, 2048);
|
|
2598
|
+
else maxSeqLen = Math.min(options.maxSeqLen ?? graph.config.context_length, graph.config.context_length, 4096);
|
|
2599
|
+
console.log(`[engine] maxSeqLen: ${maxSeqLen}, architecture: ${graph.architecture}`);
|
|
2600
|
+
if (ctx.isWebKitWebGPU) console.log(`[engine] device limits: maxBufferSize=${ctx.limits.maxBufferSize}, maxStorageBufferBindingSize=${ctx.limits.maxStorageBufferBindingSize}, maxComputeWorkgroupStorageSize=${ctx.limits.maxComputeWorkgroupStorageSize}`);
|
|
2601
|
+
setPhase("engine:allocating-buffers");
|
|
2602
|
+
ctx.device.pushErrorScope("out-of-memory");
|
|
2603
|
+
ctx.device.pushErrorScope("validation");
|
|
2604
|
+
const groupOverrideNum = groupOverride ? Number.parseInt(groupOverride, 10) : void 0;
|
|
2605
|
+
let webkitGroupSize;
|
|
2606
|
+
let probingGroup = false;
|
|
2607
|
+
if (groupOverrideNum) {
|
|
2608
|
+
webkitGroupSize = groupOverrideNum;
|
|
2609
|
+
console.log(`[engine] webkitGroupSize override: ${webkitGroupSize} dispatches/command buffer`);
|
|
2610
|
+
} else if (ctx.isWebKitWebGPU) {
|
|
2611
|
+
webkitGroupSize = resolveWebkitGroupSize({ isWebKit: true });
|
|
2612
|
+
probingGroup = true;
|
|
2613
|
+
}
|
|
2614
|
+
const executor = new Executor(ctx, graph, {
|
|
2615
|
+
maxSeqLen,
|
|
2616
|
+
kvMode,
|
|
2617
|
+
webkitGroupSize
|
|
2618
|
+
});
|
|
2619
|
+
if (pleSource) executor.setPleSource(pleSource);
|
|
2620
|
+
setPhase("engine:uploading-weights");
|
|
2621
|
+
await executor.uploadWeights(weights);
|
|
2622
|
+
await weights.dispose?.();
|
|
2623
|
+
setPhase("engine:compiling-shaders");
|
|
2624
|
+
executor.initBindGroups();
|
|
2625
|
+
const validationError = await ctx.device.popErrorScope();
|
|
2626
|
+
const oomError = await ctx.device.popErrorScope();
|
|
2627
|
+
if (oomError || validationError) {
|
|
2628
|
+
const detail = [oomError ? `out-of-memory: ${oomError.message}` : null, validationError ? `validation: ${validationError.message}` : null].filter(Boolean).join("; ");
|
|
2629
|
+
setPhase(`engine:gpu-error:${detail.slice(0, 120)}`);
|
|
2630
|
+
throw new Error(`GPU setup failed (${detail}). The model likely exceeds this device's memory budget — try a smaller maxSeqLen or a q4-quantized model.`);
|
|
2631
|
+
}
|
|
2632
|
+
setPhase("engine:ready");
|
|
2633
|
+
return new WebGPUEngine(ctx, executor, tokenizer, graph, {
|
|
2634
|
+
multimodalGraph: Boolean(multimodal),
|
|
2635
|
+
rawConfig,
|
|
2636
|
+
maxSeqLen,
|
|
2637
|
+
createOptions: options,
|
|
2638
|
+
groupProbePending: probingGroup
|
|
2639
|
+
}, visionBundle);
|
|
2640
|
+
}
|
|
2641
|
+
/**
|
|
2642
|
+
* Encode an image (already preprocessed into patches) into merged
|
|
2643
|
+
* image-embedding tokens of dim `out_hidden_size` (1024 for Qwen3.5).
|
|
2644
|
+
*
|
|
2645
|
+
* This is the VISION ENCODER ONLY — it returns the image tokens; it does not
|
|
2646
|
+
* splice them into a text sequence or apply M-RoPE (that is the LM-side
|
|
2647
|
+
* integration phase). Requires `enableVision: true` at create() on a
|
|
2648
|
+
* vision-capable checkpoint.
|
|
2649
|
+
*
|
|
2650
|
+
* @param patches Flattened patches, row-major [numPatches, patch_dim].
|
|
2651
|
+
* patch_dim = in_channels * temporal_patch_size * patch_size^2 (1536 for Qwen3.5).
|
|
2652
|
+
* Patches must already be ordered in spatial_merge_size×spatial_merge_size
|
|
2653
|
+
* groups (as the HF image processor emits them).
|
|
2654
|
+
* @param gridTHW The (temporal, height, width) patch-grid dims for the image.
|
|
2655
|
+
* numPatches must equal t*h*w.
|
|
2656
|
+
*/
|
|
2657
|
+
async encodeImage(patches, gridTHW, onStage) {
|
|
2658
|
+
this.checkDestroyed();
|
|
2659
|
+
if (!this.visionExecutor || !this.visionConfig || !this.visionPosEmbedTable) throw new Error("encodeImage() requires a vision encoder. Load with { enableVision: true } on a vision-capable checkpoint (e.g. Qwen/Qwen3.5-0.8B).");
|
|
2660
|
+
const vcfg = this.visionConfig;
|
|
2661
|
+
if (this.visionExecutor.gemma4) {
|
|
2662
|
+
const info = resolveGemma4VisionInfo({ vision_config: vcfg });
|
|
2663
|
+
const gridH = gridTHW[1];
|
|
2664
|
+
const gridW = gridTHW[2];
|
|
2665
|
+
const posSize = Math.floor(this.visionPosEmbedTable.length / (2 * info.hiddenSize));
|
|
2666
|
+
const host$1 = buildGemma4VisionPositionTensors(gridH, gridW, this.visionPosEmbedTable, posSize, {
|
|
2667
|
+
hiddenSize: info.hiddenSize,
|
|
2668
|
+
numHeads: info.numHeads,
|
|
2669
|
+
headDim: info.headDim,
|
|
2670
|
+
ropeTheta: info.ropeTheta,
|
|
2671
|
+
poolingKernelSize: info.poolingKernelSize
|
|
2672
|
+
});
|
|
2673
|
+
return this.visionExecutor.encodeGemma4({
|
|
2674
|
+
patches,
|
|
2675
|
+
posEmbeds: host$1.posEmbeds,
|
|
2676
|
+
cos: host$1.cos,
|
|
2677
|
+
sin: host$1.sin,
|
|
2678
|
+
poolMatrix: host$1.poolMatrix,
|
|
2679
|
+
numPatches: host$1.numPatches,
|
|
2680
|
+
numPooled: host$1.numPooled
|
|
2681
|
+
}, onStage);
|
|
2682
|
+
}
|
|
2683
|
+
const hiddenSize = vcfg.hidden_size;
|
|
2684
|
+
const numHeads = vcfg.num_heads;
|
|
2685
|
+
const host = buildVisionPositionTensors(gridTHW, this.visionPosEmbedTable, {
|
|
2686
|
+
hiddenSize,
|
|
2687
|
+
numHeads,
|
|
2688
|
+
numPositionEmbeddings: vcfg.num_position_embeddings,
|
|
2689
|
+
spatialMergeSize: vcfg.spatial_merge_size,
|
|
2690
|
+
ropeTheta: 1e4
|
|
2691
|
+
});
|
|
2692
|
+
const numPatches = gridTHW[0] * gridTHW[1] * gridTHW[2];
|
|
2693
|
+
return this.visionExecutor.encode({
|
|
2694
|
+
patches,
|
|
2695
|
+
posEmbeds: host.posEmbeds,
|
|
2696
|
+
cos: host.cos,
|
|
2697
|
+
sin: host.sin,
|
|
2698
|
+
numPatches
|
|
2699
|
+
}, onStage);
|
|
2700
|
+
}
|
|
2701
|
+
/** Resolve M-RoPE params from rawConfig: rope_dim, theta, mrope_section. */
|
|
2702
|
+
mropeParams() {
|
|
2703
|
+
const cfg = this.rawConfig ?? {};
|
|
2704
|
+
const rope = (cfg.text_config ?? cfg).rope_parameters ?? {};
|
|
2705
|
+
const headDim = this.config.head_dim;
|
|
2706
|
+
const partial = rope.partial_rotary_factor ?? .25;
|
|
2707
|
+
const ropeDim = Math.floor(headDim * partial);
|
|
2708
|
+
const theta = rope.rope_theta ?? this.config.rope_base ?? 1e4;
|
|
2709
|
+
const section = rope.mrope_section ?? [
|
|
2710
|
+
11,
|
|
2711
|
+
11,
|
|
2712
|
+
10
|
|
2713
|
+
];
|
|
2714
|
+
const visCfg = cfg.vision_config ?? {};
|
|
2715
|
+
return {
|
|
2716
|
+
ropeDim,
|
|
2717
|
+
theta,
|
|
2718
|
+
section,
|
|
2719
|
+
imageTokenId: cfg.image_token_id ?? 248056,
|
|
2720
|
+
mergeSize: visCfg.spatial_merge_size ?? 2
|
|
2721
|
+
};
|
|
2722
|
+
}
|
|
2723
|
+
/**
|
|
2724
|
+
* Write the M-RoPE cos/sin (token order) + image row-map for a prefill of
|
|
2725
|
+
* `positionIds3` ([3, seq]). `rowMap[i]` = vision-buffer row for image tokens,
|
|
2726
|
+
* -1 for text. Returns the logical position of the last token (for decode).
|
|
2727
|
+
*/
|
|
2728
|
+
writeMRoPEPrefill(positionIds3, seq, rowMap) {
|
|
2729
|
+
const { ropeDim, theta, section } = this.mropeParams();
|
|
2730
|
+
const { cos, sin } = buildMRoPECosSin(positionIds3, seq, ropeDim, theta, section);
|
|
2731
|
+
this.executor.writeInput("mrope_cos", cos);
|
|
2732
|
+
this.executor.writeInput("mrope_sin", sin);
|
|
2733
|
+
this.executor.writeInput("vision_row_map", rowMap);
|
|
2734
|
+
const last = seq - 1;
|
|
2735
|
+
return Math.max(positionIds3[last], positionIds3[seq + last], positionIds3[2 * seq + last]);
|
|
2736
|
+
}
|
|
2737
|
+
/**
|
|
2738
|
+
* Write a single decode-step M-RoPE cos/sin row at table slot `seqPos` for a
|
|
2739
|
+
* text token at logical position `logicalPos`, plus a -1 row-map entry.
|
|
2740
|
+
*/
|
|
2741
|
+
writeMRoPEDecodeStep(seqPos, logicalPos) {
|
|
2742
|
+
const { ropeDim, theta, section } = this.mropeParams();
|
|
2743
|
+
const pid = new Int32Array(3);
|
|
2744
|
+
pid[0] = logicalPos;
|
|
2745
|
+
pid[1] = logicalPos;
|
|
2746
|
+
pid[2] = logicalPos;
|
|
2747
|
+
const { cos, sin } = buildMRoPECosSin(pid, 1, ropeDim, theta, section);
|
|
2748
|
+
const rowBytes = ropeDim * 4;
|
|
2749
|
+
this.executor.writeInputAt("mrope_cos", cos, seqPos * rowBytes);
|
|
2750
|
+
this.executor.writeInputAt("mrope_sin", sin, seqPos * rowBytes);
|
|
2751
|
+
}
|
|
2752
|
+
/** Write linear-position M-RoPE inputs for a pure-text forward (no image). */
|
|
2753
|
+
writeMRoPELinearText(seq) {
|
|
2754
|
+
const { ropeDim, theta, section } = this.mropeParams();
|
|
2755
|
+
const pid = new Int32Array(3 * seq);
|
|
2756
|
+
for (let i = 0; i < seq; i++) {
|
|
2757
|
+
pid[i] = i;
|
|
2758
|
+
pid[seq + i] = i;
|
|
2759
|
+
pid[2 * seq + i] = i;
|
|
2760
|
+
}
|
|
2761
|
+
const { cos, sin } = buildMRoPECosSin(pid, seq, ropeDim, theta, section);
|
|
2762
|
+
this.executor.writeInput("mrope_cos", cos);
|
|
2763
|
+
this.executor.writeInput("mrope_sin", sin);
|
|
2764
|
+
const rowMap = new Int32Array(seq).fill(-1);
|
|
2765
|
+
this.executor.writeInput("vision_row_map", rowMap);
|
|
2766
|
+
}
|
|
2767
|
+
/**
|
|
2768
|
+
* Generate text from a prompt.
|
|
2769
|
+
*/
|
|
2770
|
+
async generate(prompt, options = {}) {
|
|
2771
|
+
this.checkDestroyed();
|
|
2772
|
+
const { maxTokens = 512, stopSequences = [], sampling = {}, systemPrompt, onToken } = options;
|
|
2773
|
+
this.executor.reset();
|
|
2774
|
+
let inputIds;
|
|
2775
|
+
if (typeof prompt === "string") {
|
|
2776
|
+
const messages = [];
|
|
2777
|
+
if (systemPrompt) messages.push({
|
|
2778
|
+
role: "system",
|
|
2779
|
+
content: systemPrompt
|
|
2780
|
+
});
|
|
2781
|
+
messages.push({
|
|
2782
|
+
role: "user",
|
|
2783
|
+
content: prompt
|
|
2784
|
+
});
|
|
2785
|
+
inputIds = this.tokenizer.encodeChat(messages, { addGenerationPrompt: true });
|
|
2786
|
+
} else {
|
|
2787
|
+
const messages = systemPrompt ? [{
|
|
2788
|
+
role: "system",
|
|
2789
|
+
content: systemPrompt
|
|
2790
|
+
}, ...prompt] : prompt;
|
|
2791
|
+
inputIds = this.tokenizer.encodeChat(messages, { addGenerationPrompt: true });
|
|
2792
|
+
}
|
|
2793
|
+
const startTime = performance.now();
|
|
2794
|
+
const isGreedy = (sampling.temperature ?? .7) < 1e-6;
|
|
2795
|
+
const hasMRoPE = this._multimodalGraph && this.executor.hasBuffer("mrope_cos");
|
|
2796
|
+
if (this._multimodalGraph) {
|
|
2797
|
+
if (hasMRoPE) this.writeMRoPELinearText(inputIds.length);
|
|
2798
|
+
else if (this.executor.hasBuffer("vision_row_map")) this.executor.writeInput("vision_row_map", new Int32Array(inputIds.length).fill(-1));
|
|
2799
|
+
}
|
|
2800
|
+
let { logits } = await this.executor.forward(new Uint32Array(inputIds));
|
|
2801
|
+
this.maybePromoteGroupProbe(logits);
|
|
2802
|
+
const generatedIds = [];
|
|
2803
|
+
let finishReason = "max_tokens";
|
|
2804
|
+
let generatedText = "";
|
|
2805
|
+
const eosId = this.tokenizer.config.eosTokenId;
|
|
2806
|
+
const decodeStart = performance.now();
|
|
2807
|
+
const consumeToken = (nextToken) => {
|
|
2808
|
+
generatedIds.push(nextToken);
|
|
2809
|
+
if (eosId !== null && nextToken === eosId) {
|
|
2810
|
+
finishReason = "eos";
|
|
2811
|
+
return true;
|
|
2812
|
+
}
|
|
2813
|
+
const tokenText = this.tokenizer.decode([nextToken], true);
|
|
2814
|
+
generatedText += tokenText;
|
|
2815
|
+
if (onToken) {
|
|
2816
|
+
const elapsedMs = performance.now() - decodeStart;
|
|
2817
|
+
const tokenIndex = generatedIds.length;
|
|
2818
|
+
onToken(tokenText, {
|
|
2819
|
+
tokenIndex,
|
|
2820
|
+
elapsedMs,
|
|
2821
|
+
tps: elapsedMs > 0 ? tokenIndex / elapsedMs * 1e3 : 0
|
|
2822
|
+
});
|
|
2823
|
+
}
|
|
2824
|
+
if (stopSequences.some((s) => generatedText.includes(s))) {
|
|
2825
|
+
for (const s of stopSequences) {
|
|
2826
|
+
const idx = generatedText.indexOf(s);
|
|
2827
|
+
if (idx !== -1) generatedText = generatedText.slice(0, idx);
|
|
2828
|
+
}
|
|
2829
|
+
finishReason = "stop_sequence";
|
|
2830
|
+
return true;
|
|
2831
|
+
}
|
|
2832
|
+
return false;
|
|
2833
|
+
};
|
|
2834
|
+
const mmDecode = hasMRoPE;
|
|
2835
|
+
let mmLogicalPos = inputIds.length;
|
|
2836
|
+
if (isGreedy && !this.executor.needsMultiEncoder && !mmDecode) {
|
|
2837
|
+
const firstToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
2838
|
+
if (!consumeToken(firstToken)) {
|
|
2839
|
+
const depth = Executor.PIPELINE_DEPTH;
|
|
2840
|
+
const stepsNeeded = Math.min(maxTokens - 1, this.executor.decodeCapacityRemaining());
|
|
2841
|
+
let submitted = 0;
|
|
2842
|
+
let consumed = 0;
|
|
2843
|
+
while (consumed < stepsNeeded) {
|
|
2844
|
+
while (submitted < stepsNeeded && submitted < consumed + depth) {
|
|
2845
|
+
this.executor.submitGreedyDecodeStep(submitted === 0 ? firstToken : null, submitted % depth);
|
|
2846
|
+
submitted++;
|
|
2847
|
+
}
|
|
2848
|
+
const tok = await this.executor.readDecodeToken(consumed % depth);
|
|
2849
|
+
consumed++;
|
|
2850
|
+
if (consumeToken(tok)) break;
|
|
2851
|
+
}
|
|
2852
|
+
}
|
|
2853
|
+
} else for (let step = 0; step < maxTokens; step++) {
|
|
2854
|
+
let nextToken;
|
|
2855
|
+
if (step === 0 || !isGreedy) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
2856
|
+
else {
|
|
2857
|
+
if (mmDecode) this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
|
|
2858
|
+
nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
|
|
2859
|
+
mmLogicalPos++;
|
|
2860
|
+
}
|
|
2861
|
+
if (consumeToken(nextToken)) break;
|
|
2862
|
+
if (!isGreedy) {
|
|
2863
|
+
if (mmDecode) this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
|
|
2864
|
+
logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
|
|
2865
|
+
mmLogicalPos++;
|
|
2866
|
+
}
|
|
2867
|
+
}
|
|
2868
|
+
const totalTime = performance.now() - startTime;
|
|
2869
|
+
const tokensGenerated = generatedIds.length;
|
|
2870
|
+
const tokensPerSecond = tokensGenerated / (totalTime / 1e3);
|
|
2871
|
+
return {
|
|
2872
|
+
text: generatedText,
|
|
2873
|
+
tokensGenerated,
|
|
2874
|
+
tokensPerSecond,
|
|
2875
|
+
totalTime,
|
|
2876
|
+
finishReason
|
|
2877
|
+
};
|
|
2878
|
+
}
|
|
2879
|
+
/**
|
|
2880
|
+
* Inline autocomplete: continue `prefix` with a brief, single-line continuation.
|
|
2881
|
+
* Wraps `generate` with low-latency defaults (16 tokens, temp 0.3, stop at the
|
|
2882
|
+
* first newline) + a continuation system prompt, then cleans the output (strip
|
|
2883
|
+
* after newline, dequote, drop an echoed prefix, smart leading space).
|
|
2884
|
+
*
|
|
2885
|
+
* ```ts
|
|
2886
|
+
* const suggestion = await engine.autocomplete("The quick brown fox");
|
|
2887
|
+
* // " jumps over the lazy dog."
|
|
2888
|
+
* ```
|
|
2889
|
+
*/
|
|
2890
|
+
async autocomplete(prefix, opts = {}) {
|
|
2891
|
+
return normalizeContinuation((await this.generate(prefix, {
|
|
2892
|
+
systemPrompt: AUTOCOMPLETE_SYSTEM,
|
|
2893
|
+
maxTokens: opts.maxTokens ?? 16,
|
|
2894
|
+
sampling: { temperature: opts.temperature ?? .3 },
|
|
2895
|
+
stopSequences: opts.stop ?? ["\n"]
|
|
2896
|
+
})).text, prefix, opts.singleLine ?? true);
|
|
2897
|
+
}
|
|
2898
|
+
/**
|
|
2899
|
+
* Rewrite `text` in a target tone (e.g. "professional", "friendly", "concise",
|
|
2900
|
+
* "playful", "pirate") or with free-form `instructions`. Returns only the
|
|
2901
|
+
* rewritten text.
|
|
2902
|
+
*
|
|
2903
|
+
* ```ts
|
|
2904
|
+
* await engine.rewrite("hey can u send the file", { tone: "professional" });
|
|
2905
|
+
* ```
|
|
2906
|
+
*/
|
|
2907
|
+
async rewrite(text, opts = {}) {
|
|
2908
|
+
const system = opts.instructions ?? `Rewrite the user's text in a ${opts.tone ?? "professional"} tone. Output ONLY the rewritten text — no preamble, no quotes, no commentary.`;
|
|
2909
|
+
return (await this.generate(text, {
|
|
2910
|
+
systemPrompt: system,
|
|
2911
|
+
maxTokens: opts.maxTokens ?? 256,
|
|
2912
|
+
sampling: { temperature: opts.temperature ?? .7 }
|
|
2913
|
+
})).text.trim();
|
|
2914
|
+
}
|
|
2915
|
+
/**
|
|
2916
|
+
* Agentic tool-calling loop: generate, parse a `<tool_call>`, run the matching
|
|
2917
|
+
* tool's `execute`, feed the result back, and repeat up to `maxSteps` until the
|
|
2918
|
+
* model answers without calling a tool. Returns the final text + a step trace.
|
|
2919
|
+
*
|
|
2920
|
+
* ```ts
|
|
2921
|
+
* const { text, steps } = await engine.generateWithTools("Weather in Paris?", {
|
|
2922
|
+
* tools: [weatherTool],
|
|
2923
|
+
* });
|
|
2924
|
+
* ```
|
|
2925
|
+
*/
|
|
2926
|
+
async generateWithTools(prompt, opts) {
|
|
2927
|
+
const { tools, maxSteps = 5, onStep, maxTokens, sampling } = opts;
|
|
2928
|
+
const systemPrompt = formatAgentToolsPrompt(tools);
|
|
2929
|
+
const messages = typeof prompt === "string" ? [{
|
|
2930
|
+
role: "user",
|
|
2931
|
+
content: prompt
|
|
2932
|
+
}] : [...prompt];
|
|
2933
|
+
const steps = [];
|
|
2934
|
+
let finalText = "";
|
|
2935
|
+
for (let i = 0; i < maxSteps; i++) {
|
|
2936
|
+
const result = await this.generate(messages, {
|
|
2937
|
+
systemPrompt,
|
|
2938
|
+
maxTokens,
|
|
2939
|
+
sampling
|
|
2940
|
+
});
|
|
2941
|
+
const call = parseAgentToolCall(result.text);
|
|
2942
|
+
if (!call) {
|
|
2943
|
+
finalText = result.text;
|
|
2944
|
+
const answer = {
|
|
2945
|
+
kind: "answer",
|
|
2946
|
+
text: result.text
|
|
2947
|
+
};
|
|
2948
|
+
steps.push(answer);
|
|
2949
|
+
onStep?.(answer);
|
|
2950
|
+
break;
|
|
2951
|
+
}
|
|
2952
|
+
const callStep = {
|
|
2953
|
+
kind: "tool_call",
|
|
2954
|
+
tool: call.name,
|
|
2955
|
+
args: call.args
|
|
2956
|
+
};
|
|
2957
|
+
steps.push(callStep);
|
|
2958
|
+
onStep?.(callStep);
|
|
2959
|
+
const tool = tools.find((t) => t.name === call.name);
|
|
2960
|
+
let resultText;
|
|
2961
|
+
if (tool) try {
|
|
2962
|
+
resultText = String(await tool.execute(call.args));
|
|
2963
|
+
} catch (e) {
|
|
2964
|
+
resultText = `Error executing ${call.name}: ${e}`;
|
|
2965
|
+
}
|
|
2966
|
+
else resultText = `Error: unknown tool "${call.name}"`;
|
|
2967
|
+
const resultStep = {
|
|
2968
|
+
kind: "tool_result",
|
|
2969
|
+
tool: call.name,
|
|
2970
|
+
result: resultText
|
|
2971
|
+
};
|
|
2972
|
+
steps.push(resultStep);
|
|
2973
|
+
onStep?.(resultStep);
|
|
2974
|
+
messages.push({
|
|
2975
|
+
role: "assistant",
|
|
2976
|
+
content: result.text
|
|
2977
|
+
});
|
|
2978
|
+
messages.push({
|
|
2979
|
+
role: "user",
|
|
2980
|
+
content: `Tool ${call.name} returned:\n${resultText}`
|
|
2981
|
+
});
|
|
2982
|
+
finalText = resultText;
|
|
2983
|
+
}
|
|
2984
|
+
return {
|
|
2985
|
+
text: finalText,
|
|
2986
|
+
steps
|
|
2987
|
+
};
|
|
2988
|
+
}
|
|
2989
|
+
/**
|
|
2990
|
+
* Generate a STRUCTURED object: generate text, extract the first JSON
|
|
2991
|
+
* object/array, parse it, validate it, and RETRY until it is valid (on-device
|
|
2992
|
+
* tokens are free, so re-rolling a malformed JSON is cheap).
|
|
2993
|
+
*
|
|
2994
|
+
* Extraction is tolerant: prose, markdown, and ```json code fences are
|
|
2995
|
+
* stripped, then the outermost balanced `{...}` or `[...]` is matched and
|
|
2996
|
+
* `JSON.parse`d. Validation is one of:
|
|
2997
|
+
* - a predicate `(o) => boolean` (return false to reject),
|
|
2998
|
+
* - a minimal JSON-schema-ish object with `required` (those keys must exist),
|
|
2999
|
+
* - nothing (only valid JSON is required).
|
|
3000
|
+
*
|
|
3001
|
+
* On each retry the prompt is nudged with a terse "return ONLY valid JSON…"
|
|
3002
|
+
* instruction (including the required-key shape when known). Throws a clear
|
|
3003
|
+
* error if it never validates within `maxRetries + 1` attempts.
|
|
3004
|
+
*
|
|
3005
|
+
* ```ts
|
|
3006
|
+
* const { object } = await engine.generateObject(
|
|
3007
|
+
* 'Extract {name, age} from: "I am Sarah, 28"',
|
|
3008
|
+
* { schema: { required: ["name", "age"] } },
|
|
3009
|
+
* );
|
|
3010
|
+
* // object === { name: "Sarah", age: 28 }
|
|
3011
|
+
* ```
|
|
3012
|
+
*
|
|
3013
|
+
* @typeParam T Expected object type (not enforced at runtime — validate via schema).
|
|
3014
|
+
*/
|
|
3015
|
+
async generateObject(prompt, options = {}) {
|
|
3016
|
+
this.checkDestroyed();
|
|
3017
|
+
const { schema, maxRetries = 4, ...generateOpts } = options;
|
|
3018
|
+
const validate = (value) => {
|
|
3019
|
+
if (typeof schema === "function") return schema(value);
|
|
3020
|
+
if (schema && typeof schema === "object" && Array.isArray(schema.required)) {
|
|
3021
|
+
if (value === null || typeof value !== "object") return false;
|
|
3022
|
+
const obj = value;
|
|
3023
|
+
return schema.required.every((key) => key in obj);
|
|
3024
|
+
}
|
|
3025
|
+
return true;
|
|
3026
|
+
};
|
|
3027
|
+
const nudge = `\n\nReturn ONLY valid JSON matching ${schema && typeof schema === "object" && Array.isArray(schema.required) ? `{ ${schema.required.join(", ")} }` : "the requested shape"}. No prose, no markdown, no code fences.`;
|
|
3028
|
+
const attemptsMax = Math.max(0, maxRetries) + 1;
|
|
3029
|
+
let lastText = "";
|
|
3030
|
+
let lastError = "";
|
|
3031
|
+
for (let attempt = 1; attempt <= attemptsMax; attempt++) {
|
|
3032
|
+
const promptForAttempt = attempt === 1 ? prompt : prompt + nudge;
|
|
3033
|
+
const result = await this.generate(promptForAttempt, generateOpts);
|
|
3034
|
+
lastText = result.text;
|
|
3035
|
+
const parsed = extractJson(result.text);
|
|
3036
|
+
if (parsed === void 0) {
|
|
3037
|
+
lastError = "no JSON object/array found in output";
|
|
3038
|
+
continue;
|
|
3039
|
+
}
|
|
3040
|
+
let value;
|
|
3041
|
+
try {
|
|
3042
|
+
value = JSON.parse(parsed);
|
|
3043
|
+
} catch (e) {
|
|
3044
|
+
lastError = `JSON.parse failed: ${e instanceof Error ? e.message : String(e)}`;
|
|
3045
|
+
continue;
|
|
3046
|
+
}
|
|
3047
|
+
if (!validate(value)) {
|
|
3048
|
+
lastError = "parsed JSON failed schema validation";
|
|
3049
|
+
continue;
|
|
3050
|
+
}
|
|
3051
|
+
return {
|
|
3052
|
+
object: value,
|
|
3053
|
+
text: result.text,
|
|
3054
|
+
attempts: attempt
|
|
3055
|
+
};
|
|
3056
|
+
}
|
|
3057
|
+
throw new Error(`generateObject failed after ${attemptsMax} attempt(s): ${lastError}. Last output: ${JSON.stringify(lastText.slice(0, 200))}`);
|
|
3058
|
+
}
|
|
3059
|
+
/**
|
|
3060
|
+
* Text-to-speech: text → 22 kHz PCM via Kani-TTS-2 (LFM2-350M codec-LM + NVIDIA
|
|
3061
|
+
* NeMo NanoCodec). Returns `{ pcm: Float32Array, sampleRate: 22050 }`.
|
|
3062
|
+
*
|
|
3063
|
+
* Runs the full pipeline: the codec-LM backbone autoregressively emits NanoCodec
|
|
3064
|
+
* audio tokens (4 per frame, frame-level positions + learnable per-layer RoPE),
|
|
3065
|
+
* then the bit-exact NanoCodec decoder (FSQ + causal HiFi-GAN) turns the codes
|
|
3066
|
+
* into PCM. The heavy lifting lives in {@link KaniTTS} (src/gpu/kani-tts.ts); this
|
|
3067
|
+
* lazily constructs that engine on first use (downloading the NanoCodec codec
|
|
3068
|
+
* checkpoint alongside the backbone).
|
|
3069
|
+
*
|
|
3070
|
+
* Requires a Kani-TTS-2 checkpoint (architecture "KaniTTS2ForCausalLM").
|
|
3071
|
+
*/
|
|
3072
|
+
async speak(text, options = {}) {
|
|
3073
|
+
this.checkDestroyed();
|
|
3074
|
+
if (this._architecture !== "KaniTTS2ForCausalLM") throw new Error(`speak() requires a Kani-TTS-2 model (architecture "KaniTTS2ForCausalLM"), loaded engine is "${this._architecture}".`);
|
|
3075
|
+
if (!this._kaniTTS) this._kaniTTS = await KaniTTS.create({
|
|
3076
|
+
repo: this._createOptions.repo,
|
|
3077
|
+
revision: this._createOptions.revision,
|
|
3078
|
+
hfToken: this._createOptions.hfToken,
|
|
3079
|
+
cacheDir: this._createOptions.cacheDir,
|
|
3080
|
+
maxSeqLen: this.maxSeqLen
|
|
3081
|
+
});
|
|
3082
|
+
return this._kaniTTS.speak(text, options);
|
|
3083
|
+
}
|
|
3084
|
+
/**
|
|
3085
|
+
* Describe an image: image-in → text-out. Runs the vision encoder, splices the
|
|
3086
|
+
* merged image tokens into a text prompt, applies multimodal M-RoPE, and
|
|
3087
|
+
* generates a description. Requires `enableVision: true` at create().
|
|
3088
|
+
*
|
|
3089
|
+
* Image input forms:
|
|
3090
|
+
* - `{ pixels, width, height }` — decoded RGB (HWC, 0..255), host-preprocessed
|
|
3091
|
+
* (smart-resize/normalize/patchify) to match the HF image processor.
|
|
3092
|
+
* - `{ patches, gridTHW }` — already-built [N,1536] patch tensor + grid (e.g.
|
|
3093
|
+
* HF-exact pixel_values from a reference; skips host preprocessing).
|
|
3094
|
+
*/
|
|
3095
|
+
async describeImage(image, prompt = "Describe this image.", options = {}) {
|
|
3096
|
+
this.checkDestroyed();
|
|
3097
|
+
if (!this.visionExecutor) throw new Error("describeImage() requires a vision encoder. Load with { enableVision: true } on a vision-capable checkpoint (Qwen3.5 or Gemma 4).");
|
|
3098
|
+
if (this.visionExecutor.gemma4) {
|
|
3099
|
+
if (!this._multimodalGraph) throw new Error("Gemma 4 describeImage() requires the multimodal text graph. Load with { enableVision: true } so generateGemma4Graph builds the vision_embeds + EmbedSplice variant.");
|
|
3100
|
+
let g4patches;
|
|
3101
|
+
let g4gridHW;
|
|
3102
|
+
if ("patches" in image) {
|
|
3103
|
+
g4patches = image.patches;
|
|
3104
|
+
g4gridHW = [image.gridTHW[1], image.gridTHW[2]];
|
|
3105
|
+
} else {
|
|
3106
|
+
const pre$1 = preprocessImageGemma4(image.pixels, image.width, image.height);
|
|
3107
|
+
g4patches = pre$1.patches;
|
|
3108
|
+
g4gridHW = pre$1.gridHW;
|
|
3109
|
+
}
|
|
3110
|
+
const gThw = [
|
|
3111
|
+
1,
|
|
3112
|
+
g4gridHW[0],
|
|
3113
|
+
g4gridHW[1]
|
|
3114
|
+
];
|
|
3115
|
+
const numPatches$1 = gThw[0] * gThw[1] * gThw[2];
|
|
3116
|
+
this.setPhase(`describe:vit-encode:N=${numPatches$1}`);
|
|
3117
|
+
const vision$1 = await this.encodeImage(g4patches, gThw, (stage, info) => {
|
|
3118
|
+
const suffix = info?.layer != null ? `:L${info.layer}` : "";
|
|
3119
|
+
this.setPhase(`describe:${stage}${suffix}`);
|
|
3120
|
+
});
|
|
3121
|
+
const cfg = this.rawConfig ?? {};
|
|
3122
|
+
const imageTokenId$1 = cfg.image_token_id ?? 258880;
|
|
3123
|
+
const boiId = this.tokenizer.tokenToId("<|image>") ?? cfg.boi_token_id ?? 255999;
|
|
3124
|
+
const eoiId = this.tokenizer.tokenToId("<image|>") ?? cfg.eoi_token_id ?? 258882;
|
|
3125
|
+
const bos = this.tokenizer.config.bosToken ? this.tokenizer.tokenToId(this.tokenizer.config.bosToken) ?? [] : [];
|
|
3126
|
+
const bosIds = Array.isArray(bos) ? bos : [bos];
|
|
3127
|
+
const turnUserOpen = this.tokenizer.encode("<|turn>user\n");
|
|
3128
|
+
const imgLead = this.tokenizer.encode("\n\n");
|
|
3129
|
+
const imgTrail = this.tokenizer.encode("\n\n");
|
|
3130
|
+
const promptIds = this.tokenizer.encode(prompt);
|
|
3131
|
+
const turnClose = this.tokenizer.encode("<turn|>\n<|turn>model\n");
|
|
3132
|
+
const imageRun$1 = new Array(vision$1.rows).fill(imageTokenId$1);
|
|
3133
|
+
const inputIds$1 = [
|
|
3134
|
+
...bosIds,
|
|
3135
|
+
...turnUserOpen,
|
|
3136
|
+
...imgLead,
|
|
3137
|
+
boiId,
|
|
3138
|
+
...imageRun$1,
|
|
3139
|
+
eoiId,
|
|
3140
|
+
...imgTrail,
|
|
3141
|
+
...promptIds,
|
|
3142
|
+
...turnClose
|
|
3143
|
+
];
|
|
3144
|
+
return this.runMultimodalGemma4(vision$1.embeds, inputIds$1, imageTokenId$1, options);
|
|
3145
|
+
}
|
|
3146
|
+
if (!this._multimodalGraph) throw new Error("describeImage() requires a multimodal engine. Load with { enableVision: true } on a vision-capable checkpoint (e.g. Qwen/Qwen3.5-0.8B).");
|
|
3147
|
+
const procCfg = options.imageProcessor ?? QWEN3_5_IMAGE_PROCESSOR;
|
|
3148
|
+
let patches;
|
|
3149
|
+
let gridTHW;
|
|
3150
|
+
if ("patches" in image) {
|
|
3151
|
+
patches = image.patches;
|
|
3152
|
+
gridTHW = image.gridTHW;
|
|
3153
|
+
} else {
|
|
3154
|
+
const pre$1 = preprocessImage(image.pixels, image.width, image.height, procCfg);
|
|
3155
|
+
patches = pre$1.patches;
|
|
3156
|
+
gridTHW = pre$1.gridTHW;
|
|
3157
|
+
}
|
|
3158
|
+
const numPatches = gridTHW[0] * gridTHW[1] * gridTHW[2];
|
|
3159
|
+
this.setPhase(`describe:vit-encode:N=${numPatches}`);
|
|
3160
|
+
const vision = await this.encodeImage(patches, gridTHW, (stage, info) => {
|
|
3161
|
+
const suffix = info?.layer != null ? `:L${info.layer}` : "";
|
|
3162
|
+
this.setPhase(`describe:${stage}${suffix}`);
|
|
3163
|
+
});
|
|
3164
|
+
const visStart = this.tokenizer.tokenToId("<|vision_start|>") ?? 248053;
|
|
3165
|
+
const visEnd = this.tokenizer.tokenToId("<|vision_end|>") ?? 248054;
|
|
3166
|
+
const { imageTokenId } = this.mropeParams();
|
|
3167
|
+
const numImageTokens = vision.rows;
|
|
3168
|
+
const pre = this.tokenizer.encode("<|im_start|>user\n");
|
|
3169
|
+
const post = this.tokenizer.encode(`${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`);
|
|
3170
|
+
const imageRun = new Array(numImageTokens).fill(imageTokenId);
|
|
3171
|
+
const inputIds = [
|
|
3172
|
+
...pre,
|
|
3173
|
+
visStart,
|
|
3174
|
+
...imageRun,
|
|
3175
|
+
visEnd,
|
|
3176
|
+
...post
|
|
3177
|
+
];
|
|
3178
|
+
return this.runMultimodal(vision.embeds, gridTHW, inputIds, options);
|
|
3179
|
+
}
|
|
3180
|
+
/**
|
|
3181
|
+
* Prepare the multimodal prefill: upload vision embeds, build the image row-map
|
|
3182
|
+
* and 3D M-RoPE cos/sin, reset state, and write all host inputs. Returns the
|
|
3183
|
+
* input ids and the post-image logical cursor for decode. Does NOT run forward.
|
|
3184
|
+
*/
|
|
3185
|
+
prepareMultimodalPrefill(visionEmbeds, gridTHW, inputIds) {
|
|
3186
|
+
const { imageTokenId, mergeSize } = this.mropeParams();
|
|
3187
|
+
const seq = inputIds.length;
|
|
3188
|
+
if (seq > this.maxSeqLen) throw new Error(`describeImage: prompt+image is ${seq} tokens > maxSeqLen ${this.maxSeqLen}. Increase maxSeqLen or use a smaller image.`);
|
|
3189
|
+
this.executor.reset();
|
|
3190
|
+
this.executor.writeInput("vision_embeds", visionEmbeds);
|
|
3191
|
+
const rowMap = new Int32Array(seq).fill(-1);
|
|
3192
|
+
let v = 0;
|
|
3193
|
+
for (let i = 0; i < seq; i++) if (inputIds[i] === imageTokenId) rowMap[i] = v++;
|
|
3194
|
+
const positionIds3 = buildMRoPEPositionIds(inputIds, [gridTHW], imageTokenId, mergeSize);
|
|
3195
|
+
return { lastLogicalPos: this.writeMRoPEPrefill(positionIds3, seq, rowMap) };
|
|
3196
|
+
}
|
|
3197
|
+
/**
|
|
3198
|
+
* Gemma 4 multimodal prefill + decode. Unlike Qwen3.5 (M-RoPE), Gemma 4 uses
|
|
3199
|
+
* STANDARD sequential 1D RoPE computed inside each layer from the KV write
|
|
3200
|
+
* position, so there are no host cos/sin inputs and decode positions are simply
|
|
3201
|
+
* the running seqPos — identical to plain text generation. We only upload the
|
|
3202
|
+
* merged vision embeds + an image-token row-map (EmbedSplice scatters them into
|
|
3203
|
+
* the image_token rows) before the forward pass.
|
|
3204
|
+
*/
|
|
3205
|
+
async runMultimodalGemma4(visionEmbeds, inputIds, imageTokenId, options) {
|
|
3206
|
+
const seq = inputIds.length;
|
|
3207
|
+
if (seq > this.maxSeqLen) throw new Error(`describeImage: prompt+image is ${seq} tokens > maxSeqLen ${this.maxSeqLen}. Increase maxSeqLen or use a smaller image.`);
|
|
3208
|
+
this.setPhase(`describe:splice:seq=${seq}`);
|
|
3209
|
+
this.executor.reset();
|
|
3210
|
+
this.executor.writeInput("vision_embeds", visionEmbeds);
|
|
3211
|
+
const rowMap = new Int32Array(seq).fill(-1);
|
|
3212
|
+
let v = 0;
|
|
3213
|
+
for (let i = 0; i < seq; i++) if (inputIds[i] === imageTokenId) rowMap[i] = v++;
|
|
3214
|
+
this.executor.writeInput("vision_row_map", rowMap);
|
|
3215
|
+
this.setPhase("describe:text-decode");
|
|
3216
|
+
const { maxTokens = 512, stopSequences = [], sampling = {}, onToken } = options;
|
|
3217
|
+
const startTime = performance.now();
|
|
3218
|
+
const isGreedy = (sampling.temperature ?? .7) < 1e-6;
|
|
3219
|
+
let { logits } = await this.executor.forward(new Uint32Array(inputIds));
|
|
3220
|
+
const generatedIds = [];
|
|
3221
|
+
let finishReason = "max_tokens";
|
|
3222
|
+
let generatedText = "";
|
|
3223
|
+
const eosId = this.tokenizer.config.eosTokenId;
|
|
3224
|
+
const eotId = this.tokenizer.tokenToId("<turn|>");
|
|
3225
|
+
const consumeToken = (nextToken) => {
|
|
3226
|
+
generatedIds.push(nextToken);
|
|
3227
|
+
if (eosId !== null && nextToken === eosId || eotId !== null && nextToken === eotId) {
|
|
3228
|
+
finishReason = "eos";
|
|
3229
|
+
return true;
|
|
3230
|
+
}
|
|
3231
|
+
const tokenText = this.tokenizer.decode([nextToken], true);
|
|
3232
|
+
generatedText += tokenText;
|
|
3233
|
+
onToken?.(tokenText);
|
|
3234
|
+
if (stopSequences.some((s) => generatedText.includes(s))) {
|
|
3235
|
+
for (const s of stopSequences) {
|
|
3236
|
+
const idx = generatedText.indexOf(s);
|
|
3237
|
+
if (idx !== -1) generatedText = generatedText.slice(0, idx);
|
|
3238
|
+
}
|
|
3239
|
+
finishReason = "stop_sequence";
|
|
3240
|
+
return true;
|
|
3241
|
+
}
|
|
3242
|
+
return false;
|
|
3243
|
+
};
|
|
3244
|
+
for (let step = 0; step < maxTokens; step++) {
|
|
3245
|
+
let nextToken;
|
|
3246
|
+
if (step === 0) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
3247
|
+
else if (isGreedy) nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
|
|
3248
|
+
else nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
3249
|
+
if (consumeToken(nextToken)) break;
|
|
3250
|
+
if (!isGreedy) logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
|
|
3251
|
+
}
|
|
3252
|
+
const totalTime = performance.now() - startTime;
|
|
3253
|
+
const tokensGenerated = generatedIds.length;
|
|
3254
|
+
return {
|
|
3255
|
+
text: generatedText,
|
|
3256
|
+
tokensGenerated,
|
|
3257
|
+
tokensPerSecond: tokensGenerated / (totalTime / 1e3),
|
|
3258
|
+
totalTime,
|
|
3259
|
+
finishReason
|
|
3260
|
+
};
|
|
3261
|
+
}
|
|
3262
|
+
/** Prepare + prefill + decode for a fully-specified multimodal token sequence. */
|
|
3263
|
+
async runMultimodal(visionEmbeds, gridTHW, inputIds, options) {
|
|
3264
|
+
this.setPhase(`describe:splice:seq=${inputIds.length}`);
|
|
3265
|
+
const { lastLogicalPos } = this.prepareMultimodalPrefill(visionEmbeds, gridTHW, inputIds);
|
|
3266
|
+
this.setPhase("describe:text-decode");
|
|
3267
|
+
return this.generateFromPrepared(inputIds, lastLogicalPos + 1, options);
|
|
3268
|
+
}
|
|
3269
|
+
/**
|
|
3270
|
+
* Debug: run ONLY the multimodal prefill for an explicit token sequence and
|
|
3271
|
+
* return the spliced input embeddings [seq, hidden] + first-token logits. Lets
|
|
3272
|
+
* tests compare the fused text+vision stream and M-RoPE numerically vs HF
|
|
3273
|
+
* without the decode loop overwriting intermediate buffers.
|
|
3274
|
+
*/
|
|
3275
|
+
async debugMultimodalPrefill(patches, gridTHW, inputIds) {
|
|
3276
|
+
this.checkDestroyed();
|
|
3277
|
+
if (!this._multimodalGraph || !this.visionExecutor) throw new Error("debugMultimodalPrefill requires a multimodal engine (enableVision: true).");
|
|
3278
|
+
const vision = await this.encodeImage(patches, gridTHW);
|
|
3279
|
+
this.prepareMultimodalPrefill(vision.embeds, gridTHW, inputIds);
|
|
3280
|
+
const { logits } = await this.executor.forward(new Uint32Array(inputIds));
|
|
3281
|
+
return {
|
|
3282
|
+
splicedEmbeds: await this.executor.debugReadBuffer("embed_spliced", inputIds.length * this.config.hidden_size),
|
|
3283
|
+
logits,
|
|
3284
|
+
seq: inputIds.length
|
|
3285
|
+
};
|
|
3286
|
+
}
|
|
3287
|
+
/**
|
|
3288
|
+
* Internal: run prefill (assumes M-RoPE/splice inputs already written) + decode,
|
|
3289
|
+
* with decode logical positions starting at `decodeStartPos`. Used by
|
|
3290
|
+
* describeImage so the post-image cursor is honored.
|
|
3291
|
+
*/
|
|
3292
|
+
async generateFromPrepared(inputIds, decodeStartPos, options) {
|
|
3293
|
+
const { maxTokens = 512, stopSequences = [], sampling = {}, onToken } = options;
|
|
3294
|
+
const startTime = performance.now();
|
|
3295
|
+
const isGreedy = (sampling.temperature ?? .7) < 1e-6;
|
|
3296
|
+
let { logits } = await this.executor.forward(new Uint32Array(inputIds));
|
|
3297
|
+
const generatedIds = [];
|
|
3298
|
+
let finishReason = "max_tokens";
|
|
3299
|
+
let generatedText = "";
|
|
3300
|
+
const eosId = this.tokenizer.config.eosTokenId;
|
|
3301
|
+
let mmLogicalPos = decodeStartPos;
|
|
3302
|
+
const consumeToken = (nextToken) => {
|
|
3303
|
+
generatedIds.push(nextToken);
|
|
3304
|
+
if (eosId !== null && nextToken === eosId) {
|
|
3305
|
+
finishReason = "eos";
|
|
3306
|
+
return true;
|
|
3307
|
+
}
|
|
3308
|
+
const tokenText = this.tokenizer.decode([nextToken], true);
|
|
3309
|
+
generatedText += tokenText;
|
|
3310
|
+
onToken?.(tokenText);
|
|
3311
|
+
if (stopSequences.some((s) => generatedText.includes(s))) {
|
|
3312
|
+
for (const s of stopSequences) {
|
|
3313
|
+
const idx = generatedText.indexOf(s);
|
|
3314
|
+
if (idx !== -1) generatedText = generatedText.slice(0, idx);
|
|
3315
|
+
}
|
|
3316
|
+
finishReason = "stop_sequence";
|
|
3317
|
+
return true;
|
|
3318
|
+
}
|
|
3319
|
+
return false;
|
|
3320
|
+
};
|
|
3321
|
+
for (let step = 0; step < maxTokens; step++) {
|
|
3322
|
+
let nextToken;
|
|
3323
|
+
if (step === 0) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
3324
|
+
else if (isGreedy) {
|
|
3325
|
+
this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
|
|
3326
|
+
nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
|
|
3327
|
+
mmLogicalPos++;
|
|
3328
|
+
} else nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
|
|
3329
|
+
if (consumeToken(nextToken)) break;
|
|
3330
|
+
if (!isGreedy) {
|
|
3331
|
+
this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
|
|
3332
|
+
logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
|
|
3333
|
+
mmLogicalPos++;
|
|
3334
|
+
}
|
|
3335
|
+
}
|
|
3336
|
+
const totalTime = performance.now() - startTime;
|
|
3337
|
+
const tokensGenerated = generatedIds.length;
|
|
3338
|
+
return {
|
|
3339
|
+
text: generatedText,
|
|
3340
|
+
tokensGenerated,
|
|
3341
|
+
tokensPerSecond: tokensGenerated / (totalTime / 1e3),
|
|
3342
|
+
totalTime,
|
|
3343
|
+
finishReason
|
|
3344
|
+
};
|
|
3345
|
+
}
|
|
3346
|
+
/**
|
|
3347
|
+
* Embed text into an L2-normalized vector. The pooling strategy depends on the
|
|
3348
|
+
* model: Qwen3-Embedding uses last-token (EOS-position) pooling, while
|
|
3349
|
+
* EmbeddingGemma (Gemma3 encoder) uses mean pooling over all tokens followed by
|
|
3350
|
+
* a 2-layer Dense head. Requires an embedding model (loaded with
|
|
3351
|
+
* { embedding: true }).
|
|
3352
|
+
*
|
|
3353
|
+
* The returned Float32Array has unit L2 norm, so cosine similarity reduces to a
|
|
3354
|
+
* dot product. Length is the model's embedding dim (768 for EmbeddingGemma;
|
|
3355
|
+
* config.hidden_size for Qwen3-Embedding).
|
|
3356
|
+
*
|
|
3357
|
+
* EmbeddingGemma is asymmetric — pass `{ taskType: "query" }` for search
|
|
3358
|
+
* queries and `{ taskType: "document" }` for the corpus, or a raw
|
|
3359
|
+
* `{ taskPrompt }` for other tasks (clustering/classification/STS).
|
|
3360
|
+
*/
|
|
3361
|
+
async embed(text, options = {}) {
|
|
3362
|
+
this.checkDestroyed();
|
|
3363
|
+
if (!this._isEmbedding) throw new Error("embed() requires an embedding model. Load with { embedding: true } (e.g. repo: 'Qwen/Qwen3-Embedding-0.6B' or 'mlx-community/embeddinggemma-300m-4bit').");
|
|
3364
|
+
this.executor.reset();
|
|
3365
|
+
const { instruction, taskType, taskPrompt, maxTokens } = options;
|
|
3366
|
+
if (this._architecture === "Gemma3TextModel" || this._architecture === "Gemma3Model") {
|
|
3367
|
+
const input$1 = `${taskPrompt ?? EMBEDDING_GEMMA_PROMPTS[taskType ?? "query"] ?? EMBEDDING_GEMMA_PROMPTS.query}${text}`;
|
|
3368
|
+
let ids$1 = this.tokenizer.encode(input$1);
|
|
3369
|
+
const cap$1 = maxTokens ?? this.config.context_length;
|
|
3370
|
+
if (ids$1.length > cap$1) ids$1 = ids$1.slice(0, cap$1);
|
|
3371
|
+
return this.executor.embed(new Uint32Array(ids$1));
|
|
3372
|
+
}
|
|
3373
|
+
const input = instruction ? `Instruct: ${instruction}\nQuery:${text}` : text;
|
|
3374
|
+
let ids = this.tokenizer.encode(input);
|
|
3375
|
+
const padId = this.tokenizer.tokenToId("<|endoftext|>") ?? this.tokenizer.config.eosTokenId;
|
|
3376
|
+
if (padId !== null && ids[ids.length - 1] !== padId) ids.push(padId);
|
|
3377
|
+
const cap = maxTokens ?? this.config.context_length;
|
|
3378
|
+
if (ids.length > cap) {
|
|
3379
|
+
const tail = padId !== null && ids[ids.length - 1] === padId ? [padId] : [];
|
|
3380
|
+
ids = [...ids.slice(0, cap - tail.length), ...tail];
|
|
3381
|
+
}
|
|
3382
|
+
return this.executor.embed(new Uint32Array(ids));
|
|
3383
|
+
}
|
|
3384
|
+
/**
|
|
3385
|
+
* Generate text as an async iterator (streaming).
|
|
3386
|
+
*
|
|
3387
|
+
* Uses the onToken callback from generate() to push tokens into a queue
|
|
3388
|
+
* that the async generator yields from. The generator returns the full
|
|
3389
|
+
* GenerateResult when generation completes.
|
|
3390
|
+
*
|
|
3391
|
+
* Usage:
|
|
3392
|
+
* const gen = engine.stream("Hello!");
|
|
3393
|
+
* for await (const token of gen) {
|
|
3394
|
+
* process.stdout.write(token);
|
|
3395
|
+
* }
|
|
3396
|
+
* const result = gen.next(); // { done: true, value: GenerateResult }
|
|
3397
|
+
*/
|
|
3398
|
+
async *stream(prompt, options = {}) {
|
|
3399
|
+
this.checkDestroyed();
|
|
3400
|
+
const tokenQueue = [];
|
|
3401
|
+
let resolve = null;
|
|
3402
|
+
let done = false;
|
|
3403
|
+
const pushToken = (token) => {
|
|
3404
|
+
tokenQueue.push(token);
|
|
3405
|
+
if (resolve) {
|
|
3406
|
+
const r = resolve;
|
|
3407
|
+
resolve = null;
|
|
3408
|
+
r();
|
|
3409
|
+
}
|
|
3410
|
+
};
|
|
3411
|
+
const waitForToken = () => {
|
|
3412
|
+
if (tokenQueue.length > 0 || done) return Promise.resolve();
|
|
3413
|
+
return new Promise((r) => {
|
|
3414
|
+
resolve = r;
|
|
3415
|
+
});
|
|
3416
|
+
};
|
|
3417
|
+
const genPromise = this.generate(prompt, {
|
|
3418
|
+
...options,
|
|
3419
|
+
onToken: (token) => {
|
|
3420
|
+
options.onToken?.(token);
|
|
3421
|
+
pushToken(token);
|
|
3422
|
+
}
|
|
3423
|
+
}).then((result) => {
|
|
3424
|
+
done = true;
|
|
3425
|
+
if (resolve) {
|
|
3426
|
+
const r = resolve;
|
|
3427
|
+
resolve = null;
|
|
3428
|
+
r();
|
|
3429
|
+
}
|
|
3430
|
+
return result;
|
|
3431
|
+
});
|
|
3432
|
+
let yielded = 0;
|
|
3433
|
+
while (true) {
|
|
3434
|
+
await waitForToken();
|
|
3435
|
+
while (yielded < tokenQueue.length) yield tokenQueue[yielded++];
|
|
3436
|
+
if (done && yielded >= tokenQueue.length) break;
|
|
3437
|
+
}
|
|
3438
|
+
return await genPromise;
|
|
3439
|
+
}
|
|
3440
|
+
/**
|
|
3441
|
+
* Debug: read back a named GPU buffer (weight or activation).
|
|
3442
|
+
* Call after forward() to inspect intermediate values.
|
|
3443
|
+
*/
|
|
3444
|
+
async debugReadBuffer(tensorName, maxElements) {
|
|
3445
|
+
return this.executor.debugReadBuffer(tensorName, maxElements);
|
|
3446
|
+
}
|
|
3447
|
+
/**
|
|
3448
|
+
* Run GPU diagnostics (buffer integrity, compute, shared memory).
|
|
3449
|
+
* Useful for isolating Safari/WebKit-specific WebGPU issues.
|
|
3450
|
+
*/
|
|
3451
|
+
async diagnose() {
|
|
3452
|
+
return verifyGPU(this.ctx);
|
|
3453
|
+
}
|
|
3454
|
+
/**
|
|
3455
|
+
* Run GPU diagnostics without loading a model.
|
|
3456
|
+
* Quick way to check if WebGPU is working correctly on this device.
|
|
3457
|
+
*/
|
|
3458
|
+
static async quickDiagnose() {
|
|
3459
|
+
const ctx = await initGPU();
|
|
3460
|
+
const result = await verifyGPU(ctx);
|
|
3461
|
+
clearPipelineCache(ctx.device);
|
|
3462
|
+
ctx.device.destroy();
|
|
3463
|
+
return result;
|
|
3464
|
+
}
|
|
3465
|
+
/**
|
|
3466
|
+
* Run a raw forward pass (no tokenization/chat template).
|
|
3467
|
+
* Returns logits for the last token.
|
|
3468
|
+
*/
|
|
3469
|
+
async rawForward(inputIds) {
|
|
3470
|
+
return this.executor.forward(inputIds);
|
|
3471
|
+
}
|
|
3472
|
+
/**
|
|
3473
|
+
* Reset executor state (SSM, positions, etc.)
|
|
3474
|
+
*/
|
|
3475
|
+
resetState() {
|
|
3476
|
+
this.executor.reset();
|
|
3477
|
+
}
|
|
3478
|
+
/**
|
|
3479
|
+
* Encode text to token IDs (useful for debugging / token counting).
|
|
3480
|
+
*/
|
|
3481
|
+
encode(text) {
|
|
3482
|
+
return this.tokenizer.encode(text);
|
|
3483
|
+
}
|
|
3484
|
+
/**
|
|
3485
|
+
* Decode token IDs to text.
|
|
3486
|
+
*/
|
|
3487
|
+
decode(ids, skipSpecialTokens) {
|
|
3488
|
+
return this.tokenizer.decode(ids, skipSpecialTokens);
|
|
3489
|
+
}
|
|
3490
|
+
/**
|
|
3491
|
+
* Integrity check: reads back key weight tensors and runs a single forward pass,
|
|
3492
|
+
* returning checksums for comparison against a known-good reference (Dawn/Node.js).
|
|
3493
|
+
*
|
|
3494
|
+
* Use this to isolate Safari/iPad corruption:
|
|
3495
|
+
* - If weights mismatch → fetch/download pipeline is corrupt
|
|
3496
|
+
* - If weights match but logits mismatch → kernel computation bug on Metal
|
|
3497
|
+
*
|
|
3498
|
+
* Resets executor state before and after (safe to call anytime).
|
|
3499
|
+
*/
|
|
3500
|
+
async integrityCheck() {
|
|
3501
|
+
this.checkDestroyed();
|
|
3502
|
+
const checks = [];
|
|
3503
|
+
const log = (label, data) => {
|
|
3504
|
+
const f = data instanceof Float32Array ? data : new Float32Array(data.buffer);
|
|
3505
|
+
let sum = 0;
|
|
3506
|
+
let maxVal = -Infinity;
|
|
3507
|
+
let argmax$1 = 0;
|
|
3508
|
+
for (let i = 0; i < f.length; i++) {
|
|
3509
|
+
sum += f[i];
|
|
3510
|
+
if (f[i] > maxVal) {
|
|
3511
|
+
maxVal = f[i];
|
|
3512
|
+
argmax$1 = i;
|
|
3513
|
+
}
|
|
3514
|
+
}
|
|
3515
|
+
const entry = {
|
|
3516
|
+
label,
|
|
3517
|
+
length: f.length,
|
|
3518
|
+
sum: Math.round(sum * 1e6) / 1e6,
|
|
3519
|
+
first4: [
|
|
3520
|
+
Math.round(f[0] * 1e6) / 1e6,
|
|
3521
|
+
Math.round(f[1] * 1e6) / 1e6,
|
|
3522
|
+
Math.round(f[2] * 1e6) / 1e6,
|
|
3523
|
+
Math.round(f[3] * 1e6) / 1e6
|
|
3524
|
+
],
|
|
3525
|
+
argmax: argmax$1,
|
|
3526
|
+
maxVal: Math.round(maxVal * 1e4) / 1e4
|
|
3527
|
+
};
|
|
3528
|
+
checks.push(entry);
|
|
3529
|
+
return entry;
|
|
3530
|
+
};
|
|
3531
|
+
checks.push({
|
|
3532
|
+
label: `webkit_detection: isWebKitWebGPU=${this.ctx.isWebKitWebGPU} needsMultiEncoder=${this.executor.needsMultiEncoder}`,
|
|
3533
|
+
length: 0,
|
|
3534
|
+
sum: 0,
|
|
3535
|
+
first4: [
|
|
3536
|
+
0,
|
|
3537
|
+
0,
|
|
3538
|
+
0,
|
|
3539
|
+
0
|
|
3540
|
+
],
|
|
3541
|
+
argmax: 0,
|
|
3542
|
+
maxVal: 0,
|
|
3543
|
+
note: this.ctx.adapterDescription
|
|
3544
|
+
});
|
|
3545
|
+
for (const name of ["norm.weight", "layers.0.input_layernorm.weight"]) try {
|
|
3546
|
+
log(name, await this.debugReadBuffer(name));
|
|
3547
|
+
} catch {
|
|
3548
|
+
checks.push({
|
|
3549
|
+
label: name,
|
|
3550
|
+
length: 0,
|
|
3551
|
+
sum: NaN,
|
|
3552
|
+
first4: [
|
|
3553
|
+
NaN,
|
|
3554
|
+
NaN,
|
|
3555
|
+
NaN,
|
|
3556
|
+
NaN
|
|
3557
|
+
],
|
|
3558
|
+
argmax: -1,
|
|
3559
|
+
maxVal: NaN,
|
|
3560
|
+
error: "buffer not found"
|
|
3561
|
+
});
|
|
3562
|
+
}
|
|
3563
|
+
for (const qName of [
|
|
3564
|
+
"layers.0.mlp.gate_proj.weight.q",
|
|
3565
|
+
"embed_tokens.weight.q",
|
|
3566
|
+
"embed_tokens.weight.scales",
|
|
3567
|
+
"embed_tokens.weight.zeros"
|
|
3568
|
+
]) try {
|
|
3569
|
+
const q = await this.debugReadBuffer(qName, 16);
|
|
3570
|
+
log(`${qName} (reinterpret)`, q);
|
|
3571
|
+
} catch {
|
|
3572
|
+
checks.push({
|
|
3573
|
+
label: qName,
|
|
3574
|
+
length: 0,
|
|
3575
|
+
sum: NaN,
|
|
3576
|
+
first4: [
|
|
3577
|
+
NaN,
|
|
3578
|
+
NaN,
|
|
3579
|
+
NaN,
|
|
3580
|
+
NaN
|
|
3581
|
+
],
|
|
3582
|
+
argmax: -1,
|
|
3583
|
+
maxVal: NaN,
|
|
3584
|
+
error: "not found"
|
|
3585
|
+
});
|
|
3586
|
+
}
|
|
3587
|
+
this.executor.reset();
|
|
3588
|
+
try {
|
|
3589
|
+
const result = await this.executor.debugFirstDispatch(new Uint32Array([1]));
|
|
3590
|
+
const entry = log(`single_dispatch(${result.opType})`, result.output);
|
|
3591
|
+
entry.note = `dispatch=${result.dispatchSize.join(",")} node=${result.nodeId}`;
|
|
3592
|
+
} catch (e) {
|
|
3593
|
+
checks.push({
|
|
3594
|
+
label: "single_dispatch",
|
|
3595
|
+
length: 0,
|
|
3596
|
+
sum: NaN,
|
|
3597
|
+
first4: [
|
|
3598
|
+
NaN,
|
|
3599
|
+
NaN,
|
|
3600
|
+
NaN,
|
|
3601
|
+
NaN
|
|
3602
|
+
],
|
|
3603
|
+
argmax: -1,
|
|
3604
|
+
maxVal: NaN,
|
|
3605
|
+
error: e.message
|
|
3606
|
+
});
|
|
3607
|
+
}
|
|
3608
|
+
try {
|
|
3609
|
+
const result = await this.executor.debugDispatchEntry(1, 1);
|
|
3610
|
+
const entry = log(`isolated_entry1(${result.opType})`, result.output);
|
|
3611
|
+
entry.note = `node=${result.nodeId}`;
|
|
3612
|
+
} catch (e) {
|
|
3613
|
+
checks.push({
|
|
3614
|
+
label: "isolated_entry1",
|
|
3615
|
+
length: 0,
|
|
3616
|
+
sum: NaN,
|
|
3617
|
+
first4: [
|
|
3618
|
+
NaN,
|
|
3619
|
+
NaN,
|
|
3620
|
+
NaN,
|
|
3621
|
+
NaN
|
|
3622
|
+
],
|
|
3623
|
+
argmax: -1,
|
|
3624
|
+
maxVal: NaN,
|
|
3625
|
+
error: e.message
|
|
3626
|
+
});
|
|
3627
|
+
}
|
|
3628
|
+
try {
|
|
3629
|
+
const result = await this.executor.debugDispatchEntry(2, 1);
|
|
3630
|
+
const entry = log(`isolated_entry2(${result.opType})`, result.output);
|
|
3631
|
+
entry.note = `node=${result.nodeId}`;
|
|
3632
|
+
} catch (e) {
|
|
3633
|
+
checks.push({
|
|
3634
|
+
label: "isolated_entry2",
|
|
3635
|
+
length: 0,
|
|
3636
|
+
sum: NaN,
|
|
3637
|
+
first4: [
|
|
3638
|
+
NaN,
|
|
3639
|
+
NaN,
|
|
3640
|
+
NaN,
|
|
3641
|
+
NaN
|
|
3642
|
+
],
|
|
3643
|
+
argmax: -1,
|
|
3644
|
+
maxVal: NaN,
|
|
3645
|
+
error: e.message
|
|
3646
|
+
});
|
|
3647
|
+
}
|
|
3648
|
+
this.executor.reset();
|
|
3649
|
+
const jsParams = this.executor.debugComputeParams(1, 5);
|
|
3650
|
+
for (const p of jsParams) checks.push({
|
|
3651
|
+
label: `jsParams[${p.idx}] ${p.opType} dispatch=[${p.dispatchSize}] u32=[${p.paramsU32}]`,
|
|
3652
|
+
length: p.paramsU32.length,
|
|
3653
|
+
sum: p.paramsU32.reduce((a, b) => a + b, 0),
|
|
3654
|
+
first4: p.paramsU32.slice(0, 4).map(Number),
|
|
3655
|
+
argmax: 0,
|
|
3656
|
+
maxVal: 0
|
|
3657
|
+
});
|
|
3658
|
+
this.executor.reset();
|
|
3659
|
+
try {
|
|
3660
|
+
const { logits } = await this.executor.forward(new Uint32Array([1]));
|
|
3661
|
+
log("logits(token=1)", logits);
|
|
3662
|
+
try {
|
|
3663
|
+
log("embed_out", await this.debugReadBuffer("embed_out", 16));
|
|
3664
|
+
} catch {}
|
|
3665
|
+
try {
|
|
3666
|
+
const probes = await this.executor.debugPipelineProbe(1);
|
|
3667
|
+
for (const p of probes) {
|
|
3668
|
+
const paramsStr = p.uniformParams ? ` params=[${p.uniformParams.join(",")}]` : "";
|
|
3669
|
+
const entry = {
|
|
3670
|
+
label: `probe[${p.idx}] ${p.opType} → ${p.tensor}${paramsStr}`,
|
|
3671
|
+
length: 16,
|
|
3672
|
+
sum: p.sum,
|
|
3673
|
+
first4: p.first4,
|
|
3674
|
+
argmax: 0,
|
|
3675
|
+
maxVal: Math.max(...p.first4.map(Math.abs))
|
|
3676
|
+
};
|
|
3677
|
+
if (p.sum === 0 && p.first4.every((v) => v === 0)) {
|
|
3678
|
+
entry.match = "FAIL";
|
|
3679
|
+
entry.note = "all zeros — activation dead";
|
|
3680
|
+
}
|
|
3681
|
+
checks.push(entry);
|
|
3682
|
+
}
|
|
3683
|
+
} catch (e) {
|
|
3684
|
+
checks.push({
|
|
3685
|
+
label: "pipeline_probe",
|
|
3686
|
+
length: 0,
|
|
3687
|
+
sum: NaN,
|
|
3688
|
+
first4: [
|
|
3689
|
+
NaN,
|
|
3690
|
+
NaN,
|
|
3691
|
+
NaN,
|
|
3692
|
+
NaN
|
|
3693
|
+
],
|
|
3694
|
+
argmax: -1,
|
|
3695
|
+
maxVal: NaN,
|
|
3696
|
+
error: e.message
|
|
3697
|
+
});
|
|
3698
|
+
}
|
|
3699
|
+
} catch (e) {
|
|
3700
|
+
checks.push({
|
|
3701
|
+
label: "logits(token=1)",
|
|
3702
|
+
length: 0,
|
|
3703
|
+
sum: NaN,
|
|
3704
|
+
first4: [
|
|
3705
|
+
NaN,
|
|
3706
|
+
NaN,
|
|
3707
|
+
NaN,
|
|
3708
|
+
NaN
|
|
3709
|
+
],
|
|
3710
|
+
argmax: -1,
|
|
3711
|
+
maxVal: NaN,
|
|
3712
|
+
error: e.message
|
|
3713
|
+
});
|
|
3714
|
+
}
|
|
3715
|
+
this.executor.reset();
|
|
3716
|
+
const reference = {
|
|
3717
|
+
"norm.weight": {
|
|
3718
|
+
sum: 4412.571289,
|
|
3719
|
+
first4: [
|
|
3720
|
+
.222656,
|
|
3721
|
+
3.75,
|
|
3722
|
+
3.640625,
|
|
3723
|
+
4.0625
|
|
3724
|
+
]
|
|
3725
|
+
},
|
|
3726
|
+
"layers.0.input_layernorm.weight": {
|
|
3727
|
+
sum: 1267.312683,
|
|
3728
|
+
first4: [
|
|
3729
|
+
2,
|
|
3730
|
+
1.294922,
|
|
3731
|
+
1.535156,
|
|
3732
|
+
1.746094
|
|
3733
|
+
]
|
|
3734
|
+
},
|
|
3735
|
+
embed_out: {
|
|
3736
|
+
sum: -.067839,
|
|
3737
|
+
first4: [
|
|
3738
|
+
.010059,
|
|
3739
|
+
.0264,
|
|
3740
|
+
.001888,
|
|
3741
|
+
-.030794
|
|
3742
|
+
]
|
|
3743
|
+
}
|
|
3744
|
+
};
|
|
3745
|
+
for (const check of checks) {
|
|
3746
|
+
const ref = reference[check.label];
|
|
3747
|
+
if (ref) {
|
|
3748
|
+
const sumMatch = Math.abs(check.sum - ref.sum) < Math.max(Math.abs(ref.sum) * .001, 1);
|
|
3749
|
+
const first4Match = check.first4.every((v, i) => Math.abs(v - ref.first4[i]) < .01);
|
|
3750
|
+
check.match = sumMatch && first4Match ? "PASS" : "FAIL";
|
|
3751
|
+
check.refSum = ref.sum;
|
|
3752
|
+
} else if (check.label === "logits(token=1)") {
|
|
3753
|
+
const hasNaN = check.first4.some((v) => Number.isNaN(v));
|
|
3754
|
+
const hasInf = check.first4.some((v) => !Number.isFinite(v));
|
|
3755
|
+
const allSame = check.first4.every((v) => v === check.first4[0]);
|
|
3756
|
+
const argmaxValid = check.argmax >= 0 && check.argmax < check.length;
|
|
3757
|
+
check.match = !hasNaN && !hasInf && !allSame && argmaxValid ? "PASS" : "FAIL";
|
|
3758
|
+
}
|
|
3759
|
+
}
|
|
3760
|
+
console.log("\n=== INTEGRITY CHECK ===");
|
|
3761
|
+
for (const c of checks) {
|
|
3762
|
+
const status = c.error ? "ERR" : c.match ?? "---";
|
|
3763
|
+
console.log(`[${status}] ${c.label}: sum=${c.sum}, first4=[${c.first4.join(", ")}], argmax=${c.argmax}, len=${c.length}` + (c.refSum !== void 0 ? ` (ref sum=${c.refSum})` : "") + (c.error ? ` ERROR: ${c.error}` : "") + (c.note ? ` | ${c.note}` : ""));
|
|
3764
|
+
}
|
|
3765
|
+
console.log("=== END INTEGRITY CHECK ===\n");
|
|
3766
|
+
return {
|
|
3767
|
+
checks,
|
|
3768
|
+
allPass: checks.every((c) => c.match === "PASS" || c.match === void 0)
|
|
3769
|
+
};
|
|
3770
|
+
}
|
|
3771
|
+
/**
|
|
3772
|
+
* Destroy the engine and free all GPU resources.
|
|
3773
|
+
*/
|
|
3774
|
+
destroy() {
|
|
3775
|
+
if (this._destroyed) return;
|
|
3776
|
+
this._destroyed = true;
|
|
3777
|
+
this.executor.destroy();
|
|
3778
|
+
this._kaniTTS?.destroy();
|
|
3779
|
+
this.visionExecutor?.destroy();
|
|
3780
|
+
clearPipelineCache(this.ctx.device);
|
|
3781
|
+
this.ctx.device.destroy();
|
|
3782
|
+
}
|
|
3783
|
+
checkDestroyed() {
|
|
3784
|
+
if (this._destroyed) throw new Error("WebGPUEngine has been destroyed");
|
|
3785
|
+
}
|
|
3786
|
+
};
|
|
3787
|
+
|
|
3788
|
+
//#endregion
|
|
3789
|
+
export { generateGemma4VisionGraph as C, dequantizeMLXProjection as S, resolveGemma4VisionInfo as T, smartResize as _, buildGemma4PosEmbeds as a, generateQwen3_5VisionGraph as b, buildMRoPECosSin as c, buildPositionIds as d, buildRotaryCosSin as f, preprocessImageGemma4 as g, preprocessImage as h, buildGemma4PoolMatrix as i, buildMRoPEPositionIds as l, mropeFreqDims as m, GEMMA4_IMAGE_PROCESSOR as n, buildGemma4RotaryCosSin as o, buildVisionPositionTensors as p, QWEN3_5_IMAGE_PROCESSOR as r, buildGemma4VisionPositionTensors as s, WebGPUEngine as t, buildPosEmbeds as u, VisionExecutor as v, patchGemma4VisionClips as w, dequantizeGemma4VisionProjection as x, KaniTTS as y };
|
|
3790
|
+
//# sourceMappingURL=gpu-DFuglcEx.mjs.map
|