gymcts 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +4 -3
- gymcts/{gymcts_deterministic_wrapper.py → gymcts_action_history_wrapper.py} +2 -2
- gymcts/gymcts_agent.py +24 -68
- gymcts/{gymcts_naive_wrapper.py → gymcts_deepcopy_wrapper.py} +2 -2
- gymcts/gymcts_distributed_agent.py +281 -0
- gymcts/{gymcts_gym_env.py → gymcts_env_abc.py} +1 -1
- gymcts/gymcts_node.py +25 -39
- gymcts/gymcts_tree_plotter.py +75 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.0.dist-info}/METADATA +25 -23
- gymcts-1.2.0.dist-info/RECORD +15 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.0.dist-info}/WHEEL +1 -1
- gymcts-1.0.0.dist-info/RECORD +0 -13
- {gymcts-1.0.0.dist-info → gymcts-1.2.0.dist-info/licenses}/LICENSE +0 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.0.dist-info}/top_level.txt +0 -0
gymcts/colorful_console_utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
import matplotlib.pyplot as plt
|
|
2
4
|
import numpy as np
|
|
3
5
|
|
|
@@ -103,8 +105,7 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
|
|
|
103
105
|
f"{CEND}"
|
|
104
106
|
|
|
105
107
|
|
|
106
|
-
|
|
107
|
-
def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
|
|
108
|
+
def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
|
|
108
109
|
if s is None or n_of_item is None or n_classes is None:
|
|
109
110
|
return s
|
|
110
111
|
|
|
@@ -117,7 +118,7 @@ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbo
|
|
|
117
118
|
return f"{color_asni}{s}{CEND}"
|
|
118
119
|
|
|
119
120
|
|
|
120
|
-
def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
|
|
121
|
+
def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
|
|
121
122
|
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
122
123
|
return s
|
|
123
124
|
|
|
@@ -7,12 +7,12 @@ import gymnasium as gym
|
|
|
7
7
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
8
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
9
|
|
|
10
|
-
from gymcts.
|
|
10
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
11
|
|
|
12
12
|
from gymcts.logger import log
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
16
|
_terminal_flag: bool = False
|
|
17
17
|
_last_reward: SupportsFloat = 0
|
|
18
18
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
gymcts/gymcts_agent.py
CHANGED
|
@@ -3,27 +3,28 @@ import gymnasium as gym
|
|
|
3
3
|
|
|
4
4
|
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
5
|
|
|
6
|
-
from gymcts.
|
|
7
|
-
from gymcts.
|
|
8
|
-
from gymcts.gymcts_node import
|
|
6
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
7
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
8
|
+
from gymcts.gymcts_node import GymctsNode
|
|
9
|
+
from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
9
10
|
|
|
10
11
|
from gymcts.logger import log
|
|
11
12
|
|
|
12
13
|
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
class
|
|
16
|
+
class GymctsAgent:
|
|
16
17
|
render_tree_after_step: bool = False
|
|
17
18
|
render_tree_max_depth: int = 2
|
|
18
19
|
exclude_unvisited_nodes_from_render: bool = False
|
|
19
20
|
number_of_simulations_per_step: int = 25
|
|
20
21
|
|
|
21
|
-
env:
|
|
22
|
-
search_root_node:
|
|
22
|
+
env: GymctsABC
|
|
23
|
+
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
23
24
|
clear_mcts_tree_after_step: bool
|
|
24
25
|
|
|
25
26
|
def __init__(self,
|
|
26
|
-
env:
|
|
27
|
+
env: GymctsABC,
|
|
27
28
|
clear_mcts_tree_after_step: bool = True,
|
|
28
29
|
render_tree_after_step: bool = False,
|
|
29
30
|
render_tree_max_depth: int = 2,
|
|
@@ -43,13 +44,13 @@ class SoloMCTSAgent:
|
|
|
43
44
|
self.env = env
|
|
44
45
|
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
45
46
|
|
|
46
|
-
self.search_root_node =
|
|
47
|
+
self.search_root_node = GymctsNode(
|
|
47
48
|
action=None,
|
|
48
49
|
parent=None,
|
|
49
50
|
env_reference=env,
|
|
50
51
|
)
|
|
51
52
|
|
|
52
|
-
def navigate_to_leaf(self, from_node:
|
|
53
|
+
def navigate_to_leaf(self, from_node: GymctsNode) -> GymctsNode:
|
|
53
54
|
log.debug(f"Navigate to leaf. from_node: {from_node}")
|
|
54
55
|
if from_node.terminal:
|
|
55
56
|
log.debug("Node is terminal. Returning from_node")
|
|
@@ -66,7 +67,7 @@ class SoloMCTSAgent:
|
|
|
66
67
|
log.debug(f"Selected leaf node: {temp_node}")
|
|
67
68
|
return temp_node
|
|
68
69
|
|
|
69
|
-
def expand_node(self, node:
|
|
70
|
+
def expand_node(self, node: GymctsNode) -> None:
|
|
70
71
|
log.debug(f"expanding node: {node}")
|
|
71
72
|
# EXPANSION STRATEGY
|
|
72
73
|
# expand all children
|
|
@@ -78,7 +79,7 @@ class SoloMCTSAgent:
|
|
|
78
79
|
self._load_state(node)
|
|
79
80
|
|
|
80
81
|
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
81
|
-
child_dict[action] =
|
|
82
|
+
child_dict[action] = GymctsNode(
|
|
82
83
|
action=action,
|
|
83
84
|
parent=node,
|
|
84
85
|
env_reference=self.env,
|
|
@@ -110,14 +111,14 @@ class SoloMCTSAgent:
|
|
|
110
111
|
# restore state of current node
|
|
111
112
|
return action_list
|
|
112
113
|
|
|
113
|
-
def _load_state(self, node:
|
|
114
|
-
if isinstance(self.env,
|
|
114
|
+
def _load_state(self, node: GymctsNode) -> None:
|
|
115
|
+
if isinstance(self.env, DeepCopyMCTSGymEnvWrapper):
|
|
115
116
|
self.env = copy.deepcopy(node.state)
|
|
116
117
|
else:
|
|
117
118
|
self.env.load_state(node.state)
|
|
118
119
|
|
|
119
|
-
def perform_mcts_step(self, search_start_node:
|
|
120
|
-
render_tree_after_step: bool = None) -> tuple[int,
|
|
120
|
+
def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
|
|
121
|
+
render_tree_after_step: bool = None) -> tuple[int, GymctsNode]:
|
|
121
122
|
|
|
122
123
|
if render_tree_after_step is None:
|
|
123
124
|
render_tree_after_step = self.render_tree_after_step
|
|
@@ -149,7 +150,7 @@ class SoloMCTSAgent:
|
|
|
149
150
|
|
|
150
151
|
return action, next_node
|
|
151
152
|
|
|
152
|
-
def vanilla_mcts_search(self, search_start_node:
|
|
153
|
+
def vanilla_mcts_search(self, search_start_node: GymctsNode = None, num_simulations=10) -> int:
|
|
153
154
|
log.debug(f"performing one MCTS search step with {num_simulations} simulations")
|
|
154
155
|
if search_start_node is None:
|
|
155
156
|
search_start_node = self.search_root_node
|
|
@@ -178,7 +179,7 @@ class SoloMCTSAgent:
|
|
|
178
179
|
|
|
179
180
|
return search_start_node.get_best_action()
|
|
180
181
|
|
|
181
|
-
def show_mcts_tree(self, start_node:
|
|
182
|
+
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
182
183
|
|
|
183
184
|
if start_node is None:
|
|
184
185
|
start_node = self.search_root_node
|
|
@@ -187,13 +188,17 @@ class SoloMCTSAgent:
|
|
|
187
188
|
tree_max_depth = self.render_tree_max_depth
|
|
188
189
|
|
|
189
190
|
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
190
|
-
for line in
|
|
191
|
+
for line in _generate_mcts_tree(
|
|
192
|
+
start_node=start_node,
|
|
193
|
+
depth=tree_max_depth,
|
|
194
|
+
action_space_n=self.env.action_space.n,
|
|
195
|
+
):
|
|
191
196
|
print(line)
|
|
192
197
|
|
|
193
198
|
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
194
199
|
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
195
200
|
|
|
196
|
-
def backpropagation(self, node:
|
|
201
|
+
def backpropagation(self, node: GymctsNode, episode_return: float) -> None:
|
|
197
202
|
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
198
203
|
while not node.is_root():
|
|
199
204
|
# node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
|
|
@@ -209,53 +214,4 @@ class SoloMCTSAgent:
|
|
|
209
214
|
node.max_value = max(node.max_value, episode_return)
|
|
210
215
|
node.min_value = min(node.min_value, episode_return)
|
|
211
216
|
|
|
212
|
-
def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
|
|
213
217
|
|
|
214
|
-
if prefix is None:
|
|
215
|
-
prefix = ""
|
|
216
|
-
import gymcts.colorful_console_utils as ccu
|
|
217
|
-
|
|
218
|
-
if start_node is None:
|
|
219
|
-
start_node = self.search_root_node
|
|
220
|
-
|
|
221
|
-
# prefix components:
|
|
222
|
-
space = ' '
|
|
223
|
-
branch = '│ '
|
|
224
|
-
# pointers:
|
|
225
|
-
tee = '├── '
|
|
226
|
-
last = '└── '
|
|
227
|
-
|
|
228
|
-
contents = start_node.children.values() if start_node.children is not None else []
|
|
229
|
-
if self.exclude_unvisited_nodes_from_render:
|
|
230
|
-
contents = [node for node in contents if node.visit_count > 0]
|
|
231
|
-
# contents each get pointers that are ├── with a final └── :
|
|
232
|
-
# pointers = [tee] * (len(contents) - 1) + [last]
|
|
233
|
-
pointers = [tee for _ in range(len(contents) - 1)] + [last]
|
|
234
|
-
|
|
235
|
-
for pointer, current_node in zip(pointers, contents):
|
|
236
|
-
n_item = current_node.parent.action if current_node.parent is not None else 0
|
|
237
|
-
n_classes = self.env.action_space.n
|
|
238
|
-
|
|
239
|
-
pointer = ccu.wrap_evenly_spaced_color(
|
|
240
|
-
s=pointer,
|
|
241
|
-
n_of_item=n_item,
|
|
242
|
-
n_classes=n_classes,
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
|
|
246
|
-
if current_node.children and len(current_node.children): # extend the prefix and recurse:
|
|
247
|
-
# extension = branch if pointer == tee else space
|
|
248
|
-
extension = branch if tee in pointer else space
|
|
249
|
-
# i.e. space because last, └── , above so no more |
|
|
250
|
-
extension = ccu.wrap_evenly_spaced_color(
|
|
251
|
-
s=extension,
|
|
252
|
-
n_of_item=n_item,
|
|
253
|
-
n_classes=n_classes,
|
|
254
|
-
)
|
|
255
|
-
if depth is not None and depth <= 0:
|
|
256
|
-
continue
|
|
257
|
-
yield from self._generate_mcts_tree(
|
|
258
|
-
current_node,
|
|
259
|
-
prefix=prefix + extension,
|
|
260
|
-
depth=depth - 1 if depth is not None else None
|
|
261
|
-
)
|
|
@@ -7,12 +7,12 @@ import gymnasium as gym
|
|
|
7
7
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
8
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
9
|
|
|
10
|
-
from gymcts.
|
|
10
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
11
|
|
|
12
12
|
from gymcts.logger import log
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
_terminal_flag:bool = False
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gymnasium as gym
|
|
3
|
+
|
|
4
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
|
+
|
|
6
|
+
from ray.types import ObjectRef
|
|
7
|
+
|
|
8
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
9
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
10
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
11
|
+
from gymcts.gymcts_node import GymctsNode
|
|
12
|
+
from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
13
|
+
|
|
14
|
+
from gymcts.logger import log
|
|
15
|
+
|
|
16
|
+
import ray
|
|
17
|
+
import copy
|
|
18
|
+
|
|
19
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@ray.remote
|
|
23
|
+
def mcts_lookahead(
|
|
24
|
+
gymcts_start_node: GymctsNode,
|
|
25
|
+
env: GymctsABC,
|
|
26
|
+
num_simulations: int) -> GymctsNode:
|
|
27
|
+
agent = GymctsAgent(
|
|
28
|
+
env=env,
|
|
29
|
+
clear_mcts_tree_after_step=False,
|
|
30
|
+
number_of_simulations_per_step=num_simulations,
|
|
31
|
+
)
|
|
32
|
+
agent.search_root_node = gymcts_start_node
|
|
33
|
+
|
|
34
|
+
agent.vanilla_mcts_search(
|
|
35
|
+
search_start_node=gymcts_start_node,
|
|
36
|
+
num_simulations=num_simulations,
|
|
37
|
+
)
|
|
38
|
+
return agent.search_root_node
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def merge_nodes(gymcts_node1, gymcts_node2, perform_state_equality_check=False):
|
|
42
|
+
log.debug(f"merging {gymcts_node1} and {gymcts_node2}")
|
|
43
|
+
# maybe add some state equality check here
|
|
44
|
+
if perform_state_equality_check:
|
|
45
|
+
if gymcts_node1.state != gymcts_node2.state:
|
|
46
|
+
raise ValueError("States are different")
|
|
47
|
+
|
|
48
|
+
if gymcts_node1 is None:
|
|
49
|
+
log.debug(f"first node is None, returning second node ({gymcts_node2})")
|
|
50
|
+
return gymcts_node2
|
|
51
|
+
if gymcts_node2 is None:
|
|
52
|
+
log.debug(f"second node is None, returning first node ({gymcts_node1})")
|
|
53
|
+
return gymcts_node1
|
|
54
|
+
if gymcts_node1 is None and gymcts_node2 is None:
|
|
55
|
+
log.error("Both nodes are None")
|
|
56
|
+
raise ValueError("Both nodes are None")
|
|
57
|
+
|
|
58
|
+
if gymcts_node1.is_leaf() and not gymcts_node2.is_leaf():
|
|
59
|
+
log.debug(f"first node is leaf, second node is not leaf")
|
|
60
|
+
gymcts_node2.parent = gymcts_node1.parent
|
|
61
|
+
log.debug(f"returning first node: {gymcts_node2}")
|
|
62
|
+
return gymcts_node2
|
|
63
|
+
|
|
64
|
+
if gymcts_node2.is_leaf() and not gymcts_node1.is_leaf():
|
|
65
|
+
log.debug(f"second node is leaf, first node is not leaf")
|
|
66
|
+
log.debug(f"returning first node: {gymcts_node1}")
|
|
67
|
+
return gymcts_node1
|
|
68
|
+
|
|
69
|
+
if gymcts_node1.is_leaf() and gymcts_node2.is_leaf():
|
|
70
|
+
log.debug(f"both nodes are leafs, returning first node")
|
|
71
|
+
log.debug(f"returning first node: {gymcts_node1}")
|
|
72
|
+
return gymcts_node1
|
|
73
|
+
|
|
74
|
+
# check if gymcts_node1 and gymcts_node2 have the same children
|
|
75
|
+
if gymcts_node1.children.keys() != gymcts_node2.children.keys():
|
|
76
|
+
log.error("Nodes have different children")
|
|
77
|
+
raise ValueError("Nodes have different children")
|
|
78
|
+
|
|
79
|
+
for (action1, child1), (action2, child2) in zip(gymcts_node1.children.items(), gymcts_node2.children.items()):
|
|
80
|
+
if action1 != action2:
|
|
81
|
+
log.error("Actions are different")
|
|
82
|
+
raise ValueError("Actions are different")
|
|
83
|
+
log.debug(f"merging children with action {action1} for node {gymcts_node1}")
|
|
84
|
+
gymcts_node1.children[action1] = merge_nodes(
|
|
85
|
+
child1,
|
|
86
|
+
child2,
|
|
87
|
+
perform_state_equality_check=perform_state_equality_check
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
visit_count = gymcts_node1.visit_count + gymcts_node2.visit_count
|
|
91
|
+
mean_value = (
|
|
92
|
+
gymcts_node1.mean_value * gymcts_node1.visit_count + gymcts_node2.mean_value * gymcts_node2.visit_count) / visit_count
|
|
93
|
+
max_value = max(gymcts_node1.max_value, gymcts_node2.max_value)
|
|
94
|
+
min_value = min(gymcts_node1.min_value, gymcts_node2.min_value)
|
|
95
|
+
|
|
96
|
+
gymcts_node1.visit_count = visit_count
|
|
97
|
+
gymcts_node1.mean_value = mean_value
|
|
98
|
+
gymcts_node1.max_value = max_value
|
|
99
|
+
gymcts_node1.min_value = min_value
|
|
100
|
+
log.debug(f"merged node: {gymcts_node1}")
|
|
101
|
+
log.debug(f"returning node: {gymcts_node1}")
|
|
102
|
+
return gymcts_node1
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class DistributedGymctsAgent:
|
|
106
|
+
render_tree_after_step: bool = False
|
|
107
|
+
render_tree_max_depth: int = 2
|
|
108
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
109
|
+
number_of_simulations_per_step: int = 25
|
|
110
|
+
num_parallel: int = 4
|
|
111
|
+
|
|
112
|
+
env: GymctsABC
|
|
113
|
+
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
114
|
+
clear_mcts_tree_after_step: bool
|
|
115
|
+
|
|
116
|
+
def __init__(self,
|
|
117
|
+
env: GymctsABC,
|
|
118
|
+
render_tree_after_step: bool = False,
|
|
119
|
+
render_tree_max_depth: int = 2,
|
|
120
|
+
num_parallel: int = 4,
|
|
121
|
+
number_of_simulations_per_step: int = 25,
|
|
122
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
123
|
+
):
|
|
124
|
+
# check if action space of env is discrete
|
|
125
|
+
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
126
|
+
raise ValueError("Action space must be discrete.")
|
|
127
|
+
|
|
128
|
+
self.num_parallel = num_parallel
|
|
129
|
+
|
|
130
|
+
self.render_tree_after_step = render_tree_after_step
|
|
131
|
+
self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
|
|
132
|
+
self.render_tree_max_depth = render_tree_max_depth
|
|
133
|
+
|
|
134
|
+
self.number_of_simulations_per_step = number_of_simulations_per_step
|
|
135
|
+
|
|
136
|
+
self.env = env
|
|
137
|
+
|
|
138
|
+
self.search_root_node = GymctsNode(
|
|
139
|
+
action=None,
|
|
140
|
+
parent=None,
|
|
141
|
+
env_reference=env,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
|
|
145
|
+
|
|
146
|
+
if num_simulations_per_step is None:
|
|
147
|
+
num_simulations_per_step = self.number_of_simulations_per_step
|
|
148
|
+
if render_tree_after_step is None:
|
|
149
|
+
render_tree_after_step = self.render_tree_after_step
|
|
150
|
+
|
|
151
|
+
log.debug(f"Solving from root node: {self.search_root_node}")
|
|
152
|
+
|
|
153
|
+
current_node = self.search_root_node
|
|
154
|
+
|
|
155
|
+
action_list = []
|
|
156
|
+
|
|
157
|
+
while not current_node.terminal:
|
|
158
|
+
next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
|
|
159
|
+
render_tree_after_step=render_tree_after_step)
|
|
160
|
+
log.info(
|
|
161
|
+
f"selected action {next_action} after {self.num_parallel} x {num_simulations_per_step} simulations.")
|
|
162
|
+
action_list.append(next_action)
|
|
163
|
+
log.info(f"current action list: {action_list}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
log.info(f"Final action list: {action_list}")
|
|
167
|
+
# restore state of current node
|
|
168
|
+
return action_list
|
|
169
|
+
|
|
170
|
+
def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
|
|
171
|
+
render_tree_after_step: bool = None, num_parallel: int = None) -> tuple[int, GymctsNode]:
|
|
172
|
+
|
|
173
|
+
if render_tree_after_step is None:
|
|
174
|
+
render_tree_after_step = self.render_tree_after_step
|
|
175
|
+
|
|
176
|
+
if render_tree_after_step is None:
|
|
177
|
+
render_tree_after_step = self.render_tree_after_step
|
|
178
|
+
|
|
179
|
+
if num_simulations is None:
|
|
180
|
+
num_simulations = self.number_of_simulations_per_step
|
|
181
|
+
|
|
182
|
+
if search_start_node is None:
|
|
183
|
+
search_start_node = self.search_root_node
|
|
184
|
+
|
|
185
|
+
if num_parallel is None:
|
|
186
|
+
num_parallel = self.num_parallel
|
|
187
|
+
|
|
188
|
+
# action = self.vanilla_mcts_search(
|
|
189
|
+
# search_start_node=search_start_node,
|
|
190
|
+
# num_simulations=num_simulations,
|
|
191
|
+
# )
|
|
192
|
+
# next_node = search_start_node.children[action]
|
|
193
|
+
|
|
194
|
+
mcts_interation_futures = [
|
|
195
|
+
mcts_lookahead.remote(
|
|
196
|
+
copy.deepcopy(search_start_node),
|
|
197
|
+
copy.deepcopy(self.env),
|
|
198
|
+
num_simulations=num_simulations
|
|
199
|
+
)
|
|
200
|
+
for _ in range(num_parallel)
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
while mcts_interation_futures:
|
|
204
|
+
ready_gymcts_nodes, mcts_interation_futures = ray.wait(mcts_interation_futures)
|
|
205
|
+
for ready_node_ref in ready_gymcts_nodes:
|
|
206
|
+
ready_node = ray.get(ready_node_ref)
|
|
207
|
+
|
|
208
|
+
# merge the tree
|
|
209
|
+
search_start_node = merge_nodes(search_start_node, ready_node)
|
|
210
|
+
|
|
211
|
+
action = search_start_node.get_best_action()
|
|
212
|
+
next_node = search_start_node.children[action]
|
|
213
|
+
|
|
214
|
+
if self.render_tree_after_step:
|
|
215
|
+
self.show_mcts_tree(
|
|
216
|
+
start_node=search_start_node,
|
|
217
|
+
tree_max_depth=self.render_tree_max_depth
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# to clear memory we need to remove all nodes except the current node
|
|
222
|
+
# this is done by setting the root node to the current node
|
|
223
|
+
# and setting the parent of the current node to None
|
|
224
|
+
# we also need to reset the children of the current node
|
|
225
|
+
# this is done by calling the reset method
|
|
226
|
+
#
|
|
227
|
+
# in a distributed setting we need we delete all previous nodes
|
|
228
|
+
# this is because backpropagation merging trees is already computationally expensive
|
|
229
|
+
# and backpropagating the whole tree would be even more expensive
|
|
230
|
+
next_node.reset()
|
|
231
|
+
|
|
232
|
+
self.search_root_node = next_node
|
|
233
|
+
|
|
234
|
+
return action, next_node
|
|
235
|
+
|
|
236
|
+
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
237
|
+
|
|
238
|
+
if start_node is None:
|
|
239
|
+
start_node = self.search_root_node
|
|
240
|
+
|
|
241
|
+
if tree_max_depth is None:
|
|
242
|
+
tree_max_depth = self.render_tree_max_depth
|
|
243
|
+
|
|
244
|
+
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
245
|
+
for line in _generate_mcts_tree(
|
|
246
|
+
start_node=start_node,
|
|
247
|
+
depth=tree_max_depth,
|
|
248
|
+
action_space_n=self.env.action_space.n
|
|
249
|
+
):
|
|
250
|
+
print(line)
|
|
251
|
+
|
|
252
|
+
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
253
|
+
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
if __name__ == '__main__':
|
|
257
|
+
ray.init()
|
|
258
|
+
|
|
259
|
+
log.setLevel(20) # 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CR
|
|
260
|
+
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)
|
|
261
|
+
env.reset()
|
|
262
|
+
|
|
263
|
+
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
264
|
+
# env1 = ActionHistoryMCTSGymEnvWrapper(env1)
|
|
265
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
266
|
+
|
|
267
|
+
# 2. create the agent
|
|
268
|
+
agent1 = DistributedGymctsAgent(
|
|
269
|
+
env=env,
|
|
270
|
+
render_tree_after_step=True,
|
|
271
|
+
number_of_simulations_per_step=1000,
|
|
272
|
+
exclude_unvisited_nodes_from_render=True,
|
|
273
|
+
num_parallel=1,
|
|
274
|
+
)
|
|
275
|
+
import time
|
|
276
|
+
|
|
277
|
+
start_time = time.perf_counter()
|
|
278
|
+
actions = agent1.solve()
|
|
279
|
+
end_time = time.perf_counter()
|
|
280
|
+
|
|
281
|
+
print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
|
gymcts/gymcts_node.py
CHANGED
|
@@ -4,21 +4,17 @@ import math
|
|
|
4
4
|
|
|
5
5
|
from typing import TypeVar, Any, SupportsFloat, Callable, Generator
|
|
6
6
|
|
|
7
|
-
from gymcts.
|
|
7
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
8
8
|
|
|
9
9
|
from gymcts.logger import log
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class SoloMCTSNode:
|
|
17
|
-
|
|
14
|
+
class GymctsNode:
|
|
18
15
|
# static properties
|
|
19
16
|
best_action_weight: float = 0.05
|
|
20
|
-
ubc_c
|
|
21
|
-
|
|
17
|
+
ubc_c = 0.707
|
|
22
18
|
|
|
23
19
|
# attributes
|
|
24
20
|
visit_count: int = 0
|
|
@@ -28,7 +24,6 @@ class SoloMCTSNode:
|
|
|
28
24
|
terminal: bool = False
|
|
29
25
|
state: Any
|
|
30
26
|
|
|
31
|
-
|
|
32
27
|
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
33
28
|
if not colored:
|
|
34
29
|
|
|
@@ -39,11 +34,9 @@ class SoloMCTSNode:
|
|
|
39
34
|
|
|
40
35
|
import gymcts.colorful_console_utils as ccu
|
|
41
36
|
|
|
42
|
-
|
|
43
37
|
if self.is_root():
|
|
44
38
|
return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
|
|
45
39
|
|
|
46
|
-
|
|
47
40
|
if action_space_n is None:
|
|
48
41
|
raise ValueError("action_space_n must be provided if colored is True")
|
|
49
42
|
|
|
@@ -68,25 +61,23 @@ class SoloMCTSNode:
|
|
|
68
61
|
if isinstance(value, int):
|
|
69
62
|
return f"{color}{value}{e}"
|
|
70
63
|
|
|
71
|
-
|
|
72
64
|
root_node = self.get_root()
|
|
73
65
|
mean_val = f"{self.mean_value:.2f}"
|
|
74
66
|
|
|
75
67
|
return ((f"("
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
68
|
+
f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
|
|
69
|
+
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
70
|
+
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
71
|
+
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
80
72
|
(f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
|
|
81
73
|
|
|
82
|
-
|
|
83
|
-
def traverse_nodes(self) -> Generator[TSoloMCTSNode, None, None]:
|
|
74
|
+
def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
|
|
84
75
|
yield self
|
|
85
76
|
if self.children:
|
|
86
77
|
for child in self.children.values():
|
|
87
78
|
yield from child.traverse_nodes()
|
|
88
79
|
|
|
89
|
-
def get_root(self) ->
|
|
80
|
+
def get_root(self) -> TGymctsNode:
|
|
90
81
|
if self.is_root():
|
|
91
82
|
return self
|
|
92
83
|
return self.parent.get_root()
|
|
@@ -101,19 +92,17 @@ class SoloMCTSNode:
|
|
|
101
92
|
return 0
|
|
102
93
|
return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
|
|
103
94
|
|
|
104
|
-
|
|
105
|
-
|
|
106
95
|
def __init__(self,
|
|
107
96
|
action: int | None,
|
|
108
|
-
parent:
|
|
109
|
-
env_reference:
|
|
97
|
+
parent: TGymctsNode | None,
|
|
98
|
+
env_reference: GymctsABC,
|
|
110
99
|
):
|
|
111
100
|
|
|
112
101
|
# field depending on whether the node is a root node or not
|
|
113
102
|
self.action: int | None
|
|
114
103
|
|
|
115
|
-
self.env_reference:
|
|
116
|
-
self.parent:
|
|
104
|
+
self.env_reference: GymctsABC
|
|
105
|
+
self.parent: GymctsNode | None
|
|
117
106
|
self.uuid = uuid.uuid4()
|
|
118
107
|
|
|
119
108
|
if parent is None:
|
|
@@ -133,7 +122,7 @@ class SoloMCTSNode:
|
|
|
133
122
|
|
|
134
123
|
from copy import copy
|
|
135
124
|
self.state = env_reference.get_state()
|
|
136
|
-
#log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
|
|
125
|
+
# log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
|
|
137
126
|
self.visit_count: int = 0
|
|
138
127
|
|
|
139
128
|
self.mean_value: float = 0
|
|
@@ -143,8 +132,7 @@ class SoloMCTSNode:
|
|
|
143
132
|
# safe valid action instead of calling the environment
|
|
144
133
|
# this reduces the compute but increases the memory usage
|
|
145
134
|
self.valid_actions: list[int] = env_reference.get_valid_actions()
|
|
146
|
-
self.children: dict[int,
|
|
147
|
-
|
|
135
|
+
self.children: dict[int, GymctsNode] | None = None # may be expanded later
|
|
148
136
|
|
|
149
137
|
def reset(self) -> None:
|
|
150
138
|
self.parent = None
|
|
@@ -153,35 +141,33 @@ class SoloMCTSNode:
|
|
|
153
141
|
self.mean_value: float = 0
|
|
154
142
|
self.max_value: float = -float("inf")
|
|
155
143
|
self.min_value: float = +float("inf")
|
|
156
|
-
self.children: dict[int,
|
|
144
|
+
self.children: dict[int, GymctsNode] | None = None # may be expanded later
|
|
157
145
|
|
|
158
146
|
# just setting the children of the parent node to None should be enough to trigger garbage collection
|
|
159
147
|
# however, we also set the parent to None to make sure that the parent is not referenced anymore
|
|
160
148
|
if self.parent:
|
|
161
149
|
self.parent.reset()
|
|
162
150
|
|
|
163
|
-
|
|
164
|
-
|
|
165
151
|
def is_root(self) -> bool:
|
|
166
152
|
return self.parent is None
|
|
167
153
|
|
|
168
154
|
def is_leaf(self) -> bool:
|
|
169
155
|
return self.children is None or len(self.children) == 0
|
|
170
156
|
|
|
171
|
-
def get_random_child(self) ->
|
|
157
|
+
def get_random_child(self) -> TGymctsNode:
|
|
172
158
|
if self.is_leaf():
|
|
173
|
-
raise ValueError("cannot get random child of leaf node")
|
|
159
|
+
raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
|
|
174
160
|
|
|
175
161
|
return list(self.children.values())[random.randint(0, len(self.children) - 1)]
|
|
176
162
|
|
|
177
163
|
def get_best_action(self) -> int:
|
|
178
164
|
return max(self.children.values(), key=lambda child: child.get_score()).action
|
|
179
165
|
|
|
180
|
-
def get_score(self) -> float:
|
|
166
|
+
def get_score(self) -> float: # todo: make it an attribute?
|
|
181
167
|
# return self.mean_value
|
|
182
|
-
assert 0 <=
|
|
183
|
-
a =
|
|
184
|
-
return (1-a) * self.mean_value + a * self.max_value
|
|
168
|
+
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
169
|
+
a = GymctsNode.best_action_weight
|
|
170
|
+
return (1 - a) * self.mean_value + a * self.max_value
|
|
185
171
|
|
|
186
172
|
def get_mean_value(self) -> float:
|
|
187
173
|
return self.mean_value
|
|
@@ -207,7 +193,7 @@ class SoloMCTSNode:
|
|
|
207
193
|
if self.is_root():
|
|
208
194
|
raise ValueError("ucb_score can only be called on non-root nodes")
|
|
209
195
|
# c = 0.707 # todo: make it an attribute?
|
|
210
|
-
c =
|
|
196
|
+
c = GymctsNode.ubc_c
|
|
211
197
|
if self.visit_count == 0:
|
|
212
198
|
return float("inf")
|
|
213
|
-
return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
|
|
199
|
+
return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from gymcts.gymcts_node import GymctsNode
|
|
2
|
+
|
|
3
|
+
from gymcts.logger import log
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _generate_mcts_tree(
|
|
7
|
+
start_node: GymctsNode = None,
|
|
8
|
+
prefix: str = None,
|
|
9
|
+
depth: int = None,
|
|
10
|
+
exclude_unvisited_nodes_from_render: bool = True,
|
|
11
|
+
action_space_n: int = None
|
|
12
|
+
) -> list[str]:
|
|
13
|
+
if prefix is None:
|
|
14
|
+
prefix = ""
|
|
15
|
+
import gymcts.colorful_console_utils as ccu
|
|
16
|
+
|
|
17
|
+
if start_node is None:
|
|
18
|
+
raise ValueError("start_node must not be None")
|
|
19
|
+
|
|
20
|
+
if action_space_n is None:
|
|
21
|
+
log.warning("action_space_n is None, defaulting to 100")
|
|
22
|
+
action_space_n = 100
|
|
23
|
+
|
|
24
|
+
# prefix components:
|
|
25
|
+
space = ' '
|
|
26
|
+
branch = '│ '
|
|
27
|
+
# pointers:
|
|
28
|
+
tee = '├── '
|
|
29
|
+
last = '└── '
|
|
30
|
+
|
|
31
|
+
contents = start_node.children.values() if start_node.children is not None else []
|
|
32
|
+
if exclude_unvisited_nodes_from_render:
|
|
33
|
+
contents = [node for node in contents if node.visit_count > 0]
|
|
34
|
+
# contents each get pointers that are ├── with a final └── :
|
|
35
|
+
# pointers = [tee] * (len(contents) - 1) + [last]
|
|
36
|
+
pointers = [tee for _ in range(len(contents) - 1)] + [last]
|
|
37
|
+
|
|
38
|
+
for pointer, current_node in zip(pointers, contents):
|
|
39
|
+
n_item = current_node.parent.action if current_node.parent is not None else 0
|
|
40
|
+
n_classes = action_space_n
|
|
41
|
+
|
|
42
|
+
pointer = ccu.wrap_evenly_spaced_color(
|
|
43
|
+
s=pointer,
|
|
44
|
+
n_of_item=n_item,
|
|
45
|
+
n_classes=n_classes,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
|
|
49
|
+
if current_node.children and len(current_node.children): # extend the prefix and recurse:
|
|
50
|
+
# extension = branch if pointer == tee else space
|
|
51
|
+
extension = branch if tee in pointer else space
|
|
52
|
+
# i.e. space because last, └── , above so no more |
|
|
53
|
+
extension = ccu.wrap_evenly_spaced_color(
|
|
54
|
+
s=extension,
|
|
55
|
+
n_of_item=n_item,
|
|
56
|
+
n_classes=n_classes,
|
|
57
|
+
)
|
|
58
|
+
if depth is not None and depth <= 0:
|
|
59
|
+
continue
|
|
60
|
+
yield from _generate_mcts_tree(
|
|
61
|
+
current_node,
|
|
62
|
+
prefix=prefix + extension,
|
|
63
|
+
action_space_n=action_space_n,
|
|
64
|
+
depth=depth - 1 if depth is not None else None
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def show_mcts_tree(
|
|
69
|
+
start_node: GymctsNode = None,
|
|
70
|
+
tree_max_depth: int = None,
|
|
71
|
+
action_space_n: int = None
|
|
72
|
+
) -> None:
|
|
73
|
+
print(start_node.__str__(colored=True, action_space_n=action_space_n))
|
|
74
|
+
for line in _generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
|
|
75
|
+
print(line)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -25,7 +25,7 @@ License: MIT License
|
|
|
25
25
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
26
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
27
|
SOFTWARE.
|
|
28
|
-
Project-URL: Homepage, https://github.com/Alexander-Nasuta/
|
|
28
|
+
Project-URL: Homepage, https://github.com/Alexander-Nasuta/gymcts
|
|
29
29
|
Platform: unix
|
|
30
30
|
Platform: linux
|
|
31
31
|
Platform: osx
|
|
@@ -34,7 +34,7 @@ Platform: win32
|
|
|
34
34
|
Classifier: License :: OSI Approved :: MIT License
|
|
35
35
|
Classifier: Programming Language :: Python
|
|
36
36
|
Classifier: Programming Language :: Python :: 3
|
|
37
|
-
Requires-Python: >=3.
|
|
37
|
+
Requires-Python: >=3.11
|
|
38
38
|
Description-Content-Type: text/markdown
|
|
39
39
|
License-File: LICENSE
|
|
40
40
|
Requires-Dist: rich
|
|
@@ -63,6 +63,9 @@ Requires-Dist: furo; extra == "dev"
|
|
|
63
63
|
Requires-Dist: twine; extra == "dev"
|
|
64
64
|
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
65
65
|
Requires-Dist: nbsphinx; extra == "dev"
|
|
66
|
+
Requires-Dist: jupytext; extra == "dev"
|
|
67
|
+
Requires-Dist: jupyter; extra == "dev"
|
|
68
|
+
Dynamic: license-file
|
|
66
69
|
|
|
67
70
|
# Graph Matrix Job Shop Env
|
|
68
71
|
|
|
@@ -118,8 +121,8 @@ The NaiveSoloMCTSGymEnvWrapper can be used with non-deterministic environments,
|
|
|
118
121
|
```python
|
|
119
122
|
import gymnasium as gym
|
|
120
123
|
|
|
121
|
-
from gymcts.gymcts_agent import
|
|
122
|
-
from gymcts.
|
|
124
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
125
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
123
126
|
|
|
124
127
|
from gymcts.logger import log
|
|
125
128
|
|
|
@@ -133,10 +136,10 @@ if __name__ == '__main__':
|
|
|
133
136
|
env.reset()
|
|
134
137
|
|
|
135
138
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
136
|
-
env =
|
|
139
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
137
140
|
|
|
138
141
|
# 2. create the agent
|
|
139
|
-
agent =
|
|
142
|
+
agent = GymctsAgent(
|
|
140
143
|
env=env,
|
|
141
144
|
clear_mcts_tree_after_step=False,
|
|
142
145
|
render_tree_after_step=True,
|
|
@@ -170,13 +173,13 @@ if __name__ == '__main__':
|
|
|
170
173
|
A minimal example of how to use the package with the FrozenLake environment and the DeterministicSoloMCTSGymEnvWrapper is provided in the following code snippet below.
|
|
171
174
|
The DeterministicSoloMCTSGymEnvWrapper can be used with deterministic environments, such as the FrozenLake environment without slippery ice.
|
|
172
175
|
|
|
173
|
-
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
176
|
+
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
174
177
|
|
|
175
178
|
```python
|
|
176
179
|
import gymnasium as gym
|
|
177
180
|
|
|
178
|
-
from gymcts.gymcts_agent import
|
|
179
|
-
from gymcts.
|
|
181
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
182
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
180
183
|
|
|
181
184
|
from gymcts.logger import log
|
|
182
185
|
|
|
@@ -190,10 +193,10 @@ if __name__ == '__main__':
|
|
|
190
193
|
env.reset()
|
|
191
194
|
|
|
192
195
|
# 1. wrap the environment with the wrapper
|
|
193
|
-
env =
|
|
196
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
194
197
|
|
|
195
198
|
# 2. create the agent
|
|
196
|
-
agent =
|
|
199
|
+
agent = GymctsAgent(
|
|
197
200
|
env=env,
|
|
198
201
|
clear_mcts_tree_after_step=False,
|
|
199
202
|
render_tree_after_step=True,
|
|
@@ -232,8 +235,8 @@ To create a video of the solution of the FrozenLake environment, you can use the
|
|
|
232
235
|
```python
|
|
233
236
|
import gymnasium as gym
|
|
234
237
|
|
|
235
|
-
from gymcts.gymcts_agent import
|
|
236
|
-
from gymcts.
|
|
238
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
239
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
237
240
|
|
|
238
241
|
from gymcts.logger import log
|
|
239
242
|
|
|
@@ -249,10 +252,10 @@ if __name__ == '__main__':
|
|
|
249
252
|
env.reset()
|
|
250
253
|
|
|
251
254
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
252
|
-
env =
|
|
255
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
253
256
|
|
|
254
257
|
# 2. create the agent
|
|
255
|
-
agent =
|
|
258
|
+
agent = GymctsAgent(
|
|
256
259
|
env=env,
|
|
257
260
|
clear_mcts_tree_after_step=False,
|
|
258
261
|
render_tree_after_step=True,
|
|
@@ -413,13 +416,12 @@ The color gradient is based on the minimum and maximum values of the respective
|
|
|
413
416
|
The visualisation is rendered in the terminal and can be limited to a certain depth of the tree.
|
|
414
417
|
The default depth is 2.
|
|
415
418
|
|
|
416
|
-
|
|
417
419
|
```python
|
|
418
420
|
import gymnasium as gym
|
|
419
421
|
|
|
420
|
-
from gymcts.gymcts_agent import
|
|
421
|
-
from gymcts.
|
|
422
|
-
from gymcts.
|
|
422
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
423
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
424
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
423
425
|
|
|
424
426
|
from gymcts.logger import log
|
|
425
427
|
|
|
@@ -433,10 +435,10 @@ if __name__ == '__main__':
|
|
|
433
435
|
env.reset()
|
|
434
436
|
|
|
435
437
|
# wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
436
|
-
env =
|
|
438
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
437
439
|
|
|
438
440
|
# create the agent
|
|
439
|
-
agent =
|
|
441
|
+
agent = GymctsAgent(
|
|
440
442
|
env=env,
|
|
441
443
|
clear_mcts_tree_after_step=False,
|
|
442
444
|
render_tree_after_step=False,
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
gymcts/colorful_console_utils.py,sha256=OhULcXHKbEA4uJDAEYCTcW6wUv0LsHX_XSYzZ_Szsv4,4553
|
|
3
|
+
gymcts/gymcts_action_history_wrapper.py,sha256=AjvBBwd1t9-nTYP09aMdlScAkFNXf5vOagejpjWYOPo,3810
|
|
4
|
+
gymcts/gymcts_agent.py,sha256=O2y98jKFjR5TzqVV7DO1jlcYDyzAgd_H2RF4-w4NP0g,8499
|
|
5
|
+
gymcts/gymcts_deepcopy_wrapper.py,sha256=OleQTnvxv3gLEo8-2asyeo-CpZ4HEbgyFGS5DTCD7NM,4167
|
|
6
|
+
gymcts/gymcts_distributed_agent.py,sha256=M7dyBfC8u3M99PJFoXKgIc_CPTyHGppmktkH-y9ci4U,10448
|
|
7
|
+
gymcts/gymcts_env_abc.py,sha256=7nCRiiClmmVLX-d_Q1dxeztmuvmAtmWZwjT81zrG1_w,575
|
|
8
|
+
gymcts/gymcts_node.py,sha256=PT_YZFwt1zjuvd8i9Wb5LEkHAqmJOFyPDp3GFD05lqM,7138
|
|
9
|
+
gymcts/gymcts_tree_plotter.py,sha256=eg207wHcDepwWODXzmDYQn1Aai29Cs4jFS1HNvAhlXs,2651
|
|
10
|
+
gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
|
|
11
|
+
gymcts-1.2.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
12
|
+
gymcts-1.2.0.dist-info/METADATA,sha256=zhEIFo0rOnv5hCv6ukImkq-9nshO4EfXMbHlhNlYhyA,23640
|
|
13
|
+
gymcts-1.2.0.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
|
14
|
+
gymcts-1.2.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
15
|
+
gymcts-1.2.0.dist-info/RECORD,,
|
gymcts-1.0.0.dist-info/RECORD
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
gymcts/colorful_console_utils.py,sha256=bbZzRFzimhsIhbT-nmz6v62WJLxFDgzFvqI_pmIsckE,4526
|
|
3
|
-
gymcts/gymcts_agent.py,sha256=TJXJH77T95EP3ZNtzWqlGw9iFF1R-nsItp7UA1ZlXUs,10537
|
|
4
|
-
gymcts/gymcts_deterministic_wrapper.py,sha256=PILGPaQnyG2u_2u48MEE3aeJCtdgjjO55ZFDxeIVeH0,3824
|
|
5
|
-
gymcts/gymcts_gym_env.py,sha256=R1Z1fhoywdXmPt_FYgrarIh0YFQvCifayAWnCcEiJKE,580
|
|
6
|
-
gymcts/gymcts_naive_wrapper.py,sha256=qeQ7rzBz7BFv2yCJj3GmdFt5UlTx5VHMw5ImZUl9H5k,4178
|
|
7
|
-
gymcts/gymcts_node.py,sha256=jxdtuC1iqeRtEA-Qfvq-mOuM8vdDl43iWe5hqItG90w,7185
|
|
8
|
-
gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
|
|
9
|
-
gymcts-1.0.0.dist-info/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
10
|
-
gymcts-1.0.0.dist-info/METADATA,sha256=sAXJQreADqEOviVL8nT8fmrx7hP-qM7C_-SC5FNw-94,23572
|
|
11
|
-
gymcts-1.0.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
12
|
-
gymcts-1.0.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
13
|
-
gymcts-1.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|