@arcanea/guardian-evolution 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/algorithms/a2c.d.ts +86 -0
- package/dist/algorithms/a2c.d.ts.map +1 -0
- package/dist/algorithms/a2c.js +361 -0
- package/dist/algorithms/a2c.js.map +1 -0
- package/dist/algorithms/curiosity.d.ts +82 -0
- package/dist/algorithms/curiosity.d.ts.map +1 -0
- package/dist/algorithms/curiosity.js +392 -0
- package/dist/algorithms/curiosity.js.map +1 -0
- package/dist/algorithms/decision-transformer.d.ts +82 -0
- package/dist/algorithms/decision-transformer.d.ts.map +1 -0
- package/dist/algorithms/decision-transformer.js +415 -0
- package/dist/algorithms/decision-transformer.js.map +1 -0
- package/dist/algorithms/dqn.d.ts +72 -0
- package/dist/algorithms/dqn.d.ts.map +1 -0
- package/dist/algorithms/dqn.js +303 -0
- package/dist/algorithms/dqn.js.map +1 -0
- package/dist/algorithms/index.d.ts +32 -0
- package/dist/algorithms/index.d.ts.map +1 -0
- package/dist/algorithms/index.js +74 -0
- package/dist/algorithms/index.js.map +1 -0
- package/dist/algorithms/ppo.d.ts +72 -0
- package/dist/algorithms/ppo.d.ts.map +1 -0
- package/dist/algorithms/ppo.js +331 -0
- package/dist/algorithms/ppo.js.map +1 -0
- package/dist/algorithms/q-learning.d.ts +77 -0
- package/dist/algorithms/q-learning.d.ts.map +1 -0
- package/dist/algorithms/q-learning.js +259 -0
- package/dist/algorithms/q-learning.js.map +1 -0
- package/dist/algorithms/sarsa.d.ts +82 -0
- package/dist/algorithms/sarsa.d.ts.map +1 -0
- package/dist/algorithms/sarsa.js +297 -0
- package/dist/algorithms/sarsa.js.map +1 -0
- package/dist/index.d.ts +118 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +201 -0
- package/dist/index.js.map +1 -0
- package/dist/modes/balanced.d.ts +60 -0
- package/dist/modes/balanced.d.ts.map +1 -0
- package/dist/modes/balanced.js +234 -0
- package/dist/modes/balanced.js.map +1 -0
- package/dist/modes/batch.d.ts +82 -0
- package/dist/modes/batch.d.ts.map +1 -0
- package/dist/modes/batch.js +316 -0
- package/dist/modes/batch.js.map +1 -0
- package/dist/modes/edge.d.ts +85 -0
- package/dist/modes/edge.d.ts.map +1 -0
- package/dist/modes/edge.js +310 -0
- package/dist/modes/edge.js.map +1 -0
- package/dist/modes/index.d.ts +55 -0
- package/dist/modes/index.d.ts.map +1 -0
- package/dist/modes/index.js +83 -0
- package/dist/modes/index.js.map +1 -0
- package/dist/modes/real-time.d.ts +58 -0
- package/dist/modes/real-time.d.ts.map +1 -0
- package/dist/modes/real-time.js +196 -0
- package/dist/modes/real-time.js.map +1 -0
- package/dist/modes/research.d.ts +79 -0
- package/dist/modes/research.d.ts.map +1 -0
- package/dist/modes/research.js +389 -0
- package/dist/modes/research.js.map +1 -0
- package/dist/pattern-learner.d.ts +117 -0
- package/dist/pattern-learner.d.ts.map +1 -0
- package/dist/pattern-learner.js +603 -0
- package/dist/pattern-learner.js.map +1 -0
- package/dist/reasoning-bank.d.ts +259 -0
- package/dist/reasoning-bank.d.ts.map +1 -0
- package/dist/reasoning-bank.js +993 -0
- package/dist/reasoning-bank.js.map +1 -0
- package/dist/reasoningbank-adapter.d.ts +168 -0
- package/dist/reasoningbank-adapter.d.ts.map +1 -0
- package/dist/reasoningbank-adapter.js +463 -0
- package/dist/reasoningbank-adapter.js.map +1 -0
- package/dist/sona-integration.d.ts +168 -0
- package/dist/sona-integration.d.ts.map +1 -0
- package/dist/sona-integration.js +316 -0
- package/dist/sona-integration.js.map +1 -0
- package/dist/sona-manager.d.ts +147 -0
- package/dist/sona-manager.d.ts.map +1 -0
- package/dist/sona-manager.js +695 -0
- package/dist/sona-manager.js.map +1 -0
- package/dist/types.d.ts +431 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +11 -0
- package/dist/types.js.map +1 -0
- package/package.json +47 -0
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Curiosity-Driven Exploration
|
|
3
|
+
*
|
|
4
|
+
* Implements intrinsic motivation for exploration:
|
|
5
|
+
* - Intrinsic Curiosity Module (ICM)
|
|
6
|
+
* - Random Network Distillation (RND)
|
|
7
|
+
* - Forward and inverse dynamics models
|
|
8
|
+
* - Exploration bonus generation
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <5ms per forward pass
|
|
11
|
+
*/
|
|
12
|
+
/**
|
|
13
|
+
* Default Curiosity configuration
|
|
14
|
+
*/
|
|
15
|
+
export const DEFAULT_CURIOSITY_CONFIG = {
|
|
16
|
+
algorithm: 'curiosity',
|
|
17
|
+
learningRate: 0.0001,
|
|
18
|
+
gamma: 0.99,
|
|
19
|
+
entropyCoef: 0.01,
|
|
20
|
+
valueLossCoef: 0.5,
|
|
21
|
+
maxGradNorm: 0.5,
|
|
22
|
+
epochs: 1,
|
|
23
|
+
miniBatchSize: 32,
|
|
24
|
+
intrinsicCoef: 0.1,
|
|
25
|
+
forwardLR: 0.001,
|
|
26
|
+
inverseLR: 0.001,
|
|
27
|
+
featureDim: 64,
|
|
28
|
+
useRND: false,
|
|
29
|
+
};
|
|
30
|
+
/**
|
|
31
|
+
* Curiosity-Driven Exploration Module
|
|
32
|
+
*/
|
|
33
|
+
export class CuriosityModule {
|
|
34
|
+
config;
|
|
35
|
+
// Feature encoder
|
|
36
|
+
featureEncoder;
|
|
37
|
+
// Forward dynamics model: predicts next feature from current feature + action
|
|
38
|
+
forwardModel;
|
|
39
|
+
// Inverse dynamics model: predicts action from current and next features
|
|
40
|
+
inverseModel;
|
|
41
|
+
// RND target and predictor networks
|
|
42
|
+
rndTarget;
|
|
43
|
+
rndPredictor;
|
|
44
|
+
// Optimizer state
|
|
45
|
+
forwardMomentum;
|
|
46
|
+
inverseMomentum;
|
|
47
|
+
rndMomentum;
|
|
48
|
+
// Dimensions
|
|
49
|
+
stateDim = 768;
|
|
50
|
+
numActions = 4;
|
|
51
|
+
// Running statistics for normalization
|
|
52
|
+
intrinsicMean = 0;
|
|
53
|
+
intrinsicVar = 1;
|
|
54
|
+
updateCount = 0;
|
|
55
|
+
// Statistics
|
|
56
|
+
avgForwardLoss = 0;
|
|
57
|
+
avgInverseLoss = 0;
|
|
58
|
+
avgIntrinsicReward = 0;
|
|
59
|
+
constructor(config = {}) {
|
|
60
|
+
this.config = { ...DEFAULT_CURIOSITY_CONFIG, ...config };
|
|
61
|
+
const featureDim = this.config.featureDim;
|
|
62
|
+
// Initialize feature encoder: state_dim -> feature_dim
|
|
63
|
+
this.featureEncoder = this.initWeight(this.stateDim, featureDim);
|
|
64
|
+
// Forward model: (feature_dim + num_actions) -> feature_dim
|
|
65
|
+
this.forwardModel = this.initWeight(featureDim + this.numActions, featureDim);
|
|
66
|
+
// Inverse model: (2 * feature_dim) -> num_actions
|
|
67
|
+
this.inverseModel = this.initWeight(2 * featureDim, this.numActions);
|
|
68
|
+
// RND networks
|
|
69
|
+
this.rndTarget = this.initWeight(this.stateDim, featureDim);
|
|
70
|
+
this.rndPredictor = this.initWeight(this.stateDim, featureDim);
|
|
71
|
+
// Momentum buffers
|
|
72
|
+
this.forwardMomentum = new Float32Array(this.forwardModel.length);
|
|
73
|
+
this.inverseMomentum = new Float32Array(this.inverseModel.length);
|
|
74
|
+
this.rndMomentum = new Float32Array(this.rndPredictor.length);
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Compute intrinsic reward for a transition
|
|
78
|
+
*/
|
|
79
|
+
computeIntrinsicReward(state, action, nextState) {
|
|
80
|
+
if (this.config.useRND) {
|
|
81
|
+
return this.computeRNDReward(nextState);
|
|
82
|
+
}
|
|
83
|
+
else {
|
|
84
|
+
return this.computeICMReward(state, action, nextState);
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
/**
|
|
88
|
+
* Compute ICM-based intrinsic reward (prediction error)
|
|
89
|
+
*/
|
|
90
|
+
computeICMReward(state, action, nextState) {
|
|
91
|
+
const startTime = performance.now();
|
|
92
|
+
// Encode states to features
|
|
93
|
+
const stateFeature = this.encodeState(state);
|
|
94
|
+
const nextStateFeature = this.encodeState(nextState);
|
|
95
|
+
// Predict next state feature
|
|
96
|
+
const actionIdx = this.hashAction(action);
|
|
97
|
+
const predictedFeature = this.forwardPredict(stateFeature, actionIdx);
|
|
98
|
+
// Compute prediction error as intrinsic reward
|
|
99
|
+
let error = 0;
|
|
100
|
+
for (let i = 0; i < this.config.featureDim; i++) {
|
|
101
|
+
error += (predictedFeature[i] - nextStateFeature[i]) ** 2;
|
|
102
|
+
}
|
|
103
|
+
// Normalize intrinsic reward
|
|
104
|
+
const intrinsic = this.normalizeIntrinsic(error);
|
|
105
|
+
const elapsed = performance.now() - startTime;
|
|
106
|
+
if (elapsed > 5) {
|
|
107
|
+
console.warn(`ICM reward exceeded target: ${elapsed.toFixed(2)}ms > 5ms`);
|
|
108
|
+
}
|
|
109
|
+
return intrinsic * this.config.intrinsicCoef;
|
|
110
|
+
}
|
|
111
|
+
/**
|
|
112
|
+
* Compute RND-based intrinsic reward
|
|
113
|
+
*/
|
|
114
|
+
computeRNDReward(state) {
|
|
115
|
+
const startTime = performance.now();
|
|
116
|
+
// Target network output (fixed random features)
|
|
117
|
+
const targetOutput = this.rndForward(state, this.rndTarget);
|
|
118
|
+
// Predictor network output (trained to match target)
|
|
119
|
+
const predictorOutput = this.rndForward(state, this.rndPredictor);
|
|
120
|
+
// Compute prediction error
|
|
121
|
+
let error = 0;
|
|
122
|
+
for (let i = 0; i < this.config.featureDim; i++) {
|
|
123
|
+
error += (predictorOutput[i] - targetOutput[i]) ** 2;
|
|
124
|
+
}
|
|
125
|
+
// Normalize
|
|
126
|
+
const intrinsic = this.normalizeIntrinsic(error);
|
|
127
|
+
const elapsed = performance.now() - startTime;
|
|
128
|
+
if (elapsed > 5) {
|
|
129
|
+
console.warn(`RND reward exceeded target: ${elapsed.toFixed(2)}ms > 5ms`);
|
|
130
|
+
}
|
|
131
|
+
return intrinsic * this.config.intrinsicCoef;
|
|
132
|
+
}
|
|
133
|
+
/**
|
|
134
|
+
* Update curiosity models from trajectory
|
|
135
|
+
*/
|
|
136
|
+
update(trajectory) {
|
|
137
|
+
const startTime = performance.now();
|
|
138
|
+
if (trajectory.steps.length < 2) {
|
|
139
|
+
return { forwardLoss: 0, inverseLoss: 0 };
|
|
140
|
+
}
|
|
141
|
+
let totalForwardLoss = 0;
|
|
142
|
+
let totalInverseLoss = 0;
|
|
143
|
+
let count = 0;
|
|
144
|
+
for (let i = 0; i < trajectory.steps.length - 1; i++) {
|
|
145
|
+
const step = trajectory.steps[i];
|
|
146
|
+
const nextStep = trajectory.steps[i + 1];
|
|
147
|
+
const stateFeature = this.encodeState(step.stateAfter);
|
|
148
|
+
const nextStateFeature = this.encodeState(nextStep.stateAfter);
|
|
149
|
+
const actionIdx = this.hashAction(step.action);
|
|
150
|
+
// Update forward model
|
|
151
|
+
const forwardLoss = this.updateForwardModel(stateFeature, actionIdx, nextStateFeature);
|
|
152
|
+
totalForwardLoss += forwardLoss;
|
|
153
|
+
// Update inverse model
|
|
154
|
+
const inverseLoss = this.updateInverseModel(stateFeature, nextStateFeature, actionIdx);
|
|
155
|
+
totalInverseLoss += inverseLoss;
|
|
156
|
+
// Update RND predictor if using RND
|
|
157
|
+
if (this.config.useRND) {
|
|
158
|
+
this.updateRNDPredictor(nextStep.stateAfter);
|
|
159
|
+
}
|
|
160
|
+
count++;
|
|
161
|
+
}
|
|
162
|
+
this.updateCount++;
|
|
163
|
+
this.avgForwardLoss = count > 0 ? totalForwardLoss / count : 0;
|
|
164
|
+
this.avgInverseLoss = count > 0 ? totalInverseLoss / count : 0;
|
|
165
|
+
const elapsed = performance.now() - startTime;
|
|
166
|
+
if (elapsed > 10) {
|
|
167
|
+
console.warn(`Curiosity update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
168
|
+
}
|
|
169
|
+
return {
|
|
170
|
+
forwardLoss: this.avgForwardLoss,
|
|
171
|
+
inverseLoss: this.avgInverseLoss,
|
|
172
|
+
};
|
|
173
|
+
}
|
|
174
|
+
/**
|
|
175
|
+
* Add intrinsic rewards to trajectory
|
|
176
|
+
*/
|
|
177
|
+
augmentTrajectory(trajectory) {
|
|
178
|
+
const augmented = { ...trajectory, steps: [...trajectory.steps] };
|
|
179
|
+
for (let i = 0; i < augmented.steps.length - 1; i++) {
|
|
180
|
+
const step = augmented.steps[i];
|
|
181
|
+
const nextStep = augmented.steps[i + 1];
|
|
182
|
+
const intrinsic = this.computeIntrinsicReward(step.stateAfter, step.action, nextStep.stateAfter);
|
|
183
|
+
// Augment reward
|
|
184
|
+
augmented.steps[i] = {
|
|
185
|
+
...step,
|
|
186
|
+
reward: step.reward + intrinsic,
|
|
187
|
+
};
|
|
188
|
+
}
|
|
189
|
+
return augmented;
|
|
190
|
+
}
|
|
191
|
+
/**
|
|
192
|
+
* Get statistics
|
|
193
|
+
*/
|
|
194
|
+
getStats() {
|
|
195
|
+
return {
|
|
196
|
+
updateCount: this.updateCount,
|
|
197
|
+
avgForwardLoss: this.avgForwardLoss,
|
|
198
|
+
avgInverseLoss: this.avgInverseLoss,
|
|
199
|
+
avgIntrinsicReward: this.avgIntrinsicReward,
|
|
200
|
+
intrinsicMean: this.intrinsicMean,
|
|
201
|
+
intrinsicStd: Math.sqrt(this.intrinsicVar),
|
|
202
|
+
};
|
|
203
|
+
}
|
|
204
|
+
// ==========================================================================
|
|
205
|
+
// Private Methods
|
|
206
|
+
// ==========================================================================
|
|
207
|
+
initWeight(inputDim, outputDim) {
|
|
208
|
+
const weight = new Float32Array(inputDim * outputDim);
|
|
209
|
+
const scale = Math.sqrt(2 / inputDim);
|
|
210
|
+
for (let i = 0; i < weight.length; i++) {
|
|
211
|
+
weight[i] = (Math.random() - 0.5) * scale;
|
|
212
|
+
}
|
|
213
|
+
return weight;
|
|
214
|
+
}
|
|
215
|
+
encodeState(state) {
|
|
216
|
+
const featureDim = this.config.featureDim;
|
|
217
|
+
const feature = new Float32Array(featureDim);
|
|
218
|
+
for (let f = 0; f < featureDim; f++) {
|
|
219
|
+
let sum = 0;
|
|
220
|
+
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
|
|
221
|
+
sum += state[s] * this.featureEncoder[s * featureDim + f];
|
|
222
|
+
}
|
|
223
|
+
feature[f] = Math.max(0, sum); // ReLU
|
|
224
|
+
}
|
|
225
|
+
return feature;
|
|
226
|
+
}
|
|
227
|
+
forwardPredict(stateFeature, action) {
|
|
228
|
+
const featureDim = this.config.featureDim;
|
|
229
|
+
const inputDim = featureDim + this.numActions;
|
|
230
|
+
const predicted = new Float32Array(featureDim);
|
|
231
|
+
// Concatenate feature and one-hot action
|
|
232
|
+
const input = new Float32Array(inputDim);
|
|
233
|
+
input.set(stateFeature);
|
|
234
|
+
input[featureDim + action] = 1;
|
|
235
|
+
// Forward pass
|
|
236
|
+
for (let f = 0; f < featureDim; f++) {
|
|
237
|
+
let sum = 0;
|
|
238
|
+
for (let i = 0; i < inputDim; i++) {
|
|
239
|
+
sum += input[i] * this.forwardModel[i * featureDim + f];
|
|
240
|
+
}
|
|
241
|
+
predicted[f] = sum;
|
|
242
|
+
}
|
|
243
|
+
return predicted;
|
|
244
|
+
}
|
|
245
|
+
inversePredict(stateFeature, nextStateFeature) {
|
|
246
|
+
const featureDim = this.config.featureDim;
|
|
247
|
+
const logits = new Float32Array(this.numActions);
|
|
248
|
+
// Concatenate features
|
|
249
|
+
const input = new Float32Array(2 * featureDim);
|
|
250
|
+
input.set(stateFeature);
|
|
251
|
+
input.set(nextStateFeature, featureDim);
|
|
252
|
+
// Forward pass
|
|
253
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
254
|
+
let sum = 0;
|
|
255
|
+
for (let i = 0; i < 2 * featureDim; i++) {
|
|
256
|
+
sum += input[i] * this.inverseModel[i * this.numActions + a];
|
|
257
|
+
}
|
|
258
|
+
logits[a] = sum;
|
|
259
|
+
}
|
|
260
|
+
return this.softmax(logits);
|
|
261
|
+
}
|
|
262
|
+
rndForward(state, weights) {
|
|
263
|
+
const featureDim = this.config.featureDim;
|
|
264
|
+
const output = new Float32Array(featureDim);
|
|
265
|
+
for (let f = 0; f < featureDim; f++) {
|
|
266
|
+
let sum = 0;
|
|
267
|
+
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
|
|
268
|
+
sum += state[s] * weights[s * featureDim + f];
|
|
269
|
+
}
|
|
270
|
+
output[f] = Math.max(0, sum); // ReLU
|
|
271
|
+
}
|
|
272
|
+
return output;
|
|
273
|
+
}
|
|
274
|
+
updateForwardModel(stateFeature, action, targetFeature) {
|
|
275
|
+
const featureDim = this.config.featureDim;
|
|
276
|
+
const inputDim = featureDim + this.numActions;
|
|
277
|
+
const lr = this.config.forwardLR;
|
|
278
|
+
const beta = 0.9;
|
|
279
|
+
// Forward pass
|
|
280
|
+
const predicted = this.forwardPredict(stateFeature, action);
|
|
281
|
+
// Compute loss and gradient
|
|
282
|
+
let loss = 0;
|
|
283
|
+
const grad = new Float32Array(predicted.length);
|
|
284
|
+
for (let f = 0; f < featureDim; f++) {
|
|
285
|
+
const diff = predicted[f] - targetFeature[f];
|
|
286
|
+
loss += diff * diff;
|
|
287
|
+
grad[f] = 2 * diff;
|
|
288
|
+
}
|
|
289
|
+
// Backprop to weights
|
|
290
|
+
const input = new Float32Array(inputDim);
|
|
291
|
+
input.set(stateFeature);
|
|
292
|
+
input[featureDim + action] = 1;
|
|
293
|
+
for (let i = 0; i < inputDim; i++) {
|
|
294
|
+
for (let f = 0; f < featureDim; f++) {
|
|
295
|
+
const weightGrad = input[i] * grad[f];
|
|
296
|
+
const idx = i * featureDim + f;
|
|
297
|
+
this.forwardMomentum[idx] = beta * this.forwardMomentum[idx] + (1 - beta) * weightGrad;
|
|
298
|
+
this.forwardModel[idx] -= lr * this.forwardMomentum[idx];
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
return loss;
|
|
302
|
+
}
|
|
303
|
+
updateInverseModel(stateFeature, nextStateFeature, targetAction) {
|
|
304
|
+
const featureDim = this.config.featureDim;
|
|
305
|
+
const lr = this.config.inverseLR;
|
|
306
|
+
const beta = 0.9;
|
|
307
|
+
// Forward pass
|
|
308
|
+
const probs = this.inversePredict(stateFeature, nextStateFeature);
|
|
309
|
+
// Cross-entropy loss
|
|
310
|
+
const loss = -Math.log(probs[targetAction] + 1e-8);
|
|
311
|
+
// Gradient
|
|
312
|
+
const grad = new Float32Array(this.numActions);
|
|
313
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
314
|
+
grad[a] = probs[a] - (a === targetAction ? 1 : 0);
|
|
315
|
+
}
|
|
316
|
+
// Backprop to weights
|
|
317
|
+
const input = new Float32Array(2 * featureDim);
|
|
318
|
+
input.set(stateFeature);
|
|
319
|
+
input.set(nextStateFeature, featureDim);
|
|
320
|
+
for (let i = 0; i < 2 * featureDim; i++) {
|
|
321
|
+
for (let a = 0; a < this.numActions; a++) {
|
|
322
|
+
const weightGrad = input[i] * grad[a];
|
|
323
|
+
const idx = i * this.numActions + a;
|
|
324
|
+
this.inverseMomentum[idx] = beta * this.inverseMomentum[idx] + (1 - beta) * weightGrad;
|
|
325
|
+
this.inverseModel[idx] -= lr * this.inverseMomentum[idx];
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
return loss;
|
|
329
|
+
}
|
|
330
|
+
updateRNDPredictor(state) {
|
|
331
|
+
const featureDim = this.config.featureDim;
|
|
332
|
+
const lr = this.config.learningRate;
|
|
333
|
+
const beta = 0.9;
|
|
334
|
+
// Target output (fixed)
|
|
335
|
+
const targetOutput = this.rndForward(state, this.rndTarget);
|
|
336
|
+
// Predictor output
|
|
337
|
+
const predictorOutput = this.rndForward(state, this.rndPredictor);
|
|
338
|
+
// Gradient
|
|
339
|
+
const grad = new Float32Array(featureDim);
|
|
340
|
+
for (let f = 0; f < featureDim; f++) {
|
|
341
|
+
grad[f] = 2 * (predictorOutput[f] - targetOutput[f]);
|
|
342
|
+
}
|
|
343
|
+
// Update predictor weights
|
|
344
|
+
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
|
|
345
|
+
for (let f = 0; f < featureDim; f++) {
|
|
346
|
+
if (predictorOutput[f] > 0) { // ReLU gradient
|
|
347
|
+
const weightGrad = state[s] * grad[f];
|
|
348
|
+
const idx = s * featureDim + f;
|
|
349
|
+
this.rndMomentum[idx] = beta * this.rndMomentum[idx] + (1 - beta) * weightGrad;
|
|
350
|
+
this.rndPredictor[idx] -= lr * this.rndMomentum[idx];
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
normalizeIntrinsic(raw) {
|
|
356
|
+
// Update running statistics
|
|
357
|
+
const alpha = 0.01;
|
|
358
|
+
this.intrinsicMean = (1 - alpha) * this.intrinsicMean + alpha * raw;
|
|
359
|
+
this.intrinsicVar = (1 - alpha) * this.intrinsicVar + alpha * (raw - this.intrinsicMean) ** 2;
|
|
360
|
+
// Normalize
|
|
361
|
+
const normalized = (raw - this.intrinsicMean) / (Math.sqrt(this.intrinsicVar) + 1e-8);
|
|
362
|
+
// Clip to reasonable range
|
|
363
|
+
return Math.max(-5, Math.min(5, normalized));
|
|
364
|
+
}
|
|
365
|
+
softmax(logits) {
|
|
366
|
+
const max = Math.max(...logits);
|
|
367
|
+
const exps = new Float32Array(logits.length);
|
|
368
|
+
let sum = 0;
|
|
369
|
+
for (let i = 0; i < logits.length; i++) {
|
|
370
|
+
exps[i] = Math.exp(logits[i] - max);
|
|
371
|
+
sum += exps[i];
|
|
372
|
+
}
|
|
373
|
+
for (let i = 0; i < exps.length; i++) {
|
|
374
|
+
exps[i] /= sum;
|
|
375
|
+
}
|
|
376
|
+
return exps;
|
|
377
|
+
}
|
|
378
|
+
hashAction(action) {
|
|
379
|
+
let hash = 0;
|
|
380
|
+
for (let i = 0; i < action.length; i++) {
|
|
381
|
+
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
382
|
+
}
|
|
383
|
+
return hash;
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
/**
|
|
387
|
+
* Factory function
|
|
388
|
+
*/
|
|
389
|
+
export function createCuriosity(config) {
|
|
390
|
+
return new CuriosityModule(config);
|
|
391
|
+
}
|
|
392
|
+
//# sourceMappingURL=curiosity.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"curiosity.js","sourceRoot":"","sources":["../../src/algorithms/curiosity.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAIH;;GAEG;AACH,MAAM,CAAC,MAAM,wBAAwB,GAAoB;IACvD,SAAS,EAAE,WAAW;IACtB,YAAY,EAAE,MAAM;IACpB,KAAK,EAAE,IAAI;IACX,WAAW,EAAE,IAAI;IACjB,aAAa,EAAE,GAAG;IAClB,WAAW,EAAE,GAAG;IAChB,MAAM,EAAE,CAAC;IACT,aAAa,EAAE,EAAE;IACjB,aAAa,EAAE,GAAG;IAClB,SAAS,EAAE,KAAK;IAChB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,EAAE;IACd,MAAM,EAAE,KAAK;CACd,CAAC;AAEF;;GAEG;AACH,MAAM,OAAO,eAAe;IAClB,MAAM,CAAkB;IAEhC,kBAAkB;IACV,cAAc,CAAe;IAErC,8EAA8E;IACtE,YAAY,CAAe;IAEnC,yEAAyE;IACjE,YAAY,CAAe;IAEnC,oCAAoC;IAC5B,SAAS,CAAe;IACxB,YAAY,CAAe;IAEnC,kBAAkB;IACV,eAAe,CAAe;IAC9B,eAAe,CAAe;IAC9B,WAAW,CAAe;IAElC,aAAa;IACL,QAAQ,GAAG,GAAG,CAAC;IACf,UAAU,GAAG,CAAC,CAAC;IAEvB,uCAAuC;IAC/B,aAAa,GAAG,CAAC,CAAC;IAClB,YAAY,GAAG,CAAC,CAAC;IACjB,WAAW,GAAG,CAAC,CAAC;IAExB,aAAa;IACL,cAAc,GAAG,CAAC,CAAC;IACnB,cAAc,GAAG,CAAC,CAAC;IACnB,kBAAkB,GAAG,CAAC,CAAC;IAE/B,YAAY,SAAmC,EAAE;QAC/C,IAAI,CAAC,MAAM,GAAG,EAAE,GAAG,wBAAwB,EAAE,GAAG,MAAM,EAAE,CAAC;QAEzD,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAE1C,uDAAuD;QACvD,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;QAEjE,4DAA4D;QAC5D,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,UAAU,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;QAE9E,kDAAkD;QAClD,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC,GAAG,UAAU,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;QAErE,eAAe;QACf,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;QAC5D,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;QAE/D,mBAAmB;QACnB,IAAI,CAAC,eAAe,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;QAClE,IAAI,CAAC,eAAe,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;IAChE,CAAC;IAED;;OAEG;IACH,sBAAsB,CACpB,KAAmB,EACnB,MAAc,EACd,SAAuB;QAEvB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC;YACvB,OAAO,IAAI,CAAC,gBAAgB,CAAC,SAAS,CAAC,CAAC;QAC1C,CAAC;aAAM,CAAC;YACN,OAAO,IAAI,CAAC,gBAAgB,CAAC,KAAK,EAAE,MAAM,EAAE,SAAS,CAAC,CAAC;QACzD,CAAC;IACH,CAAC;IAED;;OAEG;IACH,gBAAgB,CACd,KAAmB,EACnB,MAAc,EACd,SAAuB;QAEvB,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,4BAA4B;QAC5B,MAAM,YAAY,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QAC7C,MAAM,gBAAgB,GAAG,IAAI,CAAC,WAAW,CAAC,SAAS,CAAC,CAAC;QAErD,6BAA6B;QAC7B,MAAM,SAAS,GAAG,IAAI,CAAC,UAAU,CAAC,MAAM,CAAC,CAAC;QAC1C,MAAM,gBAAgB,GAAG,IAAI,CAAC,cAAc,CAAC,YAAY,EAAE,SAAS,CAAC,CAAC;QAEtE,+CAA+C;QAC/C,IAAI,KAAK,GAAG,CAAC,CAAC;QACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YAChD,KAAK,IAAI,CAAC,gBAAgB,CAAC,CAAC,CAAC,GAAG,gBAAgB,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAC5D,CAAC;QAED,6BAA6B;QAC7B,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,CAAC;QAEjD,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,CAAC,EAAE,CAAC;YAChB,OAAO,CAAC,IAAI,CAAC,+BAA+B,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QAC5E,CAAC;QAED,OAAO,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,aAAa,CAAC;IAC/C,CAAC;IAED;;OAEG;IACH,gBAAgB,CAAC,KAAmB;QAClC,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,gDAAgD;QAChD,MAAM,YAAY,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAE5D,qDAAqD;QACrD,MAAM,eAAe,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;QAElE,2BAA2B;QAC3B,IAAI,KAAK,GAAG,CAAC,CAAC;QACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YAChD,KAAK,IAAI,CAAC,eAAe,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QACvD,CAAC;QAED,YAAY;QACZ,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,CAAC;QAEjD,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,CAAC,EAAE,CAAC;YAChB,OAAO,CAAC,IAAI,CAAC,+BAA+B,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;QAC5E,CAAC;QAED,OAAO,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,aAAa,CAAC;IAC/C,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,UAAsB;QAC3B,MAAM,SAAS,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAEpC,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YAChC,OAAO,EAAE,WAAW,EAAE,CAAC,EAAE,WAAW,EAAE,CAAC,EAAE,CAAC;QAC5C,CAAC;QAED,IAAI,gBAAgB,GAAG,CAAC,CAAC;QACzB,IAAI,gBAAgB,GAAG,CAAC,CAAC;QACzB,IAAI,KAAK,GAAG,CAAC,CAAC;QAEd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACrD,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,QAAQ,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAEzC,MAAM,YAAY,GAAG,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;YACvD,MAAM,gBAAgB,GAAG,IAAI,CAAC,WAAW,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC;YAC/D,MAAM,SAAS,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAE/C,uBAAuB;YACvB,MAAM,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,YAAY,EAAE,SAAS,EAAE,gBAAgB,CAAC,CAAC;YACvF,gBAAgB,IAAI,WAAW,CAAC;YAEhC,uBAAuB;YACvB,MAAM,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,CAAC,CAAC;YACvF,gBAAgB,IAAI,WAAW,CAAC;YAEhC,oCAAoC;YACpC,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC;gBACvB,IAAI,CAAC,kBAAkB,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC;YAC/C,CAAC;YAED,KAAK,EAAE,CAAC;QACV,CAAC;QAED,IAAI,CAAC,WAAW,EAAE,CAAC;QACnB,IAAI,CAAC,cAAc,GAAG,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC,gBAAgB,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/D,IAAI,CAAC,cAAc,GAAG,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC,gBAAgB,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QAE/D,MAAM,OAAO,GAAG,WAAW,CAAC,GAAG,EAAE,GAAG,SAAS,CAAC;QAC9C,IAAI,OAAO,GAAG,EAAE,EAAE,CAAC;YACjB,OAAO,CAAC,IAAI,CAAC,qCAAqC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;QACnF,CAAC;QAED,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,cAAc;YAChC,WAAW,EAAE,IAAI,CAAC,cAAc;SACjC,CAAC;IACJ,CAAC;IAED;;OAEG;IACH,iBAAiB,CAAC,UAAsB;QACtC,MAAM,SAAS,GAAG,EAAE,GAAG,UAAU,EAAE,KAAK,EAAE,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,EAAE,CAAC;QAElE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACpD,MAAM,IAAI,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAChC,MAAM,QAAQ,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAExC,MAAM,SAAS,GAAG,IAAI,CAAC,sBAAsB,CAC3C,IAAI,CAAC,UAAU,EACf,IAAI,CAAC,MAAM,EACX,QAAQ,CAAC,UAAU,CACpB,CAAC;YAEF,iBAAiB;YACjB,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG;gBACnB,GAAG,IAAI;gBACP,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,SAAS;aAChC,CAAC;QACJ,CAAC;QAED,OAAO,SAAS,CAAC;IACnB,CAAC;IAED;;OAEG;IACH,QAAQ;QACN,OAAO;YACL,WAAW,EAAE,IAAI,CAAC,WAAW;YAC7B,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,kBAAkB,EAAE,IAAI,CAAC,kBAAkB;YAC3C,aAAa,EAAE,IAAI,CAAC,aAAa;YACjC,YAAY,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,YAAY,CAAC;SAC3C,CAAC;IACJ,CAAC;IAED,6EAA6E;IAC7E,kBAAkB;IAClB,6EAA6E;IAErE,UAAU,CAAC,QAAgB,EAAE,SAAiB;QACpD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,QAAQ,GAAG,SAAS,CAAC,CAAC;QACtD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC;QACtC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC;QAC5C,CAAC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,WAAW,CAAC,KAAmB;QACrC,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,OAAO,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;QAE7C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/D,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,cAAc,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,CAAC;YAC5D,CAAC;YACD,OAAO,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,OAAO;QACxC,CAAC;QAED,OAAO,OAAO,CAAC;IACjB,CAAC;IAEO,cAAc,CAAC,YAA0B,EAAE,MAAc;QAC/D,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,QAAQ,GAAG,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9C,MAAM,SAAS,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;QAE/C,yCAAyC;QACzC,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,QAAQ,CAAC,CAAC;QACzC,KAAK,CAAC,GAAG,CAAC,YAAY,CAAC,CAAC;QACxB,KAAK,CAAC,UAAU,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC;QAE/B,eAAe;QACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,CAAC;YAC1D,CAAC;YACD,SAAS,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;QACrB,CAAC;QAED,OAAO,SAAS,CAAC;IACnB,CAAC;IAEO,cAAc,CACpB,YAA0B,EAC1B,gBAA8B;QAE9B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAEjD,uBAAuB;QACvB,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC;QAC/C,KAAK,CAAC,GAAG,CAAC,YAAY,CAAC,CAAC;QACxB,KAAK,CAAC,GAAG,CAAC,gBAAgB,EAAE,UAAU,CAAC,CAAC;QAExC,eAAe;QACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACxC,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC;YAC/D,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;QAClB,CAAC;QAED,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC;IAC9B,CAAC;IAEO,UAAU,CAAC,KAAmB,EAAE,OAAqB;QAC3D,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;QAE5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,IAAI,GAAG,GAAG,CAAC,CAAC;YACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/D,GAAG,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,CAAC;YAChD,CAAC;YACD,MAAM,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,OAAO;QACvC,CAAC;QAED,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,kBAAkB,CACxB,YAA0B,EAC1B,MAAc,EACd,aAA2B;QAE3B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,QAAQ,GAAG,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9C,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC;QACjC,MAAM,IAAI,GAAG,GAAG,CAAC;QAEjB,eAAe;QACf,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,YAAY,EAAE,MAAM,CAAC,CAAC;QAE5D,4BAA4B;QAC5B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;QAEhD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,MAAM,IAAI,GAAG,SAAS,CAAC,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;YAC7C,IAAI,IAAI,IAAI,GAAG,IAAI,CAAC;YACpB,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC;QACrB,CAAC;QAED,sBAAsB;QACtB,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,QAAQ,CAAC,CAAC;QACzC,KAAK,CAAC,GAAG,CAAC,YAAY,CAAC,CAAC;QACxB,KAAK,CAAC,UAAU,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC;QAE/B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACpC,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;gBACtC,MAAM,GAAG,GAAG,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC;gBAC/B,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,GAAG,IAAI,GAAG,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,UAAU,CAAC;gBACvF,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,EAAE,GAAG,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,CAAC;YAC3D,CAAC;QACH,CAAC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,kBAAkB,CACxB,YAA0B,EAC1B,gBAA8B,EAC9B,YAAoB;QAEpB,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC;QACjC,MAAM,IAAI,GAAG,GAAG,CAAC;QAEjB,eAAe;QACf,MAAM,KAAK,GAAG,IAAI,CAAC,cAAc,CAAC,YAAY,EAAE,gBAAgB,CAAC,CAAC;QAElE,qBAAqB;QACrB,MAAM,IAAI,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,YAAY,CAAC,GAAG,IAAI,CAAC,CAAC;QAEnD,WAAW;QACX,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAC/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,IAAI,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,KAAK,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACpD,CAAC;QAED,sBAAsB;QACtB,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC;QAC/C,KAAK,CAAC,GAAG,CAAC,YAAY,CAAC,CAAC;QACxB,KAAK,CAAC,GAAG,CAAC,gBAAgB,EAAE,UAAU,CAAC,CAAC;QAExC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzC,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;gBACtC,MAAM,GAAG,GAAG,CAAC,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;gBACpC,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,GAAG,IAAI,GAAG,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,UAAU,CAAC;gBACvF,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,EAAE,GAAG,IAAI,CAAC,eAAe,CAAC,GAAG,CAAC,CAAC;YAC3D,CAAC;QACH,CAAC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,kBAAkB,CAAC,KAAmB;QAC5C,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC;QAC1C,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC,YAAY,CAAC;QACpC,MAAM,IAAI,GAAG,GAAG,CAAC;QAEjB,wBAAwB;QACxB,MAAM,YAAY,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAE5D,mBAAmB;QACnB,MAAM,eAAe,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;QAElE,WAAW;QACX,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;QAC1C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;QACvD,CAAC;QAED,2BAA2B;QAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YAC/D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;gBACpC,IAAI,eAAe,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,gBAAgB;oBAC5C,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;oBACtC,MAAM,GAAG,GAAG,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC;oBAC/B,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,IAAI,GAAG,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,UAAU,CAAC;oBAC/E,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,EAAE,GAAG,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC;gBACvD,CAAC;YACH,CAAC;QACH,CAAC;IACH,CAAC;IAEO,kBAAkB,CAAC,GAAW;QACpC,4BAA4B;QAC5B,MAAM,KAAK,GAAG,IAAI,CAAC;QACnB,IAAI,CAAC,aAAa,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,GAAG,IAAI,CAAC,aAAa,GAAG,KAAK,GAAG,GAAG,CAAC;QACpE,IAAI,CAAC,YAAY,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,GAAG,KAAK,GAAG,CAAC,GAAG,GAAG,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;QAE9F,YAAY;QACZ,MAAM,UAAU,GAAG,CAAC,GAAG,GAAG,IAAI,CAAC,aAAa,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,YAAY,CAAC,GAAG,IAAI,CAAC,CAAC;QAEtF,2BAA2B;QAC3B,OAAO,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;IAC/C,CAAC;IAEO,OAAO,CAAC,MAAoB;QAClC,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;QAChC,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;QAC7C,IAAI,GAAG,GAAG,CAAC,CAAC;QAEZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;YACpC,GAAG,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC;QACjB,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACrC,IAAI,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC;QACjB,CAAC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,UAAU,CAAC,MAAc;QAC/B,IAAI,IAAI,GAAG,CAAC,CAAC;QACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACvC,IAAI,GAAG,CAAC,IAAI,GAAG,EAAE,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QAC9D,CAAC;QACD,OAAO,IAAI,CAAC;IACd,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,eAAe,CAAC,MAAiC;IAC/D,OAAO,IAAI,eAAe,CAAC,MAAM,CAAC,CAAC;AACrC,CAAC"}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Decision Transformer
|
|
3
|
+
*
|
|
4
|
+
* Implements sequence modeling approach for RL:
|
|
5
|
+
* - Trajectory as sequence: (s, a, R, s, a, R, ...)
|
|
6
|
+
* - Return-conditioned generation
|
|
7
|
+
* - Causal transformer attention
|
|
8
|
+
* - Offline RL from trajectories
|
|
9
|
+
*
|
|
10
|
+
* Performance Target: <10ms per forward pass
|
|
11
|
+
*/
|
|
12
|
+
import type { DecisionTransformerConfig, Trajectory } from '../types.js';
|
|
13
|
+
/**
|
|
14
|
+
* Default Decision Transformer configuration
|
|
15
|
+
*/
|
|
16
|
+
export declare const DEFAULT_DT_CONFIG: DecisionTransformerConfig;
|
|
17
|
+
/**
|
|
18
|
+
* Sequence entry for transformer
|
|
19
|
+
*/
|
|
20
|
+
interface SequenceEntry {
|
|
21
|
+
returnToGo: number;
|
|
22
|
+
state: Float32Array;
|
|
23
|
+
action: number;
|
|
24
|
+
timestep: number;
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Decision Transformer Implementation
|
|
28
|
+
*/
|
|
29
|
+
export declare class DecisionTransformer {
|
|
30
|
+
private config;
|
|
31
|
+
private stateEmbed;
|
|
32
|
+
private actionEmbed;
|
|
33
|
+
private returnEmbed;
|
|
34
|
+
private posEmbed;
|
|
35
|
+
private attentionWeights;
|
|
36
|
+
private ffnWeights;
|
|
37
|
+
private actionHead;
|
|
38
|
+
private trajectoryBuffer;
|
|
39
|
+
private stateDim;
|
|
40
|
+
private numActions;
|
|
41
|
+
private updateCount;
|
|
42
|
+
private avgLoss;
|
|
43
|
+
constructor(config?: Partial<DecisionTransformerConfig>);
|
|
44
|
+
/**
|
|
45
|
+
* Add trajectory for training
|
|
46
|
+
*/
|
|
47
|
+
addTrajectory(trajectory: Trajectory): void;
|
|
48
|
+
/**
|
|
49
|
+
* Train on buffered trajectories
|
|
50
|
+
* Target: <10ms per batch
|
|
51
|
+
*/
|
|
52
|
+
train(): {
|
|
53
|
+
loss: number;
|
|
54
|
+
accuracy: number;
|
|
55
|
+
};
|
|
56
|
+
/**
|
|
57
|
+
* Get action conditioned on target return
|
|
58
|
+
*/
|
|
59
|
+
getAction(states: Float32Array[], actions: number[], targetReturn: number): number;
|
|
60
|
+
/**
|
|
61
|
+
* Forward pass through transformer
|
|
62
|
+
*/
|
|
63
|
+
forward(sequence: SequenceEntry[]): Float32Array;
|
|
64
|
+
/**
|
|
65
|
+
* Get statistics
|
|
66
|
+
*/
|
|
67
|
+
getStats(): Record<string, number>;
|
|
68
|
+
private initEmbedding;
|
|
69
|
+
private initWeight;
|
|
70
|
+
private createSequence;
|
|
71
|
+
private transformerLayer;
|
|
72
|
+
private updateWeights;
|
|
73
|
+
private softmax;
|
|
74
|
+
private argmax;
|
|
75
|
+
private hashAction;
|
|
76
|
+
}
|
|
77
|
+
/**
|
|
78
|
+
* Factory function
|
|
79
|
+
*/
|
|
80
|
+
export declare function createDecisionTransformer(config?: Partial<DecisionTransformerConfig>): DecisionTransformer;
|
|
81
|
+
export {};
|
|
82
|
+
//# sourceMappingURL=decision-transformer.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"decision-transformer.d.ts","sourceRoot":"","sources":["../../src/algorithms/decision-transformer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAEH,OAAO,KAAK,EACV,yBAAyB,EACzB,UAAU,EAEX,MAAM,aAAa,CAAC;AAErB;;GAEG;AACH,eAAO,MAAM,iBAAiB,EAAE,yBAe/B,CAAC;AAEF;;GAEG;AACH,UAAU,aAAa;IACrB,UAAU,EAAE,MAAM,CAAC;IACnB,KAAK,EAAE,YAAY,CAAC;IACpB,MAAM,EAAE,MAAM,CAAC;IACf,QAAQ,EAAE,MAAM,CAAC;CAClB;AAED;;GAEG;AACH,qBAAa,mBAAmB;IAC9B,OAAO,CAAC,MAAM,CAA4B;IAG1C,OAAO,CAAC,UAAU,CAAe;IACjC,OAAO,CAAC,WAAW,CAAe;IAClC,OAAO,CAAC,WAAW,CAAe;IAClC,OAAO,CAAC,QAAQ,CAAe;IAG/B,OAAO,CAAC,gBAAgB,CAAmB;IAC3C,OAAO,CAAC,UAAU,CAAmB;IAGrC,OAAO,CAAC,UAAU,CAAe;IAGjC,OAAO,CAAC,gBAAgB,CAAoB;IAG5C,OAAO,CAAC,QAAQ,CAAO;IACvB,OAAO,CAAC,UAAU,CAAK;IAGvB,OAAO,CAAC,WAAW,CAAK;IACxB,OAAO,CAAC,OAAO,CAAK;gBAER,MAAM,GAAE,OAAO,CAAC,yBAAyB,CAAM;IAiC3D;;OAEG;IACH,aAAa,CAAC,UAAU,EAAE,UAAU,GAAG,IAAI;IAW3C;;;OAGG;IACH,KAAK,IAAI;QAAE,IAAI,EAAE,MAAM,CAAC;QAAC,QAAQ,EAAE,MAAM,CAAA;KAAE;IAgE3C;;OAEG;IACH,SAAS,CACP,MAAM,EAAE,YAAY,EAAE,EACtB,OAAO,EAAE,MAAM,EAAE,EACjB,YAAY,EAAE,MAAM,GACnB,MAAM;IAwBT;;OAEG;IACH,OAAO,CAAC,QAAQ,EAAE,aAAa,EAAE,GAAG,YAAY;IA2DhD;;OAEG;IACH,QAAQ,IAAI,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC;IAclC,OAAO,CAAC,aAAa;IASrB,OAAO,CAAC,UAAU;IASlB,OAAO,CAAC,cAAc;IA0BtB,OAAO,CAAC,gBAAgB;IAyGxB,OAAO,CAAC,aAAa;IAuBrB,OAAO,CAAC,OAAO;IAiBf,OAAO,CAAC,MAAM;IAYd,OAAO,CAAC,UAAU;CAOnB;AAED;;GAEG;AACH,wBAAgB,yBAAyB,CACvC,MAAM,CAAC,EAAE,OAAO,CAAC,yBAAyB,CAAC,GAC1C,mBAAmB,CAErB"}
|