@sparkleideas/neural 3.5.2-patch.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. package/README.md +260 -0
  2. package/__tests__/README.md +235 -0
  3. package/__tests__/algorithms.test.ts +582 -0
  4. package/__tests__/patterns.test.ts +549 -0
  5. package/__tests__/sona.test.ts +445 -0
  6. package/docs/SONA_INTEGRATION.md +460 -0
  7. package/docs/SONA_QUICKSTART.md +168 -0
  8. package/examples/sona-usage.ts +318 -0
  9. package/package.json +23 -0
  10. package/src/algorithms/a2c.d.ts +86 -0
  11. package/src/algorithms/a2c.d.ts.map +1 -0
  12. package/src/algorithms/a2c.js +361 -0
  13. package/src/algorithms/a2c.js.map +1 -0
  14. package/src/algorithms/a2c.ts +478 -0
  15. package/src/algorithms/curiosity.d.ts +82 -0
  16. package/src/algorithms/curiosity.d.ts.map +1 -0
  17. package/src/algorithms/curiosity.js +392 -0
  18. package/src/algorithms/curiosity.js.map +1 -0
  19. package/src/algorithms/curiosity.ts +509 -0
  20. package/src/algorithms/decision-transformer.d.ts +82 -0
  21. package/src/algorithms/decision-transformer.d.ts.map +1 -0
  22. package/src/algorithms/decision-transformer.js +415 -0
  23. package/src/algorithms/decision-transformer.js.map +1 -0
  24. package/src/algorithms/decision-transformer.ts +521 -0
  25. package/src/algorithms/dqn.d.ts +72 -0
  26. package/src/algorithms/dqn.d.ts.map +1 -0
  27. package/src/algorithms/dqn.js +303 -0
  28. package/src/algorithms/dqn.js.map +1 -0
  29. package/src/algorithms/dqn.ts +382 -0
  30. package/src/algorithms/index.d.ts +32 -0
  31. package/src/algorithms/index.d.ts.map +1 -0
  32. package/src/algorithms/index.js +74 -0
  33. package/src/algorithms/index.js.map +1 -0
  34. package/src/algorithms/index.ts +122 -0
  35. package/src/algorithms/ppo.d.ts +72 -0
  36. package/src/algorithms/ppo.d.ts.map +1 -0
  37. package/src/algorithms/ppo.js +331 -0
  38. package/src/algorithms/ppo.js.map +1 -0
  39. package/src/algorithms/ppo.ts +429 -0
  40. package/src/algorithms/q-learning.d.ts +77 -0
  41. package/src/algorithms/q-learning.d.ts.map +1 -0
  42. package/src/algorithms/q-learning.js +259 -0
  43. package/src/algorithms/q-learning.js.map +1 -0
  44. package/src/algorithms/q-learning.ts +333 -0
  45. package/src/algorithms/sarsa.d.ts +82 -0
  46. package/src/algorithms/sarsa.d.ts.map +1 -0
  47. package/src/algorithms/sarsa.js +297 -0
  48. package/src/algorithms/sarsa.js.map +1 -0
  49. package/src/algorithms/sarsa.ts +383 -0
  50. package/src/algorithms/tmp.json +0 -0
  51. package/src/application/index.ts +11 -0
  52. package/src/application/services/neural-application-service.ts +217 -0
  53. package/src/domain/entities/pattern.ts +169 -0
  54. package/src/domain/index.ts +18 -0
  55. package/src/domain/services/learning-service.ts +256 -0
  56. package/src/index.d.ts +118 -0
  57. package/src/index.d.ts.map +1 -0
  58. package/src/index.js +201 -0
  59. package/src/index.js.map +1 -0
  60. package/src/index.ts +363 -0
  61. package/src/modes/balanced.d.ts +60 -0
  62. package/src/modes/balanced.d.ts.map +1 -0
  63. package/src/modes/balanced.js +234 -0
  64. package/src/modes/balanced.js.map +1 -0
  65. package/src/modes/balanced.ts +299 -0
  66. package/src/modes/base.ts +163 -0
  67. package/src/modes/batch.d.ts +82 -0
  68. package/src/modes/batch.d.ts.map +1 -0
  69. package/src/modes/batch.js +316 -0
  70. package/src/modes/batch.js.map +1 -0
  71. package/src/modes/batch.ts +434 -0
  72. package/src/modes/edge.d.ts +85 -0
  73. package/src/modes/edge.d.ts.map +1 -0
  74. package/src/modes/edge.js +310 -0
  75. package/src/modes/edge.js.map +1 -0
  76. package/src/modes/edge.ts +409 -0
  77. package/src/modes/index.d.ts +55 -0
  78. package/src/modes/index.d.ts.map +1 -0
  79. package/src/modes/index.js +83 -0
  80. package/src/modes/index.js.map +1 -0
  81. package/src/modes/index.ts +16 -0
  82. package/src/modes/real-time.d.ts +58 -0
  83. package/src/modes/real-time.d.ts.map +1 -0
  84. package/src/modes/real-time.js +196 -0
  85. package/src/modes/real-time.js.map +1 -0
  86. package/src/modes/real-time.ts +257 -0
  87. package/src/modes/research.d.ts +79 -0
  88. package/src/modes/research.d.ts.map +1 -0
  89. package/src/modes/research.js +389 -0
  90. package/src/modes/research.js.map +1 -0
  91. package/src/modes/research.ts +486 -0
  92. package/src/modes/tmp.json +0 -0
  93. package/src/pattern-learner.d.ts +117 -0
  94. package/src/pattern-learner.d.ts.map +1 -0
  95. package/src/pattern-learner.js +603 -0
  96. package/src/pattern-learner.js.map +1 -0
  97. package/src/pattern-learner.ts +757 -0
  98. package/src/reasoning-bank.d.ts +259 -0
  99. package/src/reasoning-bank.d.ts.map +1 -0
  100. package/src/reasoning-bank.js +993 -0
  101. package/src/reasoning-bank.js.map +1 -0
  102. package/src/reasoning-bank.ts +1279 -0
  103. package/src/reasoningbank-adapter.ts +697 -0
  104. package/src/sona-integration.d.ts +168 -0
  105. package/src/sona-integration.d.ts.map +1 -0
  106. package/src/sona-integration.js +316 -0
  107. package/src/sona-integration.js.map +1 -0
  108. package/src/sona-integration.ts +432 -0
  109. package/src/sona-manager.d.ts +147 -0
  110. package/src/sona-manager.d.ts.map +1 -0
  111. package/src/sona-manager.js +695 -0
  112. package/src/sona-manager.js.map +1 -0
  113. package/src/sona-manager.ts +835 -0
  114. package/src/tmp.json +0 -0
  115. package/src/types.d.ts +431 -0
  116. package/src/types.d.ts.map +1 -0
  117. package/src/types.js +11 -0
  118. package/src/types.js.map +1 -0
  119. package/src/types.ts +590 -0
  120. package/tmp.json +0 -0
  121. package/tsconfig.json +9 -0
  122. package/vitest.config.ts +19 -0
