gymcts 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -106,6 +106,18 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
106
106
 
107
107
 
108
108
  def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
109
+ """
110
+ Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
111
+ This function is used to color code the available actions in the MCTS tree visualisation.
112
+ The children of the MCTS tree are colored based on their action for a clearer visualisation.
113
+
114
+ :param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
115
+ :param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
116
+ :param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
117
+ :param c_map: the colormap to be used (default is 'rainbow').
118
+ The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
119
+ :return: a string that contains the color-codes (prefix and suffix) and the string s in between.
120
+ """
109
121
  if s is None or n_of_item is None or n_classes is None:
110
122
  return s
111
123
 
@@ -119,6 +131,16 @@ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rain
119
131
 
120
132
 
121
133
  def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
134
+ """
135
+ Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
136
+
137
+ :param s: the string to be wrapped
138
+ :param value: the value to be mapped to a color
139
+ :param min_val: the minimum value of the scale
140
+ :param max_val: the maximum value of the scale
141
+ :param c_map: the colormap to be used (default is 'rainbow')
142
+ :return:
143
+ """
122
144
  if s is None or min_val is None or max_val is None or min_val >= max_val:
123
145
  return s
124
146
 
@@ -1,8 +1,7 @@
1
1
  import random
2
- import copy
3
2
 
4
3
  import numpy as np
5
- from typing import TypeVar, Any, SupportsFloat, Callable
4
+ from typing import Any, SupportsFloat, Callable
6
5
  import gymnasium as gym
7
6
  from gymnasium.core import WrapperActType, WrapperObsType
8
7
  from gymnasium.wrappers import RecordEpisodeStatistics
@@ -13,6 +12,21 @@ from gymcts.logger import log
13
12
 
14
13
 
15
14
  class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
15
+ """
16
+ A wrapper for gym environments that implements the GymctsABC interface.
17
+ It uses the action history as state representation.
18
+ Please note that this is not the most efficient way to implement the state representation.
19
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
20
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
21
+ The action history is a list of actions taken in the environment.
22
+ The state is represented as a list of actions taken in the environment.
23
+ The state is used to restore the environment using the load_state method.
24
+
25
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
26
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
27
+ """
28
+
29
+ # helper attributes for the wrapper
16
30
  _terminal_flag: bool = False
17
31
  _last_reward: SupportsFloat = 0
18
32
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -25,6 +39,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
25
39
  action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
26
40
  buffer_length: int = 100,
27
41
  ):
42
+ """
43
+ A wrapper for gym environments that implements the GymctsABC interface.
44
+ It uses the action history as state representation.
45
+ Please note that this is not the most efficient way to implement the state representation.
46
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
47
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
48
+
49
+ :param env: the environment to wrap
50
+ :param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
51
+ :param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
52
+ """
28
53
  # wrap with RecordEpisodeStatistics if it is not already wrapped
29
54
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
30
55
 
@@ -48,6 +73,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
48
73
  self._action_mask_fn = action_mask_fn
49
74
 
50
75
  def load_state(self, state: list[int]) -> None:
76
+ """
77
+ Loads the state of the environment. The state is a list of actions taken in the environment.
78
+
79
+ The environment is reset and all actions in the state are performed in order to restore the environment to the
80
+ same state.
81
+
82
+ This works only for deterministic environments!
83
+
84
+ :param state: the state to load
85
+ :return: None
86
+ """
51
87
  self.env.reset()
52
88
  self._wrapper_action_history = []
53
89
 
@@ -56,15 +92,30 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
56
92
  self._wrapper_action_history.append(action)
57
93
 
58
94
  def is_terminal(self) -> bool:
95
+ """
96
+ Returns True if the environment is in a terminal state, False otherwise.
97
+
98
+ :return:
99
+ """
59
100
  if not len(self.get_valid_actions()):
60
101
  return True
61
102
  else:
62
103
  return self._terminal_flag
63
104
 
64
105
  def action_masks(self) -> np.ndarray | None:
106
+ """
107
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
108
+
109
+ :return:
110
+ """
65
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
66
112
 
