@tryhamster/gerbil 1.0.0-rc.9 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +247 -84
- package/dist/architectures-C1I5V3Dt.mjs +6070 -0
- package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
- package/dist/browser/index.d.ts +264 -588
- package/dist/browser/index.d.ts.map +1 -1
- package/dist/browser/index.js +585 -2334
- package/dist/browser/index.js.map +1 -1
- package/dist/cli.mjs +625 -1098
- package/dist/cli.mjs.map +1 -1
- package/dist/defaults-9komdrbY.mjs +24 -0
- package/dist/defaults-9komdrbY.mjs.map +1 -0
- package/dist/frameworks/express.d.mts +1 -3
- package/dist/frameworks/express.d.mts.map +1 -1
- package/dist/frameworks/express.mjs +7 -7
- package/dist/frameworks/express.mjs.map +1 -1
- package/dist/frameworks/fastify.d.mts +1 -1
- package/dist/frameworks/fastify.d.mts.map +1 -1
- package/dist/frameworks/fastify.mjs +3 -3
- package/dist/frameworks/fastify.mjs.map +1 -1
- package/dist/frameworks/hono.d.mts +1 -1
- package/dist/frameworks/hono.d.mts.map +1 -1
- package/dist/frameworks/hono.mjs +4 -4
- package/dist/frameworks/hono.mjs.map +1 -1
- package/dist/frameworks/next.d.mts +3 -2
- package/dist/frameworks/next.d.mts.map +1 -1
- package/dist/frameworks/next.mjs +4 -4
- package/dist/frameworks/next.mjs.map +1 -1
- package/dist/frameworks/react.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts.map +1 -1
- package/dist/frameworks/trpc.mjs +4 -4
- package/dist/frameworks/trpc.mjs.map +1 -1
- package/dist/gerbil-BHrJJIa4.mjs +1656 -0
- package/dist/gerbil-BHrJJIa4.mjs.map +1 -0
- package/dist/gerbil-BT9fCydo.d.mts +488 -0
- package/dist/gerbil-BT9fCydo.d.mts.map +1 -0
- package/dist/gerbil-DomNfIr1.mjs +4 -0
- package/dist/gpu/hooks.d.mts +520 -0
- package/dist/gpu/hooks.d.mts.map +1 -0
- package/dist/gpu/hooks.mjs +1188 -0
- package/dist/gpu/hooks.mjs.map +1 -0
- package/dist/gpu/index.d.mts +2 -0
- package/dist/gpu/index.mjs +6 -0
- package/dist/gpu-33qCAtHW.mjs +3615 -0
- package/dist/gpu-33qCAtHW.mjs.map +1 -0
- package/dist/index-Dgmb2kE3.d.mts +245 -0
- package/dist/index-Dgmb2kE3.d.mts.map +1 -0
- package/dist/index-jEAL2s-A.d.mts +2022 -0
- package/dist/index-jEAL2s-A.d.mts.map +1 -0
- package/dist/index.d.mts +22 -487
- package/dist/index.d.mts.map +1 -1
- package/dist/index.mjs +13 -8
- package/dist/index.mjs.map +1 -1
- package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
- package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
- package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
- package/dist/integrations/ai-sdk.d.mts +75 -6
- package/dist/integrations/ai-sdk.d.mts.map +1 -1
- package/dist/integrations/ai-sdk.mjs +131 -15
- package/dist/integrations/ai-sdk.mjs.map +1 -1
- package/dist/integrations/langchain.d.mts +1 -1
- package/dist/integrations/langchain.d.mts.map +1 -1
- package/dist/integrations/langchain.mjs +5 -5
- package/dist/integrations/langchain.mjs.map +1 -1
- package/dist/integrations/llamaindex.d.mts +1 -1
- package/dist/integrations/llamaindex.d.mts.map +1 -1
- package/dist/integrations/llamaindex.mjs +5 -5
- package/dist/integrations/llamaindex.mjs.map +1 -1
- package/dist/integrations/mcp-client.mjs +3 -3
- package/dist/integrations/mcp-client.mjs.map +1 -1
- package/dist/integrations/mcp.d.mts +3 -2
- package/dist/integrations/mcp.d.mts.map +1 -1
- package/dist/integrations/mcp.mjs +5 -5
- package/dist/{mcp-BvbriaBy.mjs → mcp-1DaMsaBc.mjs} +4 -4
- package/dist/mcp-1DaMsaBc.mjs.map +1 -0
- package/dist/memory/index.d.mts +3 -0
- package/dist/memory/index.mjs +6 -0
- package/dist/memory-D1P7Tmda.mjs +4 -0
- package/dist/memory-DVN0MnIG.mjs +132 -0
- package/dist/memory-DVN0MnIG.mjs.map +1 -0
- package/dist/memory-Dj0J1v88.mjs +294 -0
- package/dist/memory-Dj0J1v88.mjs.map +1 -0
- package/dist/moonshine-stt-BLyVoRpB.mjs +4 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs +11936 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs.map +1 -0
- package/dist/{one-liner-s-lD8rCC.mjs → one-liner-DnQn7HJK.mjs} +14 -16
- package/dist/one-liner-DnQn7HJK.mjs.map +1 -0
- package/dist/repl-jV5gcJFA.mjs +9 -0
- package/dist/skills/index.d.mts +270 -320
- package/dist/skills/index.d.mts.map +1 -1
- package/dist/skills/index.mjs +5 -5
- package/dist/{skills-CD3Orlex.mjs → skills-DX8D59UH.mjs} +187 -32
- package/dist/skills-DX8D59UH.mjs.map +1 -0
- package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
- package/dist/tools-DQ1mPUw5.mjs.map +1 -0
- package/dist/{types-CiTc7ez3.d.mts → types-D6FiR_oh.d.mts} +106 -12
- package/dist/types-D6FiR_oh.d.mts.map +1 -0
- package/dist/types-DQBe2lFo.d.mts +165 -0
- package/dist/types-DQBe2lFo.d.mts.map +1 -0
- package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
- package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
- package/dist/vector-B0panuy6.mjs +95 -0
- package/dist/vector-B0panuy6.mjs.map +1 -0
- package/docs/PROJECT-STATE.md +321 -0
- package/docs/adding-a-model-family.md +280 -0
- package/docs/ai-sdk.md +70 -61
- package/docs/architecture/overview.md +17 -7
- package/docs/browser.md +203 -8
- package/docs/embeddings.md +156 -0
- package/docs/gerbil-site-native-migration.md +217 -0
- package/docs/gpu-engine/architectures.md +398 -0
- package/docs/gpu-engine/ir.md +372 -0
- package/docs/gpu-engine/kernels.md +718 -0
- package/docs/gpu-engine/paper.html +1759 -0
- package/docs/gpu-engine/paper.md +2109 -0
- package/docs/gpu-engine/safetensors.md +312 -0
- package/docs/gpu-engine/tokenizer.md +302 -0
- package/docs/memory-rag.md +91 -0
- package/docs/metal-safari-intel.md +190 -0
- package/docs/mobile-failure-diagnosis.md +124 -0
- package/docs/mobile.md +99 -0
- package/docs/observability.md +230 -0
- package/docs/onnx-removal-plan.md +339 -0
- package/docs/research/autoresearch-portable.md +904 -0
- package/docs/research/dispatch-reduction-hivemind.md +84 -0
- package/docs/research/ios-safari-model-caching.md +117 -0
- package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
- package/docs/research/native-stt-model-selection.md +49 -0
- package/docs/research/native-tts-model-selection.md +90 -0
- package/docs/research/native-vs-chromium-decision.md +152 -0
- package/docs/research/nemotron-mamba2-inference.md +910 -0
- package/docs/research/qwen35-multimodal.md +293 -0
- package/docs/research/qwen36-gemma4-targets.md +337 -0
- package/docs/research/sota-embedding-models.md +179 -0
- package/docs/research/sota-mobile-models-2026.md +263 -0
- package/docs/research/sota-modality-models.md +202 -0
- package/docs/research/tps-baselines.md +71 -0
- package/docs/research/webgpu-m4-reference.md +104 -0
- package/docs/site-update-plan.md +155 -0
- package/docs/structured-output.md +123 -0
- package/docs/stt.md +63 -446
- package/docs/tts.md +77 -499
- package/docs/vision.md +100 -338
- package/package.json +22 -7
- package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
- package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
- package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
- package/dist/gerbil-CJ3ifloF.mjs +0 -4
- package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
- package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
- package/dist/gerbil-qOTe1nl2.d.mts +0 -431
- package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
- package/dist/kokoro-BNTb6egA.mjs +0 -20210
- package/dist/kokoro-BNTb6egA.mjs.map +0 -1
- package/dist/kokoro-CMOGDSgT.js +0 -20212
- package/dist/kokoro-CMOGDSgT.js.map +0 -1
- package/dist/mcp-BvbriaBy.mjs.map +0 -1
- package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
- package/dist/repl-DveXw36T.mjs +0 -9
- package/dist/skills-CD3Orlex.mjs.map +0 -1
- package/dist/stt-Bu-E23Sc.js +0 -433
- package/dist/stt-Bu-E23Sc.js.map +0 -1
- package/dist/stt-CpLYbGFd.mjs +0 -433
- package/dist/stt-CpLYbGFd.mjs.map +0 -1
- package/dist/stt-DRPLEEHB.mjs +0 -3
- package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
- package/dist/transformers.web-DiD1gTwk.js +0 -44695
- package/dist/transformers.web-DiD1gTwk.js.map +0 -1
- package/dist/transformers.web-u34VxRFM.js +0 -3
- package/dist/tts-CqroPaSK.js +0 -724
- package/dist/tts-CqroPaSK.js.map +0 -1
- package/dist/tts-DXgsKGCe.mjs +0 -3
- package/dist/tts-DeGANMNV.mjs +0 -730
- package/dist/tts-DeGANMNV.mjs.map +0 -1
- package/dist/types-CiTc7ez3.d.mts.map +0 -1
- /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
- /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
- /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# Adding a Model Family to the Gerbil WebGPU Engine
|
|
2
|
+
|
|
3
|
+
This is the repeatable process for teaching the native WebGPU engine a new model
|
|
4
|
+
architecture. The engine is **not** Qwen-specific — it's a registry of graph
|
|
5
|
+
generators over a family-agnostic IR and kernel library. Adding a family is
|
|
6
|
+
usually "write one generator + register it," and only occasionally "write a new
|
|
7
|
+
kernel."
|
|
8
|
+
|
|
9
|
+
> TL;DR effort tiers:
|
|
10
|
+
> - **Tier 1 (hours):** the model is Llama-like (standard transformer). Reuse existing ops; write a config→IR generator. Llama, Mistral, Gemma-text, Phi, Qwen all live here.
|
|
11
|
+
> - **Tier 2 (days):** the model has one novel op (a new norm, sliding-window attention, a gate) OR new per-node params on existing kernels. Write the generator + 1–2 WGSL kernels.
|
|
12
|
+
> - **Tier 3 (weeks):** the model has a fundamentally new computation (SSM/Mamba, MoE routing, dual-graph encoder-decoder, codec decoder). New kernels + executor support. Qwen3.5's Mamba-2 path was Tier 3.
|
|
13
|
+
|
|
14
|
+
> **Repeatable SOP:** for the step-by-step loop use the `add-model-family` skill
|
|
15
|
+
> (`.claude/skills/add-model-family/SKILL.md`). This doc is the reference; the
|
|
16
|
+
> skill is the procedure. Both encode the lessons in "Lessons from production"
|
|
17
|
+
> below — read that section first.
|
|
18
|
+
|
|
19
|
+
---
|
|
20
|
+
|
|
21
|
+
## Lessons from production (read first)
|
|
22
|
+
|
|
23
|
+
A wave of families was added against this framework — Qwen3.5 (Mamba-2 SSM + ViT),
|
|
24
|
+
LFM2.5 (hybrid conv/attn), EmbeddingGemma (bidirectional Gemma3 encoder), Moonshine
|
|
25
|
+
STT (CrossAttention + dual-graph), Kani-TTS (LFM2 codec-LM + NanoCodec decoder),
|
|
26
|
+
Gemma 4 E2B (PLE CPU-streaming + proportional RoPE + double-wide MLP + value-norm +
|
|
27
|
+
head_dim-512). What actually mattered:
|
|
28
|
+
|
|
29
|
+
| Family | Arch string | Tier | What was new |
|
|
30
|
+
|---|---|---|---|
|
|
31
|
+
| Qwen2/3 | `Qwen2/3ForCausalLM` | 1 | baseline standard transformer (the template) |
|
|
32
|
+
| Qwen3-Embedding | `Qwen3ForCausalLM` (embedding flag) | 1 | `SliceLastRow`→`L2Norm` pooling tail, no new kernels |
|
|
33
|
+
| Qwen3.5 | `Qwen3_5ForConditionalGeneration` | 3 | Mamba-2 SSM kernels + ViT encoder (`VisionExecutor`) |
|
|
34
|
+
| LFM2.5 | `Lfm2ForCausalLM` | 2 | short-conv/attention hybrid; `out_proj`→`o_proj` rename |
|
|
35
|
+
| EmbeddingGemma | `Gemma3TextModel`/`Gemma3Model` | 2 | bidirectional encoder (`causal:false`) + mean-pool + dense tail |
|
|
36
|
+
| Moonshine STT | `MoonshineForConditionalGeneration` | 3 | new `CrossAttention` kernel; dual-graph (encoder+decoder) executor |
|
|
37
|
+
| Kani-TTS | `KaniTTS2ForCausalLM` | 3 | LFM2 codec-LM backbone + NanoCodec decoder (FSQ + causal HiFi-GAN) |
|
|
38
|
+
| Gemma 4 E2B | `Gemma4ForConditionalGeneration` | 2 | PLE CPU-streaming, proportional/partial RoPE, value-norm, head_dim-512, double-wide MLP, per-layer scalar — all via per-node attrs + loader hooks |
|
|
39
|
+
|
|
40
|
+
1. **VERIFY FIRST against the live checkpoint.** Fetch the live `config.json` and
|
|
41
|
+
range-read the safetensors header (keys/dtypes/shapes) BEFORE writing anything.
|
|
42
|
+
Confirm the arch string, dims, quant format, and classify every op. Do NOT
|
|
43
|
+
trust assumptions from sibling models: Gemma4 had no MatFormer though Gemma3n
|
|
44
|
+
did; OmniVoice's codec is diffusion not AR; MLX repos ship the ViT under a
|
|
45
|
+
different key prefix (`vision_tower.*`) than HF bf16 (`model.visual.*`), and the
|
|
46
|
+
LM under `language_model.model.*` (MLX) vs `model.language_model.*` (GPTQ/HF).
|
|
47
|
+
Watch for **nested config** — Qwen3.5/Gemma4 put the text tower under
|
|
48
|
+
`text_config`.
|
|
49
|
+
2. **Prefer per-node params over new kernels.** Most "novel" behavior is a new
|
|
50
|
+
parameter on an existing kernel. Add it to the WGSL `Params` struct and read it
|
|
51
|
+
in `buildParams` with a **default that reproduces the old behavior
|
|
52
|
+
byte-identically** (`attn_scale ?? 1/sqrt(head_dim)`), so every existing caller
|
|
53
|
+
is unchanged. Gemma4's attn_scale=1.0, partial RoPE
|
|
54
|
+
(`rope_dim/rope_half/rope_denom/rope_active_pairs`), and per-layer `Scale`
|
|
55
|
+
(loader-patched from `layer_scalar_key`) all landed this way — zero new attention
|
|
56
|
+
kernels.
|
|
57
|
+
3. **Validate with a LOAD-ONCE reference harness — the make-or-break lesson.** Dump
|
|
58
|
+
the HF/MLX reference activations to disk ONCE
|
|
59
|
+
(`{tokens, per_layer_last[L][hidden], argmax, logits_top}`), then iterate
|
|
60
|
+
engine-only in SECONDS via a single `engine.create()` + `rawForward()` +
|
|
61
|
+
per-layer `debugReadBuffer` cosine loop (`test-gemma4-perlayer.mjs`). Reloading
|
|
62
|
+
the 2.5GB model every cycle made Gemma4 cost an hour per iteration. Localize
|
|
63
|
+
divergence to the exact layer/op. Set `GERBIL_NO_ACT_POOL=1` when reading early
|
|
64
|
+
activations (the pool aliases buffers). Gate on coherent generation — bit-exact
|
|
65
|
+
is not the bar for lossy quants.
|
|
66
|
+
4. **No-regression suite is mandatory for shared-kernel changes.** Keep all
|
|
67
|
+
existing models bit-exact: `test-q4-generate`, `test-vision-e2e`,
|
|
68
|
+
`test-crossattention`, `test-nanocodec-decode`, the embedding test, and the fast
|
|
69
|
+
`test-gemma4-perlayer`. The byte-identical-default pattern (lesson 2) is what
|
|
70
|
+
keeps these green.
|
|
71
|
+
5. **Big tensors: shard or stream.** Per-buffer caps are ~256MB
|
|
72
|
+
(`maxBufferSize`) / ~128MB (`maxStorageBufferBindingSize`) on iPad. Shard the
|
|
73
|
+
embedding if vocab×hidden exceeds the cap; for vocab-scale auxiliary tables keep
|
|
74
|
+
them CPU-resident and stream per-layer slices. Gemma4's ~1.17GB PLE table has
|
|
75
|
+
**0 MB GPU residency** — the loader builds a `PleSource`, diverts it from the
|
|
76
|
+
weights map, and the executor uploads only the needed `[T, width]` rows per step.
|
|
77
|
+
6. **Merge additively, never commit node_modules.** Registry/IR/index conflicts are
|
|
78
|
+
always additive — keep ALL families/ops from both sides. Use targeted `git add`,
|
|
79
|
+
never `git add -A`.
|
|
80
|
+
|
|
81
|
+
---
|
|
82
|
+
|
|
83
|
+
## The architecture
|
|
84
|
+
|
|
85
|
+
When a model loads, `model-loader.ts` reads `config.json` → `architectures[0]`
|
|
86
|
+
(e.g. `"LlamaForCausalLM"`) and looks it up in the registry:
|
|
87
|
+
|
|
88
|
+
```
|
|
89
|
+
src/gpu/architectures/index.ts → ARCHITECTURES: Record<string, GraphGenerator>
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
A `GraphGenerator` turns the raw HF config into a `ModelGraph` (the IR):
|
|
93
|
+
|
|
94
|
+
```ts
|
|
95
|
+
type GraphGenerator = (
|
|
96
|
+
config: Record<string, unknown>,
|
|
97
|
+
dtype?: "f32" | "q4", // q4 = on-the-fly INT4 quantization
|
|
98
|
+
groupSize?: number, // INT4 group size (default 128)
|
|
99
|
+
kvDtype?: "f32" | "f16", // KV cache precision
|
|
100
|
+
) => ModelGraph;
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
The `ModelGraph` is `{ tensors, nodes, executionOrder, inputs, outputs, config, capabilities, architecture }`.
|
|
104
|
+
The executor (`executor.ts`) consumes it generically: it allocates buffers
|
|
105
|
+
(liveness-pooled), compiles one pipeline per node from `KERNEL_REGISTRY`, and
|
|
106
|
+
dispatches in `executionOrder`. **The generator never touches WebGPU** — it only
|
|
107
|
+
describes the computation.
|
|
108
|
+
|
|
109
|
+
---
|
|
110
|
+
|
|
111
|
+
## Step-by-step
|
|
112
|
+
|
|
113
|
+
### 1. Identify the family and check op coverage
|
|
114
|
+
|
|
115
|
+
Read the model's `config.json` and its modeling code on HuggingFace. Classify
|
|
116
|
+
every layer's ops and check them against what the engine already has
|
|
117
|
+
(`KERNEL_REGISTRY` in `src/gpu/kernels/registry.ts`):
|
|
118
|
+
|
|
119
|
+
**Implemented ops (reuse freely):** `Embedding`, `EmbeddingInt4`, `MatMul`,
|
|
120
|
+
`MatMulInt4`, `Add`, `Mul`, `Scale` (per-element, loader-patchable scalar),
|
|
121
|
+
`RMSNorm`, `LayerNorm`, `RoPE` (with optional partial-rotary attrs
|
|
122
|
+
`rope_dim/rope_half/rope_denom/rope_active_pairs`), `Attention` (causal *and*
|
|
123
|
+
bidirectional via the `causal` flag — set `causal: false` for an encoder/ViT;
|
|
124
|
+
takes a per-node `attn_scale` defaulting to `1/sqrt(head_dim)`, plus
|
|
125
|
+
`sliding_window`), `CrossAttention` (encoder-decoder), `Softmax`, `SiLU`,
|
|
126
|
+
`SwiGLU`, `GELU` (tanh approx), `GeluErf` (exact erf), `AddBias` (row-broadcast
|
|
127
|
+
bias), `ApplyRotaryEmb` (precomputed-cos/sin `rotate_half`), `SliceCols`
|
|
128
|
+
(column-range extract, e.g. split a fused QKV), `L2Norm` (row-wise, embedding
|
|
129
|
+
tail), `ResidualRMSNorm`, `KVCacheAppend`, `MambaSSM`, `CausalConv1d`,
|
|
130
|
+
`SigmoidGate`, `ConvStateUpdate`, `SliceLastRow`, and codec ops
|
|
131
|
+
(`Conv1d`/`ConvTranspose1d`/`Snake1d`, FSQ) for audio decoders.
|
|
132
|
+
|
|
133
|
+
If the model only uses these → **Tier 1**. If it needs something new → note it;
|
|
134
|
+
you'll write a kernel in Step 5.
|
|
135
|
+
|
|
136
|
+
**Embedding models** are Tier 1: an embedding model (e.g. Qwen3-Embedding-0.6B,
|
|
137
|
+
which is `Qwen3ForCausalLM`) is the causal-LM forward pass with a different tail —
|
|
138
|
+
pass the `embedding` flag to `generateQwen2Graph` to swap `lm_head` for the
|
|
139
|
+
`SliceLastRow` (last-token EOS pool) → `L2Norm` tail. No new kernels (paper §21).
|
|
140
|
+
|
|
141
|
+
**Vision-capable checkpoints** (e.g. Qwen3.5) ship a ViT in the same weights; the
|
|
142
|
+
encoder is a separate graph (`qwen3_5_vision.ts`) run by `VisionExecutor`, already
|
|
143
|
+
built and bit-exact vs HF (paper §22). Patches arrive pre-flattened, so the
|
|
144
|
+
patch-embed is a plain `MatMul` + `AddBias` — no Conv3d kernel.
|
|
145
|
+
|
|
146
|
+
**Exotic features that need new work (flag early):**
|
|
147
|
+
- Sliding-window / local attention (needs an attention kernel variant)
|
|
148
|
+
- Mixture-of-Experts routing (`MoERouter`/`ExpertMatMul` are stubbed, not built)
|
|
149
|
+
- Novel normalization (e.g. QK-norm, per-head norm) — new kernel
|
|
150
|
+
- MatFormer / elastic parameters (Gemma 3n) — needs slicing logic + a loader story
|
|
151
|
+
- Attention logit soft-capping (Gemma 2/3) — small kernel tweak
|
|
152
|
+
- Non-RoPE position encodings
|
|
153
|
+
|
|
154
|
+
### 2. Map config → dimensions
|
|
155
|
+
|
|
156
|
+
Every generator starts by pulling dimensions from the raw config. Copy the block
|
|
157
|
+
from `src/gpu/architectures/qwen2.ts` and adjust names to the new model's config:
|
|
158
|
+
|
|
159
|
+
```ts
|
|
160
|
+
const hidden_size = rawConfig.hidden_size as number;
|
|
161
|
+
const num_layers = rawConfig.num_hidden_layers as number;
|
|
162
|
+
const num_heads = rawConfig.num_attention_heads as number;
|
|
163
|
+
const num_kv_heads = (rawConfig.num_key_value_heads as number) ?? num_heads; // GQA
|
|
164
|
+
const intermediate_size = rawConfig.intermediate_size as number;
|
|
165
|
+
const vocab_size = rawConfig.vocab_size as number;
|
|
166
|
+
const context_length = (rawConfig.max_position_embeddings as number) ?? 32768;
|
|
167
|
+
const rms_norm_eps = (rawConfig.rms_norm_eps as number) ?? 1e-6;
|
|
168
|
+
const rope_base = (rawConfig.rope_theta as number) ?? 1_000_000.0;
|
|
169
|
+
const head_dim = (rawConfig.head_dim as number) ?? Math.floor(hidden_size / num_heads);
|
|
170
|
+
const tieWordEmbeddings = (rawConfig.tie_word_embeddings as boolean) ?? false;
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
Watch for family differences: **head_dim is often NOT hidden_size/num_heads**
|
|
174
|
+
(Gemma sets it explicitly); GQA means `num_kv_heads < num_heads`; some families
|
|
175
|
+
have **QKV projection bias** (Qwen2 does, Llama/Gemma don't).
|
|
176
|
+
|
|
177
|
+
### 3. Write the generator
|
|
178
|
+
|
|
179
|
+
Create `src/gpu/architectures/<family>.ts` exporting
|
|
180
|
+
`generate<Family>Graph(...)`. Use the local `addTensor`/`addNode` helpers (they
|
|
181
|
+
also push to `executionOrder`). The per-layer skeleton, following qwen2.ts:
|
|
182
|
+
|
|
183
|
+
1. **Embedding** — `Embedding`/`EmbeddingInt4` reading `input_ids` + `CANONICAL_KEYS.EMBED`.
|
|
184
|
+
2. **Per layer** (loop `num_layers`):
|
|
185
|
+
- input RMSNorm → Q/K/V projections (`MatMulInt4`) → RoPE → `KVCacheAppend` → `Attention` → output projection → residual `Add`
|
|
186
|
+
- post-attn RMSNorm → gate/up projections → `SwiGLU` → down projection → residual `Add`
|
|
187
|
+
- (use `ResidualRMSNorm` to fuse the residual+norm where the family allows)
|
|
188
|
+
3. **Final norm** (`RMSNorm`).
|
|
189
|
+
4. **`SliceLastRow`** on the final hidden state → **lm_head** (`MatMulInt4`) → `logits`.
|
|
190
|
+
The `[1, vocab]` logits + SliceLastRow are mandatory (saves ~485MB at long
|
|
191
|
+
context and skips the full-vocab prefill matmul — see paper §18). Honor
|
|
192
|
+
`tieWordEmbeddings` (reuse embed weights for lm_head).
|
|
193
|
+
|
|
194
|
+
Declare intermediate tensors with `storage: "activation"` — the executor pools
|
|
195
|
+
them automatically; don't manage buffers yourself. Set `capabilities`
|
|
196
|
+
(`text`, `vision`, `moe`) and return the graph.
|
|
197
|
+
|
|
198
|
+
### 4. Register it
|
|
199
|
+
|
|
200
|
+
In `src/gpu/architectures/index.ts`:
|
|
201
|
+
|
|
202
|
+
```ts
|
|
203
|
+
import { generateLlamaGraph } from "./llama.js";
|
|
204
|
+
export const ARCHITECTURES = {
|
|
205
|
+
// ...existing...
|
|
206
|
+
LlamaForCausalLM: generateLlamaGraph,
|
|
207
|
+
MistralForCausalLM: generateLlamaGraph, // Mistral == Llama arch
|
|
208
|
+
};
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
One generator can serve a whole family — Llama, Mistral, and many fine-tunes
|
|
212
|
+
share `LlamaForCausalLM`.
|
|
213
|
+
|
|
214
|
+
### 5. Weight key mapping
|
|
215
|
+
|
|
216
|
+
The loader maps HF safetensors keys → canonical names via a `HFKeyMapper`
|
|
217
|
+
(`createDefaultHFKeyMapper` strips the `model.` prefix). If the new family names
|
|
218
|
+
weights differently, supply a mapper or set `safetensorsKey` on each `addTensor`.
|
|
219
|
+
`CANONICAL_KEYS` has helpers (`qProj(i)`, `layerInputNorm(i)`, `EMBED`,
|
|
220
|
+
`LM_HEAD`, …) — reuse them so the loader and generator agree.
|
|
221
|
+
|
|
222
|
+
### 6. Write any new kernel (only if Step 1 flagged one)
|
|
223
|
+
|
|
224
|
+
Add to `src/gpu/kernels/registry.ts`: a WGSL string + a `KernelSpec`
|
|
225
|
+
(`shaderCode`, `entryPoint`, `bindings`, `getDispatchSize`, `buildParams`), then
|
|
226
|
+
register it in `KERNEL_REGISTRY` under your new `OpType` (add the op to the union
|
|
227
|
+
in `ir.ts`). **Mobile constraints (WebKit/iPad):** workgroup memory ≤ 16 KB
|
|
228
|
+
(`maxComputeWorkgroupStorageSize`); no `enable f16` (use packed `pack2x16float`);
|
|
229
|
+
clamp `exp()` for Metal fast-math; avoid `select()` (use if/else). See the
|
|
230
|
+
attention kernels for the safe patterns and the two-phase reduction.
|
|
231
|
+
|
|
232
|
+
### 7. Validate correctness (do not skip)
|
|
233
|
+
|
|
234
|
+
1. **Reference activations:** adapt `scripts/engine/test-reference.py` to dump the
|
|
235
|
+
new model's logits + key intermediate activations from HF transformers.
|
|
236
|
+
2. **Compare:** run `engine.integrityCheck()` (weights + isolated dispatches +
|
|
237
|
+
pipeline probes) and diff the engine's logits against the reference. Match the
|
|
238
|
+
argmax and check sums within tolerance.
|
|
239
|
+
3. **Coherence:** generate from a fixed prompt at temperature 0 and confirm
|
|
240
|
+
sensible text (the `scripts/engine/test-q4-generate.mjs` pattern).
|
|
241
|
+
4. **Cross-platform:** the desktop (Dawn) and mobile (WebKit) outputs must be
|
|
242
|
+
byte-identical — queue an iPad run via the harness (`scripts/engine/ipad-queue.json`).
|
|
243
|
+
|
|
244
|
+
### 8. Mobile / download budget checks
|
|
245
|
+
|
|
246
|
+
- **Download size at q4** is the headline mobile metric — report it.
|
|
247
|
+
- **Per-buffer limit:** iPad `maxBufferSize` defaults to 256MB and
|
|
248
|
+
`maxStorageBufferBindingSize` to 128MB. A single weight tensor (often the
|
|
249
|
+
embedding) must fit — shard it if the model's vocab×hidden exceeds the cap.
|
|
250
|
+
- **maxSeqLen policy:** the engine clamps iOS to 512 by default; large contexts
|
|
251
|
+
multiply KV-cache memory.
|
|
252
|
+
|
|
253
|
+
---
|
|
254
|
+
|
|
255
|
+
## Worked example: the smallest possible new family (Llama)
|
|
256
|
+
|
|
257
|
+
Llama is Qwen2 minus QKV bias. The fastest path:
|
|
258
|
+
1. Copy `qwen2.ts` → `llama.ts`, rename the export.
|
|
259
|
+
2. Remove the QKV-bias tensors/handling (Llama has none).
|
|
260
|
+
3. Confirm head_dim, GQA (`num_key_value_heads`), and `tie_word_embeddings` from config.
|
|
261
|
+
4. Register `LlamaForCausalLM`/`MistralForCausalLM` → `generateLlamaGraph`.
|
|
262
|
+
5. Validate against a reference (Step 7). No new kernels needed — Tier 1.
|
|
263
|
+
|
|
264
|
+
This single generator unlocks Llama, Mistral, and every model that ships as
|
|
265
|
+
`LlamaForCausalLM` on HuggingFace.
|
|
266
|
+
|
|
267
|
+
---
|
|
268
|
+
|
|
269
|
+
## Checklist
|
|
270
|
+
|
|
271
|
+
- [ ] Classified every op; new ops listed (or confirmed none)
|
|
272
|
+
- [ ] `generate<Family>Graph` written, dimensions mapped from config
|
|
273
|
+
- [ ] GQA / head_dim / QKV-bias / tied-embeddings handled
|
|
274
|
+
- [ ] `SliceLastRow` + `[1, vocab]` logits in place
|
|
275
|
+
- [ ] Registered in `ARCHITECTURES`
|
|
276
|
+
- [ ] Key mapping verified (loader finds every weight)
|
|
277
|
+
- [ ] New kernels (if any) respect the 16 KB / no-`enable f16` / clamped-`exp` mobile rules
|
|
278
|
+
- [ ] Validated vs HF reference (logits argmax + coherence)
|
|
279
|
+
- [ ] Desktop and iPad outputs byte-identical
|
|
280
|
+
- [ ] Download size at q4 reported; largest weight fits the 256MB/128MB buffer caps
|
package/docs/ai-sdk.md
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
1
|
# Gerbil + AI SDK
|
|
2
2
|
|
|
3
|
-
Gerbil works as a [Vercel AI SDK v5](https://sdk.vercel.ai/) provider, supporting text generation, speech synthesis (TTS), and transcription (STT).
|
|
3
|
+
Gerbil works as a [Vercel AI SDK v5](https://sdk.vercel.ai/) provider, supporting text generation, embeddings, speech synthesis (TTS), and transcription (STT).
|
|
4
|
+
|
|
5
|
+
> **Pre-1.0 note.** The AI SDK provider routes through the `Gerbil` class, which now runs on
|
|
6
|
+
> the native WebGPU engine (no ONNX / transformers.js). TTS uses Kani-TTS-2, STT uses
|
|
7
|
+
> Moonshine, and embeddings use EmbeddingGemma regardless of the model-id string you pass —
|
|
8
|
+
> legacy ids like `kokoro-82m` / `whisper-tiny.en` are vestigial labels and the device must
|
|
9
|
+
> have WebGPU (there is no CPU/WASM fallback). The first-class surface for the engine is
|
|
10
|
+
> `WebGPUEngine` / `useEngine` (see the [README](../README.md), [TTS](./tts.md),
|
|
11
|
+
> [STT](./stt.md), [Embeddings](./embeddings.md) docs).
|
|
4
12
|
|
|
5
13
|
## Setup
|
|
6
14
|
|
|
@@ -56,52 +64,64 @@ const { text } = await generateText({
|
|
|
56
64
|
});
|
|
57
65
|
```
|
|
58
66
|
|
|
59
|
-
##
|
|
67
|
+
## Embeddings
|
|
68
|
+
|
|
69
|
+
> **Native.** `gerbil.embedding()` runs native EmbeddingGemma-300M on the WebGPU engine
|
|
70
|
+
> (768-dim; the `all-MiniLM-L6-v2` default id is a vestigial label — the old ONNX
|
|
71
|
+
> MiniLM/BGE/GTE lane has been removed). For direct control use `engine.embed()` — see
|
|
72
|
+
> [Embeddings docs](./embeddings.md). Requires WebGPU.
|
|
60
73
|
|
|
61
|
-
Generate
|
|
74
|
+
Generate text embeddings for semantic search, similarity, and RAG:
|
|
62
75
|
|
|
63
76
|
```typescript
|
|
64
|
-
import {
|
|
77
|
+
import { embed, embedMany } from "ai";
|
|
65
78
|
import { gerbil } from "@tryhamster/gerbil/ai";
|
|
66
79
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
80
|
+
// Single embedding (768-dim EmbeddingGemma vector)
|
|
81
|
+
const { embedding } = await embed({
|
|
82
|
+
model: gerbil.embedding(),
|
|
83
|
+
value: "Hello world",
|
|
71
84
|
});
|
|
72
85
|
|
|
73
|
-
//
|
|
74
|
-
await
|
|
86
|
+
// Multiple embeddings
|
|
87
|
+
const { embeddings } = await embedMany({
|
|
88
|
+
model: gerbil.embedding(),
|
|
89
|
+
values: ["Hello", "World", "How are you?"],
|
|
90
|
+
});
|
|
75
91
|
```
|
|
76
92
|
|
|
77
|
-
|
|
93
|
+
## Speech Generation (TTS)
|
|
78
94
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
// Example voices:
|
|
84
|
-
// - af_heart (Female, American)
|
|
85
|
-
// - bf_emma (Female, British)
|
|
86
|
-
// - am_fenrir (Male, American)
|
|
87
|
-
// - bm_daniel (Male, British)
|
|
88
|
-
```
|
|
95
|
+
> **Native.** `gerbil.speech()` runs Kani-TTS-2 on the native engine (the `kokoro-82m`
|
|
96
|
+
> default id is a vestigial label). For direct control use `engine.speak()` — see
|
|
97
|
+
> [TTS docs](./tts.md). Requires WebGPU.
|
|
89
98
|
|
|
90
|
-
|
|
99
|
+
Generate speech from text:
|
|
91
100
|
|
|
92
101
|
```typescript
|
|
102
|
+
import { experimental_generateSpeech as generateSpeech } from "ai";
|
|
103
|
+
import { gerbil } from "@tryhamster/gerbil/ai";
|
|
104
|
+
|
|
93
105
|
const result = await generateSpeech({
|
|
94
|
-
model: gerbil.speech(
|
|
95
|
-
|
|
96
|
-
speed: 1.2, // Speed multiplier
|
|
97
|
-
}),
|
|
98
|
-
text: "Speak faster!",
|
|
106
|
+
model: gerbil.speech(), // native Kani-TTS-2
|
|
107
|
+
text: "Hello, welcome to Gerbil!",
|
|
99
108
|
});
|
|
109
|
+
|
|
110
|
+
// result.audio is the synthesized PCM clip
|
|
111
|
+
await writeFile("output.wav", result.audio);
|
|
100
112
|
```
|
|
101
113
|
|
|
114
|
+
> Voice/speed selection from the old Kokoro lane no longer applies — Kani-TTS-2 uses its own
|
|
115
|
+
> default voice. For full control over the native speech path, use `engine.speak()` directly
|
|
116
|
+
> (see [TTS docs](./tts.md)).
|
|
117
|
+
|
|
102
118
|
## Transcription (STT)
|
|
103
119
|
|
|
104
|
-
|
|
120
|
+
> **Native.** `gerbil.transcription()` runs Moonshine on the native engine (the
|
|
121
|
+
> `whisper-tiny.en` default id is a vestigial label; English only). For direct control use
|
|
122
|
+
> `MoonshineSTT` — see [STT docs](./stt.md). Requires WebGPU.
|
|
123
|
+
|
|
124
|
+
Transcribe audio to text:
|
|
105
125
|
|
|
106
126
|
```typescript
|
|
107
127
|
import { experimental_transcribe as transcribe } from "ai";
|
|
@@ -109,46 +129,17 @@ import { gerbil } from "@tryhamster/gerbil/ai";
|
|
|
109
129
|
import { readFile } from "fs/promises";
|
|
110
130
|
|
|
111
131
|
const result = await transcribe({
|
|
112
|
-
model: gerbil.transcription(), //
|
|
132
|
+
model: gerbil.transcription(), // native Moonshine (English)
|
|
113
133
|
audio: await readFile("audio.wav"),
|
|
114
134
|
});
|
|
115
135
|
|
|
116
136
|
console.log(result.text); // "Hello world"
|
|
117
137
|
console.log(result.language); // "en"
|
|
118
|
-
console.log(result.durationInSeconds); // 2.5
|
|
119
|
-
console.log(result.segments); // Timestamped segments
|
|
120
|
-
```
|
|
121
|
-
|
|
122
|
-
### Available Models
|
|
123
|
-
|
|
124
|
-
```typescript
|
|
125
|
-
const models = gerbil.listTranscriptionModels();
|
|
126
|
-
|
|
127
|
-
// Models (smallest to largest):
|
|
128
|
-
// - whisper-tiny.en (39M, English only, fastest)
|
|
129
|
-
// - whisper-tiny (39M, multilingual)
|
|
130
|
-
// - whisper-base.en (74M, English only)
|
|
131
|
-
// - whisper-base (74M, multilingual)
|
|
132
|
-
// - whisper-small.en (244M, English only)
|
|
133
|
-
// - whisper-small (244M, multilingual)
|
|
134
|
-
// - whisper-large-v3-turbo (809M, 80+ languages, best quality)
|
|
135
138
|
```
|
|
136
139
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
// Use a larger model for better accuracy
|
|
141
|
-
const result = await transcribe({
|
|
142
|
-
model: gerbil.transcription("whisper-base"),
|
|
143
|
-
audio: audioBuffer,
|
|
144
|
-
});
|
|
145
|
-
|
|
146
|
-
// Use multilingual model with language hint
|
|
147
|
-
const result = await transcribe({
|
|
148
|
-
model: gerbil.transcription("whisper-small", { language: "es" }),
|
|
149
|
-
audio: spanishAudio,
|
|
150
|
-
});
|
|
151
|
-
```
|
|
140
|
+
> The Whisper model family (multilingual variants, timestamped segments) has been removed.
|
|
141
|
+
> The native path is Moonshine, English-only, and does not produce timestamps. For full
|
|
142
|
+
> control use `MoonshineSTT` directly (see [STT docs](./stt.md)).
|
|
152
143
|
|
|
153
144
|
## Custom Provider
|
|
154
145
|
|
|
@@ -179,6 +170,24 @@ const transcript = await transcribe({
|
|
|
179
170
|
});
|
|
180
171
|
```
|
|
181
172
|
|
|
173
|
+
## Model Preloading
|
|
174
|
+
|
|
175
|
+
Download models ahead of time via the provider:
|
|
176
|
+
|
|
177
|
+
```typescript
|
|
178
|
+
import { gerbil } from "@tryhamster/gerbil/ai";
|
|
179
|
+
|
|
180
|
+
// Check if cached
|
|
181
|
+
if (!(await gerbil.isCached("qwen3-0.6b"))) {
|
|
182
|
+
// Preload during app init
|
|
183
|
+
await gerbil.preload("qwen3-0.6b", {
|
|
184
|
+
onProgress: (p) => console.log(p.status, p.progress),
|
|
185
|
+
});
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// Later: generateText loads instantly from cache
|
|
189
|
+
```
|
|
190
|
+
|
|
182
191
|
## Specification
|
|
183
192
|
|
|
184
193
|
Gerbil implements the following AI SDK v5 interfaces:
|
|
@@ -20,6 +20,8 @@ Responsibilities:
|
|
|
20
20
|
- Generation orchestration
|
|
21
21
|
- Streaming coordination
|
|
22
22
|
- Session statistics
|
|
23
|
+
- Request queue (concurrency control)
|
|
24
|
+
- Telemetry hooks for observability
|
|
23
25
|
|
|
24
26
|
### 2. Model Registry (`src/core/models.ts`)
|
|
25
27
|
|
|
@@ -57,21 +59,29 @@ Enables WebGPU in Node.js by using headless Chrome as a GPU accelerator:
|
|
|
57
59
|
└───────────┘
|
|
58
60
|
```
|
|
59
61
|
|
|
60
|
-
### 4. Browser
|
|
62
|
+
### 4. Browser Module (`src/browser/index.ts`)
|
|
61
63
|
|
|
62
|
-
Provides
|
|
64
|
+
Provides React hooks and workers for browser applications:
|
|
63
65
|
|
|
64
66
|
```typescript
|
|
67
|
+
// LLM worker
|
|
65
68
|
const gerbil = await createGerbilWorker({
|
|
66
69
|
modelId: "qwen3-0.6b",
|
|
67
70
|
onToken: (token) => console.log(token.text),
|
|
68
71
|
});
|
|
72
|
+
|
|
73
|
+
// React hooks
|
|
74
|
+
const { messages, handleSubmit } = useChat();
|
|
75
|
+
const { speak, isSpeaking } = useSpeech();
|
|
76
|
+
const { startRecording, transcript } = useVoiceInput();
|
|
69
77
|
```
|
|
70
78
|
|
|
71
|
-
Uses
|
|
72
|
-
-
|
|
73
|
-
-
|
|
74
|
-
-
|
|
79
|
+
Uses inline Web Workers for:
|
|
80
|
+
- **LLM**: Model loading, token streaming, GPU memory management
|
|
81
|
+
- **TTS**: Kokoro/Supertonic speech synthesis
|
|
82
|
+
- **STT**: Whisper transcription
|
|
83
|
+
|
|
84
|
+
All workers load dependencies from CDN to avoid bundler issues with onnxruntime-web.
|
|
75
85
|
|
|
76
86
|
## Execution Paths
|
|
77
87
|
|
|
@@ -166,7 +176,7 @@ src/
|
|
|
166
176
|
│ ├── tools.ts # Tool calling system
|
|
167
177
|
│ └── chrome-backend.ts # Node.js WebGPU via Chrome
|
|
168
178
|
├── browser/
|
|
169
|
-
│ └── index.ts #
|
|
179
|
+
│ └── index.ts # React hooks + LLM/TTS/STT workers
|
|
170
180
|
├── skills/
|
|
171
181
|
│ └── ... # Built-in skills (commit, summarize, etc.)
|
|
172
182
|
├── integrations/
|