MASA-Safe-RL 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. masa/__init__.py +0 -0
  2. masa/algorithms/__init__.py +0 -0
  3. masa/algorithms/a2c/__init__.py +1 -0
  4. masa/algorithms/a2c/a2c.py +163 -0
  5. masa/algorithms/ppo/__init__.py +1 -0
  6. masa/algorithms/ppo/ppo.py +197 -0
  7. masa/algorithms/tabular/__init__.py +6 -0
  8. masa/algorithms/tabular/base.py +37 -0
  9. masa/algorithms/tabular/lcrl.py +81 -0
  10. masa/algorithms/tabular/q_learning.py +207 -0
  11. masa/algorithms/tabular/q_learning_lambda.py +84 -0
  12. masa/algorithms/tabular/recovery_rl.py +227 -0
  13. masa/algorithms/tabular/recreg.py +559 -0
  14. masa/algorithms/tabular/sem.py +150 -0
  15. masa/cli/__init__.py +0 -0
  16. masa/cli/cli_app.py +262 -0
  17. masa/common/__init__.py +0 -0
  18. masa/common/base_class.py +241 -0
  19. masa/common/buffers.py +139 -0
  20. masa/common/configs.py +516 -0
  21. masa/common/constraints/__init__.py +0 -0
  22. masa/common/constraints/base.py +275 -0
  23. masa/common/constraints/cmdp.py +133 -0
  24. masa/common/constraints/ltl_safety.py +648 -0
  25. masa/common/constraints/multi_agent/__init__.py +3 -0
  26. masa/common/constraints/multi_agent/cmg.py +257 -0
  27. masa/common/constraints/pctl.py +112 -0
  28. masa/common/constraints/prob.py +131 -0
  29. masa/common/constraints/reach_avoid.py +132 -0
  30. masa/common/dummy.py +6 -0
  31. masa/common/label_fn.py +16 -0
  32. masa/common/labelled_env.py +118 -0
  33. masa/common/labelled_pz_env.py +67 -0
  34. masa/common/layers.py +42 -0
  35. masa/common/ltl.py +572 -0
  36. masa/common/metrics.py +934 -0
  37. masa/common/on_policy_algorithm.py +314 -0
  38. masa/common/pctl.py +1774 -0
  39. masa/common/pettingzoo_record_video.py +353 -0
  40. masa/common/policies.py +286 -0
  41. masa/common/registry.py +38 -0
  42. masa/common/running_mean_std.py +40 -0
  43. masa/common/schedule.py +28 -0
  44. masa/common/utils.py +227 -0
  45. masa/common/wrappers.py +1480 -0
  46. masa/configs/__init__.py +0 -0
  47. masa/configs/algorithms/__init__.py +0 -0
  48. masa/configs/envs/__init__.py +0 -0
  49. masa/envs/__init__.py +0 -0
  50. masa/envs/continuous/__init__.py +0 -0
  51. masa/envs/continuous/base.py +8 -0
  52. masa/envs/continuous/cartpole.py +145 -0
  53. masa/envs/discrete/__init__.py +0 -0
  54. masa/envs/discrete/base.py +8 -0
  55. masa/envs/discrete/cartpole.py +146 -0
  56. masa/envs/discrete/conveyor_belt.py +333 -0
  57. masa/envs/discrete/island_navigation.py +291 -0
  58. masa/envs/discrete/mini_pacman_with_coins.py +208 -0
  59. masa/envs/discrete/pacman_with_coins.py +217 -0
  60. masa/envs/discrete/renderers/__init__.py +2 -0
  61. masa/envs/discrete/renderers/cartpole.py +295 -0
  62. masa/envs/discrete/renderers/pacman.py +92 -0
  63. masa/envs/discrete/sokoban.py +320 -0
  64. masa/envs/multiagent/matrix/_label_utils.py +22 -0
  65. masa/envs/multiagent/matrix/bertrand.py +402 -0
  66. masa/envs/multiagent/matrix/chicken.py +403 -0
  67. masa/envs/multiagent/matrix/congestion.py +496 -0
  68. masa/envs/multiagent/matrix/dpgg.py +448 -0
  69. masa/envs/multiagent/matrix/inspection.py +401 -0
  70. masa/envs/tabular/__init__.py +0 -0
  71. masa/envs/tabular/base.py +36 -0
  72. masa/envs/tabular/bridge_crossing.py +107 -0
  73. masa/envs/tabular/bridge_crossing_v2.py +107 -0
  74. masa/envs/tabular/colour_bomb_grid_world.py +121 -0
  75. masa/envs/tabular/colour_bomb_grid_world_v2.py +142 -0
  76. masa/envs/tabular/colour_bomb_grid_world_v3.py +168 -0
  77. masa/envs/tabular/colour_grid_world.py +106 -0
  78. masa/envs/tabular/media_streaming.py +115 -0
  79. masa/envs/tabular/mini_pacman.py +161 -0
  80. masa/envs/tabular/pacman.py +170 -0
  81. masa/envs/tabular/renderers/__init__.py +2 -0
  82. masa/envs/tabular/renderers/bridge_crossing.py +277 -0
  83. masa/envs/tabular/renderers/colour_bomb_grid_world.py +407 -0
  84. masa/envs/tabular/renderers/colour_grid_world.py +300 -0
  85. masa/envs/tabular/renderers/media_streaming.py +282 -0
  86. masa/envs/tabular/renderers/pacman.py +489 -0
  87. masa/envs/tabular/utils.py +414 -0
  88. masa/examples/__init__.py +0 -0
  89. masa/examples/colour_bomb_grid_world/__init__.py +0 -0
  90. masa/examples/colour_bomb_grid_world/property_1.py +23 -0
  91. masa/examples/colour_bomb_grid_world/property_2.py +30 -0
  92. masa/examples/colour_bomb_grid_world/property_3.py +89 -0
  93. masa/examples/norm_obs_example.py +116 -0
  94. masa/examples/prob_shield_cont_example.py +100 -0
  95. masa/examples/prob_shield_cont_ltl_example.py +144 -0
  96. masa/examples/prob_shield_example.py +107 -0
  97. masa/examples/prob_shield_ltl_example.py +105 -0
  98. masa/examples/prob_shield_safety_abstraction_example.py +109 -0
  99. masa/examples/prob_shield_vec_ltl_example.py +142 -0
  100. masa/examples/reward_shaping_example.py +103 -0
  101. masa/plugins/__init__.py +0 -0
  102. masa/plugins/helpers.py +12 -0
  103. masa/plugins/supported.py +52 -0
  104. masa/prob_shield/__init__.py +0 -0
  105. masa/prob_shield/eventual_discounted_vi.py +105 -0
  106. masa/prob_shield/helpers.py +177 -0
  107. masa/prob_shield/interval_bound_vi.py +127 -0
  108. masa/prob_shield/parameterized_policy.py +418 -0
  109. masa/prob_shield/parameterized_policy_v2.py +424 -0
  110. masa/prob_shield/parameterized_ppo.py +381 -0
  111. masa/prob_shield/parameterized_ppo_v2.py +267 -0
  112. masa/prob_shield/prob_shield_wrapper_v1.py +487 -0
  113. masa/prob_shield/prob_shield_wrapper_v2.py +443 -0
  114. masa/run.py +122 -0
  115. masa_safe_rl-0.1.0.dist-info/METADATA +151 -0
  116. masa_safe_rl-0.1.0.dist-info/RECORD +120 -0
  117. masa_safe_rl-0.1.0.dist-info/WHEEL +5 -0
  118. masa_safe_rl-0.1.0.dist-info/entry_points.txt +2 -0
  119. masa_safe_rl-0.1.0.dist-info/licenses/LICENSE +201 -0
  120. masa_safe_rl-0.1.0.dist-info/top_level.txt +1 -0
