mambacode.js 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +196 -0
- package/package.json +54 -0
- package/src/index.js +89 -0
- package/src/kernels/activations.js +88 -0
- package/src/kernels/conv1d.js +153 -0
- package/src/kernels/linear_projection.js +220 -0
- package/src/kernels/selective_scan.js +350 -0
- package/src/kernels/weight_update.js +120 -0
- package/src/model/mamba_block.js +443 -0
- package/src/model/mamba_model.js +335 -0
- package/src/tokenizer/bpe.js +256 -0
- package/src/training/autograd.js +221 -0
- package/src/training/trainer.js +394 -0
- package/src/utils/gpu_utils.js +217 -0
- package/src/utils/quantization.js +215 -0
package/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sean Hogg
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/README.md
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Mamba
|
|
2
|
+
MambaCode.js — WebGPU-accelerated Mamba SSM library for browser-based code model training and inference.
|
|
3
|
+
|
|
4
|
+
## Overview
|
|
5
|
+
|
|
6
|
+
MambaCode.js is a pure JavaScript/WGSL implementation of the **Mamba State Space Model (SSM)** architecture, optimised for on-device code model training and inference in the browser. It targets the Qwen3.5-Coder-0.8B logic and supports full **on-device training** (backpropagation) via WebGPU, allowing models to adapt to a user's private codebase locally — without any data leaving the browser.
|
|
7
|
+
|
|
8
|
+
### Key features
|
|
9
|
+
|
|
10
|
+
| Feature | Detail |
|
|
11
|
+
|---|---|
|
|
12
|
+
| **Architecture** | Selective State Space Model (S6) — linear O(N) context scaling |
|
|
13
|
+
| **Hardware target** | WebGPU (WGSL) — Chrome 113+, Edge 113+, Firefox Nightly |
|
|
14
|
+
| **Memory ceiling** | ≤ 3 GB VRAM (Chrome/Edge/Firefox stable) |
|
|
15
|
+
| **No heavy frameworks** | Zero TensorFlow.js / Transformers.js dependencies |
|
|
16
|
+
| **On-device training** | Tape-based autograd + AdamW GPU optimizer |
|
|
17
|
+
| **Quantization** | FP16 weights, Int8 activations |
|
|
18
|
+
| **Tokenizer** | Browser-side BPE (Qwen3.5-Coder compatible) |
|
|
19
|
+
| **WSLA mode** | Fine-tune only B & C matrices for rapid local adaptation |
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
## Architecture
|
|
24
|
+
|
|
25
|
+
```
|
|
26
|
+
Token IDs
|
|
27
|
+
│
|
|
28
|
+
▼
|
|
29
|
+
Embedding Lookup (GPU gather kernel)
|
|
30
|
+
│
|
|
31
|
+
▼ ┌─────────────────────────────────────────┐
|
|
32
|
+
│ Mamba Block × N │
|
|
33
|
+
│ │
|
|
34
|
+
│ Input ──► RMSNorm │
|
|
35
|
+
│ │ │
|
|
36
|
+
│ ┌────────┴────────┐ │
|
|
37
|
+
│ ▼ ▼ │
|
|
38
|
+
│ in_proj(x) in_proj(z) [gate] │
|
|
39
|
+
│ │ │
|
|
40
|
+
│ Conv1D (causal, K=4) │
|
|
41
|
+
│ │ │
|
|
42
|
+
│ SiLU activation │
|
|
43
|
+
│ │ │
|
|
44
|
+
│ x_proj → Δ, B, C (selective) │
|
|
45
|
+
│ │ │
|
|
46
|
+
│ Δ → dt_proj (full D_inner width) │
|
|
47
|
+
│ │ │
|
|
48
|
+
│ ┌───▼──────────────────────────────┐ │
|
|
49
|
+
│ │ Selective Scan S6 │ │
|
|
50
|
+
│ │ (Kogge-Stone parallel prefix) │ │
|
|
51
|
+
│ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
|
|
52
|
+
│ │ y_t = C·h_t + D·x_t │ │
|
|
53
|
+
│ └──────────────────────────────────┘ │
|
|
54
|
+
│ │ │
|
|
55
|
+
│ Gate: y * SiLU(z) │
|
|
56
|
+
│ │ │
|
|
57
|
+
│ out_proj → residual add ──► output │
|
|
58
|
+
└─────────────────────────────────────────┘
|
|
59
|
+
│
|
|
60
|
+
▼
|
|
61
|
+
Final RMSNorm → LM Head (tied embedding) → Logits
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
---
|
|
65
|
+
|
|
66
|
+
## Quick Start
|
|
67
|
+
|
|
68
|
+
```js
|
|
69
|
+
import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from './src/index.js';
|
|
70
|
+
|
|
71
|
+
// 1. Initialise WebGPU
|
|
72
|
+
const { device } = await initWebGPU();
|
|
73
|
+
|
|
74
|
+
// 2. Load tokenizer
|
|
75
|
+
const tokenizer = new BPETokenizer();
|
|
76
|
+
await tokenizer.load('/vocab.json', '/merges.txt');
|
|
77
|
+
|
|
78
|
+
// 3. Create model
|
|
79
|
+
const model = new MambaModel(device, {
|
|
80
|
+
vocabSize : tokenizer.vocabSize, // e.g. 151936 for Qwen3.5-Coder
|
|
81
|
+
dModel : 512,
|
|
82
|
+
numLayers : 8,
|
|
83
|
+
dState : 16,
|
|
84
|
+
dConv : 4,
|
|
85
|
+
expand : 2,
|
|
86
|
+
});
|
|
87
|
+
|
|
88
|
+
// 4. Train on local code
|
|
89
|
+
const trainer = new MambaTrainer(model, tokenizer);
|
|
90
|
+
const losses = await trainer.train(myCodeString, {
|
|
91
|
+
learningRate : 1e-4,
|
|
92
|
+
epochs : 5,
|
|
93
|
+
device : 'webgpu',
|
|
94
|
+
onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
// 5. Generate code
|
|
98
|
+
const promptIds = tokenizer.encode('function fibonacci(');
|
|
99
|
+
const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
|
|
100
|
+
console.log(tokenizer.decode(outputIds));
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
### WSLA (Weight-Selective Local Adaptation)
|
|
104
|
+
|
|
105
|
+
Fine-tune only the B and C matrices for rapid private-codebase adaptation:
|
|
106
|
+
|
|
107
|
+
```js
|
|
108
|
+
await trainer.train(apiUsageExamples, {
|
|
109
|
+
learningRate : 1e-4,
|
|
110
|
+
epochs : 3,
|
|
111
|
+
wsla : true, // only B and C matrices are updated
|
|
112
|
+
});
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
---
|
|
116
|
+
|
|
117
|
+
## File Structure
|
|
118
|
+
|
|
119
|
+
```
|
|
120
|
+
src/
|
|
121
|
+
├── index.js ← public API entry point
|
|
122
|
+
├── kernels/
|
|
123
|
+
│ ├── selective_scan.js ← WGSL: S6 forward + backward (Kogge-Stone)
|
|
124
|
+
│ ├── conv1d.js ← WGSL: 1D causal convolution
|
|
125
|
+
│ ├── linear_projection.js ← WGSL: tiled matrix multiplication
|
|
126
|
+
│ ├── weight_update.js ← WGSL: AdamW optimizer + gradient clipping
|
|
127
|
+
│ └── activations.js ← WGSL: SiLU, RMSNorm
|
|
128
|
+
├── model/
|
|
129
|
+
│ ├── mamba_block.js ← Mamba Mixer Block (forward pass)
|
|
130
|
+
│ └── mamba_model.js ← Full stacked model + generation
|
|
131
|
+
├── training/
|
|
132
|
+
│ ├── autograd.js ← Tape-based AD engine + loss helpers
|
|
133
|
+
│ └── trainer.js ← MambaTrainer class
|
|
134
|
+
├── tokenizer/
|
|
135
|
+
│ └── bpe.js ← Browser-side BPE tokenizer
|
|
136
|
+
└── utils/
|
|
137
|
+
├── gpu_utils.js ← WebGPU device/buffer management
|
|
138
|
+
└── quantization.js ← FP16 / Int8 quantization utilities
|
|
139
|
+
tests/
|
|
140
|
+
├── kernels.test.js ← WGSL kernel source smoke tests
|
|
141
|
+
├── autograd.test.js ← Autograd engine unit tests
|
|
142
|
+
├── bpe.test.js ← BPE tokenizer unit tests
|
|
143
|
+
└── quantization.test.js ← Quantization round-trip tests
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
---
|
|
147
|
+
|
|
148
|
+
## WGSL Kernels
|
|
149
|
+
|
|
150
|
+
### Parallel Selective Scan (`selective_scan.js`)
|
|
151
|
+
Implements the S6 core using a **Kogge-Stone parallel prefix-sum** inside each workgroup tile. Each tile of 64 time steps is scanned in log₂(64) = 6 GPU barrier rounds, giving O(log N) wall-clock time on the GPU.
|
|
152
|
+
|
|
153
|
+
The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
|
|
154
|
+
|
|
155
|
+
```
|
|
156
|
+
(a₁, b₁) ∘ (a₂, b₂) = (a₁·a₂, a₁·b₂ + b₁)
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
Tiles are chained via a carry-in state, covering arbitrarily long sequences.
|
|
160
|
+
|
|
161
|
+
### 1D Causal Convolution (`conv1d.js`)
|
|
162
|
+
Depthwise 1D causal conv (kernel size K=4) with zero left-padding. Enforces causality by only reading positions `t-k` for `k ≥ 0`, contributing 0 for `t < k`.
|
|
163
|
+
|
|
164
|
+
### Linear Projection (`linear_projection.js`)
|
|
165
|
+
Tiled 16×16 GEMM in WGSL using workgroup shared memory. Handles arbitrary (M, K) × (N, K) → (M, N) shapes with boundary guards.
|
|
166
|
+
|
|
167
|
+
### AdamW Optimizer (`weight_update.js`)
|
|
168
|
+
Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pass gradient norm clipping kernel (reduce → scale).
|
|
169
|
+
|
|
170
|
+
---
|
|
171
|
+
|
|
172
|
+
## Testing
|
|
173
|
+
|
|
174
|
+
```bash
|
|
175
|
+
npm test
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
Runs 58 unit tests covering quantization, BPE tokenization, autograd, and WGSL kernel source validation. GPU execution tests require a real browser with WebGPU support.
|
|
179
|
+
|
|
180
|
+
---
|
|
181
|
+
|
|
182
|
+
## Browser Compatibility
|
|
183
|
+
|
|
184
|
+
| Browser | Version | Status |
|
|
185
|
+
|---|---|---|
|
|
186
|
+
| Chrome | 113+ | ✅ Supported |
|
|
187
|
+
| Edge | 113+ | ✅ Supported |
|
|
188
|
+
| Firefox | Nightly | ✅ Supported (flag: `dom.webgpu.enabled`) |
|
|
189
|
+
| Safari | 18+ | ⚠️ Partial (WebGPU in preview) |
|
|
190
|
+
| Node.js | — | ❌ Not supported (no `navigator.gpu`) |
|
|
191
|
+
|
|
192
|
+
---
|
|
193
|
+
|
|
194
|
+
## License
|
|
195
|
+
|
|
196
|
+
MIT
|
package/package.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "mambacode.js",
|
|
3
|
+
"version": "1.0.0",
|
|
4
|
+
"description": "High-performance JavaScript/WGSL Mamba SSM library for browser-based code model training and inference",
|
|
5
|
+
"main": "src/index.js",
|
|
6
|
+
"type": "module",
|
|
7
|
+
"files": [
|
|
8
|
+
"src",
|
|
9
|
+
"README.md",
|
|
10
|
+
"LICENSE"
|
|
11
|
+
],
|
|
12
|
+
"scripts": {
|
|
13
|
+
"test": "node --experimental-vm-modules node_modules/.bin/jest",
|
|
14
|
+
"lint": "eslint src/ tests/"
|
|
15
|
+
},
|
|
16
|
+
"keywords": [
|
|
17
|
+
"mamba",
|
|
18
|
+
"ssm",
|
|
19
|
+
"state-space-model",
|
|
20
|
+
"webgpu",
|
|
21
|
+
"wgsl",
|
|
22
|
+
"machine-learning",
|
|
23
|
+
"code-model",
|
|
24
|
+
"bpe",
|
|
25
|
+
"transformer-alternative"
|
|
26
|
+
],
|
|
27
|
+
"author": {
|
|
28
|
+
"name": "Sean Hogg",
|
|
29
|
+
"email": "seanhogg@gmail.com",
|
|
30
|
+
"url": "https://builderforce.ai"
|
|
31
|
+
},
|
|
32
|
+
"license": "MIT",
|
|
33
|
+
"repository": {
|
|
34
|
+
"type": "git",
|
|
35
|
+
"url": "https://github.com/SeanHogg/Mamba.git"
|
|
36
|
+
},
|
|
37
|
+
"bugs": {
|
|
38
|
+
"url": "https://github.com/SeanHogg/Mamba/issues"
|
|
39
|
+
},
|
|
40
|
+
"homepage": "https://github.com/SeanHogg/Mamba#readme",
|
|
41
|
+
"engines": {
|
|
42
|
+
"node": ">=18.0.0"
|
|
43
|
+
},
|
|
44
|
+
"publishConfig": {
|
|
45
|
+
"access": "public"
|
|
46
|
+
},
|
|
47
|
+
"devDependencies": {
|
|
48
|
+
"jest": "^29.7.0",
|
|
49
|
+
"eslint": "^8.57.0"
|
|
50
|
+
},
|
|
51
|
+
"jest": {
|
|
52
|
+
"transform": {}
|
|
53
|
+
}
|
|
54
|
+
}
|
package/src/index.js
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MambaCode.js – Entry Point
|
|
3
|
+
*
|
|
4
|
+
* High-performance JavaScript/WGSL Mamba SSM library for browser-based
|
|
5
|
+
* code model training and inference.
|
|
6
|
+
*
|
|
7
|
+
* Quick-start example
|
|
8
|
+
* -------------------
|
|
9
|
+
* import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from 'mambacode.js';
|
|
10
|
+
*
|
|
11
|
+
* const { device } = await initWebGPU();
|
|
12
|
+
* const tokenizer = new BPETokenizer();
|
|
13
|
+
* await tokenizer.load('/vocab.json', '/merges.txt');
|
|
14
|
+
*
|
|
15
|
+
* const model = new MambaModel(device, {
|
|
16
|
+
* vocabSize : tokenizer.vocabSize,
|
|
17
|
+
* dModel : 512,
|
|
18
|
+
* numLayers : 8,
|
|
19
|
+
* });
|
|
20
|
+
*
|
|
21
|
+
* const trainer = new MambaTrainer(model, tokenizer);
|
|
22
|
+
* const losses = await trainer.train(myCodeString, { learningRate: 1e-4, epochs: 5 });
|
|
23
|
+
*
|
|
24
|
+
* const generated = await model.generate(tokenizer.encode('function '), 100);
|
|
25
|
+
* console.log(tokenizer.decode(generated));
|
|
26
|
+
*/
|
|
27
|
+
|
|
28
|
+
// ── Core model ────────────────────────────────────────────────────────────────
|
|
29
|
+
export { MambaModel } from './model/mamba_model.js';
|
|
30
|
+
export { MambaBlock } from './model/mamba_block.js';
|
|
31
|
+
|
|
32
|
+
// ── Training ──────────────────────────────────────────────────────────────────
|
|
33
|
+
export { MambaTrainer } from './training/trainer.js';
|
|
34
|
+
export {
|
|
35
|
+
Tensor,
|
|
36
|
+
backward,
|
|
37
|
+
enableGrad,
|
|
38
|
+
noGrad,
|
|
39
|
+
clearTape,
|
|
40
|
+
recordOperation,
|
|
41
|
+
crossEntropyLoss,
|
|
42
|
+
crossEntropyGrad,
|
|
43
|
+
} from './training/autograd.js';
|
|
44
|
+
|
|
45
|
+
// ── Tokenizer ─────────────────────────────────────────────────────────────────
|
|
46
|
+
export { BPETokenizer } from './tokenizer/bpe.js';
|
|
47
|
+
|
|
48
|
+
// ── WebGPU utilities ──────────────────────────────────────────────────────────
|
|
49
|
+
export {
|
|
50
|
+
initWebGPU,
|
|
51
|
+
createStorageBuffer,
|
|
52
|
+
createEmptyStorageBuffer,
|
|
53
|
+
createUniformBuffer,
|
|
54
|
+
createComputePipeline,
|
|
55
|
+
createBindGroup,
|
|
56
|
+
dispatchKernel,
|
|
57
|
+
readBuffer,
|
|
58
|
+
uploadBuffer,
|
|
59
|
+
cdiv,
|
|
60
|
+
} from './utils/gpu_utils.js';
|
|
61
|
+
|
|
62
|
+
// ── Quantization utilities ────────────────────────────────────────────────────
|
|
63
|
+
export {
|
|
64
|
+
quantizeFp16,
|
|
65
|
+
dequantizeFp16,
|
|
66
|
+
floatToFp16,
|
|
67
|
+
fp16ToFloat,
|
|
68
|
+
quantizeInt8,
|
|
69
|
+
dequantizeInt8,
|
|
70
|
+
quantizeInt8PerChannel,
|
|
71
|
+
dequantizeInt8PerChannel,
|
|
72
|
+
estimateMemory,
|
|
73
|
+
} from './utils/quantization.js';
|
|
74
|
+
|
|
75
|
+
// ── Raw WGSL kernel sources (for advanced users / custom pipelines) ───────────
|
|
76
|
+
export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL }
|
|
77
|
+
from './kernels/selective_scan.js';
|
|
78
|
+
export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL }
|
|
79
|
+
from './kernels/conv1d.js';
|
|
80
|
+
export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL }
|
|
81
|
+
from './kernels/linear_projection.js';
|
|
82
|
+
export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL }
|
|
83
|
+
from './kernels/weight_update.js';
|
|
84
|
+
export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL }
|
|
85
|
+
from './kernels/activations.js';
|
|
86
|
+
|
|
87
|
+
// ── Library metadata ──────────────────────────────────────────────────────────
|
|
88
|
+
export const VERSION = '0.1.0';
|
|
89
|
+
export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
// Activation function WGSL kernels: SiLU (Swish) and its backward pass.
|
|
2
|
+
// Used in the gating mechanism of the Mamba Mixer Block.
|
|
3
|
+
|
|
4
|
+
export const ACTIVATIONS_WGSL = /* wgsl */`
|
|
5
|
+
|
|
6
|
+
struct ActParams {
|
|
7
|
+
num_elements : u32,
|
|
8
|
+
};
|
|
9
|
+
|
|
10
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
11
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
12
|
+
@group(0) @binding(2) var<storage, read_write> y : array<f32>;
|
|
13
|
+
|
|
14
|
+
// SiLU(x) = x * sigmoid(x)
|
|
15
|
+
@compute @workgroup_size(256, 1, 1)
|
|
16
|
+
fn silu_forward(
|
|
17
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
18
|
+
) {
|
|
19
|
+
let i = gid.x;
|
|
20
|
+
if (i >= p.num_elements) { return; }
|
|
21
|
+
let v = x[i];
|
|
22
|
+
y[i] = v / (1.0 + exp(-v));
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// RMSNorm forward: y = x / rms(x) * weight
|
|
26
|
+
// Requires separate uniform for rms norm params.
|
|
27
|
+
struct RMSNormParams {
|
|
28
|
+
num_rows : u32, // number of vectors (batch * seq_len)
|
|
29
|
+
dim : u32, // feature dimension
|
|
30
|
+
eps : f32,
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
@group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
|
|
34
|
+
@group(0) @binding(1) var<storage, read> rms_x : array<f32>;
|
|
35
|
+
@group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
|
|
36
|
+
@group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
|
|
37
|
+
@group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
|
|
38
|
+
|
|
39
|
+
@compute @workgroup_size(64, 1, 1)
|
|
40
|
+
fn rmsnorm_forward(
|
|
41
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
42
|
+
) {
|
|
43
|
+
let row = gid.x;
|
|
44
|
+
if (row >= rms_p.num_rows) { return; }
|
|
45
|
+
|
|
46
|
+
let D = rms_p.dim;
|
|
47
|
+
let base = row * D;
|
|
48
|
+
|
|
49
|
+
var sq_sum: f32 = 0.0;
|
|
50
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
51
|
+
let v = rms_x[base + i];
|
|
52
|
+
sq_sum = sq_sum + v * v;
|
|
53
|
+
}
|
|
54
|
+
let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
|
|
55
|
+
rms_inv[row] = inv_rms;
|
|
56
|
+
|
|
57
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
58
|
+
rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
`;
|
|
62
|
+
|
|
63
|
+
// ---- Backward for SiLU ----
|
|
64
|
+
export const ACTIVATIONS_BACKWARD_WGSL = /* wgsl */`
|
|
65
|
+
|
|
66
|
+
struct ActParams {
|
|
67
|
+
num_elements : u32,
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
71
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
72
|
+
@group(0) @binding(2) var<storage, read> dy : array<f32>;
|
|
73
|
+
@group(0) @binding(3) var<storage, read_write> dx : array<f32>;
|
|
74
|
+
|
|
75
|
+
// d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
|
76
|
+
// = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
|
|
77
|
+
// simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
|
|
78
|
+
@compute @workgroup_size(256, 1, 1)
|
|
79
|
+
fn silu_backward(
|
|
80
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
81
|
+
) {
|
|
82
|
+
let i = gid.x;
|
|
83
|
+
if (i >= p.num_elements) { return; }
|
|
84
|
+
let v = x[i];
|
|
85
|
+
let sig = 1.0 / (1.0 + exp(-v));
|
|
86
|
+
dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
|
|
87
|
+
}
|
|
88
|
+
`;
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
// 1D Causal Convolution WGSL Kernel
|
|
2
|
+
// Implements a depthwise 1D causal convolution over the sequence dimension.
|
|
3
|
+
// "Causal" means the output at position t only depends on positions <= t,
|
|
4
|
+
// which is enforced by left-padding with (kernel_size - 1) zeros.
|
|
5
|
+
//
|
|
6
|
+
// Forward: y[b, t, d] = sum_{k=0}^{K-1} weight[d, k] * x[b, t-k, d]
|
|
7
|
+
// where x[b, t', d] = 0 for t' < 0 (causal padding)
|
|
8
|
+
|
|
9
|
+
export const CONV1D_FORWARD_WGSL = /* wgsl */`
|
|
10
|
+
|
|
11
|
+
struct ConvParams {
|
|
12
|
+
seq_len : u32, // L
|
|
13
|
+
d_channels : u32, // D (number of depthwise channels)
|
|
14
|
+
kernel_size : u32, // K (typically 4)
|
|
15
|
+
batch : u32, // B
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
@group(0) @binding(0) var<uniform> params : ConvParams;
|
|
19
|
+
// x (B, L, D) – input
|
|
20
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
21
|
+
// weight (D, K) – depthwise conv weights
|
|
22
|
+
@group(0) @binding(2) var<storage, read> weight : array<f32>;
|
|
23
|
+
// bias (D,) – optional bias (zeros if unused)
|
|
24
|
+
@group(0) @binding(3) var<storage, read> bias : array<f32>;
|
|
25
|
+
// y (B, L, D) – output
|
|
26
|
+
@group(0) @binding(4) var<storage, read_write> y : array<f32>;
|
|
27
|
+
|
|
28
|
+
// Dispatch: (ceil(L/16), ceil(D/16), B)
|
|
29
|
+
@compute @workgroup_size(16, 16, 1)
|
|
30
|
+
fn conv1d_forward(
|
|
31
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
32
|
+
) {
|
|
33
|
+
let L = params.seq_len;
|
|
34
|
+
let D = params.d_channels;
|
|
35
|
+
let K = params.kernel_size;
|
|
36
|
+
let B = params.batch;
|
|
37
|
+
|
|
38
|
+
let t = gid.x; // time position
|
|
39
|
+
let d = gid.y; // channel
|
|
40
|
+
let b = gid.z; // batch
|
|
41
|
+
|
|
42
|
+
if (t >= L || d >= D || b >= B) { return; }
|
|
43
|
+
|
|
44
|
+
var acc: f32 = 0.0;
|
|
45
|
+
|
|
46
|
+
// Causal: convolve over k = 0..K-1, reading position (t - k)
|
|
47
|
+
for (var k: u32 = 0u; k < K; k = k + 1u) {
|
|
48
|
+
let w_idx = d * K + k;
|
|
49
|
+
let w_val = weight[w_idx];
|
|
50
|
+
|
|
51
|
+
// t - k: use causal zero-padding for t < k
|
|
52
|
+
if (t >= k) {
|
|
53
|
+
let src = b * L * D + (t - k) * D + d;
|
|
54
|
+
acc = acc + w_val * x[src];
|
|
55
|
+
}
|
|
56
|
+
// else: zero-padding contributes 0
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
acc = acc + bias[d];
|
|
60
|
+
|
|
61
|
+
let out = b * L * D + t * D + d;
|
|
62
|
+
y[out] = acc;
|
|
63
|
+
}
|
|
64
|
+
`;
|
|
65
|
+
|
|
66
|
+
// ---- Backward kernel for 1D convolution ----
|
|
67
|
+
export const CONV1D_BACKWARD_WGSL = /* wgsl */`
|
|
68
|
+
|
|
69
|
+
struct ConvParams {
|
|
70
|
+
seq_len : u32,
|
|
71
|
+
d_channels : u32,
|
|
72
|
+
kernel_size : u32,
|
|
73
|
+
batch : u32,
|
|
74
|
+
};
|
|
75
|
+
|
|
76
|
+
@group(0) @binding(0) var<uniform> params : ConvParams;
|
|
77
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
78
|
+
@group(0) @binding(2) var<storage, read> weight : array<f32>;
|
|
79
|
+
@group(0) @binding(3) var<storage, read> dy : array<f32>;
|
|
80
|
+
@group(0) @binding(4) var<storage, read_write> dx : array<f32>;
|
|
81
|
+
@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;
|
|
82
|
+
@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;
|
|
83
|
+
|
|
84
|
+
// Dispatch: (ceil(L/16), ceil(D/16), B) – computes dx
|
|
85
|
+
@compute @workgroup_size(16, 16, 1)
|
|
86
|
+
fn conv1d_backward_dx(
|
|
87
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
88
|
+
) {
|
|
89
|
+
let L = params.seq_len;
|
|
90
|
+
let D = params.d_channels;
|
|
91
|
+
let K = params.kernel_size;
|
|
92
|
+
let B = params.batch;
|
|
93
|
+
|
|
94
|
+
let t = gid.x;
|
|
95
|
+
let d = gid.y;
|
|
96
|
+
let b = gid.z;
|
|
97
|
+
|
|
98
|
+
if (t >= L || d >= D || b >= B) { return; }
|
|
99
|
+
|
|
100
|
+
var grad: f32 = 0.0;
|
|
101
|
+
|
|
102
|
+
// dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]
|
|
103
|
+
for (var k: u32 = 0u; k < K; k = k + 1u) {
|
|
104
|
+
let tp = t + k;
|
|
105
|
+
if (tp < L) {
|
|
106
|
+
let dy_idx = b * L * D + tp * D + d;
|
|
107
|
+
let w_idx = d * K + k;
|
|
108
|
+
grad = grad + dy[dy_idx] * weight[w_idx];
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
let dx_idx = b * L * D + t * D + d;
|
|
113
|
+
dx[dx_idx] = grad;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// Dispatch: (K, D, 1) – accumulates dweight over (B, L)
|
|
117
|
+
@compute @workgroup_size(1, 1, 1)
|
|
118
|
+
fn conv1d_backward_dw(
|
|
119
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
120
|
+
) {
|
|
121
|
+
let L = params.seq_len;
|
|
122
|
+
let D = params.d_channels;
|
|
123
|
+
let K = params.kernel_size;
|
|
124
|
+
let B = params.batch;
|
|
125
|
+
|
|
126
|
+
let k = gid.x;
|
|
127
|
+
let d = gid.y;
|
|
128
|
+
|
|
129
|
+
if (k >= K || d >= D) { return; }
|
|
130
|
+
|
|
131
|
+
var grad_w: f32 = 0.0;
|
|
132
|
+
var grad_b: f32 = 0.0;
|
|
133
|
+
|
|
134
|
+
for (var b: u32 = 0u; b < B; b = b + 1u) {
|
|
135
|
+
for (var t: u32 = 0u; t < L; t = t + 1u) {
|
|
136
|
+
let dy_idx = b * L * D + t * D + d;
|
|
137
|
+
let dy_val = dy[dy_idx];
|
|
138
|
+
if (t >= k) {
|
|
139
|
+
let x_idx = b * L * D + (t - k) * D + d;
|
|
140
|
+
grad_w = grad_w + dy_val * x[x_idx];
|
|
141
|
+
}
|
|
142
|
+
if (k == 0u) {
|
|
143
|
+
grad_b = grad_b + dy_val;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
dweight[d * K + k] = grad_w;
|
|
149
|
+
if (k == 0u) {
|
|
150
|
+
dbias[d] = grad_b;
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
`;
|