@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,120 @@
1
+ // Weight Update WGSL Kernel (AdamW Optimizer)
2
+ // Implements fused AdamW parameter update on the GPU.
3
+ //
4
+ // AdamW update rule:
5
+ // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
6
+ // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
7
+ // m_hat = m_t / (1 - beta1^t)
8
+ // v_hat = v_t / (1 - beta2^t)
9
+ // theta_t = theta_{t-1} * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
10
+
11
+ export const WEIGHT_UPDATE_WGSL: string = /* wgsl */`
12
+
13
+ struct AdamParams {
14
+ num_elements : u32,
15
+ lr : f32, // learning rate
16
+ beta1 : f32, // default 0.9
17
+ beta2 : f32, // default 0.999
18
+ eps : f32, // default 1e-8
19
+ weight_decay : f32, // default 0.01
20
+ beta1_t : f32, // beta1^t (precomputed bias correction term)
21
+ beta2_t : f32, // beta2^t
22
+ };
23
+
24
+ @group(0) @binding(0) var<uniform> adam : AdamParams;
25
+ // param (N,) – weight tensor (read-write: updated in-place)
26
+ @group(0) @binding(1) var<storage, read_write> param : array<f32>;
27
+ // grad (N,) – gradient
28
+ @group(0) @binding(2) var<storage, read> grad : array<f32>;
29
+ // m (N,) – first moment
30
+ @group(0) @binding(3) var<storage, read_write> m_state : array<f32>;
31
+ // v (N,) – second moment
32
+ @group(0) @binding(4) var<storage, read_write> v_state : array<f32>;
33
+
34
+ // Dispatch: (ceil(N / 256), 1, 1)
35
+ @compute @workgroup_size(256, 1, 1)
36
+ fn adamw_update(
37
+ @builtin(global_invocation_id) gid : vec3<u32>,
38
+ ) {
39
+ let i = gid.x;
40
+ if (i >= adam.num_elements) { return; }
41
+
42
+ let g = grad[i];
43
+ let p = param[i];
44
+
45
+ // Moment updates
46
+ let m_new = adam.beta1 * m_state[i] + (1.0 - adam.beta1) * g;
47
+ let v_new = adam.beta2 * v_state[i] + (1.0 - adam.beta2) * g * g;
48
+ m_state[i] = m_new;
49
+ v_state[i] = v_new;
50
+
51
+ // Bias-corrected estimates
52
+ let m_hat = m_new / (1.0 - adam.beta1_t);
53
+ let v_hat = v_new / (1.0 - adam.beta2_t);
54
+
55
+ // Weight decay (decoupled) + gradient step
56
+ param[i] = p * (1.0 - adam.lr * adam.weight_decay) -
57
+ adam.lr * m_hat / (sqrt(v_hat) + adam.eps);
58
+ }
59
+ `;
60
+
61
+ // Gradient clipping kernel – clips global gradient norm to max_norm.
62
+ // Run before weight updates. Two-pass: first compute squared norm, then scale.
63
+ export const GRAD_CLIP_WGSL: string = /* wgsl */`
64
+
65
+ struct ClipParams {
66
+ num_elements : u32,
67
+ max_norm_sq : f32, // max_norm^2
68
+ };
69
+
70
+ @group(0) @binding(0) var<uniform> clip_p : ClipParams;
71
+ @group(0) @binding(1) var<storage, read_write> grad : array<f32>;
72
+ @group(0) @binding(2) var<storage, read_write> norm_sq : array<f32>; // size 1, atomic accumulator
73
+
74
+ var<workgroup> local_sq : array<f32, 256>;
75
+
76
+ // Pass 1: reduce sum of squares into norm_sq[0]
77
+ @compute @workgroup_size(256, 1, 1)
78
+ fn grad_norm_reduce(
79
+ @builtin(global_invocation_id) gid : vec3<u32>,
80
+ @builtin(local_invocation_index) lid : u32,
81
+ ) {
82
+ let i = gid.x;
83
+ local_sq[lid] = 0.0;
84
+ if (i < clip_p.num_elements) {
85
+ local_sq[lid] = grad[i] * grad[i];
86
+ }
87
+ workgroupBarrier();
88
+
89
+ // Parallel reduction within workgroup
90
+ var s: u32 = 128u;
91
+ loop {
92
+ if (s == 0u) { break; }
93
+ if (lid < s) {
94
+ local_sq[lid] = local_sq[lid] + local_sq[lid + s];
95
+ }
96
+ workgroupBarrier();
97
+ s = s >> 1u;
98
+ }
99
+
100
+ if (lid == 0u) {
101
+ // Non-atomic accumulation (single workgroup assumption for small models)
102
+ norm_sq[0] = norm_sq[0] + local_sq[0];
103
+ }
104
+ }
105
+
106
+ // Pass 2: scale gradients if norm exceeds max_norm
107
+ @compute @workgroup_size(256, 1, 1)
108
+ fn grad_clip_scale(
109
+ @builtin(global_invocation_id) gid : vec3<u32>,
110
+ ) {
111
+ let i = gid.x;
112
+ if (i >= clip_p.num_elements) { return; }
113
+
114
+ let ns = norm_sq[0];
115
+ if (ns > clip_p.max_norm_sq) {
116
+ let scale = sqrt(clip_p.max_norm_sq / ns);
117
+ grad[i] = grad[i] * scale;
118
+ }
119
+ }
120
+ `;
@@ -0,0 +1,344 @@
1
+ /**
2
+ * attention_block.ts – Causal Multi-Head Self-Attention Block.
3
+ *
4
+ * Intentionally simple for WebGPU — naive O(L²) tiled attention,
5
+ * no Flash-Attention dependency. Suitable for hybrid (Jamba/Zamba) schedules
6
+ * where a few attention layers interleave with SSM layers.
7
+ *
8
+ * Data flow:
9
+ * Input (B, L, D_model)
10
+ * └─ RMSNorm
11
+ * └─ wQKV → Q (B,L,H,dh), K (B,L,H,dh), V (B,L,H,dh)
12
+ * └─ causal attention scores = Q·Kᵀ / √dh (masked)
13
+ * └─ softmax
14
+ * └─ weighted V sum
15
+ * └─ concat heads → wO → D_model
16
+ * └─ + residual
17
+ * [optional FFN sublayer]
18
+ *
19
+ * Implements SequenceLayer.
20
+ */
21
+
22
+ import {
23
+ createComputePipeline,
24
+ createBindGroup,
25
+ createStorageBuffer,
26
+ createEmptyStorageBuffer,
27
+ createUniformBuffer,
28
+ dispatchKernel,
29
+ cdiv,
30
+ } from '../utils/gpu_utils.js';
31
+
32
+ import {
33
+ ATTENTION_FORWARD_WGSL,
34
+ SOFTMAX_WGSL,
35
+ } from '../kernels/attention.js';
36
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
37
+ import { gaussianArray } from '../utils/rng.js';
38
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
39
+
40
+ import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
41
+
42
+ export interface AttentionBlockConfig {
43
+ dModel : number;
44
+ nHeads : number;
45
+ dHead? : number; // default dModel / nHeads
46
+ hasFfn? : boolean; // include 4×dModel FFN sublayer
47
+ ffnMult?: number; // FFN expansion factor (default 4)
48
+ }
49
+
50
+ export interface AttentionCache {
51
+ scores: GPUBuffer; // post-softmax scores for backward
52
+ }
53
+
54
+ const ADD_SHADER = /* wgsl */`
55
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
56
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
57
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
58
+ @group(0) @binding(3) var<uniform> n : u32;
59
+ @compute @workgroup_size(256)
60
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
61
+ let i = gid.x;
62
+ if (i < n) { c[i] = a[i] + b[i]; }
63
+ }
64
+ `;
65
+
66
+ // SiLU for FFN
67
+ const SILU_SHADER = /* wgsl */`
68
+ struct ActParams { num_elements: u32; };
69
+ @group(0) @binding(0) var<uniform> p : ActParams;
70
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
71
+ @group(0) @binding(2) var<storage, read_write> y : array<f32>;
72
+ @compute @workgroup_size(256, 1, 1)
73
+ fn silu_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
74
+ let i = gid.x;
75
+ if (i >= p.num_elements) { return; }
76
+ let v = x[i];
77
+ y[i] = v / (1.0 + exp(-v));
78
+ }
79
+ `;
80
+
81
+ export class AttentionBlock implements SequenceLayer {
82
+ readonly layerType = 'attention' as const;
83
+
84
+ device : GPUDevice;
85
+ config : Required<AttentionBlockConfig>;
86
+ dHead : number;
87
+
88
+ gpuWeights: Record<string, GPUBuffer>;
89
+ pipelines : Record<string, GPUComputePipeline>;
90
+
91
+ constructor(device: GPUDevice, config: AttentionBlockConfig) {
92
+ this.device = device;
93
+
94
+ if (config.dModel % config.nHeads !== 0) {
95
+ throw new Error(
96
+ `AttentionBlock: dModel (${config.dModel}) must be divisible by nHeads (${config.nHeads}).`
97
+ );
98
+ }
99
+
100
+ this.config = {
101
+ dHead : config.dModel / config.nHeads,
102
+ hasFfn : false,
103
+ ffnMult: 4,
104
+ ...config,
105
+ } as Required<AttentionBlockConfig>;
106
+
107
+ this.dHead = this.config.dHead;
108
+
109
+ this.gpuWeights = {};
110
+ this.pipelines = {};
111
+
112
+ this._initWeights();
113
+ this._buildPipelines();
114
+ }
115
+
116
+ private _initWeights(): void {
117
+ const { dModel, nHeads, hasFfn, ffnMult } = this.config;
118
+
119
+ const randn = (n: number, std = 0.02): Float32Array => gaussianArray(n, std);
120
+
121
+ const zeros = (n: number) => new Float32Array(n);
122
+ const ones = (n: number) => new Float32Array(n).fill(1.0);
123
+ const mk = (arr: Float32Array) => createStorageBuffer(this.device, arr, true);
124
+
125
+ this.gpuWeights = {
126
+ wQKV : mk(randn(3 * dModel * dModel)),
127
+ bQKV : mk(zeros(3 * dModel)),
128
+ wO : mk(randn(dModel * dModel)),
129
+ bO : mk(zeros(dModel)),
130
+ normWeight: mk(ones(dModel)),
131
+ };
132
+
133
+ if (hasFfn) {
134
+ const ffnDim = dModel * ffnMult;
135
+ this.gpuWeights['wFfn1'] = mk(randn(ffnDim * dModel));
136
+ this.gpuWeights['bFfn1'] = mk(zeros(ffnDim));
137
+ this.gpuWeights['wFfn2'] = mk(randn(dModel * ffnDim));
138
+ this.gpuWeights['bFfn2'] = mk(zeros(dModel));
139
+ }
140
+ }
141
+
142
+ private _buildPipelines(): void {
143
+ const d = this.device;
144
+ this.pipelines = {
145
+ linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
146
+ rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
147
+ attn_fwd: createComputePipeline(d, ATTENTION_FORWARD_WGSL, 'attention_forward'),
148
+ attn_val: createComputePipeline(d, ATTENTION_FORWARD_WGSL, 'attention_value'),
149
+ softmax : createComputePipeline(d, SOFTMAX_WGSL, 'softmax_forward'),
150
+ elAdd : createComputePipeline(d, ADD_SHADER, 'main'),
151
+ };
152
+
153
+ if (this.config.hasFfn) {
154
+ this.pipelines['silu'] = createComputePipeline(d, SILU_SHADER, 'silu_forward');
155
+ }
156
+ }
157
+
158
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult {
159
+ const d = this.device;
160
+ const { dModel, nHeads, hasFfn } = this.config;
161
+ const dh = this.dHead;
162
+ const B = batch;
163
+ const L = seqLen;
164
+ const M = B * L;
165
+ const H = nHeads;
166
+
167
+ // 1. Pre-block RMSNorm
168
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
169
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
170
+ {
171
+ const params = new ArrayBuffer(16);
172
+ new Uint32Array(params, 0, 2).set([M, dModel]);
173
+ new Float32Array(params, 8, 1).set([1e-6]);
174
+ const pBuf = createUniformBuffer(d, params);
175
+ const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
176
+ [pBuf, xBuf, this.gpuWeights['normWeight']!, normOut, normInv]);
177
+ dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
178
+ }
179
+ normInv.destroy();
180
+
181
+ // 2. QKV projection: [B, L, 3*D]
182
+ const qkvOut = createEmptyStorageBuffer(d, M * 3 * dModel * 4, true);
183
+ {
184
+ const params = new Uint32Array([M, dModel, 3 * dModel]).buffer;
185
+ const pBuf = createUniformBuffer(d, params);
186
+ const bg = createBindGroup(d, this.pipelines['linear']!,
187
+ [pBuf, normOut, this.gpuWeights['wQKV']!, this.gpuWeights['bQKV']!, qkvOut]);
188
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(3 * dModel, 16), 1]);
189
+ }
190
+ normOut.destroy();
191
+
192
+ // Split QKV into Q, K, V: each [B, L, H, dh] = [B, L, D]
193
+ const QBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
194
+ const KBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
195
+ const VBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
196
+ {
197
+ const enc = d.createCommandEncoder();
198
+ enc.copyBufferToBuffer(qkvOut, 0, QBuf, 0, M * dModel * 4);
199
+ enc.copyBufferToBuffer(qkvOut, M * dModel * 4, KBuf, 0, M * dModel * 4);
200
+ enc.copyBufferToBuffer(qkvOut, 2 * M * dModel * 4, VBuf, 0, M * dModel * 4);
201
+ d.queue.submit([enc.finish()]);
202
+ }
203
+ qkvOut.destroy();
204
+
205
+ // 3. Attention scores: [B, H, L, L]
206
+ const scores = createEmptyStorageBuffer(d, B * H * L * L * 4, true);
207
+ {
208
+ const attnParams = new Uint32Array([B, L, dModel, H, dh]).buffer;
209
+ const pBuf = createUniformBuffer(d, attnParams);
210
+ const bg = createBindGroup(d, this.pipelines['attn_fwd']!,
211
+ [pBuf, QBuf, KBuf, VBuf, scores,
212
+ createEmptyStorageBuffer(d, M * dModel * 4, true)]); // out_buf placeholder
213
+ dispatchKernel(d, this.pipelines['attn_fwd']!, bg, [cdiv(L, 16), H, B]);
214
+ }
215
+
216
+ // 4. Softmax (causal) per row: dispatch (L, H, B)
217
+ {
218
+ const smParams = new Uint32Array([L, L, 1]).buffer;
219
+ const pBuf = createUniformBuffer(d, smParams);
220
+ const bg = createBindGroup(d, this.pipelines['softmax']!,
221
+ [pBuf, scores]);
222
+ dispatchKernel(d, this.pipelines['softmax']!, bg, [L, H, B]);
223
+ }
224
+
225
+ // 5. Weighted V sum → attn output [B, L, H, dh]
226
+ const attnOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
227
+ {
228
+ const attnParams = new Uint32Array([B, L, dModel, H, dh]).buffer;
229
+ const pBuf = createUniformBuffer(d, attnParams);
230
+ const bg = createBindGroup(d, this.pipelines['attn_val']!,
231
+ [pBuf, QBuf, KBuf, VBuf, scores, attnOut]);
232
+ dispatchKernel(d, this.pipelines['attn_val']!, bg, [cdiv(L, 16), H, B]);
233
+ }
234
+ QBuf.destroy();
235
+ KBuf.destroy();
236
+ VBuf.destroy();
237
+
238
+ // 6. Output projection: [B, L, D] → [B, L, D]
239
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
240
+ {
241
+ const params = new Uint32Array([M, dModel, dModel]).buffer;
242
+ const pBuf = createUniformBuffer(d, params);
243
+ const bg = createBindGroup(d, this.pipelines['linear']!,
244
+ [pBuf, attnOut, this.gpuWeights['wO']!, this.gpuWeights['bO']!, outProjOut]);
245
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
246
+ }
247
+ attnOut.destroy();
248
+
249
+ // 7. Residual add
250
+ let current = createEmptyStorageBuffer(d, M * dModel * 4, true);
251
+ {
252
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
253
+ const bg = createBindGroup(d, this.pipelines['elAdd']!,
254
+ [outProjOut, xBuf, current, nBuf]);
255
+ dispatchKernel(d, this.pipelines['elAdd']!, bg, [cdiv(M * dModel, 256), 1, 1]);
256
+ }
257
+ outProjOut.destroy();
258
+
259
+ // 8. Optional FFN sublayer
260
+ if (hasFfn) {
261
+ const { ffnMult } = this.config;
262
+ const ffnDim = dModel * ffnMult;
263
+
264
+ const ffn1Out = createEmptyStorageBuffer(d, M * ffnDim * 4, true);
265
+ {
266
+ const params = new Uint32Array([M, dModel, ffnDim]).buffer;
267
+ const pBuf = createUniformBuffer(d, params);
268
+ const bg = createBindGroup(d, this.pipelines['linear']!,
269
+ [pBuf, current, this.gpuWeights['wFfn1']!, this.gpuWeights['bFfn1']!, ffn1Out]);
270
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(ffnDim, 16), 1]);
271
+ }
272
+
273
+ const siluOut = createEmptyStorageBuffer(d, M * ffnDim * 4, true);
274
+ {
275
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * ffnDim]).buffer);
276
+ const bg = createBindGroup(d, this.pipelines['silu']!,
277
+ [nBuf, ffn1Out, siluOut]);
278
+ dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * ffnDim, 256), 1, 1]);
279
+ }
280
+ ffn1Out.destroy();
281
+
282
+ const ffn2Out = createEmptyStorageBuffer(d, M * dModel * 4, true);
283
+ {
284
+ const params = new Uint32Array([M, ffnDim, dModel]).buffer;
285
+ const pBuf = createUniformBuffer(d, params);
286
+ const bg = createBindGroup(d, this.pipelines['linear']!,
287
+ [pBuf, siluOut, this.gpuWeights['wFfn2']!, this.gpuWeights['bFfn2']!, ffn2Out]);
288
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
289
+ }
290
+ siluOut.destroy();
291
+
292
+ const residual2 = createEmptyStorageBuffer(d, M * dModel * 4, true);
293
+ {
294
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
295
+ const bg = createBindGroup(d, this.pipelines['elAdd']!,
296
+ [ffn2Out, current, residual2, nBuf]);
297
+ dispatchKernel(d, this.pipelines['elAdd']!, bg, [cdiv(M * dModel, 256), 1, 1]);
298
+ }
299
+ ffn2Out.destroy();
300
+ current.destroy();
301
+ current = residual2;
302
+ }
303
+
304
+ const cache: AttentionCache = { scores };
305
+ return { output: current, cache };
306
+ }
307
+
308
+ parameters(): LayerParam[] {
309
+ const { dModel, hasFfn, ffnMult } = this.config;
310
+ const params: LayerParam[] = [
311
+ { buf: this.gpuWeights['wQKV']!, numel: 3 * dModel * dModel, name: 'wQKV' },
312
+ { buf: this.gpuWeights['bQKV']!, numel: 3 * dModel, name: 'bQKV' },
313
+ { buf: this.gpuWeights['wO']!, numel: dModel * dModel, name: 'wO' },
314
+ { buf: this.gpuWeights['bO']!, numel: dModel, name: 'bO' },
315
+ { buf: this.gpuWeights['normWeight']!, numel: dModel, name: 'normWeight'},
316
+ ];
317
+
318
+ if (hasFfn) {
319
+ const ffnDim = dModel * ffnMult;
320
+ params.push(
321
+ { buf: this.gpuWeights['wFfn1']!, numel: ffnDim * dModel, name: 'wFfn1' },
322
+ { buf: this.gpuWeights['bFfn1']!, numel: ffnDim, name: 'bFfn1' },
323
+ { buf: this.gpuWeights['wFfn2']!, numel: dModel * ffnDim, name: 'wFfn2' },
324
+ { buf: this.gpuWeights['bFfn2']!, numel: dModel, name: 'bFfn2' },
325
+ );
326
+ }
327
+
328
+ return params;
329
+ }
330
+
331
+ getTrainableParams(): LayerParam[] {
332
+ // Attention layers are always fully trained — no WSLA subset
333
+ return this.parameters();
334
+ }
335
+
336
+ setWSLAMode(_enabled: boolean): void {
337
+ // No-op for attention: WSLA does not apply
338
+ }
339
+
340
+ destroy(): void {
341
+ for (const buf of Object.values(this.gpuWeights)) buf.destroy();
342
+ this.gpuWeights = {};
343
+ }
344
+ }