mambacode.js 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +19 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +61 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
  61. package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/{trainer.js → trainer.ts} +79 -161
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/utils/gpu_utils.js +0 -217
  71. package/src/utils/quantization.js +0 -215
  72. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
package/README.md CHANGED
@@ -1,14 +1,22 @@
1
- # Mamba
2
- MambaCode.js — WebGPU-accelerated Mamba SSM library for browser-based code model training and inference.
1
+ # MambaCode.js
3
2
 
4
- ## Overview
3
+ > WebGPU-accelerated Mamba State Space Model library — written in **TypeScript**, compiled for use in any JavaScript application.
5
4
 
6
- MambaCode.js is a pure JavaScript/WGSL implementation of the **Mamba State Space Model (SSM)** architecture, optimised for on-device code model training and inference in the browser. It targets the Qwen3.5-Coder-0.8B logic and supports full **on-device training** (backpropagation) via WebGPU, allowing models to adapt to a user's private codebase locally — without any data leaving the browser.
5
+ [![npm](https://img.shields.io/npm/v/mambacode.js)](https://www.npmjs.com/package/mambacode.js)
6
+ [![license](https://img.shields.io/badge/license-MIT-blue)](./LICENSE)
7
7
 
8
- ### Key features
8
+ MambaCode.js is a **TypeScript-first** library that brings the [Mamba SSM](https://arxiv.org/abs/2312.00752) architecture to the browser via WebGPU. It targets the Qwen3.5-Coder-0.8B model shape and supports full **on-device training** (backpropagation), allowing models to adapt to a user's private codebase locally — without any data leaving the browser.
9
+
10
+ > 📖 **New to MambaCode.js?** Start with the [Getting Started Guide](./docs/getting-started.md).
11
+
12
+ ---
13
+
14
+ ## Key Features
9
15
 
10
16
  | Feature | Detail |
11
17
  |---|---|
18
+ | **TypeScript-first** | Full type declarations shipped with the package |
19
+ | **Plain JS compatible** | Import the compiled `dist/` in any JavaScript project — no TypeScript toolchain required |
12
20
  | **Architecture** | Selective State Space Model (S6) — linear O(N) context scaling |
13
21
  | **Hardware target** | WebGPU (WGSL) — Chrome 113+, Edge 113+, Firefox Nightly |
14
22
  | **Memory ceiling** | ≤ 3 GB VRAM (Chrome/Edge/Firefox stable) |
@@ -20,91 +28,138 @@ MambaCode.js is a pure JavaScript/WGSL implementation of the **Mamba State Space
20
28
 
21
29
  ---
22
30
 
23
- ## Architecture
31
+ ## Installation
24
32
 
33
+ ```bash
34
+ npm install mambacode.js
25
35
  ```
26
- Token IDs
27
-
28
-
29
- Embedding Lookup (GPU gather kernel)
30
-
31
- ▼ ┌─────────────────────────────────────────┐
32
- │ Mamba Block × N │
33
- │ │
34
- │ Input ──► RMSNorm │
35
- │ │ │
36
- │ ┌────────┴────────┐ │
37
- │ ▼ ▼ │
38
- │ in_proj(x) in_proj(z) [gate] │
39
- │ │ │
40
- │ Conv1D (causal, K=4) │
41
- │ │ │
42
- │ SiLU activation │
43
- │ │ │
44
- │ x_proj → Δ, B, C (selective) │
45
- │ │ │
46
- │ Δ → dt_proj (full D_inner width) │
47
- │ │ │
48
- │ ┌───▼──────────────────────────────┐ │
49
- │ │ Selective Scan S6 │ │
50
- │ │ (Kogge-Stone parallel prefix) │ │
51
- │ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
52
- │ │ y_t = C·h_t + D·x_t │ │
53
- │ └──────────────────────────────────┘ │
54
- │ │ │
55
- │ Gate: y * SiLU(z) │
56
- │ │ │
57
- │ out_proj → residual add ──► output │
58
- └─────────────────────────────────────────┘
59
-
60
-
61
- Final RMSNorm → LM Head (tied embedding) → Logits
36
+
37
+ Build the library from source:
38
+
39
+ ```bash
40
+ npm run build # compiles TypeScript → dist/
62
41
  ```
63
42
 
64
43
  ---
65
44
 
45
+ ## Documentation
46
+
47
+ | Guide | Description |
48
+ |---|---|
49
+ | **[Getting Started](docs/getting-started.md)** | Beginner-friendly introduction — what LLMs are, how Qwen fits in, the full model lifecycle, and what to do next |
50
+ | **[Integration & Architecture](docs/integration-architecture.md)** | Production architecture guide — embedding Mamba as a unified brain + memory system, integration patterns, advanced use cases, and design tradeoffs |
51
+ | **[Weight Lifecycle](docs/weight-lifecycle.md)** | Complete guide to obtaining Qwen vocabulary files, loading pre-trained checkpoints, fine-tuning, exporting weights, and sharing with your team |
52
+ | **[API Reference](docs/api-reference.md)** | Full technical reference — every exported class, interface, and function with TypeScript and JavaScript examples |
53
+ | **[MambaKit PRD](docs/mamba-kit-prd.md)** | Product requirements document for MambaKit — an opinionated, zero-boilerplate facade over MambaCode.js |
54
+
55
+ ---
56
+
66
57
  ## Quick Start
67
58
 
68
- ```js
69
- import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from './src/index.js';
59
+ ### TypeScript
60
+
61
+ ```ts
62
+ import {
63
+ MambaModel,
64
+ MambaTrainer,
65
+ BPETokenizer,
66
+ initWebGPU,
67
+ type MambaModelConfig,
68
+ type TrainOptions,
69
+ } from 'mambacode.js';
70
70
 
71
71
  // 1. Initialise WebGPU
72
72
  const { device } = await initWebGPU();
73
73
 
74
- // 2. Load tokenizer
74
+ // 2. Load tokenizer (vocab.json + merges.txt from Qwen3.5-Coder)
75
75
  const tokenizer = new BPETokenizer();
76
76
  await tokenizer.load('/vocab.json', '/merges.txt');
77
77
 
78
78
  // 3. Create model
79
- const model = new MambaModel(device, {
80
- vocabSize : tokenizer.vocabSize, // e.g. 151936 for Qwen3.5-Coder
79
+ const config: MambaModelConfig = {
80
+ vocabSize : tokenizer.vocabSize, // 151936 for Qwen3.5-Coder
81
81
  dModel : 512,
82
82
  numLayers : 8,
83
83
  dState : 16,
84
84
  dConv : 4,
85
85
  expand : 2,
86
+ };
87
+ const model = new MambaModel(device, config);
88
+
89
+ // 4. Load a pre-trained checkpoint
90
+ const response = await fetch('/models/mamba-coder-checkpoint.bin');
91
+ await model.loadWeights(await response.arrayBuffer());
92
+
93
+ // 5. Fine-tune on local code
94
+ const trainer = new MambaTrainer(model, tokenizer);
95
+ const opts: TrainOptions = {
96
+ learningRate : 1e-4,
97
+ epochs : 5,
98
+ onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
99
+ };
100
+ const losses = await trainer.train(myCodeString, opts);
101
+
102
+ // 6. Generate code
103
+ const promptIds = tokenizer.encode('function fibonacci(');
104
+ const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
105
+ console.log(tokenizer.decode(outputIds));
106
+
107
+ // 7. Save fine-tuned weights for next session
108
+ const checkpoint = await model.exportWeights();
109
+ ```
110
+
111
+ ### JavaScript (ESM)
112
+
113
+ The compiled output in `dist/` is plain JavaScript with no TypeScript runtime dependency:
114
+
115
+ ```js
116
+ import {
117
+ MambaModel,
118
+ MambaTrainer,
119
+ BPETokenizer,
120
+ initWebGPU,
121
+ } from 'mambacode.js';
122
+
123
+ // 1. Initialise WebGPU
124
+ const { device } = await initWebGPU();
125
+
126
+ // 2. Load tokenizer (vocab.json + merges.txt from Qwen3.5-Coder)
127
+ const tokenizer = new BPETokenizer();
128
+ await tokenizer.load('/vocab.json', '/merges.txt');
129
+
130
+ // 3. Create model
131
+ const model = new MambaModel(device, {
132
+ vocabSize : tokenizer.vocabSize,
133
+ dModel : 512,
134
+ numLayers : 8,
86
135
  });
87
136
 
88
- // 4. Train on local code
137
+ // 4. Load a pre-trained checkpoint
138
+ const response = await fetch('/models/mamba-coder-checkpoint.bin');
139
+ await model.loadWeights(await response.arrayBuffer());
140
+
141
+ // 5. Fine-tune on local code
89
142
  const trainer = new MambaTrainer(model, tokenizer);
90
- const losses = await trainer.train(myCodeString, {
143
+ const losses = await trainer.train(myCodeString, {
91
144
  learningRate : 1e-4,
92
145
  epochs : 5,
93
- device : 'webgpu',
94
146
  onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
95
147
  });
96
148
 
97
- // 5. Generate code
149
+ // 6. Generate code
98
150
  const promptIds = tokenizer.encode('function fibonacci(');
99
151
  const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
100
152
  console.log(tokenizer.decode(outputIds));
153
+
154
+ // 7. Save fine-tuned weights for next session
155
+ const checkpoint = await model.exportWeights();
101
156
  ```
102
157
 
103
158
  ### WSLA (Weight-Selective Local Adaptation)
104
159
 
105
160
  Fine-tune only the B and C matrices for rapid private-codebase adaptation:
106
161
 
107
- ```js
162
+ ```ts
108
163
  await trainer.train(apiUsageExamples, {
109
164
  learningRate : 1e-4,
110
165
  epochs : 3,
@@ -114,40 +169,95 @@ await trainer.train(apiUsageExamples, {
114
169
 
115
170
  ---
116
171
 
172
+ ## Architecture
173
+
174
+ ```
175
+ Token IDs
176
+
177
+
178
+ Embedding Lookup (GPU gather kernel)
179
+
180
+ ▼ ┌─────────────────────────────────────────┐
181
+ │ Mamba Block × N │
182
+ │ │
183
+ │ Input ──► RMSNorm │
184
+ │ │ │
185
+ │ ┌────────┴────────┐ │
186
+ │ ▼ ▼ │
187
+ │ in_proj(x) in_proj(z) [gate] │
188
+ │ │ │
189
+ │ Conv1D (causal, K=4) │
190
+ │ │ │
191
+ │ SiLU activation │
192
+ │ │ │
193
+ │ x_proj → Δ, B, C (selective) │
194
+ │ │ │
195
+ │ Δ → dt_proj (full D_inner width) │
196
+ │ │ │
197
+ │ ┌───▼──────────────────────────────┐ │
198
+ │ │ Selective Scan S6 │ │
199
+ │ │ (Kogge-Stone parallel prefix) │ │
200
+ │ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
201
+ │ │ y_t = C·h_t + D·x_t │ │
202
+ │ └──────────────────────────────────┘ │
203
+ │ │ │
204
+ │ Gate: y * SiLU(z) │
205
+ │ │ │
206
+ │ out_proj → residual add ──► output │
207
+ └─────────────────────────────────────────┘
208
+
209
+
210
+ Final RMSNorm → LM Head (tied embedding) → Logits
211
+ ```
212
+
213
+ ---
214
+
117
215
  ## File Structure
118
216
 
119
217
  ```
120
- src/
121
- ├── index.js ← public API entry point
218
+ src/ ← TypeScript source (edit here)
219
+ ├── index.ts ← public API entry point
122
220
  ├── kernels/
123
- │ ├── selective_scan.js ← WGSL: S6 forward + backward (Kogge-Stone)
124
- │ ├── conv1d.js ← WGSL: 1D causal convolution
125
- │ ├── linear_projection.js ← WGSL: tiled matrix multiplication
126
- │ ├── weight_update.js ← WGSL: AdamW optimizer + gradient clipping
127
- │ └── activations.js ← WGSL: SiLU, RMSNorm
221
+ │ ├── selective_scan.ts ← WGSL: S6 forward + backward (Kogge-Stone)
222
+ │ ├── conv1d.ts ← WGSL: 1D causal convolution
223
+ │ ├── linear_projection.ts ← WGSL: tiled matrix multiplication
224
+ │ ├── weight_update.ts ← WGSL: AdamW optimizer + gradient clipping
225
+ │ └── activations.ts ← WGSL: SiLU, RMSNorm
128
226
  ├── model/
129
- │ ├── mamba_block.js ← Mamba Mixer Block (forward pass)
130
- │ └── mamba_model.js ← Full stacked model + generation
227
+ │ ├── mamba_block.ts ← Mamba Mixer Block (forward pass)
228
+ │ └── mamba_model.ts ← Full stacked model + generation
131
229
  ├── training/
132
- │ ├── autograd.js ← Tape-based AD engine + loss helpers
133
- │ └── trainer.js ← MambaTrainer class
230
+ │ ├── autograd.ts ← Tape-based AD engine + loss helpers
231
+ │ └── trainer.ts ← MambaTrainer class
134
232
  ├── tokenizer/
135
- │ └── bpe.js ← Browser-side BPE tokenizer
233
+ │ └── bpe.ts ← Browser-side BPE tokenizer
136
234
  └── utils/
137
- ├── gpu_utils.js ← WebGPU device/buffer management
138
- └── quantization.js ← FP16 / Int8 quantization utilities
235
+ ├── gpu_utils.ts ← WebGPU device/buffer management
236
+ └── quantization.ts ← FP16 / Int8 quantization utilities
237
+
238
+ dist/ ← Compiled output (JS + .d.ts, gitignored)
239
+ ├── index.js ← ESM entry point for JS consumers
240
+ ├── index.d.ts ← TypeScript declarations for TS consumers
241
+ └── ... ← mirrored sub-folders
242
+
139
243
  tests/
140
- ├── kernels.test.js ← WGSL kernel source smoke tests
141
- ├── autograd.test.js ← Autograd engine unit tests
142
- ├── bpe.test.js ← BPE tokenizer unit tests
143
- └── quantization.test.js ← Quantization round-trip tests
244
+ ├── kernels.test.ts ← WGSL kernel source smoke tests
245
+ ├── autograd.test.ts ← Autograd engine unit tests
246
+ ├── bpe.test.ts ← BPE tokenizer unit tests
247
+ └── quantization.test.ts ← Quantization round-trip tests
248
+
249
+ docs/
250
+ ├── getting-started.md ← Step-by-step guide (TS & JS)
251
+ ├── integration-architecture.md ← Brain + Memory architecture guide
252
+ ├── weight-lifecycle.md ← Weight loading, fine-tuning, export
253
+ └── api-reference.md ← Full API reference
144
254
  ```
145
255
 
146
256
  ---
147
257
 
148
258
  ## WGSL Kernels
149
259
 
150
- ### Parallel Selective Scan (`selective_scan.js`)
260
+ ### Parallel Selective Scan (`selective_scan.ts`)
151
261
  Implements the S6 core using a **Kogge-Stone parallel prefix-sum** inside each workgroup tile. Each tile of 64 time steps is scanned in log₂(64) = 6 GPU barrier rounds, giving O(log N) wall-clock time on the GPU.
152
262
 
153
263
  The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
@@ -158,13 +268,13 @@ The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
158
268
 
159
269
  Tiles are chained via a carry-in state, covering arbitrarily long sequences.
160
270
 
161
- ### 1D Causal Convolution (`conv1d.js`)
271
+ ### 1D Causal Convolution (`conv1d.ts`)
162
272
  Depthwise 1D causal conv (kernel size K=4) with zero left-padding. Enforces causality by only reading positions `t-k` for `k ≥ 0`, contributing 0 for `t < k`.
163
273
 
164
- ### Linear Projection (`linear_projection.js`)
274
+ ### Linear Projection (`linear_projection.ts`)
165
275
  Tiled 16×16 GEMM in WGSL using workgroup shared memory. Handles arbitrary (M, K) × (N, K) → (M, N) shapes with boundary guards.
166
276
 
167
- ### AdamW Optimizer (`weight_update.js`)
277
+ ### AdamW Optimizer (`weight_update.ts`)
168
278
  Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pass gradient norm clipping kernel (reduce → scale).
169
279
 
170
280
  ---
@@ -172,10 +282,12 @@ Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pas
172
282
  ## Testing
173
283
 
174
284
  ```bash
175
- npm test
285
+ npm test # run 58 unit tests (no GPU required)
286
+ npm run build # compile TypeScript → dist/
287
+ npm run lint # ESLint on src/ and tests/
176
288
  ```
177
289
 
178
- Runs 58 unit tests covering quantization, BPE tokenization, autograd, and WGSL kernel source validation. GPU execution tests require a real browser with WebGPU support.
290
+ Unit tests cover quantization, BPE tokenization, autograd, and WGSL kernel source validation. GPU execution tests require a real browser with WebGPU support.
179
291
 
180
292
  ---
181
293
 
@@ -191,6 +303,16 @@ Runs 58 unit tests covering quantization, BPE tokenization, autograd, and WGSL k
191
303
 
192
304
  ---
193
305
 
306
+ ## Acknowledgements
307
+
308
+ This library builds on the Mamba selective state space model research. Special credit to:
309
+
310
+ - **Mamba 3** — Tri Dao's blog post [*Mamba 3, Part 1*](https://tridao.me/blog/2026/mamba3-part1/) (2026), which describes the latest architectural refinements.
311
+ - **Mamba 3 paper** — [*Mamba: The Hard Way* (arXiv 2603.15569)](https://arxiv.org/abs/2603.15569), the accompanying technical paper.
312
+ - Original **Mamba SSM** paper — [*Mamba: Linear-Time Sequence Modeling with Selective State Spaces* (arXiv 2312.00752)](https://arxiv.org/abs/2312.00752) by Gu & Dao (2023).
313
+
314
+ ---
315
+
194
316
  ## License
195
317
 
196
318
  MIT
@@ -0,0 +1,19 @@
1
+ /**
2
+ * MambaCode.js – Entry Point
3
+ */
4
+ export { MambaModel } from './model/mamba_model.js';
5
+ export { MambaBlock } from './model/mamba_block.js';
6
+ export { MambaTrainer } from './training/trainer.js';
7
+ export { Tensor, backward, enableGrad, noGrad, clearTape, recordOperation, crossEntropyLoss, crossEntropyGrad, } from './training/autograd.js';
8
+ export { BPETokenizer } from './tokenizer/bpe.js';
9
+ export type { MambaModelConfig, SamplingOptions } from './model/mamba_model.js';
10
+ export { initWebGPU, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, createComputePipeline, createBindGroup, dispatchKernel, readBuffer, uploadBuffer, cdiv, } from './utils/gpu_utils.js';
11
+ export { quantizeFp16, dequantizeFp16, floatToFp16, fp16ToFloat, quantizeInt8, dequantizeInt8, quantizeInt8PerChannel, dequantizeInt8PerChannel, estimateMemory, } from './utils/quantization.js';
12
+ export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL } from './kernels/selective_scan.js';
13
+ export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL } from './kernels/conv1d.js';
14
+ export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL } from './kernels/linear_projection.js';
15
+ export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from './kernels/weight_update.js';
16
+ export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL } from './kernels/activations.js';
17
+ export declare const VERSION = "1.0.2";
18
+ export declare const DESCRIPTION = "MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models";
19
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AACtD,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,uBAAuB,CAAC;AACrD,OAAO,EACH,MAAM,EACN,QAAQ,EACR,UAAU,EACV,MAAM,EACN,SAAS,EACT,eAAe,EACf,gBAAgB,EAChB,gBAAgB,GACnB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAElD,YAAY,EAAE,gBAAgB,EAAE,eAAe,EAAE,MAAM,wBAAwB,CAAC;AAEhF,OAAO,EACH,UAAU,EACV,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,UAAU,EACV,YAAY,EACZ,IAAI,GACP,MAAM,sBAAsB,CAAC;AAE9B,OAAO,EACH,YAAY,EACZ,cAAc,EACd,WAAW,EACX,WAAW,EACX,YAAY,EACZ,cAAc,EACd,sBAAsB,EACtB,wBAAwB,EACxB,cAAc,GACjB,MAAM,yBAAyB,CAAC;AAEjC,OAAO,EAAE,2BAA2B,EAAE,4BAA4B,EAAE,MAC3D,6BAA6B,CAAC;AACvC,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,qBAAqB,CAAC;AAC/B,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,gCAAgC,CAAC;AAC1C,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MACpC,4BAA4B,CAAC;AACtC,OAAO,EAAE,gBAAgB,EAAE,yBAAyB,EAAE,MAC7C,0BAA0B,CAAC;AAEpC,eAAO,MAAM,OAAO,UAAU,CAAC;AAC/B,eAAO,MAAM,WAAW,uEAAuE,CAAC"}
package/dist/index.js ADDED
@@ -0,0 +1,18 @@
1
+ /**
2
+ * MambaCode.js – Entry Point
3
+ */
4
+ export { MambaModel } from './model/mamba_model.js';
5
+ export { MambaBlock } from './model/mamba_block.js';
6
+ export { MambaTrainer } from './training/trainer.js';
7
+ export { Tensor, backward, enableGrad, noGrad, clearTape, recordOperation, crossEntropyLoss, crossEntropyGrad, } from './training/autograd.js';
8
+ export { BPETokenizer } from './tokenizer/bpe.js';
9
+ export { initWebGPU, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, createComputePipeline, createBindGroup, dispatchKernel, readBuffer, uploadBuffer, cdiv, } from './utils/gpu_utils.js';
10
+ export { quantizeFp16, dequantizeFp16, floatToFp16, fp16ToFloat, quantizeInt8, dequantizeInt8, quantizeInt8PerChannel, dequantizeInt8PerChannel, estimateMemory, } from './utils/quantization.js';
11
+ export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL } from './kernels/selective_scan.js';
12
+ export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL } from './kernels/conv1d.js';
13
+ export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL } from './kernels/linear_projection.js';
14
+ export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from './kernels/weight_update.js';
15
+ export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL } from './kernels/activations.js';
16
+ export const VERSION = '1.0.2';
17
+ export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
18
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AACtD,OAAO,EAAE,UAAU,EAAE,MAAQ,wBAAwB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,uBAAuB,CAAC;AACrD,OAAO,EACH,MAAM,EACN,QAAQ,EACR,UAAU,EACV,MAAM,EACN,SAAS,EACT,eAAe,EACf,gBAAgB,EAChB,gBAAgB,GACnB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAIlD,OAAO,EACH,UAAU,EACV,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,UAAU,EACV,YAAY,EACZ,IAAI,GACP,MAAM,sBAAsB,CAAC;AAE9B,OAAO,EACH,YAAY,EACZ,cAAc,EACd,WAAW,EACX,WAAW,EACX,YAAY,EACZ,cAAc,EACd,sBAAsB,EACtB,wBAAwB,EACxB,cAAc,GACjB,MAAM,yBAAyB,CAAC;AAEjC,OAAO,EAAE,2BAA2B,EAAE,4BAA4B,EAAE,MAC3D,6BAA6B,CAAC;AACvC,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,qBAAqB,CAAC;AAC/B,OAAO,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAC3C,gCAAgC,CAAC;AAC1C,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MACpC,4BAA4B,CAAC;AACtC,OAAO,EAAE,gBAAgB,EAAE,yBAAyB,EAAE,MAC7C,0BAA0B,CAAC;AAEpC,MAAM,CAAC,MAAM,OAAO,GAAG,OAAO,CAAC;AAC/B,MAAM,CAAC,MAAM,WAAW,GAAG,oEAAoE,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const ACTIVATIONS_WGSL: string;
2
+ export declare const ACTIVATIONS_BACKWARD_WGSL: string;
3
+ //# sourceMappingURL=activations.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"activations.d.ts","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAGA,eAAO,MAAM,gBAAgB,EAAE,MAyD9B,CAAC;AAGF,eAAO,MAAM,yBAAyB,EAAE,MAwBvC,CAAC"}
@@ -0,0 +1,87 @@
1
+ // Activation function WGSL kernels: SiLU (Swish) and its backward pass.
2
+ // Used in the gating mechanism of the Mamba Mixer Block.
3
+ export const ACTIVATIONS_WGSL = /* wgsl */ `
4
+
5
+ struct ActParams {
6
+ num_elements : u32,
7
+ };
8
+
9
+ @group(0) @binding(0) var<uniform> p : ActParams;
10
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
11
+ @group(0) @binding(2) var<storage, read_write> y : array<f32>;
12
+
13
+ // SiLU(x) = x * sigmoid(x)
14
+ @compute @workgroup_size(256, 1, 1)
15
+ fn silu_forward(
16
+ @builtin(global_invocation_id) gid : vec3<u32>,
17
+ ) {
18
+ let i = gid.x;
19
+ if (i >= p.num_elements) { return; }
20
+ let v = x[i];
21
+ y[i] = v / (1.0 + exp(-v));
22
+ }
23
+
24
+ // RMSNorm forward: y = x / rms(x) * weight
25
+ // Requires separate uniform for rms norm params.
26
+ struct RMSNormParams {
27
+ num_rows : u32, // number of vectors (batch * seq_len)
28
+ dim : u32, // feature dimension
29
+ eps : f32,
30
+ };
31
+
32
+ @group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
33
+ @group(0) @binding(1) var<storage, read> rms_x : array<f32>;
34
+ @group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
35
+ @group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
36
+ @group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
37
+
38
+ @compute @workgroup_size(64, 1, 1)
39
+ fn rmsnorm_forward(
40
+ @builtin(global_invocation_id) gid : vec3<u32>,
41
+ ) {
42
+ let row = gid.x;
43
+ if (row >= rms_p.num_rows) { return; }
44
+
45
+ let D = rms_p.dim;
46
+ let base = row * D;
47
+
48
+ var sq_sum: f32 = 0.0;
49
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
50
+ let v = rms_x[base + i];
51
+ sq_sum = sq_sum + v * v;
52
+ }
53
+ let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
54
+ rms_inv[row] = inv_rms;
55
+
56
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
57
+ rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
58
+ }
59
+ }
60
+ `;
61
+ // ---- Backward for SiLU ----
62
+ export const ACTIVATIONS_BACKWARD_WGSL = /* wgsl */ `
63
+
64
+ struct ActParams {
65
+ num_elements : u32,
66
+ };
67
+
68
+ @group(0) @binding(0) var<uniform> p : ActParams;
69
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
70
+ @group(0) @binding(2) var<storage, read> dy : array<f32>;
71
+ @group(0) @binding(3) var<storage, read_write> dx : array<f32>;
72
+
73
+ // d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
74
+ // = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
75
+ // simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
76
+ @compute @workgroup_size(256, 1, 1)
77
+ fn silu_backward(
78
+ @builtin(global_invocation_id) gid : vec3<u32>,
79
+ ) {
80
+ let i = gid.x;
81
+ if (i >= p.num_elements) { return; }
82
+ let v = x[i];
83
+ let sig = 1.0 / (1.0 + exp(-v));
84
+ dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
85
+ }
86
+ `;
87
+ //# sourceMappingURL=activations.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"activations.js","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAAA,wEAAwE;AACxE,yDAAyD;AAEzD,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyDjD,CAAC;AAEF,8BAA8B;AAC9B,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;CAwB1D,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const CONV1D_FORWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32, // L\n d_channels : u32, // D (number of depthwise channels)\n kernel_size : u32, // K (typically 4)\n batch : u32, // B\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n// x (B, L, D) \u2013 input\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n// weight (D, K) \u2013 depthwise conv weights\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n// bias (D,) \u2013 optional bias (zeros if unused)\n@group(0) @binding(3) var<storage, read> bias : array<f32>;\n// y (B, L, D) \u2013 output\n@group(0) @binding(4) var<storage, read_write> y : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B)\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_forward(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x; // time position\n let d = gid.y; // channel\n let b = gid.z; // batch\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var acc: f32 = 0.0;\n\n // Causal: convolve over k = 0..K-1, reading position (t - k)\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let w_idx = d * K + k;\n let w_val = weight[w_idx];\n\n // t - k: use causal zero-padding for t < k\n if (t >= k) {\n let src = b * L * D + (t - k) * D + d;\n acc = acc + w_val * x[src];\n }\n // else: zero-padding contributes 0\n }\n\n acc = acc + bias[d];\n\n let out = b * L * D + t * D + d;\n y[out] = acc;\n}\n";
2
+ export declare const CONV1D_BACKWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32,\n d_channels : u32,\n kernel_size : u32,\n batch : u32,\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n@group(0) @binding(3) var<storage, read> dy : array<f32>;\n@group(0) @binding(4) var<storage, read_write> dx : array<f32>;\n@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;\n@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B) \u2013 computes dx\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_backward_dx(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x;\n let d = gid.y;\n let b = gid.z;\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var grad: f32 = 0.0;\n\n // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let tp = t + k;\n if (tp < L) {\n let dy_idx = b * L * D + tp * D + d;\n let w_idx = d * K + k;\n grad = grad + dy[dy_idx] * weight[w_idx];\n }\n }\n\n let dx_idx = b * L * D + t * D + d;\n dx[dx_idx] = grad;\n}\n\n// Dispatch: (K, D, 1) \u2013 accumulates dweight over (B, L)\n@compute @workgroup_size(1, 1, 1)\nfn conv1d_backward_dw(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let k = gid.x;\n let d = gid.y;\n\n if (k >= K || d >= D) { return; }\n\n var grad_w: f32 = 0.0;\n var grad_b: f32 = 0.0;\n\n for (var b: u32 = 0u; b < B; b = b + 1u) {\n for (var t: u32 = 0u; t < L; t = t + 1u) {\n let dy_idx = b * L * D + t * D + d;\n let dy_val = dy[dy_idx];\n if (t >= k) {\n let x_idx = b * L * D + (t - k) * D + d;\n grad_w = grad_w + dy_val * x[x_idx];\n }\n if (k == 0u) {\n grad_b = grad_b + dy_val;\n }\n }\n }\n\n dweight[d * K + k] = grad_w;\n if (k == 0u) {\n dbias[d] = grad_b;\n }\n}\n";
3
+ //# sourceMappingURL=conv1d.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"conv1d.d.ts","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAQA,eAAO,MAAM,mBAAmB,urDAuD/B,CAAC;AAGF,eAAO,MAAM,oBAAoB,y6EAsFhC,CAAC"}