gymcts 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/__init__.py +0 -0
- gymcts/colorful_console_utils.py +142 -0
- gymcts/gymcts_agent.py +261 -0
- gymcts/gymcts_deterministic_wrapper.py +107 -0
- gymcts/gymcts_gym_env.py +28 -0
- gymcts/gymcts_naive_wrapper.py +114 -0
- gymcts/gymcts_node.py +213 -0
- gymcts/logger.py +33 -0
- gymcts-1.0.0.dist-info/LICENSE +21 -0
- gymcts-1.0.0.dist-info/METADATA +634 -0
- gymcts-1.0.0.dist-info/RECORD +13 -0
- gymcts-1.0.0.dist-info/WHEEL +5 -0
- gymcts-1.0.0.dist-info/top_level.txt +1 -0
gymcts/gymcts_node.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import random
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from typing import TypeVar, Any, SupportsFloat, Callable, Generator
|
|
6
|
+
|
|
7
|
+
from gymcts.gymcts_gym_env import SoloMCTSGymEnv
|
|
8
|
+
|
|
9
|
+
from gymcts.logger import log
|
|
10
|
+
|
|
11
|
+
TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SoloMCTSNode:
|
|
17
|
+
|
|
18
|
+
# static properties
|
|
19
|
+
best_action_weight: float = 0.05
|
|
20
|
+
ubc_c = 0.707
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# attributes
|
|
24
|
+
visit_count: int = 0
|
|
25
|
+
mean_value: float = 0
|
|
26
|
+
max_value: float = -float("inf")
|
|
27
|
+
min_value: float = +float("inf")
|
|
28
|
+
terminal: bool = False
|
|
29
|
+
state: Any
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
33
|
+
if not colored:
|
|
34
|
+
|
|
35
|
+
if not self.is_root():
|
|
36
|
+
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.ucb_score():.2f})"
|
|
37
|
+
else:
|
|
38
|
+
return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
|
|
39
|
+
|
|
40
|
+
import gymcts.colorful_console_utils as ccu
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if self.is_root():
|
|
44
|
+
return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if action_space_n is None:
|
|
48
|
+
raise ValueError("action_space_n must be provided if colored is True")
|
|
49
|
+
|
|
50
|
+
p = ccu.CYELLOW
|
|
51
|
+
e = ccu.CEND
|
|
52
|
+
v = ccu.CCYAN
|
|
53
|
+
|
|
54
|
+
def colorful_value(value: float | int | None) -> str:
|
|
55
|
+
if value == None:
|
|
56
|
+
return f"{ccu.CGREY}None{e}"
|
|
57
|
+
color = ccu.CCYAN
|
|
58
|
+
if value == 0:
|
|
59
|
+
color = ccu.CRED
|
|
60
|
+
if value == float("inf"):
|
|
61
|
+
color = ccu.CGREY
|
|
62
|
+
if value == -float("inf"):
|
|
63
|
+
color = ccu.CGREY
|
|
64
|
+
|
|
65
|
+
if isinstance(value, float):
|
|
66
|
+
return f"{color}{value:.2f}{e}"
|
|
67
|
+
|
|
68
|
+
if isinstance(value, int):
|
|
69
|
+
return f"{color}{value}{e}"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
root_node = self.get_root()
|
|
73
|
+
mean_val = f"{self.mean_value:.2f}"
|
|
74
|
+
|
|
75
|
+
return ((f"("
|
|
76
|
+
f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
|
|
77
|
+
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
78
|
+
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
79
|
+
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
80
|
+
(f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def traverse_nodes(self) -> Generator[TSoloMCTSNode, None, None]:
|
|
84
|
+
yield self
|
|
85
|
+
if self.children:
|
|
86
|
+
for child in self.children.values():
|
|
87
|
+
yield from child.traverse_nodes()
|
|
88
|
+
|
|
89
|
+
def get_root(self) -> TSoloMCTSNode:
|
|
90
|
+
if self.is_root():
|
|
91
|
+
return self
|
|
92
|
+
return self.parent.get_root()
|
|
93
|
+
|
|
94
|
+
def max_tree_depth(self):
|
|
95
|
+
if self.is_leaf():
|
|
96
|
+
return 0
|
|
97
|
+
return 1 + max(child.max_tree_depth() for child in self.children.values())
|
|
98
|
+
|
|
99
|
+
def n_children_recursively(self):
|
|
100
|
+
if self.is_leaf():
|
|
101
|
+
return 0
|
|
102
|
+
return len(self.children) + sum(child.n_children_recursively() for child in self.children.values())
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def __init__(self,
|
|
107
|
+
action: int | None,
|
|
108
|
+
parent: TSoloMCTSNode | None,
|
|
109
|
+
env_reference: SoloMCTSGymEnv,
|
|
110
|
+
):
|
|
111
|
+
|
|
112
|
+
# field depending on whether the node is a root node or not
|
|
113
|
+
self.action: int | None
|
|
114
|
+
|
|
115
|
+
self.env_reference: SoloMCTSGymEnv
|
|
116
|
+
self.parent: SoloMCTSNode | None
|
|
117
|
+
self.uuid = uuid.uuid4()
|
|
118
|
+
|
|
119
|
+
if parent is None:
|
|
120
|
+
self.action = None
|
|
121
|
+
self.parent = None
|
|
122
|
+
if env_reference.is_terminal():
|
|
123
|
+
raise ValueError("Root nodes shall not be terminal.")
|
|
124
|
+
else:
|
|
125
|
+
if action is None:
|
|
126
|
+
raise ValueError("action must be provided if parent is not None")
|
|
127
|
+
|
|
128
|
+
self.action = action
|
|
129
|
+
self.parent = parent # not None
|
|
130
|
+
|
|
131
|
+
# fields that are always initialized the same way
|
|
132
|
+
self.terminal: bool = env_reference.is_terminal()
|
|
133
|
+
|
|
134
|
+
from copy import copy
|
|
135
|
+
self.state = env_reference.get_state()
|
|
136
|
+
#log.debug(f"saving state of node '{str(self)}' to memory location: {hex(id(self.state))}")
|
|
137
|
+
self.visit_count: int = 0
|
|
138
|
+
|
|
139
|
+
self.mean_value: float = 0
|
|
140
|
+
self.max_value: float = -float("inf")
|
|
141
|
+
self.min_value: float = +float("inf")
|
|
142
|
+
|
|
143
|
+
# safe valid action instead of calling the environment
|
|
144
|
+
# this reduces the compute but increases the memory usage
|
|
145
|
+
self.valid_actions: list[int] = env_reference.get_valid_actions()
|
|
146
|
+
self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def reset(self) -> None:
|
|
150
|
+
self.parent = None
|
|
151
|
+
self.visit_count: int = 0
|
|
152
|
+
|
|
153
|
+
self.mean_value: float = 0
|
|
154
|
+
self.max_value: float = -float("inf")
|
|
155
|
+
self.min_value: float = +float("inf")
|
|
156
|
+
self.children: dict[int, SoloMCTSNode] | None = None # may be expanded later
|
|
157
|
+
|
|
158
|
+
# just setting the children of the parent node to None should be enough to trigger garbage collection
|
|
159
|
+
# however, we also set the parent to None to make sure that the parent is not referenced anymore
|
|
160
|
+
if self.parent:
|
|
161
|
+
self.parent.reset()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def is_root(self) -> bool:
|
|
166
|
+
return self.parent is None
|
|
167
|
+
|
|
168
|
+
def is_leaf(self) -> bool:
|
|
169
|
+
return self.children is None or len(self.children) == 0
|
|
170
|
+
|
|
171
|
+
def get_random_child(self) -> TSoloMCTSNode:
|
|
172
|
+
if self.is_leaf():
|
|
173
|
+
raise ValueError("cannot get random child of leaf node") #todo: maybe return self instead?
|
|
174
|
+
|
|
175
|
+
return list(self.children.values())[random.randint(0, len(self.children) - 1)]
|
|
176
|
+
|
|
177
|
+
def get_best_action(self) -> int:
|
|
178
|
+
return max(self.children.values(), key=lambda child: child.get_score()).action
|
|
179
|
+
|
|
180
|
+
def get_score(self) -> float: # todo: make it an attribute?
|
|
181
|
+
# return self.mean_value
|
|
182
|
+
assert 0 <= SoloMCTSNode.best_action_weight <= 1
|
|
183
|
+
a = SoloMCTSNode.best_action_weight
|
|
184
|
+
return (1-a) * self.mean_value + a * self.max_value
|
|
185
|
+
|
|
186
|
+
def get_mean_value(self) -> float:
|
|
187
|
+
return self.mean_value
|
|
188
|
+
|
|
189
|
+
def get_max_value(self) -> float:
|
|
190
|
+
return self.max_value
|
|
191
|
+
|
|
192
|
+
def ucb_score(self):
|
|
193
|
+
"""
|
|
194
|
+
The score for an action that would transition between the parent and child.
|
|
195
|
+
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
|
|
196
|
+
|
|
197
|
+
if child.visit_count > 0:
|
|
198
|
+
# The value of the child is from the perspective of the opposing player
|
|
199
|
+
value_score = -child.value()
|
|
200
|
+
else:
|
|
201
|
+
value_score = 0
|
|
202
|
+
|
|
203
|
+
return value_score + prior_score
|
|
204
|
+
|
|
205
|
+
:return:
|
|
206
|
+
"""
|
|
207
|
+
if self.is_root():
|
|
208
|
+
raise ValueError("ucb_score can only be called on non-root nodes")
|
|
209
|
+
# c = 0.707 # todo: make it an attribute?
|
|
210
|
+
c = SoloMCTSNode.ubc_c
|
|
211
|
+
if self.visit_count == 0:
|
|
212
|
+
return float("inf")
|
|
213
|
+
return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
|
gymcts/logger.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from rich.logging import RichHandler
|
|
4
|
+
|
|
5
|
+
FORMAT = "%(message)s"
|
|
6
|
+
logging.basicConfig(
|
|
7
|
+
level=logging.INFO,
|
|
8
|
+
format=FORMAT,
|
|
9
|
+
datefmt="[%X]",
|
|
10
|
+
handlers=[RichHandler(show_path=False)]
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("rich")
|
|
14
|
+
|
|
15
|
+
banner_sw = f"""
|
|
16
|
+
|
|
17
|
+
▟█████▛▜█▙▟█▛ ▟█▙ ▟██▛▟█████▛████▛▟████▛
|
|
18
|
+
▟█▛ ▜██▛ ▟█▛██▛██▛▟█▛ ▟█▛ ▜███▙
|
|
19
|
+
▟█▛ ▟█▛ ▟█▛ ▟█▛ ▟█▛▟█▛ ▟█▛ ▟█▛
|
|
20
|
+
▜████▛ ▟█▛ ▟█▛ ▟█▛ ▜████▛ ▟█▛ ▟████▛
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if __name__ == '__main__':
|
|
29
|
+
log.debug("Hello, World!")
|
|
30
|
+
log.info("Hello, World!")
|
|
31
|
+
log.error("Hello, World!")
|
|
32
|
+
log.warning("Hello, World!")
|
|
33
|
+
log.critical("Hello, World!")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Alexander Nasuta
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|