gymcts 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -106,6 +106,18 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
106
106
 
107
107
 
108
108
  def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
109
+ """
110
+ Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
111
+ This function is used to color code the available actions in the MCTS tree visualisation.
112
+ The children of the MCTS tree are colored based on their action for a clearer visualisation.
113
+
114
+ :param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
115
+ :param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
116
+ :param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
117
+ :param c_map: the colormap to be used (default is 'rainbow').
118
+ The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
119
+ :return: a string that contains the color-codes (prefix and suffix) and the string s in between.
120
+ """
109
121
  if s is None or n_of_item is None or n_classes is None:
110
122
  return s
111
123
 
@@ -119,6 +131,16 @@ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rain
119
131
 
120
132
 
121
133
  def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
134
+ """
135
+ Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
136
+
137
+ :param s: the string to be wrapped
138
+ :param value: the value to be mapped to a color
139
+ :param min_val: the minimum value of the scale
140
+ :param max_val: the maximum value of the scale
141
+ :param c_map: the colormap to be used (default is 'rainbow')
142
+ :return:
143
+ """
122
144
  if s is None or min_val is None or max_val is None or min_val >= max_val:
123
145
  return s
124
146
 
@@ -1,8 +1,7 @@
1
1
  import random
2
- import copy
3
2
 
4
3
  import numpy as np
5
- from typing import TypeVar, Any, SupportsFloat, Callable
4
+ from typing import Any, SupportsFloat, Callable
6
5
  import gymnasium as gym
7
6
  from gymnasium.core import WrapperActType, WrapperObsType
8
7
  from gymnasium.wrappers import RecordEpisodeStatistics
@@ -13,6 +12,21 @@ from gymcts.logger import log
13
12
 
14
13
 
15
14
  class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
15
+ """
16
+ A wrapper for gym environments that implements the GymctsABC interface.
17
+ It uses the action history as state representation.
18
+ Please note that this is not the most efficient way to implement the state representation.
19
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
20
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
21
+ The action history is a list of actions taken in the environment.
22
+ The state is represented as a list of actions taken in the environment.
23
+ The state is used to restore the environment using the load_state method.
24
+
25
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
26
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
27
+ """
28
+
29
+ # helper attributes for the wrapper
16
30
  _terminal_flag: bool = False
17
31
  _last_reward: SupportsFloat = 0
18
32
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -25,6 +39,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
25
39
  action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
26
40
  buffer_length: int = 100,
27
41
  ):
42
+ """
43
+ A wrapper for gym environments that implements the GymctsABC interface.
44
+ It uses the action history as state representation.
45
+ Please note that this is not the most efficient way to implement the state representation.
46
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
47
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
48
+
49
+ :param env: the environment to wrap
50
+ :param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
51
+ :param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
52
+ """
28
53
  # wrap with RecordEpisodeStatistics if it is not already wrapped
29
54
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
30
55
 
@@ -48,6 +73,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
48
73
  self._action_mask_fn = action_mask_fn
49
74
 
50
75
  def load_state(self, state: list[int]) -> None:
76
+ """
77
+ Loads the state of the environment. The state is a list of actions taken in the environment.
78
+
79
+ The environment is reset and all actions in the state are performed in order to restore the environment to the
80
+ same state.
81
+
82
+ This works only for deterministic environments!
83
+
84
+ :param state: the state to load
85
+ :return: None
86
+ """
51
87
  self.env.reset()
52
88
  self._wrapper_action_history = []
53
89
 
@@ -56,15 +92,30 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
56
92
  self._wrapper_action_history.append(action)
57
93
 
58
94
  def is_terminal(self) -> bool:
95
+ """
96
+ Returns True if the environment is in a terminal state, False otherwise.
97
+
98
+ :return:
99
+ """
59
100
  if not len(self.get_valid_actions()):
60
101
  return True
61
102
  else:
62
103
  return self._terminal_flag
63
104
 
64
105
  def action_masks(self) -> np.ndarray | None:
106
+ """
107
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
108
+
109
+ :return:
110
+ """
65
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
66
112
 
