gymcts 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,479 @@
1
+ import copy
2
+ import sys
3
+ from typing import Any, Literal
4
+
5
+ import random
6
+ import math
7
+ import sb3_contrib
8
+
9
+ import gymnasium as gym
10
+ import numpy as np
11
+
12
+ from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
13
+ from jsp_instance_utils.instances import ft06, ft06_makespan
14
+ from sb3_contrib.common.maskable.distributions import MaskableCategoricalDistribution
15
+ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
16
+ from sb3_contrib.common.wrappers import ActionMasker
17
+
18
+ from gymcts.gymcts_agent import GymctsAgent
19
+ from gymcts.gymcts_env_abc import GymctsABC
20
+ from gymcts.gymcts_node import GymctsNode
21
+
22
+ from gymcts.logger import log
23
+
24
+
25
+ class GraphJspNeuralGYMCTSWrapper(GymctsABC, gym.Wrapper):
26
+
27
+ def __init__(self, env: DisjunctiveGraphJspEnv):
28
+ gym.Wrapper.__init__(self, env)
29
+
30
+ def load_state(self, state: Any) -> None:
31
+ self.env.reset()
32
+ for action in state:
33
+ self.env.step(action)
34
+
35
+ def is_terminal(self) -> bool:
36
+ return self.env.unwrapped.is_terminal()
37
+
38
+ def get_valid_actions(self) -> list[int]:
39
+ return list(self.env.unwrapped.valid_actions())
40
+
41
+ def rollout(self) -> float:
42
+ terminal = env.is_terminal()
43
+
44
+ if terminal:
45
+ lower_bound = env.unwrapped.reward_function_parameters['scaling_divisor']
46
+ return - env.unwrapped.get_makespan() / lower_bound + 2
47
+
48
+ reward = 0
49
+ while not terminal:
50
+ action = random.choice(self.get_valid_actions())
51
+ obs, reward, terminal, truncated, _ = env.step(action)
52
+
53
+ return reward + 2
54
+
55
+ def get_state(self) -> Any:
56
+ return env.unwrapped.get_action_history()
57
+
58
+
59
+ def action_masks(self) -> np.ndarray | None:
60
+ """Return the action mask for the current state."""
61
+ return self.env.unwrapped.valid_action_mask()
62
+
63
+
64
+
65
+
66
+ class GymctsNeuralNode(GymctsNode):
67
+ PUCT_v3_mu = 0.95
68
+
69
+ MuZero_c1 = 1.25
70
+ MuZero_c2 = 19652.0
71
+
72
+ """
73
+ PUCT (Predictor + UCT) exploration terms:
74
+
75
+ PUCT_v0:
76
+ c * P(s, a) * √( N(s) / (1 + N(s,a)) )
77
+
78
+ PUCT_v1:
79
+ c * P(s, a) * √( 2 * ln(N(s)) / N(s,a) )
80
+
81
+ PUCT_v2:
82
+ c * P(s, a) * √( N(s) ) / N(s,a)
83
+
84
+ PUCT_v3:
85
+ c * P(s, a)^μ * √( N(s) / (1 + N(s,a)) )
86
+
87
+ PUCT_v4:
88
+ c * ( P(s, a) / (1 + N(s,a)) )
89
+
90
+ PUCT_v5:
91
+ c * P(s, a) * ( √(N(s)) + 1 ) / (N(s,a) + 1)
92
+
93
+ PUCT_v6:
94
+ c * P(s, a) * N(s) / (1 + N(s,a))
95
+
96
+ PUCT_v7:
97
+ c * P(s, a) * ( √(N(s)) + ε ) / (N(s,a) + 1)
98
+
99
+ PUCT_v8:
100
+ c * P(s, a) * √( (ln(N(s)) + 1) / (1 + N(s,a)) )
101
+
102
+ PUCT_v9:
103
+ c * P(s, a) * √( N(s) / (1 + N(s,a)) )
104
+
105
+ PUCT_v10:
106
+ c * P(s, a) * √( ln(N(s)) / (1 + N(s,a)) )
107
+
108
+
109
+ MuZero exploration terms:
110
+
111
+ MuZero_v0:
112
+ P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
113
+
114
+ MuZero_v1:
115
+ P(s, a) * √( N(s) / (1 + N(s,a)) ) * [ c₁ + ln( (N(s) + c₂ + 1) / c₂ ) ]
116
+
117
+
118
+ Where:
119
+ - N(s): number of times state s has been visited
120
+ - N(s,a): number of times action a was taken from state s
121
+ - P(s,a): prior probability of selecting action a from state s
122
+ - c, c₁, c₂: exploration constants
123
+ - μ: exponent applied to P(s,a) in some variants
124
+ - ε: small constant to avoid division by zero (in PUCT 7)
125
+ """
126
+ score_variate: Literal[
127
+ "PUCT_v0",
128
+ "PUCT_v1",
129
+ "PUTC_v2",
130
+ "PUTC_v3",
131
+ "PUTC_v4",
132
+ "PUTC_v5",
133
+ "PUTC_v6",
134
+ "PUTC_v7",
135
+ "PUTC_v8",
136
+ "PUTC_v9",
137
+ "PUTC_v10",
138
+ "MuZero_v0",
139
+ "MuZero_v1",
140
+ ] = "PUCT_v0"
141
+
142
+ def __init__(
143
+ self,
144
+ action: int,
145
+ parent: 'GymctsNeuralNode',
146
+ env_reference: GymctsABC,
147
+ prior_selection_score: float,
148
+ observation: np.ndarray | None = None,
149
+ ):
150
+ super().__init__(action, parent, env_reference)
151
+
152
+ self._obs = observation
153
+ self._selection_score_prior = prior_selection_score
154
+
155
+
156
+ def tree_policy_score(self) -> float:
157
+ # call the superclass (GymctsNode) for ucb_score
158
+ c = GymctsNode.ubc_c
159
+ # the way alpha zero does it
160
+ # exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count)) / (1 + self.visit_count)
161
+ # the way the vanilla gymcts does it
162
+ p_sa = self._selection_score_prior
163
+ n_s = self.parent.visit_count
164
+ n_sa = self.visit_count
165
+ if GymctsNeuralNode.score_variate == "PUCT_v0":
166
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
167
+ elif GymctsNeuralNode.score_variate == "PUCT_v1":
168
+ return self.mean_value + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
169
+ elif GymctsNeuralNode.score_variate == "PUCT_v2":
170
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / n_sa
171
+ elif GymctsNeuralNode.score_variate == "PUCT_v3":
172
+ return self.mean_value + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
173
+ elif GymctsNeuralNode.score_variate == "PUCT_v4":
174
+ return self.mean_value + c * (p_sa / (1 + n_sa))
175
+ elif GymctsNeuralNode.score_variate == "PUCT_v5":
176
+ return self.mean_value + c * p_sa * (math.sqrt(n_s) + 1) / (n_sa + 1)
177
+ elif GymctsNeuralNode.score_variate == "PUCT_v6":
178
+ return self.mean_value + c * p_sa * n_s / (1 + n_sa)
179
+ elif GymctsNeuralNode.score_variate == "PUCT_v7":
180
+ epsilon = 1e-8
181
+ return self.mean_value + c * p_sa * (math.sqrt(n_s) + epsilon) / (n_sa + 1)
182
+ elif GymctsNeuralNode.score_variate == "PUCT_v8":
183
+ return self.mean_value + c * p_sa * math.sqrt((math.log(n_s) + 1) / (1 + n_sa))
184
+ elif GymctsNeuralNode.score_variate == "PUCT_v9":
185
+ return self.mean_value + c * p_sa * math.sqrt(n_s / (1 + n_sa))
186
+ elif GymctsNeuralNode.score_variate == "PUCT_v10":
187
+ return self.mean_value + c * p_sa * math.sqrt(math.log(n_s) / (1 + n_sa))
188
+ elif GymctsNeuralNode.score_variate == "MuZero_v0":
189
+ c1 = GymctsNeuralNode.MuZero_c1
190
+ c2 = GymctsNeuralNode.MuZero_c2
191
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
192
+ elif GymctsNeuralNode.score_variate == "MuZero_v1":
193
+ c1 = GymctsNeuralNode.MuZero_c1
194
+ c2 = GymctsNeuralNode.MuZero_c2
195
+ return self.mean_value + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
196
+
197
+
198
+ exploration_term = self._selection_score_prior * c * math.sqrt(math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
199
+ return self.mean_value + exploration_term
200
+
201
+
202
+ def get_best_action(self) -> int:
203
+ """
204
+ Returns the best action of the node. The best action is the action with the highest score.
205
+ The best action is the action that has the highest score.
206
+
207
+ :return: the best action of the node.
208
+ """
209
+ return max(self.children.values(), key=lambda child: child.max_value).action
210
+
211
+
212
+ def __str__(self, colored=False, action_space_n=None) -> str:
213
+ """
214
+ Returns a string representation of the node. The string representation is used for visualisation purposes.
215
+ It is used for example in the mcts tree visualisation functionality.
216
+
217
+ :param colored: true if the string representation should be colored, false otherwise. (ture is used by the mcts tree visualisation)
218
+ :param action_space_n: the number of actions in the action space. This is used for coloring the action in the string representation.
219
+ :return: a potentially colored string representation of the node.
220
+ """
221
+ if not colored:
222
+
223
+ if not self.is_root():
224
+ return f"(a={self.action}, N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}, ubc={self.tree_policy_score():.2f})"
225
+ else:
226
+ return f"(N={self.visit_count}, Q_v={self.mean_value:.2f}, best={self.max_value:.2f}) [root]"
227
+
228
+ import gymcts.colorful_console_utils as ccu
229
+
230
+ if self.is_root():
231
+ return f"({ccu.CYELLOW}N{ccu.CEND}={self.visit_count}, {ccu.CYELLOW}Q_v{ccu.CEND}={self.mean_value:.2f}, {ccu.CYELLOW}best{ccu.CEND}={self.max_value:.2f})"
232
+
233
+ if action_space_n is None:
234
+ raise ValueError("action_space_n must be provided if colored is True")
235
+
236
+ p = ccu.CYELLOW
237
+ e = ccu.CEND
238
+ v = ccu.CCYAN
239
+
240
+ def colorful_value(value: float | int | None) -> str:
241
+ if value == None:
242
+ return f"{ccu.CGREY}None{e}"
243
+ color = ccu.CCYAN
244
+ if value == 0:
245
+ color = ccu.CRED
246
+ if value == float("inf"):
247
+ color = ccu.CGREY
248
+ if value == -float("inf"):
249
+ color = ccu.CGREY
250
+
251
+ if isinstance(value, float):
252
+ return f"{color}{value:.2f}{e}"
253
+
254
+ if isinstance(value, int):
255
+ return f"{color}{value}{e}"
256
+
257
+ root_node = self.get_root()
258
+ mean_val = f"{self.mean_value:.2f}"
259
+
260
+
261
+ return ((f"("
262
+ f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
263
+ f"{p}N{e}={colorful_value(self.visit_count)}, "
264
+ f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
265
+ f"{p}best{e}={colorful_value(self.max_value)}") +
266
+ (f", {p}{GymctsNeuralNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
267
+
268
+
269
+
270
+ class GymctsNeuralAgent(GymctsAgent):
271
+
272
+ def __init__(self,
273
+ env: GymctsABC,
274
+ *args,
275
+ model_kwargs=None,
276
+ score_variate: Literal[
277
+ "PUCT_v0",
278
+ "PUCT_v1",
279
+ "PUTC_v2",
280
+ "PUTC_v3",
281
+ "PUTC_v4",
282
+ "PUTC_v5",
283
+ "PUTC_v6",
284
+ "PUTC_v7",
285
+ "PUTC_v8",
286
+ "PUTC_v9",
287
+ "PUTC_v10",
288
+ "MuZero_v0",
289
+ "MuZero_v1",
290
+ ] = "PUCT_v0",
291
+ **kwargs
292
+ ):
293
+
294
+ # init super class
295
+ super().__init__(
296
+ env=env,
297
+ *args,
298
+ **kwargs
299
+ )
300
+ if score_variate not in [
301
+ "PUCT_v0", "PUCT_v1", "PUTC_v2",
302
+ "PUTC_v3", "PUTC_v4", "PUTC_v5",
303
+ "PUTC_v6", "PUTC_v7", "PUTC_v8",
304
+ "PUTC_v9", "PUTC_v10",
305
+ "MuZero_v0", "MuZero_v1"
306
+ ]:
307
+ raise ValueError(f"Invalid score_variate: {score_variate}. Must be one of: "
308
+ f"PUCT_v0, PUCT_v1, PUTC_v2, PUTC_v3, PUTC_v4, PUTC_v5, "
309
+ f"PUTC_v6, PUTC_v7, PUTC_v8, PUTC_v9, PUTC_v10, MuZero_v0, MuZero_v1")
310
+ GymctsNeuralNode.score_variate = score_variate
311
+
312
+ if model_kwargs is None:
313
+ model_kwargs = {}
314
+ obs, info = env.reset()
315
+
316
+ self.search_root_node = GymctsNeuralNode(
317
+ action=None,
318
+ parent=None,
319
+ env_reference=env,
320
+ observation=obs,
321
+ prior_selection_score=1.0,
322
+ )
323
+
324
+ def mask_fn(env: gym.Env) -> np.ndarray:
325
+ mask = env.action_masks()
326
+ if mask is None:
327
+ mask = np.ones(env.action_space.n, dtype=np.float32)
328
+ return mask
329
+
330
+ env = ActionMasker(env, action_mask_fn=mask_fn)
331
+
332
+ model_kwargs = {
333
+ "policy": MaskableActorCriticPolicy,
334
+ "env": env,
335
+ "verbose": 1,
336
+ } | model_kwargs
337
+
338
+ self._model = sb3_contrib.MaskablePPO(**model_kwargs)
339
+
340
+
341
+
342
+
343
+
344
+ def learn(self, total_timesteps:int, **kwargs) -> None:
345
+ """Learn from the environment using the MaskablePPO model."""
346
+ self._model.learn(total_timesteps=total_timesteps, **kwargs)
347
+
348
+
349
+ def expand_node(self, node: GymctsNeuralNode) -> None:
350
+ log.debug(f"expanding node: {node}")
351
+ # EXPANSION STRATEGY
352
+ # expand all children
353
+
354
+ child_dict = {}
355
+
356
+ self._load_state(node)
357
+
358
+ obs_tensor, vectorized_env = self._model.policy.obs_to_tensor(np.array([node._obs]))
359
+ action_masks = np.array([self.env.action_masks()])
360
+ distribution = self._model.policy.get_distribution(obs=obs_tensor, action_masks=action_masks)
361
+ unwrapped_distribution = distribution.distribution.probs[0]
362
+
363
+ # print(f'valid actions: {node.valid_actions}')
364
+ # print(f'env mask: {self.env.action_masks()}')
365
+ # print(f'env valid actions: {self.env.get_valid_actions()}')
366
+ """
367
+ for action in node.valid_actions:
368
+ # reconstruct state
369
+ # load state of leaf node
370
+ self._load_state(node)
371
+
372
+ obs, reward, terminal, truncated, _ = self.env.step(action)
373
+ child_dict[action] = GymctsNeuralNode(
374
+ action=action,
375
+ parent=node,
376
+ env_reference=self.env,
377
+ observation=obs,
378
+ prior_selection_score=1.0,
379
+ )
380
+ node.children = child_dict
381
+ return
382
+ """
383
+
384
+ for action, prob in enumerate(unwrapped_distribution):
385
+ self._load_state(node)
386
+
387
+ log.debug(f"Probabily for action {action}: {prob}")
388
+
389
+ if prob == 0.0:
390
+ continue
391
+
392
+
393
+ assert action in node.valid_actions, f"Action {action} is not in valid actions: {node.valid_actions}"
394
+
395
+ obs, reward, terminal, truncated, _ = self.env.step(action)
396
+ child_dict[action] = GymctsNeuralNode(
397
+ action=action,
398
+ parent=node,
399
+ observation=copy.deepcopy(obs),
400
+ env_reference=self.env,
401
+ prior_selection_score=float(prob)
402
+ )
403
+
404
+ node.children = child_dict
405
+ # print(f"Expanded node {node} with {len(node.children)} children.")
406
+
407
+
408
+
409
+
410
+
411
+ if __name__ == '__main__':
412
+ log.setLevel(20)
413
+
414
+ env_kwargs = {
415
+ "jps_instance": ft06,
416
+ "default_visualisations": ["gantt_console", "graph_console"],
417
+ "reward_function_parameters": {
418
+ "scaling_divisor": ft06_makespan
419
+ },
420
+ "reward_function": "nasuta",
421
+ }
422
+
423
+
424
+
425
+ env = DisjunctiveGraphJspEnv(**env_kwargs)
426
+ env.reset()
427
+
428
+ env = GraphJspNeuralGYMCTSWrapper(env)
429
+
430
+ import torch
431
+ model_kwargs = {
432
+ "gamma": 0.99013,
433
+ "gae_lambda": 0.9,
434
+ "normalize_advantage": True,
435
+ "n_epochs": 28,
436
+ "n_steps": 432,
437
+ "max_grad_norm": 0.5,
438
+ "learning_rate": 6e-4,
439
+ "policy_kwargs": {
440
+ "net_arch": {
441
+ "pi": [90, 90],
442
+ "vf": [90, 90],
443
+ },
444
+ "ortho_init": True,
445
+ "activation_fn": torch.nn.ELU,
446
+ "optimizer_kwargs": {
447
+ "eps": 1e-7
448
+ }
449
+ }
450
+ }
451
+
452
+ agent = GymctsNeuralAgent(
453
+ env=env,
454
+ render_tree_after_step=True,
455
+ render_tree_max_depth=3,
456
+ exclude_unvisited_nodes_from_render=False,
457
+ number_of_simulations_per_step=15,
458
+ # clear_mcts_tree_after_step = False,
459
+ model_kwargs=model_kwargs
460
+ )
461
+
462
+ agent.learn(total_timesteps=10_000)
463
+
464
+
465
+ agent.solve()
466
+
467
+ actions = agent.solve(render_tree_after_step=True)
468
+ for a in actions:
469
+ obs, rew, term, trun, info = env.step(a)
470
+
471
+ env.render()
472
+ makespan = env.unwrapped.get_makespan()
473
+ print(f"makespan: {makespan}")
474
+
475
+
476
+
477
+
478
+
479
+