gymcts 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/__init__.py +0 -0
- gymcts/colorful_console_utils.py +142 -0
- gymcts/gymcts_agent.py +261 -0
- gymcts/gymcts_deterministic_wrapper.py +107 -0
- gymcts/gymcts_gym_env.py +28 -0
- gymcts/gymcts_naive_wrapper.py +114 -0
- gymcts/gymcts_node.py +213 -0
- gymcts/logger.py +33 -0
- gymcts-1.0.0.dist-info/LICENSE +21 -0
- gymcts-1.0.0.dist-info/METADATA +634 -0
- gymcts-1.0.0.dist-info/RECORD +13 -0
- gymcts-1.0.0.dist-info/WHEEL +5 -0
- gymcts-1.0.0.dist-info/top_level.txt +1 -0
gymcts/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
CEND = "\33[0m"
|
|
5
|
+
CBOLD = "\33[1m"
|
|
6
|
+
CITALIC = "\33[3m"
|
|
7
|
+
CURL = "\33[4m"
|
|
8
|
+
CBLINK = "\33[5m"
|
|
9
|
+
CBLINK2 = "\33[6m"
|
|
10
|
+
CSELECTED = "\33[7m"
|
|
11
|
+
|
|
12
|
+
CBLACK = "\33[30m"
|
|
13
|
+
CRED = "\33[31m"
|
|
14
|
+
CGREEN = "\33[32m"
|
|
15
|
+
CYELLOW = "\33[33m"
|
|
16
|
+
CBLUE = "\33[34m"
|
|
17
|
+
CCYAN = '\33[96m'
|
|
18
|
+
CMAGENTA = '\033[35m'
|
|
19
|
+
CVIOLET = "\33[35m"
|
|
20
|
+
CBEIGE = "\33[36m"
|
|
21
|
+
CWHITE = "\33[37m"
|
|
22
|
+
|
|
23
|
+
CBLACKBG = "\33[40m"
|
|
24
|
+
CREDBG = "\33[41m"
|
|
25
|
+
CGREENBG = "\33[42m"
|
|
26
|
+
CYELLOWBG = "\33[43m"
|
|
27
|
+
CBLUEBG = "\33[44m"
|
|
28
|
+
CVIOLETBG = "\33[45m"
|
|
29
|
+
CBEIGEBG = "\33[46m"
|
|
30
|
+
CWHITEBG = "\33[47m"
|
|
31
|
+
|
|
32
|
+
CGREY = "\33[90m"
|
|
33
|
+
CRED2 = "\33[91m"
|
|
34
|
+
CGREEN2 = "\33[92m"
|
|
35
|
+
CYELLOW2 = "\33[93m"
|
|
36
|
+
CBLUE2 = "\33[94m"
|
|
37
|
+
CCYAN2 = "\033[36m"
|
|
38
|
+
CVIOLET2 = "\33[95m"
|
|
39
|
+
CBEIGE2 = "\33[96m"
|
|
40
|
+
CWHITE2 = "\33[97m"
|
|
41
|
+
|
|
42
|
+
CGREYBG = "\33[100m"
|
|
43
|
+
CREDBG2 = "\33[101m"
|
|
44
|
+
CGREENBG2 = "\33[102m"
|
|
45
|
+
CYELLOWBG2 = "\33[103m"
|
|
46
|
+
CBLUEBG2 = "\33[104m"
|
|
47
|
+
CVIOLETBG2 = "\33[105m"
|
|
48
|
+
CBEIGEBG2 = "\33[106m"
|
|
49
|
+
CWHITEBG2 = "\33[107m"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def rgb_color_sequence(r: int | float, g: int | float, b: int | float,
|
|
53
|
+
*, format_type: str = 'foreground') -> str:
|
|
54
|
+
"""
|
|
55
|
+
generates a color-codes, that change the color of text in console outputs.
|
|
56
|
+
|
|
57
|
+
rgb values must be numbers between 0 and 255 or 0.0 and 1.0.
|
|
58
|
+
|
|
59
|
+
:param r: red value.
|
|
60
|
+
:param g: green value
|
|
61
|
+
:param b: blue value
|
|
62
|
+
|
|
63
|
+
:param format_type: specifies weather the foreground-color or the background-color shall be adjusted.
|
|
64
|
+
valid options: 'foreground','background'
|
|
65
|
+
:return: a string that contains the color-codes.
|
|
66
|
+
"""
|
|
67
|
+
# type: ignore # noqa: F401
|
|
68
|
+
if format_type == 'foreground':
|
|
69
|
+
f = '\033[38;2;{};{};{}m'.format # font rgb format
|
|
70
|
+
elif format_type == 'background':
|
|
71
|
+
f = '\033[48;2;{};{};{}m'.format # font background rgb format
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"format {format_type} is not defined. Use 'foreground' or 'background'.")
|
|
74
|
+
rgb = [r, g, b]
|
|
75
|
+
|
|
76
|
+
if isinstance(r, int) and isinstance(g, int) and isinstance(b, int):
|
|
77
|
+
if min(rgb) < 0 and max(rgb) > 255:
|
|
78
|
+
raise ValueError("rgb values must be numbers between 0 and 255 or 0.0 and 1.0")
|
|
79
|
+
return f(r, g, b)
|
|
80
|
+
if isinstance(r, float) and isinstance(g, float) and isinstance(b, float):
|
|
81
|
+
if min(rgb) < 0 and max(rgb) > 1:
|
|
82
|
+
raise ValueError("rgb values must be numbers between 0 and 255 or 0.0 and 1.0")
|
|
83
|
+
return f(*[int(n * 255) for n in [r, g, b]])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int | float, **kwargs) \
|
|
87
|
+
-> str:
|
|
88
|
+
"""
|
|
89
|
+
stringify an object and wrap it with console color codes. It adds the color control sequence in front and one
|
|
90
|
+
at the end that resolves the color again.
|
|
91
|
+
|
|
92
|
+
rgb values must be numbers between 0 and 255 or 0.0 and 1.0.
|
|
93
|
+
|
|
94
|
+
:param s: the object to stringify and wrap
|
|
95
|
+
:param r: red value.
|
|
96
|
+
:param g: green value.
|
|
97
|
+
:param b: blue value.
|
|
98
|
+
:param kwargs: additional argument for the 'DisjunctiveGraphJspVisualizer.rgb_color_sequence'-method.
|
|
99
|
+
:return:
|
|
100
|
+
"""
|
|
101
|
+
return f"{rgb_color_sequence(r, g, b, **kwargs)}" \
|
|
102
|
+
f"{s}" \
|
|
103
|
+
f"{CEND}"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
|
|
108
|
+
if s is None or n_of_item is None or n_classes is None:
|
|
109
|
+
return s
|
|
110
|
+
|
|
111
|
+
c_map = plt.cm.get_cmap(c_map) # select the desired cmap
|
|
112
|
+
arr = np.linspace(0, 1, n_classes + 1) # create a list with numbers from 0 to 1 with n items
|
|
113
|
+
|
|
114
|
+
color_vals = c_map(arr[n_of_item])[:-1]
|
|
115
|
+
color_asni = rgb_color_sequence(*color_vals, format_type='foreground')
|
|
116
|
+
|
|
117
|
+
return f"{color_asni}{s}{CEND}"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
|
|
121
|
+
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
122
|
+
return s
|
|
123
|
+
|
|
124
|
+
if c_map is not None:
|
|
125
|
+
c_map = plt.cm.get_cmap(c_map) # select the desired cmap
|
|
126
|
+
else:
|
|
127
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
128
|
+
colors = [
|
|
129
|
+
np.array([255 / 255, 100 / 255, 128 / 255, 1.0]), # RGBA values
|
|
130
|
+
np.array([63 / 255, 197 / 255, 161 / 255, 1.0]), # RGBA values
|
|
131
|
+
]
|
|
132
|
+
c_map = LinearSegmentedColormap.from_list("custom_cmap", colors, N=256)
|
|
133
|
+
|
|
134
|
+
color_vals = c_map((value - min_val) / (max_val - min_val))[:-1]
|
|
135
|
+
color_asni = rgb_color_sequence(*color_vals, format_type='foreground')
|
|
136
|
+
|
|
137
|
+
return f"{color_asni}{s}{CEND}"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == '__main__':
|
|
141
|
+
res = wrap_with_color_scale("test", 1.0, 0, 1)
|
|
142
|
+
print(res)
|
gymcts/gymcts_agent.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gymnasium as gym
|
|
3
|
+
|
|
4
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
|
+
|
|
6
|
+
from gymcts.gymcts_gym_env import SoloMCTSGymEnv
|
|
7
|
+
from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
|
|
8
|
+
from gymcts.gymcts_node import SoloMCTSNode
|
|
9
|
+
|
|
10
|
+
from gymcts.logger import log
|
|
11
|
+
|
|
12
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SoloMCTSAgent:
|
|
16
|
+
render_tree_after_step: bool = False
|
|
17
|
+
render_tree_max_depth: int = 2
|
|
18
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
19
|
+
number_of_simulations_per_step: int = 25
|
|
20
|
+
|
|
21
|
+
env: SoloMCTSGymEnv
|
|
22
|
+
search_root_node: SoloMCTSNode # NOTE: this is not the same as the root of the tree!
|
|
23
|
+
clear_mcts_tree_after_step: bool
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
env: SoloMCTSGymEnv,
|
|
27
|
+
clear_mcts_tree_after_step: bool = True,
|
|
28
|
+
render_tree_after_step: bool = False,
|
|
29
|
+
render_tree_max_depth: int = 2,
|
|
30
|
+
number_of_simulations_per_step: int = 25,
|
|
31
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
32
|
+
):
|
|
33
|
+
# check if action space of env is discrete
|
|
34
|
+
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
35
|
+
raise ValueError("Action space must be discrete.")
|
|
36
|
+
|
|
37
|
+
self.render_tree_after_step = render_tree_after_step
|
|
38
|
+
self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
|
|
39
|
+
self.render_tree_max_depth = render_tree_max_depth
|
|
40
|
+
|
|
41
|
+
self.number_of_simulations_per_step = number_of_simulations_per_step
|
|
42
|
+
|
|
43
|
+
self.env = env
|
|
44
|
+
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
45
|
+
|
|
46
|
+
self.search_root_node = SoloMCTSNode(
|
|
47
|
+
action=None,
|
|
48
|
+
parent=None,
|
|
49
|
+
env_reference=env,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def navigate_to_leaf(self, from_node: SoloMCTSNode) -> SoloMCTSNode:
|
|
53
|
+
log.debug(f"Navigate to leaf. from_node: {from_node}")
|
|
54
|
+
if from_node.terminal:
|
|
55
|
+
log.debug("Node is terminal. Returning from_node")
|
|
56
|
+
return from_node
|
|
57
|
+
if from_node.is_leaf():
|
|
58
|
+
log.debug("Node is leaf. Returning from_node")
|
|
59
|
+
return from_node
|
|
60
|
+
|
|
61
|
+
temp_node = from_node
|
|
62
|
+
# NAVIGATION STRATEGY
|
|
63
|
+
# select child with highest UCB score
|
|
64
|
+
while not temp_node.is_leaf():
|
|
65
|
+
temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score())
|
|
66
|
+
log.debug(f"Selected leaf node: {temp_node}")
|
|
67
|
+
return temp_node
|
|
68
|
+
|
|
69
|
+
def expand_node(self, node: SoloMCTSNode) -> None:
|
|
70
|
+
log.debug(f"expanding node: {node}")
|
|
71
|
+
# EXPANSION STRATEGY
|
|
72
|
+
# expand all children
|
|
73
|
+
|
|
74
|
+
child_dict = {}
|
|
75
|
+
for action in node.valid_actions:
|
|
76
|
+
# reconstruct state
|
|
77
|
+
# load state of leaf node
|
|
78
|
+
self._load_state(node)
|
|
79
|
+
|
|
80
|
+
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
81
|
+
child_dict[action] = SoloMCTSNode(
|
|
82
|
+
action=action,
|
|
83
|
+
parent=node,
|
|
84
|
+
env_reference=self.env,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
node.children = child_dict
|
|
88
|
+
|
|
89
|
+
def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
|
|
90
|
+
|
|
91
|
+
if num_simulations_per_step is None:
|
|
92
|
+
num_simulations_per_step = self.number_of_simulations_per_step
|
|
93
|
+
if render_tree_after_step is None:
|
|
94
|
+
render_tree_after_step = self.render_tree_after_step
|
|
95
|
+
|
|
96
|
+
log.debug(f"Solving from root node: {self.search_root_node}")
|
|
97
|
+
|
|
98
|
+
current_node = self.search_root_node
|
|
99
|
+
|
|
100
|
+
action_list = []
|
|
101
|
+
|
|
102
|
+
while not current_node.terminal:
|
|
103
|
+
next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
|
|
104
|
+
render_tree_after_step=render_tree_after_step)
|
|
105
|
+
log.info(f"selected action {next_action} after {num_simulations_per_step} simulations.")
|
|
106
|
+
action_list.append(next_action)
|
|
107
|
+
log.info(f"current action list: {action_list}")
|
|
108
|
+
|
|
109
|
+
log.info(f"Final action list: {action_list}")
|
|
110
|
+
# restore state of current node
|
|
111
|
+
return action_list
|
|
112
|
+
|
|
113
|
+
def _load_state(self, node: SoloMCTSNode) -> None:
|
|
114
|
+
if isinstance(self.env, NaiveSoloMCTSGymEnvWrapper):
|
|
115
|
+
self.env = copy.deepcopy(node.state)
|
|
116
|
+
else:
|
|
117
|
+
self.env.load_state(node.state)
|
|
118
|
+
|
|
119
|
+
def perform_mcts_step(self, search_start_node: SoloMCTSNode = None, num_simulations: int = None,
|
|
120
|
+
render_tree_after_step: bool = None) -> tuple[int, SoloMCTSNode]:
|
|
121
|
+
|
|
122
|
+
if render_tree_after_step is None:
|
|
123
|
+
render_tree_after_step = self.render_tree_after_step
|
|
124
|
+
|
|
125
|
+
if render_tree_after_step is None:
|
|
126
|
+
render_tree_after_step = self.render_tree_after_step
|
|
127
|
+
|
|
128
|
+
if num_simulations is None:
|
|
129
|
+
num_simulations = self.number_of_simulations_per_step
|
|
130
|
+
|
|
131
|
+
if search_start_node is None:
|
|
132
|
+
search_start_node = self.search_root_node
|
|
133
|
+
|
|
134
|
+
action = self.vanilla_mcts_search(
|
|
135
|
+
search_start_node=search_start_node,
|
|
136
|
+
num_simulations=num_simulations,
|
|
137
|
+
)
|
|
138
|
+
next_node = search_start_node.children[action]
|
|
139
|
+
|
|
140
|
+
if self.clear_mcts_tree_after_step:
|
|
141
|
+
# to clear memory we need to remove all nodes except the current node
|
|
142
|
+
# this is done by setting the root node to the current node
|
|
143
|
+
# and setting the parent of the current node to None
|
|
144
|
+
# we also need to reset the children of the current node
|
|
145
|
+
# this is done by calling the reset method
|
|
146
|
+
next_node.reset()
|
|
147
|
+
|
|
148
|
+
self.search_root_node = next_node
|
|
149
|
+
|
|
150
|
+
return action, next_node
|
|
151
|
+
|
|
152
|
+
def vanilla_mcts_search(self, search_start_node: SoloMCTSNode = None, num_simulations=10) -> int:
|
|
153
|
+
log.debug(f"performing one MCTS search step with {num_simulations} simulations")
|
|
154
|
+
if search_start_node is None:
|
|
155
|
+
search_start_node = self.search_root_node
|
|
156
|
+
|
|
157
|
+
for i in range(num_simulations):
|
|
158
|
+
log.debug(f"simulation {i}")
|
|
159
|
+
# navigate to leaf
|
|
160
|
+
leaf_node = self.navigate_to_leaf(from_node=search_start_node)
|
|
161
|
+
|
|
162
|
+
if leaf_node.visit_count > 0 and not leaf_node.terminal:
|
|
163
|
+
# expand leaf
|
|
164
|
+
self.expand_node(leaf_node)
|
|
165
|
+
leaf_node = leaf_node.get_random_child()
|
|
166
|
+
|
|
167
|
+
# load state of leaf node
|
|
168
|
+
self._load_state(leaf_node)
|
|
169
|
+
|
|
170
|
+
# rollout
|
|
171
|
+
episode_return = self.env.rollout()
|
|
172
|
+
# self.env.render()
|
|
173
|
+
|
|
174
|
+
self.backpropagation(node=leaf_node, episode_return=episode_return)
|
|
175
|
+
|
|
176
|
+
if self.render_tree_after_step:
|
|
177
|
+
self.show_mcts_tree()
|
|
178
|
+
|
|
179
|
+
return search_start_node.get_best_action()
|
|
180
|
+
|
|
181
|
+
def show_mcts_tree(self, start_node: SoloMCTSNode = None, tree_max_depth: int = None) -> None:
|
|
182
|
+
|
|
183
|
+
if start_node is None:
|
|
184
|
+
start_node = self.search_root_node
|
|
185
|
+
|
|
186
|
+
if tree_max_depth is None:
|
|
187
|
+
tree_max_depth = self.render_tree_max_depth
|
|
188
|
+
|
|
189
|
+
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
190
|
+
for line in self._generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
|
|
191
|
+
print(line)
|
|
192
|
+
|
|
193
|
+
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
194
|
+
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
195
|
+
|
|
196
|
+
def backpropagation(self, node: SoloMCTSNode, episode_return: float) -> None:
|
|
197
|
+
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
198
|
+
while not node.is_root():
|
|
199
|
+
# node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
|
|
200
|
+
node.mean_value = node.mean_value + (episode_return - node.mean_value) / (node.visit_count + 1)
|
|
201
|
+
node.visit_count += 1
|
|
202
|
+
node.max_value = max(node.max_value, episode_return)
|
|
203
|
+
node.min_value = min(node.min_value, episode_return)
|
|
204
|
+
node = node.parent
|
|
205
|
+
# also update root node
|
|
206
|
+
# node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
|
|
207
|
+
node.mean_value = node.mean_value + (episode_return - node.mean_value) / (node.visit_count + 1)
|
|
208
|
+
node.visit_count += 1
|
|
209
|
+
node.max_value = max(node.max_value, episode_return)
|
|
210
|
+
node.min_value = min(node.min_value, episode_return)
|
|
211
|
+
|
|
212
|
+
def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
|
|
213
|
+
|
|
214
|
+
if prefix is None:
|
|
215
|
+
prefix = ""
|
|
216
|
+
import gymcts.colorful_console_utils as ccu
|
|
217
|
+
|
|
218
|
+
if start_node is None:
|
|
219
|
+
start_node = self.search_root_node
|
|
220
|
+
|
|
221
|
+
# prefix components:
|
|
222
|
+
space = ' '
|
|
223
|
+
branch = '│ '
|
|
224
|
+
# pointers:
|
|
225
|
+
tee = '├── '
|
|
226
|
+
last = '└── '
|
|
227
|
+
|
|
228
|
+
contents = start_node.children.values() if start_node.children is not None else []
|
|
229
|
+
if self.exclude_unvisited_nodes_from_render:
|
|
230
|
+
contents = [node for node in contents if node.visit_count > 0]
|
|
231
|
+
# contents each get pointers that are ├── with a final └── :
|
|
232
|
+
# pointers = [tee] * (len(contents) - 1) + [last]
|
|
233
|
+
pointers = [tee for _ in range(len(contents) - 1)] + [last]
|
|
234
|
+
|
|
235
|
+
for pointer, current_node in zip(pointers, contents):
|
|
236
|
+
n_item = current_node.parent.action if current_node.parent is not None else 0
|
|
237
|
+
n_classes = self.env.action_space.n
|
|
238
|
+
|
|
239
|
+
pointer = ccu.wrap_evenly_spaced_color(
|
|
240
|
+
s=pointer,
|
|
241
|
+
n_of_item=n_item,
|
|
242
|
+
n_classes=n_classes,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
|
|
246
|
+
if current_node.children and len(current_node.children): # extend the prefix and recurse:
|
|
247
|
+
# extension = branch if pointer == tee else space
|
|
248
|
+
extension = branch if tee in pointer else space
|
|
249
|
+
# i.e. space because last, └── , above so no more |
|
|
250
|
+
extension = ccu.wrap_evenly_spaced_color(
|
|
251
|
+
s=extension,
|
|
252
|
+
n_of_item=n_item,
|
|
253
|
+
n_classes=n_classes,
|
|
254
|
+
)
|
|
255
|
+
if depth is not None and depth <= 0:
|
|
256
|
+
continue
|
|
257
|
+
yield from self._generate_mcts_tree(
|
|
258
|
+
current_node,
|
|
259
|
+
prefix=prefix + extension,
|
|
260
|
+
depth=depth - 1 if depth is not None else None
|
|
261
|
+
)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import copy
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
6
|
+
import gymnasium as gym
|
|
7
|
+
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
|
+
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
|
+
|
|
10
|
+
from gymcts.gymcts_gym_env import SoloMCTSGymEnv
|
|
11
|
+
|
|
12
|
+
from gymcts.logger import log
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
16
|
+
_terminal_flag: bool = False
|
|
17
|
+
_last_reward: SupportsFloat = 0
|
|
18
|
+
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
19
|
+
|
|
20
|
+
_action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
env,
|
|
25
|
+
action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
|
|
26
|
+
buffer_length: int = 100,
|
|
27
|
+
):
|
|
28
|
+
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
29
|
+
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
30
|
+
|
|
31
|
+
gym.Wrapper.__init__(self, env)
|
|
32
|
+
|
|
33
|
+
self._wrapper_action_history = []
|
|
34
|
+
|
|
35
|
+
# assert that the action space is discrete
|
|
36
|
+
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
37
|
+
raise ValueError("Only discrete action spaces are supported.")
|
|
38
|
+
|
|
39
|
+
if action_mask_fn is not None:
|
|
40
|
+
# copy of stable baselines3 contrib implementation
|
|
41
|
+
if isinstance(action_mask_fn, str):
|
|
42
|
+
found_method = getattr(self.env, action_mask_fn)
|
|
43
|
+
if not callable(found_method):
|
|
44
|
+
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
|
|
45
|
+
|
|
46
|
+
self._action_mask_fn = found_method
|
|
47
|
+
else:
|
|
48
|
+
self._action_mask_fn = action_mask_fn
|
|
49
|
+
|
|
50
|
+
def load_state(self, state: list[int]) -> None:
|
|
51
|
+
self.env.reset()
|
|
52
|
+
self._wrapper_action_history = []
|
|
53
|
+
|
|
54
|
+
for action in state:
|
|
55
|
+
self.env.step(action)
|
|
56
|
+
self._wrapper_action_history.append(action)
|
|
57
|
+
|
|
58
|
+
def is_terminal(self) -> bool:
|
|
59
|
+
if not len(self.get_valid_actions()):
|
|
60
|
+
return True
|
|
61
|
+
else:
|
|
62
|
+
return self._terminal_flag
|
|
63
|
+
|
|
64
|
+
def action_masks(self) -> np.ndarray | None:
|
|
65
|
+
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
66
|
+
|
|
67
|
+
def get_valid_actions(self) -> list[int]:
|
|
68
|
+
if self._action_mask_fn is None:
|
|
69
|
+
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
70
|
+
return list(range(action_space.n))
|
|
71
|
+
else:
|
|
72
|
+
return [i for i, mask in enumerate(self.action_masks()) if mask]
|
|
73
|
+
|
|
74
|
+
def rollout(self) -> float:
|
|
75
|
+
log.debug("performing rollout")
|
|
76
|
+
# random rollout
|
|
77
|
+
# perform random valid action util terminal
|
|
78
|
+
is_terminal_state = self.is_terminal()
|
|
79
|
+
|
|
80
|
+
if is_terminal_state:
|
|
81
|
+
_, _, _, _, info = self._step_tuple
|
|
82
|
+
episode_return = info["episode"]["r"]
|
|
83
|
+
return episode_return
|
|
84
|
+
|
|
85
|
+
while not is_terminal_state:
|
|
86
|
+
action = random.choice(self.get_valid_actions())
|
|
87
|
+
# print(f"Valid actions: {self.get_valid_actions()}, selected action: {action}")
|
|
88
|
+
_obs, _reward, is_terminal_state, _truncated, info = self.step(action)
|
|
89
|
+
|
|
90
|
+
episode_return = info["episode"]["r"]
|
|
91
|
+
log.debug(f"Rollout return: {episode_return}")
|
|
92
|
+
return episode_return
|
|
93
|
+
|
|
94
|
+
def get_state(self) -> list[int]:
|
|
95
|
+
return self._wrapper_action_history.copy()
|
|
96
|
+
|
|
97
|
+
def step(
|
|
98
|
+
self, action: WrapperActType
|
|
99
|
+
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
100
|
+
step_tuple = self.env.step(action)
|
|
101
|
+
self._wrapper_action_history.append(action)
|
|
102
|
+
obs, reward, terminated, truncated, info = step_tuple
|
|
103
|
+
|
|
104
|
+
self._terminal_flag = terminated or truncated
|
|
105
|
+
self._step_tuple = step_tuple
|
|
106
|
+
|
|
107
|
+
return step_tuple
|
gymcts/gymcts_gym_env.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import gymnasium as gym
|
|
4
|
+
|
|
5
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SoloMCTSGymEnv(ABC, gym.Env):
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def get_state(self) -> Any:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def load_state(self, state: Any) -> None:
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def is_terminal(self) -> bool:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_valid_actions(self) -> list[int]:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def rollout(self) -> float:
|
|
28
|
+
pass
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import copy
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
6
|
+
import gymnasium as gym
|
|
7
|
+
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
|
+
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
|
+
|
|
10
|
+
from gymcts.gymcts_gym_env import SoloMCTSGymEnv
|
|
11
|
+
|
|
12
|
+
from gymcts.logger import log
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_terminal_flag:bool = False
|
|
19
|
+
_last_reward: SupportsFloat = 0
|
|
20
|
+
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
21
|
+
|
|
22
|
+
_action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
|
|
23
|
+
|
|
24
|
+
def is_terminal(self) -> bool:
|
|
25
|
+
return self._terminal_flag
|
|
26
|
+
|
|
27
|
+
def load_state(self, state: Any) -> None:
|
|
28
|
+
msg = """
|
|
29
|
+
The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
|
|
30
|
+
The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
|
|
31
|
+
'self' in a method cannot be replaced with another object (as far as i know). Therefore the copy is done by
|
|
32
|
+
MCTSaAgent here.
|
|
33
|
+
"""
|
|
34
|
+
raise NotImplementedError(msg)
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
env,
|
|
38
|
+
action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
|
|
39
|
+
buffer_length: int = 100,
|
|
40
|
+
record_video: bool = False,
|
|
41
|
+
):
|
|
42
|
+
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
43
|
+
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
44
|
+
|
|
45
|
+
gym.Wrapper.__init__(self, env)
|
|
46
|
+
# super().__init__(env)
|
|
47
|
+
|
|
48
|
+
# assert that the action space is discrete
|
|
49
|
+
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
50
|
+
raise ValueError("Only discrete action spaces are supported.")
|
|
51
|
+
|
|
52
|
+
if action_mask_fn is not None:
|
|
53
|
+
# copy of stable baselines3 contrib implementation
|
|
54
|
+
if isinstance(action_mask_fn, str):
|
|
55
|
+
found_method = getattr(self.env, action_mask_fn)
|
|
56
|
+
if not callable(found_method):
|
|
57
|
+
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
|
|
58
|
+
|
|
59
|
+
self._action_mask_fn = found_method
|
|
60
|
+
else:
|
|
61
|
+
self._action_mask_fn = action_mask_fn
|
|
62
|
+
|
|
63
|
+
def get_state(self) -> Any:
|
|
64
|
+
log.debug("getting state")
|
|
65
|
+
original_state = self
|
|
66
|
+
copied_state = copy.deepcopy(self)
|
|
67
|
+
|
|
68
|
+
log.debug(f"original state memory location: {hex(id(original_state))}")
|
|
69
|
+
log.debug(f"copied memory location: {hex(id(copied_state))}")
|
|
70
|
+
|
|
71
|
+
return copied_state
|
|
72
|
+
|
|
73
|
+
def action_masks(self) -> np.ndarray | None:
|
|
74
|
+
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
75
|
+
|
|
76
|
+
def get_valid_actions(self) -> list[int]:
|
|
77
|
+
if self._action_mask_fn is None:
|
|
78
|
+
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
79
|
+
return list(range(action_space.n))
|
|
80
|
+
else:
|
|
81
|
+
return [i for i, mask in enumerate(self.action_masks()) if mask]
|
|
82
|
+
|
|
83
|
+
def step(
|
|
84
|
+
self, action: WrapperActType
|
|
85
|
+
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
86
|
+
step_tuple = self.env.step(action)
|
|
87
|
+
|
|
88
|
+
obs, reward, terminated, truncated, info = step_tuple
|
|
89
|
+
self._terminal_flag = terminated or truncated
|
|
90
|
+
self._step_tuple = step_tuple
|
|
91
|
+
|
|
92
|
+
return step_tuple
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def rollout(self) -> float:
|
|
96
|
+
log.debug("performing rollout")
|
|
97
|
+
# random rollout
|
|
98
|
+
# perform random valid action util terminal
|
|
99
|
+
is_terminal_state = self.is_terminal()
|
|
100
|
+
|
|
101
|
+
if is_terminal_state:
|
|
102
|
+
_, _, _, _, info = self._step_tuple
|
|
103
|
+
episode_return = info["episode"]["r"]
|
|
104
|
+
return episode_return
|
|
105
|
+
|
|
106
|
+
while not is_terminal_state:
|
|
107
|
+
action = random.choice(self.get_valid_actions())
|
|
108
|
+
# print(f"Valid actions: {self.get_valid_actions()}, selected action: {action}")
|
|
109
|
+
_obs, _reward, is_terminal_state, _truncated, info = self.step(action)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
episode_return = info["episode"]["r"]
|
|
113
|
+
log.debug(f"Rollout return: {episode_return}")
|
|
114
|
+
return episode_return
|