@sparkleideas/neural 3.5.2-patch.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. package/README.md +260 -0
  2. package/__tests__/README.md +235 -0
  3. package/__tests__/algorithms.test.ts +582 -0
  4. package/__tests__/patterns.test.ts +549 -0
  5. package/__tests__/sona.test.ts +445 -0
  6. package/docs/SONA_INTEGRATION.md +460 -0
  7. package/docs/SONA_QUICKSTART.md +168 -0
  8. package/examples/sona-usage.ts +318 -0
  9. package/package.json +23 -0
  10. package/src/algorithms/a2c.d.ts +86 -0
  11. package/src/algorithms/a2c.d.ts.map +1 -0
  12. package/src/algorithms/a2c.js +361 -0
  13. package/src/algorithms/a2c.js.map +1 -0
  14. package/src/algorithms/a2c.ts +478 -0
  15. package/src/algorithms/curiosity.d.ts +82 -0
  16. package/src/algorithms/curiosity.d.ts.map +1 -0
  17. package/src/algorithms/curiosity.js +392 -0
  18. package/src/algorithms/curiosity.js.map +1 -0
  19. package/src/algorithms/curiosity.ts +509 -0
  20. package/src/algorithms/decision-transformer.d.ts +82 -0
  21. package/src/algorithms/decision-transformer.d.ts.map +1 -0
  22. package/src/algorithms/decision-transformer.js +415 -0
  23. package/src/algorithms/decision-transformer.js.map +1 -0
  24. package/src/algorithms/decision-transformer.ts +521 -0
  25. package/src/algorithms/dqn.d.ts +72 -0
  26. package/src/algorithms/dqn.d.ts.map +1 -0
  27. package/src/algorithms/dqn.js +303 -0
  28. package/src/algorithms/dqn.js.map +1 -0
  29. package/src/algorithms/dqn.ts +382 -0
  30. package/src/algorithms/index.d.ts +32 -0
  31. package/src/algorithms/index.d.ts.map +1 -0
  32. package/src/algorithms/index.js +74 -0
  33. package/src/algorithms/index.js.map +1 -0
  34. package/src/algorithms/index.ts +122 -0
  35. package/src/algorithms/ppo.d.ts +72 -0
  36. package/src/algorithms/ppo.d.ts.map +1 -0
  37. package/src/algorithms/ppo.js +331 -0
  38. package/src/algorithms/ppo.js.map +1 -0
  39. package/src/algorithms/ppo.ts +429 -0
  40. package/src/algorithms/q-learning.d.ts +77 -0
  41. package/src/algorithms/q-learning.d.ts.map +1 -0
  42. package/src/algorithms/q-learning.js +259 -0
  43. package/src/algorithms/q-learning.js.map +1 -0
  44. package/src/algorithms/q-learning.ts +333 -0
  45. package/src/algorithms/sarsa.d.ts +82 -0
  46. package/src/algorithms/sarsa.d.ts.map +1 -0
  47. package/src/algorithms/sarsa.js +297 -0
  48. package/src/algorithms/sarsa.js.map +1 -0
  49. package/src/algorithms/sarsa.ts +383 -0
  50. package/src/algorithms/tmp.json +0 -0
  51. package/src/application/index.ts +11 -0
  52. package/src/application/services/neural-application-service.ts +217 -0
  53. package/src/domain/entities/pattern.ts +169 -0
  54. package/src/domain/index.ts +18 -0
  55. package/src/domain/services/learning-service.ts +256 -0
  56. package/src/index.d.ts +118 -0
  57. package/src/index.d.ts.map +1 -0
  58. package/src/index.js +201 -0
  59. package/src/index.js.map +1 -0
  60. package/src/index.ts +363 -0
  61. package/src/modes/balanced.d.ts +60 -0
  62. package/src/modes/balanced.d.ts.map +1 -0
  63. package/src/modes/balanced.js +234 -0
  64. package/src/modes/balanced.js.map +1 -0
  65. package/src/modes/balanced.ts +299 -0
  66. package/src/modes/base.ts +163 -0
  67. package/src/modes/batch.d.ts +82 -0
  68. package/src/modes/batch.d.ts.map +1 -0
  69. package/src/modes/batch.js +316 -0
  70. package/src/modes/batch.js.map +1 -0
  71. package/src/modes/batch.ts +434 -0
  72. package/src/modes/edge.d.ts +85 -0
  73. package/src/modes/edge.d.ts.map +1 -0
  74. package/src/modes/edge.js +310 -0
  75. package/src/modes/edge.js.map +1 -0
  76. package/src/modes/edge.ts +409 -0
  77. package/src/modes/index.d.ts +55 -0
  78. package/src/modes/index.d.ts.map +1 -0
  79. package/src/modes/index.js +83 -0
  80. package/src/modes/index.js.map +1 -0
  81. package/src/modes/index.ts +16 -0
  82. package/src/modes/real-time.d.ts +58 -0
  83. package/src/modes/real-time.d.ts.map +1 -0
  84. package/src/modes/real-time.js +196 -0
  85. package/src/modes/real-time.js.map +1 -0
  86. package/src/modes/real-time.ts +257 -0
  87. package/src/modes/research.d.ts +79 -0
  88. package/src/modes/research.d.ts.map +1 -0
  89. package/src/modes/research.js +389 -0
  90. package/src/modes/research.js.map +1 -0
  91. package/src/modes/research.ts +486 -0
  92. package/src/modes/tmp.json +0 -0
  93. package/src/pattern-learner.d.ts +117 -0
  94. package/src/pattern-learner.d.ts.map +1 -0
  95. package/src/pattern-learner.js +603 -0
  96. package/src/pattern-learner.js.map +1 -0
  97. package/src/pattern-learner.ts +757 -0
  98. package/src/reasoning-bank.d.ts +259 -0
  99. package/src/reasoning-bank.d.ts.map +1 -0
  100. package/src/reasoning-bank.js +993 -0
  101. package/src/reasoning-bank.js.map +1 -0
  102. package/src/reasoning-bank.ts +1279 -0
  103. package/src/reasoningbank-adapter.ts +697 -0
  104. package/src/sona-integration.d.ts +168 -0
  105. package/src/sona-integration.d.ts.map +1 -0
  106. package/src/sona-integration.js +316 -0
  107. package/src/sona-integration.js.map +1 -0
  108. package/src/sona-integration.ts +432 -0
  109. package/src/sona-manager.d.ts +147 -0
  110. package/src/sona-manager.d.ts.map +1 -0
  111. package/src/sona-manager.js +695 -0
  112. package/src/sona-manager.js.map +1 -0
  113. package/src/sona-manager.ts +835 -0
  114. package/src/tmp.json +0 -0
  115. package/src/types.d.ts +431 -0
  116. package/src/types.d.ts.map +1 -0
  117. package/src/types.js +11 -0
  118. package/src/types.js.map +1 -0
  119. package/src/types.ts +590 -0
  120. package/tmp.json +0 -0
  121. package/tsconfig.json +9 -0
  122. package/vitest.config.ts +19 -0
