gymcts 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  import matplotlib.pyplot as plt
2
4
  import numpy as np
3
5
 
@@ -103,8 +105,19 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
103
105
  f"{CEND}"
104
106
 
105
107
 
106
-
107
- def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
108
+ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
109
+ """
110
+ Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
111
+ This function is used to color code the available actions in the MCTS tree visualisation.
112
+ The children of the MCTS tree are colored based on their action for a clearer visualisation.
113
+
114
+ :param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
115
+ :param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
116
+ :param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
117
+ :param c_map: the colormap to be used (default is 'rainbow').
118
+ The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
119
+ :return: a string that contains the color-codes (prefix and suffix) and the string s in between.
120
+ """
108
121
  if s is None or n_of_item is None or n_classes is None:
109
122
  return s
110
123
 
@@ -117,7 +130,17 @@ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbo
117
130
  return f"{color_asni}{s}{CEND}"
118
131
 
119
132
 
120
- def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
133
+ def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
134
+ """
135
+ Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
136
+
137
+ :param s: the string to be wrapped
138
+ :param value: the value to be mapped to a color
139
+ :param min_val: the minimum value of the scale
140
+ :param max_val: the maximum value of the scale
141
+ :param c_map: the colormap to be used (default is 'rainbow')
142
+ :return:
143
+ """
121
144
  if s is None or min_val is None or max_val is None or min_val >= max_val:
122
145
  return s
123
146
 
@@ -1,18 +1,32 @@
1
1
  import random
2
- import copy
3
2
 
4
3
  import numpy as np
5
- from typing import TypeVar, Any, SupportsFloat, Callable
4
+ from typing import Any, SupportsFloat, Callable
6
5
  import gymnasium as gym
7
6
  from gymnasium.core import WrapperActType, WrapperObsType
8
7
  from gymnasium.wrappers import RecordEpisodeStatistics
9
8
 
10
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
9
+ from gymcts.gymcts_env_abc import GymctsABC
11
10
 
12
11
  from gymcts.logger import log
13
12
 
14
13
 
15
- class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
14
+ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
15
+ """
16
+ A wrapper for gym environments that implements the GymctsABC interface.
17
+ It uses the action history as state representation.
18
+ Please note that this is not the most efficient way to implement the state representation.
19
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
20
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
21
+ The action history is a list of actions taken in the environment.
22
+ The state is represented as a list of actions taken in the environment.
23
+ The state is used to restore the environment using the load_state method.
24
+
25
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
26
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
27
+ """
28
+
29
+ # helper attributes for the wrapper
16
30
  _terminal_flag: bool = False
17
31
  _last_reward: SupportsFloat = 0
18
32
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -25,6 +39,17 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
25
39
  action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
26
40
  buffer_length: int = 100,
27
41
  ):
42
+ """
43
+ A wrapper for gym environments that implements the GymctsABC interface.
44
+ It uses the action history as state representation.
45
+ Please note that this is not the most efficient way to implement the state representation.
46
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
47
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
48
+
49
+ :param env: the environment to wrap
50
+ :param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
51
+ :param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
52
+ """
28
53
  # wrap with RecordEpisodeStatistics if it is not already wrapped
29
54
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
30
55
 
@@ -48,6 +73,17 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
48
73
  self._action_mask_fn = action_mask_fn
49
74
 
50
75
  def load_state(self, state: list[int]) -> None:
76
+ """
77
+ Loads the state of the environment. The state is a list of actions taken in the environment.
78
+
79
+ The environment is reset and all actions in the state are performed in order to restore the environment to the
80
+ same state.
81
+
82
+ This works only for deterministic environments!
83
+
84
+ :param state: the state to load
85
+ :return: None
86
+ """
51
87
  self.env.reset()
52
88
  self._wrapper_action_history = []
53
89
 
@@ -56,15 +92,30 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
56
92
  self._wrapper_action_history.append(action)
57
93
 
58
94
  def is_terminal(self) -> bool:
95
+ """
96
+ Returns True if the environment is in a terminal state, False otherwise.
97
+
98
+ :return:
99
+ """
59
100
  if not len(self.get_valid_actions()):
60
101
  return True
61
102
  else:
62
103
  return self._terminal_flag
63
104
 
64
105
  def action_masks(self) -> np.ndarray | None:
106
+ """
107
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
108
+
109
+ :return:
110
+ """
65
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
66
112
 
