gymcts 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +26 -3
- gymcts/{gymcts_deterministic_wrapper.py → gymcts_action_history_wrapper.py} +74 -4
- gymcts/gymcts_agent.py +29 -69
- gymcts/{gymcts_naive_wrapper.py → gymcts_deepcopy_wrapper.py} +60 -3
- gymcts/gymcts_distributed_agent.py +299 -0
- gymcts/gymcts_env_abc.py +61 -0
- gymcts/gymcts_node.py +107 -44
- gymcts/gymcts_tree_plotter.py +96 -0
- gymcts/logger.py +1 -4
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/METADATA +54 -56
- gymcts-1.2.1.dist-info/RECORD +15 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/WHEEL +1 -1
- gymcts/gymcts_gym_env.py +0 -28
- gymcts-1.0.0.dist-info/RECORD +0 -13
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info/licenses}/LICENSE +0 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/top_level.txt +0 -0
gymcts/colorful_console_utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
import matplotlib.pyplot as plt
|
|
2
4
|
import numpy as np
|
|
3
5
|
|
|
@@ -103,8 +105,19 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
|
|
|
103
105
|
f"{CEND}"
|
|
104
106
|
|
|
105
107
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
+
def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
|
|
109
|
+
"""
|
|
110
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
|
|
111
|
+
This function is used to color code the available actions in the MCTS tree visualisation.
|
|
112
|
+
The children of the MCTS tree are colored based on their action for a clearer visualisation.
|
|
113
|
+
|
|
114
|
+
:param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
|
|
115
|
+
:param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
|
|
116
|
+
:param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
|
|
117
|
+
:param c_map: the colormap to be used (default is 'rainbow').
|
|
118
|
+
The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
|
|
119
|
+
:return: a string that contains the color-codes (prefix and suffix) and the string s in between.
|
|
120
|
+
"""
|
|
108
121
|
if s is None or n_of_item is None or n_classes is None:
|
|
109
122
|
return s
|
|
110
123
|
|
|
@@ -117,7 +130,17 @@ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbo
|
|
|
117
130
|
return f"{color_asni}{s}{CEND}"
|
|
118
131
|
|
|
119
132
|
|
|
120
|
-
def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
|
|
133
|
+
def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
|
|
136
|
+
|
|
137
|
+
:param s: the string to be wrapped
|
|
138
|
+
:param value: the value to be mapped to a color
|
|
139
|
+
:param min_val: the minimum value of the scale
|
|
140
|
+
:param max_val: the maximum value of the scale
|
|
141
|
+
:param c_map: the colormap to be used (default is 'rainbow')
|
|
142
|
+
:return:
|
|
143
|
+
"""
|
|
121
144
|
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
122
145
|
return s
|
|
123
146
|
|
|
@@ -1,18 +1,32 @@
|
|
|
1
1
|
import random
|
|
2
|
-
import copy
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any, SupportsFloat, Callable
|
|
6
5
|
import gymnasium as gym
|
|
7
6
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
7
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
8
|
|
|
10
|
-
from gymcts.
|
|
9
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
10
|
|
|
12
11
|
from gymcts.logger import log
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
class
|
|
14
|
+
class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
15
|
+
"""
|
|
16
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
17
|
+
It uses the action history as state representation.
|
|
18
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
19
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
20
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
21
|
+
The action history is a list of actions taken in the environment.
|
|
22
|
+
The state is represented as a list of actions taken in the environment.
|
|
23
|
+
The state is used to restore the environment using the load_state method.
|
|
24
|
+
|
|
25
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
26
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# helper attributes for the wrapper
|
|
16
30
|
_terminal_flag: bool = False
|
|
17
31
|
_last_reward: SupportsFloat = 0
|
|
18
32
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -25,6 +39,17 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
25
39
|
action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
|
|
26
40
|
buffer_length: int = 100,
|
|
27
41
|
):
|
|
42
|
+
"""
|
|
43
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
44
|
+
It uses the action history as state representation.
|
|
45
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
46
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
47
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
48
|
+
|
|
49
|
+
:param env: the environment to wrap
|
|
50
|
+
:param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
|
|
51
|
+
:param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
|
|
52
|
+
"""
|
|
28
53
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
29
54
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
30
55
|
|
|
@@ -48,6 +73,17 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
48
73
|
self._action_mask_fn = action_mask_fn
|
|
49
74
|
|
|
50
75
|
def load_state(self, state: list[int]) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Loads the state of the environment. The state is a list of actions taken in the environment.
|
|
78
|
+
|
|
79
|
+
The environment is reset and all actions in the state are performed in order to restore the environment to the
|
|
80
|
+
same state.
|
|
81
|
+
|
|
82
|
+
This works only for deterministic environments!
|
|
83
|
+
|
|
84
|
+
:param state: the state to load
|
|
85
|
+
:return: None
|
|
86
|
+
"""
|
|
51
87
|
self.env.reset()
|
|
52
88
|
self._wrapper_action_history = []
|
|
53
89
|
|
|
@@ -56,15 +92,30 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
56
92
|
self._wrapper_action_history.append(action)
|
|
57
93
|
|
|
58
94
|
def is_terminal(self) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
97
|
+
|
|
98
|
+
:return:
|
|
99
|
+
"""
|
|
59
100
|
if not len(self.get_valid_actions()):
|
|
60
101
|
return True
|
|
61
102
|
else:
|
|
62
103
|
return self._terminal_flag
|
|
63
104
|
|
|
64
105
|
def action_masks(self) -> np.ndarray | None:
|
|
106
|
+
"""
|
|
107
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
108
|
+
|
|
109
|
+
:return:
|
|
110
|
+
"""
|
|
65
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
66
112
|
|
|
67
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
|
|
117
|
+
:return: a list of valid actions
|
|
118
|
+
"""
|
|
68
119
|
if self._action_mask_fn is None:
|
|
69
120
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
70
121
|
return list(range(action_space.n))
|
|
@@ -72,6 +123,12 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
72
123
|
return [i for i, mask in enumerate(self.action_masks()) if mask]
|
|
73
124
|
|
|
74
125
|
def rollout(self) -> float:
|
|
126
|
+
"""
|
|
127
|
+
Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
|
|
128
|
+
of the rollout.
|
|
129
|
+
|
|
130
|
+
:return: the return of the rollout
|
|
131
|
+
"""
|
|
75
132
|
log.debug("performing rollout")
|
|
76
133
|
# random rollout
|
|
77
134
|
# perform random valid action util terminal
|
|
@@ -92,11 +149,24 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
92
149
|
return episode_return
|
|
93
150
|
|
|
94
151
|
def get_state(self) -> list[int]:
|
|
152
|
+
"""
|
|
153
|
+
Returns the current state of the environment. The state is a list of actions taken in the environment,
|
|
154
|
+
namely all action that have been taken in the environment so far (since the last reset).
|
|
155
|
+
|
|
156
|
+
:return: a list of actions taken in the environment
|
|
157
|
+
"""
|
|
158
|
+
|
|
95
159
|
return self._wrapper_action_history.copy()
|
|
96
160
|
|
|
97
161
|
def step(
|
|
98
162
|
self, action: WrapperActType
|
|
99
163
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
164
|
+
"""
|
|
165
|
+
Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
|
|
166
|
+
|
|
167
|
+
:param action: action to perform in the environment
|
|
168
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
169
|
+
"""
|
|
100
170
|
step_tuple = self.env.step(action)
|
|
101
171
|
self._wrapper_action_history.append(action)
|
|
102
172
|
obs, reward, terminated, truncated, info = step_tuple
|
gymcts/gymcts_agent.py
CHANGED
|
@@ -1,29 +1,31 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import random
|
|
2
3
|
import gymnasium as gym
|
|
3
4
|
|
|
4
5
|
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
6
|
|
|
6
|
-
from gymcts.
|
|
7
|
-
from gymcts.
|
|
8
|
-
from gymcts.gymcts_node import
|
|
7
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
8
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
9
|
+
from gymcts.gymcts_node import GymctsNode
|
|
10
|
+
from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
9
11
|
|
|
10
12
|
from gymcts.logger import log
|
|
11
13
|
|
|
12
14
|
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class
|
|
17
|
+
class GymctsAgent:
|
|
16
18
|
render_tree_after_step: bool = False
|
|
17
19
|
render_tree_max_depth: int = 2
|
|
18
20
|
exclude_unvisited_nodes_from_render: bool = False
|
|
19
21
|
number_of_simulations_per_step: int = 25
|
|
20
22
|
|
|
21
|
-
env:
|
|
22
|
-
search_root_node:
|
|
23
|
+
env: GymctsABC
|
|
24
|
+
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
23
25
|
clear_mcts_tree_after_step: bool
|
|
24
26
|
|
|
25
27
|
def __init__(self,
|
|
26
|
-
env:
|
|
28
|
+
env: GymctsABC,
|
|
27
29
|
clear_mcts_tree_after_step: bool = True,
|
|
28
30
|
render_tree_after_step: bool = False,
|
|
29
31
|
render_tree_max_depth: int = 2,
|
|
@@ -43,13 +45,13 @@ class SoloMCTSAgent:
|
|
|
43
45
|
self.env = env
|
|
44
46
|
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
45
47
|
|
|
46
|
-
self.search_root_node =
|
|
48
|
+
self.search_root_node = GymctsNode(
|
|
47
49
|
action=None,
|
|
48
50
|
parent=None,
|
|
49
51
|
env_reference=env,
|
|
50
52
|
)
|
|
51
53
|
|
|
52
|
-
def navigate_to_leaf(self, from_node:
|
|
54
|
+
def navigate_to_leaf(self, from_node: GymctsNode) -> GymctsNode:
|
|
53
55
|
log.debug(f"Navigate to leaf. from_node: {from_node}")
|
|
54
56
|
if from_node.terminal:
|
|
55
57
|
log.debug("Node is terminal. Returning from_node")
|
|
@@ -62,11 +64,14 @@ class SoloMCTSAgent:
|
|
|
62
64
|
# NAVIGATION STRATEGY
|
|
63
65
|
# select child with highest UCB score
|
|
64
66
|
while not temp_node.is_leaf():
|
|
65
|
-
|
|
67
|
+
children = list(temp_node.children.values())
|
|
68
|
+
max_ucb_score = max(child.ucb_score() for child in children)
|
|
69
|
+
best_children = [child for child in children if child.ucb_score() == max_ucb_score]
|
|
70
|
+
temp_node = random.choice(best_children)
|
|
66
71
|
log.debug(f"Selected leaf node: {temp_node}")
|
|
67
72
|
return temp_node
|
|
68
73
|
|
|
69
|
-
def expand_node(self, node:
|
|
74
|
+
def expand_node(self, node: GymctsNode) -> None:
|
|
70
75
|
log.debug(f"expanding node: {node}")
|
|
71
76
|
# EXPANSION STRATEGY
|
|
72
77
|
# expand all children
|
|
@@ -78,7 +83,7 @@ class SoloMCTSAgent:
|
|
|
78
83
|
self._load_state(node)
|
|
79
84
|
|
|
80
85
|
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
81
|
-
child_dict[action] =
|
|
86
|
+
child_dict[action] = GymctsNode(
|
|
82
87
|
action=action,
|
|
83
88
|
parent=node,
|
|
84
89
|
env_reference=self.env,
|
|
@@ -110,14 +115,14 @@ class SoloMCTSAgent:
|
|
|
110
115
|
# restore state of current node
|
|
111
116
|
return action_list
|
|
112
117
|
|
|
113
|
-
def _load_state(self, node:
|
|
114
|
-
if isinstance(self.env,
|
|
118
|
+
def _load_state(self, node: GymctsNode) -> None:
|
|
119
|
+
if isinstance(self.env, DeepCopyMCTSGymEnvWrapper):
|
|
115
120
|
self.env = copy.deepcopy(node.state)
|
|
116
121
|
else:
|
|
117
122
|
self.env.load_state(node.state)
|
|
118
123
|
|
|
119
|
-
def perform_mcts_step(self, search_start_node:
|
|
120
|
-
render_tree_after_step: bool = None) -> tuple[int,
|
|
124
|
+
def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
|
|
125
|
+
render_tree_after_step: bool = None) -> tuple[int, GymctsNode]:
|
|
121
126
|
|
|
122
127
|
if render_tree_after_step is None:
|
|
123
128
|
render_tree_after_step = self.render_tree_after_step
|
|
@@ -149,7 +154,7 @@ class SoloMCTSAgent:
|
|
|
149
154
|
|
|
150
155
|
return action, next_node
|
|
151
156
|
|
|
152
|
-
def vanilla_mcts_search(self, search_start_node:
|
|
157
|
+
def vanilla_mcts_search(self, search_start_node: GymctsNode = None, num_simulations=10) -> int:
|
|
153
158
|
log.debug(f"performing one MCTS search step with {num_simulations} simulations")
|
|
154
159
|
if search_start_node is None:
|
|
155
160
|
search_start_node = self.search_root_node
|
|
@@ -178,7 +183,7 @@ class SoloMCTSAgent:
|
|
|
178
183
|
|
|
179
184
|
return search_start_node.get_best_action()
|
|
180
185
|
|
|
181
|
-
def show_mcts_tree(self, start_node:
|
|
186
|
+
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
182
187
|
|
|
183
188
|
if start_node is None:
|
|
184
189
|
start_node = self.search_root_node
|
|
@@ -187,13 +192,17 @@ class SoloMCTSAgent:
|
|
|
187
192
|
tree_max_depth = self.render_tree_max_depth
|
|
188
193
|
|
|
189
194
|
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
190
|
-
for line in
|
|
195
|
+
for line in _generate_mcts_tree(
|
|
196
|
+
start_node=start_node,
|
|
197
|
+
depth=tree_max_depth,
|
|
198
|
+
action_space_n=self.env.action_space.n,
|
|
199
|
+
):
|
|
191
200
|
print(line)
|
|
192
201
|
|
|
193
202
|
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
194
203
|
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
195
204
|
|
|
196
|
-
def backpropagation(self, node:
|
|
205
|
+
def backpropagation(self, node: GymctsNode, episode_return: float) -> None:
|
|
197
206
|
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
198
207
|
while not node.is_root():
|
|
199
208
|
# node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
|
|
@@ -209,53 +218,4 @@ class SoloMCTSAgent:
|
|
|
209
218
|
node.max_value = max(node.max_value, episode_return)
|
|
210
219
|
node.min_value = min(node.min_value, episode_return)
|
|
211
220
|
|
|
212
|
-
def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
|
|
213
221
|
|
|
214
|
-
if prefix is None:
|
|
215
|
-
prefix = ""
|
|
216
|
-
import gymcts.colorful_console_utils as ccu
|
|
217
|
-
|
|
218
|
-
if start_node is None:
|
|
219
|
-
start_node = self.search_root_node
|
|
220
|
-
|
|
221
|
-
# prefix components:
|
|
222
|
-
space = ' '
|
|
223
|
-
branch = '│ '
|
|
224
|
-
# pointers:
|
|
225
|
-
tee = '├── '
|
|
226
|
-
last = '└── '
|
|
227
|
-
|
|
228
|
-
contents = start_node.children.values() if start_node.children is not None else []
|
|
229
|
-
if self.exclude_unvisited_nodes_from_render:
|
|
230
|
-
contents = [node for node in contents if node.visit_count > 0]
|
|
231
|
-
# contents each get pointers that are ├── with a final └── :
|
|
232
|
-
# pointers = [tee] * (len(contents) - 1) + [last]
|
|
233
|
-
pointers = [tee for _ in range(len(contents) - 1)] + [last]
|
|
234
|
-
|
|
235
|
-
for pointer, current_node in zip(pointers, contents):
|
|
236
|
-
n_item = current_node.parent.action if current_node.parent is not None else 0
|
|
237
|
-
n_classes = self.env.action_space.n
|
|
238
|
-
|
|
239
|
-
pointer = ccu.wrap_evenly_spaced_color(
|
|
240
|
-
s=pointer,
|
|
241
|
-
n_of_item=n_item,
|
|
242
|
-
n_classes=n_classes,
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
|
|
246
|
-
if current_node.children and len(current_node.children): # extend the prefix and recurse:
|
|
247
|
-
# extension = branch if pointer == tee else space
|
|
248
|
-
extension = branch if tee in pointer else space
|
|
249
|
-
# i.e. space because last, └── , above so no more |
|
|
250
|
-
extension = ccu.wrap_evenly_spaced_color(
|
|
251
|
-
s=extension,
|
|
252
|
-
n_of_item=n_item,
|
|
253
|
-
n_classes=n_classes,
|
|
254
|
-
)
|
|
255
|
-
if depth is not None and depth <= 0:
|
|
256
|
-
continue
|
|
257
|
-
yield from self._generate_mcts_tree(
|
|
258
|
-
current_node,
|
|
259
|
-
prefix=prefix + extension,
|
|
260
|
-
depth=depth - 1 if depth is not None else None
|
|
261
|
-
)
|
|
@@ -7,14 +7,21 @@ import gymnasium as gym
|
|
|
7
7
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
8
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
9
|
|
|
10
|
-
from gymcts.
|
|
10
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
11
|
|
|
12
12
|
from gymcts.logger import log
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
16
|
-
|
|
15
|
+
class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
|
+
"""
|
|
17
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
18
|
+
It uses deepcopys as state representation.
|
|
19
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
20
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
21
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
22
|
+
"""
|
|
17
23
|
|
|
24
|
+
# helper attributes for the wrapper
|
|
18
25
|
_terminal_flag:bool = False
|
|
19
26
|
_last_reward: SupportsFloat = 0
|
|
20
27
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -22,9 +29,21 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
22
29
|
_action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
|
|
23
30
|
|
|
24
31
|
def is_terminal(self) -> bool:
|
|
32
|
+
"""
|
|
33
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
34
|
+
|
|
35
|
+
:return: True if the environment is in a terminal state, False otherwise.
|
|
36
|
+
"""
|
|
25
37
|
return self._terminal_flag
|
|
26
38
|
|
|
27
39
|
def load_state(self, state: Any) -> None:
|
|
40
|
+
"""
|
|
41
|
+
The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
|
|
42
|
+
provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
|
|
43
|
+
|
|
44
|
+
:param state: a deepcopy of the environment
|
|
45
|
+
:return: None
|
|
46
|
+
"""
|
|
28
47
|
msg = """
|
|
29
48
|
The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
|
|
30
49
|
The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
|
|
@@ -39,6 +58,16 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
39
58
|
buffer_length: int = 100,
|
|
40
59
|
record_video: bool = False,
|
|
41
60
|
):
|
|
61
|
+
"""
|
|
62
|
+
The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
|
|
63
|
+
space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
|
|
64
|
+
find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
|
|
65
|
+
|
|
66
|
+
:param env: the environment to wrap
|
|
67
|
+
:param action_mask_fn:
|
|
68
|
+
:param buffer_length:
|
|
69
|
+
:param record_video:
|
|
70
|
+
"""
|
|
42
71
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
43
72
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
44
73
|
|
|
@@ -61,6 +90,10 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
61
90
|
self._action_mask_fn = action_mask_fn
|
|
62
91
|
|
|
63
92
|
def get_state(self) -> Any:
|
|
93
|
+
"""
|
|
94
|
+
Returns the current state of the environment as a deepcopy of the environment.
|
|
95
|
+
:return: a deepcopy of the environment
|
|
96
|
+
"""
|
|
64
97
|
log.debug("getting state")
|
|
65
98
|
original_state = self
|
|
66
99
|
copied_state = copy.deepcopy(self)
|
|
@@ -71,9 +104,19 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
71
104
|
return copied_state
|
|
72
105
|
|
|
73
106
|
def action_masks(self) -> np.ndarray | None:
|
|
107
|
+
"""
|
|
108
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
109
|
+
:return: the action masks for the environment
|
|
110
|
+
"""
|
|
74
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
75
112
|
|
|
76
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
117
|
+
|
|
118
|
+
:return: the list of valid actions
|
|
119
|
+
"""
|
|
77
120
|
if self._action_mask_fn is None:
|
|
78
121
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
79
122
|
return list(range(action_space.n))
|
|
@@ -83,6 +126,14 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
83
126
|
def step(
|
|
84
127
|
self, action: WrapperActType
|
|
85
128
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
129
|
+
"""
|
|
130
|
+
Performs a step in the environment.
|
|
131
|
+
This method is used to update the wrapper with the new state and the new action, to realize the terminal state
|
|
132
|
+
functionality.
|
|
133
|
+
|
|
134
|
+
:param action: action to perform in the environment
|
|
135
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
136
|
+
"""
|
|
86
137
|
step_tuple = self.env.step(action)
|
|
87
138
|
|
|
88
139
|
obs, reward, terminated, truncated, info = step_tuple
|
|
@@ -93,6 +144,12 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
|
93
144
|
|
|
94
145
|
|
|
95
146
|
def rollout(self) -> float:
|
|
147
|
+
"""
|
|
148
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
|
|
149
|
+
rollout.
|
|
150
|
+
|
|
151
|
+
:return: the return of the rollout
|
|
152
|
+
"""
|
|
96
153
|
log.debug("performing rollout")
|
|
97
154
|
# random rollout
|
|
98
155
|
# perform random valid action util terminal
|