helloRL 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- helloRL/__init__.py +7 -0
- helloRL/modal_training.py +98 -0
- helloRL/modules/__init__.py +17 -0
- helloRL/modules/a2c.py +102 -0
- helloRL/modules/actors.py +256 -0
- helloRL/modules/agents.py +117 -0
- helloRL/modules/critic_loss_clipped.py +20 -0
- helloRL/modules/critic_loss_q.py +18 -0
- helloRL/modules/critics.py +145 -0
- helloRL/modules/epochs.py +53 -0
- helloRL/modules/foundation.py +163 -0
- helloRL/modules/gae.py +62 -0
- helloRL/modules/grad_norm.py +10 -0
- helloRL/modules/lr_anneal.py +14 -0
- helloRL/modules/monte_carlo.py +168 -0
- helloRL/modules/params.py +30 -0
- helloRL/modules/po_clipped.py +16 -0
- helloRL/modules/replay.py +167 -0
- helloRL/modules/rollout_data.py +105 -0
- helloRL/modules/trainer.py +152 -0
- helloRL/trainer.py +150 -0
- helloRL/utils/__init__.py +4 -0
- helloRL/utils/plot.py +326 -0
- helloRL/utils/progress.py +182 -0
- helloRL/utils/session_tracker.py +61 -0
- helloRL/utils/sim.py +138 -0
- hellorl-1.0.dist-info/METADATA +181 -0
- hellorl-1.0.dist-info/RECORD +30 -0
- hellorl-1.0.dist-info/WHEEL +4 -0
- hellorl-1.0.dist-info/licenses/LICENSE +21 -0
helloRL/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import modal
|
|
2
|
+
import time
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
from . import trainer
|
|
6
|
+
from .utils import plot
|
|
7
|
+
|
|
8
|
+
def gather_modal_results_for_calls(calls, n_timesteps, n_sessions, progress_dict):
|
|
9
|
+
# Monitor progress by polling the Dict
|
|
10
|
+
total_timesteps = n_sessions * n_timesteps
|
|
11
|
+
|
|
12
|
+
# this file is imported to modal, but progress bar shouldn't be used there
|
|
13
|
+
from .utils.progress import RemoteProgressBar
|
|
14
|
+
|
|
15
|
+
with RemoteProgressBar("Training", n_steps=total_timesteps, n_sessions=n_sessions) as bar:
|
|
16
|
+
while True:
|
|
17
|
+
time.sleep(0.2) # Poll every 0.2 seconds
|
|
18
|
+
|
|
19
|
+
# Sum up completed timesteps across all sessions
|
|
20
|
+
completed_timesteps = 0
|
|
21
|
+
completed_sessions = 0
|
|
22
|
+
|
|
23
|
+
for i in range(n_sessions):
|
|
24
|
+
value = progress_dict.get(i)
|
|
25
|
+
|
|
26
|
+
if value is None:
|
|
27
|
+
continue
|
|
28
|
+
|
|
29
|
+
completed_timesteps += value
|
|
30
|
+
|
|
31
|
+
if value == n_timesteps:
|
|
32
|
+
completed_sessions += 1
|
|
33
|
+
|
|
34
|
+
bar.update_completed_sessions(completed_sessions)
|
|
35
|
+
bar.update_value(completed_timesteps)
|
|
36
|
+
|
|
37
|
+
if completed_sessions >= n_sessions:
|
|
38
|
+
break
|
|
39
|
+
|
|
40
|
+
# Gather results
|
|
41
|
+
results = modal.FunctionCall.gather(*calls)
|
|
42
|
+
|
|
43
|
+
return results
|
|
44
|
+
|
|
45
|
+
def train_session_on_modal_with_func(train_func, session_id, n_timesteps, progress_dict):
|
|
46
|
+
def progress_callback(current_timestep):
|
|
47
|
+
progress_dict[session_id] = current_timestep
|
|
48
|
+
|
|
49
|
+
train_func_return = train_func(progress_callback=progress_callback)
|
|
50
|
+
progress_callback(n_timesteps)
|
|
51
|
+
|
|
52
|
+
return train_func_return
|
|
53
|
+
|
|
54
|
+
def create_modal_train_function(app, image, timeout=3600):
|
|
55
|
+
@app.function(image=image, timeout=timeout, serialized=True)
|
|
56
|
+
def _modal_train(n_timesteps, setup_func, session_id=None, progress_dict=None):
|
|
57
|
+
agent, env_name, continuous, params = setup_func()
|
|
58
|
+
|
|
59
|
+
training_func = partial(trainer.train, agent, env_name, continuous=continuous,
|
|
60
|
+
n_timesteps=n_timesteps, should_print=False)
|
|
61
|
+
|
|
62
|
+
train_results = train_session_on_modal_with_func(training_func, session_id,
|
|
63
|
+
n_timesteps, progress_dict)
|
|
64
|
+
|
|
65
|
+
return (*train_results, agent, env_name, continuous, params)
|
|
66
|
+
|
|
67
|
+
return _modal_train
|
|
68
|
+
|
|
69
|
+
def train(n_sessions, n_timesteps, setup_func, app, image, timeout=3600):
|
|
70
|
+
# Create the modal function before entering app.run() so it gets registered
|
|
71
|
+
modal_train = create_modal_train_function(app, image, timeout)
|
|
72
|
+
|
|
73
|
+
with app.run():
|
|
74
|
+
progress_dict = modal.Dict.from_name("training-progress", create_if_missing=True)
|
|
75
|
+
progress_dict.clear()
|
|
76
|
+
|
|
77
|
+
calls = [modal_train.spawn(n_timesteps, setup_func=setup_func, session_id=i, progress_dict=progress_dict
|
|
78
|
+
) for i in range(n_sessions)]
|
|
79
|
+
|
|
80
|
+
results = gather_modal_results_for_calls(calls, n_timesteps, n_sessions, progress_dict)
|
|
81
|
+
|
|
82
|
+
return results
|
|
83
|
+
|
|
84
|
+
def plot_results(results, title, n_timesteps, nb_name=None, save_dir=None):
|
|
85
|
+
plot_results = [(returns, lengths) for returns, lengths, _, _, _, _ in results]
|
|
86
|
+
_, _, agent, env_name, continuous, params = results[0]
|
|
87
|
+
|
|
88
|
+
plot.plot_sessions(
|
|
89
|
+
plot_results,
|
|
90
|
+
title,
|
|
91
|
+
env_name=env_name,
|
|
92
|
+
continuous=continuous,
|
|
93
|
+
params=params,
|
|
94
|
+
agent=agent,
|
|
95
|
+
n_timesteps=n_timesteps,
|
|
96
|
+
nb_name=nb_name,
|
|
97
|
+
save_dir=save_dir
|
|
98
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .actors import *
|
|
2
|
+
from .critics import *
|
|
3
|
+
from .agents import *
|
|
4
|
+
from .params import *
|
|
5
|
+
from .foundation import *
|
|
6
|
+
|
|
7
|
+
from .a2c import *
|
|
8
|
+
from .critic_loss_clipped import *
|
|
9
|
+
from .critic_loss_q import *
|
|
10
|
+
from .epochs import *
|
|
11
|
+
from .gae import *
|
|
12
|
+
from .grad_norm import *
|
|
13
|
+
from .lr_anneal import *
|
|
14
|
+
from .monte_carlo import *
|
|
15
|
+
from .po_clipped import *
|
|
16
|
+
from .replay import *
|
|
17
|
+
from .rollout_data import *
|
helloRL/modules/a2c.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .foundation import *
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class AdvantageTransformNormalize(AdvantageTransform):
|
|
8
|
+
# mean or sum normalization
|
|
9
|
+
method: str = "mean"
|
|
10
|
+
|
|
11
|
+
def transform(self, raw_advantages: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
if self.method == "sum":
|
|
13
|
+
return raw_advantages / (raw_advantages.sum() + 1e-9)
|
|
14
|
+
else: # default to mean normalization
|
|
15
|
+
if len(raw_advantages) < 2:
|
|
16
|
+
return raw_advantages
|
|
17
|
+
|
|
18
|
+
return (raw_advantages - raw_advantages.mean()) / (raw_advantages.std() + 1e-9)
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class RolloutMethodA2C(RolloutMethod):
|
|
22
|
+
n_steps: int = 16
|
|
23
|
+
n_envs: int = 4
|
|
24
|
+
|
|
25
|
+
def collect_rollout_data(
|
|
26
|
+
self, envs: gym.vector.VectorEnv, initial_states: torch.Tensor, agent: AgentProtocol, tracker: SessionTracker
|
|
27
|
+
) -> tuple[RolloutData, list[np.ndarray]]:
|
|
28
|
+
n_envs = envs.num_envs
|
|
29
|
+
state_space = envs.single_observation_space.shape[0]
|
|
30
|
+
action_dim = envs.single_action_space.shape[0] if isinstance(envs.single_action_space, gym.spaces.Box) else 1
|
|
31
|
+
|
|
32
|
+
rollout_states_t = torch.zeros(n_envs, self.n_steps, state_space) # (n_envs, n_steps, state_space)
|
|
33
|
+
rollout_actions_t = torch.zeros(n_envs, self.n_steps, action_dim) # (n_envs, n_steps, action_dim)
|
|
34
|
+
rollout_next_states_t = torch.zeros(n_envs, self.n_steps, state_space) # (n_envs, n_steps, state_space)
|
|
35
|
+
rollout_rewards_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
36
|
+
rollout_terminateds_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
37
|
+
rollout_truncateds_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
38
|
+
rollout_dones_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
39
|
+
rollout_critic_values_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
40
|
+
rollout_log_probs_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
|
|
41
|
+
|
|
42
|
+
states = initial_states
|
|
43
|
+
|
|
44
|
+
for step in range(self.n_steps):
|
|
45
|
+
tracker.increment_timestep(n=n_envs)
|
|
46
|
+
|
|
47
|
+
states_t = torch.tensor(states).float() # (n_envs, state_space)
|
|
48
|
+
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
actions_t, _ = agent.actor.output(states_t)
|
|
51
|
+
# apply exploration noise etc to the actions
|
|
52
|
+
actions_t = agent.actor.exploration(actions_t)
|
|
53
|
+
log_probs_t, _ = agent.actor.get_log_prob_and_entropy(states_t, actions_t)
|
|
54
|
+
critic_values_t = agent.get_critic_value(states_t, actions_t)
|
|
55
|
+
actions_np = actions_t.squeeze(-1).numpy() # (n_envs)
|
|
56
|
+
next_states, rewards, terminateds, truncateds, infos = envs.step(actions_np)
|
|
57
|
+
|
|
58
|
+
dones = terminateds | truncateds
|
|
59
|
+
dones_t = torch.tensor(dones).float().reshape(n_envs, 1) # (n_envs, 1 value)
|
|
60
|
+
terminateds_t = torch.tensor(terminateds).float().reshape(n_envs, 1) # (n_envs, 1 value)
|
|
61
|
+
truncateds_t = torch.tensor(truncateds).float().reshape(n_envs, 1) # (n_envs, 1 value)
|
|
62
|
+
|
|
63
|
+
rewards_t = torch.tensor(rewards).float().reshape(n_envs, 1) # (n_envs, 1 value)
|
|
64
|
+
|
|
65
|
+
rollout_states_t[:, step] = states_t # (n_envs, n_steps, state_space)
|
|
66
|
+
rollout_actions_t[:, step] = actions_t # (n_envs, n_steps, 1 value)
|
|
67
|
+
rollout_next_states_t[:, step] = torch.tensor(next_states).float() # (n_envs, n_steps, state_space)
|
|
68
|
+
rollout_rewards_t[:, step] = rewards_t # (n_envs, n_steps, 1 value)
|
|
69
|
+
rollout_terminateds_t[:, step] = terminateds_t # (n_envs, n_steps, 1 value)
|
|
70
|
+
rollout_truncateds_t[:, step] = truncateds_t # (n_envs, n_steps, 1 value)
|
|
71
|
+
rollout_dones_t[:, step] = dones_t # (n_envs, n_steps, 1 value)
|
|
72
|
+
rollout_critic_values_t[:, step] = critic_values_t # (n_envs, n_steps, 1 value)
|
|
73
|
+
rollout_log_probs_t[:, step] = log_probs_t # (n_envs, n_steps, 1 value)
|
|
74
|
+
|
|
75
|
+
if dones.any():
|
|
76
|
+
episode_returns = infos['episode']['r'][dones]
|
|
77
|
+
episode_lengths = infos['episode']['l'][dones]
|
|
78
|
+
|
|
79
|
+
tracker.finish_episodes(episode_returns, episode_lengths)
|
|
80
|
+
|
|
81
|
+
next_states, _ = envs.reset(options={'reset_mask': dones})
|
|
82
|
+
|
|
83
|
+
states = next_states
|
|
84
|
+
|
|
85
|
+
rollout_returns_t = torch.zeros_like(rollout_rewards_t) # (n_envs, n_steps, 1)
|
|
86
|
+
rollout_advantages_t = torch.zeros_like(rollout_rewards_t) # (n_envs, n_steps, 1)
|
|
87
|
+
|
|
88
|
+
rollout_data = RolloutData(
|
|
89
|
+
states=rollout_states_t,
|
|
90
|
+
actions=rollout_actions_t,
|
|
91
|
+
next_states=rollout_next_states_t,
|
|
92
|
+
rewards=rollout_rewards_t,
|
|
93
|
+
terminateds=rollout_terminateds_t,
|
|
94
|
+
truncateds=rollout_truncateds_t,
|
|
95
|
+
dones=rollout_dones_t,
|
|
96
|
+
critic_values=rollout_critic_values_t,
|
|
97
|
+
log_probs=rollout_log_probs_t,
|
|
98
|
+
returns=rollout_returns_t,
|
|
99
|
+
advantages=rollout_advantages_t
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return rollout_data, next_states
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch.distributions import Categorical, Normal
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
HIDDEN_SIZES_DEFAULT = [64, 64]
|
|
9
|
+
|
|
10
|
+
class ActorProtocol(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def output(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
+
"""Given a state, output an action and its log probability.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
state: The input state (torch.Tensor)
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
action: The action to take (torch.Tensor)
|
|
20
|
+
log_prob: The log probability of the action (torch.Tensor)
|
|
21
|
+
"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_loss(self, data, critic_value_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
|
26
|
+
"""Compute the actor loss given the data and critic.
|
|
27
|
+
Advantages should be precomputed and stored in data.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
data: The rollout data (RolloutData)
|
|
31
|
+
critic_value_func: A function that takes state and action tensors and returns critic values
|
|
32
|
+
(Callable[[torch.Tensor, torch.Tensor], torch.Tensor])
|
|
33
|
+
"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def exploration(self, action) -> torch.Tensor:
|
|
38
|
+
"""Apply exploration noise to the given action.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
action: The action to apply exploration to (torch.Tensor)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
explored_action: The action after applying exploration (torch.Tensor)
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def get_log_prob_and_entropy(self, state, action) -> tuple[torch.Tensor, torch.Tensor]:
|
|
50
|
+
"""Given a state and action, output its log probability and entropy.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
state: The input state (torch.Tensor)
|
|
54
|
+
action: The input action (torch.Tensor)
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
log_prob: The log probability of the action (torch.Tensor)
|
|
58
|
+
entropy: The entropy of the action distribution (torch.Tensor)
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class ActorParams:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class PolicyObjectiveMethod(ABC):
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def compute_policy_objective(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
"""Compute the policy objective given the ratio of new to old policy probabilities and the advantages.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
ratio: Ratio of new to old policy probabilities (torch.Tensor)
|
|
74
|
+
advantages: Advantages computed from the advantage method (torch.Tensor)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Policy objective (torch.Tensor)
|
|
78
|
+
"""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
class PolicyObjectiveMethodStandard(PolicyObjectiveMethod):
|
|
82
|
+
def compute_policy_objective(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
return ratio * advantages
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class DistributionalActorParams(ActorParams):
|
|
87
|
+
policy_objective_method: PolicyObjectiveMethod = field(
|
|
88
|
+
default_factory=PolicyObjectiveMethodStandard
|
|
89
|
+
)
|
|
90
|
+
entropy_coef: float = 0.0
|
|
91
|
+
|
|
92
|
+
class DiscreteActorNetwork(nn.Module):
|
|
93
|
+
def __init__(self, state_dim, action_dim, hidden_sizes=HIDDEN_SIZES_DEFAULT):
|
|
94
|
+
super(DiscreteActorNetwork, self).__init__()
|
|
95
|
+
|
|
96
|
+
self.hidden_layers = nn.ModuleList()
|
|
97
|
+
|
|
98
|
+
prev_size = state_dim
|
|
99
|
+
|
|
100
|
+
for size in hidden_sizes:
|
|
101
|
+
layer = nn.Linear(prev_size, size, dtype=torch.float32)
|
|
102
|
+
self.hidden_layers.append(layer)
|
|
103
|
+
prev_size = size
|
|
104
|
+
|
|
105
|
+
self.head = nn.Linear(prev_size, action_dim, dtype=torch.float32)
|
|
106
|
+
|
|
107
|
+
def forward(self, state): # logits: (batch_size, action_space)
|
|
108
|
+
x = state
|
|
109
|
+
|
|
110
|
+
for layer in self.hidden_layers:
|
|
111
|
+
x = torch.relu(layer(x))
|
|
112
|
+
|
|
113
|
+
logits = self.head(x)
|
|
114
|
+
return logits
|
|
115
|
+
|
|
116
|
+
class DistributionalActor(ActorProtocol, nn.Module):
|
|
117
|
+
def __init__(self, params=DistributionalActorParams()):
|
|
118
|
+
super(DistributionalActor, self).__init__()
|
|
119
|
+
self.params = params
|
|
120
|
+
|
|
121
|
+
def get_log_prob_and_entropy(self, state, action):
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
def forward(self, state):
|
|
125
|
+
return self.network(state)
|
|
126
|
+
|
|
127
|
+
def get_loss(self, data, critic_value_func):
|
|
128
|
+
old_log_probs = data.log_probs.detach() # stored at rollout time
|
|
129
|
+
|
|
130
|
+
new_log_probs, entropies = self.get_log_prob_and_entropy(data.states, data.actions)
|
|
131
|
+
log_ratio = new_log_probs - old_log_probs
|
|
132
|
+
ratio = torch.exp(log_ratio)
|
|
133
|
+
|
|
134
|
+
policy_objective = self.params.policy_objective_method.compute_policy_objective(ratio, data.advantages)
|
|
135
|
+
actor_loss = -(policy_objective.sum())
|
|
136
|
+
actor_loss -= (self.params.entropy_coef * entropies.sum())
|
|
137
|
+
|
|
138
|
+
return actor_loss
|
|
139
|
+
|
|
140
|
+
def exploration(self, action) -> torch.Tensor:
|
|
141
|
+
return action
|
|
142
|
+
|
|
143
|
+
class DiscreteActor(DistributionalActor):
|
|
144
|
+
def __init__(self, state_dim, action_dim, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DistributionalActorParams()):
|
|
145
|
+
super(DiscreteActor, self).__init__(params=params)
|
|
146
|
+
|
|
147
|
+
self.network = DiscreteActorNetwork(state_dim, action_dim, hidden_sizes=hidden_sizes)
|
|
148
|
+
|
|
149
|
+
def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
|
|
150
|
+
actor_logits = self.network.forward(state)
|
|
151
|
+
action_distribution = Categorical(logits=actor_logits)
|
|
152
|
+
action = action_distribution.sample().unsqueeze(-1) # shape: (batch_size, 1)
|
|
153
|
+
log_prob = action_distribution.log_prob(action.squeeze()).unsqueeze(-1) # shape: (batch_size, 1)
|
|
154
|
+
|
|
155
|
+
return action, log_prob
|
|
156
|
+
|
|
157
|
+
def get_log_prob_and_entropy(self, state, action):
|
|
158
|
+
actor_logits = self.network.forward(state)
|
|
159
|
+
action_distribution = Categorical(logits=actor_logits)
|
|
160
|
+
log_prob = action_distribution.log_prob(action.squeeze(-1)).unsqueeze(-1)
|
|
161
|
+
entropy = action_distribution.entropy().unsqueeze(-1) # (batch_size, 1)
|
|
162
|
+
return log_prob, entropy
|
|
163
|
+
|
|
164
|
+
class ContinuousActorNetwork(nn.Module):
|
|
165
|
+
def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT):
|
|
166
|
+
super(ContinuousActorNetwork, self).__init__()
|
|
167
|
+
|
|
168
|
+
self.hidden_layers = nn.ModuleList()
|
|
169
|
+
|
|
170
|
+
prev_size = state_dim
|
|
171
|
+
|
|
172
|
+
for size in hidden_sizes:
|
|
173
|
+
layer = nn.Linear(prev_size, size, dtype=torch.float32)
|
|
174
|
+
self.hidden_layers.append(layer)
|
|
175
|
+
prev_size = size
|
|
176
|
+
|
|
177
|
+
self.head = nn.Linear(prev_size, action_dim, dtype=torch.float32)
|
|
178
|
+
self.action_range = action_range
|
|
179
|
+
|
|
180
|
+
def forward(self, state): # logits: (batch_size, action_space)
|
|
181
|
+
x = state
|
|
182
|
+
|
|
183
|
+
for layer in self.hidden_layers:
|
|
184
|
+
x = torch.relu(layer(x))
|
|
185
|
+
|
|
186
|
+
x = torch.tanh(self.head(x)) # Bounds to -1 to 1
|
|
187
|
+
|
|
188
|
+
mins = self.action_range.min(dim=0).values
|
|
189
|
+
maxs = self.action_range.max(dim=0).values
|
|
190
|
+
length = maxs - mins
|
|
191
|
+
x = (x * (length / 2)) + ((mins + maxs) / 2)
|
|
192
|
+
|
|
193
|
+
return x
|
|
194
|
+
|
|
195
|
+
class StochasticActor(DistributionalActor):
|
|
196
|
+
def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DistributionalActorParams()):
|
|
197
|
+
super(StochasticActor, self).__init__(params=params)
|
|
198
|
+
self.network = ContinuousActorNetwork(state_dim, action_dim, action_range, hidden_sizes=hidden_sizes)
|
|
199
|
+
self.log_std = nn.Parameter(torch.zeros(action_dim))
|
|
200
|
+
|
|
201
|
+
def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
|
|
202
|
+
# mean, also referred to as 'mu'
|
|
203
|
+
mean = self.network(state)
|
|
204
|
+
std = torch.exp(self.log_std) # Exp to make positive
|
|
205
|
+
dist = Normal(mean, std)
|
|
206
|
+
action = dist.sample() # Shape: (batch_size, action_dim)
|
|
207
|
+
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # Sum for multi-dim, shape: (batch_size, 1)
|
|
208
|
+
|
|
209
|
+
return action, log_prob
|
|
210
|
+
|
|
211
|
+
def get_log_prob_and_entropy(self, state, action):
|
|
212
|
+
mean = self.network(state)
|
|
213
|
+
std = torch.exp(self.log_std) # Exp to make positive
|
|
214
|
+
dist = Normal(mean, std)
|
|
215
|
+
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # (batch_size, 1)
|
|
216
|
+
entropy = dist.entropy().sum(dim=-1, keepdim=True) # (batch_size, 1)
|
|
217
|
+
return log_prob, entropy
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class DeterministicActorParams(ActorParams):
|
|
221
|
+
exploration_std: float = 0.1
|
|
222
|
+
|
|
223
|
+
class DeterministicActor(ActorProtocol, nn.Module):
|
|
224
|
+
def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DeterministicActorParams()):
|
|
225
|
+
super(DeterministicActor, self).__init__()
|
|
226
|
+
self.network = ContinuousActorNetwork(state_dim, action_dim, action_range, hidden_sizes=hidden_sizes)
|
|
227
|
+
self.params = params
|
|
228
|
+
|
|
229
|
+
def forward(self, state): # action: (batch_size, 1)
|
|
230
|
+
return self.network(state)
|
|
231
|
+
|
|
232
|
+
def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
|
|
233
|
+
# return zeroes instead of logprobs
|
|
234
|
+
return self.network(state), torch.zeros((state.shape[0], 1), dtype=torch.float32)
|
|
235
|
+
|
|
236
|
+
def get_loss(self, data, critic_value_func):
|
|
237
|
+
actions_pi = self.network(data.states)
|
|
238
|
+
actor_loss = -(critic_value_func(data.states, actions_pi).mean())
|
|
239
|
+
|
|
240
|
+
return actor_loss
|
|
241
|
+
|
|
242
|
+
def exploration(self, action) -> torch.Tensor:
|
|
243
|
+
action = action + (torch.randn_like(action) * self.params.exploration_std)
|
|
244
|
+
|
|
245
|
+
mins = self.network.action_range.min(dim=0).values
|
|
246
|
+
maxs = self.network.action_range.max(dim=0).values
|
|
247
|
+
action = torch.clamp(action, mins, maxs)
|
|
248
|
+
|
|
249
|
+
return action
|
|
250
|
+
|
|
251
|
+
def get_log_prob_and_entropy(self, state, action):
|
|
252
|
+
# return zeroes instead of logprobs and entropies
|
|
253
|
+
batch_size = state.shape[0]
|
|
254
|
+
log_prob = torch.zeros((batch_size, 1), dtype=torch.float32)
|
|
255
|
+
entropy = torch.zeros((batch_size, 1), dtype=torch.float32)
|
|
256
|
+
return log_prob, entropy
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
|
|
7
|
+
from .actors import ActorProtocol
|
|
8
|
+
from .critics import CriticProtocol
|
|
9
|
+
|
|
10
|
+
class AgentProtocol(ABC):
|
|
11
|
+
actor: ActorProtocol
|
|
12
|
+
critics: list[CriticProtocol]
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def get_action(self, state):
|
|
16
|
+
"""Get the action from the actor for the given state.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
state: The input state (torch.Tensor)
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def get_critic_value(self, state, action):
|
|
24
|
+
"""Get the critic value for the given state and action.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
state: The input state (torch.Tensor)
|
|
28
|
+
action: The input action (torch.Tensor)
|
|
29
|
+
"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def get_target_action(self, state):
|
|
34
|
+
"""Get the target actor's output for the given state.
|
|
35
|
+
A simple implementation could use the main actor if no target networks are used.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
state: The input state (torch.Tensor)
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def get_target_critic_value(self, state, action):
|
|
44
|
+
"""Get the target critic's output for the given state and action.
|
|
45
|
+
A simple implementation could use the main critics if no target networks are used.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
state: The input state (torch.Tensor)
|
|
49
|
+
action: The input action (torch.Tensor)
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def update_targets(self):
|
|
55
|
+
"""Update the target networks, if applicable."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class AgentParams:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
class Agent(AgentProtocol):
|
|
63
|
+
def __init__(self, actor: ActorProtocol, critics: list[CriticProtocol], params: AgentParams=AgentParams()):
|
|
64
|
+
|
|
65
|
+
super(Agent, self).__init__()
|
|
66
|
+
|
|
67
|
+
self.actor = actor
|
|
68
|
+
self.critics = critics
|
|
69
|
+
self.params = params
|
|
70
|
+
|
|
71
|
+
def get_action(self, state):
|
|
72
|
+
action, _ = self.actor.output(state)
|
|
73
|
+
|
|
74
|
+
return action
|
|
75
|
+
|
|
76
|
+
def get_critic_value(self, state, action):
|
|
77
|
+
# Return the minimum value across all critics
|
|
78
|
+
values = [critic.output(state, action) for critic in self.critics]
|
|
79
|
+
return torch.min(torch.stack(values), dim=0).values
|
|
80
|
+
|
|
81
|
+
def get_target_action(self, state):
|
|
82
|
+
return self.get_action(state)
|
|
83
|
+
|
|
84
|
+
def get_target_critic_value(self, state, action):
|
|
85
|
+
return self.get_critic_value(state, action)
|
|
86
|
+
|
|
87
|
+
def update_targets(self):
|
|
88
|
+
# No target networks to update
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class AgentWithTargetsParams(AgentParams):
|
|
93
|
+
tau: float = 0.1
|
|
94
|
+
|
|
95
|
+
class AgentWithTargets(Agent):
|
|
96
|
+
def __init__(self, actor: ActorProtocol, critics: list[CriticProtocol], params: AgentWithTargetsParams):
|
|
97
|
+
super(AgentWithTargets, self).__init__(actor, critics, params)
|
|
98
|
+
|
|
99
|
+
self.target_actor_network = deepcopy(self.actor.network)
|
|
100
|
+
self.target_critic_networks = [deepcopy(critic.network) for critic in self.critics]
|
|
101
|
+
|
|
102
|
+
def get_target_action(self, state):
|
|
103
|
+
return self.target_actor_network(state)
|
|
104
|
+
|
|
105
|
+
def get_target_critic_value(self, state, action):
|
|
106
|
+
# Return the minimum value across all target critics
|
|
107
|
+
values = [target_critic_network(state, action) for target_critic_network in self.target_critic_networks]
|
|
108
|
+
return torch.min(torch.stack(values), dim=0).values
|
|
109
|
+
|
|
110
|
+
def update_targets(self):
|
|
111
|
+
# Soft update targets
|
|
112
|
+
for target_param, param in zip(self.target_actor_network.parameters(), self.actor.network.parameters()):
|
|
113
|
+
target_param.data.copy_((self.params.tau * param.data) + ((1 - self.params.tau) * target_param.data))
|
|
114
|
+
|
|
115
|
+
for i, critic in enumerate(self.critics):
|
|
116
|
+
for target_param, param in zip(self.target_critic_networks[i].parameters(), critic.network.parameters()):
|
|
117
|
+
target_param.data.copy_((self.params.tau * param.data) + ((1 - self.params.tau) * target_param.data))
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .foundation import *
|
|
4
|
+
from .critics import CriticLossMethod
|
|
5
|
+
|
|
6
|
+
class CriticLossMethodClipped(CriticLossMethod):
|
|
7
|
+
def __init__(self, clip_range_vf: float):
|
|
8
|
+
self.clip_range_vf = clip_range_vf
|
|
9
|
+
|
|
10
|
+
def compute_critic_loss(self, reference_actor, reference_critic, new_values, data, gamma):
|
|
11
|
+
old_values = data.critic_values.detach()
|
|
12
|
+
|
|
13
|
+
value_delta = (new_values - old_values)
|
|
14
|
+
value_delta_clipped = torch.clamp(value_delta, -self.clip_range_vf, self.clip_range_vf)
|
|
15
|
+
values_clipped = old_values + value_delta_clipped
|
|
16
|
+
|
|
17
|
+
loss_unclipped = (new_values - data.returns).pow(2)
|
|
18
|
+
loss_clipped = (values_clipped - data.returns).pow(2)
|
|
19
|
+
|
|
20
|
+
return torch.maximum(loss_unclipped, loss_clipped).mean()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .foundation import *
|
|
4
|
+
from .critics import CriticLossMethod
|
|
5
|
+
|
|
6
|
+
class CriticLossMethodQ(CriticLossMethod):
|
|
7
|
+
def compute_critic_loss(self, target_action_func, target_critic_value_func, new_values, data, gamma):
|
|
8
|
+
with torch.no_grad():
|
|
9
|
+
# this should be the target actor
|
|
10
|
+
a_next = target_action_func(data.next_states)
|
|
11
|
+
# this should be the target critic
|
|
12
|
+
q_next = target_critic_value_func(data.next_states, a_next)
|
|
13
|
+
y = data.rewards + ((gamma * (1.0 - data.terminateds)) * q_next)
|
|
14
|
+
|
|
15
|
+
# new values come from the current critic, passed into this function
|
|
16
|
+
critic_loss = torch.nn.MSELoss()(new_values, y)
|
|
17
|
+
|
|
18
|
+
return critic_loss
|