gymcts 1.3.0__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {gymcts-1.3.0/src/gymcts.egg-info → gymcts-1.4.0}/PKG-INFO +1 -1
  2. {gymcts-1.3.0 → gymcts-1.4.0}/pyproject.toml +1 -1
  3. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_agent.py +3 -1
  4. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_neural_agent.py +19 -13
  5. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_node.py +7 -3
  6. {gymcts-1.3.0 → gymcts-1.4.0/src/gymcts.egg-info}/PKG-INFO +1 -1
  7. {gymcts-1.3.0 → gymcts-1.4.0}/tests/test_number_of_visits.py +4 -3
  8. {gymcts-1.3.0 → gymcts-1.4.0}/LICENSE +0 -0
  9. {gymcts-1.3.0 → gymcts-1.4.0}/MANIFEST.in +0 -0
  10. {gymcts-1.3.0 → gymcts-1.4.0}/README.md +0 -0
  11. {gymcts-1.3.0 → gymcts-1.4.0}/setup.cfg +0 -0
  12. {gymcts-1.3.0 → gymcts-1.4.0}/setup.py +0 -0
  13. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/__init__.py +0 -0
  14. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/colorful_console_utils.py +0 -0
  15. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_action_history_wrapper.py +0 -0
  16. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_deepcopy_wrapper.py +0 -0
  17. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_distributed_agent.py +0 -0
  18. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_env_abc.py +0 -0
  19. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/gymcts_tree_plotter.py +0 -0
  20. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts/logger.py +0 -0
  21. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts.egg-info/SOURCES.txt +0 -0
  22. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts.egg-info/dependency_links.txt +0 -0
  23. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts.egg-info/not-zip-safe +0 -0
  24. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts.egg-info/requires.txt +0 -0
  25. {gymcts-1.3.0 → gymcts-1.4.0}/src/gymcts.egg-info/top_level.txt +0 -0
  26. {gymcts-1.3.0 → gymcts-1.4.0}/tests/test_graph_matrix_jsp_env.py +0 -0
  27. {gymcts-1.3.0 → gymcts-1.4.0}/tests/test_gymnasium_envs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "gymcts"
7
- version = "1.3.0"
7
+ version = "1.4.0"
8
8
  description = "A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments."
9
9
  readme = "README.md"
10
10
  authors = [{ name = "Alexander Nasuta", email = "alexander.nasuta@wzl-iqs.rwth-aachen.de" }]
@@ -49,6 +49,7 @@ class GymctsAgent:
49
49
  calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
50
50
  score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
51
51
  best_action_weight=None,
52
+ keep_whole_tree_till_initial_root: bool = False,
52
53
  ):
53
54
  # check if action space of env is discrete
54
55
  if not isinstance(env.action_space, gym.spaces.Discrete):
@@ -79,6 +80,7 @@ class GymctsAgent:
79
80
 
80
81
  self.env = env
81
82
  self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
83
+ self.keep_whole_tree_till_initial_root = keep_whole_tree_till_initial_root
82
84
 
