@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Proximal Policy Optimization (PPO)
|
|
3
|
+
*
|
|
4
|
+
* Implements PPO algorithm for stable policy learning with:
|
|
5
|
+
* - Clipped surrogate objective
|
|
6
|
+
* - GAE (Generalized Advantage Estimation)
|
|
7
|
+
* - Value function clipping
|
|
8
|
+
* - Entropy bonus
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <10ms per update step
|
|
11
|
+
*/
|
|
12
|
+
|
|
13
|
+
import type {
|
|
14
|
+
PPOConfig,
|
|
15
|
+
Trajectory,
|
|
16
|
+
TrajectoryStep,
|
|
17
|
+
} from '../types.js';
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Default PPO configuration
|
|
21
|
+
*/
|
|
22
|
+
export const DEFAULT_PPO_CONFIG: PPOConfig = {
|
|
23
|
+
algorithm: 'ppo',
|
|
24
|
+
learningRate: 0.0003,
|
|
25
|
+
gamma: 0.99,
|
|
26
|
+
entropyCoef: 0.01,
|
|
27
|
+
valueLossCoef: 0.5,
|
|
28
|
+
maxGradNorm: 0.5,
|
|
29
|
+
epochs: 4,
|
|
30
|
+
miniBatchSize: 64,
|
|
31
|
+
clipRange: 0.2,
|
|
32
|
+
clipRangeVf: null,
|
|
33
|
+
targetKL: 0.01,
|
|
34
|
+
gaeLambda: 0.95,
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
/**
|
|
38
|
+
* PPO experience buffer entry
|
|
39
|
+
*/
|
|
40
|
+
interface PPOExperience {
|
|
41
|
+
state: Float32Array;
|
|
42
|
+
action: number;
|
|
43
|
+
reward: number;
|
|
44
|
+
value: number;
|
|
45
|
+
logProb: number;
|
|
46
|
+
advantage: number;
|
|
47
|
+
return_: number;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* PPO Algorithm Implementation
|
|
52
|
+
*/
|
|
53
|
+
export class PPOAlgorithm {
|
|
54
|
+
private config: PPOConfig;
|
|
55
|
+
|
|
56
|
+
// Policy network weights (simplified linear model for speed)
|
|
57
|
+
private policyWeights: Float32Array;
|
|
58
|
+
private valueWeights: Float32Array;
|
|
59
|
+
|
|
60
|
+
// Optimizer state
|
|
61
|
+
private policyMomentum: Float32Array;
|
|
62
|
+
private valueMomentum: Float32Array;
|
|
63
|
+
|
|
64
|
+
// Experience buffer
|
|
65
|
+
private buffer: PPOExperience[] = [];
|
|
66
|
+
|
|
67
|
+
// Statistics
|
|
68
|
+
private updateCount = 0;
|
|
69
|
+
private totalLoss = 0;
|
|
70
|
+
private approxKL = 0;
|
|
71
|
+
private clipFraction = 0;
|
|
72
|
+
|
|
73
|
+
constructor(config: Partial<PPOConfig> = {}) {
|
|
74
|
+
this.config = { ...DEFAULT_PPO_CONFIG, ...config };
|
|
75
|
+
|
|
76
|
+
// Initialize weights (768 input dim, simplified)
|
|
77
|
+
const dim = 768;
|
|
78
|
+
this.policyWeights = new Float32Array(dim);
|
|
79
|
+
this.valueWeights = new Float32Array(dim);
|
|
80
|
+
this.policyMomentum = new Float32Array(dim);
|
|
81
|
+
this.valueMomentum = new Float32Array(dim);
|
|
82
|
+
|
|
83
|
+
// Xavier initialization
|
|
84
|
+
const scale = Math.sqrt(2 / dim);
|
|
85
|
+
for (let i = 0; i < dim; i++) {
|
|
86
|
+
this.policyWeights[i] = (Math.random() - 0.5) * scale;
|
|
87
|
+
this.valueWeights[i] = (Math.random() - 0.5) * scale;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Add experience from trajectory
|
|
93
|
+
*/
|
|
94
|
+
addExperience(trajectory: Trajectory): void {
|
|
95
|
+
if (trajectory.steps.length === 0) return;
|
|
96
|
+
|
|
97
|
+
// Compute values for each step
|
|
98
|
+
const values = trajectory.steps.map(step =>
|
|
99
|
+
this.computeValue(step.stateAfter)
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
// Compute advantages using GAE
|
|
103
|
+
const advantages = this.computeGAE(
|
|
104
|
+
trajectory.steps.map(s => s.reward),
|
|
105
|
+
values
|
|
106
|
+
);
|
|
107
|
+
|
|
108
|
+
// Compute returns
|
|
109
|
+
const returns = this.computeReturns(
|
|
110
|
+
trajectory.steps.map(s => s.reward)
|
|
111
|
+
);
|
|
112
|
+
|
|
113
|
+
// Add to buffer
|
|
114
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
115
|
+
const step = trajectory.steps[i];
|
|
116
|
+
this.buffer.push({
|
|
117
|
+
state: step.stateAfter,
|
|
118
|
+
action: this.hashAction(step.action),
|
|
119
|
+
reward: step.reward,
|
|
120
|
+
value: values[i],
|
|
121
|
+
logProb: this.computeLogProb(step.stateAfter, step.action),
|
|
122
|
+
advantage: advantages[i],
|
|
123
|
+
return_: returns[i],
|
|
124
|
+
});
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* Perform PPO update
|
|
130
|
+
* Target: <10ms
|
|
131
|
+
*/
|
|
132
|
+
update(): { policyLoss: number; valueLoss: number; entropy: number } {
|
|
133
|
+
const startTime = performance.now();
|
|
134
|
+
|
|
135
|
+
if (this.buffer.length < this.config.miniBatchSize) {
|
|
136
|
+
return { policyLoss: 0, valueLoss: 0, entropy: 0 };
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Normalize advantages
|
|
140
|
+
const advantages = this.buffer.map(e => e.advantage);
|
|
141
|
+
const advMean = advantages.reduce((a, b) => a + b, 0) / advantages.length;
|
|
142
|
+
const advStd = Math.sqrt(
|
|
143
|
+
advantages.reduce((a, b) => a + (b - advMean) ** 2, 0) / advantages.length
|
|
144
|
+
) + 1e-8;
|
|
145
|
+
|
|
146
|
+
for (const exp of this.buffer) {
|
|
147
|
+
exp.advantage = (exp.advantage - advMean) / advStd;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
let totalPolicyLoss = 0;
|
|
151
|
+
let totalValueLoss = 0;
|
|
152
|
+
let totalEntropy = 0;
|
|
153
|
+
let totalClipFrac = 0;
|
|
154
|
+
let totalKL = 0;
|
|
155
|
+
let numUpdates = 0;
|
|
156
|
+
|
|
157
|
+
// Multiple epochs
|
|
158
|
+
for (let epoch = 0; epoch < this.config.epochs; epoch++) {
|
|
159
|
+
// Shuffle buffer
|
|
160
|
+
this.shuffleBuffer();
|
|
161
|
+
|
|
162
|
+
// Process mini-batches
|
|
163
|
+
for (let i = 0; i < this.buffer.length; i += this.config.miniBatchSize) {
|
|
164
|
+
const batch = this.buffer.slice(i, i + this.config.miniBatchSize);
|
|
165
|
+
if (batch.length < this.config.miniBatchSize / 2) continue;
|
|
166
|
+
|
|
167
|
+
const result = this.updateMiniBatch(batch);
|
|
168
|
+
totalPolicyLoss += result.policyLoss;
|
|
169
|
+
totalValueLoss += result.valueLoss;
|
|
170
|
+
totalEntropy += result.entropy;
|
|
171
|
+
totalClipFrac += result.clipFrac;
|
|
172
|
+
totalKL += result.kl;
|
|
173
|
+
numUpdates++;
|
|
174
|
+
|
|
175
|
+
// Early stopping if KL too high
|
|
176
|
+
if (result.kl > this.config.targetKL * 1.5) {
|
|
177
|
+
break;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Clear buffer
|
|
183
|
+
this.buffer = [];
|
|
184
|
+
this.updateCount++;
|
|
185
|
+
|
|
186
|
+
const elapsed = performance.now() - startTime;
|
|
187
|
+
if (elapsed > 10) {
|
|
188
|
+
console.warn(`PPO update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
return {
|
|
192
|
+
policyLoss: numUpdates > 0 ? totalPolicyLoss / numUpdates : 0,
|
|
193
|
+
valueLoss: numUpdates > 0 ? totalValueLoss / numUpdates : 0,
|
|
194
|
+
entropy: numUpdates > 0 ? totalEntropy / numUpdates : 0,
|
|
195
|
+
};
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/**
|
|
199
|
+
* Get action from policy
|
|
200
|
+
*/
|
|
201
|
+
getAction(state: Float32Array): { action: number; logProb: number; value: number } {
|
|
202
|
+
const logits = this.computeLogits(state);
|
|
203
|
+
const probs = this.softmax(logits);
|
|
204
|
+
const action = this.sampleAction(probs);
|
|
205
|
+
|
|
206
|
+
return {
|
|
207
|
+
action,
|
|
208
|
+
logProb: Math.log(probs[action] + 1e-8),
|
|
209
|
+
value: this.computeValue(state),
|
|
210
|
+
};
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
/**
|
|
214
|
+
* Get statistics
|
|
215
|
+
*/
|
|
216
|
+
getStats(): Record<string, number> {
|
|
217
|
+
return {
|
|
218
|
+
updateCount: this.updateCount,
|
|
219
|
+
bufferSize: this.buffer.length,
|
|
220
|
+
avgLoss: this.updateCount > 0 ? this.totalLoss / this.updateCount : 0,
|
|
221
|
+
approxKL: this.approxKL,
|
|
222
|
+
clipFraction: this.clipFraction,
|
|
223
|
+
};
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// ==========================================================================
|
|
227
|
+
// Private Methods
|
|
228
|
+
// ==========================================================================
|
|
229
|
+
|
|
230
|
+
private computeValue(state: Float32Array): number {
|
|
231
|
+
let value = 0;
|
|
232
|
+
for (let i = 0; i < Math.min(state.length, this.valueWeights.length); i++) {
|
|
233
|
+
value += state[i] * this.valueWeights[i];
|
|
234
|
+
}
|
|
235
|
+
return value;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
private computeLogits(state: Float32Array): Float32Array {
|
|
239
|
+
// Simplified: 4 discrete actions
|
|
240
|
+
const numActions = 4;
|
|
241
|
+
const logits = new Float32Array(numActions);
|
|
242
|
+
|
|
243
|
+
for (let a = 0; a < numActions; a++) {
|
|
244
|
+
for (let i = 0; i < Math.min(state.length, this.policyWeights.length); i++) {
|
|
245
|
+
logits[a] += state[i] * this.policyWeights[i] * (1 + a * 0.1);
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
return logits;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
private computeLogProb(state: Float32Array, action: string): number {
|
|
253
|
+
const logits = this.computeLogits(state);
|
|
254
|
+
const probs = this.softmax(logits);
|
|
255
|
+
const actionIdx = this.hashAction(action);
|
|
256
|
+
return Math.log(probs[actionIdx] + 1e-8);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
private hashAction(action: string): number {
|
|
260
|
+
// Simple hash to action index (0-3)
|
|
261
|
+
let hash = 0;
|
|
262
|
+
for (let i = 0; i < action.length; i++) {
|
|
263
|
+
hash = (hash * 31 + action.charCodeAt(i)) % 4;
|
|
264
|
+
}
|
|
265
|
+
return hash;
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
private softmax(logits: Float32Array): Float32Array {
|
|
269
|
+
const max = Math.max(...logits);
|
|
270
|
+
const exps = new Float32Array(logits.length);
|
|
271
|
+
let sum = 0;
|
|
272
|
+
|
|
273
|
+
for (let i = 0; i < logits.length; i++) {
|
|
274
|
+
exps[i] = Math.exp(logits[i] - max);
|
|
275
|
+
sum += exps[i];
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
for (let i = 0; i < exps.length; i++) {
|
|
279
|
+
exps[i] /= sum;
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
return exps;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
private sampleAction(probs: Float32Array): number {
|
|
286
|
+
const r = Math.random();
|
|
287
|
+
let cumSum = 0;
|
|
288
|
+
for (let i = 0; i < probs.length; i++) {
|
|
289
|
+
cumSum += probs[i];
|
|
290
|
+
if (r < cumSum) return i;
|
|
291
|
+
}
|
|
292
|
+
return probs.length - 1;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
private computeGAE(rewards: number[], values: number[]): number[] {
|
|
296
|
+
const advantages = new Array(rewards.length).fill(0);
|
|
297
|
+
let lastGae = 0;
|
|
298
|
+
|
|
299
|
+
for (let t = rewards.length - 1; t >= 0; t--) {
|
|
300
|
+
const nextValue = t < rewards.length - 1 ? values[t + 1] : 0;
|
|
301
|
+
const delta = rewards[t] + this.config.gamma * nextValue - values[t];
|
|
302
|
+
lastGae = delta + this.config.gamma * this.config.gaeLambda * lastGae;
|
|
303
|
+
advantages[t] = lastGae;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
return advantages;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
private computeReturns(rewards: number[]): number[] {
|
|
310
|
+
const returns = new Array(rewards.length).fill(0);
|
|
311
|
+
let cumReturn = 0;
|
|
312
|
+
|
|
313
|
+
for (let t = rewards.length - 1; t >= 0; t--) {
|
|
314
|
+
cumReturn = rewards[t] + this.config.gamma * cumReturn;
|
|
315
|
+
returns[t] = cumReturn;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
return returns;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
private shuffleBuffer(): void {
|
|
322
|
+
for (let i = this.buffer.length - 1; i > 0; i--) {
|
|
323
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
324
|
+
[this.buffer[i], this.buffer[j]] = [this.buffer[j], this.buffer[i]];
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
private updateMiniBatch(batch: PPOExperience[]): {
|
|
329
|
+
policyLoss: number;
|
|
330
|
+
valueLoss: number;
|
|
331
|
+
entropy: number;
|
|
332
|
+
clipFrac: number;
|
|
333
|
+
kl: number;
|
|
334
|
+
} {
|
|
335
|
+
let policyLoss = 0;
|
|
336
|
+
let valueLoss = 0;
|
|
337
|
+
let entropy = 0;
|
|
338
|
+
let clipFrac = 0;
|
|
339
|
+
let kl = 0;
|
|
340
|
+
|
|
341
|
+
const policyGrad = new Float32Array(this.policyWeights.length);
|
|
342
|
+
const valueGrad = new Float32Array(this.valueWeights.length);
|
|
343
|
+
|
|
344
|
+
for (const exp of batch) {
|
|
345
|
+
// Current policy
|
|
346
|
+
const logits = this.computeLogits(exp.state);
|
|
347
|
+
const probs = this.softmax(logits);
|
|
348
|
+
const newLogProb = Math.log(probs[exp.action] + 1e-8);
|
|
349
|
+
const currentValue = this.computeValue(exp.state);
|
|
350
|
+
|
|
351
|
+
// Ratio for PPO
|
|
352
|
+
const ratio = Math.exp(newLogProb - exp.logProb);
|
|
353
|
+
|
|
354
|
+
// Clipped surrogate objective
|
|
355
|
+
const surr1 = ratio * exp.advantage;
|
|
356
|
+
const surr2 = Math.max(
|
|
357
|
+
Math.min(ratio, 1 + this.config.clipRange),
|
|
358
|
+
1 - this.config.clipRange
|
|
359
|
+
) * exp.advantage;
|
|
360
|
+
|
|
361
|
+
const policyLossI = -Math.min(surr1, surr2);
|
|
362
|
+
policyLoss += policyLossI;
|
|
363
|
+
|
|
364
|
+
// Track clipping
|
|
365
|
+
if (Math.abs(ratio - 1) > this.config.clipRange) {
|
|
366
|
+
clipFrac++;
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// KL divergence approximation
|
|
370
|
+
kl += (exp.logProb - newLogProb);
|
|
371
|
+
|
|
372
|
+
// Value loss
|
|
373
|
+
let valueLossI: number;
|
|
374
|
+
if (this.config.clipRangeVf !== null) {
|
|
375
|
+
const valuePred = currentValue;
|
|
376
|
+
const valueClipped = exp.value + Math.max(
|
|
377
|
+
Math.min(valuePred - exp.value, this.config.clipRangeVf),
|
|
378
|
+
-this.config.clipRangeVf
|
|
379
|
+
);
|
|
380
|
+
const vf1 = (valuePred - exp.return_) ** 2;
|
|
381
|
+
const vf2 = (valueClipped - exp.return_) ** 2;
|
|
382
|
+
valueLossI = Math.max(vf1, vf2);
|
|
383
|
+
} else {
|
|
384
|
+
valueLossI = (currentValue - exp.return_) ** 2;
|
|
385
|
+
}
|
|
386
|
+
valueLoss += valueLossI;
|
|
387
|
+
|
|
388
|
+
// Entropy
|
|
389
|
+
let entropyI = 0;
|
|
390
|
+
for (const p of probs) {
|
|
391
|
+
if (p > 0) entropyI -= p * Math.log(p);
|
|
392
|
+
}
|
|
393
|
+
entropy += entropyI;
|
|
394
|
+
|
|
395
|
+
// Compute gradients (simplified)
|
|
396
|
+
for (let i = 0; i < Math.min(exp.state.length, policyGrad.length); i++) {
|
|
397
|
+
policyGrad[i] += exp.state[i] * policyLossI * 0.01;
|
|
398
|
+
valueGrad[i] += exp.state[i] * valueLossI * 0.01;
|
|
399
|
+
}
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
// Apply gradients with momentum
|
|
403
|
+
const lr = this.config.learningRate;
|
|
404
|
+
const beta = 0.9;
|
|
405
|
+
|
|
406
|
+
for (let i = 0; i < this.policyWeights.length; i++) {
|
|
407
|
+
this.policyMomentum[i] = beta * this.policyMomentum[i] + (1 - beta) * policyGrad[i];
|
|
408
|
+
this.policyWeights[i] -= lr * this.policyMomentum[i];
|
|
409
|
+
|
|
410
|
+
this.valueMomentum[i] = beta * this.valueMomentum[i] + (1 - beta) * valueGrad[i];
|
|
411
|
+
this.valueWeights[i] -= lr * this.valueMomentum[i];
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
return {
|
|
415
|
+
policyLoss: policyLoss / batch.length,
|
|
416
|
+
valueLoss: valueLoss / batch.length,
|
|
417
|
+
entropy: entropy / batch.length,
|
|
418
|
+
clipFrac: clipFrac / batch.length,
|
|
419
|
+
kl: kl / batch.length,
|
|
420
|
+
};
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
/**
|
|
425
|
+
* Factory function
|
|
426
|
+
*/
|
|
427
|
+
export function createPPO(config?: Partial<PPOConfig>): PPOAlgorithm {
|
|
428
|
+
return new PPOAlgorithm(config);
|
|
429
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tabular Q-Learning
|
|
3
|
+
*
|
|
4
|
+
* Classic Q-learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Eligibility traces (optional)
|
|
8
|
+
* - Experience replay
|
|
9
|
+
*
|
|
10
|
+
* Suitable for smaller state spaces or discretized environments.
|
|
11
|
+
* Performance Target: <1ms per update
|
|
12
|
+
*/
|
|
13
|
+
import type { Trajectory, RLConfig } from '../types.js';
|
|
14
|
+
/**
|
|
15
|
+
* Q-Learning configuration
|
|
16
|
+
*/
|
|
17
|
+
export interface QLearningConfig extends RLConfig {
|
|
18
|
+
algorithm: 'q-learning';
|
|
19
|
+
explorationInitial: number;
|
|
20
|
+
explorationFinal: number;
|
|
21
|
+
explorationDecay: number;
|
|
22
|
+
maxStates: number;
|
|
23
|
+
useEligibilityTraces: boolean;
|
|
24
|
+
traceDecay: number;
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Default Q-Learning configuration
|
|
28
|
+
*/
|
|
29
|
+
export declare const DEFAULT_QLEARNING_CONFIG: QLearningConfig;
|
|
30
|
+
/**
|
|
31
|
+
* Q-Learning Algorithm Implementation
|
|
32
|
+
*/
|
|
33
|
+
export declare class QLearning {
|
|
34
|
+
private config;
|
|
35
|
+
private qTable;
|
|
36
|
+
private epsilon;
|
|
37
|
+
private stepCount;
|
|
38
|
+
private numActions;
|
|
39
|
+
private traces;
|
|
40
|
+
private updateCount;
|
|
41
|
+
private avgTDError;
|
|
42
|
+
constructor(config?: Partial<QLearningConfig>);
|
|
43
|
+
/**
|
|
44
|
+
* Update Q-values from trajectory
|
|
45
|
+
*/
|
|
46
|
+
update(trajectory: Trajectory): {
|
|
47
|
+
tdError: number;
|
|
48
|
+
};
|
|
49
|
+
/**
|
|
50
|
+
* Get action using epsilon-greedy policy
|
|
51
|
+
*/
|
|
52
|
+
getAction(state: Float32Array, explore?: boolean): number;
|
|
53
|
+
/**
|
|
54
|
+
* Get Q-values for a state
|
|
55
|
+
*/
|
|
56
|
+
getQValues(state: Float32Array): Float32Array;
|
|
57
|
+
/**
|
|
58
|
+
* Get statistics
|
|
59
|
+
*/
|
|
60
|
+
getStats(): Record<string, number>;
|
|
61
|
+
/**
|
|
62
|
+
* Reset Q-table
|
|
63
|
+
*/
|
|
64
|
+
reset(): void;
|
|
65
|
+
private hashState;
|
|
66
|
+
private hashAction;
|
|
67
|
+
private getOrCreateEntry;
|
|
68
|
+
private updateTrace;
|
|
69
|
+
private updateWithTraces;
|
|
70
|
+
private pruneQTable;
|
|
71
|
+
private argmax;
|
|
72
|
+
}
|
|
73
|
+
/**
|
|
74
|
+
* Factory function
|
|
75
|
+
*/
|
|
76
|
+
export declare function createQLearning(config?: Partial<QLearningConfig>): QLearning;
|
|
77
|
+
//# sourceMappingURL=q-learning.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"q-learning.d.ts","sourceRoot":"","sources":["q-learning.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAEH,OAAO,KAAK,EAAE,UAAU,EAAE,QAAQ,EAAE,MAAM,aAAa,CAAC;AAExD;;GAEG;AACH,MAAM,WAAW,eAAgB,SAAQ,QAAQ;IAC/C,SAAS,EAAE,YAAY,CAAC;IACxB,kBAAkB,EAAE,MAAM,CAAC;IAC3B,gBAAgB,EAAE,MAAM,CAAC;IACzB,gBAAgB,EAAE,MAAM,CAAC;IACzB,SAAS,EAAE,MAAM,CAAC;IAClB,oBAAoB,EAAE,OAAO,CAAC;IAC9B,UAAU,EAAE,MAAM,CAAC;CACpB;AAED;;GAEG;AACH,eAAO,MAAM,wBAAwB,EAAE,eAetC,CAAC;AAWF;;GAEG;AACH,qBAAa,SAAS;IACpB,OAAO,CAAC,MAAM,CAAkB;IAGhC,OAAO,CAAC,MAAM,CAAkC;IAGhD,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,UAAU,CAAK;IAGvB,OAAO,CAAC,MAAM,CAAwC;IAGtD,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,UAAU,CAAK;gBAEX,MAAM,GAAE,OAAO,CAAC,eAAe,CAAM;IAKjD;;OAEG;IACH,MAAM,CAAC,UAAU,EAAE,UAAU,GAAG;QAAE,OAAO,EAAE,MAAM,CAAA;KAAE;IA8EnD;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,GAAE,OAAc,GAAG,MAAM;IAe/D;;OAEG;IACH,UAAU,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IAW7C;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAUlC;;OAEG;IACH,KAAK,IAAI,IAAI;IAab,OAAO,CAAC,SAAS;IAejB,OAAO,CAAC,UAAU;IAQlB,OAAO,CAAC,gBAAgB;IAexB,OAAO,CAAC,WAAW;IAuBnB,OAAO,CAAC,gBAAgB;IAexB,OAAO,CAAC,WAAW;IAWnB,OAAO,CAAC,MAAM;CAWf;AAED;;GAEG;AACH,wBAAgB,eAAe,CAAC,MAAM,CAAC,EAAE,OAAO,CAAC,eAAe,CAAC,GAAG,SAAS,CAE5E"}
|