67
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+
117
+ :return: a list of valid actions
118
+ """
68
119
  if self._action_mask_fn is None:
69
120
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
70
121
  return list(range(action_space.n))
@@ -72,6 +123,12 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
72
123
  return [i for i, mask in enumerate(self.action_masks()) if mask]
73
124
 
74
125
  def rollout(self) -> float:
126
+ """
127
+ Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
128
+ of the rollout.
129
+
130
+ :return: the return of the rollout
131
+ """
75
132
  log.debug("performing rollout")
76
133
  # random rollout
77
134
  # perform random valid action util terminal
@@ -92,11 +149,24 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
92
149
  return episode_return
93
150
 
94
151
  def get_state(self) -> list[int]:
152
+ """
153
+ Returns the current state of the environment. The state is a list of actions taken in the environment,
154
+ namely all action that have been taken in the environment so far (since the last reset).
155
+
156
+ :return: a list of actions taken in the environment
157
+ """
158
+
95
159
  return self._wrapper_action_history.copy()
96
160
 
97
161
  def step(
98
162
  self, action: WrapperActType
99
163
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
164
+ """
165
+ Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
166
+
167
+ :param action: action to perform in the environment
168
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
169
+ """
100
170
  step_tuple = self.env.step(action)
101
171
  self._wrapper_action_history.append(action)
102
172
  obs, reward, terminated, truncated, info = step_tuple
gymcts/gymcts_agent.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import random
2
3
  import gymnasium as gym
3
4
 
4
5
  from typing import TypeVar, Any, SupportsFloat, Callable
@@ -63,7 +64,10 @@ class GymctsAgent:
63
64
  # NAVIGATION STRATEGY
64
65
  # select child with highest UCB score
65
66
  while not temp_node.is_leaf():
66
- temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score())
67
+ children = list(temp_node.children.values())
68
+ max_ucb_score = max(child.ucb_score() for child in children)
69
+ best_children = [child for child in children if child.ucb_score() == max_ucb_score]
70
+ temp_node = random.choice(best_children)
67
71
  log.debug(f"Selected leaf node: {temp_node}")
68
72
  return temp_node
69
73
 
@@ -13,8 +13,15 @@ from gymcts.logger import log
13
13
 
14
14
 
15
15
  class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
16
-
17
-
16
+ """
17
+ A wrapper for gym environments that implements the GymctsABC interface.
18
+ It uses deepcopys as state representation.
19
+ Please note that this is not the most efficient way to implement the state representation.
20
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
21
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
22
+ """
23
+
24
+ # helper attributes for the wrapper
18
25
  _terminal_flag:bool = False
19
26
  _last_reward: SupportsFloat = 0
20
27
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -22,9 +29,21 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
22
29
  _action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
23
30
 
24
31
  def is_terminal(self) -> bool:
32
+ """
33
+ Returns True if the environment is in a terminal state, False otherwise.
34
+
35
+ :return: True if the environment is in a terminal state, False otherwise.
36
+ """
25
37
  return self._terminal_flag
26
38
 
27
39
  def load_state(self, state: Any) -> None:
40
+ """
41
+ The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
42
+ provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
43
+
44
+ :param state: a deepcopy of the environment
45
+ :return: None
46
+ """
28
47
  msg = """
29
48
  The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
30
49
  The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
@@ -39,6 +58,16 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
39
58
  buffer_length: int = 100,
40
59
  record_video: bool = False,
41
60
  ):
61
+ """
62
+ The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
63
+ space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
64
+ find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
65
+
66
+ :param env: the environment to wrap
67
+ :param action_mask_fn:
68
+ :param buffer_length:
69
+ :param record_video:
70
+ """
42
71
  # wrap with RecordEpisodeStatistics if it is not already wrapped
43
72
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
44
73
 
@@ -61,6 +90,10 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
61
90
  self._action_mask_fn = action_mask_fn
62
91
 
63
92
  def get_state(self) -> Any:
93
+ """
94
+ Returns the current state of the environment as a deepcopy of the environment.
95
+ :return: a deepcopy of the environment
96
+ """
64
97
  log.debug("getting state")
65
98
  original_state = self
66
99
  copied_state = copy.deepcopy(self)
@@ -71,9 +104,19 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
71
104
  return copied_state
72
105
 
73
106
  def action_masks(self) -> np.ndarray | None:
107
+ """
108
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
109
+ :return: the action masks for the environment
110
+ """
74
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
75
112
 
