gymcts 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/colorful_console_utils.py +22 -0
- gymcts/gymcts_action_history_wrapper.py +72 -2
- gymcts/gymcts_agent.py +54 -7
- gymcts/gymcts_deepcopy_wrapper.py +59 -2
- gymcts/gymcts_distributed_agent.py +30 -12
- gymcts/gymcts_env_abc.py +45 -2
- gymcts/gymcts_neural_agent.py +479 -0
- gymcts/gymcts_node.py +161 -17
- gymcts/gymcts_tree_plotter.py +22 -1
- gymcts/logger.py +1 -4
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/METADATA +39 -39
- gymcts-1.3.0.dist-info/RECORD +16 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/WHEEL +1 -1
- gymcts-1.2.0.dist-info/RECORD +0 -15
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.2.0.dist-info → gymcts-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
import math
|
|
7
|
+
import sb3_contrib
|
|
8
|
+
|
|
9
|
+
import gymnasium as gym
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
|
|
13
|
+
from jsp_instance_utils.instances import ft06, ft06_makespan
|
|
14
|
+
from sb3_contrib.common.maskable.distributions import MaskableCategoricalDistribution
|
|
15
|
+
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
|
16
|
+
from sb3_contrib.common.wrappers import ActionMasker
|
|
17
|
+
|
|
18
|
+
from gymcts.gymcts_agent import GymctsAgent
|
|
19
|
+
from gymcts.gymcts_env_abc import GymctsABC
|
|
20
|
+
from gymcts.gymcts_node import GymctsNode
|
|
21
|
+
|
|
22
|
+
from gymcts.logger import log
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GraphJspNeuralGYMCTSWrapper(GymctsABC, gym.Wrapper):
|
|
26
|
+
|
|
27
|
+
def __init__(self, env: DisjunctiveGraphJspEnv):
|
|
28
|
+
gym.Wrapper.__init__(self, env)
|
|
29
|
+
|
|
30
|
+
def load_state(self, state: Any) -> None:
|
|
31
|
+
self.env.reset()
|
|
32
|
+
for action in state:
|
|
33
|
+
self.env.step(action)
|
|
34
|
+
|
|
35
|
+
def is_terminal(self) -> bool:
|
|
36
|
+
return self.env.unwrapped.is_terminal()
|
|
37
|
+
|
|
38
|
+
def get_valid_actions(self) -> list[int]:
|
|
39
|
+
return list(self.env.unwrapped.valid_actions())
|
|
40
|
+
|
|
41
|
+
def rollout(self) -> float:
|
|
42
|
+
terminal = env.is_terminal()
|
|
43
|
+
|
|
44
|
+
if terminal:
|
|
45
|
+
lower_bound = env.unwrapped.reward_function_parameters['scaling_divisor']
|
|
46
|
+
return - env.unwrapped.get_makespan() / lower_bound + 2
|
|
47
|
+
|
|
48
|
+
reward = 0
|
|
49
|
+
while not terminal:
|
|
50
|
+
action = random.choice(self.get_valid_actions())
|
|
51
|
+
obs, reward, terminal, truncated, _ = env.step(action)
|
|
52
|
+
|
|
53
|
+
return reward + 2
|
|
54
|
+
|
|
55
|
+
def get_state(self) -> Any:
|
|
56
|
+
return env.unwrapped.get_action_history()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def action_masks(self) -> np.ndarray | None:
|
|
60
|
+
"""Return the action mask for the current state."""
|
|
61
|
+
return self.env.unwrapped.valid_action_mask()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class GymctsNeuralNode(GymctsNode):
|
|
67
|
+
PUCT_v3_mu = 0.95
|
|
68
|
+
|
|
69
|
+
MuZero_c1 = 1.25
|
|
70
|
+
MuZero_c2 = 19652.0
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
PUCT (Predictor + UCT) exploration terms:
|
|
74
|
+
|
|
75
|
+
PUCT_v0:
|
|
76
|
+
c * P(s, a) * √( N(s) / (1 + N(s,a)) )
|
|
77
|
+
|
|
78
|
+
PUCT_v1:
|
|
79
|
+
c * P(s, a) * √( 2 * ln(N(s)) / N(s,a) )
|
|
80
|
+
|
|
81
|
+
PUCT_v2:
|
|
82
|
+
c * P(s, a) * √( N(s) ) / N(s,a)
|
|
83
|
+
|
|
84
|
+
PUCT_v3:
|
|
85
|
+
c * P(s, a)^μ * √( N(s) / (1 + N(s,a)) )
|
|
86
|
+
|
|
87
|
+
PUCT_v4:
|
|
88
|
+
c * ( P(s, a) / (1 + N(s,a)) )
|
|
89
|
+
|
|
90
|
+
PUCT_v5:
|
|
91
|
+
c * P(s, a) * ( √(N(s)) + 1 ) / (N(s,a) + 1)
|
|
92
|
+
|
|
93
|
+
PUCT_v6:
|
|
94
|
+
c * P(s, a) * N(s) / (1 + N(s,a))
|
|
95
|
+
|
|
96
|
+
PUCT_v7:
|
|
97
|
+
c * P(s, a) * ( √(N(s)) + ε ) / (N(s,a) + 1)
|
|
98
|
+
|
|
99
|
+
PUCT_v8:
|
|
100
|
+
c * P(s, a) * √( (ln(N(s)) + 1) / (1 + N(s,a)) )
|
|
101
|
+
|
|
102
|
+
PUCT_v9:
|
|
103
|
+
c * P(s, a) * √( N(s) / (1 + N(s,a)) )
|
|
104
|
+
|
|
105
|
+
PUCT_v10:
|
|
106
|
+
c * P(s, a) * √( ln(N(s)) / (1 + N(s,a)) )
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
MuZero exploration terms:
|
|
110
|
+
|
|
111
|
+
MuZero_v0:
|
|
112
|
+
P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
|
|
113
|
+
|
|
114
|
+
MuZero_v1:
|
|
115
|
+
P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
Where:
|
|
119
|
+
- N(s): number of times state s has been visited
|
|
120
|
+
- N(s,a): number of times action a was taken from state s
|
|
121
|
+
- P(s,a): prior probability of selecting action a from state s
|
|
122
|
+
- c, c₁, c₂: exploration constants
|
|
123
|
+
- μ: exponent applied to P(s,a) in some variants
|
|
124
|
+
- ε: small constant to avoid division by zero (in PUCT 7)
|
|
125
|
+
"""
|
|
126
|
+
score_variate: Literal[
|
|
127
|
+
"PUCT_v0",
|
|
128
|
+
"PUCT_v1",
|
|
129
|
+
"PUTC_v2",
|
|
130
|
+
"PUTC_v3",
|
|
131
|
+
"PUTC_v4",
|
|
132
|
+
"PUTC_v5",
|
|
133
|
+
"PUTC_v6",
|
|
134
|
+
"PUTC_v7",
|
|
135
|
+
"PUTC_v8",
|
|
136
|
+
"PUTC_v9",
|
|
137
|
+
"PUTC_v10",
|
|
138
|
+
"MuZero_v0",
|
|
139
|
+
"MuZero_v1",
|
|
140
|
+
] = "PUCT_v0"
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
action: int,
|
|
145
|
+
parent: 'GymctsNeuralNode',
|
|
146
|
+
env_reference: GymctsABC,
|
|
147
|
+
prior_selection_score: float,
|
|
148
|
+
observation: np.ndarray | None = None,
|
|
149
|
+
):
|
|
150
|
+
super().__init__(action, parent, env_reference)
|
|
151
|
+
|
|
152
|
+
self._obs = observation
|
|
153
|
+
self._selection_score_prior = prior_selection_score
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def tree_policy_score(self) -> float:
|
|
157
|
+
# call the superclass (GymctsNode) for ucb_score
|
|
158
|
+
c = GymctsNode.ubc_c
|
|
159
|
+
# the way alpha zero does it
|
|
160
|
+
# exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count)) / (1 + self.visit_count)
|
|
161
|
+
# the way the vanilla gymcts does it
|
|
162
|
+
p_sa = self._selection_score_prior
|
|
163
|
+
n_s = self.parent.visit_count
|
|
164
|
+
n_sa = self.visit_count
|
|
165
|
+
if GymctsNeuralNode.score_variate == "PUCT_v0":
|
|
166
|
+
return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
|
|
167
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v1":
|
|
168
|
+
return self.mean_value + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
|
|
169
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v2":
|
|
170
|
+
return self.mean_value + c * p_sa * math.sqrt(n_s) / n_sa
|
|
171
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v3":
|
|
172
|
+
return self.mean_value + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
|
|
173
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v4":
|
|
174
|
+
return self.mean_value + c * (p_sa / (1 + n_sa))
|
|
175
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v5":
|
|
176
|
+
return self.mean_value + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
|
|
177
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v6":
|
|
178
|
+
return self.mean_value + c * p_sa * n_s / (1 + n_sa)
|
|
179
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v7":
|
|
180
|
+
epsilon = 1e-8
|
|
181
|
+
return self.mean_value + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
|
|
182
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v8":
|
|
183
|
+
return self.mean_value + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
|
|
184
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v9":
|
|
185
|
+
return self.mean_value + c * p_sa * math.sqrt(n_s / (1 + n_sa))
|
|
186
|
+
elif GymctsNeuralNode.score_variate == "PUCT_v10":
|
|
187
|
+
return self.mean_value + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
|
|
188
|
+
elif GymctsNeuralNode.score_variate == "MuZero_v0":
|
|
189
|
+
c1 = GymctsNeuralNode.MuZero_c1
|
|
190
|
+
c2 = GymctsNeuralNode.MuZero_c2
|
|
191
|
+
return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
|
|
192
|
+
elif GymctsNeuralNode.score_variate == "MuZero_v1":
|
|
193
|
+
c1 = GymctsNeuralNode.MuZero_c1
|
|
194
|
+
c2 = GymctsNeuralNode.MuZero_c2
|
|
195
|
+
return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
|
|
199
|
+
return self.mean_value + exploration_term
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def get_best_action(self) -> int:
|
|
203
|
+
"""
|
|
204
|
+
Returns the best action of the node. The best action is the action with the highest score.
|
|
205
|
+
The best action is the action that has the highest score.
|
|
206
|
+
|
|
207
|
+
:return: the best action of the node.
|
|
208
|
+
"""
|
|
209
|
+
return max(self.children.values(), key=lambda child: child.max_value).action
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
213
|
+
"""
|
|
214
|
+
Returns a string representation of the node. The string representation is used for visualisation purposes.
|
|
215
|
+
It is used for example in the mcts tree visualisation functionality.
|
|
216
|
+
|
|
217
|
+
:param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
|
|
218
|
+
:param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
|
|
219
|
+
:return: a potentially colored string representation of the node.
|
|
220
|
+
"""
|
|
221
|
+
if not colored:
|
|
222
|
+
|
|
223
|
+
if not self.is_root():
|
|
224
|
+
return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
|
|
225
|
+
else:
|
|
226
|
+
return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
|
|
227
|
+
|
|
228
|
+
import gymcts.colorful_console_utils as ccu
|
|
229
|
+
|
|
230
|
+
if self.is_root():
|
|
231
|
+
return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
|
|
232
|
+
|
|
233
|
+
if action_space_n is None:
|
|
234
|
+
raise ValueError("action_space_n must be provided if colored is True")
|
|
235
|
+
|
|
236
|
+
p = ccu.CYELLOW
|
|
237
|
+
e = ccu.CEND
|
|
238
|
+
v = ccu.CCYAN
|
|
239
|
+
|
|
240
|
+
def colorful_value(value: float | int | None) -> str:
|
|
241
|
+
if value == None:
|
|
242
|
+
return f"{ccu.CGREY}None{e}"
|
|
243
|
+
color = ccu.CCYAN
|
|
244
|
+
if value == 0:
|
|
245
|
+
color = ccu.CRED
|
|
246
|
+
if value == float("inf"):
|
|
247
|
+
color = ccu.CGREY
|
|
248
|
+
if value == -float("inf"):
|
|
249
|
+
color = ccu.CGREY
|
|
250
|
+
|
|
251
|
+
if isinstance(value, float):
|
|
252
|
+
return f"{color}{value:.2f}{e}"
|
|
253
|
+
|
|
254
|
+
if isinstance(value, int):
|
|
255
|
+
return f"{color}{value}{e}"
|
|
256
|
+
|
|
257
|
+
root_node = self.get_root()
|
|
258
|
+
mean_val = f"{self.mean_value:.2f}"
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
return ((f"("
|
|
262
|
+
f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
|
|
263
|
+
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
264
|
+
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
265
|
+
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
266
|
+
(f", {p}{GymctsNeuralNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class GymctsNeuralAgent(GymctsAgent):
|
|
271
|
+
|
|
272
|
+
def __init__(self,
|
|
273
|
+
env: GymctsABC,
|
|
274
|
+
*args,
|
|
275
|
+
model_kwargs=None,
|
|
276
|
+
score_variate: Literal[
|
|
277
|
+
"PUCT_v0",
|
|
278
|
+
"PUCT_v1",
|
|
279
|
+
"PUTC_v2",
|
|
280
|
+
"PUTC_v3",
|
|
281
|
+
"PUTC_v4",
|
|
282
|
+
"PUTC_v5",
|
|
283
|
+
"PUTC_v6",
|
|
284
|
+
"PUTC_v7",
|
|
285
|
+
"PUTC_v8",
|
|
286
|
+
"PUTC_v9",
|
|
287
|
+
"PUTC_v10",
|
|
288
|
+
"MuZero_v0",
|
|
289
|
+
"MuZero_v1",
|
|
290
|
+
] = "PUCT_v0",
|
|
291
|
+
**kwargs
|
|
292
|
+
):
|
|
293
|
+
|
|
294
|
+
# init super class
|
|
295
|
+
super().__init__(
|
|
296
|
+
env=env,
|
|
297
|
+
*args,
|
|
298
|
+
**kwargs
|
|
299
|
+
)
|
|
300
|
+
if score_variate not in [
|
|
301
|
+
"PUCT_v0", "PUCT_v1", "PUTC_v2",
|
|
302
|
+
"PUTC_v3", "PUTC_v4", "PUTC_v5",
|
|
303
|
+
"PUTC_v6", "PUTC_v7", "PUTC_v8",
|
|
304
|
+
"PUTC_v9", "PUTC_v10",
|
|
305
|
+
"MuZero_v0", "MuZero_v1"
|
|
306
|
+
]:
|
|
307
|
+
raise ValueError(f"Invalid score_variate: {score_variate}. Must be one of: "
|
|
308
|
+
f"PUCT_v0, PUCT_v1, PUTC_v2, PUTC_v3, PUTC_v4, PUTC_v5, "
|
|
309
|
+
f"PUTC_v6, PUTC_v7, PUTC_v8, PUTC_v9, PUTC_v10, MuZero_v0, MuZero_v1")
|
|
310
|
+
GymctsNeuralNode.score_variate = score_variate
|
|
311
|
+
|
|
312
|
+
if model_kwargs is None:
|
|
313
|
+
model_kwargs = {}
|
|
314
|
+
obs, info = env.reset()
|
|
315
|
+
|
|
316
|
+
self.search_root_node = GymctsNeuralNode(
|
|
317
|
+
action=None,
|
|
318
|
+
parent=None,
|
|
319
|
+
env_reference=env,
|
|
320
|
+
observation=obs,
|
|
321
|
+
prior_selection_score=1.0,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def mask_fn(env: gym.Env) -> np.ndarray:
|
|
325
|
+
mask = env.action_masks()
|
|
326
|
+
if mask is None:
|
|
327
|
+
mask = np.ones(env.action_space.n, dtype=np.float32)
|
|
328
|
+
return mask
|
|
329
|
+
|
|
330
|
+
env = ActionMasker(env, action_mask_fn=mask_fn)
|
|
331
|
+
|
|
332
|
+
model_kwargs = {
|
|
333
|
+
"policy": MaskableActorCriticPolicy,
|
|
334
|
+
"env": env,
|
|
335
|
+
"verbose": 1,
|
|
336
|
+
} | model_kwargs
|
|
337
|
+
|
|
338
|
+
self._model = sb3_contrib.MaskablePPO(**model_kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def learn(self, total_timesteps:int, **kwargs) -> None:
|
|
345
|
+
"""Learn from the environment using the MaskablePPO model."""
|
|
346
|
+
self._model.learn(total_timesteps=total_timesteps, **kwargs)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def expand_node(self, node: GymctsNeuralNode) -> None:
|
|
350
|
+
log.debug(f"expanding node: {node}")
|
|
351
|
+
# EXPANSION STRATEGY
|
|
352
|
+
# expand all children
|
|
353
|
+
|
|
354
|
+
child_dict = {}
|
|
355
|
+
|
|
356
|
+
self._load_state(node)
|
|
357
|
+
|
|
358
|
+
obs_tensor, vectorized_env = self._model.policy.obs_to_tensor(np.array([node._obs]))
|
|
359
|
+
action_masks = np.array([self.env.action_masks()])
|
|
360
|
+
distribution = self._model.policy.get_distribution(obs=obs_tensor, action_masks=action_masks)
|
|
361
|
+
unwrapped_distribution = distribution.distribution.probs[0]
|
|
362
|
+
|
|
363
|
+
# print(f'valid actions: {node.valid_actions}')
|
|
364
|
+
# print(f'env mask: {self.env.action_masks()}')
|
|
365
|
+
# print(f'env valid actions: {self.env.get_valid_actions()}')
|
|
366
|
+
"""
|
|
367
|
+
for action in node.valid_actions:
|
|
368
|
+
# reconstruct state
|
|
369
|
+
# load state of leaf node
|
|
370
|
+
self._load_state(node)
|
|
371
|
+
|
|
372
|
+
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
373
|
+
child_dict[action] = GymctsNeuralNode(
|
|
374
|
+
action=action,
|
|
375
|
+
parent=node,
|
|
376
|
+
env_reference=self.env,
|
|
377
|
+
observation=obs,
|
|
378
|
+
prior_selection_score=1.0,
|
|
379
|
+
)
|
|
380
|
+
node.children = child_dict
|
|
381
|
+
return
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
for action, prob in enumerate(unwrapped_distribution):
|
|
385
|
+
self._load_state(node)
|
|
386
|
+
|
|
387
|
+
log.debug(f"Probabily for action {action}: {prob}")
|
|
388
|
+
|
|
389
|
+
if prob == 0.0:
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
assert action in node.valid_actions, f"Action {action} is not in valid actions: {node.valid_actions}"
|
|
394
|
+
|
|
395
|
+
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
396
|
+
child_dict[action] = GymctsNeuralNode(
|
|
397
|
+
action=action,
|
|
398
|
+
parent=node,
|
|
399
|
+
observation=copy.deepcopy(obs),
|
|
400
|
+
env_reference=self.env,
|
|
401
|
+
prior_selection_score=float(prob)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
node.children = child_dict
|
|
405
|
+
# print(f"Expanded node {node} with {len(node.children)} children.")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
if __name__ == '__main__':
|
|
412
|
+
log.setLevel(20)
|
|
413
|
+
|
|
414
|
+
env_kwargs = {
|
|
415
|
+
"jps_instance": ft06,
|
|
416
|
+
"default_visualisations": ["gantt_console", "graph_console"],
|
|
417
|
+
"reward_function_parameters": {
|
|
418
|
+
"scaling_divisor": ft06_makespan
|
|
419
|
+
},
|
|
420
|
+
"reward_function": "nasuta",
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
env = DisjunctiveGraphJspEnv(**env_kwargs)
|
|
426
|
+
env.reset()
|
|
427
|
+
|
|
428
|
+
env = GraphJspNeuralGYMCTSWrapper(env)
|
|
429
|
+
|
|
430
|
+
import torch
|
|
431
|
+
model_kwargs = {
|
|
432
|
+
"gamma": 0.99013,
|
|
433
|
+
"gae_lambda": 0.9,
|
|
434
|
+
"normalize_advantage": True,
|
|
435
|
+
"n_epochs": 28,
|
|
436
|
+
"n_steps": 432,
|
|
437
|
+
"max_grad_norm": 0.5,
|
|
438
|
+
"learning_rate": 6e-4,
|
|
439
|
+
"policy_kwargs": {
|
|
440
|
+
"net_arch": {
|
|
441
|
+
"pi": [90, 90],
|
|
442
|
+
"vf": [90, 90],
|
|
443
|
+
},
|
|
444
|
+
"ortho_init": True,
|
|
445
|
+
"activation_fn": torch.nn.ELU,
|
|
446
|
+
"optimizer_kwargs": {
|
|
447
|
+
"eps": 1e-7
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
agent = GymctsNeuralAgent(
|
|
453
|
+
env=env,
|
|
454
|
+
render_tree_after_step=True,
|
|
455
|
+
render_tree_max_depth=3,
|
|
456
|
+
exclude_unvisited_nodes_from_render=False,
|
|
457
|
+
number_of_simulations_per_step=15,
|
|
458
|
+
# clear_mcts_tree_after_step = False,
|
|
459
|
+
model_kwargs=model_kwargs
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
agent.learn(total_timesteps=10_000)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
agent.solve()
|
|
466
|
+
|
|
467
|
+
actions = agent.solve(render_tree_after_step=True)
|
|
468
|
+
for a in actions:
|
|
469
|
+
obs, rew, term, trun, info = env.step(a)
|
|
470
|
+
|
|
471
|
+
env.render()
|
|
472
|
+
makespan = env.unwrapped.get_makespan()
|
|
473
|
+
print(f"makespan: {makespan}")
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
|