@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tabular Q-Learning
|
|
3
|
+
*
|
|
4
|
+
* Classic Q-learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Eligibility traces (optional)
|
|
8
|
+
* - Experience replay
|
|
9
|
+
*
|
|
10
|
+
* Suitable for smaller state spaces or discretized environments.
|
|
11
|
+
* Performance Target: <1ms per update
|
|
12
|
+
*/
|
|
13
|
+
/**
|
|
14
|
+
* Default Q-Learning configuration
|
|
15
|
+
*/
|
|
16
|
+
export const DEFAULT_QLEARNING_CONFIG = {
|
|
17
|
+
algorithm: 'q-learning',
|
|
18
|
+
learningRate: 0.1,
|
|
19
|
+
gamma: 0.99,
|
|
20
|
+
entropyCoef: 0,
|
|
21
|
+
valueLossCoef: 1,
|
|
22
|
+
maxGradNorm: 1,
|
|
23
|
+
epochs: 1,
|
|
24
|
+
miniBatchSize: 1,
|
|
25
|
+
explorationInitial: 1.0,
|
|
26
|
+
explorationFinal: 0.01,
|
|
27
|
+
explorationDecay: 10000,
|
|
28
|
+
maxStates: 10000,
|
|
29
|
+
useEligibilityTraces: false,
|
|
30
|
+
traceDecay: 0.9,
|
|
31
|
+
};
|
|
32
|
+
/**
|
|
33
|
+
* Q-Learning Algorithm Implementation
|
|
34
|
+
*/
|
|
35
|
+
export class QLearning {
|
|
36
|
+
config;
|
|
37
|
+
// Q-table
|
|
38
|
+
qTable = new Map();
|
|
39
|
+
// Exploration
|
|
40
|
+
epsilon;
|
|
41
|
+
stepCount = 0;
|
|
42
|
+
// Number of actions
|
|
43
|
+
numActions = 4;
|
|
44
|
+
// Eligibility traces
|
|
45
|
+
traces = new Map();
|
|
46
|
+
// Statistics
|
|
47
|
+
updateCount = 0;
|
|
48
|
+
avgTDError = 0;
|
|
49
|
+
constructor(config = {}) {
|
|
50
|
+
this.config = { ...DEFAULT_QLEARNING_CONFIG, ...config };
|
|
51
|
+
this.epsilon = this.config.explorationInitial;
|
|
52
|
+
}
|
|
53
|
+
/**
|
|
54
|
+
* Update Q-values from trajectory
|
|
55
|
+
*/
|
|
56
|
+
update(trajectory) {
|
|
57
|
+
const startTime = performance.now();
|
|
58
|
+
if (trajectory.steps.length === 0) {
|
|
59
|
+
return { tdError: 0 };
|
|
60
|
+
}
|
|
61
|
+
let totalTDError = 0;
|
|
62
|
+
// Reset eligibility traces for new trajectory
|
|
63
|
+
if (this.config.useEligibilityTraces) {
|
|
64
|
+
this.traces.clear();
|
|
65
|
+
}
|
|
66
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
67
|
+
const step = trajectory.steps[i];
|
|
68
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
69
|
+
const action = this.hashAction(step.action);
|
|
70
|
+
// Get or create Q-entry
|
|
71
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
72
|
+
// Current Q-value
|
|
73
|
+
const currentQ = qEntry.qValues[action];
|
|
74
|
+
// Compute target Q-value
|
|
75
|
+
let targetQ;
|
|
76
|
+
if (i === trajectory.steps.length - 1) {
|
|
77
|
+
// Terminal state
|
|
78
|
+
targetQ = step.reward;
|
|
79
|
+
}
|
|
80
|
+
else {
|
|
81
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
82
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
83
|
+
const maxNextQ = Math.max(...nextEntry.qValues);
|
|
84
|
+
targetQ = step.reward + this.config.gamma * maxNextQ;
|
|
85
|
+
}
|
|
86
|
+
// TD error
|
|
87
|
+
const tdError = targetQ - currentQ;
|
|
88
|
+
totalTDError += Math.abs(tdError);
|
|
89
|
+
if (this.config.useEligibilityTraces) {
|
|
90
|
+
// Update eligibility trace
|
|
91
|
+
this.updateTrace(stateKey, action);
|
|
92
|
+
// Update all states with traces
|
|
93
|
+
this.updateWithTraces(tdError);
|
|
94
|
+
}
|
|
95
|
+
else {
|
|
96
|
+
// Simple Q-learning update
|
|
97
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
98
|
+
qEntry.visits++;
|
|
99
|
+
qEntry.lastUpdate = Date.now();
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
// Decay exploration
|
|
103
|
+
this.stepCount += trajectory.steps.length;
|
|
104
|
+
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
105
|
+
// Prune Q-table if too large
|
|
106
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
107
|
+
this.pruneQTable();
|
|
108
|
+
}
|
|
109
|
+
this.updateCount++;
|
|
110
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
111
|
+
const elapsed = performance.now() - startTime;
|
|
112
|
+
if (elapsed > 1) {
|
|
113
|
+
console.warn(`Q-learning update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
114
|
+
}
|
|
115
|
+
return { tdError: this.avgTDError };
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Get action using epsilon-greedy policy
|
|
119
|
+
*/
|
|
120
|
+
getAction(state, explore = true) {
|
|
121
|
+
if (explore && Math.random() < this.epsilon) {
|
|
122
|
+
return Math.floor(Math.random() * this.numActions);
|
|
123
|
+
}
|
|
124
|
+
const stateKey = this.hashState(state);
|
|
125
|
+
const entry = this.qTable.get(stateKey);
|
|
126
|
+
if (!entry) {
|
|
127
|
+
return Math.floor(Math.random() * this.numActions);
|
|
128
|
+
}
|
|
129
|
+
return this.argmax(entry.qValues);
|
|
130
|
+
}
|
|
131
|
+
/**
|
|
132
|
+
* Get Q-values for a state
|
|
133
|
+
*/
|
|
134
|
+
getQValues(state) {
|
|
135
|
+
const stateKey = this.hashState(state);
|
|
136
|
+
const entry = this.qTable.get(stateKey);
|
|
137
|
+
if (!entry) {
|
|
138
|
+
return new Float32Array(this.numActions);
|
|
139
|
+
}
|
|
140
|
+
return new Float32Array(entry.qValues);
|
|
141
|
+
}
|
|
142
|
+
/**
|
|
143
|
+
* Get statistics
|
|
144
|
+
*/
|
|
145
|
+
getStats() {
|
|
146
|
+
return {
|
|
147
|
+
updateCount: this.updateCount,
|
|
148
|
+
qTableSize: this.qTable.size,
|
|
149
|
+
epsilon: this.epsilon,
|
|
150
|
+
avgTDError: this.avgTDError,
|
|
151
|
+
stepCount: this.stepCount,
|
|
152
|
+
};
|
|
153
|
+
}
|
|
154
|
+
/**
|
|
155
|
+
* Reset Q-table
|
|
156
|
+
*/
|
|
157
|
+
reset() {
|
|
158
|
+
this.qTable.clear();
|
|
159
|
+
this.traces.clear();
|
|
160
|
+
this.epsilon = this.config.explorationInitial;
|
|
161
|
+
this.stepCount = 0;
|
|
162
|
+
this.updateCount = 0;
|
|
163
|
+
this.avgTDError = 0;
|
|
164
|
+
}
|
|
165
|
+
// ==========================================================================
|
|
166
|
+
// Private Methods
|
|
167
|
+
// ==========================================================================
|
|
168
|
+
hashState(state) {
|
|
169
|
+
// Discretize state by binning values
|
|
170
|
+
const bins = 10;
|
|
171
|
+
const parts = [];
|
|
172
|
+
// Use first 8 dimensions for hashing
|
|
173
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
174
|
+
const normalized = (state[i] + 1) / 2; // Assume [-1, 1] range
|
|
175
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
176
|
+
parts.push(bin);
|
|
177
|
+
}
|
|
178
|
+
return parts.join(',');
|
|
179
|
+
}
|
|
180
|
+
hashAction(action) {
|
|
181
|
+
let hash = 0;
|
|
182
|
+
for (let i = 0; i < action.length; i++) {
|
|
183
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
184
|
+
}
|
|
185
|
+
return hash;
|
|
186
|
+
}
|
|
187
|
+
getOrCreateEntry(stateKey) {
|
|
188
|
+
let entry = this.qTable.get(stateKey);
|
|
189
|
+
if (!entry) {
|
|
190
|
+
entry = {
|
|
191
|
+
qValues: new Float32Array(this.numActions),
|
|
192
|
+
visits: 0,
|
|
193
|
+
lastUpdate: Date.now(),
|
|
194
|
+
};
|
|
195
|
+
this.qTable.set(stateKey, entry);
|
|
196
|
+
}
|
|
197
|
+
return entry;
|
|
198
|
+
}
|
|
199
|
+
updateTrace(stateKey, action) {
|
|
200
|
+
// Decay all existing traces
|
|
201
|
+
for (const [key, trace] of this.traces) {
|
|
202
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
203
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
204
|
+
}
|
|
205
|
+
// Remove near-zero traces
|
|
206
|
+
const maxTrace = Math.max(...trace);
|
|
207
|
+
if (maxTrace < 0.001) {
|
|
208
|
+
this.traces.delete(key);
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
// Set trace for current state-action
|
|
212
|
+
let trace = this.traces.get(stateKey);
|
|
213
|
+
if (!trace) {
|
|
214
|
+
trace = new Float32Array(this.numActions);
|
|
215
|
+
this.traces.set(stateKey, trace);
|
|
216
|
+
}
|
|
217
|
+
trace[action] = 1.0;
|
|
218
|
+
}
|
|
219
|
+
updateWithTraces(tdError) {
|
|
220
|
+
const lr = this.config.learningRate;
|
|
221
|
+
for (const [stateKey, trace] of this.traces) {
|
|
222
|
+
const entry = this.qTable.get(stateKey);
|
|
223
|
+
if (entry) {
|
|
224
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
225
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
226
|
+
}
|
|
227
|
+
entry.visits++;
|
|
228
|
+
entry.lastUpdate = Date.now();
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
pruneQTable() {
|
|
233
|
+
// Remove least recently used states
|
|
234
|
+
const entries = Array.from(this.qTable.entries())
|
|
235
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
236
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
237
|
+
for (let i = 0; i < toRemove; i++) {
|
|
238
|
+
this.qTable.delete(entries[i][0]);
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
argmax(values) {
|
|
242
|
+
let maxIdx = 0;
|
|
243
|
+
let maxVal = values[0];
|
|
244
|
+
for (let i = 1; i < values.length; i++) {
|
|
245
|
+
if (values[i] > maxVal) {
|
|
246
|
+
maxVal = values[i];
|
|
247
|
+
maxIdx = i;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
return maxIdx;
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
/**
|
|
254
|
+
* Factory function
|
|
255
|
+
*/
|
|
256
|
+
export function createQLearning(config) {
|
|
257
|
+
return new QLearning(config);
|
|
258
|
+
}
|
|
259
|
+
//# sourceMappingURL=q-learning.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"q-learning.js","sourceRoot":"","sources":["q-learning.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAiBH;;GAEG;AACH,MAAM,CAAC,MAAM,wBAAwB,GAAoB;IACvD,SAAS,EAAE,YAAY;IACvB,YAAY,EAAE,GAAG;IACjB,KAAK,EAAE,IAAI;IACX,WAAW,EAAE,CAAC;IACd,aAAa,EAAE,CAAC;IAChB,WAAW,EAAE,CAAC;IACd,MAAM,EAAE,CAAC;IACT,aAAa,EAAE,CAAC;IAChB,kBAAkB,EAAE,GAAG;IACvB,gBAAgB,EAAE,IAAI;IACtB,gBAAgB,EAAE,KAAK;IACvB,SAAS,EAAE,KAAK;IAChB,oBAAoB,EAAE,KAAK;IAC3B,UAAU,EAAE,GAAG;CAChB,CAAC;AAWF;;GAEG;AACH,MAAM,OAAO,SAAS;IACZ,MAAM,CAAkB;IAEhC,UAAU;IACF,MAAM,GAAwB,IAAI,GAAG,EAAE,CAAC;IAEhD,cAAc;IACN,OAAO,CAAS;IAChB,SAAS,GAAG,CAAC,CAAC;IAEtB,oBAAoB;IACZ,UAAU,GAAG,CAAC,CAAC;IAEvB,qBAAqB;IACb,MAAM,GAA8B,IAAI,GAAG,EAAE,CAAC;IAEtD,aAAa;IACL,WAAW,GAAG,CAAC,CAAC;IAChB,UAAU,GAAG,CAAC,CAAC;IAEvB,YAAY,SAAmC,EAAE;QAC/C,IAAI,CAAC,MAAM,GAAG,EAAE,GAAG,wBAAwB,EAAE,GAAG,MAAM,EAAE,CAAC;QACzD,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC,kBAAkB,CAAC;IAChD,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,UAAsB;QAC3B,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YAClC,OAAO,EAAE,OAAO,EAAE,CAAC,EAAE,CAAC;QACxB,CAAC;QAED,IAAI,YAAY,GAAG,CAAC,CAAC;QAErB,8CAA8C;QAC9C,IAAI,IAAI,CAAC,MAAM,CAAC,oBAAoB,EAAE,CAAC;YACrC,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACtB,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACjD,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;YAClD,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAE5C,wBAAwB;YACxB,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,QAAQ,CAAC,CAAC;YAE/C,kBAAkB;YAClB,MAAM,QAAQ,GAAG,MAAM,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC;YAExC,yBAAyB;YACzB,IAAI,OAAe,CAAC;YACpB,IAAI,CAAC,KAAK,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;gBACtC,iBAAiB;gBACjB,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC;YACxB,CAAC;iBAAM,CAAC;gBACN,MAAM,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;gBACrD,MAAM,SAAS,GAAG,IAAI,CAAC,gBAAgB,CAAC,YAAY,CAAC,CAAC;gBACtD,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,SAAS,CAAC,OAAO,CAAC,CAAC;gBAChD,OAAO,GAAG,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,QAAQ,CAAC;YACvD,CAAC;YAED,WAAW;YACX,MAAM,OAAO,GAAG,OAAO,GAAG,QAAQ,CAAC;YACnC,YAAY,IAAI,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;YAElC,IAAI,IAAI,CAAC,MAAM,CAAC,oBAAoB,EAAE,CAAC;gBACrC,2BAA2B;gBAC3B,IAAI,CAAC,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC,CAAC;gBAEnC,gCAAgC;gBAChC,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;YACjC,CAAC;iBAAM,CAAC;gBACN,2BAA2B;gBAC3B,MAAM,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,YAAY,GAAG,OAAO,CAAC;gBAC7D,MAAM,CAAC,MAAM,EAAE,CAAC;gBAChB,MAAM,CAAC,UAAU,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC;YACjC,CAAC;QACH,CAAC;QAED,oBAAoB;QACpB,IAAI,CAAC,SAAS,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,CAAC;QAC1C,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,GAAG,CACrB,IAAI,CAAC,MAAM,CAAC,gBAAgB,EAC5B,IAAI,CAAC,MAAM,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,gBAAgB,CAC/E,CAAC;QAEF,6BAA6B;QAC7B,IAAI,IAAI,CAAC,MAAM,CAAC,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC;YAC7C,IAAI,CAAC,WAAW,EAAE,CAAC;QACrB,CAAC;QAED,IAAI,CAAC,WAAW,EAAE,CAAC;QACnB,IAAI,CAAC,UAAU,GAAG,YAAY,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,CAAC;QAEzD,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,CAAC,EAAE,CAAC;YAChB,OAAO,CAAC,IAAI,CAAC,sCAAsC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QACnF,CAAC;QAED,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,EAAE,CAAC;IACtC,CAAC;IAED;;OAEG;IACH,SAAS,CAAC,KAAmB,EAAE,UAAmB,IAAI;QACpD,IAAI,OAAO,IAAI,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,EAAE,CAAC;YAC5C,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACrD,CAAC;QAED,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAExC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACrD,CAAC;QAED,OAAO,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IACpC,CAAC;IAED;;OAEG;IACH,UAAU,CAAC,KAAmB;QAC5B,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAExC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,OAAO,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAC3C,CAAC;QAED,OAAO,IAAI,YAAY,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IACzC,CAAC;IAED;;OAEG;IACH,QAAQ;QACN,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,WAAW;YAC7B,UAAU,EAAE,IAAI,CAAC,MAAM,CAAC,IAAI;YAC5B,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,SAAS,EAAE,IAAI,CAAC,SAAS;SAC1B,CAAC;IACJ,CAAC;IAED;;OAEG;IACH,KAAK;QACH,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACpB,IAAI,CAAC,MAAM,CAAC,KAAK,EAAE,CAAC;QACpB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC,kBAAkB,CAAC;QAC9C,IAAI,CAAC,SAAS,GAAG,CAAC,CAAC;QACnB,IAAI,CAAC,WAAW,GAAG,CAAC,CAAC;QACrB,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;IACtB,CAAC;IAED,6EAA6E;IAC7E,kBAAkB;IAClB,6EAA6E;IAErE,SAAS,CAAC,KAAmB;QACnC,qCAAqC;QACrC,MAAM,IAAI,GAAG,EAAE,CAAC;QAChB,MAAM,KAAK,GAAa,EAAE,CAAC;QAE3B,qCAAqC;QACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACnD,MAAM,UAAU,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,uBAAuB;YAC9D,MAAM,GAAG,GAAG,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,EAAE,UAAU,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;YAC3E,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAClB,CAAC;QAED,OAAO,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;IACzB,CAAC;IAEO,UAAU,CAAC,MAAc;QAC/B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,GAAG,CAAC,IAAI,GAAG,EAAE,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9D,CAAC;QACD,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,gBAAgB,CAAC,QAAgB;QACvC,IAAI,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAEtC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,KAAK,GAAG;gBACN,OAAO,EAAE,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC;gBAC1C,MAAM,EAAE,CAAC;gBACT,UAAU,EAAE,IAAI,CAAC,GAAG,EAAE;aACvB,CAAC;YACF,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,EAAE,KAAK,CAAC,CAAC;QACnC,CAAC;QAED,OAAO,KAAK,CAAC;IACf,CAAC;IAEO,WAAW,CAAC,QAAgB,EAAE,MAAc;QAClD,4BAA4B;QAC5B,KAAK,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,MAAM,EAAE,CAAC;YACvC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;YACzD,CAAC;YAED,0BAA0B;YAC1B,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC;YACpC,IAAI,QAAQ,GAAG,KAAK,EAAE,CAAC;gBACrB,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;YAC1B,CAAC;QACH,CAAC;QAED,qCAAqC;QACrC,IAAI,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,EAAE,CAAC;YACX,KAAK,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YAC1C,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,EAAE,KAAK,CAAC,CAAC;QACnC,CAAC;QACD,KAAK,CAAC,MAAM,CAAC,GAAG,GAAG,CAAC;IACtB,CAAC;IAEO,gBAAgB,CAAC,OAAe;QACtC,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QAEpC,KAAK,MAAM,CAAC,QAAQ,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,MAAM,EAAE,CAAC;YAC5C,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;YACxC,IAAI,KAAK,EAAE,CAAC;gBACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;oBACzC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,EAAE,GAAG,OAAO,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC9C,CAAC;gBACD,KAAK,CAAC,MAAM,EAAE,CAAC;gBACf,KAAK,CAAC,UAAU,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC;YAChC,CAAC;QACH,CAAC;IACH,CAAC;IAEO,WAAW;QACjB,oCAAoC;QACpC,MAAM,OAAO,GAAG,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;aAC9C,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QAErD,MAAM,QAAQ,GAAG,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,SAAS,GAAG,GAAG,CAAC,CAAC;QAC1E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACpC,CAAC;IACH,CAAC;IAEO,MAAM,CAAC,MAAoB;QACjC,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,IAAI,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC;gBACvB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACnB,MAAM,GAAG,CAAC,CAAC;YACb,CAAC;QACH,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,eAAe,CAAC,MAAiC;IAC/D,OAAO,IAAI,SAAS,CAAC,MAAM,CAAC,CAAC;AAC/B,CAAC"}
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tabular Q-Learning
|
|
3
|
+
*
|
|
4
|
+
* Classic Q-learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Eligibility traces (optional)
|
|
8
|
+
* - Experience replay
|
|
9
|
+
*
|
|
10
|
+
* Suitable for smaller state spaces or discretized environments.
|
|
11
|
+
* Performance Target: <1ms per update
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import type { Trajectory, RLConfig } from '../types.js';
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Q-Learning configuration
|
|
18
|
+
*/
|
|
19
|
+
export interface QLearningConfig extends RLConfig {
|
|
20
|
+
algorithm: 'q-learning';
|
|
21
|
+
explorationInitial: number;
|
|
22
|
+
explorationFinal: number;
|
|
23
|
+
explorationDecay: number;
|
|
24
|
+
maxStates: number;
|
|
25
|
+
useEligibilityTraces: boolean;
|
|
26
|
+
traceDecay: number;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Default Q-Learning configuration
|
|
31
|
+
*/
|
|
32
|
+
export const DEFAULT_QLEARNING_CONFIG: QLearningConfig = {
|
|
33
|
+
algorithm: 'q-learning',
|
|
34
|
+
learningRate: 0.1,
|
|
35
|
+
gamma: 0.99,
|
|
36
|
+
entropyCoef: 0,
|
|
37
|
+
valueLossCoef: 1,
|
|
38
|
+
maxGradNorm: 1,
|
|
39
|
+
epochs: 1,
|
|
40
|
+
miniBatchSize: 1,
|
|
41
|
+
explorationInitial: 1.0,
|
|
42
|
+
explorationFinal: 0.01,
|
|
43
|
+
explorationDecay: 10000,
|
|
44
|
+
maxStates: 10000,
|
|
45
|
+
useEligibilityTraces: false,
|
|
46
|
+
traceDecay: 0.9,
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
/**
|
|
50
|
+
* Q-table entry
|
|
51
|
+
*/
|
|
52
|
+
interface QEntry {
|
|
53
|
+
qValues: Float32Array;
|
|
54
|
+
visits: number;
|
|
55
|
+
lastUpdate: number;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Q-Learning Algorithm Implementation
|
|
60
|
+
*/
|
|
61
|
+
export class QLearning {
|
|
62
|
+
private config: QLearningConfig;
|
|
63
|
+
|
|
64
|
+
// Q-table
|
|
65
|
+
private qTable: Map<string, QEntry> = new Map();
|
|
66
|
+
|
|
67
|
+
// Exploration
|
|
68
|
+
private epsilon: number;
|
|
69
|
+
private stepCount = 0;
|
|
70
|
+
|
|
71
|
+
// Number of actions
|
|
72
|
+
private numActions = 4;
|
|
73
|
+
|
|
74
|
+
// Eligibility traces
|
|
75
|
+
private traces: Map<string, Float32Array> = new Map();
|
|
76
|
+
|
|
77
|
+
// Statistics
|
|
78
|
+
private updateCount = 0;
|
|
79
|
+
private avgTDError = 0;
|
|
80
|
+
|
|
81
|
+
constructor(config: Partial<QLearningConfig> = {}) {
|
|
82
|
+
this.config = { ...DEFAULT_QLEARNING_CONFIG, ...config };
|
|
83
|
+
this.epsilon = this.config.explorationInitial;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Update Q-values from trajectory
|
|
88
|
+
*/
|
|
89
|
+
update(trajectory: Trajectory): { tdError: number } {
|
|
90
|
+
const startTime = performance.now();
|
|
91
|
+
|
|
92
|
+
if (trajectory.steps.length === 0) {
|
|
93
|
+
return { tdError: 0 };
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
let totalTDError = 0;
|
|
97
|
+
|
|
98
|
+
// Reset eligibility traces for new trajectory
|
|
99
|
+
if (this.config.useEligibilityTraces) {
|
|
100
|
+
this.traces.clear();
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
104
|
+
const step = trajectory.steps[i];
|
|
105
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
106
|
+
const action = this.hashAction(step.action);
|
|
107
|
+
|
|
108
|
+
// Get or create Q-entry
|
|
109
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
110
|
+
|
|
111
|
+
// Current Q-value
|
|
112
|
+
const currentQ = qEntry.qValues[action];
|
|
113
|
+
|
|
114
|
+
// Compute target Q-value
|
|
115
|
+
let targetQ: number;
|
|
116
|
+
if (i === trajectory.steps.length - 1) {
|
|
117
|
+
// Terminal state
|
|
118
|
+
targetQ = step.reward;
|
|
119
|
+
} else {
|
|
120
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
121
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
122
|
+
const maxNextQ = Math.max(...nextEntry.qValues);
|
|
123
|
+
targetQ = step.reward + this.config.gamma * maxNextQ;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// TD error
|
|
127
|
+
const tdError = targetQ - currentQ;
|
|
128
|
+
totalTDError += Math.abs(tdError);
|
|
129
|
+
|
|
130
|
+
if (this.config.useEligibilityTraces) {
|
|
131
|
+
// Update eligibility trace
|
|
132
|
+
this.updateTrace(stateKey, action);
|
|
133
|
+
|
|
134
|
+
// Update all states with traces
|
|
135
|
+
this.updateWithTraces(tdError);
|
|
136
|
+
} else {
|
|
137
|
+
// Simple Q-learning update
|
|
138
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
139
|
+
qEntry.visits++;
|
|
140
|
+
qEntry.lastUpdate = Date.now();
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// Decay exploration
|
|
145
|
+
this.stepCount += trajectory.steps.length;
|
|
146
|
+
this.epsilon = Math.max(
|
|
147
|
+
this.config.explorationFinal,
|
|
148
|
+
this.config.explorationInitial - this.stepCount / this.config.explorationDecay
|
|
149
|
+
);
|
|
150
|
+
|
|
151
|
+
// Prune Q-table if too large
|
|
152
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
153
|
+
this.pruneQTable();
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
this.updateCount++;
|
|
157
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
158
|
+
|
|
159
|
+
const elapsed = performance.now() - startTime;
|
|
160
|
+
if (elapsed > 1) {
|
|
161
|
+
console.warn(`Q-learning update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
return { tdError: this.avgTDError };
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/**
|
|
168
|
+
* Get action using epsilon-greedy policy
|
|
169
|
+
*/
|
|
170
|
+
getAction(state: Float32Array, explore: boolean = true): number {
|
|
171
|
+
if (explore && Math.random() < this.epsilon) {
|
|
172
|
+
return Math.floor(Math.random() * this.numActions);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
const stateKey = this.hashState(state);
|
|
176
|
+
const entry = this.qTable.get(stateKey);
|
|
177
|
+
|
|
178
|
+
if (!entry) {
|
|
179
|
+
return Math.floor(Math.random() * this.numActions);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
return this.argmax(entry.qValues);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
/**
|
|
186
|
+
* Get Q-values for a state
|
|
187
|
+
*/
|
|
188
|
+
getQValues(state: Float32Array): Float32Array {
|
|
189
|
+
const stateKey = this.hashState(state);
|
|
190
|
+
const entry = this.qTable.get(stateKey);
|
|
191
|
+
|
|
192
|
+
if (!entry) {
|
|
193
|
+
return new Float32Array(this.numActions);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
return new Float32Array(entry.qValues);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
/**
|
|
200
|
+
* Get statistics
|
|
201
|
+
*/
|
|
202
|
+
getStats(): Record<string, number> {
|
|
203
|
+
return {
|
|
204
|
+
updateCount: this.updateCount,
|
|
205
|
+
qTableSize: this.qTable.size,
|
|
206
|
+
epsilon: this.epsilon,
|
|
207
|
+
avgTDError: this.avgTDError,
|
|
208
|
+
stepCount: this.stepCount,
|
|
209
|
+
};
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
/**
|
|
213
|
+
* Reset Q-table
|
|
214
|
+
*/
|
|
215
|
+
reset(): void {
|
|
216
|
+
this.qTable.clear();
|
|
217
|
+
this.traces.clear();
|
|
218
|
+
this.epsilon = this.config.explorationInitial;
|
|
219
|
+
this.stepCount = 0;
|
|
220
|
+
this.updateCount = 0;
|
|
221
|
+
this.avgTDError = 0;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// ==========================================================================
|
|
225
|
+
// Private Methods
|
|
226
|
+
// ==========================================================================
|
|
227
|
+
|
|
228
|
+
private hashState(state: Float32Array): string {
|
|
229
|
+
// Discretize state by binning values
|
|
230
|
+
const bins = 10;
|
|
231
|
+
const parts: number[] = [];
|
|
232
|
+
|
|
233
|
+
// Use first 8 dimensions for hashing
|
|
234
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
235
|
+
const normalized = (state[i] + 1) / 2; // Assume [-1, 1] range
|
|
236
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
237
|
+
parts.push(bin);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return parts.join(',');
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
private hashAction(action: string): number {
|
|
244
|
+
let hash = 0;
|
|
245
|
+
for (let i = 0; i < action.length; i++) {
|
|
246
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
247
|
+
}
|
|
248
|
+
return hash;
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
private getOrCreateEntry(stateKey: string): QEntry {
|
|
252
|
+
let entry = this.qTable.get(stateKey);
|
|
253
|
+
|
|
254
|
+
if (!entry) {
|
|
255
|
+
entry = {
|
|
256
|
+
qValues: new Float32Array(this.numActions),
|
|
257
|
+
visits: 0,
|
|
258
|
+
lastUpdate: Date.now(),
|
|
259
|
+
};
|
|
260
|
+
this.qTable.set(stateKey, entry);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
return entry;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
private updateTrace(stateKey: string, action: number): void {
|
|
267
|
+
// Decay all existing traces
|
|
268
|
+
for (const [key, trace] of this.traces) {
|
|
269
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
270
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
// Remove near-zero traces
|
|
274
|
+
const maxTrace = Math.max(...trace);
|
|
275
|
+
if (maxTrace < 0.001) {
|
|
276
|
+
this.traces.delete(key);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Set trace for current state-action
|
|
281
|
+
let trace = this.traces.get(stateKey);
|
|
282
|
+
if (!trace) {
|
|
283
|
+
trace = new Float32Array(this.numActions);
|
|
284
|
+
this.traces.set(stateKey, trace);
|
|
285
|
+
}
|
|
286
|
+
trace[action] = 1.0;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
private updateWithTraces(tdError: number): void {
|
|
290
|
+
const lr = this.config.learningRate;
|
|
291
|
+
|
|
292
|
+
for (const [stateKey, trace] of this.traces) {
|
|
293
|
+
const entry = this.qTable.get(stateKey);
|
|
294
|
+
if (entry) {
|
|
295
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
296
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
297
|
+
}
|
|
298
|
+
entry.visits++;
|
|
299
|
+
entry.lastUpdate = Date.now();
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
private pruneQTable(): void {
|
|
305
|
+
// Remove least recently used states
|
|
306
|
+
const entries = Array.from(this.qTable.entries())
|
|
307
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
308
|
+
|
|
309
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
310
|
+
for (let i = 0; i < toRemove; i++) {
|
|
311
|
+
this.qTable.delete(entries[i][0]);
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
private argmax(values: Float32Array): number {
|
|
316
|
+
let maxIdx = 0;
|
|
317
|
+
let maxVal = values[0];
|
|
318
|
+
for (let i = 1; i < values.length; i++) {
|
|
319
|
+
if (values[i] > maxVal) {
|
|
320
|
+
maxVal = values[i];
|
|
321
|
+
maxIdx = i;
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
return maxIdx;
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
/**
|
|
329
|
+
* Factory function
|
|
330
|
+
*/
|
|
331
|
+
export function createQLearning(config?: Partial<QLearningConfig>): QLearning {
|
|
332
|
+
return new QLearning(config);
|
|
333
|
+
}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SARSA (State-Action-Reward-State-Action)
|
|
3
|
+
*
|
|
4
|
+
* On-policy TD learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Expected SARSA variant (optional)
|
|
8
|
+
* - Eligibility traces (SARSA-lambda)
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <1ms per update
|
|
11
|
+
*/
|
|
12
|
+
import type { Trajectory, RLConfig } from '../types.js';
|
|
13
|
+
/**
|
|
14
|
+
* SARSA configuration
|
|
15
|
+
*/
|
|
16
|
+
export interface SARSAConfig extends RLConfig {
|
|
17
|
+
algorithm: 'sarsa';
|
|
18
|
+
explorationInitial: number;
|
|
19
|
+
explorationFinal: number;
|
|
20
|
+
explorationDecay: number;
|
|
21
|
+
maxStates: number;
|
|
22
|
+
useExpectedSARSA: boolean;
|
|
23
|
+
useEligibilityTraces: boolean;
|
|
24
|
+
traceDecay: number;
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Default SARSA configuration
|
|
28
|
+
*/
|
|
29
|
+
export declare const DEFAULT_SARSA_CONFIG: SARSAConfig;
|
|
30
|
+
/**
|
|
31
|
+
* SARSA Algorithm Implementation
|
|
32
|
+
*/
|
|
33
|
+
export declare class SARSAAlgorithm {
|
|
34
|
+
private config;
|
|
35
|
+
private qTable;
|
|
36
|
+
private epsilon;
|
|
37
|
+
private stepCount;
|
|
38
|
+
private numActions;
|
|
39
|
+
private traces;
|
|
40
|
+
private updateCount;
|
|
41
|
+
private avgTDError;
|
|
42
|
+
constructor(config?: Partial<SARSAConfig>);
|
|
43
|
+
/**
|
|
44
|
+
* Update Q-values from trajectory using SARSA
|
|
45
|
+
*/
|
|
46
|
+
update(trajectory: Trajectory): {
|
|
47
|
+
tdError: number;
|
|
48
|
+
};
|
|
49
|
+
/**
|
|
50
|
+
* Get action using epsilon-greedy policy
|
|
51
|
+
*/
|
|
52
|
+
getAction(state: Float32Array, explore?: boolean): number;
|
|
53
|
+
/**
|
|
54
|
+
* Get action probabilities for a state
|
|
55
|
+
*/
|
|
56
|
+
getActionProbabilities(state: Float32Array): Float32Array;
|
|
57
|
+
/**
|
|
58
|
+
* Get Q-values for a state
|
|
59
|
+
*/
|
|
60
|
+
getQValues(state: Float32Array): Float32Array;
|
|
61
|
+
/**
|
|
62
|
+
* Get statistics
|
|
63
|
+
*/
|
|
64
|
+
getStats(): Record<string, number>;
|
|
65
|
+
/**
|
|
66
|
+
* Reset algorithm state
|
|
67
|
+
*/
|
|
68
|
+
reset(): void;
|
|
69
|
+
private hashState;
|
|
70
|
+
private hashAction;
|
|
71
|
+
private getOrCreateEntry;
|
|
72
|
+
private expectedValue;
|
|
73
|
+
private updateTrace;
|
|
74
|
+
private updateWithTraces;
|
|
75
|
+
private pruneQTable;
|
|
76
|
+
private argmax;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* Factory function
|
|
80
|
+
*/
|
|
81
|
+
export declare function createSARSA(config?: Partial<SARSAConfig>): SARSAAlgorithm;
|
|
82
|
+
//# sourceMappingURL=sarsa.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"sarsa.d.ts","sourceRoot":"","sources":["sarsa.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAEH,OAAO,KAAK,EAAE,UAAU,EAAE,QAAQ,EAAE,MAAM,aAAa,CAAC;AAExD;;GAEG;AACH,MAAM,WAAW,WAAY,SAAQ,QAAQ;IAC3C,SAAS,EAAE,OAAO,CAAC;IACnB,kBAAkB,EAAE,MAAM,CAAC;IAC3B,gBAAgB,EAAE,MAAM,CAAC;IACzB,gBAAgB,EAAE,MAAM,CAAC;IACzB,SAAS,EAAE,MAAM,CAAC;IAClB,gBAAgB,EAAE,OAAO,CAAC;IAC1B,oBAAoB,EAAE,OAAO,CAAC;IAC9B,UAAU,EAAE,MAAM,CAAC;CACpB;AAED;;GAEG;AACH,eAAO,MAAM,oBAAoB,EAAE,WAgBlC,CAAC;AAWF;;GAEG;AACH,qBAAa,cAAc;IACzB,OAAO,CAAC,MAAM,CAAc;IAG5B,OAAO,CAAC,MAAM,CAAsC;IAGpD,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,UAAU,CAAK;IAGvB,OAAO,CAAC,MAAM,CAAwC;IAGtD,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,UAAU,CAAK;gBAEX,MAAM,GAAE,OAAO,CAAC,WAAW,CAAM;IAK7C;;OAEG;IACH,MAAM,CAAC,UAAU,EAAE,UAAU,GAAG;QAAE,OAAO,EAAE,MAAM,CAAA;KAAE;IAwFnD;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,GAAE,OAAc,GAAG,MAAM;IAe/D;;OAEG;IACH,sBAAsB,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IA0BzD;;OAEG;IACH,UAAU,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IAW7C;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAUlC;;OAEG;IACH,KAAK,IAAI,IAAI;IAab,OAAO,CAAC,SAAS;IAajB,OAAO,CAAC,UAAU;IAQlB,OAAO,CAAC,gBAAgB;IAexB,OAAO,CAAC,aAAa;IAcrB,OAAO,CAAC,WAAW;IAsBnB,OAAO,CAAC,gBAAgB;IAexB,OAAO,CAAC,WAAW;IAUnB,OAAO,CAAC,MAAM;CAWf;AAED;;GAEG;AACH,wBAAgB,WAAW,CAAC,MAAM,CAAC,EAAE,OAAO,CAAC,WAAW,CAAC,GAAG,cAAc,CAEzE"}
|