gymcts 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  import matplotlib.pyplot as plt
2
4
  import numpy as np
3
5
 
@@ -103,8 +105,7 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
103
105
  f"{CEND}"
104
106
 
105
107
 
106
-
107
- def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
108
+ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
108
109
  if s is None or n_of_item is None or n_classes is None:
109
110
  return s
110
111
 
@@ -117,7 +118,7 @@ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbo
117
118
  return f"{color_asni}{s}{CEND}"
118
119
 
119
120
 
120
- def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
121
+ def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
121
122
  if s is None or min_val is None or max_val is None or min_val >= max_val:
122
123
  return s
123
124
 
@@ -7,12 +7,12 @@ import gymnasium as gym
7
7
  from gymnasium.core import WrapperActType, WrapperObsType
8
8
  from gymnasium.wrappers import RecordEpisodeStatistics
9
9
 
10
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
10
+ from gymcts.gymcts_env_abc import GymctsABC
11
11
 
12
12
  from gymcts.logger import log
13
13
 
14
14
 
15
- class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
15
+ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
16
16
  _terminal_flag: bool = False
17
17
  _last_reward: SupportsFloat = 0
18
18
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
gymcts/gymcts_agent.py CHANGED
@@ -3,27 +3,28 @@ import gymnasium as gym
3
3
 
4
4
  from typing import TypeVar, Any, SupportsFloat, Callable
5
5
 
6
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
7
- from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
8
- from gymcts.gymcts_node import SoloMCTSNode
6
+ from gymcts.gymcts_env_abc import GymctsABC
7
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
8
+ from gymcts.gymcts_node import GymctsNode
9
+ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
9
10
 
10
11
  from gymcts.logger import log
11
12
 
12
13
  TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
13
14
 
14
15
 
15
- class SoloMCTSAgent:
16
+ class GymctsAgent:
16
17
  render_tree_after_step: bool = False
17
18
  render_tree_max_depth: int = 2
18
19
  exclude_unvisited_nodes_from_render: bool = False
19
20
  number_of_simulations_per_step: int = 25
20
21
 
21
- env: SoloMCTSGymEnv
22
- search_root_node: SoloMCTSNode # NOTE: this is not the same as the root of the tree!
22
+ env: GymctsABC
23
+ search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
23
24
  clear_mcts_tree_after_step: bool
24
25
 
