@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* RL Algorithms Tests
|
|
3
|
+
*
|
|
4
|
+
* Tests for reinforcement learning algorithms:
|
|
5
|
+
* - Q-Learning
|
|
6
|
+
* - SARSA
|
|
7
|
+
* - DQN
|
|
8
|
+
* - PPO
|
|
9
|
+
* - Decision Transformer
|
|
10
|
+
*
|
|
11
|
+
* Performance target: <10ms per update
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import { describe, it, expect, beforeEach } from 'vitest';
|
|
15
|
+
import { QLearning, createQLearning } from '../src/algorithms/q-learning.js';
|
|
16
|
+
import { SARSAAlgorithm, createSARSA } from '../src/algorithms/sarsa.js';
|
|
17
|
+
import { DQNAlgorithm, createDQN } from '../src/algorithms/dqn.js';
|
|
18
|
+
import { PPOAlgorithm, createPPO } from '../src/algorithms/ppo.js';
|
|
19
|
+
import { DecisionTransformer, createDecisionTransformer } from '../src/algorithms/decision-transformer.js';
|
|
20
|
+
import type { Trajectory } from '../src/types.js';
|
|
21
|
+
|
|
22
|
+
// Helper function to create test trajectories
|
|
23
|
+
function createTestTrajectory(steps: number = 5): Trajectory {
|
|
24
|
+
return {
|
|
25
|
+
trajectoryId: `test-traj-${Date.now()}`,
|
|
26
|
+
context: 'Test task',
|
|
27
|
+
domain: 'code',
|
|
28
|
+
steps: Array.from({ length: steps }, (_, i) => ({
|
|
29
|
+
stepId: `step-${i}`,
|
|
30
|
+
timestamp: Date.now() + i * 100,
|
|
31
|
+
action: `action-${i % 4}`, // 4 discrete actions
|
|
32
|
+
stateBefore: new Float32Array(768).fill(i * 0.1),
|
|
33
|
+
stateAfter: new Float32Array(768).fill((i + 1) * 0.1),
|
|
34
|
+
reward: 0.5 + (i / steps) * 0.5, // Increasing rewards
|
|
35
|
+
})),
|
|
36
|
+
qualityScore: 0.75,
|
|
37
|
+
isComplete: true,
|
|
38
|
+
startTime: Date.now() - 1000,
|
|
39
|
+
endTime: Date.now(),
|
|
40
|
+
};
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
describe('Q-Learning Algorithm', () => {
|
|
44
|
+
let qlearning: QLearning;
|
|
45
|
+
|
|
46
|
+
beforeEach(() => {
|
|
47
|
+
qlearning = createQLearning({
|
|
48
|
+
learningRate: 0.1,
|
|
49
|
+
gamma: 0.99,
|
|
50
|
+
explorationInitial: 1.0,
|
|
51
|
+
explorationFinal: 0.01,
|
|
52
|
+
explorationDecay: 1000,
|
|
53
|
+
});
|
|
54
|
+
});
|
|
55
|
+
|
|
56
|
+
it('should initialize correctly', () => {
|
|
57
|
+
expect(qlearning).toBeDefined();
|
|
58
|
+
const stats = qlearning.getStats();
|
|
59
|
+
expect(stats.updateCount).toBe(0);
|
|
60
|
+
expect(stats.qTableSize).toBe(0);
|
|
61
|
+
expect(stats.epsilon).toBeCloseTo(1.0);
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('should update Q-values from trajectory', () => {
|
|
65
|
+
const trajectory = createTestTrajectory(5);
|
|
66
|
+
const result = qlearning.update(trajectory);
|
|
67
|
+
|
|
68
|
+
expect(result.tdError).toBeGreaterThanOrEqual(0);
|
|
69
|
+
const stats = qlearning.getStats();
|
|
70
|
+
expect(stats.updateCount).toBe(1);
|
|
71
|
+
expect(stats.qTableSize).toBeGreaterThan(0);
|
|
72
|
+
});
|
|
73
|
+
|
|
74
|
+
it('should update under performance target (<1ms)', () => {
|
|
75
|
+
const trajectory = createTestTrajectory(10);
|
|
76
|
+
|
|
77
|
+
const startTime = performance.now();
|
|
78
|
+
qlearning.update(trajectory);
|
|
79
|
+
const elapsed = performance.now() - startTime;
|
|
80
|
+
|
|
81
|
+
expect(elapsed).toBeLessThan(10); // Reasonable target for small trajectories
|
|
82
|
+
});
|
|
83
|
+
|
|
84
|
+
it('should decay exploration rate', () => {
|
|
85
|
+
const trajectory = createTestTrajectory(5);
|
|
86
|
+
const initialEpsilon = qlearning.getStats().epsilon;
|
|
87
|
+
|
|
88
|
+
for (let i = 0; i < 10; i++) {
|
|
89
|
+
qlearning.update(trajectory);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
const finalEpsilon = qlearning.getStats().epsilon;
|
|
93
|
+
expect(finalEpsilon).toBeLessThan(initialEpsilon);
|
|
94
|
+
});
|
|
95
|
+
|
|
96
|
+
it('should select actions with epsilon-greedy', () => {
|
|
97
|
+
const state = new Float32Array(768).fill(0.5);
|
|
98
|
+
|
|
99
|
+
// First call should be random (high epsilon)
|
|
100
|
+
const action1 = qlearning.getAction(state, true);
|
|
101
|
+
expect(action1).toBeGreaterThanOrEqual(0);
|
|
102
|
+
expect(action1).toBeLessThan(4);
|
|
103
|
+
|
|
104
|
+
// Without exploration, should be deterministic
|
|
105
|
+
const action2 = qlearning.getAction(state, false);
|
|
106
|
+
expect(action2).toBeDefined();
|
|
107
|
+
});
|
|
108
|
+
|
|
109
|
+
it('should return Q-values for a state', () => {
|
|
110
|
+
const trajectory = createTestTrajectory(5);
|
|
111
|
+
qlearning.update(trajectory);
|
|
112
|
+
|
|
113
|
+
const state = new Float32Array(768).fill(0.5);
|
|
114
|
+
const qValues = qlearning.getQValues(state);
|
|
115
|
+
|
|
116
|
+
expect(qValues).toBeInstanceOf(Float32Array);
|
|
117
|
+
expect(qValues.length).toBe(4);
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
it('should handle eligibility traces', () => {
|
|
121
|
+
const qlearningWithTraces = createQLearning({
|
|
122
|
+
useEligibilityTraces: true,
|
|
123
|
+
traceDecay: 0.9,
|
|
124
|
+
});
|
|
125
|
+
|
|
126
|
+
const trajectory = createTestTrajectory(10);
|
|
127
|
+
expect(() => qlearningWithTraces.update(trajectory)).not.toThrow();
|
|
128
|
+
});
|
|
129
|
+
|
|
130
|
+
it('should prune Q-table when over capacity', () => {
|
|
131
|
+
const smallQLearning = createQLearning({
|
|
132
|
+
maxStates: 10,
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
// Add many different trajectories to fill Q-table
|
|
136
|
+
for (let i = 0; i < 20; i++) {
|
|
137
|
+
const trajectory = createTestTrajectory(5);
|
|
138
|
+
smallQLearning.update(trajectory);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
const stats = smallQLearning.getStats();
|
|
142
|
+
expect(stats.qTableSize).toBeLessThanOrEqual(10);
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
it('should reset correctly', () => {
|
|
146
|
+
const trajectory = createTestTrajectory(5);
|
|
147
|
+
qlearning.update(trajectory);
|
|
148
|
+
|
|
149
|
+
qlearning.reset();
|
|
150
|
+
const stats = qlearning.getStats();
|
|
151
|
+
|
|
152
|
+
expect(stats.updateCount).toBe(0);
|
|
153
|
+
expect(stats.qTableSize).toBe(0);
|
|
154
|
+
expect(stats.epsilon).toBeCloseTo(1.0);
|
|
155
|
+
});
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
describe('SARSA Algorithm', () => {
|
|
159
|
+
let sarsa: SARSAAlgorithm;
|
|
160
|
+
|
|
161
|
+
beforeEach(() => {
|
|
162
|
+
sarsa = createSARSA({
|
|
163
|
+
learningRate: 0.1,
|
|
164
|
+
gamma: 0.99,
|
|
165
|
+
explorationInitial: 1.0,
|
|
166
|
+
explorationFinal: 0.01,
|
|
167
|
+
explorationDecay: 1000,
|
|
168
|
+
});
|
|
169
|
+
});
|
|
170
|
+
|
|
171
|
+
it('should initialize correctly', () => {
|
|
172
|
+
expect(sarsa).toBeDefined();
|
|
173
|
+
const stats = sarsa.getStats();
|
|
174
|
+
expect(stats.updateCount).toBe(0);
|
|
175
|
+
expect(stats.qTableSize).toBe(0);
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
it('should update using SARSA rule', () => {
|
|
179
|
+
const trajectory = createTestTrajectory(5);
|
|
180
|
+
const result = sarsa.update(trajectory);
|
|
181
|
+
|
|
182
|
+
expect(result.tdError).toBeGreaterThanOrEqual(0);
|
|
183
|
+
const stats = sarsa.getStats();
|
|
184
|
+
expect(stats.updateCount).toBe(1);
|
|
185
|
+
});
|
|
186
|
+
|
|
187
|
+
it('should handle expected SARSA variant', () => {
|
|
188
|
+
const expectedSARSA = createSARSA({
|
|
189
|
+
useExpectedSARSA: true,
|
|
190
|
+
});
|
|
191
|
+
|
|
192
|
+
const trajectory = createTestTrajectory(5);
|
|
193
|
+
expect(() => expectedSARSA.update(trajectory)).not.toThrow();
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
it('should return action probabilities', () => {
|
|
197
|
+
const state = new Float32Array(768).fill(0.5);
|
|
198
|
+
const probs = sarsa.getActionProbabilities(state);
|
|
199
|
+
|
|
200
|
+
expect(probs).toBeInstanceOf(Float32Array);
|
|
201
|
+
expect(probs.length).toBe(4);
|
|
202
|
+
|
|
203
|
+
// Probabilities should sum to ~1
|
|
204
|
+
const sum = Array.from(probs).reduce((a, b) => a + b, 0);
|
|
205
|
+
expect(sum).toBeCloseTo(1.0, 2);
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
it('should select actions with epsilon-greedy policy', () => {
|
|
209
|
+
const state = new Float32Array(768).fill(0.5);
|
|
210
|
+
const action = sarsa.getAction(state, true);
|
|
211
|
+
|
|
212
|
+
expect(action).toBeGreaterThanOrEqual(0);
|
|
213
|
+
expect(action).toBeLessThan(4);
|
|
214
|
+
});
|
|
215
|
+
|
|
216
|
+
it('should handle eligibility traces (SARSA-lambda)', () => {
|
|
217
|
+
const sarsaLambda = createSARSA({
|
|
218
|
+
useEligibilityTraces: true,
|
|
219
|
+
traceDecay: 0.9,
|
|
220
|
+
});
|
|
221
|
+
|
|
222
|
+
const trajectory = createTestTrajectory(10);
|
|
223
|
+
expect(() => sarsaLambda.update(trajectory)).not.toThrow();
|
|
224
|
+
});
|
|
225
|
+
|
|
226
|
+
it('should handle short trajectories gracefully', () => {
|
|
227
|
+
const shortTrajectory = createTestTrajectory(1);
|
|
228
|
+
const result = sarsa.update(shortTrajectory);
|
|
229
|
+
|
|
230
|
+
expect(result.tdError).toBe(0); // Not enough steps for SARSA
|
|
231
|
+
});
|
|
232
|
+
|
|
233
|
+
it('should reset algorithm state', () => {
|
|
234
|
+
const trajectory = createTestTrajectory(5);
|
|
235
|
+
sarsa.update(trajectory);
|
|
236
|
+
|
|
237
|
+
sarsa.reset();
|
|
238
|
+
const stats = sarsa.getStats();
|
|
239
|
+
|
|
240
|
+
expect(stats.updateCount).toBe(0);
|
|
241
|
+
expect(stats.qTableSize).toBe(0);
|
|
242
|
+
});
|
|
243
|
+
});
|
|
244
|
+
|
|
245
|
+
describe('DQN Algorithm', () => {
|
|
246
|
+
let dqn: DQNAlgorithm;
|
|
247
|
+
|
|
248
|
+
beforeEach(() => {
|
|
249
|
+
dqn = createDQN({
|
|
250
|
+
learningRate: 0.0001,
|
|
251
|
+
bufferSize: 1000,
|
|
252
|
+
miniBatchSize: 32,
|
|
253
|
+
doubleDQN: true,
|
|
254
|
+
targetUpdateFreq: 100,
|
|
255
|
+
});
|
|
256
|
+
});
|
|
257
|
+
|
|
258
|
+
it('should initialize correctly', () => {
|
|
259
|
+
expect(dqn).toBeDefined();
|
|
260
|
+
const stats = dqn.getStats();
|
|
261
|
+
expect(stats.updateCount).toBe(0);
|
|
262
|
+
expect(stats.bufferSize).toBe(0);
|
|
263
|
+
});
|
|
264
|
+
|
|
265
|
+
it('should add experience to replay buffer', () => {
|
|
266
|
+
const trajectory = createTestTrajectory(10);
|
|
267
|
+
dqn.addExperience(trajectory);
|
|
268
|
+
|
|
269
|
+
const stats = dqn.getStats();
|
|
270
|
+
expect(stats.bufferSize).toBe(10);
|
|
271
|
+
});
|
|
272
|
+
|
|
273
|
+
it('should perform DQN update', () => {
|
|
274
|
+
// Add enough experiences
|
|
275
|
+
for (let i = 0; i < 5; i++) {
|
|
276
|
+
dqn.addExperience(createTestTrajectory(10));
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
const result = dqn.update();
|
|
280
|
+
expect(result.loss).toBeGreaterThanOrEqual(0);
|
|
281
|
+
expect(result.epsilon).toBeGreaterThan(0);
|
|
282
|
+
});
|
|
283
|
+
|
|
284
|
+
it('should update under performance target (<10ms)', () => {
|
|
285
|
+
// Add experiences
|
|
286
|
+
for (let i = 0; i < 5; i++) {
|
|
287
|
+
dqn.addExperience(createTestTrajectory(10));
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
const startTime = performance.now();
|
|
291
|
+
dqn.update();
|
|
292
|
+
const elapsed = performance.now() - startTime;
|
|
293
|
+
|
|
294
|
+
// Allow generous overhead for neural network in test environment
|
|
295
|
+
// (actual production target is <10ms, but tests run in CI may be slower)
|
|
296
|
+
expect(elapsed).toBeLessThan(500);
|
|
297
|
+
});
|
|
298
|
+
|
|
299
|
+
it('should use double DQN when enabled', () => {
|
|
300
|
+
const doubleDQN = createDQN({
|
|
301
|
+
doubleDQN: true,
|
|
302
|
+
miniBatchSize: 16,
|
|
303
|
+
});
|
|
304
|
+
|
|
305
|
+
for (let i = 0; i < 3; i++) {
|
|
306
|
+
doubleDQN.addExperience(createTestTrajectory(10));
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
expect(() => doubleDQN.update()).not.toThrow();
|
|
310
|
+
});
|
|
311
|
+
|
|
312
|
+
it('should select actions with epsilon-greedy', () => {
|
|
313
|
+
const state = new Float32Array(768).fill(0.5);
|
|
314
|
+
const action = dqn.getAction(state, true);
|
|
315
|
+
|
|
316
|
+
expect(action).toBeGreaterThanOrEqual(0);
|
|
317
|
+
expect(action).toBeLessThan(4);
|
|
318
|
+
});
|
|
319
|
+
|
|
320
|
+
it('should return Q-values for a state', () => {
|
|
321
|
+
const state = new Float32Array(768).fill(0.5);
|
|
322
|
+
const qValues = dqn.getQValues(state);
|
|
323
|
+
|
|
324
|
+
expect(qValues).toBeInstanceOf(Float32Array);
|
|
325
|
+
expect(qValues.length).toBe(4);
|
|
326
|
+
});
|
|
327
|
+
|
|
328
|
+
it('should update target network periodically', () => {
|
|
329
|
+
const dqnWithFreqUpdate = createDQN({
|
|
330
|
+
targetUpdateFreq: 5,
|
|
331
|
+
miniBatchSize: 16,
|
|
332
|
+
});
|
|
333
|
+
|
|
334
|
+
for (let i = 0; i < 3; i++) {
|
|
335
|
+
dqnWithFreqUpdate.addExperience(createTestTrajectory(10));
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
// Perform multiple updates to trigger target network update
|
|
339
|
+
for (let i = 0; i < 10; i++) {
|
|
340
|
+
dqnWithFreqUpdate.update();
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
const stats = dqnWithFreqUpdate.getStats();
|
|
344
|
+
expect(stats.stepCount).toBeGreaterThan(5);
|
|
345
|
+
});
|
|
346
|
+
|
|
347
|
+
it('should handle circular replay buffer correctly', () => {
|
|
348
|
+
const smallDQN = createDQN({
|
|
349
|
+
bufferSize: 10,
|
|
350
|
+
miniBatchSize: 4,
|
|
351
|
+
});
|
|
352
|
+
|
|
353
|
+
// Add more experiences than buffer size
|
|
354
|
+
for (let i = 0; i < 15; i++) {
|
|
355
|
+
smallDQN.addExperience(createTestTrajectory(2));
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
const stats = smallDQN.getStats();
|
|
359
|
+
expect(stats.bufferSize).toBe(10);
|
|
360
|
+
});
|
|
361
|
+
});
|
|
362
|
+
|
|
363
|
+
describe('PPO Algorithm', () => {
|
|
364
|
+
let ppo: PPOAlgorithm;
|
|
365
|
+
|
|
366
|
+
beforeEach(() => {
|
|
367
|
+
ppo = createPPO({
|
|
368
|
+
learningRate: 0.0003,
|
|
369
|
+
clipRange: 0.2,
|
|
370
|
+
gaeLambda: 0.95,
|
|
371
|
+
epochs: 4,
|
|
372
|
+
miniBatchSize: 64,
|
|
373
|
+
});
|
|
374
|
+
});
|
|
375
|
+
|
|
376
|
+
it('should initialize correctly', () => {
|
|
377
|
+
expect(ppo).toBeDefined();
|
|
378
|
+
const stats = ppo.getStats();
|
|
379
|
+
expect(stats.updateCount).toBe(0);
|
|
380
|
+
});
|
|
381
|
+
|
|
382
|
+
it('should add experience from trajectory', () => {
|
|
383
|
+
const trajectory = createTestTrajectory(10);
|
|
384
|
+
expect(() => ppo.addExperience(trajectory)).not.toThrow();
|
|
385
|
+
|
|
386
|
+
const stats = ppo.getStats();
|
|
387
|
+
expect(stats.bufferSize).toBe(10);
|
|
388
|
+
});
|
|
389
|
+
|
|
390
|
+
it('should perform PPO update with clipping', () => {
|
|
391
|
+
// Add enough experiences
|
|
392
|
+
for (let i = 0; i < 10; i++) {
|
|
393
|
+
ppo.addExperience(createTestTrajectory(10));
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
const result = ppo.update();
|
|
397
|
+
|
|
398
|
+
// Policy loss can be negative in PPO (we minimize -surrogate_objective)
|
|
399
|
+
expect(typeof result.policyLoss).toBe('number');
|
|
400
|
+
expect(result.valueLoss).toBeGreaterThanOrEqual(0);
|
|
401
|
+
expect(result.entropy).toBeGreaterThanOrEqual(0);
|
|
402
|
+
});
|
|
403
|
+
|
|
404
|
+
it('should update under performance target (<10ms for small batches)', () => {
|
|
405
|
+
const smallPPO = createPPO({
|
|
406
|
+
miniBatchSize: 16,
|
|
407
|
+
epochs: 1,
|
|
408
|
+
});
|
|
409
|
+
|
|
410
|
+
for (let i = 0; i < 3; i++) {
|
|
411
|
+
smallPPO.addExperience(createTestTrajectory(10));
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
const startTime = performance.now();
|
|
415
|
+
smallPPO.update();
|
|
416
|
+
const elapsed = performance.now() - startTime;
|
|
417
|
+
|
|
418
|
+
expect(elapsed).toBeLessThan(100); // Allow overhead for PPO complexity
|
|
419
|
+
});
|
|
420
|
+
|
|
421
|
+
it('should compute GAE advantages', () => {
|
|
422
|
+
const trajectory = createTestTrajectory(20);
|
|
423
|
+
expect(() => ppo.addExperience(trajectory)).not.toThrow();
|
|
424
|
+
|
|
425
|
+
// Verify experiences were added with advantages
|
|
426
|
+
const stats = ppo.getStats();
|
|
427
|
+
expect(stats.bufferSize).toBe(20);
|
|
428
|
+
});
|
|
429
|
+
|
|
430
|
+
it('should sample actions from policy', () => {
|
|
431
|
+
const state = new Float32Array(768).fill(0.5);
|
|
432
|
+
const result = ppo.getAction(state);
|
|
433
|
+
|
|
434
|
+
expect(result.action).toBeGreaterThanOrEqual(0);
|
|
435
|
+
expect(result.action).toBeLessThan(4);
|
|
436
|
+
expect(result.logProb).toBeDefined();
|
|
437
|
+
expect(result.value).toBeDefined();
|
|
438
|
+
});
|
|
439
|
+
|
|
440
|
+
it('should handle multiple training epochs', () => {
|
|
441
|
+
const multiEpochPPO = createPPO({
|
|
442
|
+
epochs: 8,
|
|
443
|
+
miniBatchSize: 32,
|
|
444
|
+
});
|
|
445
|
+
|
|
446
|
+
for (let i = 0; i < 5; i++) {
|
|
447
|
+
multiEpochPPO.addExperience(createTestTrajectory(10));
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
expect(() => multiEpochPPO.update()).not.toThrow();
|
|
451
|
+
});
|
|
452
|
+
|
|
453
|
+
it('should clear buffer after update', () => {
|
|
454
|
+
for (let i = 0; i < 10; i++) {
|
|
455
|
+
ppo.addExperience(createTestTrajectory(10));
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
ppo.update();
|
|
459
|
+
const stats = ppo.getStats();
|
|
460
|
+
|
|
461
|
+
expect(stats.bufferSize).toBe(0);
|
|
462
|
+
});
|
|
463
|
+
});
|
|
464
|
+
|
|
465
|
+
describe('Decision Transformer', () => {
|
|
466
|
+
let dt: DecisionTransformer;
|
|
467
|
+
|
|
468
|
+
beforeEach(() => {
|
|
469
|
+
dt = createDecisionTransformer({
|
|
470
|
+
contextLength: 20,
|
|
471
|
+
numHeads: 4,
|
|
472
|
+
numLayers: 2,
|
|
473
|
+
hiddenDim: 64,
|
|
474
|
+
embeddingDim: 32,
|
|
475
|
+
});
|
|
476
|
+
});
|
|
477
|
+
|
|
478
|
+
it('should initialize correctly', () => {
|
|
479
|
+
expect(dt).toBeDefined();
|
|
480
|
+
const stats = dt.getStats();
|
|
481
|
+
expect(stats.updateCount).toBe(0);
|
|
482
|
+
expect(stats.bufferSize).toBe(0);
|
|
483
|
+
expect(stats.contextLength).toBe(20);
|
|
484
|
+
expect(stats.numLayers).toBe(2);
|
|
485
|
+
});
|
|
486
|
+
|
|
487
|
+
it('should add complete trajectories to buffer', () => {
|
|
488
|
+
const trajectory = createTestTrajectory(10);
|
|
489
|
+
dt.addTrajectory(trajectory);
|
|
490
|
+
|
|
491
|
+
const stats = dt.getStats();
|
|
492
|
+
expect(stats.bufferSize).toBe(1);
|
|
493
|
+
});
|
|
494
|
+
|
|
495
|
+
it('should not add incomplete trajectories', () => {
|
|
496
|
+
const incompleteTrajectory: Trajectory = {
|
|
497
|
+
...createTestTrajectory(5),
|
|
498
|
+
isComplete: false,
|
|
499
|
+
};
|
|
500
|
+
|
|
501
|
+
dt.addTrajectory(incompleteTrajectory);
|
|
502
|
+
const stats = dt.getStats();
|
|
503
|
+
|
|
504
|
+
expect(stats.bufferSize).toBe(0);
|
|
505
|
+
});
|
|
506
|
+
|
|
507
|
+
it('should train on buffered trajectories', () => {
|
|
508
|
+
// Add multiple trajectories
|
|
509
|
+
for (let i = 0; i < 5; i++) {
|
|
510
|
+
dt.addTrajectory(createTestTrajectory(10));
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
const result = dt.train();
|
|
514
|
+
|
|
515
|
+
expect(result.loss).toBeGreaterThanOrEqual(0);
|
|
516
|
+
expect(result.accuracy).toBeGreaterThanOrEqual(0);
|
|
517
|
+
expect(result.accuracy).toBeLessThanOrEqual(1);
|
|
518
|
+
});
|
|
519
|
+
|
|
520
|
+
it('should train under performance target (<10ms per batch)', () => {
|
|
521
|
+
for (let i = 0; i < 3; i++) {
|
|
522
|
+
dt.addTrajectory(createTestTrajectory(5));
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
const startTime = performance.now();
|
|
526
|
+
dt.train();
|
|
527
|
+
const elapsed = performance.now() - startTime;
|
|
528
|
+
|
|
529
|
+
expect(elapsed).toBeLessThan(100); // Allow overhead for transformer
|
|
530
|
+
});
|
|
531
|
+
|
|
532
|
+
it('should get action conditioned on target return', () => {
|
|
533
|
+
const states = [
|
|
534
|
+
new Float32Array(768).fill(0.1),
|
|
535
|
+
new Float32Array(768).fill(0.2),
|
|
536
|
+
new Float32Array(768).fill(0.3),
|
|
537
|
+
];
|
|
538
|
+
const actions = [0, 1, 2];
|
|
539
|
+
const targetReturn = 0.9;
|
|
540
|
+
|
|
541
|
+
const action = dt.getAction(states, actions, targetReturn);
|
|
542
|
+
|
|
543
|
+
expect(action).toBeGreaterThanOrEqual(0);
|
|
544
|
+
expect(action).toBeLessThan(4);
|
|
545
|
+
});
|
|
546
|
+
|
|
547
|
+
it('should handle causal attention masking', () => {
|
|
548
|
+
// Train with sequence data
|
|
549
|
+
for (let i = 0; i < 5; i++) {
|
|
550
|
+
dt.addTrajectory(createTestTrajectory(15));
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
expect(() => dt.train()).not.toThrow();
|
|
554
|
+
});
|
|
555
|
+
|
|
556
|
+
it('should maintain bounded trajectory buffer', () => {
|
|
557
|
+
// Add more than max capacity (1000)
|
|
558
|
+
for (let i = 0; i < 1100; i++) {
|
|
559
|
+
dt.addTrajectory(createTestTrajectory(5));
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
const stats = dt.getStats();
|
|
563
|
+
expect(stats.bufferSize).toBe(1000);
|
|
564
|
+
});
|
|
565
|
+
|
|
566
|
+
it('should handle varying trajectory lengths', () => {
|
|
567
|
+
dt.addTrajectory(createTestTrajectory(3));
|
|
568
|
+
dt.addTrajectory(createTestTrajectory(10));
|
|
569
|
+
dt.addTrajectory(createTestTrajectory(25));
|
|
570
|
+
|
|
571
|
+
expect(() => dt.train()).not.toThrow();
|
|
572
|
+
});
|
|
573
|
+
|
|
574
|
+
it('should compute returns-to-go correctly', () => {
|
|
575
|
+
const trajectory = createTestTrajectory(5);
|
|
576
|
+
dt.addTrajectory(trajectory);
|
|
577
|
+
|
|
578
|
+
expect(() => dt.train()).not.toThrow();
|
|
579
|
+
const stats = dt.getStats();
|
|
580
|
+
expect(stats.avgLoss).toBeGreaterThanOrEqual(0);
|
|
581
|
+
});
|
|
582
|
+
});
|