67
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+
117
+ :return: a list of valid actions
118
+ """
68
119
  if self._action_mask_fn is None:
69
120
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
70
121
  return list(range(action_space.n))
@@ -72,6 +123,12 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
72
123
  return [i for i, mask in enumerate(self.action_masks()) if mask]
73
124
 
74
125
  def rollout(self) -> float:
126
+ """
127
+ Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
128
+ of the rollout.
129
+
130
+ :return: the return of the rollout
131
+ """
75
132
  log.debug("performing rollout")
76
133
  # random rollout
77
134
  # perform random valid action util terminal
@@ -92,11 +149,24 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
92
149
  return episode_return
93
150
 
94
151
  def get_state(self) -> list[int]:
152
+ """
153
+ Returns the current state of the environment. The state is a list of actions taken in the environment,
154
+ namely all action that have been taken in the environment so far (since the last reset).
155
+
156
+ :return: a list of actions taken in the environment
157
+ """
158
+
95
159
  return self._wrapper_action_history.copy()
96
160
 
97
161
  def step(
98
162
  self, action: WrapperActType
99
163
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
164
+ """
165
+ Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
166
+
167
+ :param action: action to perform in the environment
168
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
169
+ """
100
170
  step_tuple = self.env.step(action)
101
171
  self._wrapper_action_history.append(action)
102
172
  obs, reward, terminated, truncated, info = step_tuple
gymcts/gymcts_agent.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import copy
2
+ import random
2
3
  import gymnasium as gym
3
4
 
4
- from typing import TypeVar, Any, SupportsFloat, Callable
5
+ from typing import TypeVar, Any, SupportsFloat, Callable, Literal
5
6
 
6
7
  from gymcts.gymcts_env_abc import GymctsABC
7
8
  from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
@@ -10,7 +11,9 @@ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
10
11
 
11
12
  from gymcts.logger import log
12
13
 
13
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
14
+
15
+
16
+
14
17
 
15
18
 
16
19
  class GymctsAgent:
@@ -23,17 +26,50 @@ class GymctsAgent:
23
26
  search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
24
27
  clear_mcts_tree_after_step: bool
25
28
 
29
+
30
+ # (num_simulations: int, step_idx: int) -> int
31
+ @staticmethod
32
+ def calc_number_of_simulations_per_step(num_simulations: int, step_idx: int) -> int:
33
+ """
34
+ A function that returns a constant number of simulations per step.
35
+
36
+ :param num_simulations: The number of simulations to return.
37
+ :param step_idx: The current step index (not used in this function).
38
+ :return: A callable that takes an environment as input and returns the constant number of simulations.
39
+ """
40
+ return num_simulations
41
+
26
42
  def __init__(self,
27
43
  env: GymctsABC,
28
44
  clear_mcts_tree_after_step: bool = True,
29
45
  render_tree_after_step: bool = False,
30
46
  render_tree_max_depth: int = 2,
31
47
  number_of_simulations_per_step: int = 25,
32
- exclude_unvisited_nodes_from_render: bool = False
48
+ exclude_unvisited_nodes_from_render: bool = False,
49
+ calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
50
+ score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
51
+ best_action_weight=None,
33
52
  ):
34
53
  # check if action space of env is discrete
35
54
  if not isinstance(env.action_space, gym.spaces.Discrete):
36
55
  raise ValueError("Action space must be discrete.")
56
+ if calc_number_of_simulations_per_step is not None:
57
+ # check if the provided function is callable
58
+ if not callable(calc_number_of_simulations_per_step):
59
+ raise ValueError("calc_number_of_simulations_per_step must be a callable accepting two arguments: num_simulations and step_idx.")
60
+ # assign the provided function to the attribute
61
+ # it needs to be staticmethod to be used as a class attribute
62
+ print("Using provided calc_number_of_simulations_per_step function.")
63
+ self.calc_number_of_simulations_per_step = staticmethod(calc_number_of_simulations_per_step)
64
+ if score_variate not in ["UCT_v0", "UCT_v1", "UCT_v2"]:
65
+ raise ValueError("score_variate must be one of ['UCT_v0', 'UCT_v1', 'UCT_v2'].")
66
+ GymctsNode.score_variate = score_variate
67
+
68
+ if best_action_weight is not None:
69
+ if best_action_weight < 0 or best_action_weight > 1:
70
+ raise ValueError("best_action_weight must be in range [0, 1].")
71
+ GymctsNode.best_action_weight = best_action_weight
72
+
37
73
 
38
74
  self.render_tree_after_step = render_tree_after_step
39
75
  self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
@@ -63,7 +99,10 @@ class GymctsAgent:
63
99
  # NAVIGATION STRATEGY
64
100
  # select child with highest UCB score
65
101
  while not temp_node.is_leaf():
66
- temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score())
102
+ children = list(temp_node.children.values())
103
+ max_ucb_score = max(child.tree_policy_score() for child in children)
104
+ best_children = [child for child in children if child.tree_policy_score() == max_ucb_score]
105
+ temp_node = random.choice(best_children)
67
106
  log.debug(f"Selected leaf node: {temp_node}")
68
107
  return temp_node
69
108
 
@@ -84,7 +123,6 @@ class GymctsAgent:
84
123
  parent=node,
85
124
  env_reference=self.env,
86
125
  )
87
-
88
126
  node.children = child_dict
89
127
 
90
128
  def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
@@ -100,13 +138,20 @@ class GymctsAgent:
100
138
 
101
139
  action_list = []
102
140
 
141
+ idx = 0
103
142
  while not current_node.terminal:
104
- next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
143
+ num_sims = self.calc_number_of_simulations_per_step(num_simulations_per_step, idx)
144
+
145
+ log.info(f"Performing MCTS step {idx} with {num_sims} simulations.")
146
+
147
+ next_action, current_node = self.perform_mcts_step(num_simulations=num_sims,
105
148
  render_tree_after_step=render_tree_after_step)
106
- log.info(f"selected action {next_action} after {num_simulations_per_step} simulations.")
149
+ log.info(f"selected action {next_action} after {num_sims} simulations.")
107
150
  action_list.append(next_action)
108
151
  log.info(f"current action list: {action_list}")
109
152
 
153
+ idx += 1
154
+
110
155
  log.info(f"Final action list: {action_list}")
111
156
  # restore state of current node
112
157
  return action_list
@@ -145,6 +190,8 @@ class GymctsAgent:
145
190
  # we also need to reset the children of the current node
146
191
  # this is done by calling the reset method
147
192
  next_node.reset()
193
+ else:
194
+ next_node.remove_parent()
148
195
 
149
196
  self.search_root_node = next_node
150
197
 
@@ -13,8 +13,15 @@ from gymcts.logger import log
13
13
 
14
14
 
15
15
  class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
16
-
17
-
16
+ """
17
+ A wrapper for gym environments that implements the GymctsABC interface.
18
+ It uses deepcopys as state representation.
19
+ Please note that this is not the most efficient way to implement the state representation.
20
+ It is supposed to be used to see if your use-case works well with the MCTS algorithm.
21
+ If it does, you can consider implementing all GymctsABC methods in a more efficient way.
22
+ """
23
+
24
+ # helper attributes for the wrapper
18
25
  _terminal_flag:bool = False
19
26
  _last_reward: SupportsFloat = 0
20
27
  _step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
@@ -22,9 +29,21 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
22
29
  _action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
23
30
 
24
31
  def is_terminal(self) -> bool:
32
+ """
33
+ Returns True if the environment is in a terminal state, False otherwise.
34
+
35
+ :return: True if the environment is in a terminal state, False otherwise.
36
+ """
25
37
  return self._terminal_flag
26
38
 
27
39
  def load_state(self, state: Any) -> None:
40
+ """
41
+ The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
42
+ provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
43
+
44
+ :param state: a deepcopy of the environment
45
+ :return: None
46
+ """
28
47
  msg = """
