@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Deep Q-Network (DQN)
|
|
3
|
+
*
|
|
4
|
+
* Implements DQN with enhancements:
|
|
5
|
+
* - Experience replay
|
|
6
|
+
* - Target network
|
|
7
|
+
* - Double DQN (optional)
|
|
8
|
+
* - Dueling architecture (optional)
|
|
9
|
+
* - Epsilon-greedy exploration
|
|
10
|
+
*
|
|
11
|
+
* Performance Target: <10ms per update step
|
|
12
|
+
*/
|
|
13
|
+
/**
|
|
14
|
+
* Default DQN configuration
|
|
15
|
+
*/
|
|
16
|
+
export const DEFAULT_DQN_CONFIG = {
|
|
17
|
+
algorithm: 'dqn',
|
|
18
|
+
learningRate: 0.0001,
|
|
19
|
+
gamma: 0.99,
|
|
20
|
+
entropyCoef: 0,
|
|
21
|
+
valueLossCoef: 1,
|
|
22
|
+
maxGradNorm: 10,
|
|
23
|
+
epochs: 1,
|
|
24
|
+
miniBatchSize: 32,
|
|
25
|
+
bufferSize: 10000,
|
|
26
|
+
explorationInitial: 1.0,
|
|
27
|
+
explorationFinal: 0.01,
|
|
28
|
+
explorationDecay: 10000,
|
|
29
|
+
targetUpdateFreq: 100,
|
|
30
|
+
doubleDQN: true,
|
|
31
|
+
duelingNetwork: false,
|
|
32
|
+
};
|
|
33
|
+
/**
|
|
34
|
+
* DQN Algorithm Implementation
|
|
35
|
+
*/
|
|
36
|
+
export class DQNAlgorithm {
|
|
37
|
+
config;
|
|
38
|
+
// Q-network weights
|
|
39
|
+
qWeights;
|
|
40
|
+
targetWeights;
|
|
41
|
+
// Optimizer state
|
|
42
|
+
qMomentum;
|
|
43
|
+
// Replay buffer (circular)
|
|
44
|
+
buffer = [];
|
|
45
|
+
bufferIdx = 0;
|
|
46
|
+
// Exploration
|
|
47
|
+
epsilon;
|
|
48
|
+
stepCount = 0;
|
|
49
|
+
// Number of actions
|
|
50
|
+
numActions = 4;
|
|
51
|
+
inputDim = 768;
|
|
52
|
+
// Statistics
|
|
53
|
+
updateCount = 0;
|
|
54
|
+
avgLoss = 0;
|
|
55
|
+
constructor(config = {}) {
|
|
56
|
+
this.config = { ...DEFAULT_DQN_CONFIG, ...config };
|
|
57
|
+
this.epsilon = this.config.explorationInitial;
|
|
58
|
+
// Initialize Q-network (2 hidden layers)
|
|
59
|
+
this.qWeights = this.initializeNetwork();
|
|
60
|
+
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
61
|
+
this.qMomentum = this.qWeights.map(w => new Float32Array(w.length));
|
|
62
|
+
}
|
|
63
|
+
/**
|
|
64
|
+
* Add experience from trajectory
|
|
65
|
+
*/
|
|
66
|
+
addExperience(trajectory) {
|
|
67
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
68
|
+
const step = trajectory.steps[i];
|
|
69
|
+
const nextStep = i < trajectory.steps.length - 1
|
|
70
|
+
? trajectory.steps[i + 1]
|
|
71
|
+
: null;
|
|
72
|
+
const experience = {
|
|
73
|
+
state: step.stateBefore,
|
|
74
|
+
action: this.hashAction(step.action),
|
|
75
|
+
reward: step.reward,
|
|
76
|
+
nextState: step.stateAfter,
|
|
77
|
+
done: nextStep === null,
|
|
78
|
+
};
|
|
79
|
+
// Add to circular buffer
|
|
80
|
+
if (this.buffer.length < this.config.bufferSize) {
|
|
81
|
+
this.buffer.push(experience);
|
|
82
|
+
}
|
|
83
|
+
else {
|
|
84
|
+
this.buffer[this.bufferIdx] = experience;
|
|
85
|
+
}
|
|
86
|
+
this.bufferIdx = (this.bufferIdx + 1) % this.config.bufferSize;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
/**
|
|
90
|
+
* Perform DQN update
|
|
91
|
+
* Target: <10ms
|
|
92
|
+
*/
|
|
93
|
+
update() {
|
|
94
|
+
const startTime = performance.now();
|
|
95
|
+
if (this.buffer.length < this.config.miniBatchSize) {
|
|
96
|
+
return { loss: 0, epsilon: this.epsilon };
|
|
97
|
+
}
|
|
98
|
+
// Sample mini-batch
|
|
99
|
+
const batch = this.sampleBatch();
|
|
100
|
+
// Compute TD targets
|
|
101
|
+
let totalLoss = 0;
|
|
102
|
+
const gradients = this.qWeights.map(w => new Float32Array(w.length));
|
|
103
|
+
for (const exp of batch) {
|
|
104
|
+
// Current Q-values
|
|
105
|
+
const qValues = this.forward(exp.state, this.qWeights);
|
|
106
|
+
const currentQ = qValues[exp.action];
|
|
107
|
+
// Target Q-value
|
|
108
|
+
let targetQ;
|
|
109
|
+
if (exp.done) {
|
|
110
|
+
targetQ = exp.reward;
|
|
111
|
+
}
|
|
112
|
+
else {
|
|
113
|
+
if (this.config.doubleDQN) {
|
|
114
|
+
// Double DQN: use online network to select action, target to evaluate
|
|
115
|
+
const nextQOnline = this.forward(exp.nextState, this.qWeights);
|
|
116
|
+
const bestAction = this.argmax(nextQOnline);
|
|
117
|
+
const nextQTarget = this.forward(exp.nextState, this.targetWeights);
|
|
118
|
+
targetQ = exp.reward + this.config.gamma * nextQTarget[bestAction];
|
|
119
|
+
}
|
|
120
|
+
else {
|
|
121
|
+
// Standard DQN
|
|
122
|
+
const nextQ = this.forward(exp.nextState, this.targetWeights);
|
|
123
|
+
targetQ = exp.reward + this.config.gamma * Math.max(...nextQ);
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
// TD error
|
|
127
|
+
const tdError = targetQ - currentQ;
|
|
128
|
+
const loss = tdError * tdError;
|
|
129
|
+
totalLoss += loss;
|
|
130
|
+
// Accumulate gradients
|
|
131
|
+
this.accumulateGradients(gradients, exp.state, exp.action, tdError);
|
|
132
|
+
}
|
|
133
|
+
// Apply gradients
|
|
134
|
+
this.applyGradients(gradients, batch.length);
|
|
135
|
+
// Update target network periodically
|
|
136
|
+
this.stepCount++;
|
|
137
|
+
if (this.stepCount % this.config.targetUpdateFreq === 0) {
|
|
138
|
+
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
139
|
+
}
|
|
140
|
+
// Decay exploration
|
|
141
|
+
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
142
|
+
this.updateCount++;
|
|
143
|
+
this.avgLoss = totalLoss / batch.length;
|
|
144
|
+
const elapsed = performance.now() - startTime;
|
|
145
|
+
if (elapsed > 10) {
|
|
146
|
+
console.warn(`DQN update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
147
|
+
}
|
|
148
|
+
return {
|
|
149
|
+
loss: this.avgLoss,
|
|
150
|
+
epsilon: this.epsilon,
|
|
151
|
+
};
|
|
152
|
+
}
|
|
153
|
+
/**
|
|
154
|
+
* Get action using epsilon-greedy
|
|
155
|
+
*/
|
|
156
|
+
getAction(state, explore = true) {
|
|
157
|
+
if (explore && Math.random() < this.epsilon) {
|
|
158
|
+
return Math.floor(Math.random() * this.numActions);
|
|
159
|
+
}
|
|
160
|
+
const qValues = this.forward(state, this.qWeights);
|
|
161
|
+
return this.argmax(qValues);
|
|
162
|
+
}
|
|
163
|
+
/**
|
|
164
|
+
* Get Q-values for a state
|
|
165
|
+
*/
|
|
166
|
+
getQValues(state) {
|
|
167
|
+
return this.forward(state, this.qWeights);
|
|
168
|
+
}
|
|
169
|
+
/**
|
|
170
|
+
* Get statistics
|
|
171
|
+
*/
|
|
172
|
+
getStats() {
|
|
173
|
+
return {
|
|
174
|
+
updateCount: this.updateCount,
|
|
175
|
+
bufferSize: this.buffer.length,
|
|
176
|
+
epsilon: this.epsilon,
|
|
177
|
+
avgLoss: this.avgLoss,
|
|
178
|
+
stepCount: this.stepCount,
|
|
179
|
+
};
|
|
180
|
+
}
|
|
181
|
+
// ==========================================================================
|
|
182
|
+
// Private Methods
|
|
183
|
+
// ==========================================================================
|
|
184
|
+
initializeNetwork() {
|
|
185
|
+
// Simple 2-layer network: input -> hidden -> output
|
|
186
|
+
const hiddenDim = 64;
|
|
187
|
+
const weights = [];
|
|
188
|
+
// Layer 1: input_dim -> hidden
|
|
189
|
+
const w1 = new Float32Array(this.inputDim * hiddenDim);
|
|
190
|
+
const scale1 = Math.sqrt(2 / this.inputDim);
|
|
191
|
+
for (let i = 0; i < w1.length; i++) {
|
|
192
|
+
w1[i] = (Math.random() - 0.5) * scale1;
|
|
193
|
+
}
|
|
194
|
+
weights.push(w1);
|
|
195
|
+
// Layer 2: hidden -> num_actions
|
|
196
|
+
const w2 = new Float32Array(hiddenDim * this.numActions);
|
|
197
|
+
const scale2 = Math.sqrt(2 / hiddenDim);
|
|
198
|
+
for (let i = 0; i < w2.length; i++) {
|
|
199
|
+
w2[i] = (Math.random() - 0.5) * scale2;
|
|
200
|
+
}
|
|
201
|
+
weights.push(w2);
|
|
202
|
+
return weights;
|
|
203
|
+
}
|
|
204
|
+
copyNetwork(weights) {
|
|
205
|
+
return weights.map(w => new Float32Array(w));
|
|
206
|
+
}
|
|
207
|
+
forward(state, weights) {
|
|
208
|
+
const hiddenDim = 64;
|
|
209
|
+
// Layer 1: ReLU(W1 * x)
|
|
210
|
+
const hidden = new Float32Array(hiddenDim);
|
|
211
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
212
|
+
let sum = 0;
|
|
213
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
214
|
+
sum += state[i] * weights[0][i * hiddenDim + h];
|
|
215
|
+
}
|
|
216
|
+
hidden[h] = Math.max(0, sum); // ReLU
|
|
217
|
+
}
|
|
218
|
+
// Layer 2: W2 * hidden (no activation for Q-values)
|
|
219
|
+
const output = new Float32Array(this.numActions);
|
|
220
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
221
|
+
let sum = 0;
|
|
222
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
223
|
+
sum += hidden[h] * weights[1][h * this.numActions + a];
|
|
224
|
+
}
|
|
225
|
+
output[a] = sum;
|
|
226
|
+
}
|
|
227
|
+
return output;
|
|
228
|
+
}
|
|
229
|
+
accumulateGradients(gradients, state, action, tdError) {
|
|
230
|
+
const hiddenDim = 64;
|
|
231
|
+
// Forward pass to get hidden activations
|
|
232
|
+
const hidden = new Float32Array(hiddenDim);
|
|
233
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
234
|
+
let sum = 0;
|
|
235
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
236
|
+
sum += state[i] * this.qWeights[0][i * hiddenDim + h];
|
|
237
|
+
}
|
|
238
|
+
hidden[h] = Math.max(0, sum);
|
|
239
|
+
}
|
|
240
|
+
// Gradient for layer 2 (only for selected action)
|
|
241
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
242
|
+
gradients[1][h * this.numActions + action] += hidden[h] * tdError;
|
|
243
|
+
}
|
|
244
|
+
// Gradient for layer 1 (backprop through ReLU)
|
|
245
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
246
|
+
if (hidden[h] > 0) { // ReLU gradient
|
|
247
|
+
const grad = tdError * this.qWeights[1][h * this.numActions + action];
|
|
248
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
249
|
+
gradients[0][i * hiddenDim + h] += state[i] * grad;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
applyGradients(gradients, batchSize) {
|
|
255
|
+
const lr = this.config.learningRate / batchSize;
|
|
256
|
+
const beta = 0.9;
|
|
257
|
+
for (let layer = 0; layer < gradients.length; layer++) {
|
|
258
|
+
for (let i = 0; i < gradients[layer].length; i++) {
|
|
259
|
+
// Gradient clipping
|
|
260
|
+
const grad = Math.max(Math.min(gradients[layer][i], this.config.maxGradNorm), -this.config.maxGradNorm);
|
|
261
|
+
// Momentum update
|
|
262
|
+
this.qMomentum[layer][i] = beta * this.qMomentum[layer][i] + (1 - beta) * grad;
|
|
263
|
+
this.qWeights[layer][i] += lr * this.qMomentum[layer][i];
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
sampleBatch() {
|
|
268
|
+
const batch = [];
|
|
269
|
+
const indices = new Set();
|
|
270
|
+
while (indices.size < this.config.miniBatchSize && indices.size < this.buffer.length) {
|
|
271
|
+
indices.add(Math.floor(Math.random() * this.buffer.length));
|
|
272
|
+
}
|
|
273
|
+
for (const idx of indices) {
|
|
274
|
+
batch.push(this.buffer[idx]);
|
|
275
|
+
}
|
|
276
|
+
return batch;
|
|
277
|
+
}
|
|
278
|
+
hashAction(action) {
|
|
279
|
+
let hash = 0;
|
|
280
|
+
for (let i = 0; i < action.length; i++) {
|
|
281
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
282
|
+
}
|
|
283
|
+
return hash;
|
|
284
|
+
}
|
|
285
|
+
argmax(values) {
|
|
286
|
+
let maxIdx = 0;
|
|
287
|
+
let maxVal = values[0];
|
|
288
|
+
for (let i = 1; i < values.length; i++) {
|
|
289
|
+
if (values[i] > maxVal) {
|
|
290
|
+
maxVal = values[i];
|
|
291
|
+
maxIdx = i;
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
return maxIdx;
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
/**
|
|
298
|
+
* Factory function
|
|
299
|
+
*/
|
|
300
|
+
export function createDQN(config) {
|
|
301
|
+
return new DQNAlgorithm(config);
|
|
302
|
+
}
|
|
303
|
+
//# sourceMappingURL=dqn.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"dqn.js","sourceRoot":"","sources":["dqn.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAQH;;GAEG;AACH,MAAM,CAAC,MAAM,kBAAkB,GAAc;IAC3C,SAAS,EAAE,KAAK;IAChB,YAAY,EAAE,MAAM;IACpB,KAAK,EAAE,IAAI;IACX,WAAW,EAAE,CAAC;IACd,aAAa,EAAE,CAAC;IAChB,WAAW,EAAE,EAAE;IACf,MAAM,EAAE,CAAC;IACT,aAAa,EAAE,EAAE;IACjB,UAAU,EAAE,KAAK;IACjB,kBAAkB,EAAE,GAAG;IACvB,gBAAgB,EAAE,IAAI;IACtB,gBAAgB,EAAE,KAAK;IACvB,gBAAgB,EAAE,GAAG;IACrB,SAAS,EAAE,IAAI;IACf,cAAc,EAAE,KAAK;CACtB,CAAC;AAaF;;GAEG;AACH,MAAM,OAAO,YAAY;IACf,MAAM,CAAY;IAE1B,oBAAoB;IACZ,QAAQ,CAAiB;IACzB,aAAa,CAAiB;IAEtC,kBAAkB;IACV,SAAS,CAAiB;IAElC,2BAA2B;IACnB,MAAM,GAAoB,EAAE,CAAC;IAC7B,SAAS,GAAG,CAAC,CAAC;IAEtB,cAAc;IACN,OAAO,CAAS;IAChB,SAAS,GAAG,CAAC,CAAC;IAEtB,oBAAoB;IACZ,UAAU,GAAG,CAAC,CAAC;IACf,QAAQ,GAAG,GAAG,CAAC;IAEvB,aAAa;IACL,WAAW,GAAG,CAAC,CAAC;IAChB,OAAO,GAAG,CAAC,CAAC;IAEpB,YAAY,SAA6B,EAAE;QACzC,IAAI,CAAC,MAAM,GAAG,EAAE,GAAG,kBAAkB,EAAE,GAAG,MAAM,EAAE,CAAC;QACnD,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,MAAM,CAAC,kBAAkB,CAAC;QAE9C,yCAAyC;QACzC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;QACzC,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;QACrD,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC;IACtE,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,UAAsB;QAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACjD,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,QAAQ,GAAG,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC;gBAC9C,CAAC,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC;gBACzB,CAAC,CAAC,IAAI,CAAC;YAET,MAAM,UAAU,GAAkB;gBAChC,KAAK,EAAE,IAAI,CAAC,WAAW;gBACvB,MAAM,EAAE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC;gBACpC,MAAM,EAAE,IAAI,CAAC,MAAM;gBACnB,SAAS,EAAE,IAAI,CAAC,UAAU;gBAC1B,IAAI,EAAE,QAAQ,KAAK,IAAI;aACxB,CAAC;YAEF,yBAAyB;YACzB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC;gBAChD,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YAC/B,CAAC;iBAAM,CAAC;gBACN,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,GAAG,UAAU,CAAC;YAC3C,CAAC;YACD,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,CAAC,SAAS,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QACjE,CAAC;IACH,CAAC;IAED;;;OAGG;IACH,MAAM;QACJ,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC;YACnD,OAAO,EAAE,IAAI,EAAE,CAAC,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC;QAC5C,CAAC;QAED,oBAAoB;QACpB,MAAM,KAAK,GAAG,IAAI,CAAC,WAAW,EAAE,CAAC;QAEjC,qBAAqB;QACrB,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC;QAErE,KAAK,MAAM,GAAG,IAAI,KAAK,EAAE,CAAC;YACxB,mBAAmB;YACnB,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;YACvD,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;YAErC,iBAAiB;YACjB,IAAI,OAAe,CAAC;YACpB,IAAI,GAAG,CAAC,IAAI,EAAE,CAAC;gBACb,OAAO,GAAG,GAAG,CAAC,MAAM,CAAC;YACvB,CAAC;iBAAM,CAAC;gBACN,IAAI,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC;oBAC1B,sEAAsE;oBACtE,MAAM,WAAW,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;oBAC/D,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,WAAW,CAAC,CAAC;oBAC5C,MAAM,WAAW,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,aAAa,CAAC,CAAC;oBACpE,OAAO,GAAG,GAAG,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,WAAW,CAAC,UAAU,CAAC,CAAC;gBACrE,CAAC;qBAAM,CAAC;oBACN,eAAe;oBACf,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,aAAa,CAAC,CAAC;oBAC9D,OAAO,GAAG,GAAG,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC;gBAChE,CAAC;YACH,CAAC;YAED,WAAW;YACX,MAAM,OAAO,GAAG,OAAO,GAAG,QAAQ,CAAC;YACnC,MAAM,IAAI,GAAG,OAAO,GAAG,OAAO,CAAC;YAC/B,SAAS,IAAI,IAAI,CAAC;YAElB,uBAAuB;YACvB,IAAI,CAAC,mBAAmB,CAAC,SAAS,EAAE,GAAG,CAAC,KAAK,EAAE,GAAG,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;QACtE,CAAC;QAED,kBAAkB;QAClB,IAAI,CAAC,cAAc,CAAC,SAAS,EAAE,KAAK,CAAC,MAAM,CAAC,CAAC;QAE7C,qCAAqC;QACrC,IAAI,CAAC,SAAS,EAAE,CAAC;QACjB,IAAI,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,gBAAgB,KAAK,CAAC,EAAE,CAAC;YACxD,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;QACvD,CAAC;QAED,oBAAoB;QACpB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,GAAG,CACrB,IAAI,CAAC,MAAM,CAAC,gBAAgB,EAC5B,IAAI,CAAC,MAAM,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,gBAAgB,CAC/E,CAAC;QAEF,IAAI,CAAC,WAAW,EAAE,CAAC;QACnB,IAAI,CAAC,OAAO,GAAG,SAAS,GAAG,KAAK,CAAC,MAAM,CAAC;QAExC,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,EAAE,EAAE,CAAC;YACjB,OAAO,CAAC,IAAI,CAAC,+BAA+B,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;QAC7E,CAAC;QAED,OAAO;YACL,IAAI,EAAE,IAAI,CAAC,OAAO;YAClB,OAAO,EAAE,IAAI,CAAC,OAAO;SACtB,CAAC;IACJ,CAAC;IAED;;OAEG;IACH,SAAS,CAAC,KAAmB,EAAE,UAAmB,IAAI;QACpD,IAAI,OAAO,IAAI,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,EAAE,CAAC;YAC5C,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACrD,CAAC;QAED,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;QACnD,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;IAC9B,CAAC;IAED;;OAEG;IACH,UAAU,CAAC,KAAmB;QAC5B,OAAO,IAAI,CAAC,OAAO,CAAC,KAAK,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC;IAED;;OAEG;IACH,QAAQ;QACN,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,WAAW;YAC7B,UAAU,EAAE,IAAI,CAAC,MAAM,CAAC,MAAM;YAC9B,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,SAAS,EAAE,IAAI,CAAC,SAAS;SAC1B,CAAC;IACJ,CAAC;IAED,6EAA6E;IAC7E,kBAAkB;IAClB,6EAA6E;IAErE,iBAAiB;QACvB,oDAAoD;QACpD,MAAM,SAAS,GAAG,EAAE,CAAC;QACrB,MAAM,OAAO,GAAmB,EAAE,CAAC;QAEnC,+BAA+B;QAC/B,MAAM,EAAE,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,QAAQ,GAAG,SAAS,CAAC,CAAC;QACvD,MAAM,MAAM,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAC5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,GAAG,CAAC,GAAG,MAAM,CAAC;QACzC,CAAC;QACD,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;QAEjB,iCAAiC;QACjC,MAAM,EAAE,GAAG,IAAI,YAAY,CAAC,SAAS,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;QACzD,MAAM,MAAM,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,SAAS,CAAC,CAAC;QACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,GAAG,CAAC,GAAG,MAAM,CAAC;QACzC,CAAC;QACD,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;QAEjB,OAAO,OAAO,CAAC;IACjB,CAAC;IAEO,WAAW,CAAC,OAAuB;QACzC,OAAO,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;IAC/C,CAAC;IAEO,OAAO,CAAC,KAAmB,EAAE,OAAuB;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,wBAAwB;QACxB,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,SAAS,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/D,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;YAClD,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,OAAO;QACvC,CAAC;QAED,oDAAoD;QACpD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;gBACnC,GAAG,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC;YACzD,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;QAClB,CAAC;QAED,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,mBAAmB,CACzB,SAAyB,EACzB,KAAmB,EACnB,MAAc,EACd,OAAe;QAEf,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,yCAAyC;QACzC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,SAAS,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/D,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;YACxD,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAC/B,CAAC;QAED,kDAAkD;QAClD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,MAAM,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC;QACpE,CAAC;QAED,+CAA+C;QAC/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,gBAAgB;gBACnC,MAAM,IAAI,GAAG,OAAO,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,MAAM,CAAC,CAAC;gBACtE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC/D,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC;gBACrD,CAAC;YACH,CAAC;QACH,CAAC;IACH,CAAC;IAEO,cAAc,CAAC,SAAyB,EAAE,SAAiB;QACjE,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,GAAG,SAAS,CAAC;QAChD,MAAM,IAAI,GAAG,GAAG,CAAC;QAEjB,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,SAAS,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YACtD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACjD,oBAAoB;gBACpB,MAAM,IAAI,GAAG,IAAI,CAAC,GAAG,CACnB,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,WAAW,CAAC,EACtD,CAAC,IAAI,CAAC,MAAM,CAAC,WAAW,CACzB,CAAC;gBAEF,kBAAkB;gBAClB,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,IAAI,CAAC;gBAC/E,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YAC3D,CAAC;QACH,CAAC;IACH,CAAC;IAEO,WAAW;QACjB,MAAM,KAAK,GAAoB,EAAE,CAAC;QAClC,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;QAElC,OAAO,OAAO,CAAC,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,aAAa,IAAI,OAAO,CAAC,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC;YACrF,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC;QAC9D,CAAC;QAED,KAAK,MAAM,GAAG,IAAI,OAAO,EAAE,CAAC;YAC1B,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;QAC/B,CAAC;QAED,OAAO,KAAK,CAAC;IACf,CAAC;IAEO,UAAU,CAAC,MAAc;QAC/B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,GAAG,CAAC,IAAI,GAAG,EAAE,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9D,CAAC;QACD,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,MAAM,CAAC,MAAoB;QACjC,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,IAAI,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC;gBACvB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACnB,MAAM,GAAG,CAAC,CAAC;YACb,CAAC;QACH,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,SAAS,CAAC,MAA2B;IACnD,OAAO,IAAI,YAAY,CAAC,MAAM,CAAC,CAAC;AAClC,CAAC"}
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Deep Q-Network (DQN)
|
|
3
|
+
*
|
|
4
|
+
* Implements DQN with enhancements:
|
|
5
|
+
* - Experience replay
|
|
6
|
+
* - Target network
|
|
7
|
+
* - Double DQN (optional)
|
|
8
|
+
* - Dueling architecture (optional)
|
|
9
|
+
* - Epsilon-greedy exploration
|
|
10
|
+
*
|
|
11
|
+
* Performance Target: <10ms per update step
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import type {
|
|
15
|
+
DQNConfig,
|
|
16
|
+
Trajectory,
|
|
17
|
+
TrajectoryStep,
|
|
18
|
+
} from '../types.js';
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Default DQN configuration
|
|
22
|
+
*/
|
|
23
|
+
export const DEFAULT_DQN_CONFIG: DQNConfig = {
|
|
24
|
+
algorithm: 'dqn',
|
|
25
|
+
learningRate: 0.0001,
|
|
26
|
+
gamma: 0.99,
|
|
27
|
+
entropyCoef: 0,
|
|
28
|
+
valueLossCoef: 1,
|
|
29
|
+
maxGradNorm: 10,
|
|
30
|
+
epochs: 1,
|
|
31
|
+
miniBatchSize: 32,
|
|
32
|
+
bufferSize: 10000,
|
|
33
|
+
explorationInitial: 1.0,
|
|
34
|
+
explorationFinal: 0.01,
|
|
35
|
+
explorationDecay: 10000,
|
|
36
|
+
targetUpdateFreq: 100,
|
|
37
|
+
doubleDQN: true,
|
|
38
|
+
duelingNetwork: false,
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
/**
|
|
42
|
+
* Experience for replay buffer
|
|
43
|
+
*/
|
|
44
|
+
interface DQNExperience {
|
|
45
|
+
state: Float32Array;
|
|
46
|
+
action: number;
|
|
47
|
+
reward: number;
|
|
48
|
+
nextState: Float32Array;
|
|
49
|
+
done: boolean;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/**
|
|
53
|
+
* DQN Algorithm Implementation
|
|
54
|
+
*/
|
|
55
|
+
export class DQNAlgorithm {
|
|
56
|
+
private config: DQNConfig;
|
|
57
|
+
|
|
58
|
+
// Q-network weights
|
|
59
|
+
private qWeights: Float32Array[];
|
|
60
|
+
private targetWeights: Float32Array[];
|
|
61
|
+
|
|
62
|
+
// Optimizer state
|
|
63
|
+
private qMomentum: Float32Array[];
|
|
64
|
+
|
|
65
|
+
// Replay buffer (circular)
|
|
66
|
+
private buffer: DQNExperience[] = [];
|
|
67
|
+
private bufferIdx = 0;
|
|
68
|
+
|
|
69
|
+
// Exploration
|
|
70
|
+
private epsilon: number;
|
|
71
|
+
private stepCount = 0;
|
|
72
|
+
|
|
73
|
+
// Number of actions
|
|
74
|
+
private numActions = 4;
|
|
75
|
+
private inputDim = 768;
|
|
76
|
+
|
|
77
|
+
// Statistics
|
|
78
|
+
private updateCount = 0;
|
|
79
|
+
private avgLoss = 0;
|
|
80
|
+
|
|
81
|
+
constructor(config: Partial<DQNConfig> = {}) {
|
|
82
|
+
this.config = { ...DEFAULT_DQN_CONFIG, ...config };
|
|
83
|
+
this.epsilon = this.config.explorationInitial;
|
|
84
|
+
|
|
85
|
+
// Initialize Q-network (2 hidden layers)
|
|
86
|
+
this.qWeights = this.initializeNetwork();
|
|
87
|
+
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
88
|
+
this.qMomentum = this.qWeights.map(w => new Float32Array(w.length));
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Add experience from trajectory
|
|
93
|
+
*/
|
|
94
|
+
addExperience(trajectory: Trajectory): void {
|
|
95
|
+
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
96
|
+
const step = trajectory.steps[i];
|
|
97
|
+
const nextStep = i < trajectory.steps.length - 1
|
|
98
|
+
? trajectory.steps[i + 1]
|
|
99
|
+
: null;
|
|
100
|
+
|
|
101
|
+
const experience: DQNExperience = {
|
|
102
|
+
state: step.stateBefore,
|
|
103
|
+
action: this.hashAction(step.action),
|
|
104
|
+
reward: step.reward,
|
|
105
|
+
nextState: step.stateAfter,
|
|
106
|
+
done: nextStep === null,
|
|
107
|
+
};
|
|
108
|
+
|
|
109
|
+
// Add to circular buffer
|
|
110
|
+
if (this.buffer.length < this.config.bufferSize) {
|
|
111
|
+
this.buffer.push(experience);
|
|
112
|
+
} else {
|
|
113
|
+
this.buffer[this.bufferIdx] = experience;
|
|
114
|
+
}
|
|
115
|
+
this.bufferIdx = (this.bufferIdx + 1) % this.config.bufferSize;
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/**
|
|
120
|
+
* Perform DQN update
|
|
121
|
+
* Target: <10ms
|
|
122
|
+
*/
|
|
123
|
+
update(): { loss: number; epsilon: number } {
|
|
124
|
+
const startTime = performance.now();
|
|
125
|
+
|
|
126
|
+
if (this.buffer.length < this.config.miniBatchSize) {
|
|
127
|
+
return { loss: 0, epsilon: this.epsilon };
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Sample mini-batch
|
|
131
|
+
const batch = this.sampleBatch();
|
|
132
|
+
|
|
133
|
+
// Compute TD targets
|
|
134
|
+
let totalLoss = 0;
|
|
135
|
+
const gradients = this.qWeights.map(w => new Float32Array(w.length));
|
|
136
|
+
|
|
137
|
+
for (const exp of batch) {
|
|
138
|
+
// Current Q-values
|
|
139
|
+
const qValues = this.forward(exp.state, this.qWeights);
|
|
140
|
+
const currentQ = qValues[exp.action];
|
|
141
|
+
|
|
142
|
+
// Target Q-value
|
|
143
|
+
let targetQ: number;
|
|
144
|
+
if (exp.done) {
|
|
145
|
+
targetQ = exp.reward;
|
|
146
|
+
} else {
|
|
147
|
+
if (this.config.doubleDQN) {
|
|
148
|
+
// Double DQN: use online network to select action, target to evaluate
|
|
149
|
+
const nextQOnline = this.forward(exp.nextState, this.qWeights);
|
|
150
|
+
const bestAction = this.argmax(nextQOnline);
|
|
151
|
+
const nextQTarget = this.forward(exp.nextState, this.targetWeights);
|
|
152
|
+
targetQ = exp.reward + this.config.gamma * nextQTarget[bestAction];
|
|
153
|
+
} else {
|
|
154
|
+
// Standard DQN
|
|
155
|
+
const nextQ = this.forward(exp.nextState, this.targetWeights);
|
|
156
|
+
targetQ = exp.reward + this.config.gamma * Math.max(...nextQ);
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
// TD error
|
|
161
|
+
const tdError = targetQ - currentQ;
|
|
162
|
+
const loss = tdError * tdError;
|
|
163
|
+
totalLoss += loss;
|
|
164
|
+
|
|
165
|
+
// Accumulate gradients
|
|
166
|
+
this.accumulateGradients(gradients, exp.state, exp.action, tdError);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
// Apply gradients
|
|
170
|
+
this.applyGradients(gradients, batch.length);
|
|
171
|
+
|
|
172
|
+
// Update target network periodically
|
|
173
|
+
this.stepCount++;
|
|
174
|
+
if (this.stepCount % this.config.targetUpdateFreq === 0) {
|
|
175
|
+
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// Decay exploration
|
|
179
|
+
this.epsilon = Math.max(
|
|
180
|
+
this.config.explorationFinal,
|
|
181
|
+
this.config.explorationInitial - this.stepCount / this.config.explorationDecay
|
|
182
|
+
);
|
|
183
|
+
|
|
184
|
+
this.updateCount++;
|
|
185
|
+
this.avgLoss = totalLoss / batch.length;
|
|
186
|
+
|
|
187
|
+
const elapsed = performance.now() - startTime;
|
|
188
|
+
if (elapsed > 10) {
|
|
189
|
+
console.warn(`DQN update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
loss: this.avgLoss,
|
|
194
|
+
epsilon: this.epsilon,
|
|
195
|
+
};
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/**
|
|
199
|
+
* Get action using epsilon-greedy
|
|
200
|
+
*/
|
|
201
|
+
getAction(state: Float32Array, explore: boolean = true): number {
|
|
202
|
+
if (explore && Math.random() < this.epsilon) {
|
|
203
|
+
return Math.floor(Math.random() * this.numActions);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
const qValues = this.forward(state, this.qWeights);
|
|
207
|
+
return this.argmax(qValues);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
* Get Q-values for a state
|
|
212
|
+
*/
|
|
213
|
+
getQValues(state: Float32Array): Float32Array {
|
|
214
|
+
return this.forward(state, this.qWeights);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/**
|
|
218
|
+
* Get statistics
|
|
219
|
+
*/
|
|
220
|
+
getStats(): Record<string, number> {
|
|
221
|
+
return {
|
|
222
|
+
updateCount: this.updateCount,
|
|
223
|
+
bufferSize: this.buffer.length,
|
|
224
|
+
epsilon: this.epsilon,
|
|
225
|
+
avgLoss: this.avgLoss,
|
|
226
|
+
stepCount: this.stepCount,
|
|
227
|
+
};
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// ==========================================================================
|
|
231
|
+
// Private Methods
|
|
232
|
+
// ==========================================================================
|
|
233
|
+
|
|
234
|
+
private initializeNetwork(): Float32Array[] {
|
|
235
|
+
// Simple 2-layer network: input -> hidden -> output
|
|
236
|
+
const hiddenDim = 64;
|
|
237
|
+
const weights: Float32Array[] = [];
|
|
238
|
+
|
|
239
|
+
// Layer 1: input_dim -> hidden
|
|
240
|
+
const w1 = new Float32Array(this.inputDim * hiddenDim);
|
|
241
|
+
const scale1 = Math.sqrt(2 / this.inputDim);
|
|
242
|
+
for (let i = 0; i < w1.length; i++) {
|
|
243
|
+
w1[i] = (Math.random() - 0.5) * scale1;
|
|
244
|
+
}
|
|
245
|
+
weights.push(w1);
|
|
246
|
+
|
|
247
|
+
// Layer 2: hidden -> num_actions
|
|
248
|
+
const w2 = new Float32Array(hiddenDim * this.numActions);
|
|
249
|
+
const scale2 = Math.sqrt(2 / hiddenDim);
|
|
250
|
+
for (let i = 0; i < w2.length; i++) {
|
|
251
|
+
w2[i] = (Math.random() - 0.5) * scale2;
|
|
252
|
+
}
|
|
253
|
+
weights.push(w2);
|
|
254
|
+
|
|
255
|
+
return weights;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
private copyNetwork(weights: Float32Array[]): Float32Array[] {
|
|
259
|
+
return weights.map(w => new Float32Array(w));
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
private forward(state: Float32Array, weights: Float32Array[]): Float32Array {
|
|
263
|
+
const hiddenDim = 64;
|
|
264
|
+
|
|
265
|
+
// Layer 1: ReLU(W1 * x)
|
|
266
|
+
const hidden = new Float32Array(hiddenDim);
|
|
267
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
268
|
+
let sum = 0;
|
|
269
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
270
|
+
sum += state[i] * weights[0][i * hiddenDim + h];
|
|
271
|
+
}
|
|
272
|
+
hidden[h] = Math.max(0, sum); // ReLU
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Layer 2: W2 * hidden (no activation for Q-values)
|
|
276
|
+
const output = new Float32Array(this.numActions);
|
|
277
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
278
|
+
let sum = 0;
|
|
279
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
280
|
+
sum += hidden[h] * weights[1][h * this.numActions + a];
|
|
281
|
+
}
|
|
282
|
+
output[a] = sum;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
return output;
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
private accumulateGradients(
|
|
289
|
+
gradients: Float32Array[],
|
|
290
|
+
state: Float32Array,
|
|
291
|
+
action: number,
|
|
292
|
+
tdError: number
|
|
293
|
+
): void {
|
|
294
|
+
const hiddenDim = 64;
|
|
295
|
+
|
|
296
|
+
// Forward pass to get hidden activations
|
|
297
|
+
const hidden = new Float32Array(hiddenDim);
|
|
298
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
299
|
+
let sum = 0;
|
|
300
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
301
|
+
sum += state[i] * this.qWeights[0][i * hiddenDim + h];
|
|
302
|
+
}
|
|
303
|
+
hidden[h] = Math.max(0, sum);
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
// Gradient for layer 2 (only for selected action)
|
|
307
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
308
|
+
gradients[1][h * this.numActions + action] += hidden[h] * tdError;
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
// Gradient for layer 1 (backprop through ReLU)
|
|
312
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
313
|
+
if (hidden[h] > 0) { // ReLU gradient
|
|
314
|
+
const grad = tdError * this.qWeights[1][h * this.numActions + action];
|
|
315
|
+
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
316
|
+
gradients[0][i * hiddenDim + h] += state[i] * grad;
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
private applyGradients(gradients: Float32Array[], batchSize: number): void {
|
|
323
|
+
const lr = this.config.learningRate / batchSize;
|
|
324
|
+
const beta = 0.9;
|
|
325
|
+
|
|
326
|
+
for (let layer = 0; layer < gradients.length; layer++) {
|
|
327
|
+
for (let i = 0; i < gradients[layer].length; i++) {
|
|
328
|
+
// Gradient clipping
|
|
329
|
+
const grad = Math.max(
|
|
330
|
+
Math.min(gradients[layer][i], this.config.maxGradNorm),
|
|
331
|
+
-this.config.maxGradNorm
|
|
332
|
+
);
|
|
333
|
+
|
|
334
|
+
// Momentum update
|
|
335
|
+
this.qMomentum[layer][i] = beta * this.qMomentum[layer][i] + (1 - beta) * grad;
|
|
336
|
+
this.qWeights[layer][i] += lr * this.qMomentum[layer][i];
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
private sampleBatch(): DQNExperience[] {
|
|
342
|
+
const batch: DQNExperience[] = [];
|
|
343
|
+
const indices = new Set<number>();
|
|
344
|
+
|
|
345
|
+
while (indices.size < this.config.miniBatchSize && indices.size < this.buffer.length) {
|
|
346
|
+
indices.add(Math.floor(Math.random() * this.buffer.length));
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
for (const idx of indices) {
|
|
350
|
+
batch.push(this.buffer[idx]);
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
return batch;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
private hashAction(action: string): number {
|
|
357
|
+
let hash = 0;
|
|
358
|
+
for (let i = 0; i < action.length; i++) {
|
|
359
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
360
|
+
}
|
|
361
|
+
return hash;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
private argmax(values: Float32Array): number {
|
|
365
|
+
let maxIdx = 0;
|
|
366
|
+
let maxVal = values[0];
|
|
367
|
+
for (let i = 1; i < values.length; i++) {
|
|
368
|
+
if (values[i] > maxVal) {
|
|
369
|
+
maxVal = values[i];
|
|
370
|
+
maxIdx = i;
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
return maxIdx;
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
/**
|
|
378
|
+
* Factory function
|
|
379
|
+
*/
|
|
380
|
+
export function createDQN(config?: Partial<DQNConfig>): DQNAlgorithm {
|
|
381
|
+
return new DQNAlgorithm(config);
|
|
382
|
+
}
|