@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1 @@
1
+ {"version":3,"file":"activations.js","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAAA,wEAAwE;AACxE,yDAAyD;AAEzD,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyDjD,CAAC;AAEF,yDAAyD;AACzD,kEAAkE;AAClE,8DAA8D;AAC9D,uEAAuE;AACvE,8DAA8D;AAC9D,MAAM,CAAC,MAAM,oBAAoB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA2CrD,CAAC;AAEF,MAAM,CAAC,MAAM,qBAAqB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAkCtD,CAAC;AAEF,8BAA8B;AAC9B,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;CAwB1D,CAAC"}
@@ -0,0 +1,19 @@
1
+ /**
2
+ * attention.ts – Causal multi-head self-attention kernels.
3
+ *
4
+ * Implements tiled 16×16 causal attention suitable for WebGPU.
5
+ * No Flash-Attention dependency — straightforward O(L²) with causal mask.
6
+ *
7
+ * Buffer layout:
8
+ * qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
9
+ * out_buf : [B, L, D_model]
10
+ * scores : [B, H, L, L] intermediate (written then read by kernel)
11
+ *
12
+ * Dispatch attention_forward: (ceil(L/16), H, B)
13
+ * Dispatch softmax_forward: (L, H, B) — one workgroup per row
14
+ * Dispatch attention_backward: (ceil(L/16), H, B)
15
+ */
16
+ export declare const SOFTMAX_WGSL: string;
17
+ export declare const ATTENTION_FORWARD_WGSL: string;
18
+ export declare const ATTENTION_BACKWARD_WGSL: string;
19
+ //# sourceMappingURL=attention.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"attention.d.ts","sourceRoot":"","sources":["../../src/kernels/attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;GAcG;AAIH,eAAO,MAAM,YAAY,EAAE,MAuE1B,CAAC;AAIF,eAAO,MAAM,sBAAsB,EAAE,MAwGpC,CAAC;AAIF,eAAO,MAAM,uBAAuB,EAAE,MAkErC,CAAC"}
@@ -0,0 +1,263 @@
1
+ /**
2
+ * attention.ts – Causal multi-head self-attention kernels.
3
+ *
4
+ * Implements tiled 16×16 causal attention suitable for WebGPU.
5
+ * No Flash-Attention dependency — straightforward O(L²) with causal mask.
6
+ *
7
+ * Buffer layout:
8
+ * qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
9
+ * out_buf : [B, L, D_model]
10
+ * scores : [B, H, L, L] intermediate (written then read by kernel)
11
+ *
12
+ * Dispatch attention_forward: (ceil(L/16), H, B)
13
+ * Dispatch softmax_forward: (L, H, B) — one workgroup per row
14
+ * Dispatch attention_backward: (ceil(L/16), H, B)
15
+ */
16
+ // ── Softmax ───────────────────────────────────────────────────────────────────
17
+ export const SOFTMAX_WGSL = /* wgsl */ `
18
+ struct SoftmaxParams {
19
+ rows : u32, // L
20
+ cols : u32, // L (score matrix is L×L per head)
21
+ };
22
+
23
+ @group(0) @binding(0) var<uniform> params : SoftmaxParams;
24
+ @group(0) @binding(1) var<storage, read_write> data : array<f32>;
25
+
26
+ // One workgroup per row; each invocation handles one element within the row.
27
+ // Workgroup size 64 – cooperative reduction for max and sum.
28
+ var<workgroup> wg_max : array<f32, 64>;
29
+ var<workgroup> wg_sum : array<f32, 64>;
30
+
31
+ @compute @workgroup_size(64, 1, 1)
32
+ fn softmax_forward(@builtin(global_invocation_id) gid: vec3<u32>,
33
+ @builtin(local_invocation_id) lid: vec3<u32>,
34
+ @builtin(workgroup_id) wid: vec3<u32>) {
35
+ let row = wid.x; // L row index
36
+ let head = wid.y;
37
+ let bat = wid.z;
38
+ let cols = params.cols;
39
+
40
+ if (row >= params.rows) { return; }
41
+
42
+ let base = (bat * params.rows * cols * /* nHeads from outer dispatch */ 1u)
43
+ + row * cols;
44
+
45
+ // Step 1: find row max (with causal mask: positions > row are -inf)
46
+ var local_max = -1e38;
47
+ for (var c = lid.x; c < cols; c = c + 64u) {
48
+ var v = -1e38;
49
+ if (c <= row) { v = data[base + c]; }
50
+ if (v > local_max) { local_max = v; }
51
+ }
52
+ wg_max[lid.x] = local_max;
53
+ workgroupBarrier();
54
+ for (var s = 32u; s >= 1u; s = s >> 1u) {
55
+ if (lid.x < s) {
56
+ if (wg_max[lid.x + s] > wg_max[lid.x]) {
57
+ wg_max[lid.x] = wg_max[lid.x + s];
58
+ }
59
+ }
60
+ workgroupBarrier();
61
+ }
62
+ let row_max = wg_max[0u];
63
+
64
+ // Step 2: exp and sum
65
+ var local_sum = 0.0;
66
+ for (var c = lid.x; c < cols; c = c + 64u) {
67
+ if (c <= row) {
68
+ let e = exp(data[base + c] - row_max);
69
+ data[base + c] = e;
70
+ local_sum = local_sum + e;
71
+ } else {
72
+ data[base + c] = 0.0;
73
+ }
74
+ }
75
+ wg_sum[lid.x] = local_sum;
76
+ workgroupBarrier();
77
+ for (var s = 32u; s >= 1u; s = s >> 1u) {
78
+ if (lid.x < s) { wg_sum[lid.x] = wg_sum[lid.x] + wg_sum[lid.x + s]; }
79
+ workgroupBarrier();
80
+ }
81
+ let inv_sum = 1.0 / (wg_sum[0u] + 1e-12);
82
+
83
+ // Step 3: normalise
84
+ for (var c = lid.x; c <= row; c = c + 64u) {
85
+ data[base + c] = data[base + c] * inv_sum;
86
+ }
87
+ }
88
+ `;
89
+ // ── Attention forward ─────────────────────────────────────────────────────────
90
+ export const ATTENTION_FORWARD_WGSL = /* wgsl */ `
91
+ struct AttnParams {
92
+ batch : u32,
93
+ seq_len : u32,
94
+ d_model : u32,
95
+ n_heads : u32,
96
+ d_head : u32,
97
+ };
98
+
99
+ @group(0) @binding(0) var<uniform> params : AttnParams;
100
+ // Q, K, V packed: [B, L, 3, H, d_head] (after projection split)
101
+ @group(0) @binding(1) var<storage, read> Q : array<f32>; // [B,L,H,dh]
102
+ @group(0) @binding(2) var<storage, read> K : array<f32>; // [B,L,H,dh]
103
+ @group(0) @binding(3) var<storage, read> V : array<f32>; // [B,L,H,dh]
104
+ @group(0) @binding(4) var<storage, read_write> scores : array<f32>; // [B,H,L,L]
105
+ @group(0) @binding(5) var<storage, read_write> out_buf : array<f32>; // [B,L,H,dh]
106
+
107
+ // Tiled 16×16 shared memory for Q row and K col
108
+ var<workgroup> tile_q : array<f32, 256>; // 16 tokens × 16 d_head
109
+ var<workgroup> tile_k : array<f32, 256>;
110
+
111
+ @compute @workgroup_size(16, 16, 1)
112
+ fn attention_forward(@builtin(global_invocation_id) gid: vec3<u32>,
113
+ @builtin(local_invocation_id) lid: vec3<u32>,
114
+ @builtin(workgroup_id) wid: vec3<u32>) {
115
+ let q_tile = wid.x; // tile index along query (row) dimension
116
+ let head = wid.y;
117
+ let batch = wid.z;
118
+
119
+ let B = params.batch;
120
+ let L = params.seq_len;
121
+ let H = params.n_heads;
122
+ let dh = params.d_head;
123
+ let inv_sqrt = 1.0 / sqrt(f32(dh));
124
+
125
+ let row = q_tile * 16u + lid.x; // query token index
126
+ let col = lid.y; // key token index offset within tile
127
+
128
+ if (row >= L) { return; }
129
+
130
+ // ── Phase 1: Compute raw attention scores for all K positions ──────────
131
+ // scores[batch, head, row, k] = Q[row] · K[k] / sqrt(dh)
132
+ // We iterate over K tiles
133
+ let q_base = batch * L * H * dh + row * H * dh + head * dh;
134
+
135
+ for (var k_start: u32 = 0u; k_start <= row; k_start = k_start + 16u) {
136
+ let k_tok = k_start + lid.y;
137
+
138
+ // Load Q row tile into shared memory (lid.y = 0..15 element index)
139
+ if (lid.y < dh && lid.y < 16u) {
140
+ tile_q[lid.x * 16u + lid.y] = Q[q_base + lid.y];
141
+ }
142
+ // Load K col tile
143
+ if (k_tok < L && lid.x < dh && lid.x < 16u) {
144
+ let k_base = batch * L * H * dh + k_tok * H * dh + head * dh;
145
+ tile_k[lid.y * 16u + lid.x] = K[k_base + lid.x];
146
+ } else if (lid.x < 16u) {
147
+ tile_k[lid.y * 16u + lid.x] = 0.0;
148
+ }
149
+ workgroupBarrier();
150
+
151
+ // Dot product: accumulate over dh
152
+ if (k_tok <= row) {
153
+ var acc = 0.0;
154
+ for (var d = 0u; d < min(dh, 16u); d = d + 1u) {
155
+ acc = acc + tile_q[lid.x * 16u + d] * tile_k[lid.y * 16u + d];
156
+ }
157
+ let score_idx = batch * H * L * L + head * L * L + row * L + k_tok;
158
+ scores[score_idx] = acc * inv_sqrt;
159
+ }
160
+ workgroupBarrier();
161
+ }
162
+ }
163
+
164
+ // Phase 2: softmax is dispatched separately via softmax_forward kernel.
165
+
166
+ // Phase 3: weighted sum of V
167
+ @compute @workgroup_size(16, 16, 1)
168
+ fn attention_value(@builtin(global_invocation_id) gid: vec3<u32>,
169
+ @builtin(local_invocation_id) lid: vec3<u32>,
170
+ @builtin(workgroup_id) wid: vec3<u32>) {
171
+ let q_tile = wid.x;
172
+ let head = wid.y;
173
+ let batch = wid.z;
174
+
175
+ let L = params.seq_len;
176
+ let H = params.n_heads;
177
+ let dh = params.d_head;
178
+
179
+ let row = q_tile * 16u + lid.x;
180
+ let d = lid.y; // d_head dimension
181
+
182
+ if (row >= L || d >= dh) { return; }
183
+
184
+ var acc = 0.0;
185
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
186
+ let score_idx = batch * H * L * L + head * L * L + row * L + k;
187
+ let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
188
+ acc = acc + scores[score_idx] * V[v_idx];
189
+ }
190
+
191
+ let out_idx = batch * L * H * dh + row * H * dh + head * dh + d;
192
+ out_buf[out_idx] = acc;
193
+ }
194
+ `;
195
+ // ── Attention backward ────────────────────────────────────────────────────────
196
+ export const ATTENTION_BACKWARD_WGSL = /* wgsl */ `
197
+ struct AttnParams {
198
+ batch : u32,
199
+ seq_len : u32,
200
+ d_model : u32,
201
+ n_heads : u32,
202
+ d_head : u32,
203
+ };
204
+
205
+ @group(0) @binding(0) var<uniform> params : AttnParams;
206
+ @group(0) @binding(1) var<storage, read> Q : array<f32>;
207
+ @group(0) @binding(2) var<storage, read> K : array<f32>;
208
+ @group(0) @binding(3) var<storage, read> V : array<f32>;
209
+ @group(0) @binding(4) var<storage, read> scores : array<f32>; // post-softmax
210
+ @group(0) @binding(5) var<storage, read> dy : array<f32>; // [B,L,H,dh]
211
+ @group(0) @binding(6) var<storage, read_write> dQ : array<f32>;
212
+ @group(0) @binding(7) var<storage, read_write> dK : array<f32>;
213
+ @group(0) @binding(8) var<storage, read_write> dV : array<f32>;
214
+ @group(0) @binding(9) var<storage, read_write> dscores : array<f32>;
215
+
216
+ @compute @workgroup_size(16, 16, 1)
217
+ fn attention_backward(@builtin(global_invocation_id) gid: vec3<u32>,
218
+ @builtin(local_invocation_id) lid: vec3<u32>,
219
+ @builtin(workgroup_id) wid: vec3<u32>) {
220
+ let q_tile = wid.x;
221
+ let head = wid.y;
222
+ let batch = wid.z;
223
+
224
+ let L = params.seq_len;
225
+ let H = params.n_heads;
226
+ let dh = params.d_head;
227
+ let inv_sqrt = 1.0 / sqrt(f32(dh));
228
+
229
+ let row = q_tile * 16u + lid.x;
230
+ let d = lid.y;
231
+
232
+ if (row >= L || d >= dh) { return; }
233
+
234
+ // dV[k, d] += score[row, k] * dy[row, d]
235
+ // dscores[row, k] += dy[row, d] * V[k, d] (before softmax backward)
236
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
237
+ let s_idx = batch * H * L * L + head * L * L + row * L + k;
238
+ let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
239
+ let dy_idx = batch * L * H * dh + row * H * dh + head * dh + d;
240
+
241
+ dV[v_idx] = dV[v_idx] + scores[s_idx] * dy[dy_idx];
242
+ dscores[s_idx] = dscores[s_idx] + dy[dy_idx] * V[v_idx];
243
+ }
244
+
245
+ // dQ[row, d] += sum_k dscores_post_softmax[row, k] * K[k, d] * inv_sqrt
246
+ var dq_acc = 0.0;
247
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
248
+ let ds_idx = batch * H * L * L + head * L * L + row * L + k;
249
+ let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
250
+ dq_acc = dq_acc + dscores[ds_idx] * K[k_idx];
251
+ }
252
+ let q_idx = batch * L * H * dh + row * H * dh + head * dh + d;
253
+ dQ[q_idx] = dQ[q_idx] + dq_acc * inv_sqrt;
254
+
255
+ // dK[k, d] += dscores[row, k] * Q[row, d] * inv_sqrt (for all rows >= k)
256
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
257
+ let ds_idx = batch * H * L * L + head * L * L + row * L + k;
258
+ let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
259
+ dK[k_idx] = dK[k_idx] + dscores[ds_idx] * Q[q_idx] * inv_sqrt;
260
+ }
261
+ }
262
+ `;
263
+ //# sourceMappingURL=attention.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"attention.js","sourceRoot":"","sources":["../../src/kernels/attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;GAcG;AAEH,iFAAiF;AAEjF,MAAM,CAAC,MAAM,YAAY,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuE7C,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,sBAAsB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAwGvD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,uBAAuB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAkExD,CAAC"}
@@ -0,0 +1,33 @@
1
+ /**
2
+ * complex_ssd.ts – Complex-valued SSD kernels for Mamba-3.
3
+ *
4
+ * Three targeted improvements over Mamba-2 SSD:
5
+ *
6
+ * 1. Complex-valued states
7
+ * h ∈ ℂ^(N/2) stored as interleaved (real, imag) f32 pairs.
8
+ * A ∈ ℂ encoded as A_log[H, 2] = [log|A|, arg(A)].
9
+ *
10
+ * 2. Exponential-Trapezoidal (ET) discretisation
11
+ * A_bar = exp(Δ · A) (complex multiply)
12
+ * B_bar = (A_bar − 1) · A⁻¹ · B (exact, complex division)
13
+ *
14
+ * 3. MIMO recurrence (G groups of G inputs/outputs per head)
15
+ * Implemented here with G=1 (SISO) as the default; G>1 is a future
16
+ * extension that enlarges the B/C projections.
17
+ *
18
+ * Buffer layout:
19
+ * x : [B, L, D_inner] real-valued
20
+ * B_proj : [B, L, n_groups, N*2] interleaved complex (re,im)
21
+ * C_proj : [B, L, n_groups, N*2]
22
+ * dt : [B, L, H] real-valued
23
+ * A_log : [H, 2] [log|A|, arg(A)] per head
24
+ * dt_bias : [H]
25
+ * D_vec : [H]
26
+ * out : [B, L, D_inner] real-valued (Re(C·h))
27
+ * state_carry: [n_chunks+1, B, H, N*2, d_head] complex states
28
+ *
29
+ * Dispatch: (n_chunks, H, B)
30
+ */
31
+ export declare const COMPLEX_SSD_FORWARD_WGSL: string;
32
+ export declare const COMPLEX_SSD_BACKWARD_WGSL: string;
33
+ //# sourceMappingURL=complex_ssd.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"complex_ssd.d.ts","sourceRoot":"","sources":["../../src/kernels/complex_ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AAEH,eAAO,MAAM,wBAAwB,EAAE,MAiJtC,CAAC;AAIF,eAAO,MAAM,yBAAyB,EAAE,MA8HvC,CAAC"}
@@ -0,0 +1,305 @@
1
+ /**
2
+ * complex_ssd.ts – Complex-valued SSD kernels for Mamba-3.
3
+ *
4
+ * Three targeted improvements over Mamba-2 SSD:
5
+ *
6
+ * 1. Complex-valued states
7
+ * h ∈ ℂ^(N/2) stored as interleaved (real, imag) f32 pairs.
8
+ * A ∈ ℂ encoded as A_log[H, 2] = [log|A|, arg(A)].
9
+ *
10
+ * 2. Exponential-Trapezoidal (ET) discretisation
11
+ * A_bar = exp(Δ · A) (complex multiply)
12
+ * B_bar = (A_bar − 1) · A⁻¹ · B (exact, complex division)
13
+ *
14
+ * 3. MIMO recurrence (G groups of G inputs/outputs per head)
15
+ * Implemented here with G=1 (SISO) as the default; G>1 is a future
16
+ * extension that enlarges the B/C projections.
17
+ *
18
+ * Buffer layout:
19
+ * x : [B, L, D_inner] real-valued
20
+ * B_proj : [B, L, n_groups, N*2] interleaved complex (re,im)
21
+ * C_proj : [B, L, n_groups, N*2]
22
+ * dt : [B, L, H] real-valued
23
+ * A_log : [H, 2] [log|A|, arg(A)] per head
24
+ * dt_bias : [H]
25
+ * D_vec : [H]
26
+ * out : [B, L, D_inner] real-valued (Re(C·h))
27
+ * state_carry: [n_chunks+1, B, H, N*2, d_head] complex states
28
+ *
29
+ * Dispatch: (n_chunks, H, B)
30
+ */
31
+ export const COMPLEX_SSD_FORWARD_WGSL = /* wgsl */ `
32
+ struct CssdParams {
33
+ seq_len : u32,
34
+ d_inner : u32,
35
+ n_heads : u32,
36
+ d_head : u32,
37
+ n_groups : u32,
38
+ n_complex : u32, // N/2 – number of complex state components
39
+ chunk_len : u32,
40
+ n_chunks : u32,
41
+ batch : u32,
42
+ };
43
+
44
+ @group(0) @binding(0) var<uniform> params : CssdParams;
45
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>;
46
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>; // complex: N_c*2 per token
47
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>;
48
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>;
49
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>; // [H, 2]
50
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
51
+ @group(0) @binding(7) var<storage, read> D_vec : array<f32>;
52
+ @group(0) @binding(8) var<storage, read_write> out_buf : array<f32>;
53
+ @group(0) @binding(9) var<storage, read_write> state_carry : array<f32>; // complex states
54
+
55
+ fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
56
+
57
+ // Complex multiply: (ar + i·ai) * (br + i·bi)
58
+ fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
59
+ fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
60
+
61
+ // Complex exp: exp(x + i·y) = exp(x)*(cos(y) + i*sin(y))
62
+ fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
63
+ fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
64
+
65
+ // ET discretisation B_bar = (A_bar - 1) * A^-1 * B
66
+ // A^-1 = 1/A = conj(A)/|A|^2. Here A = exp(log_mag)*exp(i*phase).
67
+ // |A| = exp(log_mag), A^-1 = exp(-log_mag)*exp(-i*phase)
68
+ // (A_bar - 1) * A^-1 = scalar complex product computed below.
69
+ fn et_bbar_re(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
70
+ // (A_bar - 1)
71
+ let num_re = a_bar_re - 1.0;
72
+ let num_im = a_bar_im;
73
+ // A^-1 = exp(-log_mag - i*phase)
74
+ let inv_re = cexp_re(-log_mag, -phase);
75
+ let inv_im = cexp_im(-log_mag, -phase);
76
+ return cmul_re(num_re, num_im, inv_re, inv_im);
77
+ }
78
+ fn et_bbar_im(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
79
+ let num_re = a_bar_re - 1.0;
80
+ let num_im = a_bar_im;
81
+ let inv_re = cexp_re(-log_mag, -phase);
82
+ let inv_im = cexp_im(-log_mag, -phase);
83
+ return cmul_im(num_re, num_im, inv_re, inv_im);
84
+ }
85
+
86
+ @compute @workgroup_size(1, 1, 1)
87
+ fn complex_ssd_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
88
+ let chunk_id = gid.x;
89
+ let head_id = gid.y;
90
+ let batch_id = gid.z;
91
+
92
+ let L = params.seq_len;
93
+ let D = params.d_inner;
94
+ let H = params.n_heads;
95
+ let dh = params.d_head;
96
+ let G = params.n_groups;
97
+ let Nc = params.n_complex; // complex state count
98
+ let N2 = Nc * 2u; // float pairs
99
+ let CL = params.chunk_len;
100
+ let B = params.batch;
101
+
102
+ let t_start = chunk_id * CL;
103
+ let t_end = min(t_start + CL, L);
104
+ let group_id = head_id * G / H;
105
+
106
+ // Load A for this head: A = exp(log_mag) * exp(i*phase)
107
+ let log_mag = A_log[head_id * 2u + 0u];
108
+ let phase = A_log[head_id * 2u + 1u];
109
+ let db = dt_bias[head_id];
110
+ let d_skip = D_vec[head_id];
111
+
112
+ // State buffer strides (complex: N2*dh floats per head)
113
+ let state_stride = B * H * N2 * dh;
114
+ let state_base_in = chunk_id * state_stride
115
+ + batch_id * H * N2 * dh
116
+ + head_id * N2 * dh;
117
+ let state_base_out = (chunk_id + 1u) * state_stride
118
+ + batch_id * H * N2 * dh
119
+ + head_id * N2 * dh;
120
+
121
+ // Copy carry-in to working slot
122
+ for (var s: u32 = 0u; s < N2 * dh; s = s + 1u) {
123
+ state_carry[state_base_out + s] = state_carry[state_base_in + s];
124
+ }
125
+
126
+ for (var t: u32 = t_start; t < t_end; t = t + 1u) {
127
+ let dt_idx = batch_id * L * H + t * H + head_id;
128
+ let dt_val = softplus(dt_in[dt_idx] + db);
129
+
130
+ // A_bar = exp(dt * A) = exp(dt*log_mag + i*dt*phase)
131
+ let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
132
+ let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
133
+
134
+ // ET B_bar scalar factor (applied per B_proj element)
135
+ let bbar_factor_re = et_bbar_re(a_bar_re, a_bar_im, log_mag, phase);
136
+ let bbar_factor_im = et_bbar_im(a_bar_re, a_bar_im, log_mag, phase);
137
+
138
+ let x_base = batch_id * L * D + t * D + head_id * dh;
139
+ // B_proj / C_proj: [B, L, G, N*2] — interleaved re/im
140
+ let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
141
+
142
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
143
+ let x_val = x_in[x_base + i];
144
+ var y_re = 0.0;
145
+
146
+ for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
147
+ let b_re = B_proj[bc_base + nc * 2u + 0u];
148
+ let b_im = B_proj[bc_base + nc * 2u + 1u];
149
+ let c_re = C_proj[bc_base + nc * 2u + 0u];
150
+ let c_im = C_proj[bc_base + nc * 2u + 1u];
151
+
152
+ // B_bar · x (complex * real = complex scale)
153
+ let inp_re = cmul_re(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
154
+ let inp_im = cmul_im(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
155
+
156
+ let s_re_idx = state_base_out + nc * 2u * dh + 0u * dh + i;
157
+ let s_im_idx = state_base_out + nc * 2u * dh + 1u * dh + i;
158
+
159
+ // h_t = A_bar * h_{t-1} + B_bar * x
160
+ let h_prev_re = state_carry[s_re_idx];
161
+ let h_prev_im = state_carry[s_im_idx];
162
+ let h_new_re = cmul_re(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_re;
163
+ let h_new_im = cmul_im(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_im;
164
+ state_carry[s_re_idx] = h_new_re;
165
+ state_carry[s_im_idx] = h_new_im;
166
+
167
+ // y += Re(C · h)
168
+ y_re = y_re + cmul_re(c_re, -c_im, h_new_re, h_new_im); // C·h real part
169
+ }
170
+
171
+ let out_idx = batch_id * L * D + t * D + head_id * dh + i;
172
+ out_buf[out_idx] = y_re + d_skip * x_val;
173
+ }
174
+ }
175
+ }
176
+ `;
177
+ // ── Backward ──────────────────────────────────────────────────────────────────
178
+ export const COMPLEX_SSD_BACKWARD_WGSL = /* wgsl */ `
179
+ struct CssdParams {
180
+ seq_len : u32,
181
+ d_inner : u32,
182
+ n_heads : u32,
183
+ d_head : u32,
184
+ n_groups : u32,
185
+ n_complex : u32,
186
+ chunk_len : u32,
187
+ n_chunks : u32,
188
+ batch : u32,
189
+ };
190
+
191
+ @group(0) @binding(0) var<uniform> params : CssdParams;
192
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>;
193
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>;
194
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>;
195
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>;
196
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>;
197
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
198
+ @group(0) @binding(7) var<storage, read> state_carry : array<f32>;
199
+ @group(0) @binding(8) var<storage, read> dy : array<f32>;
200
+ @group(0) @binding(9) var<storage, read_write> dx : array<f32>;
201
+ @group(0) @binding(10) var<storage, read_write> dB : array<f32>;
202
+ @group(0) @binding(11) var<storage, read_write> dC : array<f32>;
203
+ @group(0) @binding(12) var<storage, read_write> ddt : array<f32>;
204
+ @group(0) @binding(13) var<storage, read_write> dA_log : array<f32>;
205
+ @group(0) @binding(14) var<storage, read_write> dD_vec : array<f32>;
206
+
207
+ fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
208
+ fn d_softplus(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
209
+ fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
210
+ fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
211
+ fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
212
+ fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
213
+
214
+ @compute @workgroup_size(1, 1, 1)
215
+ fn complex_ssd_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
216
+ let chunk_id = gid.x;
217
+ let head_id = gid.y;
218
+ let batch_id = gid.z;
219
+
220
+ let L = params.seq_len;
221
+ let D = params.d_inner;
222
+ let H = params.n_heads;
223
+ let dh = params.d_head;
224
+ let G = params.n_groups;
225
+ let Nc = params.n_complex;
226
+ let N2 = Nc * 2u;
227
+ let CL = params.chunk_len;
228
+ let B = params.batch;
229
+
230
+ let t_start = chunk_id * CL;
231
+ let t_end = min(t_start + CL, L);
232
+ let group_id = head_id * G / H;
233
+
234
+ let log_mag = A_log[head_id * 2u + 0u];
235
+ let phase = A_log[head_id * 2u + 1u];
236
+ let db = dt_bias[head_id];
237
+
238
+ let state_stride = B * H * N2 * dh;
239
+
240
+ for (var t_rev: u32 = 0u; t_rev < t_end - t_start; t_rev = t_rev + 1u) {
241
+ let t = t_end - 1u - t_rev;
242
+
243
+ let dt_idx = batch_id * L * H + t * H + head_id;
244
+ let dt_raw = dt_in[dt_idx] + db;
245
+ let dt_val = softplus(dt_raw);
246
+ let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
247
+ let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
248
+
249
+ let x_base = batch_id * L * D + t * D + head_id * dh;
250
+ let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
251
+ let state_base = (chunk_id + 1u) * state_stride
252
+ + batch_id * H * N2 * dh
253
+ + head_id * N2 * dh;
254
+ let state_prev = chunk_id * state_stride
255
+ + batch_id * H * N2 * dh
256
+ + head_id * N2 * dh;
257
+
258
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
259
+ let dy_val = dy[batch_id * L * D + t * D + head_id * dh + i];
260
+ let x_val = x_in[x_base + i];
261
+
262
+ dD_vec[head_id] = dD_vec[head_id] + dy_val * x_val;
263
+ dx[x_base + i] = dx[x_base + i] + dy_val;
264
+
265
+ for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
266
+ let c_re = C_proj[bc_base + nc * 2u + 0u];
267
+ let c_im = C_proj[bc_base + nc * 2u + 1u];
268
+ let b_re = B_proj[bc_base + nc * 2u + 0u];
269
+ let b_im = B_proj[bc_base + nc * 2u + 1u];
270
+
271
+ let h_re = state_carry[state_base + nc * 2u * dh + 0u * dh + i];
272
+ let h_im = state_carry[state_base + nc * 2u * dh + 1u * dh + i];
273
+
274
+ // dC from Re(C · h) output — gradient of Re(C·h) w.r.t. C is Re(h)
275
+ dC[bc_base + nc * 2u + 0u] = dC[bc_base + nc * 2u + 0u] + dy_val * h_re;
276
+ dC[bc_base + nc * 2u + 1u] = dC[bc_base + nc * 2u + 1u] - dy_val * h_im;
277
+
278
+ // dh from upstream: dh_re = c_re * dy, dh_im = -c_im * dy (Re(C·h) gradient)
279
+ let dh_re = c_re * dy_val;
280
+ let dh_im = -c_im * dy_val;
281
+
282
+ // dB: B_bar · x contributed h_new; gradient flows through B_bar
283
+ // simplified: dB += dh * x (ignoring complex B_bar Jacobian)
284
+ dB[bc_base + nc * 2u + 0u] = dB[bc_base + nc * 2u + 0u] + dh_re * x_val;
285
+ dB[bc_base + nc * 2u + 1u] = dB[bc_base + nc * 2u + 1u] + dh_im * x_val;
286
+
287
+ // dx += Re(B_bar* · dh) (simplified)
288
+ dx[x_base + i] = dx[x_base + i] + cmul_re(b_re, -b_im, dh_re, dh_im);
289
+
290
+ // ddt: from A_bar and B_bar dependence on dt
291
+ let h_prev_re = state_carry[state_prev + nc * 2u * dh + 0u * dh + i];
292
+ let h_prev_im = state_carry[state_prev + nc * 2u * dh + 1u * dh + i];
293
+ // dA_bar/ddt = A * A_bar
294
+ let da_bar_re = cmul_re(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
295
+ let da_bar_im = cmul_im(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
296
+ ddt[dt_idx] = ddt[dt_idx]
297
+ + (cmul_re(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_re
298
+ - cmul_im(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_im)
299
+ * d_softplus(dt_raw);
300
+ }
301
+ }
302
+ }
303
+ }
304
+ `;
305
+ //# sourceMappingURL=complex_ssd.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"complex_ssd.js","sourceRoot":"","sources":["../../src/kernels/complex_ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AAEH,MAAM,CAAC,MAAM,wBAAwB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAiJzD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA8H1D,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const CONV1D_FORWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32, // L\n d_channels : u32, // D (number of depthwise channels in this call)\n kernel_size : u32, // K (typically 4)\n batch : u32, // B\n groups : u32, // number of channel groups (1 = standard depthwise)\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n// x (B, L, D) \u2013 input\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n// weight (D, K) \u2013 depthwise conv weights\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n// bias (D,) \u2013 optional bias (zeros if unused)\n@group(0) @binding(3) var<storage, read> bias : array<f32>;\n// y (B, L, D) \u2013 output\n@group(0) @binding(4) var<storage, read_write> y : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B)\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_forward(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x; // time position\n let d = gid.y; // channel\n let b = gid.z; // batch\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var acc: f32 = 0.0;\n\n // Causal: convolve over k = 0..K-1, reading position (t - k)\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let w_idx = d * K + k;\n let w_val = weight[w_idx];\n\n // t - k: use causal zero-padding for t < k\n if (t >= k) {\n let src = b * L * D + (t - k) * D + d;\n acc = acc + w_val * x[src];\n }\n // else: zero-padding contributes 0\n }\n\n acc = acc + bias[d];\n\n let out = b * L * D + t * D + d;\n y[out] = acc;\n}\n";
2
+ export declare const CONV1D_BACKWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32,\n d_channels : u32,\n kernel_size : u32,\n batch : u32,\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n@group(0) @binding(3) var<storage, read> dy : array<f32>;\n@group(0) @binding(4) var<storage, read_write> dx : array<f32>;\n@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;\n@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B) \u2013 computes dx\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_backward_dx(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x;\n let d = gid.y;\n let b = gid.z;\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var grad: f32 = 0.0;\n\n // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let tp = t + k;\n if (tp < L) {\n let dy_idx = b * L * D + tp * D + d;\n let w_idx = d * K + k;\n grad = grad + dy[dy_idx] * weight[w_idx];\n }\n }\n\n let dx_idx = b * L * D + t * D + d;\n dx[dx_idx] = grad;\n}\n\n// Dispatch: (K, D, 1) \u2013 accumulates dweight over (B, L)\n@compute @workgroup_size(1, 1, 1)\nfn conv1d_backward_dw(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let k = gid.x;\n let d = gid.y;\n\n if (k >= K || d >= D) { return; }\n\n var grad_w: f32 = 0.0;\n var grad_b: f32 = 0.0;\n\n for (var b: u32 = 0u; b < B; b = b + 1u) {\n for (var t: u32 = 0u; t < L; t = t + 1u) {\n let dy_idx = b * L * D + t * D + d;\n let dy_val = dy[dy_idx];\n if (t >= k) {\n let x_idx = b * L * D + (t - k) * D + d;\n grad_w = grad_w + dy_val * x[x_idx];\n }\n if (k == 0u) {\n grad_b = grad_b + dy_val;\n }\n }\n }\n\n dweight[d * K + k] = grad_w;\n if (k == 0u) {\n dbias[d] = grad_b;\n }\n}\n";
3
+ //# sourceMappingURL=conv1d.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"conv1d.d.ts","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAaA,eAAO,MAAM,mBAAmB,mxDAwD/B,CAAC;AAGF,eAAO,MAAM,oBAAoB,y6EAsFhC,CAAC"}