gymcts 1.0.0__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gymcts-1.0.0/src/gymcts.egg-info → gymcts-1.2.0}/PKG-INFO +25 -23
- {gymcts-1.0.0 → gymcts-1.2.0}/README.md +18 -19
- {gymcts-1.0.0 → gymcts-1.2.0}/pyproject.toml +6 -3
- {gymcts-1.0.0 → gymcts-1.2.0}/setup.cfg +2 -2
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts/colorful_console_utils.py +4 -3
- gymcts-1.0.0/src/gymcts/gymcts_deterministic_wrapper.py → gymcts-1.2.0/src/gymcts/gymcts_action_history_wrapper.py +2 -2
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts/gymcts_agent.py +24 -68
- gymcts-1.0.0/src/gymcts/gymcts_naive_wrapper.py → gymcts-1.2.0/src/gymcts/gymcts_deepcopy_wrapper.py +2 -2
- gymcts-1.2.0/src/gymcts/gymcts_distributed_agent.py +281 -0
- gymcts-1.0.0/src/gymcts/gymcts_gym_env.py → gymcts-1.2.0/src/gymcts/gymcts_env_abc.py +1 -1
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts/gymcts_node.py +25 -39
- gymcts-1.2.0/src/gymcts/gymcts_tree_plotter.py +75 -0
- {gymcts-1.0.0 → gymcts-1.2.0/src/gymcts.egg-info}/PKG-INFO +25 -23
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts.egg-info/SOURCES.txt +5 -3
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts.egg-info/requires.txt +2 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/tests/test_graph_matrix_jsp_env.py +6 -15
- {gymcts-1.0.0 → gymcts-1.2.0}/tests/test_gymnasium_envs.py +8 -8
- {gymcts-1.0.0 → gymcts-1.2.0}/tests/test_number_of_visits.py +6 -8
- {gymcts-1.0.0 → gymcts-1.2.0}/LICENSE +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/MANIFEST.in +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/setup.py +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts/__init__.py +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts/logger.py +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts.egg-info/dependency_links.txt +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts.egg-info/not-zip-safe +0 -0
- {gymcts-1.0.0 → gymcts-1.2.0}/src/gymcts.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -25,7 +25,7 @@ License: MIT License
|
|
|
25
25
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
26
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
27
|
SOFTWARE.
|
|
28
|
-
Project-URL: Homepage, https://github.com/Alexander-Nasuta/
|
|
28
|
+
Project-URL: Homepage, https://github.com/Alexander-Nasuta/gymcts
|
|
29
29
|
Platform: unix
|
|
30
30
|
Platform: linux
|
|
31
31
|
Platform: osx
|
|
@@ -34,7 +34,7 @@ Platform: win32
|
|
|
34
34
|
Classifier: License :: OSI Approved :: MIT License
|
|
35
35
|
Classifier: Programming Language :: Python
|
|
36
36
|
Classifier: Programming Language :: Python :: 3
|
|
37
|
-
Requires-Python: >=3.
|
|
37
|
+
Requires-Python: >=3.11
|
|
38
38
|
Description-Content-Type: text/markdown
|
|
39
39
|
License-File: LICENSE
|
|
40
40
|
Requires-Dist: rich
|
|
@@ -63,6 +63,9 @@ Requires-Dist: furo; extra == "dev"
|
|
|
63
63
|
Requires-Dist: twine; extra == "dev"
|
|
64
64
|
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
65
65
|
Requires-Dist: nbsphinx; extra == "dev"
|
|
66
|
+
Requires-Dist: jupytext; extra == "dev"
|
|
67
|
+
Requires-Dist: jupyter; extra == "dev"
|
|
68
|
+
Dynamic: license-file
|
|
66
69
|
|
|
67
70
|
# Graph Matrix Job Shop Env
|
|
68
71
|
|
|
@@ -118,8 +121,8 @@ The NaiveSoloMCTSGymEnvWrapper can be used with non-deterministic environments,
|
|
|
118
121
|
```python
|
|
119
122
|
import gymnasium as gym
|
|
120
123
|
|
|
121
|
-
from gymcts.gymcts_agent import
|
|
122
|
-
from gymcts.
|
|
124
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
125
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
123
126
|
|
|
124
127
|
from gymcts.logger import log
|
|
125
128
|
|
|
@@ -133,10 +136,10 @@ if __name__ == '__main__':
|
|
|
133
136
|
env.reset()
|
|
134
137
|
|
|
135
138
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
136
|
-
env =
|
|
139
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
137
140
|
|
|
138
141
|
# 2. create the agent
|
|
139
|
-
agent =
|
|
142
|
+
agent = GymctsAgent(
|
|
140
143
|
env=env,
|
|
141
144
|
clear_mcts_tree_after_step=False,
|
|
142
145
|
render_tree_after_step=True,
|
|
@@ -170,13 +173,13 @@ if __name__ == '__main__':
|
|
|
170
173
|
A minimal example of how to use the package with the FrozenLake environment and the DeterministicSoloMCTSGymEnvWrapper is provided in the following code snippet below.
|
|
171
174
|
The DeterministicSoloMCTSGymEnvWrapper can be used with deterministic environments, such as the FrozenLake environment without slippery ice.
|
|
172
175
|
|
|
173
|
-
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
176
|
+
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
174
177
|
|
|
175
178
|
```python
|
|
176
179
|
import gymnasium as gym
|
|
177
180
|
|
|
178
|
-
from gymcts.gymcts_agent import
|
|
179
|
-
from gymcts.
|
|
181
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
182
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
180
183
|
|
|
181
184
|
from gymcts.logger import log
|
|
182
185
|
|
|
@@ -190,10 +193,10 @@ if __name__ == '__main__':
|
|
|
190
193
|
env.reset()
|
|
191
194
|
|
|
192
195
|
# 1. wrap the environment with the wrapper
|
|
193
|
-
env =
|
|
196
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
194
197
|
|
|
195
198
|
# 2. create the agent
|
|
196
|
-
agent =
|
|
199
|
+
agent = GymctsAgent(
|
|
197
200
|
env=env,
|
|
198
201
|
clear_mcts_tree_after_step=False,
|
|
199
202
|
render_tree_after_step=True,
|
|
@@ -232,8 +235,8 @@ To create a video of the solution of the FrozenLake environment, you can use the
|
|
|
232
235
|
```python
|
|
233
236
|
import gymnasium as gym
|
|
234
237
|
|
|
235
|
-
from gymcts.gymcts_agent import
|
|
236
|
-
from gymcts.
|
|
238
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
239
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
237
240
|
|
|
238
241
|
from gymcts.logger import log
|
|
239
242
|
|
|
@@ -249,10 +252,10 @@ if __name__ == '__main__':
|
|
|
249
252
|
env.reset()
|
|
250
253
|
|
|
251
254
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
252
|
-
env =
|
|
255
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
253
256
|
|
|
254
257
|
# 2. create the agent
|
|
255
|
-
agent =
|
|
258
|
+
agent = GymctsAgent(
|
|
256
259
|
env=env,
|
|
257
260
|
clear_mcts_tree_after_step=False,
|
|
258
261
|
render_tree_after_step=True,
|
|
@@ -413,13 +416,12 @@ The color gradient is based on the minimum and maximum values of the respective
|
|
|
413
416
|
The visualisation is rendered in the terminal and can be limited to a certain depth of the tree.
|
|
414
417
|
The default depth is 2.
|
|
415
418
|
|
|
416
|
-
|
|
417
419
|
```python
|
|
418
420
|
import gymnasium as gym
|
|
419
421
|
|
|
420
|
-
from gymcts.gymcts_agent import
|
|
421
|
-
from gymcts.
|
|
422
|
-
from gymcts.
|
|
422
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
423
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
424
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
423
425
|
|
|
424
426
|
from gymcts.logger import log
|
|
425
427
|
|
|
@@ -433,10 +435,10 @@ if __name__ == '__main__':
|
|
|
433
435
|
env.reset()
|
|
434
436
|
|
|
435
437
|
# wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
436
|
-
env =
|
|
438
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
437
439
|
|
|
438
440
|
# create the agent
|
|
439
|
-
agent =
|
|
441
|
+
agent = GymctsAgent(
|
|
440
442
|
env=env,
|
|
441
443
|
clear_mcts_tree_after_step=False,
|
|
442
444
|
render_tree_after_step=False,
|
|
@@ -52,8 +52,8 @@ The NaiveSoloMCTSGymEnvWrapper can be used with non-deterministic environments,
|
|
|
52
52
|
```python
|
|
53
53
|
import gymnasium as gym
|
|
54
54
|
|
|
55
|
-
from gymcts.gymcts_agent import
|
|
56
|
-
from gymcts.
|
|
55
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
56
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
57
57
|
|
|
58
58
|
from gymcts.logger import log
|
|
59
59
|
|
|
@@ -67,10 +67,10 @@ if __name__ == '__main__':
|
|
|
67
67
|
env.reset()
|
|
68
68
|
|
|
69
69
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
70
|
-
env =
|
|
70
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
71
71
|
|
|
72
72
|
# 2. create the agent
|
|
73
|
-
agent =
|
|
73
|
+
agent = GymctsAgent(
|
|
74
74
|
env=env,
|
|
75
75
|
clear_mcts_tree_after_step=False,
|
|
76
76
|
render_tree_after_step=True,
|
|
@@ -104,13 +104,13 @@ if __name__ == '__main__':
|
|
|
104
104
|
A minimal example of how to use the package with the FrozenLake environment and the DeterministicSoloMCTSGymEnvWrapper is provided in the following code snippet below.
|
|
105
105
|
The DeterministicSoloMCTSGymEnvWrapper can be used with deterministic environments, such as the FrozenLake environment without slippery ice.
|
|
106
106
|
|
|
107
|
-
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
107
|
+
The DeterministicSoloMCTSGymEnvWrapper saves the action sequence that lead to the current state in the MCTS node.
|
|
108
108
|
|
|
109
109
|
```python
|
|
110
110
|
import gymnasium as gym
|
|
111
111
|
|
|
112
|
-
from gymcts.gymcts_agent import
|
|
113
|
-
from gymcts.
|
|
112
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
113
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
114
114
|
|
|
115
115
|
from gymcts.logger import log
|
|
116
116
|
|
|
@@ -124,10 +124,10 @@ if __name__ == '__main__':
|
|
|
124
124
|
env.reset()
|
|
125
125
|
|
|
126
126
|
# 1. wrap the environment with the wrapper
|
|
127
|
-
env =
|
|
127
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
128
128
|
|
|
129
129
|
# 2. create the agent
|
|
130
|
-
agent =
|
|
130
|
+
agent = GymctsAgent(
|
|
131
131
|
env=env,
|
|
132
132
|
clear_mcts_tree_after_step=False,
|
|
133
133
|
render_tree_after_step=True,
|
|
@@ -166,8 +166,8 @@ To create a video of the solution of the FrozenLake environment, you can use the
|
|
|
166
166
|
```python
|
|
167
167
|
import gymnasium as gym
|
|
168
168
|
|
|
169
|
-
from gymcts.gymcts_agent import
|
|
170
|
-
from gymcts.
|
|
169
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
170
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
171
171
|
|
|
172
172
|
from gymcts.logger import log
|
|
173
173
|
|
|
@@ -183,10 +183,10 @@ if __name__ == '__main__':
|
|
|
183
183
|
env.reset()
|
|
184
184
|
|
|
185
185
|
# 1. wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
186
|
-
env =
|
|
186
|
+
env = DeepCopyMCTSGymEnvWrapper(env)
|
|
187
187
|
|
|
188
188
|
# 2. create the agent
|
|
189
|
-
agent =
|
|
189
|
+
agent = GymctsAgent(
|
|
190
190
|
env=env,
|
|
191
191
|
clear_mcts_tree_after_step=False,
|
|
192
192
|
render_tree_after_step=True,
|
|
@@ -347,13 +347,12 @@ The color gradient is based on the minimum and maximum values of the respective
|
|
|
347
347
|
The visualisation is rendered in the terminal and can be limited to a certain depth of the tree.
|
|
348
348
|
The default depth is 2.
|
|
349
349
|
|
|
350
|
-
|
|
351
350
|
```python
|
|
352
351
|
import gymnasium as gym
|
|
353
352
|
|
|
354
|
-
from gymcts.gymcts_agent import
|
|
355
|
-
from gymcts.
|
|
356
|
-
from gymcts.
|
|
353
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
354
|
+
from gymcts.gymcts_action_history_wrapper import ActionHistoryMCTSGymEnvWrapper
|
|
355
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
357
356
|
|
|
358
357
|
from gymcts.logger import log
|
|
359
358
|
|
|
@@ -367,10 +366,10 @@ if __name__ == '__main__':
|
|
|
367
366
|
env.reset()
|
|
368
367
|
|
|
369
368
|
# wrap the environment with the naive wrapper or a custom gymcts wrapper
|
|
370
|
-
env =
|
|
369
|
+
env = ActionHistoryMCTSGymEnvWrapper(env)
|
|
371
370
|
|
|
372
371
|
# create the agent
|
|
373
|
-
agent =
|
|
372
|
+
agent = GymctsAgent(
|
|
374
373
|
env=env,
|
|
375
374
|
clear_mcts_tree_after_step=False,
|
|
376
375
|
render_tree_after_step=False,
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "gymcts"
|
|
7
|
-
version = "1.
|
|
7
|
+
version = "1.2.0"
|
|
8
8
|
description = "A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
authors = [{ name = "Alexander Nasuta", email = "alexander.nasuta@wzl-iqs.rwth-aachen.de" }]
|
|
@@ -21,7 +21,7 @@ dependencies = [
|
|
|
21
21
|
"gymnasium",
|
|
22
22
|
"matplotlib<3.9",
|
|
23
23
|
]
|
|
24
|
-
requires-python = ">=3.
|
|
24
|
+
requires-python = ">=3.11"
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
[project.optional-dependencies]
|
|
@@ -53,10 +53,13 @@ dev = [
|
|
|
53
53
|
"twine",
|
|
54
54
|
"sphinx-copybutton", # for code copy buttons
|
|
55
55
|
"nbsphinx", # for jupyter notebook support in sphinx
|
|
56
|
+
|
|
57
|
+
"jupytext", # converting .py examples to jupyter notebook jupytext --to notebook *.py
|
|
58
|
+
"jupyter", # for jupyter notebook kernel
|
|
56
59
|
]
|
|
57
60
|
|
|
58
61
|
[project.urls]
|
|
59
|
-
Homepage = "https://github.com/Alexander-Nasuta/
|
|
62
|
+
Homepage = "https://github.com/Alexander-Nasuta/gymcts"
|
|
60
63
|
|
|
61
64
|
[tool.pytest.ini_options]
|
|
62
65
|
addopts = "--cov=gymcts -p no:warnings"
|
|
@@ -7,12 +7,12 @@ platforms = unix, linux, osx, cygwin, win32
|
|
|
7
7
|
classifiers =
|
|
8
8
|
Programming Language :: Python :: 3
|
|
9
9
|
Programming Language :: Python :: 3 :: Only
|
|
10
|
-
Programming Language :: Python :: 3.
|
|
10
|
+
Programming Language :: Python :: 3.11
|
|
11
11
|
|
|
12
12
|
[options]
|
|
13
13
|
packages =
|
|
14
14
|
gymcts
|
|
15
|
-
python_requires = >=3.
|
|
15
|
+
python_requires = >=3.11
|
|
16
16
|
package_dir =
|
|
17
17
|
=src
|
|
18
18
|
zip_safe = no
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
import matplotlib.pyplot as plt
|
|
2
4
|
import numpy as np
|
|
3
5
|
|
|
@@ -103,8 +105,7 @@ def wrap_with_color_codes(s: object, /, r: int | float, g: int | float, b: int |
|
|
|
103
105
|
f"{CEND}"
|
|
104
106
|
|
|
105
107
|
|
|
106
|
-
|
|
107
|
-
def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbow") -> str:
|
|
108
|
+
def wrap_evenly_spaced_color(s: Any, n_of_item: int, n_classes: int, c_map="rainbow") -> str:
|
|
108
109
|
if s is None or n_of_item is None or n_classes is None:
|
|
109
110
|
return s
|
|
110
111
|
|
|
@@ -117,7 +118,7 @@ def wrap_evenly_spaced_color(s: str, n_of_item:int, n_classes:int, c_map="rainbo
|
|
|
117
118
|
return f"{color_asni}{s}{CEND}"
|
|
118
119
|
|
|
119
120
|
|
|
120
|
-
def wrap_with_color_scale(s: str, value: float, min_val:float, max_val:float, c_map=None) -> str:
|
|
121
|
+
def wrap_with_color_scale(s: str, value: float, min_val: float, max_val: float, c_map=None) -> str:
|
|
121
122
|
if s is None or min_val is None or max_val is None or min_val >= max_val:
|
|
122
123
|
return s
|
|
123
124
|
|
|
@@ -7,12 +7,12 @@ import gymnasium as gym
|
|
|
7
7
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
8
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
9
|
|
|
10
|
-
from gymcts.
|
|
10
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
11
|
|
|
12
12
|
from gymcts.logger import log
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class ActionHistoryMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
16
|
_terminal_flag: bool = False
|
|
17
17
|
_last_reward: SupportsFloat = 0
|
|
18
18
|
_step_tuple: tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]] = None
|
|
@@ -3,27 +3,28 @@ import gymnasium as gym
|
|
|
3
3
|
|
|
4
4
|
from typing import TypeVar, Any, SupportsFloat, Callable
|
|
5
5
|
|
|
6
|
-
from gymcts.
|
|
7
|
-
from gymcts.
|
|
8
|
-
from gymcts.gymcts_node import
|
|
6
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
7
|
+
from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
|
|
8
|
+
from gymcts.gymcts_node import GymctsNode
|
|
9
|
+
from gymcts.gymcts_tree_plotter import _generate_mcts_tree
|
|
9
10
|
|
|
10
11
|
from gymcts.logger import log
|
|
11
12
|
|
|
12
13
|
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
class
|
|
16
|
+
class GymctsAgent:
|
|
16
17
|
render_tree_after_step: bool = False
|
|
17
18
|
render_tree_max_depth: int = 2
|
|
18
19
|
exclude_unvisited_nodes_from_render: bool = False
|
|
19
20
|
number_of_simulations_per_step: int = 25
|
|
20
21
|
|
|
21
|
-
env:
|
|
22
|
-
search_root_node:
|
|
22
|
+
env: GymctsABC
|
|
23
|
+
search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
|
|
23
24
|
clear_mcts_tree_after_step: bool
|
|
24
25
|
|
|
25
26
|
def __init__(self,
|
|
26
|
-
env:
|
|
27
|
+
env: GymctsABC,
|
|
27
28
|
clear_mcts_tree_after_step: bool = True,
|
|
28
29
|
render_tree_after_step: bool = False,
|
|
29
30
|
render_tree_max_depth: int = 2,
|
|
@@ -43,13 +44,13 @@ class SoloMCTSAgent:
|
|
|
43
44
|
self.env = env
|
|
44
45
|
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
45
46
|
|
|
46
|
-
self.search_root_node =
|
|
47
|
+
self.search_root_node = GymctsNode(
|
|
47
48
|
action=None,
|
|
48
49
|
parent=None,
|
|
49
50
|
env_reference=env,
|
|
50
51
|
)
|
|
51
52
|
|
|
52
|
-
def navigate_to_leaf(self, from_node:
|
|
53
|
+
def navigate_to_leaf(self, from_node: GymctsNode) -> GymctsNode:
|
|
53
54
|
log.debug(f"Navigate to leaf. from_node: {from_node}")
|
|
54
55
|
if from_node.terminal:
|
|
55
56
|
log.debug("Node is terminal. Returning from_node")
|
|
@@ -66,7 +67,7 @@ class SoloMCTSAgent:
|
|
|
66
67
|
log.debug(f"Selected leaf node: {temp_node}")
|
|
67
68
|
return temp_node
|
|
68
69
|
|
|
69
|
-
def expand_node(self, node:
|
|
70
|
+
def expand_node(self, node: GymctsNode) -> None:
|
|
70
71
|
log.debug(f"expanding node: {node}")
|
|
71
72
|
# EXPANSION STRATEGY
|
|
72
73
|
# expand all children
|
|
@@ -78,7 +79,7 @@ class SoloMCTSAgent:
|
|
|
78
79
|
self._load_state(node)
|
|
79
80
|
|
|
80
81
|
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
81
|
-
child_dict[action] =
|
|
82
|
+
child_dict[action] = GymctsNode(
|
|
82
83
|
action=action,
|
|
83
84
|
parent=node,
|
|
84
85
|
env_reference=self.env,
|
|
@@ -110,14 +111,14 @@ class SoloMCTSAgent:
|
|
|
110
111
|
# restore state of current node
|
|
111
112
|
return action_list
|
|
112
113
|
|
|
113
|
-
def _load_state(self, node:
|
|
114
|
-
if isinstance(self.env,
|
|
114
|
+
def _load_state(self, node: GymctsNode) -> None:
|
|
115
|
+
if isinstance(self.env, DeepCopyMCTSGymEnvWrapper):
|
|
115
116
|
self.env = copy.deepcopy(node.state)
|
|
116
117
|
else:
|
|
117
118
|
self.env.load_state(node.state)
|
|
118
119
|
|
|
119
|
-
def perform_mcts_step(self, search_start_node:
|
|
120
|
-
render_tree_after_step: bool = None) -> tuple[int,
|
|
120
|
+
def perform_mcts_step(self, search_start_node: GymctsNode = None, num_simulations: int = None,
|
|
121
|
+
render_tree_after_step: bool = None) -> tuple[int, GymctsNode]:
|
|
121
122
|
|
|
122
123
|
if render_tree_after_step is None:
|
|
123
124
|
render_tree_after_step = self.render_tree_after_step
|
|
@@ -149,7 +150,7 @@ class SoloMCTSAgent:
|
|
|
149
150
|
|
|
150
151
|
return action, next_node
|
|
151
152
|
|
|
152
|
-
def vanilla_mcts_search(self, search_start_node:
|
|
153
|
+
def vanilla_mcts_search(self, search_start_node: GymctsNode = None, num_simulations=10) -> int:
|
|
153
154
|
log.debug(f"performing one MCTS search step with {num_simulations} simulations")
|
|
154
155
|
if search_start_node is None:
|
|
155
156
|
search_start_node = self.search_root_node
|
|
@@ -178,7 +179,7 @@ class SoloMCTSAgent:
|
|
|
178
179
|
|
|
179
180
|
return search_start_node.get_best_action()
|
|
180
181
|
|
|
181
|
-
def show_mcts_tree(self, start_node:
|
|
182
|
+
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
|
|
182
183
|
|
|
183
184
|
if start_node is None:
|
|
184
185
|
start_node = self.search_root_node
|
|
@@ -187,13 +188,17 @@ class SoloMCTSAgent:
|
|
|
187
188
|
tree_max_depth = self.render_tree_max_depth
|
|
188
189
|
|
|
189
190
|
print(start_node.__str__(colored=True, action_space_n=self.env.action_space.n))
|
|
190
|
-
for line in
|
|
191
|
+
for line in _generate_mcts_tree(
|
|
192
|
+
start_node=start_node,
|
|
193
|
+
depth=tree_max_depth,
|
|
194
|
+
action_space_n=self.env.action_space.n,
|
|
195
|
+
):
|
|
191
196
|
print(line)
|
|
192
197
|
|
|
193
198
|
def show_mcts_tree_from_root(self, tree_max_depth: int = None) -> None:
|
|
194
199
|
self.show_mcts_tree(start_node=self.search_root_node.get_root(), tree_max_depth=tree_max_depth)
|
|
195
200
|
|
|
196
|
-
def backpropagation(self, node:
|
|
201
|
+
def backpropagation(self, node: GymctsNode, episode_return: float) -> None:
|
|
197
202
|
log.debug(f"performing backpropagation from leaf node: {node}")
|
|
198
203
|
while not node.is_root():
|
|
199
204
|
# node.mean_value = ((node.mean_value * node.visit_count) + episode_return) / (node.visit_count + 1)
|
|
@@ -209,53 +214,4 @@ class SoloMCTSAgent:
|
|
|
209
214
|
node.max_value = max(node.max_value, episode_return)
|
|
210
215
|
node.min_value = min(node.min_value, episode_return)
|
|
211
216
|
|
|
212
|
-
def _generate_mcts_tree(self, start_node: SoloMCTSNode = None, prefix: str = None, depth: int = None) -> list[str]:
|
|
213
217
|
|
|
214
|
-
if prefix is None:
|
|
215
|
-
prefix = ""
|
|
216
|
-
import gymcts.colorful_console_utils as ccu
|
|
217
|
-
|
|
218
|
-
if start_node is None:
|
|
219
|
-
start_node = self.search_root_node
|
|
220
|
-
|
|
221
|
-
# prefix components:
|
|
222
|
-
space = ' '
|
|
223
|
-
branch = '│ '
|
|
224
|
-
# pointers:
|
|
225
|
-
tee = '├── '
|
|
226
|
-
last = '└── '
|
|
227
|
-
|
|
228
|
-
contents = start_node.children.values() if start_node.children is not None else []
|
|
229
|
-
if self.exclude_unvisited_nodes_from_render:
|
|
230
|
-
contents = [node for node in contents if node.visit_count > 0]
|
|
231
|
-
# contents each get pointers that are ├── with a final └── :
|
|
232
|
-
# pointers = [tee] * (len(contents) - 1) + [last]
|
|
233
|
-
pointers = [tee for _ in range(len(contents) - 1)] + [last]
|
|
234
|
-
|
|
235
|
-
for pointer, current_node in zip(pointers, contents):
|
|
236
|
-
n_item = current_node.parent.action if current_node.parent is not None else 0
|
|
237
|
-
n_classes = self.env.action_space.n
|
|
238
|
-
|
|
239
|
-
pointer = ccu.wrap_evenly_spaced_color(
|
|
240
|
-
s=pointer,
|
|
241
|
-
n_of_item=n_item,
|
|
242
|
-
n_classes=n_classes,
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
yield prefix + pointer + f"{current_node.__str__(colored=True, action_space_n=n_classes)}"
|
|
246
|
-
if current_node.children and len(current_node.children): # extend the prefix and recurse:
|
|
247
|
-
# extension = branch if pointer == tee else space
|
|
248
|
-
extension = branch if tee in pointer else space
|
|
249
|
-
# i.e. space because last, └── , above so no more |
|
|
250
|
-
extension = ccu.wrap_evenly_spaced_color(
|
|
251
|
-
s=extension,
|
|
252
|
-
n_of_item=n_item,
|
|
253
|
-
n_classes=n_classes,
|
|
254
|
-
)
|
|
255
|
-
if depth is not None and depth <= 0:
|
|
256
|
-
continue
|
|
257
|
-
yield from self._generate_mcts_tree(
|
|
258
|
-
current_node,
|
|
259
|
-
prefix=prefix + extension,
|
|
260
|
-
depth=depth - 1 if depth is not None else None
|
|
261
|
-
)
|
gymcts-1.0.0/src/gymcts/gymcts_naive_wrapper.py → gymcts-1.2.0/src/gymcts/gymcts_deepcopy_wrapper.py
RENAMED
|
@@ -7,12 +7,12 @@ import gymnasium as gym
|
|
|
7
7
|
from gymnasium.core import WrapperActType, WrapperObsType
|
|
8
8
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
9
9
|
|
|
10
|
-
from gymcts.
|
|
10
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
11
11
|
|
|
12
12
|
from gymcts.logger import log
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class DeepCopyMCTSGymEnvWrapper(GymctsABC, gym.Wrapper):
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
_terminal_flag:bool = False
|