@@ -0,0 +1,521 @@
1
+ /**
2
+ * Decision Transformer
3
+ *
4
+ * Implements sequence modeling approach for RL:
5
+ * - Trajectory as sequence: (s, a, R, s, a, R, ...)
6
+ * - Return-conditioned generation
7
+ * - Causal transformer attention
8
+ * - Offline RL from trajectories
9
+ *
10
+ * Performance Target: <10ms per forward pass
11
+ */
12
+
13
+ import type {
14
+ DecisionTransformerConfig,
15
+ Trajectory,
16
+ TrajectoryStep,
17
+ } from '../types.js';
18
+
19
+ /**
20
+ * Default Decision Transformer configuration
21
+ */
22
+ export const DEFAULT_DT_CONFIG: DecisionTransformerConfig = {
23
+ algorithm: 'decision-transformer',
24
+ learningRate: 0.0001,
25
+ gamma: 0.99,
26
+ entropyCoef: 0,
27
+ valueLossCoef: 0,
28
+ maxGradNorm: 1.0,
29
+ epochs: 1,
30
+ miniBatchSize: 64,
31
+ contextLength: 20,
32
+ numHeads: 4,
33
+ numLayers: 2,
34
+ hiddenDim: 64,
35
+ embeddingDim: 32,
36
+ dropout: 0.1,
37
+ };
38
+
39
+ /**
40
+ * Sequence entry for transformer
41
+ */
42
+ interface SequenceEntry {
43
+ returnToGo: number;
44
+ state: Float32Array;
45
+ action: number;
46
+ timestep: number;
47
+ }
48
+
49
+ /**
50
+ * Decision Transformer Implementation
51
+ */
52
+ export class DecisionTransformer {
53
+ private config: DecisionTransformerConfig;
54
+
55
+ // Embeddings
56
+ private stateEmbed: Float32Array;
57
+ private actionEmbed: Float32Array;
58
+ private returnEmbed: Float32Array;
59
+ private posEmbed: Float32Array;
60
+
61
+ // Transformer layers (simplified)
62
+ private attentionWeights: Float32Array[][];
63
+ private ffnWeights: Float32Array[][];
64
+
65
+ // Output head
66
+ private actionHead: Float32Array;
67
+
68
+ // Training buffer
69
+ private trajectoryBuffer: Trajectory[] = [];
70
+
71
+ // Dimensions
72
+ private stateDim = 768;
73
+ private numActions = 4;
74
+
75
+ // Statistics
76
+ private updateCount = 0;
77
+ private avgLoss = 0;
78
+
79
+ constructor(config: Partial<DecisionTransformerConfig> = {}) {
80
+ this.config = { ...DEFAULT_DT_CONFIG, ...config };
81
+
82
+ // Initialize embeddings
83
+ this.stateEmbed = this.initEmbedding(this.stateDim, this.config.embeddingDim);
84
+ this.actionEmbed = this.initEmbedding(this.numActions, this.config.embeddingDim);
85
+ this.returnEmbed = this.initEmbedding(1, this.config.embeddingDim);
86
+ this.posEmbed = this.initEmbedding(this.config.contextLength * 3, this.config.embeddingDim);
87
+
88
+ // Initialize transformer layers
89
+ this.attentionWeights = [];
90
+ this.ffnWeights = [];
91
+
92
+ for (let l = 0; l < this.config.numLayers; l++) {
93
+ // Attention: Q, K, V, O projections
94
+ this.attentionWeights.push([
95
+ this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // Q
96
+ this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // K
97
+ this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // V
98
+ this.initWeight(this.config.hiddenDim, this.config.embeddingDim), // O
99
+ ]);
100
+
101
+ // FFN: up and down projections
102
+ this.ffnWeights.push([
103
+ this.initWeight(this.config.embeddingDim, this.config.hiddenDim * 4),
104
+ this.initWeight(this.config.hiddenDim * 4, this.config.embeddingDim),
105
+ ]);
106
+ }
107
+
108
+ // Action prediction head
109
+ this.actionHead = this.initWeight(this.config.embeddingDim, this.numActions);
110
+ }
111
+
112
+ /**
113
+ * Add trajectory for training
114
+ */
115
+ addTrajectory(trajectory: Trajectory): void {
116
+ if (trajectory.isComplete && trajectory.steps.length > 0) {
117
+ this.trajectoryBuffer.push(trajectory);
118
+
119
+ // Keep buffer bounded
120
+ if (this.trajectoryBuffer.length > 1000) {
121
+ this.trajectoryBuffer = this.trajectoryBuffer.slice(-1000);
122
+ }
123
+ }
124
+ }
125
+
126
+ /**
127
+ * Train on buffered trajectories
128
+ * Target: <10ms per batch
129
+ */
130
+ train(): { loss: number; accuracy: number } {
131
+ const startTime = performance.now();
132
+
133
+ if (this.trajectoryBuffer.length === 0) {
134
+ return { loss: 0, accuracy: 0 };
135
+ }
136
+
137
+ // Sample mini-batch of trajectories
138
+ const batchSize = Math.min(this.config.miniBatchSize, this.trajectoryBuffer.length);
139
+ const batch: Trajectory[] = [];
140
+
141
+ for (let i = 0; i < batchSize; i++) {
142
+ const idx = Math.floor(Math.random() * this.trajectoryBuffer.length);
143
+ batch.push(this.trajectoryBuffer[idx]);
144
+ }
145
+
146
+ let totalLoss = 0;
147
+ let correct = 0;
148
+ let total = 0;
149
+
150
+ for (const trajectory of batch) {
151
+ // Create sequence from trajectory
152
+ const sequence = this.createSequence(trajectory);
153
+
154
+ if (sequence.length < 2) continue;
155
+
156
+ // Forward pass and compute loss
157
+ for (let t = 1; t < sequence.length; t++) {
158
+ // Use context up to position t
159
+ const context = sequence.slice(Math.max(0, t - this.config.contextLength), t);
160
+ const target = sequence[t];
161
+
162
+ // Predict action
163
+ const predicted = this.forward(context);
164
+ const predictedAction = this.argmax(predicted);
165
+
166
+ // Cross-entropy loss
167
+ const loss = -Math.log(predicted[target.action] + 1e-8);
168
+ totalLoss += loss;
169
+
170
+ if (predictedAction === target.action) {
171
+ correct++;
172
+ }
173
+ total++;
174
+
175
+ // Gradient update (simplified)
176
+ this.updateWeights(context, target.action, predicted);
177
+ }
178
+ }
179
+
180
+ this.updateCount++;
181
+ this.avgLoss = total > 0 ? totalLoss / total : 0;
182
+
183
+ const elapsed = performance.now() - startTime;
184
+ if (elapsed > 10) {
185
+ console.warn(`DT training exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
186
+ }
187
+
188
+ return {
189
+ loss: this.avgLoss,
190
+ accuracy: total > 0 ? correct / total : 0,
191
+ };
192
+ }
193
+
194
+ /**
195
+ * Get action conditioned on target return
196
+ */
197
+ getAction(
198
+ states: Float32Array[],
199
+ actions: number[],
200
+ targetReturn: number
201
+ ): number {
202
+ // Build sequence
203
+ const sequence: SequenceEntry[] = [];
204
+ let returnToGo = targetReturn;
205
+
206
+ for (let i = 0; i < states.length; i++) {
207
+ sequence.push({
208
+ returnToGo,
209
+ state: states[i],
210
+ action: actions[i] ?? 0,
211
+ timestep: i,
212
+ });
213
+
214
+ // Decrease return-to-go by estimated reward
215
+ if (i > 0) {
216
+ returnToGo -= 0.1; // Default reward decrement for inference
217
+ }
218
+ }
219
+
220
+ // Forward pass
221
+ const logits = this.forward(sequence);
222
+ return this.argmax(logits);
223
+ }
224
+
225
+ /**
226
+ * Forward pass through transformer
227
+ */
228
+ forward(sequence: SequenceEntry[]): Float32Array {
229
+ // Embed sequence elements
230
+ const seqLen = Math.min(sequence.length, this.config.contextLength);
231
+ const embedDim = this.config.embeddingDim;
232
+
233
+ // Initialize hidden states (simplified: stack all modalities)
234
+ const hidden = new Float32Array(seqLen * 3 * embedDim);
235
+
236
+ for (let t = 0; t < seqLen; t++) {
237
+ const entry = sequence[sequence.length - seqLen + t];
238
+ const baseIdx = t * 3 * embedDim;
239
+
240
+ // Embed return
241
+ for (let d = 0; d < embedDim; d++) {
242
+ hidden[baseIdx + d] = entry.returnToGo * this.returnEmbed[d];
243
+ }
244
+
245
+ // Embed state
246
+ for (let d = 0; d < embedDim; d++) {
247
+ let stateSum = 0;
248
+ for (let s = 0; s < Math.min(entry.state.length, this.stateDim); s++) {
249
+ stateSum += entry.state[s] * this.stateEmbed[s * embedDim + d];
250
+ }
251
+ hidden[baseIdx + embedDim + d] = stateSum;
252
+ }
253
+
254
+ // Embed action
255
+ for (let d = 0; d < embedDim; d++) {
256
+ hidden[baseIdx + 2 * embedDim + d] = this.actionEmbed[entry.action * embedDim + d];
257
+ }
258
+
259
+ // Add positional embedding
260
+ for (let d = 0; d < 3 * embedDim; d++) {
261
+ hidden[baseIdx + d] += this.posEmbed[t * 3 * embedDim + d] || 0;
262
+ }
263
+ }
264
+
265
+ // Apply transformer layers
266
+ for (let l = 0; l < this.config.numLayers; l++) {
267
+ hidden.set(this.transformerLayer(hidden, seqLen * 3, l));
268
+ }
269
+
270
+ // Extract last state position embedding for action prediction
271
+ const lastStateIdx = (seqLen * 3 - 2) * embedDim;
272
+ const lastState = hidden.slice(lastStateIdx, lastStateIdx + embedDim);
273
+
274
+ // Action prediction
275
+ const logits = new Float32Array(this.numActions);
276
+ for (let a = 0; a < this.numActions; a++) {
277
+ let sum = 0;
278
+ for (let d = 0; d < embedDim; d++) {
279
+ sum += lastState[d] * this.actionHead[d * this.numActions + a];
280
+ }
281
+ logits[a] = sum;
282
+ }
283
+
284
+ return this.softmax(logits);
285
+ }
286
+
287
+ /**
288
+ * Get statistics
289
+ */
290
+ getStats(): Record<string, number> {
291
+ return {
292
+ updateCount: this.updateCount,
293
+ bufferSize: this.trajectoryBuffer.length,
294
+ avgLoss: this.avgLoss,
295
+ contextLength: this.config.contextLength,
296
+ numLayers: this.config.numLayers,
297
+ };
298
+ }
299
+
300
+ // ==========================================================================
301
+ // Private Methods
302
+ // ==========================================================================
303
+
304
+ private initEmbedding(inputDim: number, outputDim: number): Float32Array {
305
+ const embed = new Float32Array(inputDim * outputDim);
306
+ const scale = Math.sqrt(2 / inputDim);
307
+ for (let i = 0; i < embed.length; i++) {
308
+ embed[i] = (Math.random() - 0.5) * scale;
309
+ }
310
+ return embed;
311
+ }
312
+
313
+ private initWeight(inputDim: number, outputDim: number): Float32Array {
314
+ const weight = new Float32Array(inputDim * outputDim);
315
+ const scale = Math.sqrt(2 / inputDim);
316
+ for (let i = 0; i < weight.length; i++) {
317
+ weight[i] = (Math.random() - 0.5) * scale;
318
+ }
319
+ return weight;
320
+ }
321
+
322
+ private createSequence(trajectory: Trajectory): SequenceEntry[] {
323
+ const sequence: SequenceEntry[] = [];
324
+
325
+ // Compute returns-to-go
326
+ const rewards = trajectory.steps.map(s => s.reward);
327
+ const returnsToGo = new Array(rewards.length).fill(0);
328
+ let cumReturn = 0;
329
+
330
+ for (let t = rewards.length - 1; t >= 0; t--) {
331
+ cumReturn = rewards[t] + this.config.gamma * cumReturn;
332
+ returnsToGo[t] = cumReturn;
333
+ }
334
+
335
+ // Create sequence entries
336
+ for (let t = 0; t < trajectory.steps.length; t++) {
337
+ sequence.push({
338
+ returnToGo: returnsToGo[t],
339
+ state: trajectory.steps[t].stateAfter,
340
+ action: this.hashAction(trajectory.steps[t].action),
341
+ timestep: t,
342
+ });
343
+ }
344
+
345
+ return sequence;
346
+ }
347
+
348
+ private transformerLayer(
349
+ hidden: Float32Array,
350
+ seqLen: number,
351
+ layerIdx: number
352
+ ): Float32Array {
353
+ const embedDim = this.config.embeddingDim;
354
+ const hiddenDim = this.config.hiddenDim;
355
+ const numHeads = this.config.numHeads;
356
+ const headDim = hiddenDim / numHeads;
357
+
358
+ const output = new Float32Array(hidden.length);
359
+
360
+ // Self-attention (simplified causal)
361
+ const [Wq, Wk, Wv, Wo] = this.attentionWeights[layerIdx];
362
+
363
+ // Compute Q, K, V for all positions
364
+ const Q = new Float32Array(seqLen * hiddenDim);
365
+ const K = new Float32Array(seqLen * hiddenDim);
366
+ const V = new Float32Array(seqLen * hiddenDim);
367
+
368
+ for (let pos = 0; pos < seqLen; pos++) {
369
+ for (let h = 0; h < hiddenDim; h++) {
370
+ let qSum = 0, kSum = 0, vSum = 0;
371
+ for (let d = 0; d < embedDim; d++) {
372
+ const hiddenVal = hidden[pos * embedDim + d];
373
+ qSum += hiddenVal * Wq[d * hiddenDim + h];
374
+ kSum += hiddenVal * Wk[d * hiddenDim + h];
375
+ vSum += hiddenVal * Wv[d * hiddenDim + h];
376
+ }
377
+ Q[pos * hiddenDim + h] = qSum;
378
+ K[pos * hiddenDim + h] = kSum;
379
+ V[pos * hiddenDim + h] = vSum;
380
+ }
381
+ }
382
+
383
+ // Causal attention
384
+ for (let pos = 0; pos < seqLen; pos++) {
385
+ // Compute attention scores for current position
386
+ const scores = new Float32Array(pos + 1);
387
+ for (let k = 0; k <= pos; k++) {
388
+ let score = 0;
389
+ for (let h = 0; h < hiddenDim; h++) {
390
+ score += Q[pos * hiddenDim + h] * K[k * hiddenDim + h];
391
+ }
392
+ scores[k] = score / Math.sqrt(headDim);
393
+ }
394
+
395
+ // Softmax
396
+ const maxScore = Math.max(...scores);
397
+ let sumExp = 0;
398
+ for (let k = 0; k <= pos; k++) {
399
+ scores[k] = Math.exp(scores[k] - maxScore);
400
+ sumExp += scores[k];
401
+ }
402
+ for (let k = 0; k <= pos; k++) {
403
+ scores[k] /= sumExp;
404
+ }
405
+
406
+ // Weighted sum of values
407
+ const attnOut = new Float32Array(hiddenDim);
408
+ for (let k = 0; k <= pos; k++) {
409
+ for (let h = 0; h < hiddenDim; h++) {
410
+ attnOut[h] += scores[k] * V[k * hiddenDim + h];
411
+ }
412
+ }
413
+
414
+ // Output projection
415
+ for (let d = 0; d < embedDim; d++) {
416
+ let sum = hidden[pos * embedDim + d]; // Residual
417
+ for (let h = 0; h < hiddenDim; h++) {
418
+ sum += attnOut[h] * Wo[h * embedDim + d];
419
+ }
420
+ output[pos * embedDim + d] = sum;
421
+ }
422
+ }
423
+
424
+ // FFN with residual
425
+ const [Wup, Wdown] = this.ffnWeights[layerIdx];
426
+ const ffnHiddenDim = hiddenDim * 4;
427
+
428
+ for (let pos = 0; pos < seqLen; pos++) {
429
+ // Up projection + GELU
430
+ const ffnHidden = new Float32Array(ffnHiddenDim);
431
+ for (let h = 0; h < ffnHiddenDim; h++) {
432
+ let sum = 0;
433
+ for (let d = 0; d < embedDim; d++) {
434
+ sum += output[pos * embedDim + d] * Wup[d * ffnHiddenDim + h];
435
+ }
436
+ // GELU approximation
437
+ ffnHidden[h] = sum * 0.5 * (1 + Math.tanh(0.7978845608 * (sum + 0.044715 * sum * sum * sum)));
438
+ }
439
+
440
+ // Down projection
441
+ for (let d = 0; d < embedDim; d++) {
442
+ let sum = output[pos * embedDim + d]; // Residual
443
+ for (let h = 0; h < ffnHiddenDim; h++) {
444
+ sum += ffnHidden[h] * Wdown[h * embedDim + d];
445
+ }
446
+ output[pos * embedDim + d] = sum;
447
+ }
448
+ }
449
+
450
+ return output;
451
+ }
452
+
453
+ private updateWeights(
454
+ context: SequenceEntry[],
455
+ targetAction: number,
456
+ predicted: Float32Array
457
+ ): void {
458
+ // Simplified gradient update for action head
459
+ const lr = this.config.learningRate;
460
+ const embedDim = this.config.embeddingDim;
461
+
462
+ // Gradient of cross-entropy
463
+ const grad = new Float32Array(this.numActions);
464
+ for (let a = 0; a < this.numActions; a++) {
465
+ grad[a] = predicted[a] - (a === targetAction ? 1 : 0);
466
+ }
467
+
468
+ // Update action head (simplified)
469
+ for (let d = 0; d < embedDim; d++) {
470
+ for (let a = 0; a < this.numActions; a++) {
471
+ this.actionHead[d * this.numActions + a] -= lr * grad[a] * 0.1;
472
+ }
473
+ }
474
+ }
475
+
476
+ private softmax(logits: Float32Array): Float32Array {
477
+ const max = Math.max(...logits);
478
+ const exps = new Float32Array(logits.length);
479
+ let sum = 0;
480
+
481
+ for (let i = 0; i < logits.length; i++) {
482
+ exps[i] = Math.exp(logits[i] - max);
483
+ sum += exps[i];
484
+ }
485
+
486
+ for (let i = 0; i < exps.length; i++) {
487
+ exps[i] /= sum;
488
+ }
489
+
490
+ return exps;
491
+ }
492
+
493
+ private argmax(values: Float32Array): number {
494
+ let maxIdx = 0;
495
+ let maxVal = values[0];
496
+ for (let i = 1; i < values.length; i++) {
497
+ if (values[i] > maxVal) {
498
+ maxVal = values[i];
499
+ maxIdx = i;
500
+ }
501
+ }
502
+ return maxIdx;
503
+ }
504
+
505
+ private hashAction(action: string): number {
506
+ let hash = 0;
507
+ for (let i = 0; i < action.length; i++) {
508
+ hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
509
+ }
510
+ return hash;
511
+ }
512
+ }
513
+
514
+ /**
515
+ * Factory function
516
+ */
517
+ export function createDecisionTransformer(
518
+ config?: Partial<DecisionTransformerConfig>
519
+ ): DecisionTransformer {
520
+ return new DecisionTransformer(config);
521
+ }
@@ -0,0 +1,72 @@
1
+ /**
2
+ * Deep Q-Network (DQN)
3
+ *
4
+ * Implements DQN with enhancements:
5
+ * - Experience replay
6
+ * - Target network
7
+ * - Double DQN (optional)
8
+ * - Dueling architecture (optional)
9
+ * - Epsilon-greedy exploration
10
+ *
11
+ * Performance Target: <10ms per update step
12
+ */
13
+ import type { DQNConfig, Trajectory } from '../types.js';
14
+ /**
15
+ * Default DQN configuration
16
+ */
17
+ export declare const DEFAULT_DQN_CONFIG: DQNConfig;
18
+ /**
19
+ * DQN Algorithm Implementation
20
+ */
21
+ export declare class DQNAlgorithm {
22
+ private config;
23
+ private qWeights;
24
+ private targetWeights;
25
+ private qMomentum;
26
+ private buffer;
27
+ private bufferIdx;
28
+ private epsilon;
29
+ private stepCount;
30
+ private numActions;
31
+ private inputDim;
32
+ private updateCount;
33
+ private avgLoss;
34
+ constructor(config?: Partial<DQNConfig>);
35
+ /**
36
+ * Add experience from trajectory
37
+ */
38
+ addExperience(trajectory: Trajectory): void;
39
+ /**
40
+ * Perform DQN update
41
+ * Target: <10ms
42
+ */
43
+ update(): {
44
+ loss: number;
45
+ epsilon: number;
46
+ };
47
+ /**
48
+ * Get action using epsilon-greedy
49
+ */
50
+ getAction(state: Float32Array, explore?: boolean): number;
51
+ /**
52
+ * Get Q-values for a state
53
+ */
54
+ getQValues(state: Float32Array): Float32Array;
55
+ /**
56
+ * Get statistics
57
+ */
58
+ getStats(): Record<string, number>;
59
+ private initializeNetwork;
60
+ private copyNetwork;
61
+ private forward;
62
+ private accumulateGradients;
63
+ private applyGradients;
64
+ private sampleBatch;
65
+ private hashAction;
66
+ private argmax;
67
+ }
68
+ /**
69
+ * Factory function
70
+ */
71
+ export declare function createDQN(config?: Partial<DQNConfig>): DQNAlgorithm;
72
+ //# sourceMappingURL=dqn.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"dqn.d.ts","sourceRoot":"","sources":["dqn.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAEH,OAAO,KAAK,EACV,SAAS,EACT,UAAU,EAEX,MAAM,aAAa,CAAC;AAErB;;GAEG;AACH,eAAO,MAAM,kBAAkB,EAAE,SAgBhC,CAAC;AAaF;;GAEG;AACH,qBAAa,YAAY;IACvB,OAAO,CAAC,MAAM,CAAY;IAG1B,OAAO,CAAC,QAAQ,CAAiB;IACjC,OAAO,CAAC,aAAa,CAAiB;IAGtC,OAAO,CAAC,SAAS,CAAiB;IAGlC,OAAO,CAAC,MAAM,CAAuB;IACrC,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,SAAS,CAAK;IAGtB,OAAO,CAAC,UAAU,CAAK;IACvB,OAAO,CAAC,QAAQ,CAAO;IAGvB,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,OAAO,CAAK;gBAER,MAAM,GAAE,OAAO,CAAC,SAAS,CAAM;IAU3C;;OAEG;IACH,aAAa,CAAC,UAAU,EAAE,UAAU,GAAG,IAAI;IAyB3C;;;OAGG;IACH,MAAM,IAAI;QAAE,IAAI,EAAE,MAAM,CAAC;QAAC,OAAO,EAAE,MAAM,CAAA;KAAE;IA2E3C;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,GAAE,OAAc,GAAG,MAAM;IAS/D;;OAEG;IACH,UAAU,CAAC,KAAK,EAAE,YAAY,GAAG,YAAY;IAI7C;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAclC,OAAO,CAAC,iBAAiB;IAwBzB,OAAO,CAAC,WAAW;IAInB,OAAO,CAAC,OAAO;IA0Bf,OAAO,CAAC,mBAAmB;IAkC3B,OAAO,CAAC,cAAc;IAmBtB,OAAO,CAAC,WAAW;IAenB,OAAO,CAAC,UAAU;IAQlB,OAAO,CAAC,MAAM;CAWf;AAED;;GAEG;AACH,wBAAgB,SAAS,CAAC,MAAM,CAAC,EAAE,OAAO,CAAC,SAAS,CAAC,GAAG,YAAY,CAEnE"}