gymcts 1.2.1__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {gymcts-1.2.1/src/gymcts.egg-info → gymcts-1.3.0}/PKG-INFO +9 -5
  2. {gymcts-1.2.1 → gymcts-1.3.0}/README.md +8 -4
  3. {gymcts-1.2.1 → gymcts-1.3.0}/pyproject.toml +1 -1
  4. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_agent.py +51 -8
  5. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_env_abc.py +12 -2
  6. gymcts-1.3.0/src/gymcts/gymcts_neural_agent.py +479 -0
  7. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_node.py +76 -9
  8. {gymcts-1.2.1 → gymcts-1.3.0/src/gymcts.egg-info}/PKG-INFO +9 -5
  9. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts.egg-info/SOURCES.txt +1 -0
  10. {gymcts-1.2.1 → gymcts-1.3.0}/LICENSE +0 -0
  11. {gymcts-1.2.1 → gymcts-1.3.0}/MANIFEST.in +0 -0
  12. {gymcts-1.2.1 → gymcts-1.3.0}/setup.cfg +0 -0
  13. {gymcts-1.2.1 → gymcts-1.3.0}/setup.py +0 -0
  14. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/__init__.py +0 -0
  15. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/colorful_console_utils.py +0 -0
  16. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_action_history_wrapper.py +0 -0
  17. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_deepcopy_wrapper.py +0 -0
  18. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_distributed_agent.py +0 -0
  19. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/gymcts_tree_plotter.py +0 -0
  20. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts/logger.py +0 -0
  21. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts.egg-info/dependency_links.txt +0 -0
  22. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts.egg-info/not-zip-safe +0 -0
  23. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts.egg-info/requires.txt +0 -0
  24. {gymcts-1.2.1 → gymcts-1.3.0}/src/gymcts.egg-info/top_level.txt +0 -0
  25. {gymcts-1.2.1 → gymcts-1.3.0}/tests/test_graph_matrix_jsp_env.py +0 -0
  26. {gymcts-1.2.1 → gymcts-1.3.0}/tests/test_gymnasium_envs.py +0 -0
  27. {gymcts-1.2.1 → gymcts-1.3.0}/tests/test_number_of_visits.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.2.1
3
+ Version: 1.3.0
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -70,11 +70,18 @@ Requires-Dist: jupyter; extra == "dev"
70
70
  Requires-Dist: typing_extensions>=4.12.0; extra == "dev"
71
71
  Dynamic: license-file
72
72
 
