gymcts 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +26 -3
- gymcts/{gymcts_deterministic_wrapper.py → gymcts_action_history_wrapper.py} +74 -4
- gymcts/gymcts_agent.py +29 -69
- gymcts/{gymcts_naive_wrapper.py → gymcts_deepcopy_wrapper.py} +60 -3
- gymcts/gymcts_distributed_agent.py +299 -0
- gymcts/gymcts_env_abc.py +61 -0
- gymcts/gymcts_node.py +107 -44
- gymcts/gymcts_tree_plotter.py +96 -0
- gymcts/logger.py +1 -4
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/METADATA +54 -56
- gymcts-1.2.1.dist-info/RECORD +15 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/WHEEL +1 -1
- gymcts/gymcts_gym_env.py +0 -28
- gymcts-1.0.0.dist-info/RECORD +0 -13
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info/licenses}/LICENSE +0 -0
- {gymcts-1.0.0.dist-info → gymcts-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gymnasium as gym
|
|
3
|
+
|
|
4
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
|
+
|
|
6
|
+
from ray.types import ObjectRef
|
|
7
|
+
|
|
8
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
9
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
10
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
11
|
+
from gymcts.gymcts_node import GymctsNode
|
|
12
|
+
from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
13
|
+
|
|
14
|
+
from gymcts.logger import log
|
|
15
|
+
|
|
16
|
+
import ray
|
|
17
|
+
import copy
|
|
18
|
+
|
|
19
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@ray.remote
|
|
23
|
+
def mcts_lookahead(
|
|
24
|
+
gymcts_start_node: GymctsNode,
|
|
25
|
+
env: GymctsABC,
|
|
26
|
+
num_simulations: int) -> GymctsNode:
|
|
27
|
+
agent = GymctsAgent(
|
|
28
|
+
env=env,
|
|
29
|
+
clear_mcts_tree_after_step=False,
|
|
30
|
+
number_of_simulations_per_step=num_simulations,
|
|
31
|
+
)
|
|
32
|
+
agent.search_root_node = gymcts_start_node
|
|
33
|
+
|
|
34
|
+
agent.vanilla_mcts_search(
|
|
35
|
+
search_start_node=gymcts_start_node,
|
|
36
|
+
num_simulations=num_simulations,
|
|
37
|
+
)
|
|
38
|
+
return agent.search_root_node
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def merge_nodes(gymcts_node1, gymcts_node2, perform_state_equality_check=False):
|
|
42
|
+
log.debug(f"merging {gymcts_node1} and {gymcts_node2}")
|
|
43
|
+
# maybe add some state equality check here
|
|
44
|
+
if perform_state_equality_check:
|
|
45
|
+
if gymcts_node1.state != gymcts_node2.state:
|
|
46
|
+
raise ValueError("States are different")
|
|
47
|
+
|
|
48
|
+
if gymcts_node1 is None:
|
|
49
|
+
log.debug(f"first node is None, returning second node ({gymcts_node2})")
|
|
50
|
+
return gymcts_node2
|
|
51
|
+
if gymcts_node2 is None:
|
|
52
|
+
log.debug(f"second node is None, returning first node ({gymcts_node1})")
|
|
53
|
+
return gymcts_node1
|
|
54
|
+
if gymcts_node1 is None and gymcts_node2 is None:
|
|
55
|
+
log.error("Both nodes are None")
|
|
56
|
+
raise ValueError("Both nodes are None")
|
|
57
|
+
|
|
58
|
+
if gymcts_node1.is_leaf() and not gymcts_node2.is_leaf():
|
|
59
|
+
log.debug(f"first node is leaf, second node is not leaf")
|
|
60
|
+
gymcts_node2.parent = gymcts_node1.parent
|
|
61
|
+
log.debug(f"returning first node: {gymcts_node2}")
|
|
62
|
+
return gymcts_node2
|
|
63
|
+
|
|
64
|
+
if gymcts_node2.is_leaf() and not gymcts_node1.is_leaf():
|
|
65
|
+
log.debug(f"second node is leaf, first node is not leaf")
|
|
66
|
+
log.debug(f"returning first node: {gymcts_node1}")
|
|
67
|
+
return gymcts_node1
|
|
68
|
+
|
|
69
|
+
if gymcts_node1.is_leaf() and gymcts_node2.is_leaf():
|
|
70
|
+
log.debug(f"both nodes are leafs, returning first node")
|
|
71
|
+
log.debug(f"returning first node: {gymcts_node1}")
|
|
72
|
+
return gymcts_node1
|
|
73
|
+
|
|
74
|
+
# check if gymcts_node1 and gymcts_node2 have the same children
|
|
75
|
+
if gymcts_node1.children.keys() != gymcts_node2.children.keys():
|
|
76
|
+
log.error("Nodes have different children")
|
|
77
|
+
raise ValueError("Nodes have different children")
|
|
78
|
+
|
|
79
|
+
for (action1, child1), (action2, child2) in zip(gymcts_node1.children.items(), gymcts_node2.children.items()):
|
|
80
|
+
if action1 != action2:
|
|
81
|
+
log.error("Actions are different")
|
|
82
|
+
raise ValueError("Actions are different")
|
|
83
|
+
log.debug(f"merging children with action {action1} for node {gymcts_node1}")
|
|
84
|
+
gymcts_node1.children[action1] = merge_nodes(
|
|
85
|
+
child1,
|
|
86
|
+
child2,
|
|
87
|
+
perform_state_equality_check=perform_state_equality_check
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
visit_count = gymcts_node1.visit_count + gymcts_node2.visit_count
|
|
91
|
+
mean_value = (
|
|
92
|
+
gymcts_node1.mean_value * gymcts_node1.visit_count + gymcts_node2.mean_value * gymcts_node2.visit_count) / visit_count
|
|
93
|
+
max_value = max(gymcts_node1.max_value, gymcts_node2.max_value)
|
|
94
|
+
min_value = min(gymcts_node1.min_value, gymcts_node2.min_value)
|
|
95
|
+
|
|
96
|
+
gymcts_node1.visit_count = visit_count
|
|
97
|
+
gymcts_node1.mean_value = mean_value
|
|
98
|
+
gymcts_node1.max_value = max_value
|
|
99
|
+
gymcts_node1.min_value = min_value
|
|
100
|
+
log.debug(f"merged node: {gymcts_node1}")
|
|
101
|
+
log.debug(f"returning node: {gymcts_node1}")
|
|
102
|
+
return gymcts_node1
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class DistributedGymctsAgent:
|
|
106
|
+
render_tree_after_step: bool = False
|
|
107
|
+
render_tree_max_depth: int = 2
|
|
108
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
109
|
+
number_of_simulations_per_step: int = 25
|
|
110
|
+
num_parallel: int = 4
|
|
111
|
+
|
|
112
|
+
env: GymctsABC
|
|
113
|
+
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
114
|
+
clear_mcts_tree_after_step: bool
|
|
115
|
+
|
|
116
|
+
def __init__(self,
|
|
117
|
+
env: GymctsABC,
|
|
118
|
+
render_tree_after_step: bool = False,
|
|
119
|
+
render_tree_max_depth: int = 2,
|
|
120
|
+
num_parallel: int = 4,
|
|
121
|
+
clear_mcts_tree_after_step: bool = False,
|
|
122
|
+
number_of_simulations_per_step: int = 25,
|
|
123
|
+
exclude_unvisited_nodes_from_render: bool = False
|
|
124
|
+
):
|
|
125
|
+
# check if action space of env is discrete
|
|
126
|
+
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
127
|
+
raise ValueError("Action space must be discrete.")
|
|
128
|
+
|
|
129
|
+
self.num_parallel = num_parallel
|
|
130
|
+
|
|
131
|
+
self.render_tree_after_step = render_tree_after_step
|
|
132
|
+
self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
|
|
133
|
+
self.render_tree_max_depth = render_tree_max_depth
|
|
134
|
+
|
|
135
|
+
self.number_of_simulations_per_step = number_of_simulations_per_step
|
|
136
|
+
|
|
137
|
+
self.env = env
|
|
138
|
+
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
139
|
+
|
|
140
|
+
self.search_root_node = GymctsNode(
|
|
141
|
+
action=None,
|
|
142
|
+
parent=None,
|
|
143
|
+
env_reference=env,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
|
|
147
|
+
|
|
148
|
+
if num_simulations_per_step is None:
|
|
149
|
+
num_simulations_per_step = self.number_of_simulations_per_step
|
|
150
|
+
if render_tree_after_step is None:
|
|
151
|
+
render_tree_after_step = self.render_tree_after_step
|
|
152
|
+
|
|
153
|
+
log.debug(f"Solving from root node: {self.search_root_node}")
|
|
154
|
+
|
|
155
|
+
current_node = self.search_root_node
|
|
156
|
+
|
|
157
|
+
action_list = []
|
|
158
|
+
|
|
159
|
+
while not current_node.terminal:
|
|
160
|
+
next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
|
|
161
|
+
render_tree_after_step=render_tree_after_step)
|
|
162
|
+
log.info(
|
|
163
|
+
f"selected action {next_action} after {self.num_parallel} x {num_simulations_per_step} simulations.")
|
|
164
|
+
action_list.append(next_action)
|
|
165
|
+
log.info(f"current action list: {action_list}")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
log.info(f"Final action list: {action_list}")
|
|
169
|
+
# restore state of current node
|
|
170
|
+
return action_list
|
|
171
|
+
|
|
172
|
+
def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
|
|
173
|
+
render_tree_after_step: bool = None, num_parallel: int = None) -> tuple[int, GymctsNode]:
|
|
174
|
+
|
|
175
|
+
if render_tree_after_step is None:
|
|
176
|
+
render_tree_after_step = self.render_tree_after_step
|
|
177
|
+
|
|
178
|
+
if render_tree_after_step is None:
|
|
179
|
+
render_tree_after_step = self.render_tree_after_step
|
|
180
|
+
|
|
181
|
+
if num_simulations is None:
|
|
182
|
+
num_simulations = self.number_of_simulations_per_step
|
|
183
|
+
|
|
184
|
+
if search_start_node is None:
|
|
185
|
+
search_start_node = self.search_root_node
|
|
186
|
+
|
|
187
|
+
if num_parallel is None:
|
|
188
|
+
num_parallel = self.num_parallel
|
|
189
|
+
|
|
190
|
+
# action = self.vanilla_mcts_search(
|
|
191
|
+
# search_start_node=search_start_node,
|
|
192
|
+
# num_simulations=num_simulations,
|
|
193
|
+
# )
|
|
194
|
+
# next_node = search_start_node.children[action]
|
|
195
|
+
|
|
196
|
+
mcts_interation_futures = [
|
|
197
|
+
mcts_lookahead.remote(
|
|
198
|
+
copy.deepcopy(search_start_node),
|
|
199
|
+
copy.deepcopy(self.env),
|
|
200
|
+
num_simulations=num_simulations
|
|
201
|
+
)
|
|
202
|
+
for _ in range(num_parallel)
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
while mcts_interation_futures:
|
|
206
|
+
ready_gymcts_nodes, mcts_interation_futures = ray.wait(mcts_interation_futures)
|
|
207
|
+
for ready_node_ref in ready_gymcts_nodes:
|
|
208
|
+
ready_node = ray.get(ready_node_ref)
|
|
209
|
+
|
|
210
|
+
# merge the tree
|
|
211
|
+
if not self.clear_mcts_tree_after_step:
|
|
212
|
+
self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
|
|
213
|
+
search_start_node = merge_nodes(search_start_node, ready_node)
|
|
214
|
+
|
|
215
|
+
action = search_start_node.get_best_action()
|
|
216
|
+
next_node = search_start_node.children[action]
|
|
217
|
+
|
|
218
|
+
if self.render_tree_after_step:
|
|
219
|
+
self.show_mcts_tree(
|
|
220
|
+
start_node=search_start_node,
|
|
221
|
+
tree_max_depth=self.render_tree_max_depth
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if self.clear_mcts_tree_after_step:
|
|
225
|
+
# to clear memory we need to remove all nodes except the current node
|
|
226
|
+
# this is done by setting the root node to the current node
|
|
227
|
+
# and setting the parent of the current node to None
|
|
228
|
+
# we also need to reset the children of the current node
|
|
229
|
+
# this is done by calling the reset method
|
|
230
|
+
next_node.reset()
|
|
231
|
+
|
|
232
|
+
self.search_root_node = next_node
|
|
233
|
+
|
|
234
|
+
return action, next_node
|
|
235
|
+
|
|
236
|
+
def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
|
|
237
|
+
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
238
|
+
while not node.is_root():
|
|
239
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
240
|
+
node.visit_count + num_episodes)
|
|
241
|
+
node.visit_count += num_episodes
|
|
242
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
243
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
244
|
+
node = node.parent
|
|
245
|
+
# also update root node
|
|
246
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
247
|
+
node.visit_count + num_episodes)
|
|
248
|
+
node.visit_count += num_episodes
|
|
249
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
250
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
251
|
+
|
|
252
|
+
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
253
|
+
|
|
254
|
+
if start_node is None:
|
|
255
|
+
start_node = self.search_root_node
|
|
256
|
+
|
|
257
|
+
if tree_max_depth is None:
|
|
258
|
+
tree_max_depth = self.render_tree_max_depth
|
|
259
|
+
|
|
260
|
+
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
261
|
+
for line in _generate_mcts_tree(
|
|
262
|
+
start_node=start_node,
|
|
263
|
+
depth=tree_max_depth,
|
|
264
|
+
action_space_n=self.env.action_space.n
|
|
265
|
+
):
|
|
266
|
+
print(line)
|
|
267
|
+
|
|
268
|
+
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
269
|
+
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
if __name__ == '__main__':
|
|
273
|
+
ray.init()
|
|
274
|
+
|
|
275
|
+
log.setLevel(20) # 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CR
|
|
276
|
+
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)
|
|
277
|
+
env.reset()
|
|
278
|
+
|
|
279
|
+
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
280
|
+
# env1 = ActionHistoryMCTSGymEnvWrapper(env1)
|
|
281
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
282
|
+
|
|
283
|
+
# 2. create the agent
|
|
284
|
+
agent1 = DistributedGymctsAgent(
|
|
285
|
+
env=env,
|
|
286
|
+
render_tree_after_step=True,
|
|
287
|
+
number_of_simulations_per_step=10,
|
|
288
|
+
exclude_unvisited_nodes_from_render=True,
|
|
289
|
+
num_parallel=1,
|
|
290
|
+
)
|
|
291
|
+
import time
|
|
292
|
+
|
|
293
|
+
start_time = time.perf_counter()
|
|
294
|
+
actions = agent1.solve()
|
|
295
|
+
end_time = time.perf_counter()
|
|
296
|
+
|
|
297
|
+
agent1.show_mcts_tree_from_root()
|
|
298
|
+
|
|
299
|
+
print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
|
gymcts/gymcts_env_abc.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import gymnasium as gym
|
|
4
|
+
|
|
5
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GymctsABC(ABC, gym.Env):
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def get_state(self) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
|
|
14
|
+
the environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
15
|
+
|
|
16
|
+
It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
|
|
17
|
+
|
|
18
|
+
:return: the current state of the environment
|
|
19
|
+
"""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def load_state(self, state: Any) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
|
|
26
|
+
environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
27
|
+
|
|
28
|
+
:param state: the state to load
|
|
29
|
+
:return: None
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def is_terminal(self) -> bool:
|
|
35
|
+
"""
|
|
36
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
37
|
+
:return:
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def get_valid_actions(self) -> list[int]:
|
|
43
|
+
"""
|
|
44
|
+
Returns a list of valid actions for the current state of the environment.
|
|
45
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
46
|
+
:return: the list of valid actions
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def rollout(self) -> float:
|
|
52
|
+
"""
|
|
53
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
|
|
54
|
+
|
|
55
|
+
Please make sure the return value is in the interval [-1, 1].
|
|
56
|
+
Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
|
|
57
|
+
exploration and exploitation are not well-balanced then).
|
|
58
|
+
|
|
59
|
+
:return: the return of the rollout
|
|
60
|
+
"""
|
|
61
|
+
pass
|
gymcts/gymcts_node.py
CHANGED
|
@@ -4,32 +4,41 @@ import math
|
|
|
4
4
|
|
|
5
5
|
from typing import TypeVar, Any, SupportsFloat, Callable, Generator
|
|
6
6
|
|
|
7
|
-
from gymcts.
|
|
7
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
8
8
|
|
|
9
9
|
from gymcts.logger import log
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class SoloMCTSNode:
|
|
17
|
-
|
|
14
|
+
class GymctsNode:
|
|
18
15
|
# static properties
|
|
19
|
-
best_action_weight: float = 0.05
|
|
20
|
-
ubc_c
|
|
16
|
+
best_action_weight: float = 0.05 # weight for the best action
|
|
17
|
+
ubc_c = 0.707 # exploration coefficient
|
|
21
18
|
|
|
22
19
|
|
|
23
|
-
# attributes
|
|
24
|
-
visit_count: int = 0
|
|
25
|
-
mean_value: float = 0
|
|
26
|
-
max_value: float = -float("inf")
|
|
27
|
-
min_value: float = +float("inf")
|
|
28
|
-
terminal: bool = False
|
|
29
|
-
state: Any
|
|
30
20
|
|
|
21
|
+
# attributes
|
|
22
|
+
#
|
|
23
|
+
# Note these attributes are not static. Their defined here to give developers a hint what fields are available
|
|
24
|
+
# in the class. They are not static because they are not shared between instances of the class in scope of
|
|
25
|
+
# this library.
|
|
26
|
+
visit_count: int = 0 # number of times the node has been visited
|
|
27
|
+
mean_value: float = 0 # mean value of the node
|
|
28
|
+
max_value: float = -float("inf") # maximum value of the node
|
|
29
|
+
min_value: float = +float("inf") # minimum value of the node
|
|
30
|
+
terminal: bool = False # whether the node is terminal or not
|
|
31
|
+
state: Any = None # state of the node
|
|
31
32
|
|
|
32
33
|
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Returns a string representation of the node. The string representation is used for visualisation purposes.
|
|
36
|
+
It is used for example in the mcts tree visualisation functionality.
|
|
37
|
+
|
|
38
|
+
:param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
|
|
39
|
+
:param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
|
|
40
|
+
:return: a potentially colored string representation of the node.
|
|
41
|
+
"""
|
|
33
42
|
if not colored:
|
|
34
43
|
|
|
35
44
|
if not self.is_root():
|
|
@@ -39,11 +48,9 @@ class SoloMCTSNode:
|
|
|
39
48
|
|
|
40
49
|
import gymcts.colorful_console_utils as ccu
|
|
41
50
|
|
|
42
|
-
|
|
43
51
|
if self.is_root():
|
|
44
52
|
return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
|
|
45
53
|
|
|
46
|
-
|
|
47
54
|
if action_space_n is None:
|
|
48
55
|
raise ValueError("action_space_n must be provided if colored is True")
|
|
49
56
|
|
|
@@ -68,52 +75,78 @@ class SoloMCTSNode:
|
|
|
68
75
|
if isinstance(value, int):
|
|
69
76
|
return f"{color}{value}{e}"
|
|
70
77
|
|
|
71
|
-
|
|
72
78
|
root_node = self.get_root()
|
|
73
79
|
mean_val = f"{self.mean_value:.2f}"
|
|
74
80
|
|
|
75
81
|
return ((f"("
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
82
|
+
f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
|
|
83
|
+
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
84
|
+
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
85
|
+
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
80
86
|
(f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
|
|
81
87
|
|
|
88
|
+
def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
|
|
89
|
+
"""
|
|
90
|
+
Traverse the tree and yield all nodes in the tree.
|
|
82
91
|
|
|
83
|
-
|
|
92
|
+
:return: a generator that yields all nodes in the tree.
|
|
93
|
+
"""
|
|
84
94
|
yield self
|
|
85
95
|
if self.children:
|
|
86
96
|
for child in self.children.values():
|
|
87
97
|
yield from child.traverse_nodes()
|
|
88
98
|
|
|
89
|
-
def get_root(self) ->
|
|
99
|
+
def get_root(self) -> TGymctsNode:
|
|
100
|
+
"""
|
|
101
|
+
Returns the root node of the tree. The root node is the node that has no parent.
|
|
102
|
+
|
|
103
|
+
:return: the root node of the tree.
|
|
104
|
+
"""
|
|
90
105
|
if self.is_root():
|
|
91
106
|
return self
|
|
92
107
|
return self.parent.get_root()
|
|
93
108
|
|
|
94
109
|
def max_tree_depth(self):
|
|
110
|
+
"""
|
|
111
|
+
Returns the maximum depth of the tree. The depth of a node is the number of edges from
|
|
112
|
+
the node to the root node.
|
|
113
|
+
|
|
114
|
+
:return: the maximum depth of the tree.
|
|
115
|
+
"""
|
|
95
116
|
if self.is_leaf():
|
|
96
117
|
return 0
|
|
97
118
|
return 1 + max(child.max_tree_depth() for child in self.children.values())
|
|
98
119
|
|
|
99
120
|
def n_children_recursively(self):
|
|
121
|
+
"""
|
|
122
|
+
Returns the number of children of the node recursively. The number of children of a node is the number of
|
|
123
|
+
children of the node plus the number of children of all children of the node.
|
|
124
|
+
|
|
125
|
+
:return: the number of children of the node recursively.
|
|
126
|
+
"""
|
|
100
127
|
if self.is_leaf():
|
|
101
128
|
return 0
|
|
102
129
|
return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
|
|
103
130
|
|
|
104
|
-
|
|
105
|
-
|
|
106
131
|
def __init__(self,
|
|
107
132
|
action: int | None,
|
|
108
|
-
parent:
|
|
109
|
-
env_reference:
|
|
133
|
+
parent: TGymctsNode | None,
|
|
134
|
+
env_reference: GymctsABC,
|
|
110
135
|
):
|
|
136
|
+
"""
|
|
137
|
+
Initializes the node. The node is initialized with the state of the environment and the action that was taken to
|
|
138
|
+
reach the node. The node is also initialized with the parent node and the environment reference.
|
|
139
|
+
|
|
140
|
+
:param action: the action that was taken to reach the node. If the node is a root node, this parameter is None.
|
|
141
|
+
:param parent: the parent node of the node. If the node is a root node, this parameter is None.
|
|
142
|
+
:param env_reference: a reference to the environment. The environment is used to get the state of the node and the valid actions.
|
|
143
|
+
"""
|
|
111
144
|
|
|
112
145
|
# field depending on whether the node is a root node or not
|
|
113
146
|
self.action: int | None
|
|
114
147
|
|
|
115
|
-
self.env_reference:
|
|
116
|
-
self.parent:
|
|
148
|
+
self.env_reference: GymctsABC
|
|
149
|
+
self.parent: GymctsNode | None
|
|
117
150
|
self.uuid = uuid.uuid4()
|
|
118
151
|
|
|
119
152
|
if parent is None:
|
|
@@ -133,7 +166,7 @@ class SoloMCTSNode:
|
|
|
133
166
|
|
|
134
167
|
from copy import copy
|
|
135
168
|
self.state = env_reference.get_state()
|
|
136
|
-
#log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
|
|
169
|
+
# log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
|
|
137
170
|
self.visit_count: int = 0
|
|
138
171
|
|
|
139
172
|
self.mean_value: float = 0
|
|
@@ -143,8 +176,7 @@ class SoloMCTSNode:
|
|
|
143
176
|
# safe valid action instead of calling the environment
|
|
144
177
|
# this reduces the compute but increases the memory usage
|
|
145
178
|
self.valid_actions: list[int] = env_reference.get_valid_actions()
|
|
146
|
-
self.children: dict[int,
|
|
147
|
-
|
|
179
|
+
self.children: dict[int, GymctsNode] | None = None # may be expanded later
|
|
148
180
|
|
|
149
181
|
def reset(self) -> None:
|
|
150
182
|
self.parent = None
|
|
@@ -153,40 +185,71 @@ class SoloMCTSNode:
|
|
|
153
185
|
self.mean_value: float = 0
|
|
154
186
|
self.max_value: float = -float("inf")
|
|
155
187
|
self.min_value: float = +float("inf")
|
|
156
|
-
self.children: dict[int,
|
|
188
|
+
self.children: dict[int, GymctsNode] | None = None # may be expanded later
|
|
157
189
|
|
|
158
190
|
# just setting the children of the parent node to None should be enough to trigger garbage collection
|
|
159
191
|
# however, we also set the parent to None to make sure that the parent is not referenced anymore
|
|
160
192
|
if self.parent:
|
|
161
193
|
self.parent.reset()
|
|
162
194
|
|
|
163
|
-
|
|
164
|
-
|
|
165
195
|
def is_root(self) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
Returns true if the node is a root node. A root node is a node that has no parent.
|
|
198
|
+
|
|
199
|
+
:return: true if the node is a root node, false otherwise.
|
|
200
|
+
"""
|
|
166
201
|
return self.parent is None
|
|
167
202
|
|
|
168
203
|
def is_leaf(self) -> bool:
|
|
204
|
+
"""
|
|
205
|
+
Returns true if the node is a leaf node. A leaf node is a node that has no children. A leaf node is a node that has no children.
|
|
206
|
+
|
|
207
|
+
:return: true if the node is a leaf node, false otherwise.
|
|
208
|
+
"""
|
|
169
209
|
return self.children is None or len(self.children) == 0
|
|
170
210
|
|
|
171
|
-
def get_random_child(self) ->
|
|
211
|
+
def get_random_child(self) -> TGymctsNode:
|
|
212
|
+
"""
|
|
213
|
+
Returns a random child of the node. A random child is a child that is selected randomly from the list of children.
|
|
214
|
+
:return:
|
|
215
|
+
"""
|
|
172
216
|
if self.is_leaf():
|
|
173
|
-
raise ValueError("cannot get random child of leaf node")
|
|
217
|
+
raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
|
|
174
218
|
|
|
175
219
|
return list(self.children.values())[random.randint(0, len(self.children) - 1)]
|
|
176
220
|
|
|
177
221
|
def get_best_action(self) -> int:
|
|
222
|
+
"""
|
|
223
|
+
Returns the best action of the node. The best action is the action that has the highest score.
|
|
224
|
+
The score is calculated using the get_score() method. The best action is the action that has the highest score.
|
|
225
|
+
The best action is the action that has the highest score.
|
|
226
|
+
|
|
227
|
+
:return: the best action of the node.
|
|
228
|
+
"""
|
|
178
229
|
return max(self.children.values(), key=lambda child: child.get_score()).action
|
|
179
230
|
|
|
180
|
-
def get_score(self) -> float:
|
|
231
|
+
def get_score(self) -> float: # todo: make it an attribute?
|
|
232
|
+
"""
|
|
233
|
+
Returns the score of the node. The score is calculated using the mean value and the maximum value of the node.
|
|
234
|
+
The score is calculated using the formula: score = (1 - a) * mean_value + a * max_value
|
|
235
|
+
where a is the best action weight.
|
|
236
|
+
|
|
237
|
+
:return: the score of the node.
|
|
238
|
+
"""
|
|
181
239
|
# return self.mean_value
|
|
182
|
-
assert 0 <=
|
|
183
|
-
a =
|
|
184
|
-
return (1-a) * self.mean_value + a * self.max_value
|
|
240
|
+
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
241
|
+
a = GymctsNode.best_action_weight
|
|
242
|
+
return (1 - a) * self.mean_value + a * self.max_value
|
|
185
243
|
|
|
186
244
|
def get_mean_value(self) -> float:
|
|
187
245
|
return self.mean_value
|
|
188
246
|
|
|
189
247
|
def get_max_value(self) -> float:
|
|
248
|
+
"""
|
|
249
|
+
Returns the maximum value of the node. The maximum value is the maximum value of the node.
|
|
250
|
+
|
|
251
|
+
:return: the maximum value of the node.
|
|
252
|
+
"""
|
|
190
253
|
return self.max_value
|
|
191
254
|
|
|
192
255
|
def ucb_score(self):
|
|
@@ -207,7 +270,7 @@ class SoloMCTSNode:
|
|
|
207
270
|
if self.is_root():
|
|
208
271
|
raise ValueError("ucb_score can only be called on non-root nodes")
|
|
209
272
|
# c = 0.707 # todo: make it an attribute?
|
|
210
|
-
c =
|
|
273
|
+
c = GymctsNode.ubc_c
|
|
211
274
|
if self.visit_count == 0:
|
|
212
275
|
return float("inf")
|
|
213
|
-
return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
|
|
276
|
+
return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
|