MASA-Safe-RL 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- masa/__init__.py +0 -0
- masa/algorithms/__init__.py +0 -0
- masa/algorithms/a2c/__init__.py +1 -0
- masa/algorithms/a2c/a2c.py +163 -0
- masa/algorithms/ppo/__init__.py +1 -0
- masa/algorithms/ppo/ppo.py +197 -0
- masa/algorithms/tabular/__init__.py +6 -0
- masa/algorithms/tabular/base.py +37 -0
- masa/algorithms/tabular/lcrl.py +81 -0
- masa/algorithms/tabular/q_learning.py +207 -0
- masa/algorithms/tabular/q_learning_lambda.py +84 -0
- masa/algorithms/tabular/recovery_rl.py +227 -0
- masa/algorithms/tabular/recreg.py +559 -0
- masa/algorithms/tabular/sem.py +150 -0
- masa/cli/__init__.py +0 -0
- masa/cli/cli_app.py +262 -0
- masa/common/__init__.py +0 -0
- masa/common/base_class.py +241 -0
- masa/common/buffers.py +139 -0
- masa/common/configs.py +516 -0
- masa/common/constraints/__init__.py +0 -0
- masa/common/constraints/base.py +275 -0
- masa/common/constraints/cmdp.py +133 -0
- masa/common/constraints/ltl_safety.py +648 -0
- masa/common/constraints/multi_agent/__init__.py +3 -0
- masa/common/constraints/multi_agent/cmg.py +257 -0
- masa/common/constraints/pctl.py +112 -0
- masa/common/constraints/prob.py +131 -0
- masa/common/constraints/reach_avoid.py +132 -0
- masa/common/dummy.py +6 -0
- masa/common/label_fn.py +16 -0
- masa/common/labelled_env.py +118 -0
- masa/common/labelled_pz_env.py +67 -0
- masa/common/layers.py +42 -0
- masa/common/ltl.py +572 -0
- masa/common/metrics.py +934 -0
- masa/common/on_policy_algorithm.py +314 -0
- masa/common/pctl.py +1774 -0
- masa/common/pettingzoo_record_video.py +353 -0
- masa/common/policies.py +286 -0
- masa/common/registry.py +38 -0
- masa/common/running_mean_std.py +40 -0
- masa/common/schedule.py +28 -0
- masa/common/utils.py +227 -0
- masa/common/wrappers.py +1480 -0
- masa/configs/__init__.py +0 -0
- masa/configs/algorithms/__init__.py +0 -0
- masa/configs/envs/__init__.py +0 -0
- masa/envs/__init__.py +0 -0
- masa/envs/continuous/__init__.py +0 -0
- masa/envs/continuous/base.py +8 -0
- masa/envs/continuous/cartpole.py +145 -0
- masa/envs/discrete/__init__.py +0 -0
- masa/envs/discrete/base.py +8 -0
- masa/envs/discrete/cartpole.py +146 -0
- masa/envs/discrete/conveyor_belt.py +333 -0
- masa/envs/discrete/island_navigation.py +291 -0
- masa/envs/discrete/mini_pacman_with_coins.py +208 -0
- masa/envs/discrete/pacman_with_coins.py +217 -0
- masa/envs/discrete/renderers/__init__.py +2 -0
- masa/envs/discrete/renderers/cartpole.py +295 -0
- masa/envs/discrete/renderers/pacman.py +92 -0
- masa/envs/discrete/sokoban.py +320 -0
- masa/envs/multiagent/matrix/_label_utils.py +22 -0
- masa/envs/multiagent/matrix/bertrand.py +402 -0
- masa/envs/multiagent/matrix/chicken.py +403 -0
- masa/envs/multiagent/matrix/congestion.py +496 -0
- masa/envs/multiagent/matrix/dpgg.py +448 -0
- masa/envs/multiagent/matrix/inspection.py +401 -0
- masa/envs/tabular/__init__.py +0 -0
- masa/envs/tabular/base.py +36 -0
- masa/envs/tabular/bridge_crossing.py +107 -0
- masa/envs/tabular/bridge_crossing_v2.py +107 -0
- masa/envs/tabular/colour_bomb_grid_world.py +121 -0
- masa/envs/tabular/colour_bomb_grid_world_v2.py +142 -0
- masa/envs/tabular/colour_bomb_grid_world_v3.py +168 -0
- masa/envs/tabular/colour_grid_world.py +106 -0
- masa/envs/tabular/media_streaming.py +115 -0
- masa/envs/tabular/mini_pacman.py +161 -0
- masa/envs/tabular/pacman.py +170 -0
- masa/envs/tabular/renderers/__init__.py +2 -0
- masa/envs/tabular/renderers/bridge_crossing.py +277 -0
- masa/envs/tabular/renderers/colour_bomb_grid_world.py +407 -0
- masa/envs/tabular/renderers/colour_grid_world.py +300 -0
- masa/envs/tabular/renderers/media_streaming.py +282 -0
- masa/envs/tabular/renderers/pacman.py +489 -0
- masa/envs/tabular/utils.py +414 -0
- masa/examples/__init__.py +0 -0
- masa/examples/colour_bomb_grid_world/__init__.py +0 -0
- masa/examples/colour_bomb_grid_world/property_1.py +23 -0
- masa/examples/colour_bomb_grid_world/property_2.py +30 -0
- masa/examples/colour_bomb_grid_world/property_3.py +89 -0
- masa/examples/norm_obs_example.py +116 -0
- masa/examples/prob_shield_cont_example.py +100 -0
- masa/examples/prob_shield_cont_ltl_example.py +144 -0
- masa/examples/prob_shield_example.py +107 -0
- masa/examples/prob_shield_ltl_example.py +105 -0
- masa/examples/prob_shield_safety_abstraction_example.py +109 -0
- masa/examples/prob_shield_vec_ltl_example.py +142 -0
- masa/examples/reward_shaping_example.py +103 -0
- masa/plugins/__init__.py +0 -0
- masa/plugins/helpers.py +12 -0
- masa/plugins/supported.py +52 -0
- masa/prob_shield/__init__.py +0 -0
- masa/prob_shield/eventual_discounted_vi.py +105 -0
- masa/prob_shield/helpers.py +177 -0
- masa/prob_shield/interval_bound_vi.py +127 -0
- masa/prob_shield/parameterized_policy.py +418 -0
- masa/prob_shield/parameterized_policy_v2.py +424 -0
- masa/prob_shield/parameterized_ppo.py +381 -0
- masa/prob_shield/parameterized_ppo_v2.py +267 -0
- masa/prob_shield/prob_shield_wrapper_v1.py +487 -0
- masa/prob_shield/prob_shield_wrapper_v2.py +443 -0
- masa/run.py +122 -0
- masa_safe_rl-0.1.0.dist-info/METADATA +151 -0
- masa_safe_rl-0.1.0.dist-info/RECORD +120 -0
- masa_safe_rl-0.1.0.dist-info/WHEEL +5 -0
- masa_safe_rl-0.1.0.dist-info/entry_points.txt +2 -0
- masa_safe_rl-0.1.0.dist-info/licenses/LICENSE +201 -0
- masa_safe_rl-0.1.0.dist-info/top_level.txt +1 -0
masa/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from masa.algorithms.a2c.a2c import A2C
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import jax.random as jr
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import optax
|
|
5
|
+
from jax import jit
|
|
6
|
+
import jax
|
|
7
|
+
from functools import partial
|
|
8
|
+
from flax.training.train_state import TrainState
|
|
9
|
+
import numpy as np
|
|
10
|
+
import gymnasium as gym
|
|
11
|
+
from gymnasium import spaces
|
|
12
|
+
from typing import Any, Optional, TypeVar, Union, Callable
|
|
13
|
+
from masa.common.base_class import BaseJaxPolicy
|
|
14
|
+
from masa.common.on_policy_algorithm import OnPolicyAlgorithm
|
|
15
|
+
from masa.common.policies import PPOPolicy
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
class A2C(OnPolicyAlgorithm):
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
env: gym.Env,
|
|
23
|
+
tensorboard_logdir: Optional[str] = None,
|
|
24
|
+
wandb_project: Optional[str] = None,
|
|
25
|
+
wandb_name: Optional[str] = None,
|
|
26
|
+
seed: Optional[int] = None,
|
|
27
|
+
monitor: bool = True,
|
|
28
|
+
device: str = "auto",
|
|
29
|
+
verbose: int = 0,
|
|
30
|
+
env_fn: Optional[Callable[[], gym.Env]] = None,
|
|
31
|
+
eval_env: Optional[gym.Env] = None,
|
|
32
|
+
learning_rate: Union[float, optax.Schedule] = 3e-4,
|
|
33
|
+
n_steps: int = 16,
|
|
34
|
+
gamma: float = 0.99,
|
|
35
|
+
gae_lambda: float = 0.95,
|
|
36
|
+
normalize_advantage: bool = False,
|
|
37
|
+
ent_coef: float = 0.0,
|
|
38
|
+
vf_coef: float = 1.0,
|
|
39
|
+
max_grad_norm: float = 0.5,
|
|
40
|
+
policy_class: type[BaseJaxPolicy] = PPOPolicy,
|
|
41
|
+
policy_kwargs: Optional[dict[str, Any]] = None,
|
|
42
|
+
):
|
|
43
|
+
|
|
44
|
+
super().__init__(
|
|
45
|
+
env,
|
|
46
|
+
tensorboard_logdir=tensorboard_logdir,
|
|
47
|
+
wandb_project=wandb_project,
|
|
48
|
+
wandb_name=wandb_name,
|
|
49
|
+
seed=seed,
|
|
50
|
+
monitor=monitor,
|
|
51
|
+
device=device,
|
|
52
|
+
verbose=verbose,
|
|
53
|
+
env_fn=env_fn,
|
|
54
|
+
eval_env=eval_env,
|
|
55
|
+
use_tqdm_rollout=False, # Turn off tqdm progress bar for rollout
|
|
56
|
+
learning_rate=learning_rate,
|
|
57
|
+
n_steps=n_steps,
|
|
58
|
+
gamma=gamma,
|
|
59
|
+
gae_lambda=gae_lambda,
|
|
60
|
+
ent_coef=ent_coef,
|
|
61
|
+
vf_coef=vf_coef,
|
|
62
|
+
max_grad_norm=max_grad_norm,
|
|
63
|
+
policy_class=policy_class,
|
|
64
|
+
policy_kwargs=policy_kwargs
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if normalize_advantage:
|
|
68
|
+
assert n_steps * self.n_envs > 1, "n_steps * n_envs must be > 1 when normalize_advantage = True"
|
|
69
|
+
|
|
70
|
+
self.normalize_advantage = normalize_advantage
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
@partial(jit, static_argnames=["normalize_advantage"])
|
|
74
|
+
def _one_update(
|
|
75
|
+
featurizer_state: TrainState,
|
|
76
|
+
actor_state: TrainState,
|
|
77
|
+
critic_state: TrainState,
|
|
78
|
+
observations: jnp.ndarray,
|
|
79
|
+
actions: jnp.ndarray,
|
|
80
|
+
advantages: jnp.ndarray,
|
|
81
|
+
returns: jnp.ndarray,
|
|
82
|
+
ent_coef: float,
|
|
83
|
+
vf_coef: float,
|
|
84
|
+
normalize_advantage: bool = True,
|
|
85
|
+
):
|
|
86
|
+
if normalize_advantage and len(advantages) > 1:
|
|
87
|
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
88
|
+
|
|
89
|
+
def actor_critic_loss(featurizer_params, actor_params, critic_params):
|
|
90
|
+
features = featurizer_state.apply_fn(featurizer_params, observations)
|
|
91
|
+
dist = actor_state.apply_fn(actor_params, features)
|
|
92
|
+
log_prob = dist.log_prob(actions)
|
|
93
|
+
entropy = dist.entropy()
|
|
94
|
+
|
|
95
|
+
# Policy gradient loss
|
|
96
|
+
policy_loss = -(advantages * log_prob).mean()
|
|
97
|
+
|
|
98
|
+
# Entropy loss favor exploration
|
|
99
|
+
# Approximate entropy when no analytical form
|
|
100
|
+
# entropy_loss = -jnp.mean(-log_prob)
|
|
101
|
+
# analytical form
|
|
102
|
+
entropy_loss = jnp.mean(-entropy)
|
|
103
|
+
|
|
104
|
+
total_policy_loss = policy_loss + ent_coef * entropy_loss
|
|
105
|
+
|
|
106
|
+
# Critic loss
|
|
107
|
+
critic_values = critic_state.apply_fn(critic_params, features).flatten()
|
|
108
|
+
value_loss = vf_coef * ((returns - critic_values)**2).mean()
|
|
109
|
+
|
|
110
|
+
total_loss = total_policy_loss + value_loss
|
|
111
|
+
return total_loss, (total_policy_loss, value_loss)
|
|
112
|
+
|
|
113
|
+
(loss, (pg_loss, vf_loss)), grads = jax.value_and_grad(actor_critic_loss, argnums=(0, 1, 2), has_aux=True)(
|
|
114
|
+
featurizer_state.params, actor_state.params, critic_state.params
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
featurizer_state = featurizer_state.apply_gradients(grads=grads[0])
|
|
118
|
+
actor_state = actor_state.apply_gradients(grads=grads[1])
|
|
119
|
+
critic_state = critic_state.apply_gradients(grads=grads[2])
|
|
120
|
+
|
|
121
|
+
return (featurizer_state, actor_state, critic_state), (pg_loss, vf_loss)
|
|
122
|
+
|
|
123
|
+
def optimize(
|
|
124
|
+
self,
|
|
125
|
+
step: int,
|
|
126
|
+
logger: Optional[TrainLogger] = None,
|
|
127
|
+
tqdm_position: int = 1 # unused
|
|
128
|
+
):
|
|
129
|
+
|
|
130
|
+
current_lr = self.lr_schedule(step)
|
|
131
|
+
|
|
132
|
+
self.key, subkey = jr.split(self.key)
|
|
133
|
+
for rollout_data in self.rollout_buffer.get(subkey, None):
|
|
134
|
+
observations, actions, rewards, values, returns, advantages, old_log_probs = rollout_data
|
|
135
|
+
|
|
136
|
+
if isinstance(self.action_space, spaces.Discrete):
|
|
137
|
+
# Convert discrete action from float to int
|
|
138
|
+
actions = actions.flatten().astype(np.int32)
|
|
139
|
+
|
|
140
|
+
(self.policy.featurizer_state, self.policy.actor_state, self.policy.critic_state), (pg_loss, vf_loss) = \
|
|
141
|
+
self._one_update(
|
|
142
|
+
featurizer_state=self.policy.featurizer_state,
|
|
143
|
+
actor_state=self.policy.actor_state,
|
|
144
|
+
critic_state=self.policy.critic_state,
|
|
145
|
+
observations=observations,
|
|
146
|
+
actions=actions,
|
|
147
|
+
advantages=advantages,
|
|
148
|
+
returns=returns,
|
|
149
|
+
ent_coef=self.ent_coef,
|
|
150
|
+
vf_coef=self.vf_coef,
|
|
151
|
+
normalize_advantage=self.normalize_advantage,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if logger:
|
|
155
|
+
logger.add("train/stats", {
|
|
156
|
+
"policy_loss": float(pg_loss),
|
|
157
|
+
"value_loss": float(vf_loss),
|
|
158
|
+
"lr": float(current_lr)
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def train_ratio(self):
|
|
163
|
+
return self.n_steps * self.n_envs
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from masa.algorithms.ppo.ppo import PPO
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import jax.random as jr
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import optax
|
|
5
|
+
from jax import jit
|
|
6
|
+
import jax
|
|
7
|
+
from functools import partial
|
|
8
|
+
from flax.training.train_state import TrainState
|
|
9
|
+
import numpy as np
|
|
10
|
+
import gymnasium as gym
|
|
11
|
+
from gymnasium import spaces
|
|
12
|
+
from typing import Any, Optional, TypeVar, Union, Callable
|
|
13
|
+
from masa.common.base_class import BaseJaxPolicy
|
|
14
|
+
from masa.common.on_policy_algorithm import OnPolicyAlgorithm
|
|
15
|
+
from masa.common.policies import PPOPolicy
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
class PPO(OnPolicyAlgorithm):
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
env: gym.Env,
|
|
23
|
+
tensorboard_logdir: Optional[str] = None,
|
|
24
|
+
wandb_project: Optional[str] = None,
|
|
25
|
+
wandb_name: Optional[str] = None,
|
|
26
|
+
seed: Optional[int] = None,
|
|
27
|
+
monitor: bool = True,
|
|
28
|
+
device: str = "auto",
|
|
29
|
+
verbose: int = 0,
|
|
30
|
+
env_fn: Optional[Callable[[], gym.Env]] = None,
|
|
31
|
+
eval_env: Optional[gym.Env] = None,
|
|
32
|
+
learning_rate: Union[float, optax.Schedule] = 3e-4,
|
|
33
|
+
n_steps: int = 2048,
|
|
34
|
+
batch_size: int = 64,
|
|
35
|
+
n_epochs: int = 10,
|
|
36
|
+
gamma: float = 0.99,
|
|
37
|
+
gae_lambda: float = 0.95,
|
|
38
|
+
clip_range: Union[float, optax.Schedule] = 0.2,
|
|
39
|
+
normalize_advantage: bool = True,
|
|
40
|
+
ent_coef: float = 0.0,
|
|
41
|
+
vf_coef: float = 1.0,
|
|
42
|
+
max_grad_norm: float = 0.5,
|
|
43
|
+
policy_class: type[BaseJaxPolicy] = PPOPolicy,
|
|
44
|
+
policy_kwargs: Optional[dict[str, Any]] = None,
|
|
45
|
+
):
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
env,
|
|
49
|
+
tensorboard_logdir=tensorboard_logdir,
|
|
50
|
+
wandb_project=wandb_project,
|
|
51
|
+
wandb_name=wandb_name,
|
|
52
|
+
seed=seed,
|
|
53
|
+
monitor=monitor,
|
|
54
|
+
device=device,
|
|
55
|
+
verbose=verbose,
|
|
56
|
+
env_fn=env_fn,
|
|
57
|
+
eval_env=eval_env,
|
|
58
|
+
use_tqdm_rollout=True, # Turn on tqdm progress bar for rollout
|
|
59
|
+
learning_rate=learning_rate,
|
|
60
|
+
n_steps=n_steps,
|
|
61
|
+
gamma=gamma,
|
|
62
|
+
gae_lambda=gae_lambda,
|
|
63
|
+
ent_coef=ent_coef,
|
|
64
|
+
vf_coef=vf_coef,
|
|
65
|
+
max_grad_norm=max_grad_norm,
|
|
66
|
+
policy_class=policy_class,
|
|
67
|
+
policy_kwargs=policy_kwargs
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if normalize_advantage:
|
|
71
|
+
assert batch_size > 1, "batch_size must be > 1 when normalize_advantage = True"
|
|
72
|
+
|
|
73
|
+
if isinstance(clip_range, float):
|
|
74
|
+
self.clip_range_schedule = optax.schedules.constant_schedule(clip_range)
|
|
75
|
+
else:
|
|
76
|
+
assert callable(clip_range), f"clip_range for class PPO must be float or optax.Schedule not {clip_range}"
|
|
77
|
+
self.clip_range_schedule = clip_range
|
|
78
|
+
|
|
79
|
+
self.normalize_advantage = normalize_advantage
|
|
80
|
+
self.batch_size = batch_size
|
|
81
|
+
self.n_epochs = n_epochs
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
@partial(jit, static_argnames=["normalize_advantage"])
|
|
85
|
+
def _one_update(
|
|
86
|
+
featurizer_state: TrainState,
|
|
87
|
+
actor_state: TrainState,
|
|
88
|
+
critic_state: TrainState,
|
|
89
|
+
observations: jnp.ndarray,
|
|
90
|
+
actions: jnp.ndarray,
|
|
91
|
+
advantages: jnp.ndarray,
|
|
92
|
+
returns: jnp.ndarray,
|
|
93
|
+
old_log_prob: jnp.ndarray,
|
|
94
|
+
clip_range: float,
|
|
95
|
+
ent_coef: float,
|
|
96
|
+
vf_coef: float,
|
|
97
|
+
normalize_advantage: bool = True,
|
|
98
|
+
):
|
|
99
|
+
if normalize_advantage and len(advantages) > 1:
|
|
100
|
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
101
|
+
|
|
102
|
+
def actor_critic_loss(featurizer_params, actor_params, critic_params):
|
|
103
|
+
features = featurizer_state.apply_fn(featurizer_params, observations)
|
|
104
|
+
dist = actor_state.apply_fn(actor_params, features)
|
|
105
|
+
log_prob = dist.log_prob(actions)
|
|
106
|
+
entropy = dist.entropy()
|
|
107
|
+
|
|
108
|
+
# ratio between old and new policy, should be one at the first iteration
|
|
109
|
+
ratio = jnp.exp(log_prob - old_log_prob)
|
|
110
|
+
# clipped surrogate loss
|
|
111
|
+
policy_loss_1 = advantages * ratio
|
|
112
|
+
policy_loss_2 = advantages * jnp.clip(ratio, 1 - clip_range, 1 + clip_range)
|
|
113
|
+
policy_loss = -jnp.minimum(policy_loss_1, policy_loss_2).mean()
|
|
114
|
+
|
|
115
|
+
# Entropy loss favor exploration
|
|
116
|
+
# Approximate entropy when no analytical form
|
|
117
|
+
# entropy_loss = -jnp.mean(-log_prob)
|
|
118
|
+
# analytical form
|
|
119
|
+
entropy_loss = -jnp.mean(entropy)
|
|
120
|
+
|
|
121
|
+
total_policy_loss = policy_loss + ent_coef * entropy_loss
|
|
122
|
+
|
|
123
|
+
# Critic loss
|
|
124
|
+
critic_values = critic_state.apply_fn(critic_params, features).flatten()
|
|
125
|
+
value_loss = vf_coef * ((returns - critic_values)**2).mean()
|
|
126
|
+
|
|
127
|
+
total_loss = total_policy_loss + value_loss
|
|
128
|
+
return total_loss, (total_policy_loss, value_loss)
|
|
129
|
+
|
|
130
|
+
(loss, (pg_loss, vf_loss)), grads = jax.value_and_grad(actor_critic_loss, argnums=(0, 1, 2), has_aux=True)(
|
|
131
|
+
featurizer_state.params, actor_state.params, critic_state.params
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
featurizer_state = featurizer_state.apply_gradients(grads=grads[0])
|
|
135
|
+
actor_state = actor_state.apply_gradients(grads=grads[1])
|
|
136
|
+
critic_state = critic_state.apply_gradients(grads=grads[2])
|
|
137
|
+
|
|
138
|
+
return (featurizer_state, actor_state, critic_state), (pg_loss, vf_loss)
|
|
139
|
+
|
|
140
|
+
def optimize(
|
|
141
|
+
self,
|
|
142
|
+
step: int,
|
|
143
|
+
logger: Optional[TrainLogger] = None,
|
|
144
|
+
tqdm_position: int = 1
|
|
145
|
+
):
|
|
146
|
+
|
|
147
|
+
clip_range = self.clip_range_schedule(step)
|
|
148
|
+
current_lr = self.lr_schedule(step)
|
|
149
|
+
|
|
150
|
+
with tqdm(
|
|
151
|
+
total=self.n_epochs*self.n_steps//(self.batch_size//self.n_envs),
|
|
152
|
+
desc="optimize",
|
|
153
|
+
position=tqdm_position,
|
|
154
|
+
leave=False,
|
|
155
|
+
dynamic_ncols=True,
|
|
156
|
+
colour="cyan",
|
|
157
|
+
) as pbar:
|
|
158
|
+
|
|
159
|
+
for _ in range(self.n_epochs):
|
|
160
|
+
self.key, subkey = jr.split(self.key)
|
|
161
|
+
for rollout_data in self.rollout_buffer.get(subkey, self.batch_size//self.n_envs):
|
|
162
|
+
|
|
163
|
+
observations, actions, rewards, values, returns, advantages, old_log_probs = rollout_data
|
|
164
|
+
|
|
165
|
+
if isinstance(self.action_space, spaces.Discrete):
|
|
166
|
+
# Convert discrete action from float to int
|
|
167
|
+
actions = actions.flatten().astype(np.int32)
|
|
168
|
+
|
|
169
|
+
(self.policy.featurizer_state, self.policy.actor_state, self.policy.critic_state), (pg_loss, vf_loss) = \
|
|
170
|
+
self._one_update(
|
|
171
|
+
featurizer_state=self.policy.featurizer_state,
|
|
172
|
+
actor_state=self.policy.actor_state,
|
|
173
|
+
critic_state=self.policy.critic_state,
|
|
174
|
+
observations=observations,
|
|
175
|
+
actions=actions,
|
|
176
|
+
advantages=advantages,
|
|
177
|
+
returns=returns,
|
|
178
|
+
old_log_prob=old_log_probs,
|
|
179
|
+
clip_range=clip_range,
|
|
180
|
+
ent_coef=self.ent_coef,
|
|
181
|
+
vf_coef=self.vf_coef,
|
|
182
|
+
normalize_advantage=self.normalize_advantage,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
pbar.update(1)
|
|
186
|
+
|
|
187
|
+
if logger:
|
|
188
|
+
logger.add("train/stats", {
|
|
189
|
+
"policy_loss": float(pg_loss),
|
|
190
|
+
"value_loss": float(vf_loss),
|
|
191
|
+
"clip_range": float(clip_range),
|
|
192
|
+
"lr": float(current_lr)
|
|
193
|
+
})
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def train_ratio(self):
|
|
197
|
+
return self.n_steps * self.n_envs
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from masa.algorithms.tabular.q_learning import QL
|
|
2
|
+
from masa.algorithms.tabular.q_learning_lambda import QL_Lambda
|
|
3
|
+
from masa.algorithms.tabular.sem import SEM
|
|
4
|
+
from masa.algorithms.tabular.lcrl import LCRL
|
|
5
|
+
from masa.algorithms.tabular.recreg import RECREG
|
|
6
|
+
from masa.algorithms.tabular.recovery_rl import RECOVERY_RL
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Optional, TypeVar, Union, Callable
|
|
3
|
+
from masa.common.base_class import BaseAlgorithm
|
|
4
|
+
import gymnasium as gym
|
|
5
|
+
from gymnasium import spaces
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TabularAlgorithm(BaseAlgorithm):
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
env: gym.Env,
|
|
13
|
+
tensorboard_logdir: Optional[str] = None,
|
|
14
|
+
wandb_project: Optional[str] = None,
|
|
15
|
+
wandb_name: Optional[str] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
17
|
+
monitor: bool = True,
|
|
18
|
+
device: str = "auto",
|
|
19
|
+
verbose: int = 0,
|
|
20
|
+
env_fn: Optional[Callable[[], gym.Env]] = None,
|
|
21
|
+
eval_env: Optional[gym.Env] = None,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
super().__init__(
|
|
25
|
+
env,
|
|
26
|
+
tensorboard_logdir=tensorboard_logdir,
|
|
27
|
+
wandb_project=wandb_project,
|
|
28
|
+
wandb_name=wandb_name,
|
|
29
|
+
seed=seed,
|
|
30
|
+
monitor=monitor,
|
|
31
|
+
device=device,
|
|
32
|
+
verbose=verbose,
|
|
33
|
+
supported_action_spaces=(spaces.Discrete,),
|
|
34
|
+
supported_observation_spaces=(spaces.Discrete,),
|
|
35
|
+
env_fn=env_fn,
|
|
36
|
+
eval_env=eval_env,
|
|
37
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Optional, TypeVar, Union, Callable
|
|
3
|
+
from masa.common.metrics import TrainLogger
|
|
4
|
+
from masa.algorithms.tabular.q_learning import QL
|
|
5
|
+
from masa.common.ltl import DFACostFn, DFA
|
|
6
|
+
from gymnasium import spaces
|
|
7
|
+
import gymnasium as gym
|
|
8
|
+
import numpy as np
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import jax.random as jr
|
|
11
|
+
from jax import jit
|
|
12
|
+
from functools import partial
|
|
13
|
+
|
|
14
|
+
class LCRL(QL):
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
env: gym.Env,
|
|
19
|
+
tensorboard_logdir: Optional[str] = None,
|
|
20
|
+
wandb_project: Optional[str] = None,
|
|
21
|
+
wandb_name: Optional[str] = None,
|
|
22
|
+
seed: Optional[int] = None,
|
|
23
|
+
monitor: bool = True,
|
|
24
|
+
device: str = "auto",
|
|
25
|
+
verbose: int = 0,
|
|
26
|
+
env_fn: Optional[Callable[[], gym.Env]] = None,
|
|
27
|
+
eval_env: Optional[gym.Env] = None,
|
|
28
|
+
alpha: float = 0.1,
|
|
29
|
+
gamma: float = 0.9,
|
|
30
|
+
r_min: float = 0.0,
|
|
31
|
+
exploration: str = 'boltzmann',
|
|
32
|
+
boltzmann_temp: float = 0.05,
|
|
33
|
+
initial_epsilon: float = 1.0,
|
|
34
|
+
final_epsilon: float = 0.1,
|
|
35
|
+
epsilon_decay: str = 'linear',
|
|
36
|
+
epsilon_decay_frames: int = 10000,
|
|
37
|
+
):
|
|
38
|
+
|
|
39
|
+
super().__init__(
|
|
40
|
+
env,
|
|
41
|
+
tensorboard_logdir=tensorboard_logdir,
|
|
42
|
+
wandb_project=wandb_project,
|
|
43
|
+
wandb_name=wandb_name,
|
|
44
|
+
seed=seed,
|
|
45
|
+
monitor=monitor,
|
|
46
|
+
device=device,
|
|
47
|
+
verbose=verbose,
|
|
48
|
+
env_fn=env_fn,
|
|
49
|
+
eval_env=eval_env,
|
|
50
|
+
alpha=alpha,
|
|
51
|
+
gamma=gamma,
|
|
52
|
+
exploration=exploration,
|
|
53
|
+
boltzmann_temp=boltzmann_temp,
|
|
54
|
+
initial_epsilon=initial_epsilon,
|
|
55
|
+
final_epsilon=final_epsilon,
|
|
56
|
+
epsilon_decay=epsilon_decay,
|
|
57
|
+
epsilon_decay_frames=epsilon_decay_frames,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.r_min = r_min
|
|
61
|
+
|
|
62
|
+
def optimize(self, step: int, logger: Optional[TrainLogger] = None):
|
|
63
|
+
"""Update the Q table with tuples of experience"""
|
|
64
|
+
if len(self.buffer) == 0:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
for (state, action, reward, _, violation, next_state, terminal) in self.buffer:
|
|
68
|
+
|
|
69
|
+
current = self.Q[next_state]
|
|
70
|
+
self.Q[state, action] = (1 - self.alpha) * self.Q[state, action] \
|
|
71
|
+
+ self.alpha * (reward * (1 - violation) + float(violation) * (self.r_min / (1.0 - self.gamma)) \
|
|
72
|
+
+ (1 - violation) * (1 - terminal) * self.gamma * np.max(current))
|
|
73
|
+
|
|
74
|
+
self.buffer.clear()
|
|
75
|
+
|
|
76
|
+
if logger:
|
|
77
|
+
logger.add("train/stats", {"alpha": self.alpha})
|
|
78
|
+
if self.exploration == "boltzmann":
|
|
79
|
+
logger.add("train/stats", {"temp": self.boltzmann_temp})
|
|
80
|
+
if self.exploration == "epsilon_greedy":
|
|
81
|
+
logger.add("train/stats", {"epsilon": self._epsilon})
|