29
48
  The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
30
49
  The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
@@ -39,6 +58,16 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
39
58
  buffer_length: int = 100,
40
59
  record_video: bool = False,
41
60
  ):
61
+ """
62
+ The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
63
+ space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
64
+ find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
65
+
66
+ :param env: the environment to wrap
67
+ :param action_mask_fn:
68
+ :param buffer_length:
69
+ :param record_video:
70
+ """
42
71
  # wrap with RecordEpisodeStatistics if it is not already wrapped
43
72
  env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
44
73
 
@@ -61,6 +90,10 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
61
90
  self._action_mask_fn = action_mask_fn
62
91
 
63
92
  def get_state(self) -> Any:
93
+ """
94
+ Returns the current state of the environment as a deepcopy of the environment.
95
+ :return: a deepcopy of the environment
96
+ """
64
97
  log.debug("getting state")
65
98
  original_state = self
66
99
  copied_state = copy.deepcopy(self)
@@ -71,9 +104,19 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
71
104
  return copied_state
72
105
 
73
106
  def action_masks(self) -> np.ndarray | None:
107
+ """
108
+ Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
109
+ :return: the action masks for the environment
110
+ """
74
111
  return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
75
112
 
76
113
  def get_valid_actions(self) -> list[int]:
114
+ """
115
+ Returns a list of valid actions for the current state of the environment.
116
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
117
+
118
+ :return: the list of valid actions
119
+ """
77
120
  if self._action_mask_fn is None:
