agentic-qe 2.1.2 → 2.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude/skills/agentic-quality-engineering/SKILL.md +4 -4
- package/.claude/skills/cicd-pipeline-qe-orchestrator/README.md +14 -11
- package/.claude/skills/skills-manifest.json +2 -2
- package/CHANGELOG.md +138 -0
- package/README.md +92 -214
- package/dist/agents/BaseAgent.d.ts +5 -1
- package/dist/agents/BaseAgent.d.ts.map +1 -1
- package/dist/agents/BaseAgent.js +32 -17
- package/dist/agents/BaseAgent.js.map +1 -1
- package/dist/agents/index.d.ts.map +1 -1
- package/dist/agents/index.js +5 -1
- package/dist/agents/index.js.map +1 -1
- package/dist/cli/commands/improve/index.d.ts +8 -1
- package/dist/cli/commands/improve/index.d.ts.map +1 -1
- package/dist/cli/commands/improve/index.js +18 -16
- package/dist/cli/commands/improve/index.js.map +1 -1
- package/dist/cli/commands/learn/index.d.ts +10 -2
- package/dist/cli/commands/learn/index.d.ts.map +1 -1
- package/dist/cli/commands/learn/index.js +99 -63
- package/dist/cli/commands/learn/index.js.map +1 -1
- package/dist/cli/commands/patterns/index.d.ts +8 -1
- package/dist/cli/commands/patterns/index.d.ts.map +1 -1
- package/dist/cli/commands/patterns/index.js +79 -45
- package/dist/cli/commands/patterns/index.js.map +1 -1
- package/dist/cli/commands/routing/index.d.ts +5 -0
- package/dist/cli/commands/routing/index.d.ts.map +1 -1
- package/dist/cli/commands/routing/index.js +11 -10
- package/dist/cli/commands/routing/index.js.map +1 -1
- package/dist/cli/init/agents.d.ts +1 -1
- package/dist/cli/init/agents.js +2 -2
- package/dist/cli/init/database-init.d.ts +7 -0
- package/dist/cli/init/database-init.d.ts.map +1 -1
- package/dist/cli/init/database-init.js +29 -48
- package/dist/cli/init/database-init.js.map +1 -1
- package/dist/core/di/AgentDependencies.d.ts +127 -0
- package/dist/core/di/AgentDependencies.d.ts.map +1 -0
- package/dist/core/di/AgentDependencies.js +251 -0
- package/dist/core/di/AgentDependencies.js.map +1 -0
- package/dist/core/di/DIContainer.d.ts +149 -0
- package/dist/core/di/DIContainer.d.ts.map +1 -0
- package/dist/core/di/DIContainer.js +333 -0
- package/dist/core/di/DIContainer.js.map +1 -0
- package/dist/core/di/index.d.ts +11 -0
- package/dist/core/di/index.d.ts.map +1 -0
- package/dist/core/di/index.js +22 -0
- package/dist/core/di/index.js.map +1 -0
- package/dist/core/index.d.ts +1 -0
- package/dist/core/index.d.ts.map +1 -1
- package/dist/core/index.js +11 -1
- package/dist/core/index.js.map +1 -1
- package/dist/core/memory/HNSWVectorMemory.d.ts +261 -0
- package/dist/core/memory/HNSWVectorMemory.d.ts.map +1 -0
- package/dist/core/memory/HNSWVectorMemory.js +647 -0
- package/dist/core/memory/HNSWVectorMemory.js.map +1 -0
- package/dist/core/memory/SwarmMemoryManager.d.ts +7 -0
- package/dist/core/memory/SwarmMemoryManager.d.ts.map +1 -1
- package/dist/core/memory/SwarmMemoryManager.js +9 -0
- package/dist/core/memory/SwarmMemoryManager.js.map +1 -1
- package/dist/core/memory/index.d.ts +2 -0
- package/dist/core/memory/index.d.ts.map +1 -1
- package/dist/core/memory/index.js +11 -1
- package/dist/core/memory/index.js.map +1 -1
- package/dist/learning/ExperienceSharingProtocol.d.ts +243 -0
- package/dist/learning/ExperienceSharingProtocol.d.ts.map +1 -0
- package/dist/learning/ExperienceSharingProtocol.js +538 -0
- package/dist/learning/ExperienceSharingProtocol.js.map +1 -0
- package/dist/learning/ExplainableLearning.d.ts +191 -0
- package/dist/learning/ExplainableLearning.d.ts.map +1 -0
- package/dist/learning/ExplainableLearning.js +441 -0
- package/dist/learning/ExplainableLearning.js.map +1 -0
- package/dist/learning/GossipPatternSharingProtocol.d.ts +228 -0
- package/dist/learning/GossipPatternSharingProtocol.d.ts.map +1 -0
- package/dist/learning/GossipPatternSharingProtocol.js +590 -0
- package/dist/learning/GossipPatternSharingProtocol.js.map +1 -0
- package/dist/learning/LearningEngine.d.ts +104 -4
- package/dist/learning/LearningEngine.d.ts.map +1 -1
- package/dist/learning/LearningEngine.js +350 -16
- package/dist/learning/LearningEngine.js.map +1 -1
- package/dist/learning/PerformanceOptimizer.d.ts +268 -0
- package/dist/learning/PerformanceOptimizer.d.ts.map +1 -0
- package/dist/learning/PerformanceOptimizer.js +552 -0
- package/dist/learning/PerformanceOptimizer.js.map +1 -0
- package/dist/learning/PrivacyManager.d.ts +197 -0
- package/dist/learning/PrivacyManager.d.ts.map +1 -0
- package/dist/learning/PrivacyManager.js +551 -0
- package/dist/learning/PrivacyManager.js.map +1 -0
- package/dist/learning/QLearning.d.ts +38 -125
- package/dist/learning/QLearning.d.ts.map +1 -1
- package/dist/learning/QLearning.js +46 -267
- package/dist/learning/QLearning.js.map +1 -1
- package/dist/learning/QLearningLegacy.d.ts +154 -0
- package/dist/learning/QLearningLegacy.d.ts.map +1 -0
- package/dist/learning/QLearningLegacy.js +337 -0
- package/dist/learning/QLearningLegacy.js.map +1 -0
- package/dist/learning/TransferLearningManager.d.ts +212 -0
- package/dist/learning/TransferLearningManager.d.ts.map +1 -0
- package/dist/learning/TransferLearningManager.js +497 -0
- package/dist/learning/TransferLearningManager.js.map +1 -0
- package/dist/learning/algorithms/AbstractRLLearner.d.ts +162 -0
- package/dist/learning/algorithms/AbstractRLLearner.d.ts.map +1 -0
- package/dist/learning/algorithms/AbstractRLLearner.js +300 -0
- package/dist/learning/algorithms/AbstractRLLearner.js.map +1 -0
- package/dist/learning/algorithms/ActorCriticLearner.d.ts +201 -0
- package/dist/learning/algorithms/ActorCriticLearner.d.ts.map +1 -0
- package/dist/learning/algorithms/ActorCriticLearner.js +447 -0
- package/dist/learning/algorithms/ActorCriticLearner.js.map +1 -0
- package/dist/learning/algorithms/MAMLMetaLearner.d.ts +218 -0
- package/dist/learning/algorithms/MAMLMetaLearner.d.ts.map +1 -0
- package/dist/learning/algorithms/MAMLMetaLearner.js +532 -0
- package/dist/learning/algorithms/MAMLMetaLearner.js.map +1 -0
- package/dist/learning/algorithms/PPOLearner.d.ts +207 -0
- package/dist/learning/algorithms/PPOLearner.d.ts.map +1 -0
- package/dist/learning/algorithms/PPOLearner.js +490 -0
- package/dist/learning/algorithms/PPOLearner.js.map +1 -0
- package/dist/learning/algorithms/QLearning.d.ts +68 -0
- package/dist/learning/algorithms/QLearning.d.ts.map +1 -0
- package/dist/learning/algorithms/QLearning.js +116 -0
- package/dist/learning/algorithms/QLearning.js.map +1 -0
- package/dist/learning/algorithms/SARSALearner.d.ts +107 -0
- package/dist/learning/algorithms/SARSALearner.d.ts.map +1 -0
- package/dist/learning/algorithms/SARSALearner.js +252 -0
- package/dist/learning/algorithms/SARSALearner.js.map +1 -0
- package/dist/learning/algorithms/index.d.ts +32 -0
- package/dist/learning/algorithms/index.d.ts.map +1 -0
- package/dist/learning/algorithms/index.js +50 -0
- package/dist/learning/algorithms/index.js.map +1 -0
- package/dist/learning/index.d.ts +11 -0
- package/dist/learning/index.d.ts.map +1 -1
- package/dist/learning/index.js +31 -1
- package/dist/learning/index.js.map +1 -1
- package/dist/learning/types.d.ts +2 -0
- package/dist/learning/types.d.ts.map +1 -1
- package/dist/mcp/server-instructions.d.ts +1 -1
- package/dist/mcp/server-instructions.js +1 -1
- package/dist/memory/DistributedPatternLibrary.d.ts +159 -0
- package/dist/memory/DistributedPatternLibrary.d.ts.map +1 -0
- package/dist/memory/DistributedPatternLibrary.js +370 -0
- package/dist/memory/DistributedPatternLibrary.js.map +1 -0
- package/dist/memory/PatternQualityScorer.d.ts +169 -0
- package/dist/memory/PatternQualityScorer.d.ts.map +1 -0
- package/dist/memory/PatternQualityScorer.js +327 -0
- package/dist/memory/PatternQualityScorer.js.map +1 -0
- package/dist/memory/PatternReplicationService.d.ts +187 -0
- package/dist/memory/PatternReplicationService.d.ts.map +1 -0
- package/dist/memory/PatternReplicationService.js +392 -0
- package/dist/memory/PatternReplicationService.js.map +1 -0
- package/dist/providers/ClaudeProvider.d.ts +98 -0
- package/dist/providers/ClaudeProvider.d.ts.map +1 -0
- package/dist/providers/ClaudeProvider.js +418 -0
- package/dist/providers/ClaudeProvider.js.map +1 -0
- package/dist/providers/HybridRouter.d.ts +217 -0
- package/dist/providers/HybridRouter.d.ts.map +1 -0
- package/dist/providers/HybridRouter.js +679 -0
- package/dist/providers/HybridRouter.js.map +1 -0
- package/dist/providers/ILLMProvider.d.ts +287 -0
- package/dist/providers/ILLMProvider.d.ts.map +1 -0
- package/dist/providers/ILLMProvider.js +33 -0
- package/dist/providers/ILLMProvider.js.map +1 -0
- package/dist/providers/LLMProviderFactory.d.ts +154 -0
- package/dist/providers/LLMProviderFactory.d.ts.map +1 -0
- package/dist/providers/LLMProviderFactory.js +426 -0
- package/dist/providers/LLMProviderFactory.js.map +1 -0
- package/dist/providers/RuvllmProvider.d.ts +107 -0
- package/dist/providers/RuvllmProvider.d.ts.map +1 -0
- package/dist/providers/RuvllmProvider.js +417 -0
- package/dist/providers/RuvllmProvider.js.map +1 -0
- package/dist/providers/index.d.ts +32 -0
- package/dist/providers/index.d.ts.map +1 -0
- package/dist/providers/index.js +75 -0
- package/dist/providers/index.js.map +1 -0
- package/dist/telemetry/LearningTelemetry.d.ts +190 -0
- package/dist/telemetry/LearningTelemetry.d.ts.map +1 -0
- package/dist/telemetry/LearningTelemetry.js +403 -0
- package/dist/telemetry/LearningTelemetry.js.map +1 -0
- package/dist/telemetry/index.d.ts +1 -0
- package/dist/telemetry/index.d.ts.map +1 -1
- package/dist/telemetry/index.js +20 -2
- package/dist/telemetry/index.js.map +1 -1
- package/dist/telemetry/instrumentation/agent.d.ts +1 -1
- package/dist/telemetry/instrumentation/agent.js +1 -1
- package/dist/telemetry/instrumentation/index.d.ts +1 -1
- package/dist/telemetry/instrumentation/index.js +1 -1
- package/dist/utils/math.d.ts +11 -0
- package/dist/utils/math.d.ts.map +1 -0
- package/dist/utils/math.js +16 -0
- package/dist/utils/math.js.map +1 -0
- package/docs/reference/agents.md +1 -1
- package/docs/reference/skills.md +3 -3
- package/docs/reference/usage.md +4 -4
- package/package.json +1 -1
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* PPOLearner - Proximal Policy Optimization Algorithm
|
|
3
|
+
*
|
|
4
|
+
* Implements PPO-Clip, the most widely used variant of PPO:
|
|
5
|
+
* - Clipped surrogate objective to prevent large policy updates
|
|
6
|
+
* - Generalized Advantage Estimation (GAE) for variance reduction
|
|
7
|
+
* - Value function clipping for stability
|
|
8
|
+
* - Multiple epochs over collected trajectories
|
|
9
|
+
*
|
|
10
|
+
* Key features:
|
|
11
|
+
* - Trust region optimization without KL constraint
|
|
12
|
+
* - Sample efficient with mini-batch updates
|
|
13
|
+
* - Robust to hyperparameter choices
|
|
14
|
+
* - Suitable for continuous and discrete action spaces
|
|
15
|
+
*
|
|
16
|
+
* PPO-Clip objective:
|
|
17
|
+
* L^CLIP(θ) = E[min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)]
|
|
18
|
+
* where r(θ) = π_θ(a|s) / π_θ_old(a|s)
|
|
19
|
+
*
|
|
20
|
+
* @module learning/algorithms/PPOLearner
|
|
21
|
+
* @version 1.0.0
|
|
22
|
+
*/
|
|
23
|
+
import { AbstractRLLearner, RLConfig } from './AbstractRLLearner';
|
|
24
|
+
import { TaskState, AgentAction, TaskExperience } from '../types';
|
|
25
|
+
/**
|
|
26
|
+
* Configuration specific to PPO algorithm
|
|
27
|
+
*/
|
|
28
|
+
export interface PPOConfig extends RLConfig {
|
|
29
|
+
/** Clipping parameter (ε) - typically 0.1-0.3 */
|
|
30
|
+
clipEpsilon: number;
|
|
31
|
+
/** Number of epochs to train on collected data */
|
|
32
|
+
ppoEpochs: number;
|
|
33
|
+
/** Mini-batch size for training */
|
|
34
|
+
miniBatchSize: number;
|
|
35
|
+
/** Value function loss coefficient */
|
|
36
|
+
valueLossCoefficient: number;
|
|
37
|
+
/** Entropy loss coefficient for exploration */
|
|
38
|
+
entropyCoefficient: number;
|
|
39
|
+
/** GAE lambda for advantage estimation */
|
|
40
|
+
gaeLambda: number;
|
|
41
|
+
/** Maximum gradient norm for clipping */
|
|
42
|
+
maxGradNorm: number;
|
|
43
|
+
/** Whether to clip value function updates */
|
|
44
|
+
clipValueLoss: boolean;
|
|
45
|
+
/** Learning rate for policy network */
|
|
46
|
+
policyLearningRate: number;
|
|
47
|
+
/** Learning rate for value network */
|
|
48
|
+
valueLearningRate: number;
|
|
49
|
+
}
|
|
50
|
+
/**
|
|
51
|
+
* Policy parameters for a state-action pair
|
|
52
|
+
*/
|
|
53
|
+
interface PolicyParams {
|
|
54
|
+
preference: number;
|
|
55
|
+
logProb: number;
|
|
56
|
+
updateCount: number;
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* PPOLearner - Proximal Policy Optimization implementation
|
|
60
|
+
*
|
|
61
|
+
* PPO is a state-of-the-art policy gradient method that achieves
|
|
62
|
+
* strong performance while being simpler than TRPO.
|
|
63
|
+
*
|
|
64
|
+
* Usage:
|
|
65
|
+
* ```typescript
|
|
66
|
+
* const ppo = new PPOLearner({
|
|
67
|
+
* learningRate: 0.0003,
|
|
68
|
+
* discountFactor: 0.99,
|
|
69
|
+
* explorationRate: 0.0,
|
|
70
|
+
* explorationDecay: 1.0,
|
|
71
|
+
* minExplorationRate: 0.0,
|
|
72
|
+
* clipEpsilon: 0.2,
|
|
73
|
+
* ppoEpochs: 4,
|
|
74
|
+
* miniBatchSize: 64,
|
|
75
|
+
* valueLossCoefficient: 0.5,
|
|
76
|
+
* entropyCoefficient: 0.01,
|
|
77
|
+
* gaeLambda: 0.95,
|
|
78
|
+
* maxGradNorm: 0.5,
|
|
79
|
+
* clipValueLoss: true,
|
|
80
|
+
* policyLearningRate: 0.0003,
|
|
81
|
+
* valueLearningRate: 0.001,
|
|
82
|
+
* useExperienceReplay: false,
|
|
83
|
+
* replayBufferSize: 2048,
|
|
84
|
+
* batchSize: 64
|
|
85
|
+
* });
|
|
86
|
+
*
|
|
87
|
+
* // Collect trajectory
|
|
88
|
+
* ppo.collectStep(state, action, reward, nextState, done);
|
|
89
|
+
*
|
|
90
|
+
* // Train on collected trajectory
|
|
91
|
+
* ppo.trainOnTrajectory();
|
|
92
|
+
* ```
|
|
93
|
+
*/
|
|
94
|
+
export declare class PPOLearner extends AbstractRLLearner {
|
|
95
|
+
private ppoConfig;
|
|
96
|
+
private policyTable;
|
|
97
|
+
private valueTable;
|
|
98
|
+
private oldPolicyTable;
|
|
99
|
+
private trajectory;
|
|
100
|
+
private readonly defaultExploration;
|
|
101
|
+
constructor(config: PPOConfig);
|
|
102
|
+
/**
|
|
103
|
+
* Select action using current policy (softmax)
|
|
104
|
+
*/
|
|
105
|
+
selectAction(state: TaskState, availableActions: AgentAction[]): AgentAction;
|
|
106
|
+
/**
|
|
107
|
+
* Get action probabilities using softmax policy
|
|
108
|
+
*/
|
|
109
|
+
private getActionProbabilities;
|
|
110
|
+
/**
|
|
111
|
+
* Get policy parameters for state-action pair
|
|
112
|
+
*/
|
|
113
|
+
private getPolicyParams;
|
|
114
|
+
/**
|
|
115
|
+
* Get log probability of action under current policy
|
|
116
|
+
*/
|
|
117
|
+
private getLogProb;
|
|
118
|
+
/**
|
|
119
|
+
* Get state value from value network
|
|
120
|
+
*/
|
|
121
|
+
getStateValue(state: TaskState): number;
|
|
122
|
+
/**
|
|
123
|
+
* Collect a step in the trajectory
|
|
124
|
+
*/
|
|
125
|
+
collectStep(state: TaskState, action: AgentAction, reward: number, nextState: TaskState, done: boolean): void;
|
|
126
|
+
/**
|
|
127
|
+
* Standard update interface - collects experience and trains when ready
|
|
128
|
+
*/
|
|
129
|
+
update(experience: TaskExperience, nextAction?: AgentAction): void;
|
|
130
|
+
/**
|
|
131
|
+
* Train on collected trajectory using PPO
|
|
132
|
+
*/
|
|
133
|
+
trainOnTrajectory(): void;
|
|
134
|
+
/**
|
|
135
|
+
* Compute Generalized Advantage Estimation (GAE)
|
|
136
|
+
*
|
|
137
|
+
* GAE: Â_t = Σ_{l=0}^∞ (γλ)^l δ_{t+l}
|
|
138
|
+
* where δ_t = r_t + γV(s_{t+1}) - V(s_t)
|
|
139
|
+
*/
|
|
140
|
+
private computeGAE;
|
|
141
|
+
/**
|
|
142
|
+
* Save current policy as old policy for ratio computation
|
|
143
|
+
*/
|
|
144
|
+
private saveOldPolicy;
|
|
145
|
+
/**
|
|
146
|
+
* Get old log probability for ratio computation
|
|
147
|
+
*/
|
|
148
|
+
private getOldLogProb;
|
|
149
|
+
/**
|
|
150
|
+
* Train one epoch on the trajectory
|
|
151
|
+
*/
|
|
152
|
+
private trainEpoch;
|
|
153
|
+
/**
|
|
154
|
+
* Train on a mini-batch
|
|
155
|
+
*/
|
|
156
|
+
private trainMiniBatch;
|
|
157
|
+
/**
|
|
158
|
+
* Update policy parameters
|
|
159
|
+
*/
|
|
160
|
+
private updatePolicy;
|
|
161
|
+
/**
|
|
162
|
+
* Update value function
|
|
163
|
+
*/
|
|
164
|
+
private updateValue;
|
|
165
|
+
/**
|
|
166
|
+
* Compute entropy of policy at state
|
|
167
|
+
*/
|
|
168
|
+
private computeEntropy;
|
|
169
|
+
/**
|
|
170
|
+
* Get default exploration rate for reset
|
|
171
|
+
*/
|
|
172
|
+
protected getDefaultExplorationRate(): number;
|
|
173
|
+
/**
|
|
174
|
+
* Get PPO-specific statistics
|
|
175
|
+
*/
|
|
176
|
+
getPPOStatistics(): {
|
|
177
|
+
trajectoryLength: number;
|
|
178
|
+
valueTableSize: number;
|
|
179
|
+
policyTableSize: number;
|
|
180
|
+
avgValue: number;
|
|
181
|
+
avgAdvantage: number;
|
|
182
|
+
clipFraction: number;
|
|
183
|
+
};
|
|
184
|
+
/**
|
|
185
|
+
* Reset PPO-specific state
|
|
186
|
+
*/
|
|
187
|
+
reset(): void;
|
|
188
|
+
/**
|
|
189
|
+
* Export PPO state
|
|
190
|
+
*/
|
|
191
|
+
exportPPO(): {
|
|
192
|
+
base: ReturnType<AbstractRLLearner['export']>;
|
|
193
|
+
policyTable: Record<string, Record<string, PolicyParams>>;
|
|
194
|
+
valueTable: Record<string, number>;
|
|
195
|
+
ppoConfig: PPOConfig;
|
|
196
|
+
};
|
|
197
|
+
/**
|
|
198
|
+
* Import PPO state
|
|
199
|
+
*/
|
|
200
|
+
importPPO(state: ReturnType<typeof this.exportPPO>): void;
|
|
201
|
+
}
|
|
202
|
+
/**
|
|
203
|
+
* Create default PPO configuration
|
|
204
|
+
*/
|
|
205
|
+
export declare function createDefaultPPOConfig(): PPOConfig;
|
|
206
|
+
export {};
|
|
207
|
+
//# sourceMappingURL=PPOLearner.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"PPOLearner.d.ts","sourceRoot":"","sources":["../../../src/learning/algorithms/PPOLearner.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;GAqBG;AAEH,OAAO,EAAE,iBAAiB,EAAE,QAAQ,EAAU,MAAM,qBAAqB,CAAC;AAC1E,OAAO,EAAE,SAAS,EAAE,WAAW,EAAE,cAAc,EAAE,MAAM,UAAU,CAAC;AAElE;;GAEG;AACH,MAAM,WAAW,SAAU,SAAQ,QAAQ;IACzC,iDAAiD;IACjD,WAAW,EAAE,MAAM,CAAC;IACpB,kDAAkD;IAClD,SAAS,EAAE,MAAM,CAAC;IAClB,mCAAmC;IACnC,aAAa,EAAE,MAAM,CAAC;IACtB,sCAAsC;IACtC,oBAAoB,EAAE,MAAM,CAAC;IAC7B,+CAA+C;IAC/C,kBAAkB,EAAE,MAAM,CAAC;IAC3B,0CAA0C;IAC1C,SAAS,EAAE,MAAM,CAAC;IAClB,yCAAyC;IACzC,WAAW,EAAE,MAAM,CAAC;IACpB,6CAA6C;IAC7C,aAAa,EAAE,OAAO,CAAC;IACvB,uCAAuC;IACvC,kBAAkB,EAAE,MAAM,CAAC;IAC3B,sCAAsC;IACtC,iBAAiB,EAAE,MAAM,CAAC;CAC3B;AAiBD;;GAEG;AACH,UAAU,YAAY;IACpB,UAAU,EAAE,MAAM,CAAC;IACnB,OAAO,EAAE,MAAM,CAAC;IAChB,WAAW,EAAE,MAAM,CAAC;CACrB;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAmCG;AACH,qBAAa,UAAW,SAAQ,iBAAiB;IAC/C,OAAO,CAAC,SAAS,CAAY;IAC7B,OAAO,CAAC,WAAW,CAAyC;IAC5D,OAAO,CAAC,UAAU,CAAsB;IACxC,OAAO,CAAC,cAAc,CAAyC;IAC/D,OAAO,CAAC,UAAU,CAAmB;IACrC,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAS;gBAEhC,MAAM,EAAE,SAAS;IAiB7B;;OAEG;IACM,YAAY,CAAC,KAAK,EAAE,SAAS,EAAE,gBAAgB,EAAE,WAAW,EAAE,GAAG,WAAW;IAsBrF;;OAEG;IACH,OAAO,CAAC,sBAAsB;IAiB9B;;OAEG;IACH,OAAO,CAAC,eAAe;IAQvB;;OAEG;IACH,OAAO,CAAC,UAAU;IA6BlB;;OAEG;IACM,aAAa,CAAC,KAAK,EAAE,SAAS,GAAG,MAAM;IAKhD;;OAEG;IACH,WAAW,CACT,KAAK,EAAE,SAAS,EAChB,MAAM,EAAE,WAAW,EACnB,MAAM,EAAE,MAAM,EACd,SAAS,EAAE,SAAS,EACpB,IAAI,EAAE,OAAO,GACZ,IAAI;IAqBP;;OAEG;IACM,MAAM,CAAC,UAAU,EAAE,cAAc,EAAE,UAAU,CAAC,EAAE,WAAW,GAAG,IAAI;IAe3E;;OAEG;IACH,iBAAiB,IAAI,IAAI;IAyBzB;;;;;OAKG;IACH,OAAO,CAAC,UAAU;IAsClB;;OAEG;IACH,OAAO,CAAC,aAAa;IAWrB;;OAEG;IACH,OAAO,CAAC,aAAa;IAQrB;;OAEG;IACH,OAAO,CAAC,UAAU;IAWlB;;OAEG;IACH,OAAO,CAAC,cAAc;IA0CtB;;OAEG;IACH,OAAO,CAAC,YAAY;IAkCpB;;OAEG;IACH,OAAO,CAAC,WAAW;IAMnB;;OAEG;IACH,OAAO,CAAC,cAAc;IAsBtB;;OAEG;IACH,SAAS,CAAC,yBAAyB,IAAI,MAAM;IAI7C;;OAEG;IACH,gBAAgB,IAAI;QAClB,gBAAgB,EAAE,MAAM,CAAC;QACzB,cAAc,EAAE,MAAM,CAAC;QACvB,eAAe,EAAE,MAAM,CAAC;QACxB,QAAQ,EAAE,MAAM,CAAC;QACjB,YAAY,EAAE,MAAM,CAAC;QACrB,YAAY,EAAE,MAAM,CAAC;KACtB;IAyBD;;OAEG;IACM,KAAK,IAAI,IAAI;IAStB;;OAEG;IACH,SAAS,IAAI;QACX,IAAI,EAAE,UAAU,CAAC,iBAAiB,CAAC,QAAQ,CAAC,CAAC,CAAC;QAC9C,WAAW,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAC;QAC1D,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QACnC,SAAS,EAAE,SAAS,CAAC;KACtB;IAsBD;;OAEG;IACH,SAAS,CAAC,KAAK,EAAE,UAAU,CAAC,OAAO,IAAI,CAAC,SAAS,CAAC,GAAG,IAAI;CAwB1D;AAED;;GAEG;AACH,wBAAgB,sBAAsB,IAAI,SAAS,CAqBlD"}
|
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
/**
|
|
3
|
+
* PPOLearner - Proximal Policy Optimization Algorithm
|
|
4
|
+
*
|
|
5
|
+
* Implements PPO-Clip, the most widely used variant of PPO:
|
|
6
|
+
* - Clipped surrogate objective to prevent large policy updates
|
|
7
|
+
* - Generalized Advantage Estimation (GAE) for variance reduction
|
|
8
|
+
* - Value function clipping for stability
|
|
9
|
+
* - Multiple epochs over collected trajectories
|
|
10
|
+
*
|
|
11
|
+
* Key features:
|
|
12
|
+
* - Trust region optimization without KL constraint
|
|
13
|
+
* - Sample efficient with mini-batch updates
|
|
14
|
+
* - Robust to hyperparameter choices
|
|
15
|
+
* - Suitable for continuous and discrete action spaces
|
|
16
|
+
*
|
|
17
|
+
* PPO-Clip objective:
|
|
18
|
+
* L^CLIP(θ) = E[min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)]
|
|
19
|
+
* where r(θ) = π_θ(a|s) / π_θ_old(a|s)
|
|
20
|
+
*
|
|
21
|
+
* @module learning/algorithms/PPOLearner
|
|
22
|
+
* @version 1.0.0
|
|
23
|
+
*/
|
|
24
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
25
|
+
exports.PPOLearner = void 0;
|
|
26
|
+
exports.createDefaultPPOConfig = createDefaultPPOConfig;
|
|
27
|
+
const AbstractRLLearner_1 = require("./AbstractRLLearner");
|
|
28
|
+
/**
|
|
29
|
+
* PPOLearner - Proximal Policy Optimization implementation
|
|
30
|
+
*
|
|
31
|
+
* PPO is a state-of-the-art policy gradient method that achieves
|
|
32
|
+
* strong performance while being simpler than TRPO.
|
|
33
|
+
*
|
|
34
|
+
* Usage:
|
|
35
|
+
* ```typescript
|
|
36
|
+
* const ppo = new PPOLearner({
|
|
37
|
+
* learningRate: 0.0003,
|
|
38
|
+
* discountFactor: 0.99,
|
|
39
|
+
* explorationRate: 0.0,
|
|
40
|
+
* explorationDecay: 1.0,
|
|
41
|
+
* minExplorationRate: 0.0,
|
|
42
|
+
* clipEpsilon: 0.2,
|
|
43
|
+
* ppoEpochs: 4,
|
|
44
|
+
* miniBatchSize: 64,
|
|
45
|
+
* valueLossCoefficient: 0.5,
|
|
46
|
+
* entropyCoefficient: 0.01,
|
|
47
|
+
* gaeLambda: 0.95,
|
|
48
|
+
* maxGradNorm: 0.5,
|
|
49
|
+
* clipValueLoss: true,
|
|
50
|
+
* policyLearningRate: 0.0003,
|
|
51
|
+
* valueLearningRate: 0.001,
|
|
52
|
+
* useExperienceReplay: false,
|
|
53
|
+
* replayBufferSize: 2048,
|
|
54
|
+
* batchSize: 64
|
|
55
|
+
* });
|
|
56
|
+
*
|
|
57
|
+
* // Collect trajectory
|
|
58
|
+
* ppo.collectStep(state, action, reward, nextState, done);
|
|
59
|
+
*
|
|
60
|
+
* // Train on collected trajectory
|
|
61
|
+
* ppo.trainOnTrajectory();
|
|
62
|
+
* ```
|
|
63
|
+
*/
|
|
64
|
+
class PPOLearner extends AbstractRLLearner_1.AbstractRLLearner {
|
|
65
|
+
constructor(config) {
|
|
66
|
+
super(config);
|
|
67
|
+
this.ppoConfig = config;
|
|
68
|
+
this.policyTable = new Map();
|
|
69
|
+
this.valueTable = new Map();
|
|
70
|
+
this.oldPolicyTable = new Map();
|
|
71
|
+
this.trajectory = [];
|
|
72
|
+
this.defaultExploration = config.explorationRate;
|
|
73
|
+
this.logger.info('PPOLearner initialized', {
|
|
74
|
+
clipEpsilon: config.clipEpsilon,
|
|
75
|
+
epochs: config.ppoEpochs,
|
|
76
|
+
gaeLambda: config.gaeLambda,
|
|
77
|
+
entropyCoeff: config.entropyCoefficient
|
|
78
|
+
});
|
|
79
|
+
}
|
|
80
|
+
/**
|
|
81
|
+
* Select action using current policy (softmax)
|
|
82
|
+
*/
|
|
83
|
+
selectAction(state, availableActions) {
|
|
84
|
+
if (availableActions.length === 0) {
|
|
85
|
+
throw new Error('No available actions to select from');
|
|
86
|
+
}
|
|
87
|
+
const stateKey = this.encodeState(state);
|
|
88
|
+
const probs = this.getActionProbabilities(stateKey, availableActions);
|
|
89
|
+
// Sample from distribution
|
|
90
|
+
const random = Math.random();
|
|
91
|
+
let cumulative = 0;
|
|
92
|
+
for (let i = 0; i < availableActions.length; i++) {
|
|
93
|
+
cumulative += probs[i];
|
|
94
|
+
if (random <= cumulative) {
|
|
95
|
+
return availableActions[i];
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
return availableActions[availableActions.length - 1];
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* Get action probabilities using softmax policy
|
|
102
|
+
*/
|
|
103
|
+
getActionProbabilities(stateKey, availableActions) {
|
|
104
|
+
const preferences = [];
|
|
105
|
+
for (const action of availableActions) {
|
|
106
|
+
const actionKey = this.encodeAction(action);
|
|
107
|
+
const params = this.getPolicyParams(stateKey, actionKey);
|
|
108
|
+
preferences.push(params.preference);
|
|
109
|
+
}
|
|
110
|
+
// Softmax with numerical stability
|
|
111
|
+
const maxPref = Math.max(...preferences);
|
|
112
|
+
const expPrefs = preferences.map(p => Math.exp(p - maxPref));
|
|
113
|
+
const sumExp = expPrefs.reduce((sum, e) => sum + e, 0);
|
|
114
|
+
return expPrefs.map(e => e / sumExp);
|
|
115
|
+
}
|
|
116
|
+
/**
|
|
117
|
+
* Get policy parameters for state-action pair
|
|
118
|
+
*/
|
|
119
|
+
getPolicyParams(stateKey, actionKey) {
|
|
120
|
+
const statePolicy = this.policyTable.get(stateKey);
|
|
121
|
+
if (!statePolicy) {
|
|
122
|
+
return { preference: 0, logProb: 0, updateCount: 0 };
|
|
123
|
+
}
|
|
124
|
+
return statePolicy.get(actionKey) ?? { preference: 0, logProb: 0, updateCount: 0 };
|
|
125
|
+
}
|
|
126
|
+
/**
|
|
127
|
+
* Get log probability of action under current policy
|
|
128
|
+
*/
|
|
129
|
+
getLogProb(stateKey, actionKey, availableActions) {
|
|
130
|
+
// Get preference for target action
|
|
131
|
+
const params = this.getPolicyParams(stateKey, actionKey);
|
|
132
|
+
// If we don't know the action space, return stored log prob
|
|
133
|
+
if (!availableActions) {
|
|
134
|
+
return params.logProb;
|
|
135
|
+
}
|
|
136
|
+
// Calculate actual log probability
|
|
137
|
+
const prefs = [];
|
|
138
|
+
let targetPref = params.preference;
|
|
139
|
+
for (const action of availableActions) {
|
|
140
|
+
const ak = this.encodeAction(action);
|
|
141
|
+
const p = this.getPolicyParams(stateKey, ak);
|
|
142
|
+
prefs.push(p.preference);
|
|
143
|
+
if (ak === actionKey) {
|
|
144
|
+
targetPref = p.preference;
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
const maxPref = Math.max(...prefs, targetPref);
|
|
148
|
+
const expTarget = Math.exp(targetPref - maxPref);
|
|
149
|
+
const sumExp = prefs.reduce((sum, p) => sum + Math.exp(p - maxPref), 0);
|
|
150
|
+
return Math.log(expTarget / sumExp);
|
|
151
|
+
}
|
|
152
|
+
/**
|
|
153
|
+
* Get state value from value network
|
|
154
|
+
*/
|
|
155
|
+
getStateValue(state) {
|
|
156
|
+
const stateKey = this.encodeState(state);
|
|
157
|
+
return this.valueTable.get(stateKey) ?? 0;
|
|
158
|
+
}
|
|
159
|
+
/**
|
|
160
|
+
* Collect a step in the trajectory
|
|
161
|
+
*/
|
|
162
|
+
collectStep(state, action, reward, nextState, done) {
|
|
163
|
+
const stateKey = this.encodeState(state);
|
|
164
|
+
const actionKey = this.encodeAction(action);
|
|
165
|
+
const nextStateKey = this.encodeState(nextState);
|
|
166
|
+
const value = this.valueTable.get(stateKey) ?? 0;
|
|
167
|
+
const logProb = this.getLogProb(stateKey, actionKey);
|
|
168
|
+
this.trajectory.push({
|
|
169
|
+
state: stateKey,
|
|
170
|
+
action: actionKey,
|
|
171
|
+
reward,
|
|
172
|
+
nextState: nextStateKey,
|
|
173
|
+
done,
|
|
174
|
+
value,
|
|
175
|
+
logProb,
|
|
176
|
+
advantage: 0, // Computed later
|
|
177
|
+
returns: 0 // Computed later
|
|
178
|
+
});
|
|
179
|
+
}
|
|
180
|
+
/**
|
|
181
|
+
* Standard update interface - collects experience and trains when ready
|
|
182
|
+
*/
|
|
183
|
+
update(experience, nextAction) {
|
|
184
|
+
this.stepCount++;
|
|
185
|
+
const { state, action, reward, nextState } = experience;
|
|
186
|
+
const done = experience.done ?? false;
|
|
187
|
+
// Collect step
|
|
188
|
+
this.collectStep(state, action, reward, nextState, done);
|
|
189
|
+
// Train when trajectory is large enough
|
|
190
|
+
if (this.trajectory.length >= this.ppoConfig.replayBufferSize) {
|
|
191
|
+
this.trainOnTrajectory();
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
/**
|
|
195
|
+
* Train on collected trajectory using PPO
|
|
196
|
+
*/
|
|
197
|
+
trainOnTrajectory() {
|
|
198
|
+
if (this.trajectory.length === 0) {
|
|
199
|
+
return;
|
|
200
|
+
}
|
|
201
|
+
// Compute advantages using GAE
|
|
202
|
+
this.computeGAE();
|
|
203
|
+
// Save old policy for ratio computation
|
|
204
|
+
this.saveOldPolicy();
|
|
205
|
+
// Multiple epochs of training
|
|
206
|
+
for (let epoch = 0; epoch < this.ppoConfig.ppoEpochs; epoch++) {
|
|
207
|
+
this.trainEpoch();
|
|
208
|
+
}
|
|
209
|
+
// Clear trajectory
|
|
210
|
+
this.trajectory = [];
|
|
211
|
+
this.logger.info('PPO training complete', {
|
|
212
|
+
epochs: this.ppoConfig.ppoEpochs,
|
|
213
|
+
steps: this.stepCount
|
|
214
|
+
});
|
|
215
|
+
}
|
|
216
|
+
/**
|
|
217
|
+
* Compute Generalized Advantage Estimation (GAE)
|
|
218
|
+
*
|
|
219
|
+
* GAE: Â_t = Σ_{l=0}^∞ (γλ)^l δ_{t+l}
|
|
220
|
+
* where δ_t = r_t + γV(s_{t+1}) - V(s_t)
|
|
221
|
+
*/
|
|
222
|
+
computeGAE() {
|
|
223
|
+
const gamma = this.config.discountFactor;
|
|
224
|
+
const lambda = this.ppoConfig.gaeLambda;
|
|
225
|
+
let lastGaeLam = 0;
|
|
226
|
+
const n = this.trajectory.length;
|
|
227
|
+
// Compute returns and advantages backwards
|
|
228
|
+
for (let t = n - 1; t >= 0; t--) {
|
|
229
|
+
const step = this.trajectory[t];
|
|
230
|
+
const nextValue = step.done
|
|
231
|
+
? 0
|
|
232
|
+
: (t < n - 1 ? this.trajectory[t + 1].value : this.valueTable.get(step.nextState) ?? 0);
|
|
233
|
+
// TD error
|
|
234
|
+
const delta = step.reward + gamma * nextValue - step.value;
|
|
235
|
+
// GAE advantage
|
|
236
|
+
lastGaeLam = step.done
|
|
237
|
+
? delta
|
|
238
|
+
: delta + gamma * lambda * lastGaeLam;
|
|
239
|
+
step.advantage = lastGaeLam;
|
|
240
|
+
step.returns = step.advantage + step.value;
|
|
241
|
+
}
|
|
242
|
+
// Normalize advantages
|
|
243
|
+
const advantages = this.trajectory.map(s => s.advantage);
|
|
244
|
+
const mean = advantages.reduce((s, a) => s + a, 0) / advantages.length;
|
|
245
|
+
const variance = advantages.reduce((s, a) => s + (a - mean) ** 2, 0) / advantages.length;
|
|
246
|
+
const std = Math.sqrt(variance) + 1e-8;
|
|
247
|
+
for (const step of this.trajectory) {
|
|
248
|
+
step.advantage = (step.advantage - mean) / std;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
/**
|
|
252
|
+
* Save current policy as old policy for ratio computation
|
|
253
|
+
*/
|
|
254
|
+
saveOldPolicy() {
|
|
255
|
+
this.oldPolicyTable.clear();
|
|
256
|
+
for (const [state, actions] of this.policyTable.entries()) {
|
|
257
|
+
const actionMap = new Map();
|
|
258
|
+
for (const [action, params] of actions.entries()) {
|
|
259
|
+
actionMap.set(action, { ...params });
|
|
260
|
+
}
|
|
261
|
+
this.oldPolicyTable.set(state, actionMap);
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
/**
|
|
265
|
+
* Get old log probability for ratio computation
|
|
266
|
+
*/
|
|
267
|
+
getOldLogProb(stateKey, actionKey) {
|
|
268
|
+
const statePolicy = this.oldPolicyTable.get(stateKey);
|
|
269
|
+
if (!statePolicy) {
|
|
270
|
+
return 0;
|
|
271
|
+
}
|
|
272
|
+
return statePolicy.get(actionKey)?.logProb ?? 0;
|
|
273
|
+
}
|
|
274
|
+
/**
|
|
275
|
+
* Train one epoch on the trajectory
|
|
276
|
+
*/
|
|
277
|
+
trainEpoch() {
|
|
278
|
+
// Shuffle trajectory
|
|
279
|
+
const shuffled = [...this.trajectory].sort(() => Math.random() - 0.5);
|
|
280
|
+
// Mini-batch updates
|
|
281
|
+
for (let i = 0; i < shuffled.length; i += this.ppoConfig.miniBatchSize) {
|
|
282
|
+
const batch = shuffled.slice(i, i + this.ppoConfig.miniBatchSize);
|
|
283
|
+
this.trainMiniBatch(batch);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
/**
|
|
287
|
+
* Train on a mini-batch
|
|
288
|
+
*/
|
|
289
|
+
trainMiniBatch(batch) {
|
|
290
|
+
for (const step of batch) {
|
|
291
|
+
// Compute probability ratio
|
|
292
|
+
const newLogProb = this.getLogProb(step.state, step.action);
|
|
293
|
+
const oldLogProb = step.logProb; // Use stored log prob
|
|
294
|
+
const ratio = Math.exp(newLogProb - oldLogProb);
|
|
295
|
+
// Compute clipped and unclipped objectives
|
|
296
|
+
const eps = this.ppoConfig.clipEpsilon;
|
|
297
|
+
const surr1 = ratio * step.advantage;
|
|
298
|
+
const surr2 = Math.max(Math.min(ratio, 1 + eps), 1 - eps) * step.advantage;
|
|
299
|
+
// Policy loss (negative because we want to maximize)
|
|
300
|
+
const policyLoss = -Math.min(surr1, surr2);
|
|
301
|
+
// Value loss
|
|
302
|
+
const valueTarget = step.returns;
|
|
303
|
+
const currentValue = this.valueTable.get(step.state) ?? 0;
|
|
304
|
+
let valueLoss = (currentValue - valueTarget) ** 2;
|
|
305
|
+
// Clip value loss if enabled
|
|
306
|
+
if (this.ppoConfig.clipValueLoss) {
|
|
307
|
+
const clippedValue = step.value + Math.max(Math.min(currentValue - step.value, eps), -eps);
|
|
308
|
+
const clippedValueLoss = (clippedValue - valueTarget) ** 2;
|
|
309
|
+
valueLoss = Math.max(valueLoss, clippedValueLoss);
|
|
310
|
+
}
|
|
311
|
+
// Entropy bonus
|
|
312
|
+
const entropy = this.computeEntropy(step.state);
|
|
313
|
+
const entropyLoss = -this.ppoConfig.entropyCoefficient * entropy;
|
|
314
|
+
// Total loss
|
|
315
|
+
const totalLoss = policyLoss + this.ppoConfig.valueLossCoefficient * valueLoss + entropyLoss;
|
|
316
|
+
// Update policy (gradient ascent direction)
|
|
317
|
+
this.updatePolicy(step.state, step.action, step.advantage, ratio);
|
|
318
|
+
// Update value function
|
|
319
|
+
this.updateValue(step.state, valueTarget);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
/**
|
|
323
|
+
* Update policy parameters
|
|
324
|
+
*/
|
|
325
|
+
updatePolicy(stateKey, actionKey, advantage, ratio) {
|
|
326
|
+
if (!this.policyTable.has(stateKey)) {
|
|
327
|
+
this.policyTable.set(stateKey, new Map());
|
|
328
|
+
}
|
|
329
|
+
const statePolicy = this.policyTable.get(stateKey);
|
|
330
|
+
const current = statePolicy.get(actionKey) ?? { preference: 0, logProb: 0, updateCount: 0 };
|
|
331
|
+
// Clipped gradient
|
|
332
|
+
const eps = this.ppoConfig.clipEpsilon;
|
|
333
|
+
let gradient = advantage;
|
|
334
|
+
if ((ratio > 1 + eps && advantage > 0) || (ratio < 1 - eps && advantage < 0)) {
|
|
335
|
+
gradient = 0; // Clipped - no update
|
|
336
|
+
}
|
|
337
|
+
// Update preference
|
|
338
|
+
const newPreference = current.preference + this.ppoConfig.policyLearningRate * gradient;
|
|
339
|
+
const newLogProb = this.getLogProb(stateKey, actionKey);
|
|
340
|
+
statePolicy.set(actionKey, {
|
|
341
|
+
preference: newPreference,
|
|
342
|
+
logProb: newLogProb,
|
|
343
|
+
updateCount: current.updateCount + 1
|
|
344
|
+
});
|
|
345
|
+
// Update Q-table for compatibility
|
|
346
|
+
this.setQValue(stateKey, actionKey, newPreference);
|
|
347
|
+
}
|
|
348
|
+
/**
|
|
349
|
+
* Update value function
|
|
350
|
+
*/
|
|
351
|
+
updateValue(stateKey, target) {
|
|
352
|
+
const current = this.valueTable.get(stateKey) ?? 0;
|
|
353
|
+
const newValue = current + this.ppoConfig.valueLearningRate * (target - current);
|
|
354
|
+
this.valueTable.set(stateKey, newValue);
|
|
355
|
+
}
|
|
356
|
+
/**
|
|
357
|
+
* Compute entropy of policy at state
|
|
358
|
+
*/
|
|
359
|
+
computeEntropy(stateKey) {
|
|
360
|
+
const statePolicy = this.policyTable.get(stateKey);
|
|
361
|
+
if (!statePolicy || statePolicy.size === 0) {
|
|
362
|
+
return 0;
|
|
363
|
+
}
|
|
364
|
+
const prefs = Array.from(statePolicy.values()).map(p => p.preference);
|
|
365
|
+
const maxPref = Math.max(...prefs);
|
|
366
|
+
const expPrefs = prefs.map(p => Math.exp(p - maxPref));
|
|
367
|
+
const sumExp = expPrefs.reduce((s, e) => s + e, 0);
|
|
368
|
+
const probs = expPrefs.map(e => e / sumExp);
|
|
369
|
+
let entropy = 0;
|
|
370
|
+
for (const p of probs) {
|
|
371
|
+
if (p > 0) {
|
|
372
|
+
entropy -= p * Math.log(p);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
return entropy;
|
|
376
|
+
}
|
|
377
|
+
/**
|
|
378
|
+
* Get default exploration rate for reset
|
|
379
|
+
*/
|
|
380
|
+
getDefaultExplorationRate() {
|
|
381
|
+
return this.defaultExploration;
|
|
382
|
+
}
|
|
383
|
+
/**
|
|
384
|
+
* Get PPO-specific statistics
|
|
385
|
+
*/
|
|
386
|
+
getPPOStatistics() {
|
|
387
|
+
let totalValue = 0;
|
|
388
|
+
for (const v of this.valueTable.values()) {
|
|
389
|
+
totalValue += v;
|
|
390
|
+
}
|
|
391
|
+
let policySize = 0;
|
|
392
|
+
for (const statePolicy of this.policyTable.values()) {
|
|
393
|
+
policySize += statePolicy.size;
|
|
394
|
+
}
|
|
395
|
+
const avgAdvantage = this.trajectory.length > 0
|
|
396
|
+
? this.trajectory.reduce((s, t) => s + t.advantage, 0) / this.trajectory.length
|
|
397
|
+
: 0;
|
|
398
|
+
return {
|
|
399
|
+
trajectoryLength: this.trajectory.length,
|
|
400
|
+
valueTableSize: this.valueTable.size,
|
|
401
|
+
policyTableSize: policySize,
|
|
402
|
+
avgValue: this.valueTable.size > 0 ? totalValue / this.valueTable.size : 0,
|
|
403
|
+
avgAdvantage,
|
|
404
|
+
clipFraction: 0 // Would need tracking during training
|
|
405
|
+
};
|
|
406
|
+
}
|
|
407
|
+
/**
|
|
408
|
+
* Reset PPO-specific state
|
|
409
|
+
*/
|
|
410
|
+
reset() {
|
|
411
|
+
super.reset();
|
|
412
|
+
this.policyTable.clear();
|
|
413
|
+
this.valueTable.clear();
|
|
414
|
+
this.oldPolicyTable.clear();
|
|
415
|
+
this.trajectory = [];
|
|
416
|
+
this.logger.info('PPOLearner reset');
|
|
417
|
+
}
|
|
418
|
+
/**
|
|
419
|
+
* Export PPO state
|
|
420
|
+
*/
|
|
421
|
+
exportPPO() {
|
|
422
|
+
const serializedPolicy = {};
|
|
423
|
+
for (const [state, actions] of this.policyTable.entries()) {
|
|
424
|
+
serializedPolicy[state] = {};
|
|
425
|
+
for (const [action, params] of actions.entries()) {
|
|
426
|
+
serializedPolicy[state][action] = params;
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
const serializedValue = {};
|
|
430
|
+
for (const [state, value] of this.valueTable.entries()) {
|
|
431
|
+
serializedValue[state] = value;
|
|
432
|
+
}
|
|
433
|
+
return {
|
|
434
|
+
base: this.export(),
|
|
435
|
+
policyTable: serializedPolicy,
|
|
436
|
+
valueTable: serializedValue,
|
|
437
|
+
ppoConfig: { ...this.ppoConfig }
|
|
438
|
+
};
|
|
439
|
+
}
|
|
440
|
+
/**
|
|
441
|
+
* Import PPO state
|
|
442
|
+
*/
|
|
443
|
+
importPPO(state) {
|
|
444
|
+
this.import(state.base);
|
|
445
|
+
this.policyTable.clear();
|
|
446
|
+
for (const [stateKey, actions] of Object.entries(state.policyTable)) {
|
|
447
|
+
const actionMap = new Map();
|
|
448
|
+
for (const [actionKey, params] of Object.entries(actions)) {
|
|
449
|
+
actionMap.set(actionKey, params);
|
|
450
|
+
}
|
|
451
|
+
this.policyTable.set(stateKey, actionMap);
|
|
452
|
+
}
|
|
453
|
+
this.valueTable.clear();
|
|
454
|
+
for (const [stateKey, value] of Object.entries(state.valueTable)) {
|
|
455
|
+
this.valueTable.set(stateKey, value);
|
|
456
|
+
}
|
|
457
|
+
this.ppoConfig = { ...state.ppoConfig };
|
|
458
|
+
this.logger.info('Imported PPO state', {
|
|
459
|
+
policySize: this.policyTable.size,
|
|
460
|
+
valueSize: this.valueTable.size
|
|
461
|
+
});
|
|
462
|
+
}
|
|
463
|
+
}
|
|
464
|
+
exports.PPOLearner = PPOLearner;
|
|
465
|
+
/**
|
|
466
|
+
* Create default PPO configuration
|
|
467
|
+
*/
|
|
468
|
+
function createDefaultPPOConfig() {
|
|
469
|
+
return {
|
|
470
|
+
learningRate: 0.0003,
|
|
471
|
+
discountFactor: 0.99,
|
|
472
|
+
explorationRate: 0.0, // PPO uses entropy for exploration
|
|
473
|
+
explorationDecay: 1.0,
|
|
474
|
+
minExplorationRate: 0.0,
|
|
475
|
+
clipEpsilon: 0.2,
|
|
476
|
+
ppoEpochs: 4,
|
|
477
|
+
miniBatchSize: 64,
|
|
478
|
+
valueLossCoefficient: 0.5,
|
|
479
|
+
entropyCoefficient: 0.01,
|
|
480
|
+
gaeLambda: 0.95,
|
|
481
|
+
maxGradNorm: 0.5,
|
|
482
|
+
clipValueLoss: true,
|
|
483
|
+
policyLearningRate: 0.0003,
|
|
484
|
+
valueLearningRate: 0.001,
|
|
485
|
+
useExperienceReplay: false, // PPO doesn't use replay buffer
|
|
486
|
+
replayBufferSize: 2048, // Used as trajectory buffer size
|
|
487
|
+
batchSize: 64
|
|
488
|
+
};
|
|
489
|
+
}
|
|
490
|
+
//# sourceMappingURL=PPOLearner.js.map
|