drlab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
drlab/__init__.py ADDED
@@ -0,0 +1,36 @@
1
+ from drlab.controllers import (
2
+ Controller,
3
+ EpsilonGreedyController,
4
+ GreedyController,
5
+ StochasticController,
6
+ )
7
+ from drlab.experiments import (
8
+ ActorCriticExperiment,
9
+ ActorCriticExperimentConfig,
10
+ DQNExperiment,
11
+ DQNExperimentConfig,
12
+ )
13
+ from drlab.learners import ActorCritic, ActorCriticConfig, DQN, DQNConfig
14
+ from drlab.replay import ReplayBuffer, TransitionBatch
15
+ from drlab.runners import Runner
16
+
17
+ __version__ = "0.1.0"
18
+
19
+ __all__ = [
20
+ "__version__",
21
+ "ActorCritic",
22
+ "ActorCriticConfig",
23
+ "ActorCriticExperiment",
24
+ "ActorCriticExperimentConfig",
25
+ "Controller",
26
+ "DQN",
27
+ "DQNConfig",
28
+ "DQNExperiment",
29
+ "DQNExperimentConfig",
30
+ "EpsilonGreedyController",
31
+ "GreedyController",
32
+ "ReplayBuffer",
33
+ "Runner",
34
+ "StochasticController",
35
+ "TransitionBatch",
36
+ ]
@@ -0,0 +1,6 @@
1
+ from .base import Controller
2
+ from .greedy import GreedyController
3
+ from .e_greedy import EpsilonGreedyController
4
+ from .stochastic_controller import StochasticController
5
+
6
+ __all__ = ["Controller", "GreedyController", "EpsilonGreedyController", "StochasticController"]
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ import torch as th
4
+ from abc import ABC, abstractmethod
5
+
6
+ class Controller(ABC):
7
+ """Abstract controller interface."""
8
+
9
+ num_actions: int
10
+ model: th.nn.Module = None
11
+ controller: Controller = None
12
+
13
+ @abstractmethod
14
+ def choose(self, obs: th.Tensor, **kwargs) -> th.Tensor: ...
15
+
16
+ @abstractmethod
17
+ def probabilities(self, obs: th.Tensor, **kwargs) -> th.Tensor: ...
@@ -0,0 +1,46 @@
1
+ import numpy as np
2
+ import torch as th
3
+ from .base import Controller
4
+
5
+ class EpsilonGreedyController(Controller):
6
+
7
+ def __init__(
8
+ self,
9
+ controller: Controller,
10
+ num_actions: int,
11
+ max_eps: float = 1.0,
12
+ min_eps: float = 0.1,
13
+ anneal_steps: int = 10_000,
14
+ ):
15
+ self.controller = controller
16
+ self.num_actions = num_actions
17
+ self.model = controller.model
18
+ self.max_eps = max_eps
19
+ self.min_eps = min_eps
20
+ self.anneal_steps = anneal_steps
21
+ self.num_decisions = 0
22
+
23
+ if anneal_steps <= 1:
24
+ raise ValueError("anneal_steps must be >= 2")
25
+
26
+ def epsilon(self) -> float:
27
+ frac = max(1 - self.num_decisions / (self.anneal_steps - 1), 0.0)
28
+ return frac * (self.max_eps - self.min_eps) + self.min_eps
29
+
30
+ def choose(self, obs: th.Tensor, increase_counter: bool = True, **kwargs) -> th.Tensor:
31
+ eps = self.epsilon()
32
+ if increase_counter:
33
+ self.num_decisions += 1
34
+
35
+ B = obs.shape[0] if obs.ndim > 1 else 1
36
+ if np.random.rand() < eps:
37
+ return th.randint(self.num_actions, (B,), device=obs.device, dtype=th.long)
38
+
39
+ return self.controller.choose(obs, **kwargs)
40
+
41
+ def probabilities(self, obs: th.Tensor, **kwargs) -> th.Tensor:
42
+ eps = self.epsilon()
43
+ greedy = self.controller.probabilities(obs, **kwargs) # one-hot on argmax, shape [B,A]
44
+ B = greedy.shape[0]
45
+ uniform = th.full((B, self.num_actions), eps / self.num_actions, device=greedy.device)
46
+ return uniform + (1 - eps) * greedy
@@ -0,0 +1,20 @@
1
+ import torch as th
2
+ from .base import Controller
3
+
4
+ class GreedyController(Controller):
5
+
6
+ def __init__(self, model: th.nn.Module, num_actions: int):
7
+
8
+ self.num_actions = num_actions
9
+ self.model = model
10
+
11
+ def choose(self, obs: th.Tensor):
12
+ output: th.Tensor = self.model(obs)[:, :self.num_actions]
13
+ return th.argmax(output, dim=-1)
14
+
15
+ def probabilities(self, obs: th.Tensor):
16
+ output: th.Tensor = self.model(obs)[:, :self.num_actions]
17
+ idx = output.argmax(dim=-1, keepdim=True) # [B,1]
18
+ probs = th.zeros(output.shape[0], self.num_actions, device=output.device)
19
+ probs.scatter_(dim=-1, index=idx, value=1.0)
20
+ return probs
@@ -0,0 +1,17 @@
1
+ import torch as th
2
+ from .base import Controller
3
+
4
+ class StochasticController(Controller):
5
+
6
+ def __init__(self, model: th.nn.Module, num_actions: int):
7
+
8
+ self.num_actions = num_actions
9
+ self.model = model
10
+
11
+ def choose(self, obs: th.Tensor) -> th.Tensor:
12
+ probs = self.probabilities(obs)
13
+ return th.distributions.Categorical(probs=probs).sample()
14
+
15
+ def probabilities(self, obs: th.Tensor) -> th.Tensor:
16
+ output = self.model(obs)[:, :self.num_actions]
17
+ return th.nn.functional.softmax(output, dim=-1)
@@ -0,0 +1,9 @@
1
+ from .dqn_experiment import DQNExperiment, DQNExperimentConfig
2
+ from .ac_experiment import ActorCriticExperiment, ActorCriticExperimentConfig
3
+
4
+ __all__ = [
5
+ "DQNExperimentConfig",
6
+ "DQNExperiment",
7
+ "ActorCriticExperimentConfig",
8
+ "ActorCriticExperiment"
9
+ ]
@@ -0,0 +1,102 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import gymnasium as gym
4
+ from dataclasses import dataclass
5
+ from typing import Callable
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ from drlab.learners import ActorCritic
9
+ from drlab.runners import Runner
10
+ from drlab.controllers import Controller
11
+
12
+
13
+ @dataclass
14
+ class ActorCriticExperimentConfig:
15
+ max_steps: int
16
+ gamma: float = 0.99
17
+ run_steps: int = 0
18
+ log_dir: str = "runs/reinforce_experiment"
19
+ experiment_name: str = "ActorCriticExperiment"
20
+ step_callback: Callable[[int], None] | None = None
21
+ step_callback_interval: int | None = None
22
+
23
+
24
+ class ActorCriticExperiment:
25
+
26
+ def __init__(
27
+ self,
28
+ env: gym.Env,
29
+ controller: Controller,
30
+ learner: ActorCritic,
31
+ config: ActorCriticExperimentConfig,
32
+ ):
33
+ # Init experiment settings
34
+ self.max_steps = config.max_steps
35
+ self.run_steps = config.run_steps
36
+ self.step_callback = config.step_callback
37
+ self.step_callback_interval = config.step_callback_interval
38
+
39
+ # Only ask the runner to compute discounted returns when the learner uses them.
40
+ calculate_returns = (
41
+ not learner.config.advantage_bootstrap
42
+ or (learner.config.use_bias and learner.config.value_targets == "returns")
43
+ )
44
+
45
+ # Init drl components
46
+ self.runner = Runner(env, controller, calculate_returns, False, config.gamma, learner.device)
47
+ self.learner = learner
48
+
49
+ # Init logging
50
+ self.writer = SummaryWriter(log_dir=config.log_dir)
51
+ self.experiment_name = config.experiment_name
52
+
53
+ if self.step_callback is not None and self.step_callback_interval is None:
54
+ raise ValueError("step_callback_interval must be set when step_callback is provided.")
55
+
56
+ def run(self):
57
+
58
+ steps = 0
59
+ if self.step_callback is not None:
60
+ self.step_callback(steps)
61
+ next_callback_step = self.step_callback_interval
62
+ pbar = tqdm(total=self.max_steps, desc=self.experiment_name, dynamic_ncols=True, mininterval=1.0)
63
+ while steps < self.max_steps:
64
+
65
+ # 1. Run batch of transitions
66
+ batch, ep_returns, ep_lengths, _ = self.runner.run(self.run_steps)
67
+ batch = batch.to(self.learner.device)
68
+
69
+ # 2. Learn from batch
70
+ loss = self.learner.train(
71
+ batch.rewards,
72
+ batch.dones,
73
+ batch.states,
74
+ batch.actions,
75
+ batch.next_states,
76
+ batch.returns,
77
+ )
78
+
79
+ # 3. Log results
80
+ self.writer.add_scalar("Loss", loss, steps)
81
+ if ep_returns:
82
+ mean_return = np.mean(ep_returns)
83
+ mean_length = np.mean(ep_lengths)
84
+ self.writer.add_scalar("Episode return", mean_return, steps)
85
+ self.writer.add_scalar("Episode Length", mean_length, steps)
86
+
87
+ pbar.set_postfix({
88
+ "loss": f"{loss:.3f}",
89
+ "return": f"{mean_return:.2f}",
90
+ "len": f"{mean_length:.1f}",
91
+ }, refresh=False)
92
+
93
+ step_inc = batch.states.shape[0]
94
+ steps += step_inc
95
+ pbar.update(step_inc)
96
+
97
+ while self.step_callback is not None and steps >= next_callback_step:
98
+ self.step_callback(next_callback_step)
99
+ next_callback_step += self.step_callback_interval
100
+
101
+ pbar.close()
102
+ self.writer.close()
@@ -0,0 +1,151 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import gymnasium as gym
4
+ from dataclasses import dataclass
5
+ from typing import Callable
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ from drlab.learners import DQN
9
+ from drlab.runners import Runner
10
+ from drlab.controllers import Controller
11
+ from drlab.replay import TransitionBatch, ReplayBuffer
12
+
13
+ @dataclass
14
+ class DQNExperimentConfig:
15
+ max_steps: int
16
+ gamma: float = 0.99
17
+ run_steps: int = 0
18
+ log_dir: str = "runs/dqn_experiment"
19
+ experiment_name: str = "DQNExperiment"
20
+ use_replay: bool = True
21
+ replay_buffer_size: int = 10_000
22
+ batch_size: int = 128
23
+ use_last_episode: bool = True
24
+ grad_repeats: int = 1
25
+ step_callback: Callable[[int], None] | None = None
26
+ step_callback_interval: int | None = None
27
+
28
+
29
+ class DQNExperiment:
30
+
31
+ def __init__(
32
+ self,
33
+ env: gym.Env,
34
+ controller: Controller,
35
+ learner: DQN,
36
+ config: DQNExperimentConfig,
37
+ ):
38
+
39
+ # Init experiment settings
40
+ self.max_steps = config.max_steps
41
+ self.grad_repeats = config.grad_repeats
42
+ self.batch_size = config.batch_size
43
+ self.use_last_episode = config.use_last_episode
44
+ self.run_steps = config.run_steps
45
+ self.step_callback = config.step_callback
46
+ self.step_callback_interval = config.step_callback_interval
47
+
48
+ # Init drl components
49
+ self.runner = Runner(env, controller, False, True, config.gamma, learner.device)
50
+ self.learner = learner
51
+
52
+ # Init replay buffer
53
+ self.replay_buffer_size = config.replay_buffer_size
54
+ self.use_replay = config.use_replay
55
+ self.replay_buffer = ReplayBuffer(
56
+ capacity=config.replay_buffer_size if config.use_replay else config.batch_size,
57
+ obs_shape=env.observation_space.shape,
58
+ device=learner.device
59
+ )
60
+
61
+ # Init logging
62
+ self.writer = SummaryWriter(log_dir=config.log_dir)
63
+ self.experiment_name = config.experiment_name
64
+
65
+ if self.step_callback is not None and self.step_callback_interval is None:
66
+ raise ValueError("step_callback_interval must be set when step_callback is provided.")
67
+
68
+ def _make_minibatch(self, batch: TransitionBatch, last_episode: TransitionBatch | None) -> TransitionBatch:
69
+
70
+ # No replay buffer
71
+ if not self.use_replay:
72
+ return self.replay_buffer.get_all()
73
+
74
+ # Replay buffer without last episode
75
+ if not self.use_last_episode:
76
+ return self.replay_buffer.sample(self.batch_size)
77
+
78
+ # Replay buffer with last episode
79
+ episode = last_episode if last_episode else batch
80
+ episode = episode.to(self.learner.device)
81
+
82
+ ep_len = episode.states.shape[0]
83
+ if ep_len >= self.batch_size:
84
+ return episode
85
+
86
+ rest = self.replay_buffer.sample(self.batch_size - ep_len)
87
+ return rest.cat(episode)
88
+
89
+ def _learn_from_batch(self, batch: TransitionBatch, last_episode: TransitionBatch | None) -> float:
90
+ # 1) Add batch to replay buffer
91
+ self.replay_buffer.add(
92
+ states=batch.states.cpu().numpy(),
93
+ actions=batch.actions.cpu().numpy(),
94
+ rewards=batch.rewards.cpu().numpy(),
95
+ dones=batch.dones.cpu().numpy(),
96
+ next_states=batch.next_states.cpu().numpy(),
97
+ returns=batch.returns.cpu().numpy(),
98
+ )
99
+ if self.replay_buffer.size < self.batch_size:
100
+ return 0.0
101
+
102
+ # 2) train repeats
103
+ total_loss = 0.0
104
+ for _ in range(self.grad_repeats):
105
+ mb = self._make_minibatch(batch, last_episode)
106
+ total_loss += self.learner.train(
107
+ mb.rewards, mb.dones, mb.states, mb.actions, mb.next_states
108
+ )
109
+
110
+ return total_loss / self.grad_repeats
111
+
112
+ def run(self):
113
+
114
+ steps = 0
115
+ if self.step_callback is not None:
116
+ self.step_callback(steps)
117
+ next_callback_step = self.step_callback_interval
118
+ pbar = tqdm(total=self.max_steps, desc=self.experiment_name, dynamic_ncols=True, mininterval=1.0)
119
+ while steps < self.max_steps:
120
+
121
+ # 1. Run batch of transitions
122
+ batch, ep_returns, ep_lengths, last_episode = self.runner.run(self.run_steps)
123
+
124
+ # 2. Learn from batch
125
+ loss = self._learn_from_batch(batch, last_episode)
126
+
127
+ # 3. Log results
128
+ self.writer.add_scalar("Loss", loss, steps)
129
+ if ep_returns:
130
+ mean_return = np.mean(ep_returns)
131
+ mean_length = np.mean(ep_lengths)
132
+ self.writer.add_scalar("Episode return", mean_return, steps)
133
+ self.writer.add_scalar("Episode Length", mean_length, steps)
134
+
135
+ pbar.set_postfix({
136
+ "loss": f"{loss:.3f}",
137
+ "return": f"{mean_return:.2f}",
138
+ "len": f"{mean_length:.1f}",
139
+ }, refresh=False)
140
+
141
+ # 3. Update steps & progress bar
142
+ step_inc = batch.states.shape[0]
143
+ steps += step_inc
144
+ pbar.update(step_inc)
145
+
146
+ while self.step_callback is not None and steps >= next_callback_step:
147
+ self.step_callback(next_callback_step)
148
+ next_callback_step += self.step_callback_interval
149
+
150
+ pbar.close()
151
+ self.writer.close()
@@ -0,0 +1,4 @@
1
+ from .dqn import DQN, DQNConfig
2
+ from .actor_critic import ActorCritic, ActorCriticConfig
3
+
4
+ __all__ = ["DQNConfig", "DQN", "ActorCriticConfig", "ActorCritic"]
@@ -0,0 +1,206 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable
3
+
4
+ import torch as th
5
+
6
+ @dataclass
7
+ class ActorCriticConfig:
8
+ device: th.device | str = "cpu"
9
+ regularizers: list[Callable[..., th.Tensor | float]] = field(default_factory=list)
10
+ reg_lams: list[float] = field(default_factory=list)
11
+ num_actions: int = 2
12
+ clip_grad: bool = True
13
+ grad_norm_clip: float = 1.0
14
+ use_bias: bool = True
15
+ value_criterion: Callable = field(default_factory=th.nn.MSELoss)
16
+ value_lambda: float = 0.1
17
+ value_targets: str = "td"
18
+ gamma: float = 0.99
19
+ advantage_bootstrap: bool = True
20
+ off_policy_iterations: int = 0
21
+ ppo_clipping: float = 0.1
22
+ use_entropy: bool = False
23
+ entropy_max_lambda: float = 0.0
24
+ entropy_min_lambda: float = 0.0
25
+ entropy_anneal_steps: int = 1
26
+ normalize_advantages: bool = False
27
+
28
+
29
+ class ActorCritic:
30
+
31
+ def __init__(
32
+ self,
33
+ actor: th.nn.Module,
34
+ optimizer: th.optim.Optimizer,
35
+ config: ActorCriticConfig,
36
+ ):
37
+
38
+ device = th.device(config.device)
39
+ self.actor = actor
40
+ self.optimizer = optimizer
41
+ self.all_parameters = list(actor.parameters())
42
+ self.actor.to(device)
43
+ self.device = device
44
+ self.config = config
45
+
46
+ if len(config.regularizers) != len(config.reg_lams):
47
+ raise ValueError("Length of regularizers and reg_lams lists must match.")
48
+
49
+ if not config.use_bias and config.advantage_bootstrap:
50
+ raise ValueError("advantage_bootstrap=True requires use_bias=True.")
51
+
52
+ if config.value_targets not in {"returns", "td"}:
53
+ raise ValueError("value_targets must be either 'returns' or 'td'.")
54
+
55
+ if config.use_entropy and config.entropy_anneal_steps <= 1:
56
+ raise ValueError("entropy_anneal_steps must be > 1 when use_entropy=True.")
57
+ if config.use_entropy and config.entropy_max_lambda < config.entropy_min_lambda:
58
+ raise ValueError("entropy_max_lambda must be >= entropy_min_lambda.")
59
+
60
+ self.entropy_step = 0
61
+
62
+ def _value_loss(
63
+ self,
64
+ returns: th.Tensor,
65
+ rewards: th.Tensor,
66
+ dones: th.Tensor,
67
+ values: th.Tensor,
68
+ next_values: th.Tensor,
69
+ ) -> th.Tensor:
70
+ if self.config.value_targets == "returns":
71
+ targets = returns
72
+ elif self.config.value_targets == "td":
73
+ targets = rewards + self.config.gamma * (~dones * next_values)
74
+ else:
75
+ raise ValueError(f"Unknown value_targets: {self.config.value_targets}")
76
+
77
+ return self.config.value_criterion(values, targets)
78
+
79
+ def _get_policy(self, logits: th.Tensor, actions: th.Tensor) -> th.Tensor:
80
+ probs = th.nn.functional.softmax(logits, dim=-1)
81
+ pi = probs.gather(dim=-1, index=actions)
82
+ return pi
83
+
84
+ def _entropy_lambda(self) -> float:
85
+ progress = min(self.entropy_step / (self.config.entropy_anneal_steps - 1), 1.0)
86
+
87
+ return (
88
+ self.config.entropy_max_lambda
89
+ + progress * (self.config.entropy_min_lambda - self.config.entropy_max_lambda)
90
+ )
91
+
92
+ def _entropy_loss(self, logits: th.Tensor) -> th.Tensor:
93
+ entropy_lambda = self._entropy_lambda()
94
+ probs = th.nn.functional.softmax(logits, dim=-1)
95
+ log_probs = th.nn.functional.log_softmax(logits, dim=-1)
96
+ entropy = -(probs * log_probs).sum(dim=-1).mean()
97
+ return -entropy_lambda * entropy
98
+
99
+ def _advantages(
100
+ self,
101
+ returns: th.Tensor,
102
+ rewards: th.Tensor,
103
+ dones: th.Tensor,
104
+ values: th.Tensor,
105
+ next_values: th.Tensor
106
+ ) -> th.Tensor:
107
+ advantages = None
108
+ if self.config.advantage_bootstrap:
109
+ advantages = rewards + self.config.gamma * (~dones * next_values)
110
+ else:
111
+ advantages = returns
112
+ if self.config.use_bias:
113
+ advantages = advantages - values.detach()
114
+
115
+ if self.config.normalize_advantages and advantages.numel() > 1:
116
+ advantages = (advantages - advantages.mean()) / (
117
+ advantages.std(unbiased=False) + 1e-8
118
+ )
119
+
120
+ return advantages
121
+
122
+ def _policy_loss(self, pi: th.Tensor, advantage: th.Tensor) -> th.Tensor:
123
+
124
+ if self.old_pi is None:
125
+ self.old_pi = pi.detach()
126
+ return -th.mean(pi.clamp_min(1e-8).log() * advantage.detach())
127
+ else:
128
+ ratios = pi / self.old_pi.detach()
129
+ loss = advantage.detach() * ratios
130
+ ppo_loss = th.clamp(ratios, 1-self.config.ppo_clipping, 1+self.config.ppo_clipping) * advantage.detach()
131
+ loss = th.min(loss, ppo_loss)
132
+ return -th.mean(loss)
133
+
134
+ def train(
135
+ self,
136
+ rewards: th.Tensor, # float32, [B,1]
137
+ dones: th.Tensor, # bool or float(0/1), [B,1]
138
+ states: th.Tensor, # float32, [B, obs_dim] or [B,C,H,W]
139
+ actions: th.Tensor, # int64, [B,1]
140
+ next_states: th.Tensor, # float32, same as states
141
+ returns: th.Tensor, # float32, [B,1]
142
+ ) -> float:
143
+
144
+ self.actor.train(True)
145
+ need_next_values = (
146
+ self.config.advantage_bootstrap
147
+ or self.config.value_targets == "td"
148
+ )
149
+
150
+ self.old_pi, loss_sum = None, 0
151
+
152
+ for _ in range(1 + self.config.off_policy_iterations):
153
+ # 1. Compute model policy, values and next_values
154
+ output: th.Tensor = self.actor(states)
155
+ logits = output[:, :self.config.num_actions]
156
+ values = output[:, self.config.num_actions:self.config.num_actions + 1]
157
+
158
+ next_values = None
159
+ if need_next_values:
160
+ with th.no_grad():
161
+ next_output: th.Tensor = self.actor(next_states)
162
+ next_values = next_output[:, self.config.num_actions:self.config.num_actions + 1]
163
+ pi = self._get_policy(logits, actions)
164
+
165
+ # 2. Compute losses
166
+ policy_loss = self._policy_loss(
167
+ pi,
168
+ self._advantages(
169
+ returns,
170
+ rewards,
171
+ dones,
172
+ values,
173
+ next_values
174
+ )
175
+ )
176
+ entropy_loss = 0.0
177
+ if self.config.use_entropy:
178
+ entropy_loss = self._entropy_loss(logits)
179
+ value_loss = 0.0
180
+ if self.config.use_bias:
181
+ value_loss = self.config.value_lambda * self._value_loss(
182
+ returns,
183
+ rewards,
184
+ dones,
185
+ values,
186
+ next_values
187
+ )
188
+ reg_loss = (
189
+ sum(
190
+ lam * reg(self.actor, rewards, dones, states, actions, next_states)
191
+ for reg, lam in zip(self.config.regularizers, self.config.reg_lams)
192
+ )
193
+ if self.config.regularizers else 0.0
194
+ )
195
+ loss = policy_loss + value_loss + entropy_loss + reg_loss
196
+
197
+ # 3. Optimize
198
+ self.optimizer.zero_grad(set_to_none=True)
199
+ loss.backward()
200
+ if self.config.clip_grad:
201
+ th.nn.utils.clip_grad_norm_(self.all_parameters, self.config.grad_norm_clip)
202
+ self.optimizer.step()
203
+ loss_sum += loss.item()
204
+
205
+ self.entropy_step += 1
206
+ return loss_sum