78
121
  action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
79
122
  return list(range(action_space.n))
@@ -83,6 +126,14 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
83
126
  def step(
84
127
  self, action: WrapperActType
85
128
  ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
129
+ """
130
+ Performs a step in the environment.
131
+ This method is used to update the wrapper with the new state and the new action, to realize the terminal state
132
+ functionality.
133
+
134
+ :param action: action to perform in the environment
135
+ :return: the step tuple of the environment (obs, reward, terminated, truncated, info)
136
+ """
86
137
  step_tuple = self.env.step(action)
87
138
 
88
139
  obs, reward, terminated, truncated, info = step_tuple
@@ -93,6 +144,12 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
93
144
 
94
145
 
95
146
  def rollout(self) -> float:
147
+ """
148
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
149
+ rollout.
150
+
151
+ :return: the return of the rollout
152
+ """
96
153
  log.debug("performing rollout")
97
154
  # random rollout
98
155
  # perform random valid action util terminal
@@ -118,6 +118,7 @@ class DistributedGymctsAgent:
118
118
  render_tree_after_step: bool = False,
119
119
  render_tree_max_depth: int = 2,
120
120
  num_parallel: int = 4,
121
+ clear_mcts_tree_after_step: bool = False,
121
122
  number_of_simulations_per_step: int = 25,
122
123
  exclude_unvisited_nodes_from_render: bool = False
123
124
  ):
@@ -134,6 +135,7 @@ class DistributedGymctsAgent:
134
135
  self.number_of_simulations_per_step = number_of_simulations_per_step
135
136
 
136
137
  self.env = env
138
+ self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
137
139
 
138
140
  self.search_root_node = GymctsNode(
139
141
  action=None,
@@ -206,6 +208,8 @@ class DistributedGymctsAgent:
206
208
  ready_node = ray.get(ready_node_ref)
207
209
 
208
210
  # merge the tree
211
+ if not self.clear_mcts_tree_after_step:
212
+ self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
209
213
  search_start_node = merge_nodes(search_start_node, ready_node)
210
214
 
211
215
  action = search_start_node.get_best_action()
@@ -217,22 +221,34 @@ class DistributedGymctsAgent:
217
221
  tree_max_depth=self.render_tree_max_depth
218
222
  )
219
223
 
220
-
221
- # to clear memory we need to remove all nodes except the current node
222
- # this is done by setting the root node to the current node
223
- # and setting the parent of the current node to None
224
- # we also need to reset the children of the current node
225
- # this is done by calling the reset method
226
- #
227
- # in a distributed setting we need we delete all previous nodes
228
- # this is because backpropagation merging trees is already computationally expensive
229
- # and backpropagating the whole tree would be even more expensive
230
- next_node.reset()
224
+ if self.clear_mcts_tree_after_step:
225
+ # to clear memory we need to remove all nodes except the current node
226
+ # this is done by setting the root node to the current node
227
+ # and setting the parent of the current node to None
228
+ # we also need to reset the children of the current node
229
+ # this is done by calling the reset method
230
+ next_node.reset()
231
231
 
