scratchkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlscratch/__init__.py +56 -0
- mlscratch/__main__.py +118 -0
- mlscratch/bayesian/__init__.py +53 -0
- mlscratch/bayesian/bayesian_linear_regression.py +171 -0
- mlscratch/bayesian/bayesian_network.py +248 -0
- mlscratch/bayesian/bayesian_nn.py +315 -0
- mlscratch/bayesian/gaussian_process.py +207 -0
- mlscratch/bayesian/hmm.py +277 -0
- mlscratch/bayesian/init.py +52 -0
- mlscratch/bayesian/kalman_filter.py +182 -0
- mlscratch/bayesian/naive_bayes.py +209 -0
- mlscratch/metrics/__init__.py +59 -0
- mlscratch/metrics/classification.py +365 -0
- mlscratch/metrics/regression.py +79 -0
- mlscratch/neural/__init__.py +121 -0
- mlscratch/neural/attention.py +420 -0
- mlscratch/neural/autoencoder.py +543 -0
- mlscratch/neural/boltzmann.py +231 -0
- mlscratch/neural/cnn.py +593 -0
- mlscratch/neural/cvnn.py +322 -0
- mlscratch/neural/gan.py +364 -0
- mlscratch/neural/hopfield.py +193 -0
- mlscratch/neural/perceptron.py +398 -0
- mlscratch/neural/rbf_network.py +230 -0
- mlscratch/neural/recurrent.py +569 -0
- mlscratch/preprocessing/__init__.py +38 -0
- mlscratch/preprocessing/encoders.py +140 -0
- mlscratch/preprocessing/model_selection.py +119 -0
- mlscratch/preprocessing/polynomial.py +105 -0
- mlscratch/preprocessing/scalers.py +220 -0
- mlscratch/py.typed +0 -0
- mlscratch/reinforcement/__init__.py +59 -0
- mlscratch/reinforcement/ddpg.py +363 -0
- mlscratch/reinforcement/dqn.py +319 -0
- mlscratch/reinforcement/ppo.py +452 -0
- mlscratch/reinforcement/q_learning.py +352 -0
- mlscratch/reinforcement/sac.py +382 -0
- mlscratch/reinforcement/utils.py +594 -0
- mlscratch/supervised/__init__.py +76 -0
- mlscratch/supervised/_validation.py +50 -0
- mlscratch/supervised/adaboost.py +255 -0
- mlscratch/supervised/decision_tree.py +495 -0
- mlscratch/supervised/gradient_boosting.py +354 -0
- mlscratch/supervised/knn.py +234 -0
- mlscratch/supervised/lasso_regression.py +125 -0
- mlscratch/supervised/linear_models.py +459 -0
- mlscratch/supervised/linear_regression.py +197 -0
- mlscratch/supervised/logistic_regression.py +119 -0
- mlscratch/supervised/naive_bayes.py +113 -0
- mlscratch/supervised/random_forest.py +321 -0
- mlscratch/supervised/ridge_regression.py +93 -0
- mlscratch/supervised/svm.py +356 -0
- mlscratch/unsupervised/__init__.py +39 -0
- mlscratch/unsupervised/apriori.py +178 -0
- mlscratch/unsupervised/dbscan.py +141 -0
- mlscratch/unsupervised/gmm.py +204 -0
- mlscratch/unsupervised/hierarchical_clustering.py +137 -0
- mlscratch/unsupervised/ica.py +167 -0
- mlscratch/unsupervised/kmeans.py +135 -0
- mlscratch/unsupervised/kmedoids.py +133 -0
- mlscratch/unsupervised/pca.py +103 -0
- mlscratch/unsupervised/tsne.py +200 -0
- scratchkit-0.2.0.dist-info/METADATA +241 -0
- scratchkit-0.2.0.dist-info/RECORD +68 -0
- scratchkit-0.2.0.dist-info/WHEEL +5 -0
- scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
- scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
- scratchkit-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Proximal Policy Optimization (PPO)
|
|
3
|
+
====================================
|
|
4
|
+
On-policy actor-critic algorithm that constrains the policy update via a
|
|
5
|
+
clipped surrogate objective, preventing destructively large steps.
|
|
6
|
+
|
|
7
|
+
Supports both:
|
|
8
|
+
- PPO-Clip (Schulman et al., 2017) — clips the probability ratio
|
|
9
|
+
- PPO-KL — uses an adaptive KL penalty instead of clipping
|
|
10
|
+
|
|
11
|
+
Architecture
|
|
12
|
+
------------
|
|
13
|
+
- Actor (policy) π_θ(a|s) — outputs logits for discrete or mean/log-std
|
|
14
|
+
for continuous actions
|
|
15
|
+
- Critic (value) V_φ(s) — baseline for advantage estimation
|
|
16
|
+
|
|
17
|
+
Training procedure per iteration
|
|
18
|
+
---------------------------------
|
|
19
|
+
1. Collect T timesteps with current policy (rollout)
|
|
20
|
+
2. Compute advantages Â_t using Generalised Advantage Estimation (GAE)
|
|
21
|
+
3. Run K epochs of minibatch SGD on the clipped surrogate + value + entropy losses:
|
|
22
|
+
|
|
23
|
+
L = E[ min(r_t Â_t, clip(r_t, 1-ε, 1+ε) Â_t) ]
|
|
24
|
+
- c_v (V_t - V_target)²
|
|
25
|
+
+ c_e H[π_θ(·|s_t)]
|
|
26
|
+
|
|
27
|
+
where r_t = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
|
|
28
|
+
|
|
29
|
+
Only numpy and Python stdlib are used.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
import numpy as np
|
|
34
|
+
from .utils import MLP
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _softmax(x: np.ndarray) -> np.ndarray:
|
|
38
|
+
e = np.exp(x - x.max(axis=-1, keepdims=True))
|
|
39
|
+
return e / e.sum(axis=-1, keepdims=True)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _log_softmax(x: np.ndarray) -> np.ndarray:
|
|
43
|
+
return x - np.log(np.exp(x - x.max(axis=-1, keepdims=True)).sum(
|
|
44
|
+
axis=-1, keepdims=True)) - x.max(axis=-1, keepdims=True) + x.max(axis=-1, keepdims=True)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class PPO:
|
|
48
|
+
"""
|
|
49
|
+
PPO agent supporting discrete and continuous action spaces.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
state_dim : int
|
|
54
|
+
n_actions : int number of discrete actions (discrete mode)
|
|
55
|
+
action_dim : int | None dimension of continuous actions; if not None,
|
|
56
|
+
continuous mode is used
|
|
57
|
+
action_low : float
|
|
58
|
+
action_high : float
|
|
59
|
+
hidden_sizes : list[int]
|
|
60
|
+
actor_lr : float
|
|
61
|
+
critic_lr : float
|
|
62
|
+
gamma : float discount factor
|
|
63
|
+
lam : float GAE lambda
|
|
64
|
+
clip_eps : float PPO clip ε (0 = use KL mode)
|
|
65
|
+
kl_target : float target KL divergence (KL mode)
|
|
66
|
+
kl_beta : float initial KL penalty coefficient (KL mode)
|
|
67
|
+
value_coef : float critic loss weight
|
|
68
|
+
entropy_coef : float entropy bonus weight
|
|
69
|
+
n_epochs : int gradient epochs per iteration
|
|
70
|
+
batch_size : int
|
|
71
|
+
rollout_len : int steps collected per iteration
|
|
72
|
+
max_grad_norm : float gradient clipping by norm
|
|
73
|
+
random_state : int | None
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
state_dim: int,
|
|
79
|
+
n_actions: int = 4,
|
|
80
|
+
action_dim: int | None = None,
|
|
81
|
+
action_low: float = -1.0,
|
|
82
|
+
action_high: float = 1.0,
|
|
83
|
+
hidden_sizes: list[int] | None = None,
|
|
84
|
+
actor_lr: float = 3e-4,
|
|
85
|
+
critic_lr: float = 1e-3,
|
|
86
|
+
gamma: float = 0.99,
|
|
87
|
+
lam: float = 0.95,
|
|
88
|
+
clip_eps: float = 0.2,
|
|
89
|
+
kl_target: float = 0.01,
|
|
90
|
+
kl_beta: float = 1.0,
|
|
91
|
+
value_coef: float = 0.5,
|
|
92
|
+
entropy_coef: float = 0.01,
|
|
93
|
+
n_epochs: int = 10,
|
|
94
|
+
batch_size: int = 64,
|
|
95
|
+
rollout_len: int = 2048,
|
|
96
|
+
max_grad_norm: float = 0.5,
|
|
97
|
+
random_state: int | None = None,
|
|
98
|
+
):
|
|
99
|
+
self.continuous = action_dim is not None
|
|
100
|
+
self.n_actions = action_dim if self.continuous else n_actions
|
|
101
|
+
self.action_dim = action_dim
|
|
102
|
+
self.action_low = action_low
|
|
103
|
+
self.action_high = action_high
|
|
104
|
+
self.gamma = gamma
|
|
105
|
+
self.lam = lam
|
|
106
|
+
self.clip_eps = clip_eps
|
|
107
|
+
self.use_kl = (clip_eps == 0.0)
|
|
108
|
+
self.kl_target = kl_target
|
|
109
|
+
self.kl_beta = kl_beta
|
|
110
|
+
self.value_coef = value_coef
|
|
111
|
+
self.entropy_coef = entropy_coef
|
|
112
|
+
self.n_epochs = n_epochs
|
|
113
|
+
self.batch_size = batch_size
|
|
114
|
+
self.rollout_len = rollout_len
|
|
115
|
+
self.max_grad_norm = max_grad_norm
|
|
116
|
+
self._rng = np.random.default_rng(random_state)
|
|
117
|
+
|
|
118
|
+
hidden = hidden_sizes or [64, 64]
|
|
119
|
+
|
|
120
|
+
if self.continuous:
|
|
121
|
+
# Actor outputs [mean, log_std] concatenated → 2*action_dim
|
|
122
|
+
self.actor = MLP([state_dim] + hidden + [action_dim * 2],
|
|
123
|
+
output_activation="linear", lr=actor_lr,
|
|
124
|
+
random_state=random_state)
|
|
125
|
+
else:
|
|
126
|
+
# Actor outputs logits → n_actions
|
|
127
|
+
self.actor = MLP([state_dim] + hidden + [n_actions],
|
|
128
|
+
output_activation="linear", lr=actor_lr,
|
|
129
|
+
random_state=random_state)
|
|
130
|
+
|
|
131
|
+
self.critic = MLP([state_dim] + hidden + [1],
|
|
132
|
+
output_activation="linear", lr=critic_lr,
|
|
133
|
+
random_state=random_state)
|
|
134
|
+
|
|
135
|
+
# Rollout buffers
|
|
136
|
+
self._reset_rollout()
|
|
137
|
+
|
|
138
|
+
# Logging
|
|
139
|
+
self.policy_losses_: list[float] = []
|
|
140
|
+
self.value_losses_: list[float] = []
|
|
141
|
+
self.entropies_: list[float] = []
|
|
142
|
+
self.episode_rewards_: list[float] = []
|
|
143
|
+
self._ep_reward = 0.0
|
|
144
|
+
self._ep_steps = 0
|
|
145
|
+
|
|
146
|
+
# ------------------------------------------------------------------
|
|
147
|
+
# Rollout buffer
|
|
148
|
+
# ------------------------------------------------------------------
|
|
149
|
+
|
|
150
|
+
def _reset_rollout(self) -> None:
|
|
151
|
+
self._states : list[np.ndarray] = []
|
|
152
|
+
self._actions : list[np.ndarray] = []
|
|
153
|
+
self._log_probs : list[float] = []
|
|
154
|
+
self._rewards : list[float] = []
|
|
155
|
+
self._values : list[float] = []
|
|
156
|
+
self._dones : list[bool] = []
|
|
157
|
+
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
# Policy helpers
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
def _actor_forward(self, state: np.ndarray) -> tuple:
|
|
163
|
+
"""
|
|
164
|
+
Returns (action, log_prob, entropy) for one state.
|
|
165
|
+
"""
|
|
166
|
+
out = self.actor.forward(state)
|
|
167
|
+
|
|
168
|
+
if self.continuous:
|
|
169
|
+
ad = self.action_dim
|
|
170
|
+
mean = out[:ad]
|
|
171
|
+
log_std = np.clip(out[ad:], -5, 2)
|
|
172
|
+
std = np.exp(log_std)
|
|
173
|
+
action = mean + std * self._rng.standard_normal(ad)
|
|
174
|
+
action = np.clip(action, self.action_low, self.action_high)
|
|
175
|
+
|
|
176
|
+
# Log prob of Gaussian (before tanh squashing — simplified)
|
|
177
|
+
log_prob = float(-0.5 * np.sum(
|
|
178
|
+
((action - mean) / (std + 1e-8)) ** 2
|
|
179
|
+
+ 2 * log_std + np.log(2 * np.pi)
|
|
180
|
+
))
|
|
181
|
+
entropy = float(np.sum(log_std + 0.5 * np.log(2 * np.pi * np.e)))
|
|
182
|
+
return action, log_prob, entropy
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
logits = out
|
|
186
|
+
probs = _softmax(logits)
|
|
187
|
+
action = int(self._rng.choice(len(probs), p=probs))
|
|
188
|
+
log_prob = float(np.log(probs[action] + 1e-8))
|
|
189
|
+
entropy = float(-np.sum(probs * np.log(probs + 1e-8)))
|
|
190
|
+
return np.array([action]), log_prob, entropy
|
|
191
|
+
|
|
192
|
+
def _log_prob_batch(
|
|
193
|
+
self,
|
|
194
|
+
states: np.ndarray,
|
|
195
|
+
actions: np.ndarray,
|
|
196
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
197
|
+
"""Batch log-probs and entropies under current policy."""
|
|
198
|
+
out = self.actor.forward(states, training=True) # (B, out_dim)
|
|
199
|
+
|
|
200
|
+
if self.continuous:
|
|
201
|
+
ad = self.action_dim
|
|
202
|
+
mean = out[:, :ad]
|
|
203
|
+
log_std = np.clip(out[:, ad:], -5, 2)
|
|
204
|
+
std = np.exp(log_std)
|
|
205
|
+
log_probs = -0.5 * np.sum(
|
|
206
|
+
((actions - mean) / (std + 1e-8)) ** 2
|
|
207
|
+
+ 2 * log_std + np.log(2 * np.pi),
|
|
208
|
+
axis=1
|
|
209
|
+
)
|
|
210
|
+
entropies = np.sum(log_std + 0.5 * np.log(2 * np.pi * np.e), axis=1)
|
|
211
|
+
else:
|
|
212
|
+
logits = out # (B, A)
|
|
213
|
+
probs = _softmax(logits) # (B, A)
|
|
214
|
+
act_idx = actions.ravel().astype(int)
|
|
215
|
+
log_probs = np.log(probs[np.arange(len(probs)), act_idx] + 1e-8)
|
|
216
|
+
entropies = -np.sum(probs * np.log(probs + 1e-8), axis=1)
|
|
217
|
+
|
|
218
|
+
return log_probs, entropies
|
|
219
|
+
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
# GAE advantage estimation
|
|
222
|
+
# ------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
def _compute_gae(
|
|
225
|
+
self,
|
|
226
|
+
rewards: np.ndarray,
|
|
227
|
+
values: np.ndarray,
|
|
228
|
+
dones: np.ndarray,
|
|
229
|
+
last_value: float,
|
|
230
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
231
|
+
"""
|
|
232
|
+
Returns advantages and returns (value targets).
|
|
233
|
+
"""
|
|
234
|
+
T = len(rewards)
|
|
235
|
+
advantages = np.zeros(T)
|
|
236
|
+
gae = 0.0
|
|
237
|
+
|
|
238
|
+
for t in reversed(range(T)):
|
|
239
|
+
next_val = last_value if t == T - 1 else values[t + 1]
|
|
240
|
+
mask = 1.0 - dones[t]
|
|
241
|
+
delta = rewards[t] + self.gamma * next_val * mask - values[t]
|
|
242
|
+
gae = delta + self.gamma * self.lam * mask * gae
|
|
243
|
+
advantages[t] = gae
|
|
244
|
+
|
|
245
|
+
returns = advantages + values
|
|
246
|
+
return advantages, returns
|
|
247
|
+
|
|
248
|
+
# ------------------------------------------------------------------
|
|
249
|
+
# Learning
|
|
250
|
+
# ------------------------------------------------------------------
|
|
251
|
+
|
|
252
|
+
def _update(
|
|
253
|
+
self,
|
|
254
|
+
states: np.ndarray,
|
|
255
|
+
actions: np.ndarray,
|
|
256
|
+
old_log_probs: np.ndarray,
|
|
257
|
+
advantages: np.ndarray,
|
|
258
|
+
returns: np.ndarray,
|
|
259
|
+
) -> tuple[float, float, float]:
|
|
260
|
+
"""Run n_epochs of minibatch updates. Returns mean losses."""
|
|
261
|
+
n = len(states)
|
|
262
|
+
policy_losses, value_losses, entropies = [], [], []
|
|
263
|
+
|
|
264
|
+
for _ in range(self.n_epochs):
|
|
265
|
+
idxs = self._rng.permutation(n)
|
|
266
|
+
for start in range(0, n, self.batch_size):
|
|
267
|
+
mb = idxs[start:start + self.batch_size]
|
|
268
|
+
s_b = states[mb]
|
|
269
|
+
a_b = actions[mb]
|
|
270
|
+
olp_b = old_log_probs[mb]
|
|
271
|
+
adv_b = advantages[mb]
|
|
272
|
+
ret_b = returns[mb]
|
|
273
|
+
|
|
274
|
+
# Normalise advantages per minibatch
|
|
275
|
+
adv_b = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
|
276
|
+
|
|
277
|
+
# New log-probs and entropies
|
|
278
|
+
new_log_probs, ent_b = self._log_prob_batch(s_b, a_b)
|
|
279
|
+
|
|
280
|
+
# Probability ratio
|
|
281
|
+
ratio = np.exp(new_log_probs - olp_b)
|
|
282
|
+
|
|
283
|
+
# Surrogate losses
|
|
284
|
+
surr1 = ratio * adv_b
|
|
285
|
+
if self.use_kl:
|
|
286
|
+
kl_approx = ((ratio - 1) - (new_log_probs - olp_b))
|
|
287
|
+
policy_loss = -(surr1 - self.kl_beta * kl_approx).mean()
|
|
288
|
+
else:
|
|
289
|
+
surr2 = np.clip(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv_b
|
|
290
|
+
policy_loss = -np.minimum(surr1, surr2).mean()
|
|
291
|
+
|
|
292
|
+
# Value loss
|
|
293
|
+
v_pred = self.critic.forward(s_b, training=True).ravel()
|
|
294
|
+
value_loss = float(np.mean((v_pred - ret_b) ** 2))
|
|
295
|
+
|
|
296
|
+
# Entropy bonus
|
|
297
|
+
entropy = float(ent_b.mean())
|
|
298
|
+
|
|
299
|
+
# Total loss (we do separate backward passes)
|
|
300
|
+
# --- Critic backward ---
|
|
301
|
+
dv = 2.0 * (v_pred - ret_b)[:, np.newaxis] / len(mb)
|
|
302
|
+
self.critic.backward(self.value_coef * dv)
|
|
303
|
+
|
|
304
|
+
# --- Actor backward (gradient of policy loss w.r.t. log-probs) ---
|
|
305
|
+
# dL/d(new_log_prob) ≈ -adv * ratio (simplified policy gradient)
|
|
306
|
+
d_logp = (-adv_b * ratio)[:, np.newaxis] / len(mb)
|
|
307
|
+
if self.continuous:
|
|
308
|
+
# Map gradient to actor output via chain rule (mean approximation)
|
|
309
|
+
d_actor_out = np.concatenate([
|
|
310
|
+
d_logp * np.ones((len(mb), self.action_dim)),
|
|
311
|
+
np.zeros((len(mb), self.action_dim))
|
|
312
|
+
], axis=1)
|
|
313
|
+
else:
|
|
314
|
+
# Distribute gradient through softmax to logits
|
|
315
|
+
probs_b = _softmax(self.actor.forward(s_b))
|
|
316
|
+
act_idx = a_b.ravel().astype(int)
|
|
317
|
+
d_actor_out = probs_b.copy()
|
|
318
|
+
d_actor_out[np.arange(len(mb)), act_idx] -= 1.0
|
|
319
|
+
d_actor_out = d_logp * d_actor_out
|
|
320
|
+
# Entropy gradient
|
|
321
|
+
d_actor_out -= self.entropy_coef * (
|
|
322
|
+
-np.log(probs_b + 1e-8) - 1) * probs_b / len(mb)
|
|
323
|
+
|
|
324
|
+
self.actor.backward(d_actor_out)
|
|
325
|
+
|
|
326
|
+
policy_losses.append(policy_loss)
|
|
327
|
+
value_losses.append(value_loss)
|
|
328
|
+
entropies.append(entropy)
|
|
329
|
+
|
|
330
|
+
# Adaptive KL penalty
|
|
331
|
+
if self.use_kl:
|
|
332
|
+
mean_kl = np.mean([np.exp(old_log_probs) *
|
|
333
|
+
(old_log_probs - old_log_probs) # placeholder
|
|
334
|
+
for _ in range(1)])
|
|
335
|
+
# Heuristic KL adaptation
|
|
336
|
+
pass # beta adaptation done externally if needed
|
|
337
|
+
|
|
338
|
+
return (float(np.mean(policy_losses)),
|
|
339
|
+
float(np.mean(value_losses)),
|
|
340
|
+
float(np.mean(entropies)))
|
|
341
|
+
|
|
342
|
+
# ------------------------------------------------------------------
|
|
343
|
+
# Step / episode / train
|
|
344
|
+
# ------------------------------------------------------------------
|
|
345
|
+
|
|
346
|
+
def step(
|
|
347
|
+
self,
|
|
348
|
+
state: np.ndarray,
|
|
349
|
+
reward: float | None = None,
|
|
350
|
+
done: bool = False,
|
|
351
|
+
next_state: np.ndarray | None = None,
|
|
352
|
+
) -> np.ndarray:
|
|
353
|
+
"""
|
|
354
|
+
Collect one interaction step.
|
|
355
|
+
|
|
356
|
+
Call as:
|
|
357
|
+
action = agent.step(state) # first call per env step
|
|
358
|
+
...
|
|
359
|
+
agent.step(state, reward, done, next_state) # after env step
|
|
360
|
+
|
|
361
|
+
Returns action (from first call).
|
|
362
|
+
"""
|
|
363
|
+
action, log_prob, entropy = self._actor_forward(state)
|
|
364
|
+
value = float(np.asarray(self.critic.forward(state)).ravel()[0])
|
|
365
|
+
|
|
366
|
+
self._states.append(state.copy())
|
|
367
|
+
self._actions.append(action.copy())
|
|
368
|
+
self._log_probs.append(log_prob)
|
|
369
|
+
self._values.append(value)
|
|
370
|
+
|
|
371
|
+
if reward is not None:
|
|
372
|
+
self._rewards.append(reward)
|
|
373
|
+
self._dones.append(done)
|
|
374
|
+
self._ep_reward += reward
|
|
375
|
+
if done:
|
|
376
|
+
self.episode_rewards_.append(self._ep_reward)
|
|
377
|
+
self._ep_reward = 0.0
|
|
378
|
+
|
|
379
|
+
# Trigger update when rollout is full
|
|
380
|
+
if len(self._rewards) >= self.rollout_len:
|
|
381
|
+
self._flush_rollout(next_state or state)
|
|
382
|
+
|
|
383
|
+
return action
|
|
384
|
+
|
|
385
|
+
def _flush_rollout(self, last_state: np.ndarray) -> None:
|
|
386
|
+
states = np.stack(self._states)
|
|
387
|
+
actions = np.stack(self._actions)
|
|
388
|
+
old_log_probs = np.array(self._log_probs)
|
|
389
|
+
rewards = np.array(self._rewards)
|
|
390
|
+
values = np.array(self._values[:len(rewards)])
|
|
391
|
+
dones = np.array(self._dones, dtype=float)
|
|
392
|
+
|
|
393
|
+
last_value = float(np.asarray(self.critic.forward(last_state)).ravel()[0])
|
|
394
|
+
advantages, returns = self._compute_gae(rewards, values, dones, last_value)
|
|
395
|
+
|
|
396
|
+
pl, vl, ent = self._update(states[:len(rewards)], actions[:len(rewards)],
|
|
397
|
+
old_log_probs[:len(rewards)],
|
|
398
|
+
advantages, returns)
|
|
399
|
+
self.policy_losses_.append(pl)
|
|
400
|
+
self.value_losses_.append(vl)
|
|
401
|
+
self.entropies_.append(ent)
|
|
402
|
+
|
|
403
|
+
self._reset_rollout()
|
|
404
|
+
|
|
405
|
+
def train_episode(self, env) -> float:
|
|
406
|
+
"""Run one full episode using step-by-step collection."""
|
|
407
|
+
state = env.reset() if not hasattr(env, '_rng') else env.reset(self._rng)
|
|
408
|
+
total_reward = 0.0
|
|
409
|
+
done = False
|
|
410
|
+
|
|
411
|
+
while not done:
|
|
412
|
+
action = self._actor_forward(state)[0]
|
|
413
|
+
next_state, reward, done = env.step(
|
|
414
|
+
action if self.continuous else int(action[0])
|
|
415
|
+
)
|
|
416
|
+
value = float(np.asarray(self.critic.forward(state)).ravel()[0])
|
|
417
|
+
self._states.append(state.copy())
|
|
418
|
+
self._actions.append(action.copy())
|
|
419
|
+
self._log_probs.append(self._actor_forward(state)[1])
|
|
420
|
+
self._values.append(value)
|
|
421
|
+
self._rewards.append(reward)
|
|
422
|
+
self._dones.append(done)
|
|
423
|
+
state = next_state
|
|
424
|
+
total_reward += reward
|
|
425
|
+
|
|
426
|
+
if len(self._rewards) >= self.rollout_len:
|
|
427
|
+
self._flush_rollout(next_state)
|
|
428
|
+
|
|
429
|
+
self.episode_rewards_.append(total_reward)
|
|
430
|
+
|
|
431
|
+
if len(self._rewards) > 0:
|
|
432
|
+
self._flush_rollout(state)
|
|
433
|
+
|
|
434
|
+
return total_reward
|
|
435
|
+
|
|
436
|
+
def train(self, env, n_episodes: int) -> "PPO":
|
|
437
|
+
for _ in range(n_episodes):
|
|
438
|
+
self.train_episode(env)
|
|
439
|
+
return self
|
|
440
|
+
|
|
441
|
+
# ------------------------------------------------------------------
|
|
442
|
+
# Inference
|
|
443
|
+
# ------------------------------------------------------------------
|
|
444
|
+
|
|
445
|
+
def predict(self, state: np.ndarray, deterministic: bool = True) -> np.ndarray:
|
|
446
|
+
"""Return action for state. Deterministic = mean (continuous) / argmax (discrete)."""
|
|
447
|
+
out = self.actor.forward(state)
|
|
448
|
+
if self.continuous:
|
|
449
|
+
ad = self.action_dim
|
|
450
|
+
mean = out[:ad]
|
|
451
|
+
return np.clip(mean, self.action_low, self.action_high)
|
|
452
|
+
return np.array([int(np.argmax(out))])
|