mambacode.js 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sean Hogg
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
package/README.md ADDED
@@ -0,0 +1,196 @@
1
+ # Mamba
2
+ MambaCode.js — WebGPU-accelerated Mamba SSM library for browser-based code model training and inference.
3
+
4
+ ## Overview
5
+
6
+ MambaCode.js is a pure JavaScript/WGSL implementation of the **Mamba State Space Model (SSM)** architecture, optimised for on-device code model training and inference in the browser. It targets the Qwen3.5-Coder-0.8B logic and supports full **on-device training** (backpropagation) via WebGPU, allowing models to adapt to a user's private codebase locally — without any data leaving the browser.
7
+
8
+ ### Key features
9
+
10
+ | Feature | Detail |
11
+ |---|---|
12
+ | **Architecture** | Selective State Space Model (S6) — linear O(N) context scaling |
13
+ | **Hardware target** | WebGPU (WGSL) — Chrome 113+, Edge 113+, Firefox Nightly |
14
+ | **Memory ceiling** | ≤ 3 GB VRAM (Chrome/Edge/Firefox stable) |
15
+ | **No heavy frameworks** | Zero TensorFlow.js / Transformers.js dependencies |
16
+ | **On-device training** | Tape-based autograd + AdamW GPU optimizer |
17
+ | **Quantization** | FP16 weights, Int8 activations |
18
+ | **Tokenizer** | Browser-side BPE (Qwen3.5-Coder compatible) |
19
+ | **WSLA mode** | Fine-tune only B & C matrices for rapid local adaptation |
20
+
21
+ ---
22
+
23
+ ## Architecture
24
+
25
+ ```
26
+ Token IDs
27
+
28
+
29
+ Embedding Lookup (GPU gather kernel)
30
+
31
+ ▼ ┌─────────────────────────────────────────┐
32
+ │ Mamba Block × N │
33
+ │ │
34
+ │ Input ──► RMSNorm │
35
+ │ │ │
36
+ │ ┌────────┴────────┐ │
37
+ │ ▼ ▼ │
38
+ │ in_proj(x) in_proj(z) [gate] │
39
+ │ │ │
40
+ │ Conv1D (causal, K=4) │
41
+ │ │ │
42
+ │ SiLU activation │
43
+ │ │ │
44
+ │ x_proj → Δ, B, C (selective) │
45
+ │ │ │
46
+ │ Δ → dt_proj (full D_inner width) │
47
+ │ │ │
48
+ │ ┌───▼──────────────────────────────┐ │
49
+ │ │ Selective Scan S6 │ │
50
+ │ │ (Kogge-Stone parallel prefix) │ │
51
+ │ │ h_t = Ā·h_{t-1} + B̄·x_t │ │
52
+ │ │ y_t = C·h_t + D·x_t │ │
53
+ │ └──────────────────────────────────┘ │
54
+ │ │ │
55
+ │ Gate: y * SiLU(z) │
56
+ │ │ │
57
+ │ out_proj → residual add ──► output │
58
+ └─────────────────────────────────────────┘
59
+
60
+
61
+ Final RMSNorm → LM Head (tied embedding) → Logits
62
+ ```
63
+
64
+ ---
65
+
66
+ ## Quick Start
67
+
68
+ ```js
69
+ import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from './src/index.js';
70
+
71
+ // 1. Initialise WebGPU
72
+ const { device } = await initWebGPU();
73
+
74
+ // 2. Load tokenizer
75
+ const tokenizer = new BPETokenizer();
76
+ await tokenizer.load('/vocab.json', '/merges.txt');
77
+
78
+ // 3. Create model
79
+ const model = new MambaModel(device, {
80
+ vocabSize : tokenizer.vocabSize, // e.g. 151936 for Qwen3.5-Coder
81
+ dModel : 512,
82
+ numLayers : 8,
83
+ dState : 16,
84
+ dConv : 4,
85
+ expand : 2,
86
+ });
87
+
88
+ // 4. Train on local code
89
+ const trainer = new MambaTrainer(model, tokenizer);
90
+ const losses = await trainer.train(myCodeString, {
91
+ learningRate : 1e-4,
92
+ epochs : 5,
93
+ device : 'webgpu',
94
+ onEpochEnd : (epoch, loss) => console.log(`Epoch ${epoch}: loss=${loss.toFixed(4)}`),
95
+ });
96
+
97
+ // 5. Generate code
98
+ const promptIds = tokenizer.encode('function fibonacci(');
99
+ const outputIds = await model.generate(promptIds, 200, { temperature: 0.8 });
100
+ console.log(tokenizer.decode(outputIds));
101
+ ```
102
+
103
+ ### WSLA (Weight-Selective Local Adaptation)
104
+
105
+ Fine-tune only the B and C matrices for rapid private-codebase adaptation:
106
+
107
+ ```js
108
+ await trainer.train(apiUsageExamples, {
109
+ learningRate : 1e-4,
110
+ epochs : 3,
111
+ wsla : true, // only B and C matrices are updated
112
+ });
113
+ ```
114
+
115
+ ---
116
+
117
+ ## File Structure
118
+
119
+ ```
120
+ src/
121
+ ├── index.js ← public API entry point
122
+ ├── kernels/
123
+ │ ├── selective_scan.js ← WGSL: S6 forward + backward (Kogge-Stone)
124
+ │ ├── conv1d.js ← WGSL: 1D causal convolution
125
+ │ ├── linear_projection.js ← WGSL: tiled matrix multiplication
126
+ │ ├── weight_update.js ← WGSL: AdamW optimizer + gradient clipping
127
+ │ └── activations.js ← WGSL: SiLU, RMSNorm
128
+ ├── model/
129
+ │ ├── mamba_block.js ← Mamba Mixer Block (forward pass)
130
+ │ └── mamba_model.js ← Full stacked model + generation
131
+ ├── training/
132
+ │ ├── autograd.js ← Tape-based AD engine + loss helpers
133
+ │ └── trainer.js ← MambaTrainer class
134
+ ├── tokenizer/
135
+ │ └── bpe.js ← Browser-side BPE tokenizer
136
+ └── utils/
137
+ ├── gpu_utils.js ← WebGPU device/buffer management
138
+ └── quantization.js ← FP16 / Int8 quantization utilities
139
+ tests/
140
+ ├── kernels.test.js ← WGSL kernel source smoke tests
141
+ ├── autograd.test.js ← Autograd engine unit tests
142
+ ├── bpe.test.js ← BPE tokenizer unit tests
143
+ └── quantization.test.js ← Quantization round-trip tests
144
+ ```
145
+
146
+ ---
147
+
148
+ ## WGSL Kernels
149
+
150
+ ### Parallel Selective Scan (`selective_scan.js`)
151
+ Implements the S6 core using a **Kogge-Stone parallel prefix-sum** inside each workgroup tile. Each tile of 64 time steps is scanned in log₂(64) = 6 GPU barrier rounds, giving O(log N) wall-clock time on the GPU.
152
+
153
+ The associative operator for the recurrence `h_t = Ā·h_{t-1} + B̄·x_t` is:
154
+
155
+ ```
156
+ (a₁, b₁) ∘ (a₂, b₂) = (a₁·a₂, a₁·b₂ + b₁)
157
+ ```
158
+
159
+ Tiles are chained via a carry-in state, covering arbitrarily long sequences.
160
+
161
+ ### 1D Causal Convolution (`conv1d.js`)
162
+ Depthwise 1D causal conv (kernel size K=4) with zero left-padding. Enforces causality by only reading positions `t-k` for `k ≥ 0`, contributing 0 for `t < k`.
163
+
164
+ ### Linear Projection (`linear_projection.js`)
165
+ Tiled 16×16 GEMM in WGSL using workgroup shared memory. Handles arbitrary (M, K) × (N, K) → (M, N) shapes with boundary guards.
166
+
167
+ ### AdamW Optimizer (`weight_update.js`)
168
+ Fused single-kernel AdamW update with decoupled weight decay. Includes a two-pass gradient norm clipping kernel (reduce → scale).
169
+
170
+ ---
171
+
172
+ ## Testing
173
+
174
+ ```bash
175
+ npm test
176
+ ```
177
+
178
+ Runs 58 unit tests covering quantization, BPE tokenization, autograd, and WGSL kernel source validation. GPU execution tests require a real browser with WebGPU support.
179
+
180
+ ---
181
+
182
+ ## Browser Compatibility
183
+
184
+ | Browser | Version | Status |
185
+ |---|---|---|
186
+ | Chrome | 113+ | ✅ Supported |
187
+ | Edge | 113+ | ✅ Supported |
188
+ | Firefox | Nightly | ✅ Supported (flag: `dom.webgpu.enabled`) |
189
+ | Safari | 18+ | ⚠️ Partial (WebGPU in preview) |
190
+ | Node.js | — | ❌ Not supported (no `navigator.gpu`) |
191
+
192
+ ---
193
+
194
+ ## License
195
+
196
+ MIT
package/package.json ADDED
@@ -0,0 +1,54 @@
1
+ {
2
+ "name": "mambacode.js",
3
+ "version": "1.0.0",
4
+ "description": "High-performance JavaScript/WGSL Mamba SSM library for browser-based code model training and inference",
5
+ "main": "src/index.js",
6
+ "type": "module",
7
+ "files": [
8
+ "src",
9
+ "README.md",
10
+ "LICENSE"
11
+ ],
12
+ "scripts": {
13
+ "test": "node --experimental-vm-modules node_modules/.bin/jest",
14
+ "lint": "eslint src/ tests/"
15
+ },
16
+ "keywords": [
17
+ "mamba",
18
+ "ssm",
19
+ "state-space-model",
20
+ "webgpu",
21
+ "wgsl",
22
+ "machine-learning",
23
+ "code-model",
24
+ "bpe",
25
+ "transformer-alternative"
26
+ ],
27
+ "author": {
28
+ "name": "Sean Hogg",
29
+ "email": "seanhogg@gmail.com",
30
+ "url": "https://builderforce.ai"
31
+ },
32
+ "license": "MIT",
33
+ "repository": {
34
+ "type": "git",
35
+ "url": "https://github.com/SeanHogg/Mamba.git"
36
+ },
37
+ "bugs": {
38
+ "url": "https://github.com/SeanHogg/Mamba/issues"
39
+ },
40
+ "homepage": "https://github.com/SeanHogg/Mamba#readme",
41
+ "engines": {
42
+ "node": ">=18.0.0"
43
+ },
44
+ "publishConfig": {
45
+ "access": "public"
46
+ },
47
+ "devDependencies": {
48
+ "jest": "^29.7.0",
49
+ "eslint": "^8.57.0"
50
+ },
51
+ "jest": {
52
+ "transform": {}
53
+ }
54
+ }
package/src/index.js ADDED
@@ -0,0 +1,89 @@
1
+ /**
2
+ * MambaCode.js – Entry Point
3
+ *
4
+ * High-performance JavaScript/WGSL Mamba SSM library for browser-based
5
+ * code model training and inference.
6
+ *
7
+ * Quick-start example
8
+ * -------------------
9
+ * import { MambaModel, MambaTrainer, BPETokenizer, initWebGPU } from 'mambacode.js';
10
+ *
11
+ * const { device } = await initWebGPU();
12
+ * const tokenizer = new BPETokenizer();
13
+ * await tokenizer.load('/vocab.json', '/merges.txt');
14
+ *
15
+ * const model = new MambaModel(device, {
16
+ * vocabSize : tokenizer.vocabSize,
17
+ * dModel : 512,
18
+ * numLayers : 8,
19
+ * });
20
+ *
21
+ * const trainer = new MambaTrainer(model, tokenizer);
22
+ * const losses = await trainer.train(myCodeString, { learningRate: 1e-4, epochs: 5 });
23
+ *
24
+ * const generated = await model.generate(tokenizer.encode('function '), 100);
25
+ * console.log(tokenizer.decode(generated));
26
+ */
27
+
28
+ // ── Core model ────────────────────────────────────────────────────────────────
29
+ export { MambaModel } from './model/mamba_model.js';
30
+ export { MambaBlock } from './model/mamba_block.js';
31
+
32
+ // ── Training ──────────────────────────────────────────────────────────────────
33
+ export { MambaTrainer } from './training/trainer.js';
34
+ export {
35
+ Tensor,
36
+ backward,
37
+ enableGrad,
38
+ noGrad,
39
+ clearTape,
40
+ recordOperation,
41
+ crossEntropyLoss,
42
+ crossEntropyGrad,
43
+ } from './training/autograd.js';
44
+
45
+ // ── Tokenizer ─────────────────────────────────────────────────────────────────
46
+ export { BPETokenizer } from './tokenizer/bpe.js';
47
+
48
+ // ── WebGPU utilities ──────────────────────────────────────────────────────────
49
+ export {
50
+ initWebGPU,
51
+ createStorageBuffer,
52
+ createEmptyStorageBuffer,
53
+ createUniformBuffer,
54
+ createComputePipeline,
55
+ createBindGroup,
56
+ dispatchKernel,
57
+ readBuffer,
58
+ uploadBuffer,
59
+ cdiv,
60
+ } from './utils/gpu_utils.js';
61
+
62
+ // ── Quantization utilities ────────────────────────────────────────────────────
63
+ export {
64
+ quantizeFp16,
65
+ dequantizeFp16,
66
+ floatToFp16,
67
+ fp16ToFloat,
68
+ quantizeInt8,
69
+ dequantizeInt8,
70
+ quantizeInt8PerChannel,
71
+ dequantizeInt8PerChannel,
72
+ estimateMemory,
73
+ } from './utils/quantization.js';
74
+
75
+ // ── Raw WGSL kernel sources (for advanced users / custom pipelines) ───────────
76
+ export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL }
77
+ from './kernels/selective_scan.js';
78
+ export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL }
79
+ from './kernels/conv1d.js';
80
+ export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL }
81
+ from './kernels/linear_projection.js';
82
+ export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL }
83
+ from './kernels/weight_update.js';
84
+ export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL }
85
+ from './kernels/activations.js';
86
+
87
+ // ── Library metadata ──────────────────────────────────────────────────────────
88
+ export const VERSION = '0.1.0';
89
+ export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
@@ -0,0 +1,88 @@
1
+ // Activation function WGSL kernels: SiLU (Swish) and its backward pass.
2
+ // Used in the gating mechanism of the Mamba Mixer Block.
3
+
4
+ export const ACTIVATIONS_WGSL = /* wgsl */`
5
+
6
+ struct ActParams {
7
+ num_elements : u32,
8
+ };
9
+
10
+ @group(0) @binding(0) var<uniform> p : ActParams;
11
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
12
+ @group(0) @binding(2) var<storage, read_write> y : array<f32>;
13
+
14
+ // SiLU(x) = x * sigmoid(x)
15
+ @compute @workgroup_size(256, 1, 1)
16
+ fn silu_forward(
17
+ @builtin(global_invocation_id) gid : vec3<u32>,
18
+ ) {
19
+ let i = gid.x;
20
+ if (i >= p.num_elements) { return; }
21
+ let v = x[i];
22
+ y[i] = v / (1.0 + exp(-v));
23
+ }
24
+
25
+ // RMSNorm forward: y = x / rms(x) * weight
26
+ // Requires separate uniform for rms norm params.
27
+ struct RMSNormParams {
28
+ num_rows : u32, // number of vectors (batch * seq_len)
29
+ dim : u32, // feature dimension
30
+ eps : f32,
31
+ };
32
+
33
+ @group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
34
+ @group(0) @binding(1) var<storage, read> rms_x : array<f32>;
35
+ @group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
36
+ @group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
37
+ @group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
38
+
39
+ @compute @workgroup_size(64, 1, 1)
40
+ fn rmsnorm_forward(
41
+ @builtin(global_invocation_id) gid : vec3<u32>,
42
+ ) {
43
+ let row = gid.x;
44
+ if (row >= rms_p.num_rows) { return; }
45
+
46
+ let D = rms_p.dim;
47
+ let base = row * D;
48
+
49
+ var sq_sum: f32 = 0.0;
50
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
51
+ let v = rms_x[base + i];
52
+ sq_sum = sq_sum + v * v;
53
+ }
54
+ let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
55
+ rms_inv[row] = inv_rms;
56
+
57
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
58
+ rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
59
+ }
60
+ }
61
+ `;
62
+
63
+ // ---- Backward for SiLU ----
64
+ export const ACTIVATIONS_BACKWARD_WGSL = /* wgsl */`
65
+
66
+ struct ActParams {
67
+ num_elements : u32,
68
+ };
69
+
70
+ @group(0) @binding(0) var<uniform> p : ActParams;
71
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
72
+ @group(0) @binding(2) var<storage, read> dy : array<f32>;
73
+ @group(0) @binding(3) var<storage, read_write> dx : array<f32>;
74
+
75
+ // d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
76
+ // = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
77
+ // simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
78
+ @compute @workgroup_size(256, 1, 1)
79
+ fn silu_backward(
80
+ @builtin(global_invocation_id) gid : vec3<u32>,
81
+ ) {
82
+ let i = gid.x;
83
+ if (i >= p.num_elements) { return; }
84
+ let v = x[i];
85
+ let sig = 1.0 / (1.0 + exp(-v));
86
+ dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
87
+ }
88
+ `;
@@ -0,0 +1,153 @@
1
+ // 1D Causal Convolution WGSL Kernel
2
+ // Implements a depthwise 1D causal convolution over the sequence dimension.
3
+ // "Causal" means the output at position t only depends on positions <= t,
4
+ // which is enforced by left-padding with (kernel_size - 1) zeros.
5
+ //
6
+ // Forward: y[b, t, d] = sum_{k=0}^{K-1} weight[d, k] * x[b, t-k, d]
7
+ // where x[b, t', d] = 0 for t' < 0 (causal padding)
8
+
9
+ export const CONV1D_FORWARD_WGSL = /* wgsl */`
10
+
11
+ struct ConvParams {
12
+ seq_len : u32, // L
13
+ d_channels : u32, // D (number of depthwise channels)
14
+ kernel_size : u32, // K (typically 4)
15
+ batch : u32, // B
16
+ };
17
+
18
+ @group(0) @binding(0) var<uniform> params : ConvParams;
19
+ // x (B, L, D) – input
20
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
21
+ // weight (D, K) – depthwise conv weights
22
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
23
+ // bias (D,) – optional bias (zeros if unused)
24
+ @group(0) @binding(3) var<storage, read> bias : array<f32>;
25
+ // y (B, L, D) – output
26
+ @group(0) @binding(4) var<storage, read_write> y : array<f32>;
27
+
28
+ // Dispatch: (ceil(L/16), ceil(D/16), B)
29
+ @compute @workgroup_size(16, 16, 1)
30
+ fn conv1d_forward(
31
+ @builtin(global_invocation_id) gid : vec3<u32>,
32
+ ) {
33
+ let L = params.seq_len;
34
+ let D = params.d_channels;
35
+ let K = params.kernel_size;
36
+ let B = params.batch;
37
+
38
+ let t = gid.x; // time position
39
+ let d = gid.y; // channel
40
+ let b = gid.z; // batch
41
+
42
+ if (t >= L || d >= D || b >= B) { return; }
43
+
44
+ var acc: f32 = 0.0;
45
+
46
+ // Causal: convolve over k = 0..K-1, reading position (t - k)
47
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
48
+ let w_idx = d * K + k;
49
+ let w_val = weight[w_idx];
50
+
51
+ // t - k: use causal zero-padding for t < k
52
+ if (t >= k) {
53
+ let src = b * L * D + (t - k) * D + d;
54
+ acc = acc + w_val * x[src];
55
+ }
56
+ // else: zero-padding contributes 0
57
+ }
58
+
59
+ acc = acc + bias[d];
60
+
61
+ let out = b * L * D + t * D + d;
62
+ y[out] = acc;
63
+ }
64
+ `;
65
+
66
+ // ---- Backward kernel for 1D convolution ----
67
+ export const CONV1D_BACKWARD_WGSL = /* wgsl */`
68
+
69
+ struct ConvParams {
70
+ seq_len : u32,
71
+ d_channels : u32,
72
+ kernel_size : u32,
73
+ batch : u32,
74
+ };
75
+
76
+ @group(0) @binding(0) var<uniform> params : ConvParams;
77
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
78
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
79
+ @group(0) @binding(3) var<storage, read> dy : array<f32>;
80
+ @group(0) @binding(4) var<storage, read_write> dx : array<f32>;
81
+ @group(0) @binding(5) var<storage, read_write> dweight : array<f32>;
82
+ @group(0) @binding(6) var<storage, read_write> dbias : array<f32>;
83
+
84
+ // Dispatch: (ceil(L/16), ceil(D/16), B) – computes dx
85
+ @compute @workgroup_size(16, 16, 1)
86
+ fn conv1d_backward_dx(
87
+ @builtin(global_invocation_id) gid : vec3<u32>,
88
+ ) {
89
+ let L = params.seq_len;
90
+ let D = params.d_channels;
91
+ let K = params.kernel_size;
92
+ let B = params.batch;
93
+
94
+ let t = gid.x;
95
+ let d = gid.y;
96
+ let b = gid.z;
97
+
98
+ if (t >= L || d >= D || b >= B) { return; }
99
+
100
+ var grad: f32 = 0.0;
101
+
102
+ // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]
103
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
104
+ let tp = t + k;
105
+ if (tp < L) {
106
+ let dy_idx = b * L * D + tp * D + d;
107
+ let w_idx = d * K + k;
108
+ grad = grad + dy[dy_idx] * weight[w_idx];
109
+ }
110
+ }
111
+
112
+ let dx_idx = b * L * D + t * D + d;
113
+ dx[dx_idx] = grad;
114
+ }
115
+
116
+ // Dispatch: (K, D, 1) – accumulates dweight over (B, L)
117
+ @compute @workgroup_size(1, 1, 1)
118
+ fn conv1d_backward_dw(
119
+ @builtin(global_invocation_id) gid : vec3<u32>,
120
+ ) {
121
+ let L = params.seq_len;
122
+ let D = params.d_channels;
123
+ let K = params.kernel_size;
124
+ let B = params.batch;
125
+
126
+ let k = gid.x;
127
+ let d = gid.y;
128
+
129
+ if (k >= K || d >= D) { return; }
130
+
131
+ var grad_w: f32 = 0.0;
132
+ var grad_b: f32 = 0.0;
133
+
134
+ for (var b: u32 = 0u; b < B; b = b + 1u) {
135
+ for (var t: u32 = 0u; t < L; t = t + 1u) {
136
+ let dy_idx = b * L * D + t * D + d;
137
+ let dy_val = dy[dy_idx];
138
+ if (t >= k) {
139
+ let x_idx = b * L * D + (t - k) * D + d;
140
+ grad_w = grad_w + dy_val * x[x_idx];
141
+ }
142
+ if (k == 0u) {
143
+ grad_b = grad_b + dy_val;
144
+ }
145
+ }
146
+ }
147
+
148
+ dweight[d * K + k] = grad_w;
149
+ if (k == 0u) {
150
+ dbias[d] = grad_b;
151
+ }
152
+ }
153
+ `;