mambacode.js 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +19 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +61 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
- package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/{trainer.js → trainer.ts} +79 -161
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
package/README.md
CHANGED
|
@@ -1,14 +1,22 @@
|
|
|
1
|
-
#
|
|
2
|
-
MambaCode.js — WebGPU-accelerated Mamba SSM library for browser-based code model training and inference.
|
|
1
|
+
# MambaCode.js
|
|
3
2
|
|
|
4
|
-
|
|
3
|
+
> WebGPU-accelerated Mamba State Space Model library — written in **TypeScript**, compiled for use in any JavaScript application.
|
|
5
4
|
|
|
6
|
-
|
|
5
|
+
[](https://www.npmjs.com/package/mambacode.js)
|
|
6
|
+
[](./LICENSE)
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
MambaCode.js is a **TypeScript-first** library that brings the [Mamba SSM](https://arxiv.org/abs/2312.00752) architecture to the browser via WebGPU. It targets the Qwen3.5-Coder-0.8B model shape and supports full **on-device training** (backpropagation), allowing models to adapt to a user's private codebase locally — without any data leaving the browser.
|
|
9
|
+
|
|
10
|
+
> 📖 **New to MambaCode.js?** Start with the [Getting Started Guide](./docs/getting-started.md).
|
|
11
|
+
|
|
12
|
+
---
|
|
13
|
+
|
|
14
|
+
## Key Features
|
|
9
15
|
|
|
10
16
|
| Feature | Detail |
|
|
11
17
|
|---|---|
|
|
18
|
+
| **TypeScript-first** | Full type declarations shipped with the package |
|
|
19
|
+
| **Plain JS compatible** | Import the compiled `dist/` in any JavaScript project — no TypeScript toolchain required |
|
|
12
20
|
| **Architecture** | Selective State Space Model (S6) — linear O(N) context scaling |
|
|
13
21
|
| **Hardware target** | WebGPU (WGSL) — Chrome 113+, Edge 113+, Firefox Nightly |
|
|
14
22
|
| **Memory ceiling** | ≤ 3 GB VRAM (Chrome/Edge/Firefox stable) |
|
|
@@ -20,91 +28,138 @@ MambaCode.js is a pure JavaScript/WGSL implementation of the **Mamba State Space
|
|
|
20
28
|
|
|
21
29
|
---
|
|
22
30
|
|
|
23
|
-
##
|
|
31
|
+
## Installation
|
|
24
32
|
|
|
33
|
+
```bash
|
|
34
|
+
npm install mambacode.js
|
|
25
35
|
```
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
▼ ┌─────────────────────────────────────────┐
|
|
32
|
-
│ Mamba Block × N │
|
|
33
|
-
│ │
|
|
34
|
-
│ Input ──► RMSNorm │
|
|
35
|
-
│ │ │
|
|
36
|
-
│ ┌────────┴────────┐ │
|
|
37
|
-
│ ▼ ▼ │
|
|
38
|
-
│ in_proj(x) in_proj(z) [gate] │
|
|
39
|
-
│ │ │
|
|
40
|
-
│ Conv1D (causal, K=4) │
|
|
41
|
-
│ │ │
|
|
42
|
-
│ SiLU activation │
|
|
43
|
-
│ │ │
|
|
44
|
-
│ x_proj → Δ, B, C (selective) │
|
|
45
|
-
│ │ │
|
|
46
|
-
│ Δ → dt_proj (full D_inner width) │
|
|
47
|
-
│ │ │
|
|
48
|
-
│ ┌───▼──────────────────────────────┐ │
|
|
49
|
-
│ │ Selective Scan S6 │ │
|
|
50
|
-
│ │ (Kogge-Stone parallel prefix) │ │
|
|
51
|
-
│ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
|
|
52
|
-
│ │ y_t = C·h_t + D·x_t │ │
|
|
53
|
-
│ └──────────────────────────────────┘ │
|
|
54
|
-
│ │ │
|
|
55
|
-
│ Gate: y * SiLU(z) │
|
|
56
|
-
│ │ │
|
|
57
|
-
│ out_proj → residual add ──► output │
|
|
58
|
-
└─────────────────────────────────────────┘
|
|
59
|
-
│
|
|
60
|
-
▼
|
|
61
|
-
Final RMSNorm → LM Head (tied embedding) → Logits
|
|
36
|
+
|
|
37
|
+
Build the library from source:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
npm run build # compiles TypeScript → dist/
|
|
62
41
|
```
|
|
63
42
|
|
|
64
43
|
---
|
|
65
44
|
|
|
45
|
+
## Documentation
|
|
46
|
+
|
|
47
|
+
| Guide | Description |
|
|
48
|
+
|---|---|
|
|
49
|
+
| **[Getting Started](docs/getting-started.md)** | Beginner-friendly introduction — what LLMs are, how Qwen fits in, the full model lifecycle, and what to do next |
|
|
50
|
+
| **[Integration & Architecture](docs/integration-architecture.md)** | Production architecture guide — embedding Mamba as a unified brain + memory system, integration patterns, advanced use cases, and design tradeoffs |
|
|
51
|
+
| **[Weight Lifecycle](docs/weight-lifecycle.md)** | Complete guide to obtaining Qwen vocabulary files, loading pre-trained checkpoints, fine-tuning, exporting weights, and sharing with your team |
|
|
52
|
+
| **[API Reference](docs/api-reference.md)** | Full technical reference — every exported class, interface, and function with TypeScript and JavaScript examples |
|
|
53
|
+
| **[MambaKit PRD](docs/mamba-kit-prd.md)** | Product requirements document for MambaKit — an opinionated, zero-boilerplate facade over MambaCode.js |
|
|
54
|
+
|
|
55
|
+
---
|
|
56
|
+
|
|
66
57
|
## Quick Start
|
|
67
58
|
|
|
68
|
-
|
|
69
|
-
|
|
59
|
+
### TypeScript
|
|
60
|
+
|
|
61
|
+
```ts
|
|
62
|
+
import {
|
|
63
|
+
MambaModel,
|
|
64
|
+
MambaTrainer,
|
|
65
|
+
BPETokenizer,
|
|
66
|
+
initWebGPU,
|
|
67
|
+
type MambaModelConfig,
|
|
68
|
+
type TrainOptions,
|
|
69
|
+
} from 'mambacode.js';
|
|
70
70
|
|
|
71
71
|
// 1. Initialise WebGPU
|
|
72
72
|
const { device } = await initWebGPU();
|
|
73
73
|
|
|
74
|
-
// 2. Load tokenizer
|
|
74
|
+
// 2. Load tokenizer (vocab.json + merges.txt from Qwen3.5-Coder)
|
|
75
75
|
const tokenizer = new BPETokenizer();
|
|
76
76
|
await tokenizer.load('/vocab.json', '/merges.txt');
|
|
77
77
|
|
|
78
78
|
// 3. Create model
|
|
79
|
-
const
|
|
80
|
-
vocabSize : tokenizer.vocabSize, //
|
|
79
|
+
const config: MambaModelConfig = {
|
|
80
|
+
vocabSize : tokenizer.vocabSize, // 151936 for Qwen3.5-Coder
|
|
81
81
|
dModel : 512,
|
|
82
82
|
numLayers : 8,
|
|
83
83
|
dState : 16,
|
|
84
84
|
dConv : 4,
|
|
85
85
|
expand : 2,
|
|
86
|
+
};
|
|
87
|
+
const model = new MambaModel(device, config);
|
|
88
|
+
|
|
89
|
+
// 4. Load a pre-trained checkpoint
|
|
90
|
+
const response = await fetch('/models/mamba-coder-checkpoint.bin');
|
|
91
|
+
await model.loadWeights(await response.arrayBuffer());
|
|
92
|
+
|
|
93
|
+
// 5. Fine-tune on local code
|
|
94
|
+
const trainer = new MambaTrainer(model, tokenizer);
|
|
95
|
+
const opts: TrainOptions = {
|
|
96
|
+
learningRate : 1e-4,
|
|
97
|
+
epochs : 5,
|
|
98
|
+
onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
|
|
99
|
+
};
|
|
100
|
+
const losses = await trainer.train(myCodeString, opts);
|
|
101
|
+
|
|
102
|
+
// 6. Generate code
|
|
103
|
+
const promptIds = tokenizer.encode('function fibonacci(');
|
|
104
|
+
const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
|
|
105
|
+
console.log(tokenizer.decode(outputIds));
|
|
106
|
+
|
|
107
|
+
// 7. Save fine-tuned weights for next session
|
|
108
|
+
const checkpoint = await model.exportWeights();
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### JavaScript (ESM)
|
|
112
|
+
|
|
113
|
+
The compiled output in `dist/` is plain JavaScript with no TypeScript runtime dependency:
|
|
114
|
+
|
|
115
|
+
```js
|
|
116
|
+
import {
|
|
117
|
+
MambaModel,
|
|
118
|
+
MambaTrainer,
|
|
119
|
+
BPETokenizer,
|
|
120
|
+
initWebGPU,
|
|
121
|
+
} from 'mambacode.js';
|
|
122
|
+
|
|
123
|
+
// 1. Initialise WebGPU
|
|
124
|
+
const { device } = await initWebGPU();
|
|
125
|
+
|
|
126
|
+
// 2. Load tokenizer (vocab.json + merges.txt from Qwen3.5-Coder)
|
|
127
|
+
const tokenizer = new BPETokenizer();
|
|
128
|
+
await tokenizer.load('/vocab.json', '/merges.txt');
|
|
129
|
+
|
|
130
|
+
// 3. Create model
|
|
131
|
+
const model = new MambaModel(device, {
|
|
132
|
+
vocabSize : tokenizer.vocabSize,
|
|
133
|
+
dModel : 512,
|
|
134
|
+
numLayers : 8,
|
|
86
135
|
});
|
|
87
136
|
|
|
88
|
-
// 4.
|
|
137
|
+
// 4. Load a pre-trained checkpoint
|
|
138
|
+
const response = await fetch('/models/mamba-coder-checkpoint.bin');
|
|
139
|
+
await model.loadWeights(await response.arrayBuffer());
|
|
140
|
+
|
|
141
|
+
// 5. Fine-tune on local code
|
|
89
142
|
const trainer = new MambaTrainer(model, tokenizer);
|
|
90
|
-
const losses
|
|
143
|
+
const losses = await trainer.train(myCodeString, {
|
|
91
144
|
learningRate : 1e-4,
|
|
92
145
|
epochs : 5,
|
|
93
|
-
device : 'webgpu',
|
|
94
146
|
onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
|
|
95
147
|
});
|
|
96
148
|
|
|
97
|
-
//
|
|
149
|
+
// 6. Generate code
|
|
98
150
|
const promptIds = tokenizer.encode('function fibonacci(');
|
|
99
151
|
const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
|
|
100
152
|
console.log(tokenizer.decode(outputIds));
|
|
153
|
+
|
|
154
|
+
// 7. Save fine-tuned weights for next session
|
|
155
|
+
const checkpoint = await model.exportWeights();
|
|
101
156
|
```
|
|
102
157
|
|
|
103
158
|
### WSLA (Weight-Selective Local Adaptation)
|
|
104
159
|
|
|
105
160
|
Fine-tune only the B and C matrices for rapid private-codebase adaptation:
|
|
106
161
|
|
|
107
|
-
```
|
|
162
|
+
```ts
|
|
108
163
|
await trainer.train(apiUsageExamples, {
|
|
109
164
|
learningRate : 1e-4,
|
|
110
165
|
epochs : 3,
|
|
@@ -114,40 +169,95 @@ await trainer.train(apiUsageExamples, {
|
|
|
114
169
|
|
|
115
170
|
---
|
|
116
171
|
|
|
172
|
+
## Architecture
|
|
173
|
+
|
|
174
|
+
```
|
|
175
|
+
Token IDs
|
|
176
|
+
│
|
|
177
|
+
▼
|
|
178
|
+
Embedding Lookup (GPU gather kernel)
|
|
179
|
+
│
|
|
180
|
+
▼ ┌─────────────────────────────────────────┐
|
|
181
|
+
│ Mamba Block × N │
|
|
182
|
+
│ │
|
|
183
|
+
│ Input ──► RMSNorm │
|
|
184
|
+
│ │ │
|
|
185
|
+
│ ┌────────┴────────┐ │
|
|
186
|
+
│ ▼ ▼ │
|
|
187
|
+
│ in_proj(x) in_proj(z) [gate] │
|
|
188
|
+
│ │ │
|
|
189
|
+
│ Conv1D (causal, K=4) │
|
|
190
|
+
│ │ │
|
|
191
|
+
│ SiLU activation │
|
|
192
|
+
│ │ │
|
|
193
|
+
│ x_proj → Δ, B, C (selective) │
|
|
194
|
+
│ │ │
|
|
195
|
+
│ Δ → dt_proj (full D_inner width) │
|
|
196
|
+
│ │ │
|
|
197
|
+
│ ┌───▼──────────────────────────────┐ │
|
|
198
|
+
│ │ Selective Scan S6 │ │
|
|
199
|
+
│ │ (Kogge-Stone parallel prefix) │ │
|
|
200
|
+
│ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
|
|
201
|
+
│ │ y_t = C·h_t + D·x_t │ │
|
|
202
|
+
│ └──────────────────────────────────┘ │
|
|
203
|
+
│ │ │
|
|
204
|
+
│ Gate: y * SiLU(z) │
|
|
205
|
+
│ │ │
|
|
206
|
+
│ out_proj → residual add ──► output │
|
|
207
|
+
└─────────────────────────────────────────┘
|
|
208
|
+
│
|
|
209
|
+
▼
|
|
210
|
+
Final RMSNorm → LM Head (tied embedding) → Logits
|
|
211
|
+
```
|
|
212
|
+
|
|
213
|
+
---
|
|
214
|
+
|
|
117
215
|
## File Structure
|
|
118
216
|
|
|
119
217
|
```
|
|
120
|
-
src/
|
|
121
|
-
├── index.
|
|
218
|
+
src/ ← TypeScript source (edit here)
|
|
219
|
+
├── index.ts ← public API entry point
|
|
122
220
|
├── kernels/
|
|
123
|
-
│ ├── selective_scan.
|
|
124
|
-
│ ├── conv1d.
|
|
125
|
-
│ ├── linear_projection.
|
|
126
|
-
│ ├── weight_update.
|
|
127
|
-
│ └── activations.
|
|
221
|
+
│ ├── selective_scan.ts ← WGSL: S6 forward + backward (Kogge-Stone)
|
|
222
|
+
│ ├── conv1d.ts ← WGSL: 1D causal convolution
|
|
223
|
+
│ ├── linear_projection.ts ← WGSL: tiled matrix multiplication
|
|
224
|
+
│ ├── weight_update.ts ← WGSL: AdamW optimizer + gradient clipping
|
|
225
|
+
│ └── activations.ts ← WGSL: SiLU, RMSNorm
|
|
128
226
|
├── model/
|
|
129
|
-
│ ├── mamba_block.
|
|
130
|
-
│ └── mamba_model.
|
|
227
|
+
│ ├── mamba_block.ts ← Mamba Mixer Block (forward pass)
|
|
228
|
+
│ └── mamba_model.ts ← Full stacked model + generation
|
|
131
229
|
├── training/
|
|
132
|
-
│ ├── autograd.
|
|
133
|
-
│ └── trainer.
|
|
230
|
+
│ ├── autograd.ts ← Tape-based AD engine + loss helpers
|
|
231
|
+
│ └── trainer.ts ← MambaTrainer class
|
|
134
232
|
├── tokenizer/
|
|
135
|
-
│ └── bpe.
|
|
233
|
+
│ └── bpe.ts ← Browser-side BPE tokenizer
|
|
136
234
|
└── utils/
|
|
137
|
-
├── gpu_utils.
|
|
138
|
-
└── quantization.
|
|
235
|
+
├── gpu_utils.ts ← WebGPU device/buffer management
|
|
236
|
+
└── quantization.ts ← FP16 / Int8 quantization utilities
|
|
237
|
+
|
|
238
|
+
dist/ ← Compiled output (JS + .d.ts, gitignored)
|
|
239
|
+
├── index.js ← ESM entry point for JS consumers
|
|
240
|
+
├── index.d.ts ← TypeScript declarations for TS consumers
|
|
241
|
+
└── ... ← mirrored sub-folders
|
|
242
|
+
|
|
139
243
|
tests/
|
|
140
|
-
├── kernels.test.
|
|
141
|
-
├── autograd.test.
|
|
142
|
-
├── bpe.test.
|
|
143
|
-
└── quantization.test.
|
|
244
|
+
├── kernels.test.ts ← WGSL kernel source smoke tests
|
|
245
|
+
├── autograd.test.ts ← Autograd engine unit tests
|
|
246
|
+
├── bpe.test.ts ← BPE tokenizer unit tests
|
|
247
|
+
└── quantization.test.ts ← Quantization round-trip tests
|
|
248
|
+
|
|
249
|
+
docs/
|
|
250
|
+
├── getting-started.md ← Step-by-step guide (TS & JS)
|
|
251
|
+
├── integration-architecture.md ← Brain + Memory architecture guide
|
|
252
|
+
├── weight-lifecycle.md ← Weight loading, fine-tuning, export
|
|
253
|
+
└── api-reference.md ← Full API reference
|
|
144
254
|
```
|
|
145
255
|
|
|
146
256
|
---
|
|
147
257
|
|
|
148
258
|
## WGSL Kernels
|
|
149
259
|
|
|
150
|
-
### Parallel Selective Scan (`selective_scan.
|
|
260
|
+
### Parallel Selective Scan (`selective_scan.ts`)
|
|
151
261
|
Implements the S6 core using a **Kogge-Stone parallel prefix-sum** inside each workgroup tile. Each tile of 64 time steps is scanned in log₂(64) = 6 GPU barrier rounds, giving O(log N) wall-clock time on the GPU.
|
|
152
262
|
|
|
153
263
|
The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
|
|
@@ -158,13 +268,13 @@ The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
|
|
|
158
268
|
|
|
159
269
|
Tiles are chained via a carry-in state, covering arbitrarily long sequences.
|
|
160
270
|
|
|
161
|
-
### 1D Causal Convolution (`conv1d.
|
|
271
|
+
### 1D Causal Convolution (`conv1d.ts`)
|
|
162
272
|
Depthwise 1D causal conv (kernel size K=4) with zero left-padding. Enforces causality by only reading positions `t-k` for `k ≥ 0`, contributing 0 for `t < k`.
|
|
163
273
|
|
|
164
|
-
### Linear Projection (`linear_projection.
|
|
274
|
+
### Linear Projection (`linear_projection.ts`)
|
|
165
275
|
Tiled 16×16 GEMM in WGSL using workgroup shared memory. Handles arbitrary (M, K) × (N, K) → (M, N) shapes with boundary guards.
|
|
166
276
|
|
|
167
|
-
### AdamW Optimizer (`weight_update.
|
|
277
|
+
### AdamW Optimizer (`weight_update.ts`)
|
|
168
278
|
Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pass gradient norm clipping kernel (reduce → scale).
|
|
169
279
|
|
|
170
280
|
---
|
|
@@ -172,10 +282,12 @@ Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pas
|
|
|
172
282
|
## Testing
|
|
173
283
|
|
|
174
284
|
```bash
|
|
175
|
-
npm test
|
|
285
|
+
npm test # run 58 unit tests (no GPU required)
|
|
286
|
+
npm run build # compile TypeScript → dist/
|
|
287
|
+
npm run lint # ESLint on src/ and tests/
|
|
176
288
|
```
|
|
177
289
|
|
|
178
|
-
|
|
290
|
+
Unit tests cover quantization, BPE tokenization, autograd, and WGSL kernel source validation. GPU execution tests require a real browser with WebGPU support.
|
|
179
291
|
|
|
180
292
|
---
|
|
181
293
|
|
|
@@ -191,6 +303,16 @@ Runs 58 unit tests covering quantization, BPE tokenization, autograd, and WGSL k
|
|
|
191
303
|
|
|
192
304
|
---
|
|
193
305
|
|
|
306
|
+
## Acknowledgements
|
|
307
|
+
|
|
308
|
+
This library builds on the Mamba selective state space model research. Special credit to:
|
|
309
|
+
|
|
310
|
+
- **Mamba 3** — Tri Dao's blog post [*Mamba 3, Part 1*](https://tridao.me/blog/2026/mamba3-part1/) (2026), which describes the latest architectural refinements.
|
|
311
|
+
- **Mamba 3 paper** — [*Mamba: The Hard Way* (arXiv 2603.15569)](https://arxiv.org/abs/2603.15569), the accompanying technical paper.
|
|
312
|
+
- Original **Mamba SSM** paper — [*Mamba: Linear-Time Sequence Modeling with Selective State Spaces* (arXiv 2312.00752)](https://arxiv.org/abs/2312.00752) by Gu & Dao (2023).
|
|
313
|
+
|
|
314
|
+
---
|
|
315
|
+
|
|
194
316
|
## License
|
|
195
317
|
|
|
196
318
|
MIT
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MambaCode.js – Entry Point
|
|
3
|
+
*/
|
|
4
|
+
export { MambaModel } from './model/mamba_model.js';
|
|
5
|
+
export { MambaBlock } from './model/mamba_block.js';
|
|
6
|
+
export { MambaTrainer } from './training/trainer.js';
|
|
7
|
+
export { Tensor, backward, enableGrad, noGrad, clearTape, recordOperation, crossEntropyLoss, crossEntropyGrad, } from './training/autograd.js';
|
|
8
|
+
export { BPETokenizer } from './tokenizer/bpe.js';
|
|
9
|
+
export type { MambaModelConfig, SamplingOptions } from './model/mamba_model.js';
|
|
10
|
+
export { initWebGPU, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, createComputePipeline, createBindGroup, dispatchKernel, readBuffer, uploadBuffer, cdiv, } from './utils/gpu_utils.js';
|
|
11
|
+
export { quantizeFp16, dequantizeFp16, floatToFp16, fp16ToFloat, quantizeInt8, dequantizeInt8, quantizeInt8PerChannel, dequantizeInt8PerChannel, estimateMemory, } from './utils/quantization.js';
|
|
12
|
+
export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL } from './kernels/selective_scan.js';
|
|
13
|
+
export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL } from './kernels/conv1d.js';
|
|
14
|
+
export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL } from './kernels/linear_projection.js';
|
|
15
|
+
export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from './kernels/weight_update.js';
|
|
16
|
+
export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL } from './kernels/activations.js';
|
|
17
|
+
export declare const VERSION = "1.0.2";
|
|
18
|
+
export declare const DESCRIPTION = "MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models";
|
|
19
|
+
//# sourceMappingURL=index.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AACtD,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,uBAAuB,CAAC;AACrD,OAAO,EACH,MAAM,EACN,QAAQ,EACR,UAAU,EACV,MAAM,EACN,SAAS,EACT,eAAe,EACf,gBAAgB,EAChB,gBAAgB,GACnB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAElD,YAAY,EAAE,gBAAgB,EAAE,eAAe,EAAE,MAAM,wBAAwB,CAAC;AAEhF,OAAO,EACH,UAAU,EACV,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,UAAU,EACV,YAAY,EACZ,IAAI,GACP,MAAM,sBAAsB,CAAC;AAE9B,OAAO,EACH,YAAY,EACZ,cAAc,EACd,WAAW,EACX,WAAW,EACX,YAAY,EACZ,cAAc,EACd,sBAAsB,EACtB,wBAAwB,EACxB,cAAc,GACjB,MAAM,yBAAyB,CAAC;AAEjC,OAAO,EAAE,2BAA2B,EAAE,4BAA4B,EAAE,MAC3D,6BAA6B,CAAC;AACvC,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,qBAAqB,CAAC;AAC/B,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,gCAAgC,CAAC;AAC1C,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MACpC,4BAA4B,CAAC;AACtC,OAAO,EAAE,gBAAgB,EAAE,yBAAyB,EAAE,MAC7C,0BAA0B,CAAC;AAEpC,eAAO,MAAM,OAAO,UAAU,CAAC;AAC/B,eAAO,MAAM,WAAW,uEAAuE,CAAC"}
|
package/dist/index.js
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MambaCode.js – Entry Point
|
|
3
|
+
*/
|
|
4
|
+
export { MambaModel } from './model/mamba_model.js';
|
|
5
|
+
export { MambaBlock } from './model/mamba_block.js';
|
|
6
|
+
export { MambaTrainer } from './training/trainer.js';
|
|
7
|
+
export { Tensor, backward, enableGrad, noGrad, clearTape, recordOperation, crossEntropyLoss, crossEntropyGrad, } from './training/autograd.js';
|
|
8
|
+
export { BPETokenizer } from './tokenizer/bpe.js';
|
|
9
|
+
export { initWebGPU, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, createComputePipeline, createBindGroup, dispatchKernel, readBuffer, uploadBuffer, cdiv, } from './utils/gpu_utils.js';
|
|
10
|
+
export { quantizeFp16, dequantizeFp16, floatToFp16, fp16ToFloat, quantizeInt8, dequantizeInt8, quantizeInt8PerChannel, dequantizeInt8PerChannel, estimateMemory, } from './utils/quantization.js';
|
|
11
|
+
export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL } from './kernels/selective_scan.js';
|
|
12
|
+
export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL } from './kernels/conv1d.js';
|
|
13
|
+
export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL } from './kernels/linear_projection.js';
|
|
14
|
+
export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from './kernels/weight_update.js';
|
|
15
|
+
export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL } from './kernels/activations.js';
|
|
16
|
+
export const VERSION = '1.0.2';
|
|
17
|
+
export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
|
|
18
|
+
//# sourceMappingURL=index.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AACtD,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,uBAAuB,CAAC;AACrD,OAAO,EACH,MAAM,EACN,QAAQ,EACR,UAAU,EACV,MAAM,EACN,SAAS,EACT,eAAe,EACf,gBAAgB,EAChB,gBAAgB,GACnB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAIlD,OAAO,EACH,UAAU,EACV,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,UAAU,EACV,YAAY,EACZ,IAAI,GACP,MAAM,sBAAsB,CAAC;AAE9B,OAAO,EACH,YAAY,EACZ,cAAc,EACd,WAAW,EACX,WAAW,EACX,YAAY,EACZ,cAAc,EACd,sBAAsB,EACtB,wBAAwB,EACxB,cAAc,GACjB,MAAM,yBAAyB,CAAC;AAEjC,OAAO,EAAE,2BAA2B,EAAE,4BAA4B,EAAE,MAC3D,6BAA6B,CAAC;AACvC,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,qBAAqB,CAAC;AAC/B,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,gCAAgC,CAAC;AAC1C,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MACpC,4BAA4B,CAAC;AACtC,OAAO,EAAE,gBAAgB,EAAE,yBAAyB,EAAE,MAC7C,0BAA0B,CAAC;AAEpC,MAAM,CAAC,MAAM,OAAO,GAAG,OAAO,CAAC;AAC/B,MAAM,CAAC,MAAM,WAAW,GAAG,oEAAoE,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"activations.d.ts","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAGA,eAAO,MAAM,gBAAgB,EAAE,MAyD9B,CAAC;AAGF,eAAO,MAAM,yBAAyB,EAAE,MAwBvC,CAAC"}
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
// Activation function WGSL kernels: SiLU (Swish) and its backward pass.
|
|
2
|
+
// Used in the gating mechanism of the Mamba Mixer Block.
|
|
3
|
+
export const ACTIVATIONS_WGSL = /* wgsl */ `
|
|
4
|
+
|
|
5
|
+
struct ActParams {
|
|
6
|
+
num_elements : u32,
|
|
7
|
+
};
|
|
8
|
+
|
|
9
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
10
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
11
|
+
@group(0) @binding(2) var<storage, read_write> y : array<f32>;
|
|
12
|
+
|
|
13
|
+
// SiLU(x) = x * sigmoid(x)
|
|
14
|
+
@compute @workgroup_size(256, 1, 1)
|
|
15
|
+
fn silu_forward(
|
|
16
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
17
|
+
) {
|
|
18
|
+
let i = gid.x;
|
|
19
|
+
if (i >= p.num_elements) { return; }
|
|
20
|
+
let v = x[i];
|
|
21
|
+
y[i] = v / (1.0 + exp(-v));
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
// RMSNorm forward: y = x / rms(x) * weight
|
|
25
|
+
// Requires separate uniform for rms norm params.
|
|
26
|
+
struct RMSNormParams {
|
|
27
|
+
num_rows : u32, // number of vectors (batch * seq_len)
|
|
28
|
+
dim : u32, // feature dimension
|
|
29
|
+
eps : f32,
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
@group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
|
|
33
|
+
@group(0) @binding(1) var<storage, read> rms_x : array<f32>;
|
|
34
|
+
@group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
|
|
35
|
+
@group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
|
|
36
|
+
@group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
|
|
37
|
+
|
|
38
|
+
@compute @workgroup_size(64, 1, 1)
|
|
39
|
+
fn rmsnorm_forward(
|
|
40
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
41
|
+
) {
|
|
42
|
+
let row = gid.x;
|
|
43
|
+
if (row >= rms_p.num_rows) { return; }
|
|
44
|
+
|
|
45
|
+
let D = rms_p.dim;
|
|
46
|
+
let base = row * D;
|
|
47
|
+
|
|
48
|
+
var sq_sum: f32 = 0.0;
|
|
49
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
50
|
+
let v = rms_x[base + i];
|
|
51
|
+
sq_sum = sq_sum + v * v;
|
|
52
|
+
}
|
|
53
|
+
let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
|
|
54
|
+
rms_inv[row] = inv_rms;
|
|
55
|
+
|
|
56
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
57
|
+
rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
`;
|
|
61
|
+
// ---- Backward for SiLU ----
|
|
62
|
+
export const ACTIVATIONS_BACKWARD_WGSL = /* wgsl */ `
|
|
63
|
+
|
|
64
|
+
struct ActParams {
|
|
65
|
+
num_elements : u32,
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
69
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
70
|
+
@group(0) @binding(2) var<storage, read> dy : array<f32>;
|
|
71
|
+
@group(0) @binding(3) var<storage, read_write> dx : array<f32>;
|
|
72
|
+
|
|
73
|
+
// d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
|
74
|
+
// = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
|
|
75
|
+
// simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
|
|
76
|
+
@compute @workgroup_size(256, 1, 1)
|
|
77
|
+
fn silu_backward(
|
|
78
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
79
|
+
) {
|
|
80
|
+
let i = gid.x;
|
|
81
|
+
if (i >= p.num_elements) { return; }
|
|
82
|
+
let v = x[i];
|
|
83
|
+
let sig = 1.0 / (1.0 + exp(-v));
|
|
84
|
+
dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
|
|
85
|
+
}
|
|
86
|
+
`;
|
|
87
|
+
//# sourceMappingURL=activations.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"activations.js","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAAA,wEAAwE;AACxE,yDAAyD;AAEzD,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyDjD,CAAC;AAEF,8BAA8B;AAC9B,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;CAwB1D,CAAC"}
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
export declare const CONV1D_FORWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32, // L\n d_channels : u32, // D (number of depthwise channels)\n kernel_size : u32, // K (typically 4)\n batch : u32, // B\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n// x (B, L, D) \u2013 input\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n// weight (D, K) \u2013 depthwise conv weights\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n// bias (D,) \u2013 optional bias (zeros if unused)\n@group(0) @binding(3) var<storage, read> bias : array<f32>;\n// y (B, L, D) \u2013 output\n@group(0) @binding(4) var<storage, read_write> y : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B)\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_forward(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x; // time position\n let d = gid.y; // channel\n let b = gid.z; // batch\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var acc: f32 = 0.0;\n\n // Causal: convolve over k = 0..K-1, reading position (t - k)\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let w_idx = d * K + k;\n let w_val = weight[w_idx];\n\n // t - k: use causal zero-padding for t < k\n if (t >= k) {\n let src = b * L * D + (t - k) * D + d;\n acc = acc + w_val * x[src];\n }\n // else: zero-padding contributes 0\n }\n\n acc = acc + bias[d];\n\n let out = b * L * D + t * D + d;\n y[out] = acc;\n}\n";
|
|
2
|
+
export declare const CONV1D_BACKWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32,\n d_channels : u32,\n kernel_size : u32,\n batch : u32,\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n@group(0) @binding(3) var<storage, read> dy : array<f32>;\n@group(0) @binding(4) var<storage, read_write> dx : array<f32>;\n@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;\n@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B) \u2013 computes dx\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_backward_dx(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x;\n let d = gid.y;\n let b = gid.z;\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var grad: f32 = 0.0;\n\n // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let tp = t + k;\n if (tp < L) {\n let dy_idx = b * L * D + tp * D + d;\n let w_idx = d * K + k;\n grad = grad + dy[dy_idx] * weight[w_idx];\n }\n }\n\n let dx_idx = b * L * D + t * D + d;\n dx[dx_idx] = grad;\n}\n\n// Dispatch: (K, D, 1) \u2013 accumulates dweight over (B, L)\n@compute @workgroup_size(1, 1, 1)\nfn conv1d_backward_dw(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let k = gid.x;\n let d = gid.y;\n\n if (k >= K || d >= D) { return; }\n\n var grad_w: f32 = 0.0;\n var grad_b: f32 = 0.0;\n\n for (var b: u32 = 0u; b < B; b = b + 1u) {\n for (var t: u32 = 0u; t < L; t = t + 1u) {\n let dy_idx = b * L * D + t * D + d;\n let dy_val = dy[dy_idx];\n if (t >= k) {\n let x_idx = b * L * D + (t - k) * D + d;\n grad_w = grad_w + dy_val * x[x_idx];\n }\n if (k == 0u) {\n grad_b = grad_b + dy_val;\n }\n }\n }\n\n dweight[d * K + k] = grad_w;\n if (k == 0u) {\n dbias[d] = grad_b;\n }\n}\n";
|
|
3
|
+
//# sourceMappingURL=conv1d.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"conv1d.d.ts","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAQA,eAAO,MAAM,mBAAmB,urDAuD/B,CAAC;AAGF,eAAO,MAAM,oBAAoB,y6EAsFhC,CAAC"}
|