67
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+
117
+ :return: a list of valid actions
118
+ """
68
119
  if self._action_mask_fn is None:
69
120
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
70
121
  return list(range(action_space.n))
@@ -72,6 +123,12 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
72
123
  return [i for i, mask in enumerate(self.action_masks()) if mask]
73
124
 
74
125
  def rollout(self) -> float:
126
+ """
127
+ Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
128
+ of the rollout.
129
+
130
+ :return: the return of the rollout
131
+ """
75
132
  log.debug("performing rollout")
76
133
  # random rollout
77
134
  # perform random valid action util terminal
@@ -92,11 +149,24 @@ class DeterministicSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
92
149
  return episode_return
93
150
 
94
151
  def get_state(self) -> list[int]:
152
+ """
153
+ Returns the current state of the environment. The state is a list of actions taken in the environment,
154
+ namely all action that have been taken in the environment so far (since the last reset).
155
+
156
+ :return: a list of actions taken in the environment
157
+ """
158
+
95
159
  return self._wrapper_action_history.copy()
96
160
 
97
161
  def step(
98
162
  self, action: WrapperActType
99
163
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
164
+ """
165
+ Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
166
+
167
+ :param action: action to perform in the environment
168
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
169
+ """
100
170
  step_tuple = self.env.step(action)
101
171
  self._wrapper_action_history.append(action)
102
172
  obs, reward, terminated, truncated, info = step_tuple
gymcts/gymcts_agent.py CHANGED
@@ -1,29 +1,31 @@
1
1
  import copy
2
+ import random
2
3
  import gymnasium as gym
3
4
 
4
5
  from typing import TypeVar, Any, SupportsFloat, Callable
5
6
 
6
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
7
- from gymcts.gymcts_naive_wrapper import NaiveSoloMCTSGymEnvWrapper
8
- from gymcts.gymcts_node import SoloMCTSNode
7
+ from gymcts.gymcts_env_abc import GymctsABC
8
+ from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
9
+ from gymcts.gymcts_node import GymctsNode
10
+ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
9
11
 
10
12
  from gymcts.logger import log
11
13
 
12
14
  TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
13
15
 
14
16
 
15
- class SoloMCTSAgent:
17
+ class GymctsAgent:
16
18
  render_tree_after_step: bool = False
17
19
  render_tree_max_depth: int = 2
18
20
  exclude_unvisited_nodes_from_render: bool = False
19
21
  number_of_simulations_per_step: int = 25
20
22
 
21
- env: SoloMCTSGymEnv
22
- search_root_node: SoloMCTSNode # NOTE: this is not the same as the root of the tree!
23
+ env: GymctsABC
24
+ search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
23
25
  clear_mcts_tree_after_step: bool
24
26
 
25
27
  def __init__(self,
26
- env: SoloMCTSGymEnv,
28
+ env: GymctsABC,
27
29
  clear_mcts_tree_after_step: bool = True,
28
30
  render_tree_after_step: bool = False,
29
31
  render_tree_max_depth: int = 2,
@@ -43,13 +45,13 @@ class SoloMCTSAgent:
43
45
  self.env = env
44
46
  self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
45
47
 
46
- self.search_root_node = SoloMCTSNode(
48
+ self.search_root_node = GymctsNode(
47
49
  action=None,
48
50
  parent=None,
49
51
  env_reference=env,
50
52
  )
51
53
 
52
- def navigate_to_leaf(self, from_node: SoloMCTSNode) -> SoloMCTSNode:
54
+ def navigate_to_leaf(self, from_node: GymctsNode) -> GymctsNode:
53
55
  log.debug(f"Navigate to leaf. from_node: {from_node}")
54
56
  if from_node.terminal:
55
57
  log.debug("Node is terminal. Returning from_node")
@@ -62,11 +64,14 @@ class SoloMCTSAgent:
62
64
  # NAVIGATION STRATEGY
63
65
  # select child with highest UCB score
64
66
  while not temp_node.is_leaf():
65
- temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score())
67
+ children = list(temp_node.children.values())
68
+ max_ucb_score = max(child.ucb_score() for child in children)
69
+ best_children = [child for child in children if child.ucb_score() == max_ucb_score]
70
+ temp_node = random.choice(best_children)
66
71
  log.debug(f"Selected leaf node: {temp_node}")
67
72
  return temp_node
68
73
 
69
- def expand_node(self, node: SoloMCTSNode) -> None:
74
+ def expand_node(self, node: GymctsNode) -> None:
70
75
  log.debug(f"expanding node: {node}")
71
76
  # EXPANSION STRATEGY
72
77
  # expand all children
@@ -78,7 +83,7 @@ class SoloMCTSAgent:
78
83
  self._load_state(node)
79
84
 
80
85
  obs, reward, terminal, truncated, _ = self.env.step(action)
81
- child_dict[action] = SoloMCTSNode(
86
+ child_dict[action] = GymctsNode(
82
87
  action=action,
83
88
  parent=node,
84
89
  env_reference=self.env,
@@ -110,14 +115,14 @@ class SoloMCTSAgent:
110
115
  # restore state of current node
111
116
  return action_list
112
117
 
113
- def _load_state(self, node: SoloMCTSNode) -> None:
114
- if isinstance(self.env, NaiveSoloMCTSGymEnvWrapper):
118
+ def _load_state(self, node: GymctsNode) -> None:
119
+ if isinstance(self.env, DeepCopyMCTSGymEnvWrapper):
115
120
  self.env = copy.deepcopy(node.state)
116
121
  else:
117
122
  self.env.load_state(node.state)
118
123
 
119
- def perform_mcts_step(self, search_start_node: SoloMCTSNode = None, num_simulations: int = None,
120
- render_tree_after_step: bool = None) -> tuple[int, SoloMCTSNode]:
124
+ def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
125
+ render_tree_after_step: bool = None) -> tuple[int, GymctsNode]:
121
126
 
122
127
  if render_tree_after_step is None:
123
128
  render_tree_after_step = self.render_tree_after_step
@@ -149,7 +154,7 @@ class SoloMCTSAgent:
149
154
 
150
155
  return action, next_node
151
156
 
152
- def vanilla_mcts_search(self, search_start_node: SoloMCTSNode = None, num_simulations=10) -> int:
157
+ def vanilla_mcts_search(self, search_start_node: GymctsNode = None, num_simulations=10) -> int:
153
158
  log.debug(f"performing one MCTS search step with {num_simulations} simulations")
154
159
  if search_start_node is None:
155
160
  search_start_node = self.search_root_node
@@ -178,7 +183,7 @@ class SoloMCTSAgent:
178
183
 
179
184
  return search_start_node.get_best_action()
180
185
 
181
- def show_mcts_tree(self, start_node: SoloMCTSNode = None, tree_max_depth: int = None) -> None:
186
+ def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
182
187
 
183
188
  if start_node is None:
184
189
  start_node = self.search_root_node
@@ -187,13 +192,17 @@ class SoloMCTSAgent:
187
192
  tree_max_depth = self.render_tree_max_depth
188
193
 
189
194
  print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
190
- for line in self._generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
195
+ for line in _generate_mcts_tree(
196
+ start_node=start_node,
197
+ depth=tree_max_depth,
198
+ action_space_n=self.env.action_space.n,
199
+ ):
191
200
  print(line)
192
201
 
193
202
  def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
194
203
  self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
195
204
 
196
- def backpropagation(self, node: SoloMCTSNode, episode_return: float) -> None:
205
+ def backpropagation(self, node: GymctsNode, episode_return: float) -> None:
197
206
  log.debug(f"performing backpropagation from leaf node: {node}")
198
207
  while not node.is_root():
199
208
  # node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
@@ -209,53 +218,4 @@ class SoloMCTSAgent:
209
218
  node.max_value = max(node.max_value, episode_return)
210
219
  node.min_value = min(node.min_value, episode_return)
211
220
 
212
- def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
213
221
 
214
- if prefix is None:
215
- prefix = ""
216
- import gymcts.colorful_console_utils as ccu
217
-
218
- if start_node is None:
219
- start_node = self.search_root_node
220
-
221
- # prefix components:
222
- space = ' '
223
- branch = '│ '
224
- # pointers:
225
- tee = '├── '
226
- last = '└── '
227
-
228
- contents = start_node.children.values() if start_node.children is not None else []
229
- if self.exclude_unvisited_nodes_from_render:
230
- contents = [node for node in contents if node.visit_count > 0]
231
- # contents each get pointers that are ├── with a final └── :
232
- # pointers = [tee] * (len(contents) - 1) + [last]
233
- pointers = [tee for _ in range(len(contents) - 1)] + [last]
234
-
235
- for pointer, current_node in zip(pointers, contents):
236
- n_item = current_node.parent.action if current_node.parent is not None else 0
237
- n_classes = self.env.action_space.n
238
-
239
- pointer = ccu.wrap_evenly_spaced_color(
240
- s=pointer,
241
- n_of_item=n_item,
242
- n_classes=n_classes,
243
- )
244
-
245
- yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
246
- if current_node.children and len(current_node.children): # extend the prefix and recurse:
247
- # extension = branch if pointer == tee else space
248
- extension = branch if tee in pointer else space
249
- # i.e. space because last, └── , above so no more |
250
- extension = ccu.wrap_evenly_spaced_color(
251
- s=extension,
252
- n_of_item=n_item,
253
- n_classes=n_classes,
254
- )
255
- if depth is not None and depth <= 0:
256
- continue
257
- yield from self._generate_mcts_tree(
258
- current_node,
259
- prefix=prefix + extension,
260
- depth=depth - 1 if depth is not None else None
261
- )
@@ -7,14 +7,21 @@ import gymnasium as gym
7
7
  from gymnasium.core import WrapperActType, WrapperObsType
8
8
  from gymnasium.wrappers import RecordEpisodeStatistics
9
9
 
10
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
10
+ from gymcts.gymcts_env_abc import GymctsABC
11
11
 
12
12
  from gymcts.logger import log
13
13
 
14
14
 
15
- class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
16
-
15
+ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
16
+ """
17
+ A wrapper for gym environments that implements the GymctsABC interface.
18
+ It uses deepcopys as state representation.
19
+ Please note that this is not the most efficient way to implement the state representation.
20
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
21
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
22
+ """
17
23
 
24
+ # helper attributes for the wrapper
18
25
  _terminal_flag:bool = False
19
26
  _last_reward: SupportsFloat = 0
20
27
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -22,9 +29,21 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
22
29
  _action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
23
30
 
24
31
  def is_terminal(self) -> bool:
32
+ """
33
+ Returns True if the environment is in a terminal state, False otherwise.
34
+
35
+ :return: True if the environment is in a terminal state, False otherwise.
36
+ """
25
37
  return self._terminal_flag
26
38
 
27
39
  def load_state(self, state: Any) -> None:
40
+ """
41
+ The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
42
+ provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
43
+
44
+ :param state: a deepcopy of the environment
45
+ :return: None
46
+ """
28
47
  msg = """
