moflo 4.8.32 → 4.8.34
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/generate-code-map.mjs +955 -955
- package/bin/index-guidance.mjs +905 -905
- package/bin/index-tests.mjs +728 -728
- package/bin/setup-project.mjs +252 -252
- package/package.json +10 -5
- package/src/@claude-flow/cli/dist/src/commands/doctor.js +1339 -1107
- package/src/@claude-flow/cli/dist/src/index.js +2 -18
- package/src/@claude-flow/cli/dist/src/mcp-tools/hooks-tools.js +17 -0
- package/src/@claude-flow/cli/dist/src/memory/memory-initializer.js +4 -7
- package/src/@claude-flow/cli/dist/src/version.js +6 -0
- package/src/@claude-flow/cli/package.json +1 -1
- package/src/@claude-flow/neural/README.md +260 -0
- package/src/@claude-flow/neural/dist/algorithms/a2c.js +361 -0
- package/src/@claude-flow/neural/dist/algorithms/curiosity.js +392 -0
- package/src/@claude-flow/neural/dist/algorithms/decision-transformer.js +415 -0
- package/src/@claude-flow/neural/dist/algorithms/dqn.js +303 -0
- package/src/@claude-flow/neural/dist/algorithms/index.js +74 -0
- package/src/@claude-flow/neural/dist/algorithms/ppo.js +331 -0
- package/src/@claude-flow/neural/dist/algorithms/q-learning.js +259 -0
- package/src/@claude-flow/neural/dist/algorithms/sarsa.js +297 -0
- package/src/@claude-flow/neural/dist/application/index.js +7 -0
- package/src/@claude-flow/neural/dist/application/services/neural-application-service.js +161 -0
- package/src/@claude-flow/neural/dist/domain/entities/pattern.js +134 -0
- package/src/@claude-flow/neural/dist/domain/index.js +8 -0
- package/src/@claude-flow/neural/dist/domain/services/learning-service.js +195 -0
- package/src/@claude-flow/neural/dist/index.js +201 -0
- package/src/@claude-flow/neural/dist/modes/balanced.js +234 -0
- package/src/@claude-flow/neural/dist/modes/base.js +77 -0
- package/src/@claude-flow/neural/dist/modes/batch.js +316 -0
- package/src/@claude-flow/neural/dist/modes/edge.js +310 -0
- package/src/@claude-flow/neural/dist/modes/index.js +13 -0
- package/src/@claude-flow/neural/dist/modes/real-time.js +196 -0
- package/src/@claude-flow/neural/dist/modes/research.js +389 -0
- package/src/@claude-flow/neural/dist/pattern-learner.js +603 -0
- package/src/@claude-flow/neural/dist/reasoning-bank.js +993 -0
- package/src/@claude-flow/neural/dist/reasoningbank-adapter.js +463 -0
- package/src/@claude-flow/neural/dist/sona-integration.js +326 -0
- package/src/@claude-flow/neural/dist/sona-manager.js +695 -0
- package/src/@claude-flow/neural/dist/types.js +11 -0
- package/src/@claude-flow/neural/package.json +26 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tabular Q-Learning
|
|
3
|
+
*
|
|
4
|
+
* Classic Q-learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Eligibility traces (optional)
|
|
8
|
+
* - Experience replay
|
|
9
|
+
*
|
|
10
|
+
* Suitable for smaller state spaces or discretized environments.
|
|
11
|
+
* Performance Target: <1ms per update
|
|
12
|
+
*/
|
|
13
|
+
/**
|
|
14
|
+
* Default Q-Learning configuration
|
|
15
|
+
*/
|
|
16
|
+
export const DEFAULT_QLEARNING_CONFIG = {
|
|
17
|
+
algorithm: 'q-learning',
|
|
18
|
+
learningRate: 0.1,
|
|
19
|
+
gamma: 0.99,
|
|
20
|
+
entropyCoef: 0,
|
|
21
|
+
valueLossCoef: 1,
|
|
22
|
+
maxGradNorm: 1,
|
|
23
|
+
epochs: 1,
|
|
24
|
+
miniBatchSize: 1,
|
|
25
|
+
explorationInitial: 1.0,
|
|
26
|
+
explorationFinal: 0.01,
|
|
27
|
+
explorationDecay: 10000,
|
|
28
|
+
maxStates: 10000,
|
|
29
|
+
useEligibilityTraces: false,
|
|
30
|
+
traceDecay: 0.9,
|
|
31
|
+
};
|
|
32
|
+
/**
|
|
33
|
+
* Q-Learning Algorithm Implementation
|
|
34
|
+
*/
|
|
35
|
+
export class QLearning {
|
|
36
|
+
config;
|
|
37
|
+
// Q-table
|
|
38
|
+
qTable = new Map();
|
|
39
|
+
// Exploration
|
|
40
|
+
epsilon;
|
|
41
|
+
stepCount = 0;
|
|
42
|
+
// Number of actions
|
|
43
|
+
numActions = 4;
|
|
44
|
+
// Eligibility traces
|
|
45
|
+
traces = new Map();
|
|
46
|
+
// Statistics
|
|
47
|
+
updateCount = 0;
|
|
48
|
+
avgTDError = 0;
|
|
49
|
+
constructor(config = {}) {
|
|
50
|
+
this.config = { ...DEFAULT_QLEARNING_CONFIG, ...config };
|
|
51
|
+
this.epsilon = this.config.explorationInitial;
|
|
52
|
+
}
|
|
53
|
+
/**
|
|
54
|
+
* Update Q-values from trajectory
|
|
55
|
+
*/
|
|
56
|
+
update(trajectory) {
|
|
57
|
+
const startTime = performance.now();
|
|
58
|
+
if (trajectory.steps.length === 0) {
|
|
59
|
+
return { tdError: 0 };
|
|
60
|
+
}
|
|
61
|
+
let totalTDError = 0;
|
|
62
|
+
// Reset eligibility traces for new trajectory
|
|
63
|
+
if (this.config.useEligibilityTraces) {
|
|
64
|
+
this.traces.clear();
|
|
65
|
+
}
|
|
66
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
67
|
+
const step = trajectory.steps[i];
|
|
68
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
69
|
+
const action = this.hashAction(step.action);
|
|
70
|
+
// Get or create Q-entry
|
|
71
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
72
|
+
// Current Q-value
|
|
73
|
+
const currentQ = qEntry.qValues[action];
|
|
74
|
+
// Compute target Q-value
|
|
75
|
+
let targetQ;
|
|
76
|
+
if (i === trajectory.steps.length - 1) {
|
|
77
|
+
// Terminal state
|
|
78
|
+
targetQ = step.reward;
|
|
79
|
+
}
|
|
80
|
+
else {
|
|
81
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
82
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
83
|
+
const maxNextQ = Math.max(...nextEntry.qValues);
|
|
84
|
+
targetQ = step.reward + this.config.gamma * maxNextQ;
|
|
85
|
+
}
|
|
86
|
+
// TD error
|
|
87
|
+
const tdError = targetQ - currentQ;
|
|
88
|
+
totalTDError += Math.abs(tdError);
|
|
89
|
+
if (this.config.useEligibilityTraces) {
|
|
90
|
+
// Update eligibility trace
|
|
91
|
+
this.updateTrace(stateKey, action);
|
|
92
|
+
// Update all states with traces
|
|
93
|
+
this.updateWithTraces(tdError);
|
|
94
|
+
}
|
|
95
|
+
else {
|
|
96
|
+
// Simple Q-learning update
|
|
97
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
98
|
+
qEntry.visits++;
|
|
99
|
+
qEntry.lastUpdate = Date.now();
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
// Decay exploration
|
|
103
|
+
this.stepCount += trajectory.steps.length;
|
|
104
|
+
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
105
|
+
// Prune Q-table if too large
|
|
106
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
107
|
+
this.pruneQTable();
|
|
108
|
+
}
|
|
109
|
+
this.updateCount++;
|
|
110
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
111
|
+
const elapsed = performance.now() - startTime;
|
|
112
|
+
if (elapsed > 1) {
|
|
113
|
+
console.warn(`Q-learning update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
114
|
+
}
|
|
115
|
+
return { tdError: this.avgTDError };
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Get action using epsilon-greedy policy
|
|
119
|
+
*/
|
|
120
|
+
getAction(state, explore = true) {
|
|
121
|
+
if (explore && Math.random() < this.epsilon) {
|
|
122
|
+
return Math.floor(Math.random() * this.numActions);
|
|
123
|
+
}
|
|
124
|
+
const stateKey = this.hashState(state);
|
|
125
|
+
const entry = this.qTable.get(stateKey);
|
|
126
|
+
if (!entry) {
|
|
127
|
+
return Math.floor(Math.random() * this.numActions);
|
|
128
|
+
}
|
|
129
|
+
return this.argmax(entry.qValues);
|
|
130
|
+
}
|
|
131
|
+
/**
|
|
132
|
+
* Get Q-values for a state
|
|
133
|
+
*/
|
|
134
|
+
getQValues(state) {
|
|
135
|
+
const stateKey = this.hashState(state);
|
|
136
|
+
const entry = this.qTable.get(stateKey);
|
|
137
|
+
if (!entry) {
|
|
138
|
+
return new Float32Array(this.numActions);
|
|
139
|
+
}
|
|
140
|
+
return new Float32Array(entry.qValues);
|
|
141
|
+
}
|
|
142
|
+
/**
|
|
143
|
+
* Get statistics
|
|
144
|
+
*/
|
|
145
|
+
getStats() {
|
|
146
|
+
return {
|
|
147
|
+
updateCount: this.updateCount,
|
|
148
|
+
qTableSize: this.qTable.size,
|
|
149
|
+
epsilon: this.epsilon,
|
|
150
|
+
avgTDError: this.avgTDError,
|
|
151
|
+
stepCount: this.stepCount,
|
|
152
|
+
};
|
|
153
|
+
}
|
|
154
|
+
/**
|
|
155
|
+
* Reset Q-table
|
|
156
|
+
*/
|
|
157
|
+
reset() {
|
|
158
|
+
this.qTable.clear();
|
|
159
|
+
this.traces.clear();
|
|
160
|
+
this.epsilon = this.config.explorationInitial;
|
|
161
|
+
this.stepCount = 0;
|
|
162
|
+
this.updateCount = 0;
|
|
163
|
+
this.avgTDError = 0;
|
|
164
|
+
}
|
|
165
|
+
// ==========================================================================
|
|
166
|
+
// Private Methods
|
|
167
|
+
// ==========================================================================
|
|
168
|
+
hashState(state) {
|
|
169
|
+
// Discretize state by binning values
|
|
170
|
+
const bins = 10;
|
|
171
|
+
const parts = [];
|
|
172
|
+
// Use first 8 dimensions for hashing
|
|
173
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
174
|
+
const normalized = (state[i] + 1) / 2; // Assume [-1, 1] range
|
|
175
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
176
|
+
parts.push(bin);
|
|
177
|
+
}
|
|
178
|
+
return parts.join(',');
|
|
179
|
+
}
|
|
180
|
+
hashAction(action) {
|
|
181
|
+
let hash = 0;
|
|
182
|
+
for (let i = 0; i < action.length; i++) {
|
|
183
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
184
|
+
}
|
|
185
|
+
return hash;
|
|
186
|
+
}
|
|
187
|
+
getOrCreateEntry(stateKey) {
|
|
188
|
+
let entry = this.qTable.get(stateKey);
|
|
189
|
+
if (!entry) {
|
|
190
|
+
entry = {
|
|
191
|
+
qValues: new Float32Array(this.numActions),
|
|
192
|
+
visits: 0,
|
|
193
|
+
lastUpdate: Date.now(),
|
|
194
|
+
};
|
|
195
|
+
this.qTable.set(stateKey, entry);
|
|
196
|
+
}
|
|
197
|
+
return entry;
|
|
198
|
+
}
|
|
199
|
+
updateTrace(stateKey, action) {
|
|
200
|
+
// Decay all existing traces
|
|
201
|
+
for (const [key, trace] of this.traces) {
|
|
202
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
203
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
204
|
+
}
|
|
205
|
+
// Remove near-zero traces
|
|
206
|
+
const maxTrace = Math.max(...trace);
|
|
207
|
+
if (maxTrace < 0.001) {
|
|
208
|
+
this.traces.delete(key);
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
// Set trace for current state-action
|
|
212
|
+
let trace = this.traces.get(stateKey);
|
|
213
|
+
if (!trace) {
|
|
214
|
+
trace = new Float32Array(this.numActions);
|
|
215
|
+
this.traces.set(stateKey, trace);
|
|
216
|
+
}
|
|
217
|
+
trace[action] = 1.0;
|
|
218
|
+
}
|
|
219
|
+
updateWithTraces(tdError) {
|
|
220
|
+
const lr = this.config.learningRate;
|
|
221
|
+
for (const [stateKey, trace] of this.traces) {
|
|
222
|
+
const entry = this.qTable.get(stateKey);
|
|
223
|
+
if (entry) {
|
|
224
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
225
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
226
|
+
}
|
|
227
|
+
entry.visits++;
|
|
228
|
+
entry.lastUpdate = Date.now();
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
pruneQTable() {
|
|
233
|
+
// Remove least recently used states
|
|
234
|
+
const entries = Array.from(this.qTable.entries())
|
|
235
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
236
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
237
|
+
for (let i = 0; i < toRemove; i++) {
|
|
238
|
+
this.qTable.delete(entries[i][0]);
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
argmax(values) {
|
|
242
|
+
let maxIdx = 0;
|
|
243
|
+
let maxVal = values[0];
|
|
244
|
+
for (let i = 1; i < values.length; i++) {
|
|
245
|
+
if (values[i] > maxVal) {
|
|
246
|
+
maxVal = values[i];
|
|
247
|
+
maxIdx = i;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
return maxIdx;
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
/**
|
|
254
|
+
* Factory function
|
|
255
|
+
*/
|
|
256
|
+
export function createQLearning(config) {
|
|
257
|
+
return new QLearning(config);
|
|
258
|
+
}
|
|
259
|
+
//# sourceMappingURL=q-learning.js.map
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SARSA (State-Action-Reward-State-Action)
|
|
3
|
+
*
|
|
4
|
+
* On-policy TD learning algorithm with:
|
|
5
|
+
* - Epsilon-greedy exploration
|
|
6
|
+
* - State hashing for continuous states
|
|
7
|
+
* - Expected SARSA variant (optional)
|
|
8
|
+
* - Eligibility traces (SARSA-lambda)
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <1ms per update
|
|
11
|
+
*/
|
|
12
|
+
/**
|
|
13
|
+
* Default SARSA configuration
|
|
14
|
+
*/
|
|
15
|
+
export const DEFAULT_SARSA_CONFIG = {
|
|
16
|
+
algorithm: 'sarsa',
|
|
17
|
+
learningRate: 0.1,
|
|
18
|
+
gamma: 0.99,
|
|
19
|
+
entropyCoef: 0,
|
|
20
|
+
valueLossCoef: 1,
|
|
21
|
+
maxGradNorm: 1,
|
|
22
|
+
epochs: 1,
|
|
23
|
+
miniBatchSize: 1,
|
|
24
|
+
explorationInitial: 1.0,
|
|
25
|
+
explorationFinal: 0.01,
|
|
26
|
+
explorationDecay: 10000,
|
|
27
|
+
maxStates: 10000,
|
|
28
|
+
useExpectedSARSA: false,
|
|
29
|
+
useEligibilityTraces: false,
|
|
30
|
+
traceDecay: 0.9,
|
|
31
|
+
};
|
|
32
|
+
/**
|
|
33
|
+
* SARSA Algorithm Implementation
|
|
34
|
+
*/
|
|
35
|
+
export class SARSAAlgorithm {
|
|
36
|
+
config;
|
|
37
|
+
// Q-table
|
|
38
|
+
qTable = new Map();
|
|
39
|
+
// Exploration
|
|
40
|
+
epsilon;
|
|
41
|
+
stepCount = 0;
|
|
42
|
+
// Number of actions
|
|
43
|
+
numActions = 4;
|
|
44
|
+
// Eligibility traces
|
|
45
|
+
traces = new Map();
|
|
46
|
+
// Statistics
|
|
47
|
+
updateCount = 0;
|
|
48
|
+
avgTDError = 0;
|
|
49
|
+
constructor(config = {}) {
|
|
50
|
+
this.config = { ...DEFAULT_SARSA_CONFIG, ...config };
|
|
51
|
+
this.epsilon = this.config.explorationInitial;
|
|
52
|
+
}
|
|
53
|
+
/**
|
|
54
|
+
* Update Q-values from trajectory using SARSA
|
|
55
|
+
*/
|
|
56
|
+
update(trajectory) {
|
|
57
|
+
const startTime = performance.now();
|
|
58
|
+
if (trajectory.steps.length < 2) {
|
|
59
|
+
return { tdError: 0 };
|
|
60
|
+
}
|
|
61
|
+
let totalTDError = 0;
|
|
62
|
+
// Reset eligibility traces
|
|
63
|
+
if (this.config.useEligibilityTraces) {
|
|
64
|
+
this.traces.clear();
|
|
65
|
+
}
|
|
66
|
+
for (let i = 0; i < trajectory.steps.length - 1; i++) {
|
|
67
|
+
const step = trajectory.steps[i];
|
|
68
|
+
const nextStep = trajectory.steps[i + 1];
|
|
69
|
+
const stateKey = this.hashState(step.stateBefore);
|
|
70
|
+
const action = this.hashAction(step.action);
|
|
71
|
+
const nextStateKey = this.hashState(step.stateAfter);
|
|
72
|
+
const nextAction = this.hashAction(nextStep.action);
|
|
73
|
+
// Get or create entries
|
|
74
|
+
const qEntry = this.getOrCreateEntry(stateKey);
|
|
75
|
+
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
76
|
+
// Current Q-value
|
|
77
|
+
const currentQ = qEntry.qValues[action];
|
|
78
|
+
// Compute target Q-value using SARSA update rule
|
|
79
|
+
let targetQ;
|
|
80
|
+
if (this.config.useExpectedSARSA) {
|
|
81
|
+
// Expected SARSA: use expected value under current policy
|
|
82
|
+
targetQ = step.reward + this.config.gamma * this.expectedValue(nextEntry.qValues);
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
// Standard SARSA: use actual next action
|
|
86
|
+
targetQ = step.reward + this.config.gamma * nextEntry.qValues[nextAction];
|
|
87
|
+
}
|
|
88
|
+
// TD error
|
|
89
|
+
const tdError = targetQ - currentQ;
|
|
90
|
+
totalTDError += Math.abs(tdError);
|
|
91
|
+
if (this.config.useEligibilityTraces) {
|
|
92
|
+
this.updateTrace(stateKey, action);
|
|
93
|
+
this.updateWithTraces(tdError);
|
|
94
|
+
}
|
|
95
|
+
else {
|
|
96
|
+
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
97
|
+
qEntry.visits++;
|
|
98
|
+
qEntry.lastUpdate = Date.now();
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
// Handle terminal state
|
|
102
|
+
const lastStep = trajectory.steps[trajectory.steps.length - 1];
|
|
103
|
+
const lastStateKey = this.hashState(lastStep.stateBefore);
|
|
104
|
+
const lastAction = this.hashAction(lastStep.action);
|
|
105
|
+
const lastEntry = this.getOrCreateEntry(lastStateKey);
|
|
106
|
+
const terminalTDError = lastStep.reward - lastEntry.qValues[lastAction];
|
|
107
|
+
lastEntry.qValues[lastAction] += this.config.learningRate * terminalTDError;
|
|
108
|
+
totalTDError += Math.abs(terminalTDError);
|
|
109
|
+
// Decay exploration
|
|
110
|
+
this.stepCount += trajectory.steps.length;
|
|
111
|
+
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
112
|
+
// Prune if needed
|
|
113
|
+
if (this.qTable.size > this.config.maxStates) {
|
|
114
|
+
this.pruneQTable();
|
|
115
|
+
}
|
|
116
|
+
this.updateCount++;
|
|
117
|
+
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
118
|
+
const elapsed = performance.now() - startTime;
|
|
119
|
+
if (elapsed > 1) {
|
|
120
|
+
console.warn(`SARSA update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
121
|
+
}
|
|
122
|
+
return { tdError: this.avgTDError };
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Get action using epsilon-greedy policy
|
|
126
|
+
*/
|
|
127
|
+
getAction(state, explore = true) {
|
|
128
|
+
if (explore && Math.random() < this.epsilon) {
|
|
129
|
+
return Math.floor(Math.random() * this.numActions);
|
|
130
|
+
}
|
|
131
|
+
const stateKey = this.hashState(state);
|
|
132
|
+
const entry = this.qTable.get(stateKey);
|
|
133
|
+
if (!entry) {
|
|
134
|
+
return Math.floor(Math.random() * this.numActions);
|
|
135
|
+
}
|
|
136
|
+
return this.argmax(entry.qValues);
|
|
137
|
+
}
|
|
138
|
+
/**
|
|
139
|
+
* Get action probabilities for a state
|
|
140
|
+
*/
|
|
141
|
+
getActionProbabilities(state) {
|
|
142
|
+
const probs = new Float32Array(this.numActions);
|
|
143
|
+
const stateKey = this.hashState(state);
|
|
144
|
+
const entry = this.qTable.get(stateKey);
|
|
145
|
+
if (!entry) {
|
|
146
|
+
// Uniform distribution
|
|
147
|
+
const uniform = 1 / this.numActions;
|
|
148
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
149
|
+
probs[a] = uniform;
|
|
150
|
+
}
|
|
151
|
+
return probs;
|
|
152
|
+
}
|
|
153
|
+
// Epsilon-greedy probabilities
|
|
154
|
+
const greedyAction = this.argmax(entry.qValues);
|
|
155
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
156
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
157
|
+
probs[a] = exploreProb;
|
|
158
|
+
}
|
|
159
|
+
probs[greedyAction] += 1 - this.epsilon;
|
|
160
|
+
return probs;
|
|
161
|
+
}
|
|
162
|
+
/**
|
|
163
|
+
* Get Q-values for a state
|
|
164
|
+
*/
|
|
165
|
+
getQValues(state) {
|
|
166
|
+
const stateKey = this.hashState(state);
|
|
167
|
+
const entry = this.qTable.get(stateKey);
|
|
168
|
+
if (!entry) {
|
|
169
|
+
return new Float32Array(this.numActions);
|
|
170
|
+
}
|
|
171
|
+
return new Float32Array(entry.qValues);
|
|
172
|
+
}
|
|
173
|
+
/**
|
|
174
|
+
* Get statistics
|
|
175
|
+
*/
|
|
176
|
+
getStats() {
|
|
177
|
+
return {
|
|
178
|
+
updateCount: this.updateCount,
|
|
179
|
+
qTableSize: this.qTable.size,
|
|
180
|
+
epsilon: this.epsilon,
|
|
181
|
+
avgTDError: this.avgTDError,
|
|
182
|
+
stepCount: this.stepCount,
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
/**
|
|
186
|
+
* Reset algorithm state
|
|
187
|
+
*/
|
|
188
|
+
reset() {
|
|
189
|
+
this.qTable.clear();
|
|
190
|
+
this.traces.clear();
|
|
191
|
+
this.epsilon = this.config.explorationInitial;
|
|
192
|
+
this.stepCount = 0;
|
|
193
|
+
this.updateCount = 0;
|
|
194
|
+
this.avgTDError = 0;
|
|
195
|
+
}
|
|
196
|
+
// ==========================================================================
|
|
197
|
+
// Private Methods
|
|
198
|
+
// ==========================================================================
|
|
199
|
+
hashState(state) {
|
|
200
|
+
const bins = 10;
|
|
201
|
+
const parts = [];
|
|
202
|
+
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
203
|
+
const normalized = (state[i] + 1) / 2;
|
|
204
|
+
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
205
|
+
parts.push(bin);
|
|
206
|
+
}
|
|
207
|
+
return parts.join(',');
|
|
208
|
+
}
|
|
209
|
+
hashAction(action) {
|
|
210
|
+
let hash = 0;
|
|
211
|
+
for (let i = 0; i < action.length; i++) {
|
|
212
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
213
|
+
}
|
|
214
|
+
return hash;
|
|
215
|
+
}
|
|
216
|
+
getOrCreateEntry(stateKey) {
|
|
217
|
+
let entry = this.qTable.get(stateKey);
|
|
218
|
+
if (!entry) {
|
|
219
|
+
entry = {
|
|
220
|
+
qValues: new Float32Array(this.numActions),
|
|
221
|
+
visits: 0,
|
|
222
|
+
lastUpdate: Date.now(),
|
|
223
|
+
};
|
|
224
|
+
this.qTable.set(stateKey, entry);
|
|
225
|
+
}
|
|
226
|
+
return entry;
|
|
227
|
+
}
|
|
228
|
+
expectedValue(qValues) {
|
|
229
|
+
// Expected value under epsilon-greedy policy
|
|
230
|
+
const greedyAction = this.argmax(qValues);
|
|
231
|
+
const exploreProb = this.epsilon / this.numActions;
|
|
232
|
+
let expected = 0;
|
|
233
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
234
|
+
const prob = exploreProb + (a === greedyAction ? 1 - this.epsilon : 0);
|
|
235
|
+
expected += prob * qValues[a];
|
|
236
|
+
}
|
|
237
|
+
return expected;
|
|
238
|
+
}
|
|
239
|
+
updateTrace(stateKey, action) {
|
|
240
|
+
// Decay all traces
|
|
241
|
+
for (const [key, trace] of this.traces) {
|
|
242
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
243
|
+
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
244
|
+
}
|
|
245
|
+
const maxTrace = Math.max(...trace);
|
|
246
|
+
if (maxTrace < 0.001) {
|
|
247
|
+
this.traces.delete(key);
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
// Set current trace (replacing traces for same state-action)
|
|
251
|
+
let trace = this.traces.get(stateKey);
|
|
252
|
+
if (!trace) {
|
|
253
|
+
trace = new Float32Array(this.numActions);
|
|
254
|
+
this.traces.set(stateKey, trace);
|
|
255
|
+
}
|
|
256
|
+
trace[action] = 1.0;
|
|
257
|
+
}
|
|
258
|
+
updateWithTraces(tdError) {
|
|
259
|
+
const lr = this.config.learningRate;
|
|
260
|
+
for (const [stateKey, trace] of this.traces) {
|
|
261
|
+
const entry = this.qTable.get(stateKey);
|
|
262
|
+
if (entry) {
|
|
263
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
264
|
+
entry.qValues[a] += lr * tdError * trace[a];
|
|
265
|
+
}
|
|
266
|
+
entry.visits++;
|
|
267
|
+
entry.lastUpdate = Date.now();
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
pruneQTable() {
|
|
272
|
+
const entries = Array.from(this.qTable.entries())
|
|
273
|
+
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
274
|
+
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
275
|
+
for (let i = 0; i < toRemove; i++) {
|
|
276
|
+
this.qTable.delete(entries[i][0]);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
argmax(values) {
|
|
280
|
+
let maxIdx = 0;
|
|
281
|
+
let maxVal = values[0];
|
|
282
|
+
for (let i = 1; i < values.length; i++) {
|
|
283
|
+
if (values[i] > maxVal) {
|
|
284
|
+
maxVal = values[i];
|
|
285
|
+
maxIdx = i;
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
return maxIdx;
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
/**
|
|
292
|
+
* Factory function
|
|
293
|
+
*/
|
|
294
|
+
export function createSARSA(config) {
|
|
295
|
+
return new SARSAAlgorithm(config);
|
|
296
|
+
}
|
|
297
|
+
//# sourceMappingURL=sarsa.js.map
|