@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SARSA (State-Action-Reward-State-Action)
|
|
3
|
+
*
|
|
4
|
+
* On-policy TD learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Expected SARSA variant (optional)
|
|
8
|
+
* - Eligibility traces (SARSA-lambda)
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <1ms per update
|
|
11
|
+
*/
|
|
12
|
+
/**
|
|
13
|
+
* Default SARSA configuration
|
|
14
|
+
*/
|
|
15
|
+
export const DEFAULT_SARSA_CONFIG = {
|
|
16
|
+
algorithm: 'sarsa',
|
|
17
|
+
learningRate: 0.1,
|
|
18
|
+
gamma: 0.99,
|
|
19
|
+
entropyCoef: 0,
|
|
20
|
+
valueLossCoef: 1,
|
|
21
|
+
maxGradNorm: 1,
|
|
22
|
+
epochs: 1,
|
|
23
|
+
miniBatchSize: 1,
|
|
24
|
+
explorationInitial: 1.0,
|
|
25
|
+
explorationFinal: 0.01,
|
|
26
|
+
explorationDecay: 10000,
|
|
27
|
+
maxStates: 10000,
|
|
28
|
+
useExpectedSARSA: false,
|
|
29
|
+
useEligibilityTraces: false,
|
|
30
|
+
traceDecay: 0.9,
|
|
31
|
+
};
|
|
32
|
+
/**
|
|
33
|
+
* SARSA Algorithm Implementation
|
|
34
|
+
*/
|
|
35
|
+
export class SARSAAlgorithm {
|
|
36
|
+
config;
|
|
37
|
+
// Q-table
|
|
38
|
+
qTable = new Map();
|
|
39
|
+
// Exploration
|
|
40
|
+
epsilon;
|
|
41
|
+
stepCount = 0;
|
|
42
|
+
// Number of actions
|
|
43
|
+
numActions = 4;
|
|
44
|
+
// Eligibility traces
|
|
45
|
+
traces = new Map();
|
|
46
|
+
// Statistics
|
|
47
|
+
updateCount = 0;
|
|
48
|
+
avgTDError = 0;
|
|
49
|
+
constructor(config = {}) {
|
|
50
|
+
this.config = { ...DEFAULT_SARSA_CONFIG, ...config };
|
|
51
|
+
this.epsilon = this.config.explorationInitial;
|
|
52
|
+
}
|
|
53
|
+
/**
|
|
54
|
+
* Update Q-values from trajectory using SARSA
|
|
55
|
+
*/
|
|
56
|
+
update(trajectory) {
|
|
57
|
+
const startTime = performance.now();
|
|
58
|
+
if (trajectory.steps.length < 2) {
|
|
59
|
+
return { tdError: 0 };
|
|
60
|
+
}
|
|
61
|
+
let totalTDError = 0;
|
|
62
|
+
// Reset eligibility traces
|
|
63
|
+
if (this.config.useEligibilityTraces) {
|
|
64
|
+
this.traces.clear();
|
|
65
|
+
}
|
|
66
|
+
for (let i = 0; i < trajectory.steps.length - 1; i++) {
|
|
67
|
+
const step = trajectory.steps[i];
|
|
68
|
+
const nextStep = trajectory.steps[i + 1];
|
|
69
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
70
|
+
const action = this.hashAction(step.action);
|
|
71
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
72
|
+
const nextAction = this.hashAction(nextStep.action);
|
|
73
|
+
// Get or create entries
|
|
74
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
75
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
76
|
+
// Current Q-value
|
|
77
|
+
const currentQ = qEntry.qValues[action];
|
|
78
|
+
// Compute target Q-value using SARSA update rule
|
|
79
|
+
let targetQ;
|
|
80
|
+
if (this.config.useExpectedSARSA) {
|
|
81
|
+
// Expected SARSA: use expected value under current policy
|
|
82
|
+
targetQ = step.reward + this.config.gamma * this.expectedValue(nextEntry.qValues);
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
// Standard SARSA: use actual next action
|
|
86
|
+
targetQ = step.reward + this.config.gamma * nextEntry.qValues[nextAction];
|
|
87
|
+
}
|
|
88
|
+
// TD error
|
|
89
|
+
const tdError = targetQ - currentQ;
|
|
90
|
+
totalTDError += Math.abs(tdError);
|
|
91
|
+
if (this.config.useEligibilityTraces) {
|
|
92
|
+
this.updateTrace(stateKey, action);
|
|
93
|
+
this.updateWithTraces(tdError);
|
|
94
|
+
}
|
|
95
|
+
else {
|
|
96
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
97
|
+
qEntry.visits++;
|
|
98
|
+
qEntry.lastUpdate = Date.now();
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
// Handle terminal state
|
|
102
|
+
const lastStep = trajectory.steps[trajectory.steps.length - 1];
|
|
103
|
+
const lastStateKey = this.hashState(lastStep.stateBefore);
|
|
104
|
+
const lastAction = this.hashAction(lastStep.action);
|
|
105
|
+
const lastEntry = this.getOrCreateEntry(lastStateKey);
|
|
106
|
+
const terminalTDError = lastStep.reward - lastEntry.qValues[lastAction];
|
|
107
|
+
lastEntry.qValues[lastAction] += this.config.learningRate * terminalTDError;
|
|
108
|
+
totalTDError += Math.abs(terminalTDError);
|
|
109
|
+
// Decay exploration
|
|
110
|
+
this.stepCount += trajectory.steps.length;
|
|
111
|
+
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
112
|
+
// Prune if needed
|
|
113
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
114
|
+
this.pruneQTable();
|
|
115
|
+
}
|
|
116
|
+
this.updateCount++;
|
|
117
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
118
|
+
const elapsed = performance.now() - startTime;
|
|
119
|
+
if (elapsed > 1) {
|
|
120
|
+
console.warn(`SARSA update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
121
|
+
}
|
|
122
|
+
return { tdError: this.avgTDError };
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Get action using epsilon-greedy policy
|
|
126
|
+
*/
|
|
127
|
+
getAction(state, explore = true) {
|
|
128
|
+
if (explore && Math.random() < this.epsilon) {
|
|
129
|
+
return Math.floor(Math.random() * this.numActions);
|
|
130
|
+
}
|
|
131
|
+
const stateKey = this.hashState(state);
|
|
132
|
+
const entry = this.qTable.get(stateKey);
|
|
133
|
+
if (!entry) {
|
|
134
|
+
return Math.floor(Math.random() * this.numActions);
|
|
135
|
+
}
|
|
136
|
+
return this.argmax(entry.qValues);
|
|
137
|
+
}
|
|
138
|
+
/**
|
|
139
|
+
* Get action probabilities for a state
|
|
140
|
+
*/
|
|
141
|
+
getActionProbabilities(state) {
|
|
142
|
+
const probs = new Float32Array(this.numActions);
|
|
143
|
+
const stateKey = this.hashState(state);
|
|
144
|
+
const entry = this.qTable.get(stateKey);
|
|
145
|
+
if (!entry) {
|
|
146
|
+
// Uniform distribution
|
|
147
|
+
const uniform = 1 / this.numActions;
|
|
148
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
149
|
+
probs[a] = uniform;
|
|
150
|
+
}
|
|
151
|
+
return probs;
|
|
152
|
+
}
|
|
153
|
+
// Epsilon-greedy probabilities
|
|
154
|
+
const greedyAction = this.argmax(entry.qValues);
|
|
155
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
156
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
157
|
+
probs[a] = exploreProb;
|
|
158
|
+
}
|
|
159
|
+
probs[greedyAction] += 1 - this.epsilon;
|
|
160
|
+
return probs;
|
|
161
|
+
}
|
|
162
|
+
/**
|
|
163
|
+
* Get Q-values for a state
|
|
164
|
+
*/
|
|
165
|
+
getQValues(state) {
|
|
166
|
+
const stateKey = this.hashState(state);
|
|
167
|
+
const entry = this.qTable.get(stateKey);
|
|
168
|
+
if (!entry) {
|
|
169
|
+
return new Float32Array(this.numActions);
|
|
170
|
+
}
|
|
171
|
+
return new Float32Array(entry.qValues);
|
|
172
|
+
}
|
|
173
|
+
/**
|
|
174
|
+
* Get statistics
|
|
175
|
+
*/
|
|
176
|
+
getStats() {
|
|
177
|
+
return {
|
|
178
|
+
updateCount: this.updateCount,
|
|
179
|
+
qTableSize: this.qTable.size,
|
|
180
|
+
epsilon: this.epsilon,
|
|
181
|
+
avgTDError: this.avgTDError,
|
|
182
|
+
stepCount: this.stepCount,
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
/**
|
|
186
|
+
* Reset algorithm state
|
|
187
|
+
*/
|
|
188
|
+
reset() {
|
|
189
|
+
this.qTable.clear();
|
|
190
|
+
this.traces.clear();
|
|
191
|
+
this.epsilon = this.config.explorationInitial;
|
|
192
|
+
this.stepCount = 0;
|
|
193
|
+
this.updateCount = 0;
|
|
194
|
+
this.avgTDError = 0;
|
|
195
|
+
}
|
|
196
|
+
// ==========================================================================
|
|
197
|
+
// Private Methods
|
|
198
|
+
// ==========================================================================
|
|
199
|
+
hashState(state) {
|
|
200
|
+
const bins = 10;
|
|
201
|
+
const parts = [];
|
|
202
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
203
|
+
const normalized = (state[i] + 1) / 2;
|
|
204
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
205
|
+
parts.push(bin);
|
|
206
|
+
}
|
|
207
|
+
return parts.join(',');
|
|
208
|
+
}
|
|
209
|
+
hashAction(action) {
|
|
210
|
+
let hash = 0;
|
|
211
|
+
for (let i = 0; i < action.length; i++) {
|
|
212
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
213
|
+
}
|
|
214
|
+
return hash;
|
|
215
|
+
}
|
|
216
|
+
getOrCreateEntry(stateKey) {
|
|
217
|
+
let entry = this.qTable.get(stateKey);
|
|
218
|
+
if (!entry) {
|
|
219
|
+
entry = {
|
|
220
|
+
qValues: new Float32Array(this.numActions),
|
|
221
|
+
visits: 0,
|
|
222
|
+
lastUpdate: Date.now(),
|
|
223
|
+
};
|
|
224
|
+
this.qTable.set(stateKey, entry);
|
|
225
|
+
}
|
|
226
|
+
return entry;
|
|
227
|
+
}
|
|
228
|
+
expectedValue(qValues) {
|
|
229
|
+
// Expected value under epsilon-greedy policy
|
|
230
|
+
const greedyAction = this.argmax(qValues);
|
|
231
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
232
|
+
let expected = 0;
|
|
233
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
234
|
+
const prob = exploreProb + (a === greedyAction ? 1 - this.epsilon : 0);
|
|
235
|
+
expected += prob * qValues[a];
|
|
236
|
+
}
|
|
237
|
+
return expected;
|
|
238
|
+
}
|
|
239
|
+
updateTrace(stateKey, action) {
|
|
240
|
+
// Decay all traces
|
|
241
|
+
for (const [key, trace] of this.traces) {
|
|
242
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
243
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
244
|
+
}
|
|
245
|
+
const maxTrace = Math.max(...trace);
|
|
246
|
+
if (maxTrace < 0.001) {
|
|
247
|
+
this.traces.delete(key);
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
// Set current trace (replacing traces for same state-action)
|
|
251
|
+
let trace = this.traces.get(stateKey);
|
|
252
|
+
if (!trace) {
|
|
253
|
+
trace = new Float32Array(this.numActions);
|
|
254
|
+
this.traces.set(stateKey, trace);
|
|
255
|
+
}
|
|
256
|
+
trace[action] = 1.0;
|
|
257
|
+
}
|
|
258
|
+
updateWithTraces(tdError) {
|
|
259
|
+
const lr = this.config.learningRate;
|
|
260
|
+
for (const [stateKey, trace] of this.traces) {
|
|
261
|
+
const entry = this.qTable.get(stateKey);
|
|
262
|
+
if (entry) {
|
|
263
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
264
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
265
|
+
}
|
|
266
|
+
entry.visits++;
|
|
267
|
+
entry.lastUpdate = Date.now();
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
pruneQTable() {
|
|
272
|
+
const entries = Array.from(this.qTable.entries())
|
|
273
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
274
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
275
|
+
for (let i = 0; i < toRemove; i++) {
|
|
276
|
+
this.qTable.delete(entries[i][0]);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
argmax(values) {
|
|
280
|
+
let maxIdx = 0;
|
|
281
|
+
let maxVal = values[0];
|
|
282
|
+
for (let i = 1; i < values.length; i++) {
|
|
283
|
+
if (values[i] > maxVal) {
|
|
284
|
+
maxVal = values[i];
|
|
285
|
+
maxIdx = i;
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
return maxIdx;
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
/**
|
|
292
|
+
* Factory function
|
|
293
|
+
*/
|
|
294
|
+
export function createSARSA(config) {
|
|
295
|
+
return new SARSAAlgorithm(config);
|
|
296
|
+
}
|
|
297
|
+
//# sourceMappingURL=sarsa.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"sarsa.js","sourceRoot":"","sources":["sarsa.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAkBH;;GAEG;AACH,MAAM,CAAC,MAAM,oBAAoB,GAAgB;IAC/C,SAAS,EAAE,OAAO;IAClB,YAAY,EAAE,GAAG;IACjB,KAAK,EAAE,IAAI;IACX,WAAW,EAAE,CAAC;IACd,aAAa,EAAE,CAAC;IAChB,WAAW,EAAE,CAAC;IACd,MAAM,EAAE,CAAC;IACT,aAAa,EAAE,CAAC;IAChB,kBAAkB,EAAE,GAAG;IACvB,gBAAgB,EAAE,IAAI;IACtB,gBAAgB,EAAE,KAAK;IACvB,SAAS,EAAE,KAAK;IAChB,gBAAgB,EAAE,KAAK;IACvB,oBAAoB,EAAE,KAAK;IAC3B,UAAU,EAAE,GAAG;CAChB,CAAC;AAWF;;GAEG;AACH,MAAM,OAAO,cAAc;IACjB,MAAM,CAAc;IAE5B,UAAU;IACF,MAAM,GAA4B,IAAI,GAAG,EAAE,CAAC;IAEpD,cAAc;IACN,OAAO,CAAS;IAChB,SAAS,GAAG,CAAC,CAAC;IAEtB,oBAAoB;IACZ,UAAU,GAAG,CAAC,CAAC;IAEvB,qBAAqB;IACb,MAAM,GAA8B,IAAI,GAAG,EAAE,CAAC;IAEtD,aAAa;IACL,WAAW,GAAG,CAAC,CAAC;IAChB,UAAU,GAAG,CAAC,CAAC;IAEvB,YAAY,SAA+B,EAAE;QAC3C,IAAI,CAAC,MAAM,GAAG,EAAE,GAAG,oBAAoB,EAAE,GAAG,MAAM,EAAE,CAAC;QACrD,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC,kBAAkB,CAAC;IAChD,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,UAAsB;QAC3B,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YAChC,OAAO,EAAE,OAAO,EAAE,CAAC,EAAE,CAAC;QACxB,CAAC;QAED,IAAI,YAAY,GAAG,CAAC,CAAC;QAErB,2BAA2B;QAC3B,IAAI,IAAI,CAAC,MAAM,CAAC,oBAAoB,EAAE,CAAC;YACrC,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACtB,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACrD,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,QAAQ,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAEzC,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;YAClD,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAC5C,MAAM,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YACrD,MAAM,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;YAEpD,wBAAwB;YACxB,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,QAAQ,CAAC,CAAC;YAC/C,MAAM,SAAS,GAAG,IAAI,CAAC,gBAAgB,CAAC,YAAY,CAAC,CAAC;YAEtD,kBAAkB;YAClB,MAAM,QAAQ,GAAG,MAAM,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC;YAExC,iDAAiD;YACjD,IAAI,OAAe,CAAC;YAEpB,IAAI,IAAI,CAAC,MAAM,CAAC,gBAAgB,EAAE,CAAC;gBACjC,0DAA0D;gBAC1D,OAAO,GAAG,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,IAAI,CAAC,aAAa,CAAC,SAAS,CAAC,OAAO,CAAC,CAAC;YACpF,CAAC;iBAAM,CAAC;gBACN,yCAAyC;gBACzC,OAAO,GAAG,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,SAAS,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;YAC5E,CAAC;YAED,WAAW;YACX,MAAM,OAAO,GAAG,OAAO,GAAG,QAAQ,CAAC;YACnC,YAAY,IAAI,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;YAElC,IAAI,IAAI,CAAC,MAAM,CAAC,oBAAoB,EAAE,CAAC;gBACrC,IAAI,CAAC,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC,CAAC;gBACnC,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;YACjC,CAAC;iBAAM,CAAC;gBACN,MAAM,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,YAAY,GAAG,OAAO,CAAC;gBAC7D,MAAM,CAAC,MAAM,EAAE,CAAC;gBAChB,MAAM,CAAC,UAAU,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC;YACjC,CAAC;QACH,CAAC;QAED,wBAAwB;QACxB,MAAM,QAAQ,GAAG,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAC/D,MAAM,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAC,WAAW,CAAC,CAAC;QAC1D,MAAM,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;QACpD,MAAM,SAAS,GAAG,IAAI,CAAC,gBAAgB,CAAC,YAAY,CAAC,CAAC;QAEtD,MAAM,eAAe,GAAG,QAAQ,CAAC,MAAM,GAAG,SAAS,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;QACxE,SAAS,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,YAAY,GAAG,eAAe,CAAC;QAC5E,YAAY,IAAI,IAAI,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC;QAE1C,oBAAoB;QACpB,IAAI,CAAC,SAAS,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,CAAC;QAC1C,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,GAAG,CACrB,IAAI,CAAC,MAAM,CAAC,gBAAgB,EAC5B,IAAI,CAAC,MAAM,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,gBAAgB,CAC/E,CAAC;QAEF,kBAAkB;QAClB,IAAI,IAAI,CAAC,MAAM,CAAC,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC;YAC7C,IAAI,CAAC,WAAW,EAAE,CAAC;QACrB,CAAC;QAED,IAAI,CAAC,WAAW,EAAE,CAAC;QACnB,IAAI,CAAC,UAAU,GAAG,YAAY,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,CAAC;QAEzD,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,CAAC,EAAE,CAAC;YAChB,OAAO,CAAC,IAAI,CAAC,iCAAiC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QAC9E,CAAC;QAED,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,EAAE,CAAC;IACtC,CAAC;IAED;;OAEG;IACH,SAAS,CAAC,KAAmB,EAAE,UAAmB,IAAI;QACpD,IAAI,OAAO,IAAI,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,EAAE,CAAC;YAC5C,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACrD,CAAC;QAED,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAExC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACrD,CAAC;QAED,OAAO,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IACpC,CAAC;IAED;;OAEG;IACH,sBAAsB,CAAC,KAAmB;QACxC,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAChD,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAExC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,uBAAuB;YACvB,MAAM,OAAO,GAAG,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;YACpC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,KAAK,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC;YACrB,CAAC;YACD,OAAO,KAAK,CAAC;QACf,CAAC;QAED,+BAA+B;QAC/B,MAAM,YAAY,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAChD,MAAM,WAAW,GAAG,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,UAAU,CAAC;QAEnD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,KAAK,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC;QACzB,CAAC;QACD,KAAK,CAAC,YAAY,CAAC,IAAI,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC;QAExC,OAAO,KAAK,CAAC;IACf,CAAC;IAED;;OAEG;IACH,UAAU,CAAC,KAAmB;QAC5B,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAExC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,OAAO,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAC3C,CAAC;QAED,OAAO,IAAI,YAAY,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IACzC,CAAC;IAED;;OAEG;IACH,QAAQ;QACN,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,WAAW;YAC7B,UAAU,EAAE,IAAI,CAAC,MAAM,CAAC,IAAI;YAC5B,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,SAAS,EAAE,IAAI,CAAC,SAAS;SAC1B,CAAC;IACJ,CAAC;IAED;;OAEG;IACH,KAAK;QACH,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACpB,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACpB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC,kBAAkB,CAAC;QAC9C,IAAI,CAAC,SAAS,GAAG,CAAC,CAAC;QACnB,IAAI,CAAC,WAAW,GAAG,CAAC,CAAC;QACrB,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;IACtB,CAAC;IAED,6EAA6E;IAC7E,kBAAkB;IAClB,6EAA6E;IAErE,SAAS,CAAC,KAAmB;QACnC,MAAM,IAAI,GAAG,EAAE,CAAC;QAChB,MAAM,KAAK,GAAa,EAAE,CAAC;QAE3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACnD,MAAM,UAAU,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;YACtC,MAAM,GAAG,GAAG,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,EAAE,UAAU,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;YAC3E,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAClB,CAAC;QAED,OAAO,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;IACzB,CAAC;IAEO,UAAU,CAAC,MAAc;QAC/B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,GAAG,CAAC,IAAI,GAAG,EAAE,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9D,CAAC;QACD,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,gBAAgB,CAAC,QAAgB;QACvC,IAAI,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAEtC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,KAAK,GAAG;gBACN,OAAO,EAAE,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC;gBAC1C,MAAM,EAAE,CAAC;gBACT,UAAU,EAAE,IAAI,CAAC,GAAG,EAAE;aACvB,CAAC;YACF,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,EAAE,KAAK,CAAC,CAAC;QACnC,CAAC;QAED,OAAO,KAAK,CAAC;IACf,CAAC;IAEO,aAAa,CAAC,OAAqB;QACzC,6CAA6C;QAC7C,MAAM,YAAY,GAAG,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;QAC1C,MAAM,WAAW,GAAG,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,UAAU,CAAC;QAEnD,IAAI,QAAQ,GAAG,CAAC,CAAC;QACjB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,MAAM,IAAI,GAAG,WAAW,GAAG,CAAC,CAAC,KAAK,YAAY,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,QAAQ,IAAI,IAAI,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;QAChC,CAAC;QAED,OAAO,QAAQ,CAAC;IAClB,CAAC;IAEO,WAAW,CAAC,QAAgB,EAAE,MAAc;QAClD,mBAAmB;QACnB,KAAK,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,MAAM,EAAE,CAAC;YACvC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;YACzD,CAAC;YAED,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC;YACpC,IAAI,QAAQ,GAAG,KAAK,EAAE,CAAC;gBACrB,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;YAC1B,CAAC;QACH,CAAC;QAED,6DAA6D;QAC7D,IAAI,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,KAAK,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YAC1C,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,EAAE,KAAK,CAAC,CAAC;QACnC,CAAC;QACD,KAAK,CAAC,MAAM,CAAC,GAAG,GAAG,CAAC;IACtB,CAAC;IAEO,gBAAgB,CAAC,OAAe;QACtC,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QAEpC,KAAK,MAAM,CAAC,QAAQ,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,MAAM,EAAE,CAAC;YAC5C,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;YACxC,IAAI,KAAK,EAAE,CAAC;gBACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;oBACzC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,EAAE,GAAG,OAAO,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC9C,CAAC;gBACD,KAAK,CAAC,MAAM,EAAE,CAAC;gBACf,KAAK,CAAC,UAAU,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC;YAChC,CAAC;QACH,CAAC;IACH,CAAC;IAEO,WAAW;QACjB,MAAM,OAAO,GAAG,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;aAC9C,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QAErD,MAAM,QAAQ,GAAG,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,SAAS,GAAG,GAAG,CAAC,CAAC;QAC1E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACpC,CAAC;IACH,CAAC;IAEO,MAAM,CAAC,MAAoB;QACjC,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,IAAI,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC;gBACvB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACnB,MAAM,GAAG,CAAC,CAAC;YACb,CAAC;QACH,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,WAAW,CAAC,MAA6B;IACvD,OAAO,IAAI,cAAc,CAAC,MAAM,CAAC,CAAC;AACpC,CAAC"}
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SARSA (State-Action-Reward-State-Action)
|
|
3
|
+
*
|
|
4
|
+
* On-policy TD learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Expected SARSA variant (optional)
|
|
8
|
+
* - Eligibility traces (SARSA-lambda)
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <1ms per update
|
|
11
|
+
*/
|
|
12
|
+
|
|
13
|
+
import type { Trajectory, RLConfig } from '../types.js';
|
|
14
|
+
|
|
15
|
+
/**
|
|
16
|
+
* SARSA configuration
|
|
17
|
+
*/
|
|
18
|
+
export interface SARSAConfig extends RLConfig {
|
|
19
|
+
algorithm: 'sarsa';
|
|
20
|
+
explorationInitial: number;
|
|
21
|
+
explorationFinal: number;
|
|
22
|
+
explorationDecay: number;
|
|
23
|
+
maxStates: number;
|
|
24
|
+
useExpectedSARSA: boolean;
|
|
25
|
+
useEligibilityTraces: boolean;
|
|
26
|
+
traceDecay: number;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Default SARSA configuration
|
|
31
|
+
*/
|
|
32
|
+
export const DEFAULT_SARSA_CONFIG: SARSAConfig = {
|
|
33
|
+
algorithm: 'sarsa',
|
|
34
|
+
learningRate: 0.1,
|
|
35
|
+
gamma: 0.99,
|
|
36
|
+
entropyCoef: 0,
|
|
37
|
+
valueLossCoef: 1,
|
|
38
|
+
maxGradNorm: 1,
|
|
39
|
+
epochs: 1,
|
|
40
|
+
miniBatchSize: 1,
|
|
41
|
+
explorationInitial: 1.0,
|
|
42
|
+
explorationFinal: 0.01,
|
|
43
|
+
explorationDecay: 10000,
|
|
44
|
+
maxStates: 10000,
|
|
45
|
+
useExpectedSARSA: false,
|
|
46
|
+
useEligibilityTraces: false,
|
|
47
|
+
traceDecay: 0.9,
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Q-table entry
|
|
52
|
+
*/
|
|
53
|
+
interface SARSAEntry {
|
|
54
|
+
qValues: Float32Array;
|
|
55
|
+
visits: number;
|
|
56
|
+
lastUpdate: number;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/**
|
|
60
|
+
* SARSA Algorithm Implementation
|
|
61
|
+
*/
|
|
62
|
+
export class SARSAAlgorithm {
|
|
63
|
+
private config: SARSAConfig;
|
|
64
|
+
|
|
65
|
+
// Q-table
|
|
66
|
+
private qTable: Map<string, SARSAEntry> = new Map();
|
|
67
|
+
|
|
68
|
+
// Exploration
|
|
69
|
+
private epsilon: number;
|
|
70
|
+
private stepCount = 0;
|
|
71
|
+
|
|
72
|
+
// Number of actions
|
|
73
|
+
private numActions = 4;
|
|
74
|
+
|
|
75
|
+
// Eligibility traces
|
|
76
|
+
private traces: Map<string, Float32Array> = new Map();
|
|
77
|
+
|
|
78
|
+
// Statistics
|
|
79
|
+
private updateCount = 0;
|
|
80
|
+
private avgTDError = 0;
|
|
81
|
+
|
|
82
|
+
constructor(config: Partial<SARSAConfig> = {}) {
|
|
83
|
+
this.config = { ...DEFAULT_SARSA_CONFIG, ...config };
|
|
84
|
+
this.epsilon = this.config.explorationInitial;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/**
|
|
88
|
+
* Update Q-values from trajectory using SARSA
|
|
89
|
+
*/
|
|
90
|
+
update(trajectory: Trajectory): { tdError: number } {
|
|
91
|
+
const startTime = performance.now();
|
|
92
|
+
|
|
93
|
+
if (trajectory.steps.length < 2) {
|
|
94
|
+
return { tdError: 0 };
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
let totalTDError = 0;
|
|
98
|
+
|
|
99
|
+
// Reset eligibility traces
|
|
100
|
+
if (this.config.useEligibilityTraces) {
|
|
101
|
+
this.traces.clear();
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
for (let i = 0; i < trajectory.steps.length - 1; i++) {
|
|
105
|
+
const step = trajectory.steps[i];
|
|
106
|
+
const nextStep = trajectory.steps[i + 1];
|
|
107
|
+
|
|
108
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
109
|
+
const action = this.hashAction(step.action);
|
|
110
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
111
|
+
const nextAction = this.hashAction(nextStep.action);
|
|
112
|
+
|
|
113
|
+
// Get or create entries
|
|
114
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
115
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
116
|
+
|
|
117
|
+
// Current Q-value
|
|
118
|
+
const currentQ = qEntry.qValues[action];
|
|
119
|
+
|
|
120
|
+
// Compute target Q-value using SARSA update rule
|
|
121
|
+
let targetQ: number;
|
|
122
|
+
|
|
123
|
+
if (this.config.useExpectedSARSA) {
|
|
124
|
+
// Expected SARSA: use expected value under current policy
|
|
125
|
+
targetQ = step.reward + this.config.gamma * this.expectedValue(nextEntry.qValues);
|
|
126
|
+
} else {
|
|
127
|
+
// Standard SARSA: use actual next action
|
|
128
|
+
targetQ = step.reward + this.config.gamma * nextEntry.qValues[nextAction];
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// TD error
|
|
132
|
+
const tdError = targetQ - currentQ;
|
|
133
|
+
totalTDError += Math.abs(tdError);
|
|
134
|
+
|
|
135
|
+
if (this.config.useEligibilityTraces) {
|
|
136
|
+
this.updateTrace(stateKey, action);
|
|
137
|
+
this.updateWithTraces(tdError);
|
|
138
|
+
} else {
|
|
139
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
140
|
+
qEntry.visits++;
|
|
141
|
+
qEntry.lastUpdate = Date.now();
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// Handle terminal state
|
|
146
|
+
const lastStep = trajectory.steps[trajectory.steps.length - 1];
|
|
147
|
+
const lastStateKey = this.hashState(lastStep.stateBefore);
|
|
148
|
+
const lastAction = this.hashAction(lastStep.action);
|
|
149
|
+
const lastEntry = this.getOrCreateEntry(lastStateKey);
|
|
150
|
+
|
|
151
|
+
const terminalTDError = lastStep.reward - lastEntry.qValues[lastAction];
|
|
152
|
+
lastEntry.qValues[lastAction] += this.config.learningRate * terminalTDError;
|
|
153
|
+
totalTDError += Math.abs(terminalTDError);
|
|
154
|
+
|
|
155
|
+
// Decay exploration
|
|
156
|
+
this.stepCount += trajectory.steps.length;
|
|
157
|
+
this.epsilon = Math.max(
|
|
158
|
+
this.config.explorationFinal,
|
|
159
|
+
this.config.explorationInitial - this.stepCount / this.config.explorationDecay
|
|
160
|
+
);
|
|
161
|
+
|
|
162
|
+
// Prune if needed
|
|
163
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
164
|
+
this.pruneQTable();
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
this.updateCount++;
|
|
168
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
169
|
+
|
|
170
|
+
const elapsed = performance.now() - startTime;
|
|
171
|
+
if (elapsed > 1) {
|
|
172
|
+
console.warn(`SARSA update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
return { tdError: this.avgTDError };
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
/**
|
|
179
|
+
* Get action using epsilon-greedy policy
|
|
180
|
+
*/
|
|
181
|
+
getAction(state: Float32Array, explore: boolean = true): number {
|
|
182
|
+
if (explore && Math.random() < this.epsilon) {
|
|
183
|
+
return Math.floor(Math.random() * this.numActions);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
const stateKey = this.hashState(state);
|
|
187
|
+
const entry = this.qTable.get(stateKey);
|
|
188
|
+
|
|
189
|
+
if (!entry) {
|
|
190
|
+
return Math.floor(Math.random() * this.numActions);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
return this.argmax(entry.qValues);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
/**
|
|
197
|
+
* Get action probabilities for a state
|
|
198
|
+
*/
|
|
199
|
+
getActionProbabilities(state: Float32Array): Float32Array {
|
|
200
|
+
const probs = new Float32Array(this.numActions);
|
|
201
|
+
const stateKey = this.hashState(state);
|
|
202
|
+
const entry = this.qTable.get(stateKey);
|
|
203
|
+
|
|
204
|
+
if (!entry) {
|
|
205
|
+
// Uniform distribution
|
|
206
|
+
const uniform = 1 / this.numActions;
|
|
207
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
208
|
+
probs[a] = uniform;
|
|
209
|
+
}
|
|
210
|
+
return probs;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
// Epsilon-greedy probabilities
|
|
214
|
+
const greedyAction = this.argmax(entry.qValues);
|
|
215
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
216
|
+
|
|
217
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
218
|
+
probs[a] = exploreProb;
|
|
219
|
+
}
|
|
220
|
+
probs[greedyAction] += 1 - this.epsilon;
|
|
221
|
+
|
|
222
|
+
return probs;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
/**
|
|
226
|
+
* Get Q-values for a state
|
|
227
|
+
*/
|
|
228
|
+
getQValues(state: Float32Array): Float32Array {
|
|
229
|
+
const stateKey = this.hashState(state);
|
|
230
|
+
const entry = this.qTable.get(stateKey);
|
|
231
|
+
|
|
232
|
+
if (!entry) {
|
|
233
|
+
return new Float32Array(this.numActions);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
return new Float32Array(entry.qValues);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
/**
|
|
240
|
+
* Get statistics
|
|
241
|
+
*/
|
|
242
|
+
getStats(): Record<string, number> {
|
|
243
|
+
return {
|
|
244
|
+
updateCount: this.updateCount,
|
|
245
|
+
qTableSize: this.qTable.size,
|
|
246
|
+
epsilon: this.epsilon,
|
|
247
|
+
avgTDError: this.avgTDError,
|
|
248
|
+
stepCount: this.stepCount,
|
|
249
|
+
};
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
/**
|
|
253
|
+
* Reset algorithm state
|
|
254
|
+
*/
|
|
255
|
+
reset(): void {
|
|
256
|
+
this.qTable.clear();
|
|
257
|
+
this.traces.clear();
|
|
258
|
+
this.epsilon = this.config.explorationInitial;
|
|
259
|
+
this.stepCount = 0;
|
|
260
|
+
this.updateCount = 0;
|
|
261
|
+
this.avgTDError = 0;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// ==========================================================================
|
|
265
|
+
// Private Methods
|
|
266
|
+
// ==========================================================================
|
|
267
|
+
|
|
268
|
+
private hashState(state: Float32Array): string {
|
|
269
|
+
const bins = 10;
|
|
270
|
+
const parts: number[] = [];
|
|
271
|
+
|
|
272
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
273
|
+
const normalized = (state[i] + 1) / 2;
|
|
274
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
275
|
+
parts.push(bin);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
return parts.join(',');
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
private hashAction(action: string): number {
|
|
282
|
+
let hash = 0;
|
|
283
|
+
for (let i = 0; i < action.length; i++) {
|
|
284
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
285
|
+
}
|
|
286
|
+
return hash;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
private getOrCreateEntry(stateKey: string): SARSAEntry {
|
|
290
|
+
let entry = this.qTable.get(stateKey);
|
|
291
|
+
|
|
292
|
+
if (!entry) {
|
|
293
|
+
entry = {
|
|
294
|
+
qValues: new Float32Array(this.numActions),
|
|
295
|
+
visits: 0,
|
|
296
|
+
lastUpdate: Date.now(),
|
|
297
|
+
};
|
|
298
|
+
this.qTable.set(stateKey, entry);
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
return entry;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
private expectedValue(qValues: Float32Array): number {
|
|
305
|
+
// Expected value under epsilon-greedy policy
|
|
306
|
+
const greedyAction = this.argmax(qValues);
|
|
307
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
308
|
+
|
|
309
|
+
let expected = 0;
|
|
310
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
311
|
+
const prob = exploreProb + (a === greedyAction ? 1 - this.epsilon : 0);
|
|
312
|
+
expected += prob * qValues[a];
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
return expected;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
private updateTrace(stateKey: string, action: number): void {
|
|
319
|
+
// Decay all traces
|
|
320
|
+
for (const [key, trace] of this.traces) {
|
|
321
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
322
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
const maxTrace = Math.max(...trace);
|
|
326
|
+
if (maxTrace < 0.001) {
|
|
327
|
+
this.traces.delete(key);
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
// Set current trace (replacing traces for same state-action)
|
|
332
|
+
let trace = this.traces.get(stateKey);
|
|
333
|
+
if (!trace) {
|
|
334
|
+
trace = new Float32Array(this.numActions);
|
|
335
|
+
this.traces.set(stateKey, trace);
|
|
336
|
+
}
|
|
337
|
+
trace[action] = 1.0;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
private updateWithTraces(tdError: number): void {
|
|
341
|
+
const lr = this.config.learningRate;
|
|
342
|
+
|
|
343
|
+
for (const [stateKey, trace] of this.traces) {
|
|
344
|
+
const entry = this.qTable.get(stateKey);
|
|
345
|
+
if (entry) {
|
|
346
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
347
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
348
|
+
}
|
|
349
|
+
entry.visits++;
|
|
350
|
+
entry.lastUpdate = Date.now();
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
private pruneQTable(): void {
|
|
356
|
+
const entries = Array.from(this.qTable.entries())
|
|
357
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
358
|
+
|
|
359
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
360
|
+
for (let i = 0; i < toRemove; i++) {
|
|
361
|
+
this.qTable.delete(entries[i][0]);
|
|
362
|
+
}
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
private argmax(values: Float32Array): number {
|
|
366
|
+
let maxIdx = 0;
|
|
367
|
+
let maxVal = values[0];
|
|
368
|
+
for (let i = 1; i < values.length; i++) {
|
|
369
|
+
if (values[i] > maxVal) {
|
|
370
|
+
maxVal = values[i];
|
|
371
|
+
maxIdx = i;
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
return maxIdx;
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
/**
|
|
379
|
+
* Factory function
|
|
380
|
+
*/
|
|
381
|
+
export function createSARSA(config?: Partial<SARSAConfig>): SARSAAlgorithm {
|
|
382
|
+
return new SARSAAlgorithm(config);
|
|
383
|
+
}
|
|
File without changes
|