29
48
  The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
30
49
  The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
@@ -39,6 +58,16 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
39
58
  buffer_length: int = 100,
40
59
  record_video: bool = False,
41
60
  ):
61
+ """
62
+ The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
63
+ space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
64
+ find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
65
+
66
+ :param env: the environment to wrap
67
+ :param action_mask_fn:
68
+ :param buffer_length:
69
+ :param record_video:
70
+ """
42
71
  # wrap with RecordEpisodeStatistics if it is not already wrapped
43
72
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
44
73
 
@@ -61,6 +90,10 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
61
90
  self._action_mask_fn = action_mask_fn
62
91
 
63
92
  def get_state(self) -> Any:
93
+ """
94
+ Returns the current state of the environment as a deepcopy of the environment.
95
+ :return: a deepcopy of the environment
96
+ """
64
97
  log.debug("getting state")
65
98
  original_state = self
66
99
  copied_state = copy.deepcopy(self)
@@ -71,9 +104,19 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
71
104
  return copied_state
72
105
 
73
106
  def action_masks(self) -> np.ndarray | None:
107
+ """
108
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
109
+ :return: the action masks for the environment
110
+ """
74
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
75
112
 
76
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
117
+
118
+ :return: the list of valid actions
119
+ """
77
120
  if self._action_mask_fn is None:
78
121
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
79
122
  return list(range(action_space.n))
@@ -83,6 +126,14 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
83
126
  def step(
84
127
  self, action: WrapperActType
85
128
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
129
+ """
130
+ Performs a step in the environment.
131
+ This method is used to update the wrapper with the new state and the new action, to realize the terminal state
132
+ functionality.
133
+
134
+ :param action: action to perform in the environment
135
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
136
+ """
86
137
  step_tuple = self.env.step(action)
87
138
 
88
139
  obs, reward, terminated, truncated, info = step_tuple
@@ -93,6 +144,12 @@ class NaiveSoloMCTSGymEnvWrapper(SoloMCTSGymEnv, gym.Wrapper):
93
144
 
94
145
 
95
146
  def rollout(self) -> float:
147
+ """
148
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
149
+ rollout.
150
+
151
+ :return: the return of the rollout
152
+ """
96
153
  log.debug("performing rollout")
97
154
  # random rollout
98
155
  # perform random valid action util terminal