25
26
  def __init__(self,
26
- env: SoloMCTSGymEnv,
27
+ env: GymctsABC,
27
28
  clear_mcts_tree_after_step: bool = True,
28
29
  render_tree_after_step: bool = False,
29
30
  render_tree_max_depth: int = 2,
@@ -43,13 +44,13 @@ class SoloMCTSAgent:
43
44
  self.env = env
44
45
  self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
45
46
 
46
- self.search_root_node = SoloMCTSNode(
47
+ self.search_root_node = GymctsNode(
47
48
  action=None,
48
49
  parent=None,
49
50
  env_reference=env,
50
51
  )
51
52
 
52
- def navigate_to_leaf(self, from_node: SoloMCTSNode) -> SoloMCTSNode:
53
+ def navigate_to_leaf(self, from_node: GymctsNode) -> GymctsNode:
53
54
  log.debug(f"Navigate to leaf. from_node: {from_node}")
54
55
  if from_node.terminal:
55
56
  log.debug("Node is terminal. Returning from_node")
@@ -66,7 +67,7 @@ class SoloMCTSAgent:
66
67
  log.debug(f"Selected leaf node: {temp_node}")
67
68
  return temp_node
68
69
 
69
- def expand_node(self, node: SoloMCTSNode) -> None:
70
+ def expand_node(self, node: GymctsNode) -> None:
70
71
  log.debug(f"expanding node: {node}")
71
72
  # EXPANSION STRATEGY
72
73
  # expand all children
@@ -78,7 +79,7 @@ class SoloMCTSAgent:
78
79
  self._load_state(node)
79
80
 
80
81
  obs, reward, terminal, truncated, _ = self.env.step(action)
81
- child_dict[action] = SoloMCTSNode(
82
+ child_dict[action] = GymctsNode(
82
83
  action=action,
83
84
  parent=node,
84
85
  env_reference=self.env,
@@ -110,14 +111,14 @@ class SoloMCTSAgent:
110
111
  # restore state of current node
111
112
  return action_list
112
113
 
113
- def _load_state(self, node: SoloMCTSNode) -> None:
114
- if isinstance(self.env, NaiveSoloMCTSGymEnvWrapper):
114
+ def _load_state(self, node: GymctsNode) -> None:
115
+ if isinstance(self.env, DeepCopyMCTSGymEnvWrapper):
115
116
  self.env = copy.deepcopy(node.state)
116
117
  else:
117
118
  self.env.load_state(node.state)
118
119
 
119
- def perform_mcts_step(self, search_start_node: SoloMCTSNode = None, num_simulations: int = None,
120
- render_tree_after_step: bool = None) -> tuple[int, SoloMCTSNode]:
120
+ def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
121
+ render_tree_after_step: bool = None) -> tuple[int, GymctsNode]:
121
122
 
122
123
  if render_tree_after_step is None:
123
124
  render_tree_after_step = self.render_tree_after_step
@@ -149,7 +150,7 @@ class SoloMCTSAgent:
149
150
 
150
151
  return action, next_node
151
152
 
152
- def vanilla_mcts_search(self, search_start_node: SoloMCTSNode = None, num_simulations=10) -> int:
153
+ def vanilla_mcts_search(self, search_start_node: GymctsNode = None, num_simulations=10) -> int:
153
154
  log.debug(f"performing one MCTS search step with {num_simulations} simulations")
154
155
  if search_start_node is None:
155
156
  search_start_node = self.search_root_node
@@ -178,7 +179,7 @@ class SoloMCTSAgent:
178
179
 
179
180
  return search_start_node.get_best_action()
180
181
 
181
- def show_mcts_tree(self, start_node: SoloMCTSNode = None, tree_max_depth: int = None) -> None:
182
+ def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
182
183
 
183
184
  if start_node is None:
184
185
  start_node = self.search_root_node
@@ -187,13 +188,17 @@ class SoloMCTSAgent:
187
188
  tree_max_depth = self.render_tree_max_depth
188
189
 
189
190
  print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
190
- for line in self._generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
191
+ for line in _generate_mcts_tree(
192
+ start_node=start_node,
193
+ depth=tree_max_depth,
194
+ action_space_n=self.env.action_space.n,
195
+ ):
191
196
  print(line)
192
197
 
193
198
  def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
194
199
  self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
195
200
 
196
- def backpropagation(self, node: SoloMCTSNode, episode_return: float) -> None:
201
+ def backpropagation(self, node: GymctsNode, episode_return: float) -> None:
197
202
  log.debug(f"performing backpropagation from leaf node: {node}")
198
203
  while not node.is_root():
199
204
  # node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
@@ -209,53 +214,4 @@ class SoloMCTSAgent:
209
214
  node.max_value = max(node.max_value, episode_return)
210
215
  node.min_value = min(node.min_value, episode_return)
211
216
 
212
- def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
213
217
 
214
- if prefix is None:
215
- prefix = ""
216
- import gymcts.colorful_console_utils as ccu
217
-
218
- if start_node is None:
219
- start_node = self.search_root_node
220
-
221
- # prefix components:
222
- space = ' '
223
- branch = '│ '
224
- # pointers:
225
- tee = '├── '
226
- last = '└── '
227
-
228
- contents = start_node.children.values() if start_node.children is not None else []
229
- if self.exclude_unvisited_nodes_from_render:
230
- contents = [node for node in contents if node.visit_count > 0]
231
- # contents each get pointers that are ├── with a final └── :
232
- # pointers = [tee] * (len(contents) - 1) + [last]
233
- pointers = [tee for _ in range(len(contents) - 1)] + [last]
234
-
235
- for pointer, current_node in zip(pointers, contents):
236
- n_item = current_node.parent.action if current_node.parent is not None else 0
237
- n_classes = self.env.action_space.n
238
-
239
- pointer = ccu.wrap_evenly_spaced_color(
240
- s=pointer,
241
- n_of_item=n_item,
242
- n_classes=n_classes,
243
- )
244
-
245
- yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
246
- if current_node.children and len(current_node.children): # extend the prefix and recurse:
247
- # extension = branch if pointer == tee else space
248
- extension = branch if tee in pointer else space
249
- # i.e. space because last, └── , above so no more |
250
- extension = ccu.wrap_evenly_spaced_color(
251
- s=extension,
252
- n_of_item=n_item,
253
- n_classes=n_classes,
254
- )
255
- if depth is not None and depth <= 0:
256
- continue
257
- yield from self._generate_mcts_tree(
258
- current_node,
259
- prefix=prefix + extension,
260
- depth=depth - 1 if depth is not None else None
261
- )
@@ -7,12 +7,12 @@ import gymnasium as gym
7
7
  from gymnasium.core import WrapperActType, WrapperObsType
8
8
  from gymnasium.wrappers import RecordEpisodeStatistics
9
9
 
10
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
10
+ from gymcts.gymcts_env_abc import GymctsABC
11
11
 
12
12
  from gymcts.logger import log
13
13
 
14
14
 
15
- class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
15
+ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
16
16
 
17
17
 
18
18
  _terminal_flag:bool = False
@@ -0,0 +1,281 @@
1
+ import copy
2
+ import gymnasium as gym
3
+
4
+ from typing import TypeVar, Any, SupportsFloat, Callable
5
+
6
+ from ray.types import ObjectRef
7
+
8
+ from gymcts.gymcts_agent import GymctsAgent
9
+ from gymcts.gymcts_env_abc import GymctsABC
10
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
11
+ from gymcts.gymcts_node import GymctsNode
12
+ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
13
+
14
+ from gymcts.logger import log
15
+
16
+ import ray
17
+ import copy
18
+
19
+ TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
20
+
21
+
22
+ @ray.remote
23
+ def mcts_lookahead(
24
+ gymcts_start_node: GymctsNode,
25
+ env: GymctsABC,
26
+ num_simulations: int) -> GymctsNode:
27
+ agent = GymctsAgent(
28
+ env=env,
29
+ clear_mcts_tree_after_step=False,
30
+ number_of_simulations_per_step=num_simulations,
31
+ )
32
+ agent.search_root_node = gymcts_start_node
33
+
34
+ agent.vanilla_mcts_search(
35
+ search_start_node=gymcts_start_node,
36
+ num_simulations=num_simulations,
37
+ )
38
+ return agent.search_root_node
39
+
40
+
41
+ def merge_nodes(gymcts_node1, gymcts_node2, perform_state_equality_check=False):
42
+ log.debug(f"merging {gymcts_node1} and {gymcts_node2}")
43
+ # maybe add some state equality check here
44
+ if perform_state_equality_check:
45
+ if gymcts_node1.state != gymcts_node2.state:
46
+ raise ValueError("States are different")
47
+
48
+ if gymcts_node1 is None:
49
+ log.debug(f"first node is None, returning second node ({gymcts_node2})")
50
+ return gymcts_node2
51
+ if gymcts_node2 is None:
52
+ log.debug(f"second node is None, returning first node ({gymcts_node1})")
53
+ return gymcts_node1
54
+ if gymcts_node1 is None and gymcts_node2 is None:
55
+ log.error("Both nodes are None")
56
+ raise ValueError("Both nodes are None")
57
+
58
+ if gymcts_node1.is_leaf() and not gymcts_node2.is_leaf():
59
+ log.debug(f"first node is leaf, second node is not leaf")
60
+ gymcts_node2.parent = gymcts_node1.parent
61
+ log.debug(f"returning first node: {gymcts_node2}")
62
+ return gymcts_node2
63
+
64
+ if gymcts_node2.is_leaf() and not gymcts_node1.is_leaf():
65
+ log.debug(f"second node is leaf, first node is not leaf")
66
+ log.debug(f"returning first node: {gymcts_node1}")
67
+ return gymcts_node1
68
+
69
+ if gymcts_node1.is_leaf() and gymcts_node2.is_leaf():
70
+ log.debug(f"both nodes are leafs, returning first node")
71
+ log.debug(f"returning first node: {gymcts_node1}")
72
+ return gymcts_node1
73
+
74
+ # check if gymcts_node1 and gymcts_node2 have the same children
75
+ if gymcts_node1.children.keys() != gymcts_node2.children.keys():
76
+ log.error("Nodes have different children")
77
+ raise ValueError("Nodes have different children")
78
+
79
+ for (action1, child1), (action2, child2) in zip(gymcts_node1.children.items(), gymcts_node2.children.items()):
80
+ if action1 != action2:
81
+ log.error("Actions are different")
82
+ raise ValueError("Actions are different")
83
+ log.debug(f"merging children with action {action1} for node {gymcts_node1}")
84
+ gymcts_node1.children[action1] = merge_nodes(
85
+ child1,
86
+ child2,
87
+ perform_state_equality_check=perform_state_equality_check
88
+ )
89
+
90
+ visit_count = gymcts_node1.visit_count + gymcts_node2.visit_count
91
+ mean_value = (
92
+ gymcts_node1.mean_value * gymcts_node1.visit_count + gymcts_node2.mean_value * gymcts_node2.visit_count) / visit_count
93
+ max_value = max(gymcts_node1.max_value, gymcts_node2.max_value)
94
+ min_value = min(gymcts_node1.min_value, gymcts_node2.min_value)
95
+
96
+ gymcts_node1.visit_count = visit_count
97
+ gymcts_node1.mean_value = mean_value
98
+ gymcts_node1.max_value = max_value
99
+ gymcts_node1.min_value = min_value
100
+ log.debug(f"merged node: {gymcts_node1}")
101
+ log.debug(f"returning node: {gymcts_node1}")
102
+ return gymcts_node1
103
+
104
+
105
+ class DistributedGymctsAgent:
106
+ render_tree_after_step: bool = False
107
+ render_tree_max_depth: int = 2
108
+ exclude_unvisited_nodes_from_render: bool = False
109
+ number_of_simulations_per_step: int = 25
110
+ num_parallel: int = 4
111
+
112
+ env: GymctsABC
113
+ search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
114
+ clear_mcts_tree_after_step: bool
115
+
116
+ def __init__(self,
117
+ env: GymctsABC,
118
+ render_tree_after_step: bool = False,
119
+ render_tree_max_depth: int = 2,
120
+ num_parallel: int = 4,
121
+ number_of_simulations_per_step: int = 25,
122
+ exclude_unvisited_nodes_from_render: bool = False
123
+ ):
124
+ # check if action space of env is discrete
125
+ if not isinstance(env.action_space, gym.spaces.Discrete):
126
+ raise ValueError("Action space must be discrete.")
127
+
128
+ self.num_parallel = num_parallel
129
+
130
+ self.render_tree_after_step = render_tree_after_step
131
+ self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
132
+ self.render_tree_max_depth = render_tree_max_depth
133
+
134
+ self.number_of_simulations_per_step = number_of_simulations_per_step
135
+
136
+ self.env = env
137
+
138
+ self.search_root_node = GymctsNode(
139
+ action=None,
140
+ parent=None,
141
+ env_reference=env,
142
+ )
143
+
144
+ def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
145
+
146
+ if num_simulations_per_step is None:
147
+ num_simulations_per_step = self.number_of_simulations_per_step
148
+ if render_tree_after_step is None:
149
+ render_tree_after_step = self.render_tree_after_step
150
+
151
+ log.debug(f"Solving from root node: {self.search_root_node}")
152
+
153
+ current_node = self.search_root_node
154
+
155
+ action_list = []
156
+
157
+ while not current_node.terminal:
158
+ next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
159
+ render_tree_after_step=render_tree_after_step)
160
+ log.info(
161
+ f"selected action {next_action} after {self.num_parallel} x {num_simulations_per_step} simulations.")
162
+ action_list.append(next_action)
163
+ log.info(f"current action list: {action_list}")
164
+
165
+
166
+ log.info(f"Final action list: {action_list}")
167
+ # restore state of current node
168
+ return action_list
169
+
170
+ def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
171
+ render_tree_after_step: bool = None, num_parallel: int = None) -> tuple[int, GymctsNode]:
172
+
173
+ if render_tree_after_step is None:
174
+ render_tree_after_step = self.render_tree_after_step
175
+
176
+ if render_tree_after_step is None:
177
+ render_tree_after_step = self.render_tree_after_step
178
+
179
+ if num_simulations is None:
180
+ num_simulations = self.number_of_simulations_per_step
181
+
182
+ if search_start_node is None:
183
+ search_start_node = self.search_root_node
184
+
185
+ if num_parallel is None:
186
+ num_parallel = self.num_parallel
187
+
188
+ # action = self.vanilla_mcts_search(
189
+ # search_start_node=search_start_node,
190
+ # num_simulations=num_simulations,
191
+ # )
192
+ # next_node = search_start_node.children[action]
193
+
194
+ mcts_interation_futures = [
195
+ mcts_lookahead.remote(
196
+ copy.deepcopy(search_start_node),
197
+ copy.deepcopy(self.env),
198
+ num_simulations=num_simulations
199
+ )
200
+ for _ in range(num_parallel)
201
+ ]
202
+
203
+ while mcts_interation_futures:
204
+ ready_gymcts_nodes, mcts_interation_futures = ray.wait(mcts_interation_futures)
205
+ for ready_node_ref in ready_gymcts_nodes:
206
+ ready_node = ray.get(ready_node_ref)
207
+
208
+ # merge the tree
209
+ search_start_node = merge_nodes(search_start_node, ready_node)
210
+
211
+ action = search_start_node.get_best_action()
212
+ next_node = search_start_node.children[action]
213
+
214
+ if self.render_tree_after_step:
215
+ self.show_mcts_tree(
216
+ start_node=search_start_node,
217
+ tree_max_depth=self.render_tree_max_depth
218
+ )
219
+
220
+
221
+ # to clear memory we need to remove all nodes except the current node
222
+ # this is done by setting the root node to the current node
223
+ # and setting the parent of the current node to None
224
+ # we also need to reset the children of the current node
225
+ # this is done by calling the reset method
226
+ #
227
+ # in a distributed setting we need we delete all previous nodes
228
+ # this is because backpropagation merging trees is already computationally expensive
229
+ # and backpropagating the whole tree would be even more expensive
230
+ next_node.reset()
231
+
232
+ self.search_root_node = next_node
233
+
234
+ return action, next_node
235
+
236
+ def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
237
+
238
+ if start_node is None:
239
+ start_node = self.search_root_node
240
+
241
+ if tree_max_depth is None:
242
+ tree_max_depth = self.render_tree_max_depth
243
+
244
+ print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
245
+ for line in _generate_mcts_tree(
246
+ start_node=start_node,
247
+ depth=tree_max_depth,
248
+ action_space_n=self.env.action_space.n
249
+ ):
250
+ print(line)
251
+
252
+ def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
253
+ self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
254
+
255
+
256
+ if __name__ == '__main__':
257
+ ray.init()
258
+
259
+ log.setLevel(20) # 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CR
260
+ env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)
261
+ env.reset()
262
+
263
+ # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
264
+ # env1 = ActionHistoryMCTSGymEnvWrapper(env1)
265
+ env = DeepCopyMCTSGymEnvWrapper(env)
266
+
267
+ # 2. create the agent
268
+ agent1 = DistributedGymctsAgent(
269
+ env=env,
270
+ render_tree_after_step=True,
271
+ number_of_simulations_per_step=1000,
272
+ exclude_unvisited_nodes_from_render=True,
273
+ num_parallel=1,
274
+ )
275
+ import time
276
+
277
+ start_time = time.perf_counter()
278
+ actions = agent1.solve()
279
+ end_time = time.perf_counter()
280
+
281
+ print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
@@ -5,7 +5,7 @@ import gymnasium as gym
5
5
  TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