83
85
  self.search_root_node = GymctsNode(
84
86
  action=None,
@@ -190,7 +192,7 @@ class GymctsAgent:
190
192
  # we also need to reset the children of the current node
191
193
  # this is done by calling the reset method
192
194
  next_node.reset()
193
- else:
195
+ elif not self.keep_whole_tree_till_initial_root:
194
196
  next_node.remove_parent()
195
197
 
196
198
  self.search_root_node = next_node
@@ -162,37 +162,43 @@ class GymctsNeuralNode(GymctsNode):
162
162
  p_sa = self._selection_score_prior
163
163
  n_s = self.parent.visit_count
164
164
  n_sa = self.visit_count
165
+
166
+ assert 0 <= GymctsNode.best_action_weight <= 1
167
+ b = GymctsNode.best_action_weight
168
+ exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
169
+
170
+
165
171
  if GymctsNeuralNode.score_variate == "PUCT_v0":
166
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
172
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
167
173
  elif GymctsNeuralNode.score_variate == "PUCT_v1":
168
- return self.mean_value + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
174
+ return exploitation_term + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
169
175
  elif GymctsNeuralNode.score_variate == "PUCT_v2":
170
- return self.mean_value + c * p_sa * math.sqrt(n_s) / n_sa
176
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / n_sa
171
177
  elif GymctsNeuralNode.score_variate == "PUCT_v3":
172
- return self.mean_value + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
178
+ return exploitation_term + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
173
179
  elif GymctsNeuralNode.score_variate == "PUCT_v4":
174
- return self.mean_value + c * (p_sa / (1 + n_sa))
180
+ return exploitation_term + c * (p_sa / (1 + n_sa))
175
181
  elif GymctsNeuralNode.score_variate == "PUCT_v5":
176
- return self.mean_value + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
182
+ return exploitation_term + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
177
183
  elif GymctsNeuralNode.score_variate == "PUCT_v6":
178
- return self.mean_value + c * p_sa * n_s / (1 + n_sa)
184
+ return exploitation_term + c * p_sa * n_s / (1 + n_sa)
179
185
  elif GymctsNeuralNode.score_variate == "PUCT_v7":
180
186
  epsilon = 1e-8
181
- return self.mean_value + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
187
+ return exploitation_term + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
182
188
  elif GymctsNeuralNode.score_variate == "PUCT_v8":
183
- return self.mean_value + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
189
+ return exploitation_term + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
184
190
  elif GymctsNeuralNode.score_variate == "PUCT_v9":
185
- return self.mean_value + c * p_sa * math.sqrt(n_s / (1 + n_sa))
191
+ return exploitation_term + c * p_sa * math.sqrt(n_s / (1 + n_sa))
186
192
  elif GymctsNeuralNode.score_variate == "PUCT_v10":
187
- return self.mean_value + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
193
+ return exploitation_term + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
188
194
  elif GymctsNeuralNode.score_variate == "MuZero_v0":
189
195
  c1 = GymctsNeuralNode.MuZero_c1
190
196
  c2 = GymctsNeuralNode.MuZero_c2
191
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
197
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
192
198
  elif GymctsNeuralNode.score_variate == "MuZero_v1":
193
199
  c1 = GymctsNeuralNode.MuZero_c1
194
200
  c2 = GymctsNeuralNode.MuZero_c2
195
- return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
201
+ return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
196
202
 
197
203
 
198
204
  exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
@@ -326,16 +326,20 @@ class GymctsNode:
326
326
  raise ValueError("ucb_score can only be called on non-root nodes")
327
327
  c = GymctsNode.ubc_c # default is 0.707
328
328
 
329
+ assert 0 <= GymctsNode.best_action_weight <= 1
330
+ b = GymctsNode.best_action_weight
331
+ exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
332
+
329
333
  if GymctsNode.score_variate == "UCT_v0":
330
334
  if self.visit_count == 0:
331
335
  return float("inf")
332
- return self.mean_value + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
336
+ return exploitation_term + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
333
337
 
334
338
  if GymctsNode.score_variate == "UCT_v1":
335
- return self.mean_value + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
339
+ return exploitation_term + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
336
340
 
337
341
  if GymctsNode.score_variate == "UCT_v2":
338
- return self.mean_value + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
342
+ return exploitation_term + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
339
343
 
340
344
  raise ValueError(f"unknown score variate: {GymctsNode.score_variate}. ")
341
345
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -4,7 +4,7 @@ from gymcts.gymcts_agent import GymctsAgent
4
4
  def test_number_of_visits_without_clearing_root(graph_matrix_env_naive_wrapper_singe_job_jsp_instance):
5
5
  env = graph_matrix_env_naive_wrapper_singe_job_jsp_instance
6
6
 
7
- agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False)
7
+ agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False, keep_whole_tree_till_initial_root=True)
8
8
 
9
9
  assert agent.search_root_node.visit_count == 0
10
10
  agent.vanilla_mcts_search(search_start_node=agent.search_root_node, num_simulations=10)
@@ -20,7 +20,7 @@ def test_number_of_visits_without_clearing_root(graph_matrix_env_naive_wrapper_s
20
20
  def test_number_of_visits_without_clearing(graph_matrix_env_naive_wrapper_singe_job_jsp_instance):
21
21
  env = graph_matrix_env_naive_wrapper_singe_job_jsp_instance
22
22
 
23
- agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False)
23
+ agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False, keep_whole_tree_till_initial_root=True)
24
24
  assert agent.search_root_node.visit_count == 0
25
25
 
26
26
  actions = agent.solve(num_simulations_per_step=10)
@@ -43,7 +43,7 @@ def test_number_of_visits_without_clearing_root_dynamic_step_size(graph_matrix_e
43
43
 
44
44
  env = graph_matrix_env_naive_wrapper_singe_job_jsp_instance
45
45
 
46
- agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False)
46
+ agent = GymctsAgent(env=env, clear_mcts_tree_after_step=False, keep_whole_tree_till_initial_root=True)
47
47
 
48
48
  tree_root = agent.search_root_node
49
49
 
@@ -89,6 +89,7 @@ def test_number_of_visits_with_clearing_root2(graph_matrix_env_naive_wrapper_two
89
89
  agent = GymctsAgent(
90
90
  env=env,
91
91
  clear_mcts_tree_after_step=False,
92
+ keep_whole_tree_till_initial_root=True,
92
93
  number_of_simulations_per_step=50
93
94
  )
94
95
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes