mambacode.js 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,220 @@
1
+ // Linear Projection WGSL Kernel
2
+ // General-purpose matrix multiplication: Y = X @ W^T + b
3
+ // Supports the up-projection and down-projection linear layers in the Mamba block.
4
+ //
5
+ // Shapes:
6
+ // X : (batch * seq_len, in_features) – input (rows)
7
+ // W : (out_features, in_features) – weight matrix (row-major)
8
+ // b : (out_features,) – bias
9
+ // Y : (batch * seq_len, out_features) – output
10
+
11
+ export const LINEAR_FORWARD_WGSL = /* wgsl */`
12
+
13
+ struct LinearParams {
14
+ M : u32, // number of rows (batch * seq_len)
15
+ K : u32, // in_features
16
+ N : u32, // out_features
17
+ };
18
+
19
+ @group(0) @binding(0) var<uniform> params : LinearParams;
20
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
21
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
22
+ @group(0) @binding(3) var<storage, read> bias : array<f32>; // (N,)
23
+ @group(0) @binding(4) var<storage, read_write> Y : array<f32>; // (M, N)
24
+
25
+ // Tiled matmul using workgroup shared memory (16x16 tiles)
26
+ var<workgroup> tile_X : array<f32, 256>; // 16 * 16
27
+ var<workgroup> tile_W : array<f32, 256>;
28
+
29
+ @compute @workgroup_size(16, 16, 1)
30
+ fn linear_forward(
31
+ @builtin(global_invocation_id) gid : vec3<u32>,
32
+ @builtin(local_invocation_id) lid : vec3<u32>,
33
+ @builtin(workgroup_id) wid : vec3<u32>,
34
+ ) {
35
+ let M = params.M;
36
+ let K = params.K;
37
+ let N = params.N;
38
+
39
+ let row = gid.x; // output row (M dimension)
40
+ let col = gid.y; // output col (N dimension)
41
+
42
+ var acc: f32 = 0.0;
43
+ let TILE: u32 = 16u;
44
+ let num_tiles = (K + TILE - 1u) / TILE;
45
+
46
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
47
+ // Load X tile: shape (TILE_M, TILE_K)
48
+ let x_col = tile_idx * TILE + lid.y;
49
+ let x_row = wid.x * TILE + lid.x;
50
+ if (x_row < M && x_col < K) {
51
+ tile_X[lid.x * TILE + lid.y] = X[x_row * K + x_col];
52
+ } else {
53
+ tile_X[lid.x * TILE + lid.y] = 0.0;
54
+ }
55
+
56
+ // Load W tile: shape (TILE_N, TILE_K) — W is (N, K)
57
+ let w_col = tile_idx * TILE + lid.x; // K dimension
58
+ let w_row = wid.y * TILE + lid.y; // N dimension
59
+ if (w_row < N && w_col < K) {
60
+ tile_W[lid.y * TILE + lid.x] = W[w_row * K + w_col];
61
+ } else {
62
+ tile_W[lid.y * TILE + lid.x] = 0.0;
63
+ }
64
+
65
+ workgroupBarrier();
66
+
67
+ // Dot product within tile
68
+ for (var k: u32 = 0u; k < TILE; k = k + 1u) {
69
+ acc = acc + tile_X[lid.x * TILE + k] * tile_W[lid.y * TILE + k];
70
+ }
71
+ workgroupBarrier();
72
+ }
73
+
74
+ if (row < M && col < N) {
75
+ Y[row * N + col] = acc + bias[col];
76
+ }
77
+ }
78
+ `;
79
+
80
+ // ---- Backward pass for linear projection ----
81
+ export const LINEAR_BACKWARD_WGSL = /* wgsl */`
82
+
83
+ struct LinearParams {
84
+ M : u32,
85
+ K : u32,
86
+ N : u32,
87
+ };
88
+
89
+ @group(0) @binding(0) var<uniform> params : LinearParams;
90
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
91
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
92
+ @group(0) @binding(3) var<storage, read> dY : array<f32>; // (M, N)
93
+ @group(0) @binding(4) var<storage, read_write> dX : array<f32>; // (M, K)
94
+ @group(0) @binding(5) var<storage, read_write> dW : array<f32>; // (N, K)
95
+ @group(0) @binding(6) var<storage, read_write> db : array<f32>; // (N,)
96
+
97
+ // Dispatch: (ceil(M/16), ceil(K/16), 1) – computes dX = dY @ W
98
+ var<workgroup> tile_dY : array<f32, 256>;
99
+ var<workgroup> tile_W : array<f32, 256>;
100
+
101
+ @compute @workgroup_size(16, 16, 1)
102
+ fn linear_backward_dX(
103
+ @builtin(global_invocation_id) gid : vec3<u32>,
104
+ @builtin(local_invocation_id) lid : vec3<u32>,
105
+ @builtin(workgroup_id) wid : vec3<u32>,
106
+ ) {
107
+ let M = params.M;
108
+ let K = params.K;
109
+ let N = params.N;
110
+
111
+ let row = gid.x; // M
112
+ let col = gid.y; // K
113
+
114
+ var acc: f32 = 0.0;
115
+ let TILE: u32 = 16u;
116
+ let num_tiles = (N + TILE - 1u) / TILE;
117
+
118
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
119
+ // tile_dY: (M, TILE_N) slice
120
+ let dy_col = tile_idx * TILE + lid.y;
121
+ let dy_row = wid.x * TILE + lid.x;
122
+ if (dy_row < M && dy_col < N) {
123
+ tile_dY[lid.x * TILE + lid.y] = dY[dy_row * N + dy_col];
124
+ } else {
125
+ tile_dY[lid.x * TILE + lid.y] = 0.0;
126
+ }
127
+
128
+ // tile_W: (TILE_N, K) slice — W[n, k]
129
+ let w_row = tile_idx * TILE + lid.x; // N
130
+ let w_col = wid.y * TILE + lid.y; // K
131
+ if (w_row < N && w_col < K) {
132
+ tile_W[lid.x * TILE + lid.y] = W[w_row * K + w_col];
133
+ } else {
134
+ tile_W[lid.x * TILE + lid.y] = 0.0;
135
+ }
136
+
137
+ workgroupBarrier();
138
+
139
+ for (var n: u32 = 0u; n < TILE; n = n + 1u) {
140
+ acc = acc + tile_dY[lid.x * TILE + n] * tile_W[n * TILE + lid.y];
141
+ }
142
+ workgroupBarrier();
143
+ }
144
+
145
+ if (row < M && col < K) {
146
+ dX[row * K + col] = acc;
147
+ }
148
+ }
149
+
150
+ // Dispatch: (ceil(N/16), ceil(K/16), 1) – computes dW = dY^T @ X
151
+ var<workgroup> tile_dY2 : array<f32, 256>;
152
+ var<workgroup> tile_X2 : array<f32, 256>;
153
+
154
+ @compute @workgroup_size(16, 16, 1)
155
+ fn linear_backward_dW(
156
+ @builtin(global_invocation_id) gid : vec3<u32>,
157
+ @builtin(local_invocation_id) lid : vec3<u32>,
158
+ @builtin(workgroup_id) wid : vec3<u32>,
159
+ ) {
160
+ let M = params.M;
161
+ let K = params.K;
162
+ let N = params.N;
163
+
164
+ let row = gid.x; // N
165
+ let col = gid.y; // K
166
+
167
+ var acc: f32 = 0.0;
168
+ let TILE: u32 = 16u;
169
+ let num_tiles = (M + TILE - 1u) / TILE;
170
+
171
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
172
+ // dY^T tile: [N, M] accessed as dY[m, n]
173
+ let m_idx = tile_idx * TILE + lid.y;
174
+ let n_idx = wid.x * TILE + lid.x;
175
+ if (n_idx < N && m_idx < M) {
176
+ tile_dY2[lid.x * TILE + lid.y] = dY[m_idx * N + n_idx];
177
+ } else {
178
+ tile_dY2[lid.x * TILE + lid.y] = 0.0;
179
+ }
180
+
181
+ // X tile: [M, K]
182
+ let xm = tile_idx * TILE + lid.x;
183
+ let xk = wid.y * TILE + lid.y;
184
+ if (xm < M && xk < K) {
185
+ tile_X2[lid.x * TILE + lid.y] = X[xm * K + xk];
186
+ } else {
187
+ tile_X2[lid.x * TILE + lid.y] = 0.0;
188
+ }
189
+
190
+ workgroupBarrier();
191
+
192
+ for (var m: u32 = 0u; m < TILE; m = m + 1u) {
193
+ acc = acc + tile_dY2[lid.x * TILE + m] * tile_X2[m * TILE + lid.y];
194
+ }
195
+ workgroupBarrier();
196
+ }
197
+
198
+ if (row < N && col < K) {
199
+ dW[row * K + col] = acc;
200
+ }
201
+ }
202
+
203
+ // Dispatch: (N, 1, 1) – accumulates db = sum_M dY
204
+ @compute @workgroup_size(64, 1, 1)
205
+ fn linear_backward_db(
206
+ @builtin(global_invocation_id) gid : vec3<u32>,
207
+ ) {
208
+ let M = params.M;
209
+ let N = params.N;
210
+
211
+ let n = gid.x;
212
+ if (n >= N) { return; }
213
+
214
+ var acc: f32 = 0.0;
215
+ for (var m: u32 = 0u; m < M; m = m + 1u) {
216
+ acc = acc + dY[m * N + n];
217
+ }
218
+ db[n] = acc;
219
+ }
220
+ `;
@@ -0,0 +1,350 @@
1
+ // Parallel Selective Scan WGSL Kernel
2
+ // Implements the S6 (Selective Scan) core of the Mamba architecture.
3
+ // Uses a Kogge-Stone parallel prefix-sum approach for O(log N) time on GPU.
4
+ //
5
+ // Forward pass recurrence:
6
+ // h_t = A_t * h_{t-1} + B_t * x_t
7
+ // y_t = C_t * h_t + D * x_t
8
+ //
9
+ // where A_t, B_t, C_t are input-dependent (selective) gate matrices.
10
+
11
+ export const SELECTIVE_SCAN_FORWARD_WGSL = /* wgsl */`
12
+
13
+ // ---- Binding layout ----
14
+ // group 0: sequence data
15
+ // group 1: SSM parameters
16
+
17
+ struct ScanParams {
18
+ seq_len : u32, // L – sequence length
19
+ d_state : u32, // N – state dimension
20
+ d_inner : u32, // D – inner (expanded) channel dimension
21
+ batch : u32, // B – batch size
22
+ };
23
+
24
+ @group(0) @binding(0) var<uniform> params : ScanParams;
25
+ // u (B, L, D) – projected input after conv
26
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
27
+ // delta (B, L, D) – time-step (Δ) after softplus
28
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
29
+ // A (D, N) – log-space diagonal state matrix (fixed, learned)
30
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
31
+ // B (B, L, N) – input projection (selective)
32
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
33
+ // C (B, L, N) – output projection (selective)
34
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
35
+ // D (D,) – skip-connection scale
36
+ @group(0) @binding(6) var<storage, read> D_vec : array<f32>;
37
+ // y (B, L, D) – output (written by this kernel)
38
+ @group(0) @binding(7) var<storage, read_write> y : array<f32>;
39
+ // h_cache (B, L, D*N) – hidden states cache (for backward pass)
40
+ @group(0) @binding(8) var<storage, read_write> h_cache : array<f32>;
41
+
42
+ // ---- Workgroup shared memory ----
43
+ // Each workgroup processes one (batch, channel) slice across all time steps.
44
+ // We store the associative pair (a_bar, bu_bar) per time step so we can run
45
+ // a Kogge-Stone scan across the workgroup tile.
46
+ var<workgroup> wg_a : array<f32, 256>; // discretised A values
47
+ var<workgroup> wg_bu : array<f32, 256>; // B*u values
48
+
49
+ // ---- Helpers ----
50
+
51
+ // Softplus: numerically stable log(1 + exp(x))
52
+ fn softplus(x: f32) -> f32 {
53
+ return log(1.0 + exp(x));
54
+ }
55
+
56
+ // ZerO-Order Hold discretisation of continuous A, Δ:
57
+ // A_bar = exp(Δ * A)
58
+ // B_bar = (A_bar - 1) / A * B ≈ Δ * B (first-order for simplicity)
59
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
60
+ // A is stored as -exp(a_log) to ensure A_bar < 1 (stable)
61
+ let a_cont = -exp(a_log);
62
+ return exp(delta_val * a_cont);
63
+ }
64
+
65
+ fn discretise_B(delta_val: f32, a_log: f32, b_val: f32) -> f32 {
66
+ let a_cont = -exp(a_log);
67
+ let a_bar = exp(delta_val * a_cont);
68
+ // (A_bar - 1) / A_cont * B
69
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
70
+ return b_bar;
71
+ }
72
+
73
+ // ---- Main kernel ----
74
+ // Dispatch: (ceil(D/8), ceil(N/8), B)
75
+ // Each invocation is responsible for one (d, n, batch) triplet and scans
76
+ // the entire sequence using a two-pass Kogge-Stone scan within workgroup tiles.
77
+
78
+ @compute @workgroup_size(64, 1, 1)
79
+ fn forward_scan(
80
+ @builtin(global_invocation_id) gid : vec3<u32>,
81
+ @builtin(local_invocation_index) lid : u32,
82
+ @builtin(workgroup_id) wgid : vec3<u32>,
83
+ ) {
84
+ let L = params.seq_len;
85
+ let N = params.d_state;
86
+ let D = params.d_inner;
87
+ let B = params.batch;
88
+
89
+ // Each workgroup handles one (batch b, channel d, state n) combination.
90
+ // We pack d and n into the x dimension: global d = wgid.x, global n = wgid.y
91
+ let d = wgid.x;
92
+ let n = wgid.y;
93
+ let b = gid.z;
94
+
95
+ if (d >= D || n >= N || b >= B) { return; }
96
+
97
+ // Tile size equals workgroup size (64). We process TILE_SIZE steps at once.
98
+ let TILE: u32 = 64u;
99
+
100
+ // Running state h for this (b, d, n)
101
+ var h: f32 = 0.0;
102
+
103
+ var tile_start: u32 = 0u;
104
+ loop {
105
+ if (tile_start >= L) { break; }
106
+
107
+ let t = tile_start + lid; // absolute time step handled by this lane
108
+ var a_bar: f32 = 1.0;
109
+ var bu: f32 = 0.0;
110
+
111
+ if (t < L) {
112
+ // Indices
113
+ let delta_idx = b * L * D + t * D + d;
114
+ let u_idx = b * L * D + t * D + d;
115
+ let A_idx = d * N + n;
116
+ let B_idx = b * L * N + t * N + n;
117
+
118
+ let dv = softplus(delta[delta_idx]);
119
+ a_bar = discretise_A(dv, A[A_idx]);
120
+ bu = discretise_B(dv, A[A_idx], B[B_idx]) * u[u_idx];
121
+ }
122
+
123
+ wg_a[lid] = a_bar;
124
+ wg_bu[lid] = bu;
125
+ workgroupBarrier();
126
+
127
+ // ---- Kogge-Stone inclusive prefix scan within tile ----
128
+ // Associative operator: (a1, b1) ∘ (a2, b2) = (a1*a2, a1*b2 + b1)
129
+ // This computes cumulative state recurrence in log2(TILE) steps.
130
+ var stride: u32 = 1u;
131
+ loop {
132
+ if (stride >= TILE) { break; }
133
+ if (lid >= stride) {
134
+ let prev_a = wg_a[lid - stride];
135
+ let prev_bu = wg_bu[lid - stride];
136
+ // Combine: new_state = prev_a * cur_a (product of A_bars)
137
+ // new_bu = prev_a * cur_bu + prev_bu
138
+ let new_a = prev_a * wg_a[lid];
139
+ let new_bu = prev_a * wg_bu[lid] + prev_bu;
140
+ workgroupBarrier();
141
+ wg_a[lid] = new_a;
142
+ wg_bu[lid] = new_bu;
143
+ }
144
+ workgroupBarrier();
145
+ stride = stride << 1u;
146
+ }
147
+
148
+ // Incorporate the carry-in state from the previous tile.
149
+ // After the scan wg_bu[lid] holds the intra-tile inclusive sum.
150
+ // The actual h at position t = h_carry * wg_a[lid] + wg_bu[lid]
151
+ let h_t = h * wg_a[lid] + wg_bu[lid];
152
+
153
+ if (t < L) {
154
+ // Cache hidden state for backward pass
155
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
156
+ h_cache[h_idx] = h_t;
157
+
158
+ // Accumulate y contribution: y_t += C_t[n] * h_t (over all n)
159
+ // We use an atomic-style accumulation: each (d, n) lane adds its
160
+ // contribution to the same y[b, t, d]. This races without atomics,
161
+ // so we instead write to a full h_cache and reduce in a second pass.
162
+ // Here we perform direct accumulation using atomicAdd approximation:
163
+ // (safe because each lane writes a unique n, which is stride 1 in mem)
164
+ let C_idx = b * L * N + t * N + n;
165
+ let y_idx = b * L * D + t * D + d;
166
+
167
+ // Direct write for n == 0 (first state dim), add for the rest.
168
+ // Since all workgroups for the same (b,d) run concurrently we must
169
+ // accumulate safely: we write each partial into h_cache and reduce
170
+ // in a subsequent lightweight kernel (forward_reduce).
171
+ // (For simplicity and correctness here we directly atomically add via
172
+ // f32 emulation – real deployment uses atomicAdd on f32 with spirv ext.)
173
+ // We store C*h contribution separately so forward_reduce can sum them.
174
+ // Layout: y_partial (B, L, D, N) – one slot per state dim
175
+ // y reused as y_partial in this kernel; forward_reduce collapses N dim.
176
+ let y_partial_idx = b * L * D * N + t * D * N + d * N + n;
177
+ // Reuse h_cache second half as y_partial (offset by B*L*D*N)
178
+ let offset = B * L * D * N;
179
+ h_cache[offset + y_partial_idx] = C[C_idx] * h_t;
180
+ }
181
+
182
+ // Update carry: last lane's h_t is the tile's final state
183
+ let last = min(TILE, L - tile_start) - 1u;
184
+ h = wg_a[last] * h + wg_bu[last]; // recombine carry
185
+
186
+ workgroupBarrier();
187
+ tile_start = tile_start + TILE;
188
+ }
189
+ }
190
+
191
+ // ---- Reduction kernel ----
192
+ // Collapses the N (d_state) dimension of y_partial into y.
193
+ // Adds the D (skip connection) term: y_t[d] += D_vec[d] * u_t[d]
194
+ // Dispatch: (ceil(L/64), D, B)
195
+
196
+ @compute @workgroup_size(64, 1, 1)
197
+ fn forward_reduce(
198
+ @builtin(global_invocation_id) gid : vec3<u32>,
199
+ ) {
200
+ let L = params.seq_len;
201
+ let N = params.d_state;
202
+ let D = params.d_inner;
203
+ let B = params.batch;
204
+
205
+ let t = gid.x;
206
+ let d = gid.y;
207
+ let b = gid.z;
208
+
209
+ if (t >= L || d >= D || b >= B) { return; }
210
+
211
+ let offset = B * L * D * N;
212
+ var sum: f32 = 0.0;
213
+ for (var n: u32 = 0u; n < N; n = n + 1u) {
214
+ let idx = offset + b * L * D * N + t * D * N + d * N + n;
215
+ sum = sum + h_cache[idx];
216
+ }
217
+
218
+ // Add skip connection
219
+ let u_idx = b * L * D + t * D + d;
220
+ sum = sum + D_vec[d] * u[u_idx];
221
+
222
+ let y_idx = b * L * D + t * D + d;
223
+ y[y_idx] = sum;
224
+ }
225
+ `;
226
+
227
+ // ---- Backward scan kernel (for autograd) ----
228
+ // Computes gradients w.r.t. Δ, A, B, C using the cached hidden states.
229
+
230
+ export const SELECTIVE_SCAN_BACKWARD_WGSL = /* wgsl */`
231
+
232
+ struct ScanParams {
233
+ seq_len : u32,
234
+ d_state : u32,
235
+ d_inner : u32,
236
+ batch : u32,
237
+ };
238
+
239
+ @group(0) @binding(0) var<uniform> params : ScanParams;
240
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
241
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
242
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
243
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
244
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
245
+ @group(0) @binding(6) var<storage, read> h_cache : array<f32>;
246
+ @group(0) @binding(7) var<storage, read> dy : array<f32>; // upstream gradient
247
+ @group(0) @binding(8) var<storage, read_write> dA : array<f32>;
248
+ @group(0) @binding(9) var<storage, read_write> dB : array<f32>;
249
+ @group(0) @binding(10) var<storage, read_write> dC : array<f32>;
250
+ @group(0) @binding(11) var<storage, read_write> dDelta : array<f32>;
251
+ @group(0) @binding(12) var<storage, read_write> du : array<f32>;
252
+
253
+ fn softplus(x: f32) -> f32 {
254
+ return log(1.0 + exp(x));
255
+ }
256
+
257
+ fn softplus_grad(x: f32) -> f32 {
258
+ // d/dx softplus(x) = sigmoid(x)
259
+ return 1.0 / (1.0 + exp(-x));
260
+ }
261
+
262
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
263
+ let a_cont = -exp(a_log);
264
+ return exp(delta_val * a_cont);
265
+ }
266
+
267
+ // Reverse scan (backward pass) – processes time from T-1 down to 0.
268
+ // Dispatch: (D, N, B)
269
+ @compute @workgroup_size(1, 1, 1)
270
+ fn backward_scan(
271
+ @builtin(global_invocation_id) gid : vec3<u32>,
272
+ ) {
273
+ let L = params.seq_len;
274
+ let N = params.d_state;
275
+ let D = params.d_inner;
276
+ let B = params.batch;
277
+
278
+ let d = gid.x;
279
+ let n = gid.y;
280
+ let b = gid.z;
281
+
282
+ if (d >= D || n >= N || b >= B) { return; }
283
+
284
+ var dh: f32 = 0.0; // gradient of loss w.r.t. h_t, accumulated backwards
285
+
286
+ var t: u32 = L;
287
+ loop {
288
+ if (t == 0u) { break; }
289
+ t = t - 1u;
290
+
291
+ let delta_raw_idx = b * L * D + t * D + d;
292
+ let A_idx = d * N + n;
293
+ let B_idx = b * L * N + t * N + n;
294
+ let C_idx = b * L * N + t * N + n;
295
+ let u_idx = b * L * D + t * D + d;
296
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
297
+
298
+ let delta_raw = delta[delta_raw_idx];
299
+ let dv = softplus(delta_raw);
300
+ let a_log = A[A_idx];
301
+ let a_cont = -exp(a_log);
302
+ let a_bar = exp(dv * a_cont);
303
+ let b_val = B[B_idx];
304
+ let c_val = C[C_idx];
305
+ let u_val = u[u_idx];
306
+ let h_t = h_cache[h_idx];
307
+
308
+ // dy_t contribution to dh (from C * h_t in the output)
309
+ // y_t[d] = sum_n C[n] * h_t[n] + D * u => dh_t[n] += C[n] * dy_t[d]
310
+ let dy_val = dy[b * L * D + t * D + d];
311
+ dh = dh + c_val * dy_val;
312
+
313
+ // dC[b, t, n] += dy_t[d] * h_t
314
+ dC[C_idx] = dC[C_idx] + dy_val * h_t;
315
+
316
+ // h_t = a_bar * h_{t-1} + b_bar * u_t
317
+ // b_bar = (a_bar - 1) / a_cont * b_val
318
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
319
+ let h_prev = (t > 0u) ? h_cache[b * L * D * N + (t - 1u) * D * N + d * N + n] : 0.0;
320
+
321
+ // dh_{t-1} += a_bar * dh_t
322
+ // (accumulated in next iteration; here dh already contains upstream)
323
+ let dh_cur = dh;
324
+
325
+ // dA[d,n] += dh_t * (d a_bar/d a_cont) * (d a_cont/d a_log) * h_{t-1}
326
+ // + dh_t * (d b_bar/d a_cont) * ... * b_val * u_val
327
+ // d(a_bar)/d(a_log) = a_bar * (-exp(a_log)) * dv = a_bar * a_cont * dv
328
+ let da_bar_da_log = a_bar * a_cont * dv;
329
+ dA[A_idx] = dA[A_idx] + dh_cur * (da_bar_da_log * h_prev);
330
+
331
+ // dB[b,t,n] += dh_t * b_bar / b_val * u_val (since b_bar is linear in b)
332
+ dB[B_idx] = dB[B_idx] + dh_cur * ((a_bar - 1.0) / a_cont) * u_val;
333
+
334
+ // du[b,t,d] += dh_t * b_bar (accumulate over n in separate kernel)
335
+ du[u_idx] = du[u_idx] + dh_cur * b_bar;
336
+
337
+ // dDelta[b,t,d]: chain rule through softplus and discretisation
338
+ // d(b_bar)/d(dv) = d/d(dv)[(a_bar-1)/a_cont * b] = a_bar * b / (a_cont ... )
339
+ // actually: d(a_bar)/d(dv) = a_bar * a_cont, d(b_bar)/d(dv) = a_bar * b_val
340
+ let da_bar_ddv = a_bar * a_cont;
341
+ let db_bar_ddv = a_bar * b_val;
342
+ let dLoss_ddv = dh_cur * (da_bar_ddv * h_prev + db_bar_ddv * u_val);
343
+ let ddv_ddelta = softplus_grad(delta_raw);
344
+ dDelta[delta_raw_idx] = dDelta[delta_raw_idx] + dLoss_ddv * ddv_ddelta;
345
+
346
+ // Propagate dh to previous timestep
347
+ dh = a_bar * dh_cur;
348
+ }
349
+ }
350
+ `;
@@ -0,0 +1,120 @@
1
+ // Weight Update WGSL Kernel (AdamW Optimizer)
2
+ // Implements fused AdamW parameter update on the GPU.
3
+ //
4
+ // AdamW update rule:
5
+ // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
6
+ // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
7
+ // m_hat = m_t / (1 - beta1^t)
8
+ // v_hat = v_t / (1 - beta2^t)
9
+ // theta_t = theta_{t-1} * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
10
+
11
+ export const WEIGHT_UPDATE_WGSL = /* wgsl */`
12
+
13
+ struct AdamParams {
14
+ num_elements : u32,
15
+ lr : f32, // learning rate
16
+ beta1 : f32, // default 0.9
17
+ beta2 : f32, // default 0.999
18
+ eps : f32, // default 1e-8
19
+ weight_decay : f32, // default 0.01
20
+ beta1_t : f32, // beta1^t (precomputed bias correction term)
21
+ beta2_t : f32, // beta2^t
22
+ };
23
+
24
+ @group(0) @binding(0) var<uniform> adam : AdamParams;
25
+ // param (N,) – weight tensor (read-write: updated in-place)
26
+ @group(0) @binding(1) var<storage, read_write> param : array<f32>;
27
+ // grad (N,) – gradient
28
+ @group(0) @binding(2) var<storage, read> grad : array<f32>;
29
+ // m (N,) – first moment
30
+ @group(0) @binding(3) var<storage, read_write> m_state : array<f32>;
31
+ // v (N,) – second moment
32
+ @group(0) @binding(4) var<storage, read_write> v_state : array<f32>;
33
+
34
+ // Dispatch: (ceil(N / 256), 1, 1)
35
+ @compute @workgroup_size(256, 1, 1)
36
+ fn adamw_update(
37
+ @builtin(global_invocation_id) gid : vec3<u32>,
38
+ ) {
39
+ let i = gid.x;
40
+ if (i >= adam.num_elements) { return; }
41
+
42
+ let g = grad[i];
43
+ let p = param[i];
44
+
45
+ // Moment updates
46
+ let m_new = adam.beta1 * m_state[i] + (1.0 - adam.beta1) * g;
47
+ let v_new = adam.beta2 * v_state[i] + (1.0 - adam.beta2) * g * g;
48
+ m_state[i] = m_new;
49
+ v_state[i] = v_new;
50
+
51
+ // Bias-corrected estimates
52
+ let m_hat = m_new / (1.0 - adam.beta1_t);
53
+ let v_hat = v_new / (1.0 - adam.beta2_t);
54
+
55
+ // Weight decay (decoupled) + gradient step
56
+ param[i] = p * (1.0 - adam.lr * adam.weight_decay) -
57
+ adam.lr * m_hat / (sqrt(v_hat) + adam.eps);
58
+ }
59
+ `;
60
+
61
+ // Gradient clipping kernel – clips global gradient norm to max_norm.
62
+ // Run before weight updates. Two-pass: first compute squared norm, then scale.
63
+ export const GRAD_CLIP_WGSL = /* wgsl */`
64
+
65
+ struct ClipParams {
66
+ num_elements : u32,
67
+ max_norm_sq : f32, // max_norm^2
68
+ };
69
+
70
+ @group(0) @binding(0) var<uniform> clip_p : ClipParams;
71
+ @group(0) @binding(1) var<storage, read_write> grad : array<f32>;
72
+ @group(0) @binding(2) var<storage, read_write> norm_sq : array<f32>; // size 1, atomic accumulator
73
+
74
+ var<workgroup> local_sq : array<f32, 256>;
75
+
76
+ // Pass 1: reduce sum of squares into norm_sq[0]
77
+ @compute @workgroup_size(256, 1, 1)
78
+ fn grad_norm_reduce(
79
+ @builtin(global_invocation_id) gid : vec3<u32>,
80
+ @builtin(local_invocation_index) lid : u32,
81
+ ) {
82
+ let i = gid.x;
83
+ local_sq[lid] = 0.0;
84
+ if (i < clip_p.num_elements) {
85
+ local_sq[lid] = grad[i] * grad[i];
86
+ }
87
+ workgroupBarrier();
88
+
89
+ // Parallel reduction within workgroup
90
+ var s: u32 = 128u;
91
+ loop {
92
+ if (s == 0u) { break; }
93
+ if (lid < s) {
94
+ local_sq[lid] = local_sq[lid] + local_sq[lid + s];
95
+ }
96
+ workgroupBarrier();
97
+ s = s >> 1u;
98
+ }
99
+
100
+ if (lid == 0u) {
101
+ // Non-atomic accumulation (single workgroup assumption for small models)
102
+ norm_sq[0] = norm_sq[0] + local_sq[0];
103
+ }
104
+ }
105
+
106
+ // Pass 2: scale gradients if norm exceeds max_norm
107
+ @compute @workgroup_size(256, 1, 1)
108
+ fn grad_clip_scale(
109
+ @builtin(global_invocation_id) gid : vec3<u32>,
110
+ ) {
111
+ let i = gid.x;
112
+ if (i >= clip_p.num_elements) { return; }
113
+
114
+ let ns = norm_sq[0];
115
+ if (ns > clip_p.max_norm_sq) {
116
+ let scale = sqrt(clip_p.max_norm_sq / ns);
117
+ grad[i] = grad[i] * scale;
118
+ }
119
+ }
120
+ `;