@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,307 @@
1
+ /**
2
+ * complex_ssd.ts – Complex-valued SSD kernels for Mamba-3.
3
+ *
4
+ * Three targeted improvements over Mamba-2 SSD:
5
+ *
6
+ * 1. Complex-valued states
7
+ * h ∈ ℂ^(N/2) stored as interleaved (real, imag) f32 pairs.
8
+ * A ∈ ℂ encoded as A_log[H, 2] = [log|A|, arg(A)].
9
+ *
10
+ * 2. Exponential-Trapezoidal (ET) discretisation
11
+ * A_bar = exp(Δ · A) (complex multiply)
12
+ * B_bar = (A_bar − 1) · A⁻¹ · B (exact, complex division)
13
+ *
14
+ * 3. MIMO recurrence (G groups of G inputs/outputs per head)
15
+ * Implemented here with G=1 (SISO) as the default; G>1 is a future
16
+ * extension that enlarges the B/C projections.
17
+ *
18
+ * Buffer layout:
19
+ * x : [B, L, D_inner] real-valued
20
+ * B_proj : [B, L, n_groups, N*2] interleaved complex (re,im)
21
+ * C_proj : [B, L, n_groups, N*2]
22
+ * dt : [B, L, H] real-valued
23
+ * A_log : [H, 2] [log|A|, arg(A)] per head
24
+ * dt_bias : [H]
25
+ * D_vec : [H]
26
+ * out : [B, L, D_inner] real-valued (Re(C·h))
27
+ * state_carry: [n_chunks+1, B, H, N*2, d_head] complex states
28
+ *
29
+ * Dispatch: (n_chunks, H, B)
30
+ */
31
+
32
+ export const COMPLEX_SSD_FORWARD_WGSL: string = /* wgsl */`
33
+ struct CssdParams {
34
+ seq_len : u32,
35
+ d_inner : u32,
36
+ n_heads : u32,
37
+ d_head : u32,
38
+ n_groups : u32,
39
+ n_complex : u32, // N/2 – number of complex state components
40
+ chunk_len : u32,
41
+ n_chunks : u32,
42
+ batch : u32,
43
+ };
44
+
45
+ @group(0) @binding(0) var<uniform> params : CssdParams;
46
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>;
47
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>; // complex: N_c*2 per token
48
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>;
49
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>;
50
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>; // [H, 2]
51
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
52
+ @group(0) @binding(7) var<storage, read> D_vec : array<f32>;
53
+ @group(0) @binding(8) var<storage, read_write> out_buf : array<f32>;
54
+ @group(0) @binding(9) var<storage, read_write> state_carry : array<f32>; // complex states
55
+
56
+ fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
57
+
58
+ // Complex multiply: (ar + i·ai) * (br + i·bi)
59
+ fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
60
+ fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
61
+
62
+ // Complex exp: exp(x + i·y) = exp(x)*(cos(y) + i*sin(y))
63
+ fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
64
+ fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
65
+
66
+ // ET discretisation B_bar = (A_bar - 1) * A^-1 * B
67
+ // A^-1 = 1/A = conj(A)/|A|^2. Here A = exp(log_mag)*exp(i*phase).
68
+ // |A| = exp(log_mag), A^-1 = exp(-log_mag)*exp(-i*phase)
69
+ // (A_bar - 1) * A^-1 = scalar complex product computed below.
70
+ fn et_bbar_re(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
71
+ // (A_bar - 1)
72
+ let num_re = a_bar_re - 1.0;
73
+ let num_im = a_bar_im;
74
+ // A^-1 = exp(-log_mag - i*phase)
75
+ let inv_re = cexp_re(-log_mag, -phase);
76
+ let inv_im = cexp_im(-log_mag, -phase);
77
+ return cmul_re(num_re, num_im, inv_re, inv_im);
78
+ }
79
+ fn et_bbar_im(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
80
+ let num_re = a_bar_re - 1.0;
81
+ let num_im = a_bar_im;
82
+ let inv_re = cexp_re(-log_mag, -phase);
83
+ let inv_im = cexp_im(-log_mag, -phase);
84
+ return cmul_im(num_re, num_im, inv_re, inv_im);
85
+ }
86
+
87
+ @compute @workgroup_size(1, 1, 1)
88
+ fn complex_ssd_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
89
+ let chunk_id = gid.x;
90
+ let head_id = gid.y;
91
+ let batch_id = gid.z;
92
+
93
+ let L = params.seq_len;
94
+ let D = params.d_inner;
95
+ let H = params.n_heads;
96
+ let dh = params.d_head;
97
+ let G = params.n_groups;
98
+ let Nc = params.n_complex; // complex state count
99
+ let N2 = Nc * 2u; // float pairs
100
+ let CL = params.chunk_len;
101
+ let B = params.batch;
102
+
103
+ let t_start = chunk_id * CL;
104
+ let t_end = min(t_start + CL, L);
105
+ let group_id = head_id * G / H;
106
+
107
+ // Load A for this head: A = exp(log_mag) * exp(i*phase)
108
+ let log_mag = A_log[head_id * 2u + 0u];
109
+ let phase = A_log[head_id * 2u + 1u];
110
+ let db = dt_bias[head_id];
111
+ let d_skip = D_vec[head_id];
112
+
113
+ // State buffer strides (complex: N2*dh floats per head)
114
+ let state_stride = B * H * N2 * dh;
115
+ let state_base_in = chunk_id * state_stride
116
+ + batch_id * H * N2 * dh
117
+ + head_id * N2 * dh;
118
+ let state_base_out = (chunk_id + 1u) * state_stride
119
+ + batch_id * H * N2 * dh
120
+ + head_id * N2 * dh;
121
+
122
+ // Copy carry-in to working slot
123
+ for (var s: u32 = 0u; s < N2 * dh; s = s + 1u) {
124
+ state_carry[state_base_out + s] = state_carry[state_base_in + s];
125
+ }
126
+
127
+ for (var t: u32 = t_start; t < t_end; t = t + 1u) {
128
+ let dt_idx = batch_id * L * H + t * H + head_id;
129
+ let dt_val = softplus(dt_in[dt_idx] + db);
130
+
131
+ // A_bar = exp(dt * A) = exp(dt*log_mag + i*dt*phase)
132
+ let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
133
+ let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
134
+
135
+ // ET B_bar scalar factor (applied per B_proj element)
136
+ let bbar_factor_re = et_bbar_re(a_bar_re, a_bar_im, log_mag, phase);
137
+ let bbar_factor_im = et_bbar_im(a_bar_re, a_bar_im, log_mag, phase);
138
+
139
+ let x_base = batch_id * L * D + t * D + head_id * dh;
140
+ // B_proj / C_proj: [B, L, G, N*2] — interleaved re/im
141
+ let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
142
+
143
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
144
+ let x_val = x_in[x_base + i];
145
+ var y_re = 0.0;
146
+
147
+ for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
148
+ let b_re = B_proj[bc_base + nc * 2u + 0u];
149
+ let b_im = B_proj[bc_base + nc * 2u + 1u];
150
+ let c_re = C_proj[bc_base + nc * 2u + 0u];
151
+ let c_im = C_proj[bc_base + nc * 2u + 1u];
152
+
153
+ // B_bar · x (complex * real = complex scale)
154
+ let inp_re = cmul_re(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
155
+ let inp_im = cmul_im(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
156
+
157
+ let s_re_idx = state_base_out + nc * 2u * dh + 0u * dh + i;
158
+ let s_im_idx = state_base_out + nc * 2u * dh + 1u * dh + i;
159
+
160
+ // h_t = A_bar * h_{t-1} + B_bar * x
161
+ let h_prev_re = state_carry[s_re_idx];
162
+ let h_prev_im = state_carry[s_im_idx];
163
+ let h_new_re = cmul_re(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_re;
164
+ let h_new_im = cmul_im(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_im;
165
+ state_carry[s_re_idx] = h_new_re;
166
+ state_carry[s_im_idx] = h_new_im;
167
+
168
+ // y += Re(C · h)
169
+ y_re = y_re + cmul_re(c_re, -c_im, h_new_re, h_new_im); // C·h real part
170
+ }
171
+
172
+ let out_idx = batch_id * L * D + t * D + head_id * dh + i;
173
+ out_buf[out_idx] = y_re + d_skip * x_val;
174
+ }
175
+ }
176
+ }
177
+ `;
178
+
179
+ // ── Backward ──────────────────────────────────────────────────────────────────
180
+
181
+ export const COMPLEX_SSD_BACKWARD_WGSL: string = /* wgsl */`
182
+ struct CssdParams {
183
+ seq_len : u32,
184
+ d_inner : u32,
185
+ n_heads : u32,
186
+ d_head : u32,
187
+ n_groups : u32,
188
+ n_complex : u32,
189
+ chunk_len : u32,
190
+ n_chunks : u32,
191
+ batch : u32,
192
+ };
193
+
194
+ @group(0) @binding(0) var<uniform> params : CssdParams;
195
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>;
196
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>;
197
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>;
198
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>;
199
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>;
200
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
201
+ @group(0) @binding(7) var<storage, read> state_carry : array<f32>;
202
+ @group(0) @binding(8) var<storage, read> dy : array<f32>;
203
+ @group(0) @binding(9) var<storage, read_write> dx : array<f32>;
204
+ @group(0) @binding(10) var<storage, read_write> dB : array<f32>;
205
+ @group(0) @binding(11) var<storage, read_write> dC : array<f32>;
206
+ @group(0) @binding(12) var<storage, read_write> ddt : array<f32>;
207
+ @group(0) @binding(13) var<storage, read_write> dA_log : array<f32>;
208
+ @group(0) @binding(14) var<storage, read_write> dD_vec : array<f32>;
209
+
210
+ fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
211
+ fn d_softplus(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
212
+ fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
213
+ fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
214
+ fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
215
+ fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
216
+
217
+ @compute @workgroup_size(1, 1, 1)
218
+ fn complex_ssd_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
219
+ let chunk_id = gid.x;
220
+ let head_id = gid.y;
221
+ let batch_id = gid.z;
222
+
223
+ let L = params.seq_len;
224
+ let D = params.d_inner;
225
+ let H = params.n_heads;
226
+ let dh = params.d_head;
227
+ let G = params.n_groups;
228
+ let Nc = params.n_complex;
229
+ let N2 = Nc * 2u;
230
+ let CL = params.chunk_len;
231
+ let B = params.batch;
232
+
233
+ let t_start = chunk_id * CL;
234
+ let t_end = min(t_start + CL, L);
235
+ let group_id = head_id * G / H;
236
+
237
+ let log_mag = A_log[head_id * 2u + 0u];
238
+ let phase = A_log[head_id * 2u + 1u];
239
+ let db = dt_bias[head_id];
240
+
241
+ let state_stride = B * H * N2 * dh;
242
+
243
+ for (var t_rev: u32 = 0u; t_rev < t_end - t_start; t_rev = t_rev + 1u) {
244
+ let t = t_end - 1u - t_rev;
245
+
246
+ let dt_idx = batch_id * L * H + t * H + head_id;
247
+ let dt_raw = dt_in[dt_idx] + db;
248
+ let dt_val = softplus(dt_raw);
249
+ let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
250
+ let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
251
+
252
+ let x_base = batch_id * L * D + t * D + head_id * dh;
253
+ let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
254
+ let state_base = (chunk_id + 1u) * state_stride
255
+ + batch_id * H * N2 * dh
256
+ + head_id * N2 * dh;
257
+ let state_prev = chunk_id * state_stride
258
+ + batch_id * H * N2 * dh
259
+ + head_id * N2 * dh;
260
+
261
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
262
+ let dy_val = dy[batch_id * L * D + t * D + head_id * dh + i];
263
+ let x_val = x_in[x_base + i];
264
+
265
+ dD_vec[head_id] = dD_vec[head_id] + dy_val * x_val;
266
+ dx[x_base + i] = dx[x_base + i] + dy_val;
267
+
268
+ for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
269
+ let c_re = C_proj[bc_base + nc * 2u + 0u];
270
+ let c_im = C_proj[bc_base + nc * 2u + 1u];
271
+ let b_re = B_proj[bc_base + nc * 2u + 0u];
272
+ let b_im = B_proj[bc_base + nc * 2u + 1u];
273
+
274
+ let h_re = state_carry[state_base + nc * 2u * dh + 0u * dh + i];
275
+ let h_im = state_carry[state_base + nc * 2u * dh + 1u * dh + i];
276
+
277
+ // dC from Re(C · h) output — gradient of Re(C·h) w.r.t. C is Re(h)
278
+ dC[bc_base + nc * 2u + 0u] = dC[bc_base + nc * 2u + 0u] + dy_val * h_re;
279
+ dC[bc_base + nc * 2u + 1u] = dC[bc_base + nc * 2u + 1u] - dy_val * h_im;
280
+
281
+ // dh from upstream: dh_re = c_re * dy, dh_im = -c_im * dy (Re(C·h) gradient)
282
+ let dh_re = c_re * dy_val;
283
+ let dh_im = -c_im * dy_val;
284
+
285
+ // dB: B_bar · x contributed h_new; gradient flows through B_bar
286
+ // simplified: dB += dh * x (ignoring complex B_bar Jacobian)
287
+ dB[bc_base + nc * 2u + 0u] = dB[bc_base + nc * 2u + 0u] + dh_re * x_val;
288
+ dB[bc_base + nc * 2u + 1u] = dB[bc_base + nc * 2u + 1u] + dh_im * x_val;
289
+
290
+ // dx += Re(B_bar* · dh) (simplified)
291
+ dx[x_base + i] = dx[x_base + i] + cmul_re(b_re, -b_im, dh_re, dh_im);
292
+
293
+ // ddt: from A_bar and B_bar dependence on dt
294
+ let h_prev_re = state_carry[state_prev + nc * 2u * dh + 0u * dh + i];
295
+ let h_prev_im = state_carry[state_prev + nc * 2u * dh + 1u * dh + i];
296
+ // dA_bar/ddt = A * A_bar
297
+ let da_bar_re = cmul_re(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
298
+ let da_bar_im = cmul_im(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
299
+ ddt[dt_idx] = ddt[dt_idx]
300
+ + (cmul_re(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_re
301
+ - cmul_im(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_im)
302
+ * d_softplus(dt_raw);
303
+ }
304
+ }
305
+ }
306
+ }
307
+ `;
@@ -0,0 +1,159 @@
1
+ // 1D Causal Convolution WGSL Kernel
2
+ // Implements a depthwise 1D causal convolution over the sequence dimension.
3
+ // "Causal" means the output at position t only depends on positions <= t,
4
+ // which is enforced by left-padding with (kernel_size - 1) zeros.
5
+ //
6
+ // Forward: y[b, t, d] = sum_{k=0}^{K-1} weight[d, k] * x[b, t-k, d]
7
+ // where x[b, t', d] = 0 for t' < 0 (causal padding)
8
+ //
9
+ // The `groups` uniform is included for Mamba-2/3 compatibility where conv runs
10
+ // over a fused (x, B, C) buffer. For standard Mamba-1 usage set groups = 1.
11
+ // The kernel math is identical; the field is reserved for future grouped
12
+ // weight-sharing variants.
13
+
14
+ export const CONV1D_FORWARD_WGSL = /* wgsl */`
15
+
16
+ struct ConvParams {
17
+ seq_len : u32, // L
18
+ d_channels : u32, // D (number of depthwise channels in this call)
19
+ kernel_size : u32, // K (typically 4)
20
+ batch : u32, // B
21
+ groups : u32, // number of channel groups (1 = standard depthwise)
22
+ };
23
+
24
+ @group(0) @binding(0) var<uniform> params : ConvParams;
25
+ // x (B, L, D) – input
26
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
27
+ // weight (D, K) – depthwise conv weights
28
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
29
+ // bias (D,) – optional bias (zeros if unused)
30
+ @group(0) @binding(3) var<storage, read> bias : array<f32>;
31
+ // y (B, L, D) – output
32
+ @group(0) @binding(4) var<storage, read_write> y : array<f32>;
33
+
34
+ // Dispatch: (ceil(L/16), ceil(D/16), B)
35
+ @compute @workgroup_size(16, 16, 1)
36
+ fn conv1d_forward(
37
+ @builtin(global_invocation_id) gid : vec3<u32>,
38
+ ) {
39
+ let L = params.seq_len;
40
+ let D = params.d_channels;
41
+ let K = params.kernel_size;
42
+ let B = params.batch;
43
+
44
+ let t = gid.x; // time position
45
+ let d = gid.y; // channel
46
+ let b = gid.z; // batch
47
+
48
+ if (t >= L || d >= D || b >= B) { return; }
49
+
50
+ var acc: f32 = 0.0;
51
+
52
+ // Causal: convolve over k = 0..K-1, reading position (t - k)
53
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
54
+ let w_idx = d * K + k;
55
+ let w_val = weight[w_idx];
56
+
57
+ // t - k: use causal zero-padding for t < k
58
+ if (t >= k) {
59
+ let src = b * L * D + (t - k) * D + d;
60
+ acc = acc + w_val * x[src];
61
+ }
62
+ // else: zero-padding contributes 0
63
+ }
64
+
65
+ acc = acc + bias[d];
66
+
67
+ let out = b * L * D + t * D + d;
68
+ y[out] = acc;
69
+ }
70
+ `;
71
+
72
+ // ---- Backward kernel for 1D convolution ----
73
+ export const CONV1D_BACKWARD_WGSL = /* wgsl */`
74
+
75
+ struct ConvParams {
76
+ seq_len : u32,
77
+ d_channels : u32,
78
+ kernel_size : u32,
79
+ batch : u32,
80
+ };
81
+
82
+ @group(0) @binding(0) var<uniform> params : ConvParams;
83
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
84
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
85
+ @group(0) @binding(3) var<storage, read> dy : array<f32>;
86
+ @group(0) @binding(4) var<storage, read_write> dx : array<f32>;
87
+ @group(0) @binding(5) var<storage, read_write> dweight : array<f32>;
88
+ @group(0) @binding(6) var<storage, read_write> dbias : array<f32>;
89
+
90
+ // Dispatch: (ceil(L/16), ceil(D/16), B) – computes dx
91
+ @compute @workgroup_size(16, 16, 1)
92
+ fn conv1d_backward_dx(
93
+ @builtin(global_invocation_id) gid : vec3<u32>,
94
+ ) {
95
+ let L = params.seq_len;
96
+ let D = params.d_channels;
97
+ let K = params.kernel_size;
98
+ let B = params.batch;
99
+
100
+ let t = gid.x;
101
+ let d = gid.y;
102
+ let b = gid.z;
103
+
104
+ if (t >= L || d >= D || b >= B) { return; }
105
+
106
+ var grad: f32 = 0.0;
107
+
108
+ // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]
109
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
110
+ let tp = t + k;
111
+ if (tp < L) {
112
+ let dy_idx = b * L * D + tp * D + d;
113
+ let w_idx = d * K + k;
114
+ grad = grad + dy[dy_idx] * weight[w_idx];
115
+ }
116
+ }
117
+
118
+ let dx_idx = b * L * D + t * D + d;
119
+ dx[dx_idx] = grad;
120
+ }
121
+
122
+ // Dispatch: (K, D, 1) – accumulates dweight over (B, L)
123
+ @compute @workgroup_size(1, 1, 1)
124
+ fn conv1d_backward_dw(
125
+ @builtin(global_invocation_id) gid : vec3<u32>,
126
+ ) {
127
+ let L = params.seq_len;
128
+ let D = params.d_channels;
129
+ let K = params.kernel_size;
130
+ let B = params.batch;
131
+
132
+ let k = gid.x;
133
+ let d = gid.y;
134
+
135
+ if (k >= K || d >= D) { return; }
136
+
137
+ var grad_w: f32 = 0.0;
138
+ var grad_b: f32 = 0.0;
139
+
140
+ for (var b: u32 = 0u; b < B; b = b + 1u) {
141
+ for (var t: u32 = 0u; t < L; t = t + 1u) {
142
+ let dy_idx = b * L * D + t * D + d;
143
+ let dy_val = dy[dy_idx];
144
+ if (t >= k) {
145
+ let x_idx = b * L * D + (t - k) * D + d;
146
+ grad_w = grad_w + dy_val * x[x_idx];
147
+ }
148
+ if (k == 0u) {
149
+ grad_b = grad_b + dy_val;
150
+ }
151
+ }
152
+ }
153
+
154
+ dweight[d * K + k] = grad_w;
155
+ if (k == 0u) {
156
+ dbias[d] = grad_b;
157
+ }
158
+ }
159
+ `;
@@ -0,0 +1,220 @@
1
+ // Linear Projection WGSL Kernel
2
+ // General-purpose matrix multiplication: Y = X @ W^T + b
3
+ // Supports the up-projection and down-projection linear layers in the Mamba block.
4
+ //
5
+ // Shapes:
6
+ // X : (batch * seq_len, in_features) – input (rows)
7
+ // W : (out_features, in_features) – weight matrix (row-major)
8
+ // b : (out_features,) – bias
9
+ // Y : (batch * seq_len, out_features) – output
10
+
11
+ export const LINEAR_FORWARD_WGSL: string = /* wgsl */`
12
+
13
+ struct LinearParams {
14
+ M : u32, // number of rows (batch * seq_len)
15
+ K : u32, // in_features
16
+ N : u32, // out_features
17
+ };
18
+
19
+ @group(0) @binding(0) var<uniform> params : LinearParams;
20
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
21
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
22
+ @group(0) @binding(3) var<storage, read> bias : array<f32>; // (N,)
23
+ @group(0) @binding(4) var<storage, read_write> Y : array<f32>; // (M, N)
24
+
25
+ // Tiled matmul using workgroup shared memory (16x16 tiles)
26
+ var<workgroup> tile_X : array<f32, 256>; // 16 * 16
27
+ var<workgroup> tile_W : array<f32, 256>;
28
+
29
+ @compute @workgroup_size(16, 16, 1)
30
+ fn linear_forward(
31
+ @builtin(global_invocation_id) gid : vec3<u32>,
32
+ @builtin(local_invocation_id) lid : vec3<u32>,
33
+ @builtin(workgroup_id) wid : vec3<u32>,
34
+ ) {
35
+ let M = params.M;
36
+ let K = params.K;
37
+ let N = params.N;
38
+
39
+ let row = gid.x; // output row (M dimension)
40
+ let col = gid.y; // output col (N dimension)
41
+
42
+ var acc: f32 = 0.0;
43
+ let TILE: u32 = 16u;
44
+ let num_tiles = (K + TILE - 1u) / TILE;
45
+
46
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
47
+ // Load X tile: shape (TILE_M, TILE_K)
48
+ let x_col = tile_idx * TILE + lid.y;
49
+ let x_row = wid.x * TILE + lid.x;
50
+ if (x_row < M && x_col < K) {
51
+ tile_X[lid.x * TILE + lid.y] = X[x_row * K + x_col];
52
+ } else {
53
+ tile_X[lid.x * TILE + lid.y] = 0.0;
54
+ }
55
+
56
+ // Load W tile: shape (TILE_N, TILE_K) — W is (N, K)
57
+ let w_col = tile_idx * TILE + lid.x; // K dimension
58
+ let w_row = wid.y * TILE + lid.y; // N dimension
59
+ if (w_row < N && w_col < K) {
60
+ tile_W[lid.y * TILE + lid.x] = W[w_row * K + w_col];
61
+ } else {
62
+ tile_W[lid.y * TILE + lid.x] = 0.0;
63
+ }
64
+
65
+ workgroupBarrier();
66
+
67
+ // Dot product within tile
68
+ for (var k: u32 = 0u; k < TILE; k = k + 1u) {
69
+ acc = acc + tile_X[lid.x * TILE + k] * tile_W[lid.y * TILE + k];
70
+ }
71
+ workgroupBarrier();
72
+ }
73
+
74
+ if (row < M && col < N) {
75
+ Y[row * N + col] = acc + bias[col];
76
+ }
77
+ }
78
+ `;
79
+
80
+ // ---- Backward pass for linear projection ----
81
+ export const LINEAR_BACKWARD_WGSL: string = /* wgsl */`
82
+
83
+ struct LinearParams {
84
+ M : u32,
85
+ K : u32,
86
+ N : u32,
87
+ };
88
+
89
+ @group(0) @binding(0) var<uniform> params : LinearParams;
90
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
91
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
92
+ @group(0) @binding(3) var<storage, read> dY : array<f32>; // (M, N)
93
+ @group(0) @binding(4) var<storage, read_write> dX : array<f32>; // (M, K)
94
+ @group(0) @binding(5) var<storage, read_write> dW : array<f32>; // (N, K)
95
+ @group(0) @binding(6) var<storage, read_write> db : array<f32>; // (N,)
96
+
97
+ // Dispatch: (ceil(M/16), ceil(K/16), 1) – computes dX = dY @ W
98
+ var<workgroup> tile_dY : array<f32, 256>;
99
+ var<workgroup> tile_W : array<f32, 256>;
100
+
101
+ @compute @workgroup_size(16, 16, 1)
102
+ fn linear_backward_dX(
103
+ @builtin(global_invocation_id) gid : vec3<u32>,
104
+ @builtin(local_invocation_id) lid : vec3<u32>,
105
+ @builtin(workgroup_id) wid : vec3<u32>,
106
+ ) {
107
+ let M = params.M;
108
+ let K = params.K;
109
+ let N = params.N;
110
+
111
+ let row = gid.x; // M
112
+ let col = gid.y; // K
113
+
114
+ var acc: f32 = 0.0;
115
+ let TILE: u32 = 16u;
116
+ let num_tiles = (N + TILE - 1u) / TILE;
117
+
118
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
119
+ // tile_dY: (M, TILE_N) slice
120
+ let dy_col = tile_idx * TILE + lid.y;
121
+ let dy_row = wid.x * TILE + lid.x;
122
+ if (dy_row < M && dy_col < N) {
123
+ tile_dY[lid.x * TILE + lid.y] = dY[dy_row * N + dy_col];
124
+ } else {
125
+ tile_dY[lid.x * TILE + lid.y] = 0.0;
126
+ }
127
+
128
+ // tile_W: (TILE_N, K) slice — W[n, k]
129
+ let w_row = tile_idx * TILE + lid.x; // N
130
+ let w_col = wid.y * TILE + lid.y; // K
131
+ if (w_row < N && w_col < K) {
132
+ tile_W[lid.x * TILE + lid.y] = W[w_row * K + w_col];
133
+ } else {
134
+ tile_W[lid.x * TILE + lid.y] = 0.0;
135
+ }
136
+
137
+ workgroupBarrier();
138
+
139
+ for (var n: u32 = 0u; n < TILE; n = n + 1u) {
140
+ acc = acc + tile_dY[lid.x * TILE + n] * tile_W[n * TILE + lid.y];
141
+ }
142
+ workgroupBarrier();
143
+ }
144
+
145
+ if (row < M && col < K) {
146
+ dX[row * K + col] = acc;
147
+ }
148
+ }
149
+
150
+ // Dispatch: (ceil(N/16), ceil(K/16), 1) – computes dW = dY^T @ X
151
+ var<workgroup> tile_dY2 : array<f32, 256>;
152
+ var<workgroup> tile_X2 : array<f32, 256>;
153
+
154
+ @compute @workgroup_size(16, 16, 1)
155
+ fn linear_backward_dW(
156
+ @builtin(global_invocation_id) gid : vec3<u32>,
157
+ @builtin(local_invocation_id) lid : vec3<u32>,
158
+ @builtin(workgroup_id) wid : vec3<u32>,
159
+ ) {
160
+ let M = params.M;
161
+ let K = params.K;
162
+ let N = params.N;
163
+
164
+ let row = gid.x; // N
165
+ let col = gid.y; // K
166
+
167
+ var acc: f32 = 0.0;
168
+ let TILE: u32 = 16u;
169
+ let num_tiles = (M + TILE - 1u) / TILE;
170
+
171
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
172
+ // dY^T tile: [N, M] accessed as dY[m, n]
173
+ let m_idx = tile_idx * TILE + lid.y;
174
+ let n_idx = wid.x * TILE + lid.x;
175
+ if (n_idx < N && m_idx < M) {
176
+ tile_dY2[lid.x * TILE + lid.y] = dY[m_idx * N + n_idx];
177
+ } else {
178
+ tile_dY2[lid.x * TILE + lid.y] = 0.0;
179
+ }
180
+
181
+ // X tile: [M, K]
182
+ let xm = tile_idx * TILE + lid.x;
183
+ let xk = wid.y * TILE + lid.y;
184
+ if (xm < M && xk < K) {
185
+ tile_X2[lid.x * TILE + lid.y] = X[xm * K + xk];
186
+ } else {
187
+ tile_X2[lid.x * TILE + lid.y] = 0.0;
188
+ }
189
+
190
+ workgroupBarrier();
191
+
192
+ for (var m: u32 = 0u; m < TILE; m = m + 1u) {
193
+ acc = acc + tile_dY2[lid.x * TILE + m] * tile_X2[m * TILE + lid.y];
194
+ }
195
+ workgroupBarrier();
196
+ }
197
+
198
+ if (row < N && col < K) {
199
+ dW[row * K + col] = acc;
200
+ }
201
+ }
202
+
203
+ // Dispatch: (N, 1, 1) – accumulates db = sum_M dY
204
+ @compute @workgroup_size(64, 1, 1)
205
+ fn linear_backward_db(
206
+ @builtin(global_invocation_id) gid : vec3<u32>,
207
+ ) {
208
+ let M = params.M;
209
+ let N = params.N;
210
+
211
+ let n = gid.x;
212
+ if (n >= N) { return; }
213
+
214
+ var acc: f32 = 0.0;
215
+ for (var m: u32 = 0u; m < M; m = m + 1u) {
216
+ acc = acc + dY[m * N + n];
217
+ }
218
+ db[n] = acc;
219
+ }
220
+ `;