@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"activations.js","sourceRoot":"","sources":["../../src/kernels/activations.ts"],"names":[],"mappings":"AAAA,wEAAwE;AACxE,yDAAyD;AAEzD,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyDjD,CAAC;AAEF,yDAAyD;AACzD,kEAAkE;AAClE,8DAA8D;AAC9D,uEAAuE;AACvE,8DAA8D;AAC9D,MAAM,CAAC,MAAM,oBAAoB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA2CrD,CAAC;AAEF,MAAM,CAAC,MAAM,qBAAqB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAkCtD,CAAC;AAEF,8BAA8B;AAC9B,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;CAwB1D,CAAC"}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* attention.ts – Causal multi-head self-attention kernels.
|
|
3
|
+
*
|
|
4
|
+
* Implements tiled 16×16 causal attention suitable for WebGPU.
|
|
5
|
+
* No Flash-Attention dependency — straightforward O(L²) with causal mask.
|
|
6
|
+
*
|
|
7
|
+
* Buffer layout:
|
|
8
|
+
* qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
|
|
9
|
+
* out_buf : [B, L, D_model]
|
|
10
|
+
* scores : [B, H, L, L] intermediate (written then read by kernel)
|
|
11
|
+
*
|
|
12
|
+
* Dispatch attention_forward: (ceil(L/16), H, B)
|
|
13
|
+
* Dispatch softmax_forward: (L, H, B) — one workgroup per row
|
|
14
|
+
* Dispatch attention_backward: (ceil(L/16), H, B)
|
|
15
|
+
*/
|
|
16
|
+
export declare const SOFTMAX_WGSL: string;
|
|
17
|
+
export declare const ATTENTION_FORWARD_WGSL: string;
|
|
18
|
+
export declare const ATTENTION_BACKWARD_WGSL: string;
|
|
19
|
+
//# sourceMappingURL=attention.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"attention.d.ts","sourceRoot":"","sources":["../../src/kernels/attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;GAcG;AAIH,eAAO,MAAM,YAAY,EAAE,MAuE1B,CAAC;AAIF,eAAO,MAAM,sBAAsB,EAAE,MAwGpC,CAAC;AAIF,eAAO,MAAM,uBAAuB,EAAE,MAkErC,CAAC"}
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* attention.ts – Causal multi-head self-attention kernels.
|
|
3
|
+
*
|
|
4
|
+
* Implements tiled 16×16 causal attention suitable for WebGPU.
|
|
5
|
+
* No Flash-Attention dependency — straightforward O(L²) with causal mask.
|
|
6
|
+
*
|
|
7
|
+
* Buffer layout:
|
|
8
|
+
* qkv_in : [B, L, 3*D_model] fused Q,K,V after wQKV projection
|
|
9
|
+
* out_buf : [B, L, D_model]
|
|
10
|
+
* scores : [B, H, L, L] intermediate (written then read by kernel)
|
|
11
|
+
*
|
|
12
|
+
* Dispatch attention_forward: (ceil(L/16), H, B)
|
|
13
|
+
* Dispatch softmax_forward: (L, H, B) — one workgroup per row
|
|
14
|
+
* Dispatch attention_backward: (ceil(L/16), H, B)
|
|
15
|
+
*/
|
|
16
|
+
// ── Softmax ───────────────────────────────────────────────────────────────────
|
|
17
|
+
export const SOFTMAX_WGSL = /* wgsl */ `
|
|
18
|
+
struct SoftmaxParams {
|
|
19
|
+
rows : u32, // L
|
|
20
|
+
cols : u32, // L (score matrix is L×L per head)
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
@group(0) @binding(0) var<uniform> params : SoftmaxParams;
|
|
24
|
+
@group(0) @binding(1) var<storage, read_write> data : array<f32>;
|
|
25
|
+
|
|
26
|
+
// One workgroup per row; each invocation handles one element within the row.
|
|
27
|
+
// Workgroup size 64 – cooperative reduction for max and sum.
|
|
28
|
+
var<workgroup> wg_max : array<f32, 64>;
|
|
29
|
+
var<workgroup> wg_sum : array<f32, 64>;
|
|
30
|
+
|
|
31
|
+
@compute @workgroup_size(64, 1, 1)
|
|
32
|
+
fn softmax_forward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
33
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
34
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
35
|
+
let row = wid.x; // L row index
|
|
36
|
+
let head = wid.y;
|
|
37
|
+
let bat = wid.z;
|
|
38
|
+
let cols = params.cols;
|
|
39
|
+
|
|
40
|
+
if (row >= params.rows) { return; }
|
|
41
|
+
|
|
42
|
+
let base = (bat * params.rows * cols * /* nHeads from outer dispatch */ 1u)
|
|
43
|
+
+ row * cols;
|
|
44
|
+
|
|
45
|
+
// Step 1: find row max (with causal mask: positions > row are -inf)
|
|
46
|
+
var local_max = -1e38;
|
|
47
|
+
for (var c = lid.x; c < cols; c = c + 64u) {
|
|
48
|
+
var v = -1e38;
|
|
49
|
+
if (c <= row) { v = data[base + c]; }
|
|
50
|
+
if (v > local_max) { local_max = v; }
|
|
51
|
+
}
|
|
52
|
+
wg_max[lid.x] = local_max;
|
|
53
|
+
workgroupBarrier();
|
|
54
|
+
for (var s = 32u; s >= 1u; s = s >> 1u) {
|
|
55
|
+
if (lid.x < s) {
|
|
56
|
+
if (wg_max[lid.x + s] > wg_max[lid.x]) {
|
|
57
|
+
wg_max[lid.x] = wg_max[lid.x + s];
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
workgroupBarrier();
|
|
61
|
+
}
|
|
62
|
+
let row_max = wg_max[0u];
|
|
63
|
+
|
|
64
|
+
// Step 2: exp and sum
|
|
65
|
+
var local_sum = 0.0;
|
|
66
|
+
for (var c = lid.x; c < cols; c = c + 64u) {
|
|
67
|
+
if (c <= row) {
|
|
68
|
+
let e = exp(data[base + c] - row_max);
|
|
69
|
+
data[base + c] = e;
|
|
70
|
+
local_sum = local_sum + e;
|
|
71
|
+
} else {
|
|
72
|
+
data[base + c] = 0.0;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
wg_sum[lid.x] = local_sum;
|
|
76
|
+
workgroupBarrier();
|
|
77
|
+
for (var s = 32u; s >= 1u; s = s >> 1u) {
|
|
78
|
+
if (lid.x < s) { wg_sum[lid.x] = wg_sum[lid.x] + wg_sum[lid.x + s]; }
|
|
79
|
+
workgroupBarrier();
|
|
80
|
+
}
|
|
81
|
+
let inv_sum = 1.0 / (wg_sum[0u] + 1e-12);
|
|
82
|
+
|
|
83
|
+
// Step 3: normalise
|
|
84
|
+
for (var c = lid.x; c <= row; c = c + 64u) {
|
|
85
|
+
data[base + c] = data[base + c] * inv_sum;
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
`;
|
|
89
|
+
// ── Attention forward ─────────────────────────────────────────────────────────
|
|
90
|
+
export const ATTENTION_FORWARD_WGSL = /* wgsl */ `
|
|
91
|
+
struct AttnParams {
|
|
92
|
+
batch : u32,
|
|
93
|
+
seq_len : u32,
|
|
94
|
+
d_model : u32,
|
|
95
|
+
n_heads : u32,
|
|
96
|
+
d_head : u32,
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
@group(0) @binding(0) var<uniform> params : AttnParams;
|
|
100
|
+
// Q, K, V packed: [B, L, 3, H, d_head] (after projection split)
|
|
101
|
+
@group(0) @binding(1) var<storage, read> Q : array<f32>; // [B,L,H,dh]
|
|
102
|
+
@group(0) @binding(2) var<storage, read> K : array<f32>; // [B,L,H,dh]
|
|
103
|
+
@group(0) @binding(3) var<storage, read> V : array<f32>; // [B,L,H,dh]
|
|
104
|
+
@group(0) @binding(4) var<storage, read_write> scores : array<f32>; // [B,H,L,L]
|
|
105
|
+
@group(0) @binding(5) var<storage, read_write> out_buf : array<f32>; // [B,L,H,dh]
|
|
106
|
+
|
|
107
|
+
// Tiled 16×16 shared memory for Q row and K col
|
|
108
|
+
var<workgroup> tile_q : array<f32, 256>; // 16 tokens × 16 d_head
|
|
109
|
+
var<workgroup> tile_k : array<f32, 256>;
|
|
110
|
+
|
|
111
|
+
@compute @workgroup_size(16, 16, 1)
|
|
112
|
+
fn attention_forward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
113
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
114
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
115
|
+
let q_tile = wid.x; // tile index along query (row) dimension
|
|
116
|
+
let head = wid.y;
|
|
117
|
+
let batch = wid.z;
|
|
118
|
+
|
|
119
|
+
let B = params.batch;
|
|
120
|
+
let L = params.seq_len;
|
|
121
|
+
let H = params.n_heads;
|
|
122
|
+
let dh = params.d_head;
|
|
123
|
+
let inv_sqrt = 1.0 / sqrt(f32(dh));
|
|
124
|
+
|
|
125
|
+
let row = q_tile * 16u + lid.x; // query token index
|
|
126
|
+
let col = lid.y; // key token index offset within tile
|
|
127
|
+
|
|
128
|
+
if (row >= L) { return; }
|
|
129
|
+
|
|
130
|
+
// ── Phase 1: Compute raw attention scores for all K positions ──────────
|
|
131
|
+
// scores[batch, head, row, k] = Q[row] · K[k] / sqrt(dh)
|
|
132
|
+
// We iterate over K tiles
|
|
133
|
+
let q_base = batch * L * H * dh + row * H * dh + head * dh;
|
|
134
|
+
|
|
135
|
+
for (var k_start: u32 = 0u; k_start <= row; k_start = k_start + 16u) {
|
|
136
|
+
let k_tok = k_start + lid.y;
|
|
137
|
+
|
|
138
|
+
// Load Q row tile into shared memory (lid.y = 0..15 element index)
|
|
139
|
+
if (lid.y < dh && lid.y < 16u) {
|
|
140
|
+
tile_q[lid.x * 16u + lid.y] = Q[q_base + lid.y];
|
|
141
|
+
}
|
|
142
|
+
// Load K col tile
|
|
143
|
+
if (k_tok < L && lid.x < dh && lid.x < 16u) {
|
|
144
|
+
let k_base = batch * L * H * dh + k_tok * H * dh + head * dh;
|
|
145
|
+
tile_k[lid.y * 16u + lid.x] = K[k_base + lid.x];
|
|
146
|
+
} else if (lid.x < 16u) {
|
|
147
|
+
tile_k[lid.y * 16u + lid.x] = 0.0;
|
|
148
|
+
}
|
|
149
|
+
workgroupBarrier();
|
|
150
|
+
|
|
151
|
+
// Dot product: accumulate over dh
|
|
152
|
+
if (k_tok <= row) {
|
|
153
|
+
var acc = 0.0;
|
|
154
|
+
for (var d = 0u; d < min(dh, 16u); d = d + 1u) {
|
|
155
|
+
acc = acc + tile_q[lid.x * 16u + d] * tile_k[lid.y * 16u + d];
|
|
156
|
+
}
|
|
157
|
+
let score_idx = batch * H * L * L + head * L * L + row * L + k_tok;
|
|
158
|
+
scores[score_idx] = acc * inv_sqrt;
|
|
159
|
+
}
|
|
160
|
+
workgroupBarrier();
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
// Phase 2: softmax is dispatched separately via softmax_forward kernel.
|
|
165
|
+
|
|
166
|
+
// Phase 3: weighted sum of V
|
|
167
|
+
@compute @workgroup_size(16, 16, 1)
|
|
168
|
+
fn attention_value(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
169
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
170
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
171
|
+
let q_tile = wid.x;
|
|
172
|
+
let head = wid.y;
|
|
173
|
+
let batch = wid.z;
|
|
174
|
+
|
|
175
|
+
let L = params.seq_len;
|
|
176
|
+
let H = params.n_heads;
|
|
177
|
+
let dh = params.d_head;
|
|
178
|
+
|
|
179
|
+
let row = q_tile * 16u + lid.x;
|
|
180
|
+
let d = lid.y; // d_head dimension
|
|
181
|
+
|
|
182
|
+
if (row >= L || d >= dh) { return; }
|
|
183
|
+
|
|
184
|
+
var acc = 0.0;
|
|
185
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
186
|
+
let score_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
187
|
+
let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
188
|
+
acc = acc + scores[score_idx] * V[v_idx];
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
let out_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
192
|
+
out_buf[out_idx] = acc;
|
|
193
|
+
}
|
|
194
|
+
`;
|
|
195
|
+
// ── Attention backward ────────────────────────────────────────────────────────
|
|
196
|
+
export const ATTENTION_BACKWARD_WGSL = /* wgsl */ `
|
|
197
|
+
struct AttnParams {
|
|
198
|
+
batch : u32,
|
|
199
|
+
seq_len : u32,
|
|
200
|
+
d_model : u32,
|
|
201
|
+
n_heads : u32,
|
|
202
|
+
d_head : u32,
|
|
203
|
+
};
|
|
204
|
+
|
|
205
|
+
@group(0) @binding(0) var<uniform> params : AttnParams;
|
|
206
|
+
@group(0) @binding(1) var<storage, read> Q : array<f32>;
|
|
207
|
+
@group(0) @binding(2) var<storage, read> K : array<f32>;
|
|
208
|
+
@group(0) @binding(3) var<storage, read> V : array<f32>;
|
|
209
|
+
@group(0) @binding(4) var<storage, read> scores : array<f32>; // post-softmax
|
|
210
|
+
@group(0) @binding(5) var<storage, read> dy : array<f32>; // [B,L,H,dh]
|
|
211
|
+
@group(0) @binding(6) var<storage, read_write> dQ : array<f32>;
|
|
212
|
+
@group(0) @binding(7) var<storage, read_write> dK : array<f32>;
|
|
213
|
+
@group(0) @binding(8) var<storage, read_write> dV : array<f32>;
|
|
214
|
+
@group(0) @binding(9) var<storage, read_write> dscores : array<f32>;
|
|
215
|
+
|
|
216
|
+
@compute @workgroup_size(16, 16, 1)
|
|
217
|
+
fn attention_backward(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
218
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
219
|
+
@builtin(workgroup_id) wid: vec3<u32>) {
|
|
220
|
+
let q_tile = wid.x;
|
|
221
|
+
let head = wid.y;
|
|
222
|
+
let batch = wid.z;
|
|
223
|
+
|
|
224
|
+
let L = params.seq_len;
|
|
225
|
+
let H = params.n_heads;
|
|
226
|
+
let dh = params.d_head;
|
|
227
|
+
let inv_sqrt = 1.0 / sqrt(f32(dh));
|
|
228
|
+
|
|
229
|
+
let row = q_tile * 16u + lid.x;
|
|
230
|
+
let d = lid.y;
|
|
231
|
+
|
|
232
|
+
if (row >= L || d >= dh) { return; }
|
|
233
|
+
|
|
234
|
+
// dV[k, d] += score[row, k] * dy[row, d]
|
|
235
|
+
// dscores[row, k] += dy[row, d] * V[k, d] (before softmax backward)
|
|
236
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
237
|
+
let s_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
238
|
+
let v_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
239
|
+
let dy_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
240
|
+
|
|
241
|
+
dV[v_idx] = dV[v_idx] + scores[s_idx] * dy[dy_idx];
|
|
242
|
+
dscores[s_idx] = dscores[s_idx] + dy[dy_idx] * V[v_idx];
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// dQ[row, d] += sum_k dscores_post_softmax[row, k] * K[k, d] * inv_sqrt
|
|
246
|
+
var dq_acc = 0.0;
|
|
247
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
248
|
+
let ds_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
249
|
+
let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
250
|
+
dq_acc = dq_acc + dscores[ds_idx] * K[k_idx];
|
|
251
|
+
}
|
|
252
|
+
let q_idx = batch * L * H * dh + row * H * dh + head * dh + d;
|
|
253
|
+
dQ[q_idx] = dQ[q_idx] + dq_acc * inv_sqrt;
|
|
254
|
+
|
|
255
|
+
// dK[k, d] += dscores[row, k] * Q[row, d] * inv_sqrt (for all rows >= k)
|
|
256
|
+
for (var k: u32 = 0u; k <= row; k = k + 1u) {
|
|
257
|
+
let ds_idx = batch * H * L * L + head * L * L + row * L + k;
|
|
258
|
+
let k_idx = batch * L * H * dh + k * H * dh + head * dh + d;
|
|
259
|
+
dK[k_idx] = dK[k_idx] + dscores[ds_idx] * Q[q_idx] * inv_sqrt;
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
`;
|
|
263
|
+
//# sourceMappingURL=attention.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"attention.js","sourceRoot":"","sources":["../../src/kernels/attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;GAcG;AAEH,iFAAiF;AAEjF,MAAM,CAAC,MAAM,YAAY,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuE7C,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,sBAAsB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAwGvD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,uBAAuB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAkExD,CAAC"}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* complex_ssd.ts – Complex-valued SSD kernels for Mamba-3.
|
|
3
|
+
*
|
|
4
|
+
* Three targeted improvements over Mamba-2 SSD:
|
|
5
|
+
*
|
|
6
|
+
* 1. Complex-valued states
|
|
7
|
+
* h ∈ ℂ^(N/2) stored as interleaved (real, imag) f32 pairs.
|
|
8
|
+
* A ∈ ℂ encoded as A_log[H, 2] = [log|A|, arg(A)].
|
|
9
|
+
*
|
|
10
|
+
* 2. Exponential-Trapezoidal (ET) discretisation
|
|
11
|
+
* A_bar = exp(Δ · A) (complex multiply)
|
|
12
|
+
* B_bar = (A_bar − 1) · A⁻¹ · B (exact, complex division)
|
|
13
|
+
*
|
|
14
|
+
* 3. MIMO recurrence (G groups of G inputs/outputs per head)
|
|
15
|
+
* Implemented here with G=1 (SISO) as the default; G>1 is a future
|
|
16
|
+
* extension that enlarges the B/C projections.
|
|
17
|
+
*
|
|
18
|
+
* Buffer layout:
|
|
19
|
+
* x : [B, L, D_inner] real-valued
|
|
20
|
+
* B_proj : [B, L, n_groups, N*2] interleaved complex (re,im)
|
|
21
|
+
* C_proj : [B, L, n_groups, N*2]
|
|
22
|
+
* dt : [B, L, H] real-valued
|
|
23
|
+
* A_log : [H, 2] [log|A|, arg(A)] per head
|
|
24
|
+
* dt_bias : [H]
|
|
25
|
+
* D_vec : [H]
|
|
26
|
+
* out : [B, L, D_inner] real-valued (Re(C·h))
|
|
27
|
+
* state_carry: [n_chunks+1, B, H, N*2, d_head] complex states
|
|
28
|
+
*
|
|
29
|
+
* Dispatch: (n_chunks, H, B)
|
|
30
|
+
*/
|
|
31
|
+
export declare const COMPLEX_SSD_FORWARD_WGSL: string;
|
|
32
|
+
export declare const COMPLEX_SSD_BACKWARD_WGSL: string;
|
|
33
|
+
//# sourceMappingURL=complex_ssd.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"complex_ssd.d.ts","sourceRoot":"","sources":["../../src/kernels/complex_ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AAEH,eAAO,MAAM,wBAAwB,EAAE,MAiJtC,CAAC;AAIF,eAAO,MAAM,yBAAyB,EAAE,MA8HvC,CAAC"}
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* complex_ssd.ts – Complex-valued SSD kernels for Mamba-3.
|
|
3
|
+
*
|
|
4
|
+
* Three targeted improvements over Mamba-2 SSD:
|
|
5
|
+
*
|
|
6
|
+
* 1. Complex-valued states
|
|
7
|
+
* h ∈ ℂ^(N/2) stored as interleaved (real, imag) f32 pairs.
|
|
8
|
+
* A ∈ ℂ encoded as A_log[H, 2] = [log|A|, arg(A)].
|
|
9
|
+
*
|
|
10
|
+
* 2. Exponential-Trapezoidal (ET) discretisation
|
|
11
|
+
* A_bar = exp(Δ · A) (complex multiply)
|
|
12
|
+
* B_bar = (A_bar − 1) · A⁻¹ · B (exact, complex division)
|
|
13
|
+
*
|
|
14
|
+
* 3. MIMO recurrence (G groups of G inputs/outputs per head)
|
|
15
|
+
* Implemented here with G=1 (SISO) as the default; G>1 is a future
|
|
16
|
+
* extension that enlarges the B/C projections.
|
|
17
|
+
*
|
|
18
|
+
* Buffer layout:
|
|
19
|
+
* x : [B, L, D_inner] real-valued
|
|
20
|
+
* B_proj : [B, L, n_groups, N*2] interleaved complex (re,im)
|
|
21
|
+
* C_proj : [B, L, n_groups, N*2]
|
|
22
|
+
* dt : [B, L, H] real-valued
|
|
23
|
+
* A_log : [H, 2] [log|A|, arg(A)] per head
|
|
24
|
+
* dt_bias : [H]
|
|
25
|
+
* D_vec : [H]
|
|
26
|
+
* out : [B, L, D_inner] real-valued (Re(C·h))
|
|
27
|
+
* state_carry: [n_chunks+1, B, H, N*2, d_head] complex states
|
|
28
|
+
*
|
|
29
|
+
* Dispatch: (n_chunks, H, B)
|
|
30
|
+
*/
|
|
31
|
+
export const COMPLEX_SSD_FORWARD_WGSL = /* wgsl */ `
|
|
32
|
+
struct CssdParams {
|
|
33
|
+
seq_len : u32,
|
|
34
|
+
d_inner : u32,
|
|
35
|
+
n_heads : u32,
|
|
36
|
+
d_head : u32,
|
|
37
|
+
n_groups : u32,
|
|
38
|
+
n_complex : u32, // N/2 – number of complex state components
|
|
39
|
+
chunk_len : u32,
|
|
40
|
+
n_chunks : u32,
|
|
41
|
+
batch : u32,
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
@group(0) @binding(0) var<uniform> params : CssdParams;
|
|
45
|
+
@group(0) @binding(1) var<storage, read> x_in : array<f32>;
|
|
46
|
+
@group(0) @binding(2) var<storage, read> B_proj : array<f32>; // complex: N_c*2 per token
|
|
47
|
+
@group(0) @binding(3) var<storage, read> C_proj : array<f32>;
|
|
48
|
+
@group(0) @binding(4) var<storage, read> dt_in : array<f32>;
|
|
49
|
+
@group(0) @binding(5) var<storage, read> A_log : array<f32>; // [H, 2]
|
|
50
|
+
@group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
|
|
51
|
+
@group(0) @binding(7) var<storage, read> D_vec : array<f32>;
|
|
52
|
+
@group(0) @binding(8) var<storage, read_write> out_buf : array<f32>;
|
|
53
|
+
@group(0) @binding(9) var<storage, read_write> state_carry : array<f32>; // complex states
|
|
54
|
+
|
|
55
|
+
fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
|
|
56
|
+
|
|
57
|
+
// Complex multiply: (ar + i·ai) * (br + i·bi)
|
|
58
|
+
fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
|
|
59
|
+
fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
|
|
60
|
+
|
|
61
|
+
// Complex exp: exp(x + i·y) = exp(x)*(cos(y) + i*sin(y))
|
|
62
|
+
fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
|
|
63
|
+
fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
|
|
64
|
+
|
|
65
|
+
// ET discretisation B_bar = (A_bar - 1) * A^-1 * B
|
|
66
|
+
// A^-1 = 1/A = conj(A)/|A|^2. Here A = exp(log_mag)*exp(i*phase).
|
|
67
|
+
// |A| = exp(log_mag), A^-1 = exp(-log_mag)*exp(-i*phase)
|
|
68
|
+
// (A_bar - 1) * A^-1 = scalar complex product computed below.
|
|
69
|
+
fn et_bbar_re(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
|
|
70
|
+
// (A_bar - 1)
|
|
71
|
+
let num_re = a_bar_re - 1.0;
|
|
72
|
+
let num_im = a_bar_im;
|
|
73
|
+
// A^-1 = exp(-log_mag - i*phase)
|
|
74
|
+
let inv_re = cexp_re(-log_mag, -phase);
|
|
75
|
+
let inv_im = cexp_im(-log_mag, -phase);
|
|
76
|
+
return cmul_re(num_re, num_im, inv_re, inv_im);
|
|
77
|
+
}
|
|
78
|
+
fn et_bbar_im(a_bar_re: f32, a_bar_im: f32, log_mag: f32, phase: f32) -> f32 {
|
|
79
|
+
let num_re = a_bar_re - 1.0;
|
|
80
|
+
let num_im = a_bar_im;
|
|
81
|
+
let inv_re = cexp_re(-log_mag, -phase);
|
|
82
|
+
let inv_im = cexp_im(-log_mag, -phase);
|
|
83
|
+
return cmul_im(num_re, num_im, inv_re, inv_im);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@compute @workgroup_size(1, 1, 1)
|
|
87
|
+
fn complex_ssd_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
88
|
+
let chunk_id = gid.x;
|
|
89
|
+
let head_id = gid.y;
|
|
90
|
+
let batch_id = gid.z;
|
|
91
|
+
|
|
92
|
+
let L = params.seq_len;
|
|
93
|
+
let D = params.d_inner;
|
|
94
|
+
let H = params.n_heads;
|
|
95
|
+
let dh = params.d_head;
|
|
96
|
+
let G = params.n_groups;
|
|
97
|
+
let Nc = params.n_complex; // complex state count
|
|
98
|
+
let N2 = Nc * 2u; // float pairs
|
|
99
|
+
let CL = params.chunk_len;
|
|
100
|
+
let B = params.batch;
|
|
101
|
+
|
|
102
|
+
let t_start = chunk_id * CL;
|
|
103
|
+
let t_end = min(t_start + CL, L);
|
|
104
|
+
let group_id = head_id * G / H;
|
|
105
|
+
|
|
106
|
+
// Load A for this head: A = exp(log_mag) * exp(i*phase)
|
|
107
|
+
let log_mag = A_log[head_id * 2u + 0u];
|
|
108
|
+
let phase = A_log[head_id * 2u + 1u];
|
|
109
|
+
let db = dt_bias[head_id];
|
|
110
|
+
let d_skip = D_vec[head_id];
|
|
111
|
+
|
|
112
|
+
// State buffer strides (complex: N2*dh floats per head)
|
|
113
|
+
let state_stride = B * H * N2 * dh;
|
|
114
|
+
let state_base_in = chunk_id * state_stride
|
|
115
|
+
+ batch_id * H * N2 * dh
|
|
116
|
+
+ head_id * N2 * dh;
|
|
117
|
+
let state_base_out = (chunk_id + 1u) * state_stride
|
|
118
|
+
+ batch_id * H * N2 * dh
|
|
119
|
+
+ head_id * N2 * dh;
|
|
120
|
+
|
|
121
|
+
// Copy carry-in to working slot
|
|
122
|
+
for (var s: u32 = 0u; s < N2 * dh; s = s + 1u) {
|
|
123
|
+
state_carry[state_base_out + s] = state_carry[state_base_in + s];
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
for (var t: u32 = t_start; t < t_end; t = t + 1u) {
|
|
127
|
+
let dt_idx = batch_id * L * H + t * H + head_id;
|
|
128
|
+
let dt_val = softplus(dt_in[dt_idx] + db);
|
|
129
|
+
|
|
130
|
+
// A_bar = exp(dt * A) = exp(dt*log_mag + i*dt*phase)
|
|
131
|
+
let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
|
|
132
|
+
let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
|
|
133
|
+
|
|
134
|
+
// ET B_bar scalar factor (applied per B_proj element)
|
|
135
|
+
let bbar_factor_re = et_bbar_re(a_bar_re, a_bar_im, log_mag, phase);
|
|
136
|
+
let bbar_factor_im = et_bbar_im(a_bar_re, a_bar_im, log_mag, phase);
|
|
137
|
+
|
|
138
|
+
let x_base = batch_id * L * D + t * D + head_id * dh;
|
|
139
|
+
// B_proj / C_proj: [B, L, G, N*2] — interleaved re/im
|
|
140
|
+
let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
|
|
141
|
+
|
|
142
|
+
for (var i: u32 = 0u; i < dh; i = i + 1u) {
|
|
143
|
+
let x_val = x_in[x_base + i];
|
|
144
|
+
var y_re = 0.0;
|
|
145
|
+
|
|
146
|
+
for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
|
|
147
|
+
let b_re = B_proj[bc_base + nc * 2u + 0u];
|
|
148
|
+
let b_im = B_proj[bc_base + nc * 2u + 1u];
|
|
149
|
+
let c_re = C_proj[bc_base + nc * 2u + 0u];
|
|
150
|
+
let c_im = C_proj[bc_base + nc * 2u + 1u];
|
|
151
|
+
|
|
152
|
+
// B_bar · x (complex * real = complex scale)
|
|
153
|
+
let inp_re = cmul_re(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
|
|
154
|
+
let inp_im = cmul_im(bbar_factor_re, bbar_factor_im, b_re, b_im) * x_val;
|
|
155
|
+
|
|
156
|
+
let s_re_idx = state_base_out + nc * 2u * dh + 0u * dh + i;
|
|
157
|
+
let s_im_idx = state_base_out + nc * 2u * dh + 1u * dh + i;
|
|
158
|
+
|
|
159
|
+
// h_t = A_bar * h_{t-1} + B_bar * x
|
|
160
|
+
let h_prev_re = state_carry[s_re_idx];
|
|
161
|
+
let h_prev_im = state_carry[s_im_idx];
|
|
162
|
+
let h_new_re = cmul_re(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_re;
|
|
163
|
+
let h_new_im = cmul_im(a_bar_re, a_bar_im, h_prev_re, h_prev_im) + inp_im;
|
|
164
|
+
state_carry[s_re_idx] = h_new_re;
|
|
165
|
+
state_carry[s_im_idx] = h_new_im;
|
|
166
|
+
|
|
167
|
+
// y += Re(C · h)
|
|
168
|
+
y_re = y_re + cmul_re(c_re, -c_im, h_new_re, h_new_im); // C·h real part
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
let out_idx = batch_id * L * D + t * D + head_id * dh + i;
|
|
172
|
+
out_buf[out_idx] = y_re + d_skip * x_val;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
`;
|
|
177
|
+
// ── Backward ──────────────────────────────────────────────────────────────────
|
|
178
|
+
export const COMPLEX_SSD_BACKWARD_WGSL = /* wgsl */ `
|
|
179
|
+
struct CssdParams {
|
|
180
|
+
seq_len : u32,
|
|
181
|
+
d_inner : u32,
|
|
182
|
+
n_heads : u32,
|
|
183
|
+
d_head : u32,
|
|
184
|
+
n_groups : u32,
|
|
185
|
+
n_complex : u32,
|
|
186
|
+
chunk_len : u32,
|
|
187
|
+
n_chunks : u32,
|
|
188
|
+
batch : u32,
|
|
189
|
+
};
|
|
190
|
+
|
|
191
|
+
@group(0) @binding(0) var<uniform> params : CssdParams;
|
|
192
|
+
@group(0) @binding(1) var<storage, read> x_in : array<f32>;
|
|
193
|
+
@group(0) @binding(2) var<storage, read> B_proj : array<f32>;
|
|
194
|
+
@group(0) @binding(3) var<storage, read> C_proj : array<f32>;
|
|
195
|
+
@group(0) @binding(4) var<storage, read> dt_in : array<f32>;
|
|
196
|
+
@group(0) @binding(5) var<storage, read> A_log : array<f32>;
|
|
197
|
+
@group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
|
|
198
|
+
@group(0) @binding(7) var<storage, read> state_carry : array<f32>;
|
|
199
|
+
@group(0) @binding(8) var<storage, read> dy : array<f32>;
|
|
200
|
+
@group(0) @binding(9) var<storage, read_write> dx : array<f32>;
|
|
201
|
+
@group(0) @binding(10) var<storage, read_write> dB : array<f32>;
|
|
202
|
+
@group(0) @binding(11) var<storage, read_write> dC : array<f32>;
|
|
203
|
+
@group(0) @binding(12) var<storage, read_write> ddt : array<f32>;
|
|
204
|
+
@group(0) @binding(13) var<storage, read_write> dA_log : array<f32>;
|
|
205
|
+
@group(0) @binding(14) var<storage, read_write> dD_vec : array<f32>;
|
|
206
|
+
|
|
207
|
+
fn softplus(v: f32) -> f32 { return log(1.0 + exp(v)); }
|
|
208
|
+
fn d_softplus(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
|
|
209
|
+
fn cmul_re(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*br - ai*bi; }
|
|
210
|
+
fn cmul_im(ar: f32, ai: f32, br: f32, bi: f32) -> f32 { return ar*bi + ai*br; }
|
|
211
|
+
fn cexp_re(x: f32, y: f32) -> f32 { return exp(x) * cos(y); }
|
|
212
|
+
fn cexp_im(x: f32, y: f32) -> f32 { return exp(x) * sin(y); }
|
|
213
|
+
|
|
214
|
+
@compute @workgroup_size(1, 1, 1)
|
|
215
|
+
fn complex_ssd_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
216
|
+
let chunk_id = gid.x;
|
|
217
|
+
let head_id = gid.y;
|
|
218
|
+
let batch_id = gid.z;
|
|
219
|
+
|
|
220
|
+
let L = params.seq_len;
|
|
221
|
+
let D = params.d_inner;
|
|
222
|
+
let H = params.n_heads;
|
|
223
|
+
let dh = params.d_head;
|
|
224
|
+
let G = params.n_groups;
|
|
225
|
+
let Nc = params.n_complex;
|
|
226
|
+
let N2 = Nc * 2u;
|
|
227
|
+
let CL = params.chunk_len;
|
|
228
|
+
let B = params.batch;
|
|
229
|
+
|
|
230
|
+
let t_start = chunk_id * CL;
|
|
231
|
+
let t_end = min(t_start + CL, L);
|
|
232
|
+
let group_id = head_id * G / H;
|
|
233
|
+
|
|
234
|
+
let log_mag = A_log[head_id * 2u + 0u];
|
|
235
|
+
let phase = A_log[head_id * 2u + 1u];
|
|
236
|
+
let db = dt_bias[head_id];
|
|
237
|
+
|
|
238
|
+
let state_stride = B * H * N2 * dh;
|
|
239
|
+
|
|
240
|
+
for (var t_rev: u32 = 0u; t_rev < t_end - t_start; t_rev = t_rev + 1u) {
|
|
241
|
+
let t = t_end - 1u - t_rev;
|
|
242
|
+
|
|
243
|
+
let dt_idx = batch_id * L * H + t * H + head_id;
|
|
244
|
+
let dt_raw = dt_in[dt_idx] + db;
|
|
245
|
+
let dt_val = softplus(dt_raw);
|
|
246
|
+
let a_bar_re = cexp_re(dt_val * log_mag, dt_val * phase);
|
|
247
|
+
let a_bar_im = cexp_im(dt_val * log_mag, dt_val * phase);
|
|
248
|
+
|
|
249
|
+
let x_base = batch_id * L * D + t * D + head_id * dh;
|
|
250
|
+
let bc_base = batch_id * L * G * N2 + t * G * N2 + group_id * N2;
|
|
251
|
+
let state_base = (chunk_id + 1u) * state_stride
|
|
252
|
+
+ batch_id * H * N2 * dh
|
|
253
|
+
+ head_id * N2 * dh;
|
|
254
|
+
let state_prev = chunk_id * state_stride
|
|
255
|
+
+ batch_id * H * N2 * dh
|
|
256
|
+
+ head_id * N2 * dh;
|
|
257
|
+
|
|
258
|
+
for (var i: u32 = 0u; i < dh; i = i + 1u) {
|
|
259
|
+
let dy_val = dy[batch_id * L * D + t * D + head_id * dh + i];
|
|
260
|
+
let x_val = x_in[x_base + i];
|
|
261
|
+
|
|
262
|
+
dD_vec[head_id] = dD_vec[head_id] + dy_val * x_val;
|
|
263
|
+
dx[x_base + i] = dx[x_base + i] + dy_val;
|
|
264
|
+
|
|
265
|
+
for (var nc: u32 = 0u; nc < Nc; nc = nc + 1u) {
|
|
266
|
+
let c_re = C_proj[bc_base + nc * 2u + 0u];
|
|
267
|
+
let c_im = C_proj[bc_base + nc * 2u + 1u];
|
|
268
|
+
let b_re = B_proj[bc_base + nc * 2u + 0u];
|
|
269
|
+
let b_im = B_proj[bc_base + nc * 2u + 1u];
|
|
270
|
+
|
|
271
|
+
let h_re = state_carry[state_base + nc * 2u * dh + 0u * dh + i];
|
|
272
|
+
let h_im = state_carry[state_base + nc * 2u * dh + 1u * dh + i];
|
|
273
|
+
|
|
274
|
+
// dC from Re(C · h) output — gradient of Re(C·h) w.r.t. C is Re(h)
|
|
275
|
+
dC[bc_base + nc * 2u + 0u] = dC[bc_base + nc * 2u + 0u] + dy_val * h_re;
|
|
276
|
+
dC[bc_base + nc * 2u + 1u] = dC[bc_base + nc * 2u + 1u] - dy_val * h_im;
|
|
277
|
+
|
|
278
|
+
// dh from upstream: dh_re = c_re * dy, dh_im = -c_im * dy (Re(C·h) gradient)
|
|
279
|
+
let dh_re = c_re * dy_val;
|
|
280
|
+
let dh_im = -c_im * dy_val;
|
|
281
|
+
|
|
282
|
+
// dB: B_bar · x contributed h_new; gradient flows through B_bar
|
|
283
|
+
// simplified: dB += dh * x (ignoring complex B_bar Jacobian)
|
|
284
|
+
dB[bc_base + nc * 2u + 0u] = dB[bc_base + nc * 2u + 0u] + dh_re * x_val;
|
|
285
|
+
dB[bc_base + nc * 2u + 1u] = dB[bc_base + nc * 2u + 1u] + dh_im * x_val;
|
|
286
|
+
|
|
287
|
+
// dx += Re(B_bar* · dh) (simplified)
|
|
288
|
+
dx[x_base + i] = dx[x_base + i] + cmul_re(b_re, -b_im, dh_re, dh_im);
|
|
289
|
+
|
|
290
|
+
// ddt: from A_bar and B_bar dependence on dt
|
|
291
|
+
let h_prev_re = state_carry[state_prev + nc * 2u * dh + 0u * dh + i];
|
|
292
|
+
let h_prev_im = state_carry[state_prev + nc * 2u * dh + 1u * dh + i];
|
|
293
|
+
// dA_bar/ddt = A * A_bar
|
|
294
|
+
let da_bar_re = cmul_re(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
|
|
295
|
+
let da_bar_im = cmul_im(cexp_re(log_mag, phase), cexp_im(log_mag, phase), a_bar_re, a_bar_im);
|
|
296
|
+
ddt[dt_idx] = ddt[dt_idx]
|
|
297
|
+
+ (cmul_re(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_re
|
|
298
|
+
- cmul_im(da_bar_re, da_bar_im, h_prev_re, h_prev_im) * dh_im)
|
|
299
|
+
* d_softplus(dt_raw);
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
`;
|
|
305
|
+
//# sourceMappingURL=complex_ssd.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"complex_ssd.js","sourceRoot":"","sources":["../../src/kernels/complex_ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AAEH,MAAM,CAAC,MAAM,wBAAwB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAiJzD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,yBAAyB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA8H1D,CAAC"}
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
export declare const CONV1D_FORWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32, // L\n d_channels : u32, // D (number of depthwise channels in this call)\n kernel_size : u32, // K (typically 4)\n batch : u32, // B\n groups : u32, // number of channel groups (1 = standard depthwise)\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n// x (B, L, D) \u2013 input\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n// weight (D, K) \u2013 depthwise conv weights\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n// bias (D,) \u2013 optional bias (zeros if unused)\n@group(0) @binding(3) var<storage, read> bias : array<f32>;\n// y (B, L, D) \u2013 output\n@group(0) @binding(4) var<storage, read_write> y : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B)\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_forward(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x; // time position\n let d = gid.y; // channel\n let b = gid.z; // batch\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var acc: f32 = 0.0;\n\n // Causal: convolve over k = 0..K-1, reading position (t - k)\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let w_idx = d * K + k;\n let w_val = weight[w_idx];\n\n // t - k: use causal zero-padding for t < k\n if (t >= k) {\n let src = b * L * D + (t - k) * D + d;\n acc = acc + w_val * x[src];\n }\n // else: zero-padding contributes 0\n }\n\n acc = acc + bias[d];\n\n let out = b * L * D + t * D + d;\n y[out] = acc;\n}\n";
|
|
2
|
+
export declare const CONV1D_BACKWARD_WGSL = "\n\nstruct ConvParams {\n seq_len : u32,\n d_channels : u32,\n kernel_size : u32,\n batch : u32,\n};\n\n@group(0) @binding(0) var<uniform> params : ConvParams;\n@group(0) @binding(1) var<storage, read> x : array<f32>;\n@group(0) @binding(2) var<storage, read> weight : array<f32>;\n@group(0) @binding(3) var<storage, read> dy : array<f32>;\n@group(0) @binding(4) var<storage, read_write> dx : array<f32>;\n@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;\n@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;\n\n// Dispatch: (ceil(L/16), ceil(D/16), B) \u2013 computes dx\n@compute @workgroup_size(16, 16, 1)\nfn conv1d_backward_dx(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let t = gid.x;\n let d = gid.y;\n let b = gid.z;\n\n if (t >= L || d >= D || b >= B) { return; }\n\n var grad: f32 = 0.0;\n\n // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]\n for (var k: u32 = 0u; k < K; k = k + 1u) {\n let tp = t + k;\n if (tp < L) {\n let dy_idx = b * L * D + tp * D + d;\n let w_idx = d * K + k;\n grad = grad + dy[dy_idx] * weight[w_idx];\n }\n }\n\n let dx_idx = b * L * D + t * D + d;\n dx[dx_idx] = grad;\n}\n\n// Dispatch: (K, D, 1) \u2013 accumulates dweight over (B, L)\n@compute @workgroup_size(1, 1, 1)\nfn conv1d_backward_dw(\n @builtin(global_invocation_id) gid : vec3<u32>,\n) {\n let L = params.seq_len;\n let D = params.d_channels;\n let K = params.kernel_size;\n let B = params.batch;\n\n let k = gid.x;\n let d = gid.y;\n\n if (k >= K || d >= D) { return; }\n\n var grad_w: f32 = 0.0;\n var grad_b: f32 = 0.0;\n\n for (var b: u32 = 0u; b < B; b = b + 1u) {\n for (var t: u32 = 0u; t < L; t = t + 1u) {\n let dy_idx = b * L * D + t * D + d;\n let dy_val = dy[dy_idx];\n if (t >= k) {\n let x_idx = b * L * D + (t - k) * D + d;\n grad_w = grad_w + dy_val * x[x_idx];\n }\n if (k == 0u) {\n grad_b = grad_b + dy_val;\n }\n }\n }\n\n dweight[d * K + k] = grad_w;\n if (k == 0u) {\n dbias[d] = grad_b;\n }\n}\n";
|
|
3
|
+
//# sourceMappingURL=conv1d.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"conv1d.d.ts","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAaA,eAAO,MAAM,mBAAmB,mxDAwD/B,CAAC;AAGF,eAAO,MAAM,oBAAoB,y6EAsFhC,CAAC"}
|