gymcts 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +22 -0
- gymcts/gymcts_action_history_wrapper.py +72 -2
- gymcts/gymcts_agent.py +5 -1
- gymcts/gymcts_deepcopy_wrapper.py +59 -2
- gymcts/gymcts_distributed_agent.py +30 -12
- gymcts/gymcts_env_abc.py +33 -0
- gymcts/gymcts_node.py +85 -8
- gymcts/gymcts_tree_plotter.py +22 -1
- gymcts/logger.py +1 -4
- {gymcts-1.2.0.dist-info → gymcts-1.2.1.dist-info}/METADATA +31 -35
- gymcts-1.2.1.dist-info/RECORD +15 -0
- {gymcts-1.2.0.dist-info → gymcts-1.2.1.dist-info}/WHEEL +1 -1
- gymcts-1.2.0.dist-info/RECORD +0 -15
- {gymcts-1.2.0.dist-info → gymcts-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.2.0.dist-info → gymcts-1.2.1.dist-info}/top_level.txt +0 -0
gymcts/colorful_console_utils.py
CHANGED
|
@@ -106,6 +106,18 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
|
|
|
106
106
|
|
|
107
107
|
|
|
108
108
|
def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
|
|
109
|
+
"""
|
|
110
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the n_of_item and n_classes.
|
|
111
|
+
This function is used to color code the available actions in the MCTS tree visualisation.
|
|
112
|
+
The children of the MCTS tree are colored based on their action for a clearer visualisation.
|
|
113
|
+
|
|
114
|
+
:param s: the string (or object) to be wrapped. objects are converted to string (using the __str__ function).
|
|
115
|
+
:param n_of_item: the index of the item to be colored. In a mcts tree, this is the (parent-)action of the node.
|
|
116
|
+
:param n_classes: the number of classes (or items) to be colored. In a mcts tree, this is the number of available actions.
|
|
117
|
+
:param c_map: the colormap to be used (default is 'rainbow').
|
|
118
|
+
The colormap can be any matplotlib colormap, e.g. 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
|
|
119
|
+
:return: a string that contains the color-codes (prefix and suffix) and the string s in between.
|
|
120
|
+
"""
|
|
109
121
|
if s is None or n_of_item is None or n_classes is None:
|
|
110
122
|
return s
|
|
111
123
|
|
|
@@ -119,6 +131,16 @@ def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rain
|
|
|
119
131
|
|
|
120
132
|
|
|
121
133
|
def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Wraps a string with a color scale (a matplotlib c_map) based on the value, min_val, and max_val.
|
|
136
|
+
|
|
137
|
+
:param s: the string to be wrapped
|
|
138
|
+
:param value: the value to be mapped to a color
|
|
139
|
+
:param min_val: the minimum value of the scale
|
|
140
|
+
:param max_val: the maximum value of the scale
|
|
141
|
+
:param c_map: the colormap to be used (default is 'rainbow')
|
|
142
|
+
:return:
|
|
143
|
+
"""
|
|
122
144
|
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
123
145
|
return s
|
|
124
146
|
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
|
-
import copy
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any, SupportsFloat, Callable
|
|
6
5
|
import gymnasium as gym
|
|
7
6
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
7
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
@@ -13,6 +12,21 @@ from gymcts.logger import log
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
15
|
+
"""
|
|
16
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
17
|
+
It uses the action history as state representation.
|
|
18
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
19
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
20
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
21
|
+
The action history is a list of actions taken in the environment.
|
|
22
|
+
The state is represented as a list of actions taken in the environment.
|
|
23
|
+
The state is used to restore the environment using the load_state method.
|
|
24
|
+
|
|
25
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
26
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# helper attributes for the wrapper
|
|
16
30
|
_terminal_flag: bool = False
|
|
17
31
|
_last_reward: SupportsFloat = 0
|
|
18
32
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -25,6 +39,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
25
39
|
action_mask_fn: str | Callable[[gym.Env], np.ndarray] | None = None,
|
|
26
40
|
buffer_length: int = 100,
|
|
27
41
|
):
|
|
42
|
+
"""
|
|
43
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
44
|
+
It uses the action history as state representation.
|
|
45
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
46
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
47
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
48
|
+
|
|
49
|
+
:param env: the environment to wrap
|
|
50
|
+
:param action_mask_fn: a function that takes the environment as input and returns a mask of valid actions
|
|
51
|
+
:param buffer_length: the length of the buffer for recording episodes for determining their rollout returns
|
|
52
|
+
"""
|
|
28
53
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
29
54
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
30
55
|
|
|
@@ -48,6 +73,17 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
48
73
|
self._action_mask_fn = action_mask_fn
|
|
49
74
|
|
|
50
75
|
def load_state(self, state: list[int]) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Loads the state of the environment. The state is a list of actions taken in the environment.
|
|
78
|
+
|
|
79
|
+
The environment is reset and all actions in the state are performed in order to restore the environment to the
|
|
80
|
+
same state.
|
|
81
|
+
|
|
82
|
+
This works only for deterministic environments!
|
|
83
|
+
|
|
84
|
+
:param state: the state to load
|
|
85
|
+
:return: None
|
|
86
|
+
"""
|
|
51
87
|
self.env.reset()
|
|
52
88
|
self._wrapper_action_history = []
|
|
53
89
|
|
|
@@ -56,15 +92,30 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
56
92
|
self._wrapper_action_history.append(action)
|
|
57
93
|
|
|
58
94
|
def is_terminal(self) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
97
|
+
|
|
98
|
+
:return:
|
|
99
|
+
"""
|
|
59
100
|
if not len(self.get_valid_actions()):
|
|
60
101
|
return True
|
|
61
102
|
else:
|
|
62
103
|
return self._terminal_flag
|
|
63
104
|
|
|
64
105
|
def action_masks(self) -> np.ndarray | None:
|
|
106
|
+
"""
|
|
107
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
108
|
+
|
|
109
|
+
:return:
|
|
110
|
+
"""
|
|
65
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
66
112
|
|
|
67
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
|
|
117
|
+
:return: a list of valid actions
|
|
118
|
+
"""
|
|
68
119
|
if self._action_mask_fn is None:
|
|
69
120
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
70
121
|
return list(range(action_space.n))
|
|
@@ -72,6 +123,12 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
72
123
|
return [i for i, mask in enumerate(self.action_masks()) if mask]
|
|
73
124
|
|
|
74
125
|
def rollout(self) -> float:
|
|
126
|
+
"""
|
|
127
|
+
Performs a random rollout from the current state of the environment and returns the return (sum of rewards)
|
|
128
|
+
of the rollout.
|
|
129
|
+
|
|
130
|
+
:return: the return of the rollout
|
|
131
|
+
"""
|
|
75
132
|
log.debug("performing rollout")
|
|
76
133
|
# random rollout
|
|
77
134
|
# perform random valid action util terminal
|
|
@@ -92,11 +149,24 @@ class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
92
149
|
return episode_return
|
|
93
150
|
|
|
94
151
|
def get_state(self) -> list[int]:
|
|
152
|
+
"""
|
|
153
|
+
Returns the current state of the environment. The state is a list of actions taken in the environment,
|
|
154
|
+
namely all action that have been taken in the environment so far (since the last reset).
|
|
155
|
+
|
|
156
|
+
:return: a list of actions taken in the environment
|
|
157
|
+
"""
|
|
158
|
+
|
|
95
159
|
return self._wrapper_action_history.copy()
|
|
96
160
|
|
|
97
161
|
def step(
|
|
98
162
|
self, action: WrapperActType
|
|
99
163
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
164
|
+
"""
|
|
165
|
+
Performs a step in the environment. It adds the action to the action history and updates the terminal flag.
|
|
166
|
+
|
|
167
|
+
:param action: action to perform in the environment
|
|
168
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
169
|
+
"""
|
|
100
170
|
step_tuple = self.env.step(action)
|
|
101
171
|
self._wrapper_action_history.append(action)
|
|
102
172
|
obs, reward, terminated, truncated, info = step_tuple
|
gymcts/gymcts_agent.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import random
|
|
2
3
|
import gymnasium as gym
|
|
3
4
|
|
|
4
5
|
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
@@ -63,7 +64,10 @@ class GymctsAgent:
|
|
|
63
64
|
# NAVIGATION STRATEGY
|
|
64
65
|
# select child with highest UCB score
|
|
65
66
|
while not temp_node.is_leaf():
|
|
66
|
-
|
|
67
|
+
children = list(temp_node.children.values())
|
|
68
|
+
max_ucb_score = max(child.ucb_score() for child in children)
|
|
69
|
+
best_children = [child for child in children if child.ucb_score() == max_ucb_score]
|
|
70
|
+
temp_node = random.choice(best_children)
|
|
67
71
|
log.debug(f"Selected leaf node: {temp_node}")
|
|
68
72
|
return temp_node
|
|
69
73
|
|
|
@@ -13,8 +13,15 @@ from gymcts.logger import log
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
"""
|
|
17
|
+
A wrapper for gym environments that implements the GymctsABC interface.
|
|
18
|
+
It uses deepcopys as state representation.
|
|
19
|
+
Please note that this is not the most efficient way to implement the state representation.
|
|
20
|
+
It is supposed to be used to see if your use-case works well with the MCTS algorithm.
|
|
21
|
+
If it does, you can consider implementing all GymctsABC methods in a more efficient way.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# helper attributes for the wrapper
|
|
18
25
|
_terminal_flag:bool = False
|
|
19
26
|
_last_reward: SupportsFloat = 0
|
|
20
27
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -22,9 +29,21 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
22
29
|
_action_mask_fn: Callable[[gym.Env], np.ndarray] | None = None
|
|
23
30
|
|
|
24
31
|
def is_terminal(self) -> bool:
|
|
32
|
+
"""
|
|
33
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
34
|
+
|
|
35
|
+
:return: True if the environment is in a terminal state, False otherwise.
|
|
36
|
+
"""
|
|
25
37
|
return self._terminal_flag
|
|
26
38
|
|
|
27
39
|
def load_state(self, state: Any) -> None:
|
|
40
|
+
"""
|
|
41
|
+
The load_state method is not implemented. The state is loaded by replacing the env with the 'state' (the copy
|
|
42
|
+
provided my 'get_state'). 'self' in a method cannot be replaced with another object (as far as i know).
|
|
43
|
+
|
|
44
|
+
:param state: a deepcopy of the environment
|
|
45
|
+
:return: None
|
|
46
|
+
"""
|
|
28
47
|
msg = """
|
|
29
48
|
The NaiveSoloMCTSGymEnvWrapper uses deepcopies of the entire env as the state.
|
|
30
49
|
The loading of the state is done by replacing the env with the 'state' (the copy provided my 'get_state').
|
|
@@ -39,6 +58,16 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
39
58
|
buffer_length: int = 100,
|
|
40
59
|
record_video: bool = False,
|
|
41
60
|
):
|
|
61
|
+
"""
|
|
62
|
+
The constructor of the wrapper. It wraps the environment with RecordEpisodeStatistics and checks if the action
|
|
63
|
+
space is discrete. It also checks if the action_mask_fn is a string or a callable. If it is a string, it tries to
|
|
64
|
+
find the method in the environment. If it is a callable, it assigns it to the _action_mask_fn attribute.
|
|
65
|
+
|
|
66
|
+
:param env: the environment to wrap
|
|
67
|
+
:param action_mask_fn:
|
|
68
|
+
:param buffer_length:
|
|
69
|
+
:param record_video:
|
|
70
|
+
"""
|
|
42
71
|
# wrap with RecordEpisodeStatistics if it is not already wrapped
|
|
43
72
|
env = RecordEpisodeStatistics(env, buffer_length=buffer_length)
|
|
44
73
|
|
|
@@ -61,6 +90,10 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
61
90
|
self._action_mask_fn = action_mask_fn
|
|
62
91
|
|
|
63
92
|
def get_state(self) -> Any:
|
|
93
|
+
"""
|
|
94
|
+
Returns the current state of the environment as a deepcopy of the environment.
|
|
95
|
+
:return: a deepcopy of the environment
|
|
96
|
+
"""
|
|
64
97
|
log.debug("getting state")
|
|
65
98
|
original_state = self
|
|
66
99
|
copied_state = copy.deepcopy(self)
|
|
@@ -71,9 +104,19 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
71
104
|
return copied_state
|
|
72
105
|
|
|
73
106
|
def action_masks(self) -> np.ndarray | None:
|
|
107
|
+
"""
|
|
108
|
+
Returns the action masks for the environment. If the action_mask_fn is not set, it returns None.
|
|
109
|
+
:return: the action masks for the environment
|
|
110
|
+
"""
|
|
74
111
|
return self._action_mask_fn(self.env) if self._action_mask_fn is not None else None
|
|
75
112
|
|
|
76
113
|
def get_valid_actions(self) -> list[int]:
|
|
114
|
+
"""
|
|
115
|
+
Returns a list of valid actions for the current state of the environment.
|
|
116
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
117
|
+
|
|
118
|
+
:return: the list of valid actions
|
|
119
|
+
"""
|
|
77
120
|
if self._action_mask_fn is None:
|
|
78
121
|
action_space: gym.spaces.Discrete = self.env.action_space # Type hinting
|
|
79
122
|
return list(range(action_space.n))
|
|
@@ -83,6 +126,14 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
83
126
|
def step(
|
|
84
127
|
self, action: WrapperActType
|
|
85
128
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
129
|
+
"""
|
|
130
|
+
Performs a step in the environment.
|
|
131
|
+
This method is used to update the wrapper with the new state and the new action, to realize the terminal state
|
|
132
|
+
functionality.
|
|
133
|
+
|
|
134
|
+
:param action: action to perform in the environment
|
|
135
|
+
:return: the step tuple of the environment (obs, reward, terminated, truncated, info)
|
|
136
|
+
"""
|
|
86
137
|
step_tuple = self.env.step(action)
|
|
87
138
|
|
|
88
139
|
obs, reward, terminated, truncated, info = step_tuple
|
|
@@ -93,6 +144,12 @@ class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
|
93
144
|
|
|
94
145
|
|
|
95
146
|
def rollout(self) -> float:
|
|
147
|
+
"""
|
|
148
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the
|
|
149
|
+
rollout.
|
|
150
|
+
|
|
151
|
+
:return: the return of the rollout
|
|
152
|
+
"""
|
|
96
153
|
log.debug("performing rollout")
|
|
97
154
|
# random rollout
|
|
98
155
|
# perform random valid action util terminal
|
|
@@ -118,6 +118,7 @@ class DistributedGymctsAgent:
|
|
|
118
118
|
render_tree_after_step: bool = False,
|
|
119
119
|
render_tree_max_depth: int = 2,
|
|
120
120
|
num_parallel: int = 4,
|
|
121
|
+
clear_mcts_tree_after_step: bool = False,
|
|
121
122
|
number_of_simulations_per_step: int = 25,
|
|
122
123
|
exclude_unvisited_nodes_from_render: bool = False
|
|
123
124
|
):
|
|
@@ -134,6 +135,7 @@ class DistributedGymctsAgent:
|
|
|
134
135
|
self.number_of_simulations_per_step = number_of_simulations_per_step
|
|
135
136
|
|
|
136
137
|
self.env = env
|
|
138
|
+
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
137
139
|
|
|
138
140
|
self.search_root_node = GymctsNode(
|
|
139
141
|
action=None,
|
|
@@ -206,6 +208,8 @@ class DistributedGymctsAgent:
|
|
|
206
208
|
ready_node = ray.get(ready_node_ref)
|
|
207
209
|
|
|
208
210
|
# merge the tree
|
|
211
|
+
if not self.clear_mcts_tree_after_step:
|
|
212
|
+
self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
|
|
209
213
|
search_start_node = merge_nodes(search_start_node, ready_node)
|
|
210
214
|
|
|
211
215
|
action = search_start_node.get_best_action()
|
|
@@ -217,22 +221,34 @@ class DistributedGymctsAgent:
|
|
|
217
221
|
tree_max_depth=self.render_tree_max_depth
|
|
218
222
|
)
|
|
219
223
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
# in a distributed setting we need we delete all previous nodes
|
|
228
|
-
# this is because backpropagation merging trees is already computationally expensive
|
|
229
|
-
# and backpropagating the whole tree would be even more expensive
|
|
230
|
-
next_node.reset()
|
|
224
|
+
if self.clear_mcts_tree_after_step:
|
|
225
|
+
# to clear memory we need to remove all nodes except the current node
|
|
226
|
+
# this is done by setting the root node to the current node
|
|
227
|
+
# and setting the parent of the current node to None
|
|
228
|
+
# we also need to reset the children of the current node
|
|
229
|
+
# this is done by calling the reset method
|
|
230
|
+
next_node.reset()
|
|
231
231
|
|
|
232
232
|
self.search_root_node = next_node
|
|
233
233
|
|
|
234
234
|
return action, next_node
|
|
235
235
|
|
|
236
|
+
def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
|
|
237
|
+
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
238
|
+
while not node.is_root():
|
|
239
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
240
|
+
node.visit_count + num_episodes)
|
|
241
|
+
node.visit_count += num_episodes
|
|
242
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
243
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
244
|
+
node = node.parent
|
|
245
|
+
# also update root node
|
|
246
|
+
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
|
|
247
|
+
node.visit_count + num_episodes)
|
|
248
|
+
node.visit_count += num_episodes
|
|
249
|
+
node.max_value = max(node.max_value, average_episode_return)
|
|
250
|
+
node.min_value = min(node.min_value, average_episode_return)
|
|
251
|
+
|
|
236
252
|
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
237
253
|
|
|
238
254
|
if start_node is None:
|
|
@@ -268,7 +284,7 @@ if __name__ == '__main__':
|
|
|
268
284
|
agent1 = DistributedGymctsAgent(
|
|
269
285
|
env=env,
|
|
270
286
|
render_tree_after_step=True,
|
|
271
|
-
number_of_simulations_per_step=
|
|
287
|
+
number_of_simulations_per_step=10,
|
|
272
288
|
exclude_unvisited_nodes_from_render=True,
|
|
273
289
|
num_parallel=1,
|
|
274
290
|
)
|
|
@@ -278,4 +294,6 @@ if __name__ == '__main__':
|
|
|
278
294
|
actions = agent1.solve()
|
|
279
295
|
end_time = time.perf_counter()
|
|
280
296
|
|
|
297
|
+
agent1.show_mcts_tree_from_root()
|
|
298
|
+
|
|
281
299
|
print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
|
gymcts/gymcts_env_abc.py
CHANGED
|
@@ -9,20 +9,53 @@ class GymctsABC(ABC, gym.Env):
|
|
|
9
9
|
|
|
10
10
|
@abstractmethod
|
|
11
11
|
def get_state(self) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Returns the current state of the environment. The state can be any datatype in principle, that allows to restore
|
|
14
|
+
the environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
15
|
+
|
|
16
|
+
It's recommended to use a numpy array if possible, as it is easy to serialize and deserialize.
|
|
17
|
+
|
|
18
|
+
:return: the current state of the environment
|
|
19
|
+
"""
|
|
12
20
|
pass
|
|
13
21
|
|
|
14
22
|
@abstractmethod
|
|
15
23
|
def load_state(self, state: Any) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Loads the state of the environment. The state can be any datatype in principle, that allows to restore the
|
|
26
|
+
environment to the same state. The state is used to restore the environment unsing the load_state method.
|
|
27
|
+
|
|
28
|
+
:param state: the state to load
|
|
29
|
+
:return: None
|
|
30
|
+
"""
|
|
16
31
|
pass
|
|
17
32
|
|
|
18
33
|
@abstractmethod
|
|
19
34
|
def is_terminal(self) -> bool:
|
|
35
|
+
"""
|
|
36
|
+
Returns True if the environment is in a terminal state, False otherwise.
|
|
37
|
+
:return:
|
|
38
|
+
"""
|
|
20
39
|
pass
|
|
21
40
|
|
|
22
41
|
@abstractmethod
|
|
23
42
|
def get_valid_actions(self) -> list[int]:
|
|
43
|
+
"""
|
|
44
|
+
Returns a list of valid actions for the current state of the environment.
|
|
45
|
+
This used to obtain potential actions/subsequent sates for the MCTS tree.
|
|
46
|
+
:return: the list of valid actions
|
|
47
|
+
"""
|
|
24
48
|
pass
|
|
25
49
|
|
|
26
50
|
@abstractmethod
|
|
27
51
|
def rollout(self) -> float:
|
|
52
|
+
"""
|
|
53
|
+
Performs a rollout from the current state of the environment and returns the return (sum of rewards) of the rollout.
|
|
54
|
+
|
|
55
|
+
Please make sure the return value is in the interval [-1, 1].
|
|
56
|
+
Otherwise, the MCTS algorithm will not work as expected (due to a male-fitted exploration coefficient;
|
|
57
|
+
exploration and exploitation are not well-balanced then).
|
|
58
|
+
|
|
59
|
+
:return: the return of the rollout
|
|
60
|
+
"""
|
|
28
61
|
pass
|
gymcts/gymcts_node.py
CHANGED
|
@@ -13,18 +13,32 @@ TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
|
|
|
13
13
|
|
|
14
14
|
class GymctsNode:
|
|
15
15
|
# static properties
|
|
16
|
-
best_action_weight: float = 0.05
|
|
17
|
-
ubc_c = 0.707
|
|
16
|
+
best_action_weight: float = 0.05 # weight for the best action
|
|
17
|
+
ubc_c = 0.707 # exploration coefficient
|
|
18
|
+
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
# attributes
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
#
|
|
23
|
+
# Note these attributes are not static. Their defined here to give developers a hint what fields are available
|
|
24
|
+
# in the class. They are not static because they are not shared between instances of the class in scope of
|
|
25
|
+
# this library.
|
|
26
|
+
visit_count: int = 0 # number of times the node has been visited
|
|
27
|
+
mean_value: float = 0 # mean value of the node
|
|
28
|
+
max_value: float = -float("inf") # maximum value of the node
|
|
29
|
+
min_value: float = +float("inf") # minimum value of the node
|
|
30
|
+
terminal: bool = False # whether the node is terminal or not
|
|
31
|
+
state: Any = None # state of the node
|
|
26
32
|
|
|
27
33
|
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Returns a string representation of the node. The string representation is used for visualisation purposes.
|
|
36
|
+
It is used for example in the mcts tree visualisation functionality.
|
|
37
|
+
|
|
38
|
+
:param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
|
|
39
|
+
:param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
|
|
40
|
+
:return: a potentially colored string representation of the node.
|
|
41
|
+
"""
|
|
28
42
|
if not colored:
|
|
29
43
|
|
|
30
44
|
if not self.is_root():
|
|
@@ -72,22 +86,44 @@ class GymctsNode:
|
|
|
72
86
|
(f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
|
|
73
87
|
|
|
74
88
|
def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
|
|
89
|
+
"""
|
|
90
|
+
Traverse the tree and yield all nodes in the tree.
|
|
91
|
+
|
|
92
|
+
:return: a generator that yields all nodes in the tree.
|
|
93
|
+
"""
|
|
75
94
|
yield self
|
|
76
95
|
if self.children:
|
|
77
96
|
for child in self.children.values():
|
|
78
97
|
yield from child.traverse_nodes()
|
|
79
98
|
|
|
80
99
|
def get_root(self) -> TGymctsNode:
|
|
100
|
+
"""
|
|
101
|
+
Returns the root node of the tree. The root node is the node that has no parent.
|
|
102
|
+
|
|
103
|
+
:return: the root node of the tree.
|
|
104
|
+
"""
|
|
81
105
|
if self.is_root():
|
|
82
106
|
return self
|
|
83
107
|
return self.parent.get_root()
|
|
84
108
|
|
|
85
109
|
def max_tree_depth(self):
|
|
110
|
+
"""
|
|
111
|
+
Returns the maximum depth of the tree. The depth of a node is the number of edges from
|
|
112
|
+
the node to the root node.
|
|
113
|
+
|
|
114
|
+
:return: the maximum depth of the tree.
|
|
115
|
+
"""
|
|
86
116
|
if self.is_leaf():
|
|
87
117
|
return 0
|
|
88
118
|
return 1 + max(child.max_tree_depth() for child in self.children.values())
|
|
89
119
|
|
|
90
120
|
def n_children_recursively(self):
|
|
121
|
+
"""
|
|
122
|
+
Returns the number of children of the node recursively. The number of children of a node is the number of
|
|
123
|
+
children of the node plus the number of children of all children of the node.
|
|
124
|
+
|
|
125
|
+
:return: the number of children of the node recursively.
|
|
126
|
+
"""
|
|
91
127
|
if self.is_leaf():
|
|
92
128
|
return 0
|
|
93
129
|
return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
|
|
@@ -97,6 +133,14 @@ class GymctsNode:
|
|
|
97
133
|
parent: TGymctsNode | None,
|
|
98
134
|
env_reference: GymctsABC,
|
|
99
135
|
):
|
|
136
|
+
"""
|
|
137
|
+
Initializes the node. The node is initialized with the state of the environment and the action that was taken to
|
|
138
|
+
reach the node. The node is also initialized with the parent node and the environment reference.
|
|
139
|
+
|
|
140
|
+
:param action: the action that was taken to reach the node. If the node is a root node, this parameter is None.
|
|
141
|
+
:param parent: the parent node of the node. If the node is a root node, this parameter is None.
|
|
142
|
+
:param env_reference: a reference to the environment. The environment is used to get the state of the node and the valid actions.
|
|
143
|
+
"""
|
|
100
144
|
|
|
101
145
|
# field depending on whether the node is a root node or not
|
|
102
146
|
self.action: int | None
|
|
@@ -149,21 +193,49 @@ class GymctsNode:
|
|
|
149
193
|
self.parent.reset()
|
|
150
194
|
|
|
151
195
|
def is_root(self) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
Returns true if the node is a root node. A root node is a node that has no parent.
|
|
198
|
+
|
|
199
|
+
:return: true if the node is a root node, false otherwise.
|
|
200
|
+
"""
|
|
152
201
|
return self.parent is None
|
|
153
202
|
|
|
154
203
|
def is_leaf(self) -> bool:
|
|
204
|
+
"""
|
|
205
|
+
Returns true if the node is a leaf node. A leaf node is a node that has no children. A leaf node is a node that has no children.
|
|
206
|
+
|
|
207
|
+
:return: true if the node is a leaf node, false otherwise.
|
|
208
|
+
"""
|
|
155
209
|
return self.children is None or len(self.children) == 0
|
|
156
210
|
|
|
157
211
|
def get_random_child(self) -> TGymctsNode:
|
|
212
|
+
"""
|
|
213
|
+
Returns a random child of the node. A random child is a child that is selected randomly from the list of children.
|
|
214
|
+
:return:
|
|
215
|
+
"""
|
|
158
216
|
if self.is_leaf():
|
|
159
217
|
raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
|
|
160
218
|
|
|
161
219
|
return list(self.children.values())[random.randint(0, len(self.children) - 1)]
|
|
162
220
|
|
|
163
221
|
def get_best_action(self) -> int:
|
|
222
|
+
"""
|
|
223
|
+
Returns the best action of the node. The best action is the action that has the highest score.
|
|
224
|
+
The score is calculated using the get_score() method. The best action is the action that has the highest score.
|
|
225
|
+
The best action is the action that has the highest score.
|
|
226
|
+
|
|
227
|
+
:return: the best action of the node.
|
|
228
|
+
"""
|
|
164
229
|
return max(self.children.values(), key=lambda child: child.get_score()).action
|
|
165
230
|
|
|
166
231
|
def get_score(self) -> float: # todo: make it an attribute?
|
|
232
|
+
"""
|
|
233
|
+
Returns the score of the node. The score is calculated using the mean value and the maximum value of the node.
|
|
234
|
+
The score is calculated using the formula: score = (1 - a) * mean_value + a * max_value
|
|
235
|
+
where a is the best action weight.
|
|
236
|
+
|
|
237
|
+
:return: the score of the node.
|
|
238
|
+
"""
|
|
167
239
|
# return self.mean_value
|
|
168
240
|
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
169
241
|
a = GymctsNode.best_action_weight
|
|
@@ -173,6 +245,11 @@ class GymctsNode:
|
|
|
173
245
|
return self.mean_value
|
|
174
246
|
|
|
175
247
|
def get_max_value(self) -> float:
|
|
248
|
+
"""
|
|
249
|
+
Returns the maximum value of the node. The maximum value is the maximum value of the node.
|
|
250
|
+
|
|
251
|
+
:return: the maximum value of the node.
|
|
252
|
+
"""
|
|
176
253
|
return self.max_value
|
|
177
254
|
|
|
178
255
|
def ucb_score(self):
|
gymcts/gymcts_tree_plotter.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any, Generator
|
|
2
|
+
|
|
1
3
|
from gymcts.gymcts_node import GymctsNode
|
|
2
4
|
|
|
3
5
|
from gymcts.logger import log
|
|
@@ -9,7 +11,19 @@ def _generate_mcts_tree(
|
|
|
9
11
|
depth: int = None,
|
|
10
12
|
exclude_unvisited_nodes_from_render: bool = True,
|
|
11
13
|
action_space_n: int = None
|
|
12
|
-
) ->
|
|
14
|
+
) -> Generator[str, Any | None, None]:
|
|
15
|
+
"""
|
|
16
|
+
Generates a tree representation of the MCTS tree starting from the given node.
|
|
17
|
+
|
|
18
|
+
This is a recursive function that generates a tree representation of the MCTS tree starting from the given node. The
|
|
19
|
+
|
|
20
|
+
:param start_node: the node to start from
|
|
21
|
+
:param prefix: used to format the tree
|
|
22
|
+
:param depth: used to limit the depth of the tree
|
|
23
|
+
:param exclude_unvisited_nodes_from_render: used to exclude unvisited nodes from the render
|
|
24
|
+
:param action_space_n: the number of actions in the action space
|
|
25
|
+
:return: a list of strings representing the tree
|
|
26
|
+
"""
|
|
13
27
|
if prefix is None:
|
|
14
28
|
prefix = ""
|
|
15
29
|
import gymcts.colorful_console_utils as ccu
|
|
@@ -70,6 +84,13 @@ def show_mcts_tree(
|
|
|
70
84
|
tree_max_depth: int = None,
|
|
71
85
|
action_space_n: int = None
|
|
72
86
|
) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Renders the MCTS tree starting from the given node.
|
|
89
|
+
|
|
90
|
+
:param start_node: the node to start from
|
|
91
|
+
:param tree_max_depth: the maximum depth of the tree to render
|
|
92
|
+
:param action_space_n: the number of actions in the action space
|
|
93
|
+
"""
|
|
73
94
|
print(start_node.__str__(colored=True, action_space_n=action_space_n))
|
|
74
95
|
for line in _generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
|
|
75
96
|
print(line)
|
gymcts/logger.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.1
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -47,7 +47,7 @@ Requires-Dist: graph-matrix-jsp-env; extra == "examples"
|
|
|
47
47
|
Requires-Dist: graph-jsp-env; extra == "examples"
|
|
48
48
|
Provides-Extra: dev
|
|
49
49
|
Requires-Dist: jsp-instance-utils; extra == "dev"
|
|
50
|
-
Requires-Dist: graph-matrix-jsp-env; extra == "dev"
|
|
50
|
+
Requires-Dist: graph-matrix-jsp-env>=0.3.0; extra == "dev"
|
|
51
51
|
Requires-Dist: graph-jsp-env; extra == "dev"
|
|
52
52
|
Requires-Dist: JSSEnv; extra == "dev"
|
|
53
53
|
Requires-Dist: pip-tools; extra == "dev"
|
|
@@ -59,21 +59,24 @@ Requires-Dist: stable_baselines3; extra == "dev"
|
|
|
59
59
|
Requires-Dist: sphinx; extra == "dev"
|
|
60
60
|
Requires-Dist: myst-parser; extra == "dev"
|
|
61
61
|
Requires-Dist: sphinx-autobuild; extra == "dev"
|
|
62
|
+
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
62
63
|
Requires-Dist: furo; extra == "dev"
|
|
63
64
|
Requires-Dist: twine; extra == "dev"
|
|
64
65
|
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
65
66
|
Requires-Dist: nbsphinx; extra == "dev"
|
|
67
|
+
Requires-Dist: pandoc; extra == "dev"
|
|
66
68
|
Requires-Dist: jupytext; extra == "dev"
|
|
67
69
|
Requires-Dist: jupyter; extra == "dev"
|
|
70
|
+
Requires-Dist: typing_extensions>=4.12.0; extra == "dev"
|
|
68
71
|
Dynamic: license-file
|
|
69
72
|
|
|
70
73
|
# Graph Matrix Job Shop Env
|
|
71
74
|
|
|
72
75
|
A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
|
|
73
76
|
|
|
74
|
-
- Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/
|
|
75
|
-
- Pypi: [GYMCTS on PyPi](https://pypi.org/project/
|
|
76
|
-
- Documentation: [GYMCTS Docs](https://
|
|
77
|
+
- Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
|
|
78
|
+
- Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
|
|
79
|
+
- Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
|
|
77
80
|
|
|
78
81
|
## Description
|
|
79
82
|
|
|
@@ -101,22 +104,26 @@ The usage of a MCTS agent can roughly organised into the following steps:
|
|
|
101
104
|
- Render the solution
|
|
102
105
|
|
|
103
106
|
The GYMCTS package provides a two types of wrappers for Gymnasium-style environments:
|
|
104
|
-
- `
|
|
105
|
-
- `
|
|
107
|
+
- `DeepCopyMCTSGymEnvWrapper`: A wrapper that uses deepcopies of the environment to save a snapshot of the environment state for each node in the MCTS tree.
|
|
108
|
+
- `ActionHistoryMCTSGymEnvWrapper`: A wrapper that saves the action sequence that lead to the current state in the MCTS node.
|
|
106
109
|
|
|
107
|
-
These wrappers can be used with the `
|
|
108
|
-
The wrapper implement methods that are required by the `
|
|
110
|
+
These wrappers can be used with the `GymctsAgent` to solve the environment.
|
|
111
|
+
The wrapper implement methods that are required by the `GymctsAgent` to interact with the environment.
|
|
109
112
|
GYMCTS is designed to use a single environment instance and reconstructing the environment state form a state snapshot, when needed.
|
|
110
113
|
|
|
111
114
|
NOTE: MCTS works best when the return of an episode is in the range of [-1, 1]. Please adjust the reward function of the environment accordingly (or change the ubc-scaling parameter of the MCTS agent).
|
|
112
115
|
Adjusting the reward function of the environment is easily done with a [NormalizeReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.NormalizeReward) or [TransformReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.TransformReward) Wrapper.
|
|
116
|
+
```python
|
|
117
|
+
env = NormalizeReward(env, gamma=0.99, epsilon=1e-8)
|
|
118
|
+
```
|
|
113
119
|
|
|
114
|
-
|
|
115
|
-
env = TransformReward(env, lambda r: r /
|
|
116
|
-
|
|
120
|
+
```python
|
|
121
|
+
env = TransformReward(env, lambda r: r / n_steps_per_episode)
|
|
122
|
+
```
|
|
123
|
+
### FrozenLake Example (DeepCopyMCTSGymEnvWrapper)
|
|
117
124
|
|
|
118
125
|
A minimal example of how to use the package with the FrozenLake environment and the NaiveSoloMCTSGymEnvWrapper is provided in the following code snippet below.
|
|
119
|
-
The
|
|
126
|
+
The DeepCopyMCTSGymEnvWrapper can be used with non-deterministic environments, such as the FrozenLake environment with slippery ice.
|
|
120
127
|
|
|
121
128
|
```python
|
|
122
129
|
import gymnasium as gym
|
|
@@ -135,7 +142,7 @@ if __name__ == '__main__':
|
|
|
135
142
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True, render_mode="ansi")
|
|
136
143
|
env.reset()
|
|
137
144
|
|
|
138
|
-
# 1. wrap the environment with the
|
|
145
|
+
# 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
|
|
139
146
|
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
140
147
|
|
|
141
148
|
# 2. create the agent
|
|
@@ -158,7 +165,7 @@ if __name__ == '__main__':
|
|
|
158
165
|
|
|
159
166
|
# 5. print the solution
|
|
160
167
|
# read the solution from the info provided by the RecordEpisodeStatistics wrapper
|
|
161
|
-
# (that
|
|
168
|
+
# (that DeepCopyMCTSGymEnvWrapper uses internally)
|
|
162
169
|
episode_length = info["episode"]["l"]
|
|
163
170
|
episode_return = info["episode"]["r"]
|
|
164
171
|
|
|
@@ -251,7 +258,7 @@ if __name__ == '__main__':
|
|
|
251
258
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="rgb_array")
|
|
252
259
|
env.reset()
|
|
253
260
|
|
|
254
|
-
# 1. wrap the environment with the
|
|
261
|
+
# 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
|
|
255
262
|
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
256
263
|
|
|
257
264
|
# 2. create the agent
|
|
@@ -280,7 +287,7 @@ if __name__ == '__main__':
|
|
|
280
287
|
env.close()
|
|
281
288
|
|
|
282
289
|
# 5. print the solution
|
|
283
|
-
# read the solution from the info provided by the RecordEpisodeStatistics wrapper (that
|
|
290
|
+
# read the solution from the info provided by the RecordEpisodeStatistics wrapper (that DeepCopyMCTSGymEnvWrapper wraps internally)
|
|
284
291
|
episode_length = info["episode"]["l"]
|
|
285
292
|
episode_return = info["episode"]["r"]
|
|
286
293
|
|
|
@@ -321,13 +328,13 @@ import gymnasium as gym
|
|
|
321
328
|
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
|
|
322
329
|
from jsp_instance_utils.instances import ft06, ft06_makespan
|
|
323
330
|
|
|
324
|
-
from gymcts.gymcts_agent import
|
|
325
|
-
from gymcts.
|
|
331
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
332
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
326
333
|
|
|
327
334
|
from gymcts.logger import log
|
|
328
335
|
|
|
329
336
|
|
|
330
|
-
class GraphJspGYMCTSWrapper(
|
|
337
|
+
class GraphJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
|
|
331
338
|
|
|
332
339
|
def __init__(self, env: DisjunctiveGraphJspEnv):
|
|
333
340
|
gym.Wrapper.__init__(self, env)
|
|
@@ -378,7 +385,7 @@ if __name__ == '__main__':
|
|
|
378
385
|
|
|
379
386
|
env = GraphJspGYMCTSWrapper(env)
|
|
380
387
|
|
|
381
|
-
agent =
|
|
388
|
+
agent = GymctsAgent(
|
|
382
389
|
env=env,
|
|
383
390
|
clear_mcts_tree_after_step=True,
|
|
384
391
|
render_tree_after_step=True,
|
|
@@ -421,7 +428,6 @@ import gymnasium as gym
|
|
|
421
428
|
|
|
422
429
|
from gymcts.gymcts_agent import GymctsAgent
|
|
423
430
|
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
424
|
-
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
425
431
|
|
|
426
432
|
from gymcts.logger import log
|
|
427
433
|
|
|
@@ -434,7 +440,7 @@ if __name__ == '__main__':
|
|
|
434
440
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="ansi")
|
|
435
441
|
env.reset()
|
|
436
442
|
|
|
437
|
-
# wrap the environment with the
|
|
443
|
+
# wrap the environment with the wrapper or a custom gymcts wrapper
|
|
438
444
|
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
439
445
|
|
|
440
446
|
# create the agent
|
|
@@ -505,11 +511,11 @@ clone the repository in your favorite code editor (for example PyCharm, VSCode,
|
|
|
505
511
|
|
|
506
512
|
using https:
|
|
507
513
|
```shell
|
|
508
|
-
git clone https://github.com/Alexander-Nasuta/
|
|
514
|
+
git clone https://github.com/Alexander-Nasuta/gymcts.git
|
|
509
515
|
```
|
|
510
516
|
or by using the GitHub CLI:
|
|
511
517
|
```shell
|
|
512
|
-
gh repo clone Alexander-Nasuta/
|
|
518
|
+
gh repo clone Alexander-Nasuta/gymcts
|
|
513
519
|
```
|
|
514
520
|
|
|
515
521
|
if you are using PyCharm, I recommend doing the following additional steps:
|
|
@@ -518,9 +524,6 @@ if you are using PyCharm, I recommend doing the following additional steps:
|
|
|
518
524
|
- mark the `tests` folder as test root (by right-clicking on the folder and selecting `Mark Directory as` -> `Test Sources Root`)
|
|
519
525
|
- mark the `resources` folder as resources root (by right-clicking on the folder and selecting `Mark Directory as` -> `Resources Root`)
|
|
520
526
|
|
|
521
|
-
at the end your project structure should look like this:
|
|
522
|
-
|
|
523
|
-
todo
|
|
524
527
|
|
|
525
528
|
### Create a Virtual Environment (optional)
|
|
526
529
|
|
|
@@ -586,12 +589,6 @@ For testing with `tox` run the following command:
|
|
|
586
589
|
tox
|
|
587
590
|
```
|
|
588
591
|
|
|
589
|
-
Here is a screenshot of what the output might look like:
|
|
590
|
-
|
|
591
|
-

|
|
592
|
-
|
|
593
|
-
Tox will run the tests in a separate environment and will also check if the requirements are installed correctly.
|
|
594
|
-
|
|
595
592
|
### Builing and Publishing the Project to PyPi
|
|
596
593
|
|
|
597
594
|
In order to publish the project to PyPi, the project needs to be built and then uploaded to PyPi.
|
|
@@ -630,7 +627,6 @@ sphinx-autobuild ./docs/source/ ./docs/build/html/
|
|
|
630
627
|
This project features most of the extensions featured in this Tutorial: [Document Your Scientific Project With Markdown, Sphinx, and Read the Docs | PyData Global 2021](https://www.youtube.com/watch?v=qRSb299awB0).
|
|
631
628
|
|
|
632
629
|
|
|
633
|
-
|
|
634
630
|
## Contact
|
|
635
631
|
|
|
636
632
|
If you have any questions or feedback, feel free to contact me via [email](mailto:alexander.nasuta@wzl-iqs.rwth-aachen.de) or open an issue on repository.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
gymcts/colorful_console_utils.py,sha256=n7nymC8kKZnA_8nXcdn201NAzjZjgEHfKpbBcnl4oAE,5891
|
|
3
|
+
gymcts/gymcts_action_history_wrapper.py,sha256=7-p17Fgb80SRCBaCm6G8SJrEPsl2Y4aIO3InviuQP08,6993
|
|
4
|
+
gymcts/gymcts_agent.py,sha256=f2imP-Wv-E7EYE0-iWd86hY9cx-rqHZMlDusp-aE-ps,8698
|
|
5
|
+
gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
|
|
6
|
+
gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
|
|
7
|
+
gymcts/gymcts_env_abc.py,sha256=U1mPz0NWZZL1sdHX7oUP1UFKtmbHwyqHQOQidyh_Uck,2107
|
|
8
|
+
gymcts/gymcts_node.py,sha256=pxjY2Zb0kPuFQ5mWEs0ct3qXoyB47NZK7h2ZGbLJbRA,11052
|
|
9
|
+
gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
|
|
10
|
+
gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
|
|
11
|
+
gymcts-1.2.1.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
12
|
+
gymcts-1.2.1.dist-info/METADATA,sha256=wUJEcWrAvdC42kl59qewCN5tK3DKMLxGWcCipnOX4pQ,23371
|
|
13
|
+
gymcts-1.2.1.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
14
|
+
gymcts-1.2.1.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
15
|
+
gymcts-1.2.1.dist-info/RECORD,,
|
gymcts-1.2.0.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
gymcts/colorful_console_utils.py,sha256=OhULcXHKbEA4uJDAEYCTcW6wUv0LsHX_XSYzZ_Szsv4,4553
|
|
3
|
-
gymcts/gymcts_action_history_wrapper.py,sha256=AjvBBwd1t9-nTYP09aMdlScAkFNXf5vOagejpjWYOPo,3810
|
|
4
|
-
gymcts/gymcts_agent.py,sha256=O2y98jKFjR5TzqVV7DO1jlcYDyzAgd_H2RF4-w4NP0g,8499
|
|
5
|
-
gymcts/gymcts_deepcopy_wrapper.py,sha256=OleQTnvxv3gLEo8-2asyeo-CpZ4HEbgyFGS5DTCD7NM,4167
|
|
6
|
-
gymcts/gymcts_distributed_agent.py,sha256=M7dyBfC8u3M99PJFoXKgIc_CPTyHGppmktkH-y9ci4U,10448
|
|
7
|
-
gymcts/gymcts_env_abc.py,sha256=7nCRiiClmmVLX-d_Q1dxeztmuvmAtmWZwjT81zrG1_w,575
|
|
8
|
-
gymcts/gymcts_node.py,sha256=PT_YZFwt1zjuvd8i9Wb5LEkHAqmJOFyPDp3GFD05lqM,7138
|
|
9
|
-
gymcts/gymcts_tree_plotter.py,sha256=eg207wHcDepwWODXzmDYQn1Aai29Cs4jFS1HNvAhlXs,2651
|
|
10
|
-
gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
|
|
11
|
-
gymcts-1.2.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
12
|
-
gymcts-1.2.0.dist-info/METADATA,sha256=zhEIFo0rOnv5hCv6ukImkq-9nshO4EfXMbHlhNlYhyA,23640
|
|
13
|
-
gymcts-1.2.0.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
|
14
|
-
gymcts-1.2.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
15
|
-
gymcts-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|