gymcts 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,299 @@
1
+ import copy
2
+ import gymnasium as gym
3
+
4
+ from typing import TypeVar, Any, SupportsFloat, Callable
5
+
6
+ from ray.types import ObjectRef
7
+
8
+ from gymcts.gymcts_agent import GymctsAgent
9
+ from gymcts.gymcts_env_abc import GymctsABC
10
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
11
+ from gymcts.gymcts_node import GymctsNode
12
+ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
13
+
14
+ from gymcts.logger import log
15
+
16
+ import ray
17
+ import copy
18
+
19
+ TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
20
+
21
+
22
+ @ray.remote
23
+ def mcts_lookahead(
24
+ gymcts_start_node: GymctsNode,
25
+ env: GymctsABC,
26
+ num_simulations: int) -> GymctsNode:
27
+ agent = GymctsAgent(
28
+ env=env,
29
+ clear_mcts_tree_after_step=False,
30
+ number_of_simulations_per_step=num_simulations,
31
+ )
32
+ agent.search_root_node = gymcts_start_node
33
+
34
+ agent.vanilla_mcts_search(
35
+ search_start_node=gymcts_start_node,
36
+ num_simulations=num_simulations,
37
+ )
38
+ return agent.search_root_node
39
+
40
+
41
+ def merge_nodes(gymcts_node1, gymcts_node2, perform_state_equality_check=False):
42
+ log.debug(f"merging {gymcts_node1} and {gymcts_node2}")
43
+ # maybe add some state equality check here
44
+ if perform_state_equality_check:
45
+ if gymcts_node1.state != gymcts_node2.state:
46
+ raise ValueError("States are different")
47
+
48
+ if gymcts_node1 is None:
49
+ log.debug(f"first node is None, returning second node ({gymcts_node2})")
50
+ return gymcts_node2
51
+ if gymcts_node2 is None:
52
+ log.debug(f"second node is None, returning first node ({gymcts_node1})")
53
+ return gymcts_node1
54
+ if gymcts_node1 is None and gymcts_node2 is None:
55
+ log.error("Both nodes are None")
56
+ raise ValueError("Both nodes are None")
57
+
58
+ if gymcts_node1.is_leaf() and not gymcts_node2.is_leaf():
59
+ log.debug(f"first node is leaf, second node is not leaf")
60
+ gymcts_node2.parent = gymcts_node1.parent
61
+ log.debug(f"returning first node: {gymcts_node2}")
62
+ return gymcts_node2
63
+
64
+ if gymcts_node2.is_leaf() and not gymcts_node1.is_leaf():
65
+ log.debug(f"second node is leaf, first node is not leaf")
66
+ log.debug(f"returning first node: {gymcts_node1}")
67
+ return gymcts_node1
68
+
69
+ if gymcts_node1.is_leaf() and gymcts_node2.is_leaf():
70
+ log.debug(f"both nodes are leafs, returning first node")
71
+ log.debug(f"returning first node: {gymcts_node1}")
72
+ return gymcts_node1
73
+
74
+ # check if gymcts_node1 and gymcts_node2 have the same children
75
+ if gymcts_node1.children.keys() != gymcts_node2.children.keys():
76
+ log.error("Nodes have different children")
77
+ raise ValueError("Nodes have different children")
78
+
79
+ for (action1, child1), (action2, child2) in zip(gymcts_node1.children.items(), gymcts_node2.children.items()):
80
+ if action1 != action2:
81
+ log.error("Actions are different")
82
+ raise ValueError("Actions are different")
83
+ log.debug(f"merging children with action {action1} for node {gymcts_node1}")
84
+ gymcts_node1.children[action1] = merge_nodes(
85
+ child1,
86
+ child2,
87
+ perform_state_equality_check=perform_state_equality_check
88
+ )
89
+
90
+ visit_count = gymcts_node1.visit_count + gymcts_node2.visit_count
91
+ mean_value = (
92
+ gymcts_node1.mean_value * gymcts_node1.visit_count + gymcts_node2.mean_value * gymcts_node2.visit_count) / visit_count
93
+ max_value = max(gymcts_node1.max_value, gymcts_node2.max_value)
94
+ min_value = min(gymcts_node1.min_value, gymcts_node2.min_value)
95
+
96
+ gymcts_node1.visit_count = visit_count
97
+ gymcts_node1.mean_value = mean_value
98
+ gymcts_node1.max_value = max_value
99
+ gymcts_node1.min_value = min_value
100
+ log.debug(f"merged node: {gymcts_node1}")
101
+ log.debug(f"returning node: {gymcts_node1}")
102
+ return gymcts_node1
103
+
104
+
105
+ class DistributedGymctsAgent:
106
+ render_tree_after_step: bool = False
107
+ render_tree_max_depth: int = 2
108
+ exclude_unvisited_nodes_from_render: bool = False
109
+ number_of_simulations_per_step: int = 25
110
+ num_parallel: int = 4
111
+
112
+ env: GymctsABC
113
+ search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
114
+ clear_mcts_tree_after_step: bool
115
+
116
+ def __init__(self,
117
+ env: GymctsABC,
118
+ render_tree_after_step: bool = False,
119
+ render_tree_max_depth: int = 2,
120
+ num_parallel: int = 4,
121
+ clear_mcts_tree_after_step: bool = False,
122
+ number_of_simulations_per_step: int = 25,
123
+ exclude_unvisited_nodes_from_render: bool = False
124
+ ):
125
+ # check if action space of env is discrete
126
+ if not isinstance(env.action_space, gym.spaces.Discrete):
127
+ raise ValueError("Action space must be discrete.")
128
+
129
+ self.num_parallel = num_parallel
130
+
131
+ self.render_tree_after_step = render_tree_after_step
132
+ self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
133
+ self.render_tree_max_depth = render_tree_max_depth
134
+
135
+ self.number_of_simulations_per_step = number_of_simulations_per_step
136
+
137
+ self.env = env
138
+ self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
139
+
140
+ self.search_root_node = GymctsNode(
141
+ action=None,
142
+ parent=None,
143
+ env_reference=env,
144
+ )
145
+
146
+ def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
147
+
148
+ if num_simulations_per_step is None:
149
+ num_simulations_per_step = self.number_of_simulations_per_step
150
+ if render_tree_after_step is None:
151
+ render_tree_after_step = self.render_tree_after_step
152
+
153
+ log.debug(f"Solving from root node: {self.search_root_node}")
154
+
155
+ current_node = self.search_root_node
156
+
157
+ action_list = []
158
+
159
+ while not current_node.terminal:
160
+ next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
161
+ render_tree_after_step=render_tree_after_step)
162
+ log.info(
163
+ f"selected action {next_action} after {self.num_parallel} x {num_simulations_per_step} simulations.")
164
+ action_list.append(next_action)
165
+ log.info(f"current action list: {action_list}")
166
+
167
+
168
+ log.info(f"Final action list: {action_list}")
169
+ # restore state of current node
170
+ return action_list
171
+
172
+ def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
173
+ render_tree_after_step: bool = None, num_parallel: int = None) -> tuple[int, GymctsNode]:
174
+
175
+ if render_tree_after_step is None:
176
+ render_tree_after_step = self.render_tree_after_step
177
+
178
+ if render_tree_after_step is None:
179
+ render_tree_after_step = self.render_tree_after_step
180
+
181
+ if num_simulations is None:
182
+ num_simulations = self.number_of_simulations_per_step
183
+
184
+ if search_start_node is None:
185
+ search_start_node = self.search_root_node
186
+
187
+ if num_parallel is None:
188
+ num_parallel = self.num_parallel
189
+
190
+ # action = self.vanilla_mcts_search(
191
+ # search_start_node=search_start_node,
192
+ # num_simulations=num_simulations,
193
+ # )
194
+ # next_node = search_start_node.children[action]
195
+
196
+ mcts_interation_futures = [
197
+ mcts_lookahead.remote(
198
+ copy.deepcopy(search_start_node),
199
+ copy.deepcopy(self.env),
200
+ num_simulations=num_simulations
201
+ )
202
+ for _ in range(num_parallel)
203
+ ]
204
+
205
+ while mcts_interation_futures:
206
+ ready_gymcts_nodes, mcts_interation_futures = ray.wait(mcts_interation_futures)
207
+ for ready_node_ref in ready_gymcts_nodes:
208
+ ready_node = ray.get(ready_node_ref)
209
+
210
+ # merge the tree
211
+ if not self.clear_mcts_tree_after_step:
212
+ self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
213
+ search_start_node = merge_nodes(search_start_node, ready_node)
214
+
215
+ action = search_start_node.get_best_action()
216
+ next_node = search_start_node.children[action]
217
+
218
+ if self.render_tree_after_step:
219
+ self.show_mcts_tree(
220
+ start_node=search_start_node,
221
+ tree_max_depth=self.render_tree_max_depth
222
+ )
223
+
224
+ if self.clear_mcts_tree_after_step:
225
+ # to clear memory we need to remove all nodes except the current node
226
+ # this is done by setting the root node to the current node
227
+ # and setting the parent of the current node to None
228
+ # we also need to reset the children of the current node
229
+ # this is done by calling the reset method
230
+ next_node.reset()
231
+
232
+ self.search_root_node = next_node
233
+
234
+ return action, next_node
235
+
236
+ def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
237
+ log.debug(f"performing backpropagation from leaf node: {node}")
238
+ while not node.is_root():
239
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
240
+ node.visit_count + num_episodes)
241
+ node.visit_count += num_episodes
242
+ node.max_value = max(node.max_value, average_episode_return)
243
+ node.min_value = min(node.min_value, average_episode_return)
244
+ node = node.parent
245
+ # also update root node
246
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
247
+ node.visit_count + num_episodes)
248
+ node.visit_count += num_episodes
249
+ node.max_value = max(node.max_value, average_episode_return)
250
+ node.min_value = min(node.min_value, average_episode_return)
251
+
252
+ def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
253
+
254
+ if start_node is None:
255
+ start_node = self.search_root_node
256
+
257
+ if tree_max_depth is None:
258
+ tree_max_depth = self.render_tree_max_depth
259
+
260
+ print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
261
+ for line in _generate_mcts_tree(
262
+ start_node=start_node,
263
+ depth=tree_max_depth,
264
+ action_space_n=self.env.action_space.n
265
+ ):
266
+ print(line)
267
+
268
+ def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
269
+ self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
270
+
271
+
272
+ if __name__ == '__main__':
273
+ ray.init()
274
+
275
+ log.setLevel(20) # 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CR
276
+ env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)
277
+ env.reset()
278
+
279
+ # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
280
+ # env1 = ActionHistoryMCTSGymEnvWrapper(env1)
281
+ env = DeepCopyMCTSGymEnvWrapper(env)
282
+
283
+ # 2. create the agent
284
+ agent1 = DistributedGymctsAgent(
285
+ env=env,
286
+ render_tree_after_step=True,
287
+ number_of_simulations_per_step=10,
288
+ exclude_unvisited_nodes_from_render=True,
289
+ num_parallel=1,
290
+ )
291
+ import time
292
+
293
+ start_time = time.perf_counter()
294
+ actions = agent1.solve()
295
+ end_time = time.perf_counter()
296
+
297
+ agent1.show_mcts_tree_from_root()
298
+
299
+ print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
@@ -0,0 +1,61 @@
1
+ from typing import TypeVar, Any, SupportsFloat, Callable
2
+ from abc import ABC, abstractmethod
3
+ import gymnasium as gym
4
+
5
+ TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
6
+
7
+
8
+ class GymctsABC(ABC, gym.Env):
9
+
10
+ @abstractmethod
11
+ def get_state(self) -> Any:
12
+ """
13
+ Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
14
+ the environment to the same state. The state is used to restore the environment unsing the load_state method.
15
+
16
+ It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
17
+
18
+ :return: the current state of the environment
19
+ """
20
+ pass
21
+
22
+ @abstractmethod
23
+ def load_state(self, state: Any) -> None:
24
+ """
25
+ Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
26
+ environment to the same state. The state is used to restore the environment unsing the load_state method.
27
+
28
+ :param state: the state to load
29
+ :return: None
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def is_terminal(self) -> bool:
35
+ """
36
+ Returns True if the environment is in a terminal state, False otherwise.
37
+ :return:
38
+ """
39
+ pass
40
+
41
+ @abstractmethod
42
+ def get_valid_actions(self) -> list[int]:
43
+ """
44
+ Returns a list of valid actions for the current state of the environment.
45
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
46
+ :return: the list of valid actions
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def rollout(self) -> float:
52
+ """
53
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
54
+
55
+ Please make sure the return value is in the interval [-1, 1].
56
+ Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
57
+ exploration and exploitation are not well-balanced then).
58
+
59
+ :return: the return of the rollout
60
+ """
61
+ pass
gymcts/gymcts_node.py CHANGED
@@ -4,32 +4,41 @@ import math
4
4
 
