@tryhamster/gerbil 1.0.0-rc.9 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +247 -84
  3. package/dist/architectures-C1I5V3Dt.mjs +6070 -0
  4. package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
  5. package/dist/browser/index.d.ts +264 -588
  6. package/dist/browser/index.d.ts.map +1 -1
  7. package/dist/browser/index.js +585 -2334
  8. package/dist/browser/index.js.map +1 -1
  9. package/dist/cli.mjs +625 -1098
  10. package/dist/cli.mjs.map +1 -1
  11. package/dist/defaults-9komdrbY.mjs +24 -0
  12. package/dist/defaults-9komdrbY.mjs.map +1 -0
  13. package/dist/frameworks/express.d.mts +1 -3
  14. package/dist/frameworks/express.d.mts.map +1 -1
  15. package/dist/frameworks/express.mjs +7 -7
  16. package/dist/frameworks/express.mjs.map +1 -1
  17. package/dist/frameworks/fastify.d.mts +1 -1
  18. package/dist/frameworks/fastify.d.mts.map +1 -1
  19. package/dist/frameworks/fastify.mjs +3 -3
  20. package/dist/frameworks/fastify.mjs.map +1 -1
  21. package/dist/frameworks/hono.d.mts +1 -1
  22. package/dist/frameworks/hono.d.mts.map +1 -1
  23. package/dist/frameworks/hono.mjs +4 -4
  24. package/dist/frameworks/hono.mjs.map +1 -1
  25. package/dist/frameworks/next.d.mts +3 -2
  26. package/dist/frameworks/next.d.mts.map +1 -1
  27. package/dist/frameworks/next.mjs +4 -4
  28. package/dist/frameworks/next.mjs.map +1 -1
  29. package/dist/frameworks/react.d.mts +1 -1
  30. package/dist/frameworks/trpc.d.mts +1 -1
  31. package/dist/frameworks/trpc.d.mts.map +1 -1
  32. package/dist/frameworks/trpc.mjs +4 -4
  33. package/dist/frameworks/trpc.mjs.map +1 -1
  34. package/dist/gerbil-BHrJJIa4.mjs +1656 -0
  35. package/dist/gerbil-BHrJJIa4.mjs.map +1 -0
  36. package/dist/gerbil-BT9fCydo.d.mts +488 -0
  37. package/dist/gerbil-BT9fCydo.d.mts.map +1 -0
  38. package/dist/gerbil-DomNfIr1.mjs +4 -0
  39. package/dist/gpu/hooks.d.mts +520 -0
  40. package/dist/gpu/hooks.d.mts.map +1 -0
  41. package/dist/gpu/hooks.mjs +1188 -0
  42. package/dist/gpu/hooks.mjs.map +1 -0
  43. package/dist/gpu/index.d.mts +2 -0
  44. package/dist/gpu/index.mjs +6 -0
  45. package/dist/gpu-33qCAtHW.mjs +3615 -0
  46. package/dist/gpu-33qCAtHW.mjs.map +1 -0
  47. package/dist/index-Dgmb2kE3.d.mts +245 -0
  48. package/dist/index-Dgmb2kE3.d.mts.map +1 -0
  49. package/dist/index-jEAL2s-A.d.mts +2022 -0
  50. package/dist/index-jEAL2s-A.d.mts.map +1 -0
  51. package/dist/index.d.mts +22 -487
  52. package/dist/index.d.mts.map +1 -1
  53. package/dist/index.mjs +13 -8
  54. package/dist/index.mjs.map +1 -1
  55. package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
  56. package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
  57. package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
  58. package/dist/integrations/ai-sdk.d.mts +75 -6
  59. package/dist/integrations/ai-sdk.d.mts.map +1 -1
  60. package/dist/integrations/ai-sdk.mjs +131 -15
  61. package/dist/integrations/ai-sdk.mjs.map +1 -1
  62. package/dist/integrations/langchain.d.mts +1 -1
  63. package/dist/integrations/langchain.d.mts.map +1 -1
  64. package/dist/integrations/langchain.mjs +5 -5
  65. package/dist/integrations/langchain.mjs.map +1 -1
  66. package/dist/integrations/llamaindex.d.mts +1 -1
  67. package/dist/integrations/llamaindex.d.mts.map +1 -1
  68. package/dist/integrations/llamaindex.mjs +5 -5
  69. package/dist/integrations/llamaindex.mjs.map +1 -1
  70. package/dist/integrations/mcp-client.mjs +3 -3
  71. package/dist/integrations/mcp-client.mjs.map +1 -1
  72. package/dist/integrations/mcp.d.mts +3 -2
  73. package/dist/integrations/mcp.d.mts.map +1 -1
  74. package/dist/integrations/mcp.mjs +5 -5
  75. package/dist/{mcp-BvbriaBy.mjs → mcp-1DaMsaBc.mjs} +4 -4
  76. package/dist/mcp-1DaMsaBc.mjs.map +1 -0
  77. package/dist/memory/index.d.mts +3 -0
  78. package/dist/memory/index.mjs +6 -0
  79. package/dist/memory-D1P7Tmda.mjs +4 -0
  80. package/dist/memory-DVN0MnIG.mjs +132 -0
  81. package/dist/memory-DVN0MnIG.mjs.map +1 -0
  82. package/dist/memory-Dj0J1v88.mjs +294 -0
  83. package/dist/memory-Dj0J1v88.mjs.map +1 -0
  84. package/dist/moonshine-stt-BLyVoRpB.mjs +4 -0
  85. package/dist/moonshine-stt-v_P_Ci_m.mjs +11936 -0
  86. package/dist/moonshine-stt-v_P_Ci_m.mjs.map +1 -0
  87. package/dist/{one-liner-s-lD8rCC.mjs → one-liner-DnQn7HJK.mjs} +14 -16
  88. package/dist/one-liner-DnQn7HJK.mjs.map +1 -0
  89. package/dist/repl-jV5gcJFA.mjs +9 -0
  90. package/dist/skills/index.d.mts +270 -320
  91. package/dist/skills/index.d.mts.map +1 -1
  92. package/dist/skills/index.mjs +5 -5
  93. package/dist/{skills-CD3Orlex.mjs → skills-DX8D59UH.mjs} +187 -32
  94. package/dist/skills-DX8D59UH.mjs.map +1 -0
  95. package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
  96. package/dist/tools-DQ1mPUw5.mjs.map +1 -0
  97. package/dist/{types-CiTc7ez3.d.mts → types-D6FiR_oh.d.mts} +106 -12
  98. package/dist/types-D6FiR_oh.d.mts.map +1 -0
  99. package/dist/types-DQBe2lFo.d.mts +165 -0
  100. package/dist/types-DQBe2lFo.d.mts.map +1 -0
  101. package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
  102. package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
  103. package/dist/vector-B0panuy6.mjs +95 -0
  104. package/dist/vector-B0panuy6.mjs.map +1 -0
  105. package/docs/PROJECT-STATE.md +321 -0
  106. package/docs/adding-a-model-family.md +280 -0
  107. package/docs/ai-sdk.md +70 -61
  108. package/docs/architecture/overview.md +17 -7
  109. package/docs/browser.md +203 -8
  110. package/docs/embeddings.md +156 -0
  111. package/docs/gerbil-site-native-migration.md +217 -0
  112. package/docs/gpu-engine/architectures.md +398 -0
  113. package/docs/gpu-engine/ir.md +372 -0
  114. package/docs/gpu-engine/kernels.md +718 -0
  115. package/docs/gpu-engine/paper.html +1759 -0
  116. package/docs/gpu-engine/paper.md +2109 -0
  117. package/docs/gpu-engine/safetensors.md +312 -0
  118. package/docs/gpu-engine/tokenizer.md +302 -0
  119. package/docs/memory-rag.md +91 -0
  120. package/docs/metal-safari-intel.md +190 -0
  121. package/docs/mobile-failure-diagnosis.md +124 -0
  122. package/docs/mobile.md +99 -0
  123. package/docs/observability.md +230 -0
  124. package/docs/onnx-removal-plan.md +339 -0
  125. package/docs/research/autoresearch-portable.md +904 -0
  126. package/docs/research/dispatch-reduction-hivemind.md +84 -0
  127. package/docs/research/ios-safari-model-caching.md +117 -0
  128. package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
  129. package/docs/research/native-stt-model-selection.md +49 -0
  130. package/docs/research/native-tts-model-selection.md +90 -0
  131. package/docs/research/native-vs-chromium-decision.md +152 -0
  132. package/docs/research/nemotron-mamba2-inference.md +910 -0
  133. package/docs/research/qwen35-multimodal.md +293 -0
  134. package/docs/research/qwen36-gemma4-targets.md +337 -0
  135. package/docs/research/sota-embedding-models.md +179 -0
  136. package/docs/research/sota-mobile-models-2026.md +263 -0
  137. package/docs/research/sota-modality-models.md +202 -0
  138. package/docs/research/tps-baselines.md +71 -0
  139. package/docs/research/webgpu-m4-reference.md +104 -0
  140. package/docs/site-update-plan.md +155 -0
  141. package/docs/structured-output.md +123 -0
  142. package/docs/stt.md +63 -446
  143. package/docs/tts.md +77 -499
  144. package/docs/vision.md +100 -338
  145. package/package.json +22 -7
  146. package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
  147. package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
  148. package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
  149. package/dist/gerbil-CJ3ifloF.mjs +0 -4
  150. package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
  151. package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
  152. package/dist/gerbil-qOTe1nl2.d.mts +0 -431
  153. package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
  154. package/dist/kokoro-BNTb6egA.mjs +0 -20210
  155. package/dist/kokoro-BNTb6egA.mjs.map +0 -1
  156. package/dist/kokoro-CMOGDSgT.js +0 -20212
  157. package/dist/kokoro-CMOGDSgT.js.map +0 -1
  158. package/dist/mcp-BvbriaBy.mjs.map +0 -1
  159. package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
  160. package/dist/repl-DveXw36T.mjs +0 -9
  161. package/dist/skills-CD3Orlex.mjs.map +0 -1
  162. package/dist/stt-Bu-E23Sc.js +0 -433
  163. package/dist/stt-Bu-E23Sc.js.map +0 -1
  164. package/dist/stt-CpLYbGFd.mjs +0 -433
  165. package/dist/stt-CpLYbGFd.mjs.map +0 -1
  166. package/dist/stt-DRPLEEHB.mjs +0 -3
  167. package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
  168. package/dist/transformers.web-DiD1gTwk.js +0 -44695
  169. package/dist/transformers.web-DiD1gTwk.js.map +0 -1
  170. package/dist/transformers.web-u34VxRFM.js +0 -3
  171. package/dist/tts-CqroPaSK.js +0 -724
  172. package/dist/tts-CqroPaSK.js.map +0 -1
  173. package/dist/tts-DXgsKGCe.mjs +0 -3
  174. package/dist/tts-DeGANMNV.mjs +0 -730
  175. package/dist/tts-DeGANMNV.mjs.map +0 -1
  176. package/dist/types-CiTc7ez3.d.mts.map +0 -1
  177. /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
  178. /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
  179. /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
