gymcts 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gymcts/gymcts_agent.py CHANGED
@@ -49,6 +49,7 @@ class GymctsAgent:
49
49
  calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
50
50
  score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
51
51
  best_action_weight=None,
52
+ keep_whole_tree_till_initial_root: bool = False,
52
53
  ):
53
54
  # check if action space of env is discrete
54
55
  if not isinstance(env.action_space, gym.spaces.Discrete):
@@ -79,6 +80,7 @@ class GymctsAgent:
79
80
 
80
81
  self.env = env
81
82
  self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
83
+ self.keep_whole_tree_till_initial_root = keep_whole_tree_till_initial_root
82
84
 
83
85
  self.search_root_node = GymctsNode(
84
86
  action=None,
@@ -190,7 +192,7 @@ class GymctsAgent:
190
192
  # we also need to reset the children of the current node
191
193
  # this is done by calling the reset method
192
194
  next_node.reset()
193
- else:
195
+ elif not self.keep_whole_tree_till_initial_root:
194
196
  next_node.remove_parent()
195
197
 
196
198
  self.search_root_node = next_node
@@ -162,37 +162,43 @@ class GymctsNeuralNode(GymctsNode):
162
162
  p_sa = self._selection_score_prior
163
163
  n_s = self.parent.visit_count
164
164
  n_sa = self.visit_count
165
+
166
+ assert 0 <= GymctsNode.best_action_weight <= 1
167
+ b = GymctsNode.best_action_weight
168
+ exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
169
+
170
+
165
171
  if GymctsNeuralNode.score_variate == "PUCT_v0":
166
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
172
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
167
173
  elif GymctsNeuralNode.score_variate == "PUCT_v1":
168
- return self.mean_value + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
174
+ return exploitation_term + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
169
175
  elif GymctsNeuralNode.score_variate == "PUCT_v2":
170
- return self.mean_value + c * p_sa * math.sqrt(n_s) / n_sa
176
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / n_sa
171
177
  elif GymctsNeuralNode.score_variate == "PUCT_v3":
172
- return self.mean_value + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
178
+ return exploitation_term + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
173
179
  elif GymctsNeuralNode.score_variate == "PUCT_v4":
174
- return self.mean_value + c * (p_sa / (1 + n_sa))
180
+ return exploitation_term + c * (p_sa / (1 + n_sa))
175
181
  elif GymctsNeuralNode.score_variate == "PUCT_v5":
176
- return self.mean_value + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
182
+ return exploitation_term + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
177
183
  elif GymctsNeuralNode.score_variate == "PUCT_v6":
178
- return self.mean_value + c * p_sa * n_s / (1 + n_sa)
184
+ return exploitation_term + c * p_sa * n_s / (1 + n_sa)
179
185
  elif GymctsNeuralNode.score_variate == "PUCT_v7":
180
186
  epsilon = 1e-8
181
- return self.mean_value + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
187
+ return exploitation_term + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
182
188
  elif GymctsNeuralNode.score_variate == "PUCT_v8":
183
- return self.mean_value + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
189
+ return exploitation_term + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
184
190
  elif GymctsNeuralNode.score_variate == "PUCT_v9":
185
- return self.mean_value + c * p_sa * math.sqrt(n_s / (1 + n_sa))
191
+ return exploitation_term + c * p_sa * math.sqrt(n_s / (1 + n_sa))
186
192
  elif GymctsNeuralNode.score_variate == "PUCT_v10":
187
- return self.mean_value + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
193
+ return exploitation_term + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
188
194
  elif GymctsNeuralNode.score_variate == "MuZero_v0":
189
195
  c1 = GymctsNeuralNode.MuZero_c1
190
196
  c2 = GymctsNeuralNode.MuZero_c2
191
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
197
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
192
198
  elif GymctsNeuralNode.score_variate == "MuZero_v1":
193
199
  c1 = GymctsNeuralNode.MuZero_c1
194
200
  c2 = GymctsNeuralNode.MuZero_c2
195
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
201
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
196
202
 
197
203
 
198
204
  exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
@@ -221,7 +227,7 @@ class GymctsNeuralNode(GymctsNode):
221
227
  if not colored:
222
228
 
223
229
  if not self.is_root():
224
- return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
230
+ return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, {GymctsNeuralNode.score_variate}={self.tree_policy_score():.2f})"
225
231
  else:
