@aleph-ai/tinyaleph 1.1.0 → 1.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,811 @@
1
+ /**
2
+ * ResoFormer Layers - Complete Transformer-style Architecture
3
+ *
4
+ * Building on the primitives in rformer.js, this module provides:
5
+ * - ResonantMultiHeadAttention: Multiple attention heads with different weights
6
+ * - PrimeFFN: Feed-forward network maintaining prime structure
7
+ * - PrimeLayerNorm: Normalization preserving prime properties
8
+ * - PositionalPrimeEncoding: Position encoded as prime phases
9
+ * - ResoFormerBlock: Complete transformer block
10
+ *
11
+ * These components enable building full ResoFormer models for
12
+ * prime-resonant sequence processing.
13
+ */
14
+
15
+ 'use strict';
16
+
17
+ const {
18
+ Quaternion,
19
+ SparsePrimeState,
20
+ resonanceScore,
21
+ resonantAttention,
22
+ hamiltonCompose,
23
+ computeCoherence
24
+ } = require('./rformer');
25
+
26
+ const { Complex, PrimeState } = require('./hilbert');
27
+ const { firstNPrimes, nthPrime, isPrime } = require('./prime');
28
+
29
+ /**
30
+ * ResonantMultiHeadAttention
31
+ *
32
+ * Multiple attention heads with different resonance weight configurations.
33
+ * Each head can emphasize different aspects:
34
+ * - Prime overlap (Jaccard)
35
+ * - Quaternion alignment
36
+ * - Phase coherence
37
+ */
38
+ class ResonantMultiHeadAttention {
39
+ /**
40
+ * @param {object} config
41
+ * @param {number} config.numHeads - Number of attention heads
42
+ * @param {number} [config.numPrimes=4096] - Size of prime vocabulary
43
+ * @param {number} [config.activeK=32] - Sparsity per state
44
+ * @param {number[][]} [config.headWeights] - Per-head [alpha, beta, gamma]
45
+ * @param {number} [config.temperature=1.0] - Softmax temperature
46
+ */
47
+ constructor(config) {
48
+ this.numHeads = config.numHeads || 8;
49
+ this.numPrimes = config.numPrimes || 4096;
50
+ this.activeK = config.activeK || 32;
51
+ this.temperature = config.temperature || 1.0;
52
+
53
+ // Initialize per-head weights
54
+ // Default: vary emphasis across heads
55
+ this.headWeights = config.headWeights || this._defaultHeadWeights();
56
+
57
+ // Output projection weights (learnable in training)
58
+ this.outputScale = config.outputScale || 1.0 / Math.sqrt(this.numHeads);
59
+ }
60
+
61
+ /**
62
+ * Generate default head weights with varying emphasis
63
+ * @private
64
+ */
65
+ _defaultHeadWeights() {
66
+ const weights = [];
67
+ for (let h = 0; h < this.numHeads; h++) {
68
+ const t = h / (this.numHeads - 1 || 1);
69
+
70
+ // Interpolate between different emphasis patterns
71
+ // Head 0: Prime overlap focus
72
+ // Head numHeads-1: Phase coherence focus
73
+ const alpha = 0.5 - 0.3 * t; // Jaccard weight
74
+ const beta = 0.3; // Quaternion (constant)
75
+ const gamma = 0.2 + 0.3 * t; // Phase weight
76
+
77
+ weights.push([alpha, beta, gamma]);
78
+ }
79
+ return weights;
80
+ }
81
+
82
+ /**
83
+ * Apply multi-head attention
84
+ *
85
+ * @param {SparsePrimeState} query - Query state
86
+ * @param {SparsePrimeState[]} keys - Key states
87
+ * @param {SparsePrimeState[]} values - Value states
88
+ * @returns {object} { result, headOutputs, attentionWeights }
89
+ */
90
+ forward(query, keys, values) {
91
+ const headOutputs = [];
92
+ const allWeights = [];
93
+
94
+ // Apply each head
95
+ for (let h = 0; h < this.numHeads; h++) {
96
+ const [alpha, beta, gamma] = this.headWeights[h];
97
+
98
+ // Compute head-specific attention
99
+ const headResult = this._headAttention(
100
+ query, keys, values, alpha, beta, gamma
101
+ );
102
+
103
+ headOutputs.push(headResult.result);
104
+ allWeights.push(headResult.weights);
105
+ }
106
+
107
+ // Combine head outputs
108
+ const combined = this._combineHeads(headOutputs);
109
+
110
+ return {
111
+ result: combined,
112
+ headOutputs,
113
+ attentionWeights: allWeights
114
+ };
115
+ }
116
+
117
+ /**
118
+ * Single head attention with custom weights
119
+ * @private
120
+ */
121
+ _headAttention(query, keys, values, alpha, beta, gamma) {
122
+ const n = keys.length;
123
+ if (n === 0) {
124
+ return { result: query, weights: [], scores: [] };
125
+ }
126
+
127
+ // Compute resonance scores with head-specific weights
128
+ const scores = keys.map(k => {
129
+ const primesQ = new Set(query.getActivePrimes());
130
+ const primesK = new Set(k.getActivePrimes());
131
+
132
+ // Jaccard
133
+ const intersection = new Set([...primesQ].filter(p => primesK.has(p)));
134
+ const union = new Set([...primesQ, ...primesK]);
135
+ const jaccard = intersection.size / (union.size || 1);
136
+
137
+ if (intersection.size === 0) {
138
+ return alpha * jaccard;
139
+ }
140
+
141
+ // Quaternion alignment
142
+ let quatSum = 0;
143
+ for (const p of intersection) {
144
+ const qi = query.get(p).quaternion;
145
+ const qk = k.get(p).quaternion;
146
+ quatSum += Math.abs(qi.dot(qk));
147
+ }
148
+ const quatAlign = quatSum / intersection.size;
149
+
150
+ // Phase coherence
151
+ let phaseSum = 0;
152
+ for (const p of intersection) {
153
+ const phaseQ = query.get(p).amplitude.phase();
154
+ const phaseK = k.get(p).amplitude.phase();
155
+ phaseSum += Math.cos(phaseQ - phaseK);
156
+ }
157
+ const phaseCoherence = (phaseSum / intersection.size + 1) / 2;
158
+
159
+ return alpha * jaccard + beta * quatAlign + gamma * phaseCoherence;
160
+ });
161
+
162
+ // Softmax
163
+ const maxScore = Math.max(...scores);
164
+ const expScores = scores.map(s => Math.exp((s - maxScore) / this.temperature));
165
+ const sumExp = expScores.reduce((a, b) => a + b, 0);
166
+ const weights = expScores.map(e => e / sumExp);
167
+
168
+ // Weighted combination
169
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
170
+
171
+ for (let i = 0; i < n; i++) {
172
+ const w = weights[i];
173
+ for (const [p, act] of values[i].activations) {
174
+ const current = result.get(p);
175
+ const newAmp = current.amplitude.add(act.amplitude.scale(w));
176
+ const newQuat = current.quaternion.add(act.quaternion.scale(w));
177
+ result.set(p, newAmp, newQuat.normalize());
178
+ }
179
+ }
180
+
181
+ return { result: result.normalize(), weights, scores };
182
+ }
183
+
184
+ /**
185
+ * Combine head outputs
186
+ * @private
187
+ */
188
+ _combineHeads(headOutputs) {
189
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
190
+
191
+ for (const headOut of headOutputs) {
192
+ for (const [p, act] of headOut.activations) {
193
+ const current = result.get(p);
194
+ const newAmp = current.amplitude.add(act.amplitude.scale(this.outputScale));
195
+ const newQuat = current.quaternion.add(act.quaternion.scale(this.outputScale));
196
+ result.set(p, newAmp, newQuat.normalize());
197
+ }
198
+ }
199
+
200
+ return result.normalize();
201
+ }
202
+
203
+ /**
204
+ * Set head weights (for training)
205
+ */
206
+ setHeadWeights(headIdx, weights) {
207
+ if (headIdx >= 0 && headIdx < this.numHeads) {
208
+ this.headWeights[headIdx] = weights;
209
+ }
210
+ }
211
+
212
+ /**
213
+ * Get all parameters (for serialization)
214
+ */
215
+ getParameters() {
216
+ return {
217
+ numHeads: this.numHeads,
218
+ headWeights: this.headWeights,
219
+ temperature: this.temperature,
220
+ outputScale: this.outputScale
221
+ };
222
+ }
223
+ }
224
+
225
+ /**
226
+ * PrimeFFN - Prime-Indexed Feed-Forward Network
227
+ *
228
+ * Two-layer MLP that operates on sparse prime activations:
229
+ * FFN(x) = activation(x·W1 + b1)·W2 + b2
230
+ *
231
+ * Maintains sparsity by only operating on active primes.
232
+ */
233
+ class PrimeFFN {
234
+ /**
235
+ * @param {object} config
236
+ * @param {number} [config.hiddenDim=256] - Hidden layer dimension
237
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
238
+ * @param {string} [config.activation='relu'] - Activation function
239
+ * @param {number} [config.dropout=0.0] - Dropout probability
240
+ */
241
+ constructor(config) {
242
+ this.hiddenDim = config.hiddenDim || 256;
243
+ this.numPrimes = config.numPrimes || 4096;
244
+ this.activation = config.activation || 'relu';
245
+ this.dropout = config.dropout || 0.0;
246
+
247
+ // Initialize weights (simplified: diagonal + bias)
248
+ // In full implementation, these would be learned
249
+ this.w1Scale = config.w1Scale || 2.0;
250
+ this.w2Scale = config.w2Scale || 0.5;
251
+ this.bias1 = config.bias1 || 0.1;
252
+ this.bias2 = config.bias2 || 0.0;
253
+ }
254
+
255
+ /**
256
+ * Apply feed-forward network
257
+ * @param {SparsePrimeState} x - Input state
258
+ * @returns {SparsePrimeState} Output state
259
+ */
260
+ forward(x) {
261
+ const result = new SparsePrimeState(this.numPrimes, x.k);
262
+
263
+ for (const [p, act] of x.activations) {
264
+ // First layer: scale + bias
265
+ let hidden = act.amplitude.scale(this.w1Scale);
266
+ hidden = hidden.add(new Complex(this.bias1, 0));
267
+
268
+ // Activation
269
+ hidden = this._activate(hidden);
270
+
271
+ // Dropout (training only)
272
+ if (this.dropout > 0 && Math.random() < this.dropout) {
273
+ continue;
274
+ }
275
+
276
+ // Second layer
277
+ const output = hidden.scale(this.w2Scale);
278
+ const finalAmp = output.add(new Complex(this.bias2, 0));
279
+
280
+ // Quaternion passes through with slight rotation
281
+ const quatRotation = Quaternion.fromAxisAngle(
282
+ [1, 0, 0],
283
+ 0.1 * act.amplitude.phase()
284
+ );
285
+ const newQuat = act.quaternion.mul(quatRotation);
286
+
287
+ result.set(p, finalAmp, newQuat.normalize());
288
+ }
289
+
290
+ return result.normalize();
291
+ }
292
+
293
+ /**
294
+ * Apply activation function
295
+ * @private
296
+ */
297
+ _activate(c) {
298
+ switch (this.activation) {
299
+ case 'relu':
300
+ return new Complex(Math.max(0, c.re), Math.max(0, c.im));
301
+
302
+ case 'gelu':
303
+ // GELU approximation
304
+ const x = c.re;
305
+ const gelu = 0.5 * x * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * x * x * x)));
306
+ return new Complex(gelu, c.im * (c.im > 0 ? 1 : 0));
307
+
308
+ case 'swish':
309
+ const sigmoid = (v) => 1 / (1 + Math.exp(-v));
310
+ return new Complex(c.re * sigmoid(c.re), c.im * sigmoid(c.im));
311
+
312
+ case 'tanh':
313
+ return new Complex(Math.tanh(c.re), Math.tanh(c.im));
314
+
315
+ default:
316
+ return c;
317
+ }
318
+ }
319
+
320
+ /**
321
+ * Set training mode (enables dropout)
322
+ */
323
+ train(mode = true) {
324
+ this.training = mode;
325
+ return this;
326
+ }
327
+
328
+ /**
329
+ * Set evaluation mode (disables dropout)
330
+ */
331
+ eval() {
332
+ return this.train(false);
333
+ }
334
+ }
335
+
336
+ /**
337
+ * PrimeLayerNorm - Layer Normalization for Sparse Prime States
338
+ *
339
+ * Normalizes activations while preserving prime structure.
340
+ * Computes per-prime mean and variance statistics.
341
+ */
342
+ class PrimeLayerNorm {
343
+ /**
344
+ * @param {object} config
345
+ * @param {number} [config.eps=1e-6] - Epsilon for numerical stability
346
+ * @param {boolean} [config.elementwiseAffine=true] - Learn gamma/beta
347
+ */
348
+ constructor(config = {}) {
349
+ this.eps = config.eps || 1e-6;
350
+ this.elementwiseAffine = config.elementwiseAffine ?? true;
351
+
352
+ // Learnable parameters (simplified: scalar)
353
+ this.gamma = config.gamma || 1.0;
354
+ this.beta = config.beta || 0.0;
355
+ }
356
+
357
+ /**
358
+ * Apply layer normalization
359
+ * @param {SparsePrimeState} x - Input state
360
+ * @returns {SparsePrimeState} Normalized state
361
+ */
362
+ forward(x) {
363
+ const activePrimes = x.getActivePrimes();
364
+ if (activePrimes.length === 0) return x;
365
+
366
+ // Compute mean amplitude
367
+ let sum = 0;
368
+ let count = 0;
369
+ for (const p of activePrimes) {
370
+ sum += x.get(p).amplitude.norm();
371
+ count++;
372
+ }
373
+ const mean = sum / count;
374
+
375
+ // Compute variance
376
+ let varSum = 0;
377
+ for (const p of activePrimes) {
378
+ const diff = x.get(p).amplitude.norm() - mean;
379
+ varSum += diff * diff;
380
+ }
381
+ const variance = varSum / count;
382
+ const std = Math.sqrt(variance + this.eps);
383
+
384
+ // Normalize
385
+ const result = new SparsePrimeState(x.allPrimes.length, x.k);
386
+
387
+ for (const p of activePrimes) {
388
+ const act = x.get(p);
389
+ const normAmp = act.amplitude.norm();
390
+ const normalizedNorm = (normAmp - mean) / std;
391
+
392
+ // Apply affine transform
393
+ const scaledNorm = this.gamma * normalizedNorm + this.beta;
394
+
395
+ // Preserve phase, scale magnitude
396
+ const phase = act.amplitude.phase();
397
+ const newAmp = Complex.fromPolar(Math.max(0, scaledNorm), phase);
398
+
399
+ result.set(p, newAmp, act.quaternion);
400
+ }
401
+
402
+ return result.normalize();
403
+ }
404
+
405
+ /**
406
+ * Get parameters
407
+ */
408
+ getParameters() {
409
+ return { gamma: this.gamma, beta: this.beta };
410
+ }
411
+
412
+ /**
413
+ * Set parameters
414
+ */
415
+ setParameters(params) {
416
+ if (params.gamma !== undefined) this.gamma = params.gamma;
417
+ if (params.beta !== undefined) this.beta = params.beta;
418
+ }
419
+ }
420
+
421
+ /**
422
+ * PositionalPrimeEncoding
423
+ *
424
+ * Encodes position information using prime-based phases.
425
+ * Each position activates a unique combination of primes with position-dependent phases.
426
+ */
427
+ class PositionalPrimeEncoding {
428
+ /**
429
+ * @param {object} config
430
+ * @param {number} [config.maxLength=512] - Maximum sequence length
431
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
432
+ * @param {number} [config.activeK=32] - Sparsity per position
433
+ * @param {string} [config.type='sinusoidal'] - Encoding type
434
+ */
435
+ constructor(config = {}) {
436
+ this.maxLength = config.maxLength || 512;
437
+ this.numPrimes = config.numPrimes || 4096;
438
+ this.activeK = config.activeK || 32;
439
+ this.type = config.type || 'sinusoidal';
440
+
441
+ // Precompute position encodings
442
+ this.encodings = this._precompute();
443
+ }
444
+
445
+ /**
446
+ * Precompute position encodings
447
+ * @private
448
+ */
449
+ _precompute() {
450
+ const encodings = [];
451
+ const primes = firstNPrimes(this.activeK);
452
+
453
+ for (let pos = 0; pos < this.maxLength; pos++) {
454
+ const state = new SparsePrimeState(this.numPrimes, this.activeK);
455
+
456
+ for (let i = 0; i < primes.length; i++) {
457
+ const p = primes[i];
458
+
459
+ let phase;
460
+ switch (this.type) {
461
+ case 'sinusoidal':
462
+ // Classic transformer-style
463
+ const freq = 1 / Math.pow(10000, i / primes.length);
464
+ phase = pos * freq;
465
+ break;
466
+
467
+ case 'prime':
468
+ // Prime-based: use p-th prime for position
469
+ phase = 2 * Math.PI * pos / nthPrime(i + 1);
470
+ break;
471
+
472
+ case 'golden':
473
+ // Golden ratio based
474
+ const phi = (1 + Math.sqrt(5)) / 2;
475
+ phase = 2 * Math.PI * pos * Math.pow(phi, -i);
476
+ break;
477
+
478
+ default:
479
+ phase = 2 * Math.PI * pos * i / primes.length;
480
+ }
481
+
482
+ const amplitude = Complex.fromPolar(1 / Math.sqrt(primes.length), phase);
483
+ const quaternion = Quaternion.fromAxisAngle(
484
+ [Math.sin(phase), Math.cos(phase), 0],
485
+ phase / 2
486
+ );
487
+
488
+ state.set(p, amplitude, quaternion.normalize());
489
+ }
490
+
491
+ encodings.push(state.normalize());
492
+ }
493
+
494
+ return encodings;
495
+ }
496
+
497
+ /**
498
+ * Get encoding for a position
499
+ * @param {number} pos - Position index
500
+ * @returns {SparsePrimeState} Position encoding
501
+ */
502
+ getEncoding(pos) {
503
+ if (pos < 0) pos = 0;
504
+ if (pos >= this.maxLength) pos = this.maxLength - 1;
505
+ return this.encodings[pos];
506
+ }
507
+
508
+ /**
509
+ * Add position encoding to a state
510
+ * @param {SparsePrimeState} state - Input state
511
+ * @param {number} pos - Position index
512
+ * @returns {SparsePrimeState} State with position encoding added
513
+ */
514
+ encode(state, pos) {
515
+ const posEnc = this.getEncoding(pos);
516
+ return hamiltonCompose(state, posEnc);
517
+ }
518
+
519
+ /**
520
+ * Encode a sequence of states
521
+ * @param {SparsePrimeState[]} sequence - Input sequence
522
+ * @returns {SparsePrimeState[]} Encoded sequence
523
+ */
524
+ encodeSequence(sequence) {
525
+ return sequence.map((state, pos) => this.encode(state, pos));
526
+ }
527
+ }
528
+
529
+ /**
530
+ * ResoFormerBlock - Complete Transformer Block
531
+ *
532
+ * Combines:
533
+ * - Multi-head resonant attention
534
+ * - Feed-forward network
535
+ * - Layer normalization
536
+ * - Residual connections
537
+ */
538
+ class ResoFormerBlock {
539
+ /**
540
+ * @param {object} config
541
+ * @param {number} [config.numHeads=8] - Number of attention heads
542
+ * @param {number} [config.hiddenDim=256] - FFN hidden dimension
543
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
544
+ * @param {number} [config.activeK=32] - Sparsity parameter
545
+ * @param {number} [config.dropout=0.1] - Dropout probability
546
+ * @param {boolean} [config.preNorm=true] - Pre-norm or post-norm
547
+ */
548
+ constructor(config = {}) {
549
+ this.preNorm = config.preNorm ?? true;
550
+ this.numPrimes = config.numPrimes || 4096;
551
+ this.activeK = config.activeK || 32;
552
+
553
+ // Sub-layers
554
+ this.attention = new ResonantMultiHeadAttention({
555
+ numHeads: config.numHeads || 8,
556
+ numPrimes: this.numPrimes,
557
+ activeK: this.activeK,
558
+ temperature: config.attentionTemperature || 1.0
559
+ });
560
+
561
+ this.ffn = new PrimeFFN({
562
+ hiddenDim: config.hiddenDim || 256,
563
+ numPrimes: this.numPrimes,
564
+ activation: config.activation || 'gelu',
565
+ dropout: config.dropout || 0.1
566
+ });
567
+
568
+ this.norm1 = new PrimeLayerNorm();
569
+ this.norm2 = new PrimeLayerNorm();
570
+
571
+ this.dropoutRate = config.dropout || 0.1;
572
+ this.training = false;
573
+ }
574
+
575
+ /**
576
+ * Forward pass through the block
577
+ *
578
+ * @param {SparsePrimeState} x - Input state
579
+ * @param {SparsePrimeState[]} context - Context states for attention
580
+ * @returns {object} { output, attentionWeights }
581
+ */
582
+ forward(x, context = null) {
583
+ // Use self-attention if no context provided
584
+ const keys = context || [x];
585
+ const values = context || [x];
586
+
587
+ let attnInput, ffnInput;
588
+
589
+ if (this.preNorm) {
590
+ // Pre-norm: Norm -> Attn -> Add -> Norm -> FFN -> Add
591
+ attnInput = this.norm1.forward(x);
592
+ const attnOut = this.attention.forward(attnInput,
593
+ keys.map(k => this.norm1.forward(k)),
594
+ values.map(v => this.norm1.forward(v))
595
+ );
596
+
597
+ // Residual connection
598
+ const afterAttn = this._add(x, this._dropout(attnOut.result));
599
+
600
+ // FFN
601
+ ffnInput = this.norm2.forward(afterAttn);
602
+ const ffnOut = this.ffn.forward(ffnInput);
603
+
604
+ // Residual connection
605
+ const output = this._add(afterAttn, this._dropout(ffnOut));
606
+
607
+ return { output, attentionWeights: attnOut.attentionWeights };
608
+
609
+ } else {
610
+ // Post-norm: Attn -> Add -> Norm -> FFN -> Add -> Norm
611
+ const attnOut = this.attention.forward(x, keys, values);
612
+ const afterAttn = this.norm1.forward(this._add(x, this._dropout(attnOut.result)));
613
+
614
+ const ffnOut = this.ffn.forward(afterAttn);
615
+ const output = this.norm2.forward(this._add(afterAttn, this._dropout(ffnOut)));
616
+
617
+ return { output, attentionWeights: attnOut.attentionWeights };
618
+ }
619
+ }
620
+
621
+ /**
622
+ * Add two sparse states (residual connection)
623
+ * @private
624
+ */
625
+ _add(a, b) {
626
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
627
+
628
+ const allPrimes = new Set([...a.getActivePrimes(), ...b.getActivePrimes()]);
629
+
630
+ for (const p of allPrimes) {
631
+ const actA = a.get(p);
632
+ const actB = b.get(p);
633
+
634
+ const newAmp = actA.amplitude.add(actB.amplitude);
635
+ const newQuat = actA.quaternion.add(actB.quaternion);
636
+
637
+ result.set(p, newAmp, newQuat.normalize());
638
+ }
639
+
640
+ return result.normalize();
641
+ }
642
+
643
+ /**
644
+ * Apply dropout
645
+ * @private
646
+ */
647
+ _dropout(state) {
648
+ if (!this.training || this.dropoutRate <= 0) return state;
649
+
650
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
651
+ const scale = 1 / (1 - this.dropoutRate);
652
+
653
+ for (const [p, act] of state.activations) {
654
+ if (Math.random() >= this.dropoutRate) {
655
+ result.set(p, act.amplitude.scale(scale), act.quaternion);
656
+ }
657
+ }
658
+
659
+ return result;
660
+ }
661
+
662
+ /**
663
+ * Set training mode
664
+ */
665
+ train(mode = true) {
666
+ this.training = mode;
667
+ this.ffn.train(mode);
668
+ return this;
669
+ }
670
+
671
+ /**
672
+ * Set evaluation mode
673
+ */
674
+ eval() {
675
+ return this.train(false);
676
+ }
677
+ }
678
+
679
+ /**
680
+ * ResoFormer - Complete Multi-Layer Model
681
+ *
682
+ * Stacks multiple ResoFormerBlocks with optional position encoding.
683
+ */
684
+ class ResoFormer {
685
+ /**
686
+ * @param {object} config
687
+ * @param {number} [config.numLayers=6] - Number of transformer blocks
688
+ * @param {number} [config.numHeads=8] - Attention heads per block
689
+ * @param {number} [config.hiddenDim=256] - FFN hidden dimension
690
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
691
+ * @param {number} [config.activeK=32] - Sparsity parameter
692
+ * @param {number} [config.dropout=0.1] - Dropout probability
693
+ * @param {boolean} [config.usePositionalEncoding=true] - Add position encoding
694
+ */
695
+ constructor(config = {}) {
696
+ this.numLayers = config.numLayers || 6;
697
+ this.numPrimes = config.numPrimes || 4096;
698
+ this.activeK = config.activeK || 32;
699
+
700
+ // Position encoding
701
+ this.usePositionalEncoding = config.usePositionalEncoding ?? true;
702
+ if (this.usePositionalEncoding) {
703
+ this.posEncoder = new PositionalPrimeEncoding({
704
+ numPrimes: this.numPrimes,
705
+ activeK: this.activeK
706
+ });
707
+ }
708
+
709
+ // Stack of transformer blocks
710
+ this.blocks = [];
711
+ for (let i = 0; i < this.numLayers; i++) {
712
+ this.blocks.push(new ResoFormerBlock({
713
+ numHeads: config.numHeads || 8,
714
+ hiddenDim: config.hiddenDim || 256,
715
+ numPrimes: this.numPrimes,
716
+ activeK: this.activeK,
717
+ dropout: config.dropout || 0.1,
718
+ preNorm: config.preNorm ?? true
719
+ }));
720
+ }
721
+
722
+ // Final normalization
723
+ this.finalNorm = new PrimeLayerNorm();
724
+ }
725
+
726
+ /**
727
+ * Forward pass through all layers
728
+ *
729
+ * @param {SparsePrimeState|SparsePrimeState[]} input - Input state(s)
730
+ * @returns {object} { output, layerOutputs, attentionMaps }
731
+ */
732
+ forward(input) {
733
+ // Handle single input or sequence
734
+ const isSequence = Array.isArray(input);
735
+ let states = isSequence ? input : [input];
736
+
737
+ // Add position encoding
738
+ if (this.usePositionalEncoding) {
739
+ states = this.posEncoder.encodeSequence(states);
740
+ }
741
+
742
+ const layerOutputs = [];
743
+ const attentionMaps = [];
744
+
745
+ // Process through each block
746
+ // For simplicity, process each state with all others as context
747
+ for (let layer = 0; layer < this.numLayers; layer++) {
748
+ const block = this.blocks[layer];
749
+ const newStates = [];
750
+ const layerAttention = [];
751
+
752
+ for (let i = 0; i < states.length; i++) {
753
+ const { output, attentionWeights } = block.forward(states[i], states);
754
+ newStates.push(output);
755
+ layerAttention.push(attentionWeights);
756
+ }
757
+
758
+ states = newStates;
759
+ layerOutputs.push([...states]);
760
+ attentionMaps.push(layerAttention);
761
+ }
762
+
763
+ // Final normalization
764
+ states = states.map(s => this.finalNorm.forward(s));
765
+
766
+ return {
767
+ output: isSequence ? states : states[0],
768
+ layerOutputs,
769
+ attentionMaps
770
+ };
771
+ }
772
+
773
+ /**
774
+ * Set training mode
775
+ */
776
+ train(mode = true) {
777
+ for (const block of this.blocks) {
778
+ block.train(mode);
779
+ }
780
+ return this;
781
+ }
782
+
783
+ /**
784
+ * Set evaluation mode
785
+ */
786
+ eval() {
787
+ return this.train(false);
788
+ }
789
+
790
+ /**
791
+ * Get total parameter count (approximate)
792
+ */
793
+ getParameterCount() {
794
+ // Simplified count
795
+ const perBlock = this.numLayers * (
796
+ 8 * 3 + // Attention head weights
797
+ 4 + // FFN weights
798
+ 2 // LayerNorm
799
+ );
800
+ return perBlock + (this.usePositionalEncoding ? this.activeK * 4 : 0);
801
+ }
802
+ }
803
+
804
+ module.exports = {
805
+ ResonantMultiHeadAttention,
806
+ PrimeFFN,
807
+ PrimeLayerNorm,
808
+ PositionalPrimeEncoding,
809
+ ResoFormerBlock,
810
+ ResoFormer
811
+ };