@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,174 @@
1
+ // Activation function WGSL kernels: SiLU (Swish) and its backward pass.
2
+ // Used in the gating mechanism of the Mamba Mixer Block.
3
+
4
+ export const ACTIVATIONS_WGSL: string = /* wgsl */`
5
+
6
+ struct ActParams {
7
+ num_elements : u32,
8
+ };
9
+
10
+ @group(0) @binding(0) var<uniform> p : ActParams;
11
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
12
+ @group(0) @binding(2) var<storage, read_write> y : array<f32>;
13
+
14
+ // SiLU(x) = x * sigmoid(x)
15
+ @compute @workgroup_size(256, 1, 1)
16
+ fn silu_forward(
17
+ @builtin(global_invocation_id) gid : vec3<u32>,
18
+ ) {
19
+ let i = gid.x;
20
+ if (i >= p.num_elements) { return; }
21
+ let v = x[i];
22
+ y[i] = v / (1.0 + exp(-v));
23
+ }
24
+
25
+ // RMSNorm forward: y = x / rms(x) * weight
26
+ // Requires separate uniform for rms norm params.
27
+ struct RMSNormParams {
28
+ num_rows : u32, // number of vectors (batch * seq_len)
29
+ dim : u32, // feature dimension
30
+ eps : f32,
31
+ };
32
+
33
+ @group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
34
+ @group(0) @binding(1) var<storage, read> rms_x : array<f32>;
35
+ @group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
36
+ @group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
37
+ @group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
38
+
39
+ @compute @workgroup_size(64, 1, 1)
40
+ fn rmsnorm_forward(
41
+ @builtin(global_invocation_id) gid : vec3<u32>,
42
+ ) {
43
+ let row = gid.x;
44
+ if (row >= rms_p.num_rows) { return; }
45
+
46
+ let D = rms_p.dim;
47
+ let base = row * D;
48
+
49
+ var sq_sum: f32 = 0.0;
50
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
51
+ let v = rms_x[base + i];
52
+ sq_sum = sq_sum + v * v;
53
+ }
54
+ let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
55
+ rms_inv[row] = inv_rms;
56
+
57
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
58
+ rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
59
+ }
60
+ }
61
+ `;
62
+
63
+ // ---- Softmax (row-wise with optional causal mask) ----
64
+ // Standalone softmax used by AttentionBlock for the score matrix.
65
+ // Dispatch: (L, H, B) — one workgroup per (row, head, batch).
66
+ // This version is a simple sequential-within-workgroup implementation;
67
+ // for large L prefer the cooperative version in attention.ts.
68
+ export const SOFTMAX_FORWARD_WGSL: string = /* wgsl */`
69
+ struct SoftmaxParams {
70
+ rows : u32, // L
71
+ cols : u32, // L
72
+ causal : u32, // 1 = apply causal mask, 0 = full softmax
73
+ };
74
+
75
+ @group(0) @binding(0) var<uniform> sp : SoftmaxParams;
76
+ @group(0) @binding(1) var<storage, read_write> data : array<f32>;
77
+
78
+ @compute @workgroup_size(1, 1, 1)
79
+ fn softmax_forward_simple(@builtin(global_invocation_id) gid: vec3<u32>) {
80
+ let row = gid.x;
81
+ let head = gid.y;
82
+ let bat = gid.z;
83
+
84
+ if (row >= sp.rows) { return; }
85
+
86
+ let L = sp.cols;
87
+ let base = bat * sp.rows * L + head * L * L + row * L;
88
+ let lim = select(L, row + 1u, sp.causal == 1u);
89
+
90
+ var max_val = -1e38;
91
+ for (var c = 0u; c < lim; c = c + 1u) {
92
+ if (data[base + c] > max_val) { max_val = data[base + c]; }
93
+ }
94
+
95
+ var sum_exp = 0.0;
96
+ for (var c = 0u; c < lim; c = c + 1u) {
97
+ let e = exp(data[base + c] - max_val);
98
+ data[base + c] = e;
99
+ sum_exp = sum_exp + e;
100
+ }
101
+
102
+ let inv = 1.0 / (sum_exp + 1e-12);
103
+ for (var c = 0u; c < lim; c = c + 1u) {
104
+ data[base + c] = data[base + c] * inv;
105
+ }
106
+ // Zero out masked positions
107
+ for (var c = lim; c < L; c = c + 1u) {
108
+ data[base + c] = 0.0;
109
+ }
110
+ }
111
+ `;
112
+
113
+ export const SOFTMAX_BACKWARD_WGSL: string = /* wgsl */`
114
+ struct SoftmaxParams {
115
+ rows : u32,
116
+ cols : u32,
117
+ causal : u32,
118
+ };
119
+
120
+ @group(0) @binding(0) var<uniform> sp : SoftmaxParams;
121
+ @group(0) @binding(1) var<storage, read> p : array<f32>; // post-softmax probs
122
+ @group(0) @binding(2) var<storage, read> dp : array<f32>; // upstream gradient
123
+ @group(0) @binding(3) var<storage, read_write> dx : array<f32>; // output gradient
124
+
125
+ @compute @workgroup_size(1, 1, 1)
126
+ fn softmax_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
127
+ let row = gid.x;
128
+ let head = gid.y;
129
+ let bat = gid.z;
130
+
131
+ if (row >= sp.rows) { return; }
132
+
133
+ let L = sp.cols;
134
+ let base = bat * sp.rows * L + head * L * L + row * L;
135
+ let lim = select(L, row + 1u, sp.causal == 1u);
136
+
137
+ // dot = sum_i p[i] * dp[i]
138
+ var dot = 0.0;
139
+ for (var i = 0u; i < lim; i = i + 1u) {
140
+ dot = dot + p[base + i] * dp[base + i];
141
+ }
142
+
143
+ for (var i = 0u; i < lim; i = i + 1u) {
144
+ dx[base + i] = p[base + i] * (dp[base + i] - dot);
145
+ }
146
+ }
147
+ `;
148
+
149
+ // ---- Backward for SiLU ----
150
+ export const ACTIVATIONS_BACKWARD_WGSL: string = /* wgsl */`
151
+
152
+ struct ActParams {
153
+ num_elements : u32,
154
+ };
155
+
156
+ @group(0) @binding(0) var<uniform> p : ActParams;
157
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
158
+ @group(0) @binding(2) var<storage, read> dy : array<f32>;
159
+ @group(0) @binding(3) var<storage, read_write> dx : array<f32>;
160
+
161
+ // d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
162
+ // = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
163
+ // simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
164
+ @compute @workgroup_size(256, 1, 1)
165
+ fn silu_backward(
166
+ @builtin(global_invocation_id) gid : vec3<u32>,
167
+ ) {
168
+ let i = gid.x;
169
+ if (i >= p.num_elements) { return; }
170
+ let v = x[i];
171
+ let sig = 1.0 / (1.0 + exp(-v));
172
+ dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
173
+ }
174
+ `;
@@ -0,0 +1,268 @@
1
+ /**
2
+ * attention.ts – Causal multi-head self-attention kernels.
3
+ *
4
+ * Implements tiled 16×16 causal attention suitable for WebGPU.
5
+ * No Flash-Attention dependency — straightforward O(L²) with causal mask.
6
+ *
7
+ * Buffer layout:
8
+ * qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
9
+ * out_buf : [B, L, D_model]
10
+ * scores : [B, H, L, L] intermediate (written then read by kernel)
11
+ *
12
+ * Dispatch attention_forward: (ceil(L/16), H, B)
13
+ * Dispatch softmax_forward: (L, H, B) — one workgroup per row
14
+ * Dispatch attention_backward: (ceil(L/16), H, B)
15
+ */
16
+
17
+ // ── Softmax ───────────────────────────────────────────────────────────────────
18
+
19
+ export const SOFTMAX_WGSL: string = /* wgsl */`
20
+ struct SoftmaxParams {
21
+ rows : u32, // L
22
+ cols : u32, // L (score matrix is L×L per head)
23
+ };
24
+
25
+ @group(0) @binding(0) var<uniform> params : SoftmaxParams;
26
+ @group(0) @binding(1) var<storage, read_write> data : array<f32>;
27
+
28
+ // One workgroup per row; each invocation handles one element within the row.
29
+ // Workgroup size 64 – cooperative reduction for max and sum.
30
+ var<workgroup> wg_max : array<f32, 64>;
31
+ var<workgroup> wg_sum : array<f32, 64>;
32
+
33
+ @compute @workgroup_size(64, 1, 1)
34
+ fn softmax_forward(@builtin(global_invocation_id) gid: vec3<u32>,
35
+ @builtin(local_invocation_id) lid: vec3<u32>,
36
+ @builtin(workgroup_id) wid: vec3<u32>) {
37
+ let row = wid.x; // L row index
38
+ let head = wid.y;
39
+ let bat = wid.z;
40
+ let cols = params.cols;
41
+
42
+ if (row >= params.rows) { return; }
43
+
44
+ let base = (bat * params.rows * cols * /* nHeads from outer dispatch */ 1u)
45
+ + row * cols;
46
+
47
+ // Step 1: find row max (with causal mask: positions > row are -inf)
48
+ var local_max = -1e38;
49
+ for (var c = lid.x; c < cols; c = c + 64u) {
50
+ var v = -1e38;
51
+ if (c <= row) { v = data[base + c]; }
52
+ if (v > local_max) { local_max = v; }
53
+ }
54
+ wg_max[lid.x] = local_max;
55
+ workgroupBarrier();
56
+ for (var s = 32u; s >= 1u; s = s >> 1u) {
57
+ if (lid.x < s) {
58
+ if (wg_max[lid.x + s] > wg_max[lid.x]) {
59
+ wg_max[lid.x] = wg_max[lid.x + s];
60
+ }
61
+ }
62
+ workgroupBarrier();
63
+ }
64
+ let row_max = wg_max[0u];
65
+
66
+ // Step 2: exp and sum
67
+ var local_sum = 0.0;
68
+ for (var c = lid.x; c < cols; c = c + 64u) {
69
+ if (c <= row) {
70
+ let e = exp(data[base + c] - row_max);
71
+ data[base + c] = e;
72
+ local_sum = local_sum + e;
73
+ } else {
74
+ data[base + c] = 0.0;
75
+ }
76
+ }
77
+ wg_sum[lid.x] = local_sum;
78
+ workgroupBarrier();
79
+ for (var s = 32u; s >= 1u; s = s >> 1u) {
80
+ if (lid.x < s) { wg_sum[lid.x] = wg_sum[lid.x] + wg_sum[lid.x + s]; }
81
+ workgroupBarrier();
82
+ }
83
+ let inv_sum = 1.0 / (wg_sum[0u] + 1e-12);
84
+
85
+ // Step 3: normalise
86
+ for (var c = lid.x; c <= row; c = c + 64u) {
87
+ data[base + c] = data[base + c] * inv_sum;
88
+ }
89
+ }
90
+ `;
91
+
92
+ // ── Attention forward ─────────────────────────────────────────────────────────
93
+
94
+ export const ATTENTION_FORWARD_WGSL: string = /* wgsl */`
95
+ struct AttnParams {
96
+ batch : u32,
97
+ seq_len : u32,
98
+ d_model : u32,
99
+ n_heads : u32,
100
+ d_head : u32,
101
+ };
102
+
103
+ @group(0) @binding(0) var<uniform> params : AttnParams;
104
+ // Q, K, V packed: [B, L, 3, H, d_head] (after projection split)
105
+ @group(0) @binding(1) var<storage, read> Q : array<f32>; // [B,L,H,dh]
106
+ @group(0) @binding(2) var<storage, read> K : array<f32>; // [B,L,H,dh]
107
+ @group(0) @binding(3) var<storage, read> V : array<f32>; // [B,L,H,dh]
108
+ @group(0) @binding(4) var<storage, read_write> scores : array<f32>; // [B,H,L,L]
109
+ @group(0) @binding(5) var<storage, read_write> out_buf : array<f32>; // [B,L,H,dh]
110
+
111
+ // Tiled 16×16 shared memory for Q row and K col
112
+ var<workgroup> tile_q : array<f32, 256>; // 16 tokens × 16 d_head
113
+ var<workgroup> tile_k : array<f32, 256>;
114
+
115
+ @compute @workgroup_size(16, 16, 1)
116
+ fn attention_forward(@builtin(global_invocation_id) gid: vec3<u32>,
117
+ @builtin(local_invocation_id) lid: vec3<u32>,
118
+ @builtin(workgroup_id) wid: vec3<u32>) {
119
+ let q_tile = wid.x; // tile index along query (row) dimension
120
+ let head = wid.y;
121
+ let batch = wid.z;
122
+
123
+ let B = params.batch;
124
+ let L = params.seq_len;
125
+ let H = params.n_heads;
126
+ let dh = params.d_head;
127
+ let inv_sqrt = 1.0 / sqrt(f32(dh));
128
+
129
+ let row = q_tile * 16u + lid.x; // query token index
130
+ let col = lid.y; // key token index offset within tile
131
+
132
+ if (row >= L) { return; }
133
+
134
+ // ── Phase 1: Compute raw attention scores for all K positions ──────────
135
+ // scores[batch, head, row, k] = Q[row] · K[k] / sqrt(dh)
136
+ // We iterate over K tiles
137
+ let q_base = batch * L * H * dh + row * H * dh + head * dh;
138
+
139
+ for (var k_start: u32 = 0u; k_start <= row; k_start = k_start + 16u) {
140
+ let k_tok = k_start + lid.y;
141
+
142
+ // Load Q row tile into shared memory (lid.y = 0..15 element index)
143
+ if (lid.y < dh && lid.y < 16u) {
144
+ tile_q[lid.x * 16u + lid.y] = Q[q_base + lid.y];
145
+ }
146
+ // Load K col tile
147
+ if (k_tok < L && lid.x < dh && lid.x < 16u) {
148
+ let k_base = batch * L * H * dh + k_tok * H * dh + head * dh;
149
+ tile_k[lid.y * 16u + lid.x] = K[k_base + lid.x];
150
+ } else if (lid.x < 16u) {
151
+ tile_k[lid.y * 16u + lid.x] = 0.0;
152
+ }
153
+ workgroupBarrier();
154
+
155
+ // Dot product: accumulate over dh
156
+ if (k_tok <= row) {
157
+ var acc = 0.0;
158
+ for (var d = 0u; d < min(dh, 16u); d = d + 1u) {
159
+ acc = acc + tile_q[lid.x * 16u + d] * tile_k[lid.y * 16u + d];
160
+ }
161
+ let score_idx = batch * H * L * L + head * L * L + row * L + k_tok;
162
+ scores[score_idx] = acc * inv_sqrt;
163
+ }
164
+ workgroupBarrier();
165
+ }
166
+ }
167
+
168
+ // Phase 2: softmax is dispatched separately via softmax_forward kernel.
169
+
170
+ // Phase 3: weighted sum of V
171
+ @compute @workgroup_size(16, 16, 1)
172
+ fn attention_value(@builtin(global_invocation_id) gid: vec3<u32>,
173
+ @builtin(local_invocation_id) lid: vec3<u32>,
174
+ @builtin(workgroup_id) wid: vec3<u32>) {
175
+ let q_tile = wid.x;
176
+ let head = wid.y;
177
+ let batch = wid.z;
178
+
179
+ let L = params.seq_len;
180
+ let H = params.n_heads;
181
+ let dh = params.d_head;
182
+
183
+ let row = q_tile * 16u + lid.x;
184
+ let d = lid.y; // d_head dimension
185
+
186
+ if (row >= L || d >= dh) { return; }
187
+
188
+ var acc = 0.0;
189
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
190
+ let score_idx = batch * H * L * L + head * L * L + row * L + k;
191
+ let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
192
+ acc = acc + scores[score_idx] * V[v_idx];
193
+ }
194
+
195
+ let out_idx = batch * L * H * dh + row * H * dh + head * dh + d;
196
+ out_buf[out_idx] = acc;
197
+ }
198
+ `;
199
+
200
+ // ── Attention backward ────────────────────────────────────────────────────────
201
+
202
+ export const ATTENTION_BACKWARD_WGSL: string = /* wgsl */`
203
+ struct AttnParams {
204
+ batch : u32,
205
+ seq_len : u32,
206
+ d_model : u32,
207
+ n_heads : u32,
208
+ d_head : u32,
209
+ };
210
+
211
+ @group(0) @binding(0) var<uniform> params : AttnParams;
212
+ @group(0) @binding(1) var<storage, read> Q : array<f32>;
213
+ @group(0) @binding(2) var<storage, read> K : array<f32>;
214
+ @group(0) @binding(3) var<storage, read> V : array<f32>;
215
+ @group(0) @binding(4) var<storage, read> scores : array<f32>; // post-softmax
216
+ @group(0) @binding(5) var<storage, read> dy : array<f32>; // [B,L,H,dh]
217
+ @group(0) @binding(6) var<storage, read_write> dQ : array<f32>;
218
+ @group(0) @binding(7) var<storage, read_write> dK : array<f32>;
219
+ @group(0) @binding(8) var<storage, read_write> dV : array<f32>;
220
+ @group(0) @binding(9) var<storage, read_write> dscores : array<f32>;
221
+
222
+ @compute @workgroup_size(16, 16, 1)
223
+ fn attention_backward(@builtin(global_invocation_id) gid: vec3<u32>,
224
+ @builtin(local_invocation_id) lid: vec3<u32>,
225
+ @builtin(workgroup_id) wid: vec3<u32>) {
226
+ let q_tile = wid.x;
227
+ let head = wid.y;
228
+ let batch = wid.z;
229
+
230
+ let L = params.seq_len;
231
+ let H = params.n_heads;
232
+ let dh = params.d_head;
233
+ let inv_sqrt = 1.0 / sqrt(f32(dh));
234
+
235
+ let row = q_tile * 16u + lid.x;
236
+ let d = lid.y;
237
+
238
+ if (row >= L || d >= dh) { return; }
239
+
240
+ // dV[k, d] += score[row, k] * dy[row, d]
241
+ // dscores[row, k] += dy[row, d] * V[k, d] (before softmax backward)
242
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
243
+ let s_idx = batch * H * L * L + head * L * L + row * L + k;
244
+ let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
245
+ let dy_idx = batch * L * H * dh + row * H * dh + head * dh + d;
246
+
247
+ dV[v_idx] = dV[v_idx] + scores[s_idx] * dy[dy_idx];
248
+ dscores[s_idx] = dscores[s_idx] + dy[dy_idx] * V[v_idx];
249
+ }
250
+
251
+ // dQ[row, d] += sum_k dscores_post_softmax[row, k] * K[k, d] * inv_sqrt
252
+ var dq_acc = 0.0;
253
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
254
+ let ds_idx = batch * H * L * L + head * L * L + row * L + k;
255
+ let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
256
+ dq_acc = dq_acc + dscores[ds_idx] * K[k_idx];
257
+ }
258
+ let q_idx = batch * L * H * dh + row * H * dh + head * dh + d;
259
+ dQ[q_idx] = dQ[q_idx] + dq_acc * inv_sqrt;
260
+
261
+ // dK[k, d] += dscores[row, k] * Q[row, d] * inv_sqrt (for all rows >= k)
262
+ for (var k: u32 = 0u; k <= row; k = k + 1u) {
263
+ let ds_idx = batch * H * L * L + head * L * L + row * L + k;
264
+ let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
265
+ dK[k_idx] = dK[k_idx] + dscores[ds_idx] * Q[q_idx] * inv_sqrt;
266
+ }
267
+ }
268
+ `;