gymcts 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gymcts/__init__.py ADDED
File without changes
@@ -0,0 +1,142 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ CEND = "\33[0m"
5
+ CBOLD = "\33[1m"
6
+ CITALIC = "\33[3m"
7
+ CURL = "\33[4m"
8
+ CBLINK = "\33[5m"
9
+ CBLINK2 = "\33[6m"
10
+ CSELECTED = "\33[7m"
11
+
12
+ CBLACK = "\33[30m"
13
+ CRED = "\33[31m"
14
+ CGREEN = "\33[32m"
15
+ CYELLOW = "\33[33m"
16
+ CBLUE = "\33[34m"
17
+ CCYAN = '\33[96m'
18
+ CMAGENTA = '\033[35m'
19
+ CVIOLET = "\33[35m"
20
+ CBEIGE = "\33[36m"
21
+ CWHITE = "\33[37m"
22
+
23
+ CBLACKBG = "\33[40m"
24
+ CREDBG = "\33[41m"
25
+ CGREENBG = "\33[42m"
26
+ CYELLOWBG = "\33[43m"
27
+ CBLUEBG = "\33[44m"
28
+ CVIOLETBG = "\33[45m"
29
+ CBEIGEBG = "\33[46m"
30
+ CWHITEBG = "\33[47m"
31
+
32
+ CGREY = "\33[90m"
33
+ CRED2 = "\33[91m"
34
+ CGREEN2 = "\33[92m"
35
+ CYELLOW2 = "\33[93m"
36
+ CBLUE2 = "\33[94m"
37
+ CCYAN2 = "\033[36m"
38
+ CVIOLET2 = "\33[95m"
39
+ CBEIGE2 = "\33[96m"
40
+ CWHITE2 = "\33[97m"
41
+
42
+ CGREYBG = "\33[100m"
43
+ CREDBG2 = "\33[101m"
44
+ CGREENBG2 = "\33[102m"
45
+ CYELLOWBG2 = "\33[103m"
46
+ CBLUEBG2 = "\33[104m"
47
+ CVIOLETBG2 = "\33[105m"
48
+ CBEIGEBG2 = "\33[106m"
49
+ CWHITEBG2 = "\33[107m"
50
+
51
+
52
+ def rgb_color_sequence(r: int | float, g: int | float, b: int | float,
53
+ *, format_type: str = 'foreground') -> str:
54
+ """
55
+ generates a color-codes, that change the color of text in console outputs.
56
+
57
+ rgb values must be numbers between 0 and 255 or 0.0 and 1.0.
58
+
59
+ :param r: red value.
60
+ :param g: green value
61
+ :param b: blue value
62
+
63
+ :param format_type: specifies weather the foreground-color or the background-color shall be adjusted.
64
+ valid options: 'foreground','background'
65
+ :return: a string that contains the color-codes.
66
+ """
67
+ # type: ignore # noqa: F401
68
+ if format_type == 'foreground':
69
+ f = '\033[38;2;{};{};{}m'.format # font rgb format
70
+ elif format_type == 'background':
71
+ f = '\033[48;2;{};{};{}m'.format # font background rgb format
72
+ else:
73
+ raise ValueError(f"format {format_type} is not defined. Use 'foreground' or 'background'.")
74
+ rgb = [r, g, b]
75
+
76
+ if isinstance(r, int) and isinstance(g, int) and isinstance(b, int):
77
+ if min(rgb) < 0 and max(rgb) > 255:
78
+ raise ValueError("rgb values must be numbers between 0 and 255 or 0.0 and 1.0")
79
+ return f(r, g, b)
80
+ if isinstance(r, float) and isinstance(g, float) and isinstance(b, float):
81
+ if min(rgb) < 0 and max(rgb) > 1:
82
+ raise ValueError("rgb values must be numbers between 0 and 255 or 0.0 and 1.0")
83
+ return f(*[int(n * 255) for n in [r, g, b]])
84
+
85
+
86
+ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int | float, **kwargs) \
87
+ -> str:
88
+ """
89
+ stringify an object and wrap it with console color codes. It adds the color control sequence in front and one
90
+ at the end that resolves the color again.
91
+
92
+ rgb values must be numbers between 0 and 255 or 0.0 and 1.0.
93
+
94
+ :param s: the object to stringify and wrap
95
+ :param r: red value.
96
+ :param g: green value.
97
+ :param b: blue value.
98
+ :param kwargs: additional argument for the 'DisjunctiveGraphJspVisualizer.rgb_color_sequence'-method.
99
+ :return:
100
+ """
101
+ return f"{rgb_color_sequence(r, g, b, **kwargs)}" \
102
+ f"{s}" \
103
+ f"{CEND}"
104
+
105
+
106
+
107
+ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
108
+ if s is None or n_of_item is None or n_classes is None:
109
+ return s
110
+
111
+ c_map = plt.cm.get_cmap(c_map) # select the desired cmap
112
+ arr = np.linspace(0, 1, n_classes + 1) # create a list with numbers from 0 to 1 with n items
113
+
114
+ color_vals = c_map(arr[n_of_item])[:-1]
115
+ color_asni = rgb_color_sequence(*color_vals, format_type='foreground')
116
+
117
+ return f"{color_asni}{s}{CEND}"
118
+
119
+
120
+ def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
121
+ if s is None or min_val is None or max_val is None or min_val >= max_val:
122
+ return s
123
+
124
+ if c_map is not None:
125
+ c_map = plt.cm.get_cmap(c_map) # select the desired cmap
126
+ else:
127
+ from matplotlib.colors import LinearSegmentedColormap
128
+ colors = [
129
+ np.array([255 / 255, 100 / 255, 128 / 255, 1.0]), # RGBA values
130
+ np.array([63 / 255, 197 / 255, 161 / 255, 1.0]), # RGBA values
131
+ ]
132
+ c_map = LinearSegmentedColormap.from_list("custom_cmap", colors, N=256)
133
+
134
+ color_vals = c_map((value - min_val) / (max_val - min_val))[:-1]
135
+ color_asni = rgb_color_sequence(*color_vals, format_type='foreground')
136
+
137
+ return f"{color_asni}{s}{CEND}"
138
+
139
+
140
+ if __name__ == '__main__':
141
+ res = wrap_with_color_scale("test", 1.0, 0, 1)
142
+ print(res)
gymcts/gymcts_agent.py ADDED
@@ -0,0 +1,261 @@
1
+ import copy
2
+ import gymnasium as gym
3
+
4
+ from typing import TypeVar, Any, SupportsFloat, Callable
5
+
6
+ from gymcts.gymcts_gym_env import SoloMCTSGymEnv
7
+ from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
8
+ from gymcts.gymcts_node import SoloMCTSNode
9
+
10
+ from gymcts.logger import log
11
+
12
+ TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
13
+
14
+
15
+ class SoloMCTSAgent:
16
+ render_tree_after_step: bool = False
17
+ render_tree_max_depth: int = 2
18
+ exclude_unvisited_nodes_from_render: bool = False
19
+ number_of_simulations_per_step: int = 25
20
+
21
+ env: SoloMCTSGymEnv
22
+ search_root_node: SoloMCTSNode # NOTE: this is not the same as the root of the tree!
23
+ clear_mcts_tree_after_step: bool
24
+
25
+ def __init__(self,
26
+ env: SoloMCTSGymEnv,
27
+ clear_mcts_tree_after_step: bool = True,
28
+ render_tree_after_step: bool = False,
29
+ render_tree_max_depth: int = 2,
30
+ number_of_simulations_per_step: int = 25,
31
+ exclude_unvisited_nodes_from_render: bool = False
32
+ ):
33
+ # check if action space of env is discrete
34
+ if not isinstance(env.action_space, gym.spaces.Discrete):
35
+ raise ValueError("Action space must be discrete.")
36
+
37
+ self.render_tree_after_step = render_tree_after_step
38
+ self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
39
+ self.render_tree_max_depth = render_tree_max_depth
40
+
41
+ self.number_of_simulations_per_step = number_of_simulations_per_step
42
+
43
+ self.env = env
44
+ self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
45
+
46
+ self.search_root_node = SoloMCTSNode(
47
+ action=None,
48
+ parent=None,
49
+ env_reference=env,
50
+ )
51
+
52
+ def navigate_to_leaf(self, from_node: SoloMCTSNode) -> SoloMCTSNode:
53
+ log.debug(f"Navigate to leaf. from_node: {from_node}")
54
+ if from_node.terminal:
55
+ log.debug("Node is terminal. Returning from_node")
56
+ return from_node
57
+ if from_node.is_leaf():
58
+ log.debug("Node is leaf. Returning from_node")
59
+ return from_node
60
+
61
+ temp_node = from_node
62
+ # NAVIGATION STRATEGY
63
+ # select child with highest UCB score
64
+ while not temp_node.is_leaf():
65
+ temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score())
66
+ log.debug(f"Selected leaf node: {temp_node}")
67
+ return temp_node
68
+
69
+ def expand_node(self, node: SoloMCTSNode) -> None:
70
+ log.debug(f"expanding node: {node}")
71
+ # EXPANSION STRATEGY
72
+ # expand all children
73
+
74
+ child_dict = {}
75
+ for action in node.valid_actions:
76
+ # reconstruct state
77
+ # load state of leaf node
78
+ self._load_state(node)
79
+
80
+ obs, reward, terminal, truncated, _ = self.env.step(action)
81
+ child_dict[action] = SoloMCTSNode(
82
+ action=action,
83
+ parent=node,
84
+ env_reference=self.env,
85
+ )
86
+
87
+ node.children = child_dict
88
+
89
+ def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
90
+
91
+ if num_simulations_per_step is None:
92
+ num_simulations_per_step = self.number_of_simulations_per_step
93
+ if render_tree_after_step is None:
94
+ render_tree_after_step = self.render_tree_after_step
95
+
96
+ log.debug(f"Solving from root node: {self.search_root_node}")
97
+
98
+ current_node = self.search_root_node
99
+
100
+ action_list = []
101
+
102
+ while not current_node.terminal:
103
+ next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
104
+ render_tree_after_step=render_tree_after_step)
105
+ log.info(f"selected action {next_action} after {num_simulations_per_step} simulations.")
106
+ action_list.append(next_action)
107
+ log.info(f"current action list: {action_list}")
108
+
109
+ log.info(f"Final action list: {action_list}")
110
+ # restore state of current node
111
+ return action_list
112
+
113
+ def _load_state(self, node: SoloMCTSNode) -> None:
114
+ if isinstance(self.env, NaiveSoloMCTSGymEnvWrapper):
115
+ self.env = copy.deepcopy(node.state)
116
+ else:
117
+ self.env.load_state(node.state)
118
+
119
+ def perform_mcts_step(self, search_start_node: SoloMCTSNode = None, num_simulations: int = None,
120
+ render_tree_after_step: bool = None) -> tuple[int, SoloMCTSNode]:
121
+
122
+ if render_tree_after_step is None:
123
+ render_tree_after_step = self.render_tree_after_step
124
+
125
+ if render_tree_after_step is None:
126
+ render_tree_after_step = self.render_tree_after_step
127
+
128
+ if num_simulations is None:
129
+ num_simulations = self.number_of_simulations_per_step
130
+
131
+ if search_start_node is None:
132
+ search_start_node = self.search_root_node
133
+
134
+ action = self.vanilla_mcts_search(
135
+ search_start_node=search_start_node,
136
+ num_simulations=num_simulations,
137
+ )
138
+ next_node = search_start_node.children[action]
139
+
140
+ if self.clear_mcts_tree_after_step:
141
+ # to clear memory we need to remove all nodes except the current node
142
+ # this is done by setting the root node to the current node
143
+ # and setting the parent of the current node to None
144
+ # we also need to reset the children of the current node
145
+ # this is done by calling the reset method
146
+ next_node.reset()
147
+
148
+ self.search_root_node = next_node
149
+
150
+ return action, next_node
151
+
152
+ def vanilla_mcts_search(self, search_start_node: SoloMCTSNode = None, num_simulations=10) -> int:
153
+ log.debug(f"performing one MCTS search step with {num_simulations} simulations")
154
+ if search_start_node is None:
155
+ search_start_node = self.search_root_node
156
+
157
+ for i in range(num_simulations):
158
+ log.debug(f"simulation {i}")
159
+ # navigate to leaf
160
+ leaf_node = self.navigate_to_leaf(from_node=search_start_node)
161
+
162
+ if leaf_node.visit_count > 0 and not leaf_node.terminal:
163
+ # expand leaf
164
+ self.expand_node(leaf_node)
165
+ leaf_node = leaf_node.get_random_child()
166
+
167
+ # load state of leaf node
168
+ self._load_state(leaf_node)
169
+
170
+ # rollout
171
+ episode_return = self.env.rollout()
172
+ # self.env.render()
173
+
174
+ self.backpropagation(node=leaf_node, episode_return=episode_return)
175
+
176
+ if self.render_tree_after_step:
177
+ self.show_mcts_tree()
178
+
179
+ return search_start_node.get_best_action()
180
+
181
+ def show_mcts_tree(self, start_node: SoloMCTSNode = None, tree_max_depth: int = None) -> None:
182
+
183
+ if start_node is None:
184
+ start_node = self.search_root_node
185
+
186
+ if tree_max_depth is None:
187
+ tree_max_depth = self.render_tree_max_depth
188
+
189
+ print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
190
+ for line in self._generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
191
+ print(line)
192
+
193
+ def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
194
+ self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
195
+
196
+ def backpropagation(self, node: SoloMCTSNode, episode_return: float) -> None:
197
+ log.debug(f"performing backpropagation from leaf node: {node}")
198
+ while not node.is_root():
199
+ # node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
200
+ node.mean_value = node.mean_value + (episode_return - node.mean_value) / (node.visit_count + 1)
201
+ node.visit_count += 1
202
+ node.max_value = max(node.max_value, episode_return)
203
+ node.min_value = min(node.min_value, episode_return)
204
+ node = node.parent
205
+ # also update root node
206
+ # node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
207
+ node.mean_value = node.mean_value + (episode_return - node.mean_value) / (node.visit_count + 1)
208
+ node.visit_count += 1
209
+ node.max_value = max(node.max_value, episode_return)
210
+ node.min_value = min(node.min_value, episode_return)
211
+
212
+ def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
213
+
214
+ if prefix is None:
215
+ prefix = ""
216
+ import gymcts.colorful_console_utils as ccu
217
+
218
+ if start_node is None:
219
+ start_node = self.search_root_node
220
+
221
+ # prefix components:
222
+ space = ' '
223
+ branch = '│ '
224
+ # pointers:
225
+ tee = '├── '
226
+ last = '└── '
227
+
228
+ contents = start_node.children.values() if start_node.children is not None else []
229
+ if self.exclude_unvisited_nodes_from_render:
230
+ contents = [node for node in contents if node.visit_count > 0]
231
+ # contents each get pointers that are ├── with a final └── :
232
+ # pointers = [tee] * (len(contents) - 1) + [last]
233
+ pointers = [tee for _ in range(len(contents) - 1)] + [last]
234
+
235
+ for pointer, current_node in zip(pointers, contents):
236
+ n_item = current_node.parent.action if current_node.parent is not None else 0
237
+ n_classes = self.env.action_space.n
238
+
239
+ pointer = ccu.wrap_evenly_spaced_color(
240
+ s=pointer,
241
+ n_of_item=n_item,
242
+ n_classes=n_classes,
243
+ )
244
+
245
+ yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
246
+ if current_node.children and len(current_node.children): # extend the prefix and recurse:
247
+ # extension = branch if pointer == tee else space
248
+ extension = branch if tee in pointer else space
249
+ # i.e. space because last, └── , above so no more |
250
+ extension = ccu.wrap_evenly_spaced_color(
251
+ s=extension,
252
+ n_of_item=n_item,
253
+ n_classes=n_classes,
254
+ )
255
+ if depth is not None and depth <= 0:
256
+ continue
257
+ yield from self._generate_mcts_tree(
258
+ current_node,
259
+ prefix=prefix + extension,
260
+ depth=depth - 1 if depth is not None else None
261
+ )
@@ -0,0 +1,107 @@
1
+ import random
2
+ import copy
3
+
4
+ import numpy as np
5
+ from typing import TypeVar, Any, SupportsFloat, Callable
6
+ import gymnasium as gym
7
+ from gymnasium.core import WrapperActType, WrapperObsType
8
+ from gymnasium.wrappers import RecordEpisodeStatistics
9
+
10
+ from gymcts.gymcts_gym_env import SoloMCTSGymEnv
11
+
12
+ from gymcts.logger import log
13
+
14
+
15
+ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
16
+ _terminal_flag: bool = False
17
+ _last_reward: SupportsFloat = 0
18
+ _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
19
+
20
+ _action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
21
+
22
+ def __init__(
23
+ self,
24
+ env,
25
+ action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
26
+ buffer_length: int = 100,
27
+ ):
28
+ # wrap with RecordEpisodeStatistics if it is not already wrapped
29
+ env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
30
+
31
+ gym.Wrapper.__init__(self, env)
32
+
33
+ self._wrapper_action_history = []
34
+
35
+ # assert that the action space is discrete
36
+ if not isinstance(env.action_space, gym.spaces.Discrete):
37
+ raise ValueError("Only discrete action spaces are supported.")
38
+
39
+ if action_mask_fn is not None:
40
+ # copy of stable baselines3 contrib implementation
41
+ if isinstance(action_mask_fn, str):
42
+ found_method = getattr(self.env, action_mask_fn)
43
+ if not callable(found_method):
44
+ raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
45
+
46
+ self._action_mask_fn = found_method
47
+ else:
48
+ self._action_mask_fn = action_mask_fn
49
+
50
+ def load_state(self, state: list[int]) -> None:
51
+ self.env.reset()
52
+ self._wrapper_action_history = []
53
+
54
+ for action in state:
55
+ self.env.step(action)
56
+ self._wrapper_action_history.append(action)
57
+
58
+ def is_terminal(self) -> bool:
59
+ if not len(self.get_valid_actions()):
60
+ return True
61
+ else:
62
+ return self._terminal_flag
63
+
64
+ def action_masks(self) -> np.ndarray | None:
65
+ return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
66
+
67
+ def get_valid_actions(self) -> list[int]:
68
+ if self._action_mask_fn is None:
69
+ action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
70
+ return list(range(action_space.n))
71
+ else:
72
+ return [i for i, mask in enumerate(self.action_masks()) if mask]
73
+
74
+ def rollout(self) -> float:
75
+ log.debug("performing rollout")
76
+ # random rollout
77
+ # perform random valid action util terminal
78
+ is_terminal_state = self.is_terminal()
79
+
80
+ if is_terminal_state:
81
+ _, _, _, _, info = self._step_tuple
82
+ episode_return = info["episode"]["r"]
83
+ return episode_return
84
+
85
+ while not is_terminal_state:
86
+ action = random.choice(self.get_valid_actions())
87
+ # print(f"Valid actions: {self.get_valid_actions()}, selected action: {action}")
88
+ _obs, _reward, is_terminal_state, _truncated, info = self.step(action)
89
+
90
+ episode_return = info["episode"]["r"]
91
+ log.debug(f"Rollout return: {episode_return}")
92
+ return episode_return
93
+
94
+ def get_state(self) -> list[int]:
95
+ return self._wrapper_action_history.copy()
96
+
97
+ def step(
98
+ self, action: WrapperActType
99
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
100
+ step_tuple = self.env.step(action)
101
+ self._wrapper_action_history.append(action)
102
+ obs, reward, terminated, truncated, info = step_tuple
103
+
104
+ self._terminal_flag = terminated or truncated
105
+ self._step_tuple = step_tuple
106
+
107
+ return step_tuple
@@ -0,0 +1,28 @@
1
+ from typing import TypeVar, Any, SupportsFloat, Callable
2
+ from abc import ABC, abstractmethod
3
+ import gymnasium as gym
4
+
5
+ TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
6
+
7
+
8
+ class SoloMCTSGymEnv(ABC, gym.Env):
9
+
10
+ @abstractmethod
11
+ def get_state(self) -> Any:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def load_state(self, state: Any) -> None:
16
+ pass
17
+
18
+ @abstractmethod
19
+ def is_terminal(self) -> bool:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def get_valid_actions(self) -> list[int]:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def rollout(self) -> float:
28
+ pass
@@ -0,0 +1,114 @@
1
+ import random
2
+ import copy
3
+
4
+ import numpy as np
5
+ from typing import TypeVar, Any, SupportsFloat, Callable
6
+ import gymnasium as gym
7
+ from gymnasium.core import WrapperActType, WrapperObsType
8
+ from gymnasium.wrappers import RecordEpisodeStatistics
9
+
10
+ from gymcts.gymcts_gym_env import SoloMCTSGymEnv
11
+
12
+ from gymcts.logger import log
13
+
14
+
15
+ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
16
+
17
+
18
+ _terminal_flag:bool = False
19
+ _last_reward: SupportsFloat = 0
20
+ _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
21
+
22
+ _action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
23
+
24
+ def is_terminal(self) -> bool:
25
+ return self._terminal_flag
26
+
27
+ def load_state(self, state: Any) -> None:
28
+ msg = """
29
+ The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
30
+ The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
31
+ 'self' in a method cannot be replaced with another object (as far as i know). Therefore the copy is done by
32
+ MCTSaAgent here.
33
+ """
34
+ raise NotImplementedError(msg)
35
+
36
+ def __init__(self,
37
+ env,
38
+ action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
39
+ buffer_length: int = 100,
40
+ record_video: bool = False,
41
+ ):
42
+ # wrap with RecordEpisodeStatistics if it is not already wrapped
43
+ env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
44
+
45
+ gym.Wrapper.__init__(self, env)
46
+ # super().__init__(env)
47
+
48
+ # assert that the action space is discrete
49
+ if not isinstance(env.action_space, gym.spaces.Discrete):
50
+ raise ValueError("Only discrete action spaces are supported.")
51
+
52
+ if action_mask_fn is not None:
53
+ # copy of stable baselines3 contrib implementation
54
+ if isinstance(action_mask_fn, str):
55
+ found_method = getattr(self.env, action_mask_fn)
56
+ if not callable(found_method):
57
+ raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
58
+
59
+ self._action_mask_fn = found_method
60
+ else:
61
+ self._action_mask_fn = action_mask_fn
62
+
63
+ def get_state(self) -> Any:
64
+ log.debug("getting state")
65
+ original_state = self
66
+ copied_state = copy.deepcopy(self)
67
+
68
+ log.debug(f"original state memory location: {hex(id(original_state))}")
69
+ log.debug(f"copied memory location: {hex(id(copied_state))}")
70
+
71
+ return copied_state
72
+
73
+ def action_masks(self) -> np.ndarray | None:
74
+ return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
75
+
76
+ def get_valid_actions(self) -> list[int]:
77
+ if self._action_mask_fn is None:
78
+ action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
79
+ return list(range(action_space.n))
80
+ else:
81
+ return [i for i, mask in enumerate(self.action_masks()) if mask]
82
+
83
+ def step(
84
+ self, action: WrapperActType
85
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
86
+ step_tuple = self.env.step(action)
87
+
88
+ obs, reward, terminated, truncated, info = step_tuple
89
+ self._terminal_flag = terminated or truncated
90
+ self._step_tuple = step_tuple
91
+
92
+ return step_tuple
93
+
94
+
95
+ def rollout(self) -> float:
96
+ log.debug("performing rollout")
97
+ # random rollout
98
+ # perform random valid action util terminal
99
+ is_terminal_state = self.is_terminal()
100
+
101
+ if is_terminal_state:
102
+ _, _, _, _, info = self._step_tuple
103
+ episode_return = info["episode"]["r"]
104
+ return episode_return
105
+
106
+ while not is_terminal_state:
107
+ action = random.choice(self.get_valid_actions())
108
+ # print(f"Valid actions: {self.get_valid_actions()}, selected action: {action}")
109
+ _obs, _reward, is_terminal_state, _truncated, info = self.step(action)
110
+
111
+
112
+ episode_return = info["episode"]["r"]
113
+ log.debug(f"Rollout return: {episode_return}")
114
+ return episode_return