scratchkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlscratch/__init__.py +56 -0
- mlscratch/__main__.py +118 -0
- mlscratch/bayesian/__init__.py +53 -0
- mlscratch/bayesian/bayesian_linear_regression.py +171 -0
- mlscratch/bayesian/bayesian_network.py +248 -0
- mlscratch/bayesian/bayesian_nn.py +315 -0
- mlscratch/bayesian/gaussian_process.py +207 -0
- mlscratch/bayesian/hmm.py +277 -0
- mlscratch/bayesian/init.py +52 -0
- mlscratch/bayesian/kalman_filter.py +182 -0
- mlscratch/bayesian/naive_bayes.py +209 -0
- mlscratch/metrics/__init__.py +59 -0
- mlscratch/metrics/classification.py +365 -0
- mlscratch/metrics/regression.py +79 -0
- mlscratch/neural/__init__.py +121 -0
- mlscratch/neural/attention.py +420 -0
- mlscratch/neural/autoencoder.py +543 -0
- mlscratch/neural/boltzmann.py +231 -0
- mlscratch/neural/cnn.py +593 -0
- mlscratch/neural/cvnn.py +322 -0
- mlscratch/neural/gan.py +364 -0
- mlscratch/neural/hopfield.py +193 -0
- mlscratch/neural/perceptron.py +398 -0
- mlscratch/neural/rbf_network.py +230 -0
- mlscratch/neural/recurrent.py +569 -0
- mlscratch/preprocessing/__init__.py +38 -0
- mlscratch/preprocessing/encoders.py +140 -0
- mlscratch/preprocessing/model_selection.py +119 -0
- mlscratch/preprocessing/polynomial.py +105 -0
- mlscratch/preprocessing/scalers.py +220 -0
- mlscratch/py.typed +0 -0
- mlscratch/reinforcement/__init__.py +59 -0
- mlscratch/reinforcement/ddpg.py +363 -0
- mlscratch/reinforcement/dqn.py +319 -0
- mlscratch/reinforcement/ppo.py +452 -0
- mlscratch/reinforcement/q_learning.py +352 -0
- mlscratch/reinforcement/sac.py +382 -0
- mlscratch/reinforcement/utils.py +594 -0
- mlscratch/supervised/__init__.py +76 -0
- mlscratch/supervised/_validation.py +50 -0
- mlscratch/supervised/adaboost.py +255 -0
- mlscratch/supervised/decision_tree.py +495 -0
- mlscratch/supervised/gradient_boosting.py +354 -0
- mlscratch/supervised/knn.py +234 -0
- mlscratch/supervised/lasso_regression.py +125 -0
- mlscratch/supervised/linear_models.py +459 -0
- mlscratch/supervised/linear_regression.py +197 -0
- mlscratch/supervised/logistic_regression.py +119 -0
- mlscratch/supervised/naive_bayes.py +113 -0
- mlscratch/supervised/random_forest.py +321 -0
- mlscratch/supervised/ridge_regression.py +93 -0
- mlscratch/supervised/svm.py +356 -0
- mlscratch/unsupervised/__init__.py +39 -0
- mlscratch/unsupervised/apriori.py +178 -0
- mlscratch/unsupervised/dbscan.py +141 -0
- mlscratch/unsupervised/gmm.py +204 -0
- mlscratch/unsupervised/hierarchical_clustering.py +137 -0
- mlscratch/unsupervised/ica.py +167 -0
- mlscratch/unsupervised/kmeans.py +135 -0
- mlscratch/unsupervised/kmedoids.py +133 -0
- mlscratch/unsupervised/pca.py +103 -0
- mlscratch/unsupervised/tsne.py +200 -0
- scratchkit-0.2.0.dist-info/METADATA +241 -0
- scratchkit-0.2.0.dist-info/RECORD +68 -0
- scratchkit-0.2.0.dist-info/WHEEL +5 -0
- scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
- scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
- scratchkit-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Soft Actor-Critic (SAC)
|
|
3
|
+
========================
|
|
4
|
+
Maximum-entropy off-policy actor-critic algorithm for continuous control
|
|
5
|
+
(Haarnoja et al., 2018, 2019).
|
|
6
|
+
|
|
7
|
+
Objective
|
|
8
|
+
---------
|
|
9
|
+
Jointly optimise expected return *and* policy entropy:
|
|
10
|
+
|
|
11
|
+
J(π) = E[ Σ_t γ^t (r_t + α H[π(·|s_t)]) ]
|
|
12
|
+
|
|
13
|
+
where α is a temperature parameter that trades off exploration vs. exploitation.
|
|
14
|
+
|
|
15
|
+
Key components
|
|
16
|
+
--------------
|
|
17
|
+
Actor π_φ(a|s) — stochastic Gaussian policy with reparameterisation
|
|
18
|
+
Critic Q_θ1, Q_θ2 — twin soft Q-functions (minimum used for targets)
|
|
19
|
+
Target Q̄_θ1, Q̄_θ2 — exponential moving average of critic weights
|
|
20
|
+
Temperature α (or log α) — entropy coefficient, optionally auto-tuned via:
|
|
21
|
+
J(α) = E[-α (log π(a|s) + H̄)]
|
|
22
|
+
|
|
23
|
+
Reparameterisation trick (squashed Gaussian):
|
|
24
|
+
ã = tanh(μ + σ ε), ε ~ N(0,I)
|
|
25
|
+
log π(ã|s) = log N(ε; 0,I) - sum log(1 - tanh²(μ + σε) + δ)
|
|
26
|
+
|
|
27
|
+
Update equations
|
|
28
|
+
----------------
|
|
29
|
+
Q targets:
|
|
30
|
+
ã' ~ π(·|s'), y = r + γ(1-d)[min Q̄_i(s',ã') - α log π(ã'|s')]
|
|
31
|
+
Critic loss:
|
|
32
|
+
L_Q = E[(Q_i(s,a) - y)²] for i ∈ {1,2}
|
|
33
|
+
Actor loss:
|
|
34
|
+
L_π = E[α log π(ã|s) - min Q_i(s, ã)]
|
|
35
|
+
Alpha loss (auto-tune):
|
|
36
|
+
L_α = E[-α (log π(a|s) + H̄)]
|
|
37
|
+
|
|
38
|
+
Only numpy and Python stdlib are used.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
import numpy as np
|
|
43
|
+
from .utils import ReplayBuffer, MLP
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SAC:
|
|
47
|
+
"""
|
|
48
|
+
Soft Actor-Critic for continuous control.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
state_dim : int
|
|
53
|
+
action_dim : int
|
|
54
|
+
action_low : float
|
|
55
|
+
action_high : float
|
|
56
|
+
hidden_sizes : list[int]
|
|
57
|
+
actor_lr : float
|
|
58
|
+
critic_lr : float
|
|
59
|
+
alpha_lr : float learning rate for entropy temperature
|
|
60
|
+
gamma : float discount factor
|
|
61
|
+
tau : float soft target update coefficient
|
|
62
|
+
alpha : float initial entropy temperature (ignored if auto_alpha)
|
|
63
|
+
auto_alpha : bool automatically tune entropy temperature
|
|
64
|
+
target_entropy : float | None
|
|
65
|
+
target entropy (default: -action_dim)
|
|
66
|
+
buffer_capacity : int
|
|
67
|
+
batch_size : int
|
|
68
|
+
warmup_steps : int
|
|
69
|
+
log_std_min : float
|
|
70
|
+
log_std_max : float
|
|
71
|
+
random_state : int | None
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
state_dim: int,
|
|
77
|
+
action_dim: int,
|
|
78
|
+
action_low: float = -1.0,
|
|
79
|
+
action_high: float = 1.0,
|
|
80
|
+
hidden_sizes: list[int] | None = None,
|
|
81
|
+
actor_lr: float = 3e-4,
|
|
82
|
+
critic_lr: float = 3e-4,
|
|
83
|
+
alpha_lr: float = 3e-4,
|
|
84
|
+
gamma: float = 0.99,
|
|
85
|
+
tau: float = 0.005,
|
|
86
|
+
alpha: float = 0.2,
|
|
87
|
+
auto_alpha: bool = True,
|
|
88
|
+
target_entropy: float | None = None,
|
|
89
|
+
buffer_capacity: int = 100_000,
|
|
90
|
+
batch_size: int = 256,
|
|
91
|
+
warmup_steps: int = 1000,
|
|
92
|
+
log_std_min: float = -5.0,
|
|
93
|
+
log_std_max: float = 2.0,
|
|
94
|
+
random_state: int | None = None,
|
|
95
|
+
):
|
|
96
|
+
self.action_dim = action_dim
|
|
97
|
+
self.action_low = action_low
|
|
98
|
+
self.action_high = action_high
|
|
99
|
+
self.gamma = gamma
|
|
100
|
+
self.tau = tau
|
|
101
|
+
self.log_alpha = np.log(alpha) # optimise log α
|
|
102
|
+
self.alpha = alpha
|
|
103
|
+
self.auto_alpha = auto_alpha
|
|
104
|
+
self.target_entropy = (target_entropy if target_entropy is not None
|
|
105
|
+
else -float(action_dim))
|
|
106
|
+
self.alpha_lr = alpha_lr
|
|
107
|
+
self.batch_size = batch_size
|
|
108
|
+
self.warmup_steps = warmup_steps
|
|
109
|
+
self.log_std_min = log_std_min
|
|
110
|
+
self.log_std_max = log_std_max
|
|
111
|
+
self._rng = np.random.default_rng(random_state)
|
|
112
|
+
self._step = 0
|
|
113
|
+
|
|
114
|
+
# Action scaling
|
|
115
|
+
self._act_scale = (action_high - action_low) / 2.0
|
|
116
|
+
self._act_bias = (action_high + action_low) / 2.0
|
|
117
|
+
|
|
118
|
+
hidden = hidden_sizes or [256, 256]
|
|
119
|
+
|
|
120
|
+
# Actor: s → [mean, log_std] (size 2*action_dim)
|
|
121
|
+
self.actor = MLP([state_dim] + hidden + [action_dim * 2],
|
|
122
|
+
output_activation="linear", lr=actor_lr,
|
|
123
|
+
random_state=random_state)
|
|
124
|
+
|
|
125
|
+
# Twin critics: (s, a) → Q
|
|
126
|
+
self.critic1 = MLP([state_dim + action_dim] + hidden + [1],
|
|
127
|
+
output_activation="linear", lr=critic_lr,
|
|
128
|
+
random_state=random_state)
|
|
129
|
+
self.critic1_target = MLP([state_dim + action_dim] + hidden + [1],
|
|
130
|
+
output_activation="linear", lr=critic_lr,
|
|
131
|
+
random_state=random_state)
|
|
132
|
+
self.critic1.hard_update(self.critic1_target)
|
|
133
|
+
|
|
134
|
+
self.critic2 = MLP([state_dim + action_dim] + hidden + [1],
|
|
135
|
+
output_activation="linear", lr=critic_lr,
|
|
136
|
+
random_state=random_state)
|
|
137
|
+
self.critic2_target = MLP([state_dim + action_dim] + hidden + [1],
|
|
138
|
+
output_activation="linear", lr=critic_lr,
|
|
139
|
+
random_state=random_state)
|
|
140
|
+
self.critic2.hard_update(self.critic2_target)
|
|
141
|
+
|
|
142
|
+
# Alpha optimiser state (single-parameter Adam)
|
|
143
|
+
self._alpha_m = 0.0
|
|
144
|
+
self._alpha_v = 0.0
|
|
145
|
+
self._alpha_t = 0
|
|
146
|
+
|
|
147
|
+
# Replay
|
|
148
|
+
self.buffer = ReplayBuffer(buffer_capacity)
|
|
149
|
+
|
|
150
|
+
# Logging
|
|
151
|
+
self.actor_losses_: list[float] = []
|
|
152
|
+
self.critic_losses_: list[float] = []
|
|
153
|
+
self.alpha_losses_: list[float] = []
|
|
154
|
+
self.alphas_: list[float] = []
|
|
155
|
+
self.episode_rewards_: list[float] = []
|
|
156
|
+
|
|
157
|
+
# ------------------------------------------------------------------
|
|
158
|
+
# Squashed Gaussian policy
|
|
159
|
+
# ------------------------------------------------------------------
|
|
160
|
+
|
|
161
|
+
def _actor_output(self, states: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
162
|
+
"""Return (mean, log_std) from actor network output."""
|
|
163
|
+
out = self.actor.forward(states) # (B, 2*A) or (2*A,)
|
|
164
|
+
if out.ndim == 1:
|
|
165
|
+
mean = out[:self.action_dim]
|
|
166
|
+
log_std = out[self.action_dim:]
|
|
167
|
+
else:
|
|
168
|
+
mean = out[:, :self.action_dim]
|
|
169
|
+
log_std = out[:, self.action_dim:]
|
|
170
|
+
log_std = np.clip(log_std, self.log_std_min, self.log_std_max)
|
|
171
|
+
return mean, log_std
|
|
172
|
+
|
|
173
|
+
def _sample_action(
|
|
174
|
+
self,
|
|
175
|
+
states: np.ndarray,
|
|
176
|
+
training: bool = False,
|
|
177
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
178
|
+
"""
|
|
179
|
+
Sample action via reparameterisation + squashing.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
action : scaled to [action_low, action_high]
|
|
184
|
+
log_prob : log π(a|s) accounting for tanh squashing
|
|
185
|
+
"""
|
|
186
|
+
mean, log_std = self._actor_output(states)
|
|
187
|
+
std = np.exp(log_std)
|
|
188
|
+
|
|
189
|
+
# Reparameterisation
|
|
190
|
+
eps = self._rng.standard_normal(mean.shape)
|
|
191
|
+
pre_tanh = mean + std * eps # pre-squash
|
|
192
|
+
|
|
193
|
+
# Squash
|
|
194
|
+
a_tanh = np.tanh(pre_tanh)
|
|
195
|
+
|
|
196
|
+
# Log prob (Gaussian) - log det(Jacobian of tanh)
|
|
197
|
+
log_prob_gauss = -0.5 * (
|
|
198
|
+
((pre_tanh - mean) / (std + 1e-8)) ** 2
|
|
199
|
+
+ 2 * log_std + np.log(2 * np.pi)
|
|
200
|
+
)
|
|
201
|
+
# Squashing correction: log(1 - tanh²(u) + δ)
|
|
202
|
+
log_prob_correction = np.log(1.0 - a_tanh ** 2 + 1e-6)
|
|
203
|
+
|
|
204
|
+
if log_prob_gauss.ndim == 1:
|
|
205
|
+
log_prob = float((log_prob_gauss - log_prob_correction).sum())
|
|
206
|
+
else:
|
|
207
|
+
log_prob = (log_prob_gauss - log_prob_correction).sum(axis=1)
|
|
208
|
+
|
|
209
|
+
action = a_tanh * self._act_scale + self._act_bias
|
|
210
|
+
return action, log_prob
|
|
211
|
+
|
|
212
|
+
def select_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray:
|
|
213
|
+
"""
|
|
214
|
+
Select action for environment interaction.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
state : (state_dim,)
|
|
219
|
+
deterministic : if True, use mean action (no sampling)
|
|
220
|
+
"""
|
|
221
|
+
mean, _ = self._actor_output(state)
|
|
222
|
+
if deterministic:
|
|
223
|
+
return np.tanh(mean) * self._act_scale + self._act_bias
|
|
224
|
+
action, _ = self._sample_action(state)
|
|
225
|
+
return action
|
|
226
|
+
|
|
227
|
+
# ------------------------------------------------------------------
|
|
228
|
+
# Alpha (entropy temperature) update
|
|
229
|
+
# ------------------------------------------------------------------
|
|
230
|
+
|
|
231
|
+
def _update_alpha(self, log_probs: np.ndarray) -> float:
|
|
232
|
+
"""
|
|
233
|
+
Gradient step on J(α) = E[-α (log π + H̄)]
|
|
234
|
+
using Adam on log α (ensures α > 0).
|
|
235
|
+
"""
|
|
236
|
+
self._alpha_t += 1
|
|
237
|
+
beta1, beta2, eps_adam = 0.9, 0.999, 1e-8
|
|
238
|
+
|
|
239
|
+
# Gradient dJ/d(log α) = -E[log π + H̄]
|
|
240
|
+
grad_log_alpha = float(-np.mean(log_probs + self.target_entropy))
|
|
241
|
+
|
|
242
|
+
self._alpha_m = beta1 * self._alpha_m + (1 - beta1) * grad_log_alpha
|
|
243
|
+
self._alpha_v = beta2 * self._alpha_v + (1 - beta2) * grad_log_alpha ** 2
|
|
244
|
+
m_hat = self._alpha_m / (1 - beta1 ** self._alpha_t)
|
|
245
|
+
v_hat = self._alpha_v / (1 - beta2 ** self._alpha_t)
|
|
246
|
+
|
|
247
|
+
self.log_alpha -= self.alpha_lr * m_hat / (np.sqrt(v_hat) + eps_adam)
|
|
248
|
+
self.alpha = float(np.exp(self.log_alpha))
|
|
249
|
+
|
|
250
|
+
alpha_loss = float(np.mean(-self.alpha * (log_probs + self.target_entropy)))
|
|
251
|
+
return alpha_loss
|
|
252
|
+
|
|
253
|
+
# ------------------------------------------------------------------
|
|
254
|
+
# Learning step
|
|
255
|
+
# ------------------------------------------------------------------
|
|
256
|
+
|
|
257
|
+
def _learn(self) -> tuple[float, float, float]:
|
|
258
|
+
if len(self.buffer) < self.batch_size:
|
|
259
|
+
return None, None, None
|
|
260
|
+
|
|
261
|
+
states, actions, rewards, next_states, dones = \
|
|
262
|
+
self.buffer.sample(self.batch_size, self._rng)
|
|
263
|
+
|
|
264
|
+
# ── Critic targets ────────────────────────────────────────────
|
|
265
|
+
a_next, log_prob_next = self._sample_action(next_states)
|
|
266
|
+
sa_next = np.concatenate([next_states, a_next], axis=1)
|
|
267
|
+
|
|
268
|
+
q1_next = self.critic1_target.forward(sa_next).ravel()
|
|
269
|
+
q2_next = self.critic2_target.forward(sa_next).ravel()
|
|
270
|
+
q_next = np.minimum(q1_next, q2_next)
|
|
271
|
+
|
|
272
|
+
# Soft Bellman target
|
|
273
|
+
y = rewards + self.gamma * (1.0 - dones) * (q_next - self.alpha * log_prob_next)
|
|
274
|
+
|
|
275
|
+
# ── Critic update ─────────────────────────────────────────────
|
|
276
|
+
sa = np.concatenate([states, actions], axis=1)
|
|
277
|
+
|
|
278
|
+
q1_pred = self.critic1.forward(sa, training=True).ravel()
|
|
279
|
+
td1 = y - q1_pred
|
|
280
|
+
c1_loss = float(np.mean(td1 ** 2))
|
|
281
|
+
self.critic1.backward(-2.0 * td1[:, np.newaxis] / self.batch_size)
|
|
282
|
+
|
|
283
|
+
q2_pred = self.critic2.forward(sa, training=True).ravel()
|
|
284
|
+
td2 = y - q2_pred
|
|
285
|
+
c2_loss = float(np.mean(td2 ** 2))
|
|
286
|
+
self.critic2.backward(-2.0 * td2[:, np.newaxis] / self.batch_size)
|
|
287
|
+
|
|
288
|
+
critic_loss = (c1_loss + c2_loss) / 2.0
|
|
289
|
+
|
|
290
|
+
# ── Actor update ──────────────────────────────────────────────
|
|
291
|
+
a_new, log_prob_new = self._sample_action(states)
|
|
292
|
+
sa_new = np.concatenate([states, a_new], axis=1)
|
|
293
|
+
|
|
294
|
+
q1_new = self.critic1.forward(sa_new).ravel()
|
|
295
|
+
q2_new = self.critic2.forward(sa_new).ravel()
|
|
296
|
+
q_min = np.minimum(q1_new, q2_new)
|
|
297
|
+
|
|
298
|
+
actor_loss = float(np.mean(self.alpha * log_prob_new - q_min))
|
|
299
|
+
|
|
300
|
+
# Actor gradient: dL/d(actor_params) via chain rule through log_prob and q
|
|
301
|
+
# dL/d(a) = α d(log π)/d(a) - d(Q)/d(a) (simplified: uniform gradient direction)
|
|
302
|
+
d_a = (self.alpha * np.ones((self.batch_size, self.action_dim))
|
|
303
|
+
- np.ones((self.batch_size, self.action_dim))) / self.batch_size
|
|
304
|
+
|
|
305
|
+
# Map through tanh squashing to get gradient for actor output
|
|
306
|
+
mean, log_std = self._actor_output(states)
|
|
307
|
+
std = np.exp(log_std)
|
|
308
|
+
a_tanh_part = (a_new - self._act_bias) / self._act_scale
|
|
309
|
+
sech2 = 1.0 - a_tanh_part ** 2 # d(tanh)/d(pre_tanh)
|
|
310
|
+
d_pretanh = d_a * self._act_scale * sech2
|
|
311
|
+
d_mean = d_pretanh
|
|
312
|
+
d_log_std = d_pretanh * (a_new - self._act_bias - mean * self._act_scale) / \
|
|
313
|
+
(self._act_scale + 1e-8)
|
|
314
|
+
|
|
315
|
+
d_actor_out = np.concatenate([d_mean, d_log_std], axis=1)
|
|
316
|
+
# Force forward with training=True to set cache
|
|
317
|
+
self.actor.forward(states, training=True)
|
|
318
|
+
self.actor.backward(d_actor_out)
|
|
319
|
+
|
|
320
|
+
# ── Alpha update ──────────────────────────────────────────────
|
|
321
|
+
alpha_loss = 0.0
|
|
322
|
+
if self.auto_alpha:
|
|
323
|
+
alpha_loss = self._update_alpha(log_prob_new)
|
|
324
|
+
|
|
325
|
+
# ── Soft target updates ───────────────────────────────────────
|
|
326
|
+
self.critic1.soft_update(self.critic1_target, self.tau)
|
|
327
|
+
self.critic2.soft_update(self.critic2_target, self.tau)
|
|
328
|
+
|
|
329
|
+
return actor_loss, critic_loss, alpha_loss
|
|
330
|
+
|
|
331
|
+
# ------------------------------------------------------------------
|
|
332
|
+
# Environment interaction
|
|
333
|
+
# ------------------------------------------------------------------
|
|
334
|
+
|
|
335
|
+
def step(
|
|
336
|
+
self,
|
|
337
|
+
state: np.ndarray,
|
|
338
|
+
action: np.ndarray,
|
|
339
|
+
reward: float,
|
|
340
|
+
next_state: np.ndarray,
|
|
341
|
+
done: bool,
|
|
342
|
+
) -> tuple[float | None, float | None, float | None]:
|
|
343
|
+
self.buffer.push(state, action, reward, next_state, done)
|
|
344
|
+
self._step += 1
|
|
345
|
+
|
|
346
|
+
if self._step < self.warmup_steps:
|
|
347
|
+
return None, None, None
|
|
348
|
+
|
|
349
|
+
al, cl, alpha_l = self._learn()
|
|
350
|
+
if al is not None:
|
|
351
|
+
self.actor_losses_.append(al)
|
|
352
|
+
self.critic_losses_.append(cl)
|
|
353
|
+
self.alpha_losses_.append(alpha_l)
|
|
354
|
+
self.alphas_.append(self.alpha)
|
|
355
|
+
|
|
356
|
+
return al, cl, alpha_l
|
|
357
|
+
|
|
358
|
+
def train_episode(self, env) -> float:
|
|
359
|
+
rng_arg = self._rng if hasattr(env, '_env') else None
|
|
360
|
+
state = env.reset(rng_arg) if rng_arg is not None else env.reset()
|
|
361
|
+
total_reward = 0.0
|
|
362
|
+
done = False
|
|
363
|
+
|
|
364
|
+
while not done:
|
|
365
|
+
if self._step < self.warmup_steps:
|
|
366
|
+
action = self._rng.uniform(
|
|
367
|
+
self.action_low, self.action_high, self.action_dim
|
|
368
|
+
)
|
|
369
|
+
else:
|
|
370
|
+
action = self.select_action(state)
|
|
371
|
+
next_state, reward, done = env.step(action)
|
|
372
|
+
self.step(state, action, reward, next_state, done)
|
|
373
|
+
state = next_state
|
|
374
|
+
total_reward += reward
|
|
375
|
+
|
|
376
|
+
self.episode_rewards_.append(total_reward)
|
|
377
|
+
return total_reward
|
|
378
|
+
|
|
379
|
+
def train(self, env, n_episodes: int) -> "SAC":
|
|
380
|
+
for _ in range(n_episodes):
|
|
381
|
+
self.train_episode(env)
|
|
382
|
+
return self
|