helloRL 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
helloRL/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from .modules import *
2
+ from .utils import *
3
+
4
+ from . import trainer
5
+ from . import modal_training
6
+
7
+ __version__ = "1.0.0"
@@ -0,0 +1,98 @@
1
+ import modal
2
+ import time
3
+ from functools import partial
4
+
5
+ from . import trainer
6
+ from .utils import plot
7
+
8
+ def gather_modal_results_for_calls(calls, n_timesteps, n_sessions, progress_dict):
9
+ # Monitor progress by polling the Dict
10
+ total_timesteps = n_sessions * n_timesteps
11
+
12
+ # this file is imported to modal, but progress bar shouldn't be used there
13
+ from .utils.progress import RemoteProgressBar
14
+
15
+ with RemoteProgressBar("Training", n_steps=total_timesteps, n_sessions=n_sessions) as bar:
16
+ while True:
17
+ time.sleep(0.2) # Poll every 0.2 seconds
18
+
19
+ # Sum up completed timesteps across all sessions
20
+ completed_timesteps = 0
21
+ completed_sessions = 0
22
+
23
+ for i in range(n_sessions):
24
+ value = progress_dict.get(i)
25
+
26
+ if value is None:
27
+ continue
28
+
29
+ completed_timesteps += value
30
+
31
+ if value == n_timesteps:
32
+ completed_sessions += 1
33
+
34
+ bar.update_completed_sessions(completed_sessions)
35
+ bar.update_value(completed_timesteps)
36
+
37
+ if completed_sessions >= n_sessions:
38
+ break
39
+
40
+ # Gather results
41
+ results = modal.FunctionCall.gather(*calls)
42
+
43
+ return results
44
+
45
+ def train_session_on_modal_with_func(train_func, session_id, n_timesteps, progress_dict):
46
+ def progress_callback(current_timestep):
47
+ progress_dict[session_id] = current_timestep
48
+
49
+ train_func_return = train_func(progress_callback=progress_callback)
50
+ progress_callback(n_timesteps)
51
+
52
+ return train_func_return
53
+
54
+ def create_modal_train_function(app, image, timeout=3600):
55
+ @app.function(image=image, timeout=timeout, serialized=True)
56
+ def _modal_train(n_timesteps, setup_func, session_id=None, progress_dict=None):
57
+ agent, env_name, continuous, params = setup_func()
58
+
59
+ training_func = partial(trainer.train, agent, env_name, continuous=continuous,
60
+ n_timesteps=n_timesteps, should_print=False)
61
+
62
+ train_results = train_session_on_modal_with_func(training_func, session_id,
63
+ n_timesteps, progress_dict)
64
+
65
+ return (*train_results, agent, env_name, continuous, params)
66
+
67
+ return _modal_train
68
+
69
+ def train(n_sessions, n_timesteps, setup_func, app, image, timeout=3600):
70
+ # Create the modal function before entering app.run() so it gets registered
71
+ modal_train = create_modal_train_function(app, image, timeout)
72
+
73
+ with app.run():
74
+ progress_dict = modal.Dict.from_name("training-progress", create_if_missing=True)
75
+ progress_dict.clear()
76
+
77
+ calls = [modal_train.spawn(n_timesteps, setup_func=setup_func, session_id=i, progress_dict=progress_dict
78
+ ) for i in range(n_sessions)]
79
+
80
+ results = gather_modal_results_for_calls(calls, n_timesteps, n_sessions, progress_dict)
81
+
82
+ return results
83
+
84
+ def plot_results(results, title, n_timesteps, nb_name=None, save_dir=None):
85
+ plot_results = [(returns, lengths) for returns, lengths, _, _, _, _ in results]
86
+ _, _, agent, env_name, continuous, params = results[0]
87
+
88
+ plot.plot_sessions(
89
+ plot_results,
90
+ title,
91
+ env_name=env_name,
92
+ continuous=continuous,
93
+ params=params,
94
+ agent=agent,
95
+ n_timesteps=n_timesteps,
96
+ nb_name=nb_name,
97
+ save_dir=save_dir
98
+ )
@@ -0,0 +1,17 @@
1
+ from .actors import *
2
+ from .critics import *
3
+ from .agents import *
4
+ from .params import *
5
+ from .foundation import *
6
+
7
+ from .a2c import *
8
+ from .critic_loss_clipped import *
9
+ from .critic_loss_q import *
10
+ from .epochs import *
11
+ from .gae import *
12
+ from .grad_norm import *
13
+ from .lr_anneal import *
14
+ from .monte_carlo import *
15
+ from .po_clipped import *
16
+ from .replay import *
17
+ from .rollout_data import *
helloRL/modules/a2c.py ADDED
@@ -0,0 +1,102 @@
1
+ from dataclasses import dataclass
2
+ import torch
3
+
4
+ from .foundation import *
5
+
6
+ @dataclass
7
+ class AdvantageTransformNormalize(AdvantageTransform):
8
+ # mean or sum normalization
9
+ method: str = "mean"
10
+
11
+ def transform(self, raw_advantages: torch.Tensor) -> torch.Tensor:
12
+ if self.method == "sum":
13
+ return raw_advantages / (raw_advantages.sum() + 1e-9)
14
+ else: # default to mean normalization
15
+ if len(raw_advantages) < 2:
16
+ return raw_advantages
17
+
18
+ return (raw_advantages - raw_advantages.mean()) / (raw_advantages.std() + 1e-9)
19
+
20
+ @dataclass
21
+ class RolloutMethodA2C(RolloutMethod):
22
+ n_steps: int = 16
23
+ n_envs: int = 4
24
+
25
+ def collect_rollout_data(
26
+ self, envs: gym.vector.VectorEnv, initial_states: torch.Tensor, agent: AgentProtocol, tracker: SessionTracker
27
+ ) -> tuple[RolloutData, list[np.ndarray]]:
28
+ n_envs = envs.num_envs
29
+ state_space = envs.single_observation_space.shape[0]
30
+ action_dim = envs.single_action_space.shape[0] if isinstance(envs.single_action_space, gym.spaces.Box) else 1
31
+
32
+ rollout_states_t = torch.zeros(n_envs, self.n_steps, state_space) # (n_envs, n_steps, state_space)
33
+ rollout_actions_t = torch.zeros(n_envs, self.n_steps, action_dim) # (n_envs, n_steps, action_dim)
34
+ rollout_next_states_t = torch.zeros(n_envs, self.n_steps, state_space) # (n_envs, n_steps, state_space)
35
+ rollout_rewards_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
36
+ rollout_terminateds_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
37
+ rollout_truncateds_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
38
+ rollout_dones_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
39
+ rollout_critic_values_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
40
+ rollout_log_probs_t = torch.zeros(n_envs, self.n_steps, 1) # (n_envs, n_steps, 1 value)
41
+
42
+ states = initial_states
43
+
44
+ for step in range(self.n_steps):
45
+ tracker.increment_timestep(n=n_envs)
46
+
47
+ states_t = torch.tensor(states).float() # (n_envs, state_space)
48
+
49
+ with torch.no_grad():
50
+ actions_t, _ = agent.actor.output(states_t)
51
+ # apply exploration noise etc to the actions
52
+ actions_t = agent.actor.exploration(actions_t)
53
+ log_probs_t, _ = agent.actor.get_log_prob_and_entropy(states_t, actions_t)
54
+ critic_values_t = agent.get_critic_value(states_t, actions_t)
55
+ actions_np = actions_t.squeeze(-1).numpy() # (n_envs)
56
+ next_states, rewards, terminateds, truncateds, infos = envs.step(actions_np)
57
+
58
+ dones = terminateds | truncateds
59
+ dones_t = torch.tensor(dones).float().reshape(n_envs, 1) # (n_envs, 1 value)
60
+ terminateds_t = torch.tensor(terminateds).float().reshape(n_envs, 1) # (n_envs, 1 value)
61
+ truncateds_t = torch.tensor(truncateds).float().reshape(n_envs, 1) # (n_envs, 1 value)
62
+
63
+ rewards_t = torch.tensor(rewards).float().reshape(n_envs, 1) # (n_envs, 1 value)
64
+
65
+ rollout_states_t[:, step] = states_t # (n_envs, n_steps, state_space)
66
+ rollout_actions_t[:, step] = actions_t # (n_envs, n_steps, 1 value)
67
+ rollout_next_states_t[:, step] = torch.tensor(next_states).float() # (n_envs, n_steps, state_space)
68
+ rollout_rewards_t[:, step] = rewards_t # (n_envs, n_steps, 1 value)
69
+ rollout_terminateds_t[:, step] = terminateds_t # (n_envs, n_steps, 1 value)
70
+ rollout_truncateds_t[:, step] = truncateds_t # (n_envs, n_steps, 1 value)
71
+ rollout_dones_t[:, step] = dones_t # (n_envs, n_steps, 1 value)
72
+ rollout_critic_values_t[:, step] = critic_values_t # (n_envs, n_steps, 1 value)
73
+ rollout_log_probs_t[:, step] = log_probs_t # (n_envs, n_steps, 1 value)
74
+
75
+ if dones.any():
76
+ episode_returns = infos['episode']['r'][dones]
77
+ episode_lengths = infos['episode']['l'][dones]
78
+
79
+ tracker.finish_episodes(episode_returns, episode_lengths)
80
+
81
+ next_states, _ = envs.reset(options={'reset_mask': dones})
82
+
83
+ states = next_states
84
+
85
+ rollout_returns_t = torch.zeros_like(rollout_rewards_t) # (n_envs, n_steps, 1)
86
+ rollout_advantages_t = torch.zeros_like(rollout_rewards_t) # (n_envs, n_steps, 1)
87
+
88
+ rollout_data = RolloutData(
89
+ states=rollout_states_t,
90
+ actions=rollout_actions_t,
91
+ next_states=rollout_next_states_t,
92
+ rewards=rollout_rewards_t,
93
+ terminateds=rollout_terminateds_t,
94
+ truncateds=rollout_truncateds_t,
95
+ dones=rollout_dones_t,
96
+ critic_values=rollout_critic_values_t,
97
+ log_probs=rollout_log_probs_t,
98
+ returns=rollout_returns_t,
99
+ advantages=rollout_advantages_t
100
+ )
101
+
102
+ return rollout_data, next_states
@@ -0,0 +1,256 @@
1
+ from collections.abc import Callable
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.distributions import Categorical, Normal
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass, field
7
+
8
+ HIDDEN_SIZES_DEFAULT = [64, 64]
9
+
10
+ class ActorProtocol(ABC):
11
+ @abstractmethod
12
+ def output(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
13
+ """Given a state, output an action and its log probability.
14
+
15
+ Args:
16
+ state: The input state (torch.Tensor)
17
+
18
+ Returns:
19
+ action: The action to take (torch.Tensor)
20
+ log_prob: The log probability of the action (torch.Tensor)
21
+ """
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_loss(self, data, critic_value_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> torch.Tensor:
26
+ """Compute the actor loss given the data and critic.
27
+ Advantages should be precomputed and stored in data.
28
+
29
+ Args:
30
+ data: The rollout data (RolloutData)
31
+ critic_value_func: A function that takes state and action tensors and returns critic values
32
+ (Callable[[torch.Tensor, torch.Tensor], torch.Tensor])
33
+ """
34
+ pass
35
+
36
+ @abstractmethod
37
+ def exploration(self, action) -> torch.Tensor:
38
+ """Apply exploration noise to the given action.
39
+
40
+ Args:
41
+ action: The action to apply exploration to (torch.Tensor)
42
+
43
+ Returns:
44
+ explored_action: The action after applying exploration (torch.Tensor)
45
+ """
46
+ pass
47
+
48
+ @abstractmethod
49
+ def get_log_prob_and_entropy(self, state, action) -> tuple[torch.Tensor, torch.Tensor]:
50
+ """Given a state and action, output its log probability and entropy.
51
+
52
+ Args:
53
+ state: The input state (torch.Tensor)
54
+ action: The input action (torch.Tensor)
55
+
56
+ Returns:
57
+ log_prob: The log probability of the action (torch.Tensor)
58
+ entropy: The entropy of the action distribution (torch.Tensor)
59
+ """
60
+ pass
61
+
62
+ @dataclass
63
+ class ActorParams:
64
+ pass
65
+
66
+ @dataclass
67
+ class PolicyObjectiveMethod(ABC):
68
+ @abstractmethod
69
+ def compute_policy_objective(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
70
+ """Compute the policy objective given the ratio of new to old policy probabilities and the advantages.
71
+
72
+ Args:
73
+ ratio: Ratio of new to old policy probabilities (torch.Tensor)
74
+ advantages: Advantages computed from the advantage method (torch.Tensor)
75
+
76
+ Returns:
77
+ Policy objective (torch.Tensor)
78
+ """
79
+ pass
80
+
81
+ class PolicyObjectiveMethodStandard(PolicyObjectiveMethod):
82
+ def compute_policy_objective(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
83
+ return ratio * advantages
84
+
85
+ @dataclass
86
+ class DistributionalActorParams(ActorParams):
87
+ policy_objective_method: PolicyObjectiveMethod = field(
88
+ default_factory=PolicyObjectiveMethodStandard
89
+ )
90
+ entropy_coef: float = 0.0
91
+
92
+ class DiscreteActorNetwork(nn.Module):
93
+ def __init__(self, state_dim, action_dim, hidden_sizes=HIDDEN_SIZES_DEFAULT):
94
+ super(DiscreteActorNetwork, self).__init__()
95
+
96
+ self.hidden_layers = nn.ModuleList()
97
+
98
+ prev_size = state_dim
99
+
100
+ for size in hidden_sizes:
101
+ layer = nn.Linear(prev_size, size, dtype=torch.float32)
102
+ self.hidden_layers.append(layer)
103
+ prev_size = size
104
+
105
+ self.head = nn.Linear(prev_size, action_dim, dtype=torch.float32)
106
+
107
+ def forward(self, state): # logits: (batch_size, action_space)
108
+ x = state
109
+
110
+ for layer in self.hidden_layers:
111
+ x = torch.relu(layer(x))
112
+
113
+ logits = self.head(x)
114
+ return logits
115
+
116
+ class DistributionalActor(ActorProtocol, nn.Module):
117
+ def __init__(self, params=DistributionalActorParams()):
118
+ super(DistributionalActor, self).__init__()
119
+ self.params = params
120
+
121
+ def get_log_prob_and_entropy(self, state, action):
122
+ pass
123
+
124
+ def forward(self, state):
125
+ return self.network(state)
126
+
127
+ def get_loss(self, data, critic_value_func):
128
+ old_log_probs = data.log_probs.detach() # stored at rollout time
129
+
130
+ new_log_probs, entropies = self.get_log_prob_and_entropy(data.states, data.actions)
131
+ log_ratio = new_log_probs - old_log_probs
132
+ ratio = torch.exp(log_ratio)
133
+
134
+ policy_objective = self.params.policy_objective_method.compute_policy_objective(ratio, data.advantages)
135
+ actor_loss = -(policy_objective.sum())
136
+ actor_loss -= (self.params.entropy_coef * entropies.sum())
137
+
138
+ return actor_loss
139
+
140
+ def exploration(self, action) -> torch.Tensor:
141
+ return action
142
+
143
+ class DiscreteActor(DistributionalActor):
144
+ def __init__(self, state_dim, action_dim, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DistributionalActorParams()):
145
+ super(DiscreteActor, self).__init__(params=params)
146
+
147
+ self.network = DiscreteActorNetwork(state_dim, action_dim, hidden_sizes=hidden_sizes)
148
+
149
+ def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
150
+ actor_logits = self.network.forward(state)
151
+ action_distribution = Categorical(logits=actor_logits)
152
+ action = action_distribution.sample().unsqueeze(-1) # shape: (batch_size, 1)
153
+ log_prob = action_distribution.log_prob(action.squeeze()).unsqueeze(-1) # shape: (batch_size, 1)
154
+
155
+ return action, log_prob
156
+
157
+ def get_log_prob_and_entropy(self, state, action):
158
+ actor_logits = self.network.forward(state)
159
+ action_distribution = Categorical(logits=actor_logits)
160
+ log_prob = action_distribution.log_prob(action.squeeze(-1)).unsqueeze(-1)
161
+ entropy = action_distribution.entropy().unsqueeze(-1) # (batch_size, 1)
162
+ return log_prob, entropy
163
+
164
+ class ContinuousActorNetwork(nn.Module):
165
+ def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT):
166
+ super(ContinuousActorNetwork, self).__init__()
167
+
168
+ self.hidden_layers = nn.ModuleList()
169
+
170
+ prev_size = state_dim
171
+
172
+ for size in hidden_sizes:
173
+ layer = nn.Linear(prev_size, size, dtype=torch.float32)
174
+ self.hidden_layers.append(layer)
175
+ prev_size = size
176
+
177
+ self.head = nn.Linear(prev_size, action_dim, dtype=torch.float32)
178
+ self.action_range = action_range
179
+
180
+ def forward(self, state): # logits: (batch_size, action_space)
181
+ x = state
182
+
183
+ for layer in self.hidden_layers:
184
+ x = torch.relu(layer(x))
185
+
186
+ x = torch.tanh(self.head(x)) # Bounds to -1 to 1
187
+
188
+ mins = self.action_range.min(dim=0).values
189
+ maxs = self.action_range.max(dim=0).values
190
+ length = maxs - mins
191
+ x = (x * (length / 2)) + ((mins + maxs) / 2)
192
+
193
+ return x
194
+
195
+ class StochasticActor(DistributionalActor):
196
+ def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DistributionalActorParams()):
197
+ super(StochasticActor, self).__init__(params=params)
198
+ self.network = ContinuousActorNetwork(state_dim, action_dim, action_range, hidden_sizes=hidden_sizes)
199
+ self.log_std = nn.Parameter(torch.zeros(action_dim))
200
+
201
+ def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
202
+ # mean, also referred to as 'mu'
203
+ mean = self.network(state)
204
+ std = torch.exp(self.log_std) # Exp to make positive
205
+ dist = Normal(mean, std)
206
+ action = dist.sample() # Shape: (batch_size, action_dim)
207
+ log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # Sum for multi-dim, shape: (batch_size, 1)
208
+
209
+ return action, log_prob
210
+
211
+ def get_log_prob_and_entropy(self, state, action):
212
+ mean = self.network(state)
213
+ std = torch.exp(self.log_std) # Exp to make positive
214
+ dist = Normal(mean, std)
215
+ log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # (batch_size, 1)
216
+ entropy = dist.entropy().sum(dim=-1, keepdim=True) # (batch_size, 1)
217
+ return log_prob, entropy
218
+
219
+ @dataclass
220
+ class DeterministicActorParams(ActorParams):
221
+ exploration_std: float = 0.1
222
+
223
+ class DeterministicActor(ActorProtocol, nn.Module):
224
+ def __init__(self, state_dim, action_dim, action_range, hidden_sizes=HIDDEN_SIZES_DEFAULT, params=DeterministicActorParams()):
225
+ super(DeterministicActor, self).__init__()
226
+ self.network = ContinuousActorNetwork(state_dim, action_dim, action_range, hidden_sizes=hidden_sizes)
227
+ self.params = params
228
+
229
+ def forward(self, state): # action: (batch_size, 1)
230
+ return self.network(state)
231
+
232
+ def output(self, state): # action: (batch_size, 1), log_prob: (batch_size, 1)
233
+ # return zeroes instead of logprobs
234
+ return self.network(state), torch.zeros((state.shape[0], 1), dtype=torch.float32)
235
+
236
+ def get_loss(self, data, critic_value_func):
237
+ actions_pi = self.network(data.states)
238
+ actor_loss = -(critic_value_func(data.states, actions_pi).mean())
239
+
240
+ return actor_loss
241
+
242
+ def exploration(self, action) -> torch.Tensor:
243
+ action = action + (torch.randn_like(action) * self.params.exploration_std)
244
+
245
+ mins = self.network.action_range.min(dim=0).values
246
+ maxs = self.network.action_range.max(dim=0).values
247
+ action = torch.clamp(action, mins, maxs)
248
+
249
+ return action
250
+
251
+ def get_log_prob_and_entropy(self, state, action):
252
+ # return zeroes instead of logprobs and entropies
253
+ batch_size = state.shape[0]
254
+ log_prob = torch.zeros((batch_size, 1), dtype=torch.float32)
255
+ entropy = torch.zeros((batch_size, 1), dtype=torch.float32)
256
+ return log_prob, entropy
@@ -0,0 +1,117 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from copy import deepcopy
6
+
7
+ from .actors import ActorProtocol
8
+ from .critics import CriticProtocol
9
+
10
+ class AgentProtocol(ABC):
11
+ actor: ActorProtocol
12
+ critics: list[CriticProtocol]
13
+
14
+ @abstractmethod
15
+ def get_action(self, state):
16
+ """Get the action from the actor for the given state.
17
+
18
+ Args:
19
+ state: The input state (torch.Tensor)
20
+ """
21
+ pass
22
+
23
+ def get_critic_value(self, state, action):
24
+ """Get the critic value for the given state and action.
25
+
26
+ Args:
27
+ state: The input state (torch.Tensor)
28
+ action: The input action (torch.Tensor)
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ def get_target_action(self, state):
34
+ """Get the target actor's output for the given state.
35
+ A simple implementation could use the main actor if no target networks are used.
36
+
37
+ Args:
38
+ state: The input state (torch.Tensor)
39
+ """
40
+ pass
41
+
42
+ @abstractmethod
43
+ def get_target_critic_value(self, state, action):
44
+ """Get the target critic's output for the given state and action.
45
+ A simple implementation could use the main critics if no target networks are used.
46
+
47
+ Args:
48
+ state: The input state (torch.Tensor)
49
+ action: The input action (torch.Tensor)
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ def update_targets(self):
55
+ """Update the target networks, if applicable."""
56
+ pass
57
+
58
+ @dataclass
59
+ class AgentParams:
60
+ pass
61
+
62
+ class Agent(AgentProtocol):
63
+ def __init__(self, actor: ActorProtocol, critics: list[CriticProtocol], params: AgentParams=AgentParams()):
64
+
65
+ super(Agent, self).__init__()
66
+
67
+ self.actor = actor
68
+ self.critics = critics
69
+ self.params = params
70
+
71
+ def get_action(self, state):
72
+ action, _ = self.actor.output(state)
73
+
74
+ return action
75
+
76
+ def get_critic_value(self, state, action):
77
+ # Return the minimum value across all critics
78
+ values = [critic.output(state, action) for critic in self.critics]
79
+ return torch.min(torch.stack(values), dim=0).values
80
+
81
+ def get_target_action(self, state):
82
+ return self.get_action(state)
83
+
84
+ def get_target_critic_value(self, state, action):
85
+ return self.get_critic_value(state, action)
86
+
87
+ def update_targets(self):
88
+ # No target networks to update
89
+ pass
90
+
91
+ @dataclass
92
+ class AgentWithTargetsParams(AgentParams):
93
+ tau: float = 0.1
94
+
95
+ class AgentWithTargets(Agent):
96
+ def __init__(self, actor: ActorProtocol, critics: list[CriticProtocol], params: AgentWithTargetsParams):
97
+ super(AgentWithTargets, self).__init__(actor, critics, params)
98
+
99
+ self.target_actor_network = deepcopy(self.actor.network)
100
+ self.target_critic_networks = [deepcopy(critic.network) for critic in self.critics]
101
+
102
+ def get_target_action(self, state):
103
+ return self.target_actor_network(state)
104
+
105
+ def get_target_critic_value(self, state, action):
106
+ # Return the minimum value across all target critics
107
+ values = [target_critic_network(state, action) for target_critic_network in self.target_critic_networks]
108
+ return torch.min(torch.stack(values), dim=0).values
109
+
110
+ def update_targets(self):
111
+ # Soft update targets
112
+ for target_param, param in zip(self.target_actor_network.parameters(), self.actor.network.parameters()):
113
+ target_param.data.copy_((self.params.tau * param.data) + ((1 - self.params.tau) * target_param.data))
114
+
115
+ for i, critic in enumerate(self.critics):
116
+ for target_param, param in zip(self.target_critic_networks[i].parameters(), critic.network.parameters()):
117
+ target_param.data.copy_((self.params.tau * param.data) + ((1 - self.params.tau) * target_param.data))
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+ from .foundation import *
4
+ from .critics import CriticLossMethod
5
+
6
+ class CriticLossMethodClipped(CriticLossMethod):
7
+ def __init__(self, clip_range_vf: float):
8
+ self.clip_range_vf = clip_range_vf
9
+
10
+ def compute_critic_loss(self, reference_actor, reference_critic, new_values, data, gamma):
11
+ old_values = data.critic_values.detach()
12
+
13
+ value_delta = (new_values - old_values)
14
+ value_delta_clipped = torch.clamp(value_delta, -self.clip_range_vf, self.clip_range_vf)
15
+ values_clipped = old_values + value_delta_clipped
16
+
17
+ loss_unclipped = (new_values - data.returns).pow(2)
18
+ loss_clipped = (values_clipped - data.returns).pow(2)
19
+
20
+ return torch.maximum(loss_unclipped, loss_clipped).mean()
@@ -0,0 +1,18 @@
1
+ import torch
2
+
3
+ from .foundation import *
4
+ from .critics import CriticLossMethod
5
+
6
+ class CriticLossMethodQ(CriticLossMethod):
7
+ def compute_critic_loss(self, target_action_func, target_critic_value_func, new_values, data, gamma):
8
+ with torch.no_grad():
9
+ # this should be the target actor
10
+ a_next = target_action_func(data.next_states)
11
+ # this should be the target critic
12
+ q_next = target_critic_value_func(data.next_states, a_next)
13
+ y = data.rewards + ((gamma * (1.0 - data.terminateds)) * q_next)
14
+
15
+ # new values come from the current critic, passed into this function
16
+ critic_loss = torch.nn.MSELoss()(new_values, y)
17
+
18
+ return critic_loss