5
5
  from typing import TypeVar, Any, SupportsFloat, Callable, Generator
6
6
 
7
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
7
+ from gymcts.gymcts_env_abc import GymctsABC
8
8
 
9
9
  from gymcts.logger import log
10
10
 
11
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
11
+ TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
12
12
 
13
13
 
14
-
15
-
16
- class SoloMCTSNode:
17
-
14
+ class GymctsNode:
18
15
  # static properties
19
- best_action_weight: float = 0.05
20
- ubc_c = 0.707
16
+ best_action_weight: float = 0.05 # weight for the best action
17
+ ubc_c = 0.707 # exploration coefficient
21
18
 
22
19
 
23
- # attributes
24
- visit_count: int = 0
25
- mean_value: float = 0
26
- max_value: float = -float("inf")
27
- min_value: float = +float("inf")
28
- terminal: bool = False
29
- state: Any
30
20
 
21
+ # attributes
22
+ #
23
+ # Note these attributes are not static. Their defined here to give developers a hint what fields are available
24
+ # in the class. They are not static because they are not shared between instances of the class in scope of
25
+ # this library.
26
+ visit_count: int = 0 # number of times the node has been visited
27
+ mean_value: float = 0 # mean value of the node
28
+ max_value: float = -float("inf") # maximum value of the node
29
+ min_value: float = +float("inf") # minimum value of the node
30
+ terminal: bool = False # whether the node is terminal or not
31
+ state: Any = None # state of the node
31
32
 