232
232
  self.search_root_node = next_node
233
233
 
234
234
  return action, next_node
235
235
 
236
+ def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
237
+ log.debug(f"performing backpropagation from leaf node: {node}")
238
+ while not node.is_root():
239
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
240
+ node.visit_count + num_episodes)
241
+ node.visit_count += num_episodes
242
+ node.max_value = max(node.max_value, average_episode_return)
243
+ node.min_value = min(node.min_value, average_episode_return)
244
+ node = node.parent
245
+ # also update root node
246
+ node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
247
+ node.visit_count + num_episodes)
248
+ node.visit_count += num_episodes
249
+ node.max_value = max(node.max_value, average_episode_return)
250
+ node.min_value = min(node.min_value, average_episode_return)
251
+
236
252
  def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
237
253
 
238
254
  if start_node is None:
@@ -268,7 +284,7 @@ if __name__ == '__main__':
268
284
  agent1 = DistributedGymctsAgent(
269
285
  env=env,
270
286
  render_tree_after_step=True,
271
- number_of_simulations_per_step=1000,
287
+ number_of_simulations_per_step=10,
272
288
  exclude_unvisited_nodes_from_render=True,
273
289
  num_parallel=1,
274
290
  )
@@ -278,4 +294,6 @@ if __name__ == '__main__':
278
294
  actions = agent1.solve()
279
295
  end_time = time.perf_counter()
280
296
 
297
+ agent1.show_mcts_tree_from_root()
298
+
281
299
  print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
gymcts/gymcts_env_abc.py CHANGED
@@ -1,28 +1,71 @@
1
1
  from typing import TypeVar, Any, SupportsFloat, Callable
2
2
  from abc import ABC, abstractmethod
3
3
  import gymnasium as gym
4
-
5
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
4
+ import numpy as np
6
5
 
7
6
 
8
7
  class GymctsABC(ABC, gym.Env):
9
8
 
10
9
  @abstractmethod
11
10
  def get_state(self) -> Any:
11
+ """
12
+ Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
13
+ the environment to the same state. The state is used to restore the environment unsing the load_state method.
14
+
15
+ It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
16
+
17
+ :return: the current state of the environment
18
+ """
12
19
  pass
13
20
 
14
21
  @abstractmethod
15
22
  def load_state(self, state: Any) -> None:
23
+ """
24
+ Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
25
+ environment to the same state. The state is used to restore the environment unsing the load_state method.
26
+
27
+ :param state: the state to load
28
+ :return: None
29
+ """
16
30
  pass
17
31
 
18
32
  @abstractmethod
19
33
  def is_terminal(self) -> bool:
34
+ """
35
+ Returns True if the environment is in a terminal state, False otherwise.
36
+ :return:
37
+ """
20
38
  pass
21
39
 
22
40
  @abstractmethod
23
41
  def get_valid_actions(self) -> list[int]:
42
+ """
43
+ Returns a list of valid actions for the current state of the environment.
44
+ This used to obtain potential actions/subsequent sates for the MCTS tree.
45
+ :return: the list of valid actions
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def action_masks(self) -> np.ndarray | None:
51
+ """
52
+ Returns a numpy array of action masks for the environment. The array should have the same length as the number
53
+ of actions in the action space. If an action is valid, the corresponding mask value should be 1, otherwise 0.
54
+ If no action mask is available, it should return None.
55
+
56
+ :return: a numpy array of action masks or None
57
+ """
24
58
  pass
25
59
 
26
60
  @abstractmethod
27
61
  def rollout(self) -> float:
62
+ """
63
+ Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
64
+
65
+ Please make sure the return value is in the interval [-1, 1].
66
+ Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
67
+ exploration and exploitation are not well-balanced then).
68
+
69
+ :return: the return of the rollout
70
+ """
28
71
  pass