@seanhogg/builderforce-memory-engine 2026.6.20 → 2026.6.27

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,373 @@
1
+ /**
2
+ * limbic_model.ts – LimbicModel: a small, trainable recurrent affect head.
3
+ *
4
+ * The limbic model learns the *dynamics* of an agent's affective/motivational
5
+ * state. Given an experience embedding (produced upstream by the hippocampus
6
+ * SSM) and the current affective state, it predicts:
7
+ * • a bounded affect *delta* (how valence/arousal/drives/attention/exploration
8
+ * should move in response to this experience), and
9
+ * • a scalar *reward* prediction (how good/bad this experience was).
10
+ *
11
+ * Architecture (intentionally tiny — the heavy representation work happens in
12
+ * the hippocampus; this head just maps representation → affect):
13
+ *
14
+ * pre[j] = Σ_i Win[j,i]·x[i] + Σ_k Ws[j,k]·s[k] (hidden pre-activation)
15
+ * a[j] = sigmoid(A[j]) (per-channel SSM gate)
16
+ * h'[j] = a[j]·h[j] + (1-a[j])·tanh(pre[j]) (recurrent leak/input)
17
+ * Δ[k] = tanh( Σ_j Wout[k,j]·h'[j] + b[k] ) (bounded affect delta)
18
+ * r = Σ_j Wr[j]·h'[j] + br (reward prediction)
19
+ *
20
+ * It runs on WebGPU (per-turn `step()` via {@link LIMBIC_AFFECT_WGSL}, and the
21
+ * AdamW optimiser via the shared WEIGHT_UPDATE_WGSL kernel during training) with
22
+ * a numerically-identical pure-CPU reference path used when no GPUDevice is
23
+ * available — the same WebGPU-or-fallback contract the rest of the engine uses.
24
+ *
25
+ * Training uses truncated BPTT(1): the recurrent state is carried forward across
26
+ * a sequence, but each step's gradient treats the incoming hidden state as a
27
+ * constant. For an affect head this is stable and sufficient — each
28
+ * (experience, state) → affect pair is close to an independent regression.
29
+ */
30
+
31
+ import { SeededRng } from "../utils/rng.js";
32
+ import { quantizeFp16, dequantizeFp16 } from "../utils/quantization.js";
33
+ import { LIMBIC_STATE_DIM } from "./regions.js";
34
+
35
+ export interface LimbicModelConfig {
36
+ /** Experience-embedding dimension (input). Default 32. */
37
+ inputDim: number;
38
+ /** Hidden recurrent width. Must be ≤ 64 (the GPU kernel's workgroup size). Default 16. */
39
+ hiddenDim: number;
40
+ /** Affective state dimension. Default {@link LIMBIC_STATE_DIM} (8). */
41
+ stateDim: number;
42
+ /** Deterministic init seed for reproducible cold-start weights. */
43
+ seed?: number;
44
+ /** Weight of the reward-prediction term in the training loss. Default 0.5. */
45
+ rewardWeight?: number;
46
+ }
47
+
48
+ export const DEFAULT_LIMBIC_CONFIG: Required<Omit<LimbicModelConfig, "seed">> = {
49
+ inputDim: 32,
50
+ hiddenDim: 16,
51
+ stateDim: LIMBIC_STATE_DIM,
52
+ rewardWeight: 0.5,
53
+ };
54
+
55
+ /** Result of a single forward step. */
56
+ export interface LimbicForward {
57
+ /** Next hidden recurrent state (length hiddenDim). */
58
+ hidden: Float32Array;
59
+ /** Bounded affect delta in (-1, 1) per state dim (length stateDim). */
60
+ delta: Float32Array;
61
+ /** Reward prediction (scalar). */
62
+ reward: number;
63
+ }
64
+
65
+ /** Per-step forward intermediates retained for the backward pass. */
66
+ interface ForwardCache {
67
+ x: Float32Array;
68
+ sPrev: Float32Array;
69
+ hPrev: Float32Array;
70
+ a: Float32Array; // sigmoid(A)
71
+ t: Float32Array; // tanh(pre)
72
+ hn: Float32Array; // hidden
73
+ delta: Float32Array; // tanh(preDelta)
74
+ reward: number;
75
+ }
76
+
77
+ /** A named trainable parameter tensor (flat row-major Float32Array). */
78
+ export interface LimbicParam {
79
+ name: string;
80
+ data: Float32Array;
81
+ numel: number;
82
+ }
83
+
84
+ const MAGIC = 0x4c4d4243; // "LMBC"
85
+
86
+ /** Fixed default init seed — reproducible byte-identical cold start across machines. */
87
+ export const DEFAULT_LIMBIC_SEED = 0x11b1c5ee;
88
+
89
+ function sigmoid(x: number): number {
90
+ return 1 / (1 + Math.exp(-x));
91
+ }
92
+
93
+ export class LimbicModel {
94
+ readonly config: Required<Omit<LimbicModelConfig, "seed">>;
95
+
96
+ // Parameters (flat, row-major).
97
+ win: Float32Array; // hidden × input
98
+ ws: Float32Array; // hidden × state
99
+ aLogit: Float32Array; // hidden
100
+ woutState: Float32Array; // state × hidden
101
+ boutState: Float32Array; // state
102
+ woutReward: Float32Array; // hidden
103
+ boutReward: Float32Array; // 1
104
+
105
+ // Gradient accumulators (same shapes), allocated lazily.
106
+ private gWin: Float32Array;
107
+ private gWs: Float32Array;
108
+ private gALogit: Float32Array;
109
+ private gWoutState: Float32Array;
110
+ private gBoutState: Float32Array;
111
+ private gWoutReward: Float32Array;
112
+ private gBoutReward: Float32Array;
113
+
114
+ constructor(config: Partial<LimbicModelConfig> = {}) {
115
+ const cfg = { ...DEFAULT_LIMBIC_CONFIG, ...config };
116
+ if (cfg.hiddenDim > 64) {
117
+ throw new Error(`LimbicModel hiddenDim must be ≤ 64 (got ${cfg.hiddenDim})`);
118
+ }
119
+ this.config = cfg;
120
+ const { inputDim, hiddenDim, stateDim } = cfg;
121
+
122
+ const rng = new SeededRng(((config.seed ?? DEFAULT_LIMBIC_SEED) >>> 0) || 1);
123
+ const randn = (std: number): number => {
124
+ const u1 = Math.max(rng.next(), 1e-12);
125
+ const u2 = rng.next();
126
+ return std * Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
127
+ };
128
+ const gauss = (n: number, std: number): Float32Array => {
129
+ const a = new Float32Array(n);
130
+ for (let i = 0; i < n; i++) a[i] = randn(std);
131
+ return a;
132
+ };
133
+
134
+ // Small init — affect deltas should start near zero so an untrained model is
135
+ // inert (no spurious mood swings) until it has learned from experience.
136
+ this.win = gauss(hiddenDim * inputDim, 0.1);
137
+ this.ws = gauss(hiddenDim * stateDim, 0.1);
138
+ this.aLogit = gauss(hiddenDim, 0.05); // sigmoid(~0) ≈ 0.5 leak
139
+ this.woutState = gauss(stateDim * hiddenDim, 0.05);
140
+ this.boutState = new Float32Array(stateDim);
141
+ this.woutReward = gauss(hiddenDim, 0.05);
142
+ this.boutReward = new Float32Array(1);
143
+
144
+ this.gWin = new Float32Array(this.win.length);
145
+ this.gWs = new Float32Array(this.ws.length);
146
+ this.gALogit = new Float32Array(this.aLogit.length);
147
+ this.gWoutState = new Float32Array(this.woutState.length);
148
+ this.gBoutState = new Float32Array(this.boutState.length);
149
+ this.gWoutReward = new Float32Array(this.woutReward.length);
150
+ this.gBoutReward = new Float32Array(1);
151
+ }
152
+
153
+ /** Trainable parameters, in the canonical checkpoint order. */
154
+ parameters(): LimbicParam[] {
155
+ return [
156
+ { name: "win", data: this.win, numel: this.win.length },
157
+ { name: "ws", data: this.ws, numel: this.ws.length },
158
+ { name: "aLogit", data: this.aLogit, numel: this.aLogit.length },
159
+ { name: "woutState", data: this.woutState, numel: this.woutState.length },
160
+ { name: "boutState", data: this.boutState, numel: this.boutState.length },
161
+ { name: "woutReward", data: this.woutReward, numel: this.woutReward.length },
162
+ { name: "boutReward", data: this.boutReward, numel: this.boutReward.length },
163
+ ];
164
+ }
165
+
166
+ /** Gradient buffers, index-aligned with {@link parameters}. */
167
+ gradients(): LimbicParam[] {
168
+ return [
169
+ { name: "win", data: this.gWin, numel: this.gWin.length },
170
+ { name: "ws", data: this.gWs, numel: this.gWs.length },
171
+ { name: "aLogit", data: this.gALogit, numel: this.gALogit.length },
172
+ { name: "woutState", data: this.gWoutState, numel: this.gWoutState.length },
173
+ { name: "boutState", data: this.gBoutState, numel: this.gBoutState.length },
174
+ { name: "woutReward", data: this.gWoutReward, numel: this.gWoutReward.length },
175
+ { name: "boutReward", data: this.gBoutReward, numel: this.gBoutReward.length },
176
+ ];
177
+ }
178
+
179
+ zeroGrad(): void {
180
+ this.gWin.fill(0);
181
+ this.gWs.fill(0);
182
+ this.gALogit.fill(0);
183
+ this.gWoutState.fill(0);
184
+ this.gBoutState.fill(0);
185
+ this.gWoutReward.fill(0);
186
+ this.gBoutReward.fill(0);
187
+ }
188
+
189
+ /** A fresh zeroed hidden state. */
190
+ initHidden(): Float32Array {
191
+ return new Float32Array(this.config.hiddenDim);
192
+ }
193
+
194
+ /**
195
+ * One forward step (CPU reference). Pure — does not mutate the model or the
196
+ * inputs. The GPU kernel path produces numerically-identical results.
197
+ */
198
+ forward(x: ArrayLike<number>, hPrev: ArrayLike<number>, sPrev: ArrayLike<number>): LimbicForward {
199
+ const cache = this._forwardCached(x, hPrev, sPrev);
200
+ return { hidden: cache.hn, delta: cache.delta, reward: cache.reward };
201
+ }
202
+
203
+ private _forwardCached(
204
+ x: ArrayLike<number>,
205
+ hPrev: ArrayLike<number>,
206
+ sPrev: ArrayLike<number>,
207
+ ): ForwardCache {
208
+ const { inputDim, hiddenDim, stateDim } = this.config;
209
+ const xa = Float32Array.from({ length: inputDim }, (_, i) => x[i] ?? 0);
210
+ const ha = Float32Array.from({ length: hiddenDim }, (_, j) => hPrev[j] ?? 0);
211
+ const sa = Float32Array.from({ length: stateDim }, (_, k) => sPrev[k] ?? 0);
212
+
213
+ const a = new Float32Array(hiddenDim);
214
+ const t = new Float32Array(hiddenDim);
215
+ const hn = new Float32Array(hiddenDim);
216
+ for (let j = 0; j < hiddenDim; j++) {
217
+ let pre = 0;
218
+ const wiOff = j * inputDim;
219
+ for (let i = 0; i < inputDim; i++) pre += this.win[wiOff + i]! * xa[i]!;
220
+ const wsOff = j * stateDim;
221
+ for (let k = 0; k < stateDim; k++) pre += this.ws[wsOff + k]! * sa[k]!;
222
+ a[j] = sigmoid(this.aLogit[j]!);
223
+ t[j] = Math.tanh(pre);
224
+ hn[j] = a[j]! * ha[j]! + (1 - a[j]!) * t[j]!;
225
+ }
226
+
227
+ const delta = new Float32Array(stateDim);
228
+ for (let k = 0; k < stateDim; k++) {
229
+ let acc = this.boutState[k]!;
230
+ const off = k * hiddenDim;
231
+ for (let j = 0; j < hiddenDim; j++) acc += this.woutState[off + j]! * hn[j]!;
232
+ delta[k] = Math.tanh(acc);
233
+ }
234
+
235
+ let reward = this.boutReward[0]!;
236
+ for (let j = 0; j < hiddenDim; j++) reward += this.woutReward[j]! * hn[j]!;
237
+
238
+ return { x: xa, sPrev: sa, hPrev: ha, a, t, hn, delta, reward };
239
+ }
240
+
241
+ /**
242
+ * Accumulate gradients for one (input, state) → (deltaTarget, rewardTarget)
243
+ * sample using truncated BPTT(1). Returns the scalar loss for this step and
244
+ * the next hidden state to carry forward. Call {@link zeroGrad} before a batch
245
+ * and apply the optimiser after.
246
+ */
247
+ backwardStep(
248
+ x: ArrayLike<number>,
249
+ hPrev: ArrayLike<number>,
250
+ sPrev: ArrayLike<number>,
251
+ deltaTarget: ArrayLike<number>,
252
+ rewardTarget: number,
253
+ ): { loss: number; hidden: Float32Array } {
254
+ const { inputDim, hiddenDim, stateDim, rewardWeight } = this.config;
255
+ const c = this._forwardCached(x, hPrev, sPrev);
256
+
257
+ // Loss
258
+ let loss = 0;
259
+ const dDelta = new Float32Array(stateDim);
260
+ for (let k = 0; k < stateDim; k++) {
261
+ const diff = c.delta[k]! - (deltaTarget[k] ?? 0);
262
+ loss += 0.5 * diff * diff;
263
+ dDelta[k] = diff;
264
+ }
265
+ const dRewardDiff = c.reward - rewardTarget;
266
+ loss += 0.5 * rewardWeight * dRewardDiff * dRewardDiff;
267
+ const dReward = rewardWeight * dRewardDiff;
268
+
269
+ // Backprop to hidden through both heads.
270
+ const dHn = new Float32Array(hiddenDim);
271
+ for (let k = 0; k < stateDim; k++) {
272
+ const dPreDelta = dDelta[k]! * (1 - c.delta[k]! * c.delta[k]!); // tanh'
273
+ this.gBoutState[k] = this.gBoutState[k]! + dPreDelta;
274
+ const off = k * hiddenDim;
275
+ for (let j = 0; j < hiddenDim; j++) {
276
+ this.gWoutState[off + j] = this.gWoutState[off + j]! + dPreDelta * c.hn[j]!;
277
+ dHn[j] = dHn[j]! + dPreDelta * this.woutState[off + j]!;
278
+ }
279
+ }
280
+ this.gBoutReward[0] = this.gBoutReward[0]! + dReward;
281
+ for (let j = 0; j < hiddenDim; j++) {
282
+ this.gWoutReward[j] = this.gWoutReward[j]! + dReward * c.hn[j]!;
283
+ dHn[j] = dHn[j]! + dReward * this.woutReward[j]!;
284
+ }
285
+
286
+ // Backprop through the recurrent update hn = a·hPrev + (1-a)·t.
287
+ for (let j = 0; j < hiddenDim; j++) {
288
+ const aj = c.a[j]!;
289
+ const tj = c.t[j]!;
290
+ // ∂hn/∂A = (hPrev - t)·a·(1-a)
291
+ this.gALogit[j] = this.gALogit[j]! + dHn[j]! * (c.hPrev[j]! - tj) * aj * (1 - aj);
292
+ // ∂hn/∂t = (1-a); ∂t/∂pre = 1 - t²
293
+ const dPre = dHn[j]! * (1 - aj) * (1 - tj * tj);
294
+ const wiOff = j * inputDim;
295
+ for (let i = 0; i < inputDim; i++) this.gWin[wiOff + i] = this.gWin[wiOff + i]! + dPre * c.x[i]!;
296
+ const wsOff = j * stateDim;
297
+ for (let k = 0; k < stateDim; k++) this.gWs[wsOff + k] = this.gWs[wsOff + k]! + dPre * c.sPrev[k]!;
298
+ }
299
+
300
+ return { loss, hidden: c.hn };
301
+ }
302
+
303
+ // ── Checkpoint serialisation ──────────────────────────────────────────────
304
+
305
+ /**
306
+ * Serialise weights to a compact "LMBC" binary. fp16 (v2) halves the size at
307
+ * ~0.5% precision cost; f32 (v1) is exact. Layout: magic, version, [inputDim,
308
+ * hiddenDim, stateDim], then params in {@link parameters} order.
309
+ */
310
+ exportWeights(opts: { fp16?: boolean } = {}): ArrayBuffer {
311
+ const fp16 = opts.fp16 ?? false;
312
+ const params = this.parameters();
313
+ const total = params.reduce((n, p) => n + p.numel, 0);
314
+ const headerEls = 5; // magic, version, inputDim, hiddenDim, stateDim
315
+ const headerBytes = headerEls * 4;
316
+ const dataBytes = fp16 ? total * 2 : total * 4;
317
+ const buf = new ArrayBuffer(headerBytes + dataBytes);
318
+ const head = new Uint32Array(buf, 0, headerEls);
319
+ head[0] = MAGIC;
320
+ head[1] = fp16 ? 2 : 1;
321
+ head[2] = this.config.inputDim;
322
+ head[3] = this.config.hiddenDim;
323
+ head[4] = this.config.stateDim;
324
+
325
+ if (fp16) {
326
+ const flat = new Float32Array(total);
327
+ let o = 0;
328
+ for (const p of params) {
329
+ flat.set(p.data, o);
330
+ o += p.numel;
331
+ }
332
+ const q = quantizeFp16(flat); // Uint16Array
333
+ new Uint16Array(buf, headerBytes, total).set(q);
334
+ } else {
335
+ const out = new Float32Array(buf, headerBytes, total);
336
+ let o = 0;
337
+ for (const p of params) {
338
+ out.set(p.data, o);
339
+ o += p.numel;
340
+ }
341
+ }
342
+ return buf;
343
+ }
344
+
345
+ /** Load weights from an "LMBC" binary. Validates magic + dims. */
346
+ loadWeights(buffer: ArrayBuffer): void {
347
+ const head = new Uint32Array(buffer, 0, 5);
348
+ if (head[0] !== MAGIC) throw new Error("LimbicModel.loadWeights: bad magic (not an LMBC checkpoint)");
349
+ const version = head[1]!;
350
+ const inputDim = head[2]!;
351
+ const hiddenDim = head[3]!;
352
+ const stateDim = head[4]!;
353
+ if (inputDim !== this.config.inputDim || hiddenDim !== this.config.hiddenDim || stateDim !== this.config.stateDim) {
354
+ throw new Error(
355
+ `LimbicModel.loadWeights: dim mismatch — checkpoint ${inputDim}/${hiddenDim}/${stateDim} vs model ${this.config.inputDim}/${this.config.hiddenDim}/${this.config.stateDim}`,
356
+ );
357
+ }
358
+ const params = this.parameters();
359
+ const total = params.reduce((n, p) => n + p.numel, 0);
360
+ const headerBytes = 20;
361
+ let flat: Float32Array;
362
+ if (version === 2) {
363
+ flat = dequantizeFp16(new Uint16Array(buffer, headerBytes, total));
364
+ } else {
365
+ flat = new Float32Array(buffer.slice(headerBytes, headerBytes + total * 4));
366
+ }
367
+ let o = 0;
368
+ for (const p of params) {
369
+ p.data.set(flat.subarray(o, o + p.numel));
370
+ o += p.numel;
371
+ }
372
+ }
373
+ }
@@ -0,0 +1,253 @@
1
+ /**
2
+ * limbic_trainer.ts – LimbicTrainer: gradient-based training for the LimbicModel.
3
+ *
4
+ * Trains the affect head to predict affect deltas and reward from
5
+ * (experience embedding, current state) pairs, using full-batch gradient
6
+ * descent with AdamW. When a GPUDevice is supplied the optimiser step runs on
7
+ * the GPU via the shared WEIGHT_UPDATE_WGSL kernel (real WebGPU training); with
8
+ * no device it uses a numerically-equivalent CPU AdamW so training works
9
+ * everywhere (CI, Node without @webgpu/node, etc.).
10
+ *
11
+ * The objective is MSE (regression), not the cross-entropy used by the
12
+ * language-model {@link MambaTrainer} — a limbic experience has continuous
13
+ * targets, not a next-token distribution.
14
+ */
15
+
16
+ import {
17
+ createUniformBuffer,
18
+ createStorageBuffer,
19
+ createComputePipeline,
20
+ createBindGroup,
21
+ dispatchKernel,
22
+ readBuffer,
23
+ cdiv,
24
+ } from "../utils/gpu_utils.js";
25
+ import { WEIGHT_UPDATE_WGSL } from "../kernels/weight_update.js";
26
+ import type { LimbicModel, LimbicParam } from "./limbic_model.js";
27
+
28
+ /** One training example: an experience and the affect change it should produce. */
29
+ export interface LimbicSample {
30
+ /** Experience embedding (length = model.inputDim). */
31
+ input: ArrayLike<number>;
32
+ /** Affective state at the time of the experience (length = model.stateDim). */
33
+ state: ArrayLike<number>;
34
+ /** Observed affect delta target in (-1, 1) per state dim (length = model.stateDim). */
35
+ deltaTarget: ArrayLike<number>;
36
+ /** Observed scalar reward for the experience. */
37
+ reward: number;
38
+ }
39
+
40
+ export interface LimbicTrainOptions {
41
+ learningRate?: number;
42
+ epochs?: number;
43
+ weightDecay?: number;
44
+ beta1?: number;
45
+ beta2?: number;
46
+ eps?: number;
47
+ /** Max global gradient L2 norm before the optimiser step. Default 1.0. */
48
+ maxGradNorm?: number;
49
+ onEpochEnd?: ((epoch: number, loss: number) => void) | null;
50
+ }
51
+
52
+ interface AdamMoment {
53
+ m: Float32Array;
54
+ v: Float32Array;
55
+ }
56
+
57
+ function packAdamParams(
58
+ numElements: number,
59
+ lr: number,
60
+ beta1: number,
61
+ beta2: number,
62
+ eps: number,
63
+ weightDecay: number,
64
+ beta1_t: number,
65
+ beta2_t: number,
66
+ ): ArrayBuffer {
67
+ const buf = new ArrayBuffer(32);
68
+ new Uint32Array(buf, 0, 1).set([numElements]);
69
+ new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
70
+ return buf;
71
+ }
72
+
73
+ export class LimbicTrainer {
74
+ readonly model: LimbicModel;
75
+ readonly device: GPUDevice | null;
76
+ private _moments: AdamMoment[] | null = null;
77
+ private _step = 0;
78
+ private readonly _adamwPipeline: GPUComputePipeline | null;
79
+
80
+ constructor(model: LimbicModel, device: GPUDevice | null = null) {
81
+ this.model = model;
82
+ this.device = device;
83
+ this._adamwPipeline = device
84
+ ? createComputePipeline(device, WEIGHT_UPDATE_WGSL, "adamw_update")
85
+ : null;
86
+ }
87
+
88
+ /** Whether the optimiser step runs on the GPU. */
89
+ get gpuTraining(): boolean {
90
+ return this.device != null && this._adamwPipeline != null;
91
+ }
92
+
93
+ private _initMoments(): void {
94
+ if (this._moments) return;
95
+ this._moments = this.model.parameters().map((p) => ({
96
+ m: new Float32Array(p.numel),
97
+ v: new Float32Array(p.numel),
98
+ }));
99
+ }
100
+
101
+ /**
102
+ * Train on a batch of samples for `epochs` passes. Returns the per-epoch mean
103
+ * loss (monotonically decreasing on a learnable mapping). Full-batch: grads
104
+ * accumulate across the whole sequence (recurrent hidden carried, reset per
105
+ * epoch), are averaged, clipped, then applied once per epoch.
106
+ */
107
+ async train(samples: LimbicSample[], opts: LimbicTrainOptions = {}): Promise<number[]> {
108
+ if (samples.length === 0) throw new Error("LimbicTrainer.train: no samples");
109
+ const {
110
+ learningRate = 0.05,
111
+ epochs = 50,
112
+ weightDecay = 0.0,
113
+ beta1 = 0.9,
114
+ beta2 = 0.999,
115
+ eps = 1e-8,
116
+ maxGradNorm = 1.0,
117
+ onEpochEnd = null,
118
+ } = opts;
119
+
120
+ this._initMoments();
121
+ const losses: number[] = [];
122
+
123
+ for (let epoch = 0; epoch < epochs; epoch++) {
124
+ this.model.zeroGrad();
125
+ let epochLoss = 0;
126
+
127
+ // Each sample is one experience appraisal: cross-experience memory flows
128
+ // through the affective state `s` (fed back by the runtime), not through
129
+ // the hidden scratch `h`, so the hidden state resets per sample. This also
130
+ // makes BPTT(1) exact — there is no truncated carry.
131
+ for (const s of samples) {
132
+ const { loss } = this.model.backwardStep(
133
+ s.input,
134
+ this.model.initHidden(),
135
+ s.state,
136
+ s.deltaTarget,
137
+ s.reward,
138
+ );
139
+ epochLoss += loss;
140
+ }
141
+
142
+ // Average gradients over the batch.
143
+ const grads = this.model.gradients();
144
+ const invN = 1 / samples.length;
145
+ for (const g of grads) {
146
+ for (let i = 0; i < g.data.length; i++) g.data[i]! *= invN;
147
+ }
148
+
149
+ this._clipGradients(grads, maxGradNorm);
150
+
151
+ this._step++;
152
+ const beta1_t = Math.pow(beta1, this._step);
153
+ const beta2_t = Math.pow(beta2, this._step);
154
+ const hp = { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t };
155
+ if (this.gpuTraining) {
156
+ await this._adamwStepGpu(grads, hp);
157
+ } else {
158
+ this._adamwStepCpu(grads, hp);
159
+ }
160
+
161
+ const avg = epochLoss / samples.length;
162
+ losses.push(avg);
163
+ if (onEpochEnd) onEpochEnd(epoch + 1, avg);
164
+ }
165
+ return losses;
166
+ }
167
+
168
+ /** Mean MSE loss over samples (no weight update). Hidden resets per sample. */
169
+ evaluate(samples: LimbicSample[]): number {
170
+ if (samples.length === 0) return 0;
171
+ let total = 0;
172
+ const { rewardWeight } = this.model.config;
173
+ for (const s of samples) {
174
+ const f = this.model.forward(s.input, this.model.initHidden(), s.state);
175
+ let loss = 0;
176
+ for (let k = 0; k < f.delta.length; k++) {
177
+ const d = f.delta[k]! - (s.deltaTarget[k] ?? 0);
178
+ loss += 0.5 * d * d;
179
+ }
180
+ const rd = f.reward - s.reward;
181
+ loss += 0.5 * rewardWeight * rd * rd;
182
+ total += loss;
183
+ }
184
+ return total / samples.length;
185
+ }
186
+
187
+ private _clipGradients(grads: LimbicParam[], maxNorm: number): void {
188
+ let normSq = 0;
189
+ for (const g of grads) for (let i = 0; i < g.data.length; i++) normSq += g.data[i]! * g.data[i]!;
190
+ const norm = Math.sqrt(normSq);
191
+ if (norm > maxNorm && norm > 0) {
192
+ const scale = maxNorm / norm;
193
+ for (const g of grads) for (let i = 0; i < g.data.length; i++) g.data[i]! *= scale;
194
+ }
195
+ }
196
+
197
+ private _adamwStepCpu(
198
+ grads: LimbicParam[],
199
+ hp: { learningRate: number; weightDecay: number; beta1: number; beta2: number; eps: number; beta1_t: number; beta2_t: number },
200
+ ): void {
201
+ const params = this.model.parameters();
202
+ const { learningRate: lr, weightDecay: wd, beta1, beta2, eps, beta1_t, beta2_t } = hp;
203
+ for (let pi = 0; pi < params.length; pi++) {
204
+ const p = params[pi]!.data;
205
+ const g = grads[pi]!.data;
206
+ const mom = this._moments![pi]!;
207
+ for (let i = 0; i < p.length; i++) {
208
+ const gi = g[i]!;
209
+ mom.m[i] = beta1 * mom.m[i]! + (1 - beta1) * gi;
210
+ mom.v[i] = beta2 * mom.v[i]! + (1 - beta2) * gi * gi;
211
+ const mHat = mom.m[i]! / (1 - beta1_t);
212
+ const vHat = mom.v[i]! / (1 - beta2_t);
213
+ p[i] = p[i]! * (1 - lr * wd) - (lr * mHat) / (Math.sqrt(vHat) + eps);
214
+ }
215
+ }
216
+ }
217
+
218
+ /** AdamW on the GPU via the shared WEIGHT_UPDATE_WGSL kernel. Awaited per step. */
219
+ private async _adamwStepGpu(
220
+ grads: LimbicParam[],
221
+ hp: { learningRate: number; weightDecay: number; beta1: number; beta2: number; eps: number; beta1_t: number; beta2_t: number },
222
+ ): Promise<void> {
223
+ const device = this.device!;
224
+ const pipeline = this._adamwPipeline!;
225
+ const params = this.model.parameters();
226
+ const { learningRate: lr, weightDecay: wd, beta1, beta2, eps, beta1_t, beta2_t } = hp;
227
+
228
+ for (let pi = 0; pi < params.length; pi++) {
229
+ const p = params[pi]!;
230
+ const mom = this._moments![pi]!;
231
+ const paramBuf = createStorageBuffer(device, p.data, true);
232
+ const gradBuf = createStorageBuffer(device, grads[pi]!.data, false);
233
+ const mBuf = createStorageBuffer(device, mom.m, true);
234
+ const vBuf = createStorageBuffer(device, mom.v, true);
235
+ const uni = createUniformBuffer(
236
+ device,
237
+ packAdamParams(p.numel, lr, beta1, beta2, eps, wd, beta1_t, beta2_t),
238
+ );
239
+ const bg = createBindGroup(device, pipeline, [uni, paramBuf, gradBuf, mBuf, vBuf]);
240
+ dispatchKernel(device, pipeline, bg, [cdiv(p.numel, 256), 1, 1]);
241
+
242
+ p.data.set((await readBuffer(device, paramBuf, p.numel * 4)).subarray(0, p.numel));
243
+ mom.m.set((await readBuffer(device, mBuf, p.numel * 4)).subarray(0, p.numel));
244
+ mom.v.set((await readBuffer(device, vBuf, p.numel * 4)).subarray(0, p.numel));
245
+
246
+ paramBuf.destroy();
247
+ gradBuf.destroy();
248
+ mBuf.destroy();
249
+ vBuf.destroy();
250
+ uni.destroy();
251
+ }
252
+ }
253
+ }