226
232
  return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
227
233
 
gymcts/gymcts_node.py CHANGED
@@ -61,7 +61,7 @@ class GymctsNode:
61
61
  if not colored:
62
62
 
63
63
  if not self.is_root():
64
- return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
64
+ return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, {GymctsNode.score_variate}={self.tree_policy_score():.2f})"
65
65
  else:
66
66
  return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
67
67
 
@@ -102,7 +102,7 @@ class GymctsNode:
102
102
  f"{p}N{e}={colorful_value(self.visit_count)}, "
103
103
  f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
104
104
  f"{p}best{e}={colorful_value(self.max_value)}") +
105
- (f", {p}ubc{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
105
+ (f", {p}{GymctsNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
106
106
 
107
107
  def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
108
108
  """
@@ -326,16 +326,20 @@ class GymctsNode:
326
326
  raise ValueError("ucb_score can only be called on non-root nodes")
327
327
  c = GymctsNode.ubc_c # default is 0.707
328
328
 
329
+ assert 0 <= GymctsNode.best_action_weight <= 1
330
+ b = GymctsNode.best_action_weight
331
+ exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
332
+
329
333
  if GymctsNode.score_variate == "UCT_v0":
330
334
  if self.visit_count == 0:
331
335
  return float("inf")
332
- return self.mean_value + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
336
+ return exploitation_term + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
333
337
 
334
338
  if GymctsNode.score_variate == "UCT_v1":
335
- return self.mean_value + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
339
+ return exploitation_term + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
336
340
 
337
341
  if GymctsNode.score_variate == "UCT_v2":
338
- return self.mean_value + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
342
+ return exploitation_term + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
339
343
 
340
344
  raise ValueError(f"unknown score variate: {GymctsNode.score_variate}. ")
341
345
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.3.0
3
+ Version: 1.4.1
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -1,16 +1,16 @@
1
1
  gymcts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  gymcts/colorful_console_utils.py,sha256=n7nymC8kKZnA_8nXcdn201NAzjZjgEHfKpbBcnl4oAE,5891
3
3
  gymcts/gymcts_action_history_wrapper.py,sha256=7-p17Fgb80SRCBaCm6G8SJrEPsl2Y4aIO3InviuQP08,6993
4
- gymcts/gymcts_agent.py,sha256=OAcN2-mFCR2AVJrRZlRtROF_zHk90SIM-uAebKektIc,10768
4
+ gymcts/gymcts_agent.py,sha256=FzMPjHXyKN6enNJubmYEouvb0wBbE1-bpxuLuW4J1gU,10960
5
5
  gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
6
6
  gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
7
7
  gymcts/gymcts_env_abc.py,sha256=iqrFNNSa-kZyAGk1UN2BjkdkV6NufAkYJT8d7PlQ07E,2525
8
- gymcts/gymcts_neural_agent.py,sha256=urYGA5D6idChPke8Ac9zqhKy2NqkJzt3Zt-j8V6OpuQ,15785
9
- gymcts/gymcts_node.py,sha256=-YKfK5fryPteCp-UTsAgzFVIBucZdXPMbXHCIb6mS24,13151
8
+ gymcts/gymcts_neural_agent.py,sha256=kP2DwoZ6nM4lUYqePhEvUIAqmZegB0oxQ3uMtMFj-Hk,16049
9
+ gymcts/gymcts_node.py,sha256=KAR5y1MrT8c_7ZXwTuCj77B7DiERDfHplF8avs76JHU,13410
10
10
  gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
11
11
  gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
12
- gymcts-1.3.0.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
13
- gymcts-1.3.0.dist-info/METADATA,sha256=pyhdSu_PAMi9IbVeSsHU0EcJSasAMttrtz-pKIjbePw,23864
14
- gymcts-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
- gymcts-1.3.0.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
16
- gymcts-1.3.0.dist-info/RECORD,,
12
+ gymcts-1.4.1.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
13
+ gymcts-1.4.1.dist-info/METADATA,sha256=DsGxePuo5m6SgNPRjrkuxUK-em2IWBm0b20ET-CVdP0,23864
14
+ gymcts-1.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ gymcts-1.4.1.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
16
+ gymcts-1.4.1.dist-info/RECORD,,
File without changes