masa/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1 @@
1
+ from masa.algorithms.a2c.a2c import A2C
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+ import jax.random as jr
3
+ import jax.numpy as jnp
4
+ import optax
5
+ from jax import jit
6
+ import jax
7
+ from functools import partial
8
+ from flax.training.train_state import TrainState
9
+ import numpy as np
10
+ import gymnasium as gym
11
+ from gymnasium import spaces
12
+ from typing import Any, Optional, TypeVar, Union, Callable
13
+ from masa.common.base_class import BaseJaxPolicy
14
+ from masa.common.on_policy_algorithm import OnPolicyAlgorithm
15
+ from masa.common.policies import PPOPolicy
16
+ from tqdm.auto import tqdm
17
+
18
+ class A2C(OnPolicyAlgorithm):
19
+
20
+ def __init__(
21
+ self,
22
+ env: gym.Env,
23
+ tensorboard_logdir: Optional[str] = None,
24
+ wandb_project: Optional[str] = None,
25
+ wandb_name: Optional[str] = None,
26
+ seed: Optional[int] = None,
27
+ monitor: bool = True,
28
+ device: str = "auto",
29
+ verbose: int = 0,
30
+ env_fn: Optional[Callable[[], gym.Env]] = None,
31
+ eval_env: Optional[gym.Env] = None,
32
+ learning_rate: Union[float, optax.Schedule] = 3e-4,
33
+ n_steps: int = 16,
34
+ gamma: float = 0.99,
35
+ gae_lambda: float = 0.95,
36
+ normalize_advantage: bool = False,
37
+ ent_coef: float = 0.0,
38
+ vf_coef: float = 1.0,
39
+ max_grad_norm: float = 0.5,
40
+ policy_class: type[BaseJaxPolicy] = PPOPolicy,
41
+ policy_kwargs: Optional[dict[str, Any]] = None,
42
+ ):
43
+
44
+ super().__init__(
45
+ env,
46
+ tensorboard_logdir=tensorboard_logdir,
47
+ wandb_project=wandb_project,
48
+ wandb_name=wandb_name,
49
+ seed=seed,
50
+ monitor=monitor,
51
+ device=device,
52
+ verbose=verbose,
53
+ env_fn=env_fn,
54
+ eval_env=eval_env,
55
+ use_tqdm_rollout=False, # Turn off tqdm progress bar for rollout
56
+ learning_rate=learning_rate,
57
+ n_steps=n_steps,
58
+ gamma=gamma,
59
+ gae_lambda=gae_lambda,
60
+ ent_coef=ent_coef,
61
+ vf_coef=vf_coef,
62
+ max_grad_norm=max_grad_norm,
63
+ policy_class=policy_class,
64
+ policy_kwargs=policy_kwargs
65
+ )
66
+
67
+ if normalize_advantage:
68
+ assert n_steps * self.n_envs > 1, "n_steps * n_envs must be > 1 when normalize_advantage = True"
69
+
70
+ self.normalize_advantage = normalize_advantage
71
+
72
+ @staticmethod
73
+ @partial(jit, static_argnames=["normalize_advantage"])
74
+ def _one_update(
75
+ featurizer_state: TrainState,
76
+ actor_state: TrainState,
77
+ critic_state: TrainState,
78
+ observations: jnp.ndarray,
79
+ actions: jnp.ndarray,
80
+ advantages: jnp.ndarray,
81
+ returns: jnp.ndarray,
82
+ ent_coef: float,
83
+ vf_coef: float,
84
+ normalize_advantage: bool = True,
85
+ ):
86
+ if normalize_advantage and len(advantages) > 1:
87
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
88
+
89
+ def actor_critic_loss(featurizer_params, actor_params, critic_params):
90
+ features = featurizer_state.apply_fn(featurizer_params, observations)
91
+ dist = actor_state.apply_fn(actor_params, features)
92
+ log_prob = dist.log_prob(actions)
93
+ entropy = dist.entropy()
94
+
95
+ # Policy gradient loss
96
+ policy_loss = -(advantages * log_prob).mean()
97
+
98
+ # Entropy loss favor exploration
99
+ # Approximate entropy when no analytical form
100
+ # entropy_loss = -jnp.mean(-log_prob)
101
+ # analytical form
102
+ entropy_loss = jnp.mean(-entropy)
103
+
104
+ total_policy_loss = policy_loss + ent_coef * entropy_loss
105
+
106
+ # Critic loss
107
+ critic_values = critic_state.apply_fn(critic_params, features).flatten()
108
+ value_loss = vf_coef * ((returns - critic_values)**2).mean()
109
+
110
+ total_loss = total_policy_loss + value_loss
111
+ return total_loss, (total_policy_loss, value_loss)
112
+
113
+ (loss, (pg_loss, vf_loss)), grads = jax.value_and_grad(actor_critic_loss, argnums=(0, 1, 2), has_aux=True)(
114
+ featurizer_state.params, actor_state.params, critic_state.params
115
+ )
116
+
117
+ featurizer_state = featurizer_state.apply_gradients(grads=grads[0])
118
+ actor_state = actor_state.apply_gradients(grads=grads[1])
119
+ critic_state = critic_state.apply_gradients(grads=grads[2])
120
+
121
+ return (featurizer_state, actor_state, critic_state), (pg_loss, vf_loss)
122
+
123
+ def optimize(
124
+ self,
125
+ step: int,
126
+ logger: Optional[TrainLogger] = None,
127
+ tqdm_position: int = 1 # unused
128
+ ):
129
+
130
+ current_lr = self.lr_schedule(step)
131
+
132
+ self.key, subkey = jr.split(self.key)
133
+ for rollout_data in self.rollout_buffer.get(subkey, None):
134
+ observations, actions, rewards, values, returns, advantages, old_log_probs = rollout_data
135
+
136
+ if isinstance(self.action_space, spaces.Discrete):
137
+ # Convert discrete action from float to int
138
+ actions = actions.flatten().astype(np.int32)
139
+
140
+ (self.policy.featurizer_state, self.policy.actor_state, self.policy.critic_state), (pg_loss, vf_loss) = \
141
+ self._one_update(
142
+ featurizer_state=self.policy.featurizer_state,
143
+ actor_state=self.policy.actor_state,
144
+ critic_state=self.policy.critic_state,
145
+ observations=observations,
146
+ actions=actions,
147
+ advantages=advantages,
148
+ returns=returns,
149
+ ent_coef=self.ent_coef,
150
+ vf_coef=self.vf_coef,
151
+ normalize_advantage=self.normalize_advantage,
152
+ )
153
+
154
+ if logger:
155
+ logger.add("train/stats", {
156
+ "policy_loss": float(pg_loss),
157
+ "value_loss": float(vf_loss),
158
+ "lr": float(current_lr)
159
+ })
160
+
161
+ @property
162
+ def train_ratio(self):
163
+ return self.n_steps * self.n_envs
@@ -0,0 +1 @@
1
+ from masa.algorithms.ppo.ppo import PPO
@@ -0,0 +1,197 @@
1
+ from __future__ import annotations
2
+ import jax.random as jr
3
+ import jax.numpy as jnp
4
+ import optax
5
+ from jax import jit
6
+ import jax
7
+ from functools import partial
8
+ from flax.training.train_state import TrainState
9
+ import numpy as np
10
+ import gymnasium as gym
11
+ from gymnasium import spaces
12
+ from typing import Any, Optional, TypeVar, Union, Callable
13
+ from masa.common.base_class import BaseJaxPolicy
14
+ from masa.common.on_policy_algorithm import OnPolicyAlgorithm
15
+ from masa.common.policies import PPOPolicy
16
+ from tqdm.auto import tqdm
17
+
18
+ class PPO(OnPolicyAlgorithm):
19
+
20
+ def __init__(
21
+ self,
22
+ env: gym.Env,
23
+ tensorboard_logdir: Optional[str] = None,
24
+ wandb_project: Optional[str] = None,
25
+ wandb_name: Optional[str] = None,
26
+ seed: Optional[int] = None,
27
+ monitor: bool = True,
28
+ device: str = "auto",
29
+ verbose: int = 0,
30
+ env_fn: Optional[Callable[[], gym.Env]] = None,
31
+ eval_env: Optional[gym.Env] = None,
32
+ learning_rate: Union[float, optax.Schedule] = 3e-4,
33
+ n_steps: int = 2048,
34
+ batch_size: int = 64,
35
+ n_epochs: int = 10,
36
+ gamma: float = 0.99,
37
+ gae_lambda: float = 0.95,
38
+ clip_range: Union[float, optax.Schedule] = 0.2,
39
+ normalize_advantage: bool = True,
40
+ ent_coef: float = 0.0,
41
+ vf_coef: float = 1.0,
42
+ max_grad_norm: float = 0.5,
43
+ policy_class: type[BaseJaxPolicy] = PPOPolicy,
44
+ policy_kwargs: Optional[dict[str, Any]] = None,
45
+ ):
46
+
47
+ super().__init__(
48
+ env,
49
+ tensorboard_logdir=tensorboard_logdir,
50
+ wandb_project=wandb_project,
51
+ wandb_name=wandb_name,
52
+ seed=seed,
53
+ monitor=monitor,
54
+ device=device,
55
+ verbose=verbose,
56
+ env_fn=env_fn,
57
+ eval_env=eval_env,
58
+ use_tqdm_rollout=True, # Turn on tqdm progress bar for rollout
59
+ learning_rate=learning_rate,
60
+ n_steps=n_steps,
61
+ gamma=gamma,
62
+ gae_lambda=gae_lambda,
63
+ ent_coef=ent_coef,
64
+ vf_coef=vf_coef,
65
+ max_grad_norm=max_grad_norm,
66
+ policy_class=policy_class,
67
+ policy_kwargs=policy_kwargs
68
+ )
69
+
70
+ if normalize_advantage:
71
+ assert batch_size > 1, "batch_size must be > 1 when normalize_advantage = True"
72
+
73
+ if isinstance(clip_range, float):
74
+ self.clip_range_schedule = optax.schedules.constant_schedule(clip_range)
75
+ else:
76
+ assert callable(clip_range), f"clip_range for class PPO must be float or optax.Schedule not {clip_range}"
77
+ self.clip_range_schedule = clip_range
78
+
79
+ self.normalize_advantage = normalize_advantage
80
+ self.batch_size = batch_size
81
+ self.n_epochs = n_epochs
82
+
83
+ @staticmethod
84
+ @partial(jit, static_argnames=["normalize_advantage"])
85
+ def _one_update(
86
+ featurizer_state: TrainState,
87
+ actor_state: TrainState,
88
+ critic_state: TrainState,
89
+ observations: jnp.ndarray,
90
+ actions: jnp.ndarray,
91
+ advantages: jnp.ndarray,
92
+ returns: jnp.ndarray,
93
+ old_log_prob: jnp.ndarray,
94
+ clip_range: float,
95
+ ent_coef: float,
96
+ vf_coef: float,
97
+ normalize_advantage: bool = True,
98
+ ):
99
+ if normalize_advantage and len(advantages) > 1:
100
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
101
+
102
+ def actor_critic_loss(featurizer_params, actor_params, critic_params):
103
+ features = featurizer_state.apply_fn(featurizer_params, observations)
104
+ dist = actor_state.apply_fn(actor_params, features)
105
+ log_prob = dist.log_prob(actions)
106
+ entropy = dist.entropy()
107
+
108
+ # ratio between old and new policy, should be one at the first iteration
109
+ ratio = jnp.exp(log_prob - old_log_prob)
110
+ # clipped surrogate loss
111
+ policy_loss_1 = advantages * ratio
112
+ policy_loss_2 = advantages * jnp.clip(ratio, 1 - clip_range, 1 + clip_range)
113
+ policy_loss = -jnp.minimum(policy_loss_1, policy_loss_2).mean()
114
+
115
+ # Entropy loss favor exploration
116
+ # Approximate entropy when no analytical form
117
+ # entropy_loss = -jnp.mean(-log_prob)
118
+ # analytical form
119
+ entropy_loss = -jnp.mean(entropy)
120
+
121
+ total_policy_loss = policy_loss + ent_coef * entropy_loss
122
+
123
+ # Critic loss
124
+ critic_values = critic_state.apply_fn(critic_params, features).flatten()
125
+ value_loss = vf_coef * ((returns - critic_values)**2).mean()
126
+
127
+ total_loss = total_policy_loss + value_loss
128
+ return total_loss, (total_policy_loss, value_loss)
129
+
130
+ (loss, (pg_loss, vf_loss)), grads = jax.value_and_grad(actor_critic_loss, argnums=(0, 1, 2), has_aux=True)(
131
+ featurizer_state.params, actor_state.params, critic_state.params
132
+ )
133
+
134
+ featurizer_state = featurizer_state.apply_gradients(grads=grads[0])
135
+ actor_state = actor_state.apply_gradients(grads=grads[1])
136
+ critic_state = critic_state.apply_gradients(grads=grads[2])
137
+
138
+ return (featurizer_state, actor_state, critic_state), (pg_loss, vf_loss)
139
+
140
+ def optimize(
141
+ self,
142
+ step: int,
143
+ logger: Optional[TrainLogger] = None,
144
+ tqdm_position: int = 1
145
+ ):
146
+
147
+ clip_range = self.clip_range_schedule(step)
148
+ current_lr = self.lr_schedule(step)
149
+
150
+ with tqdm(
151
+ total=self.n_epochs*self.n_steps//(self.batch_size//self.n_envs),
152
+ desc="optimize",
153
+ position=tqdm_position,
154
+ leave=False,
155
+ dynamic_ncols=True,
156
+ colour="cyan",
157
+ ) as pbar:
158
+
159
+ for _ in range(self.n_epochs):
160
+ self.key, subkey = jr.split(self.key)
161
+ for rollout_data in self.rollout_buffer.get(subkey, self.batch_size//self.n_envs):
162
+
163
+ observations, actions, rewards, values, returns, advantages, old_log_probs = rollout_data
164
+
165
+ if isinstance(self.action_space, spaces.Discrete):
166
+ # Convert discrete action from float to int
167
+ actions = actions.flatten().astype(np.int32)
168
+
169
+ (self.policy.featurizer_state, self.policy.actor_state, self.policy.critic_state), (pg_loss, vf_loss) = \
170
+ self._one_update(
171
+ featurizer_state=self.policy.featurizer_state,
172
+ actor_state=self.policy.actor_state,
173
+ critic_state=self.policy.critic_state,
174
+ observations=observations,
175
+ actions=actions,
176
+ advantages=advantages,
177
+ returns=returns,
178
+ old_log_prob=old_log_probs,
179
+ clip_range=clip_range,
180
+ ent_coef=self.ent_coef,
181
+ vf_coef=self.vf_coef,
182
+ normalize_advantage=self.normalize_advantage,
183
+ )
184
+
185
+ pbar.update(1)
186
+
187
+ if logger:
188
+ logger.add("train/stats", {
189
+ "policy_loss": float(pg_loss),
190
+ "value_loss": float(vf_loss),
191
+ "clip_range": float(clip_range),
192
+ "lr": float(current_lr)
193
+ })
194
+
195
+ @property
196
+ def train_ratio(self):
197
+ return self.n_steps * self.n_envs
@@ -0,0 +1,6 @@
1
+ from masa.algorithms.tabular.q_learning import QL
2
+ from masa.algorithms.tabular.q_learning_lambda import QL_Lambda
3
+ from masa.algorithms.tabular.sem import SEM
4
+ from masa.algorithms.tabular.lcrl import LCRL
5
+ from masa.algorithms.tabular.recreg import RECREG
6
+ from masa.algorithms.tabular.recovery_rl import RECOVERY_RL
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+ from typing import Any, Optional, TypeVar, Union, Callable
3
+ from masa.common.base_class import BaseAlgorithm
4
+ import gymnasium as gym
5
+ from gymnasium import spaces
6
+
7
+
8
+ class TabularAlgorithm(BaseAlgorithm):
9
+
10
+ def __init__(
11
+ self,
12
+ env: gym.Env,
13
+ tensorboard_logdir: Optional[str] = None,
14
+ wandb_project: Optional[str] = None,
15
+ wandb_name: Optional[str] = None,
16
+ seed: Optional[int] = None,
17
+ monitor: bool = True,
18
+ device: str = "auto",
19
+ verbose: int = 0,
20
+ env_fn: Optional[Callable[[], gym.Env]] = None,
21
+ eval_env: Optional[gym.Env] = None,
22
+ ):
23
+
24
+ super().__init__(
25
+ env,
26
+ tensorboard_logdir=tensorboard_logdir,
27
+ wandb_project=wandb_project,
28
+ wandb_name=wandb_name,
29
+ seed=seed,
30
+ monitor=monitor,
31
+ device=device,
32
+ verbose=verbose,
33
+ supported_action_spaces=(spaces.Discrete,),
34
+ supported_observation_spaces=(spaces.Discrete,),
35
+ env_fn=env_fn,
36
+ eval_env=eval_env,
37
+ )
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+ from typing import Any, Optional, TypeVar, Union, Callable
3
+ from masa.common.metrics import TrainLogger
4
+ from masa.algorithms.tabular.q_learning import QL
5
+ from masa.common.ltl import DFACostFn, DFA
6
+ from gymnasium import spaces
7
+ import gymnasium as gym
8
+ import numpy as np
9
+ import jax.numpy as jnp
10
+ import jax.random as jr
11
+ from jax import jit
12
+ from functools import partial
13
+
14
+ class LCRL(QL):
15
+
16
+ def __init__(
17
+ self,
18
+ env: gym.Env,
19
+ tensorboard_logdir: Optional[str] = None,
20
+ wandb_project: Optional[str] = None,
21
+ wandb_name: Optional[str] = None,
22
+ seed: Optional[int] = None,
23
+ monitor: bool = True,
24
+ device: str = "auto",
25
+ verbose: int = 0,
26
+ env_fn: Optional[Callable[[], gym.Env]] = None,
27
+ eval_env: Optional[gym.Env] = None,
28
+ alpha: float = 0.1,
29
+ gamma: float = 0.9,
30
+ r_min: float = 0.0,
31
+ exploration: str = 'boltzmann',
32
+ boltzmann_temp: float = 0.05,
33
+ initial_epsilon: float = 1.0,
34
+ final_epsilon: float = 0.1,
35
+ epsilon_decay: str = 'linear',
36
+ epsilon_decay_frames: int = 10000,
37
+ ):
38
+
39
+ super().__init__(
40
+ env,
41
+ tensorboard_logdir=tensorboard_logdir,
42
+ wandb_project=wandb_project,
43
+ wandb_name=wandb_name,
44
+ seed=seed,
45
+ monitor=monitor,
46
+ device=device,
47
+ verbose=verbose,
48
+ env_fn=env_fn,
49
+ eval_env=eval_env,
50
+ alpha=alpha,
51
+ gamma=gamma,
52
+ exploration=exploration,
53
+ boltzmann_temp=boltzmann_temp,
54
+ initial_epsilon=initial_epsilon,
55
+ final_epsilon=final_epsilon,
56
+ epsilon_decay=epsilon_decay,
57
+ epsilon_decay_frames=epsilon_decay_frames,
58
+ )
59
+
60
+ self.r_min = r_min
61
+
62
+ def optimize(self, step: int, logger: Optional[TrainLogger] = None):
63
+ """Update the Q table with tuples of experience"""
64
+ if len(self.buffer) == 0:
65
+ return
66
+
67
+ for (state, action, reward, _, violation, next_state, terminal) in self.buffer:
68
+
69
+ current = self.Q[next_state]
70
+ self.Q[state, action] = (1 - self.alpha) * self.Q[state, action] \
71
+ + self.alpha * (reward * (1 - violation) + float(violation) * (self.r_min / (1.0 - self.gamma)) \
72
+ + (1 - violation) * (1 - terminal) * self.gamma * np.max(current))
73
+
74
+ self.buffer.clear()
75
+
76
+ if logger:
77
+ logger.add("train/stats", {"alpha": self.alpha})
78
+ if self.exploration == "boltzmann":
79
+ logger.add("train/stats", {"temp": self.boltzmann_temp})
80
+ if self.exploration == "epsilon_greedy":
81
+ logger.add("train/stats", {"epsilon": self._epsilon})