32
33
  def __str__(self, colored=False, action_space_n=None) -> str:
34
+ """
35
+ Returns a string representation of the node. The string representation is used for visualisation purposes.
36
+ It is used for example in the mcts tree visualisation functionality.
37
+
38
+ :param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
39
+ :param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
40
+ :return: a potentially colored string representation of the node.
41
+ """
33
42
  if not colored:
34
43
 
35
44
  if not self.is_root():
@@ -39,11 +48,9 @@ class SoloMCTSNode:
39
48
 
40
49
  import gymcts.colorful_console_utils as ccu
41
50
 
42
-
43
51
  if self.is_root():
44
52
  return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
45
53
 
46
-
47
54
  if action_space_n is None:
48
55
  raise ValueError("action_space_n must be provided if colored is True")
49
56
 
@@ -68,52 +75,78 @@ class SoloMCTSNode:
68
75
  if isinstance(value, int):
69
76
  return f"{color}{value}{e}"
70
77
 
71
-
72
78
  root_node = self.get_root()
73
79
  mean_val = f"{self.mean_value:.2f}"
74
80
 
75
81
  return ((f"("
76
- f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
77
- f"{p}N{e}={colorful_value(self.visit_count)}, "
78
- f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
79
- f"{p}best{e}={colorful_value(self.max_value)}") +
82
+ f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
83
+ f"{p}N{e}={colorful_value(self.visit_count)}, "
84
+ f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
85
+ f"{p}best{e}={colorful_value(self.max_value)}") +
80
86
  (f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
81
87
 
88
+ def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
89
+ """
90
+ Traverse the tree and yield all nodes in the tree.
82
91
 
83
- def traverse_nodes(self) -> Generator[TSoloMCTSNode, None, None]:
92
+ :return: a generator that yields all nodes in the tree.
93
+ """
84
94
  yield self
85
95
  if self.children:
86
96
  for child in self.children.values():
87
97
  yield from child.traverse_nodes()
88
98
 
89
- def get_root(self) -> TSoloMCTSNode:
99
+ def get_root(self) -> TGymctsNode:
100
+ """
101
+ Returns the root node of the tree. The root node is the node that has no parent.
102
+
103
+ :return: the root node of the tree.
104
+ """
90
105
  if self.is_root():
91
106
  return self
92
107
  return self.parent.get_root()
93
108
 
94
109
  def max_tree_depth(self):
110
+ """
111
+ Returns the maximum depth of the tree. The depth of a node is the number of edges from
112
+ the node to the root node.
113
+
114
+ :return: the maximum depth of the tree.
115
+ """
95
116
  if self.is_leaf():
96
117
  return 0
97
118
  return 1 + max(child.max_tree_depth() for child in self.children.values())
98
119
 
99
120
  def n_children_recursively(self):
121
+ """
122
+ Returns the number of children of the node recursively. The number of children of a node is the number of
123
+ children of the node plus the number of children of all children of the node.
124
+
125
+ :return: the number of children of the node recursively.
126
+ """
100
127
  if self.is_leaf():
101
128
  return 0
102
129
  return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
103
130
 
104
-
105
-
106
131
  def __init__(self,
107
132
  action: int | None,
108
- parent: TSoloMCTSNode | None,
109
- env_reference: SoloMCTSGymEnv,
133
+ parent: TGymctsNode | None,
134
+ env_reference: GymctsABC,
110
135
  ):
136
+ """
137
+ Initializes the node. The node is initialized with the state of the environment and the action that was taken to
138
+ reach the node. The node is also initialized with the parent node and the environment reference.
139
+
140
+ :param action: the action that was taken to reach the node. If the node is a root node, this parameter is None.
141
+ :param parent: the parent node of the node. If the node is a root node, this parameter is None.
142
+ :param env_reference: a reference to the environment. The environment is used to get the state of the node and the valid actions.
143
+ """
111
144
 
112
145
  # field depending on whether the node is a root node or not
113
146
  self.action: int | None
114
147
 
115
- self.env_reference: SoloMCTSGymEnv
116
- self.parent: SoloMCTSNode | None
148
+ self.env_reference: GymctsABC
149
+ self.parent: GymctsNode | None
117
150
  self.uuid = uuid.uuid4()
118
151
 
119
152
  if parent is None:
@@ -133,7 +166,7 @@ class SoloMCTSNode:
133
166
 
134
167
  from copy import copy
135
168
  self.state = env_reference.get_state()
136
- #log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
169
+ # log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
137
170
  self.visit_count: int = 0
138
171
 
139
172
  self.mean_value: float = 0
@@ -143,8 +176,7 @@ class SoloMCTSNode:
143
176
  # safe valid action instead of calling the environment
144
177
  # this reduces the compute but increases the memory usage
145
178
  self.valid_actions: list[int] = env_reference.get_valid_actions()
146
- self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
147
-
179
+ self.children: dict[int, GymctsNode] | None = None # may be expanded later
148
180
 
149
181
  def reset(self) -> None:
150
182
  self.parent = None
@@ -153,40 +185,71 @@ class SoloMCTSNode:
153
185
  self.mean_value: float = 0
154
186
  self.max_value: float = -float("inf")
155
187
  self.min_value: float = +float("inf")
156
- self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
188
+ self.children: dict[int, GymctsNode] | None = None # may be expanded later
157
189
 
158
190
  # just setting the children of the parent node to None should be enough to trigger garbage collection
159
191
  # however, we also set the parent to None to make sure that the parent is not referenced anymore
160
192
  if self.parent:
161
193
  self.parent.reset()
162
194
 
163
-
164
-
165
195
  def is_root(self) -> bool:
196
+ """
197
+ Returns true if the node is a root node. A root node is a node that has no parent.
198
+
199
+ :return: true if the node is a root node, false otherwise.
200
+ """
166
201
  return self.parent is None
167
202
 
168
203
  def is_leaf(self) -> bool:
204
+ """
205
+ Returns true if the node is a leaf node. A leaf node is a node that has no children. A leaf node is a node that has no children.
206
+
207
+ :return: true if the node is a leaf node, false otherwise.
208
+ """
169
209
  return self.children is None or len(self.children) == 0
170
210
 
171
- def get_random_child(self) -> TSoloMCTSNode:
211
+ def get_random_child(self) -> TGymctsNode:
212
+ """
213
+ Returns a random child of the node. A random child is a child that is selected randomly from the list of children.
214
+ :return:
215
+ """
172
216
  if self.is_leaf():
173
- raise ValueError("cannot get random child of leaf node") #todo: maybe return self instead?
217
+ raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
174
218
 
175
219
  return list(self.children.values())[random.randint(0, len(self.children) - 1)]
176
220
 
177
221
  def get_best_action(self) -> int:
222
+ """
223
+ Returns the best action of the node. The best action is the action that has the highest score.
224
+ The score is calculated using the get_score() method. The best action is the action that has the highest score.
225
+ The best action is the action that has the highest score.
226
+
227
+ :return: the best action of the node.
228
+ """
178
229
  return max(self.children.values(), key=lambda child: child.get_score()).action
179
230
 
180
- def get_score(self) -> float: # todo: make it an attribute?
231
+ def get_score(self) -> float: # todo: make it an attribute?
232
+ """
233
+ Returns the score of the node. The score is calculated using the mean value and the maximum value of the node.
234
+ The score is calculated using the formula: score = (1 - a) * mean_value + a * max_value
235
+ where a is the best action weight.
236
+
237
+ :return: the score of the node.
238
+ """
181
239
  # return self.mean_value
182
- assert 0 <= SoloMCTSNode.best_action_weight <= 1
183
- a = SoloMCTSNode.best_action_weight
184
- return (1-a) * self.mean_value + a * self.max_value
240
+ assert 0 <= GymctsNode.best_action_weight <= 1
241
+ a = GymctsNode.best_action_weight
242
+ return (1 - a) * self.mean_value + a * self.max_value
185
243
 
186
244
  def get_mean_value(self) -> float:
187
245
  return self.mean_value
188
246
 
189
247
  def get_max_value(self) -> float:
248
+ """
249
+ Returns the maximum value of the node. The maximum value is the maximum value of the node.
250
+
251
+ :return: the maximum value of the node.
252
+ """
190
253
  return self.max_value
191
254
 
192
255
  def ucb_score(self):
@@ -207,7 +270,7 @@ class SoloMCTSNode:
207
270
  if self.is_root():
208
271
  raise ValueError("ucb_score can only be called on non-root nodes")
209
272
  # c = 0.707 # todo: make it an attribute?
210
- c = SoloMCTSNode.ubc_c
273
+ c = GymctsNode.ubc_c
211
274
  if self.visit_count == 0:
212
275
  return float("inf")
213
- return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
276
+ return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))