gymcts 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +22 -0
- gymcts/gymcts_action_history_wrapper.py +72 -2
- gymcts/gymcts_agent.py +54 -7
- gymcts/gymcts_deepcopy_wrapper.py +59 -2
- gymcts/gymcts_distributed_agent.py +30 -12
- gymcts/gymcts_env_abc.py +45 -2
- gymcts/gymcts_neural_agent.py +479 -0
- gymcts/gymcts_node.py +161 -17
- gymcts/gymcts_tree_plotter.py +22 -1
- gymcts/logger.py +1 -4
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/METADATA +39 -39
- gymcts-1.3.0.dist-info/RECORD +16 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/WHEEL +1 -1
- gymcts-1.2.0.dist-info/RECORD +0 -15
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/top_level.txt +0 -0
gymcts/colorful_console_utils.py
CHANGED
|
@@ -106,6 +106,18 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
|
|
|
106
106
|
|
|
107
107
|
|
|
108
108
|
def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
|
|
109
|
+
"""
|
|
110
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
|
|
111
|
+
This function is used to color code the available actions in the MCTS tree visualisation.
|
|
112
|
+
The children of the MCTS tree are colored based on their action for a clearer visualisation.
|
|
113
|
+
|
|
114
|
+
:param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
|
|
115
|
+
:param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
|
|
116
|
+
:param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
|
|
117
|
+
:param c_map: the colormap to be used (default is 'rainbow').
|
|
118
|
+
The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
|
|
119
|
+
:return: a string that contains the color-codes (prefix and suffix) and the string s in between.
|
|
120
|
+
"""
|
|
109
121
|
if s is None or n_of_item is None or n_classes is None:
|
|
110
122
|
return s
|
|
111
123
|
|
|
@@ -119,6 +131,16 @@ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rain
|
|
|
119
131
|
|
|
120
132
|
|
|
121
133
|
def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
|
|
136
|
+
|
|
137
|
+
:param s: the string to be wrapped
|
|
138
|
+
:param value: the value to be mapped to a color
|
|
139
|
+
:param min_val: the minimum value of the scale
|
|
140
|
+
:param max_val: the maximum value of the scale
|
|
141
|
+
:param c_map: the colormap to be used (default is 'rainbow')
|
|
142
|
+
:return:
|
|
143
|
+
"""
|
|
122
144
|
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
123
145
|
return s
|
|
124
146
|
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
|
-
import copy
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any, SupportsFloat, Callable
|
|
6
5
|
import gymnasium as gym
|
|
7
6
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
7
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
@@ -13,6 +12,21 @@ from gymcts.logger import log
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
15
|
+
"""
|
|
16
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
17
|
+
It uses the action history as state representation.
|
|
18
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
19
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
20
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
21
|
+
The action history is a list of actions taken in the environment.
|
|
22
|
+
The state is represented as a list of actions taken in the environment.
|
|
23
|
+
The state is used to restore the environment using the load_state method.
|
|
24
|
+
|
|
25
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
26
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# helper attributes for the wrapper
|
|
16
30
|
_terminal_flag: bool = False
|
|
17
31
|
_last_reward: SupportsFloat = 0
|
|
18
32
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -25,6 +39,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
25
39
|
action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
|
|
26
40
|
buffer_length: int = 100,
|
|
27
41
|
):
|
|
42
|
+
"""
|
|
43
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
44
|
+
It uses the action history as state representation.
|
|
45
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
46
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
47
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
48
|
+
|
|
49
|
+
:param env: the environment to wrap
|
|
50
|
+
:param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
|
|
51
|
+
:param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
|
|
52
|
+
"""
|
|
28
53
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
29
54
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
30
55
|
|
|
@@ -48,6 +73,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
48
73
|
self._action_mask_fn = action_mask_fn
|
|
49
74
|
|
|
50
75
|
def load_state(self, state: list[int]) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Loads the state of the environment. The state is a list of actions taken in the environment.
|
|
78
|
+
|
|
79
|
+
The environment is reset and all actions in the state are performed in order to restore the environment to the
|
|
80
|
+
same state.
|
|
81
|
+
|
|
82
|
+
This works only for deterministic environments!
|
|
83
|
+
|
|
84
|
+
:param state: the state to load
|
|
85
|
+
:return: None
|
|
86
|
+
"""
|
|
51
87
|
self.env.reset()
|
|
52
88
|
self._wrapper_action_history = []
|
|
53
89
|
|
|
@@ -56,15 +92,30 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
56
92
|
self._wrapper_action_history.append(action)
|
|
57
93
|
|
|
58
94
|
def is_terminal(self) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
97
|
+
|
|
98
|
+
:return:
|
|
99
|
+
"""
|
|
59
100
|
if not len(self.get_valid_actions()):
|
|
60
101
|
return True
|
|
61
102
|
else:
|
|
62
103
|
return self._terminal_flag
|
|
63
104
|
|
|
64
105
|
def action_masks(self) -> np.ndarray | None:
|
|
106
|
+
"""
|
|
107
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
108
|
+
|
|
109
|
+
:return:
|
|
110
|
+
"""
|
|
65
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
66
112
|
|
|
67
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
|
|
117
|
+
:return: a list of valid actions
|
|
118
|
+
"""
|
|
68
119
|
if self._action_mask_fn is None:
|
|
69
120
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
70
121
|
return list(range(action_space.n))
|
|
@@ -72,6 +123,12 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
72
123
|
return [i for i, mask in enumerate(self.action_masks()) if mask]
|
|
73
124
|
|
|
74
125
|
def rollout(self) -> float:
|
|
126
|
+
"""
|
|
127
|
+
Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
|
|
128
|
+
of the rollout.
|
|
129
|
+
|
|
130
|
+
:return: the return of the rollout
|
|
131
|
+
"""
|
|
75
132
|
log.debug("performing rollout")
|
|
76
133
|
# random rollout
|
|
77
134
|
# perform random valid action util terminal
|
|
@@ -92,11 +149,24 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
92
149
|
return episode_return
|
|
93
150
|
|
|
94
151
|
def get_state(self) -> list[int]:
|
|
152
|
+
"""
|
|
153
|
+
Returns the current state of the environment. The state is a list of actions taken in the environment,
|
|
154
|
+
namely all action that have been taken in the environment so far (since the last reset).
|
|
155
|
+
|
|
156
|
+
:return: a list of actions taken in the environment
|
|
157
|
+
"""
|
|
158
|
+
|
|
95
159
|
return self._wrapper_action_history.copy()
|
|
96
160
|
|
|
97
161
|
def step(
|
|
98
162
|
self, action: WrapperActType
|
|
99
163
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
164
|
+
"""
|
|
165
|
+
Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
|
|
166
|
+
|
|
167
|
+
:param action: action to perform in the environment
|
|
168
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
169
|
+
"""
|
|
100
170
|
step_tuple = self.env.step(action)
|
|
101
171
|
self._wrapper_action_history.append(action)
|
|
102
172
|
obs, reward, terminated, truncated, info = step_tuple
|
gymcts/gymcts_agent.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import random
|
|
2
3
|
import gymnasium as gym
|
|
3
4
|
|
|
4
|
-
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
|
+
from typing import TypeVar, Any, SupportsFloat, Callable, Literal
|
|
5
6
|
|
|
6
7
|
from gymcts.gymcts_env_abc import GymctsABC
|
|
7
8
|
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
@@ -10,7 +11,9 @@ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
|
10
11
|
|
|
11
12
|
from gymcts.logger import log
|
|
12
13
|
|
|
13
|
-
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
class GymctsAgent:
|
|
@@ -23,17 +26,50 @@ class GymctsAgent:
|
|
|
23
26
|
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
24
27
|
clear_mcts_tree_after_step: bool
|
|
25
28
|
|
|
29
|
+
|
|
30
|
+
# (num_simulations: int, step_idx: int) -> int
|
|
31
|
+
@staticmethod
|
|
32
|
+
def calc_number_of_simulations_per_step(num_simulations: int, step_idx: int) -> int:
|
|
33
|
+
"""
|
|
34
|
+
A function that returns a constant number of simulations per step.
|
|
35
|
+
|
|
36
|
+
:param num_simulations: The number of simulations to return.
|
|
37
|
+
:param step_idx: The current step index (not used in this function).
|
|
38
|
+
:return: A callable that takes an environment as input and returns the constant number of simulations.
|
|
39
|
+
"""
|
|
40
|
+
return num_simulations
|
|
41
|
+
|
|
26
42
|
def __init__(self,
|
|
27
43
|
env: GymctsABC,
|
|
28
44
|
clear_mcts_tree_after_step: bool = True,
|
|
29
45
|
render_tree_after_step: bool = False,
|
|
30
46
|
render_tree_max_depth: int = 2,
|
|
31
47
|
number_of_simulations_per_step: int = 25,
|
|
32
|
-
exclude_unvisited_nodes_from_render: bool = False
|
|
48
|
+
exclude_unvisited_nodes_from_render: bool = False,
|
|
49
|
+
calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
|
|
50
|
+
score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
|
|
51
|
+
best_action_weight=None,
|
|
33
52
|
):
|
|
34
53
|
# check if action space of env is discrete
|
|
35
54
|
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
36
55
|
raise ValueError("Action space must be discrete.")
|
|
56
|
+
if calc_number_of_simulations_per_step is not None:
|
|
57
|
+
# check if the provided function is callable
|
|
58
|
+
if not callable(calc_number_of_simulations_per_step):
|
|
59
|
+
raise ValueError("calc_number_of_simulations_per_step must be a callable accepting two arguments: num_simulations and step_idx.")
|
|
60
|
+
# assign the provided function to the attribute
|
|
61
|
+
# it needs to be staticmethod to be used as a class attribute
|
|
62
|
+
print("Using provided calc_number_of_simulations_per_step function.")
|
|
63
|
+
self.calc_number_of_simulations_per_step = staticmethod(calc_number_of_simulations_per_step)
|
|
64
|
+
if score_variate not in ["UCT_v0", "UCT_v1", "UCT_v2"]:
|
|
65
|
+
raise ValueError("score_variate must be one of ['UCT_v0', 'UCT_v1', 'UCT_v2'].")
|
|
66
|
+
GymctsNode.score_variate = score_variate
|
|
67
|
+
|
|
68
|
+
if best_action_weight is not None:
|
|
69
|
+
if best_action_weight < 0 or best_action_weight > 1:
|
|
70
|
+
raise ValueError("best_action_weight must be in range [0, 1].")
|
|
71
|
+
GymctsNode.best_action_weight = best_action_weight
|
|
72
|
+
|
|
37
73
|
|
|
38
74
|
self.render_tree_after_step = render_tree_after_step
|
|
39
75
|
self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
|
|
@@ -63,7 +99,10 @@ class GymctsAgent:
|
|
|
63
99
|
# NAVIGATION STRATEGY
|
|
64
100
|
# select child with highest UCB score
|
|
65
101
|
while not temp_node.is_leaf():
|
|
66
|
-
|
|
102
|
+
children = list(temp_node.children.values())
|
|
103
|
+
max_ucb_score = max(child.tree_policy_score() for child in children)
|
|
104
|
+
best_children = [child for child in children if child.tree_policy_score() == max_ucb_score]
|
|
105
|
+
temp_node = random.choice(best_children)
|
|
67
106
|
log.debug(f"Selected leaf node: {temp_node}")
|
|
68
107
|
return temp_node
|
|
69
108
|
|
|
@@ -84,7 +123,6 @@ class GymctsAgent:
|
|
|
84
123
|
parent=node,
|
|
85
124
|
env_reference=self.env,
|
|
86
125
|
)
|
|
87
|
-
|
|
88
126
|
node.children = child_dict
|
|
89
127
|
|
|
90
128
|
def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
|
|
@@ -100,13 +138,20 @@ class GymctsAgent:
|
|
|
100
138
|
|
|
101
139
|
action_list = []
|
|
102
140
|
|
|
141
|
+
idx = 0
|
|
103
142
|
while not current_node.terminal:
|
|
104
|
-
|
|
143
|
+
num_sims = self.calc_number_of_simulations_per_step(num_simulations_per_step, idx)
|
|
144
|
+
|
|
145
|
+
log.info(f"Performing MCTS step {idx} with {num_sims} simulations.")
|
|
146
|
+
|
|
147
|
+
next_action, current_node = self.perform_mcts_step(num_simulations=num_sims,
|
|
105
148
|
render_tree_after_step=render_tree_after_step)
|
|
106
|
-
log.info(f"selected action {next_action} after {
|
|
149
|
+
log.info(f"selected action {next_action} after {num_sims} simulations.")
|
|
107
150
|
action_list.append(next_action)
|
|
108
151
|
log.info(f"current action list: {action_list}")
|
|
109
152
|
|
|
153
|
+
idx += 1
|
|
154
|
+
|
|
110
155
|
log.info(f"Final action list: {action_list}")
|
|
111
156
|
# restore state of current node
|
|
112
157
|
return action_list
|
|
@@ -145,6 +190,8 @@ class GymctsAgent:
|
|
|
145
190
|
# we also need to reset the children of the current node
|
|
146
191
|
# this is done by calling the reset method
|
|
147
192
|
next_node.reset()
|
|
193
|
+
else:
|
|
194
|
+
next_node.remove_parent()
|
|
148
195
|
|
|
149
196
|
self.search_root_node = next_node
|
|
150
197
|
|
|
@@ -13,8 +13,15 @@ from gymcts.logger import log
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
"""
|
|
17
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
18
|
+
It uses deepcopys as state representation.
|
|
19
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
20
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
21
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# helper attributes for the wrapper
|
|
18
25
|
_terminal_flag:bool = False
|
|
19
26
|
_last_reward: SupportsFloat = 0
|
|
20
27
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -22,9 +29,21 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
22
29
|
_action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
|
|
23
30
|
|
|
24
31
|
def is_terminal(self) -> bool:
|
|
32
|
+
"""
|
|
33
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
34
|
+
|
|
35
|
+
:return: True if the environment is in a terminal state, False otherwise.
|
|
36
|
+
"""
|
|
25
37
|
return self._terminal_flag
|
|
26
38
|
|
|
27
39
|
def load_state(self, state: Any) -> None:
|
|
40
|
+
"""
|
|
41
|
+
The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
|
|
42
|
+
provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
|
|
43
|
+
|
|
44
|
+
:param state: a deepcopy of the environment
|
|
45
|
+
:return: None
|
|
46
|
+
"""
|
|
28
47
|
msg = """
|
|
29
48
|
The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
|
|
30
49
|
The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
|
|
@@ -39,6 +58,16 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
39
58
|
buffer_length: int = 100,
|
|
40
59
|
record_video: bool = False,
|
|
41
60
|
):
|
|
61
|
+
"""
|
|
62
|
+
The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
|
|
63
|
+
space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
|
|
64
|
+
find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
|
|
65
|
+
|
|
66
|
+
:param env: the environment to wrap
|
|
67
|
+
:param action_mask_fn:
|
|
68
|
+
:param buffer_length:
|
|
69
|
+
:param record_video:
|
|
70
|
+
"""
|
|
42
71
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
43
72
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
44
73
|
|
|
@@ -61,6 +90,10 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
61
90
|
self._action_mask_fn = action_mask_fn
|
|
62
91
|
|
|
63
92
|
def get_state(self) -> Any:
|
|
93
|
+
"""
|
|
94
|
+
Returns the current state of the environment as a deepcopy of the environment.
|
|
95
|
+
:return: a deepcopy of the environment
|
|
96
|
+
"""
|
|
64
97
|
log.debug("getting state")
|
|
65
98
|
original_state = self
|
|
66
99
|
copied_state = copy.deepcopy(self)
|
|
@@ -71,9 +104,19 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
71
104
|
return copied_state
|
|
72
105
|
|
|
73
106
|
def action_masks(self) -> np.ndarray | None:
|
|
107
|
+
"""
|
|
108
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
109
|
+
:return: the action masks for the environment
|
|
110
|
+
"""
|
|
74
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
75
112
|
|
|
76
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
117
|
+
|
|
118
|
+
:return: the list of valid actions
|
|
119
|
+
"""
|
|
77
120
|
if self._action_mask_fn is None:
|
|
78
121
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
79
122
|
return list(range(action_space.n))
|
|
@@ -83,6 +126,14 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
83
126
|
def step(
|
|
84
127
|
self, action: WrapperActType
|
|
85
128
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
129
|
+
"""
|
|
130
|
+
Performs a step in the environment.
|
|
131
|
+
This method is used to update the wrapper with the new state and the new action, to realize the terminal state
|
|
132
|
+
functionality.
|
|
133
|
+
|
|
134
|
+
:param action: action to perform in the environment
|
|
135
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
136
|
+
"""
|
|
86
137
|
step_tuple = self.env.step(action)
|
|
87
138
|
|
|
88
139
|
obs, reward, terminated, truncated, info = step_tuple
|
|
@@ -93,6 +144,12 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
93
144
|
|
|
94
145
|
|
|
95
146
|
def rollout(self) -> float:
|
|
147
|
+
"""
|
|
148
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
|
|
149
|
+
rollout.
|
|
150
|
+
|
|
151
|
+
:return: the return of the rollout
|
|
152
|
+
"""
|
|
96
153
|
log.debug("performing rollout")
|
|
97
154
|
# random rollout
|
|
98
155
|
# perform random valid action util terminal
|
|
@@ -118,6 +118,7 @@ class DistributedGymctsAgent:
|
|
|
118
118
|
render_tree_after_step: bool = False,
|
|
119
119
|
render_tree_max_depth: int = 2,
|
|
120
120
|
num_parallel: int = 4,
|
|
121
|
+
clear_mcts_tree_after_step: bool = False,
|
|
121
122
|
number_of_simulations_per_step: int = 25,
|
|
122
123
|
exclude_unvisited_nodes_from_render: bool = False
|
|
123
124
|
):
|
|
@@ -134,6 +135,7 @@ class DistributedGymctsAgent:
|
|
|
134
135
|
self.number_of_simulations_per_step = number_of_simulations_per_step
|
|
135
136
|
|
|
136
137
|
self.env = env
|
|
138
|
+
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
137
139
|
|
|
138
140
|
self.search_root_node = GymctsNode(
|
|
139
141
|
action=None,
|
|
@@ -206,6 +208,8 @@ class DistributedGymctsAgent:
|
|
|
206
208
|
ready_node = ray.get(ready_node_ref)
|
|
207
209
|
|
|
208
210
|
# merge the tree
|
|
211
|
+
if not self.clear_mcts_tree_after_step:
|
|
212
|
+
self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
|
|
209
213
|
search_start_node = merge_nodes(search_start_node, ready_node)
|
|
210
214
|
|
|
211
215
|
action = search_start_node.get_best_action()
|
|
@@ -217,22 +221,34 @@ class DistributedGymctsAgent:
|
|
|
217
221
|
tree_max_depth=self.render_tree_max_depth
|
|
218
222
|
)
|
|
219
223
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
# in a distributed setting we need we delete all previous nodes
|
|
228
|
-
# this is because backpropagation merging trees is already computationally expensive
|
|
229
|
-
# and backpropagating the whole tree would be even more expensive
|
|
230
|
-
next_node.reset()
|
|
224
|
+
if self.clear_mcts_tree_after_step:
|
|
225
|
+
# to clear memory we need to remove all nodes except the current node
|
|
226
|
+
# this is done by setting the root node to the current node
|
|
227
|
+
# and setting the parent of the current node to None
|
|
228
|
+
# we also need to reset the children of the current node
|
|
229
|
+
# this is done by calling the reset method
|
|
230
|
+
next_node.reset()
|
|
231
231
|
|
|
232
232
|
self.search_root_node = next_node
|
|
233
233
|
|
|
234
234
|
return action, next_node
|
|
235
235
|
|
|
236
|
+
def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
|
|
237
|
+
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
238
|
+
while not node.is_root():
|
|
239
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
240
|
+
node.visit_count + num_episodes)
|
|
241
|
+
node.visit_count += num_episodes
|
|
242
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
243
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
244
|
+
node = node.parent
|
|
245
|
+
# also update root node
|
|
246
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
247
|
+
node.visit_count + num_episodes)
|
|
248
|
+
node.visit_count += num_episodes
|
|
249
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
250
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
251
|
+
|
|
236
252
|
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
237
253
|
|
|
238
254
|
if start_node is None:
|
|
@@ -268,7 +284,7 @@ if __name__ == '__main__':
|
|
|
268
284
|
agent1 = DistributedGymctsAgent(
|
|
269
285
|
env=env,
|
|
270
286
|
render_tree_after_step=True,
|
|
271
|
-
number_of_simulations_per_step=
|
|
287
|
+
number_of_simulations_per_step=10,
|
|
272
288
|
exclude_unvisited_nodes_from_render=True,
|
|
273
289
|
num_parallel=1,
|
|
274
290
|
)
|
|
@@ -278,4 +294,6 @@ if __name__ == '__main__':
|
|
|
278
294
|
actions = agent1.solve()
|
|
279
295
|
end_time = time.perf_counter()
|
|
280
296
|
|
|
297
|
+
agent1.show_mcts_tree_from_root()
|
|
298
|
+
|
|
281
299
|
print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
|
gymcts/gymcts_env_abc.py
CHANGED
|
@@ -1,28 +1,71 @@
|
|
|
1
1
|
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
import gymnasium as gym
|
|
4
|
-
|
|
5
|
-
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
4
|
+
import numpy as np
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class GymctsABC(ABC, gym.Env):
|
|
9
8
|
|
|
10
9
|
@abstractmethod
|
|
11
10
|
def get_state(self) -> Any:
|
|
11
|
+
"""
|
|
12
|
+
Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
|
|
13
|
+
the environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
14
|
+
|
|
15
|
+
It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
|
|
16
|
+
|
|
17
|
+
:return: the current state of the environment
|
|
18
|
+
"""
|
|
12
19
|
pass
|
|
13
20
|
|
|
14
21
|
@abstractmethod
|
|
15
22
|
def load_state(self, state: Any) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
|
|
25
|
+
environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
26
|
+
|
|
27
|
+
:param state: the state to load
|
|
28
|
+
:return: None
|
|
29
|
+
"""
|
|
16
30
|
pass
|
|
17
31
|
|
|
18
32
|
@abstractmethod
|
|
19
33
|
def is_terminal(self) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
36
|
+
:return:
|
|
37
|
+
"""
|
|
20
38
|
pass
|
|
21
39
|
|
|
22
40
|
@abstractmethod
|
|
23
41
|
def get_valid_actions(self) -> list[int]:
|
|
42
|
+
"""
|
|
43
|
+
Returns a list of valid actions for the current state of the environment.
|
|
44
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
45
|
+
:return: the list of valid actions
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def action_masks(self) -> np.ndarray | None:
|
|
51
|
+
"""
|
|
52
|
+
Returns a numpy array of action masks for the environment. The array should have the same length as the number
|
|
53
|
+
of actions in the action space. If an action is valid, the corresponding mask value should be 1, otherwise 0.
|
|
54
|
+
If no action mask is available, it should return None.
|
|
55
|
+
|
|
56
|
+
:return: a numpy array of action masks or None
|
|
57
|
+
"""
|
|
24
58
|
pass
|
|
25
59
|
|
|
26
60
|
@abstractmethod
|
|
27
61
|
def rollout(self) -> float:
|
|
62
|
+
"""
|
|
63
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
|
|
64
|
+
|
|
65
|
+
Please make sure the return value is in the interval [-1, 1].
|
|
66
|
+
Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
|
|
67
|
+
exploration and exploitation are not well-balanced then).
|
|
68
|
+
|
|
69
|
+
:return: the return of the rollout
|
|
70
|
+
"""
|
|
28
71
|
pass
|