gymcts 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/gymcts_agent.py +3 -1
- gymcts/gymcts_neural_agent.py +20 -14
- gymcts/gymcts_node.py +9 -5
- {gymcts-1.3.0.dist-info → gymcts-1.4.1.dist-info}/METADATA +1 -1
- {gymcts-1.3.0.dist-info → gymcts-1.4.1.dist-info}/RECORD +8 -8
- {gymcts-1.3.0.dist-info → gymcts-1.4.1.dist-info}/WHEEL +0 -0
- {gymcts-1.3.0.dist-info → gymcts-1.4.1.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.3.0.dist-info → gymcts-1.4.1.dist-info}/top_level.txt +0 -0
gymcts/gymcts_agent.py
CHANGED
|
@@ -49,6 +49,7 @@ class GymctsAgent:
|
|
|
49
49
|
calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
|
|
50
50
|
score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
|
|
51
51
|
best_action_weight=None,
|
|
52
|
+
keep_whole_tree_till_initial_root: bool = False,
|
|
52
53
|
):
|
|
53
54
|
# check if action space of env is discrete
|
|
54
55
|
if not isinstance(env.action_space, gym.spaces.Discrete):
|
|
@@ -79,6 +80,7 @@ class GymctsAgent:
|
|
|
79
80
|
|
|
80
81
|
self.env = env
|
|
81
82
|
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
|
|
83
|
+
self.keep_whole_tree_till_initial_root = keep_whole_tree_till_initial_root
|
|
82
84
|
|
|
83
85
|
self.search_root_node = GymctsNode(
|
|
84
86
|
action=None,
|
|
@@ -190,7 +192,7 @@ class GymctsAgent:
|
|
|
190
192
|
# we also need to reset the children of the current node
|
|
191
193
|
# this is done by calling the reset method
|
|
192
194
|
next_node.reset()
|
|
193
|
-
|
|
195
|
+
elif not self.keep_whole_tree_till_initial_root:
|
|
194
196
|
next_node.remove_parent()
|
|
195
197
|
|
|
196
198
|
self.search_root_node = next_node
|
gymcts/gymcts_neural_agent.py
CHANGED
|
@@ -162,37 +162,43 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
162
162
|
p_sa = self._selection_score_prior
|
|
163
163
|
n_s = self.parent.visit_count
|
|
164
164
|
n_sa = self.visit_count
|
|
165
|
+
|
|
166
|
+
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
167
|
+
b = GymctsNode.best_action_weight
|
|
168
|
+
exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
|
|
169
|
+
|
|
170
|
+
|
|
165
171
|
if GymctsNeuralNode.score_variate == "PUCT_v0":
|
|
166
|
-
return
|
|
172
|
+
return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
|
|
167
173
|
elif GymctsNeuralNode.score_variate == "PUCT_v1":
|
|
168
|
-
return
|
|
174
|
+
return exploitation_term + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
|
|
169
175
|
elif GymctsNeuralNode.score_variate == "PUCT_v2":
|
|
170
|
-
return
|
|
176
|
+
return exploitation_term + c * p_sa * math.sqrt(n_s) / n_sa
|
|
171
177
|
elif GymctsNeuralNode.score_variate == "PUCT_v3":
|
|
172
|
-
return
|
|
178
|
+
return exploitation_term + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
|
|
173
179
|
elif GymctsNeuralNode.score_variate == "PUCT_v4":
|
|
174
|
-
return
|
|
180
|
+
return exploitation_term + c * (p_sa / (1 + n_sa))
|
|
175
181
|
elif GymctsNeuralNode.score_variate == "PUCT_v5":
|
|
176
|
-
return
|
|
182
|
+
return exploitation_term + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
|
|
177
183
|
elif GymctsNeuralNode.score_variate == "PUCT_v6":
|
|
178
|
-
return
|
|
184
|
+
return exploitation_term + c * p_sa * n_s / (1 + n_sa)
|
|
179
185
|
elif GymctsNeuralNode.score_variate == "PUCT_v7":
|
|
180
186
|
epsilon = 1e-8
|
|
181
|
-
return
|
|
187
|
+
return exploitation_term + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
|
|
182
188
|
elif GymctsNeuralNode.score_variate == "PUCT_v8":
|
|
183
|
-
return
|
|
189
|
+
return exploitation_term + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
|
|
184
190
|
elif GymctsNeuralNode.score_variate == "PUCT_v9":
|
|
185
|
-
return
|
|
191
|
+
return exploitation_term + c * p_sa * math.sqrt(n_s / (1 + n_sa))
|
|
186
192
|
elif GymctsNeuralNode.score_variate == "PUCT_v10":
|
|
187
|
-
return
|
|
193
|
+
return exploitation_term + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
|
|
188
194
|
elif GymctsNeuralNode.score_variate == "MuZero_v0":
|
|
189
195
|
c1 = GymctsNeuralNode.MuZero_c1
|
|
190
196
|
c2 = GymctsNeuralNode.MuZero_c2
|
|
191
|
-
return
|
|
197
|
+
return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
|
|
192
198
|
elif GymctsNeuralNode.score_variate == "MuZero_v1":
|
|
193
199
|
c1 = GymctsNeuralNode.MuZero_c1
|
|
194
200
|
c2 = GymctsNeuralNode.MuZero_c2
|
|
195
|
-
return
|
|
201
|
+
return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
|
|
196
202
|
|
|
197
203
|
|
|
198
204
|
exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
|
|
@@ -221,7 +227,7 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
221
227
|
if not colored:
|
|
222
228
|
|
|
223
229
|
if not self.is_root():
|
|
224
|
-
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f},
|
|
230
|
+
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, {GymctsNeuralNode.score_variate}={self.tree_policy_score():.2f})"
|
|
225
231
|
else:
|
|
226
232
|
return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
|
|
227
233
|
|
gymcts/gymcts_node.py
CHANGED
|
@@ -61,7 +61,7 @@ class GymctsNode:
|
|
|
61
61
|
if not colored:
|
|
62
62
|
|
|
63
63
|
if not self.is_root():
|
|
64
|
-
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f},
|
|
64
|
+
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, {GymctsNode.score_variate}={self.tree_policy_score():.2f})"
|
|
65
65
|
else:
|
|
66
66
|
return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
|
|
67
67
|
|
|
@@ -102,7 +102,7 @@ class GymctsNode:
|
|
|
102
102
|
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
103
103
|
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
104
104
|
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
105
|
-
(f", {p}
|
|
105
|
+
(f", {p}{GymctsNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
|
|
106
106
|
|
|
107
107
|
def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
|
|
108
108
|
"""
|
|
@@ -326,16 +326,20 @@ class GymctsNode:
|
|
|
326
326
|
raise ValueError("ucb_score can only be called on non-root nodes")
|
|
327
327
|
c = GymctsNode.ubc_c # default is 0.707
|
|
328
328
|
|
|
329
|
+
assert 0 <= GymctsNode.best_action_weight <= 1
|
|
330
|
+
b = GymctsNode.best_action_weight
|
|
331
|
+
exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
|
|
332
|
+
|
|
329
333
|
if GymctsNode.score_variate == "UCT_v0":
|
|
330
334
|
if self.visit_count == 0:
|
|
331
335
|
return float("inf")
|
|
332
|
-
return
|
|
336
|
+
return exploitation_term + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
|
|
333
337
|
|
|
334
338
|
if GymctsNode.score_variate == "UCT_v1":
|
|
335
|
-
return
|
|
339
|
+
return exploitation_term + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
|
|
336
340
|
|
|
337
341
|
if GymctsNode.score_variate == "UCT_v2":
|
|
338
|
-
return
|
|
342
|
+
return exploitation_term + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
|
|
339
343
|
|
|
340
344
|
raise ValueError(f"unknown score variate: {GymctsNode.score_variate}. ")
|
|
341
345
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
gymcts/colorful_console_utils.py,sha256=n7nymC8kKZnA_8nXcdn201NAzjZjgEHfKpbBcnl4oAE,5891
|
|
3
3
|
gymcts/gymcts_action_history_wrapper.py,sha256=7-p17Fgb80SRCBaCm6G8SJrEPsl2Y4aIO3InviuQP08,6993
|
|
4
|
-
gymcts/gymcts_agent.py,sha256=
|
|
4
|
+
gymcts/gymcts_agent.py,sha256=FzMPjHXyKN6enNJubmYEouvb0wBbE1-bpxuLuW4J1gU,10960
|
|
5
5
|
gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
|
|
6
6
|
gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
|
|
7
7
|
gymcts/gymcts_env_abc.py,sha256=iqrFNNSa-kZyAGk1UN2BjkdkV6NufAkYJT8d7PlQ07E,2525
|
|
8
|
-
gymcts/gymcts_neural_agent.py,sha256=
|
|
9
|
-
gymcts/gymcts_node.py,sha256
|
|
8
|
+
gymcts/gymcts_neural_agent.py,sha256=kP2DwoZ6nM4lUYqePhEvUIAqmZegB0oxQ3uMtMFj-Hk,16049
|
|
9
|
+
gymcts/gymcts_node.py,sha256=KAR5y1MrT8c_7ZXwTuCj77B7DiERDfHplF8avs76JHU,13410
|
|
10
10
|
gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
|
|
11
11
|
gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
|
|
12
|
-
gymcts-1.
|
|
13
|
-
gymcts-1.
|
|
14
|
-
gymcts-1.
|
|
15
|
-
gymcts-1.
|
|
16
|
-
gymcts-1.
|
|
12
|
+
gymcts-1.4.1.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
13
|
+
gymcts-1.4.1.dist-info/METADATA,sha256=DsGxePuo5m6SgNPRjrkuxUK-em2IWBm0b20ET-CVdP0,23864
|
|
14
|
+
gymcts-1.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
gymcts-1.4.1.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
16
|
+
gymcts-1.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|