drlab 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drlab/__init__.py +36 -0
- drlab/controllers/__init__.py +6 -0
- drlab/controllers/base.py +17 -0
- drlab/controllers/e_greedy.py +46 -0
- drlab/controllers/greedy.py +20 -0
- drlab/controllers/stochastic_controller.py +17 -0
- drlab/experiments/__init__.py +9 -0
- drlab/experiments/ac_experiment.py +102 -0
- drlab/experiments/dqn_experiment.py +151 -0
- drlab/learners/__init__.py +4 -0
- drlab/learners/actor_critic.py +206 -0
- drlab/learners/dqn.py +149 -0
- drlab/replay/__init__.py +4 -0
- drlab/replay/replay_buffer.py +86 -0
- drlab/replay/transition_batch.py +33 -0
- drlab/runners/__init__.py +3 -0
- drlab/runners/runner.py +123 -0
- drlab-0.1.0.dist-info/METADATA +276 -0
- drlab-0.1.0.dist-info/RECORD +22 -0
- drlab-0.1.0.dist-info/WHEEL +5 -0
- drlab-0.1.0.dist-info/licenses/LICENSE +21 -0
- drlab-0.1.0.dist-info/top_level.txt +1 -0
drlab/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from drlab.controllers import (
|
|
2
|
+
Controller,
|
|
3
|
+
EpsilonGreedyController,
|
|
4
|
+
GreedyController,
|
|
5
|
+
StochasticController,
|
|
6
|
+
)
|
|
7
|
+
from drlab.experiments import (
|
|
8
|
+
ActorCriticExperiment,
|
|
9
|
+
ActorCriticExperimentConfig,
|
|
10
|
+
DQNExperiment,
|
|
11
|
+
DQNExperimentConfig,
|
|
12
|
+
)
|
|
13
|
+
from drlab.learners import ActorCritic, ActorCriticConfig, DQN, DQNConfig
|
|
14
|
+
from drlab.replay import ReplayBuffer, TransitionBatch
|
|
15
|
+
from drlab.runners import Runner
|
|
16
|
+
|
|
17
|
+
__version__ = "0.1.0"
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"__version__",
|
|
21
|
+
"ActorCritic",
|
|
22
|
+
"ActorCriticConfig",
|
|
23
|
+
"ActorCriticExperiment",
|
|
24
|
+
"ActorCriticExperimentConfig",
|
|
25
|
+
"Controller",
|
|
26
|
+
"DQN",
|
|
27
|
+
"DQNConfig",
|
|
28
|
+
"DQNExperiment",
|
|
29
|
+
"DQNExperimentConfig",
|
|
30
|
+
"EpsilonGreedyController",
|
|
31
|
+
"GreedyController",
|
|
32
|
+
"ReplayBuffer",
|
|
33
|
+
"Runner",
|
|
34
|
+
"StochasticController",
|
|
35
|
+
"TransitionBatch",
|
|
36
|
+
]
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .base import Controller
|
|
2
|
+
from .greedy import GreedyController
|
|
3
|
+
from .e_greedy import EpsilonGreedyController
|
|
4
|
+
from .stochastic_controller import StochasticController
|
|
5
|
+
|
|
6
|
+
__all__ = ["Controller", "GreedyController", "EpsilonGreedyController", "StochasticController"]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch as th
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
class Controller(ABC):
|
|
7
|
+
"""Abstract controller interface."""
|
|
8
|
+
|
|
9
|
+
num_actions: int
|
|
10
|
+
model: th.nn.Module = None
|
|
11
|
+
controller: Controller = None
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def choose(self, obs: th.Tensor, **kwargs) -> th.Tensor: ...
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def probabilities(self, obs: th.Tensor, **kwargs) -> th.Tensor: ...
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch as th
|
|
3
|
+
from .base import Controller
|
|
4
|
+
|
|
5
|
+
class EpsilonGreedyController(Controller):
|
|
6
|
+
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
controller: Controller,
|
|
10
|
+
num_actions: int,
|
|
11
|
+
max_eps: float = 1.0,
|
|
12
|
+
min_eps: float = 0.1,
|
|
13
|
+
anneal_steps: int = 10_000,
|
|
14
|
+
):
|
|
15
|
+
self.controller = controller
|
|
16
|
+
self.num_actions = num_actions
|
|
17
|
+
self.model = controller.model
|
|
18
|
+
self.max_eps = max_eps
|
|
19
|
+
self.min_eps = min_eps
|
|
20
|
+
self.anneal_steps = anneal_steps
|
|
21
|
+
self.num_decisions = 0
|
|
22
|
+
|
|
23
|
+
if anneal_steps <= 1:
|
|
24
|
+
raise ValueError("anneal_steps must be >= 2")
|
|
25
|
+
|
|
26
|
+
def epsilon(self) -> float:
|
|
27
|
+
frac = max(1 - self.num_decisions / (self.anneal_steps - 1), 0.0)
|
|
28
|
+
return frac * (self.max_eps - self.min_eps) + self.min_eps
|
|
29
|
+
|
|
30
|
+
def choose(self, obs: th.Tensor, increase_counter: bool = True, **kwargs) -> th.Tensor:
|
|
31
|
+
eps = self.epsilon()
|
|
32
|
+
if increase_counter:
|
|
33
|
+
self.num_decisions += 1
|
|
34
|
+
|
|
35
|
+
B = obs.shape[0] if obs.ndim > 1 else 1
|
|
36
|
+
if np.random.rand() < eps:
|
|
37
|
+
return th.randint(self.num_actions, (B,), device=obs.device, dtype=th.long)
|
|
38
|
+
|
|
39
|
+
return self.controller.choose(obs, **kwargs)
|
|
40
|
+
|
|
41
|
+
def probabilities(self, obs: th.Tensor, **kwargs) -> th.Tensor:
|
|
42
|
+
eps = self.epsilon()
|
|
43
|
+
greedy = self.controller.probabilities(obs, **kwargs) # one-hot on argmax, shape [B,A]
|
|
44
|
+
B = greedy.shape[0]
|
|
45
|
+
uniform = th.full((B, self.num_actions), eps / self.num_actions, device=greedy.device)
|
|
46
|
+
return uniform + (1 - eps) * greedy
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch as th
|
|
2
|
+
from .base import Controller
|
|
3
|
+
|
|
4
|
+
class GreedyController(Controller):
|
|
5
|
+
|
|
6
|
+
def __init__(self, model: th.nn.Module, num_actions: int):
|
|
7
|
+
|
|
8
|
+
self.num_actions = num_actions
|
|
9
|
+
self.model = model
|
|
10
|
+
|
|
11
|
+
def choose(self, obs: th.Tensor):
|
|
12
|
+
output: th.Tensor = self.model(obs)[:, :self.num_actions]
|
|
13
|
+
return th.argmax(output, dim=-1)
|
|
14
|
+
|
|
15
|
+
def probabilities(self, obs: th.Tensor):
|
|
16
|
+
output: th.Tensor = self.model(obs)[:, :self.num_actions]
|
|
17
|
+
idx = output.argmax(dim=-1, keepdim=True) # [B,1]
|
|
18
|
+
probs = th.zeros(output.shape[0], self.num_actions, device=output.device)
|
|
19
|
+
probs.scatter_(dim=-1, index=idx, value=1.0)
|
|
20
|
+
return probs
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import torch as th
|
|
2
|
+
from .base import Controller
|
|
3
|
+
|
|
4
|
+
class StochasticController(Controller):
|
|
5
|
+
|
|
6
|
+
def __init__(self, model: th.nn.Module, num_actions: int):
|
|
7
|
+
|
|
8
|
+
self.num_actions = num_actions
|
|
9
|
+
self.model = model
|
|
10
|
+
|
|
11
|
+
def choose(self, obs: th.Tensor) -> th.Tensor:
|
|
12
|
+
probs = self.probabilities(obs)
|
|
13
|
+
return th.distributions.Categorical(probs=probs).sample()
|
|
14
|
+
|
|
15
|
+
def probabilities(self, obs: th.Tensor) -> th.Tensor:
|
|
16
|
+
output = self.model(obs)[:, :self.num_actions]
|
|
17
|
+
return th.nn.functional.softmax(output, dim=-1)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .dqn_experiment import DQNExperiment, DQNExperimentConfig
|
|
2
|
+
from .ac_experiment import ActorCriticExperiment, ActorCriticExperimentConfig
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"DQNExperimentConfig",
|
|
6
|
+
"DQNExperiment",
|
|
7
|
+
"ActorCriticExperimentConfig",
|
|
8
|
+
"ActorCriticExperiment"
|
|
9
|
+
]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
import gymnasium as gym
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Callable
|
|
6
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
7
|
+
|
|
8
|
+
from drlab.learners import ActorCritic
|
|
9
|
+
from drlab.runners import Runner
|
|
10
|
+
from drlab.controllers import Controller
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ActorCriticExperimentConfig:
|
|
15
|
+
max_steps: int
|
|
16
|
+
gamma: float = 0.99
|
|
17
|
+
run_steps: int = 0
|
|
18
|
+
log_dir: str = "runs/reinforce_experiment"
|
|
19
|
+
experiment_name: str = "ActorCriticExperiment"
|
|
20
|
+
step_callback: Callable[[int], None] | None = None
|
|
21
|
+
step_callback_interval: int | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ActorCriticExperiment:
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
env: gym.Env,
|
|
29
|
+
controller: Controller,
|
|
30
|
+
learner: ActorCritic,
|
|
31
|
+
config: ActorCriticExperimentConfig,
|
|
32
|
+
):
|
|
33
|
+
# Init experiment settings
|
|
34
|
+
self.max_steps = config.max_steps
|
|
35
|
+
self.run_steps = config.run_steps
|
|
36
|
+
self.step_callback = config.step_callback
|
|
37
|
+
self.step_callback_interval = config.step_callback_interval
|
|
38
|
+
|
|
39
|
+
# Only ask the runner to compute discounted returns when the learner uses them.
|
|
40
|
+
calculate_returns = (
|
|
41
|
+
not learner.config.advantage_bootstrap
|
|
42
|
+
or (learner.config.use_bias and learner.config.value_targets == "returns")
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Init drl components
|
|
46
|
+
self.runner = Runner(env, controller, calculate_returns, False, config.gamma, learner.device)
|
|
47
|
+
self.learner = learner
|
|
48
|
+
|
|
49
|
+
# Init logging
|
|
50
|
+
self.writer = SummaryWriter(log_dir=config.log_dir)
|
|
51
|
+
self.experiment_name = config.experiment_name
|
|
52
|
+
|
|
53
|
+
if self.step_callback is not None and self.step_callback_interval is None:
|
|
54
|
+
raise ValueError("step_callback_interval must be set when step_callback is provided.")
|
|
55
|
+
|
|
56
|
+
def run(self):
|
|
57
|
+
|
|
58
|
+
steps = 0
|
|
59
|
+
if self.step_callback is not None:
|
|
60
|
+
self.step_callback(steps)
|
|
61
|
+
next_callback_step = self.step_callback_interval
|
|
62
|
+
pbar = tqdm(total=self.max_steps, desc=self.experiment_name, dynamic_ncols=True, mininterval=1.0)
|
|
63
|
+
while steps < self.max_steps:
|
|
64
|
+
|
|
65
|
+
# 1. Run batch of transitions
|
|
66
|
+
batch, ep_returns, ep_lengths, _ = self.runner.run(self.run_steps)
|
|
67
|
+
batch = batch.to(self.learner.device)
|
|
68
|
+
|
|
69
|
+
# 2. Learn from batch
|
|
70
|
+
loss = self.learner.train(
|
|
71
|
+
batch.rewards,
|
|
72
|
+
batch.dones,
|
|
73
|
+
batch.states,
|
|
74
|
+
batch.actions,
|
|
75
|
+
batch.next_states,
|
|
76
|
+
batch.returns,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# 3. Log results
|
|
80
|
+
self.writer.add_scalar("Loss", loss, steps)
|
|
81
|
+
if ep_returns:
|
|
82
|
+
mean_return = np.mean(ep_returns)
|
|
83
|
+
mean_length = np.mean(ep_lengths)
|
|
84
|
+
self.writer.add_scalar("Episode return", mean_return, steps)
|
|
85
|
+
self.writer.add_scalar("Episode Length", mean_length, steps)
|
|
86
|
+
|
|
87
|
+
pbar.set_postfix({
|
|
88
|
+
"loss": f"{loss:.3f}",
|
|
89
|
+
"return": f"{mean_return:.2f}",
|
|
90
|
+
"len": f"{mean_length:.1f}",
|
|
91
|
+
}, refresh=False)
|
|
92
|
+
|
|
93
|
+
step_inc = batch.states.shape[0]
|
|
94
|
+
steps += step_inc
|
|
95
|
+
pbar.update(step_inc)
|
|
96
|
+
|
|
97
|
+
while self.step_callback is not None and steps >= next_callback_step:
|
|
98
|
+
self.step_callback(next_callback_step)
|
|
99
|
+
next_callback_step += self.step_callback_interval
|
|
100
|
+
|
|
101
|
+
pbar.close()
|
|
102
|
+
self.writer.close()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
import gymnasium as gym
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Callable
|
|
6
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
7
|
+
|
|
8
|
+
from drlab.learners import DQN
|
|
9
|
+
from drlab.runners import Runner
|
|
10
|
+
from drlab.controllers import Controller
|
|
11
|
+
from drlab.replay import TransitionBatch, ReplayBuffer
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DQNExperimentConfig:
|
|
15
|
+
max_steps: int
|
|
16
|
+
gamma: float = 0.99
|
|
17
|
+
run_steps: int = 0
|
|
18
|
+
log_dir: str = "runs/dqn_experiment"
|
|
19
|
+
experiment_name: str = "DQNExperiment"
|
|
20
|
+
use_replay: bool = True
|
|
21
|
+
replay_buffer_size: int = 10_000
|
|
22
|
+
batch_size: int = 128
|
|
23
|
+
use_last_episode: bool = True
|
|
24
|
+
grad_repeats: int = 1
|
|
25
|
+
step_callback: Callable[[int], None] | None = None
|
|
26
|
+
step_callback_interval: int | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DQNExperiment:
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
env: gym.Env,
|
|
34
|
+
controller: Controller,
|
|
35
|
+
learner: DQN,
|
|
36
|
+
config: DQNExperimentConfig,
|
|
37
|
+
):
|
|
38
|
+
|
|
39
|
+
# Init experiment settings
|
|
40
|
+
self.max_steps = config.max_steps
|
|
41
|
+
self.grad_repeats = config.grad_repeats
|
|
42
|
+
self.batch_size = config.batch_size
|
|
43
|
+
self.use_last_episode = config.use_last_episode
|
|
44
|
+
self.run_steps = config.run_steps
|
|
45
|
+
self.step_callback = config.step_callback
|
|
46
|
+
self.step_callback_interval = config.step_callback_interval
|
|
47
|
+
|
|
48
|
+
# Init drl components
|
|
49
|
+
self.runner = Runner(env, controller, False, True, config.gamma, learner.device)
|
|
50
|
+
self.learner = learner
|
|
51
|
+
|
|
52
|
+
# Init replay buffer
|
|
53
|
+
self.replay_buffer_size = config.replay_buffer_size
|
|
54
|
+
self.use_replay = config.use_replay
|
|
55
|
+
self.replay_buffer = ReplayBuffer(
|
|
56
|
+
capacity=config.replay_buffer_size if config.use_replay else config.batch_size,
|
|
57
|
+
obs_shape=env.observation_space.shape,
|
|
58
|
+
device=learner.device
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Init logging
|
|
62
|
+
self.writer = SummaryWriter(log_dir=config.log_dir)
|
|
63
|
+
self.experiment_name = config.experiment_name
|
|
64
|
+
|
|
65
|
+
if self.step_callback is not None and self.step_callback_interval is None:
|
|
66
|
+
raise ValueError("step_callback_interval must be set when step_callback is provided.")
|
|
67
|
+
|
|
68
|
+
def _make_minibatch(self, batch: TransitionBatch, last_episode: TransitionBatch | None) -> TransitionBatch:
|
|
69
|
+
|
|
70
|
+
# No replay buffer
|
|
71
|
+
if not self.use_replay:
|
|
72
|
+
return self.replay_buffer.get_all()
|
|
73
|
+
|
|
74
|
+
# Replay buffer without last episode
|
|
75
|
+
if not self.use_last_episode:
|
|
76
|
+
return self.replay_buffer.sample(self.batch_size)
|
|
77
|
+
|
|
78
|
+
# Replay buffer with last episode
|
|
79
|
+
episode = last_episode if last_episode else batch
|
|
80
|
+
episode = episode.to(self.learner.device)
|
|
81
|
+
|
|
82
|
+
ep_len = episode.states.shape[0]
|
|
83
|
+
if ep_len >= self.batch_size:
|
|
84
|
+
return episode
|
|
85
|
+
|
|
86
|
+
rest = self.replay_buffer.sample(self.batch_size - ep_len)
|
|
87
|
+
return rest.cat(episode)
|
|
88
|
+
|
|
89
|
+
def _learn_from_batch(self, batch: TransitionBatch, last_episode: TransitionBatch | None) -> float:
|
|
90
|
+
# 1) Add batch to replay buffer
|
|
91
|
+
self.replay_buffer.add(
|
|
92
|
+
states=batch.states.cpu().numpy(),
|
|
93
|
+
actions=batch.actions.cpu().numpy(),
|
|
94
|
+
rewards=batch.rewards.cpu().numpy(),
|
|
95
|
+
dones=batch.dones.cpu().numpy(),
|
|
96
|
+
next_states=batch.next_states.cpu().numpy(),
|
|
97
|
+
returns=batch.returns.cpu().numpy(),
|
|
98
|
+
)
|
|
99
|
+
if self.replay_buffer.size < self.batch_size:
|
|
100
|
+
return 0.0
|
|
101
|
+
|
|
102
|
+
# 2) train repeats
|
|
103
|
+
total_loss = 0.0
|
|
104
|
+
for _ in range(self.grad_repeats):
|
|
105
|
+
mb = self._make_minibatch(batch, last_episode)
|
|
106
|
+
total_loss += self.learner.train(
|
|
107
|
+
mb.rewards, mb.dones, mb.states, mb.actions, mb.next_states
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return total_loss / self.grad_repeats
|
|
111
|
+
|
|
112
|
+
def run(self):
|
|
113
|
+
|
|
114
|
+
steps = 0
|
|
115
|
+
if self.step_callback is not None:
|
|
116
|
+
self.step_callback(steps)
|
|
117
|
+
next_callback_step = self.step_callback_interval
|
|
118
|
+
pbar = tqdm(total=self.max_steps, desc=self.experiment_name, dynamic_ncols=True, mininterval=1.0)
|
|
119
|
+
while steps < self.max_steps:
|
|
120
|
+
|
|
121
|
+
# 1. Run batch of transitions
|
|
122
|
+
batch, ep_returns, ep_lengths, last_episode = self.runner.run(self.run_steps)
|
|
123
|
+
|
|
124
|
+
# 2. Learn from batch
|
|
125
|
+
loss = self._learn_from_batch(batch, last_episode)
|
|
126
|
+
|
|
127
|
+
# 3. Log results
|
|
128
|
+
self.writer.add_scalar("Loss", loss, steps)
|
|
129
|
+
if ep_returns:
|
|
130
|
+
mean_return = np.mean(ep_returns)
|
|
131
|
+
mean_length = np.mean(ep_lengths)
|
|
132
|
+
self.writer.add_scalar("Episode return", mean_return, steps)
|
|
133
|
+
self.writer.add_scalar("Episode Length", mean_length, steps)
|
|
134
|
+
|
|
135
|
+
pbar.set_postfix({
|
|
136
|
+
"loss": f"{loss:.3f}",
|
|
137
|
+
"return": f"{mean_return:.2f}",
|
|
138
|
+
"len": f"{mean_length:.1f}",
|
|
139
|
+
}, refresh=False)
|
|
140
|
+
|
|
141
|
+
# 3. Update steps & progress bar
|
|
142
|
+
step_inc = batch.states.shape[0]
|
|
143
|
+
steps += step_inc
|
|
144
|
+
pbar.update(step_inc)
|
|
145
|
+
|
|
146
|
+
while self.step_callback is not None and steps >= next_callback_step:
|
|
147
|
+
self.step_callback(next_callback_step)
|
|
148
|
+
next_callback_step += self.step_callback_interval
|
|
149
|
+
|
|
150
|
+
pbar.close()
|
|
151
|
+
self.writer.close()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import torch as th
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ActorCriticConfig:
|
|
8
|
+
device: th.device | str = "cpu"
|
|
9
|
+
regularizers: list[Callable[..., th.Tensor | float]] = field(default_factory=list)
|
|
10
|
+
reg_lams: list[float] = field(default_factory=list)
|
|
11
|
+
num_actions: int = 2
|
|
12
|
+
clip_grad: bool = True
|
|
13
|
+
grad_norm_clip: float = 1.0
|
|
14
|
+
use_bias: bool = True
|
|
15
|
+
value_criterion: Callable = field(default_factory=th.nn.MSELoss)
|
|
16
|
+
value_lambda: float = 0.1
|
|
17
|
+
value_targets: str = "td"
|
|
18
|
+
gamma: float = 0.99
|
|
19
|
+
advantage_bootstrap: bool = True
|
|
20
|
+
off_policy_iterations: int = 0
|
|
21
|
+
ppo_clipping: float = 0.1
|
|
22
|
+
use_entropy: bool = False
|
|
23
|
+
entropy_max_lambda: float = 0.0
|
|
24
|
+
entropy_min_lambda: float = 0.0
|
|
25
|
+
entropy_anneal_steps: int = 1
|
|
26
|
+
normalize_advantages: bool = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ActorCritic:
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
actor: th.nn.Module,
|
|
34
|
+
optimizer: th.optim.Optimizer,
|
|
35
|
+
config: ActorCriticConfig,
|
|
36
|
+
):
|
|
37
|
+
|
|
38
|
+
device = th.device(config.device)
|
|
39
|
+
self.actor = actor
|
|
40
|
+
self.optimizer = optimizer
|
|
41
|
+
self.all_parameters = list(actor.parameters())
|
|
42
|
+
self.actor.to(device)
|
|
43
|
+
self.device = device
|
|
44
|
+
self.config = config
|
|
45
|
+
|
|
46
|
+
if len(config.regularizers) != len(config.reg_lams):
|
|
47
|
+
raise ValueError("Length of regularizers and reg_lams lists must match.")
|
|
48
|
+
|
|
49
|
+
if not config.use_bias and config.advantage_bootstrap:
|
|
50
|
+
raise ValueError("advantage_bootstrap=True requires use_bias=True.")
|
|
51
|
+
|
|
52
|
+
if config.value_targets not in {"returns", "td"}:
|
|
53
|
+
raise ValueError("value_targets must be either 'returns' or 'td'.")
|
|
54
|
+
|
|
55
|
+
if config.use_entropy and config.entropy_anneal_steps <= 1:
|
|
56
|
+
raise ValueError("entropy_anneal_steps must be > 1 when use_entropy=True.")
|
|
57
|
+
if config.use_entropy and config.entropy_max_lambda < config.entropy_min_lambda:
|
|
58
|
+
raise ValueError("entropy_max_lambda must be >= entropy_min_lambda.")
|
|
59
|
+
|
|
60
|
+
self.entropy_step = 0
|
|
61
|
+
|
|
62
|
+
def _value_loss(
|
|
63
|
+
self,
|
|
64
|
+
returns: th.Tensor,
|
|
65
|
+
rewards: th.Tensor,
|
|
66
|
+
dones: th.Tensor,
|
|
67
|
+
values: th.Tensor,
|
|
68
|
+
next_values: th.Tensor,
|
|
69
|
+
) -> th.Tensor:
|
|
70
|
+
if self.config.value_targets == "returns":
|
|
71
|
+
targets = returns
|
|
72
|
+
elif self.config.value_targets == "td":
|
|
73
|
+
targets = rewards + self.config.gamma * (~dones * next_values)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Unknown value_targets: {self.config.value_targets}")
|
|
76
|
+
|
|
77
|
+
return self.config.value_criterion(values, targets)
|
|
78
|
+
|
|
79
|
+
def _get_policy(self, logits: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
|
80
|
+
probs = th.nn.functional.softmax(logits, dim=-1)
|
|
81
|
+
pi = probs.gather(dim=-1, index=actions)
|
|
82
|
+
return pi
|
|
83
|
+
|
|
84
|
+
def _entropy_lambda(self) -> float:
|
|
85
|
+
progress = min(self.entropy_step / (self.config.entropy_anneal_steps - 1), 1.0)
|
|
86
|
+
|
|
87
|
+
return (
|
|
88
|
+
self.config.entropy_max_lambda
|
|
89
|
+
+ progress * (self.config.entropy_min_lambda - self.config.entropy_max_lambda)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _entropy_loss(self, logits: th.Tensor) -> th.Tensor:
|
|
93
|
+
entropy_lambda = self._entropy_lambda()
|
|
94
|
+
probs = th.nn.functional.softmax(logits, dim=-1)
|
|
95
|
+
log_probs = th.nn.functional.log_softmax(logits, dim=-1)
|
|
96
|
+
entropy = -(probs * log_probs).sum(dim=-1).mean()
|
|
97
|
+
return -entropy_lambda * entropy
|
|
98
|
+
|
|
99
|
+
def _advantages(
|
|
100
|
+
self,
|
|
101
|
+
returns: th.Tensor,
|
|
102
|
+
rewards: th.Tensor,
|
|
103
|
+
dones: th.Tensor,
|
|
104
|
+
values: th.Tensor,
|
|
105
|
+
next_values: th.Tensor
|
|
106
|
+
) -> th.Tensor:
|
|
107
|
+
advantages = None
|
|
108
|
+
if self.config.advantage_bootstrap:
|
|
109
|
+
advantages = rewards + self.config.gamma * (~dones * next_values)
|
|
110
|
+
else:
|
|
111
|
+
advantages = returns
|
|
112
|
+
if self.config.use_bias:
|
|
113
|
+
advantages = advantages - values.detach()
|
|
114
|
+
|
|
115
|
+
if self.config.normalize_advantages and advantages.numel() > 1:
|
|
116
|
+
advantages = (advantages - advantages.mean()) / (
|
|
117
|
+
advantages.std(unbiased=False) + 1e-8
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return advantages
|
|
121
|
+
|
|
122
|
+
def _policy_loss(self, pi: th.Tensor, advantage: th.Tensor) -> th.Tensor:
|
|
123
|
+
|
|
124
|
+
if self.old_pi is None:
|
|
125
|
+
self.old_pi = pi.detach()
|
|
126
|
+
return -th.mean(pi.clamp_min(1e-8).log() * advantage.detach())
|
|
127
|
+
else:
|
|
128
|
+
ratios = pi / self.old_pi.detach()
|
|
129
|
+
loss = advantage.detach() * ratios
|
|
130
|
+
ppo_loss = th.clamp(ratios, 1-self.config.ppo_clipping, 1+self.config.ppo_clipping) * advantage.detach()
|
|
131
|
+
loss = th.min(loss, ppo_loss)
|
|
132
|
+
return -th.mean(loss)
|
|
133
|
+
|
|
134
|
+
def train(
|
|
135
|
+
self,
|
|
136
|
+
rewards: th.Tensor, # float32, [B,1]
|
|
137
|
+
dones: th.Tensor, # bool or float(0/1), [B,1]
|
|
138
|
+
states: th.Tensor, # float32, [B, obs_dim] or [B,C,H,W]
|
|
139
|
+
actions: th.Tensor, # int64, [B,1]
|
|
140
|
+
next_states: th.Tensor, # float32, same as states
|
|
141
|
+
returns: th.Tensor, # float32, [B,1]
|
|
142
|
+
) -> float:
|
|
143
|
+
|
|
144
|
+
self.actor.train(True)
|
|
145
|
+
need_next_values = (
|
|
146
|
+
self.config.advantage_bootstrap
|
|
147
|
+
or self.config.value_targets == "td"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.old_pi, loss_sum = None, 0
|
|
151
|
+
|
|
152
|
+
for _ in range(1 + self.config.off_policy_iterations):
|
|
153
|
+
# 1. Compute model policy, values and next_values
|
|
154
|
+
output: th.Tensor = self.actor(states)
|
|
155
|
+
logits = output[:, :self.config.num_actions]
|
|
156
|
+
values = output[:, self.config.num_actions:self.config.num_actions + 1]
|
|
157
|
+
|
|
158
|
+
next_values = None
|
|
159
|
+
if need_next_values:
|
|
160
|
+
with th.no_grad():
|
|
161
|
+
next_output: th.Tensor = self.actor(next_states)
|
|
162
|
+
next_values = next_output[:, self.config.num_actions:self.config.num_actions + 1]
|
|
163
|
+
pi = self._get_policy(logits, actions)
|
|
164
|
+
|
|
165
|
+
# 2. Compute losses
|
|
166
|
+
policy_loss = self._policy_loss(
|
|
167
|
+
pi,
|
|
168
|
+
self._advantages(
|
|
169
|
+
returns,
|
|
170
|
+
rewards,
|
|
171
|
+
dones,
|
|
172
|
+
values,
|
|
173
|
+
next_values
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
entropy_loss = 0.0
|
|
177
|
+
if self.config.use_entropy:
|
|
178
|
+
entropy_loss = self._entropy_loss(logits)
|
|
179
|
+
value_loss = 0.0
|
|
180
|
+
if self.config.use_bias:
|
|
181
|
+
value_loss = self.config.value_lambda * self._value_loss(
|
|
182
|
+
returns,
|
|
183
|
+
rewards,
|
|
184
|
+
dones,
|
|
185
|
+
values,
|
|
186
|
+
next_values
|
|
187
|
+
)
|
|
188
|
+
reg_loss = (
|
|
189
|
+
sum(
|
|
190
|
+
lam * reg(self.actor, rewards, dones, states, actions, next_states)
|
|
191
|
+
for reg, lam in zip(self.config.regularizers, self.config.reg_lams)
|
|
192
|
+
)
|
|
193
|
+
if self.config.regularizers else 0.0
|
|
194
|
+
)
|
|
195
|
+
loss = policy_loss + value_loss + entropy_loss + reg_loss
|
|
196
|
+
|
|
197
|
+
# 3. Optimize
|
|
198
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
199
|
+
loss.backward()
|
|
200
|
+
if self.config.clip_grad:
|
|
201
|
+
th.nn.utils.clip_grad_norm_(self.all_parameters, self.config.grad_norm_clip)
|
|
202
|
+
self.optimizer.step()
|
|
203
|
+
loss_sum += loss.item()
|
|
204
|
+
|
|
205
|
+
self.entropy_step += 1
|
|
206
|
+
return loss_sum
|