@tryhamster/gerbil 1.0.0-rc.9 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +318 -104
- package/dist/architectures-C1I5V3Dt.mjs +6070 -0
- package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
- package/dist/browser/index.d.ts +276 -590
- package/dist/browser/index.d.ts.map +1 -1
- package/dist/browser/index.js +592 -2334
- package/dist/browser/index.js.map +1 -1
- package/dist/cli.mjs +625 -1098
- package/dist/cli.mjs.map +1 -1
- package/dist/defaults-9komdrbY.mjs +24 -0
- package/dist/defaults-9komdrbY.mjs.map +1 -0
- package/dist/frameworks/express.d.mts +1 -3
- package/dist/frameworks/express.d.mts.map +1 -1
- package/dist/frameworks/express.mjs +7 -7
- package/dist/frameworks/express.mjs.map +1 -1
- package/dist/frameworks/fastify.d.mts +1 -1
- package/dist/frameworks/fastify.d.mts.map +1 -1
- package/dist/frameworks/fastify.mjs +3 -3
- package/dist/frameworks/fastify.mjs.map +1 -1
- package/dist/frameworks/hono.d.mts +1 -1
- package/dist/frameworks/hono.d.mts.map +1 -1
- package/dist/frameworks/hono.mjs +4 -4
- package/dist/frameworks/hono.mjs.map +1 -1
- package/dist/frameworks/next.d.mts +3 -2
- package/dist/frameworks/next.d.mts.map +1 -1
- package/dist/frameworks/next.mjs +4 -4
- package/dist/frameworks/next.mjs.map +1 -1
- package/dist/frameworks/react.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts.map +1 -1
- package/dist/frameworks/trpc.mjs +4 -4
- package/dist/frameworks/trpc.mjs.map +1 -1
- package/dist/gerbil-BetB5xb0.d.mts +488 -0
- package/dist/gerbil-BetB5xb0.d.mts.map +1 -0
- package/dist/gerbil-CTZUa8EZ.mjs +4 -0
- package/dist/gerbil-DNniplr4.mjs +1656 -0
- package/dist/gerbil-DNniplr4.mjs.map +1 -0
- package/dist/gpu/hooks.d.mts +640 -0
- package/dist/gpu/hooks.d.mts.map +1 -0
- package/dist/gpu/hooks.mjs +1369 -0
- package/dist/gpu/hooks.mjs.map +1 -0
- package/dist/gpu/index.d.mts +2 -0
- package/dist/gpu/index.mjs +6 -0
- package/dist/gpu-DFuglcEx.mjs +3790 -0
- package/dist/gpu-DFuglcEx.mjs.map +1 -0
- package/dist/index-Dgmb2kE3.d.mts +245 -0
- package/dist/index-Dgmb2kE3.d.mts.map +1 -0
- package/dist/index-DukkJRMj.d.mts +2114 -0
- package/dist/index-DukkJRMj.d.mts.map +1 -0
- package/dist/index.d.mts +22 -487
- package/dist/index.d.mts.map +1 -1
- package/dist/index.mjs +13 -8
- package/dist/index.mjs.map +1 -1
- package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
- package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
- package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
- package/dist/integrations/ai-sdk.d.mts +75 -6
- package/dist/integrations/ai-sdk.d.mts.map +1 -1
- package/dist/integrations/ai-sdk.mjs +131 -15
- package/dist/integrations/ai-sdk.mjs.map +1 -1
- package/dist/integrations/langchain.d.mts +1 -1
- package/dist/integrations/langchain.d.mts.map +1 -1
- package/dist/integrations/langchain.mjs +5 -5
- package/dist/integrations/langchain.mjs.map +1 -1
- package/dist/integrations/llamaindex.d.mts +1 -1
- package/dist/integrations/llamaindex.d.mts.map +1 -1
- package/dist/integrations/llamaindex.mjs +5 -5
- package/dist/integrations/llamaindex.mjs.map +1 -1
- package/dist/integrations/mcp-client.mjs +3 -3
- package/dist/integrations/mcp-client.mjs.map +1 -1
- package/dist/integrations/mcp.d.mts +3 -2
- package/dist/integrations/mcp.d.mts.map +1 -1
- package/dist/integrations/mcp.mjs +5 -5
- package/dist/{mcp-BvbriaBy.mjs → mcp-D2vvH1Xc.mjs} +4 -4
- package/dist/mcp-D2vvH1Xc.mjs.map +1 -0
- package/dist/memory/index.d.mts +3 -0
- package/dist/memory/index.mjs +6 -0
- package/dist/memory-D1P7Tmda.mjs +4 -0
- package/dist/memory-DVN0MnIG.mjs +132 -0
- package/dist/memory-DVN0MnIG.mjs.map +1 -0
- package/dist/memory-Dj0J1v88.mjs +294 -0
- package/dist/memory-Dj0J1v88.mjs.map +1 -0
- package/dist/moonshine-stt-17dpP1kr.mjs +4 -0
- package/dist/moonshine-stt-4ojLtMq7.mjs +11962 -0
- package/dist/moonshine-stt-4ojLtMq7.mjs.map +1 -0
- package/dist/{one-liner-s-lD8rCC.mjs → one-liner-JhdIPxzF.mjs} +14 -16
- package/dist/one-liner-JhdIPxzF.mjs.map +1 -0
- package/dist/repl-BDRkwPGX.mjs +9 -0
- package/dist/skills/index.d.mts +270 -320
- package/dist/skills/index.d.mts.map +1 -1
- package/dist/skills/index.mjs +5 -5
- package/dist/{skills-CD3Orlex.mjs → skills-CU694Dc8.mjs} +187 -32
- package/dist/skills-CU694Dc8.mjs.map +1 -0
- package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
- package/dist/tools-DQ1mPUw5.mjs.map +1 -0
- package/dist/types-DQBe2lFo.d.mts +165 -0
- package/dist/types-DQBe2lFo.d.mts.map +1 -0
- package/dist/{types-CiTc7ez3.d.mts → types-LlyYILII.d.mts} +112 -14
- package/dist/types-LlyYILII.d.mts.map +1 -0
- package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
- package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
- package/dist/vector-B0panuy6.mjs +95 -0
- package/dist/vector-B0panuy6.mjs.map +1 -0
- package/docs/PROJECT-STATE.md +321 -0
- package/docs/adding-a-model-family.md +280 -0
- package/docs/ai-sdk.md +70 -61
- package/docs/architecture/overview.md +17 -7
- package/docs/browser.md +203 -8
- package/docs/embeddings.md +156 -0
- package/docs/gerbil-site-native-migration.md +217 -0
- package/docs/gpu-engine/architectures.md +398 -0
- package/docs/gpu-engine/ir.md +372 -0
- package/docs/gpu-engine/kernels.md +718 -0
- package/docs/gpu-engine/paper.html +1759 -0
- package/docs/gpu-engine/paper.md +2109 -0
- package/docs/gpu-engine/safetensors.md +312 -0
- package/docs/gpu-engine/tokenizer.md +302 -0
- package/docs/memory-rag.md +91 -0
- package/docs/metal-safari-intel.md +190 -0
- package/docs/mobile-failure-diagnosis.md +124 -0
- package/docs/mobile.md +99 -0
- package/docs/observability.md +230 -0
- package/docs/onnx-removal-plan.md +339 -0
- package/docs/research/autoresearch-portable.md +904 -0
- package/docs/research/dispatch-reduction-hivemind.md +84 -0
- package/docs/research/ios-safari-model-caching.md +117 -0
- package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
- package/docs/research/native-stt-model-selection.md +49 -0
- package/docs/research/native-tts-model-selection.md +90 -0
- package/docs/research/native-vs-chromium-decision.md +152 -0
- package/docs/research/nemotron-mamba2-inference.md +910 -0
- package/docs/research/qwen35-multimodal.md +293 -0
- package/docs/research/qwen36-gemma4-targets.md +337 -0
- package/docs/research/sota-embedding-models.md +179 -0
- package/docs/research/sota-mobile-models-2026.md +263 -0
- package/docs/research/sota-modality-models.md +202 -0
- package/docs/research/tps-baselines.md +71 -0
- package/docs/research/webgpu-m4-reference.md +104 -0
- package/docs/site-update-plan.md +155 -0
- package/docs/structured-output.md +123 -0
- package/docs/stt.md +63 -446
- package/docs/tts.md +77 -499
- package/docs/vision.md +100 -338
- package/package.json +22 -7
- package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
- package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
- package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
- package/dist/gerbil-CJ3ifloF.mjs +0 -4
- package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
- package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
- package/dist/gerbil-qOTe1nl2.d.mts +0 -431
- package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
- package/dist/kokoro-BNTb6egA.mjs +0 -20210
- package/dist/kokoro-BNTb6egA.mjs.map +0 -1
- package/dist/kokoro-CMOGDSgT.js +0 -20212
- package/dist/kokoro-CMOGDSgT.js.map +0 -1
- package/dist/mcp-BvbriaBy.mjs.map +0 -1
- package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
- package/dist/repl-DveXw36T.mjs +0 -9
- package/dist/skills-CD3Orlex.mjs.map +0 -1
- package/dist/stt-Bu-E23Sc.js +0 -433
- package/dist/stt-Bu-E23Sc.js.map +0 -1
- package/dist/stt-CpLYbGFd.mjs +0 -433
- package/dist/stt-CpLYbGFd.mjs.map +0 -1
- package/dist/stt-DRPLEEHB.mjs +0 -3
- package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
- package/dist/transformers.web-DiD1gTwk.js +0 -44695
- package/dist/transformers.web-DiD1gTwk.js.map +0 -1
- package/dist/transformers.web-u34VxRFM.js +0 -3
- package/dist/tts-CqroPaSK.js +0 -724
- package/dist/tts-CqroPaSK.js.map +0 -1
- package/dist/tts-DXgsKGCe.mjs +0 -3
- package/dist/tts-DeGANMNV.mjs +0 -730
- package/dist/tts-DeGANMNV.mjs.map +0 -1
- package/dist/types-CiTc7ez3.d.mts.map +0 -1
- /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
- /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
- /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
# Adding a New Model Architecture
|
|
2
|
+
|
|
3
|
+
This guide walks through how to add support for a new model family to the gerbil GPU engine. The reference implementation is the Qwen2 graph generator in `src/gpu/architectures/qwen2.ts`.
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## How Architecture Support Works
|
|
8
|
+
|
|
9
|
+
The engine uses a **registry pattern**: HuggingFace models declare their architecture in `config.json` via the `architectures` field (e.g., `["Qwen2ForCausalLM"]`). The engine maps this string to a **graph generator function** that produces a complete `ModelGraph` (IR) from the config.
|
|
10
|
+
|
|
11
|
+
Adding a new model family requires:
|
|
12
|
+
1. Writing a graph generator function
|
|
13
|
+
2. Registering it in the architecture registry
|
|
14
|
+
3. (Optionally) writing a custom HF key mapper if the model uses non-standard weight naming
|
|
15
|
+
|
|
16
|
+
No changes to the executor, kernels, or device layer are needed -- the IR is the contract.
|
|
17
|
+
|
|
18
|
+
---
|
|
19
|
+
|
|
20
|
+
## Prerequisites
|
|
21
|
+
|
|
22
|
+
Before writing a generator, you need:
|
|
23
|
+
|
|
24
|
+
1. **A model's `config.json`**: Download from HuggingFace to understand the architecture dimensions
|
|
25
|
+
2. **Understanding of the model's layer structure**: Read the model's paper or reference implementation to know the computation graph
|
|
26
|
+
3. **A safetensors file**: To verify weight tensor names match your key mapping
|
|
27
|
+
|
|
28
|
+
---
|
|
29
|
+
|
|
30
|
+
## Step-by-Step Guide
|
|
31
|
+
|
|
32
|
+
### Step 1: Study the Config
|
|
33
|
+
|
|
34
|
+
Download `config.json` from the target model's HuggingFace repo. Example for LLaMA-3:
|
|
35
|
+
|
|
36
|
+
```json
|
|
37
|
+
{
|
|
38
|
+
"architectures": ["LlamaForCausalLM"],
|
|
39
|
+
"hidden_size": 4096,
|
|
40
|
+
"intermediate_size": 14336,
|
|
41
|
+
"num_attention_heads": 32,
|
|
42
|
+
"num_hidden_layers": 32,
|
|
43
|
+
"num_key_value_heads": 8,
|
|
44
|
+
"rms_norm_eps": 1e-05,
|
|
45
|
+
"rope_theta": 500000.0,
|
|
46
|
+
"vocab_size": 128256,
|
|
47
|
+
"max_position_embeddings": 8192
|
|
48
|
+
}
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Map these to the `ModelArchConfig` fields:
|
|
52
|
+
|
|
53
|
+
| Config Key | ModelArchConfig Field |
|
|
54
|
+
|-----------|---------------------|
|
|
55
|
+
| `hidden_size` | `hidden_size` |
|
|
56
|
+
| `num_hidden_layers` | `num_layers` |
|
|
57
|
+
| `num_attention_heads` | `num_heads` |
|
|
58
|
+
| `num_key_value_heads` | `num_kv_heads` |
|
|
59
|
+
| `intermediate_size` | `intermediate_size` |
|
|
60
|
+
| `vocab_size` | `vocab_size` |
|
|
61
|
+
| `max_position_embeddings` | `context_length` |
|
|
62
|
+
| `rms_norm_eps` | `rms_norm_eps` |
|
|
63
|
+
| `rope_theta` | `rope_base` |
|
|
64
|
+
|
|
65
|
+
### Step 2: Understand the Layer Structure
|
|
66
|
+
|
|
67
|
+
For LLaMA-family models, each transformer layer contains:
|
|
68
|
+
|
|
69
|
+
```
|
|
70
|
+
Input
|
|
71
|
+
-> RMSNorm (input_layernorm)
|
|
72
|
+
-> Q/K/V projections (3x MatMul)
|
|
73
|
+
-> RoPE (on Q and K)
|
|
74
|
+
-> GQA Attention
|
|
75
|
+
-> O projection (MatMul)
|
|
76
|
+
-> Residual Add
|
|
77
|
+
-> RMSNorm (post_attention_layernorm)
|
|
78
|
+
-> Gate projection (MatMul) ]
|
|
79
|
+
-> Up projection (MatMul) ] SwiGLU MLP
|
|
80
|
+
-> SiLU (on gate) ]
|
|
81
|
+
-> Mul (gate * up) ]
|
|
82
|
+
-> Down projection (MatMul) ]
|
|
83
|
+
-> Residual Add
|
|
84
|
+
Output
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
This is identical to Qwen2's structure, which is why the Qwen2 generator can often be reused directly for LLaMA-family models.
|
|
88
|
+
|
|
89
|
+
### Step 3: Check Weight Key Names
|
|
90
|
+
|
|
91
|
+
Download or inspect the safetensors file to see actual weight names. For LLaMA-3:
|
|
92
|
+
|
|
93
|
+
```
|
|
94
|
+
model.embed_tokens.weight
|
|
95
|
+
model.layers.0.input_layernorm.weight
|
|
96
|
+
model.layers.0.self_attn.q_proj.weight
|
|
97
|
+
model.layers.0.self_attn.k_proj.weight
|
|
98
|
+
model.layers.0.self_attn.v_proj.weight
|
|
99
|
+
model.layers.0.self_attn.o_proj.weight
|
|
100
|
+
model.layers.0.post_attention_layernorm.weight
|
|
101
|
+
model.layers.0.mlp.gate_proj.weight
|
|
102
|
+
model.layers.0.mlp.up_proj.weight
|
|
103
|
+
model.layers.0.mlp.down_proj.weight
|
|
104
|
+
model.norm.weight
|
|
105
|
+
lm_head.weight
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
After stripping the `model.` prefix, these are already in canonical form. The default HF key mapper handles this.
|
|
109
|
+
|
|
110
|
+
### Step 4: Write the Graph Generator
|
|
111
|
+
|
|
112
|
+
Create a new file, e.g., `src/gpu/architectures/llama.ts`:
|
|
113
|
+
|
|
114
|
+
```typescript
|
|
115
|
+
import type { ModelGraph, ModelArchConfig, ModelCapabilities, OpNode, TensorDesc } from "../ir.js";
|
|
116
|
+
import { CANONICAL_KEYS } from "../ir.js";
|
|
117
|
+
|
|
118
|
+
export function generateLlamaGraph(rawConfig: Record<string, unknown>): ModelGraph {
|
|
119
|
+
// Extract config values
|
|
120
|
+
const hidden_size = rawConfig.hidden_size as number;
|
|
121
|
+
const num_layers = rawConfig.num_hidden_layers as number;
|
|
122
|
+
const num_heads = rawConfig.num_attention_heads as number;
|
|
123
|
+
const num_kv_heads = (rawConfig.num_key_value_heads as number) ?? num_heads;
|
|
124
|
+
const intermediate_size = rawConfig.intermediate_size as number;
|
|
125
|
+
const vocab_size = rawConfig.vocab_size as number;
|
|
126
|
+
const context_length = (rawConfig.max_position_embeddings as number) ?? 8192;
|
|
127
|
+
const rms_norm_eps = (rawConfig.rms_norm_eps as number) ?? 1e-5;
|
|
128
|
+
const rope_base = (rawConfig.rope_theta as number) ?? 10000.0;
|
|
129
|
+
const head_dim = (rawConfig.head_dim as number) ?? Math.floor(hidden_size / num_heads);
|
|
130
|
+
const kv_dim = num_kv_heads * head_dim;
|
|
131
|
+
|
|
132
|
+
const config: ModelArchConfig = {
|
|
133
|
+
hidden_size,
|
|
134
|
+
num_layers,
|
|
135
|
+
num_heads,
|
|
136
|
+
num_kv_heads,
|
|
137
|
+
head_dim,
|
|
138
|
+
intermediate_size,
|
|
139
|
+
vocab_size,
|
|
140
|
+
context_length,
|
|
141
|
+
rms_norm_eps,
|
|
142
|
+
norm_type: "rmsnorm",
|
|
143
|
+
rope_base,
|
|
144
|
+
rope_dim: head_dim,
|
|
145
|
+
kv_layout: "LHSd",
|
|
146
|
+
is_moe: false,
|
|
147
|
+
has_vision_tower: false,
|
|
148
|
+
};
|
|
149
|
+
|
|
150
|
+
// ... (same tensor/node generation as Qwen2)
|
|
151
|
+
// For LLaMA, the structure is identical to Qwen2.
|
|
152
|
+
// The only differences are config values and default parameters.
|
|
153
|
+
}
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
If the model uses the same layer structure as Qwen2 (which LLaMA does), you can literally reuse `generateQwen2Graph` -- the generator reads all dimensions from the config, so it adapts automatically.
|
|
157
|
+
|
|
158
|
+
### Step 5: Handle Architectural Differences
|
|
159
|
+
|
|
160
|
+
For models with different layer structures, you'll need custom node generation. Common differences:
|
|
161
|
+
|
|
162
|
+
#### Different MLP Structure (e.g., Phi)
|
|
163
|
+
|
|
164
|
+
Phi models use `fc1`/`fc2` instead of `gate_proj`/`up_proj`/`down_proj`:
|
|
165
|
+
|
|
166
|
+
```typescript
|
|
167
|
+
// Phi MLP: hidden -> fc1 (with GELU) -> fc2 -> hidden
|
|
168
|
+
// Instead of SwiGLU: hidden -> gate+up (with SiLU) -> down -> hidden
|
|
169
|
+
|
|
170
|
+
addNode({
|
|
171
|
+
id: `${prefix}_fc1`,
|
|
172
|
+
opType: "MatMul",
|
|
173
|
+
inputs: [norm2Out, fc1Weight],
|
|
174
|
+
outputs: [fc1Out],
|
|
175
|
+
attributes: { M_tensor: norm2Out, K: hidden_size, N: intermediate_size },
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
addNode({
|
|
179
|
+
id: `${prefix}_gelu`,
|
|
180
|
+
opType: "GELU",
|
|
181
|
+
inputs: [fc1Out],
|
|
182
|
+
outputs: [geluOut],
|
|
183
|
+
attributes: { count_tensor: fc1Out },
|
|
184
|
+
});
|
|
185
|
+
|
|
186
|
+
addNode({
|
|
187
|
+
id: `${prefix}_fc2`,
|
|
188
|
+
opType: "MatMul",
|
|
189
|
+
inputs: [geluOut, fc2Weight],
|
|
190
|
+
outputs: [fc2Out],
|
|
191
|
+
attributes: { M_tensor: geluOut, K: intermediate_size, N: hidden_size },
|
|
192
|
+
});
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
#### Different Normalization (LayerNorm vs RMSNorm)
|
|
196
|
+
|
|
197
|
+
Some models (GPT-2, older architectures) use LayerNorm instead of RMSNorm. Use the `LayerNorm` op type:
|
|
198
|
+
|
|
199
|
+
```typescript
|
|
200
|
+
addNode({
|
|
201
|
+
id: `${prefix}_norm`,
|
|
202
|
+
opType: "LayerNorm",
|
|
203
|
+
inputs: [prevOutput, normWeight, normBias], // LayerNorm has a bias term
|
|
204
|
+
outputs: [normOut],
|
|
205
|
+
attributes: { hidden_size, eps },
|
|
206
|
+
});
|
|
207
|
+
```
|
|
208
|
+
|
|
209
|
+
#### Q/K/V Bias
|
|
210
|
+
|
|
211
|
+
Some models (Qwen2, but not Qwen3 or LLaMA) include bias terms in the Q/K/V projections. When bias is present:
|
|
212
|
+
|
|
213
|
+
```typescript
|
|
214
|
+
// After MatMul for Q projection:
|
|
215
|
+
addNode({
|
|
216
|
+
id: `${prefix}_q_bias`,
|
|
217
|
+
opType: "Add",
|
|
218
|
+
inputs: [qOut, qProjBias],
|
|
219
|
+
outputs: [qBiasOut],
|
|
220
|
+
attributes: { count_tensor: qOut },
|
|
221
|
+
});
|
|
222
|
+
```
|
|
223
|
+
|
|
224
|
+
### Step 6: Handle Non-Standard Weight Names
|
|
225
|
+
|
|
226
|
+
If the model uses weight names that don't follow the canonical convention after stripping `model.`, write a custom key mapper:
|
|
227
|
+
|
|
228
|
+
```typescript
|
|
229
|
+
// For a hypothetical model with unusual naming:
|
|
230
|
+
import type { HFKeyMapper } from "../ir.js";
|
|
231
|
+
|
|
232
|
+
export function createPhiKeyMapper(): HFKeyMapper {
|
|
233
|
+
return (hfKey: string): string | null => {
|
|
234
|
+
let key = hfKey;
|
|
235
|
+
if (key.startsWith("model.")) key = key.slice(6);
|
|
236
|
+
|
|
237
|
+
// Phi uses "fc1" and "fc2" instead of "gate_proj"/"up_proj"/"down_proj"
|
|
238
|
+
key = key.replace(/\.mlp\.fc1\./, ".mlp.gate_proj.");
|
|
239
|
+
key = key.replace(/\.mlp\.fc2\./, ".mlp.down_proj.");
|
|
240
|
+
|
|
241
|
+
// Phi uses "dense" instead of "o_proj"
|
|
242
|
+
key = key.replace(/\.self_attn\.dense\./, ".self_attn.o_proj.");
|
|
243
|
+
|
|
244
|
+
return key;
|
|
245
|
+
};
|
|
246
|
+
}
|
|
247
|
+
```
|
|
248
|
+
|
|
249
|
+
### Step 7: Register the Architecture
|
|
250
|
+
|
|
251
|
+
In `src/gpu/architectures/index.ts`, import and register:
|
|
252
|
+
|
|
253
|
+
```typescript
|
|
254
|
+
import { generateLlamaGraph } from "./llama.js";
|
|
255
|
+
|
|
256
|
+
export const ARCHITECTURES: Record<string, GraphGenerator> = {
|
|
257
|
+
// Existing
|
|
258
|
+
Qwen2ForCausalLM: generateQwen2Graph,
|
|
259
|
+
Qwen3ForCausalLM: generateQwen2Graph,
|
|
260
|
+
|
|
261
|
+
// New
|
|
262
|
+
LlamaForCausalLM: generateLlamaGraph,
|
|
263
|
+
MistralForCausalLM: generateLlamaGraph, // Same architecture as LLaMA
|
|
264
|
+
};
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
### Step 8: Test
|
|
268
|
+
|
|
269
|
+
```typescript
|
|
270
|
+
import { generateGraph } from "./architectures/index.js";
|
|
271
|
+
|
|
272
|
+
// Load a test config
|
|
273
|
+
const config = {
|
|
274
|
+
architectures: ["LlamaForCausalLM"],
|
|
275
|
+
hidden_size: 4096,
|
|
276
|
+
num_hidden_layers: 32,
|
|
277
|
+
num_attention_heads: 32,
|
|
278
|
+
num_key_value_heads: 8,
|
|
279
|
+
intermediate_size: 14336,
|
|
280
|
+
vocab_size: 128256,
|
|
281
|
+
max_position_embeddings: 8192,
|
|
282
|
+
rms_norm_eps: 1e-5,
|
|
283
|
+
rope_theta: 500000.0,
|
|
284
|
+
};
|
|
285
|
+
|
|
286
|
+
const graph = generateGraph("LlamaForCausalLM", config);
|
|
287
|
+
|
|
288
|
+
// Verify graph structure
|
|
289
|
+
console.log(`Architecture: ${graph.architecture}`);
|
|
290
|
+
console.log(`Nodes: ${graph.nodes.length}`);
|
|
291
|
+
console.log(`Tensors: ${Object.keys(graph.tensors).length}`);
|
|
292
|
+
console.log(`Execution order: ${graph.executionOrder.length} steps`);
|
|
293
|
+
|
|
294
|
+
// Check that all expected weight tensors exist
|
|
295
|
+
const weightTensors = Object.values(graph.tensors)
|
|
296
|
+
.filter(t => t.storage === "constant");
|
|
297
|
+
console.log(`Weight tensors: ${weightTensors.length}`);
|
|
298
|
+
|
|
299
|
+
// Verify execution order starts and ends correctly
|
|
300
|
+
console.log(`First op: ${graph.executionOrder[0]}`); // "embed"
|
|
301
|
+
console.log(`Last op: ${graph.executionOrder.at(-1)}`); // "lm_head"
|
|
302
|
+
```
|
|
303
|
+
|
|
304
|
+
---
|
|
305
|
+
|
|
306
|
+
## Architecture Checklist
|
|
307
|
+
|
|
308
|
+
When adding a new model family, verify each item:
|
|
309
|
+
|
|
310
|
+
- [ ] Config extraction: All required fields read from `config.json` with sensible defaults
|
|
311
|
+
- [ ] Architecture string: Matches exactly what HF `config.architectures[0]` contains
|
|
312
|
+
- [ ] Embedding: `Embedding` node with correct vocab_size and hidden_size
|
|
313
|
+
- [ ] Per-layer structure: Correct sequence of norm -> attention -> residual -> MLP -> residual
|
|
314
|
+
- [ ] Norm type: Using `RMSNorm` or `LayerNorm` as appropriate
|
|
315
|
+
- [ ] Q/K/V projections: Correct shapes (hidden_size -> hidden_size for Q, hidden_size -> kv_dim for K/V)
|
|
316
|
+
- [ ] GQA: `num_kv_heads` correctly extracted (may differ from `num_heads`)
|
|
317
|
+
- [ ] RoPE: Correct `head_dim`, `rope_base`, and `position_offset` handling
|
|
318
|
+
- [ ] MLP structure: SwiGLU (gate/up/silu/mul/down) vs GELU (fc1/gelu/fc2)
|
|
319
|
+
- [ ] Residual connections: Add nodes connecting skip paths correctly
|
|
320
|
+
- [ ] Final norm: After all layers, before LM head
|
|
321
|
+
- [ ] LM head: MatMul from hidden_size to vocab_size
|
|
322
|
+
- [ ] Biases: Included where the model uses them (check safetensors for `*.bias` keys)
|
|
323
|
+
- [ ] Key mapper: Weight names resolve correctly from safetensors to canonical names
|
|
324
|
+
- [ ] Input/output: Graph inputs = `["input_ids"]`, outputs = `["logits"]`
|
|
325
|
+
- [ ] Execution order: Topologically valid (each node's inputs are produced before it runs)
|
|
326
|
+
- [ ] KV cache tensors: Created for each layer with `storage: "kv_cache"`
|
|
327
|
+
|
|
328
|
+
---
|
|
329
|
+
|
|
330
|
+
## Model Family Reference
|
|
331
|
+
|
|
332
|
+
| Family | Architecture String | Layer Structure | MLP | Norm | Notes |
|
|
333
|
+
|--------|-------------------|----------------|-----|------|-------|
|
|
334
|
+
| Qwen2 | `Qwen2ForCausalLM` | Standard | SwiGLU | RMSNorm | Q/K/V bias |
|
|
335
|
+
| Qwen3 | `Qwen3ForCausalLM` | Standard | SwiGLU | RMSNorm | No bias |
|
|
336
|
+
| LLaMA 2 | `LlamaForCausalLM` | Standard | SwiGLU | RMSNorm | rope_theta=10000 |
|
|
337
|
+
| LLaMA 3 | `LlamaForCausalLM` | Standard | SwiGLU | RMSNorm | rope_theta=500000 |
|
|
338
|
+
| Mistral | `MistralForCausalLM` | Standard | SwiGLU | RMSNorm | Sliding window attention |
|
|
339
|
+
| Phi-3 | `Phi3ForCausalLM` | Standard | SwiGLU | RMSNorm | Fused QKV |
|
|
340
|
+
| GPT-2 | `GPT2LMHeadModel` | Different | GELU MLP | LayerNorm | Pre-norm vs post-norm |
|
|
341
|
+
| SmolLM | `LlamaForCausalLM` | Standard | SwiGLU | RMSNorm | Same as LLaMA |
|
|
342
|
+
|
|
343
|
+
"Standard" layer structure means: pre-norm -> attention -> residual -> post-norm -> MLP -> residual. Most modern models follow this pattern.
|
|
344
|
+
|
|
345
|
+
---
|
|
346
|
+
|
|
347
|
+
## Common Pitfalls
|
|
348
|
+
|
|
349
|
+
### 1. Forgetting to Update prevOutput
|
|
350
|
+
|
|
351
|
+
Each layer must update the `prevOutput` variable to its final residual output. If you forget, subsequent layers will read stale data:
|
|
352
|
+
|
|
353
|
+
```typescript
|
|
354
|
+
// WRONG: prevOutput still points to previous layer
|
|
355
|
+
addNode({ id: `${prefix}_resid2`, ... outputs: [`${prefix}_resid2`], ... });
|
|
356
|
+
// prevOutput is not updated!
|
|
357
|
+
|
|
358
|
+
// RIGHT:
|
|
359
|
+
prevOutput = `${prefix}_resid2`;
|
|
360
|
+
```
|
|
361
|
+
|
|
362
|
+
### 2. Shape Mismatches in GQA
|
|
363
|
+
|
|
364
|
+
When `num_kv_heads < num_heads`, K and V projections output `kv_dim = num_kv_heads * head_dim`, not `hidden_size`. The attention kernel handles the head mapping internally, but the tensor shapes must be correct:
|
|
365
|
+
|
|
366
|
+
```typescript
|
|
367
|
+
// Q: [T, hidden_size] = [T, num_heads * head_dim]
|
|
368
|
+
// K: [T, kv_dim] = [T, num_kv_heads * head_dim] <-- NOT hidden_size
|
|
369
|
+
// V: [T, kv_dim]
|
|
370
|
+
```
|
|
371
|
+
|
|
372
|
+
### 3. Missing Safetensors Key Assignment
|
|
373
|
+
|
|
374
|
+
Weight tensors must have `safetensorsKey` set, or the model loader won't know to upload them:
|
|
375
|
+
|
|
376
|
+
```typescript
|
|
377
|
+
addTensor({
|
|
378
|
+
name: CANONICAL_KEYS.qProj(i),
|
|
379
|
+
shape: [hidden_size, hidden_size],
|
|
380
|
+
dtype: "f32",
|
|
381
|
+
storage: "constant",
|
|
382
|
+
safetensorsKey: CANONICAL_KEYS.qProj(i), // Don't forget this!
|
|
383
|
+
});
|
|
384
|
+
```
|
|
385
|
+
|
|
386
|
+
### 4. Activation Tensors with safetensorsKey
|
|
387
|
+
|
|
388
|
+
Activation tensors should NOT have `safetensorsKey`. They're computed during the forward pass, not loaded from files:
|
|
389
|
+
|
|
390
|
+
```typescript
|
|
391
|
+
addTensor({
|
|
392
|
+
name: qOut,
|
|
393
|
+
shape: ["T", hidden_size],
|
|
394
|
+
dtype: "f32",
|
|
395
|
+
storage: "activation",
|
|
396
|
+
// NO safetensorsKey here
|
|
397
|
+
});
|
|
398
|
+
```
|