@aleph-ai/tinyaleph 1.3.0 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,892 @@
1
+ /**
2
+ * CRT-Enhanced ResoFormer Layers
3
+ *
4
+ * Integrates Chinese Remainder Theorem reconstruction with Birkhoff polytope
5
+ * constraints and homology-based regularization into the ResoFormer architecture.
6
+ *
7
+ * Key additions:
8
+ * - CRTResonantAttention: Modular attention with per-modulus Birkhoff projection
9
+ * - HomologyRegularizedBlock: Detects semantic inconsistencies as topological holes
10
+ * - CRTResoFormer: Full model with homology loss integration
11
+ *
12
+ * Mathematical Foundation:
13
+ * - Encode: r_k = softmax(W_k h + b_k) ∈ Δ(ℤ/p_k) for coprime moduli p_k
14
+ * - Reconstruct: L̂ = Σ_k E[r_k] · (P/p_k) · (P/p_k)^{-1} mod P via CRT
15
+ * - Project: A_k ∈ Birkhoff(n) via Sinkhorn-Knopp
16
+ * - Regularize: ℒ_homology = Σ_{cycles} f(cycle) detects Ker(ℛ) holes
17
+ */
18
+
19
+ 'use strict';
20
+
21
+ const {
22
+ Quaternion,
23
+ SparsePrimeState,
24
+ resonanceScore,
25
+ resonantAttention,
26
+ hamiltonCompose,
27
+ computeCoherence
28
+ } = require('./rformer');
29
+
30
+ const {
31
+ ResonantMultiHeadAttention,
32
+ PrimeFFN,
33
+ PrimeLayerNorm,
34
+ PositionalPrimeEncoding,
35
+ ResoFormerBlock
36
+ } = require('./rformer-layers');
37
+
38
+ const {
39
+ CRTReconstructor,
40
+ BirkhoffProjector,
41
+ HomologyLoss,
42
+ CRTModularLayer,
43
+ CRTFusedAttention,
44
+ CoprimeSelector,
45
+ ResidueEncoder
46
+ } = require('./crt-homology');
47
+
48
+ const { Complex, PrimeState } = require('./hilbert');
49
+ const { firstNPrimes, isPrime } = require('./prime');
50
+
51
+ /**
52
+ * CRTResonantAttention - Multi-head attention with CRT-fused modular structure
53
+ *
54
+ * Each head computes attention in a different modular field, then
55
+ * fuses via CRT reconstruction. Attention matrices are projected
56
+ * onto the Birkhoff polytope for doubly-stochastic structure.
57
+ */
58
+ class CRTResonantAttention {
59
+ /**
60
+ * @param {object} config
61
+ * @param {number} config.numHeads - Number of attention heads (maps to moduli)
62
+ * @param {number} [config.numPrimes=4096] - Size of prime vocabulary
63
+ * @param {number} [config.activeK=32] - Sparsity per state
64
+ * @param {number} [config.sinkhornIterations=10] - Iterations for Birkhoff projection
65
+ * @param {number} [config.temperature=1.0] - Softmax temperature
66
+ */
67
+ constructor(config) {
68
+ this.numHeads = config.numHeads || 8;
69
+ this.numPrimes = config.numPrimes || 4096;
70
+ this.activeK = config.activeK || 32;
71
+ this.temperature = config.temperature || 1.0;
72
+
73
+ // Select coprime moduli for CRT reconstruction
74
+ this.coprimeSelector = new CoprimeSelector(this.numHeads);
75
+ this.moduli = this.coprimeSelector.selectMinimal();
76
+
77
+ // CRT components
78
+ this.crtReconstructor = new CRTReconstructor(this.moduli);
79
+ this.birkhoffProjector = new BirkhoffProjector({
80
+ maxIterations: config.sinkhornIterations || 10,
81
+ tolerance: config.birkhoffTolerance || 1e-3
82
+ });
83
+
84
+ // Per-head weights [alpha, beta, gamma] for resonance scoring
85
+ this.headWeights = config.headWeights || this._defaultHeadWeights();
86
+
87
+ // Output scaling
88
+ this.outputScale = config.outputScale || 1.0 / Math.sqrt(this.numHeads);
89
+ }
90
+
91
+ /**
92
+ * Generate default head weights with moduli-specific emphasis
93
+ * @private
94
+ */
95
+ _defaultHeadWeights() {
96
+ const weights = [];
97
+ for (let h = 0; h < this.numHeads; h++) {
98
+ const modulus = this.moduli[h];
99
+ // Weight configuration varies with modulus size
100
+ // Larger moduli get more phase weight, smaller get more Jaccard
101
+ const t = Math.log(modulus) / Math.log(this.moduli[this.numHeads - 1]);
102
+
103
+ const alpha = 0.5 - 0.2 * t; // Jaccard weight
104
+ const beta = 0.3; // Quaternion (constant)
105
+ const gamma = 0.2 + 0.2 * t; // Phase weight
106
+
107
+ weights.push([alpha, beta, gamma]);
108
+ }
109
+ return weights;
110
+ }
111
+
112
+ /**
113
+ * Apply CRT-fused multi-head attention
114
+ *
115
+ * @param {SparsePrimeState} query - Query state
116
+ * @param {SparsePrimeState[]} keys - Key states
117
+ * @param {SparsePrimeState[]} values - Value states
118
+ * @returns {object} { result, headOutputs, attentionWeights, crtResidues, homologyInfo }
119
+ */
120
+ forward(query, keys, values) {
121
+ const n = keys.length;
122
+ if (n === 0) {
123
+ return {
124
+ result: query,
125
+ headOutputs: [],
126
+ attentionWeights: [],
127
+ crtResidues: [],
128
+ homologyInfo: { hasHoles: false, bettiNumbers: [1, 0] }
129
+ };
130
+ }
131
+
132
+ const headOutputs = [];
133
+ const allWeights = [];
134
+ const crtResidues = [];
135
+
136
+ // Apply each head (modular attention)
137
+ for (let h = 0; h < this.numHeads; h++) {
138
+ const modulus = this.moduli[h];
139
+ const [alpha, beta, gamma] = this.headWeights[h];
140
+
141
+ // Compute head-specific attention with modular structure
142
+ const headResult = this._modularHeadAttention(
143
+ query, keys, values, modulus, alpha, beta, gamma
144
+ );
145
+
146
+ headOutputs.push(headResult.result);
147
+ allWeights.push(headResult.birkhoffWeights);
148
+ crtResidues.push(headResult.residue);
149
+ }
150
+
151
+ // Fuse head outputs via CRT reconstruction
152
+ const fusedResult = this._crtFuseHeads(headOutputs);
153
+
154
+ // Detect homology holes (cycles in kernel)
155
+ const homologyInfo = this._detectHomologyHoles(crtResidues, allWeights);
156
+
157
+ return {
158
+ result: fusedResult,
159
+ headOutputs,
160
+ attentionWeights: allWeights,
161
+ crtResidues,
162
+ homologyInfo
163
+ };
164
+ }
165
+
166
+ /**
167
+ * Single head attention with modular structure and Birkhoff projection
168
+ * @private
169
+ */
170
+ _modularHeadAttention(query, keys, values, modulus, alpha, beta, gamma) {
171
+ const n = keys.length;
172
+
173
+ // Compute resonance scores
174
+ const scores = keys.map((k, idx) => {
175
+ const primesQ = new Set(query.getActivePrimes());
176
+ const primesK = new Set(k.getActivePrimes());
177
+
178
+ // Jaccard
179
+ const intersection = new Set([...primesQ].filter(p => primesK.has(p)));
180
+ const union = new Set([...primesQ, ...primesK]);
181
+ const jaccard = intersection.size / (union.size || 1);
182
+
183
+ if (intersection.size === 0) {
184
+ return alpha * jaccard;
185
+ }
186
+
187
+ // Quaternion alignment
188
+ let quatSum = 0;
189
+ for (const p of intersection) {
190
+ const qi = query.get(p).quaternion;
191
+ const qk = k.get(p).quaternion;
192
+ quatSum += Math.abs(qi.dot(qk));
193
+ }
194
+ const quatAlign = quatSum / intersection.size;
195
+
196
+ // Phase coherence (modular)
197
+ let phaseSum = 0;
198
+ for (const p of intersection) {
199
+ const phaseQ = query.get(p).amplitude.phase();
200
+ const phaseK = k.get(p).amplitude.phase();
201
+ // Modular phase difference
202
+ const phaseDiff = ((phaseQ - phaseK) % (2 * Math.PI / modulus)) * modulus;
203
+ phaseSum += Math.cos(phaseDiff);
204
+ }
205
+ const phaseCoherence = (phaseSum / intersection.size + 1) / 2;
206
+
207
+ return alpha * jaccard + beta * quatAlign + gamma * phaseCoherence;
208
+ });
209
+
210
+ // Create attention matrix (single row for query)
211
+ // Expand to n×n by computing pairwise scores among keys
212
+ const attentionMatrix = [];
213
+ for (let i = 0; i < n; i++) {
214
+ const row = keys.map((k, j) => resonanceScore(keys[i], k, alpha, beta, gamma));
215
+ // Apply modular structure
216
+ for (let j = 0; j < n; j++) {
217
+ row[j] = row[j] % 1.0; // Keep in [0, 1) range
218
+ }
219
+ attentionMatrix.push(row);
220
+ }
221
+
222
+ // Project onto Birkhoff polytope (doubly-stochastic)
223
+ const birkhoffResult = this.birkhoffProjector.project(attentionMatrix);
224
+ const birkhoffMatrix = birkhoffResult && birkhoffResult.matrix ? birkhoffResult.matrix : null;
225
+
226
+ // Extract query-specific weights (first row after projection)
227
+ // For proper attention, use the scores projected through Birkhoff
228
+ const softmaxWeights = this._softmax(scores.map(s => s / this.temperature));
229
+
230
+ // Blend Birkhoff structure with softmax attention
231
+ const blendFactor = 0.3; // How much Birkhoff structure to incorporate
232
+ const blendedWeights = softmaxWeights.map((w, i) => {
233
+ // Safely access Birkhoff weights with full null checking
234
+ if (birkhoffMatrix && Array.isArray(birkhoffMatrix) && birkhoffMatrix.length > 0) {
235
+ const birkhoffRow = birkhoffMatrix[0];
236
+ if (Array.isArray(birkhoffRow) && i < birkhoffRow.length && typeof birkhoffRow[i] === 'number') {
237
+ return (1 - blendFactor) * w + blendFactor * birkhoffRow[i];
238
+ }
239
+ }
240
+ return w; // Fall back to pure softmax if Birkhoff unavailable
241
+ });
242
+
243
+ // Normalize
244
+ const weightSum = blendedWeights.reduce((a, b) => a + b, 0);
245
+ const normalizedWeights = blendedWeights.map(w => w / (weightSum || 1));
246
+
247
+ // Compute residue for CRT reconstruction
248
+ // Residue = weighted sum of value indices mod modulus
249
+ let residue = 0;
250
+ for (let i = 0; i < n; i++) {
251
+ residue = (residue + normalizedWeights[i] * i) % modulus;
252
+ }
253
+
254
+ // Weighted combination of values
255
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
256
+
257
+ for (let i = 0; i < n; i++) {
258
+ const w = normalizedWeights[i];
259
+ for (const [p, act] of values[i].activations) {
260
+ const current = result.get(p);
261
+ const newAmp = current.amplitude.add(act.amplitude.scale(w));
262
+ const newQuat = current.quaternion.add(act.quaternion.scale(w));
263
+ result.set(p, newAmp, newQuat.normalize());
264
+ }
265
+ }
266
+
267
+ return {
268
+ result: result.normalize(),
269
+ birkhoffWeights: normalizedWeights,
270
+ residue,
271
+ convergenceInfo: birkhoffResult.convergenceInfo
272
+ };
273
+ }
274
+
275
+ /**
276
+ * Softmax function
277
+ * @private
278
+ */
279
+ _softmax(scores) {
280
+ const maxScore = Math.max(...scores);
281
+ const expScores = scores.map(s => Math.exp(s - maxScore));
282
+ const sumExp = expScores.reduce((a, b) => a + b, 0);
283
+ return expScores.map(e => e / (sumExp || 1));
284
+ }
285
+
286
+ /**
287
+ * Fuse head outputs via CRT reconstruction
288
+ * @private
289
+ */
290
+ _crtFuseHeads(headOutputs) {
291
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
292
+
293
+ // Collect all active primes across heads
294
+ const allPrimes = new Set();
295
+ for (const head of headOutputs) {
296
+ for (const p of head.getActivePrimes()) {
297
+ allPrimes.add(p);
298
+ }
299
+ }
300
+
301
+ // For each prime, fuse activations via CRT-style combination
302
+ for (const p of allPrimes) {
303
+ // Collect residues (amplitude norms) from each head
304
+ const residues = headOutputs.map(h => h.get(p).amplitude.norm());
305
+
306
+ // Compute reconstructed norm using weighted average (simpler than full CRT)
307
+ // This avoids issues with probability distributions of varying sizes
308
+ let reconstructedNorm = 0;
309
+ let totalWeight = 0;
310
+ for (let h = 0; h < headOutputs.length; h++) {
311
+ const w = 1 / (h + 1); // Decaying weights by head index
312
+ reconstructedNorm += residues[h] * w;
313
+ totalWeight += w;
314
+ }
315
+ reconstructedNorm = totalWeight > 0 ? reconstructedNorm / totalWeight : 0;
316
+
317
+ // Fuse phases and quaternions via weighted average
318
+ let phaseSum = 0;
319
+ let quatSum = Quaternion.zero();
320
+ let weightSum = 0;
321
+
322
+ for (let h = 0; h < headOutputs.length; h++) {
323
+ const act = headOutputs[h].get(p);
324
+ const w = act.amplitude.norm();
325
+ phaseSum += w * act.amplitude.phase();
326
+ quatSum = quatSum.add(act.quaternion.scale(w));
327
+ weightSum += w;
328
+ }
329
+
330
+ const avgPhase = weightSum > 0 ? phaseSum / weightSum : 0;
331
+ const avgQuat = weightSum > 0 ? quatSum.scale(1/weightSum).normalize() : Quaternion.one();
332
+
333
+ // Scale by output factor
334
+ const finalAmp = Complex.fromPolar(
335
+ reconstructedNorm * this.outputScale / this.numHeads,
336
+ avgPhase
337
+ );
338
+
339
+ result.set(p, finalAmp, avgQuat);
340
+ }
341
+
342
+ return result.normalize();
343
+ }
344
+
345
+ /**
346
+ * Detect homology holes in the kernel
347
+ * @private
348
+ */
349
+ _detectHomologyHoles(residues, weights) {
350
+ // Build adjacency from attention weights
351
+ const n = weights[0]?.length || 0;
352
+ if (n < 2) {
353
+ return { hasHoles: false, bettiNumbers: [1, 0], cycles: [] };
354
+ }
355
+
356
+ // Create averaged adjacency matrix
357
+ const adjacency = [];
358
+ for (let i = 0; i < n; i++) {
359
+ adjacency.push(new Array(n).fill(0));
360
+ }
361
+
362
+ for (const headWeights of weights) {
363
+ for (let i = 0; i < Math.min(headWeights.length, n); i++) {
364
+ for (let j = 0; j < Math.min(headWeights.length, n); j++) {
365
+ adjacency[i][j] += headWeights[i] * headWeights[j] / weights.length;
366
+ }
367
+ }
368
+ }
369
+
370
+ // Threshold adjacency to binary
371
+ const threshold = 0.1;
372
+ for (let i = 0; i < n; i++) {
373
+ for (let j = 0; j < n; j++) {
374
+ adjacency[i][j] = adjacency[i][j] > threshold ? 1 : 0;
375
+ }
376
+ }
377
+
378
+ // Compute error terms (CRT inconsistencies)
379
+ const errors = [];
380
+ for (let i = 0; i < residues.length; i++) {
381
+ // Error = deviation from expected value
382
+ errors.push(Math.abs(residues[i] - Math.floor(residues[i])));
383
+ }
384
+
385
+ // Build kernel (high-error nodes)
386
+ const errorThreshold = 0.1;
387
+ const kernel = errors.map((e, i) => ({ index: i, error: e, inKernel: e > errorThreshold }));
388
+ const kernelNodes = kernel.filter(k => k.inKernel).map(k => k.index);
389
+
390
+ // Detect cycles in kernel subgraph
391
+ const cycles = this._findCyclesInSubgraph(adjacency, kernelNodes);
392
+
393
+ // Compute Betti numbers
394
+ const beta0 = this._countConnectedComponents(adjacency, kernelNodes);
395
+ const beta1 = cycles.length;
396
+
397
+ return {
398
+ hasHoles: beta1 > 0,
399
+ bettiNumbers: [beta0, beta1],
400
+ cycles,
401
+ kernelNodes,
402
+ errors
403
+ };
404
+ }
405
+
406
+ /**
407
+ * Find cycles in a subgraph
408
+ * @private
409
+ */
410
+ _findCyclesInSubgraph(adjacency, nodes) {
411
+ const cycles = [];
412
+ const n = nodes.length;
413
+
414
+ if (n < 3) return cycles;
415
+
416
+ // Simple cycle detection: look for triangles
417
+ for (let i = 0; i < n; i++) {
418
+ for (let j = i + 1; j < n; j++) {
419
+ for (let k = j + 1; k < n; k++) {
420
+ const a = nodes[i], b = nodes[j], c = nodes[k];
421
+ if (a < adjacency.length && b < adjacency.length && c < adjacency.length) {
422
+ if (adjacency[a][b] && adjacency[b][c] && adjacency[c][a]) {
423
+ cycles.push([a, b, c]);
424
+ }
425
+ }
426
+ }
427
+ }
428
+ }
429
+
430
+ return cycles;
431
+ }
432
+
433
+ /**
434
+ * Count connected components in subgraph
435
+ * @private
436
+ */
437
+ _countConnectedComponents(adjacency, nodes) {
438
+ if (nodes.length === 0) return 0;
439
+
440
+ const visited = new Set();
441
+ let components = 0;
442
+
443
+ const dfs = (node) => {
444
+ visited.add(node);
445
+ for (const neighbor of nodes) {
446
+ if (!visited.has(neighbor) &&
447
+ node < adjacency.length && neighbor < adjacency.length &&
448
+ adjacency[node][neighbor]) {
449
+ dfs(neighbor);
450
+ }
451
+ }
452
+ };
453
+
454
+ for (const node of nodes) {
455
+ if (!visited.has(node)) {
456
+ dfs(node);
457
+ components++;
458
+ }
459
+ }
460
+
461
+ return components;
462
+ }
463
+
464
+ /**
465
+ * Set head weights (for training)
466
+ */
467
+ setHeadWeights(headIdx, weights) {
468
+ if (headIdx >= 0 && headIdx < this.numHeads) {
469
+ this.headWeights[headIdx] = weights;
470
+ }
471
+ }
472
+
473
+ /**
474
+ * Get all parameters (for serialization)
475
+ */
476
+ getParameters() {
477
+ return {
478
+ numHeads: this.numHeads,
479
+ moduli: this.moduli,
480
+ headWeights: this.headWeights,
481
+ temperature: this.temperature,
482
+ outputScale: this.outputScale
483
+ };
484
+ }
485
+ }
486
+
487
+ /**
488
+ * HomologyRegularizedBlock - ResoFormer block with homology-based regularization
489
+ *
490
+ * Extends ResoFormerBlock with:
491
+ * - CRT-fused attention
492
+ * - Homology loss computation
493
+ * - Kernel detection for semantic inconsistencies
494
+ */
495
+ class HomologyRegularizedBlock {
496
+ /**
497
+ * @param {object} config
498
+ * @param {number} [config.numHeads=8] - Number of attention heads
499
+ * @param {number} [config.hiddenDim=256] - FFN hidden dimension
500
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
501
+ * @param {number} [config.activeK=32] - Sparsity parameter
502
+ * @param {number} [config.dropout=0.1] - Dropout probability
503
+ * @param {boolean} [config.preNorm=true] - Pre-norm or post-norm
504
+ * @param {number} [config.homologyWeight=0.1] - Weight for homology loss
505
+ */
506
+ constructor(config = {}) {
507
+ this.preNorm = config.preNorm ?? true;
508
+ this.numPrimes = config.numPrimes || 4096;
509
+ this.activeK = config.activeK || 32;
510
+ this.homologyWeight = config.homologyWeight || 0.1;
511
+
512
+ // CRT-enhanced attention
513
+ this.attention = new CRTResonantAttention({
514
+ numHeads: config.numHeads || 8,
515
+ numPrimes: this.numPrimes,
516
+ activeK: this.activeK,
517
+ temperature: config.attentionTemperature || 1.0,
518
+ sinkhornIterations: config.sinkhornIterations || 10
519
+ });
520
+
521
+ // FFN and norms from standard block
522
+ this.ffn = new PrimeFFN({
523
+ hiddenDim: config.hiddenDim || 256,
524
+ numPrimes: this.numPrimes,
525
+ activation: config.activation || 'gelu',
526
+ dropout: config.dropout || 0.1
527
+ });
528
+
529
+ this.norm1 = new PrimeLayerNorm();
530
+ this.norm2 = new PrimeLayerNorm();
531
+
532
+ // Homology loss computer
533
+ this.homologyLoss = new HomologyLoss({
534
+ errorThreshold: config.errorThreshold || 0.1,
535
+ alpha: config.homologyAlpha || 0.5,
536
+ beta: config.homologyBeta || 1.0,
537
+ gamma: config.homologyGamma || 0.5
538
+ });
539
+
540
+ this.dropoutRate = config.dropout || 0.1;
541
+ this.training = false;
542
+ }
543
+
544
+ /**
545
+ * Forward pass with homology regularization
546
+ *
547
+ * @param {SparsePrimeState} x - Input state
548
+ * @param {SparsePrimeState[]} context - Context states for attention
549
+ * @returns {object} { output, attentionWeights, homologyInfo, loss }
550
+ */
551
+ forward(x, context = null) {
552
+ const keys = context || [x];
553
+ const values = context || [x];
554
+
555
+ let attnInput, ffnInput;
556
+ let homologyLossValue = 0;
557
+ let homologyInfo = null;
558
+
559
+ if (this.preNorm) {
560
+ // Pre-norm: Norm -> Attn -> Add -> Norm -> FFN -> Add
561
+ attnInput = this.norm1.forward(x);
562
+ const attnOut = this.attention.forward(
563
+ attnInput,
564
+ keys.map(k => this.norm1.forward(k)),
565
+ values.map(v => this.norm1.forward(v))
566
+ );
567
+
568
+ homologyInfo = attnOut.homologyInfo;
569
+
570
+ // Compute homology loss if there are holes
571
+ if (homologyInfo.hasHoles && homologyInfo.cycles && homologyInfo.cycles.length > 0) {
572
+ // Compute loss directly from detected cycles
573
+ // f(cycle) = |cycle|^α * β^γ * Σ errors
574
+ for (const cycle of homologyInfo.cycles) {
575
+ const cycleLength = cycle.length;
576
+ // Sum of errors for nodes in cycle
577
+ let errorSum = 0;
578
+ for (const nodeIdx of cycle) {
579
+ if (homologyInfo.errors && nodeIdx < homologyInfo.errors.length) {
580
+ errorSum += this.homologyLoss.sigmoid(homologyInfo.errors[nodeIdx] - 0.1);
581
+ }
582
+ }
583
+ homologyLossValue += errorSum *
584
+ Math.pow(cycleLength, this.homologyLoss.alpha) *
585
+ Math.pow(this.homologyLoss.beta, this.homologyLoss.gamma);
586
+ }
587
+ }
588
+
589
+ // Residual connection
590
+ const afterAttn = this._add(x, this._dropout(attnOut.result));
591
+
592
+ // FFN
593
+ ffnInput = this.norm2.forward(afterAttn);
594
+ const ffnOut = this.ffn.forward(ffnInput);
595
+
596
+ // Residual connection
597
+ const output = this._add(afterAttn, this._dropout(ffnOut));
598
+
599
+ return {
600
+ output,
601
+ attentionWeights: attnOut.attentionWeights,
602
+ homologyInfo,
603
+ loss: this.homologyWeight * homologyLossValue,
604
+ crtResidues: attnOut.crtResidues
605
+ };
606
+
607
+ } else {
608
+ // Post-norm
609
+ const attnOut = this.attention.forward(x, keys, values);
610
+ homologyInfo = attnOut.homologyInfo;
611
+
612
+ const afterAttn = this.norm1.forward(this._add(x, this._dropout(attnOut.result)));
613
+ const ffnOut = this.ffn.forward(afterAttn);
614
+ const output = this.norm2.forward(this._add(afterAttn, this._dropout(ffnOut)));
615
+
616
+ return {
617
+ output,
618
+ attentionWeights: attnOut.attentionWeights,
619
+ homologyInfo,
620
+ loss: 0,
621
+ crtResidues: attnOut.crtResidues
622
+ };
623
+ }
624
+ }
625
+
626
+ /**
627
+ * Add two sparse states (residual connection)
628
+ * @private
629
+ */
630
+ _add(a, b) {
631
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
632
+ const allPrimes = new Set([...a.getActivePrimes(), ...b.getActivePrimes()]);
633
+
634
+ for (const p of allPrimes) {
635
+ const actA = a.get(p);
636
+ const actB = b.get(p);
637
+
638
+ const newAmp = actA.amplitude.add(actB.amplitude);
639
+ const newQuat = actA.quaternion.add(actB.quaternion);
640
+
641
+ result.set(p, newAmp, newQuat.normalize());
642
+ }
643
+
644
+ return result.normalize();
645
+ }
646
+
647
+ /**
648
+ * Apply dropout
649
+ * @private
650
+ */
651
+ _dropout(state) {
652
+ if (!this.training || this.dropoutRate <= 0) return state;
653
+
654
+ const result = new SparsePrimeState(this.numPrimes, this.activeK);
655
+ const scale = 1 / (1 - this.dropoutRate);
656
+
657
+ for (const [p, act] of state.activations) {
658
+ if (Math.random() >= this.dropoutRate) {
659
+ result.set(p, act.amplitude.scale(scale), act.quaternion);
660
+ }
661
+ }
662
+
663
+ return result;
664
+ }
665
+
666
+ /**
667
+ * Set training mode
668
+ */
669
+ train(mode = true) {
670
+ this.training = mode;
671
+ this.ffn.train(mode);
672
+ return this;
673
+ }
674
+
675
+ /**
676
+ * Set evaluation mode
677
+ */
678
+ eval() {
679
+ return this.train(false);
680
+ }
681
+ }
682
+
683
+ /**
684
+ * CRTResoFormer - Complete CRT-enhanced ResoFormer model
685
+ *
686
+ * Stacks HomologyRegularizedBlocks with:
687
+ * - CRT-fused attention at each layer
688
+ * - Accumulated homology loss for training
689
+ * - Semantic hole detection across layers
690
+ */
691
+ class CRTResoFormer {
692
+ /**
693
+ * @param {object} config
694
+ * @param {number} [config.numLayers=6] - Number of transformer blocks
695
+ * @param {number} [config.numHeads=8] - Attention heads per block
696
+ * @param {number} [config.hiddenDim=256] - FFN hidden dimension
697
+ * @param {number} [config.numPrimes=4096] - Prime vocabulary size
698
+ * @param {number} [config.activeK=32] - Sparsity parameter
699
+ * @param {number} [config.dropout=0.1] - Dropout probability
700
+ * @param {boolean} [config.usePositionalEncoding=true] - Add position encoding
701
+ * @param {number} [config.homologyWeight=0.1] - Weight for homology loss
702
+ */
703
+ constructor(config = {}) {
704
+ this.numLayers = config.numLayers || 6;
705
+ this.numPrimes = config.numPrimes || 4096;
706
+ this.activeK = config.activeK || 32;
707
+
708
+ // Position encoding
709
+ this.usePositionalEncoding = config.usePositionalEncoding ?? true;
710
+ if (this.usePositionalEncoding) {
711
+ this.posEncoder = new PositionalPrimeEncoding({
712
+ numPrimes: this.numPrimes,
713
+ activeK: this.activeK
714
+ });
715
+ }
716
+
717
+ // Stack of CRT-enhanced blocks
718
+ this.blocks = [];
719
+ for (let i = 0; i < this.numLayers; i++) {
720
+ this.blocks.push(new HomologyRegularizedBlock({
721
+ numHeads: config.numHeads || 8,
722
+ hiddenDim: config.hiddenDim || 256,
723
+ numPrimes: this.numPrimes,
724
+ activeK: this.activeK,
725
+ dropout: config.dropout || 0.1,
726
+ preNorm: config.preNorm ?? true,
727
+ homologyWeight: config.homologyWeight || 0.1,
728
+ sinkhornIterations: config.sinkhornIterations || 10
729
+ }));
730
+ }
731
+
732
+ // Final normalization
733
+ this.finalNorm = new PrimeLayerNorm();
734
+ }
735
+ /**
736
+ * Forward pass through all layers
737
+ *
738
+ * @param {SparsePrimeState|SparsePrimeState[]} input - Input state(s)
739
+ * @returns {object} { output, layerOutputs, attentionMaps, homologyReport, totalLoss }
740
+ */
741
+ forward(input) {
742
+ const isSequence = Array.isArray(input);
743
+ let states = isSequence ? input : [input];
744
+
745
+ // Add position encoding
746
+ if (this.usePositionalEncoding) {
747
+ states = this.posEncoder.encodeSequence(states);
748
+ }
749
+
750
+ const layerOutputs = [];
751
+ const attentionMaps = [];
752
+ const homologyReports = [];
753
+ let totalLoss = 0;
754
+
755
+ // Process through each block
756
+ for (let layer = 0; layer < this.numLayers; layer++) {
757
+ const block = this.blocks[layer];
758
+ const newStates = [];
759
+ const layerAttention = [];
760
+ const layerHomology = [];
761
+
762
+ for (let i = 0; i < states.length; i++) {
763
+ const { output, attentionWeights, homologyInfo, loss } = block.forward(states[i], states);
764
+ newStates.push(output);
765
+ layerAttention.push(attentionWeights);
766
+ layerHomology.push(homologyInfo);
767
+ totalLoss += loss;
768
+ }
769
+
770
+ states = newStates;
771
+ layerOutputs.push([...states]);
772
+ attentionMaps.push(layerAttention);
773
+ homologyReports.push(layerHomology);
774
+ }
775
+
776
+ // Final normalization
777
+ states = states.map(s => this.finalNorm.forward(s));
778
+
779
+ // Aggregate homology report
780
+ const aggregateHomology = this._aggregateHomologyReports(homologyReports);
781
+
782
+ return {
783
+ output: isSequence ? states : states[0],
784
+ layerOutputs,
785
+ attentionMaps,
786
+ homologyReport: aggregateHomology,
787
+ totalLoss
788
+ };
789
+ }
790
+
791
+ /**
792
+ * Aggregate homology reports across layers
793
+ * @private
794
+ */
795
+ _aggregateHomologyReports(reports) {
796
+ let totalHoles = 0;
797
+ let maxBeta1 = 0;
798
+ const allCycles = [];
799
+
800
+ for (const layerReports of reports) {
801
+ for (const report of layerReports) {
802
+ if (report.hasHoles) {
803
+ totalHoles++;
804
+ maxBeta1 = Math.max(maxBeta1, report.bettiNumbers[1]);
805
+ allCycles.push(...(report.cycles || []));
806
+ }
807
+ }
808
+ }
809
+
810
+ return {
811
+ hasHoles: totalHoles > 0,
812
+ totalHolesDetected: totalHoles,
813
+ maxBettiNumber: maxBeta1,
814
+ uniqueCycles: allCycles.length,
815
+ layerReports: reports
816
+ };
817
+ }
818
+
819
+ /**
820
+ * Set training mode
821
+ */
822
+ train(mode = true) {
823
+ for (const block of this.blocks) {
824
+ block.train(mode);
825
+ }
826
+ return this;
827
+ }
828
+
829
+ /**
830
+ * Set evaluation mode
831
+ */
832
+ eval() {
833
+ return this.train(false);
834
+ }
835
+
836
+ /**
837
+ * Get total parameter count (approximate)
838
+ */
839
+ getParameterCount() {
840
+ const perBlock = this.numLayers * (
841
+ 8 * 3 + // Attention head weights
842
+ 4 + // FFN weights
843
+ 2 + // LayerNorm
844
+ 1 // Homology params
845
+ );
846
+ return perBlock + (this.usePositionalEncoding ? this.activeK * 4 : 0);
847
+ }
848
+
849
+ /**
850
+ * Get CRT configuration
851
+ */
852
+ getCRTConfig() {
853
+ if (this.blocks.length === 0) return null;
854
+ return this.blocks[0].attention.getParameters();
855
+ }
856
+ }
857
+
858
+ /**
859
+ * createCRTResoFormer - Factory function with sensible defaults
860
+ *
861
+ * @param {object} config - Configuration options
862
+ * @returns {CRTResoFormer} Configured model
863
+ */
864
+ function createCRTResoFormer(config = {}) {
865
+ return new CRTResoFormer({
866
+ numLayers: config.numLayers || 6,
867
+ numHeads: config.numHeads || 8,
868
+ hiddenDim: config.hiddenDim || 256,
869
+ numPrimes: config.numPrimes || 4096,
870
+ activeK: config.activeK || 32,
871
+ dropout: config.dropout || 0.1,
872
+ usePositionalEncoding: config.usePositionalEncoding ?? true,
873
+ homologyWeight: config.homologyWeight || 0.1,
874
+ sinkhornIterations: config.sinkhornIterations || 10,
875
+ preNorm: config.preNorm ?? true
876
+ });
877
+ }
878
+
879
+ module.exports = {
880
+ // CRT-enhanced attention
881
+ CRTResonantAttention,
882
+
883
+ // Homology-regularized block
884
+ HomologyRegularizedBlock,
885
+
886
+ // Complete model
887
+ CRTResoFormer,
888
+
889
+ // Factory function
890
+ createCRTResoFormer
891
+ };
892
+