@sparkleideas/neural 3.5.2-patch.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +260 -0
- package/__tests__/README.md +235 -0
- package/__tests__/algorithms.test.ts +582 -0
- package/__tests__/patterns.test.ts +549 -0
- package/__tests__/sona.test.ts +445 -0
- package/docs/SONA_INTEGRATION.md +460 -0
- package/docs/SONA_QUICKSTART.md +168 -0
- package/examples/sona-usage.ts +318 -0
- package/package.json +23 -0
- package/src/algorithms/a2c.d.ts +86 -0
- package/src/algorithms/a2c.d.ts.map +1 -0
- package/src/algorithms/a2c.js +361 -0
- package/src/algorithms/a2c.js.map +1 -0
- package/src/algorithms/a2c.ts +478 -0
- package/src/algorithms/curiosity.d.ts +82 -0
- package/src/algorithms/curiosity.d.ts.map +1 -0
- package/src/algorithms/curiosity.js +392 -0
- package/src/algorithms/curiosity.js.map +1 -0
- package/src/algorithms/curiosity.ts +509 -0
- package/src/algorithms/decision-transformer.d.ts +82 -0
- package/src/algorithms/decision-transformer.d.ts.map +1 -0
- package/src/algorithms/decision-transformer.js +415 -0
- package/src/algorithms/decision-transformer.js.map +1 -0
- package/src/algorithms/decision-transformer.ts +521 -0
- package/src/algorithms/dqn.d.ts +72 -0
- package/src/algorithms/dqn.d.ts.map +1 -0
- package/src/algorithms/dqn.js +303 -0
- package/src/algorithms/dqn.js.map +1 -0
- package/src/algorithms/dqn.ts +382 -0
- package/src/algorithms/index.d.ts +32 -0
- package/src/algorithms/index.d.ts.map +1 -0
- package/src/algorithms/index.js +74 -0
- package/src/algorithms/index.js.map +1 -0
- package/src/algorithms/index.ts +122 -0
- package/src/algorithms/ppo.d.ts +72 -0
- package/src/algorithms/ppo.d.ts.map +1 -0
- package/src/algorithms/ppo.js +331 -0
- package/src/algorithms/ppo.js.map +1 -0
- package/src/algorithms/ppo.ts +429 -0
- package/src/algorithms/q-learning.d.ts +77 -0
- package/src/algorithms/q-learning.d.ts.map +1 -0
- package/src/algorithms/q-learning.js +259 -0
- package/src/algorithms/q-learning.js.map +1 -0
- package/src/algorithms/q-learning.ts +333 -0
- package/src/algorithms/sarsa.d.ts +82 -0
- package/src/algorithms/sarsa.d.ts.map +1 -0
- package/src/algorithms/sarsa.js +297 -0
- package/src/algorithms/sarsa.js.map +1 -0
- package/src/algorithms/sarsa.ts +383 -0
- package/src/algorithms/tmp.json +0 -0
- package/src/application/index.ts +11 -0
- package/src/application/services/neural-application-service.ts +217 -0
- package/src/domain/entities/pattern.ts +169 -0
- package/src/domain/index.ts +18 -0
- package/src/domain/services/learning-service.ts +256 -0
- package/src/index.d.ts +118 -0
- package/src/index.d.ts.map +1 -0
- package/src/index.js +201 -0
- package/src/index.js.map +1 -0
- package/src/index.ts +363 -0
- package/src/modes/balanced.d.ts +60 -0
- package/src/modes/balanced.d.ts.map +1 -0
- package/src/modes/balanced.js +234 -0
- package/src/modes/balanced.js.map +1 -0
- package/src/modes/balanced.ts +299 -0
- package/src/modes/base.ts +163 -0
- package/src/modes/batch.d.ts +82 -0
- package/src/modes/batch.d.ts.map +1 -0
- package/src/modes/batch.js +316 -0
- package/src/modes/batch.js.map +1 -0
- package/src/modes/batch.ts +434 -0
- package/src/modes/edge.d.ts +85 -0
- package/src/modes/edge.d.ts.map +1 -0
- package/src/modes/edge.js +310 -0
- package/src/modes/edge.js.map +1 -0
- package/src/modes/edge.ts +409 -0
- package/src/modes/index.d.ts +55 -0
- package/src/modes/index.d.ts.map +1 -0
- package/src/modes/index.js +83 -0
- package/src/modes/index.js.map +1 -0
- package/src/modes/index.ts +16 -0
- package/src/modes/real-time.d.ts +58 -0
- package/src/modes/real-time.d.ts.map +1 -0
- package/src/modes/real-time.js +196 -0
- package/src/modes/real-time.js.map +1 -0
- package/src/modes/real-time.ts +257 -0
- package/src/modes/research.d.ts +79 -0
- package/src/modes/research.d.ts.map +1 -0
- package/src/modes/research.js +389 -0
- package/src/modes/research.js.map +1 -0
- package/src/modes/research.ts +486 -0
- package/src/modes/tmp.json +0 -0
- package/src/pattern-learner.d.ts +117 -0
- package/src/pattern-learner.d.ts.map +1 -0
- package/src/pattern-learner.js +603 -0
- package/src/pattern-learner.js.map +1 -0
- package/src/pattern-learner.ts +757 -0
- package/src/reasoning-bank.d.ts +259 -0
- package/src/reasoning-bank.d.ts.map +1 -0
- package/src/reasoning-bank.js +993 -0
- package/src/reasoning-bank.js.map +1 -0
- package/src/reasoning-bank.ts +1279 -0
- package/src/reasoningbank-adapter.ts +697 -0
- package/src/sona-integration.d.ts +168 -0
- package/src/sona-integration.d.ts.map +1 -0
- package/src/sona-integration.js +316 -0
- package/src/sona-integration.js.map +1 -0
- package/src/sona-integration.ts +432 -0
- package/src/sona-manager.d.ts +147 -0
- package/src/sona-manager.d.ts.map +1 -0
- package/src/sona-manager.js +695 -0
- package/src/sona-manager.js.map +1 -0
- package/src/sona-manager.ts +835 -0
- package/src/tmp.json +0 -0
- package/src/types.d.ts +431 -0
- package/src/types.d.ts.map +1 -0
- package/src/types.js +11 -0
- package/src/types.js.map +1 -0
- package/src/types.ts +590 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +19 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Decision Transformer
|
|
3
|
+
*
|
|
4
|
+
* Implements sequence modeling approach for RL:
|
|
5
|
+
* - Trajectory as sequence: (s, a, R, s, a, R, ...)
|
|
6
|
+
* - Return-conditioned generation
|
|
7
|
+
* - Causal transformer attention
|
|
8
|
+
* - Offline RL from trajectories
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <10ms per forward pass
|
|
11
|
+
*/
|
|
12
|
+
|
|
13
|
+
import type {
|
|
14
|
+
DecisionTransformerConfig,
|
|
15
|
+
Trajectory,
|
|
16
|
+
TrajectoryStep,
|
|
17
|
+
} from '../types.js';
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Default Decision Transformer configuration
|
|
21
|
+
*/
|
|
22
|
+
export const DEFAULT_DT_CONFIG: DecisionTransformerConfig = {
|
|
23
|
+
algorithm: 'decision-transformer',
|
|
24
|
+
learningRate: 0.0001,
|
|
25
|
+
gamma: 0.99,
|
|
26
|
+
entropyCoef: 0,
|
|
27
|
+
valueLossCoef: 0,
|
|
28
|
+
maxGradNorm: 1.0,
|
|
29
|
+
epochs: 1,
|
|
30
|
+
miniBatchSize: 64,
|
|
31
|
+
contextLength: 20,
|
|
32
|
+
numHeads: 4,
|
|
33
|
+
numLayers: 2,
|
|
34
|
+
hiddenDim: 64,
|
|
35
|
+
embeddingDim: 32,
|
|
36
|
+
dropout: 0.1,
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Sequence entry for transformer
|
|
41
|
+
*/
|
|
42
|
+
interface SequenceEntry {
|
|
43
|
+
returnToGo: number;
|
|
44
|
+
state: Float32Array;
|
|
45
|
+
action: number;
|
|
46
|
+
timestep: number;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/**
|
|
50
|
+
* Decision Transformer Implementation
|
|
51
|
+
*/
|
|
52
|
+
export class DecisionTransformer {
|
|
53
|
+
private config: DecisionTransformerConfig;
|
|
54
|
+
|
|
55
|
+
// Embeddings
|
|
56
|
+
private stateEmbed: Float32Array;
|
|
57
|
+
private actionEmbed: Float32Array;
|
|
58
|
+
private returnEmbed: Float32Array;
|
|
59
|
+
private posEmbed: Float32Array;
|
|
60
|
+
|
|
61
|
+
// Transformer layers (simplified)
|
|
62
|
+
private attentionWeights: Float32Array[][];
|
|
63
|
+
private ffnWeights: Float32Array[][];
|
|
64
|
+
|
|
65
|
+
// Output head
|
|
66
|
+
private actionHead: Float32Array;
|
|
67
|
+
|
|
68
|
+
// Training buffer
|
|
69
|
+
private trajectoryBuffer: Trajectory[] = [];
|
|
70
|
+
|
|
71
|
+
// Dimensions
|
|
72
|
+
private stateDim = 768;
|
|
73
|
+
private numActions = 4;
|
|
74
|
+
|
|
75
|
+
// Statistics
|
|
76
|
+
private updateCount = 0;
|
|
77
|
+
private avgLoss = 0;
|
|
78
|
+
|
|
79
|
+
constructor(config: Partial<DecisionTransformerConfig> = {}) {
|
|
80
|
+
this.config = { ...DEFAULT_DT_CONFIG, ...config };
|
|
81
|
+
|
|
82
|
+
// Initialize embeddings
|
|
83
|
+
this.stateEmbed = this.initEmbedding(this.stateDim, this.config.embeddingDim);
|
|
84
|
+
this.actionEmbed = this.initEmbedding(this.numActions, this.config.embeddingDim);
|
|
85
|
+
this.returnEmbed = this.initEmbedding(1, this.config.embeddingDim);
|
|
86
|
+
this.posEmbed = this.initEmbedding(this.config.contextLength * 3, this.config.embeddingDim);
|
|
87
|
+
|
|
88
|
+
// Initialize transformer layers
|
|
89
|
+
this.attentionWeights = [];
|
|
90
|
+
this.ffnWeights = [];
|
|
91
|
+
|
|
92
|
+
for (let l = 0; l < this.config.numLayers; l++) {
|
|
93
|
+
// Attention: Q, K, V, O projections
|
|
94
|
+
this.attentionWeights.push([
|
|
95
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // Q
|
|
96
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // K
|
|
97
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // V
|
|
98
|
+
this.initWeight(this.config.hiddenDim, this.config.embeddingDim), // O
|
|
99
|
+
]);
|
|
100
|
+
|
|
101
|
+
// FFN: up and down projections
|
|
102
|
+
this.ffnWeights.push([
|
|
103
|
+
this.initWeight(this.config.embeddingDim, this.config.hiddenDim * 4),
|
|
104
|
+
this.initWeight(this.config.hiddenDim * 4, this.config.embeddingDim),
|
|
105
|
+
]);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Action prediction head
|
|
109
|
+
this.actionHead = this.initWeight(this.config.embeddingDim, this.numActions);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Add trajectory for training
|
|
114
|
+
*/
|
|
115
|
+
addTrajectory(trajectory: Trajectory): void {
|
|
116
|
+
if (trajectory.isComplete && trajectory.steps.length > 0) {
|
|
117
|
+
this.trajectoryBuffer.push(trajectory);
|
|
118
|
+
|
|
119
|
+
// Keep buffer bounded
|
|
120
|
+
if (this.trajectoryBuffer.length > 1000) {
|
|
121
|
+
this.trajectoryBuffer = this.trajectoryBuffer.slice(-1000);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
/**
|
|
127
|
+
* Train on buffered trajectories
|
|
128
|
+
* Target: <10ms per batch
|
|
129
|
+
*/
|
|
130
|
+
train(): { loss: number; accuracy: number } {
|
|
131
|
+
const startTime = performance.now();
|
|
132
|
+
|
|
133
|
+
if (this.trajectoryBuffer.length === 0) {
|
|
134
|
+
return { loss: 0, accuracy: 0 };
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// Sample mini-batch of trajectories
|
|
138
|
+
const batchSize = Math.min(this.config.miniBatchSize, this.trajectoryBuffer.length);
|
|
139
|
+
const batch: Trajectory[] = [];
|
|
140
|
+
|
|
141
|
+
for (let i = 0; i < batchSize; i++) {
|
|
142
|
+
const idx = Math.floor(Math.random() * this.trajectoryBuffer.length);
|
|
143
|
+
batch.push(this.trajectoryBuffer[idx]);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
let totalLoss = 0;
|
|
147
|
+
let correct = 0;
|
|
148
|
+
let total = 0;
|
|
149
|
+
|
|
150
|
+
for (const trajectory of batch) {
|
|
151
|
+
// Create sequence from trajectory
|
|
152
|
+
const sequence = this.createSequence(trajectory);
|
|
153
|
+
|
|
154
|
+
if (sequence.length < 2) continue;
|
|
155
|
+
|
|
156
|
+
// Forward pass and compute loss
|
|
157
|
+
for (let t = 1; t < sequence.length; t++) {
|
|
158
|
+
// Use context up to position t
|
|
159
|
+
const context = sequence.slice(Math.max(0, t - this.config.contextLength), t);
|
|
160
|
+
const target = sequence[t];
|
|
161
|
+
|
|
162
|
+
// Predict action
|
|
163
|
+
const predicted = this.forward(context);
|
|
164
|
+
const predictedAction = this.argmax(predicted);
|
|
165
|
+
|
|
166
|
+
// Cross-entropy loss
|
|
167
|
+
const loss = -Math.log(predicted[target.action] + 1e-8);
|
|
168
|
+
totalLoss += loss;
|
|
169
|
+
|
|
170
|
+
if (predictedAction === target.action) {
|
|
171
|
+
correct++;
|
|
172
|
+
}
|
|
173
|
+
total++;
|
|
174
|
+
|
|
175
|
+
// Gradient update (simplified)
|
|
176
|
+
this.updateWeights(context, target.action, predicted);
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
this.updateCount++;
|
|
181
|
+
this.avgLoss = total > 0 ? totalLoss / total : 0;
|
|
182
|
+
|
|
183
|
+
const elapsed = performance.now() - startTime;
|
|
184
|
+
if (elapsed > 10) {
|
|
185
|
+
console.warn(`DT training exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
return {
|
|
189
|
+
loss: this.avgLoss,
|
|
190
|
+
accuracy: total > 0 ? correct / total : 0,
|
|
191
|
+
};
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/**
|
|
195
|
+
* Get action conditioned on target return
|
|
196
|
+
*/
|
|
197
|
+
getAction(
|
|
198
|
+
states: Float32Array[],
|
|
199
|
+
actions: number[],
|
|
200
|
+
targetReturn: number
|
|
201
|
+
): number {
|
|
202
|
+
// Build sequence
|
|
203
|
+
const sequence: SequenceEntry[] = [];
|
|
204
|
+
let returnToGo = targetReturn;
|
|
205
|
+
|
|
206
|
+
for (let i = 0; i < states.length; i++) {
|
|
207
|
+
sequence.push({
|
|
208
|
+
returnToGo,
|
|
209
|
+
state: states[i],
|
|
210
|
+
action: actions[i] ?? 0,
|
|
211
|
+
timestep: i,
|
|
212
|
+
});
|
|
213
|
+
|
|
214
|
+
// Decrease return-to-go by estimated reward
|
|
215
|
+
if (i > 0) {
|
|
216
|
+
returnToGo -= 0.1; // Default reward decrement for inference
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
// Forward pass
|
|
221
|
+
const logits = this.forward(sequence);
|
|
222
|
+
return this.argmax(logits);
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
/**
|
|
226
|
+
* Forward pass through transformer
|
|
227
|
+
*/
|
|
228
|
+
forward(sequence: SequenceEntry[]): Float32Array {
|
|
229
|
+
// Embed sequence elements
|
|
230
|
+
const seqLen = Math.min(sequence.length, this.config.contextLength);
|
|
231
|
+
const embedDim = this.config.embeddingDim;
|
|
232
|
+
|
|
233
|
+
// Initialize hidden states (simplified: stack all modalities)
|
|
234
|
+
const hidden = new Float32Array(seqLen * 3 * embedDim);
|
|
235
|
+
|
|
236
|
+
for (let t = 0; t < seqLen; t++) {
|
|
237
|
+
const entry = sequence[sequence.length - seqLen + t];
|
|
238
|
+
const baseIdx = t * 3 * embedDim;
|
|
239
|
+
|
|
240
|
+
// Embed return
|
|
241
|
+
for (let d = 0; d < embedDim; d++) {
|
|
242
|
+
hidden[baseIdx + d] = entry.returnToGo * this.returnEmbed[d];
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// Embed state
|
|
246
|
+
for (let d = 0; d < embedDim; d++) {
|
|
247
|
+
let stateSum = 0;
|
|
248
|
+
for (let s = 0; s < Math.min(entry.state.length, this.stateDim); s++) {
|
|
249
|
+
stateSum += entry.state[s] * this.stateEmbed[s * embedDim + d];
|
|
250
|
+
}
|
|
251
|
+
hidden[baseIdx + embedDim + d] = stateSum;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
// Embed action
|
|
255
|
+
for (let d = 0; d < embedDim; d++) {
|
|
256
|
+
hidden[baseIdx + 2 * embedDim + d] = this.actionEmbed[entry.action * embedDim + d];
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// Add positional embedding
|
|
260
|
+
for (let d = 0; d < 3 * embedDim; d++) {
|
|
261
|
+
hidden[baseIdx + d] += this.posEmbed[t * 3 * embedDim + d] || 0;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// Apply transformer layers
|
|
266
|
+
for (let l = 0; l < this.config.numLayers; l++) {
|
|
267
|
+
hidden.set(this.transformerLayer(hidden, seqLen * 3, l));
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Extract last state position embedding for action prediction
|
|
271
|
+
const lastStateIdx = (seqLen * 3 - 2) * embedDim;
|
|
272
|
+
const lastState = hidden.slice(lastStateIdx, lastStateIdx + embedDim);
|
|
273
|
+
|
|
274
|
+
// Action prediction
|
|
275
|
+
const logits = new Float32Array(this.numActions);
|
|
276
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
277
|
+
let sum = 0;
|
|
278
|
+
for (let d = 0; d < embedDim; d++) {
|
|
279
|
+
sum += lastState[d] * this.actionHead[d * this.numActions + a];
|
|
280
|
+
}
|
|
281
|
+
logits[a] = sum;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
return this.softmax(logits);
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
/**
|
|
288
|
+
* Get statistics
|
|
289
|
+
*/
|
|
290
|
+
getStats(): Record<string, number> {
|
|
291
|
+
return {
|
|
292
|
+
updateCount: this.updateCount,
|
|
293
|
+
bufferSize: this.trajectoryBuffer.length,
|
|
294
|
+
avgLoss: this.avgLoss,
|
|
295
|
+
contextLength: this.config.contextLength,
|
|
296
|
+
numLayers: this.config.numLayers,
|
|
297
|
+
};
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// ==========================================================================
|
|
301
|
+
// Private Methods
|
|
302
|
+
// ==========================================================================
|
|
303
|
+
|
|
304
|
+
private initEmbedding(inputDim: number, outputDim: number): Float32Array {
|
|
305
|
+
const embed = new Float32Array(inputDim * outputDim);
|
|
306
|
+
const scale = Math.sqrt(2 / inputDim);
|
|
307
|
+
for (let i = 0; i < embed.length; i++) {
|
|
308
|
+
embed[i] = (Math.random() - 0.5) * scale;
|
|
309
|
+
}
|
|
310
|
+
return embed;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
private initWeight(inputDim: number, outputDim: number): Float32Array {
|
|
314
|
+
const weight = new Float32Array(inputDim * outputDim);
|
|
315
|
+
const scale = Math.sqrt(2 / inputDim);
|
|
316
|
+
for (let i = 0; i < weight.length; i++) {
|
|
317
|
+
weight[i] = (Math.random() - 0.5) * scale;
|
|
318
|
+
}
|
|
319
|
+
return weight;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
private createSequence(trajectory: Trajectory): SequenceEntry[] {
|
|
323
|
+
const sequence: SequenceEntry[] = [];
|
|
324
|
+
|
|
325
|
+
// Compute returns-to-go
|
|
326
|
+
const rewards = trajectory.steps.map(s => s.reward);
|
|
327
|
+
const returnsToGo = new Array(rewards.length).fill(0);
|
|
328
|
+
let cumReturn = 0;
|
|
329
|
+
|
|
330
|
+
for (let t = rewards.length - 1; t >= 0; t--) {
|
|
331
|
+
cumReturn = rewards[t] + this.config.gamma * cumReturn;
|
|
332
|
+
returnsToGo[t] = cumReturn;
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
// Create sequence entries
|
|
336
|
+
for (let t = 0; t < trajectory.steps.length; t++) {
|
|
337
|
+
sequence.push({
|
|
338
|
+
returnToGo: returnsToGo[t],
|
|
339
|
+
state: trajectory.steps[t].stateAfter,
|
|
340
|
+
action: this.hashAction(trajectory.steps[t].action),
|
|
341
|
+
timestep: t,
|
|
342
|
+
});
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
return sequence;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
private transformerLayer(
|
|
349
|
+
hidden: Float32Array,
|
|
350
|
+
seqLen: number,
|
|
351
|
+
layerIdx: number
|
|
352
|
+
): Float32Array {
|
|
353
|
+
const embedDim = this.config.embeddingDim;
|
|
354
|
+
const hiddenDim = this.config.hiddenDim;
|
|
355
|
+
const numHeads = this.config.numHeads;
|
|
356
|
+
const headDim = hiddenDim / numHeads;
|
|
357
|
+
|
|
358
|
+
const output = new Float32Array(hidden.length);
|
|
359
|
+
|
|
360
|
+
// Self-attention (simplified causal)
|
|
361
|
+
const [Wq, Wk, Wv, Wo] = this.attentionWeights[layerIdx];
|
|
362
|
+
|
|
363
|
+
// Compute Q, K, V for all positions
|
|
364
|
+
const Q = new Float32Array(seqLen * hiddenDim);
|
|
365
|
+
const K = new Float32Array(seqLen * hiddenDim);
|
|
366
|
+
const V = new Float32Array(seqLen * hiddenDim);
|
|
367
|
+
|
|
368
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
369
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
370
|
+
let qSum = 0, kSum = 0, vSum = 0;
|
|
371
|
+
for (let d = 0; d < embedDim; d++) {
|
|
372
|
+
const hiddenVal = hidden[pos * embedDim + d];
|
|
373
|
+
qSum += hiddenVal * Wq[d * hiddenDim + h];
|
|
374
|
+
kSum += hiddenVal * Wk[d * hiddenDim + h];
|
|
375
|
+
vSum += hiddenVal * Wv[d * hiddenDim + h];
|
|
376
|
+
}
|
|
377
|
+
Q[pos * hiddenDim + h] = qSum;
|
|
378
|
+
K[pos * hiddenDim + h] = kSum;
|
|
379
|
+
V[pos * hiddenDim + h] = vSum;
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
// Causal attention
|
|
384
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
385
|
+
// Compute attention scores for current position
|
|
386
|
+
const scores = new Float32Array(pos + 1);
|
|
387
|
+
for (let k = 0; k <= pos; k++) {
|
|
388
|
+
let score = 0;
|
|
389
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
390
|
+
score += Q[pos * hiddenDim + h] * K[k * hiddenDim + h];
|
|
391
|
+
}
|
|
392
|
+
scores[k] = score / Math.sqrt(headDim);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
// Softmax
|
|
396
|
+
const maxScore = Math.max(...scores);
|
|
397
|
+
let sumExp = 0;
|
|
398
|
+
for (let k = 0; k <= pos; k++) {
|
|
399
|
+
scores[k] = Math.exp(scores[k] - maxScore);
|
|
400
|
+
sumExp += scores[k];
|
|
401
|
+
}
|
|
402
|
+
for (let k = 0; k <= pos; k++) {
|
|
403
|
+
scores[k] /= sumExp;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
// Weighted sum of values
|
|
407
|
+
const attnOut = new Float32Array(hiddenDim);
|
|
408
|
+
for (let k = 0; k <= pos; k++) {
|
|
409
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
410
|
+
attnOut[h] += scores[k] * V[k * hiddenDim + h];
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
// Output projection
|
|
415
|
+
for (let d = 0; d < embedDim; d++) {
|
|
416
|
+
let sum = hidden[pos * embedDim + d]; // Residual
|
|
417
|
+
for (let h = 0; h < hiddenDim; h++) {
|
|
418
|
+
sum += attnOut[h] * Wo[h * embedDim + d];
|
|
419
|
+
}
|
|
420
|
+
output[pos * embedDim + d] = sum;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
// FFN with residual
|
|
425
|
+
const [Wup, Wdown] = this.ffnWeights[layerIdx];
|
|
426
|
+
const ffnHiddenDim = hiddenDim * 4;
|
|
427
|
+
|
|
428
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
429
|
+
// Up projection + GELU
|
|
430
|
+
const ffnHidden = new Float32Array(ffnHiddenDim);
|
|
431
|
+
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
432
|
+
let sum = 0;
|
|
433
|
+
for (let d = 0; d < embedDim; d++) {
|
|
434
|
+
sum += output[pos * embedDim + d] * Wup[d * ffnHiddenDim + h];
|
|
435
|
+
}
|
|
436
|
+
// GELU approximation
|
|
437
|
+
ffnHidden[h] = sum * 0.5 * (1 + Math.tanh(0.7978845608 * (sum + 0.044715 * sum * sum * sum)));
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
// Down projection
|
|
441
|
+
for (let d = 0; d < embedDim; d++) {
|
|
442
|
+
let sum = output[pos * embedDim + d]; // Residual
|
|
443
|
+
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
444
|
+
sum += ffnHidden[h] * Wdown[h * embedDim + d];
|
|
445
|
+
}
|
|
446
|
+
output[pos * embedDim + d] = sum;
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
return output;
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
private updateWeights(
|
|
454
|
+
context: SequenceEntry[],
|
|
455
|
+
targetAction: number,
|
|
456
|
+
predicted: Float32Array
|
|
457
|
+
): void {
|
|
458
|
+
// Simplified gradient update for action head
|
|
459
|
+
const lr = this.config.learningRate;
|
|
460
|
+
const embedDim = this.config.embeddingDim;
|
|
461
|
+
|
|
462
|
+
// Gradient of cross-entropy
|
|
463
|
+
const grad = new Float32Array(this.numActions);
|
|
464
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
465
|
+
grad[a] = predicted[a] - (a === targetAction ? 1 : 0);
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
// Update action head (simplified)
|
|
469
|
+
for (let d = 0; d < embedDim; d++) {
|
|
470
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
471
|
+
this.actionHead[d * this.numActions + a] -= lr * grad[a] * 0.1;
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
private softmax(logits: Float32Array): Float32Array {
|
|
477
|
+
const max = Math.max(...logits);
|
|
478
|
+
const exps = new Float32Array(logits.length);
|
|
479
|
+
let sum = 0;
|
|
480
|
+
|
|
481
|
+
for (let i = 0; i < logits.length; i++) {
|
|
482
|
+
exps[i] = Math.exp(logits[i] - max);
|
|
483
|
+
sum += exps[i];
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
for (let i = 0; i < exps.length; i++) {
|
|
487
|
+
exps[i] /= sum;
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
return exps;
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
private argmax(values: Float32Array): number {
|
|
494
|
+
let maxIdx = 0;
|
|
495
|
+
let maxVal = values[0];
|
|
496
|
+
for (let i = 1; i < values.length; i++) {
|
|
497
|
+
if (values[i] > maxVal) {
|
|
498
|
+
maxVal = values[i];
|
|
499
|
+
maxIdx = i;
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
return maxIdx;
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
private hashAction(action: string): number {
|
|
506
|
+
let hash = 0;
|
|
507
|
+
for (let i = 0; i < action.length; i++) {
|
|
508
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
509
|
+
}
|
|
510
|
+
return hash;
|
|
511
|
+
}
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
/**
|
|
515
|
+
* Factory function
|
|
516
|
+
*/
|
|
517
|
+
export function createDecisionTransformer(
|
|
518
|
+
config?: Partial<DecisionTransformerConfig>
|
|
519
|
+
): DecisionTransformer {
|
|
520
|
+
return new DecisionTransformer(config);
|
|
521
|
+
}
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Deep Q-Network (DQN)
|
|
3
|
+
*
|
|
4
|
+
* Implements DQN with enhancements:
|
|
5
|
+
* - Experience replay
|
|
6
|
+
* - Target network
|
|
7
|
+
* - Double DQN (optional)
|
|
8
|
+
* - Dueling architecture (optional)
|
|
9
|
+
* - Epsilon-greedy exploration
|
|
10
|
+
*
|
|
11
|
+
* Performance Target: <10ms per update step
|
|
12
|
+
*/
|
|
13
|
+
import type { DQNConfig, Trajectory } from '../types.js';
|
|
14
|
+
/**
|
|
15
|
+
* Default DQN configuration
|
|
16
|
+
*/
|
|
17
|
+
export declare const DEFAULT_DQN_CONFIG: DQNConfig;
|
|
18
|
+
/**
|
|
19
|
+
* DQN Algorithm Implementation
|
|
20
|
+
*/
|
|
21
|
+
export declare class DQNAlgorithm {
|
|
22
|
+
private config;
|
|
23
|
+
private qWeights;
|
|
24
|
+
private targetWeights;
|
|
25
|
+
private qMomentum;
|
|
26
|
+
private buffer;
|
|
27
|
+
private bufferIdx;
|
|
28
|
+
private epsilon;
|
|
29
|
+
private stepCount;
|
|
30
|
+
private numActions;
|
|
31
|
+
private inputDim;
|
|
32
|
+
private updateCount;
|
|
33
|
+
private avgLoss;
|
|
34
|
+
constructor(config?: Partial<DQNConfig>);
|
|
35
|
+
/**
|
|
36
|
+
* Add experience from trajectory
|
|
37
|
+
*/
|
|
38
|
+
addExperience(trajectory: Trajectory): void;
|
|
39
|
+
/**
|
|
40
|
+
* Perform DQN update
|
|
41
|
+
* Target: <10ms
|
|
42
|
+
*/
|
|
43
|
+
update(): {
|
|
44
|
+
loss: number;
|
|
45
|
+
epsilon: number;
|
|
46
|
+
};
|
|
47
|
+
/**
|
|
48
|
+
* Get action using epsilon-greedy
|
|
49
|
+
*/
|
|
50
|
+
getAction(state: Float32Array, explore?: boolean): number;
|
|
51
|
+
/**
|
|
52
|
+
* Get Q-values for a state
|
|
53
|
+
*/
|
|
54
|
+
getQValues(state: Float32Array): Float32Array;
|
|
55
|
+
/**
|
|
56
|
+
* Get statistics
|
|
57
|
+
*/
|
|
58
|
+
getStats(): Record<string, number>;
|
|
59
|
+
private initializeNetwork;
|
|
60
|
+
private copyNetwork;
|
|
61
|
+
private forward;
|
|
62
|
+
private accumulateGradients;
|
|
63
|
+
private applyGradients;
|
|
64
|
+
private sampleBatch;
|
|
65
|
+
private hashAction;
|
|
66
|
+
private argmax;
|
|
67
|
+
}
|
|
68
|
+
/**
|
|
69
|
+
* Factory function
|
|
70
|
+
*/
|
|
71
|
+
export declare function createDQN(config?: Partial<DQNConfig>): DQNAlgorithm;
|
|
72
|
+
//# sourceMappingURL=dqn.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"dqn.d.ts","sourceRoot":"","sources":["dqn.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAEH,OAAO,KAAK,EACV,SAAS,EACT,UAAU,EAEX,MAAM,aAAa,CAAC;AAErB;;GAEG;AACH,eAAO,MAAM,kBAAkB,EAAE,SAgBhC,CAAC;AAaF;;GAEG;AACH,qBAAa,YAAY;IACvB,OAAO,CAAC,MAAM,CAAY;IAG1B,OAAO,CAAC,QAAQ,CAAiB;IACjC,OAAO,CAAC,aAAa,CAAiB;IAGtC,OAAO,CAAC,SAAS,CAAiB;IAGlC,OAAO,CAAC,MAAM,CAAuB;IACrC,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,UAAU,CAAK;IACvB,OAAO,CAAC,QAAQ,CAAO;IAGvB,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,OAAO,CAAK;gBAER,MAAM,GAAE,OAAO,CAAC,SAAS,CAAM;IAU3C;;OAEG;IACH,aAAa,CAAC,UAAU,EAAE,UAAU,GAAG,IAAI;IAyB3C;;;OAGG;IACH,MAAM,IAAI;QAAE,IAAI,EAAE,MAAM,CAAC;QAAC,OAAO,EAAE,MAAM,CAAA;KAAE;IA2E3C;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,GAAE,OAAc,GAAG,MAAM;IAS/D;;OAEG;IACH,UAAU,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IAI7C;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAclC,OAAO,CAAC,iBAAiB;IAwBzB,OAAO,CAAC,WAAW;IAInB,OAAO,CAAC,OAAO;IA0Bf,OAAO,CAAC,mBAAmB;IAkC3B,OAAO,CAAC,cAAc;IAmBtB,OAAO,CAAC,WAAW;IAenB,OAAO,CAAC,UAAU;IAQlB,OAAO,CAAC,MAAM;CAWf;AAED;;GAEG;AACH,wBAAgB,SAAS,CAAC,MAAM,CAAC,EAAE,OAAO,CAAC,SAAS,CAAC,GAAG,YAAY,CAEnE"}
|