6
6
 
7
7
 
8
- class SoloMCTSGymEnv(ABC, gym.Env):
8
+ class GymctsABC(ABC, gym.Env):
9
9
 
10
10
  @abstractmethod
11
11
  def get_state(self) -> Any:
gymcts/gymcts_node.py CHANGED
@@ -4,21 +4,17 @@ import math
4
4
 
5
5
  from typing import TypeVar, Any, SupportsFloat, Callable, Generator
6
6
 
7
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
7
+ from gymcts.gymcts_env_abc import GymctsABC
8
8
 
9
9
  from gymcts.logger import log
10
10
 
11
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
11
+ TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
12
12
 
13
13
 
14
-
15
-
16
- class SoloMCTSNode:
17
-
14
+ class GymctsNode:
18
15
  # static properties
19
16
  best_action_weight: float = 0.05
20
- ubc_c = 0.707
21
-
17
+ ubc_c = 0.707
22
18
 
23
19
  # attributes
24
20
  visit_count: int = 0
@@ -28,7 +24,6 @@ class SoloMCTSNode:
28
24
  terminal: bool = False
29
25
  state: Any
30
26
 
31
-
32
27
  def __str__(self, colored=False, action_space_n=None) -> str:
33
28
  if not colored:
34
29
 
@@ -39,11 +34,9 @@ class SoloMCTSNode:
39
34
 
