gymcts 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +22 -0
- gymcts/gymcts_action_history_wrapper.py +72 -2
- gymcts/gymcts_agent.py +54 -7
- gymcts/gymcts_deepcopy_wrapper.py +59 -2
- gymcts/gymcts_distributed_agent.py +30 -12
- gymcts/gymcts_env_abc.py +45 -2
- gymcts/gymcts_neural_agent.py +479 -0
- gymcts/gymcts_node.py +161 -17
- gymcts/gymcts_tree_plotter.py +22 -1
- gymcts/logger.py +1 -4
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/METADATA +39 -39
- gymcts-1.3.0.dist-info/RECORD +16 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/WHEEL +1 -1
- gymcts-1.2.0.dist-info/RECORD +0 -15
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/top_level.txt +0 -0
gymcts/gymcts_node.py
CHANGED
|
@@ -2,7 +2,7 @@ import uuid
|
|
|
2
2
|
import random
|
|
3
3
|
import math
|
|
4
4
|
|
|
5
|
-
from typing import TypeVar, Any, SupportsFloat, Callable, Generator
|
|
5
|
+
from typing import TypeVar, Any, SupportsFloat, Callable, Generator, Literal
|
|
6
6
|
|
|
7
7
|
from gymcts.gymcts_env_abc import GymctsABC
|
|
8
8
|
|
|
@@ -13,22 +13,55 @@ TGymctsNode = TypeVar("TGymctsNode", bound="GymctsNode")
|
|
|
13
13
|
|
|
14
14
|
class GymctsNode:
|
|
15
15
|
# static properties
|
|
16
|
-
best_action_weight: float = 0.05
|
|
17
|
-
ubc_c = 0.707
|
|
16
|
+
best_action_weight: float = 0.05 # weight for the best action
|
|
17
|
+
ubc_c = 0.707 # exploration coefficient
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
UCT (Upper Confidence Bound applied to Trees) exploration terms:
|
|
21
|
+
|
|
22
|
+
UCT 0:
|
|
23
|
+
c * √( 2 * ln(N(s)) / N(s,a) )
|
|
24
|
+
|
|
25
|
+
UCT 1:
|
|
26
|
+
c * √( ln(N(s)) / (1 + N(s,a)) )
|
|
27
|
+
|
|
28
|
+
UCT 2:
|
|
29
|
+
c * ( √(N(s)) / (1 + N(s,a)) )
|
|
30
|
+
|
|
31
|
+
Where:
|
|
32
|
+
N(s) = number of times state s has been visited
|
|
33
|
+
N(s,a) = number of times action a was taken from state s
|
|
34
|
+
c = exploration constant
|
|
35
|
+
"""
|
|
36
|
+
score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0"
|
|
37
|
+
|
|
38
|
+
|
|
18
39
|
|
|
19
40
|
# attributes
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
41
|
+
#
|
|
42
|
+
# Note these attributes are not static. Their defined here to give developers a hint what fields are available
|
|
43
|
+
# in the class. They are not static because they are not shared between instances of the class in scope of
|
|
44
|
+
# this library.
|
|
45
|
+
visit_count: int = 0 # number of times the node has been visited
|
|
46
|
+
mean_value: float = 0 # mean value of the node
|
|
47
|
+
max_value: float = -float("inf") # maximum value of the node
|
|
48
|
+
min_value: float = +float("inf") # minimum value of the node
|
|
49
|
+
terminal: bool = False # whether the node is terminal or not
|
|
50
|
+
state: Any = None # state of the node
|
|
26
51
|
|
|
27
52
|
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Returns a string representation of the node. The string representation is used for visualisation purposes.
|
|
55
|
+
It is used for example in the mcts tree visualisation functionality.
|
|
56
|
+
|
|
57
|
+
:param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
|
|
58
|
+
:param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
|
|
59
|
+
:return: a potentially colored string representation of the node.
|
|
60
|
+
"""
|
|
28
61
|
if not colored:
|
|
29
62
|
|
|
30
63
|
if not self.is_root():
|
|
31
|
-
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.
|
|
64
|
+
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
|
|
32
65
|
else:
|
|
33
66
|
return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
|
|
34
67
|
|
|
@@ -69,25 +102,47 @@ class GymctsNode:
|
|
|
69
102
|
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
70
103
|
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
71
104
|
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
72
|
-
(f", {p}ubc{e}={colorful_value(self.
|
|
105
|
+
(f", {p}ubc{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
|
|
73
106
|
|
|
74
107
|
def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
|
|
108
|
+
"""
|
|
109
|
+
Traverse the tree and yield all nodes in the tree.
|
|
110
|
+
|
|
111
|
+
:return: a generator that yields all nodes in the tree.
|
|
112
|
+
"""
|
|
75
113
|
yield self
|
|
76
114
|
if self.children:
|
|
77
115
|
for child in self.children.values():
|
|
78
116
|
yield from child.traverse_nodes()
|
|
79
117
|
|
|
80
118
|
def get_root(self) -> TGymctsNode:
|
|
119
|
+
"""
|
|
120
|
+
Returns the root node of the tree. The root node is the node that has no parent.
|
|
121
|
+
|
|
122
|
+
:return: the root node of the tree.
|
|
123
|
+
"""
|
|
81
124
|
if self.is_root():
|
|
82
125
|
return self
|
|
83
126
|
return self.parent.get_root()
|
|
84
127
|
|
|
85
128
|
def max_tree_depth(self):
|
|
129
|
+
"""
|
|
130
|
+
Returns the maximum depth of the tree. The depth of a node is the number of edges from
|
|
131
|
+
the node to the root node.
|
|
132
|
+
|
|
133
|
+
:return: the maximum depth of the tree.
|
|
134
|
+
"""
|
|
86
135
|
if self.is_leaf():
|
|
87
136
|
return 0
|
|
88
137
|
return 1 + max(child.max_tree_depth() for child in self.children.values())
|
|
89
138
|
|
|
90
139
|
def n_children_recursively(self):
|
|
140
|
+
"""
|
|
141
|
+
Returns the number of children of the node recursively. The number of children of a node is the number of
|
|
142
|
+
children of the node plus the number of children of all children of the node.
|
|
143
|
+
|
|
144
|
+
:return: the number of children of the node recursively.
|
|
145
|
+
"""
|
|
91
146
|
if self.is_leaf():
|
|
92
147
|
return 0
|
|
93
148
|
return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
|
|
@@ -97,6 +152,14 @@ class GymctsNode:
|
|
|
97
152
|
parent: TGymctsNode | None,
|
|
98
153
|
env_reference: GymctsABC,
|
|
99
154
|
):
|
|
155
|
+
"""
|
|
156
|
+
Initializes the node. The node is initialized with the state of the environment and the action that was taken to
|
|
157
|
+
reach the node. The node is also initialized with the parent node and the environment reference.
|
|
158
|
+
|
|
159
|
+
:param action: the action that was taken to reach the node. If the node is a root node, this parameter is None.
|
|
160
|
+
:param parent: the parent node of the node. If the node is a root node, this parameter is None.
|
|
161
|
+
:param env_reference: a reference to the environment. The environment is used to get the state of the node and the valid actions.
|
|
162
|
+
"""
|
|
100
163
|
|
|
101
164
|
# field depending on whether the node is a root node or not
|
|
102
165
|
self.action: int | None
|
|
@@ -148,22 +211,56 @@ class GymctsNode:
|
|
|
148
211
|
if self.parent:
|
|
149
212
|
self.parent.reset()
|
|
150
213
|
|
|
214
|
+
def remove_parent(self) -> None:
|
|
215
|
+
self.parent = None
|
|
216
|
+
|
|
217
|
+
if self.parent is not None:
|
|
218
|
+
self.parent.remove_parent()
|
|
219
|
+
|
|
151
220
|
def is_root(self) -> bool:
|
|
221
|
+
"""
|
|
222
|
+
Returns true if the node is a root node. A root node is a node that has no parent.
|
|
223
|
+
|
|
224
|
+
:return: true if the node is a root node, false otherwise.
|
|
225
|
+
"""
|
|
152
226
|
return self.parent is None
|
|
153
227
|
|
|
154
228
|
def is_leaf(self) -> bool:
|
|
229
|
+
"""
|
|
230
|
+
Returns true if the node is a leaf node. A leaf node is a node that has no children. A leaf node is a node that has no children.
|
|
231
|
+
|
|
232
|
+
:return: true if the node is a leaf node, false otherwise.
|
|
233
|
+
"""
|
|
155
234
|
return self.children is None or len(self.children) == 0
|
|
156
235
|
|
|
157
236
|
def get_random_child(self) -> TGymctsNode:
|
|
237
|
+
"""
|
|
238
|
+
Returns a random child of the node. A random child is a child that is selected randomly from the list of children.
|
|
239
|
+
:return:
|
|
240
|
+
"""
|
|
158
241
|
if self.is_leaf():
|
|
159
242
|
raise ValueError("cannot get random child of leaf node") # todo: maybe return self instead?
|
|
160
243
|
|
|
161
244
|
return list(self.children.values())[random.randint(0, len(self.children) - 1)]
|
|
162
245
|
|
|
163
246
|
def get_best_action(self) -> int:
|
|
247
|
+
"""
|
|
248
|
+
Returns the best action of the node. The best action is the action that has the highest score.
|
|
249
|
+
The score is calculated using the get_score() method. The best action is the action that has the highest score.
|
|
250
|
+
The best action is the action that has the highest score.
|
|
251
|
+
|
|
252
|
+
:return: the best action of the node.
|
|
253
|
+
"""
|
|
164
254
|
return max(self.children.values(), key=lambda child: child.get_score()).action
|
|
165
255
|
|
|
166
256
|
def get_score(self) -> float: # todo: make it an attribute?
|
|
257
|
+
"""
|
|
258
|
+
Returns the score of the node. The score is calculated using the mean value and the maximum value of the node.
|
|
259
|
+
The score is calculated using the formula: score = (1 - a) * mean_value + a * max_value
|
|
260
|
+
where a is the best action weight.
|
|
261
|
+
|
|
262
|
+
:return: the score of the node.
|
|
263
|
+
"""
|
|
167
264
|
# return self.mean_value
|
|
168
265
|
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
169
266
|
a = GymctsNode.best_action_weight
|
|
@@ -173,11 +270,46 @@ class GymctsNode:
|
|
|
173
270
|
return self.mean_value
|
|
174
271
|
|
|
175
272
|
def get_max_value(self) -> float:
|
|
273
|
+
"""
|
|
274
|
+
Returns the maximum value of the node. The maximum value is the maximum value of the node.
|
|
275
|
+
|
|
276
|
+
:return: the maximum value of the node.
|
|
277
|
+
"""
|
|
176
278
|
return self.max_value
|
|
177
279
|
|
|
178
|
-
def
|
|
280
|
+
def tree_policy_score(self):
|
|
179
281
|
"""
|
|
282
|
+
TODO: update docstring
|
|
283
|
+
|
|
180
284
|
The score for an action that would transition between the parent and child.
|
|
285
|
+
For vanilla MCTS, this is the UCB1 score.
|
|
286
|
+
|
|
287
|
+
The UCB1 score is calculated using the formula:
|
|
288
|
+
|
|
289
|
+
UCT (Upper Confidence Bound applied to Trees) exploration terms:
|
|
290
|
+
|
|
291
|
+
UCT_v0:
|
|
292
|
+
c * √( 2 * ln(N(s)) / N(s,a) )
|
|
293
|
+
|
|
294
|
+
UCT_v1:
|
|
295
|
+
c * √( ln(N(s)) / (1 + N(s,a)) )
|
|
296
|
+
|
|
297
|
+
UCT_v2:
|
|
298
|
+
c * ( √(N(s)) / (1 + N(s,a)) )
|
|
299
|
+
|
|
300
|
+
Where:
|
|
301
|
+
N(s) = number of times state s has been visited
|
|
302
|
+
N(s,a) = number of times action a was taken from state s
|
|
303
|
+
c = exploration constant
|
|
304
|
+
|
|
305
|
+
where:
|
|
306
|
+
- mean_value is the mean value of the node
|
|
307
|
+
- c is a constant that controls the exploration-exploitation trade-off (GymctsNode.ubc_c)
|
|
308
|
+
- parent_visit_count is the number of times the parent node has been visited
|
|
309
|
+
- visit_count is the number of times the node has been visited
|
|
310
|
+
|
|
311
|
+
If the node has not been visited yet, the score is set to infinity.
|
|
312
|
+
|
|
181
313
|
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
|
|
182
314
|
|
|
183
315
|
if child.visit_count > 0:
|
|
@@ -192,8 +324,20 @@ class GymctsNode:
|
|
|
192
324
|
"""
|
|
193
325
|
if self.is_root():
|
|
194
326
|
raise ValueError("ucb_score can only be called on non-root nodes")
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
if
|
|
198
|
-
|
|
199
|
-
|
|
327
|
+
c = GymctsNode.ubc_c # default is 0.707
|
|
328
|
+
|
|
329
|
+
if GymctsNode.score_variate == "UCT_v0":
|
|
330
|
+
if self.visit_count == 0:
|
|
331
|
+
return float("inf")
|
|
332
|
+
return self.mean_value + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
|
|
333
|
+
|
|
334
|
+
if GymctsNode.score_variate == "UCT_v1":
|
|
335
|
+
return self.mean_value + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
|
|
336
|
+
|
|
337
|
+
if GymctsNode.score_variate == "UCT_v2":
|
|
338
|
+
return self.mean_value + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
|
|
339
|
+
|
|
340
|
+
raise ValueError(f"unknown score variate: {GymctsNode.score_variate}. ")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
|
gymcts/gymcts_tree_plotter.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any, Generator
|
|
2
|
+
|
|
1
3
|
from gymcts.gymcts_node import GymctsNode
|
|
2
4
|
|
|
3
5
|
from gymcts.logger import log
|
|
@@ -9,7 +11,19 @@ def _generate_mcts_tree(
|
|
|
9
11
|
depth: int = None,
|
|
10
12
|
exclude_unvisited_nodes_from_render: bool = True,
|
|
11
13
|
action_space_n: int = None
|
|
12
|
-
) ->
|
|
14
|
+
) -> Generator[str, Any | None, None]:
|
|
15
|
+
"""
|
|
16
|
+
Generates a tree representation of the MCTS tree starting from the given node.
|
|
17
|
+
|
|
18
|
+
This is a recursive function that generates a tree representation of the MCTS tree starting from the given node. The
|
|
19
|
+
|
|
20
|
+
:param start_node: the node to start from
|
|
21
|
+
:param prefix: used to format the tree
|
|
22
|
+
:param depth: used to limit the depth of the tree
|
|
23
|
+
:param exclude_unvisited_nodes_from_render: used to exclude unvisited nodes from the render
|
|
24
|
+
:param action_space_n: the number of actions in the action space
|
|
25
|
+
:return: a list of strings representing the tree
|
|
26
|
+
"""
|
|
13
27
|
if prefix is None:
|
|
14
28
|
prefix = ""
|
|
15
29
|
import gymcts.colorful_console_utils as ccu
|
|
@@ -70,6 +84,13 @@ def show_mcts_tree(
|
|
|
70
84
|
tree_max_depth: int = None,
|
|
71
85
|
action_space_n: int = None
|
|
72
86
|
) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Renders the MCTS tree starting from the given node.
|
|
89
|
+
|
|
90
|
+
:param start_node: the node to start from
|
|
91
|
+
:param tree_max_depth: the maximum depth of the tree to render
|
|
92
|
+
:param action_space_n: the number of actions in the action space
|
|
93
|
+
"""
|
|
73
94
|
print(start_node.__str__(colored=True, action_space_n=action_space_n))
|
|
74
95
|
for line in _generate_mcts_tree(start_node=start_node, depth=tree_max_depth):
|
|
75
96
|
print(line)
|
gymcts/logger.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -47,7 +47,7 @@ Requires-Dist: graph-matrix-jsp-env; extra == "examples"
|
|
|
47
47
|
Requires-Dist: graph-jsp-env; extra == "examples"
|
|
48
48
|
Provides-Extra: dev
|
|
49
49
|
Requires-Dist: jsp-instance-utils; extra == "dev"
|
|
50
|
-
Requires-Dist: graph-matrix-jsp-env; extra == "dev"
|
|
50
|
+
Requires-Dist: graph-matrix-jsp-env>=0.3.0; extra == "dev"
|
|
51
51
|
Requires-Dist: graph-jsp-env; extra == "dev"
|
|
52
52
|
Requires-Dist: JSSEnv; extra == "dev"
|
|
53
53
|
Requires-Dist: pip-tools; extra == "dev"
|
|
@@ -59,21 +59,31 @@ Requires-Dist: stable_baselines3; extra == "dev"
|
|
|
59
59
|
Requires-Dist: sphinx; extra == "dev"
|
|
60
60
|
Requires-Dist: myst-parser; extra == "dev"
|
|
61
61
|
Requires-Dist: sphinx-autobuild; extra == "dev"
|
|
62
|
+
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
62
63
|
Requires-Dist: furo; extra == "dev"
|
|
63
64
|
Requires-Dist: twine; extra == "dev"
|
|
64
65
|
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
65
66
|
Requires-Dist: nbsphinx; extra == "dev"
|
|
67
|
+
Requires-Dist: pandoc; extra == "dev"
|
|
66
68
|
Requires-Dist: jupytext; extra == "dev"
|
|
67
69
|
Requires-Dist: jupyter; extra == "dev"
|
|
70
|
+
Requires-Dist: typing_extensions>=4.12.0; extra == "dev"
|
|
68
71
|
Dynamic: license-file
|
|
69
72
|
|
|
70
|
-
|
|
73
|
+
[](https://doi.org/10.5281/zenodo.15283390)
|
|
74
|
+
[](https://www.python.org/downloads/)
|
|
75
|
+
[](https://pypi.org/project/gymcts/)
|
|
76
|
+
[](https://github.com/Alexander-Nasuta/gymcts/blob/master/LICENSE)
|
|
77
|
+
[](https://gymcts.readthedocs.io/en/latest/?badge=latest)
|
|
78
|
+
|
|
79
|
+
# GYMCTS
|
|
71
80
|
|
|
72
81
|
A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
|
|
73
82
|
|
|
74
|
-
- Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/
|
|
75
|
-
-
|
|
76
|
-
-
|
|
83
|
+
- Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
|
|
84
|
+
- GitLab: [GYMCTS on GitLab](https://git-ce.rwth-aachen.de/alexander.nasuta/gymcts)
|
|
85
|
+
- Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
|
|
86
|
+
- Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
|
|
77
87
|
|
|
78
88
|
## Description
|
|
79
89
|
|
|
@@ -101,22 +111,26 @@ The usage of a MCTS agent can roughly organised into the following steps:
|
|
|
101
111
|
- Render the solution
|
|
102
112
|
|
|
103
113
|
The GYMCTS package provides a two types of wrappers for Gymnasium-style environments:
|
|
104
|
-
- `
|
|
105
|
-
- `
|
|
114
|
+
- `DeepCopyMCTSGymEnvWrapper`: A wrapper that uses deepcopies of the environment to save a snapshot of the environment state for each node in the MCTS tree.
|
|
115
|
+
- `ActionHistoryMCTSGymEnvWrapper`: A wrapper that saves the action sequence that lead to the current state in the MCTS node.
|
|
106
116
|
|
|
107
|
-
These wrappers can be used with the `
|
|
108
|
-
The wrapper implement methods that are required by the `
|
|
117
|
+
These wrappers can be used with the `GymctsAgent` to solve the environment.
|
|
118
|
+
The wrapper implement methods that are required by the `GymctsAgent` to interact with the environment.
|
|
109
119
|
GYMCTS is designed to use a single environment instance and reconstructing the environment state form a state snapshot, when needed.
|
|
110
120
|
|
|
111
121
|
NOTE: MCTS works best when the return of an episode is in the range of [-1, 1]. Please adjust the reward function of the environment accordingly (or change the ubc-scaling parameter of the MCTS agent).
|
|
112
122
|
Adjusting the reward function of the environment is easily done with a [NormalizeReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.NormalizeReward) or [TransformReward](https://gymnasium.farama.org/api/wrappers/reward_wrappers/#gymnasium.wrappers.TransformReward) Wrapper.
|
|
123
|
+
```python
|
|
124
|
+
env = NormalizeReward(env, gamma=0.99, epsilon=1e-8)
|
|
125
|
+
```
|
|
113
126
|
|
|
114
|
-
|
|
115
|
-
env = TransformReward(env, lambda r: r /
|
|
116
|
-
|
|
127
|
+
```python
|
|
128
|
+
env = TransformReward(env, lambda r: r / n_steps_per_episode)
|
|
129
|
+
```
|
|
130
|
+
### FrozenLake Example (DeepCopyMCTSGymEnvWrapper)
|
|
117
131
|
|
|
118
132
|
A minimal example of how to use the package with the FrozenLake environment and the NaiveSoloMCTSGymEnvWrapper is provided in the following code snippet below.
|
|
119
|
-
The
|
|
133
|
+
The DeepCopyMCTSGymEnvWrapper can be used with non-deterministic environments, such as the FrozenLake environment with slippery ice.
|
|
120
134
|
|
|
121
135
|
```python
|
|
122
136
|
import gymnasium as gym
|
|
@@ -135,7 +149,7 @@ if __name__ == '__main__':
|
|
|
135
149
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True, render_mode="ansi")
|
|
136
150
|
env.reset()
|
|
137
151
|
|
|
138
|
-
# 1. wrap the environment with the
|
|
152
|
+
# 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
|
|
139
153
|
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
140
154
|
|
|
141
155
|
# 2. create the agent
|
|
@@ -158,7 +172,7 @@ if __name__ == '__main__':
|
|
|
158
172
|
|
|
159
173
|
# 5. print the solution
|
|
160
174
|
# read the solution from the info provided by the RecordEpisodeStatistics wrapper
|
|
161
|
-
# (that
|
|
175
|
+
# (that DeepCopyMCTSGymEnvWrapper uses internally)
|
|
162
176
|
episode_length = info["episode"]["l"]
|
|
163
177
|
episode_return = info["episode"]["r"]
|
|
164
178
|
|
|
@@ -251,7 +265,7 @@ if __name__ == '__main__':
|
|
|
251
265
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="rgb_array")
|
|
252
266
|
env.reset()
|
|
253
267
|
|
|
254
|
-
# 1. wrap the environment with the
|
|
268
|
+
# 1. wrap the environment with the deep copy wrapper or a custom gymcts wrapper
|
|
255
269
|
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
256
270
|
|
|
257
271
|
# 2. create the agent
|
|
@@ -280,7 +294,7 @@ if __name__ == '__main__':
|
|
|
280
294
|
env.close()
|
|
281
295
|
|
|
282
296
|
# 5. print the solution
|
|
283
|
-
# read the solution from the info provided by the RecordEpisodeStatistics wrapper (that
|
|
297
|
+
# read the solution from the info provided by the RecordEpisodeStatistics wrapper (that DeepCopyMCTSGymEnvWrapper wraps internally)
|
|
284
298
|
episode_length = info["episode"]["l"]
|
|
285
299
|
episode_return = info["episode"]["r"]
|
|
286
300
|
|
|
@@ -321,13 +335,13 @@ import gymnasium as gym
|
|
|
321
335
|
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
|
|
322
336
|
from jsp_instance_utils.instances import ft06, ft06_makespan
|
|
323
337
|
|
|
324
|
-
from gymcts.gymcts_agent import
|
|
325
|
-
from gymcts.
|
|
338
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
339
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
326
340
|
|
|
327
341
|
from gymcts.logger import log
|
|
328
342
|
|
|
329
343
|
|
|
330
|
-
class GraphJspGYMCTSWrapper(
|
|
344
|
+
class GraphJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
|
|
331
345
|
|
|
332
346
|
def __init__(self, env: DisjunctiveGraphJspEnv):
|
|
333
347
|
gym.Wrapper.__init__(self, env)
|
|
@@ -378,7 +392,7 @@ if __name__ == '__main__':
|
|
|
378
392
|
|
|
379
393
|
env = GraphJspGYMCTSWrapper(env)
|
|
380
394
|
|
|
381
|
-
agent =
|
|
395
|
+
agent = GymctsAgent(
|
|
382
396
|
env=env,
|
|
383
397
|
clear_mcts_tree_after_step=True,
|
|
384
398
|
render_tree_after_step=True,
|
|
@@ -421,7 +435,6 @@ import gymnasium as gym
|
|
|
421
435
|
|
|
422
436
|
from gymcts.gymcts_agent import GymctsAgent
|
|
423
437
|
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
424
|
-
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
425
438
|
|
|
426
439
|
from gymcts.logger import log
|
|
427
440
|
|
|
@@ -434,7 +447,7 @@ if __name__ == '__main__':
|
|
|
434
447
|
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode="ansi")
|
|
435
448
|
env.reset()
|
|
436
449
|
|
|
437
|
-
# wrap the environment with the
|
|
450
|
+
# wrap the environment with the wrapper or a custom gymcts wrapper
|
|
438
451
|
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
439
452
|
|
|
440
453
|
# create the agent
|
|
@@ -505,11 +518,11 @@ clone the repository in your favorite code editor (for example PyCharm, VSCode,
|
|
|
505
518
|
|
|
506
519
|
using https:
|
|
507
520
|
```shell
|
|
508
|
-
git clone https://github.com/Alexander-Nasuta/
|
|
521
|
+
git clone https://github.com/Alexander-Nasuta/gymcts.git
|
|
509
522
|
```
|
|
510
523
|
or by using the GitHub CLI:
|
|
511
524
|
```shell
|
|
512
|
-
gh repo clone Alexander-Nasuta/
|
|
525
|
+
gh repo clone Alexander-Nasuta/gymcts
|
|
513
526
|
```
|
|
514
527
|
|
|
515
528
|
if you are using PyCharm, I recommend doing the following additional steps:
|
|
@@ -518,9 +531,6 @@ if you are using PyCharm, I recommend doing the following additional steps:
|
|
|
518
531
|
- mark the `tests` folder as test root (by right-clicking on the folder and selecting `Mark Directory as` -> `Test Sources Root`)
|
|
519
532
|
- mark the `resources` folder as resources root (by right-clicking on the folder and selecting `Mark Directory as` -> `Resources Root`)
|
|
520
533
|
|
|
521
|
-
at the end your project structure should look like this:
|
|
522
|
-
|
|
523
|
-
todo
|
|
524
534
|
|
|
525
535
|
### Create a Virtual Environment (optional)
|
|
526
536
|
|
|
@@ -576,9 +586,6 @@ This project uses `pytest` for testing. To run the tests, run the following comm
|
|
|
576
586
|
```shell
|
|
577
587
|
pytest
|
|
578
588
|
```
|
|
579
|
-
Here is a screenshot of what the output might look like:
|
|
580
|
-
|
|
581
|
-

|
|
582
589
|
|
|
583
590
|
For testing with `tox` run the following command:
|
|
584
591
|
|
|
@@ -586,12 +593,6 @@ For testing with `tox` run the following command:
|
|
|
586
593
|
tox
|
|
587
594
|
```
|
|
588
595
|
|
|
589
|
-
Here is a screenshot of what the output might look like:
|
|
590
|
-
|
|
591
|
-

|
|
592
|
-
|
|
593
|
-
Tox will run the tests in a separate environment and will also check if the requirements are installed correctly.
|
|
594
|
-
|
|
595
596
|
### Builing and Publishing the Project to PyPi
|
|
596
597
|
|
|
597
598
|
In order to publish the project to PyPi, the project needs to be built and then uploaded to PyPi.
|
|
@@ -630,7 +631,6 @@ sphinx-autobuild ./docs/source/ ./docs/build/html/
|
|
|
630
631
|
This project features most of the extensions featured in this Tutorial: [Document Your Scientific Project With Markdown, Sphinx, and Read the Docs | PyData Global 2021](https://www.youtube.com/watch?v=qRSb299awB0).
|
|
631
632
|
|
|
632
633
|
|
|
633
|
-
|
|
634
634
|
## Contact
|
|
635
635
|
|
|
636
636
|
If you have any questions or feedback, feel free to contact me via [email](mailto:alexander.nasuta@wzl-iqs.rwth-aachen.de) or open an issue on repository.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
gymcts/colorful_console_utils.py,sha256=n7nymC8kKZnA_8nXcdn201NAzjZjgEHfKpbBcnl4oAE,5891
|
|
3
|
+
gymcts/gymcts_action_history_wrapper.py,sha256=7-p17Fgb80SRCBaCm6G8SJrEPsl2Y4aIO3InviuQP08,6993
|
|
4
|
+
gymcts/gymcts_agent.py,sha256=OAcN2-mFCR2AVJrRZlRtROF_zHk90SIM-uAebKektIc,10768
|
|
5
|
+
gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
|
|
6
|
+
gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
|
|
7
|
+
gymcts/gymcts_env_abc.py,sha256=iqrFNNSa-kZyAGk1UN2BjkdkV6NufAkYJT8d7PlQ07E,2525
|
|
8
|
+
gymcts/gymcts_neural_agent.py,sha256=urYGA5D6idChPke8Ac9zqhKy2NqkJzt3Zt-j8V6OpuQ,15785
|
|
9
|
+
gymcts/gymcts_node.py,sha256=-YKfK5fryPteCp-UTsAgzFVIBucZdXPMbXHCIb6mS24,13151
|
|
10
|
+
gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
|
|
11
|
+
gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
|
|
12
|
+
gymcts-1.3.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
13
|
+
gymcts-1.3.0.dist-info/METADATA,sha256=pyhdSu_PAMi9IbVeSsHU0EcJSasAMttrtz-pKIjbePw,23864
|
|
14
|
+
gymcts-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
gymcts-1.3.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
16
|
+
gymcts-1.3.0.dist-info/RECORD,,
|
gymcts-1.2.0.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
gymcts/colorful_console_utils.py,sha256=OhULcXHKbEA4uJDAEYCTcW6wUv0LsHX_XSYzZ_Szsv4,4553
|
|
3
|
-
gymcts/gymcts_action_history_wrapper.py,sha256=AjvBBwd1t9-nTYP09aMdlScAkFNXf5vOagejpjWYOPo,3810
|
|
4
|
-
gymcts/gymcts_agent.py,sha256=O2y98jKFjR5TzqVV7DO1jlcYDyzAgd_H2RF4-w4NP0g,8499
|
|
5
|
-
gymcts/gymcts_deepcopy_wrapper.py,sha256=OleQTnvxv3gLEo8-2asyeo-CpZ4HEbgyFGS5DTCD7NM,4167
|
|
6
|
-
gymcts/gymcts_distributed_agent.py,sha256=M7dyBfC8u3M99PJFoXKgIc_CPTyHGppmktkH-y9ci4U,10448
|
|
7
|
-
gymcts/gymcts_env_abc.py,sha256=7nCRiiClmmVLX-d_Q1dxeztmuvmAtmWZwjT81zrG1_w,575
|
|
8
|
-
gymcts/gymcts_node.py,sha256=PT_YZFwt1zjuvd8i9Wb5LEkHAqmJOFyPDp3GFD05lqM,7138
|
|
9
|
-
gymcts/gymcts_tree_plotter.py,sha256=eg207wHcDepwWODXzmDYQn1Aai29Cs4jFS1HNvAhlXs,2651
|
|
10
|
-
gymcts/logger.py,sha256=nAkUa4djiuCR7hF0EUsplhqFHCp76QcOX1cV3lIPzOI,937
|
|
11
|
-
gymcts-1.2.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
12
|
-
gymcts-1.2.0.dist-info/METADATA,sha256=zhEIFo0rOnv5hCv6ukImkq-9nshO4EfXMbHlhNlYhyA,23640
|
|
13
|
-
gymcts-1.2.0.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
|
14
|
-
gymcts-1.2.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
15
|
-
gymcts-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|