@arcanea/guardian-evolution 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/algorithms/a2c.d.ts +86 -0
- package/dist/algorithms/a2c.d.ts.map +1 -0
- package/dist/algorithms/a2c.js +361 -0
- package/dist/algorithms/a2c.js.map +1 -0
- package/dist/algorithms/curiosity.d.ts +82 -0
- package/dist/algorithms/curiosity.d.ts.map +1 -0
- package/dist/algorithms/curiosity.js +392 -0
- package/dist/algorithms/curiosity.js.map +1 -0
- package/dist/algorithms/decision-transformer.d.ts +82 -0
- package/dist/algorithms/decision-transformer.d.ts.map +1 -0
- package/dist/algorithms/decision-transformer.js +415 -0
- package/dist/algorithms/decision-transformer.js.map +1 -0
- package/dist/algorithms/dqn.d.ts +72 -0
- package/dist/algorithms/dqn.d.ts.map +1 -0
- package/dist/algorithms/dqn.js +303 -0
- package/dist/algorithms/dqn.js.map +1 -0
- package/dist/algorithms/index.d.ts +32 -0
- package/dist/algorithms/index.d.ts.map +1 -0
- package/dist/algorithms/index.js +74 -0
- package/dist/algorithms/index.js.map +1 -0
- package/dist/algorithms/ppo.d.ts +72 -0
- package/dist/algorithms/ppo.d.ts.map +1 -0
- package/dist/algorithms/ppo.js +331 -0
- package/dist/algorithms/ppo.js.map +1 -0
- package/dist/algorithms/q-learning.d.ts +77 -0
- package/dist/algorithms/q-learning.d.ts.map +1 -0
- package/dist/algorithms/q-learning.js +259 -0
- package/dist/algorithms/q-learning.js.map +1 -0
- package/dist/algorithms/sarsa.d.ts +82 -0
- package/dist/algorithms/sarsa.d.ts.map +1 -0
- package/dist/algorithms/sarsa.js +297 -0
- package/dist/algorithms/sarsa.js.map +1 -0
- package/dist/index.d.ts +118 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +201 -0
- package/dist/index.js.map +1 -0
- package/dist/modes/balanced.d.ts +60 -0
- package/dist/modes/balanced.d.ts.map +1 -0
- package/dist/modes/balanced.js +234 -0
- package/dist/modes/balanced.js.map +1 -0
- package/dist/modes/batch.d.ts +82 -0
- package/dist/modes/batch.d.ts.map +1 -0
- package/dist/modes/batch.js +316 -0
- package/dist/modes/batch.js.map +1 -0
- package/dist/modes/edge.d.ts +85 -0
- package/dist/modes/edge.d.ts.map +1 -0
- package/dist/modes/edge.js +310 -0
- package/dist/modes/edge.js.map +1 -0
- package/dist/modes/index.d.ts +55 -0
- package/dist/modes/index.d.ts.map +1 -0
- package/dist/modes/index.js +83 -0
- package/dist/modes/index.js.map +1 -0
- package/dist/modes/real-time.d.ts +58 -0
- package/dist/modes/real-time.d.ts.map +1 -0
- package/dist/modes/real-time.js +196 -0
- package/dist/modes/real-time.js.map +1 -0
- package/dist/modes/research.d.ts +79 -0
- package/dist/modes/research.d.ts.map +1 -0
- package/dist/modes/research.js +389 -0
- package/dist/modes/research.js.map +1 -0
- package/dist/pattern-learner.d.ts +117 -0
- package/dist/pattern-learner.d.ts.map +1 -0
- package/dist/pattern-learner.js +603 -0
- package/dist/pattern-learner.js.map +1 -0
- package/dist/reasoning-bank.d.ts +259 -0
- package/dist/reasoning-bank.d.ts.map +1 -0
- package/dist/reasoning-bank.js +993 -0
- package/dist/reasoning-bank.js.map +1 -0
- package/dist/reasoningbank-adapter.d.ts +168 -0
- package/dist/reasoningbank-adapter.d.ts.map +1 -0
- package/dist/reasoningbank-adapter.js +463 -0
- package/dist/reasoningbank-adapter.js.map +1 -0
- package/dist/sona-integration.d.ts +168 -0
- package/dist/sona-integration.d.ts.map +1 -0
- package/dist/sona-integration.js +316 -0
- package/dist/sona-integration.js.map +1 -0
- package/dist/sona-manager.d.ts +147 -0
- package/dist/sona-manager.d.ts.map +1 -0
- package/dist/sona-manager.js +695 -0
- package/dist/sona-manager.js.map +1 -0
- package/dist/types.d.ts +431 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +11 -0
- package/dist/types.js.map +1 -0
- package/package.json +47 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Decision Transformer
|
|
3
|
+
*
|
|
4
|
+
* Implements sequence modeling approach for RL:
|
|
5
|
+
* - Trajectory as sequence: (s, a, R, s, a, R, ...)
|
|
6
|
+
* - Return-conditioned generation
|
|
7
|
+
* - Causal transformer attention
|
|
8
|
+
* - Offline RL from trajectories
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <10ms per forward pass
|
|
11
|
+
*/
|
|
12
|
+
/**
|
|
13
|
+
* Default Decision Transformer configuration
|
|
14
|
+
*/
|
|
15
|
+
export const DEFAULT_DT_CONFIG = {
|
|
16
|
+
algorithm: 'decision-transformer',
|
|
17
|
+
learningRate: 0.0001,
|
|
18
|
+
gamma: 0.99,
|
|
19
|
+
entropyCoef: 0,
|
|
20
|
+
valueLossCoef: 0,
|
|
21
|
+
maxGradNorm: 1.0,
|
|
22
|
+
epochs: 1,
|
|
23
|
+
miniBatchSize: 64,
|
|
24
|
+
contextLength: 20,
|
|
25
|
+
numHeads: 4,
|
|
26
|
+
numLayers: 2,
|
|
27
|
+
hiddenDim: 64,
|
|
28
|
+
embeddingDim: 32,
|
|
29
|
+
dropout: 0.1,
|
|
30
|
+
};
|
|
31
|
+
/**
|
|
32
|
+
* Decision Transformer Implementation
|
|
33
|
+
*/
|
|
34
|
+
export class DecisionTransformer {
|
|
35
|
+
config;
|
|
36
|
+
// Embeddings
|
|
37
|
+
stateEmbed;
|
|
38
|
+
actionEmbed;
|
|
39
|
+
returnEmbed;
|
|
40
|
+
posEmbed;
|
|
41
|
+
// Transformer layers (simplified)
|
|
42
|
+
attentionWeights;
|
|
43
|
+
ffnWeights;
|
|
44
|
+
// Output head
|
|
45
|
+
actionHead;
|
|
46
|
+
// Training buffer
|
|
47
|
+
trajectoryBuffer = [];
|
|
48
|
+
// Dimensions
|
|
49
|
+
stateDim = 768;
|
|
50
|
+
numActions = 4;
|
|
51
|
+
// Statistics
|
|
52
|
+
updateCount = 0;
|
|
53
|
+
avgLoss = 0;
|
|
54
|
+
constructor(config = {}) {
|
|
55
|
+
this.config = { ...DEFAULT_DT_CONFIG, ...config };
|
|
56
|
+
// Initialize embeddings
|
|
57
|
+
this.stateEmbed = this.initEmbedding(this.stateDim, this.config.embeddingDim);
|
|
58
|
+
this.actionEmbed = this.initEmbedding(this.numActions, this.config.embeddingDim);
|
|
59
|
+
this.returnEmbed = this.initEmbedding(1, this.config.embeddingDim);
|
|
60
|
+
this.posEmbed = this.initEmbedding(this.config.contextLength * 3, this.config.embeddingDim);
|
|
61
|
+
// Initialize transformer layers
|
|
62
|
+
this.attentionWeights = [];
|
|
63
|
+
this.ffnWeights = [];
|
|
64
|
+
for (let l = 0; l < this.config.numLayers; l++) {
|
|
65
|
+
// Attention: Q, K, V, O projections
|
|
66
|
+
this.attentionWeights.push([
|
|
67
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // Q
|
|
68
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // K
|
|
69
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // V
|
|
70
|
+
this.initWeight(this.config.hiddenDim, this.config.embeddingDim), // O
|
|
71
|
+
]);
|
|
72
|
+
// FFN: up and down projections
|
|
73
|
+
this.ffnWeights.push([
|
|
74
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim * 4),
|
|
75
|
+
this.initWeight(this.config.hiddenDim * 4, this.config.embeddingDim),
|
|
76
|
+
]);
|
|
77
|
+
}
|
|
78
|
+
// Action prediction head
|
|
79
|
+
this.actionHead = this.initWeight(this.config.embeddingDim, this.numActions);
|
|
80
|
+
}
|
|
81
|
+
/**
|
|
82
|
+
* Add trajectory for training
|
|
83
|
+
*/
|
|
84
|
+
addTrajectory(trajectory) {
|
|
85
|
+
if (trajectory.isComplete && trajectory.steps.length > 0) {
|
|
86
|
+
this.trajectoryBuffer.push(trajectory);
|
|
87
|
+
// Keep buffer bounded
|
|
88
|
+
if (this.trajectoryBuffer.length > 1000) {
|
|
89
|
+
this.trajectoryBuffer = this.trajectoryBuffer.slice(-1000);
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Train on buffered trajectories
|
|
95
|
+
* Target: <10ms per batch
|
|
96
|
+
*/
|
|
97
|
+
train() {
|
|
98
|
+
const startTime = performance.now();
|
|
99
|
+
if (this.trajectoryBuffer.length === 0) {
|
|
100
|
+
return { loss: 0, accuracy: 0 };
|
|
101
|
+
}
|
|
102
|
+
// Sample mini-batch of trajectories
|
|
103
|
+
const batchSize = Math.min(this.config.miniBatchSize, this.trajectoryBuffer.length);
|
|
104
|
+
const batch = [];
|
|
105
|
+
for (let i = 0; i < batchSize; i++) {
|
|
106
|
+
const idx = Math.floor(Math.random() * this.trajectoryBuffer.length);
|
|
107
|
+
batch.push(this.trajectoryBuffer[idx]);
|
|
108
|
+
}
|
|
109
|
+
let totalLoss = 0;
|
|
110
|
+
let correct = 0;
|
|
111
|
+
let total = 0;
|
|
112
|
+
for (const trajectory of batch) {
|
|
113
|
+
// Create sequence from trajectory
|
|
114
|
+
const sequence = this.createSequence(trajectory);
|
|
115
|
+
if (sequence.length < 2)
|
|
116
|
+
continue;
|
|
117
|
+
// Forward pass and compute loss
|
|
118
|
+
for (let t = 1; t < sequence.length; t++) {
|
|
119
|
+
// Use context up to position t
|
|
120
|
+
const context = sequence.slice(Math.max(0, t - this.config.contextLength), t);
|
|
121
|
+
const target = sequence[t];
|
|
122
|
+
// Predict action
|
|
123
|
+
const predicted = this.forward(context);
|
|
124
|
+
const predictedAction = this.argmax(predicted);
|
|
125
|
+
// Cross-entropy loss
|
|
126
|
+
const loss = -Math.log(predicted[target.action] + 1e-8);
|
|
127
|
+
totalLoss += loss;
|
|
128
|
+
if (predictedAction === target.action) {
|
|
129
|
+
correct++;
|
|
130
|
+
}
|
|
131
|
+
total++;
|
|
132
|
+
// Gradient update (simplified)
|
|
133
|
+
this.updateWeights(context, target.action, predicted);
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
this.updateCount++;
|
|
137
|
+
this.avgLoss = total > 0 ? totalLoss / total : 0;
|
|
138
|
+
const elapsed = performance.now() - startTime;
|
|
139
|
+
if (elapsed > 10) {
|
|
140
|
+
console.warn(`DT training exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
141
|
+
}
|
|
142
|
+
return {
|
|
143
|
+
loss: this.avgLoss,
|
|
144
|
+
accuracy: total > 0 ? correct / total : 0,
|
|
145
|
+
};
|
|
146
|
+
}
|
|
147
|
+
/**
|
|
148
|
+
* Get action conditioned on target return
|
|
149
|
+
*/
|
|
150
|
+
getAction(states, actions, targetReturn) {
|
|
151
|
+
// Build sequence
|
|
152
|
+
const sequence = [];
|
|
153
|
+
let returnToGo = targetReturn;
|
|
154
|
+
for (let i = 0; i < states.length; i++) {
|
|
155
|
+
sequence.push({
|
|
156
|
+
returnToGo,
|
|
157
|
+
state: states[i],
|
|
158
|
+
action: actions[i] ?? 0,
|
|
159
|
+
timestep: i,
|
|
160
|
+
});
|
|
161
|
+
// Decrease return-to-go by estimated reward
|
|
162
|
+
if (i > 0) {
|
|
163
|
+
returnToGo -= 0.1; // Default reward decrement for inference
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
// Forward pass
|
|
167
|
+
const logits = this.forward(sequence);
|
|
168
|
+
return this.argmax(logits);
|
|
169
|
+
}
|
|
170
|
+
/**
|
|
171
|
+
* Forward pass through transformer
|
|
172
|
+
*/
|
|
173
|
+
forward(sequence) {
|
|
174
|
+
// Embed sequence elements
|
|
175
|
+
const seqLen = Math.min(sequence.length, this.config.contextLength);
|
|
176
|
+
const embedDim = this.config.embeddingDim;
|
|
177
|
+
// Initialize hidden states (simplified: stack all modalities)
|
|
178
|
+
const hidden = new Float32Array(seqLen * 3 * embedDim);
|
|
179
|
+
for (let t = 0; t < seqLen; t++) {
|
|
180
|
+
const entry = sequence[sequence.length - seqLen + t];
|
|
181
|
+
const baseIdx = t * 3 * embedDim;
|
|
182
|
+
// Embed return
|
|
183
|
+
for (let d = 0; d < embedDim; d++) {
|
|
184
|
+
hidden[baseIdx + d] = entry.returnToGo * this.returnEmbed[d];
|
|
185
|
+
}
|
|
186
|
+
// Embed state
|
|
187
|
+
for (let d = 0; d < embedDim; d++) {
|
|
188
|
+
let stateSum = 0;
|
|
189
|
+
for (let s = 0; s < Math.min(entry.state.length, this.stateDim); s++) {
|
|
190
|
+
stateSum += entry.state[s] * this.stateEmbed[s * embedDim + d];
|
|
191
|
+
}
|
|
192
|
+
hidden[baseIdx + embedDim + d] = stateSum;
|
|
193
|
+
}
|
|
194
|
+
// Embed action
|
|
195
|
+
for (let d = 0; d < embedDim; d++) {
|
|
196
|
+
hidden[baseIdx + 2 * embedDim + d] = this.actionEmbed[entry.action * embedDim + d];
|
|
197
|
+
}
|
|
198
|
+
// Add positional embedding
|
|
199
|
+
for (let d = 0; d < 3 * embedDim; d++) {
|
|
200
|
+
hidden[baseIdx + d] += this.posEmbed[t * 3 * embedDim + d] || 0;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
// Apply transformer layers
|
|
204
|
+
for (let l = 0; l < this.config.numLayers; l++) {
|
|
205
|
+
hidden.set(this.transformerLayer(hidden, seqLen * 3, l));
|
|
206
|
+
}
|
|
207
|
+
// Extract last state position embedding for action prediction
|
|
208
|
+
const lastStateIdx = (seqLen * 3 - 2) * embedDim;
|
|
209
|
+
const lastState = hidden.slice(lastStateIdx, lastStateIdx + embedDim);
|
|
210
|
+
// Action prediction
|
|
211
|
+
const logits = new Float32Array(this.numActions);
|
|
212
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
213
|
+
let sum = 0;
|
|
214
|
+
for (let d = 0; d < embedDim; d++) {
|
|
215
|
+
sum += lastState[d] * this.actionHead[d * this.numActions + a];
|
|
216
|
+
}
|
|
217
|
+
logits[a] = sum;
|
|
218
|
+
}
|
|
219
|
+
return this.softmax(logits);
|
|
220
|
+
}
|
|
221
|
+
/**
|
|
222
|
+
* Get statistics
|
|
223
|
+
*/
|
|
224
|
+
getStats() {
|
|
225
|
+
return {
|
|
226
|
+
updateCount: this.updateCount,
|
|
227
|
+
bufferSize: this.trajectoryBuffer.length,
|
|
228
|
+
avgLoss: this.avgLoss,
|
|
229
|
+
contextLength: this.config.contextLength,
|
|
230
|
+
numLayers: this.config.numLayers,
|
|
231
|
+
};
|
|
232
|
+
}
|
|
233
|
+
// ==========================================================================
|
|
234
|
+
// Private Methods
|
|
235
|
+
// ==========================================================================
|
|
236
|
+
initEmbedding(inputDim, outputDim) {
|
|
237
|
+
const embed = new Float32Array(inputDim * outputDim);
|
|
238
|
+
const scale = Math.sqrt(2 / inputDim);
|
|
239
|
+
for (let i = 0; i < embed.length; i++) {
|
|
240
|
+
embed[i] = (Math.random() - 0.5) * scale;
|
|
241
|
+
}
|
|
242
|
+
return embed;
|
|
243
|
+
}
|
|
244
|
+
initWeight(inputDim, outputDim) {
|
|
245
|
+
const weight = new Float32Array(inputDim * outputDim);
|
|
246
|
+
const scale = Math.sqrt(2 / inputDim);
|
|
247
|
+
for (let i = 0; i < weight.length; i++) {
|
|
248
|
+
weight[i] = (Math.random() - 0.5) * scale;
|
|
249
|
+
}
|
|
250
|
+
return weight;
|
|
251
|
+
}
|
|
252
|
+
createSequence(trajectory) {
|
|
253
|
+
const sequence = [];
|
|
254
|
+
// Compute returns-to-go
|
|
255
|
+
const rewards = trajectory.steps.map(s => s.reward);
|
|
256
|
+
const returnsToGo = new Array(rewards.length).fill(0);
|
|
257
|
+
let cumReturn = 0;
|
|
258
|
+
for (let t = rewards.length - 1; t >= 0; t--) {
|
|
259
|
+
cumReturn = rewards[t] + this.config.gamma * cumReturn;
|
|
260
|
+
returnsToGo[t] = cumReturn;
|
|
261
|
+
}
|
|
262
|
+
// Create sequence entries
|
|
263
|
+
for (let t = 0; t < trajectory.steps.length; t++) {
|
|
264
|
+
sequence.push({
|
|
265
|
+
returnToGo: returnsToGo[t],
|
|
266
|
+
state: trajectory.steps[t].stateAfter,
|
|
267
|
+
action: this.hashAction(trajectory.steps[t].action),
|
|
268
|
+
timestep: t,
|
|
269
|
+
});
|
|
270
|
+
}
|
|
271
|
+
return sequence;
|
|
272
|
+
}
|
|
273
|
+
transformerLayer(hidden, seqLen, layerIdx) {
|
|
274
|
+
const embedDim = this.config.embeddingDim;
|
|
275
|
+
const hiddenDim = this.config.hiddenDim;
|
|
276
|
+
const numHeads = this.config.numHeads;
|
|
277
|
+
const headDim = hiddenDim / numHeads;
|
|
278
|
+
const output = new Float32Array(hidden.length);
|
|
279
|
+
// Self-attention (simplified causal)
|
|
280
|
+
const [Wq, Wk, Wv, Wo] = this.attentionWeights[layerIdx];
|
|
281
|
+
// Compute Q, K, V for all positions
|
|
282
|
+
const Q = new Float32Array(seqLen * hiddenDim);
|
|
283
|
+
const K = new Float32Array(seqLen * hiddenDim);
|
|
284
|
+
const V = new Float32Array(seqLen * hiddenDim);
|
|
285
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
286
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
287
|
+
let qSum = 0, kSum = 0, vSum = 0;
|
|
288
|
+
for (let d = 0; d < embedDim; d++) {
|
|
289
|
+
const hiddenVal = hidden[pos * embedDim + d];
|
|
290
|
+
qSum += hiddenVal * Wq[d * hiddenDim + h];
|
|
291
|
+
kSum += hiddenVal * Wk[d * hiddenDim + h];
|
|
292
|
+
vSum += hiddenVal * Wv[d * hiddenDim + h];
|
|
293
|
+
}
|
|
294
|
+
Q[pos * hiddenDim + h] = qSum;
|
|
295
|
+
K[pos * hiddenDim + h] = kSum;
|
|
296
|
+
V[pos * hiddenDim + h] = vSum;
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
// Causal attention
|
|
300
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
301
|
+
// Compute attention scores for current position
|
|
302
|
+
const scores = new Float32Array(pos + 1);
|
|
303
|
+
for (let k = 0; k <= pos; k++) {
|
|
304
|
+
let score = 0;
|
|
305
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
306
|
+
score += Q[pos * hiddenDim + h] * K[k * hiddenDim + h];
|
|
307
|
+
}
|
|
308
|
+
scores[k] = score / Math.sqrt(headDim);
|
|
309
|
+
}
|
|
310
|
+
// Softmax
|
|
311
|
+
const maxScore = Math.max(...scores);
|
|
312
|
+
let sumExp = 0;
|
|
313
|
+
for (let k = 0; k <= pos; k++) {
|
|
314
|
+
scores[k] = Math.exp(scores[k] - maxScore);
|
|
315
|
+
sumExp += scores[k];
|
|
316
|
+
}
|
|
317
|
+
for (let k = 0; k <= pos; k++) {
|
|
318
|
+
scores[k] /= sumExp;
|
|
319
|
+
}
|
|
320
|
+
// Weighted sum of values
|
|
321
|
+
const attnOut = new Float32Array(hiddenDim);
|
|
322
|
+
for (let k = 0; k <= pos; k++) {
|
|
323
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
324
|
+
attnOut[h] += scores[k] * V[k * hiddenDim + h];
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
// Output projection
|
|
328
|
+
for (let d = 0; d < embedDim; d++) {
|
|
329
|
+
let sum = hidden[pos * embedDim + d]; // Residual
|
|
330
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
331
|
+
sum += attnOut[h] * Wo[h * embedDim + d];
|
|
332
|
+
}
|
|
333
|
+
output[pos * embedDim + d] = sum;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
// FFN with residual
|
|
337
|
+
const [Wup, Wdown] = this.ffnWeights[layerIdx];
|
|
338
|
+
const ffnHiddenDim = hiddenDim * 4;
|
|
339
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
340
|
+
// Up projection + GELU
|
|
341
|
+
const ffnHidden = new Float32Array(ffnHiddenDim);
|
|
342
|
+
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
343
|
+
let sum = 0;
|
|
344
|
+
for (let d = 0; d < embedDim; d++) {
|
|
345
|
+
sum += output[pos * embedDim + d] * Wup[d * ffnHiddenDim + h];
|
|
346
|
+
}
|
|
347
|
+
// GELU approximation
|
|
348
|
+
ffnHidden[h] = sum * 0.5 * (1 + Math.tanh(0.7978845608 * (sum + 0.044715 * sum * sum * sum)));
|
|
349
|
+
}
|
|
350
|
+
// Down projection
|
|
351
|
+
for (let d = 0; d < embedDim; d++) {
|
|
352
|
+
let sum = output[pos * embedDim + d]; // Residual
|
|
353
|
+
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
354
|
+
sum += ffnHidden[h] * Wdown[h * embedDim + d];
|
|
355
|
+
}
|
|
356
|
+
output[pos * embedDim + d] = sum;
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
return output;
|
|
360
|
+
}
|
|
361
|
+
updateWeights(context, targetAction, predicted) {
|
|
362
|
+
// Simplified gradient update for action head
|
|
363
|
+
const lr = this.config.learningRate;
|
|
364
|
+
const embedDim = this.config.embeddingDim;
|
|
365
|
+
// Gradient of cross-entropy
|
|
366
|
+
const grad = new Float32Array(this.numActions);
|
|
367
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
368
|
+
grad[a] = predicted[a] - (a === targetAction ? 1 : 0);
|
|
369
|
+
}
|
|
370
|
+
// Update action head (simplified)
|
|
371
|
+
for (let d = 0; d < embedDim; d++) {
|
|
372
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
373
|
+
this.actionHead[d * this.numActions + a] -= lr * grad[a] * 0.1;
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
softmax(logits) {
|
|
378
|
+
const max = Math.max(...logits);
|
|
379
|
+
const exps = new Float32Array(logits.length);
|
|
380
|
+
let sum = 0;
|
|
381
|
+
for (let i = 0; i < logits.length; i++) {
|
|
382
|
+
exps[i] = Math.exp(logits[i] - max);
|
|
383
|
+
sum += exps[i];
|
|
384
|
+
}
|
|
385
|
+
for (let i = 0; i < exps.length; i++) {
|
|
386
|
+
exps[i] /= sum;
|
|
387
|
+
}
|
|
388
|
+
return exps;
|
|
389
|
+
}
|
|
390
|
+
argmax(values) {
|
|
391
|
+
let maxIdx = 0;
|
|
392
|
+
let maxVal = values[0];
|
|
393
|
+
for (let i = 1; i < values.length; i++) {
|
|
394
|
+
if (values[i] > maxVal) {
|
|
395
|
+
maxVal = values[i];
|
|
396
|
+
maxIdx = i;
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
return maxIdx;
|
|
400
|
+
}
|
|
401
|
+
hashAction(action) {
|
|
402
|
+
let hash = 0;
|
|
403
|
+
for (let i = 0; i < action.length; i++) {
|
|
404
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
405
|
+
}
|
|
406
|
+
return hash;
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
/**
|
|
410
|
+
* Factory function
|
|
411
|
+
*/
|
|
412
|
+
export function createDecisionTransformer(config) {
|
|
413
|
+
return new DecisionTransformer(config);
|
|
414
|
+
}
|
|
415
|
+
//# sourceMappingURL=decision-transformer.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"decision-transformer.js","sourceRoot":"","sources":["../../src/algorithms/decision-transformer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAQH;;GAEG;AACH,MAAM,CAAC,MAAM,iBAAiB,GAA8B;IAC1D,SAAS,EAAE,sBAAsB;IACjC,YAAY,EAAE,MAAM;IACpB,KAAK,EAAE,IAAI;IACX,WAAW,EAAE,CAAC;IACd,aAAa,EAAE,CAAC;IAChB,WAAW,EAAE,GAAG;IAChB,MAAM,EAAE,CAAC;IACT,aAAa,EAAE,EAAE;IACjB,aAAa,EAAE,EAAE;IACjB,QAAQ,EAAE,CAAC;IACX,SAAS,EAAE,CAAC;IACZ,SAAS,EAAE,EAAE;IACb,YAAY,EAAE,EAAE;IAChB,OAAO,EAAE,GAAG;CACb,CAAC;AAYF;;GAEG;AACH,MAAM,OAAO,mBAAmB;IACtB,MAAM,CAA4B;IAE1C,aAAa;IACL,UAAU,CAAe;IACzB,WAAW,CAAe;IAC1B,WAAW,CAAe;IAC1B,QAAQ,CAAe;IAE/B,kCAAkC;IAC1B,gBAAgB,CAAmB;IACnC,UAAU,CAAmB;IAErC,cAAc;IACN,UAAU,CAAe;IAEjC,kBAAkB;IACV,gBAAgB,GAAiB,EAAE,CAAC;IAE5C,aAAa;IACL,QAAQ,GAAG,GAAG,CAAC;IACf,UAAU,GAAG,CAAC,CAAC;IAEvB,aAAa;IACL,WAAW,GAAG,CAAC,CAAC;IAChB,OAAO,GAAG,CAAC,CAAC;IAEpB,YAAY,SAA6C,EAAE;QACzD,IAAI,CAAC,MAAM,GAAG,EAAE,GAAG,iBAAiB,EAAE,GAAG,MAAM,EAAE,CAAC;QAElD,wBAAwB;QACxB,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC,CAAC;QAC9E,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC,CAAC;QACjF,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC,CAAC;QACnE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,MAAM,CAAC,aAAa,GAAG,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC,CAAC;QAE5F,gCAAgC;QAChC,IAAI,CAAC,gBAAgB,GAAG,EAAE,CAAC;QAC3B,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;QAErB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YAC/C,oCAAoC;YACpC,IAAI,CAAC,gBAAgB,CAAC,IAAI,CAAC;gBACzB,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,YAAY,EAAE,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI;gBACtE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,YAAY,EAAE,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI;gBACtE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,YAAY,EAAE,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI;gBACtE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI;aACvE,CAAC,CAAC;YAEH,+BAA+B;YAC/B,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC;gBACnB,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,YAAY,EAAE,IAAI,CAAC,MAAM,CAAC,SAAS,GAAG,CAAC,CAAC;gBACpE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,SAAS,GAAG,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;aACrE,CAAC,CAAC;QACL,CAAC;QAED,yBAAyB;QACzB,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,YAAY,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;IAC/E,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,UAAsB;QAClC,IAAI,UAAU,CAAC,UAAU,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YACzD,IAAI,CAAC,gBAAgB,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YAEvC,sBAAsB;YACtB,IAAI,IAAI,CAAC,gBAAgB,CAAC,MAAM,GAAG,IAAI,EAAE,CAAC;gBACxC,IAAI,CAAC,gBAAgB,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC;YAC7D,CAAC;QACH,CAAC;IACH,CAAC;IAED;;;OAGG;IACH,KAAK;QACH,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,IAAI,IAAI,CAAC,gBAAgB,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YACvC,OAAO,EAAE,IAAI,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,CAAC;QAClC,CAAC;QAED,oCAAoC;QACpC,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,aAAa,EAAE,IAAI,CAAC,gBAAgB,CAAC,MAAM,CAAC,CAAC;QACpF,MAAM,KAAK,GAAiB,EAAE,CAAC;QAE/B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,MAAM,GAAG,GAAG,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,CAAC,CAAC;YACrE,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,gBAAgB,CAAC,GAAG,CAAC,CAAC,CAAC;QACzC,CAAC;QAED,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,IAAI,OAAO,GAAG,CAAC,CAAC;QAChB,IAAI,KAAK,GAAG,CAAC,CAAC;QAEd,KAAK,MAAM,UAAU,IAAI,KAAK,EAAE,CAAC;YAC/B,kCAAkC;YAClC,MAAM,QAAQ,GAAG,IAAI,CAAC,cAAc,CAAC,UAAU,CAAC,CAAC;YAEjD,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC;gBAAE,SAAS;YAElC,gCAAgC;YAChC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,+BAA+B;gBAC/B,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,aAAa,CAAC,EAAE,CAAC,CAAC,CAAC;gBAC9E,MAAM,MAAM,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;gBAE3B,iBAAiB;gBACjB,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC;gBACxC,MAAM,eAAe,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;gBAE/C,qBAAqB;gBACrB,MAAM,IAAI,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC,MAAM,CAAC,GAAG,IAAI,CAAC,CAAC;gBACxD,SAAS,IAAI,IAAI,CAAC;gBAElB,IAAI,eAAe,KAAK,MAAM,CAAC,MAAM,EAAE,CAAC;oBACtC,OAAO,EAAE,CAAC;gBACZ,CAAC;gBACD,KAAK,EAAE,CAAC;gBAER,+BAA+B;gBAC/B,IAAI,CAAC,aAAa,CAAC,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;YACxD,CAAC;QACH,CAAC;QAED,IAAI,CAAC,WAAW,EAAE,CAAC;QACnB,IAAI,CAAC,OAAO,GAAG,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC,SAAS,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QAEjD,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,EAAE,EAAE,CAAC;YACjB,OAAO,CAAC,IAAI,CAAC,gCAAgC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;QAC9E,CAAC;QAED,OAAO;YACL,IAAI,EAAE,IAAI,CAAC,OAAO;YAClB,QAAQ,EAAE,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;SAC1C,CAAC;IACJ,CAAC;IAED;;OAEG;IACH,SAAS,CACP,MAAsB,EACtB,OAAiB,EACjB,YAAoB;QAEpB,iBAAiB;QACjB,MAAM,QAAQ,GAAoB,EAAE,CAAC;QACrC,IAAI,UAAU,GAAG,YAAY,CAAC;QAE9B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,QAAQ,CAAC,IAAI,CAAC;gBACZ,UAAU;gBACV,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC;gBAChB,MAAM,EAAE,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC;gBACvB,QAAQ,EAAE,CAAC;aACZ,CAAC,CAAC;YAEH,4CAA4C;YAC5C,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC;gBACV,UAAU,IAAI,GAAG,CAAC,CAAC,yCAAyC;YAC9D,CAAC;QACH,CAAC;QAED,eAAe;QACf,MAAM,MAAM,GAAG,IAAI,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC;QACtC,OAAO,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;IAC7B,CAAC;IAED;;OAEG;IACH,OAAO,CAAC,QAAyB;QAC/B,0BAA0B;QAC1B,MAAM,MAAM,GAAG,IAAI,CAAC,GAAG,CAAC,QAAQ,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACpE,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QAE1C,8DAA8D;QAC9D,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,MAAM,GAAG,CAAC,GAAG,QAAQ,CAAC,CAAC;QAEvD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YAChC,MAAM,KAAK,GAAG,QAAQ,CAAC,QAAQ,CAAC,MAAM,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;YACrD,MAAM,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC;YAEjC,eAAe;YACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,MAAM,CAAC,OAAO,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,UAAU,GAAG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAC/D,CAAC;YAED,cAAc;YACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,IAAI,QAAQ,GAAG,CAAC,CAAC;gBACjB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;oBACrE,QAAQ,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC;gBACjE,CAAC;gBACD,MAAM,CAAC,OAAO,GAAG,QAAQ,GAAG,CAAC,CAAC,GAAG,QAAQ,CAAC;YAC5C,CAAC;YAED,eAAe;YACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,MAAM,CAAC,OAAO,GAAG,CAAC,GAAG,QAAQ,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,MAAM,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC;YACrF,CAAC;YAED,2BAA2B;YAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,MAAM,CAAC,OAAO,GAAG,CAAC,CAAC,IAAI,IAAI,CAAC,QAAQ,CAAC,CAAC,GAAG,CAAC,GAAG,QAAQ,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC;YAClE,CAAC;QACH,CAAC;QAED,2BAA2B;QAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;YAC/C,MAAM,CAAC,GAAG,CAAC,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3D,CAAC;QAED,8DAA8D;QAC9D,MAAM,YAAY,GAAG,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,QAAQ,CAAC;QACjD,MAAM,SAAS,GAAG,MAAM,CAAC,KAAK,CAAC,YAAY,EAAE,YAAY,GAAG,QAAQ,CAAC,CAAC;QAEtE,oBAAoB;QACpB,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,GAAG,IAAI,SAAS,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC;YACjE,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;QAClB,CAAC;QAED,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC;IAC9B,CAAC;IAED;;OAEG;IACH,QAAQ;QACN,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,WAAW;YAC7B,UAAU,EAAE,IAAI,CAAC,gBAAgB,CAAC,MAAM;YACxC,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,aAAa,EAAE,IAAI,CAAC,MAAM,CAAC,aAAa;YACxC,SAAS,EAAE,IAAI,CAAC,MAAM,CAAC,SAAS;SACjC,CAAC;IACJ,CAAC;IAED,6EAA6E;IAC7E,kBAAkB;IAClB,6EAA6E;IAErE,aAAa,CAAC,QAAgB,EAAE,SAAiB;QACvD,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,QAAQ,GAAG,SAAS,CAAC,CAAC;QACrD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC;QACtC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACtC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC;QAC3C,CAAC;QACD,OAAO,KAAK,CAAC;IACf,CAAC;IAEO,UAAU,CAAC,QAAgB,EAAE,SAAiB;QACpD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,QAAQ,GAAG,SAAS,CAAC,CAAC;QACtD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC;QACtC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC;QAC5C,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,cAAc,CAAC,UAAsB;QAC3C,MAAM,QAAQ,GAAoB,EAAE,CAAC;QAErC,wBAAwB;QACxB,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QACpD,MAAM,WAAW,GAAG,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QACtD,IAAI,SAAS,GAAG,CAAC,CAAC;QAElB,KAAK,IAAI,CAAC,GAAG,OAAO,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YAC7C,SAAS,GAAG,OAAO,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,GAAG,SAAS,CAAC;YACvD,WAAW,CAAC,CAAC,CAAC,GAAG,SAAS,CAAC;QAC7B,CAAC;QAED,0BAA0B;QAC1B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACjD,QAAQ,CAAC,IAAI,CAAC;gBACZ,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC;gBAC1B,KAAK,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,UAAU;gBACrC,MAAM,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;gBACnD,QAAQ,EAAE,CAAC;aACZ,CAAC,CAAC;QACL,CAAC;QAED,OAAO,QAAQ,CAAC;IAClB,CAAC;IAEO,gBAAgB,CACtB,MAAoB,EACpB,MAAc,EACd,QAAgB;QAEhB,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QAC1C,MAAM,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC;QACxC,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,QAAQ,CAAC;QACtC,MAAM,OAAO,GAAG,SAAS,GAAG,QAAQ,CAAC;QAErC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;QAE/C,qCAAqC;QACrC,MAAM,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,GAAG,IAAI,CAAC,gBAAgB,CAAC,QAAQ,CAAC,CAAC;QAEzD,oCAAoC;QACpC,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC;QAC/C,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC;QAC/C,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC;QAE/C,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC;YACtC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;gBACnC,IAAI,IAAI,GAAG,CAAC,EAAE,IAAI,GAAG,CAAC,EAAE,IAAI,GAAG,CAAC,CAAC;gBACjC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;oBAClC,MAAM,SAAS,GAAG,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC;oBAC7C,IAAI,IAAI,SAAS,GAAG,EAAE,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;oBAC1C,IAAI,IAAI,SAAS,GAAG,EAAE,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;oBAC1C,IAAI,IAAI,SAAS,GAAG,EAAE,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;gBAC5C,CAAC;gBACD,CAAC,CAAC,GAAG,GAAG,SAAS,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC;gBAC9B,CAAC,CAAC,GAAG,GAAG,SAAS,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC;gBAC9B,CAAC,CAAC,GAAG,GAAG,SAAS,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC;YAChC,CAAC;QACH,CAAC;QAED,mBAAmB;QACnB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC;YACtC,gDAAgD;YAChD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;YACzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC9B,IAAI,KAAK,GAAG,CAAC,CAAC;gBACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;oBACnC,KAAK,IAAI,CAAC,CAAC,GAAG,GAAG,SAAS,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;gBACzD,CAAC;gBACD,MAAM,CAAC,CAAC,CAAC,GAAG,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YACzC,CAAC;YAED,UAAU;YACV,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;YACrC,IAAI,MAAM,GAAG,CAAC,CAAC;YACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC9B,MAAM,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC;gBAC3C,MAAM,IAAI,MAAM,CAAC,CAAC,CAAC,CAAC;YACtB,CAAC;YACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC9B,MAAM,CAAC,CAAC,CAAC,IAAI,MAAM,CAAC;YACtB,CAAC;YAED,yBAAyB;YACzB,MAAM,OAAO,GAAG,IAAI,YAAY,CAAC,SAAS,CAAC,CAAC;YAC5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC9B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;oBACnC,OAAO,CAAC,CAAC,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,CAAC;gBACjD,CAAC;YACH,CAAC;YAED,oBAAoB;YACpB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,IAAI,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC,WAAW;gBACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;oBACnC,GAAG,IAAI,OAAO,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,CAAC,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC;gBAC3C,CAAC;gBACD,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC;YACnC,CAAC;QACH,CAAC;QAED,oBAAoB;QACpB,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAC,CAAC;QAC/C,MAAM,YAAY,GAAG,SAAS,GAAG,CAAC,CAAC;QAEnC,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC;YACtC,uBAAuB;YACvB,MAAM,SAAS,GAAG,IAAI,YAAY,CAAC,YAAY,CAAC,CAAC;YACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,IAAI,GAAG,GAAG,CAAC,CAAC;gBACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;oBAClC,GAAG,IAAI,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,YAAY,GAAG,CAAC,CAAC,CAAC;gBAChE,CAAC;gBACD,qBAAqB;gBACrB,SAAS,CAAC,CAAC,CAAC,GAAG,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,IAAI,CAAC,YAAY,GAAG,CAAC,GAAG,GAAG,QAAQ,GAAG,GAAG,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC;YAChG,CAAC;YAED,kBAAkB;YAClB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,IAAI,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC,WAAW;gBACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,EAAE,CAAC,EAAE,EAAE,CAAC;oBACtC,GAAG,IAAI,SAAS,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,GAAG,QAAQ,GAAG,CAAC,CAAC,CAAC;gBAChD,CAAC;gBACD,MAAM,CAAC,GAAG,GAAG,QAAQ,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC;YACnC,CAAC;QACH,CAAC;QAED,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,aAAa,CACnB,OAAwB,EACxB,YAAoB,EACpB,SAAuB;QAEvB,6CAA6C;QAC7C,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QACpC,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QAE1C,4BAA4B;QAC5B,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAC/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,IAAI,CAAC,CAAC,CAAC,GAAG,SAAS,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,KAAK,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACxD,CAAC;QAED,kCAAkC;QAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,IAAI,CAAC,UAAU,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC,IAAI,EAAE,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;YACjE,CAAC;QACH,CAAC;IACH,CAAC;IAEO,OAAO,CAAC,MAAoB;QAClC,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;QAChC,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;QAC7C,IAAI,GAAG,GAAG,CAAC,CAAC;QAEZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;YACpC,GAAG,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC;QACjB,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACrC,IAAI,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC;QACjB,CAAC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,MAAM,CAAC,MAAoB;QACjC,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,IAAI,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC;gBACvB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACnB,MAAM,GAAG,CAAC,CAAC;YACb,CAAC;QACH,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,UAAU,CAAC,MAAc;QAC/B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,GAAG,CAAC,IAAI,GAAG,EAAE,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9D,CAAC;QACD,OAAO,IAAI,CAAC;IACd,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,yBAAyB,CACvC,MAA2C;IAE3C,OAAO,IAAI,mBAAmB,CAAC,MAAM,CAAC,CAAC;AACzC,CAAC"}
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Deep Q-Network (DQN)
|
|
3
|
+
*
|
|
4
|
+
* Implements DQN with enhancements:
|
|
5
|
+
* - Experience replay
|
|
6
|
+
* - Target network
|
|
7
|
+
* - Double DQN (optional)
|
|
8
|
+
* - Dueling architecture (optional)
|
|
9
|
+
* - Epsilon-greedy exploration
|
|
10
|
+
*
|
|
11
|
+
* Performance Target: <10ms per update step
|
|
12
|
+
*/
|
|
13
|
+
import type { DQNConfig, Trajectory } from '../types.js';
|
|
14
|
+
/**
|
|
15
|
+
* Default DQN configuration
|
|
16
|
+
*/
|
|
17
|
+
export declare const DEFAULT_DQN_CONFIG: DQNConfig;
|
|
18
|
+
/**
|
|
19
|
+
* DQN Algorithm Implementation
|
|
20
|
+
*/
|
|
21
|
+
export declare class DQNAlgorithm {
|
|
22
|
+
private config;
|
|
23
|
+
private qWeights;
|
|
24
|
+
private targetWeights;
|
|
25
|
+
private qMomentum;
|
|
26
|
+
private buffer;
|
|
27
|
+
private bufferIdx;
|
|
28
|
+
private epsilon;
|
|
29
|
+
private stepCount;
|
|
30
|
+
private numActions;
|
|
31
|
+
private inputDim;
|
|
32
|
+
private updateCount;
|
|
33
|
+
private avgLoss;
|
|
34
|
+
constructor(config?: Partial<DQNConfig>);
|
|
35
|
+
/**
|
|
36
|
+
* Add experience from trajectory
|
|
37
|
+
*/
|
|
38
|
+
addExperience(trajectory: Trajectory): void;
|
|
39
|
+
/**
|
|
40
|
+
* Perform DQN update
|
|
41
|
+
* Target: <10ms
|
|
42
|
+
*/
|
|
43
|
+
update(): {
|
|
44
|
+
loss: number;
|
|
45
|
+
epsilon: number;
|
|
46
|
+
};
|
|
47
|
+
/**
|
|
48
|
+
* Get action using epsilon-greedy
|
|
49
|
+
*/
|
|
50
|
+
getAction(state: Float32Array, explore?: boolean): number;
|
|
51
|
+
/**
|
|
52
|
+
* Get Q-values for a state
|
|
53
|
+
*/
|
|
54
|
+
getQValues(state: Float32Array): Float32Array;
|
|
55
|
+
/**
|
|
56
|
+
* Get statistics
|
|
57
|
+
*/
|
|
58
|
+
getStats(): Record<string, number>;
|
|
59
|
+
private initializeNetwork;
|
|
60
|
+
private copyNetwork;
|
|
61
|
+
private forward;
|
|
62
|
+
private accumulateGradients;
|
|
63
|
+
private applyGradients;
|
|
64
|
+
private sampleBatch;
|
|
65
|
+
private hashAction;
|
|
66
|
+
private argmax;
|
|
67
|
+
}
|
|
68
|
+
/**
|
|
69
|
+
* Factory function
|
|
70
|
+
*/
|
|
71
|
+
export declare function createDQN(config?: Partial<DQNConfig>): DQNAlgorithm;
|
|
72
|
+
//# sourceMappingURL=dqn.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"dqn.d.ts","sourceRoot":"","sources":["../../src/algorithms/dqn.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAEH,OAAO,KAAK,EACV,SAAS,EACT,UAAU,EAEX,MAAM,aAAa,CAAC;AAErB;;GAEG;AACH,eAAO,MAAM,kBAAkB,EAAE,SAgBhC,CAAC;AAaF;;GAEG;AACH,qBAAa,YAAY;IACvB,OAAO,CAAC,MAAM,CAAY;IAG1B,OAAO,CAAC,QAAQ,CAAiB;IACjC,OAAO,CAAC,aAAa,CAAiB;IAGtC,OAAO,CAAC,SAAS,CAAiB;IAGlC,OAAO,CAAC,MAAM,CAAuB;IACrC,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,UAAU,CAAK;IACvB,OAAO,CAAC,QAAQ,CAAO;IAGvB,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,OAAO,CAAK;gBAER,MAAM,GAAE,OAAO,CAAC,SAAS,CAAM;IAU3C;;OAEG;IACH,aAAa,CAAC,UAAU,EAAE,UAAU,GAAG,IAAI;IAyB3C;;;OAGG;IACH,MAAM,IAAI;QAAE,IAAI,EAAE,MAAM,CAAC;QAAC,OAAO,EAAE,MAAM,CAAA;KAAE;IA2E3C;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,GAAE,OAAc,GAAG,MAAM;IAS/D;;OAEG;IACH,UAAU,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IAI7C;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAclC,OAAO,CAAC,iBAAiB;IAwBzB,OAAO,CAAC,WAAW;IAInB,OAAO,CAAC,OAAO;IA0Bf,OAAO,CAAC,mBAAmB;IAkC3B,OAAO,CAAC,cAAc;IAmBtB,OAAO,CAAC,WAAW;IAenB,OAAO,CAAC,UAAU;IAQlB,OAAO,CAAC,MAAM;CAWf;AAED;;GAEG;AACH,wBAAgB,SAAS,CAAC,MAAM,CAAC,EAAE,OAAO,CAAC,SAAS,CAAC,GAAG,YAAY,CAEnE"}
|