76
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
117
+
118
+ :return: the list of valid actions
119
+ """
77
120
  if self._action_mask_fn is None:
78
121
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
79
122
  return list(range(action_space.n))
@@ -83,6 +126,14 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
83
126
  def step(
84
127
  self, action: WrapperActType
85
128
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
129
+ """
130
+ Performs a step in the environment.
131
+ This method is used to update the wrapper with the new state and the new action, to realize the terminal state
132
+ functionality.
133
+
134
+ :param action: action to perform in the environment
135
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
136
+ """
86
137
  step_tuple = self.env.step(action)
87
138
 
88
139
  obs, reward, terminated, truncated, info = step_tuple
@@ -93,6 +144,12 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
93
144
 
94
145
 
95
146
  def rollout(self) -> float:
147
+ """
148
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
149
+ rollout.
150
+
151
+ :return: the return of the rollout
152
+ """
96
153
  log.debug("performing rollout")
97
154
  # random rollout
98
155
  # perform random valid action util terminal
@@ -118,6 +118,7 @@ class DistributedGymctsAgent:
118
118
  render_tree_after_step: bool = False,
119
119
  render_tree_max_depth: int = 2,
120
120
  num_parallel: int = 4,
121
+ clear_mcts_tree_after_step: bool = False,
121
122
  number_of_simulations_per_step: int = 25,
122
123
  exclude_unvisited_nodes_from_render: bool = False
123
124
  ):
@@ -134,6 +135,7 @@ class DistributedGymctsAgent:
134
135
  self.number_of_simulations_per_step = number_of_simulations_per_step
135
136
 
136
137
  self.env = env
138
+ self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
137
139
 
138
140
  self.search_root_node = GymctsNode(
139
141
  action=None,
@@ -206,6 +208,8 @@ class DistributedGymctsAgent:
206
208
  ready_node = ray.get(ready_node_ref)
207
209
 
208
210
  # merge the tree
211
+ if not self.clear_mcts_tree_after_step:
212
+ self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
209
213
  search_start_node = merge_nodes(search_start_node, ready_node)
210
214
 
211
215
  action = search_start_node.get_best_action()
@@ -217,22 +221,34 @@ class DistributedGymctsAgent:
217
221
  tree_max_depth=self.render_tree_max_depth
218
222
  )
219
223
 
220
-
221
- # to clear memory we need to remove all nodes except the current node
222
- # this is done by setting the root node to the current node
223
- # and setting the parent of the current node to None
224
- # we also need to reset the children of the current node
225
- # this is done by calling the reset method
226
- #
227
- # in a distributed setting we need we delete all previous nodes
228
- # this is because backpropagation merging trees is already computationally expensive
229
- # and backpropagating the whole tree would be even more expensive
230
- next_node.reset()
224
+ if self.clear_mcts_tree_after_step:
225
+ # to clear memory we need to remove all nodes except the current node
226
+ # this is done by setting the root node to the current node
227
+ # and setting the parent of the current node to None
228
+ # we also need to reset the children of the current node
229
+ # this is done by calling the reset method
230
+ next_node.reset()
231
231
 
232
232
  self.search_root_node = next_node
233
233
 
234
234
  return action, next_node
235
235
 
236
+ def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
237
+ log.debug(f"performing backpropagation from leaf node: {node}")
238
+ while not node.is_root():
239
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
240
+ node.visit_count + num_episodes)
241
+ node.visit_count += num_episodes
242
+ node.max_value = max(node.max_value, average_episode_return)
243
+ node.min_value = min(node.min_value, average_episode_return)
244
+ node = node.parent
245
+ # also update root node
246
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
247
+ node.visit_count + num_episodes)
248
+ node.visit_count += num_episodes
249
+ node.max_value = max(node.max_value, average_episode_return)
250
+ node.min_value = min(node.min_value, average_episode_return)
251
+
236
252
  def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
237
253
 
238
254
  if start_node is None:
@@ -268,7 +284,7 @@ if __name__ == '__main__':
268
284
  agent1 = DistributedGymctsAgent(
269
285
  env=env,
270
286
  render_tree_after_step=True,
271
- number_of_simulations_per_step=1000,
287
+ number_of_simulations_per_step=10,
272
288
  exclude_unvisited_nodes_from_render=True,
273
289
  num_parallel=1,
274
290
  )
@@ -278,4 +294,6 @@ if __name__ == '__main__':
278
294
  actions = agent1.solve()
279
295
  end_time = time.perf_counter()
280
296
 
297
+ agent1.show_mcts_tree_from_root()
298
+
281
299
  print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
gymcts/gymcts_env_abc.py CHANGED
@@ -9,20 +9,53 @@ class GymctsABC(ABC, gym.Env):
9
9
 
10
10
  @abstractmethod
11
11
  def get_state(self) -> Any:
12
+ """
13
+ Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
14
+ the environment to the same state. The state is used to restore the environment unsing the load_state method.
15
+
16
+ It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
17
+
18
+ :return: the current state of the environment
19
+ """
12
20
  pass
13
21
 
14
22
  @abstractmethod
15
23
  def load_state(self, state: Any) -> None:
24
+ """
25
+ Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
26
+ environment to the same state. The state is used to restore the environment unsing the load_state method.
27
+
28
+ :param state: the state to load
29
+ :return: None
30
+ """
16
31
  pass
17
32
 
18
33
  @abstractmethod
19
34
  def is_terminal(self) -> bool:
35
+ """
36
+ Returns True if the environment is in a terminal state, False otherwise.
37
+ :return:
38
+ """
20
39
  pass
21
40
 
22
41
  @abstractmethod
23
42
  def get_valid_actions(self) -> list[int]:
43
+ """
44
+ Returns a list of valid actions for the current state of the environment.
45
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
46
+ :return: the list of valid actions
47
+ """
24
48
  pass
25
49
 
26
50
  @abstractmethod
27
51
  def rollout(self) -> float:
52
+ """
53
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
54
+
55
+ Please make sure the return value is in the interval [-1, 1].
56
+ Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
57
+ exploration and exploitation are not well-balanced then).
58
+
59
+ :return: the return of the rollout
60
+ """
28
61
  pass
gymcts/gymcts_node.py CHANGED
@@ -13,18 +13,32 @@ TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
13
13
 
14
14
  class GymctsNode:
15
15
  # static properties
16
- best_action_weight: float = 0.05
17
- ubc_c = 0.707
16
+ best_action_weight: float = 0.05 # weight for the best action
17
+ ubc_c = 0.707 # exploration coefficient
18
+
19
+
18
20
 
19
21
  # attributes
20
- visit_count: int = 0
21
- mean_value: float = 0
22
- max_value: float = -float("inf")
23
- min_value: float = +float("inf")
24
- terminal: bool = False
25
- state: Any
22
+ #
23
+ # Note these attributes are not static. Their defined here to give developers a hint what fields are available
24
+ # in the class. They are not static because they are not shared between instances of the class in scope of
25
+ # this library.
26
+ visit_count: int = 0 # number of times the node has been visited
27
+ mean_value: float = 0 # mean value of the node
28
+ max_value: float = -float("inf") # maximum value of the node
29
+ min_value: float = +float("inf") # minimum value of the node
30
+ terminal: bool = False # whether the node is terminal or not
31
+ state: Any = None # state of the node
26
32
 
27
33
  def __str__(self, colored=False, action_space_n=None) -> str:
34
+ """
35
+ Returns a string representation of the node. The string representation is used for visualisation purposes.
36
+ It is used for example in the mcts tree visualisation functionality.
37
+
38
+ :param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
39
+ :param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
40
+ :return: a potentially colored string representation of the node.
41
+ """
28
42
  if not colored:
29
43
 
30
44
  if not self.is_root():
@@ -72,22 +86,44 @@ class GymctsNode:
72
86
  (f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
73
87
 
74
88
  def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
89
+ """
90
+ Traverse the tree and yield all nodes in the tree.
91
+
92
+ :return: a generator that yields all nodes in the tree.
93
+ """
75
94
  yield self
76
95
  if self.children:
77
96
  for child in self.children.values():
78
97
  yield from child.traverse_nodes()
79
98
 
80
99
  def get_root(self) -> TGymctsNode:
100
+ """
101
+ Returns the root node of the tree. The root node is the node that has no parent.
102
+
103
+ :return: the root node of the tree.
104
+ """
81
105
  if self.is_root():
82
106
  return self
83
107
  return self.parent.get_root()
84
108
 
85
109
  def max_tree_depth(self):
110
+ """
111
+ Returns the maximum depth of the tree. The depth of a node is the number of edges from
112
+ the node to the root node.
113
+
114
+ :return: the maximum depth of the tree.
115
+ """
86
116
  if self.is_leaf():
87
117
  return 0
88
118
  return 1 + max(child.max_tree_depth() for child in self.children.values())
89
119
 
90
120
  def n_children_recursively(self):
121
+ """
122
+ Returns the number of children of the node recursively. The number of children of a node is the number of
123
+ children of the node plus the number of children of all children of the node.
124
+
125
+ :return: the number of children of the node recursively.
126
+ """
91
127
  if self.is_leaf():
92
128
  return 0
93
129
  return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
@@ -97,6 +133,14 @@ class GymctsNode:
97
133
  parent: TGymctsNode | None,
98
134
  env_reference: GymctsABC,
99
135
  ):
136
+ """
137
+ Initializes the node. The node is initialized with the state of the environment and the action that was taken to
138
+ reach the node. The node is also initialized with the parent node and the environment reference.
139
+
140
+ :param action: the action that was taken to reach the node. If the node is a root node, this parameter is None.
141
+ :param parent: the parent node of the node. If the node is a root node, this parameter is None.
142
+ :param env_reference: a reference to the environment. The environment is used to get the state of the node and the valid actions.
143
+ """
100
144
 
101
145
  # field depending on whether the node is a root node or not
102
146
  self.action: int | None
@@ -149,21 +193,49 @@ class GymctsNode:
149
193
  self.parent.reset()
150
194
 
151
195
  def is_root(self) -> bool:
196
+ """
197
+ Returns true if the node is a root node. A root node is a node that has no parent.
198
+
199
+ :return: true if the node is a root node, false otherwise.
200
+ """
152
201
  return self.parent is None
153
202
 
154
203
  def is_leaf(self) -> bool:
204
+ """
205
+ Returns true if the node is a leaf node. A leaf node is a node that has no children. A leaf node is a node that has no children.
206
+
207
+ :return: true if the node is a leaf node, false otherwise.
208
+ """
155
209
  return self.children is None or len(self.children) == 0
156
210
 
157
211
  def get_random_child(self) -> TGymctsNode:
212
+ """
213
+ Returns a random child of the node. A random child is a child that is selected randomly from the list of children.
214
+ :return:
215
+ """
158
216
  if self.is_leaf():
159
217
  raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
160
218
 
161
219
  return list(self.children.values())[random.randint(0, len(self.children) - 1)]
162
220
 
163
221
  def get_best_action(self) -> int:
222
+ """
223
+ Returns the best action of the node. The best action is the action that has the highest score.
224
+ The score is calculated using the get_score() method. The best action is the action that has the highest score.
225
+ The best action is the action that has the highest score.
226
+
227
+ :return: the best action of the node.
228
+ """
164
229
  return max(self.children.values(), key=lambda child: child.get_score()).action
165
230
 
166
231
  def get_score(self) -> float: # todo: make it an attribute?
232
+ """
233
+ Returns the score of the node. The score is calculated using the mean value and the maximum value of the node.
234
+ The score is calculated using the formula: score = (1 - a) * mean_value + a * max_value
235
+ where a is the best action weight.
236
+
237
+ :return: the score of the node.
238
+ """
167
239
  # return self.mean_value
168
240
  assert 0 <= GymctsNode.best_action_weight <= 1
169
241
  a = GymctsNode.best_action_weight
@@ -173,6 +245,11 @@ class GymctsNode:
173
245
  return self.mean_value
174
246
 
175
247
  def get_max_value(self) -> float:
248
+ """
249
+ Returns the maximum value of the node. The maximum value is the maximum value of the node.
250
+
251
+ :return: the maximum value of the node.
252
+ """
176
253
  return self.max_value
177
254
 
178
255
  def ucb_score(self):
@@ -1,3 +1,5 @@
1
+ from typing import Any, Generator
2
+
1
3
  from gymcts.gymcts_node import GymctsNode
2
4
 
3
5
  from gymcts.logger import log
@@ -9,7 +11,19 @@ def _generate_mcts_tree(
9
11
  depth: int = None,
10
12
  exclude_unvisited_nodes_from_render: bool = True,
11
13
  action_space_n: int = None
12
- ) -> list[str]:
14
+ ) -> Generator[str, Any | None, None]:
15
+ """
16
+ Generates a tree representation of the MCTS tree starting from the given node.
17
+
18
+ This is a recursive function that generates a tree representation of the MCTS tree starting from the given node. The
19
+
20
+ :param start_node: the node to start from
21
+ :param prefix: used to format the tree
22
+ :param depth: used to limit the depth of the tree
23
+ :param exclude_unvisited_nodes_from_render: used to exclude unvisited nodes from the render
24
+ :param action_space_n: the number of actions in the action space
25
+ :return: a list of strings representing the tree
26
+ """
13
27
  if prefix is None:
14
28
  prefix = ""
15
29
  import gymcts.colorful_console_utils as ccu
@@ -70,6 +84,13 @@ def show_mcts_tree(
70
84
  tree_max_depth: int = None,
71
85
  action_space_n: int = None
72
86
  ) -> None:
87
+ """
88
+ Renders the MCTS tree starting from the given node.
89
+
90
+ :param start_node: the node to start from
91
+ :param tree_max_depth: the maximum depth of the tree to render
92
+ :param action_space_n: the number of actions in the action space
93
+ """
73
94
  print(start_node.__str__(colored=True, action_space_n=action_space_n))
74
95
  for line in _generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
75
96
  print(line)
gymcts/logger.py CHANGED
@@ -18,10 +18,7 @@ banner_sw = f"""
18
18
  ▟█▛ ▜██▛ ▟█▛██▛██▛▟█▛ ▟█▛ ▜███▙