40
35
  import gymcts.colorful_console_utils as ccu
41
36
 
42
-
43
37
  if self.is_root():
44
38
  return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
45
39
 
46
-
47
40
  if action_space_n is None:
48
41
  raise ValueError("action_space_n must be provided if colored is True")
49
42
 
@@ -68,25 +61,23 @@ class SoloMCTSNode:
68
61
  if isinstance(value, int):
69
62
  return f"{color}{value}{e}"
70
63
 
71
-
72
64
  root_node = self.get_root()
73
65
  mean_val = f"{self.mean_value:.2f}"
74
66
 
75
67
  return ((f"("
76
- f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
77
- f"{p}N{e}={colorful_value(self.visit_count)}, "
78
- f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
79
- f"{p}best{e}={colorful_value(self.max_value)}") +
68
+ f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
69
+ f"{p}N{e}={colorful_value(self.visit_count)}, "
70
+ f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
71
+ f"{p}best{e}={colorful_value(self.max_value)}") +
80
72
  (f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
81
73
 
82
-
83
- def traverse_nodes(self) -> Generator[TSoloMCTSNode, None, None]:
74
+ def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
84
75
  yield self
85
76
  if self.children:
86
77
  for child in self.children.values():
87
78
  yield from child.traverse_nodes()
88
79
 
89
- def get_root(self) -> TSoloMCTSNode:
80
+ def get_root(self) -> TGymctsNode:
90
81
  if self.is_root():
91
82
  return self
92
83
  return self.parent.get_root()
@@ -101,19 +92,17 @@ class SoloMCTSNode:
101
92
  return 0
102
93
  return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
103
94
 
104
-
105
-
106
95
  def __init__(self,
107
96
  action: int | None,
108
- parent: TSoloMCTSNode | None,
109
- env_reference: SoloMCTSGymEnv,
97
+ parent: TGymctsNode | None,
98
+ env_reference: GymctsABC,
110
99
  ):
111
100
 
112
101
  # field depending on whether the node is a root node or not
113
102
  self.action: int | None
114
103
 
115
- self.env_reference: SoloMCTSGymEnv
116
- self.parent: SoloMCTSNode | None
104
+ self.env_reference: GymctsABC
105
+ self.parent: GymctsNode | None
117
106
  self.uuid = uuid.uuid4()
118
107
 
119
108
  if parent is None:
@@ -133,7 +122,7 @@ class SoloMCTSNode:
133
122
 
134
123
  from copy import copy
135
124
  self.state = env_reference.get_state()
136
- #log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
125
+ # log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
137
126
  self.visit_count: int = 0
138
127
 
139
128
  self.mean_value: float = 0
@@ -143,8 +132,7 @@ class SoloMCTSNode:
143
132
  # safe valid action instead of calling the environment
144
133
  # this reduces the compute but increases the memory usage
145
134
  self.valid_actions: list[int] = env_reference.get_valid_actions()
146
- self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
147
-
135
+ self.children: dict[int, GymctsNode] | None = None # may be expanded later
148
136
 
149
137
  def reset(self) -> None:
150
138
  self.parent = None
@@ -153,35 +141,33 @@ class SoloMCTSNode:
153
141
  self.mean_value: float = 0
154
142
  self.max_value: float = -float("inf")
155
143
  self.min_value: float = +float("inf")
156
- self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
144
+ self.children: dict[int, GymctsNode] | None = None # may be expanded later
157
145
 
158
146
  # just setting the children of the parent node to None should be enough to trigger garbage collection
159
147
  # however, we also set the parent to None to make sure that the parent is not referenced anymore
160
148
  if self.parent:
161
149
  self.parent.reset()
162
150
 
163
-
164
-
165
151
  def is_root(self) -> bool:
166
152
  return self.parent is None
167
153
 
168
154
  def is_leaf(self) -> bool:
169
155
  return self.children is None or len(self.children) == 0
170
156
 
171
- def get_random_child(self) -> TSoloMCTSNode:
157
+ def get_random_child(self) -> TGymctsNode:
172
158
  if self.is_leaf():
173
- raise ValueError("cannot get random child of leaf node") #todo: maybe return self instead?
159
+ raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
174
160
 
175
161
  return list(self.children.values())[random.randint(0, len(self.children) - 1)]
176
162
 
177
163
  def get_best_action(self) -> int:
178
164
  return max(self.children.values(), key=lambda child: child.get_score()).action
179
165
 
180
- def get_score(self) -> float: # todo: make it an attribute?
166
+ def get_score(self) -> float: # todo: make it an attribute?
181
167
  # return self.mean_value
182
- assert 0 <= SoloMCTSNode.best_action_weight <= 1
183
- a = SoloMCTSNode.best_action_weight
184
- return (1-a) * self.mean_value + a * self.max_value
168
+ assert 0 <= GymctsNode.best_action_weight <= 1
169
+ a = GymctsNode.best_action_weight
170
+ return (1 - a) * self.mean_value + a * self.max_value
185
171
 
186
172
  def get_mean_value(self) -> float:
187
173
  return self.mean_value
@@ -207,7 +193,7 @@ class SoloMCTSNode:
207
193
  if self.is_root():
208
194
  raise ValueError("ucb_score can only be called on non-root nodes")
209
195
  # c = 0.707 # todo: make it an attribute?
210
- c = SoloMCTSNode.ubc_c
196
+ c = GymctsNode.ubc_c
211
197
  if self.visit_count == 0:
212
198
  return float("inf")
213
- return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
199
+ return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
@@ -0,0 +1,75 @@
1
+ from gymcts.gymcts_node import GymctsNode
2
+
3
+ from gymcts.logger import log
4
+
5
+
6
+ def _generate_mcts_tree(
7
+ start_node: GymctsNode = None,
8
+ prefix: str = None,
9
+ depth: int = None,
10
+ exclude_unvisited_nodes_from_render: bool = True,
11
+ action_space_n: int = None
12
+ ) -> list[str]:
13
+ if prefix is None:
14
+ prefix = ""
15
+ import gymcts.colorful_console_utils as ccu
16
+
17
+ if start_node is None:
18
+ raise ValueError("start_node must not be None")
19
+
20
+ if action_space_n is None:
21
+ log.warning("action_space_n is None, defaulting to 100")
22
+ action_space_n = 100
23
+
24
+ # prefix components:
25
+ space = ' '
26
+ branch = '│ '
27
+ # pointers:
28
+ tee = '├── '
29
+ last = '└── '
30
+
31
+ contents = start_node.children.values() if start_node.children is not None else []
32
+ if exclude_unvisited_nodes_from_render:
33
+ contents = [node for node in contents if node.visit_count > 0]
34
+ # contents each get pointers that are ├── with a final └── :
35
+ # pointers = [tee] * (len(contents) - 1) + [last]
36
+ pointers = [tee for _ in range(len(contents) - 1)] + [last]
37
+
38
+ for pointer, current_node in zip(pointers, contents):
39
+ n_item = current_node.parent.action if current_node.parent is not None else 0
40
+ n_classes = action_space_n
41
+
42
+ pointer = ccu.wrap_evenly_spaced_color(
43
+ s=pointer,
44
+ n_of_item=n_item,
45
+ n_classes=n_classes,
46
+ )
47
+
48
+ yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
49
+ if current_node.children and len(current_node.children): # extend the prefix and recurse:
50
+ # extension = branch if pointer == tee else space
51
+ extension = branch if tee in pointer else space
52
+ # i.e. space because last, └── , above so no more |
53
+ extension = ccu.wrap_evenly_spaced_color(
54
+ s=extension,
55
+ n_of_item=n_item,
56
+ n_classes=n_classes,
57
+ )
58
+ if depth is not None and depth <= 0:
59
+ continue
60
+ yield from _generate_mcts_tree(
61
+ current_node,
62
+ prefix=prefix + extension,
63
+ action_space_n=action_space_n,
64
+ depth=depth - 1 if depth is not None else None
65
+ )
66
+
67
+
68
+ def show_mcts_tree(
69
+ start_node: GymctsNode = None,
70
+ tree_max_depth: int = None,
71
+ action_space_n: int = None
72
+ ) -> None:
73
+ print(start_node.__str__(colored=True, action_space_n=action_space_n))
74
+ for line in _generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
75
+ print(line)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.0.0
3
+ Version: 1.2.0
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -25,7 +25,7 @@ License: MIT License
25
25
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
26
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
27
  SOFTWARE.
28
- Project-URL: Homepage, https://github.com/Alexander-Nasuta/pypitemplate
28
+ Project-URL: Homepage, https://github.com/Alexander-Nasuta/gymcts
29
29
  Platform: unix
30
30
  Platform: linux
31
31
  Platform: osx
@@ -34,7 +34,7 @@ Platform: win32
34
34
  Classifier: License :: OSI Approved :: MIT License
35
35
  Classifier: Programming Language :: Python
36
36
  Classifier: Programming Language :: Python :: 3
37
- Requires-Python: >=3.9
37
+ Requires-Python: >=3.11
38
38
  Description-Content-Type: text/markdown
39
39
  License-File: LICENSE
40
40
  Requires-Dist: rich
@@ -63,6 +63,9 @@ Requires-Dist: furo; extra == "dev"
63
63
  Requires-Dist: twine; extra == "dev"
64
64
  Requires-Dist: sphinx-copybutton; extra == "dev"
65
65
  Requires-Dist: nbsphinx; extra == "dev"
66
+ Requires-Dist: jupytext; extra == "dev"
67
+ Requires-Dist: jupyter; extra == "dev"
68
+ Dynamic: license-file
66
69
 
67
70
  # Graph Matrix Job Shop Env
68
71
 
@@ -118,8 +121,8 @@ The NaiveSoloMCTSGymEnvWrapper can be used with non-deterministic environments,
118
121
  ```python
119
122
  import gymnasium as gym
120
123
 
121
- from gymcts.gymcts_agent import SoloMCTSAgent
122
- from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
124
+ from gymcts.gymcts_agent import GymctsAgent
125
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
123
126
 
124
127
  from gymcts.logger import log
125
128
 
@@ -133,10 +136,10 @@ if __name__ == '__main__':
133
136
  env.reset()
134
137
 
135
138
  # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
136
- env = NaiveSoloMCTSGymEnvWrapper(env)
139
+ env = DeepCopyMCTSGymEnvWrapper(env)
137
140
 
138
141
  # 2. create the agent
139
- agent = SoloMCTSAgent(
142
+ agent = GymctsAgent(
140
143
  env=env,
141
144
  clear_mcts_tree_after_step=False,
142
145
  render_tree_after_step=True,
@@ -170,13 +173,13 @@ if __name__ == '__main__':
170
173
  A minimal example of how to use the package with the FrozenLake environment and the DeterministicSoloMCTSGymEnvWrapper is provided in the following code snippet below.
171
174
  The DeterministicSoloMCTSGymEnvWrapper can be used with deterministic environments, such as the FrozenLake environment without slippery ice.
172
175
 
173
- The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
176
+ The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
174
177
 
175
178
  ```python
176
179
  import gymnasium as gym
177
180
 
178
- from gymcts.gymcts_agent import SoloMCTSAgent
179
- from gymcts.gymcts_deterministic_wrapper import DeterministicSoloMCTSGymEnvWrapper
181
+ from gymcts.gymcts_agent import GymctsAgent
182
+ from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
180
183
 
181
184
  from gymcts.logger import log
182
185
 
@@ -190,10 +193,10 @@ if __name__ == '__main__':
190
193
  env.reset()
191
194
 
192
195
  # 1. wrap the environment with the wrapper
193
- env = DeterministicSoloMCTSGymEnvWrapper(env)
196
+ env = ActionHistoryMCTSGymEnvWrapper(env)
194
197
 
195
198
  # 2. create the agent
196
- agent = SoloMCTSAgent(
199
+ agent = GymctsAgent(
197
200
  env=env,
198
201
  clear_mcts_tree_after_step=False,
199
202
  render_tree_after_step=True,
@@ -232,8 +235,8 @@ To create a video of the solution of the FrozenLake environment, you can use the
232
235
  ```python
233
236
  import gymnasium as gym
234
237
 
235
- from gymcts.gymcts_agent import SoloMCTSAgent
236
- from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
238
+ from gymcts.gymcts_agent import GymctsAgent
239
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
237
240
 
238
241
  from gymcts.logger import log
239
242
 
@@ -249,10 +252,10 @@ if __name__ == '__main__':
249
252
  env.reset()
250
253
 
251
254
  # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
252
- env = NaiveSoloMCTSGymEnvWrapper(env)
255
+ env = DeepCopyMCTSGymEnvWrapper(env)
253
256
 
254
257
  # 2. create the agent
255
- agent = SoloMCTSAgent(
258
+ agent = GymctsAgent(
256
259
  env=env,
257
260
  clear_mcts_tree_after_step=False,
258
261
  render_tree_after_step=True,
@@ -413,13 +416,12 @@ The color gradient is based on the minimum and maximum values of the respective
413
416
  The visualisation is rendered in the terminal and can be limited to a certain depth of the tree.
414
417
  The default depth is 2.
415
418
 
416
-
417
419
  ```python
418
420
  import gymnasium as gym
419
421
 
420
- from gymcts.gymcts_agent import SoloMCTSAgent
421
- from gymcts.gymcts_deterministic_wrapper import DeterministicSoloMCTSGymEnvWrapper
422
- from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
422
+ from gymcts.gymcts_agent import GymctsAgent
423
+ from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
424
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
423
425
 
424
426
  from gymcts.logger import log
425
427
 
@@ -433,10 +435,10 @@ if __name__ == '__main__':
433
435
  env.reset()
434
436
 
435
437
  # wrap the environment with the naive wrapper or a custom gymcts wrapper
436
- env = DeterministicSoloMCTSGymEnvWrapper(env)
438
+ env = ActionHistoryMCTSGymEnvWrapper(env)
437
439
 
438
440
  # create the agent
439
- agent = SoloMCTSAgent(
441
+ agent = GymctsAgent(
440
442
  env=env,
441
443
  clear_mcts_tree_after_step=False,
442
444
  render_tree_after_step=False,
@@ -0,0 +1,15 @@
1
+ gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ gymcts/colorful_console_utils.py,sha256=OhULcXHKbEA4uJDAEYCTcW6wUv0LsHX_XSYzZ_Szsv4,4553
3
+ gymcts/gymcts_action_history_wrapper.py,sha256=AjvBBwd1t9-nTYP09aMdlScAkFNXf5vOagejpjWYOPo,3810
4
+ gymcts/gymcts_agent.py,sha256=O2y98jKFjR5TzqVV7DO1jlcYDyzAgd_H2RF4-w4NP0g,8499
5
+ gymcts/gymcts_deepcopy_wrapper.py,sha256=OleQTnvxv3gLEo8-2asyeo-CpZ4HEbgyFGS5DTCD7NM,4167
6
+ gymcts/gymcts_distributed_agent.py,sha256=M7dyBfC8u3M99PJFoXKgIc_CPTyHGppmktkH-y9ci4U,10448
7
+ gymcts/gymcts_env_abc.py,sha256=7nCRiiClmmVLX-d_Q1dxeztmuvmAtmWZwjT81zrG1_w,575
8
+ gymcts/gymcts_node.py,sha256=PT_YZFwt1zjuvd8i9Wb5LEkHAqmJOFyPDp3GFD05lqM,7138
9
+ gymcts/gymcts_tree_plotter.py,sha256=eg207wHcDepwWODXzmDYQn1Aai29Cs4jFS1HNvAhlXs,2651
10
+ gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
11
+ gymcts-1.2.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
12
+ gymcts-1.2.0.dist-info/METADATA,sha256=zhEIFo0rOnv5hCv6ukImkq-9nshO4EfXMbHlhNlYhyA,23640
13
+ gymcts-1.2.0.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
14
+ gymcts-1.2.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
15
+ gymcts-1.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (78.0.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,13 +0,0 @@
1
- gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- gymcts/colorful_console_utils.py,sha256=bbZzRFzimhsIhbT-nmz6v62WJLxFDgzFvqI_pmIsckE,4526
3
- gymcts/gymcts_agent.py,sha256=TJXJH77T95EP3ZNtzWqlGw9iFF1R-nsItp7UA1ZlXUs,10537
4
- gymcts/gymcts_deterministic_wrapper.py,sha256=PILGPaQnyG2u_2u48MEE3aeJCtdgjjO55ZFDxeIVeH0,3824
5
- gymcts/gymcts_gym_env.py,sha256=R1Z1fhoywdXmPt_FYgrarIh0YFQvCifayAWnCcEiJKE,580
6
- gymcts/gymcts_naive_wrapper.py,sha256=qeQ7rzBz7BFv2yCJj3GmdFt5UlTx5VHMw5ImZUl9H5k,4178
7
- gymcts/gymcts_node.py,sha256=jxdtuC1iqeRtEA-Qfvq-mOuM8vdDl43iWe5hqItG90w,7185
8
- gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
9
- gymcts-1.0.0.dist-info/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
10
- gymcts-1.0.0.dist-info/METADATA,sha256=sAXJQreADqEOviVL8nT8fmrx7hP-qM7C_-SC5FNw-94,23572
11
- gymcts-1.0.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
12
- gymcts-1.0.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
13
- gymcts-1.0.0.dist-info/RECORD,,