@@ -0,0 +1,3615 @@
1
+ import { n as resolveDefaultRepo } from "./defaults-9komdrbY.mjs";
2
+ import { C as parseKaniConfig, E as GEMMA4_VIS_KEYS, T as DTYPE_BYTES, _ as generateNanoCodecDecoderGraph, b as kaniLayerAlpha, d as KANI_START_OF_HUMAN, f as audioTokensToCodes, g as generateKaniTtsGraph, m as computeKaniPositions, p as buildKaniLayerCosSin, u as KANI_END_OF_HUMAN, v as kaniAttentionLayerIndices, w as CANONICAL_KEYS, x as kaniSinTensor, y as kaniCosTensor } from "./architectures-C1I5V3Dt.mjs";
3
+ import { c as MATMUL_BIAS_F16C_SPEC, d as createStorageBuffer, f as createUniformBuffer, g as verifyGPU, h as initGPU, i as loadModel, l as clearPipelineCache, m as getOrCreatePipeline, o as Executor, p as destroyBuffers, r as loadKaniTTS, s as KERNEL_REGISTRY, u as createBindGroup } from "./moonshine-stt-v_P_Ci_m.mjs";
4
+
5
+ //#region src/gpu/architectures/gemma4_vision.ts
6
+ /**
7
+ * Resolve the Gemma 4 vision dims from a raw HF config. Accepts either the
8
+ * top-level config (reads `.vision_config` + `.text_config.hidden_size`) or a
9
+ * bare vision_config (then `textHidden` falls back to the projector row count if
10
+ * present, else hidden). Family-general — no E2B constants.
11
+ */
12
+ function resolveGemma4VisionInfo(rawConfig) {
13
+ const vcfg = rawConfig.vision_config ?? rawConfig;
14
+ const tcfg = rawConfig.text_config ?? {};
15
+ const hidden_size = vcfg.hidden_size;
16
+ const num_heads = vcfg.num_attention_heads;
17
+ const head_dim = vcfg.head_dim ?? Math.floor(hidden_size / num_heads);
18
+ const depth = vcfg.num_hidden_layers;
19
+ const intermediate_size = vcfg.intermediate_size;
20
+ const patch_size = vcfg.patch_size;
21
+ const patch_dim = (vcfg.num_channels ?? 3) * patch_size * patch_size;
22
+ const pooling_kernel_size = vcfg.pooling_kernel_size ?? 1;
23
+ const rope_theta = (vcfg.rope_parameters ?? {}).rope_theta ?? 100;
24
+ const rms_norm_eps = vcfg.rms_norm_eps ?? 1e-6;
25
+ return {
26
+ hiddenSize: hidden_size,
27
+ numHeads: num_heads,
28
+ headDim: head_dim,
29
+ depth,
30
+ intermediateSize: intermediate_size,
31
+ textHidden: tcfg.hidden_size ?? hidden_size,
32
+ patchSize: patch_size,
33
+ patchDim: patch_dim,
34
+ poolingKernelSize: pooling_kernel_size,
35
+ ropeTheta: rope_theta,
36
+ rmsNormEps: rms_norm_eps
37
+ };
38
+ }
39
+ /**
40
+ * Dequantize an MLX affine-int4 weight to a plain f32 [rows, cols] matrix.
41
+ * MLX packs 8 int4 values per u32 (low-nibble first); each group of `groupSize`
42
+ * columns shares one scale + bias: w[r,c] = scale[r, c/gs] * q + bias[r, c/gs].
43
+ * Used for the Gemma 4 multimodal projector (`embed_vision.embedding_projection`)
44
+ * in MLX-4bit checkpoints, where (unlike the BF16 ViT body) the projector is int4.
45
+ */
46
+ function dequantizeMLXProjection(packed, scales, biases, rows, cols, groupSize) {
47
+ const out = new Float32Array(rows * cols);
48
+ const numGroups = cols / groupSize;
49
+ const u32PerRow = cols / 8;
50
+ for (let r = 0; r < rows; r++) {
51
+ const rowPacked = r * u32PerRow;
52
+ const rowScale = r * numGroups;
53
+ for (let c = 0; c < cols; c++) {
54
+ const q = packed[rowPacked + (c >> 3)] >>> (c & 7) * 4 & 15;
55
+ const g = Math.floor(c / groupSize);
56
+ out[r * cols + c] = scales[rowScale + g] * q + biases[rowScale + g];
57
+ }
58
+ }
59
+ return out;
60
+ }
61
+ /**
62
+ * If the Gemma 4 multimodal projector arrived as an MLX affine-int4 triplet
63
+ * (`embed_vision.embedding_projection.{weight(U32), scales, biases}`), dequantize
64
+ * it in-place to a plain f32 `embed_vision.embedding_projection.weight` and drop
65
+ * the scales/biases, so the vision graph's plain MatMul on the projector works for
66
+ * MLX-4bit checkpoints too. No-op for BF16 (HF) checkpoints (weight already f32).
67
+ */
68
+ function dequantizeGemma4VisionProjection(weights, groupSize, rows, cols) {
69
+ const wKey = GEMMA4_VIS_KEYS.projW;
70
+ const w = weights.get(wKey);
71
+ const sc = weights.get("embed_vision.embedding_projection.scales");
72
+ const bi = weights.get("embed_vision.embedding_projection.biases");
73
+ if (!w || !(w.data instanceof Uint32Array) || !sc || !bi) return;
74
+ const scales = sc.data instanceof Float32Array ? sc.data : new Float32Array(sc.data.buffer, sc.data.byteOffset, sc.data.byteLength / 4);
75
+ const biases = bi.data instanceof Float32Array ? bi.data : new Float32Array(bi.data.buffer, bi.data.byteOffset, bi.data.byteLength / 4);
76
+ const f32 = dequantizeMLXProjection(w.data, scales, biases, rows, cols, groupSize);
77
+ weights.set(wKey, {
78
+ data: f32,
79
+ shape: [rows, cols]
80
+ });
81
+ weights.delete("embed_vision.embedding_projection.scales");
82
+ weights.delete("embed_vision.embedding_projection.biases");
83
+ }
84
+ /**
85
+ * Ensure the Gemma3n multimodal-embedder norm weights exist in the weights map.
86
+ *
87
+ * `embed_vision.embedding_post_projection_norm` is a no-scale RMSNorm in HF
88
+ * (Gemma3nRMSNorm(..., with_scale=False)), so the checkpoint ships NO weight for
89
+ * it — we synthesize a ones gain (matching the kernel's `(x/rms)*weight`).
90
+ *
91
+ * `embed_vision.soft_embedding_norm` IS a learned RMSNorm; the full BF16 repo ships
92
+ * its gain. Lean exports (e.g. some MLX-4bit conversions) omit it — we fall back to
93
+ * a ones gain so the model still loads and produces coherent (if not bit-exact)
94
+ * output. Call BEFORE VisionExecutor.uploadWeights().
95
+ *
96
+ * Gemma 4 vision RMSNorm stores the FULL gain (no +1 bake), so a ones gain is a
97
+ * true identity scale.
98
+ */
99
+ function ensureGemma4VisionEmbedderNorms(weights, visionHidden, textHidden) {
100
+ const ones = (n) => ({
101
+ data: new Float32Array(n).fill(1),
102
+ shape: [n]
103
+ });
104
+ if (!weights.has(GEMMA4_VIS_KEYS.postProjNormW)) weights.set(GEMMA4_VIS_KEYS.postProjNormW, ones(textHidden));
105
+ if (!weights.has(GEMMA4_VIS_KEYS.softEmbNormW)) weights.set(GEMMA4_VIS_KEYS.softEmbNormW, ones(visionHidden));
106
+ }
107
+ /**
108
+ * Patch the ClippedMatMul nodes of a Gemma 4 vision graph with the calibrated clip
109
+ * scalars from the checkpoint (Gemma4ClippableLinear's per-tensor input/output
110
+ * min/max buffers), then drop those scalar tensors from the weights map so the
111
+ * vision executor doesn't try to upload them as GPU buffers. Call BEFORE
112
+ * VisionExecutor.uploadWeights(). Missing scalars default to ±inf (clip = identity),
113
+ * so a checkpoint without calibration still loads.
114
+ */
115
+ function patchGemma4VisionClips(graph, weights) {
116
+ const readScalar = (key) => {
117
+ if (!key) return void 0;
118
+ const e = weights.get(key);
119
+ if (!e) return void 0;
120
+ const d = e.data instanceof Float32Array ? e.data : new Float32Array(e.data.buffer, e.data.byteOffset, e.data.byteLength / 4);
121
+ return d.length > 0 ? d[0] : void 0;
122
+ };
123
+ for (const node of graph.nodes) {
124
+ if (node.opType !== "ClippedMatMul") continue;
125
+ const a = node.attributes;
126
+ const imin = readScalar(a.clip_input_min_key);
127
+ const imax = readScalar(a.clip_input_max_key);
128
+ const omin = readScalar(a.clip_output_min_key);
129
+ const omax = readScalar(a.clip_output_max_key);
130
+ if (imin !== void 0) a.imin = imin;
131
+ if (imax !== void 0) a.imax = imax;
132
+ if (omin !== void 0) a.omin = omin;
133
+ if (omax !== void 0) a.omax = omax;
134
+ for (const k of [
135
+ a.clip_input_min_key,
136
+ a.clip_input_max_key,
137
+ a.clip_output_min_key,
138
+ a.clip_output_max_key
139
+ ]) if (k) weights.delete(k);
140
+ }
141
+ }
142
+ /**
143
+ * Build the Gemma 4 ViT graph. Shaped by symbolic "N" (number of patches, runtime)
144
+ * and "Np" (number of pooled tokens = ceil(grid_h/k)·ceil(grid_w/k), runtime),
145
+ * resolved from input tensor dims — like the Qwen ViT's "N"/"Nm".
146
+ */
147
+ function generateGemma4VisionGraph(rawConfig) {
148
+ const info = resolveGemma4VisionInfo(rawConfig);
149
+ const { hiddenSize: hidden_size, numHeads: num_heads, headDim: head_dim, depth, intermediateSize: intermediate_size, textHidden: text_hidden, patchDim: patch_dim, rmsNormEps: eps } = info;
150
+ const qkv_dim = num_heads * head_dim;
151
+ const tensors = {};
152
+ const nodes = [];
153
+ const executionOrder = [];
154
+ function addTensor(desc) {
155
+ tensors[desc.name] = desc;
156
+ }
157
+ function addNode(node) {
158
+ nodes.push(node);
159
+ executionOrder.push(node.id);
160
+ }
161
+ function constant(name, shape, safetensorsKey = name) {
162
+ addTensor({
163
+ name,
164
+ shape,
165
+ dtype: "f32",
166
+ storage: "constant",
167
+ safetensorsKey
168
+ });
169
+ }
170
+ function activation(name, shape) {
171
+ addTensor({
172
+ name,
173
+ shape,
174
+ dtype: "f32",
175
+ storage: "activation"
176
+ });
177
+ }
178
+ /** MatMul (no bias): out[M,N] = in[M,K] @ W[N,K]^T. M from `mTensor`. */
179
+ function matmul(id, input, weight, output, K, N, mTensor) {
180
+ addNode({
181
+ id,
182
+ opType: "MatMul",
183
+ inputs: [input, weight],
184
+ outputs: [output],
185
+ attributes: {
186
+ M_tensor: mTensor,
187
+ K,
188
+ N
189
+ }
190
+ });
191
+ }
192
+ /**
193
+ * Gemma4ClippableLinear: out = clamp(clamp(x,imin,imax) @ W^T, omin, omax).
194
+ * The four clip scalars live in the checkpoint as per-tensor tensors under
195
+ * `<projBase>.{input_min,input_max,output_min,output_max}`; the loader reads
196
+ * them and patches the op's imin/imax/omin/omax attributes (and drops the
197
+ * scalar tensors), like Gemma4's layer_scalar. `projBase` is the canonical
198
+ * key prefix of the projection (e.g. `vision_tower.encoder.layers.0.self_attn.q_proj`).
199
+ */
200
+ function clippedMatmul(id, input, weight, output, K, N, mTensor, projBase) {
201
+ addNode({
202
+ id,
203
+ opType: "ClippedMatMul",
204
+ inputs: [input, weight],
205
+ outputs: [output],
206
+ attributes: {
207
+ M_tensor: mTensor,
208
+ K,
209
+ N,
210
+ clip_input_min_key: `${projBase}.input_min`,
211
+ clip_input_max_key: `${projBase}.input_max`,
212
+ clip_output_min_key: `${projBase}.output_min`,
213
+ clip_output_max_key: `${projBase}.output_max`
214
+ }
215
+ });
216
+ }
217
+ /** RMSNorm over `hidden` (per row), Gemma full-gain (no +1; baked elsewhere=none). */
218
+ function rmsnorm(id, input, w, output, hidden, seqTensor) {
219
+ addNode({
220
+ id,
221
+ opType: "RMSNorm",
222
+ inputs: [input, w],
223
+ outputs: [output],
224
+ attributes: {
225
+ hidden_size: hidden,
226
+ eps,
227
+ seq_len_tensor: seqTensor
228
+ }
229
+ });
230
+ }
231
+ function residualAdd(id, a, b, output) {
232
+ addNode({
233
+ id,
234
+ opType: "Add",
235
+ inputs: [a, b],
236
+ outputs: [output],
237
+ attributes: {
238
+ count_tensor: a,
239
+ hidden_size
240
+ }
241
+ });
242
+ }
243
+ activation("g4v_patches", ["N", patch_dim]);
244
+ activation("g4v_pos_embeds", ["N", hidden_size]);
245
+ activation("g4v_cos", ["N", head_dim]);
246
+ activation("g4v_sin", ["N", head_dim]);
247
+ activation("g4v_pool_w", ["Np", "N"]);
248
+ constant(GEMMA4_VIS_KEYS.patchEmbedProjW, [hidden_size, patch_dim]);
249
+ activation("g4v_patch_mm", ["N", hidden_size]);
250
+ activation("g4v_h0", ["N", hidden_size]);
251
+ matmul("g4v_patch_proj", "g4v_patches", GEMMA4_VIS_KEYS.patchEmbedProjW, "g4v_patch_mm", patch_dim, hidden_size, "g4v_patches");
252
+ addNode({
253
+ id: "g4v_add_pos",
254
+ opType: "Add",
255
+ inputs: ["g4v_patch_mm", "g4v_pos_embeds"],
256
+ outputs: ["g4v_h0"],
257
+ attributes: {
258
+ count_tensor: "g4v_patch_mm",
259
+ hidden_size
260
+ }
261
+ });
262
+ let prev = "g4v_h0";
263
+ for (let i = 0; i < depth; i++) {
264
+ const p = `g4v_b${i}`;
265
+ constant(GEMMA4_VIS_KEYS.inputNorm(i), [hidden_size]);
266
+ activation(`${p}_n1`, ["N", hidden_size]);
267
+ rmsnorm(`${p}_input_norm`, prev, GEMMA4_VIS_KEYS.inputNorm(i), `${p}_n1`, hidden_size, prev);
268
+ constant(GEMMA4_VIS_KEYS.qProjW(i), [qkv_dim, hidden_size]);
269
+ constant(GEMMA4_VIS_KEYS.kProjW(i), [qkv_dim, hidden_size]);
270
+ constant(GEMMA4_VIS_KEYS.vProjW(i), [qkv_dim, hidden_size]);
271
+ activation(`${p}_q`, ["N", qkv_dim]);
272
+ activation(`${p}_k`, ["N", qkv_dim]);
273
+ activation(`${p}_v`, ["N", qkv_dim]);
274
+ const saBase = `vision_tower.encoder.layers.${i}.self_attn`;
275
+ clippedMatmul(`${p}_q_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.qProjW(i), `${p}_q`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.q_proj`);
276
+ clippedMatmul(`${p}_k_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.kProjW(i), `${p}_k`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.k_proj`);
277
+ clippedMatmul(`${p}_v_proj`, `${p}_n1`, GEMMA4_VIS_KEYS.vProjW(i), `${p}_v`, hidden_size, qkv_dim, `${p}_n1`, `${saBase}.v_proj`);
278
+ constant(GEMMA4_VIS_KEYS.qNormW(i), [head_dim]);
279
+ constant(GEMMA4_VIS_KEYS.kNormW(i), [head_dim]);
280
+ activation(`${p}_qn`, ["N", qkv_dim]);
281
+ activation(`${p}_kn`, ["N", qkv_dim]);
282
+ rmsnorm(`${p}_q_norm`, `${p}_q`, GEMMA4_VIS_KEYS.qNormW(i), `${p}_qn`, head_dim, `${p}_q`);
283
+ rmsnorm(`${p}_k_norm`, `${p}_k`, GEMMA4_VIS_KEYS.kNormW(i), `${p}_kn`, head_dim, `${p}_k`);
284
+ activation(`${p}_qr`, ["N", qkv_dim]);
285
+ activation(`${p}_kr`, ["N", qkv_dim]);
286
+ addNode({
287
+ id: `${p}_rope_q`,
288
+ opType: "ApplyRotaryEmb",
289
+ inputs: [
290
+ `${p}_qn`,
291
+ "g4v_cos",
292
+ "g4v_sin"
293
+ ],
294
+ outputs: [`${p}_qr`],
295
+ attributes: {
296
+ num_heads,
297
+ head_dim
298
+ }
299
+ });
300
+ addNode({
301
+ id: `${p}_rope_k`,
302
+ opType: "ApplyRotaryEmb",
303
+ inputs: [
304
+ `${p}_kn`,
305
+ "g4v_cos",
306
+ "g4v_sin"
307
+ ],
308
+ outputs: [`${p}_kr`],
309
+ attributes: {
310
+ num_heads,
311
+ head_dim
312
+ }
313
+ });
314
+ activation(`${p}_attn`, ["N", qkv_dim]);
315
+ addNode({
316
+ id: `${p}_attention`,
317
+ opType: "Attention",
318
+ inputs: [
319
+ `${p}_qr`,
320
+ `${p}_kr`,
321
+ `${p}_v`
322
+ ],
323
+ outputs: [`${p}_attn`],
324
+ attributes: {
325
+ num_q_heads: num_heads,
326
+ num_kv_heads: num_heads,
327
+ head_dim,
328
+ causal: false
329
+ }
330
+ });
331
+ constant(GEMMA4_VIS_KEYS.oProjW(i), [hidden_size, qkv_dim]);
332
+ activation(`${p}_o`, ["N", hidden_size]);
333
+ clippedMatmul(`${p}_o_proj`, `${p}_attn`, GEMMA4_VIS_KEYS.oProjW(i), `${p}_o`, qkv_dim, hidden_size, `${p}_attn`, `${saBase}.o_proj`);
334
+ constant(GEMMA4_VIS_KEYS.postAttnNorm(i), [hidden_size]);
335
+ activation(`${p}_post_attn`, ["N", hidden_size]);
336
+ rmsnorm(`${p}_post_attn_norm`, `${p}_o`, GEMMA4_VIS_KEYS.postAttnNorm(i), `${p}_post_attn`, hidden_size, `${p}_o`);
337
+ activation(`${p}_res1`, ["N", hidden_size]);
338
+ residualAdd(`${p}_residual1`, prev, `${p}_post_attn`, `${p}_res1`);
339
+ constant(GEMMA4_VIS_KEYS.preFfNorm(i), [hidden_size]);
340
+ activation(`${p}_n2`, ["N", hidden_size]);
341
+ rmsnorm(`${p}_pre_ff_norm`, `${p}_res1`, GEMMA4_VIS_KEYS.preFfNorm(i), `${p}_n2`, hidden_size, `${p}_res1`);
342
+ constant(GEMMA4_VIS_KEYS.gateProjW(i), [intermediate_size, hidden_size]);
343
+ constant(GEMMA4_VIS_KEYS.upProjW(i), [intermediate_size, hidden_size]);
344
+ constant(GEMMA4_VIS_KEYS.downProjW(i), [hidden_size, intermediate_size]);
345
+ activation(`${p}_gate`, ["N", intermediate_size]);
346
+ activation(`${p}_up`, ["N", intermediate_size]);
347
+ activation(`${p}_gelu`, ["N", intermediate_size]);
348
+ activation(`${p}_geglu`, ["N", intermediate_size]);
349
+ activation(`${p}_mlp`, ["N", hidden_size]);
350
+ const mlpBase = `vision_tower.encoder.layers.${i}.mlp`;
351
+ clippedMatmul(`${p}_gate_proj`, `${p}_n2`, GEMMA4_VIS_KEYS.gateProjW(i), `${p}_gate`, hidden_size, intermediate_size, `${p}_n2`, `${mlpBase}.gate_proj`);
352
+ clippedMatmul(`${p}_up_proj`, `${p}_n2`, GEMMA4_VIS_KEYS.upProjW(i), `${p}_up`, hidden_size, intermediate_size, `${p}_n2`, `${mlpBase}.up_proj`);
353
+ addNode({
354
+ id: `${p}_gelu_op`,
355
+ opType: "GELU",
356
+ inputs: [`${p}_gate`],
357
+ outputs: [`${p}_gelu`],
358
+ attributes: { count_tensor: `${p}_gate` }
359
+ });
360
+ addNode({
361
+ id: `${p}_geglu_mul`,
362
+ opType: "Mul",
363
+ inputs: [`${p}_gelu`, `${p}_up`],
364
+ outputs: [`${p}_geglu`],
365
+ attributes: {
366
+ count_tensor: `${p}_gelu`,
367
+ hidden_size: intermediate_size
368
+ }
369
+ });
370
+ clippedMatmul(`${p}_down_proj`, `${p}_geglu`, GEMMA4_VIS_KEYS.downProjW(i), `${p}_mlp`, intermediate_size, hidden_size, `${p}_geglu`, `${mlpBase}.down_proj`);
371
+ constant(GEMMA4_VIS_KEYS.postFfNorm(i), [hidden_size]);
372
+ activation(`${p}_post_ff`, ["N", hidden_size]);
373
+ rmsnorm(`${p}_post_ff_norm`, `${p}_mlp`, GEMMA4_VIS_KEYS.postFfNorm(i), `${p}_post_ff`, hidden_size, `${p}_res1`);
374
+ activation(`${p}_res2`, ["N", hidden_size]);
375
+ residualAdd(`${p}_residual2`, `${p}_res1`, `${p}_post_ff`, `${p}_res2`);
376
+ prev = `${p}_res2`;
377
+ }
378
+ activation("g4v_pooled", ["Np", hidden_size]);
379
+ addNode({
380
+ id: "g4v_pool",
381
+ opType: "PoolMatMul",
382
+ inputs: ["g4v_pool_w", prev],
383
+ outputs: ["g4v_pooled"],
384
+ attributes: {
385
+ K_tensor: prev,
386
+ hidden_size,
387
+ np_tensor: "g4v_pool_w"
388
+ }
389
+ });
390
+ constant(GEMMA4_VIS_KEYS.softEmbNormW, [hidden_size]);
391
+ activation("g4v_soft_normed", ["Np", hidden_size]);
392
+ rmsnorm("g4v_soft_emb_norm", "g4v_pooled", GEMMA4_VIS_KEYS.softEmbNormW, "g4v_soft_normed", hidden_size, "g4v_pooled");
393
+ constant(GEMMA4_VIS_KEYS.projW, [text_hidden, hidden_size]);
394
+ activation("g4v_proj_out", ["Np", text_hidden]);
395
+ matmul("g4v_proj", "g4v_soft_normed", GEMMA4_VIS_KEYS.projW, "g4v_proj_out", hidden_size, text_hidden, "g4v_proj_out");
396
+ constant(GEMMA4_VIS_KEYS.postProjNormW, [text_hidden]);
397
+ activation("g4v_image_embeds", ["Np", text_hidden]);
398
+ rmsnorm("g4v_post_proj_norm", "g4v_proj_out", GEMMA4_VIS_KEYS.postProjNormW, "g4v_image_embeds", text_hidden, "g4v_proj_out");
399
+ return {
400
+ architecture: "Gemma4VisionModel",
401
+ config: {
402
+ hidden_size,
403
+ num_layers: depth,
404
+ num_heads,
405
+ num_kv_heads: num_heads,
406
+ head_dim,
407
+ intermediate_size,
408
+ vocab_size: 0,
409
+ context_length: rawConfig.vision_config?.max_position_embeddings ?? 0,
410
+ rms_norm_eps: eps,
411
+ norm_type: "rmsnorm",
412
+ rope_base: info.ropeTheta,
413
+ rope_dim: head_dim,
414
+ kv_layout: "LHSd",
415
+ is_moe: false,
416
+ has_vision_tower: true,
417
+ vision_architecture: "gemma4_vision",
418
+ vision_patch_size: info.patchSize,
419
+ vision_embed_dim: hidden_size
420
+ },
421
+ capabilities: {
422
+ text: true,
423
+ vision: true,
424
+ moe: false
425
+ },
426
+ tensors,
427
+ nodes,
428
+ executionOrder,
429
+ inputs: [
430
+ "g4v_patches",
431
+ "g4v_pos_embeds",
432
+ "g4v_cos",
433
+ "g4v_sin",
434
+ "g4v_pool_w"
435
+ ],
436
+ outputs: ["g4v_image_embeds"]
437
+ };
438
+ }
439
+
440
+ //#endregion
441
+ //#region src/gpu/architectures/qwen3_5_vision.ts
442
+ const VIS_NORM_EPS = 1e-6;
443
+ /**
444
+ * Build the ViT graph. The graph is shaped by symbolic "N" (number of patches),
445
+ * resolved at run time from the input tensor's first dim — exactly like the LM's
446
+ * symbolic "T".
447
+ */
448
+ function generateQwen3_5VisionGraph(rawConfig) {
449
+ const vcfg = rawConfig.vision_config ?? rawConfig;
450
+ const hidden_size = vcfg.hidden_size;
451
+ const num_heads = vcfg.num_heads;
452
+ const head_dim = Math.floor(hidden_size / num_heads);
453
+ const depth = vcfg.depth;
454
+ const intermediate_size = vcfg.intermediate_size;
455
+ const out_hidden_size = vcfg.out_hidden_size;
456
+ const spatial_merge_size = vcfg.spatial_merge_size;
457
+ const in_channels = vcfg.in_channels;
458
+ const temporal_patch_size = vcfg.temporal_patch_size;
459
+ const patch_size = vcfg.patch_size;
460
+ const patch_dim = in_channels * temporal_patch_size * patch_size * patch_size;
461
+ const merged_in = hidden_size * (spatial_merge_size * spatial_merge_size);
462
+ const tensors = {};
463
+ const nodes = [];
464
+ const executionOrder = [];
465
+ function addTensor(desc) {
466
+ tensors[desc.name] = desc;
467
+ }
468
+ function addNode(node) {
469
+ nodes.push(node);
470
+ executionOrder.push(node.id);
471
+ }
472
+ function constant(name, shape) {
473
+ addTensor({
474
+ name,
475
+ shape,
476
+ dtype: "f32",
477
+ storage: "constant",
478
+ safetensorsKey: name
479
+ });
480
+ }
481
+ function activation(name, shape) {
482
+ addTensor({
483
+ name,
484
+ shape,
485
+ dtype: "f32",
486
+ storage: "activation"
487
+ });
488
+ }
489
+ /**
490
+ * Fused MatMul + row-broadcast bias → out = in @ W^T + bias. Replaces a MatMul
491
+ * followed by a separate AddBias, removing a full read+write of the matmul
492
+ * output per ViT linear layer.
493
+ */
494
+ function matmulBias(id, input, weight, bias, output, K, N, mTensor) {
495
+ addNode({
496
+ id,
497
+ opType: "MatMulBias",
498
+ inputs: [
499
+ input,
500
+ weight,
501
+ bias
502
+ ],
503
+ outputs: [output],
504
+ attributes: {
505
+ M_tensor: mTensor,
506
+ K,
507
+ N
508
+ }
509
+ });
510
+ }
511
+ function layernorm(id, input, w, b, output, hidden, seqTensor) {
512
+ addNode({
513
+ id,
514
+ opType: "LayerNorm",
515
+ inputs: [
516
+ input,
517
+ w,
518
+ b
519
+ ],
520
+ outputs: [output],
521
+ attributes: {
522
+ hidden_size: hidden,
523
+ eps: VIS_NORM_EPS,
524
+ seq_len_tensor: seqTensor
525
+ }
526
+ });
527
+ }
528
+ activation("vis_patches", ["N", patch_dim]);
529
+ activation("vis_pos_embeds", ["N", hidden_size]);
530
+ activation("vis_cos", ["N", head_dim]);
531
+ activation("vis_sin", ["N", head_dim]);
532
+ constant(CANONICAL_KEYS.visPatchEmbedWeight, [hidden_size, patch_dim]);
533
+ constant(CANONICAL_KEYS.visPatchEmbedBias, [hidden_size]);
534
+ activation("vis_patch_mm", ["N", hidden_size]);
535
+ activation("vis_patch_bias", ["N", hidden_size]);
536
+ activation("vis_h0", ["N", hidden_size]);
537
+ matmulBias("vis_patch_mm", "vis_patches", CANONICAL_KEYS.visPatchEmbedWeight, CANONICAL_KEYS.visPatchEmbedBias, "vis_patch_bias", patch_dim, hidden_size, "vis_patches");
538
+ addNode({
539
+ id: "vis_add_pos",
540
+ opType: "Add",
541
+ inputs: ["vis_patch_bias", "vis_pos_embeds"],
542
+ outputs: ["vis_h0"],
543
+ attributes: {
544
+ count_tensor: "vis_patch_bias",
545
+ hidden_size
546
+ }
547
+ });
548
+ let prev = "vis_h0";
549
+ const qkv_dim = hidden_size * 3;
550
+ for (let i = 0; i < depth; i++) {
551
+ const p = `vis_b${i}`;
552
+ constant(CANONICAL_KEYS.visBlockNorm1W(i), [hidden_size]);
553
+ constant(CANONICAL_KEYS.visBlockNorm1B(i), [hidden_size]);
554
+ activation(`${p}_n1`, ["N", hidden_size]);
555
+ layernorm(`${p}_norm1`, prev, CANONICAL_KEYS.visBlockNorm1W(i), CANONICAL_KEYS.visBlockNorm1B(i), `${p}_n1`, hidden_size, prev);
556
+ constant(CANONICAL_KEYS.visBlockQkvW(i), [qkv_dim, hidden_size]);
557
+ constant(CANONICAL_KEYS.visBlockQkvB(i), [qkv_dim]);
558
+ activation(`${p}_qkv_mm`, ["N", qkv_dim]);
559
+ activation(`${p}_qkv`, ["N", qkv_dim]);
560
+ matmulBias(`${p}_qkv_proj`, `${p}_n1`, CANONICAL_KEYS.visBlockQkvW(i), CANONICAL_KEYS.visBlockQkvB(i), `${p}_qkv`, hidden_size, qkv_dim, `${p}_n1`);
561
+ activation(`${p}_q`, ["N", hidden_size]);
562
+ activation(`${p}_k`, ["N", hidden_size]);
563
+ activation(`${p}_v`, ["N", hidden_size]);
564
+ const sliceCols = (id, output, offset) => addNode({
565
+ id,
566
+ opType: "SliceCols",
567
+ inputs: [`${p}_qkv`],
568
+ outputs: [output],
569
+ attributes: {
570
+ in_width: qkv_dim,
571
+ out_width: hidden_size,
572
+ col_offset: offset,
573
+ seq_len_tensor: `${p}_qkv`
574
+ }
575
+ });
576
+ sliceCols(`${p}_split_q`, `${p}_q`, 0);
577
+ sliceCols(`${p}_split_k`, `${p}_k`, hidden_size);
578
+ sliceCols(`${p}_split_v`, `${p}_v`, 2 * hidden_size);
579
+ activation(`${p}_qr`, ["N", hidden_size]);
580
+ activation(`${p}_kr`, ["N", hidden_size]);
581
+ addNode({
582
+ id: `${p}_rope_q`,
583
+ opType: "ApplyRotaryEmb",
584
+ inputs: [
585
+ `${p}_q`,
586
+ "vis_cos",
587
+ "vis_sin"
588
+ ],
589
+ outputs: [`${p}_qr`],
590
+ attributes: {
591
+ num_heads,
592
+ head_dim
593
+ }
594
+ });
595
+ addNode({
596
+ id: `${p}_rope_k`,
597
+ opType: "ApplyRotaryEmb",
598
+ inputs: [
599
+ `${p}_k`,
600
+ "vis_cos",
601
+ "vis_sin"
602
+ ],
603
+ outputs: [`${p}_kr`],
604
+ attributes: {
605
+ num_heads,
606
+ head_dim
607
+ }
608
+ });
609
+ activation(`${p}_attn`, ["N", hidden_size]);
610
+ addNode({
611
+ id: `${p}_attention`,
612
+ opType: "Attention",
613
+ inputs: [
614
+ `${p}_qr`,
615
+ `${p}_kr`,
616
+ `${p}_v`
617
+ ],
618
+ outputs: [`${p}_attn`],
619
+ attributes: {
620
+ num_q_heads: num_heads,
621
+ num_kv_heads: num_heads,
622
+ head_dim,
623
+ causal: false
624
+ }
625
+ });
626
+ constant(CANONICAL_KEYS.visBlockProjW(i), [hidden_size, hidden_size]);
627
+ constant(CANONICAL_KEYS.visBlockProjB(i), [hidden_size]);
628
+ activation(`${p}_proj_mm`, ["N", hidden_size]);
629
+ activation(`${p}_proj`, ["N", hidden_size]);
630
+ matmulBias(`${p}_proj_op`, `${p}_attn`, CANONICAL_KEYS.visBlockProjW(i), CANONICAL_KEYS.visBlockProjB(i), `${p}_proj`, hidden_size, hidden_size, `${p}_attn`);
631
+ activation(`${p}_res1`, ["N", hidden_size]);
632
+ addNode({
633
+ id: `${p}_residual1`,
634
+ opType: "Add",
635
+ inputs: [prev, `${p}_proj`],
636
+ outputs: [`${p}_res1`],
637
+ attributes: {
638
+ count_tensor: prev,
639
+ hidden_size
640
+ }
641
+ });
642
+ constant(CANONICAL_KEYS.visBlockNorm2W(i), [hidden_size]);
643
+ constant(CANONICAL_KEYS.visBlockNorm2B(i), [hidden_size]);
644
+ constant(CANONICAL_KEYS.visBlockFc1W(i), [intermediate_size, hidden_size]);
645
+ constant(CANONICAL_KEYS.visBlockFc1B(i), [intermediate_size]);
646
+ constant(CANONICAL_KEYS.visBlockFc2W(i), [hidden_size, intermediate_size]);
647
+ constant(CANONICAL_KEYS.visBlockFc2B(i), [hidden_size]);
648
+ activation(`${p}_n2`, ["N", hidden_size]);
649
+ activation(`${p}_fc1_mm`, ["N", intermediate_size]);
650
+ activation(`${p}_fc1`, ["N", intermediate_size]);
651
+ activation(`${p}_gelu`, ["N", intermediate_size]);
652
+ activation(`${p}_fc2_mm`, ["N", hidden_size]);
653
+ activation(`${p}_fc2`, ["N", hidden_size]);
654
+ activation(`${p}_res2`, ["N", hidden_size]);
655
+ layernorm(`${p}_norm2`, `${p}_res1`, CANONICAL_KEYS.visBlockNorm2W(i), CANONICAL_KEYS.visBlockNorm2B(i), `${p}_n2`, hidden_size, `${p}_res1`);
656
+ matmulBias(`${p}_fc1_op`, `${p}_n2`, CANONICAL_KEYS.visBlockFc1W(i), CANONICAL_KEYS.visBlockFc1B(i), `${p}_fc1`, hidden_size, intermediate_size, `${p}_n2`);
657
+ addNode({
658
+ id: `${p}_gelu_op`,
659
+ opType: "GELU",
660
+ inputs: [`${p}_fc1`],
661
+ outputs: [`${p}_gelu`],
662
+ attributes: { count_tensor: `${p}_fc1` }
663
+ });
664
+ matmulBias(`${p}_fc2_op`, `${p}_gelu`, CANONICAL_KEYS.visBlockFc2W(i), CANONICAL_KEYS.visBlockFc2B(i), `${p}_fc2`, intermediate_size, hidden_size, `${p}_gelu`);
665
+ addNode({
666
+ id: `${p}_residual2`,
667
+ opType: "Add",
668
+ inputs: [`${p}_res1`, `${p}_fc2`],
669
+ outputs: [`${p}_res2`],
670
+ attributes: {
671
+ count_tensor: `${p}_res1`,
672
+ hidden_size
673
+ }
674
+ });
675
+ prev = `${p}_res2`;
676
+ }
677
+ constant(CANONICAL_KEYS.visMergerNormW, [hidden_size]);
678
+ constant(CANONICAL_KEYS.visMergerNormB, [hidden_size]);
679
+ constant(CANONICAL_KEYS.visMergerFc1W, [merged_in, merged_in]);
680
+ constant(CANONICAL_KEYS.visMergerFc1B, [merged_in]);
681
+ constant(CANONICAL_KEYS.visMergerFc2W, [out_hidden_size, merged_in]);
682
+ constant(CANONICAL_KEYS.visMergerFc2B, [out_hidden_size]);
683
+ activation("vis_merge_norm", ["N", hidden_size]);
684
+ activation("vis_merge_fc1_mm", ["Nm", merged_in]);
685
+ activation("vis_merge_fc1", ["Nm", merged_in]);
686
+ activation("vis_merge_gelu", ["Nm", merged_in]);
687
+ activation("vis_merge_fc2_mm", ["Nm", out_hidden_size]);
688
+ activation("vis_image_embeds", ["Nm", out_hidden_size]);
689
+ layernorm("vis_merger_norm", prev, CANONICAL_KEYS.visMergerNormW, CANONICAL_KEYS.visMergerNormB, "vis_merge_norm", hidden_size, prev);
690
+ matmulBias("vis_merger_fc1", "vis_merge_norm", CANONICAL_KEYS.visMergerFc1W, CANONICAL_KEYS.visMergerFc1B, "vis_merge_fc1", merged_in, merged_in, "vis_merge_fc1");
691
+ addNode({
692
+ id: "vis_merger_gelu",
693
+ opType: "GeluErf",
694
+ inputs: ["vis_merge_fc1"],
695
+ outputs: ["vis_merge_gelu"],
696
+ attributes: { count_tensor: "vis_merge_fc1" }
697
+ });
698
+ matmulBias("vis_merger_fc2", "vis_merge_gelu", CANONICAL_KEYS.visMergerFc2W, CANONICAL_KEYS.visMergerFc2B, "vis_image_embeds", merged_in, out_hidden_size, "vis_image_embeds");
699
+ return {
700
+ architecture: "Qwen3_5VisionModel",
701
+ config: {
702
+ hidden_size,
703
+ num_layers: depth,
704
+ num_heads,
705
+ num_kv_heads: num_heads,
706
+ head_dim,
707
+ intermediate_size,
708
+ vocab_size: 0,
709
+ context_length: vcfg.num_position_embeddings,
710
+ rms_norm_eps: VIS_NORM_EPS,
711
+ norm_type: "layernorm",
712
+ rope_base: 1e4,
713
+ rope_dim: head_dim,
714
+ kv_layout: "LHSd",
715
+ is_moe: false,
716
+ has_vision_tower: true,
717
+ vision_architecture: "qwen3_5_vit",
718
+ vision_patch_size: patch_size,
719
+ vision_embed_dim: hidden_size
720
+ },
721
+ capabilities: {
722
+ text: true,
723
+ vision: true,
724
+ moe: false
725
+ },
726
+ tensors,
727
+ nodes,
728
+ executionOrder,
729
+ inputs: [
730
+ "vis_patches",
731
+ "vis_pos_embeds",
732
+ "vis_cos",
733
+ "vis_sin"
734
+ ],
735
+ outputs: ["vis_image_embeds"]
736
+ };
737
+ }
738
+
739
+ //#endregion
740
+ //#region src/gpu/kani-tts.ts
741
+ /**
742
+ * KaniTTS — native text-to-speech engine for Gerbil's WebGPU backend.
743
+ *
744
+ * Kani-TTS-2 (nineninesix/kani-tts-2-en) is a two-stage TTS model that, like
745
+ * Moonshine, needs more than one graph:
746
+ *
747
+ * 1. CODEC-LM BACKBONE (LFM2-350M body): autoregressively emits NanoCodec audio
748
+ * tokens (4 per frame) into the same vocab as text. Reuses LFM2's block math
749
+ * with two KaniTTS2 deltas — frame-level position IDs (the 4 audio tokens of a
750
+ * frame share a position) and learnable per-layer RoPE (α^(l)-scaled freqs) —
751
+ * both folded host-side into per-layer cos/sin tables fed to the MRoPE op.
752
+ * 2. NANOCODEC DECODER (NVIDIA NeMo 22 kHz): FSQ dequant + causal HiFi-GAN conv
753
+ * decoder → 22 kHz PCM. Validated bit-exact vs MLX (test-nanocodec-decode.mjs).
754
+ *
755
+ * The AR loop runs on the host with full-logit readback so each frame's 4 audio
756
+ * tokens are sampled per-codebook (constrained to the valid codebook window), then
757
+ * the collected codes are decoded once through the NanoCodec graph.
758
+ *
759
+ * Validated on Dawn (desktop) via scripts/engine/test-kani-speak.mjs.
760
+ */
761
+ const KANI_SAMPLE_RATE = 22050;
762
+ const KANI_HOP = 1764;
763
+ const KANI_END_OF_TEXT = 2;
764
+ /** NanoCodec decode chunking (keeps per-dispatch under WebGPU's 65535 cap). */
765
+ const KANI_CODEC_CHUNK_FRAMES = 32;
766
+ const KANI_CODEC_LOOKBACK_FRAMES = 32;
767
+ /** Hard cap so a stuck decode cannot loop forever (reference uses 3000). */
768
+ const DEFAULT_MAX_NEW_TOKENS = 3e3;
769
+ /** Build a fresh map of just the constant weights a graph references (see moonshine-stt). */
770
+ function selectGraphWeights(graph, weights) {
771
+ const out = /* @__PURE__ */ new Map();
772
+ for (const [name, desc] of Object.entries(graph.tensors)) {
773
+ if (desc.storage !== "constant") continue;
774
+ const w = weights.get(name) ?? (desc.safetensorsKey ? weights.get(desc.safetensorsKey) : void 0);
775
+ if (w) out.set(name, w);
776
+ }
777
+ return out;
778
+ }
779
+ /**
780
+ * Top-p nucleus sample from a (already softmaxed) probability array over `ids`.
781
+ * Sorts by prob desc, keeps the smallest prefix whose cumulative mass ≥ topP,
782
+ * renormalizes, and samples one id.
783
+ */
784
+ function nucleusSample(probs, ids, topP) {
785
+ const n = probs.length;
786
+ const order = Array.from({ length: n }, (_, i) => i).sort((a, b) => probs[b] - probs[a]);
787
+ let cum = 0;
788
+ let cutoff = n;
789
+ for (let r = 0; r < n; r++) {
790
+ cum += probs[order[r]];
791
+ if (cum >= topP) {
792
+ cutoff = r + 1;
793
+ break;
794
+ }
795
+ }
796
+ let keepSum = 0;
797
+ for (let r = 0; r < cutoff; r++) keepSum += probs[order[r]];
798
+ let pick = Math.random() * keepSum;
799
+ for (let r = 0; r < cutoff; r++) {
800
+ pick -= probs[order[r]];
801
+ if (pick <= 0) return ids[order[r]];
802
+ }
803
+ return ids[order[0]];
804
+ }
805
+ /** In-place softmax of `scores`. */
806
+ function softmaxInPlace(scores) {
807
+ let maxV = Number.NEGATIVE_INFINITY;
808
+ for (let i = 0; i < scores.length; i++) if (scores[i] > maxV) maxV = scores[i];
809
+ let sum = 0;
810
+ for (let i = 0; i < scores.length; i++) {
811
+ const e = Math.exp(scores[i] - maxV);
812
+ scores[i] = e;
813
+ sum += e;
814
+ }
815
+ for (let i = 0; i < scores.length; i++) scores[i] /= sum;
816
+ }
817
+ var KaniTTS = class KaniTTS {
818
+ ctx;
819
+ loaded;
820
+ tokenizer;
821
+ cfg;
822
+ rawConfig;
823
+ maxSeqLen;
824
+ /** Backbone executor (built once; reused across speak() calls). */
825
+ backboneExec;
826
+ /** The attention layer indices (carry learnable α) and their α values. */
827
+ attnLayers;
828
+ layerAlpha;
829
+ headDim;
830
+ ropeBase;
831
+ _destroyed = false;
832
+ architecture = "KaniTTS2ForCausalLM";
833
+ constructor(ctx, loaded, maxSeqLen) {
834
+ this.ctx = ctx;
835
+ this.loaded = loaded;
836
+ this.tokenizer = loaded.tokenizer;
837
+ this.rawConfig = loaded.rawConfig;
838
+ this.cfg = parseKaniConfig(loaded.rawConfig);
839
+ this.maxSeqLen = maxSeqLen;
840
+ const hidden = this.rawConfig.hidden_size;
841
+ const heads = this.rawConfig.num_attention_heads;
842
+ this.headDim = this.rawConfig.head_dim ?? Math.floor(hidden / heads);
843
+ this.ropeBase = this.rawConfig.rope_theta ?? 1e6;
844
+ this.attnLayers = kaniAttentionLayerIndices(this.rawConfig);
845
+ this.layerAlpha = /* @__PURE__ */ new Map();
846
+ const useLearnable = this.rawConfig.use_learnable_rope ?? false;
847
+ for (const layer of this.attnLayers) {
848
+ let alpha = 1;
849
+ if (useLearnable) {
850
+ const w = loaded.backboneWeights.get(`learnable_rope_layers.${layer}.alpha_weight`);
851
+ alpha = kaniLayerAlpha(w ? w.data[0] ?? 0 : 0, this.cfg);
852
+ }
853
+ this.layerAlpha.set(layer, alpha);
854
+ }
855
+ const graph = generateKaniTtsGraph(this.rawConfig);
856
+ this.backboneExec = new Executor(ctx, graph, {
857
+ maxSeqLen,
858
+ kvMode: "f32"
859
+ });
860
+ this.backboneExec.uploadWeightsMap(selectGraphWeights(graph, loaded.backboneWeights));
861
+ this.backboneExec.initBindGroups();
862
+ }
863
+ static async create(options = {}) {
864
+ const ctx = await initGPU();
865
+ const loaded = await loadKaniTTS({
866
+ repo: options.repo,
867
+ codecRepo: options.codecRepo,
868
+ revision: options.revision,
869
+ hfToken: options.hfToken,
870
+ cacheDir: options.cacheDir,
871
+ onProgress: options.onProgress
872
+ });
873
+ const ctxLen = loaded.rawConfig.max_position_embeddings ?? 4096;
874
+ return new KaniTTS(ctx, loaded, Math.min(options.maxSeqLen ?? 2048, ctxLen));
875
+ }
876
+ /** Write per-layer cos/sin for token rows [rowStart, rowStart+positions.length). */
877
+ writeCosSin(positions, rowStart) {
878
+ const rowBytes = this.headDim * 4;
879
+ for (const layer of this.attnLayers) {
880
+ const alpha = this.layerAlpha.get(layer) ?? 1;
881
+ const { cos, sin } = buildKaniLayerCosSin(positions, this.headDim, this.ropeBase, alpha);
882
+ if (rowStart === 0) {
883
+ this.backboneExec.writeInput(kaniCosTensor(layer), cos);
884
+ this.backboneExec.writeInput(kaniSinTensor(layer), sin);
885
+ } else {
886
+ this.backboneExec.writeInputAt(kaniCosTensor(layer), cos, rowStart * rowBytes);
887
+ this.backboneExec.writeInputAt(kaniSinTensor(layer), sin, rowStart * rowBytes);
888
+ }
889
+ }
890
+ }
891
+ /**
892
+ * Synthesize speech for `text`. Returns 22 kHz mono PCM.
893
+ *
894
+ * Pipeline: build the [SOH]+text+[EOT,EOH] prompt → prefill the codec-LM →
895
+ * AR-decode 4-token frames (per-codebook constrained sampling) until end_of_speech
896
+ * → strip markers → codes → NanoCodec decode → PCM.
897
+ */
898
+ async speak(text, opts = {}) {
899
+ if (this._destroyed) throw new Error("KaniTTS has been destroyed.");
900
+ const temperature = opts.temperature ?? 1;
901
+ const topP = opts.topP ?? .95;
902
+ const repetitionPenalty = opts.repetitionPenalty ?? 1.1;
903
+ const tag = opts.languageTag ?? "en_us";
904
+ const maxNewTokens = Math.min(opts.maxNewTokens ?? DEFAULT_MAX_NEW_TOKENS, this.maxSeqLen);
905
+ const prompt = [
906
+ KANI_START_OF_HUMAN,
907
+ ...this.tokenizer.encode(`${tag.trim()}: ${text}`),
908
+ KANI_END_OF_TEXT,
909
+ KANI_END_OF_HUMAN
910
+ ];
911
+ if (prompt.length >= this.maxSeqLen) throw new Error(`Kani prompt is ${prompt.length} tokens >= maxSeqLen ${this.maxSeqLen}.`);
912
+ this.backboneExec.reset();
913
+ const promptPos = computeKaniPositions(prompt, this.cfg);
914
+ this.writeCosSin(promptPos, 0);
915
+ const { logits } = await this.backboneExec.forward(new Uint32Array(prompt));
916
+ const startTime = performance.now();
917
+ const { audioTokens, finished } = await this.runDecodeLoop(prompt, logits, {
918
+ temperature,
919
+ topP,
920
+ repetitionPenalty,
921
+ maxNewTokens,
922
+ maxFrames: opts.maxFrames ?? Number.POSITIVE_INFINITY
923
+ });
924
+ const usable = audioTokens.length - audioTokens.length % this.cfg.tokensPerFrame;
925
+ if (usable === 0) throw new Error(`Kani produced no complete audio frames (finished=${finished}, audioTokens=${audioTokens.length}). The codec-LM emitted no speech.`);
926
+ const { codes, numFrames } = audioTokensToCodes(audioTokens.slice(0, usable), this.cfg);
927
+ const pcm = await this.decodeCodes(codes, numFrames);
928
+ const totalTime = (performance.now() - startTime) / 1e3;
929
+ const audioSeconds = numFrames * KANI_HOP / KANI_SAMPLE_RATE;
930
+ console.log(`[kani] frames=${numFrames} audio=${audioSeconds.toFixed(2)}s wall=${totalTime.toFixed(2)}s RTF=${(audioSeconds / totalTime).toFixed(2)}x`);
931
+ return {
932
+ pcm,
933
+ sampleRate: KANI_SAMPLE_RATE,
934
+ frames: numFrames,
935
+ audioSeconds
936
+ };
937
+ }
938
+ /**
939
+ * Autoregressive decode: from the prefill logits, emit one token per step —
940
+ * greedy for the structural markers ([SOA][SOS]) and per-codebook constrained
941
+ * sampling once in speech — collecting the audio tokens between SOS and EOS.
942
+ * Writes each step's per-layer cos/sin row (frame-level logical position) before
943
+ * the forward. Returns the collected audio tokens and whether EOS/cap was hit.
944
+ */
945
+ async runDecodeLoop(prompt, prefillLogits, p) {
946
+ const seq = [...prompt];
947
+ const audioTokens = [];
948
+ const frameCap = p.maxFrames * this.cfg.tokensPerFrame;
949
+ let logits = prefillLogits;
950
+ let inSpeech = false;
951
+ let finished = false;
952
+ for (let step = 0; step < p.maxNewTokens; step++) {
953
+ const nextToken = inSpeech ? this.sampleAudioToken(logits, audioTokens.length % this.cfg.tokensPerFrame, {
954
+ temperature: p.temperature,
955
+ topP: p.topP,
956
+ repetitionPenalty: p.repetitionPenalty,
957
+ previous: seq
958
+ }) : this.argmax(logits);
959
+ if (nextToken === this.cfg.startOfSpeech) inSpeech = true;
960
+ else if (nextToken === this.cfg.endOfSpeech) {
961
+ seq.push(nextToken);
962
+ finished = true;
963
+ break;
964
+ } else if (inSpeech && nextToken >= this.cfg.audioTokensStart) audioTokens.push(nextToken);
965
+ seq.push(nextToken);
966
+ if (inSpeech && audioTokens.length >= frameCap) {
967
+ finished = true;
968
+ break;
969
+ }
970
+ if (seq.length >= this.maxSeqLen) break;
971
+ this.writeCosSin(Float32Array.of(this.logicalPositionAt(seq)), this.backboneExec.currentSeqPos);
972
+ logits = (await this.backboneExec.forward(new Uint32Array([nextToken]))).logits;
973
+ }
974
+ return {
975
+ audioTokens,
976
+ finished
977
+ };
978
+ }
979
+ /** Logical (frame-level) position of the LAST token in `seq`. */
980
+ logicalPositionAt(seq) {
981
+ const positions = computeKaniPositions(seq, this.cfg);
982
+ return positions[positions.length - 1];
983
+ }
984
+ /** Greedy argmax over a logits row. */
985
+ argmax(logits) {
986
+ let best = 0;
987
+ let bestVal = logits[0];
988
+ for (let i = 1; i < logits.length; i++) if (logits[i] > bestVal) {
989
+ bestVal = logits[i];
990
+ best = i;
991
+ }
992
+ return best;
993
+ }
994
+ /**
995
+ * Sample one audio token for codebook position `codebook`, constrained to that
996
+ * codebook's valid window [audio_tokens_start + 4032*c, +4032). Allows end_of_speech
997
+ * only at codebook 0 (frame boundary). Applies temperature, top-p, rep-penalty.
998
+ */
999
+ sampleAudioToken(logits, codebook, p) {
1000
+ const winStart = this.cfg.audioTokensStart + this.cfg.codebookSize * codebook;
1001
+ const winLen = this.cfg.codebookSize;
1002
+ const allowEos = codebook === 0;
1003
+ const n = winLen + (allowEos ? 1 : 0);
1004
+ const ids = new Int32Array(n);
1005
+ for (let i = 0; i < winLen; i++) ids[i] = winStart + i;
1006
+ if (allowEos) ids[winLen] = this.cfg.endOfSpeech;
1007
+ const prev = p.repetitionPenalty !== 1 ? new Set(p.previous) : null;
1008
+ const scores = new Float32Array(n);
1009
+ for (let i = 0; i < n; i++) {
1010
+ let s = logits[ids[i]];
1011
+ if (prev?.has(ids[i])) s = s > 0 ? s / p.repetitionPenalty : s * p.repetitionPenalty;
1012
+ scores[i] = s / p.temperature;
1013
+ }
1014
+ softmaxInPlace(scores);
1015
+ return nucleusSample(scores, ids, p.topP);
1016
+ }
1017
+ /**
1018
+ * Decode NanoCodec codes [groups, T] (group-major) → PCM.
1019
+ *
1020
+ * The decoder graph carries concrete lengths, and the upsampled conv activations
1021
+ * for long clips overflow WebGPU's 65535 per-dimension dispatch cap. The decoder
1022
+ * is fully causal with a small (≤ a few frames) receptive field, so we decode in
1023
+ * frame chunks with a left-context lookback and keep only each chunk's own output
1024
+ * samples — numerically identical to a single decode, but bounded per dispatch.
1025
+ */
1026
+ async decodeCodes(codes, numFrames) {
1027
+ const groups = 4;
1028
+ const lookback = KANI_CODEC_LOOKBACK_FRAMES;
1029
+ const chunk = KANI_CODEC_CHUNK_FRAMES;
1030
+ const pcm = new Float32Array(numFrames * KANI_HOP);
1031
+ for (let start = 0; start < numFrames; start += chunk) {
1032
+ const ctxStart = Math.max(0, start - lookback);
1033
+ const ctxFrames = start - ctxStart;
1034
+ const end = Math.min(numFrames, start + chunk);
1035
+ const winFrames = end - ctxStart;
1036
+ const winCodes = new Uint32Array(groups * winFrames);
1037
+ for (let g = 0; g < groups; g++) for (let t = 0; t < winFrames; t++) winCodes[g * winFrames + t] = codes[g * numFrames + (ctxStart + t)];
1038
+ const win = await this.decodeCodesWindow(winCodes, winFrames);
1039
+ const dropSamples = ctxFrames * KANI_HOP;
1040
+ const keepSamples = (end - start) * KANI_HOP;
1041
+ pcm.set(win.subarray(dropSamples, dropSamples + keepSamples), start * KANI_HOP);
1042
+ }
1043
+ for (let i = 0; i < pcm.length; i++) pcm[i] = Math.min(1, Math.max(-1, pcm[i]));
1044
+ return pcm;
1045
+ }
1046
+ /** Run the NanoCodec decoder graph for a single (bounded) code window → PCM. */
1047
+ async decodeCodesWindow(winCodes, winFrames) {
1048
+ const graph = generateNanoCodecDecoderGraph({ numFrames: winFrames });
1049
+ const exec = new Executor(this.ctx, graph, {
1050
+ maxSeqLen: winFrames,
1051
+ kvMode: "f32"
1052
+ });
1053
+ try {
1054
+ exec.uploadWeightsMap(selectGraphWeights(graph, this.loaded.codecWeights));
1055
+ exec.initBindGroups();
1056
+ exec.reset();
1057
+ exec.writeInput("audio_codes", winCodes);
1058
+ return await exec.runGraphOutput(graph.outputs[0], winFrames * KANI_HOP);
1059
+ } finally {
1060
+ exec.destroy();
1061
+ }
1062
+ }
1063
+ destroy() {
1064
+ if (this._destroyed) return;
1065
+ this._destroyed = true;
1066
+ this.backboneExec.destroy();
1067
+ this.loaded.backboneWeights.clear();
1068
+ this.loaded.codecWeights.clear();
1069
+ }
1070
+ };
1071
+
1072
+ //#endregion
1073
+ //#region src/gpu/sampler.ts
1074
+ let _heapIndices = null;
1075
+ let _heapValues = null;
1076
+ /**
1077
+ * Sample a token ID from logits.
1078
+ *
1079
+ * Pipeline: repetition penalty → temperature → top-k (min-heap) → softmax → top-p → sample.
1080
+ */
1081
+ function sampleToken(logits, params = {}, previousTokens) {
1082
+ const temperature = params.temperature ?? .7;
1083
+ const topK = params.topK ?? 50;
1084
+ const topP = params.topP ?? .9;
1085
+ const repetitionPenalty = params.repetitionPenalty ?? 1;
1086
+ if (temperature < 1e-6) return argmax(logits);
1087
+ const N = logits.length;
1088
+ const K = Math.min(topK > 0 ? topK : N, N);
1089
+ if (!_heapIndices || _heapIndices.length < K) {
1090
+ _heapIndices = new Uint32Array(K);
1091
+ _heapValues = new Float32Array(K);
1092
+ }
1093
+ const hIdx = _heapIndices;
1094
+ const hVal = _heapValues;
1095
+ let penaltySet = null;
1096
+ if (repetitionPenalty !== 1 && previousTokens?.length) penaltySet = new Set(previousTokens);
1097
+ let heapSize = 0;
1098
+ for (let i = 0; i < N; i++) {
1099
+ let s = logits[i];
1100
+ if (penaltySet?.has(i)) s = s > 0 ? s / repetitionPenalty : s * repetitionPenalty;
1101
+ s /= temperature;
1102
+ if (heapSize < K) {
1103
+ hIdx[heapSize] = i;
1104
+ hVal[heapSize] = s;
1105
+ heapSize++;
1106
+ if (heapSize === K) for (let j = (K >> 1) - 1; j >= 0; j--) siftDown(hIdx, hVal, j, K);
1107
+ } else if (s > hVal[0]) {
1108
+ hIdx[0] = i;
1109
+ hVal[0] = s;
1110
+ siftDown(hIdx, hVal, 0, K);
1111
+ }
1112
+ }
1113
+ for (let i = 1; i < heapSize; i++) {
1114
+ const vi = hVal[i];
1115
+ const ii = hIdx[i];
1116
+ let j = i - 1;
1117
+ while (j >= 0 && hVal[j] < vi) {
1118
+ hVal[j + 1] = hVal[j];
1119
+ hIdx[j + 1] = hIdx[j];
1120
+ j--;
1121
+ }
1122
+ hVal[j + 1] = vi;
1123
+ hIdx[j + 1] = ii;
1124
+ }
1125
+ const maxScore = hVal[0];
1126
+ let sumExp = 0;
1127
+ for (let i = 0; i < heapSize; i++) {
1128
+ const p = Math.exp(hVal[i] - maxScore);
1129
+ hVal[i] = p;
1130
+ sumExp += p;
1131
+ }
1132
+ const invSum = 1 / sumExp;
1133
+ for (let i = 0; i < heapSize; i++) hVal[i] *= invSum;
1134
+ let candidateCount = heapSize;
1135
+ if (topP < 1) {
1136
+ let cumulative$1 = 0;
1137
+ for (let i = 0; i < heapSize; i++) {
1138
+ cumulative$1 += hVal[i];
1139
+ if (cumulative$1 >= topP) {
1140
+ candidateCount = i + 1;
1141
+ break;
1142
+ }
1143
+ }
1144
+ let sum = 0;
1145
+ for (let i = 0; i < candidateCount; i++) sum += hVal[i];
1146
+ const inv = 1 / sum;
1147
+ for (let i = 0; i < candidateCount; i++) hVal[i] *= inv;
1148
+ }
1149
+ const r = Math.random();
1150
+ let cumulative = 0;
1151
+ for (let i = 0; i < candidateCount; i++) {
1152
+ cumulative += hVal[i];
1153
+ if (r <= cumulative) return hIdx[i];
1154
+ }
1155
+ return hIdx[candidateCount - 1];
1156
+ }
1157
+ /**
1158
+ * Return the index of the maximum value (greedy decoding).
1159
+ */
1160
+ function argmax(arr) {
1161
+ let maxIdx = 0;
1162
+ let maxVal = arr[0];
1163
+ for (let i = 1; i < arr.length; i++) if (arr[i] > maxVal) {
1164
+ maxVal = arr[i];
1165
+ maxIdx = i;
1166
+ }
1167
+ return maxIdx;
1168
+ }
1169
+ /** Min-heap sift down on parallel index/value typed arrays. */
1170
+ function siftDown(indices, values, i, n) {
1171
+ while (true) {
1172
+ let smallest = i;
1173
+ const left = 2 * i + 1;
1174
+ const right = 2 * i + 2;
1175
+ if (left < n && values[left] < values[smallest]) smallest = left;
1176
+ if (right < n && values[right] < values[smallest]) smallest = right;
1177
+ if (smallest === i) break;
1178
+ const ti = indices[i];
1179
+ indices[i] = indices[smallest];
1180
+ indices[smallest] = ti;
1181
+ const tv = values[i];
1182
+ values[i] = values[smallest];
1183
+ values[smallest] = tv;
1184
+ i = smallest;
1185
+ }
1186
+ }
1187
+
1188
+ //#endregion
1189
+ //#region src/gpu/vision-executor.ts
1190
+ const MAP_MODE_READ = 1;
1191
+ /** Pack an f32 array to IEEE half (f16) bits (round-to-nearest-even). */
1192
+ const _f16scratch = /* @__PURE__ */ new ArrayBuffer(4);
1193
+ const _f16f32 = new Float32Array(_f16scratch);
1194
+ const _f16u32 = new Uint32Array(_f16scratch);
1195
+ function packF16(src) {
1196
+ const out = new Uint16Array(src.length);
1197
+ for (let i = 0; i < src.length; i++) {
1198
+ _f16f32[0] = src[i];
1199
+ const f = _f16u32[0];
1200
+ const sign = f >>> 16 & 32768;
1201
+ const exp = (f >>> 23 & 255) - 112;
1202
+ const mant = f & 8388607;
1203
+ if (exp <= 0) out[i] = sign;
1204
+ else if (exp >= 31) out[i] = sign | 31744;
1205
+ else {
1206
+ const round = (mant & 4096) !== 0;
1207
+ let half = sign | exp << 10 | mant >>> 13;
1208
+ if (round) half += 1;
1209
+ out[i] = half;
1210
+ }
1211
+ }
1212
+ return out;
1213
+ }
1214
+ /**
1215
+ * WebKit only: the ViT bidirectional Attention op is a single dispatch whose
1216
+ * grid is [N query rows, heads], each row looping over all N keys — O(N²) work in
1217
+ * one submission. Past a few hundred rows that can exceed Metal's ~few-second GPU
1218
+ * watchdog and kill the web-content process (indistinguishable from OOM). We split
1219
+ * the attention grid into windows of this many query rows, each its own
1220
+ * submit + onSubmittedWorkDone() drain, so no single submission runs long enough
1221
+ * to trip the watchdog. Math is unchanged: every query row computes identically
1222
+ * regardless of how the grid is partitioned.
1223
+ */
1224
+ const WEBKIT_ATTN_CHUNK_ROWS = 64;
1225
+ var VisionExecutor = class {
1226
+ ctx;
1227
+ graph;
1228
+ mergeUnit;
1229
+ weightBuffers = /* @__PURE__ */ new Map();
1230
+ activationBuffers = /* @__PURE__ */ new Map();
1231
+ dispatches = [];
1232
+ maxPatches;
1233
+ /** Weight (B) names of MatMulBias nodes stored as f16 (empty without shader-f16). */
1234
+ f16WeightNames = /* @__PURE__ */ new Set();
1235
+ /** Runtime pooled-token count for the Gemma 4 ViT ("Np" dim); 0 for Qwen. */
1236
+ gemma4Np = 0;
1237
+ /** True when this graph is the Gemma 4 vision tower (uses "Np" + PoolMatMul). */
1238
+ isGemma4;
1239
+ constructor(ctx, graph, maxPatches) {
1240
+ this.ctx = ctx;
1241
+ this.graph = graph;
1242
+ this.maxPatches = maxPatches;
1243
+ this.isGemma4 = graph.architecture === "Gemma4VisionModel";
1244
+ const mergedIn = graph.tensors[CANONICAL_KEYS.visMergerFc1W]?.shape[1];
1245
+ const hidden = graph.config.hidden_size;
1246
+ this.mergeUnit = mergedIn && hidden ? mergedIn / hidden : 1;
1247
+ if (ctx.hasF16) {
1248
+ for (const node of graph.nodes) if (node.opType === "MatMulBias" && node.inputs[1]) this.f16WeightNames.add(node.inputs[1]);
1249
+ }
1250
+ this.allocateActivationBuffers();
1251
+ }
1252
+ uploadWeights(weights) {
1253
+ for (const [name, desc] of Object.entries(this.graph.tensors)) {
1254
+ if (desc.storage !== "constant") continue;
1255
+ const w = weights.get(name);
1256
+ if (!w) throw new Error(`VisionExecutor: missing weight "${name}"`);
1257
+ let data = w.data;
1258
+ if (this.f16WeightNames.has(name) && w.data instanceof Float32Array) data = packF16(w.data);
1259
+ const buffer = createStorageBuffer(this.ctx, `vweight_${name}`, data.byteLength, data);
1260
+ this.weightBuffers.set(name, buffer);
1261
+ }
1262
+ }
1263
+ initBindGroups() {
1264
+ const dummyShapes = this.resolveShapes(this.maxPatches);
1265
+ for (const nodeId of this.graph.executionOrder) {
1266
+ const node = this.graph.nodes.find((n) => n.id === nodeId);
1267
+ let spec = KERNEL_REGISTRY[node.opType];
1268
+ if (!spec) throw new Error(`VisionExecutor: no kernel for op "${node.opType}"`);
1269
+ if (node.opType === "MatMulBias" && this.f16WeightNames.has(node.inputs[1])) spec = MATMUL_BIAS_F16C_SPEC;
1270
+ const pipeline = getOrCreatePipeline(this.ctx, `vkernel_${nodeId}`, spec.shaderCode, spec.entryPoint);
1271
+ const paramsData = spec.buildParams(node, dummyShapes, { seqPos: 0 });
1272
+ const uniformBuffer = createUniformBuffer(this.ctx, `vuniform_${nodeId}`, paramsData);
1273
+ const bufferEntries = this.gatherBuffers(spec, node, uniformBuffer);
1274
+ const bindGroup = createBindGroup(this.ctx, pipeline, bufferEntries, `vbg_${nodeId}`);
1275
+ this.dispatches.push({
1276
+ node,
1277
+ spec,
1278
+ pipeline,
1279
+ bindGroup,
1280
+ uniformBuffer
1281
+ });
1282
+ }
1283
+ }
1284
+ /**
1285
+ * Encode patches → merged image embeddings [Nm, out_hidden_size].
1286
+ *
1287
+ * `onStage` (optional) fires coarse phase breadcrumbs during the WebKit path so
1288
+ * a host can localize a GPU-process crash to a specific layer.
1289
+ */
1290
+ async encode(inputs, onStage) {
1291
+ const N = inputs.numPatches;
1292
+ if (N > this.maxPatches) throw new Error(`VisionExecutor: ${N} patches exceeds maxPatches=${this.maxPatches}`);
1293
+ const q = this.ctx.device.queue;
1294
+ q.writeBuffer(this.activationBuffers.get("vis_patches"), 0, inputs.patches);
1295
+ q.writeBuffer(this.activationBuffers.get("vis_pos_embeds"), 0, inputs.posEmbeds);
1296
+ q.writeBuffer(this.activationBuffers.get("vis_cos"), 0, inputs.cos);
1297
+ q.writeBuffer(this.activationBuffers.get("vis_sin"), 0, inputs.sin);
1298
+ const resolvedShapes = this.resolveShapes(N);
1299
+ const ctxRt = { seqPos: 0 };
1300
+ const sizes = [];
1301
+ for (const d of this.dispatches) {
1302
+ const params = d.spec.buildParams(d.node, resolvedShapes, ctxRt);
1303
+ q.writeBuffer(d.uniformBuffer, 0, params);
1304
+ sizes.push(d.spec.getDispatchSize(d.node, resolvedShapes, ctxRt));
1305
+ }
1306
+ if (this.ctx.isWebKitWebGPU) {
1307
+ onStage?.("vit-encode-start", { total: this.dispatches.length });
1308
+ for (let i = 0; i < this.dispatches.length; i++) {
1309
+ const d = this.dispatches[i];
1310
+ if (d.node.opType === "Attention") {
1311
+ const [gridX, gridY, gridZ] = sizes[i];
1312
+ const layer = Number.parseInt(d.node.id.replace(/\D+/g, ""), 10);
1313
+ if (!Number.isNaN(layer)) onStage?.("vit-attn", { layer });
1314
+ for (let start = 0; start < gridX; start += WEBKIT_ATTN_CHUNK_ROWS) {
1315
+ const rows$1 = Math.min(WEBKIT_ATTN_CHUNK_ROWS, gridX - start);
1316
+ const params = d.spec.buildParams(d.node, resolvedShapes, {
1317
+ seqPos: 0,
1318
+ qOffset: start
1319
+ });
1320
+ q.writeBuffer(d.uniformBuffer, 0, params);
1321
+ const enc$2 = this.ctx.device.createCommandEncoder();
1322
+ const p$1 = enc$2.beginComputePass();
1323
+ p$1.setPipeline(d.pipeline);
1324
+ p$1.setBindGroup(0, d.bindGroup);
1325
+ p$1.dispatchWorkgroups(rows$1, gridY, gridZ);
1326
+ p$1.end();
1327
+ q.submit([enc$2.finish()]);
1328
+ await q.onSubmittedWorkDone();
1329
+ }
1330
+ continue;
1331
+ }
1332
+ const enc$1 = this.ctx.device.createCommandEncoder();
1333
+ const p = enc$1.beginComputePass();
1334
+ p.setPipeline(d.pipeline);
1335
+ p.setBindGroup(0, d.bindGroup);
1336
+ p.dispatchWorkgroups(...sizes[i]);
1337
+ p.end();
1338
+ q.submit([enc$1.finish()]);
1339
+ await q.onSubmittedWorkDone();
1340
+ }
1341
+ onStage?.("vit-encode-done");
1342
+ } else {
1343
+ const enc$1 = this.ctx.device.createCommandEncoder({ label: "vision_encode" });
1344
+ const pass = enc$1.beginComputePass({ label: "vision_pass" });
1345
+ for (let i = 0; i < this.dispatches.length; i++) {
1346
+ pass.setPipeline(this.dispatches[i].pipeline);
1347
+ pass.setBindGroup(0, this.dispatches[i].bindGroup);
1348
+ pass.dispatchWorkgroups(...sizes[i]);
1349
+ }
1350
+ pass.end();
1351
+ q.submit([enc$1.finish()]);
1352
+ }
1353
+ const rows = N / this.mergeUnit;
1354
+ const dim = this.graph.config.vision_embed_dim ? this.graph.tensors.vis_image_embeds.shape[1] : this.graph.tensors.vis_image_embeds.shape[1];
1355
+ const byteLen = rows * dim * 4;
1356
+ const out = this.activationBuffers.get("vis_image_embeds");
1357
+ const readback = this.ctx.device.createBuffer({
1358
+ size: byteLen,
1359
+ usage: 9
1360
+ });
1361
+ const enc = this.ctx.device.createCommandEncoder();
1362
+ enc.copyBufferToBuffer(out, 0, readback, 0, byteLen);
1363
+ this.ctx.device.queue.submit([enc.finish()]);
1364
+ await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
1365
+ const embeds = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
1366
+ readback.unmap();
1367
+ readback.destroy();
1368
+ return {
1369
+ embeds,
1370
+ rows,
1371
+ dim
1372
+ };
1373
+ }
1374
+ /**
1375
+ * Encode patches through the Gemma 4 ViT → projected image tokens [Np, text_hidden].
1376
+ *
1377
+ * Distinct from the Qwen `encode()`: the Gemma graph has 5 inputs (patches,
1378
+ * axial pos-embeds, axial rotary cos/sin, and a host-built [Np,N] pooling matrix)
1379
+ * and its output rows (Np) are the pooled soft-token count, resolved from the
1380
+ * pooling matrix rather than an N/mergeUnit ratio. Reuses the same dispatch
1381
+ * machinery + WebKit per-dispatch-drain discipline.
1382
+ */
1383
+ async encodeGemma4(inputs, onStage) {
1384
+ const N = inputs.numPatches;
1385
+ if (N > this.maxPatches) throw new Error(`VisionExecutor(Gemma4): ${N} patches exceeds maxPatches=${this.maxPatches}`);
1386
+ this.gemma4Np = inputs.numPooled;
1387
+ const q = this.ctx.device.queue;
1388
+ q.writeBuffer(this.activationBuffers.get("g4v_patches"), 0, inputs.patches);
1389
+ q.writeBuffer(this.activationBuffers.get("g4v_pos_embeds"), 0, inputs.posEmbeds);
1390
+ q.writeBuffer(this.activationBuffers.get("g4v_cos"), 0, inputs.cos);
1391
+ q.writeBuffer(this.activationBuffers.get("g4v_sin"), 0, inputs.sin);
1392
+ q.writeBuffer(this.activationBuffers.get("g4v_pool_w"), 0, inputs.poolMatrix);
1393
+ const resolvedShapes = this.resolveShapes(N);
1394
+ const ctxRt = { seqPos: 0 };
1395
+ const sizes = [];
1396
+ for (const d of this.dispatches) {
1397
+ const params = d.spec.buildParams(d.node, resolvedShapes, ctxRt);
1398
+ q.writeBuffer(d.uniformBuffer, 0, params);
1399
+ sizes.push(d.spec.getDispatchSize(d.node, resolvedShapes, ctxRt));
1400
+ }
1401
+ if (this.ctx.isWebKitWebGPU) {
1402
+ onStage?.("vit-encode-start", { total: this.dispatches.length });
1403
+ for (let i = 0; i < this.dispatches.length; i++) {
1404
+ const d = this.dispatches[i];
1405
+ if (d.node.opType === "Attention") {
1406
+ const [gridX, gridY, gridZ] = sizes[i];
1407
+ const layer = Number.parseInt(d.node.id.replace(/\D+/g, ""), 10);
1408
+ if (!Number.isNaN(layer)) onStage?.("vit-attn", { layer });
1409
+ for (let start = 0; start < gridX; start += WEBKIT_ATTN_CHUNK_ROWS) {
1410
+ const rows$1 = Math.min(WEBKIT_ATTN_CHUNK_ROWS, gridX - start);
1411
+ const params = d.spec.buildParams(d.node, resolvedShapes, {
1412
+ seqPos: 0,
1413
+ qOffset: start
1414
+ });
1415
+ q.writeBuffer(d.uniformBuffer, 0, params);
1416
+ const enc$2 = this.ctx.device.createCommandEncoder();
1417
+ const p$1 = enc$2.beginComputePass();
1418
+ p$1.setPipeline(d.pipeline);
1419
+ p$1.setBindGroup(0, d.bindGroup);
1420
+ p$1.dispatchWorkgroups(rows$1, gridY, gridZ);
1421
+ p$1.end();
1422
+ q.submit([enc$2.finish()]);
1423
+ await q.onSubmittedWorkDone();
1424
+ }
1425
+ continue;
1426
+ }
1427
+ const enc$1 = this.ctx.device.createCommandEncoder();
1428
+ const p = enc$1.beginComputePass();
1429
+ p.setPipeline(d.pipeline);
1430
+ p.setBindGroup(0, d.bindGroup);
1431
+ p.dispatchWorkgroups(...sizes[i]);
1432
+ p.end();
1433
+ q.submit([enc$1.finish()]);
1434
+ await q.onSubmittedWorkDone();
1435
+ }
1436
+ onStage?.("vit-encode-done");
1437
+ } else {
1438
+ const enc$1 = this.ctx.device.createCommandEncoder({ label: "gemma4_vision_encode" });
1439
+ const pass = enc$1.beginComputePass({ label: "gemma4_vision_pass" });
1440
+ for (let i = 0; i < this.dispatches.length; i++) {
1441
+ pass.setPipeline(this.dispatches[i].pipeline);
1442
+ pass.setBindGroup(0, this.dispatches[i].bindGroup);
1443
+ pass.dispatchWorkgroups(...sizes[i]);
1444
+ }
1445
+ pass.end();
1446
+ q.submit([enc$1.finish()]);
1447
+ }
1448
+ const rows = inputs.numPooled;
1449
+ const dim = this.graph.tensors.g4v_image_embeds.shape[1];
1450
+ const byteLen = rows * dim * 4;
1451
+ const out = this.activationBuffers.get("g4v_image_embeds");
1452
+ const readback = this.ctx.device.createBuffer({
1453
+ size: byteLen,
1454
+ usage: 9
1455
+ });
1456
+ const enc = this.ctx.device.createCommandEncoder();
1457
+ enc.copyBufferToBuffer(out, 0, readback, 0, byteLen);
1458
+ this.ctx.device.queue.submit([enc.finish()]);
1459
+ await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
1460
+ const embeds = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
1461
+ readback.unmap();
1462
+ readback.destroy();
1463
+ return {
1464
+ embeds,
1465
+ rows,
1466
+ dim
1467
+ };
1468
+ }
1469
+ /** True if this executor is the Gemma 4 vision tower. */
1470
+ get gemma4() {
1471
+ return this.isGemma4;
1472
+ }
1473
+ /** Read back any named activation (debug). Must be called right after encode(). */
1474
+ async debugReadBuffer(name, maxElements) {
1475
+ const buffer = this.activationBuffers.get(name) ?? this.weightBuffers.get(name);
1476
+ if (!buffer) throw new Error(`VisionExecutor: no buffer "${name}"`);
1477
+ const byteLen = maxElements ? Math.min(maxElements * 4, buffer.size) : buffer.size;
1478
+ const readback = this.ctx.device.createBuffer({
1479
+ size: byteLen,
1480
+ usage: 9
1481
+ });
1482
+ const enc = this.ctx.device.createCommandEncoder();
1483
+ enc.copyBufferToBuffer(buffer, 0, readback, 0, byteLen);
1484
+ this.ctx.device.queue.submit([enc.finish()]);
1485
+ await readback.mapAsync(MAP_MODE_READ, 0, byteLen);
1486
+ const data = new Float32Array(readback.getMappedRange(0, byteLen).slice(0));
1487
+ readback.unmap();
1488
+ readback.destroy();
1489
+ return data;
1490
+ }
1491
+ destroy() {
1492
+ destroyBuffers([...this.weightBuffers.values()]);
1493
+ destroyBuffers([...this.activationBuffers.values()]);
1494
+ for (const d of this.dispatches) d.uniformBuffer.destroy();
1495
+ this.weightBuffers.clear();
1496
+ this.activationBuffers.clear();
1497
+ this.dispatches = [];
1498
+ }
1499
+ /** Max pooled tokens for buffer sizing: maxPatches with no merge/pool ratio applied. */
1500
+ maxPooled() {
1501
+ return this.maxPatches;
1502
+ }
1503
+ resolveShapes(N) {
1504
+ const Nm = N / this.mergeUnit;
1505
+ const Np = this.gemma4Np || N;
1506
+ const resolved = {};
1507
+ for (const [name, desc] of Object.entries(this.graph.tensors)) resolved[name] = desc.shape.map((d) => {
1508
+ if (d === "N") return N;
1509
+ if (d === "Nm") return Nm;
1510
+ if (d === "Np") return Np;
1511
+ return d;
1512
+ });
1513
+ return resolved;
1514
+ }
1515
+ allocateActivationBuffers() {
1516
+ for (const [name, desc] of Object.entries(this.graph.tensors)) {
1517
+ if (desc.storage !== "activation") continue;
1518
+ const bytes = desc.shape.map((d) => {
1519
+ if (d === "N") return this.maxPatches;
1520
+ if (d === "Nm") return this.maxPatches / this.mergeUnit;
1521
+ if (d === "Np") return this.maxPooled();
1522
+ return d;
1523
+ }).reduce((a, b) => a * b, 1) * DTYPE_BYTES[desc.dtype];
1524
+ this.activationBuffers.set(name, createStorageBuffer(this.ctx, `vact_${name}`, bytes));
1525
+ }
1526
+ }
1527
+ gatherBuffers(spec, node, uniformBuffer) {
1528
+ const entries = [];
1529
+ const tensorNames = [...node.inputs, ...node.outputs];
1530
+ let tensorIdx = 0;
1531
+ for (const binding of spec.bindings) if (binding.type === "uniform") entries.push({ buffer: uniformBuffer });
1532
+ else {
1533
+ const tensorName = tensorNames[tensorIdx++];
1534
+ const buffer = this.weightBuffers.get(tensorName) ?? this.activationBuffers.get(tensorName);
1535
+ if (!buffer) throw new Error(`VisionExecutor: no buffer for "${tensorName}" in op ${node.id}`);
1536
+ entries.push({ buffer });
1537
+ }
1538
+ return entries;
1539
+ }
1540
+ };
1541
+
1542
+ //#endregion
1543
+ //#region src/gpu/vision-preprocess.ts
1544
+ /** linspace(0, n-1, count) matching torch.linspace. */
1545
+ function linspace(stop, count) {
1546
+ const out = new Float64Array(count);
1547
+ if (count === 1) {
1548
+ out[0] = 0;
1549
+ return out;
1550
+ }
1551
+ const step = stop / (count - 1);
1552
+ for (let i = 0; i < count; i++) out[i] = i * step;
1553
+ return out;
1554
+ }
1555
+ /**
1556
+ * Compute the spatial-merge reorder index for a (h, w) grid:
1557
+ * reorder[k] gives the source patch index (row-major h*w) for output slot k,
1558
+ * grouping 2×2 (merge×merge) spatial blocks contiguously. Matches the `reorder`
1559
+ * in get_vision_bilinear_indices_and_weights and the position-id reshape.
1560
+ */
1561
+ function spatialMergeReorder(h, w, merge) {
1562
+ const hb = h / merge;
1563
+ const wb = w / merge;
1564
+ const out = new Int32Array(h * w);
1565
+ let k = 0;
1566
+ for (let bh = 0; bh < hb; bh++) for (let bw = 0; bw < wb; bw++) for (let ih = 0; ih < merge; ih++) for (let iw = 0; iw < merge; iw++) {
1567
+ const hh = bh * merge + ih;
1568
+ const ww = bw * merge + iw;
1569
+ out[k++] = hh * w + ww;
1570
+ }
1571
+ return out;
1572
+ }
1573
+ /**
1574
+ * Build bilinear-interpolated learned position embeddings [N, hidden].
1575
+ * posEmbedTable is the raw pos_embed.weight [num_position_embeddings, hidden].
1576
+ */
1577
+ function buildPosEmbeds(gridTHW, posEmbedTable, cfg) {
1578
+ const [t, h, w] = gridTHW;
1579
+ const hidden = cfg.hiddenSize;
1580
+ const side = Math.round(Math.sqrt(cfg.numPositionEmbeddings));
1581
+ const merge = cfg.spatialMergeSize;
1582
+ const hGrid = linspace(side - 1, h);
1583
+ const wGrid = linspace(side - 1, w);
1584
+ const hFloor = new Int32Array(h);
1585
+ const hCeil = new Int32Array(h);
1586
+ const hFrac = new Float64Array(h);
1587
+ for (let i = 0; i < h; i++) {
1588
+ hFloor[i] = Math.trunc(hGrid[i]);
1589
+ hCeil[i] = Math.min(hFloor[i] + 1, side - 1);
1590
+ hFrac[i] = hGrid[i] - hFloor[i];
1591
+ }
1592
+ const wFloor = new Int32Array(w);
1593
+ const wCeil = new Int32Array(w);
1594
+ const wFrac = new Float64Array(w);
1595
+ for (let j = 0; j < w; j++) {
1596
+ wFloor[j] = Math.trunc(wGrid[j]);
1597
+ wCeil[j] = Math.min(wFloor[j] + 1, side - 1);
1598
+ wFrac[j] = wGrid[j] - wFloor[j];
1599
+ }
1600
+ const hw = h * w;
1601
+ const cornerIdx = [
1602
+ new Int32Array(hw),
1603
+ new Int32Array(hw),
1604
+ new Int32Array(hw),
1605
+ new Int32Array(hw)
1606
+ ];
1607
+ const cornerW = [
1608
+ new Float64Array(hw),
1609
+ new Float64Array(hw),
1610
+ new Float64Array(hw),
1611
+ new Float64Array(hw)
1612
+ ];
1613
+ let p = 0;
1614
+ for (let i = 0; i < h; i++) {
1615
+ const hf = hFloor[i] * side;
1616
+ const hc = hCeil[i] * side;
1617
+ const hfr = hFrac[i];
1618
+ for (let j = 0; j < w; j++) {
1619
+ cornerIdx[0][p] = hf + wFloor[j];
1620
+ cornerIdx[1][p] = hf + wCeil[j];
1621
+ cornerIdx[2][p] = hc + wFloor[j];
1622
+ cornerIdx[3][p] = hc + wCeil[j];
1623
+ cornerW[0][p] = (1 - hfr) * (1 - wFrac[j]);
1624
+ cornerW[1][p] = (1 - hfr) * wFrac[j];
1625
+ cornerW[2][p] = hfr * (1 - wFrac[j]);
1626
+ cornerW[3][p] = hfr * wFrac[j];
1627
+ p++;
1628
+ }
1629
+ }
1630
+ const reorder = spatialMergeReorder(h, w, merge);
1631
+ const N = t * hw;
1632
+ const out = new Float32Array(N * hidden);
1633
+ for (let tt = 0; tt < t; tt++) for (let k = 0; k < hw; k++) {
1634
+ const src = reorder[k];
1635
+ const dstBase = (tt * hw + k) * hidden;
1636
+ for (let corner = 0; corner < 4; corner++) {
1637
+ const tableRow = cornerIdx[corner][src] * hidden;
1638
+ const wgt = cornerW[corner][src];
1639
+ if (wgt === 0) continue;
1640
+ for (let d = 0; d < hidden; d++) out[dstBase + d] += posEmbedTable[tableRow + d] * wgt;
1641
+ }
1642
+ }
1643
+ return out;
1644
+ }
1645
+ /**
1646
+ * Build the reordered (row, col) position ids [N, 2] for rotary, matching
1647
+ * get_vision_position_ids.
1648
+ */
1649
+ function buildPositionIds(gridTHW, merge) {
1650
+ const [t, h, w] = gridTHW;
1651
+ const hb = h / merge;
1652
+ const wb = w / merge;
1653
+ const hw = h * w;
1654
+ const perFrame = new Int32Array(hw * 2);
1655
+ let k = 0;
1656
+ for (let bh = 0; bh < hb; bh++) for (let bw = 0; bw < wb; bw++) for (let ih = 0; ih < merge; ih++) for (let iw = 0; iw < merge; iw++) {
1657
+ perFrame[k * 2] = bh * merge + ih;
1658
+ perFrame[k * 2 + 1] = bw * merge + iw;
1659
+ k++;
1660
+ }
1661
+ const out = new Int32Array(t * hw * 2);
1662
+ for (let tt = 0; tt < t; tt++) out.set(perFrame, tt * hw * 2);
1663
+ return out;
1664
+ }
1665
+ /**
1666
+ * Build rotary cos/sin tables [N, head_dim] from position ids, matching
1667
+ * Qwen3_5VisionRotaryEmbedding + the cat((rotary, rotary)) in VisionModel.forward.
1668
+ *
1669
+ * rotary_pos_emb(position_ids) = (position_ids[..,None] * inv_freq).flatten(1)
1670
+ * where inv_freq has length (head_dim/2)/2 = head_dim/4, computed over dim=head_dim/2.
1671
+ * For each token the two position components (h, w) each produce head_dim/4 freqs,
1672
+ * concatenated → head_dim/2, then duplicated → head_dim for cos/sin.
1673
+ */
1674
+ function buildRotaryCosSin(positionIds, headDim, theta = 1e4) {
1675
+ const rotaryDim = headDim / 2;
1676
+ const half = rotaryDim / 2;
1677
+ const invFreq = new Float64Array(half);
1678
+ for (let i = 0; i < half; i++) invFreq[i] = 1 / theta ** (2 * i / rotaryDim);
1679
+ const N = positionIds.length / 2;
1680
+ const cos = new Float32Array(N * headDim);
1681
+ const sin = new Float32Array(N * headDim);
1682
+ for (let n = 0; n < N; n++) {
1683
+ const hp = positionIds[n * 2];
1684
+ const wp = positionIds[n * 2 + 1];
1685
+ const base = n * headDim;
1686
+ for (let i = 0; i < half; i++) {
1687
+ const fh = hp * invFreq[i];
1688
+ const fw = wp * invFreq[i];
1689
+ const ch = Math.cos(fh);
1690
+ const sh = Math.sin(fh);
1691
+ const cw = Math.cos(fw);
1692
+ const sw = Math.sin(fw);
1693
+ cos[base + i] = ch;
1694
+ sin[base + i] = sh;
1695
+ cos[base + half + i] = cw;
1696
+ sin[base + half + i] = sw;
1697
+ cos[base + rotaryDim + i] = ch;
1698
+ sin[base + rotaryDim + i] = sh;
1699
+ cos[base + rotaryDim + half + i] = cw;
1700
+ sin[base + rotaryDim + half + i] = sw;
1701
+ }
1702
+ }
1703
+ return {
1704
+ cos,
1705
+ sin,
1706
+ numPatches: N
1707
+ };
1708
+ }
1709
+ /**
1710
+ * Build all host position tensors for a single image grid in one call.
1711
+ */
1712
+ function buildVisionPositionTensors(gridTHW, posEmbedTable, cfg) {
1713
+ const headDim = Math.floor(cfg.hiddenSize / cfg.numHeads);
1714
+ const posEmbeds = buildPosEmbeds(gridTHW, posEmbedTable, cfg);
1715
+ const { cos, sin, numPatches } = buildRotaryCosSin(buildPositionIds(gridTHW, cfg.spatialMergeSize), headDim, cfg.ropeTheta ?? 1e4);
1716
+ return {
1717
+ posEmbeds,
1718
+ cos,
1719
+ sin,
1720
+ numPatches
1721
+ };
1722
+ }
1723
+ /**
1724
+ * Per-patch (x, y) grid coordinates in row-major order, matching the Gemma4
1725
+ * image processor's meshgrid(arange(width), arange(height)) reshape: patch p at
1726
+ * row r (0..gridH-1), col c (0..gridW-1) → x=c, y=r, index = r*gridW + c.
1727
+ */
1728
+ function gemma4PatchXY(gridH, gridW) {
1729
+ const n = gridH * gridW;
1730
+ const x = new Int32Array(n);
1731
+ const y = new Int32Array(n);
1732
+ let p = 0;
1733
+ for (let r = 0; r < gridH; r++) for (let c = 0; c < gridW; c++) {
1734
+ x[p] = c;
1735
+ y[p] = r;
1736
+ p++;
1737
+ }
1738
+ return {
1739
+ x,
1740
+ y
1741
+ };
1742
+ }
1743
+ /**
1744
+ * Build axial learned position embeddings [N, hidden] from the [2, posSize, hidden]
1745
+ * table: pos[p] = table[0][x_p] + table[1][y_p]. Direct lookup, no interpolation
1746
+ * (HF F.embedding on clamped positions).
1747
+ */
1748
+ function buildGemma4PosEmbeds(gridH, gridW, posEmbedTable, hidden, posSize) {
1749
+ const { x, y } = gemma4PatchXY(gridH, gridW);
1750
+ const n = gridH * gridW;
1751
+ const out = new Float32Array(n * hidden);
1752
+ const yPlane = posSize * hidden;
1753
+ for (let p = 0; p < n; p++) {
1754
+ const xi = Math.max(0, x[p]);
1755
+ const yi = Math.max(0, y[p]);
1756
+ const xBase = xi * hidden;
1757
+ const yBase = yPlane + yi * hidden;
1758
+ const dst = p * hidden;
1759
+ for (let d = 0; d < hidden; d++) out[dst + d] = posEmbedTable[xBase + d] + posEmbedTable[yBase + d];
1760
+ }
1761
+ return out;
1762
+ }
1763
+ /**
1764
+ * Build the 2D axial rotary cos/sin tables [N, headDim].
1765
+ * spatial_dim = headDim / 2; inv_freq[j] = 1/theta^((2j)/spatial_dim), j in [0, spatial_dim/2)
1766
+ * per spatial dim: f = pos * inv_freq (spatial_dim/2 values); emb = cat(f, f) (spatial_dim values)
1767
+ * cos/sin = cat([emb_x, emb_y]) → headDim values, layout [fx,fx,fy,fy].
1768
+ * Applied with the global-half rotate_half kernel (ApplyRotaryEmb), which computes
1769
+ * out = x*cos + rotate_half(x)*sin element-wise — exact for this layout.
1770
+ */
1771
+ function buildGemma4RotaryCosSin(gridH, gridW, headDim, theta) {
1772
+ const { x, y } = gemma4PatchXY(gridH, gridW);
1773
+ const n = gridH * gridW;
1774
+ const spatialDim = headDim / 2;
1775
+ const half = spatialDim / 2;
1776
+ const invFreq = new Float64Array(half);
1777
+ for (let j = 0; j < half; j++) invFreq[j] = 1 / theta ** (2 * j / spatialDim);
1778
+ const cos = new Float32Array(n * headDim);
1779
+ const sin = new Float32Array(n * headDim);
1780
+ for (let p = 0; p < n; p++) {
1781
+ const xp = Math.max(0, x[p]);
1782
+ const yp = Math.max(0, y[p]);
1783
+ const base = p * headDim;
1784
+ for (let j = 0; j < half; j++) {
1785
+ const fx = xp * invFreq[j];
1786
+ const fy = yp * invFreq[j];
1787
+ const cx = Math.cos(fx);
1788
+ const sx = Math.sin(fx);
1789
+ const cy = Math.cos(fy);
1790
+ const sy = Math.sin(fy);
1791
+ cos[base + j] = cx;
1792
+ sin[base + j] = sx;
1793
+ cos[base + half + j] = cx;
1794
+ sin[base + half + j] = sx;
1795
+ cos[base + spatialDim + j] = cy;
1796
+ sin[base + spatialDim + j] = sy;
1797
+ cos[base + spatialDim + half + j] = cy;
1798
+ sin[base + spatialDim + half + j] = sy;
1799
+ }
1800
+ }
1801
+ return {
1802
+ cos,
1803
+ sin
1804
+ };
1805
+ }
1806
+ /**
1807
+ * Build the [Np, N] average-pooling matrix for k×k spatial pooling over the real
1808
+ * (unpadded) grid, matching modeling_gemma4's kernel_idxs/one_hot pooling:
1809
+ * cell(p) = floor(x_p/k) + ceil(gridW/k) * floor(y_p/k)
1810
+ * poolMatrix[cell, p] = 1/k² (so pooled = poolMatrix @ hidden = mean over the k×k block)
1811
+ * Np = ceil(gridH/k) * ceil(gridW/k). Each pooled cell averages exactly the patches
1812
+ * that fall in it (edge cells with fewer than k² patches still divide by k², matching
1813
+ * HF's fixed 1/k² normalization).
1814
+ */
1815
+ function buildGemma4PoolMatrix(gridH, gridW, k) {
1816
+ const n = gridH * gridW;
1817
+ const cellsW = Math.ceil(gridW / k);
1818
+ const np = cellsW * Math.ceil(gridH / k);
1819
+ const inv = 1 / (k * k);
1820
+ const poolMatrix = new Float32Array(np * n);
1821
+ const { x, y } = gemma4PatchXY(gridH, gridW);
1822
+ for (let p = 0; p < n; p++) {
1823
+ const cell = Math.floor(x[p] / k) + cellsW * Math.floor(y[p] / k);
1824
+ poolMatrix[cell * n + p] = inv;
1825
+ }
1826
+ return {
1827
+ poolMatrix,
1828
+ numPooled: np
1829
+ };
1830
+ }
1831
+ /**
1832
+ * Build all Gemma 4 vision host tensors for one image grid in one call.
1833
+ * `posEmbedTable` is the raw [2, posSize, hidden] flattened table.
1834
+ */
1835
+ function buildGemma4VisionPositionTensors(gridH, gridW, posEmbedTable, posSize, cfg) {
1836
+ const posEmbeds = buildGemma4PosEmbeds(gridH, gridW, posEmbedTable, cfg.hiddenSize, posSize);
1837
+ const { cos, sin } = buildGemma4RotaryCosSin(gridH, gridW, cfg.headDim, cfg.ropeTheta);
1838
+ const { poolMatrix, numPooled } = buildGemma4PoolMatrix(gridH, gridW, cfg.poolingKernelSize);
1839
+ return {
1840
+ posEmbeds,
1841
+ cos,
1842
+ sin,
1843
+ poolMatrix,
1844
+ numPatches: gridH * gridW,
1845
+ numPooled
1846
+ };
1847
+ }
1848
+ /** Gemma 4 image processor config (from processor_config.json). */
1849
+ const GEMMA4_IMAGE_PROCESSOR = {
1850
+ patchSize: 16,
1851
+ temporalPatchSize: 1,
1852
+ mergeSize: 1,
1853
+ imageMean: [
1854
+ 0,
1855
+ 0,
1856
+ 0
1857
+ ],
1858
+ imageStd: [
1859
+ 1,
1860
+ 1,
1861
+ 1
1862
+ ],
1863
+ rescaleFactor: 1 / 255,
1864
+ minPixels: 256,
1865
+ maxPixels: 2520 * 16 * 16
1866
+ };
1867
+ /**
1868
+ * Preprocess a decoded RGB image for the Gemma 4 ViT: aspect-preserving resize so
1869
+ * the patch grid is ≤ max_soft_tokens·k² patches and H,W divisible by k·patch,
1870
+ * rescale ×1/255 (no normalize), patchify row-major into [N, 3·16·16].
1871
+ *
1872
+ * @param pixels row-major HWC RGB (0..255), length width*height*3.
1873
+ */
1874
+ function preprocessImageGemma4(pixels, width, height, maxSoftTokens = 280, poolingKernelSize = 3, patchSize = 16) {
1875
+ const factor = poolingKernelSize * patchSize;
1876
+ const maxPatches = maxSoftTokens * poolingKernelSize * poolingKernelSize;
1877
+ const floorByFactor = (n) => Math.max(factor, Math.floor(n / factor) * factor);
1878
+ let outH = floorByFactor(height);
1879
+ let outW = floorByFactor(width);
1880
+ while (outH / patchSize * (outW / patchSize) > maxPatches) {
1881
+ const beta = Math.sqrt(outH / patchSize * (outW / patchSize) / maxPatches);
1882
+ outH = floorByFactor(outH / beta);
1883
+ outW = floorByFactor(outW / beta);
1884
+ if (outH <= factor && outW <= factor) break;
1885
+ }
1886
+ const resized = bilinearResize(pixels, height, width, outH, outW);
1887
+ const gridH = outH / patchSize;
1888
+ const gridW = outW / patchSize;
1889
+ const ch = 3;
1890
+ const ps = patchSize;
1891
+ const patchDim = ch * ps * ps;
1892
+ const numP = gridH * gridW;
1893
+ const patches = new Float32Array(numP * patchDim);
1894
+ const rescale = 1 / 255;
1895
+ let pIdx = 0;
1896
+ for (let pr = 0; pr < gridH; pr++) for (let pc = 0; pc < gridW; pc++) {
1897
+ const base = pIdx * patchDim;
1898
+ let kk = 0;
1899
+ for (let c = 0; c < ch; c++) for (let py = 0; py < ps; py++) {
1900
+ const row = (pr * ps + py) * outW;
1901
+ for (let px = 0; px < ps; px++) {
1902
+ patches[base + kk] = resized[(row + pc * ps + px) * 3 + c] * rescale;
1903
+ kk++;
1904
+ }
1905
+ }
1906
+ pIdx++;
1907
+ }
1908
+ return {
1909
+ patches,
1910
+ gridHW: [gridH, gridW]
1911
+ };
1912
+ }
1913
+ const QWEN3_5_IMAGE_PROCESSOR = {
1914
+ patchSize: 16,
1915
+ temporalPatchSize: 2,
1916
+ mergeSize: 2,
1917
+ imageMean: [
1918
+ .5,
1919
+ .5,
1920
+ .5
1921
+ ],
1922
+ imageStd: [
1923
+ .5,
1924
+ .5,
1925
+ .5
1926
+ ],
1927
+ rescaleFactor: 1 / 255,
1928
+ minPixels: 65536,
1929
+ maxPixels: 16777216
1930
+ };
1931
+ /**
1932
+ * Qwen2-VL smart-resize: round H and W to multiples of factor=patch*merge,
1933
+ * keeping aspect ratio and clamping the total pixel budget to [minPixels, maxPixels].
1934
+ * Matches transformers.models.qwen2_vl.image_processing.smart_resize.
1935
+ */
1936
+ function smartResize(height, width, factor, minPixels, maxPixels) {
1937
+ if (height < factor || width < factor) {
1938
+ const scale = factor / Math.min(height, width);
1939
+ height = Math.max(factor, Math.round(height * scale));
1940
+ width = Math.max(factor, Math.round(width * scale));
1941
+ }
1942
+ const roundByFactor = (n) => Math.round(n / factor) * factor;
1943
+ const floorByFactor = (n) => Math.floor(n / factor) * factor;
1944
+ const ceilByFactor = (n) => Math.ceil(n / factor) * factor;
1945
+ let hBar = Math.max(factor, roundByFactor(height));
1946
+ let wBar = Math.max(factor, roundByFactor(width));
1947
+ if (hBar * wBar > maxPixels) {
1948
+ const beta = Math.sqrt(height * width / maxPixels);
1949
+ hBar = Math.max(factor, floorByFactor(height / beta));
1950
+ wBar = Math.max(factor, floorByFactor(width / beta));
1951
+ } else if (hBar * wBar < minPixels) {
1952
+ const beta = Math.sqrt(minPixels / (height * width));
1953
+ hBar = ceilByFactor(height * beta);
1954
+ wBar = ceilByFactor(width * beta);
1955
+ }
1956
+ return [hBar, wBar];
1957
+ }
1958
+ /**
1959
+ * Bilinear-resample an RGB image (HWC, 0..255 or 0..1) to (outH, outW).
1960
+ * Half-pixel-centered sampling to match PIL/torchvision BILINEAR closely.
1961
+ * `pixels` is row-major [H, W, 3].
1962
+ */
1963
+ function bilinearResize(pixels, inH, inW, outH, outW) {
1964
+ const out = new Float32Array(outH * outW * 3);
1965
+ const scaleH = inH / outH;
1966
+ const scaleW = inW / outW;
1967
+ for (let oy = 0; oy < outH; oy++) {
1968
+ let sy = (oy + .5) * scaleH - .5;
1969
+ if (sy < 0) sy = 0;
1970
+ if (sy > inH - 1) sy = inH - 1;
1971
+ const y0 = Math.floor(sy);
1972
+ const y1 = Math.min(y0 + 1, inH - 1);
1973
+ const fy = sy - y0;
1974
+ for (let ox = 0; ox < outW; ox++) {
1975
+ let sx = (ox + .5) * scaleW - .5;
1976
+ if (sx < 0) sx = 0;
1977
+ if (sx > inW - 1) sx = inW - 1;
1978
+ const x0 = Math.floor(sx);
1979
+ const x1 = Math.min(x0 + 1, inW - 1);
1980
+ const fx = sx - x0;
1981
+ const i00 = (y0 * inW + x0) * 3;
1982
+ const i01 = (y0 * inW + x1) * 3;
1983
+ const i10 = (y1 * inW + x0) * 3;
1984
+ const i11 = (y1 * inW + x1) * 3;
1985
+ const od = (oy * outW + ox) * 3;
1986
+ for (let c = 0; c < 3; c++) {
1987
+ const top = pixels[i00 + c] * (1 - fx) + pixels[i01 + c] * fx;
1988
+ const bot = pixels[i10 + c] * (1 - fx) + pixels[i11 + c] * fx;
1989
+ out[od + c] = top * (1 - fy) + bot * fy;
1990
+ }
1991
+ }
1992
+ }
1993
+ return out;
1994
+ }
1995
+ /**
1996
+ * Preprocess a decoded RGB image into the [N, 1536] patch tensor + grid_thw that
1997
+ * `encodeImage()` expects, matching the HF Qwen2-VL image processor:
1998
+ * smart_resize → rescale (×1/255) → normalize → temporal-pair (×temporal_patch_size)
1999
+ * → patchify into spatial_merge×spatial_merge blocks → flatten to [N, C·T·P·P].
2000
+ *
2001
+ * @param pixels row-major HWC RGB, length width*height*3. Values 0..255 (default)
2002
+ * or already 0..1 if `rescaled` is true.
2003
+ * @param width source pixel width
2004
+ * @param height source pixel height
2005
+ */
2006
+ function preprocessImage(pixels, width, height, cfg = QWEN3_5_IMAGE_PROCESSOR, rescaled = false) {
2007
+ const [outH, outW] = smartResize(height, width, cfg.patchSize * cfg.mergeSize, cfg.minPixels, cfg.maxPixels);
2008
+ const resized = bilinearResize(pixels, height, width, outH, outW);
2009
+ const [mr, mg, mb] = cfg.imageMean;
2010
+ const [sr, sg, sb] = cfg.imageStd;
2011
+ const mean = [
2012
+ mr,
2013
+ mg,
2014
+ mb
2015
+ ];
2016
+ const std = [
2017
+ sr,
2018
+ sg,
2019
+ sb
2020
+ ];
2021
+ const rescale = rescaled ? 1 : cfg.rescaleFactor;
2022
+ const chw = new Float32Array(3 * outH * outW);
2023
+ const plane = outH * outW;
2024
+ for (let y = 0; y < outH; y++) for (let x = 0; x < outW; x++) {
2025
+ const src = (y * outW + x) * 3;
2026
+ const dst = y * outW + x;
2027
+ for (let c = 0; c < 3; c++) chw[c * plane + dst] = (resized[src + c] * rescale - mean[c]) / std[c];
2028
+ }
2029
+ const t = 1;
2030
+ const gridH = outH / cfg.patchSize;
2031
+ const gridW = outW / cfg.patchSize;
2032
+ const ps = cfg.patchSize;
2033
+ const tps = cfg.temporalPatchSize;
2034
+ const ch = 3;
2035
+ const merge = cfg.mergeSize;
2036
+ const gridHm = gridH / merge;
2037
+ const gridWm = gridW / merge;
2038
+ const patchDim = ch * tps * ps * ps;
2039
+ const numPatches = t * gridH * gridW;
2040
+ const patches = new Float32Array(numPatches * patchDim);
2041
+ let pIdx = 0;
2042
+ for (let gt = 0; gt < t; gt++) for (let bh = 0; bh < gridHm; bh++) for (let bw = 0; bw < gridWm; bw++) for (let mh = 0; mh < merge; mh++) for (let mw = 0; mw < merge; mw++) {
2043
+ const patchRow = (bh * merge + mh) * ps;
2044
+ const patchCol = (bw * merge + mw) * ps;
2045
+ const base = pIdx * patchDim;
2046
+ let k = 0;
2047
+ for (let c = 0; c < ch; c++) for (let tt = 0; tt < tps; tt++) for (let py = 0; py < ps; py++) {
2048
+ const row = (patchRow + py) * outW;
2049
+ for (let px = 0; px < ps; px++) {
2050
+ patches[base + k] = chw[c * plane + row + (patchCol + px)];
2051
+ k++;
2052
+ }
2053
+ }
2054
+ pIdx++;
2055
+ }
2056
+ return {
2057
+ patches,
2058
+ gridTHW: [
2059
+ t,
2060
+ gridH,
2061
+ gridW
2062
+ ]
2063
+ };
2064
+ }
2065
+ /**
2066
+ * Build the 3D M-RoPE position ids [3, seq] for a text sequence that contains
2067
+ * `image_token_id` runs, matching Qwen3_5Model.get_rope_index for a single image.
2068
+ *
2069
+ * Text tokens advance all three (t,h,w) components together (standard 1D RoPE).
2070
+ * Each image run uses the vision grid: temporal=start (t=1), height=start+h_block,
2071
+ * width=start+w_block, in row-major (h outer, w inner) order; after the image the
2072
+ * cursor jumps by max(h_merged, w_merged). Returns Int32Array of length 3*seq
2073
+ * laid out as [T-row(seq), H-row(seq), W-row(seq)].
2074
+ */
2075
+ /**
2076
+ * Fill the 3D position rows for one image run starting at sequence index `at`
2077
+ * and logical `start`. Returns the next sequence index and the post-image cursor.
2078
+ */
2079
+ function fillImagePositions(rows, at, start, grid, mergeSize) {
2080
+ const [gt, gh, gw] = grid;
2081
+ const hm = gh / mergeSize;
2082
+ const wm = gw / mergeSize;
2083
+ let i = at;
2084
+ for (let tt = 0; tt < gt; tt++) for (let h = 0; h < hm; h++) for (let w = 0; w < wm; w++) {
2085
+ rows.t[i] = start + tt;
2086
+ rows.h[i] = start + h;
2087
+ rows.w[i] = start + w;
2088
+ i++;
2089
+ }
2090
+ return {
2091
+ next: i,
2092
+ cursor: start + Math.max(hm, wm)
2093
+ };
2094
+ }
2095
+ function buildMRoPEPositionIds(inputIds, imageGrids, imageTokenId, mergeSize) {
2096
+ const seq = inputIds.length;
2097
+ const rows = {
2098
+ t: new Int32Array(seq),
2099
+ h: new Int32Array(seq),
2100
+ w: new Int32Array(seq)
2101
+ };
2102
+ let i = 0;
2103
+ let cursor = 0;
2104
+ let imageIdx = 0;
2105
+ while (i < seq) if (inputIds[i] === imageTokenId) {
2106
+ const res = fillImagePositions(rows, i, cursor, imageGrids[imageIdx++], mergeSize);
2107
+ i = res.next;
2108
+ cursor = res.cursor;
2109
+ } else {
2110
+ rows.t[i] = cursor;
2111
+ rows.h[i] = cursor;
2112
+ rows.w[i] = cursor;
2113
+ cursor++;
2114
+ i++;
2115
+ }
2116
+ const out = new Int32Array(3 * seq);
2117
+ out.set(rows.t, 0);
2118
+ out.set(rows.h, seq);
2119
+ out.set(rows.w, 2 * seq);
2120
+ return out;
2121
+ }
2122
+ /**
2123
+ * Per-pair frequency→dimension assignment for interleaved M-RoPE, matching
2124
+ * Qwen3_5TextRotaryEmbedding.apply_interleaved_mrope. For pair index i in
2125
+ * [0, sum(section)) the position component is section-cyclic: T,H,W,T,H,W,...
2126
+ * but each component capped at its section count. Returns an array of length
2127
+ * (rope_dim/2) with values 0=T, 1=H, 2=W.
2128
+ */
2129
+ function mropeFreqDims(mropeSection) {
2130
+ const total = mropeSection[0] + mropeSection[1] + mropeSection[2];
2131
+ const dims = new Int32Array(total);
2132
+ for (let d = 1; d <= 2; d++) {
2133
+ const length = mropeSection[d] * 3;
2134
+ for (let idx = d; idx < length && idx < total; idx += 3) dims[idx] = d;
2135
+ }
2136
+ return dims;
2137
+ }
2138
+ /**
2139
+ * Build the interleaved-M-RoPE cos/sin tables [seq, rope_dim] from 3D position
2140
+ * ids, matching Qwen3_5TextRotaryEmbedding.forward:
2141
+ * freqs[d][i] = pos[d] * inv_freq[i], inv_freq[i] = 1/theta^(2i/rope_dim)
2142
+ * freq[i] picks component mropeFreqDims[i]; emb = cat(freqs, freqs).
2143
+ * cos/sin have length seq*rope_dim. For text-only (all 3 pos rows equal) this
2144
+ * reduces exactly to standard 1D partial RoPE.
2145
+ *
2146
+ * @param positionIds3 [3, seq] as produced by buildMRoPEPositionIds.
2147
+ * @param ropeDim number of rotated dims per head (head_dim * partial_factor).
2148
+ */
2149
+ function buildMRoPECosSin(positionIds3, seq, ropeDim, theta, mropeSection) {
2150
+ const half = ropeDim / 2;
2151
+ const freqDims = mropeFreqDims(mropeSection);
2152
+ const invFreq = new Float64Array(half);
2153
+ for (let i = 0; i < half; i++) invFreq[i] = 1 / theta ** (2 * i / ropeDim);
2154
+ const cos = new Float32Array(seq * ropeDim);
2155
+ const sin = new Float32Array(seq * ropeDim);
2156
+ for (let n = 0; n < seq; n++) {
2157
+ const base = n * ropeDim;
2158
+ for (let i = 0; i < half; i++) {
2159
+ const angle = positionIds3[freqDims[i] * seq + n] * invFreq[i];
2160
+ const c = Math.cos(angle);
2161
+ const s = Math.sin(angle);
2162
+ cos[base + i] = c;
2163
+ sin[base + i] = s;
2164
+ cos[base + half + i] = c;
2165
+ sin[base + half + i] = s;
2166
+ }
2167
+ }
2168
+ return {
2169
+ cos,
2170
+ sin
2171
+ };
2172
+ }
2173
+
2174
+ //#endregion
2175
+ //#region src/browser/device-guards.ts
2176
+ const WEBKIT_GROUP_PROBE_KEY = "gerbil-webkit-group-v1";
2177
+ const DEFAULT_PROBE = {
2178
+ knownGood: 1,
2179
+ trying: null,
2180
+ capped: false
2181
+ };
2182
+ /** Read the persisted WebKit group probe record (guarded; safe on node). */
2183
+ function readGroupProbe() {
2184
+ if (typeof localStorage === "undefined") return { ...DEFAULT_PROBE };
2185
+ try {
2186
+ const raw = localStorage.getItem(WEBKIT_GROUP_PROBE_KEY);
2187
+ if (!raw) return { ...DEFAULT_PROBE };
2188
+ const parsed = JSON.parse(raw);
2189
+ return {
2190
+ knownGood: typeof parsed.knownGood === "number" && parsed.knownGood >= 1 ? parsed.knownGood : 1,
2191
+ trying: typeof parsed.trying === "number" ? parsed.trying : null,
2192
+ capped: parsed.capped === true
2193
+ };
2194
+ } catch {
2195
+ return { ...DEFAULT_PROBE };
2196
+ }
2197
+ }
2198
+ /** Persist the WebKit group probe record (guarded; no-op on node). */
2199
+ function writeGroupProbe(rec) {
2200
+ if (typeof localStorage === "undefined") return;
2201
+ try {
2202
+ localStorage.setItem(WEBKIT_GROUP_PROBE_KEY, JSON.stringify(rec));
2203
+ } catch {}
2204
+ }
2205
+ /**
2206
+ * The validated non-phone sweet spot. iPad swept 1→7.9, 8→19, 32→24.8, 64→26.6,
2207
+ * 128→26.9 (peak), 256→26.2 tok/s — a plateau from ~64 up, so 128 is the best
2208
+ * stable target (more batching just costs memory). Non-phone WebKit jumps here
2209
+ * directly; the crash breadcrumb caps it down if a device can't sustain it.
2210
+ */
2211
+ const NONPHONE_TARGET_GROUP = 128;
2212
+ /**
2213
+ * Resolve the WebKit group size to use this session, recording `trying` as a
2214
+ * side effect so a crash this load is detectable on the next load.
2215
+ *
2216
+ * Algorithm (only meaningful on WebKit; inert otherwise):
2217
+ * 1. Read the record (default {knownGood:1, trying:null, capped:false}).
2218
+ * 2. If `trying !== null` on entry → the previous load set it but never cleared
2219
+ * it → that load CRASHED at `trying`. Cap there, keep `knownGood`, clear
2220
+ * `trying`. Use `knownGood` this session.
2221
+ * 3. Else if !capped and there is a rung above `knownGood` → set `trying = next`,
2222
+ * persist BEFORE any GPU work, and use it (we're escalating).
2223
+ * 4. Else → use `knownGood`.
2224
+ *
2225
+ * @returns the group size to use this session.
2226
+ */
2227
+ function resolveWebkitGroupSize(args) {
2228
+ if (args.override && args.override > 0) return args.override;
2229
+ if (!args.isWebKit) return 1;
2230
+ const rec = readGroupProbe();
2231
+ if (rec.trying !== null) {
2232
+ const crashedAt = rec.trying;
2233
+ const next = {
2234
+ knownGood: rec.knownGood,
2235
+ trying: null,
2236
+ capped: true
2237
+ };
2238
+ writeGroupProbe(next);
2239
+ console.log(`[engine] webkit group probe: previous load crashed at group=${crashedAt} → capping at knownGood=${next.knownGood}`);
2240
+ return next.knownGood;
2241
+ }
2242
+ if (!rec.capped && !args.conservative && rec.knownGood < NONPHONE_TARGET_GROUP) {
2243
+ writeGroupProbe({
2244
+ knownGood: rec.knownGood,
2245
+ trying: NONPHONE_TARGET_GROUP,
2246
+ capped: false
2247
+ });
2248
+ console.log(`[engine] webkit group probe: trying target group=${NONPHONE_TARGET_GROUP} (knownGood=${rec.knownGood})`);
2249
+ return NONPHONE_TARGET_GROUP;
2250
+ }
2251
+ console.log(`[engine] webkit group probe: knownGood=${rec.knownGood} trying=null capped=${rec.capped} → using ${rec.knownGood}`);
2252
+ return rec.knownGood;
2253
+ }
2254
+ /**
2255
+ * Promote (or cap) the WebKit group probe after the first successful forward.
2256
+ *
2257
+ * Call this once per page-load, after the model has loaded AND a first forward
2258
+ * completed without the page dying. The breadcrumb already handles the crash
2259
+ * class (the page death leaves `trying` set for the next load); this handles the
2260
+ * wrong-output class and records success.
2261
+ *
2262
+ * @param correct true if the first forward produced non-corrupt output.
2263
+ * - correct → promote: knownGood = trying, trying = null.
2264
+ * - incorrect → cap: keep knownGood at the prior rung, trying = null, capped.
2265
+ */
2266
+ function promoteGroupProbe(correct) {
2267
+ if (typeof localStorage === "undefined") return;
2268
+ const rec = readGroupProbe();
2269
+ if (rec.trying === null) return;
2270
+ if (correct) {
2271
+ writeGroupProbe({
2272
+ knownGood: rec.trying,
2273
+ trying: null,
2274
+ capped: rec.capped
2275
+ });
2276
+ console.log(`[engine] webkit group probe: PROMOTED group=${rec.trying} to known-good`);
2277
+ } else {
2278
+ writeGroupProbe({
2279
+ knownGood: rec.knownGood,
2280
+ trying: null,
2281
+ capped: true
2282
+ });
2283
+ console.log(`[engine] webkit group probe: group=${rec.trying} produced INCORRECT output → capping at knownGood=${rec.knownGood}`);
2284
+ }
2285
+ }
2286
+
2287
+ //#endregion
2288
+ //#region src/gpu/index.ts
2289
+ /**
2290
+ * WebGPUEngine -- gerbil's native WebGPU inference engine.
2291
+ *
2292
+ * Public API:
2293
+ * const engine = await WebGPUEngine.create({ repo: "Qwen/Qwen3.5-0.8B" });
2294
+ * const result = await engine.generate("What is 2+2?");
2295
+ * engine.destroy();
2296
+ */
2297
+ /** EmbeddingGemma task prefixes (from config_sentence_transformers.json prompts). */
2298
+ const EMBEDDING_GEMMA_PROMPTS = {
2299
+ query: "task: search result | query: ",
2300
+ document: "title: none | text: "
2301
+ };
2302
+ /**
2303
+ * Extract the first complete JSON object or array from a model's text output.
2304
+ *
2305
+ * Tolerant of prose, markdown, and ```json code fences: it scans for the first
2306
+ * `{` or `[`, then walks forward tracking string/escape state and brace/bracket
2307
+ * depth to find the matching close, returning the balanced substring. Returns
2308
+ * `undefined` when no JSON-looking span is present. Used by
2309
+ * {@link WebGPUEngine.generateObject}.
2310
+ */
2311
+ function extractJson(text) {
2312
+ const cleaned = text.replace(/```(?:json)?/gi, "");
2313
+ const start = cleaned.search(/[{[]/);
2314
+ if (start === -1) return void 0;
2315
+ const open = cleaned[start];
2316
+ const close = open === "{" ? "}" : "]";
2317
+ let depth = 0;
2318
+ let inString = false;
2319
+ let escaped = false;
2320
+ for (let i = start; i < cleaned.length; i++) {
2321
+ const ch = cleaned[i];
2322
+ if (inString) {
2323
+ if (escaped) escaped = false;
2324
+ else if (ch === "\\") escaped = true;
2325
+ else if (ch === "\"") inString = false;
2326
+ continue;
2327
+ }
2328
+ if (ch === "\"") inString = true;
2329
+ else if (ch === open) depth++;
2330
+ else if (ch === close) {
2331
+ depth--;
2332
+ if (depth === 0) return cleaned.slice(start, i + 1);
2333
+ }
2334
+ }
2335
+ }
2336
+ /**
2337
+ * The main WebGPU inference engine.
2338
+ *
2339
+ * Usage:
2340
+ * const engine = await WebGPUEngine.create({ repo: "Qwen/Qwen3.5-0.8B" });
2341
+ * const result = await engine.generate("Hello!");
2342
+ * console.log(result.text);
2343
+ * engine.destroy();
2344
+ */
2345
+ var WebGPUEngine = class WebGPUEngine {
2346
+ ctx;
2347
+ executor;
2348
+ tokenizer;
2349
+ _destroyed = false;
2350
+ _isEmbedding;
2351
+ /** HF architecture string (e.g. "Gemma3TextModel", "Qwen3ForCausalLM"). */
2352
+ _architecture;
2353
+ /** Vision encoder (built only when enableVision and the model is vision-capable). */
2354
+ visionExecutor;
2355
+ /** Raw vision_config (for host preprocessing of grids). */
2356
+ visionConfig;
2357
+ /** Raw pos_embed.weight table for bilinear interpolation. */
2358
+ visionPosEmbedTable;
2359
+ /** True when the LM graph was built with the multimodal (M-RoPE + splice) path. */
2360
+ _multimodalGraph;
2361
+ /** Raw config.json (for M-RoPE params: mrope_section, rope_theta, partial factor). */
2362
+ rawConfig;
2363
+ /** Effective max sequence length (cos/sin table coverage). */
2364
+ maxSeqLen;
2365
+ /** Original create() options (used to lazily spin up the Kani-TTS engine for speak()). */
2366
+ _createOptions;
2367
+ /** Lazily-created Kani-TTS engine (codec-LM + NanoCodec) backing speak(). */
2368
+ _kaniTTS = null;
2369
+ /**
2370
+ * WebKit group-size probe state. When true, a candidate group size is being
2371
+ * tried this page-load and must be promoted (or capped) after the FIRST
2372
+ * successful forward produces non-corrupt logits. Goes false once handled so
2373
+ * promotion runs at most once per session. Always false on Dawn/node.
2374
+ */
2375
+ _groupProbePending = false;
2376
+ /** Model capabilities (text, vision, moe). */
2377
+ capabilities;
2378
+ /** Model architecture config. */
2379
+ config;
2380
+ constructor(ctx, executor, tokenizer, graph, opts, vision) {
2381
+ this.ctx = ctx;
2382
+ this.executor = executor;
2383
+ this.tokenizer = tokenizer;
2384
+ this.capabilities = graph.capabilities;
2385
+ this.config = graph.config;
2386
+ this._isEmbedding = graph.outputs.includes("embedding");
2387
+ this._architecture = graph.architecture;
2388
+ this.visionExecutor = vision?.executor ?? null;
2389
+ this.visionConfig = vision?.config ?? null;
2390
+ this.visionPosEmbedTable = vision?.posEmbedTable ?? null;
2391
+ this._multimodalGraph = opts.multimodalGraph;
2392
+ this.rawConfig = opts.rawConfig;
2393
+ this.maxSeqLen = opts.maxSeqLen;
2394
+ this._createOptions = opts.createOptions;
2395
+ this._groupProbePending = opts.groupProbePending ?? false;
2396
+ }
2397
+ /** True if this engine has a vision encoder built (use encodeImage()). */
2398
+ get hasVision() {
2399
+ return this.visionExecutor !== null;
2400
+ }
2401
+ /** Per-opType decode GPU-time breakdown (only populated under GERBIL_PROFILE). */
2402
+ getDecodeProfile() {
2403
+ return this.executor.getProfile();
2404
+ }
2405
+ /** Clear accumulated decode profiler data (e.g. to drop warm-up tokens). */
2406
+ resetDecodeProfile() {
2407
+ this.executor.resetProfile();
2408
+ }
2409
+ /** Profile ONE real decode step (the pipelined-greedy kernels). Token-independent
2410
+ * timing — pass any valid id. Only meaningful under GERBIL_PROFILE. */
2411
+ async profileDecodeStep(tokenId) {
2412
+ await this.executor.profileDecodeStep(tokenId);
2413
+ }
2414
+ /** Decode dispatch count + the device's storage-buffer limit (which gates the
2415
+ * INT4 projection fusions). Lets the iPad runner report whether fusions applied
2416
+ * on-device or silently fell back (8 < 9 ⇒ more dispatches ⇒ more mobile drains). */
2417
+ getDecodeStats() {
2418
+ return {
2419
+ dispatches: this.executor.decodeDispatchCount,
2420
+ maxStorageBuffers: this.executor.maxStorageBuffers
2421
+ };
2422
+ }
2423
+ /**
2424
+ * Write a coarse crash-phase breadcrumb that survives a GPU-process kill / page
2425
+ * reload. The iPad harness reads `localStorage["gerbil-crash-phase"]` after a
2426
+ * crash; without these, a describe-time crash only shows the last load phase
2427
+ * ("engine:ready"). The describe path tags vit-encode / splice / text-decode so
2428
+ * the next run shows WHERE it died, not just "crashed after load".
2429
+ */
2430
+ setPhase(phase) {
2431
+ try {
2432
+ if (typeof localStorage !== "undefined") localStorage.setItem("gerbil-crash-phase", phase);
2433
+ } catch {}
2434
+ }
2435
+ /** True if this engine was loaded as an embedding model (use embed(), not generate()). */
2436
+ get isEmbedding() {
2437
+ return this._isEmbedding;
2438
+ }
2439
+ /**
2440
+ * WebKit group-size probe promotion hook. Runs at most once per session, after
2441
+ * the FIRST forward completes without the page dying. If the page had crashed
2442
+ * at this group size, this code never runs and the localStorage breadcrumb
2443
+ * (left by the resolver) caps the device on the next load — that is what makes
2444
+ * the probe survive the crash class. Here we additionally handle the
2445
+ * wrong-output class by inspecting the first forward's logits for corruption
2446
+ * (NaN / Inf / all-zero / all-same), reusing the same signals as integrityCheck().
2447
+ */
2448
+ maybePromoteGroupProbe(logits) {
2449
+ if (!this._groupProbePending) return;
2450
+ this._groupProbePending = false;
2451
+ const n = Math.min(logits.length, 256);
2452
+ let allZero = true;
2453
+ let allSame = true;
2454
+ let finite = true;
2455
+ const first = logits[0];
2456
+ for (let i = 0; i < n; i++) {
2457
+ const v = logits[i];
2458
+ if (!Number.isFinite(v)) {
2459
+ finite = false;
2460
+ break;
2461
+ }
2462
+ if (v !== 0) allZero = false;
2463
+ if (v !== first) allSame = false;
2464
+ }
2465
+ promoteGroupProbe(finite && !allZero && !allSame);
2466
+ }
2467
+ /**
2468
+ * Create and initialize a WebGPUEngine.
2469
+ *
2470
+ * Downloads the model from HuggingFace, compiles shaders, uploads weights.
2471
+ */
2472
+ static async create(options = {}) {
2473
+ options = {
2474
+ ...options,
2475
+ repo: resolveDefaultRepo(options)
2476
+ };
2477
+ const ctx = await initGPU();
2478
+ const isBrowser = typeof navigator !== "undefined" && typeof location !== "undefined";
2479
+ const isSafari = isBrowser && /Safari/.test(navigator.userAgent) && !/Chrome/.test(navigator.userAgent);
2480
+ const params = isBrowser ? new URLSearchParams(location.search) : null;
2481
+ const forceKvF32 = params?.has("kvf32");
2482
+ const maxSeqOverride = params?.get("maxseq");
2483
+ const groupOverride = params?.get("group");
2484
+ let kvMode;
2485
+ if (options.kvMode) kvMode = options.kvMode;
2486
+ else if (forceKvF32) kvMode = "f32";
2487
+ else if (isSafari && ctx.hasF16) kvMode = "packed-f16";
2488
+ else if (ctx.hasF16) kvMode = "native-f16";
2489
+ else kvMode = "f32";
2490
+ const kvDtype = kvMode === "f32" ? "f32" : "f16";
2491
+ console.log(`[engine] kvMode: ${kvMode}, kvDtype: ${kvDtype}, f16 supported: ${ctx.hasF16}, safari: ${isSafari}`);
2492
+ const setPhase = (phase) => {
2493
+ try {
2494
+ if (typeof localStorage !== "undefined") localStorage.setItem("gerbil-crash-phase", phase);
2495
+ } catch {}
2496
+ };
2497
+ setPhase("engine:loading-model");
2498
+ const maxVisionPatches = options.maxVisionPatches ?? (ctx.isWebKitWebGPU ? 1024 : 4096);
2499
+ const multimodal = options.enableVision ? { maxVisionTokens: Math.ceil(maxVisionPatches / 4) } : void 0;
2500
+ const { graph, tokenizer, weights, rawConfig, pleSource } = await loadModel({
2501
+ ...options,
2502
+ repo: resolveDefaultRepo(options),
2503
+ kvDtype,
2504
+ multimodal
2505
+ });
2506
+ setPhase(`engine:model-loaded:${weights.size}-weights`);
2507
+ let visionBundle = null;
2508
+ const hasVisionConfig = rawConfig.vision_config != null;
2509
+ const isGemma4Vision = rawConfig.vision_config?.model_type === "gemma4_vision" || [...weights.keys()].some((k) => k.startsWith("vision_tower."));
2510
+ const visPosKey = isGemma4Vision ? "vision_tower.patch_embedder.position_embedding_table" : "visual.pos_embed.weight";
2511
+ const towerPrefix = isGemma4Vision ? "vision_tower." : "visual.";
2512
+ const hasVisualWeights = weights.keys().some((k) => k.startsWith(towerPrefix));
2513
+ if (Boolean(options.enableVision && hasVisionConfig) && hasVisualWeights) {
2514
+ const visGraph = isGemma4Vision ? generateGemma4VisionGraph(rawConfig) : generateQwen3_5VisionGraph(rawConfig);
2515
+ const maxPatches = maxVisionPatches;
2516
+ const visWeights = /* @__PURE__ */ new Map();
2517
+ for (const k of weights.keys()) if (k.startsWith(towerPrefix) || k.startsWith("embed_vision.") || k === visPosKey) {
2518
+ const w = await weights.get(k);
2519
+ if (w) visWeights.set(k, w);
2520
+ }
2521
+ if (isGemma4Vision) {
2522
+ patchGemma4VisionClips(visGraph, visWeights);
2523
+ const gi = resolveGemma4VisionInfo(rawConfig);
2524
+ dequantizeGemma4VisionProjection(visWeights, (rawConfig.quantization_config ?? rawConfig.quantization)?.group_size ?? 64, gi.textHidden, gi.hiddenSize);
2525
+ ensureGemma4VisionEmbedderNorms(visWeights, gi.hiddenSize, gi.textHidden);
2526
+ }
2527
+ const visExec = new VisionExecutor(ctx, visGraph, maxPatches);
2528
+ const posW = visWeights.get(visPosKey);
2529
+ const posTable = posW && posW.data instanceof Float32Array ? new Float32Array(posW.data) : posW ? new Float32Array(posW.data.buffer, posW.data.byteOffset, posW.data.byteLength / 4).slice() : null;
2530
+ await visExec.uploadWeights(visWeights);
2531
+ visExec.initBindGroups();
2532
+ if (posTable) visionBundle = {
2533
+ executor: visExec,
2534
+ config: rawConfig.vision_config ?? rawConfig,
2535
+ posEmbedTable: posTable
2536
+ };
2537
+ }
2538
+ let maxSeqLen;
2539
+ if (maxSeqOverride) maxSeqLen = Math.min(Number.parseInt(maxSeqOverride, 10), graph.config.context_length);
2540
+ else if (forceKvF32 && isSafari) maxSeqLen = Math.min(options.maxSeqLen ?? 1024, graph.config.context_length, 1024);
2541
+ else if (ctx.isWebKitWebGPU) maxSeqLen = Math.min(options.maxSeqLen ?? 512, graph.config.context_length, 2048);
2542
+ else maxSeqLen = Math.min(options.maxSeqLen ?? graph.config.context_length, graph.config.context_length, 4096);
2543
+ console.log(`[engine] maxSeqLen: ${maxSeqLen}, architecture: ${graph.architecture}`);
2544
+ if (ctx.isWebKitWebGPU) console.log(`[engine] device limits: maxBufferSize=${ctx.limits.maxBufferSize}, maxStorageBufferBindingSize=${ctx.limits.maxStorageBufferBindingSize}, maxComputeWorkgroupStorageSize=${ctx.limits.maxComputeWorkgroupStorageSize}`);
2545
+ setPhase("engine:allocating-buffers");
2546
+ ctx.device.pushErrorScope("out-of-memory");
2547
+ ctx.device.pushErrorScope("validation");
2548
+ const groupOverrideNum = groupOverride ? Number.parseInt(groupOverride, 10) : void 0;
2549
+ let webkitGroupSize;
2550
+ let probingGroup = false;
2551
+ if (groupOverrideNum) {
2552
+ webkitGroupSize = groupOverrideNum;
2553
+ console.log(`[engine] webkitGroupSize override: ${webkitGroupSize} dispatches/command buffer`);
2554
+ } else if (ctx.isWebKitWebGPU) {
2555
+ webkitGroupSize = resolveWebkitGroupSize({ isWebKit: true });
2556
+ probingGroup = true;
2557
+ }
2558
+ const executor = new Executor(ctx, graph, {
2559
+ maxSeqLen,
2560
+ kvMode,
2561
+ webkitGroupSize
2562
+ });
2563
+ if (pleSource) executor.setPleSource(pleSource);
2564
+ setPhase("engine:uploading-weights");
2565
+ await executor.uploadWeights(weights);
2566
+ await weights.dispose?.();
2567
+ setPhase("engine:compiling-shaders");
2568
+ executor.initBindGroups();
2569
+ const validationError = await ctx.device.popErrorScope();
2570
+ const oomError = await ctx.device.popErrorScope();
2571
+ if (oomError || validationError) {
2572
+ const detail = [oomError ? `out-of-memory: ${oomError.message}` : null, validationError ? `validation: ${validationError.message}` : null].filter(Boolean).join("; ");
2573
+ setPhase(`engine:gpu-error:${detail.slice(0, 120)}`);
2574
+ throw new Error(`GPU setup failed (${detail}). The model likely exceeds this device's memory budget — try a smaller maxSeqLen or a q4-quantized model.`);
2575
+ }
2576
+ setPhase("engine:ready");
2577
+ return new WebGPUEngine(ctx, executor, tokenizer, graph, {
2578
+ multimodalGraph: Boolean(multimodal),
2579
+ rawConfig,
2580
+ maxSeqLen,
2581
+ createOptions: options,
2582
+ groupProbePending: probingGroup
2583
+ }, visionBundle);
2584
+ }
2585
+ /**
2586
+ * Encode an image (already preprocessed into patches) into merged
2587
+ * image-embedding tokens of dim `out_hidden_size` (1024 for Qwen3.5).
2588
+ *
2589
+ * This is the VISION ENCODER ONLY — it returns the image tokens; it does not
2590
+ * splice them into a text sequence or apply M-RoPE (that is the LM-side
2591
+ * integration phase). Requires `enableVision: true` at create() on a
2592
+ * vision-capable checkpoint.
2593
+ *
2594
+ * @param patches Flattened patches, row-major [numPatches, patch_dim].
2595
+ * patch_dim = in_channels * temporal_patch_size * patch_size^2 (1536 for Qwen3.5).
2596
+ * Patches must already be ordered in spatial_merge_size×spatial_merge_size
2597
+ * groups (as the HF image processor emits them).
2598
+ * @param gridTHW The (temporal, height, width) patch-grid dims for the image.
2599
+ * numPatches must equal t*h*w.
2600
+ */
2601
+ async encodeImage(patches, gridTHW, onStage) {
2602
+ this.checkDestroyed();
2603
+ if (!this.visionExecutor || !this.visionConfig || !this.visionPosEmbedTable) throw new Error("encodeImage() requires a vision encoder. Load with { enableVision: true } on a vision-capable checkpoint (e.g. Qwen/Qwen3.5-0.8B).");
2604
+ const vcfg = this.visionConfig;
2605
+ if (this.visionExecutor.gemma4) {
2606
+ const info = resolveGemma4VisionInfo({ vision_config: vcfg });
2607
+ const gridH = gridTHW[1];
2608
+ const gridW = gridTHW[2];
2609
+ const posSize = Math.floor(this.visionPosEmbedTable.length / (2 * info.hiddenSize));
2610
+ const host$1 = buildGemma4VisionPositionTensors(gridH, gridW, this.visionPosEmbedTable, posSize, {
2611
+ hiddenSize: info.hiddenSize,
2612
+ numHeads: info.numHeads,
2613
+ headDim: info.headDim,
2614
+ ropeTheta: info.ropeTheta,
2615
+ poolingKernelSize: info.poolingKernelSize
2616
+ });
2617
+ return this.visionExecutor.encodeGemma4({
2618
+ patches,
2619
+ posEmbeds: host$1.posEmbeds,
2620
+ cos: host$1.cos,
2621
+ sin: host$1.sin,
2622
+ poolMatrix: host$1.poolMatrix,
2623
+ numPatches: host$1.numPatches,
2624
+ numPooled: host$1.numPooled
2625
+ }, onStage);
2626
+ }
2627
+ const hiddenSize = vcfg.hidden_size;
2628
+ const numHeads = vcfg.num_heads;
2629
+ const host = buildVisionPositionTensors(gridTHW, this.visionPosEmbedTable, {
2630
+ hiddenSize,
2631
+ numHeads,
2632
+ numPositionEmbeddings: vcfg.num_position_embeddings,
2633
+ spatialMergeSize: vcfg.spatial_merge_size,
2634
+ ropeTheta: 1e4
2635
+ });
2636
+ const numPatches = gridTHW[0] * gridTHW[1] * gridTHW[2];
2637
+ return this.visionExecutor.encode({
2638
+ patches,
2639
+ posEmbeds: host.posEmbeds,
2640
+ cos: host.cos,
2641
+ sin: host.sin,
2642
+ numPatches
2643
+ }, onStage);
2644
+ }
2645
+ /** Resolve M-RoPE params from rawConfig: rope_dim, theta, mrope_section. */
2646
+ mropeParams() {
2647
+ const cfg = this.rawConfig ?? {};
2648
+ const rope = (cfg.text_config ?? cfg).rope_parameters ?? {};
2649
+ const headDim = this.config.head_dim;
2650
+ const partial = rope.partial_rotary_factor ?? .25;
2651
+ const ropeDim = Math.floor(headDim * partial);
2652
+ const theta = rope.rope_theta ?? this.config.rope_base ?? 1e4;
2653
+ const section = rope.mrope_section ?? [
2654
+ 11,
2655
+ 11,
2656
+ 10
2657
+ ];
2658
+ const visCfg = cfg.vision_config ?? {};
2659
+ return {
2660
+ ropeDim,
2661
+ theta,
2662
+ section,
2663
+ imageTokenId: cfg.image_token_id ?? 248056,
2664
+ mergeSize: visCfg.spatial_merge_size ?? 2
2665
+ };
2666
+ }
2667
+ /**
2668
+ * Write the M-RoPE cos/sin (token order) + image row-map for a prefill of
2669
+ * `positionIds3` ([3, seq]). `rowMap[i]` = vision-buffer row for image tokens,
2670
+ * -1 for text. Returns the logical position of the last token (for decode).
2671
+ */
2672
+ writeMRoPEPrefill(positionIds3, seq, rowMap) {
2673
+ const { ropeDim, theta, section } = this.mropeParams();
2674
+ const { cos, sin } = buildMRoPECosSin(positionIds3, seq, ropeDim, theta, section);
2675
+ this.executor.writeInput("mrope_cos", cos);
2676
+ this.executor.writeInput("mrope_sin", sin);
2677
+ this.executor.writeInput("vision_row_map", rowMap);
2678
+ const last = seq - 1;
2679
+ return Math.max(positionIds3[last], positionIds3[seq + last], positionIds3[2 * seq + last]);
2680
+ }
2681
+ /**
2682
+ * Write a single decode-step M-RoPE cos/sin row at table slot `seqPos` for a
2683
+ * text token at logical position `logicalPos`, plus a -1 row-map entry.
2684
+ */
2685
+ writeMRoPEDecodeStep(seqPos, logicalPos) {
2686
+ const { ropeDim, theta, section } = this.mropeParams();
2687
+ const pid = new Int32Array(3);
2688
+ pid[0] = logicalPos;
2689
+ pid[1] = logicalPos;
2690
+ pid[2] = logicalPos;
2691
+ const { cos, sin } = buildMRoPECosSin(pid, 1, ropeDim, theta, section);
2692
+ const rowBytes = ropeDim * 4;
2693
+ this.executor.writeInputAt("mrope_cos", cos, seqPos * rowBytes);
2694
+ this.executor.writeInputAt("mrope_sin", sin, seqPos * rowBytes);
2695
+ }
2696
+ /** Write linear-position M-RoPE inputs for a pure-text forward (no image). */
2697
+ writeMRoPELinearText(seq) {
2698
+ const { ropeDim, theta, section } = this.mropeParams();
2699
+ const pid = new Int32Array(3 * seq);
2700
+ for (let i = 0; i < seq; i++) {
2701
+ pid[i] = i;
2702
+ pid[seq + i] = i;
2703
+ pid[2 * seq + i] = i;
2704
+ }
2705
+ const { cos, sin } = buildMRoPECosSin(pid, seq, ropeDim, theta, section);
2706
+ this.executor.writeInput("mrope_cos", cos);
2707
+ this.executor.writeInput("mrope_sin", sin);
2708
+ const rowMap = new Int32Array(seq).fill(-1);
2709
+ this.executor.writeInput("vision_row_map", rowMap);
2710
+ }
2711
+ /**
2712
+ * Generate text from a prompt.
2713
+ */
2714
+ async generate(prompt, options = {}) {
2715
+ this.checkDestroyed();
2716
+ const { maxTokens = 512, stopSequences = [], sampling = {}, systemPrompt, onToken } = options;
2717
+ this.executor.reset();
2718
+ let inputIds;
2719
+ if (typeof prompt === "string") {
2720
+ const messages = [];
2721
+ if (systemPrompt) messages.push({
2722
+ role: "system",
2723
+ content: systemPrompt
2724
+ });
2725
+ messages.push({
2726
+ role: "user",
2727
+ content: prompt
2728
+ });
2729
+ inputIds = this.tokenizer.encodeChat(messages, { addGenerationPrompt: true });
2730
+ } else {
2731
+ const messages = systemPrompt ? [{
2732
+ role: "system",
2733
+ content: systemPrompt
2734
+ }, ...prompt] : prompt;
2735
+ inputIds = this.tokenizer.encodeChat(messages, { addGenerationPrompt: true });
2736
+ }
2737
+ const startTime = performance.now();
2738
+ const isGreedy = (sampling.temperature ?? .7) < 1e-6;
2739
+ const hasMRoPE = this._multimodalGraph && this.executor.hasBuffer("mrope_cos");
2740
+ if (this._multimodalGraph) {
2741
+ if (hasMRoPE) this.writeMRoPELinearText(inputIds.length);
2742
+ else if (this.executor.hasBuffer("vision_row_map")) this.executor.writeInput("vision_row_map", new Int32Array(inputIds.length).fill(-1));
2743
+ }
2744
+ let { logits } = await this.executor.forward(new Uint32Array(inputIds));
2745
+ this.maybePromoteGroupProbe(logits);
2746
+ const generatedIds = [];
2747
+ let finishReason = "max_tokens";
2748
+ let generatedText = "";
2749
+ const eosId = this.tokenizer.config.eosTokenId;
2750
+ const consumeToken = (nextToken) => {
2751
+ generatedIds.push(nextToken);
2752
+ if (eosId !== null && nextToken === eosId) {
2753
+ finishReason = "eos";
2754
+ return true;
2755
+ }
2756
+ const tokenText = this.tokenizer.decode([nextToken], true);
2757
+ generatedText += tokenText;
2758
+ onToken?.(tokenText);
2759
+ if (stopSequences.some((s) => generatedText.includes(s))) {
2760
+ for (const s of stopSequences) {
2761
+ const idx = generatedText.indexOf(s);
2762
+ if (idx !== -1) generatedText = generatedText.slice(0, idx);
2763
+ }
2764
+ finishReason = "stop_sequence";
2765
+ return true;
2766
+ }
2767
+ return false;
2768
+ };
2769
+ const mmDecode = hasMRoPE;
2770
+ let mmLogicalPos = inputIds.length;
2771
+ if (isGreedy && !this.executor.needsMultiEncoder && !mmDecode) {
2772
+ const firstToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
2773
+ if (!consumeToken(firstToken)) {
2774
+ const depth = Executor.PIPELINE_DEPTH;
2775
+ const stepsNeeded = Math.min(maxTokens - 1, this.executor.decodeCapacityRemaining());
2776
+ let submitted = 0;
2777
+ let consumed = 0;
2778
+ while (consumed < stepsNeeded) {
2779
+ while (submitted < stepsNeeded && submitted < consumed + depth) {
2780
+ this.executor.submitGreedyDecodeStep(submitted === 0 ? firstToken : null, submitted % depth);
2781
+ submitted++;
2782
+ }
2783
+ const tok = await this.executor.readDecodeToken(consumed % depth);
2784
+ consumed++;
2785
+ if (consumeToken(tok)) break;
2786
+ }
2787
+ }
2788
+ } else for (let step = 0; step < maxTokens; step++) {
2789
+ let nextToken;
2790
+ if (step === 0 || !isGreedy) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
2791
+ else {
2792
+ if (mmDecode) this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
2793
+ nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
2794
+ mmLogicalPos++;
2795
+ }
2796
+ if (consumeToken(nextToken)) break;
2797
+ if (!isGreedy) {
2798
+ if (mmDecode) this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
2799
+ logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
2800
+ mmLogicalPos++;
2801
+ }
2802
+ }
2803
+ const totalTime = performance.now() - startTime;
2804
+ const tokensGenerated = generatedIds.length;
2805
+ const tokensPerSecond = tokensGenerated / (totalTime / 1e3);
2806
+ return {
2807
+ text: generatedText,
2808
+ tokensGenerated,
2809
+ tokensPerSecond,
2810
+ totalTime,
2811
+ finishReason
2812
+ };
2813
+ }
2814
+ /**
2815
+ * Generate a STRUCTURED object: generate text, extract the first JSON
2816
+ * object/array, parse it, validate it, and RETRY until it is valid (on-device
2817
+ * tokens are free, so re-rolling a malformed JSON is cheap).
2818
+ *
2819
+ * Extraction is tolerant: prose, markdown, and ```json code fences are
2820
+ * stripped, then the outermost balanced `{...}` or `[...]` is matched and
2821
+ * `JSON.parse`d. Validation is one of:
2822
+ * - a predicate `(o) => boolean` (return false to reject),
2823
+ * - a minimal JSON-schema-ish object with `required` (those keys must exist),
2824
+ * - nothing (only valid JSON is required).
2825
+ *
2826
+ * On each retry the prompt is nudged with a terse "return ONLY valid JSON…"
2827
+ * instruction (including the required-key shape when known). Throws a clear
2828
+ * error if it never validates within `maxRetries + 1` attempts.
2829
+ *
2830
+ * ```ts
2831
+ * const { object } = await engine.generateObject(
2832
+ * 'Extract {name, age} from: "I am Sarah, 28"',
2833
+ * { schema: { required: ["name", "age"] } },
2834
+ * );
2835
+ * // object === { name: "Sarah", age: 28 }
2836
+ * ```
2837
+ *
2838
+ * @typeParam T Expected object type (not enforced at runtime — validate via schema).
2839
+ */
2840
+ async generateObject(prompt, options = {}) {
2841
+ this.checkDestroyed();
2842
+ const { schema, maxRetries = 4, ...generateOpts } = options;
2843
+ const validate = (value) => {
2844
+ if (typeof schema === "function") return schema(value);
2845
+ if (schema && typeof schema === "object" && Array.isArray(schema.required)) {
2846
+ if (value === null || typeof value !== "object") return false;
2847
+ const obj = value;
2848
+ return schema.required.every((key) => key in obj);
2849
+ }
2850
+ return true;
2851
+ };
2852
+ const nudge = `\n\nReturn ONLY valid JSON matching ${schema && typeof schema === "object" && Array.isArray(schema.required) ? `{ ${schema.required.join(", ")} }` : "the requested shape"}. No prose, no markdown, no code fences.`;
2853
+ const attemptsMax = Math.max(0, maxRetries) + 1;
2854
+ let lastText = "";
2855
+ let lastError = "";
2856
+ for (let attempt = 1; attempt <= attemptsMax; attempt++) {
2857
+ const promptForAttempt = attempt === 1 ? prompt : prompt + nudge;
2858
+ const result = await this.generate(promptForAttempt, generateOpts);
2859
+ lastText = result.text;
2860
+ const parsed = extractJson(result.text);
2861
+ if (parsed === void 0) {
2862
+ lastError = "no JSON object/array found in output";
2863
+ continue;
2864
+ }
2865
+ let value;
2866
+ try {
2867
+ value = JSON.parse(parsed);
2868
+ } catch (e) {
2869
+ lastError = `JSON.parse failed: ${e instanceof Error ? e.message : String(e)}`;
2870
+ continue;
2871
+ }
2872
+ if (!validate(value)) {
2873
+ lastError = "parsed JSON failed schema validation";
2874
+ continue;
2875
+ }
2876
+ return {
2877
+ object: value,
2878
+ text: result.text,
2879
+ attempts: attempt
2880
+ };
2881
+ }
2882
+ throw new Error(`generateObject failed after ${attemptsMax} attempt(s): ${lastError}. Last output: ${JSON.stringify(lastText.slice(0, 200))}`);
2883
+ }
2884
+ /**
2885
+ * Text-to-speech: text → 22 kHz PCM via Kani-TTS-2 (LFM2-350M codec-LM + NVIDIA
2886
+ * NeMo NanoCodec). Returns `{ pcm: Float32Array, sampleRate: 22050 }`.
2887
+ *
2888
+ * Runs the full pipeline: the codec-LM backbone autoregressively emits NanoCodec
2889
+ * audio tokens (4 per frame, frame-level positions + learnable per-layer RoPE),
2890
+ * then the bit-exact NanoCodec decoder (FSQ + causal HiFi-GAN) turns the codes
2891
+ * into PCM. The heavy lifting lives in {@link KaniTTS} (src/gpu/kani-tts.ts); this
2892
+ * lazily constructs that engine on first use (downloading the NanoCodec codec
2893
+ * checkpoint alongside the backbone).
2894
+ *
2895
+ * Requires a Kani-TTS-2 checkpoint (architecture "KaniTTS2ForCausalLM").
2896
+ */
2897
+ async speak(text, options = {}) {
2898
+ this.checkDestroyed();
2899
+ if (this._architecture !== "KaniTTS2ForCausalLM") throw new Error(`speak() requires a Kani-TTS-2 model (architecture "KaniTTS2ForCausalLM"), loaded engine is "${this._architecture}".`);
2900
+ if (!this._kaniTTS) this._kaniTTS = await KaniTTS.create({
2901
+ repo: this._createOptions.repo,
2902
+ revision: this._createOptions.revision,
2903
+ hfToken: this._createOptions.hfToken,
2904
+ cacheDir: this._createOptions.cacheDir,
2905
+ maxSeqLen: this.maxSeqLen
2906
+ });
2907
+ return this._kaniTTS.speak(text, options);
2908
+ }
2909
+ /**
2910
+ * Describe an image: image-in → text-out. Runs the vision encoder, splices the
2911
+ * merged image tokens into a text prompt, applies multimodal M-RoPE, and
2912
+ * generates a description. Requires `enableVision: true` at create().
2913
+ *
2914
+ * Image input forms:
2915
+ * - `{ pixels, width, height }` — decoded RGB (HWC, 0..255), host-preprocessed
2916
+ * (smart-resize/normalize/patchify) to match the HF image processor.
2917
+ * - `{ patches, gridTHW }` — already-built [N,1536] patch tensor + grid (e.g.
2918
+ * HF-exact pixel_values from a reference; skips host preprocessing).
2919
+ */
2920
+ async describeImage(image, prompt = "Describe this image.", options = {}) {
2921
+ this.checkDestroyed();
2922
+ if (!this.visionExecutor) throw new Error("describeImage() requires a vision encoder. Load with { enableVision: true } on a vision-capable checkpoint (Qwen3.5 or Gemma 4).");
2923
+ if (this.visionExecutor.gemma4) {
2924
+ if (!this._multimodalGraph) throw new Error("Gemma 4 describeImage() requires the multimodal text graph. Load with { enableVision: true } so generateGemma4Graph builds the vision_embeds + EmbedSplice variant.");
2925
+ let g4patches;
2926
+ let g4gridHW;
2927
+ if ("patches" in image) {
2928
+ g4patches = image.patches;
2929
+ g4gridHW = [image.gridTHW[1], image.gridTHW[2]];
2930
+ } else {
2931
+ const pre$1 = preprocessImageGemma4(image.pixels, image.width, image.height);
2932
+ g4patches = pre$1.patches;
2933
+ g4gridHW = pre$1.gridHW;
2934
+ }
2935
+ const gThw = [
2936
+ 1,
2937
+ g4gridHW[0],
2938
+ g4gridHW[1]
2939
+ ];
2940
+ const numPatches$1 = gThw[0] * gThw[1] * gThw[2];
2941
+ this.setPhase(`describe:vit-encode:N=${numPatches$1}`);
2942
+ const vision$1 = await this.encodeImage(g4patches, gThw, (stage, info) => {
2943
+ const suffix = info?.layer != null ? `:L${info.layer}` : "";
2944
+ this.setPhase(`describe:${stage}${suffix}`);
2945
+ });
2946
+ const cfg = this.rawConfig ?? {};
2947
+ const imageTokenId$1 = cfg.image_token_id ?? 258880;
2948
+ const boiId = this.tokenizer.tokenToId("<|image>") ?? cfg.boi_token_id ?? 255999;
2949
+ const eoiId = this.tokenizer.tokenToId("<image|>") ?? cfg.eoi_token_id ?? 258882;
2950
+ const bos = this.tokenizer.config.bosToken ? this.tokenizer.tokenToId(this.tokenizer.config.bosToken) ?? [] : [];
2951
+ const bosIds = Array.isArray(bos) ? bos : [bos];
2952
+ const turnUserOpen = this.tokenizer.encode("<|turn>user\n");
2953
+ const imgLead = this.tokenizer.encode("\n\n");
2954
+ const imgTrail = this.tokenizer.encode("\n\n");
2955
+ const promptIds = this.tokenizer.encode(prompt);
2956
+ const turnClose = this.tokenizer.encode("<turn|>\n<|turn>model\n");
2957
+ const imageRun$1 = new Array(vision$1.rows).fill(imageTokenId$1);
2958
+ const inputIds$1 = [
2959
+ ...bosIds,
2960
+ ...turnUserOpen,
2961
+ ...imgLead,
2962
+ boiId,
2963
+ ...imageRun$1,
2964
+ eoiId,
2965
+ ...imgTrail,
2966
+ ...promptIds,
2967
+ ...turnClose
2968
+ ];
2969
+ return this.runMultimodalGemma4(vision$1.embeds, inputIds$1, imageTokenId$1, options);
2970
+ }
2971
+ if (!this._multimodalGraph) throw new Error("describeImage() requires a multimodal engine. Load with { enableVision: true } on a vision-capable checkpoint (e.g. Qwen/Qwen3.5-0.8B).");
2972
+ const procCfg = options.imageProcessor ?? QWEN3_5_IMAGE_PROCESSOR;
2973
+ let patches;
2974
+ let gridTHW;
2975
+ if ("patches" in image) {
2976
+ patches = image.patches;
2977
+ gridTHW = image.gridTHW;
2978
+ } else {
2979
+ const pre$1 = preprocessImage(image.pixels, image.width, image.height, procCfg);
2980
+ patches = pre$1.patches;
2981
+ gridTHW = pre$1.gridTHW;
2982
+ }
2983
+ const numPatches = gridTHW[0] * gridTHW[1] * gridTHW[2];
2984
+ this.setPhase(`describe:vit-encode:N=${numPatches}`);
2985
+ const vision = await this.encodeImage(patches, gridTHW, (stage, info) => {
2986
+ const suffix = info?.layer != null ? `:L${info.layer}` : "";
2987
+ this.setPhase(`describe:${stage}${suffix}`);
2988
+ });
2989
+ const visStart = this.tokenizer.tokenToId("<|vision_start|>") ?? 248053;
2990
+ const visEnd = this.tokenizer.tokenToId("<|vision_end|>") ?? 248054;
2991
+ const { imageTokenId } = this.mropeParams();
2992
+ const numImageTokens = vision.rows;
2993
+ const pre = this.tokenizer.encode("<|im_start|>user\n");
2994
+ const post = this.tokenizer.encode(`${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`);
2995
+ const imageRun = new Array(numImageTokens).fill(imageTokenId);
2996
+ const inputIds = [
2997
+ ...pre,
2998
+ visStart,
2999
+ ...imageRun,
3000
+ visEnd,
3001
+ ...post
3002
+ ];
3003
+ return this.runMultimodal(vision.embeds, gridTHW, inputIds, options);
3004
+ }
3005
+ /**
3006
+ * Prepare the multimodal prefill: upload vision embeds, build the image row-map
3007
+ * and 3D M-RoPE cos/sin, reset state, and write all host inputs. Returns the
3008
+ * input ids and the post-image logical cursor for decode. Does NOT run forward.
3009
+ */
3010
+ prepareMultimodalPrefill(visionEmbeds, gridTHW, inputIds) {
3011
+ const { imageTokenId, mergeSize } = this.mropeParams();
3012
+ const seq = inputIds.length;
3013
+ if (seq > this.maxSeqLen) throw new Error(`describeImage: prompt+image is ${seq} tokens > maxSeqLen ${this.maxSeqLen}. Increase maxSeqLen or use a smaller image.`);
3014
+ this.executor.reset();
3015
+ this.executor.writeInput("vision_embeds", visionEmbeds);
3016
+ const rowMap = new Int32Array(seq).fill(-1);
3017
+ let v = 0;
3018
+ for (let i = 0; i < seq; i++) if (inputIds[i] === imageTokenId) rowMap[i] = v++;
3019
+ const positionIds3 = buildMRoPEPositionIds(inputIds, [gridTHW], imageTokenId, mergeSize);
3020
+ return { lastLogicalPos: this.writeMRoPEPrefill(positionIds3, seq, rowMap) };
3021
+ }
3022
+ /**
3023
+ * Gemma 4 multimodal prefill + decode. Unlike Qwen3.5 (M-RoPE), Gemma 4 uses
3024
+ * STANDARD sequential 1D RoPE computed inside each layer from the KV write
3025
+ * position, so there are no host cos/sin inputs and decode positions are simply
3026
+ * the running seqPos — identical to plain text generation. We only upload the
3027
+ * merged vision embeds + an image-token row-map (EmbedSplice scatters them into
3028
+ * the image_token rows) before the forward pass.
3029
+ */
3030
+ async runMultimodalGemma4(visionEmbeds, inputIds, imageTokenId, options) {
3031
+ const seq = inputIds.length;
3032
+ if (seq > this.maxSeqLen) throw new Error(`describeImage: prompt+image is ${seq} tokens > maxSeqLen ${this.maxSeqLen}. Increase maxSeqLen or use a smaller image.`);
3033
+ this.setPhase(`describe:splice:seq=${seq}`);
3034
+ this.executor.reset();
3035
+ this.executor.writeInput("vision_embeds", visionEmbeds);
3036
+ const rowMap = new Int32Array(seq).fill(-1);
3037
+ let v = 0;
3038
+ for (let i = 0; i < seq; i++) if (inputIds[i] === imageTokenId) rowMap[i] = v++;
3039
+ this.executor.writeInput("vision_row_map", rowMap);
3040
+ this.setPhase("describe:text-decode");
3041
+ const { maxTokens = 512, stopSequences = [], sampling = {}, onToken } = options;
3042
+ const startTime = performance.now();
3043
+ const isGreedy = (sampling.temperature ?? .7) < 1e-6;
3044
+ let { logits } = await this.executor.forward(new Uint32Array(inputIds));
3045
+ const generatedIds = [];
3046
+ let finishReason = "max_tokens";
3047
+ let generatedText = "";
3048
+ const eosId = this.tokenizer.config.eosTokenId;
3049
+ const eotId = this.tokenizer.tokenToId("<turn|>");
3050
+ const consumeToken = (nextToken) => {
3051
+ generatedIds.push(nextToken);
3052
+ if (eosId !== null && nextToken === eosId || eotId !== null && nextToken === eotId) {
3053
+ finishReason = "eos";
3054
+ return true;
3055
+ }
3056
+ const tokenText = this.tokenizer.decode([nextToken], true);
3057
+ generatedText += tokenText;
3058
+ onToken?.(tokenText);
3059
+ if (stopSequences.some((s) => generatedText.includes(s))) {
3060
+ for (const s of stopSequences) {
3061
+ const idx = generatedText.indexOf(s);
3062
+ if (idx !== -1) generatedText = generatedText.slice(0, idx);
3063
+ }
3064
+ finishReason = "stop_sequence";
3065
+ return true;
3066
+ }
3067
+ return false;
3068
+ };
3069
+ for (let step = 0; step < maxTokens; step++) {
3070
+ let nextToken;
3071
+ if (step === 0) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
3072
+ else if (isGreedy) nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
3073
+ else nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
3074
+ if (consumeToken(nextToken)) break;
3075
+ if (!isGreedy) logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
3076
+ }
3077
+ const totalTime = performance.now() - startTime;
3078
+ const tokensGenerated = generatedIds.length;
3079
+ return {
3080
+ text: generatedText,
3081
+ tokensGenerated,
3082
+ tokensPerSecond: tokensGenerated / (totalTime / 1e3),
3083
+ totalTime,
3084
+ finishReason
3085
+ };
3086
+ }
3087
+ /** Prepare + prefill + decode for a fully-specified multimodal token sequence. */
3088
+ async runMultimodal(visionEmbeds, gridTHW, inputIds, options) {
3089
+ this.setPhase(`describe:splice:seq=${inputIds.length}`);
3090
+ const { lastLogicalPos } = this.prepareMultimodalPrefill(visionEmbeds, gridTHW, inputIds);
3091
+ this.setPhase("describe:text-decode");
3092
+ return this.generateFromPrepared(inputIds, lastLogicalPos + 1, options);
3093
+ }
3094
+ /**
3095
+ * Debug: run ONLY the multimodal prefill for an explicit token sequence and
3096
+ * return the spliced input embeddings [seq, hidden] + first-token logits. Lets
3097
+ * tests compare the fused text+vision stream and M-RoPE numerically vs HF
3098
+ * without the decode loop overwriting intermediate buffers.
3099
+ */
3100
+ async debugMultimodalPrefill(patches, gridTHW, inputIds) {
3101
+ this.checkDestroyed();
3102
+ if (!this._multimodalGraph || !this.visionExecutor) throw new Error("debugMultimodalPrefill requires a multimodal engine (enableVision: true).");
3103
+ const vision = await this.encodeImage(patches, gridTHW);
3104
+ this.prepareMultimodalPrefill(vision.embeds, gridTHW, inputIds);
3105
+ const { logits } = await this.executor.forward(new Uint32Array(inputIds));
3106
+ return {
3107
+ splicedEmbeds: await this.executor.debugReadBuffer("embed_spliced", inputIds.length * this.config.hidden_size),
3108
+ logits,
3109
+ seq: inputIds.length
3110
+ };
3111
+ }
3112
+ /**
3113
+ * Internal: run prefill (assumes M-RoPE/splice inputs already written) + decode,
3114
+ * with decode logical positions starting at `decodeStartPos`. Used by
3115
+ * describeImage so the post-image cursor is honored.
3116
+ */
3117
+ async generateFromPrepared(inputIds, decodeStartPos, options) {
3118
+ const { maxTokens = 512, stopSequences = [], sampling = {}, onToken } = options;
3119
+ const startTime = performance.now();
3120
+ const isGreedy = (sampling.temperature ?? .7) < 1e-6;
3121
+ let { logits } = await this.executor.forward(new Uint32Array(inputIds));
3122
+ const generatedIds = [];
3123
+ let finishReason = "max_tokens";
3124
+ let generatedText = "";
3125
+ const eosId = this.tokenizer.config.eosTokenId;
3126
+ let mmLogicalPos = decodeStartPos;
3127
+ const consumeToken = (nextToken) => {
3128
+ generatedIds.push(nextToken);
3129
+ if (eosId !== null && nextToken === eosId) {
3130
+ finishReason = "eos";
3131
+ return true;
3132
+ }
3133
+ const tokenText = this.tokenizer.decode([nextToken], true);
3134
+ generatedText += tokenText;
3135
+ onToken?.(tokenText);
3136
+ if (stopSequences.some((s) => generatedText.includes(s))) {
3137
+ for (const s of stopSequences) {
3138
+ const idx = generatedText.indexOf(s);
3139
+ if (idx !== -1) generatedText = generatedText.slice(0, idx);
3140
+ }
3141
+ finishReason = "stop_sequence";
3142
+ return true;
3143
+ }
3144
+ return false;
3145
+ };
3146
+ for (let step = 0; step < maxTokens; step++) {
3147
+ let nextToken;
3148
+ if (step === 0) nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
3149
+ else if (isGreedy) {
3150
+ this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
3151
+ nextToken = await this.executor.forwardArgmax(new Uint32Array([generatedIds[generatedIds.length - 1]]));
3152
+ mmLogicalPos++;
3153
+ } else nextToken = sampleToken(logits, sampling, [...inputIds, ...generatedIds]);
3154
+ if (consumeToken(nextToken)) break;
3155
+ if (!isGreedy) {
3156
+ this.writeMRoPEDecodeStep(this.executor.currentSeqPos, mmLogicalPos);
3157
+ logits = (await this.executor.forward(new Uint32Array([nextToken]))).logits;
3158
+ mmLogicalPos++;
3159
+ }
3160
+ }
3161
+ const totalTime = performance.now() - startTime;
3162
+ const tokensGenerated = generatedIds.length;
3163
+ return {
3164
+ text: generatedText,
3165
+ tokensGenerated,
3166
+ tokensPerSecond: tokensGenerated / (totalTime / 1e3),
3167
+ totalTime,
3168
+ finishReason
3169
+ };
3170
+ }
3171
+ /**
3172
+ * Embed text into an L2-normalized vector. The pooling strategy depends on the
3173
+ * model: Qwen3-Embedding uses last-token (EOS-position) pooling, while
3174
+ * EmbeddingGemma (Gemma3 encoder) uses mean pooling over all tokens followed by
3175
+ * a 2-layer Dense head. Requires an embedding model (loaded with
3176
+ * { embedding: true }).
3177
+ *
3178
+ * The returned Float32Array has unit L2 norm, so cosine similarity reduces to a
3179
+ * dot product. Length is the model's embedding dim (768 for EmbeddingGemma;
3180
+ * config.hidden_size for Qwen3-Embedding).
3181
+ *
3182
+ * EmbeddingGemma is asymmetric — pass `{ taskType: "query" }` for search
3183
+ * queries and `{ taskType: "document" }` for the corpus, or a raw
3184
+ * `{ taskPrompt }` for other tasks (clustering/classification/STS).
3185
+ */
3186
+ async embed(text, options = {}) {
3187
+ this.checkDestroyed();
3188
+ if (!this._isEmbedding) throw new Error("embed() requires an embedding model. Load with { embedding: true } (e.g. repo: 'Qwen/Qwen3-Embedding-0.6B' or 'mlx-community/embeddinggemma-300m-4bit').");
3189
+ this.executor.reset();
3190
+ const { instruction, taskType, taskPrompt, maxTokens } = options;
3191
+ if (this._architecture === "Gemma3TextModel" || this._architecture === "Gemma3Model") {
3192
+ const input$1 = `${taskPrompt ?? EMBEDDING_GEMMA_PROMPTS[taskType ?? "query"] ?? EMBEDDING_GEMMA_PROMPTS.query}${text}`;
3193
+ let ids$1 = this.tokenizer.encode(input$1);
3194
+ const cap$1 = maxTokens ?? this.config.context_length;
3195
+ if (ids$1.length > cap$1) ids$1 = ids$1.slice(0, cap$1);
3196
+ return this.executor.embed(new Uint32Array(ids$1));
3197
+ }
3198
+ const input = instruction ? `Instruct: ${instruction}\nQuery:${text}` : text;
3199
+ let ids = this.tokenizer.encode(input);
3200
+ const padId = this.tokenizer.tokenToId("<|endoftext|>") ?? this.tokenizer.config.eosTokenId;
3201
+ if (padId !== null && ids[ids.length - 1] !== padId) ids.push(padId);
3202
+ const cap = maxTokens ?? this.config.context_length;
3203
+ if (ids.length > cap) {
3204
+ const tail = padId !== null && ids[ids.length - 1] === padId ? [padId] : [];
3205
+ ids = [...ids.slice(0, cap - tail.length), ...tail];
3206
+ }
3207
+ return this.executor.embed(new Uint32Array(ids));
3208
+ }
3209
+ /**
3210
+ * Generate text as an async iterator (streaming).
3211
+ *
3212
+ * Uses the onToken callback from generate() to push tokens into a queue
3213
+ * that the async generator yields from. The generator returns the full
3214
+ * GenerateResult when generation completes.
3215
+ *
3216
+ * Usage:
3217
+ * const gen = engine.stream("Hello!");
3218
+ * for await (const token of gen) {
3219
+ * process.stdout.write(token);
3220
+ * }
3221
+ * const result = gen.next(); // { done: true, value: GenerateResult }
3222
+ */
3223
+ async *stream(prompt, options = {}) {
3224
+ this.checkDestroyed();
3225
+ const tokenQueue = [];
3226
+ let resolve = null;
3227
+ let done = false;
3228
+ const pushToken = (token) => {
3229
+ tokenQueue.push(token);
3230
+ if (resolve) {
3231
+ const r = resolve;
3232
+ resolve = null;
3233
+ r();
3234
+ }
3235
+ };
3236
+ const waitForToken = () => {
3237
+ if (tokenQueue.length > 0 || done) return Promise.resolve();
3238
+ return new Promise((r) => {
3239
+ resolve = r;
3240
+ });
3241
+ };
3242
+ const genPromise = this.generate(prompt, {
3243
+ ...options,
3244
+ onToken: (token) => {
3245
+ options.onToken?.(token);
3246
+ pushToken(token);
3247
+ }
3248
+ }).then((result) => {
3249
+ done = true;
3250
+ if (resolve) {
3251
+ const r = resolve;
3252
+ resolve = null;
3253
+ r();
3254
+ }
3255
+ return result;
3256
+ });
3257
+ let yielded = 0;
3258
+ while (true) {
3259
+ await waitForToken();
3260
+ while (yielded < tokenQueue.length) yield tokenQueue[yielded++];
3261
+ if (done && yielded >= tokenQueue.length) break;
3262
+ }
3263
+ return await genPromise;
3264
+ }
3265
+ /**
3266
+ * Debug: read back a named GPU buffer (weight or activation).
3267
+ * Call after forward() to inspect intermediate values.
3268
+ */
3269
+ async debugReadBuffer(tensorName, maxElements) {
3270
+ return this.executor.debugReadBuffer(tensorName, maxElements);
3271
+ }
3272
+ /**
3273
+ * Run GPU diagnostics (buffer integrity, compute, shared memory).
3274
+ * Useful for isolating Safari/WebKit-specific WebGPU issues.
3275
+ */
3276
+ async diagnose() {
3277
+ return verifyGPU(this.ctx);
3278
+ }
3279
+ /**
3280
+ * Run GPU diagnostics without loading a model.
3281
+ * Quick way to check if WebGPU is working correctly on this device.
3282
+ */
3283
+ static async quickDiagnose() {
3284
+ const ctx = await initGPU();
3285
+ const result = await verifyGPU(ctx);
3286
+ clearPipelineCache(ctx.device);
3287
+ ctx.device.destroy();
3288
+ return result;
3289
+ }
3290
+ /**
3291
+ * Run a raw forward pass (no tokenization/chat template).
3292
+ * Returns logits for the last token.
3293
+ */
3294
+ async rawForward(inputIds) {
3295
+ return this.executor.forward(inputIds);
3296
+ }
3297
+ /**
3298
+ * Reset executor state (SSM, positions, etc.)
3299
+ */
3300
+ resetState() {
3301
+ this.executor.reset();
3302
+ }
3303
+ /**
3304
+ * Encode text to token IDs (useful for debugging / token counting).
3305
+ */
3306
+ encode(text) {
3307
+ return this.tokenizer.encode(text);
3308
+ }
3309
+ /**
3310
+ * Decode token IDs to text.
3311
+ */
3312
+ decode(ids, skipSpecialTokens) {
3313
+ return this.tokenizer.decode(ids, skipSpecialTokens);
3314
+ }
3315
+ /**
3316
+ * Integrity check: reads back key weight tensors and runs a single forward pass,
3317
+ * returning checksums for comparison against a known-good reference (Dawn/Node.js).
3318
+ *
3319
+ * Use this to isolate Safari/iPad corruption:
3320
+ * - If weights mismatch → fetch/download pipeline is corrupt
3321
+ * - If weights match but logits mismatch → kernel computation bug on Metal
3322
+ *
3323
+ * Resets executor state before and after (safe to call anytime).
3324
+ */
3325
+ async integrityCheck() {
3326
+ this.checkDestroyed();
3327
+ const checks = [];
3328
+ const log = (label, data) => {
3329
+ const f = data instanceof Float32Array ? data : new Float32Array(data.buffer);
3330
+ let sum = 0;
3331
+ let maxVal = -Infinity;
3332
+ let argmax$1 = 0;
3333
+ for (let i = 0; i < f.length; i++) {
3334
+ sum += f[i];
3335
+ if (f[i] > maxVal) {
3336
+ maxVal = f[i];
3337
+ argmax$1 = i;
3338
+ }
3339
+ }
3340
+ const entry = {
3341
+ label,
3342
+ length: f.length,
3343
+ sum: Math.round(sum * 1e6) / 1e6,
3344
+ first4: [
3345
+ Math.round(f[0] * 1e6) / 1e6,
3346
+ Math.round(f[1] * 1e6) / 1e6,
3347
+ Math.round(f[2] * 1e6) / 1e6,
3348
+ Math.round(f[3] * 1e6) / 1e6
3349
+ ],
3350
+ argmax: argmax$1,
3351
+ maxVal: Math.round(maxVal * 1e4) / 1e4
3352
+ };
3353
+ checks.push(entry);
3354
+ return entry;
3355
+ };
3356
+ checks.push({
3357
+ label: `webkit_detection: isWebKitWebGPU=${this.ctx.isWebKitWebGPU} needsMultiEncoder=${this.executor.needsMultiEncoder}`,
3358
+ length: 0,
3359
+ sum: 0,
3360
+ first4: [
3361
+ 0,
3362
+ 0,
3363
+ 0,
3364
+ 0
3365
+ ],
3366
+ argmax: 0,
3367
+ maxVal: 0,
3368
+ note: this.ctx.adapterDescription
3369
+ });
3370
+ for (const name of ["norm.weight", "layers.0.input_layernorm.weight"]) try {
3371
+ log(name, await this.debugReadBuffer(name));
3372
+ } catch {
3373
+ checks.push({
3374
+ label: name,
3375
+ length: 0,
3376
+ sum: NaN,
3377
+ first4: [
3378
+ NaN,
3379
+ NaN,
3380
+ NaN,
3381
+ NaN
3382
+ ],
3383
+ argmax: -1,
3384
+ maxVal: NaN,
3385
+ error: "buffer not found"
3386
+ });
3387
+ }
3388
+ for (const qName of [
3389
+ "layers.0.mlp.gate_proj.weight.q",
3390
+ "embed_tokens.weight.q",
3391
+ "embed_tokens.weight.scales",
3392
+ "embed_tokens.weight.zeros"
3393
+ ]) try {
3394
+ const q = await this.debugReadBuffer(qName, 16);
3395
+ log(`${qName} (reinterpret)`, q);
3396
+ } catch {
3397
+ checks.push({
3398
+ label: qName,
3399
+ length: 0,
3400
+ sum: NaN,
3401
+ first4: [
3402
+ NaN,
3403
+ NaN,
3404
+ NaN,
3405
+ NaN
3406
+ ],
3407
+ argmax: -1,
3408
+ maxVal: NaN,
3409
+ error: "not found"
3410
+ });
3411
+ }
3412
+ this.executor.reset();
3413
+ try {
3414
+ const result = await this.executor.debugFirstDispatch(new Uint32Array([1]));
3415
+ const entry = log(`single_dispatch(${result.opType})`, result.output);
3416
+ entry.note = `dispatch=${result.dispatchSize.join(",")} node=${result.nodeId}`;
3417
+ } catch (e) {
3418
+ checks.push({
3419
+ label: "single_dispatch",
3420
+ length: 0,
3421
+ sum: NaN,
3422
+ first4: [
3423
+ NaN,
3424
+ NaN,
3425
+ NaN,
3426
+ NaN
3427
+ ],
3428
+ argmax: -1,
3429
+ maxVal: NaN,
3430
+ error: e.message
3431
+ });
3432
+ }
3433
+ try {
3434
+ const result = await this.executor.debugDispatchEntry(1, 1);
3435
+ const entry = log(`isolated_entry1(${result.opType})`, result.output);
3436
+ entry.note = `node=${result.nodeId}`;
3437
+ } catch (e) {
3438
+ checks.push({
3439
+ label: "isolated_entry1",
3440
+ length: 0,
3441
+ sum: NaN,
3442
+ first4: [
3443
+ NaN,
3444
+ NaN,
3445
+ NaN,
3446
+ NaN
3447
+ ],
3448
+ argmax: -1,
3449
+ maxVal: NaN,
3450
+ error: e.message
3451
+ });
3452
+ }
3453
+ try {
3454
+ const result = await this.executor.debugDispatchEntry(2, 1);
3455
+ const entry = log(`isolated_entry2(${result.opType})`, result.output);
3456
+ entry.note = `node=${result.nodeId}`;
3457
+ } catch (e) {
3458
+ checks.push({
3459
+ label: "isolated_entry2",
3460
+ length: 0,
3461
+ sum: NaN,
3462
+ first4: [
3463
+ NaN,
3464
+ NaN,
3465
+ NaN,
3466
+ NaN
3467
+ ],
3468
+ argmax: -1,
3469
+ maxVal: NaN,
3470
+ error: e.message
3471
+ });
3472
+ }
3473
+ this.executor.reset();
3474
+ const jsParams = this.executor.debugComputeParams(1, 5);
3475
+ for (const p of jsParams) checks.push({
3476
+ label: `jsParams[${p.idx}] ${p.opType} dispatch=[${p.dispatchSize}] u32=[${p.paramsU32}]`,
3477
+ length: p.paramsU32.length,
3478
+ sum: p.paramsU32.reduce((a, b) => a + b, 0),
3479
+ first4: p.paramsU32.slice(0, 4).map(Number),
3480
+ argmax: 0,
3481
+ maxVal: 0
3482
+ });
3483
+ this.executor.reset();
3484
+ try {
3485
+ const { logits } = await this.executor.forward(new Uint32Array([1]));
3486
+ log("logits(token=1)", logits);
3487
+ try {
3488
+ log("embed_out", await this.debugReadBuffer("embed_out", 16));
3489
+ } catch {}
3490
+ try {
3491
+ const probes = await this.executor.debugPipelineProbe(1);
3492
+ for (const p of probes) {
3493
+ const paramsStr = p.uniformParams ? ` params=[${p.uniformParams.join(",")}]` : "";
3494
+ const entry = {
3495
+ label: `probe[${p.idx}] ${p.opType} → ${p.tensor}${paramsStr}`,
3496
+ length: 16,
3497
+ sum: p.sum,
3498
+ first4: p.first4,
3499
+ argmax: 0,
3500
+ maxVal: Math.max(...p.first4.map(Math.abs))
3501
+ };
3502
+ if (p.sum === 0 && p.first4.every((v) => v === 0)) {
3503
+ entry.match = "FAIL";
3504
+ entry.note = "all zeros — activation dead";
3505
+ }
3506
+ checks.push(entry);
3507
+ }
3508
+ } catch (e) {
3509
+ checks.push({
3510
+ label: "pipeline_probe",
3511
+ length: 0,
3512
+ sum: NaN,
3513
+ first4: [
3514
+ NaN,
3515
+ NaN,
3516
+ NaN,
3517
+ NaN
3518
+ ],
3519
+ argmax: -1,
3520
+ maxVal: NaN,
3521
+ error: e.message
3522
+ });
3523
+ }
3524
+ } catch (e) {
3525
+ checks.push({
3526
+ label: "logits(token=1)",
3527
+ length: 0,
3528
+ sum: NaN,
3529
+ first4: [
3530
+ NaN,
3531
+ NaN,
3532
+ NaN,
3533
+ NaN
3534
+ ],
3535
+ argmax: -1,
3536
+ maxVal: NaN,
3537
+ error: e.message
3538
+ });
3539
+ }
3540
+ this.executor.reset();
3541
+ const reference = {
3542
+ "norm.weight": {
3543
+ sum: 4412.571289,
3544
+ first4: [
3545
+ .222656,
3546
+ 3.75,
3547
+ 3.640625,
3548
+ 4.0625
3549
+ ]
3550
+ },
3551
+ "layers.0.input_layernorm.weight": {
3552
+ sum: 1267.312683,
3553
+ first4: [
3554
+ 2,
3555
+ 1.294922,
3556
+ 1.535156,
3557
+ 1.746094
3558
+ ]
3559
+ },
3560
+ embed_out: {
3561
+ sum: -.067839,
3562
+ first4: [
3563
+ .010059,
3564
+ .0264,
3565
+ .001888,
3566
+ -.030794
3567
+ ]
3568
+ }
3569
+ };
3570
+ for (const check of checks) {
3571
+ const ref = reference[check.label];
3572
+ if (ref) {
3573
+ const sumMatch = Math.abs(check.sum - ref.sum) < Math.max(Math.abs(ref.sum) * .001, 1);
3574
+ const first4Match = check.first4.every((v, i) => Math.abs(v - ref.first4[i]) < .01);
3575
+ check.match = sumMatch && first4Match ? "PASS" : "FAIL";
3576
+ check.refSum = ref.sum;
3577
+ } else if (check.label === "logits(token=1)") {
3578
+ const hasNaN = check.first4.some((v) => Number.isNaN(v));
3579
+ const hasInf = check.first4.some((v) => !Number.isFinite(v));
3580
+ const allSame = check.first4.every((v) => v === check.first4[0]);
3581
+ const argmaxValid = check.argmax >= 0 && check.argmax < check.length;
3582
+ check.match = !hasNaN && !hasInf && !allSame && argmaxValid ? "PASS" : "FAIL";
3583
+ }
3584
+ }
3585
+ console.log("\n=== INTEGRITY CHECK ===");
3586
+ for (const c of checks) {
3587
+ const status = c.error ? "ERR" : c.match ?? "---";
3588
+ console.log(`[${status}] ${c.label}: sum=${c.sum}, first4=[${c.first4.join(", ")}], argmax=${c.argmax}, len=${c.length}` + (c.refSum !== void 0 ? ` (ref sum=${c.refSum})` : "") + (c.error ? ` ERROR: ${c.error}` : "") + (c.note ? ` | ${c.note}` : ""));
3589
+ }
3590
+ console.log("=== END INTEGRITY CHECK ===\n");
3591
+ return {
3592
+ checks,
3593
+ allPass: checks.every((c) => c.match === "PASS" || c.match === void 0)
3594
+ };
3595
+ }
3596
+ /**
3597
+ * Destroy the engine and free all GPU resources.
3598
+ */
3599
+ destroy() {
3600
+ if (this._destroyed) return;
3601
+ this._destroyed = true;
3602
+ this.executor.destroy();
3603
+ this._kaniTTS?.destroy();
3604
+ this.visionExecutor?.destroy();
3605
+ clearPipelineCache(this.ctx.device);
3606
+ this.ctx.device.destroy();
3607
+ }
3608
+ checkDestroyed() {
3609
+ if (this._destroyed) throw new Error("WebGPUEngine has been destroyed");
3610
+ }
3611
+ };
3612
+
3613
+ //#endregion
3614
+ export { generateGemma4VisionGraph as C, dequantizeMLXProjection as S, resolveGemma4VisionInfo as T, smartResize as _, buildGemma4PosEmbeds as a, generateQwen3_5VisionGraph as b, buildMRoPECosSin as c, buildPositionIds as d, buildRotaryCosSin as f, preprocessImageGemma4 as g, preprocessImage as h, buildGemma4PoolMatrix as i, buildMRoPEPositionIds as l, mropeFreqDims as m, GEMMA4_IMAGE_PROCESSOR as n, buildGemma4RotaryCosSin as o, buildVisionPositionTensors as p, QWEN3_5_IMAGE_PROCESSOR as r, buildGemma4VisionPositionTensors as s, WebGPUEngine as t, buildPosEmbeds as u, VisionExecutor as v, patchGemma4VisionClips as w, dequantizeGemma4VisionProjection as x, KaniTTS as y };
3615
+ //# sourceMappingURL=gpu-33qCAtHW.mjs.map