@@ -0,0 +1,582 @@
1
+ /**
2
+ * RL Algorithms Tests
3
+ *
4
+ * Tests for reinforcement learning algorithms:
5
+ * - Q-Learning
6
+ * - SARSA
7
+ * - DQN
8
+ * - PPO
9
+ * - Decision Transformer
10
+ *
11
+ * Performance target: <10ms per update
12
+ */
13
+
14
+ import { describe, it, expect, beforeEach } from 'vitest';
15
+ import { QLearning, createQLearning } from '../src/algorithms/q-learning.js';
16
+ import { SARSAAlgorithm, createSARSA } from '../src/algorithms/sarsa.js';
17
+ import { DQNAlgorithm, createDQN } from '../src/algorithms/dqn.js';
18
+ import { PPOAlgorithm, createPPO } from '../src/algorithms/ppo.js';
19
+ import { DecisionTransformer, createDecisionTransformer } from '../src/algorithms/decision-transformer.js';
20
+ import type { Trajectory } from '../src/types.js';
21
+
22
+ // Helper function to create test trajectories
23
+ function createTestTrajectory(steps: number = 5): Trajectory {
24
+ return {
25
+ trajectoryId: `test-traj-${Date.now()}`,
26
+ context: 'Test task',
27
+ domain: 'code',
28
+ steps: Array.from({ length: steps }, (_, i) => ({
29
+ stepId: `step-${i}`,
30
+ timestamp: Date.now() + i * 100,
31
+ action: `action-${i % 4}`, // 4 discrete actions
32
+ stateBefore: new Float32Array(768).fill(i * 0.1),
33
+ stateAfter: new Float32Array(768).fill((i + 1) * 0.1),
34
+ reward: 0.5 + (i / steps) * 0.5, // Increasing rewards
35
+ })),
36
+ qualityScore: 0.75,
37
+ isComplete: true,
38
+ startTime: Date.now() - 1000,
39
+ endTime: Date.now(),
40
+ };
41
+ }
42
+
43
+ describe('Q-Learning Algorithm', () => {
44
+ let qlearning: QLearning;
45
+
46
+ beforeEach(() => {
47
+ qlearning = createQLearning({
48
+ learningRate: 0.1,
49
+ gamma: 0.99,
50
+ explorationInitial: 1.0,
51
+ explorationFinal: 0.01,
52
+ explorationDecay: 1000,
53
+ });
54
+ });
55
+
56
+ it('should initialize correctly', () => {
57
+ expect(qlearning).toBeDefined();
58
+ const stats = qlearning.getStats();
59
+ expect(stats.updateCount).toBe(0);
60
+ expect(stats.qTableSize).toBe(0);
61
+ expect(stats.epsilon).toBeCloseTo(1.0);
62
+ });
63
+
64
+ it('should update Q-values from trajectory', () => {
65
+ const trajectory = createTestTrajectory(5);
66
+ const result = qlearning.update(trajectory);
67
+
68
+ expect(result.tdError).toBeGreaterThanOrEqual(0);
69
+ const stats = qlearning.getStats();
70
+ expect(stats.updateCount).toBe(1);
71
+ expect(stats.qTableSize).toBeGreaterThan(0);
72
+ });
73
+
74
+ it('should update under performance target (<1ms)', () => {
75
+ const trajectory = createTestTrajectory(10);
76
+
77
+ const startTime = performance.now();
78
+ qlearning.update(trajectory);
79
+ const elapsed = performance.now() - startTime;
80
+
81
+ expect(elapsed).toBeLessThan(10); // Reasonable target for small trajectories
82
+ });
83
+
84
+ it('should decay exploration rate', () => {
85
+ const trajectory = createTestTrajectory(5);
86
+ const initialEpsilon = qlearning.getStats().epsilon;
87
+
88
+ for (let i = 0; i < 10; i++) {
89
+ qlearning.update(trajectory);
90
+ }
91
+
92
+ const finalEpsilon = qlearning.getStats().epsilon;
93
+ expect(finalEpsilon).toBeLessThan(initialEpsilon);
94
+ });
95
+
96
+ it('should select actions with epsilon-greedy', () => {
97
+ const state = new Float32Array(768).fill(0.5);
98
+
99
+ // First call should be random (high epsilon)
100
+ const action1 = qlearning.getAction(state, true);
101
+ expect(action1).toBeGreaterThanOrEqual(0);
102
+ expect(action1).toBeLessThan(4);
103
+
104
+ // Without exploration, should be deterministic
105
+ const action2 = qlearning.getAction(state, false);
106
+ expect(action2).toBeDefined();
107
+ });
108
+
109
+ it('should return Q-values for a state', () => {
110
+ const trajectory = createTestTrajectory(5);
111
+ qlearning.update(trajectory);
112
+
113
+ const state = new Float32Array(768).fill(0.5);
114
+ const qValues = qlearning.getQValues(state);
115
+
116
+ expect(qValues).toBeInstanceOf(Float32Array);
117
+ expect(qValues.length).toBe(4);
118
+ });
119
+
120
+ it('should handle eligibility traces', () => {
121
+ const qlearningWithTraces = createQLearning({
122
+ useEligibilityTraces: true,
123
+ traceDecay: 0.9,
124
+ });
125
+
126
+ const trajectory = createTestTrajectory(10);
127
+ expect(() => qlearningWithTraces.update(trajectory)).not.toThrow();
128
+ });
129
+
130
+ it('should prune Q-table when over capacity', () => {
131
+ const smallQLearning = createQLearning({
132
+ maxStates: 10,
133
+ });
134
+
135
+ // Add many different trajectories to fill Q-table
136
+ for (let i = 0; i < 20; i++) {
137
+ const trajectory = createTestTrajectory(5);
138
+ smallQLearning.update(trajectory);
139
+ }
140
+
141
+ const stats = smallQLearning.getStats();
142
+ expect(stats.qTableSize).toBeLessThanOrEqual(10);
143
+ });
144
+
145
+ it('should reset correctly', () => {
146
+ const trajectory = createTestTrajectory(5);
147
+ qlearning.update(trajectory);
148
+
149
+ qlearning.reset();
150
+ const stats = qlearning.getStats();
151
+
152
+ expect(stats.updateCount).toBe(0);
153
+ expect(stats.qTableSize).toBe(0);
154
+ expect(stats.epsilon).toBeCloseTo(1.0);
155
+ });
156
+ });
157
+
158
+ describe('SARSA Algorithm', () => {
159
+ let sarsa: SARSAAlgorithm;
160
+
161
+ beforeEach(() => {
162
+ sarsa = createSARSA({
163
+ learningRate: 0.1,
164
+ gamma: 0.99,
165
+ explorationInitial: 1.0,
166
+ explorationFinal: 0.01,
167
+ explorationDecay: 1000,
168
+ });
169
+ });
170
+
171
+ it('should initialize correctly', () => {
172
+ expect(sarsa).toBeDefined();
173
+ const stats = sarsa.getStats();
174
+ expect(stats.updateCount).toBe(0);
175
+ expect(stats.qTableSize).toBe(0);
176
+ });
177
+
178
+ it('should update using SARSA rule', () => {
179
+ const trajectory = createTestTrajectory(5);
180
+ const result = sarsa.update(trajectory);
181
+
182
+ expect(result.tdError).toBeGreaterThanOrEqual(0);
183
+ const stats = sarsa.getStats();
184
+ expect(stats.updateCount).toBe(1);
185
+ });
186
+
187
+ it('should handle expected SARSA variant', () => {
188
+ const expectedSARSA = createSARSA({
189
+ useExpectedSARSA: true,
190
+ });
191
+
192
+ const trajectory = createTestTrajectory(5);
193
+ expect(() => expectedSARSA.update(trajectory)).not.toThrow();
194
+ });
195
+
196
+ it('should return action probabilities', () => {
197
+ const state = new Float32Array(768).fill(0.5);
198
+ const probs = sarsa.getActionProbabilities(state);
199
+
200
+ expect(probs).toBeInstanceOf(Float32Array);
201
+ expect(probs.length).toBe(4);
202
+
203
+ // Probabilities should sum to ~1
204
+ const sum = Array.from(probs).reduce((a, b) => a + b, 0);
205
+ expect(sum).toBeCloseTo(1.0, 2);
206
+ });
207
+
208
+ it('should select actions with epsilon-greedy policy', () => {
209
+ const state = new Float32Array(768).fill(0.5);
210
+ const action = sarsa.getAction(state, true);
211
+
212
+ expect(action).toBeGreaterThanOrEqual(0);
213
+ expect(action).toBeLessThan(4);
214
+ });
215
+
216
+ it('should handle eligibility traces (SARSA-lambda)', () => {
217
+ const sarsaLambda = createSARSA({
218
+ useEligibilityTraces: true,
219
+ traceDecay: 0.9,
220
+ });
221
+
222
+ const trajectory = createTestTrajectory(10);
223
+ expect(() => sarsaLambda.update(trajectory)).not.toThrow();
224
+ });
225
+
226
+ it('should handle short trajectories gracefully', () => {
227
+ const shortTrajectory = createTestTrajectory(1);
228
+ const result = sarsa.update(shortTrajectory);
229
+
230
+ expect(result.tdError).toBe(0); // Not enough steps for SARSA
231
+ });
232
+
233
+ it('should reset algorithm state', () => {
234
+ const trajectory = createTestTrajectory(5);
235
+ sarsa.update(trajectory);
236
+
237
+ sarsa.reset();
238
+ const stats = sarsa.getStats();
239
+
240
+ expect(stats.updateCount).toBe(0);
241
+ expect(stats.qTableSize).toBe(0);
242
+ });
243
+ });
244
+
245
+ describe('DQN Algorithm', () => {
246
+ let dqn: DQNAlgorithm;
247
+
248
+ beforeEach(() => {
249
+ dqn = createDQN({
250
+ learningRate: 0.0001,
251
+ bufferSize: 1000,
252
+ miniBatchSize: 32,
253
+ doubleDQN: true,
254
+ targetUpdateFreq: 100,
255
+ });
256
+ });
257
+
258
+ it('should initialize correctly', () => {
259
+ expect(dqn).toBeDefined();
260
+ const stats = dqn.getStats();
261
+ expect(stats.updateCount).toBe(0);
262
+ expect(stats.bufferSize).toBe(0);
263
+ });
264
+
265
+ it('should add experience to replay buffer', () => {
266
+ const trajectory = createTestTrajectory(10);
267
+ dqn.addExperience(trajectory);
268
+
269
+ const stats = dqn.getStats();
270
+ expect(stats.bufferSize).toBe(10);
271
+ });
272
+
273
+ it('should perform DQN update', () => {
274
+ // Add enough experiences
275
+ for (let i = 0; i < 5; i++) {
276
+ dqn.addExperience(createTestTrajectory(10));
277
+ }
278
+
279
+ const result = dqn.update();
280
+ expect(result.loss).toBeGreaterThanOrEqual(0);
281
+ expect(result.epsilon).toBeGreaterThan(0);
282
+ });
283
+
284
+ it('should update under performance target (<10ms)', () => {
285
+ // Add experiences
286
+ for (let i = 0; i < 5; i++) {
287
+ dqn.addExperience(createTestTrajectory(10));
288
+ }
289
+
290
+ const startTime = performance.now();
291
+ dqn.update();
292
+ const elapsed = performance.now() - startTime;
293
+
294
+ // Allow generous overhead for neural network in test environment
295
+ // (actual production target is <10ms, but tests run in CI may be slower)
296
+ expect(elapsed).toBeLessThan(500);
297
+ });
298
+
299
+ it('should use double DQN when enabled', () => {
300
+ const doubleDQN = createDQN({
301
+ doubleDQN: true,
302
+ miniBatchSize: 16,
303
+ });
304
+
305
+ for (let i = 0; i < 3; i++) {
306
+ doubleDQN.addExperience(createTestTrajectory(10));
307
+ }
308
+
309
+ expect(() => doubleDQN.update()).not.toThrow();
310
+ });
311
+
312
+ it('should select actions with epsilon-greedy', () => {
313
+ const state = new Float32Array(768).fill(0.5);
314
+ const action = dqn.getAction(state, true);
315
+
316
+ expect(action).toBeGreaterThanOrEqual(0);
317
+ expect(action).toBeLessThan(4);
318
+ });
319
+
320
+ it('should return Q-values for a state', () => {
321
+ const state = new Float32Array(768).fill(0.5);
322
+ const qValues = dqn.getQValues(state);
323
+
324
+ expect(qValues).toBeInstanceOf(Float32Array);
325
+ expect(qValues.length).toBe(4);
326
+ });
327
+
328
+ it('should update target network periodically', () => {
329
+ const dqnWithFreqUpdate = createDQN({
330
+ targetUpdateFreq: 5,
331
+ miniBatchSize: 16,
332
+ });
333
+
334
+ for (let i = 0; i < 3; i++) {
335
+ dqnWithFreqUpdate.addExperience(createTestTrajectory(10));
336
+ }
337
+
338
+ // Perform multiple updates to trigger target network update
339
+ for (let i = 0; i < 10; i++) {
340
+ dqnWithFreqUpdate.update();
341
+ }
342
+
343
+ const stats = dqnWithFreqUpdate.getStats();
344
+ expect(stats.stepCount).toBeGreaterThan(5);
345
+ });
346
+
347
+ it('should handle circular replay buffer correctly', () => {
348
+ const smallDQN = createDQN({
349
+ bufferSize: 10,
350
+ miniBatchSize: 4,
351
+ });
352
+
353
+ // Add more experiences than buffer size
354
+ for (let i = 0; i < 15; i++) {
355
+ smallDQN.addExperience(createTestTrajectory(2));
356
+ }
357
+
358
+ const stats = smallDQN.getStats();
359
+ expect(stats.bufferSize).toBe(10);
360
+ });
361
+ });
362
+
363
+ describe('PPO Algorithm', () => {
364
+ let ppo: PPOAlgorithm;
365
+
366
+ beforeEach(() => {
367
+ ppo = createPPO({
368
+ learningRate: 0.0003,
369
+ clipRange: 0.2,
370
+ gaeLambda: 0.95,
371
+ epochs: 4,
372
+ miniBatchSize: 64,
373
+ });
374
+ });
375
+
376
+ it('should initialize correctly', () => {
377
+ expect(ppo).toBeDefined();
378
+ const stats = ppo.getStats();
379
+ expect(stats.updateCount).toBe(0);
380
+ });
381
+
382
+ it('should add experience from trajectory', () => {
383
+ const trajectory = createTestTrajectory(10);
384
+ expect(() => ppo.addExperience(trajectory)).not.toThrow();
385
+
386
+ const stats = ppo.getStats();
387
+ expect(stats.bufferSize).toBe(10);
388
+ });
389
+
390
+ it('should perform PPO update with clipping', () => {
391
+ // Add enough experiences
392
+ for (let i = 0; i < 10; i++) {
393
+ ppo.addExperience(createTestTrajectory(10));
394
+ }
395
+
396
+ const result = ppo.update();
397
+
398
+ // Policy loss can be negative in PPO (we minimize -surrogate_objective)
399
+ expect(typeof result.policyLoss).toBe('number');
400
+ expect(result.valueLoss).toBeGreaterThanOrEqual(0);
401
+ expect(result.entropy).toBeGreaterThanOrEqual(0);
402
+ });
403
+
404
+ it('should update under performance target (<10ms for small batches)', () => {
405
+ const smallPPO = createPPO({
406
+ miniBatchSize: 16,
407
+ epochs: 1,
408
+ });
409
+
410
+ for (let i = 0; i < 3; i++) {
411
+ smallPPO.addExperience(createTestTrajectory(10));
412
+ }
413
+
414
+ const startTime = performance.now();
415
+ smallPPO.update();
416
+ const elapsed = performance.now() - startTime;
417
+
418
+ expect(elapsed).toBeLessThan(100); // Allow overhead for PPO complexity
419
+ });
420
+
421
+ it('should compute GAE advantages', () => {
422
+ const trajectory = createTestTrajectory(20);
423
+ expect(() => ppo.addExperience(trajectory)).not.toThrow();
424
+
425
+ // Verify experiences were added with advantages
426
+ const stats = ppo.getStats();
427
+ expect(stats.bufferSize).toBe(20);
428
+ });
429
+
430
+ it('should sample actions from policy', () => {
431
+ const state = new Float32Array(768).fill(0.5);
432
+ const result = ppo.getAction(state);
433
+
434
+ expect(result.action).toBeGreaterThanOrEqual(0);
435
+ expect(result.action).toBeLessThan(4);
436
+ expect(result.logProb).toBeDefined();
437
+ expect(result.value).toBeDefined();
438
+ });
439
+
440
+ it('should handle multiple training epochs', () => {
441
+ const multiEpochPPO = createPPO({
442
+ epochs: 8,
443
+ miniBatchSize: 32,
444
+ });
445
+
446
+ for (let i = 0; i < 5; i++) {
447
+ multiEpochPPO.addExperience(createTestTrajectory(10));
448
+ }
449
+
450
+ expect(() => multiEpochPPO.update()).not.toThrow();
451
+ });
452
+
453
+ it('should clear buffer after update', () => {
454
+ for (let i = 0; i < 10; i++) {
455
+ ppo.addExperience(createTestTrajectory(10));
456
+ }
457
+
458
+ ppo.update();
459
+ const stats = ppo.getStats();
460
+
461
+ expect(stats.bufferSize).toBe(0);
462
+ });
463
+ });
464
+
465
+ describe('Decision Transformer', () => {
466
+ let dt: DecisionTransformer;
467
+
468
+ beforeEach(() => {
469
+ dt = createDecisionTransformer({
470
+ contextLength: 20,
471
+ numHeads: 4,
472
+ numLayers: 2,
473
+ hiddenDim: 64,
474
+ embeddingDim: 32,
475
+ });
476
+ });
477
+
478
+ it('should initialize correctly', () => {
479
+ expect(dt).toBeDefined();
480
+ const stats = dt.getStats();
481
+ expect(stats.updateCount).toBe(0);
482
+ expect(stats.bufferSize).toBe(0);
483
+ expect(stats.contextLength).toBe(20);
484
+ expect(stats.numLayers).toBe(2);
485
+ });
486
+
487
+ it('should add complete trajectories to buffer', () => {
488
+ const trajectory = createTestTrajectory(10);
489
+ dt.addTrajectory(trajectory);
490
+
491
+ const stats = dt.getStats();
492
+ expect(stats.bufferSize).toBe(1);
493
+ });
494
+
495
+ it('should not add incomplete trajectories', () => {
496
+ const incompleteTrajectory: Trajectory = {
497
+ ...createTestTrajectory(5),
498
+ isComplete: false,
499
+ };
500
+
501
+ dt.addTrajectory(incompleteTrajectory);
502
+ const stats = dt.getStats();
503
+
504
+ expect(stats.bufferSize).toBe(0);
505
+ });
506
+
507
+ it('should train on buffered trajectories', () => {
508
+ // Add multiple trajectories
509
+ for (let i = 0; i < 5; i++) {
510
+ dt.addTrajectory(createTestTrajectory(10));
511
+ }
512
+
513
+ const result = dt.train();
514
+
515
+ expect(result.loss).toBeGreaterThanOrEqual(0);
516
+ expect(result.accuracy).toBeGreaterThanOrEqual(0);
517
+ expect(result.accuracy).toBeLessThanOrEqual(1);
518
+ });
519
+
520
+ it('should train under performance target (<10ms per batch)', () => {
521
+ for (let i = 0; i < 3; i++) {
522
+ dt.addTrajectory(createTestTrajectory(5));
523
+ }
524
+
525
+ const startTime = performance.now();
526
+ dt.train();
527
+ const elapsed = performance.now() - startTime;
528
+
529
+ expect(elapsed).toBeLessThan(100); // Allow overhead for transformer
530
+ });
531
+
532
+ it('should get action conditioned on target return', () => {
533
+ const states = [
534
+ new Float32Array(768).fill(0.1),
535
+ new Float32Array(768).fill(0.2),
536
+ new Float32Array(768).fill(0.3),
537
+ ];
538
+ const actions = [0, 1, 2];
539
+ const targetReturn = 0.9;
540
+
541
+ const action = dt.getAction(states, actions, targetReturn);
542
+
543
+ expect(action).toBeGreaterThanOrEqual(0);
544
+ expect(action).toBeLessThan(4);
545
+ });
546
+
547
+ it('should handle causal attention masking', () => {
548
+ // Train with sequence data
549
+ for (let i = 0; i < 5; i++) {
550
+ dt.addTrajectory(createTestTrajectory(15));
551
+ }
552
+
553
+ expect(() => dt.train()).not.toThrow();
554
+ });
555
+
556
+ it('should maintain bounded trajectory buffer', () => {
557
+ // Add more than max capacity (1000)
558
+ for (let i = 0; i < 1100; i++) {
559
+ dt.addTrajectory(createTestTrajectory(5));
560
+ }
561
+
562
+ const stats = dt.getStats();
563
+ expect(stats.bufferSize).toBe(1000);
564
+ });
565
+
566
+ it('should handle varying trajectory lengths', () => {
567
+ dt.addTrajectory(createTestTrajectory(3));
568
+ dt.addTrajectory(createTestTrajectory(10));
569
+ dt.addTrajectory(createTestTrajectory(25));
570
+
571
+ expect(() => dt.train()).not.toThrow();
572
+ });
573
+
574
+ it('should compute returns-to-go correctly', () => {
575
+ const trajectory = createTestTrajectory(5);
576
+ dt.addTrajectory(trajectory);
577
+
578
+ expect(() => dt.train()).not.toThrow();
579
+ const stats = dt.getStats();
580
+ expect(stats.avgLoss).toBeGreaterThanOrEqual(0);
581
+ });
582
+ });