@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
// Activation function WGSL kernels: SiLU (Swish) and its backward pass.
|
|
2
|
+
// Used in the gating mechanism of the Mamba Mixer Block.
|
|
3
|
+
|
|
4
|
+
export const ACTIVATIONS_WGSL: string = /* wgsl */`
|
|
5
|
+
|
|
6
|
+
struct ActParams {
|
|
7
|
+
num_elements : u32,
|
|
8
|
+
};
|
|
9
|
+
|
|
10
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
11
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
12
|
+
@group(0) @binding(2) var<storage, read_write> y : array<f32>;
|
|
13
|
+
|
|
14
|
+
// SiLU(x) = x * sigmoid(x)
|
|
15
|
+
@compute @workgroup_size(256, 1, 1)
|
|
16
|
+
fn silu_forward(
|
|
17
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
18
|
+
) {
|
|
19
|
+
let i = gid.x;
|
|
20
|
+
if (i >= p.num_elements) { return; }
|
|
21
|
+
let v = x[i];
|
|
22
|
+
y[i] = v / (1.0 + exp(-v));
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// RMSNorm forward: y = x / rms(x) * weight
|
|
26
|
+
// Requires separate uniform for rms norm params.
|
|
27
|
+
struct RMSNormParams {
|
|
28
|
+
num_rows : u32, // number of vectors (batch * seq_len)
|
|
29
|
+
dim : u32, // feature dimension
|
|
30
|
+
eps : f32,
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
@group(0) @binding(0) var<uniform> rms_p : RMSNormParams;
|
|
34
|
+
@group(0) @binding(1) var<storage, read> rms_x : array<f32>;
|
|
35
|
+
@group(0) @binding(2) var<storage, read> rms_w : array<f32>; // scale (dim,)
|
|
36
|
+
@group(0) @binding(3) var<storage, read_write> rms_y : array<f32>;
|
|
37
|
+
@group(0) @binding(4) var<storage, read_write> rms_inv : array<f32>; // cache 1/rms per row
|
|
38
|
+
|
|
39
|
+
@compute @workgroup_size(64, 1, 1)
|
|
40
|
+
fn rmsnorm_forward(
|
|
41
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
42
|
+
) {
|
|
43
|
+
let row = gid.x;
|
|
44
|
+
if (row >= rms_p.num_rows) { return; }
|
|
45
|
+
|
|
46
|
+
let D = rms_p.dim;
|
|
47
|
+
let base = row * D;
|
|
48
|
+
|
|
49
|
+
var sq_sum: f32 = 0.0;
|
|
50
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
51
|
+
let v = rms_x[base + i];
|
|
52
|
+
sq_sum = sq_sum + v * v;
|
|
53
|
+
}
|
|
54
|
+
let inv_rms = 1.0 / sqrt(sq_sum / f32(D) + rms_p.eps);
|
|
55
|
+
rms_inv[row] = inv_rms;
|
|
56
|
+
|
|
57
|
+
for (var i: u32 = 0u; i < D; i = i + 1u) {
|
|
58
|
+
rms_y[base + i] = rms_x[base + i] * inv_rms * rms_w[i];
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
`;
|
|
62
|
+
|
|
63
|
+
// ---- Softmax (row-wise with optional causal mask) ----
|
|
64
|
+
// Standalone softmax used by AttentionBlock for the score matrix.
|
|
65
|
+
// Dispatch: (L, H, B) — one workgroup per (row, head, batch).
|
|
66
|
+
// This version is a simple sequential-within-workgroup implementation;
|
|
67
|
+
// for large L prefer the cooperative version in attention.ts.
|
|
68
|
+
export const SOFTMAX_FORWARD_WGSL: string = /* wgsl */`
|
|
69
|
+
struct SoftmaxParams {
|
|
70
|
+
rows : u32, // L
|
|
71
|
+
cols : u32, // L
|
|
72
|
+
causal : u32, // 1 = apply causal mask, 0 = full softmax
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
@group(0) @binding(0) var<uniform> sp : SoftmaxParams;
|
|
76
|
+
@group(0) @binding(1) var<storage, read_write> data : array<f32>;
|
|
77
|
+
|
|
78
|
+
@compute @workgroup_size(1, 1, 1)
|
|
79
|
+
fn softmax_forward_simple(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
80
|
+
let row = gid.x;
|
|
81
|
+
let head = gid.y;
|
|
82
|
+
let bat = gid.z;
|
|
83
|
+
|
|
84
|
+
if (row >= sp.rows) { return; }
|
|
85
|
+
|
|
86
|
+
let L = sp.cols;
|
|
87
|
+
let base = bat * sp.rows * L + head * L * L + row * L;
|
|
88
|
+
let lim = select(L, row + 1u, sp.causal == 1u);
|
|
89
|
+
|
|
90
|
+
var max_val = -1e38;
|
|
91
|
+
for (var c = 0u; c < lim; c = c + 1u) {
|
|
92
|
+
if (data[base + c] > max_val) { max_val = data[base + c]; }
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
var sum_exp = 0.0;
|
|
96
|
+
for (var c = 0u; c < lim; c = c + 1u) {
|
|
97
|
+
let e = exp(data[base + c] - max_val);
|
|
98
|
+
data[base + c] = e;
|
|
99
|
+
sum_exp = sum_exp + e;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
let inv = 1.0 / (sum_exp + 1e-12);
|
|
103
|
+
for (var c = 0u; c < lim; c = c + 1u) {
|
|
104
|
+
data[base + c] = data[base + c] * inv;
|
|
105
|
+
}
|
|
106
|
+
// Zero out masked positions
|
|
107
|
+
for (var c = lim; c < L; c = c + 1u) {
|
|
108
|
+
data[base + c] = 0.0;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
`;
|
|
112
|
+
|
|
113
|
+
export const SOFTMAX_BACKWARD_WGSL: string = /* wgsl */`
|
|
114
|
+
struct SoftmaxParams {
|
|
115
|
+
rows : u32,
|
|
116
|
+
cols : u32,
|
|
117
|
+
causal : u32,
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
@group(0) @binding(0) var<uniform> sp : SoftmaxParams;
|
|
121
|
+
@group(0) @binding(1) var<storage, read> p : array<f32>; // post-softmax probs
|
|
122
|
+
@group(0) @binding(2) var<storage, read> dp : array<f32>; // upstream gradient
|
|
123
|
+
@group(0) @binding(3) var<storage, read_write> dx : array<f32>; // output gradient
|
|
124
|
+
|
|
125
|
+
@compute @workgroup_size(1, 1, 1)
|
|
126
|
+
fn softmax_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
127
|
+
let row = gid.x;
|
|
128
|
+
let head = gid.y;
|
|
129
|
+
let bat = gid.z;
|
|
130
|
+
|
|
131
|
+
if (row >= sp.rows) { return; }
|
|
132
|
+
|
|
133
|
+
let L = sp.cols;
|
|
134
|
+
let base = bat * sp.rows * L + head * L * L + row * L;
|
|
135
|
+
let lim = select(L, row + 1u, sp.causal == 1u);
|
|
136
|
+
|
|
137
|
+
// dot = sum_i p[i] * dp[i]
|
|
138
|
+
var dot = 0.0;
|
|
139
|
+
for (var i = 0u; i < lim; i = i + 1u) {
|
|
140
|
+
dot = dot + p[base + i] * dp[base + i];
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
for (var i = 0u; i < lim; i = i + 1u) {
|
|
144
|
+
dx[base + i] = p[base + i] * (dp[base + i] - dot);
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
`;
|
|
148
|
+
|
|
149
|
+
// ---- Backward for SiLU ----
|
|
150
|
+
export const ACTIVATIONS_BACKWARD_WGSL: string = /* wgsl */`
|
|
151
|
+
|
|
152
|
+
struct ActParams {
|
|
153
|
+
num_elements : u32,
|
|
154
|
+
};
|
|
155
|
+
|
|
156
|
+
@group(0) @binding(0) var<uniform> p : ActParams;
|
|
157
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
158
|
+
@group(0) @binding(2) var<storage, read> dy : array<f32>;
|
|
159
|
+
@group(0) @binding(3) var<storage, read_write> dx : array<f32>;
|
|
160
|
+
|
|
161
|
+
// d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
|
162
|
+
// = silu(x)/x + sigmoid(x) * (1 - sigmoid(x)) * x
|
|
163
|
+
// simplified: sigmoid(x) * (1 + x*(1 - sigmoid(x)))
|
|
164
|
+
@compute @workgroup_size(256, 1, 1)
|
|
165
|
+
fn silu_backward(
|
|
166
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
167
|
+
) {
|
|
168
|
+
let i = gid.x;
|
|
169
|
+
if (i >= p.num_elements) { return; }
|
|
170
|
+
let v = x[i];
|
|
171
|
+
let sig = 1.0 / (1.0 + exp(-v));
|
|
172
|
+
dx[i] = dy[i] * sig * (1.0 + v * (1.0 - sig));
|
|
173
|
+
}
|
|
174
|
+
`;
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* attention.ts – Causal multi-head self-attention kernels.
|
|
3
|
+
*
|
|
4
|
+
* Implements tiled 16×16 causal attention suitable for WebGPU.
|
|
5
|
+
* No Flash-Attention dependency — straightforward O(L²) with causal mask.
|
|
6
|
+
*
|
|
7
|
+
* Buffer layout:
|
|
8
|
+
* qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
|
|
9
|
+
* out_buf : [B, L, D_model]
|
|
10
|
+
* scores : [B, H, L, L] intermediate (written then read by kernel)
|
|
11
|
+
*
|
|
12
|
+
* Dispatch attention_forward: (ceil(L/16), H, B)
|
|
13
|
+
* Dispatch softmax_forward: (L, H, B) — one workgroup per row
|
|
14
|
+
* Dispatch attention_backward: (ceil(L/16), H, B)
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
// ── Softmax ───────────────────────────────────────────────────────────────────
|
|
18
|
+
|
|
19
|
+
export const SOFTMAX_WGSL: string = /* wgsl */`
|
|
20
|
+
struct SoftmaxParams {
|
|
21
|
+
rows : u32, // L
|
|
22
|
+
cols : u32, // L (score matrix is L×L per head)
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
@group(0) @binding(0) var<uniform> params : SoftmaxParams;
|
|
26
|
+
@group(0) @binding(1) var<storage, read_write> data : array<f32>;
|
|
27
|
+
|
|
28
|
+
// One workgroup per row; each invocation handles one element within the row.
|
|
29
|
+
// Workgroup size 64 – cooperative reduction for max and sum.
|
|
30
|
+
var<workgroup> wg_max : array<f32, 64>;
|
|
31
|
+
var<workgroup> wg_sum : array<f32, 64>;
|
|
32
|
+
|
|
33
|
+
@compute @workgroup_size(64, 1, 1)
|
|
34
|
+
fn softmax_forward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
35
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
36
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
37
|
+
let row = wid.x; // L row index
|
|
38
|
+
let head = wid.y;
|
|
39
|
+
let bat = wid.z;
|
|
40
|
+
let cols = params.cols;
|
|
41
|
+
|
|
42
|
+
if (row >= params.rows) { return; }
|
|
43
|
+
|
|
44
|
+
let base = (bat * params.rows * cols * /* nHeads from outer dispatch */ 1u)
|
|
45
|
+
+ row * cols;
|
|
46
|
+
|
|
47
|
+
// Step 1: find row max (with causal mask: positions > row are -inf)
|
|
48
|
+
var local_max = -1e38;
|
|
49
|
+
for (var c = lid.x; c < cols; c = c + 64u) {
|
|
50
|
+
var v = -1e38;
|
|
51
|
+
if (c <= row) { v = data[base + c]; }
|
|
52
|
+
if (v > local_max) { local_max = v; }
|
|
53
|
+
}
|
|
54
|
+
wg_max[lid.x] = local_max;
|
|
55
|
+
workgroupBarrier();
|
|
56
|
+
for (var s = 32u; s >= 1u; s = s >> 1u) {
|
|
57
|
+
if (lid.x < s) {
|
|
58
|
+
if (wg_max[lid.x + s] > wg_max[lid.x]) {
|
|
59
|
+
wg_max[lid.x] = wg_max[lid.x + s];
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
workgroupBarrier();
|
|
63
|
+
}
|
|
64
|
+
let row_max = wg_max[0u];
|
|
65
|
+
|
|
66
|
+
// Step 2: exp and sum
|
|
67
|
+
var local_sum = 0.0;
|
|
68
|
+
for (var c = lid.x; c < cols; c = c + 64u) {
|
|
69
|
+
if (c <= row) {
|
|
70
|
+
let e = exp(data[base + c] - row_max);
|
|
71
|
+
data[base + c] = e;
|
|
72
|
+
local_sum = local_sum + e;
|
|
73
|
+
} else {
|
|
74
|
+
data[base + c] = 0.0;
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
wg_sum[lid.x] = local_sum;
|
|
78
|
+
workgroupBarrier();
|
|
79
|
+
for (var s = 32u; s >= 1u; s = s >> 1u) {
|
|
80
|
+
if (lid.x < s) { wg_sum[lid.x] = wg_sum[lid.x] + wg_sum[lid.x + s]; }
|
|
81
|
+
workgroupBarrier();
|
|
82
|
+
}
|
|
83
|
+
let inv_sum = 1.0 / (wg_sum[0u] + 1e-12);
|
|
84
|
+
|
|
85
|
+
// Step 3: normalise
|
|
86
|
+
for (var c = lid.x; c <= row; c = c + 64u) {
|
|
87
|
+
data[base + c] = data[base + c] * inv_sum;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
`;
|
|
91
|
+
|
|
92
|
+
// ── Attention forward ─────────────────────────────────────────────────────────
|
|
93
|
+
|
|
94
|
+
export const ATTENTION_FORWARD_WGSL: string = /* wgsl */`
|
|
95
|
+
struct AttnParams {
|
|
96
|
+
batch : u32,
|
|
97
|
+
seq_len : u32,
|
|
98
|
+
d_model : u32,
|
|
99
|
+
n_heads : u32,
|
|
100
|
+
d_head : u32,
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
@group(0) @binding(0) var<uniform> params : AttnParams;
|
|
104
|
+
// Q, K, V packed: [B, L, 3, H, d_head] (after projection split)
|
|
105
|
+
@group(0) @binding(1) var<storage, read> Q : array<f32>; // [B,L,H,dh]
|
|
106
|
+
@group(0) @binding(2) var<storage, read> K : array<f32>; // [B,L,H,dh]
|
|
107
|
+
@group(0) @binding(3) var<storage, read> V : array<f32>; // [B,L,H,dh]
|
|
108
|
+
@group(0) @binding(4) var<storage, read_write> scores : array<f32>; // [B,H,L,L]
|
|
109
|
+
@group(0) @binding(5) var<storage, read_write> out_buf : array<f32>; // [B,L,H,dh]
|
|
110
|
+
|
|
111
|
+
// Tiled 16×16 shared memory for Q row and K col
|
|
112
|
+
var<workgroup> tile_q : array<f32, 256>; // 16 tokens × 16 d_head
|
|
113
|
+
var<workgroup> tile_k : array<f32, 256>;
|
|
114
|
+
|
|
115
|
+
@compute @workgroup_size(16, 16, 1)
|
|
116
|
+
fn attention_forward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
117
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
118
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
119
|
+
let q_tile = wid.x; // tile index along query (row) dimension
|
|
120
|
+
let head = wid.y;
|
|
121
|
+
let batch = wid.z;
|
|
122
|
+
|
|
123
|
+
let B = params.batch;
|
|
124
|
+
let L = params.seq_len;
|
|
125
|
+
let H = params.n_heads;
|
|
126
|
+
let dh = params.d_head;
|
|
127
|
+
let inv_sqrt = 1.0 / sqrt(f32(dh));
|
|
128
|
+
|
|
129
|
+
let row = q_tile * 16u + lid.x; // query token index
|
|
130
|
+
let col = lid.y; // key token index offset within tile
|
|
131
|
+
|
|
132
|
+
if (row >= L) { return; }
|
|
133
|
+
|
|
134
|
+
// ── Phase 1: Compute raw attention scores for all K positions ──────────
|
|
135
|
+
// scores[batch, head, row, k] = Q[row] · K[k] / sqrt(dh)
|
|
136
|
+
// We iterate over K tiles
|
|
137
|
+
let q_base = batch * L * H * dh + row * H * dh + head * dh;
|
|
138
|
+
|
|
139
|
+
for (var k_start: u32 = 0u; k_start <= row; k_start = k_start + 16u) {
|
|
140
|
+
let k_tok = k_start + lid.y;
|
|
141
|
+
|
|
142
|
+
// Load Q row tile into shared memory (lid.y = 0..15 element index)
|
|
143
|
+
if (lid.y < dh && lid.y < 16u) {
|
|
144
|
+
tile_q[lid.x * 16u + lid.y] = Q[q_base + lid.y];
|
|
145
|
+
}
|
|
146
|
+
// Load K col tile
|
|
147
|
+
if (k_tok < L && lid.x < dh && lid.x < 16u) {
|
|
148
|
+
let k_base = batch * L * H * dh + k_tok * H * dh + head * dh;
|
|
149
|
+
tile_k[lid.y * 16u + lid.x] = K[k_base + lid.x];
|
|
150
|
+
} else if (lid.x < 16u) {
|
|
151
|
+
tile_k[lid.y * 16u + lid.x] = 0.0;
|
|
152
|
+
}
|
|
153
|
+
workgroupBarrier();
|
|
154
|
+
|
|
155
|
+
// Dot product: accumulate over dh
|
|
156
|
+
if (k_tok <= row) {
|
|
157
|
+
var acc = 0.0;
|
|
158
|
+
for (var d = 0u; d < min(dh, 16u); d = d + 1u) {
|
|
159
|
+
acc = acc + tile_q[lid.x * 16u + d] * tile_k[lid.y * 16u + d];
|
|
160
|
+
}
|
|
161
|
+
let score_idx = batch * H * L * L + head * L * L + row * L + k_tok;
|
|
162
|
+
scores[score_idx] = acc * inv_sqrt;
|
|
163
|
+
}
|
|
164
|
+
workgroupBarrier();
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
// Phase 2: softmax is dispatched separately via softmax_forward kernel.
|
|
169
|
+
|
|
170
|
+
// Phase 3: weighted sum of V
|
|
171
|
+
@compute @workgroup_size(16, 16, 1)
|
|
172
|
+
fn attention_value(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
173
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
174
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
175
|
+
let q_tile = wid.x;
|
|
176
|
+
let head = wid.y;
|
|
177
|
+
let batch = wid.z;
|
|
178
|
+
|
|
179
|
+
let L = params.seq_len;
|
|
180
|
+
let H = params.n_heads;
|
|
181
|
+
let dh = params.d_head;
|
|
182
|
+
|
|
183
|
+
let row = q_tile * 16u + lid.x;
|
|
184
|
+
let d = lid.y; // d_head dimension
|
|
185
|
+
|
|
186
|
+
if (row >= L || d >= dh) { return; }
|
|
187
|
+
|
|
188
|
+
var acc = 0.0;
|
|
189
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
190
|
+
let score_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
191
|
+
let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
192
|
+
acc = acc + scores[score_idx] * V[v_idx];
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
let out_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
196
|
+
out_buf[out_idx] = acc;
|
|
197
|
+
}
|
|
198
|
+
`;
|
|
199
|
+
|
|
200
|
+
// ── Attention backward ────────────────────────────────────────────────────────
|
|
201
|
+
|
|
202
|
+
export const ATTENTION_BACKWARD_WGSL: string = /* wgsl */`
|
|
203
|
+
struct AttnParams {
|
|
204
|
+
batch : u32,
|
|
205
|
+
seq_len : u32,
|
|
206
|
+
d_model : u32,
|
|
207
|
+
n_heads : u32,
|
|
208
|
+
d_head : u32,
|
|
209
|
+
};
|
|
210
|
+
|
|
211
|
+
@group(0) @binding(0) var<uniform> params : AttnParams;
|
|
212
|
+
@group(0) @binding(1) var<storage, read> Q : array<f32>;
|
|
213
|
+
@group(0) @binding(2) var<storage, read> K : array<f32>;
|
|
214
|
+
@group(0) @binding(3) var<storage, read> V : array<f32>;
|
|
215
|
+
@group(0) @binding(4) var<storage, read> scores : array<f32>; // post-softmax
|
|
216
|
+
@group(0) @binding(5) var<storage, read> dy : array<f32>; // [B,L,H,dh]
|
|
217
|
+
@group(0) @binding(6) var<storage, read_write> dQ : array<f32>;
|
|
218
|
+
@group(0) @binding(7) var<storage, read_write> dK : array<f32>;
|
|
219
|
+
@group(0) @binding(8) var<storage, read_write> dV : array<f32>;
|
|
220
|
+
@group(0) @binding(9) var<storage, read_write> dscores : array<f32>;
|
|
221
|
+
|
|
222
|
+
@compute @workgroup_size(16, 16, 1)
|
|
223
|
+
fn attention_backward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
224
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
225
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
226
|
+
let q_tile = wid.x;
|
|
227
|
+
let head = wid.y;
|
|
228
|
+
let batch = wid.z;
|
|
229
|
+
|
|
230
|
+
let L = params.seq_len;
|
|
231
|
+
let H = params.n_heads;
|
|
232
|
+
let dh = params.d_head;
|
|
233
|
+
let inv_sqrt = 1.0 / sqrt(f32(dh));
|
|
234
|
+
|
|
235
|
+
let row = q_tile * 16u + lid.x;
|
|
236
|
+
let d = lid.y;
|
|
237
|
+
|
|
238
|
+
if (row >= L || d >= dh) { return; }
|
|
239
|
+
|
|
240
|
+
// dV[k, d] += score[row, k] * dy[row, d]
|
|
241
|
+
// dscores[row, k] += dy[row, d] * V[k, d] (before softmax backward)
|
|
242
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
243
|
+
let s_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
244
|
+
let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
245
|
+
let dy_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
246
|
+
|
|
247
|
+
dV[v_idx] = dV[v_idx] + scores[s_idx] * dy[dy_idx];
|
|
248
|
+
dscores[s_idx] = dscores[s_idx] + dy[dy_idx] * V[v_idx];
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// dQ[row, d] += sum_k dscores_post_softmax[row, k] * K[k, d] * inv_sqrt
|
|
252
|
+
var dq_acc = 0.0;
|
|
253
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
254
|
+
let ds_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
255
|
+
let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
256
|
+
dq_acc = dq_acc + dscores[ds_idx] * K[k_idx];
|
|
257
|
+
}
|
|
258
|
+
let q_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
259
|
+
dQ[q_idx] = dQ[q_idx] + dq_acc * inv_sqrt;
|
|
260
|
+
|
|
261
|
+
// dK[k, d] += dscores[row, k] * Q[row, d] * inv_sqrt (for all rows >= k)
|
|
262
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
263
|
+
let ds_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
264
|
+
let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
265
|
+
dK[k_idx] = dK[k_idx] + dscores[ds_idx] * Q[q_idx] * inv_sqrt;
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
`;
|