19
19
  ▟█▛ ▟█▛ ▟█▛ ▟█▛ ▟█▛▟█▛ ▟█▛ ▟█▛
20
20
  ▜████▛ ▟█▛ ▟█▛ ▟█▛ ▜████▛ ▟█▛ ▟████▛
21
-
22
-
23
-
24
-
21
+
25
22
  """
26
23
 
27
24
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.2.0
3
+ Version: 1.2.1
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -47,7 +47,7 @@ Requires-Dist: graph-matrix-jsp-env; extra == "examples"
47
47
  Requires-Dist: graph-jsp-env; extra == "examples"
48
48
  Provides-Extra: dev
49
49
  Requires-Dist: jsp-instance-utils; extra == "dev"
50
- Requires-Dist: graph-matrix-jsp-env; extra == "dev"
50
+ Requires-Dist: graph-matrix-jsp-env>=0.3.0; extra == "dev"
51
51
  Requires-Dist: graph-jsp-env; extra == "dev"
52
52
  Requires-Dist: JSSEnv; extra == "dev"
53
53
  Requires-Dist: pip-tools; extra == "dev"
@@ -59,21 +59,24 @@ Requires-Dist: stable_baselines3; extra == "dev"
59
59
  Requires-Dist: sphinx; extra == "dev"
60
60
  Requires-Dist: myst-parser; extra == "dev"
61
61
  Requires-Dist: sphinx-autobuild; extra == "dev"
62
+ Requires-Dist: sphinx-copybutton; extra == "dev"
62
63
  Requires-Dist: furo; extra == "dev"
63
64
  Requires-Dist: twine; extra == "dev"
64
65
  Requires-Dist: sphinx-copybutton; extra == "dev"
65
66
  Requires-Dist: nbsphinx; extra == "dev"
67
+ Requires-Dist: pandoc; extra == "dev"
66
68
  Requires-Dist: jupytext; extra == "dev"
67
69
  Requires-Dist: jupyter; extra == "dev"
70
+ Requires-Dist: typing_extensions>=4.12.0; extra == "dev"
68
71
  Dynamic: license-file
69
72
 
70
73
  # Graph Matrix Job Shop Env
71
74
 
72
75
  A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
73
76
 
74
- - Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/GraphMatrixJobShopEnv)
75
- - Pypi: [GYMCTS on PyPi](https://pypi.org/project/graph-matrix-jsp-env/)
76
- - Documentation: [GYMCTS Docs](https://graphmatrixjobshopenv.readthedocs.io/en/latest/)
77
+ - Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
78
+ - Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
79
+ - Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
77
80
 
78
81
  ## Description
79
82
 
@@ -101,22 +104,26 @@ The usage of a MCTS agent can roughly organised into the following steps:
101
104
  - Render the solution
102
105
 
103
106
  The GYMCTS package provides a two types of wrappers for Gymnasium-style environments:
104
- - `NaiveSoloMCTSGymEnvWrapper`: A wrapper that uses deepcopies of the environment to save a snapshot of the environment state for each node in the MCTS tree.
105
- - `DeterministicSoloMCTSGymEnvWrapper`: A wrapper that saves the action sequence that lead to the current state in the MCTS node.
107
+ - `DeepCopyMCTSGymEnvWrapper`: A wrapper that uses deepcopies of the environment to save a snapshot of the environment state for each node in the MCTS tree.
108
+ - `ActionHistoryMCTSGymEnvWrapper`: A wrapper that saves the action sequence that lead to the current state in the MCTS node.
106
109
 
107
- These wrappers can be used with the `SoloMCTSAgent` to solve the environment.
108
- The wrapper implement methods that are required by the `SoloMCTSAgent` to interact with the environment.
110
+ These wrappers can be used with the `GymctsAgent` to solve the environment.
111
+ The wrapper implement methods that are required by the `GymctsAgent` to interact with the environment.
109
112
  GYMCTS is designed to use a single environment instance and reconstructing the environment state form a state snapshot, when needed.
110
113
 
111
114
  NOTE: MCTS works best when the return of an episode is in the range of [-1, 1]. Please adjust the reward function of the environment accordingly (or change the ubc-scaling parameter of the MCTS agent).
112
115
  Adjusting the reward function of the environment is easily done with a [NormalizeReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.NormalizeReward) or [TransformReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.TransformReward) Wrapper.
116
+ ```python
117
+ env = NormalizeReward(env, gamma=0.99, epsilon=1e-8)
118
+ ```
113
119
 
114
- NormalizeReward(env, gamma=0.99, epsilon=1e-8)
115
- env = TransformReward(env, lambda r: r / 36)
116
- ### FrozenLake Example (NaiveSoloMCTSGymEnvWrapper)
120
+ ```python
121
+ env = TransformReward(env, lambda r: r / n_steps_per_episode)
122
+ ```
123
+ ### FrozenLake Example (DeepCopyMCTSGymEnvWrapper)
117
124
 
118
125
  A minimal example of how to use the package with the FrozenLake environment and the NaiveSoloMCTSGymEnvWrapper is provided in the following code snippet below.
119
- The NaiveSoloMCTSGymEnvWrapper can be used with non-deterministic environments, such as the FrozenLake environment with slippery ice.
126
+ The DeepCopyMCTSGymEnvWrapper can be used with non-deterministic environments, such as the FrozenLake environment with slippery ice.
120
127
 
121
128
  ```python
122
129
  import gymnasium as gym
@@ -135,7 +142,7 @@ if __name__ == '__main__':
135
142
  env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True, render_mode="ansi")
136
143
  env.reset()
137
144
 
138
- # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
145
+ # 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
139
146
  env = DeepCopyMCTSGymEnvWrapper(env)
140
147
 
141
148
  # 2. create the agent
@@ -158,7 +165,7 @@ if __name__ == '__main__':
158
165
 
159
166
  # 5. print the solution
160
167
  # read the solution from the info provided by the RecordEpisodeStatistics wrapper
161
- # (that NaiveSoloMCTSGymEnvWrapper uses internally)
168
+ # (that DeepCopyMCTSGymEnvWrapper uses internally)
162
169
  episode_length = info["episode"]["l"]
163
170
  episode_return = info["episode"]["r"]
164
171
 
@@ -251,7 +258,7 @@ if __name__ == '__main__':
251
258
  env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="rgb_array")
252
259
  env.reset()
253
260
 
254
- # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
261
+ # 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
255
262
  env = DeepCopyMCTSGymEnvWrapper(env)
256
263
 
257
264
  # 2. create the agent
@@ -280,7 +287,7 @@ if __name__ == '__main__':
280
287
  env.close()
281
288
 
282
289
  # 5. print the solution
283
- # read the solution from the info provided by the RecordEpisodeStatistics wrapper (that NaiveSoloMCTSGymEnvWrapper wraps internally)
290
+ # read the solution from the info provided by the RecordEpisodeStatistics wrapper (that DeepCopyMCTSGymEnvWrapper wraps internally)
284
291
  episode_length = info["episode"]["l"]
285
292
  episode_return = info["episode"]["r"]
286
293
 
@@ -321,13 +328,13 @@ import gymnasium as gym
321
328
  from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
322
329
  from jsp_instance_utils.instances import ft06, ft06_makespan
323
330
 
324
- from gymcts.gymcts_agent import SoloMCTSAgent
325
- from gymcts.gymcts_gym_env import SoloMCTSGymEnv
331
+ from gymcts.gymcts_agent import GymctsAgent
332
+ from gymcts.gymcts_env_abc import GymctsABC
326
333
 
327
334
  from gymcts.logger import log
328
335
 
329
336
 
330
- class GraphJspGYMCTSWrapper(SoloMCTSGymEnv, gym.Wrapper):
337
+ class GraphJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
331
338
 
332
339
  def __init__(self, env: DisjunctiveGraphJspEnv):
333
340
  gym.Wrapper.__init__(self, env)
@@ -378,7 +385,7 @@ if __name__ == '__main__':
378
385
 
379
386
  env = GraphJspGYMCTSWrapper(env)
380
387
 
381
- agent = SoloMCTSAgent(
388
+ agent = GymctsAgent(
382
389
  env=env,
383
390
  clear_mcts_tree_after_step=True,
384
391
  render_tree_after_step=True,
@@ -421,7 +428,6 @@ import gymnasium as gym
421
428
 
422
429
  from gymcts.gymcts_agent import GymctsAgent
423
430
  from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
424
- from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
425
431
 
426
432
  from gymcts.logger import log
427
433
 
@@ -434,7 +440,7 @@ if __name__ == '__main__':
434
440
  env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="ansi")
435
441
  env.reset()
436
442
 
437
- # wrap the environment with the naive wrapper or a custom gymcts wrapper
443
+ # wrap the environment with the wrapper or a custom gymcts wrapper
438
444
  env = ActionHistoryMCTSGymEnvWrapper(env)
439
445
 
440
446
  # create the agent
@@ -505,11 +511,11 @@ clone the repository in your favorite code editor (for example PyCharm, VSCode,
505
511
 
506
512
  using https:
507
513
  ```shell
508
- git clone https://github.com/Alexander-Nasuta/todo
514
+ git clone https://github.com/Alexander-Nasuta/gymcts.git
509
515
  ```
510
516
  or by using the GitHub CLI:
511
517
  ```shell
512
- gh repo clone Alexander-Nasuta/todo
518
+ gh repo clone Alexander-Nasuta/gymcts
513
519
  ```
514
520
 
515
521
  if you are using PyCharm, I recommend doing the following additional steps:
@@ -518,9 +524,6 @@ if you are using PyCharm, I recommend doing the following additional steps:
518
524
  - mark the `tests` folder as test root (by right-clicking on the folder and selecting `Mark Directory as` -> `Test Sources Root`)
519
525
  - mark the `resources` folder as resources root (by right-clicking on the folder and selecting `Mark Directory as` -> `Resources Root`)
520
526
 
521
- at the end your project structure should look like this:
522
-
523
- todo
524
527
 
525
528
  ### Create a Virtual Environment (optional)
526
529
 
@@ -586,12 +589,6 @@ For testing with `tox` run the following command:
586
589
  tox
587
590
  ```
588
591
 
589
- Here is a screenshot of what the output might look like:
590
-
591
- ![](https://github.com/Alexander-Nasuta/GraphMatrixJobShopEnv/raw/master/resources/tox-screenshot.png)
592
-
593
- Tox will run the tests in a separate environment and will also check if the requirements are installed correctly.
594
-
595
592
  ### Builing and Publishing the Project to PyPi
596
593
 
597
594
  In order to publish the project to PyPi, the project needs to be built and then uploaded to PyPi.
@@ -630,7 +627,6 @@ sphinx-autobuild ./docs/source/ ./docs/build/html/
630
627
  This project features most of the extensions featured in this Tutorial: [Document Your Scientific Project With Markdown, Sphinx, and Read the Docs | PyData Global 2021](https://www.youtube.com/watch?v=qRSb299awB0).
631
628
 
632
629
 
633
-
634
630
  ## Contact
635
631
 
636
632
  If you have any questions or feedback, feel free to contact me via [email](mailto:alexander.nasuta@wzl-iqs.rwth-aachen.de) or open an issue on repository.
@@ -0,0 +1,15 @@
1
+ gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ gymcts/colorful_console_utils.py,sha256=n7nymC8kKZnA_8nXcdn201NAzjZjgEHfKpbBcnl4oAE,5891
3
+ gymcts/gymcts_action_history_wrapper.py,sha256=7-p17Fgb80SRCBaCm6G8SJrEPsl2Y4aIO3InviuQP08,6993
4
+ gymcts/gymcts_agent.py,sha256=f2imP-Wv-E7EYE0-iWd86hY9cx-rqHZMlDusp-aE-ps,8698
5
+ gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
6
+ gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
7
+ gymcts/gymcts_env_abc.py,sha256=U1mPz0NWZZL1sdHX7oUP1UFKtmbHwyqHQOQidyh_Uck,2107
8
+ gymcts/gymcts_node.py,sha256=pxjY2Zb0kPuFQ5mWEs0ct3qXoyB47NZK7h2ZGbLJbRA,11052
9
+ gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
10
+ gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
11
+ gymcts-1.2.1.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
12
+ gymcts-1.2.1.dist-info/METADATA,sha256=wUJEcWrAvdC42kl59qewCN5tK3DKMLxGWcCipnOX4pQ,23371
13
+ gymcts-1.2.1.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
14
+ gymcts-1.2.1.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
15
+ gymcts-1.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.2)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,15 +0,0 @@
1
- gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- gymcts/colorful_console_utils.py,sha256=OhULcXHKbEA4uJDAEYCTcW6wUv0LsHX_XSYzZ_Szsv4,4553
3
- gymcts/gymcts_action_history_wrapper.py,sha256=AjvBBwd1t9-nTYP09aMdlScAkFNXf5vOagejpjWYOPo,3810
4
- gymcts/gymcts_agent.py,sha256=O2y98jKFjR5TzqVV7DO1jlcYDyzAgd_H2RF4-w4NP0g,8499
5
- gymcts/gymcts_deepcopy_wrapper.py,sha256=OleQTnvxv3gLEo8-2asyeo-CpZ4HEbgyFGS5DTCD7NM,4167
6
- gymcts/gymcts_distributed_agent.py,sha256=M7dyBfC8u3M99PJFoXKgIc_CPTyHGppmktkH-y9ci4U,10448
7
- gymcts/gymcts_env_abc.py,sha256=7nCRiiClmmVLX-d_Q1dxeztmuvmAtmWZwjT81zrG1_w,575
8
- gymcts/gymcts_node.py,sha256=PT_YZFwt1zjuvd8i9Wb5LEkHAqmJOFyPDp3GFD05lqM,7138
9
- gymcts/gymcts_tree_plotter.py,sha256=eg207wHcDepwWODXzmDYQn1Aai29Cs4jFS1HNvAhlXs,2651
10
- gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
11
- gymcts-1.2.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
12
- gymcts-1.2.0.dist-info/METADATA,sha256=zhEIFo0rOnv5hCv6ukImkq-9nshO4EfXMbHlhNlYhyA,23640
13
- gymcts-1.2.0.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
14
- gymcts-1.2.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
15
- gymcts-1.2.0.dist-info/RECORD,,