73
- # Graph Matrix Job Shop Env
73
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15283390.svg)](https://doi.org/10.5281/zenodo.15283390)
74
+ [![Python Badge](https://img.shields.io/badge/Python-3776AB?logo=python&logoColor=fff&style=flat)](https://www.python.org/downloads/)
75
+ [![PyPI version](https://img.shields.io/pypi/v/gymcts)](https://pypi.org/project/gymcts/)
76
+ [![License](https://img.shields.io/pypi/l/gymcts)](https://github.com/Alexander-Nasuta/gymcts/blob/master/LICENSE)
77
+ [![Documentation Status](https://readthedocs.org/projects/gymcts/badge/?version=latest)](https://gymcts.readthedocs.io/en/latest/?badge=latest)
78
+
79
+ # GYMCTS
74
80
 
75
81
  A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
76
82
 
77
83
  - Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
84
+ - GitLab: [GYMCTS on GitLab](https://git-ce.rwth-aachen.de/alexander.nasuta/gymcts)
78
85
  - Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
79
86
  - Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
80
87
 
@@ -579,9 +586,6 @@ This project uses `pytest` for testing. To run the tests, run the following comm
579
586
  ```shell
580
587
  pytest
581
588
  ```
582
- Here is a screenshot of what the output might look like:
583
-
584
- ![](https://github.com/Alexander-Nasuta/GraphMatrixJobShopEnv/raw/master/resources/pytest-screenshot.png)
585
589
 
586
590
  For testing with `tox` run the following command:
587
591
 
@@ -1,8 +1,15 @@
1
- # Graph Matrix Job Shop Env
1
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15283390.svg)](https://doi.org/10.5281/zenodo.15283390)
2
+ [![Python Badge](https://img.shields.io/badge/Python-3776AB?logo=python&logoColor=fff&style=flat)](https://www.python.org/downloads/)
3
+ [![PyPI version](https://img.shields.io/pypi/v/gymcts)](https://pypi.org/project/gymcts/)
4
+ [![License](https://img.shields.io/pypi/l/gymcts)](https://github.com/Alexander-Nasuta/gymcts/blob/master/LICENSE)
5
+ [![Documentation Status](https://readthedocs.org/projects/gymcts/badge/?version=latest)](https://gymcts.readthedocs.io/en/latest/?badge=latest)
6
+
7
+ # GYMCTS
2
8
 
3
9
  A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
4
10
 
5
11
  - Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
12
+ - GitLab: [GYMCTS on GitLab](https://git-ce.rwth-aachen.de/alexander.nasuta/gymcts)
6
13
  - Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
7
14
  - Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
8
15
 
@@ -507,9 +514,6 @@ This project uses `pytest` for testing. To run the tests, run the following comm
507
514
  ```shell
508
515
  pytest
509
516
  ```
510
- Here is a screenshot of what the output might look like:
511
-
512
- ![](https://github.com/Alexander-Nasuta/GraphMatrixJobShopEnv/raw/master/resources/pytest-screenshot.png)
513
517
 
514
518
  For testing with `tox` run the following command:
515
519
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "gymcts"
7
- version = "1.2.1"
7
+ version = "1.3.0"
8
8
  description = "A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments."
9
9
  readme = "README.md"
10
10
  authors = [{ name = "Alexander Nasuta", email = "alexander.nasuta@wzl-iqs.rwth-aachen.de" }]
@@ -2,7 +2,7 @@ import copy
2
2
  import random
3
3
  import gymnasium as gym
4
4
 
5
- from typing import TypeVar, Any, SupportsFloat, Callable
5
+ from typing import TypeVar, Any, SupportsFloat, Callable, Literal
6
6
 
7
7
  from gymcts.gymcts_env_abc import GymctsABC
8
8
  from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper
@@ -11,7 +11,9 @@ from gymcts.gymcts_tree_plotter import _generate_mcts_tree
11
11
 
12
12
  from gymcts.logger import log
13
13
 
14
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
14
+
15
+
16
+
15
17
 
16
18
 
17
19
  class GymctsAgent:
@@ -24,17 +26,50 @@ class GymctsAgent:
24
26
  search_root_node: GymctsNode # NOTE: this is not the same as the root of the tree!
25
27
  clear_mcts_tree_after_step: bool
26
28
 
29
+
30
+ # (num_simulations: int, step_idx: int) -> int
31
+ @staticmethod
32
+ def calc_number_of_simulations_per_step(num_simulations: int, step_idx: int) -> int:
33
+ """
34
+ A function that returns a constant number of simulations per step.
35
+
36
+ :param num_simulations: The number of simulations to return.
37
+ :param step_idx: The current step index (not used in this function).
38
+ :return: A callable that takes an environment as input and returns the constant number of simulations.
39
+ """
40
+ return num_simulations
41
+
27
42
  def __init__(self,
28
43
  env: GymctsABC,
29
44
  clear_mcts_tree_after_step: bool = True,
30
45
  render_tree_after_step: bool = False,
31
46
  render_tree_max_depth: int = 2,
32
47
  number_of_simulations_per_step: int = 25,
33
- exclude_unvisited_nodes_from_render: bool = False
48
+ exclude_unvisited_nodes_from_render: bool = False,
49
+ calc_number_of_simulations_per_step: Callable[[int,int], int] = None,
50
+ score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0",
51
+ best_action_weight=None,
34
52
  ):
35
53
  # check if action space of env is discrete
36
54
  if not isinstance(env.action_space, gym.spaces.Discrete):
37
55
  raise ValueError("Action space must be discrete.")
56
+ if calc_number_of_simulations_per_step is not None:
57
+ # check if the provided function is callable
58
+ if not callable(calc_number_of_simulations_per_step):
59
+ raise ValueError("calc_number_of_simulations_per_step must be a callable accepting two arguments: num_simulations and step_idx.")
60
+ # assign the provided function to the attribute
61
+ # it needs to be staticmethod to be used as a class attribute
62
+ print("Using provided calc_number_of_simulations_per_step function.")
63
+ self.calc_number_of_simulations_per_step = staticmethod(calc_number_of_simulations_per_step)
64
+ if score_variate not in ["UCT_v0", "UCT_v1", "UCT_v2"]:
65
+ raise ValueError("score_variate must be one of ['UCT_v0', 'UCT_v1', 'UCT_v2'].")
66
+ GymctsNode.score_variate = score_variate
67
+
68
+ if best_action_weight is not None:
69
+ if best_action_weight < 0 or best_action_weight > 1:
70
+ raise ValueError("best_action_weight must be in range [0, 1].")
71
+ GymctsNode.best_action_weight = best_action_weight
72
+
38
73
 
39
74
  self.render_tree_after_step = render_tree_after_step
40
75
  self.exclude_unvisited_nodes_from_render = exclude_unvisited_nodes_from_render
@@ -65,8 +100,8 @@ class GymctsAgent:
65
100
  # select child with highest UCB score
66
101
  while not temp_node.is_leaf():
67
102
  children = list(temp_node.children.values())
68
- max_ucb_score = max(child.ucb_score() for child in children)
69
- best_children = [child for child in children if child.ucb_score() == max_ucb_score]
103
+ max_ucb_score = max(child.tree_policy_score() for child in children)
104
+ best_children = [child for child in children if child.tree_policy_score() == max_ucb_score]
70
105
  temp_node = random.choice(best_children)
71
106
  log.debug(f"Selected leaf node: {temp_node}")
72
107
  return temp_node
@@ -88,7 +123,6 @@ class GymctsAgent:
88
123
  parent=node,
89
124
  env_reference=self.env,
90
125
  )
91
-
92
126
  node.children = child_dict
93
127
 
94
128
  def solve(self, num_simulations_per_step: int = None, render_tree_after_step: bool = None) -> list[int]:
@@ -104,13 +138,20 @@ class GymctsAgent:
104
138
 
105
139
  action_list = []
106
140
 
141
+ idx = 0
107
142
  while not current_node.terminal:
108
- next_action, current_node = self.perform_mcts_step(num_simulations=num_simulations_per_step,
143
+ num_sims = self.calc_number_of_simulations_per_step(num_simulations_per_step, idx)
144
+
145
+ log.info(f"Performing MCTS step {idx} with {num_sims} simulations.")
146
+
147
+ next_action, current_node = self.perform_mcts_step(num_simulations=num_sims,
109
148
  render_tree_after_step=render_tree_after_step)
110
- log.info(f"selected action {next_action} after {num_simulations_per_step} simulations.")
149
+ log.info(f"selected action {next_action} after {num_sims} simulations.")
111
150
  action_list.append(next_action)
112
151
  log.info(f"current action list: {action_list}")
113
152
 
153
+ idx += 1
154
+
114
155
  log.info(f"Final action list: {action_list}")
115
156
  # restore state of current node
116
157
  return action_list
@@ -149,6 +190,8 @@ class GymctsAgent:
149
190
  # we also need to reset the children of the current node
150
191
  # this is done by calling the reset method
151
192
  next_node.reset()
193
+ else:
194
+ next_node.remove_parent()
152
195
 
153
196
  self.search_root_node = next_node
154
197
 
@@ -1,8 +1,7 @@
1
1
  from typing import TypeVar, Any, SupportsFloat, Callable
2
2
  from abc import ABC, abstractmethod
3
3
  import gymnasium as gym
4
-
5
- TSoloMCTSNode = TypeVar("TSoloMCTSNode", bound="SoloMCTSNode")
4
+ import numpy as np
6
5
 
7
6
 
8
7
  class GymctsABC(ABC, gym.Env):
@@ -47,6 +46,17 @@ class GymctsABC(ABC, gym.Env):
47
46
  """
48
47
  pass
49
48
 
49
+ @abstractmethod
50
+ def action_masks(self) -> np.ndarray | None:
51
+ """
52
+ Returns a numpy array of action masks for the environment. The array should have the same length as the number
53
+ of actions in the action space. If an action is valid, the corresponding mask value should be 1, otherwise 0.
54
+ If no action mask is available, it should return None.
55
+
56
+ :return: a numpy array of action masks or None
57
+ """
58
+ pass
59
+
50
60
  @abstractmethod
51
61
  def rollout(self) -> float:
52
62
  """
@@ -0,0 +1,479 @@
1
+ import copy
2
+ import sys
3
+ from typing import Any, Literal
4
+
5
+ import random
6
+ import math
7
+ import sb3_contrib
8
+
9
+ import gymnasium as gym
10
+ import numpy as np
11
+
12
+ from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
13
+ from jsp_instance_utils.instances import ft06, ft06_makespan
14
+ from sb3_contrib.common.maskable.distributions import MaskableCategoricalDistribution
15
+ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
16
+ from sb3_contrib.common.wrappers import ActionMasker
17
+
18
+ from gymcts.gymcts_agent import GymctsAgent
19
+ from gymcts.gymcts_env_abc import GymctsABC
20
+ from gymcts.gymcts_node import GymctsNode
21
+
22
+ from gymcts.logger import log
23
+
24
+
25
+ class GraphJspNeuralGYMCTSWrapper(GymctsABC, gym.Wrapper):
26
+
27
+ def __init__(self, env: DisjunctiveGraphJspEnv):
28
+ gym.Wrapper.__init__(self, env)
29
+
30
+ def load_state(self, state: Any) -> None:
31
+ self.env.reset()
32
+ for action in state:
33
+ self.env.step(action)
34
+
35
+ def is_terminal(self) -> bool:
36
+ return self.env.unwrapped.is_terminal()
37
+
38
+ def get_valid_actions(self) -> list[int]:
39
+ return list(self.env.unwrapped.valid_actions())
40
+
41
+ def rollout(self) -> float:
42
+ terminal = env.is_terminal()
43
+
44
+ if terminal:
45
+ lower_bound = env.unwrapped.reward_function_parameters['scaling_divisor']
46
+ return - env.unwrapped.get_makespan() / lower_bound + 2
47
+
48
+ reward = 0
49
+ while not terminal:
50
+ action = random.choice(self.get_valid_actions())
51
+ obs, reward, terminal, truncated, _ = env.step(action)
52
+
53
+ return reward + 2
54
+
55
+ def get_state(self) -> Any:
56
+ return env.unwrapped.get_action_history()
57
+
58
+
59
+ def action_masks(self) -> np.ndarray | None:
60
+ """Return the action mask for the current state."""
61
+ return self.env.unwrapped.valid_action_mask()
62
+
63
+
64
+
65
+
66
+ class GymctsNeuralNode(GymctsNode):
67
+ PUCT_v3_mu = 0.95
68
+
69
+ MuZero_c1 = 1.25
70
+ MuZero_c2 = 19652.0
71
+
72
+ """
73
+ PUCT (Predictor + UCT) exploration terms:
74
+
75
+ PUCT_v0:
76
+ c * P(s, a) * √( N(s) / (1 + N(s,a)) )
77
+
78
+ PUCT_v1:
79
+ c * P(s, a) * √( 2 * ln(N(s)) / N(s,a) )
80
+
81
+ PUCT_v2:
82
+ c * P(s, a) * √( N(s) ) / N(s,a)
83
+
84
+ PUCT_v3:
85
+ c * P(s, a)^μ * √( N(s) / (1 + N(s,a)) )
86
+
87
+ PUCT_v4:
88
+ c * ( P(s, a) / (1 + N(s,a)) )
89
+
90
+ PUCT_v5:
91
+ c * P(s, a) * ( √(N(s)) + 1 ) / (N(s,a) + 1)
92
+
93
+ PUCT_v6:
94
+ c * P(s, a) * N(s) / (1 + N(s,a))
95
+
96
+ PUCT_v7:
97
+ c * P(s, a) * ( √(N(s)) + ε ) / (N(s,a) + 1)
98
+
99
+ PUCT_v8:
100
+ c * P(s, a) * √( (ln(N(s)) + 1) / (1 + N(s,a)) )
101
+
102
+ PUCT_v9:
103
+ c * P(s, a) * √( N(s) / (1 + N(s,a)) )
104
+
105
+ PUCT_v10:
106
+ c * P(s, a) * √( ln(N(s)) / (1 + N(s,a)) )
107
+
108
+
109
+ MuZero exploration terms:
110
+
111
+ MuZero_v0:
112
+ P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
113
+
114
+ MuZero_v1:
115
+ P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
116
+
117
+
118
+ Where:
119
+ - N(s): number of times state s has been visited
120
+ - N(s,a): number of times action a was taken from state s
121
+ - P(s,a): prior probability of selecting action a from state s
122
+ - c, c₁, c₂: exploration constants
123
+ - μ: exponent applied to P(s,a) in some variants
124
+ - ε: small constant to avoid division by zero (in PUCT 7)
125
+ """
126
+ score_variate: Literal[
127
+ "PUCT_v0",
128
+ "PUCT_v1",
129
+ "PUTC_v2",
130
+ "PUTC_v3",
131
+ "PUTC_v4",
132
+ "PUTC_v5",
133
+ "PUTC_v6",
134
+ "PUTC_v7",
135
+ "PUTC_v8",
136
+ "PUTC_v9",
137
+ "PUTC_v10",
138
+ "MuZero_v0",
139
+ "MuZero_v1",
140
+ ] = "PUCT_v0"
141
+
142
+ def __init__(
143
+ self,
144
+ action: int,
145
+ parent: 'GymctsNeuralNode',
146
+ env_reference: GymctsABC,
147
+ prior_selection_score: float,
148
+ observation: np.ndarray | None = None,
149
+ ):
150
+ super().__init__(action, parent, env_reference)
151
+
152
+ self._obs = observation
153
+ self._selection_score_prior = prior_selection_score
154
+
155
+
156
+ def tree_policy_score(self) -> float:
157
+ # call the superclass (GymctsNode) for ucb_score
158
+ c = GymctsNode.ubc_c
159
+ # the way alpha zero does it
160
+ # exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count)) / (1 + self.visit_count)
161
+ # the way the vanilla gymcts does it
162
+ p_sa = self._selection_score_prior
163
+ n_s = self.parent.visit_count
164
+ n_sa = self.visit_count
165
+ if GymctsNeuralNode.score_variate == "PUCT_v0":
166
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
167
+ elif GymctsNeuralNode.score_variate == "PUCT_v1":
168
+ return self.mean_value + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
169
+ elif GymctsNeuralNode.score_variate == "PUCT_v2":
170
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / n_sa
171
+ elif GymctsNeuralNode.score_variate == "PUCT_v3":
172
+ return self.mean_value + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
173
+ elif GymctsNeuralNode.score_variate == "PUCT_v4":
174
+ return self.mean_value + c * (p_sa / (1 + n_sa))
175
+ elif GymctsNeuralNode.score_variate == "PUCT_v5":
176
+ return self.mean_value + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
177
+ elif GymctsNeuralNode.score_variate == "PUCT_v6":
178
+ return self.mean_value + c * p_sa * n_s / (1 + n_sa)
179
+ elif GymctsNeuralNode.score_variate == "PUCT_v7":
180
+ epsilon = 1e-8
181
+ return self.mean_value + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
182
+ elif GymctsNeuralNode.score_variate == "PUCT_v8":
183
+ return self.mean_value + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
184
+ elif GymctsNeuralNode.score_variate == "PUCT_v9":
185
+ return self.mean_value + c * p_sa * math.sqrt(n_s / (1 + n_sa))
186
+ elif GymctsNeuralNode.score_variate == "PUCT_v10":
187
+ return self.mean_value + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
188
+ elif GymctsNeuralNode.score_variate == "MuZero_v0":
189
+ c1 = GymctsNeuralNode.MuZero_c1
190
+ c2 = GymctsNeuralNode.MuZero_c2
191
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
192
+ elif GymctsNeuralNode.score_variate == "MuZero_v1":
193
+ c1 = GymctsNeuralNode.MuZero_c1
194
+ c2 = GymctsNeuralNode.MuZero_c2
195
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
196
+
197
+
198
+ exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
199
+ return self.mean_value + exploration_term
200
+
201
+
202
+ def get_best_action(self) -> int:
203
+ """
204
+ Returns the best action of the node. The best action is the action with the highest score.
205
+ The best action is the action that has the highest score.
206
+
207
+ :return: the best action of the node.
208
+ """
209
+ return max(self.children.values(), key=lambda child: child.max_value).action
210
+
211
+
212
+ def __str__(self, colored=False, action_space_n=None) -> str:
213
+ """
214
+ Returns a string representation of the node. The string representation is used for visualisation purposes.
215
+ It is used for example in the mcts tree visualisation functionality.
216
+
217
+ :param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
218
+ :param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
219
+ :return: a potentially colored string representation of the node.
220
+ """
221
+ if not colored:
222
+
223
+ if not self.is_root():
224
+ return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
225
+ else:
226
+ return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
227
+
228
+ import gymcts.colorful_console_utils as ccu
229
+
230
+ if self.is_root():
231
+ return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
232
+
233
+ if action_space_n is None:
234
+ raise ValueError("action_space_n must be provided if colored is True")
235
+
236
+ p = ccu.CYELLOW
237
+ e = ccu.CEND
238
+ v = ccu.CCYAN
239
+
240
+ def colorful_value(value: float | int | None) -> str:
241
+ if value == None:
242
+ return f"{ccu.CGREY}None{e}"
243
+ color = ccu.CCYAN
244
+ if value == 0:
245
+ color = ccu.CRED
246
+ if value == float("inf"):
247
+ color = ccu.CGREY
248
+ if value == -float("inf"):
249
+ color = ccu.CGREY
250
+
251
+ if isinstance(value, float):
252
+ return f"{color}{value:.2f}{e}"
253
+
254
+ if isinstance(value, int):
255
+ return f"{color}{value}{e}"
256
+
257
+ root_node = self.get_root()
258
+ mean_val = f"{self.mean_value:.2f}"
259
+
260
+
261
+ return ((f"("
262
+ f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
263
+ f"{p}N{e}={colorful_value(self.visit_count)}, "
264
+ f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
265
+ f"{p}best{e}={colorful_value(self.max_value)}") +
266
+ (f", {p}{GymctsNeuralNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
267
+
268
+
269
+
270
+ class GymctsNeuralAgent(GymctsAgent):
271
+
272
+ def __init__(self,
273
+ env: GymctsABC,
274
+ *args,
275
+ model_kwargs=None,
276
+ score_variate: Literal[
277
+ "PUCT_v0",
278
+ "PUCT_v1",
279
+ "PUTC_v2",
280
+ "PUTC_v3",
281
+ "PUTC_v4",
282
+ "PUTC_v5",
283
+ "PUTC_v6",
284
+ "PUTC_v7",
285
+ "PUTC_v8",
286
+ "PUTC_v9",
287
+ "PUTC_v10",
288
+ "MuZero_v0",
289
+ "MuZero_v1",
290
+ ] = "PUCT_v0",
291
+ **kwargs
292
+ ):
293
+
294
+ # init super class
295
+ super().__init__(
296
+ env=env,
297
+ *args,
298
+ **kwargs
299
+ )
300
+ if score_variate not in [
301
+ "PUCT_v0", "PUCT_v1", "PUTC_v2",
302
+ "PUTC_v3", "PUTC_v4", "PUTC_v5",
303
+ "PUTC_v6", "PUTC_v7", "PUTC_v8",
304
+ "PUTC_v9", "PUTC_v10",
305
+ "MuZero_v0", "MuZero_v1"
306
+ ]:
307
+ raise ValueError(f"Invalid score_variate: {score_variate}. Must be one of: "
308
+ f"PUCT_v0, PUCT_v1, PUTC_v2, PUTC_v3, PUTC_v4, PUTC_v5, "
309
+ f"PUTC_v6, PUTC_v7, PUTC_v8, PUTC_v9, PUTC_v10, MuZero_v0, MuZero_v1")
310
+ GymctsNeuralNode.score_variate = score_variate
311
+
312
+ if model_kwargs is None:
313
+ model_kwargs = {}
314
+ obs, info = env.reset()
315
+
316
+ self.search_root_node = GymctsNeuralNode(
317
+ action=None,
318
+ parent=None,
319
+ env_reference=env,
320
+ observation=obs,
321
+ prior_selection_score=1.0,
322
+ )
323
+
324
+ def mask_fn(env: gym.Env) -> np.ndarray:
325
+ mask = env.action_masks()
326
+ if mask is None:
327
+ mask = np.ones(env.action_space.n, dtype=np.float32)
328
+ return mask
329
+
330
+ env = ActionMasker(env, action_mask_fn=mask_fn)
331
+
332
+ model_kwargs = {
333
+ "policy": MaskableActorCriticPolicy,
334
+ "env": env,
335
+ "verbose": 1,
336
+ } | model_kwargs
337
+
338
+ self._model = sb3_contrib.MaskablePPO(**model_kwargs)
339
+
340
+
341
+
342
+
343
+
344
+ def learn(self, total_timesteps:int, **kwargs) -> None:
345
+ """Learn from the environment using the MaskablePPO model."""
346
+ self._model.learn(total_timesteps=total_timesteps, **kwargs)
347
+
348
+
349
+ def expand_node(self, node: GymctsNeuralNode) -> None:
350
+ log.debug(f"expanding node: {node}")
351
+ # EXPANSION STRATEGY
352
+ # expand all children
353
+
354
+ child_dict = {}
355
+
356
+ self._load_state(node)
357
+
358
+ obs_tensor, vectorized_env = self._model.policy.obs_to_tensor(np.array([node._obs]))
359
+ action_masks = np.array([self.env.action_masks()])
360
+ distribution = self._model.policy.get_distribution(obs=obs_tensor, action_masks=action_masks)
361
+ unwrapped_distribution = distribution.distribution.probs[0]
362
+
363
+ # print(f'valid actions: {node.valid_actions}')
364
+ # print(f'env mask: {self.env.action_masks()}')
365
+ # print(f'env valid actions: {self.env.get_valid_actions()}')
366
+ """
367
+ for action in node.valid_actions:
368
+ # reconstruct state
369
+ # load state of leaf node
370
+ self._load_state(node)
371
+
372
+ obs, reward, terminal, truncated, _ = self.env.step(action)
373
+ child_dict[action] = GymctsNeuralNode(
374
+ action=action,
375
+ parent=node,
376
+ env_reference=self.env,
377
+ observation=obs,
378
+ prior_selection_score=1.0,
379
+ )
380
+ node.children = child_dict
381
+ return
382
+ """
383
+
384
+ for action, prob in enumerate(unwrapped_distribution):
385
+ self._load_state(node)
386
+
387
+ log.debug(f"Probabily for action {action}: {prob}")
388
+
389
+ if prob == 0.0:
390
+ continue
391
+
392
+
393
+ assert action in node.valid_actions, f"Action {action} is not in valid actions: {node.valid_actions}"
394
+
395
+ obs, reward, terminal, truncated, _ = self.env.step(action)
396
+ child_dict[action] = GymctsNeuralNode(
397
+ action=action,
398
+ parent=node,
399
+ observation=copy.deepcopy(obs),
400
+ env_reference=self.env,
401
+ prior_selection_score=float(prob)
402
+ )
403
+
404
+ node.children = child_dict
405
+ # print(f"Expanded node {node} with {len(node.children)} children.")
406
+
407
+
408
+
409
+
410
+
411
+ if __name__ == '__main__':
412
+ log.setLevel(20)
413
+
414
+ env_kwargs = {
415
+ "jps_instance": ft06,
416
+ "default_visualisations": ["gantt_console", "graph_console"],
417
+ "reward_function_parameters": {
418
+ "scaling_divisor": ft06_makespan
419
+ },
420
+ "reward_function": "nasuta",
421
+ }
422
+
423
+
424
+
425
+ env = DisjunctiveGraphJspEnv(**env_kwargs)
426
+ env.reset()
427
+
428
+ env = GraphJspNeuralGYMCTSWrapper(env)
429
+
430
+ import torch
431
+ model_kwargs = {
432
+ "gamma": 0.99013,
433
+ "gae_lambda": 0.9,
434
+ "normalize_advantage": True,
435
+ "n_epochs": 28,
436
+ "n_steps": 432,
437
+ "max_grad_norm": 0.5,
438
+ "learning_rate": 6e-4,
439
+ "policy_kwargs": {
440
+ "net_arch": {
441
+ "pi": [90, 90],
442
+ "vf": [90, 90],
443
+ },
444
+ "ortho_init": True,
445
+ "activation_fn": torch.nn.ELU,
446
+ "optimizer_kwargs": {
447
+ "eps": 1e-7
448
+ }
449
+ }
450
+ }
451
+
452
+ agent = GymctsNeuralAgent(
453
+ env=env,
454
+ render_tree_after_step=True,
455
+ render_tree_max_depth=3,
456
+ exclude_unvisited_nodes_from_render=False,
457
+ number_of_simulations_per_step=15,
458
+ # clear_mcts_tree_after_step = False,
459
+ model_kwargs=model_kwargs
460
+ )
461
+
462
+ agent.learn(total_timesteps=10_000)
463
+
464
+
465
+ agent.solve()
466
+
467
+ actions = agent.solve(render_tree_after_step=True)
468
+ for a in actions:
469
+ obs, rew, term, trun, info = env.step(a)
470
+
471
+ env.render()
472
+ makespan = env.unwrapped.get_makespan()
473
+ print(f"makespan: {makespan}")
474
+
475
+
476
+
477
+
478
+
479
+
@@ -2,7 +2,7 @@ import uuid
2
2
  import random
3
3
  import math
4
4
 
5
- from typing import TypeVar, Any, SupportsFloat, Callable, Generator
5
+ from typing import TypeVar, Any, SupportsFloat, Callable, Generator, Literal
6
6
 
7
7
  from gymcts.gymcts_env_abc import GymctsABC
8
8
 
@@ -16,6 +16,25 @@ class GymctsNode:
16
16
  best_action_weight: float = 0.05 # weight for the best action
17
17
  ubc_c = 0.707 # exploration coefficient
18
18
 
19
+ """
20
+ UCT (Upper Confidence Bound applied to Trees) exploration terms:
21
+
22
+ UCT 0:
23
+ c * √( 2 * ln(N(s)) / N(s,a) )
24
+
25
+ UCT 1:
26
+ c * √( ln(N(s)) / (1 + N(s,a)) )
27
+
28
+ UCT 2:
29
+ c * ( √(N(s)) / (1 + N(s,a)) )
30
+
31
+ Where:
32
+ N(s) = number of times state s has been visited
33
+ N(s,a) = number of times action a was taken from state s
34
+ c = exploration constant
35
+ """
36
+ score_variate: Literal["UCT_v0", "UCT_v1", "UCT_v2",] = "UCT_v0"
37
+
19
38
 
20
39
 
21
40
  # attributes
@@ -42,7 +61,7 @@ class GymctsNode:
42
61
  if not colored:
43
62
 
44
63
  if not self.is_root():
45
- return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.ucb_score():.2f})"
64
+ return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
46
65
  else:
47
66
  return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
48
67
 
@@ -83,7 +102,7 @@ class GymctsNode:
83
102
  f"{p}N{e}={colorful_value(self.visit_count)}, "
84
103
  f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
85
104
  f"{p}best{e}={colorful_value(self.max_value)}") +
86
- (f", {p}ubc{e}={colorful_value(self.ucb_score())})" if not self.is_root() else ")"))
105
+ (f", {p}ubc{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
87
106
 
88
107
  def traverse_nodes(self) -> Generator[TGymctsNode, None, None]:
89
108
  """
@@ -192,6 +211,12 @@ class GymctsNode:
192
211
  if self.parent:
193
212
  self.parent.reset()
194
213
 
214
+ def remove_parent(self) -> None:
215
+ self.parent = None
216
+
217
+ if self.parent is not None:
218
+ self.parent.remove_parent()
219
+
195
220
  def is_root(self) -> bool:
196
221
  """
197
222
  Returns true if the node is a root node. A root node is a node that has no parent.
@@ -252,9 +277,39 @@ class GymctsNode:
252
277
  """
253
278
  return self.max_value
254
279
 
255
- def ucb_score(self):
280
+ def tree_policy_score(self):
256
281
  """
282
+ TODO: update docstring
283
+
257
284
  The score for an action that would transition between the parent and child.
285
+ For vanilla MCTS, this is the UCB1 score.
286
+
287
+ The UCB1 score is calculated using the formula:
288
+
289
+ UCT (Upper Confidence Bound applied to Trees) exploration terms:
290
+
291
+ UCT_v0:
292
+ c * √( 2 * ln(N(s)) / N(s,a) )
293
+
294
+ UCT_v1:
295
+ c * √( ln(N(s)) / (1 + N(s,a)) )
296
+
297
+ UCT_v2:
298
+ c * ( √(N(s)) / (1 + N(s,a)) )
299
+
300
+ Where:
301
+ N(s) = number of times state s has been visited
302
+ N(s,a) = number of times action a was taken from state s
303
+ c = exploration constant
304
+
305
+ where:
306
+ - mean_value is the mean value of the node
307
+ - c is a constant that controls the exploration-exploitation trade-off (GymctsNode.ubc_c)
308
+ - parent_visit_count is the number of times the parent node has been visited
309
+ - visit_count is the number of times the node has been visited
310
+
311
+ If the node has not been visited yet, the score is set to infinity.
312
+
258
313
  prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
259
314
 
260
315
  if child.visit_count > 0:
@@ -269,8 +324,20 @@ class GymctsNode:
269
324
  """
270
325
  if self.is_root():
271
326
  raise ValueError("ucb_score can only be called on non-root nodes")
272
- # c = 0.707 # todo: make it an attribute?
273
- c = GymctsNode.ubc_c
274
- if self.visit_count == 0:
275
- return float("inf")
276
- return self.mean_value + c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count))
327
+ c = GymctsNode.ubc_c # default is 0.707
328
+
329
+ if GymctsNode.score_variate == "UCT_v0":
330
+ if self.visit_count == 0:
331
+ return float("inf")
332
+ return self.mean_value + c * math.sqrt( 2 * math.log(self.parent.visit_count) / (self.visit_count))
333
+
334
+ if GymctsNode.score_variate == "UCT_v1":
335
+ return self.mean_value + c * math.sqrt( math.log(self.parent.visit_count) / (1 + self.visit_count))
336
+
337
+ if GymctsNode.score_variate == "UCT_v2":
338
+ return self.mean_value + c * math.sqrt(self.parent.visit_count) / (1 + self.visit_count)
339
+
340
+ raise ValueError(f"unknown score variate: {GymctsNode.score_variate}. ")
341
+
342
+
343
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gymcts
3
- Version: 1.2.1
3
+ Version: 1.3.0
4
4
  Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
5
5
  Author: Alexander Nasuta
6
6
  Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
@@ -70,11 +70,18 @@ Requires-Dist: jupyter; extra == "dev"
70
70
  Requires-Dist: typing_extensions>=4.12.0; extra == "dev"
71
71
  Dynamic: license-file
72
72
 
73
- # Graph Matrix Job Shop Env
73
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15283390.svg)](https://doi.org/10.5281/zenodo.15283390)
74
+ [![Python Badge](https://img.shields.io/badge/Python-3776AB?logo=python&logoColor=fff&style=flat)](https://www.python.org/downloads/)
75
+ [![PyPI version](https://img.shields.io/pypi/v/gymcts)](https://pypi.org/project/gymcts/)
76
+ [![License](https://img.shields.io/pypi/l/gymcts)](https://github.com/Alexander-Nasuta/gymcts/blob/master/LICENSE)
77
+ [![Documentation Status](https://readthedocs.org/projects/gymcts/badge/?version=latest)](https://gymcts.readthedocs.io/en/latest/?badge=latest)
78
+
79
+ # GYMCTS
74
80
 
75
81
  A Monte Carlo Tree Search Implementation for Gymnasium-style Environments.
76
82
 
77
83
  - Github: [GYMCTS on Github](https://github.com/Alexander-Nasuta/gymcts)
84
+ - GitLab: [GYMCTS on GitLab](https://git-ce.rwth-aachen.de/alexander.nasuta/gymcts)
78
85
  - Pypi: [GYMCTS on PyPi](https://pypi.org/project/gymcts/)
79
86
  - Documentation: [GYMCTS Docs](https://gymcts.readthedocs.io/en/latest/)
80
87
 
@@ -579,9 +586,6 @@ This project uses `pytest` for testing. To run the tests, run the following comm
579
586
  ```shell
580
587
  pytest
581
588
  ```
582
- Here is a screenshot of what the output might look like:
583
-
584
- ![](https://github.com/Alexander-Nasuta/GraphMatrixJobShopEnv/raw/master/resources/pytest-screenshot.png)
585
589
 
586
590
  For testing with `tox` run the following command:
587
591
 
@@ -11,6 +11,7 @@ src/gymcts/gymcts_agent.py
11
11
  src/gymcts/gymcts_deepcopy_wrapper.py
12
12
  src/gymcts/gymcts_distributed_agent.py
13
13
  src/gymcts/gymcts_env_abc.py
14
+ src/gymcts/gymcts_neural_agent.py
14
15
  src/gymcts/gymcts_node.py
15
16
  src/gymcts/gymcts_tree_plotter.py
16
17
  src/gymcts/logger.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes