gymcts 1.4.2__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gymcts/gymcts_neural_agent.py +37 -52
- {gymcts-1.4.2.dist-info → gymcts-1.4.4.dist-info}/METADATA +1 -1
- {gymcts-1.4.2.dist-info → gymcts-1.4.4.dist-info}/RECORD +6 -6
- {gymcts-1.4.2.dist-info → gymcts-1.4.4.dist-info}/WHEEL +0 -0
- {gymcts-1.4.2.dist-info → gymcts-1.4.4.dist-info}/licenses/LICENSE +0 -0
- {gymcts-1.4.2.dist-info → gymcts-1.4.4.dist-info}/top_level.txt +0 -0
gymcts/gymcts_neural_agent.py
CHANGED
|
@@ -55,14 +55,11 @@ class GraphJspNeuralGYMCTSWrapper(GymctsABC, gym.Wrapper):
|
|
|
55
55
|
def get_state(self) -> Any:
|
|
56
56
|
return env.unwrapped.get_action_history()
|
|
57
57
|
|
|
58
|
-
|
|
59
58
|
def action_masks(self) -> np.ndarray | None:
|
|
60
59
|
"""Return the action mask for the current state."""
|
|
61
60
|
return self.env.unwrapped.valid_action_mask()
|
|
62
61
|
|
|
63
62
|
|
|
64
|
-
|
|
65
|
-
|
|
66
63
|
class GymctsNeuralNode(GymctsNode):
|
|
67
64
|
PUCT_v3_mu = 0.95
|
|
68
65
|
|
|
@@ -146,13 +143,12 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
146
143
|
env_reference: GymctsABC,
|
|
147
144
|
prior_selection_score: float,
|
|
148
145
|
observation: np.ndarray | None = None,
|
|
149
|
-
|
|
146
|
+
):
|
|
150
147
|
super().__init__(action, parent, env_reference)
|
|
151
148
|
|
|
152
149
|
self._obs = observation
|
|
153
150
|
self._selection_score_prior = prior_selection_score
|
|
154
151
|
|
|
155
|
-
|
|
156
152
|
def tree_policy_score(self) -> float:
|
|
157
153
|
# call the superclass (GymctsNode) for ucb_score
|
|
158
154
|
c = GymctsNode.ubc_c
|
|
@@ -167,12 +163,13 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
167
163
|
b = GymctsNode.best_action_weight
|
|
168
164
|
exploitation_term = 0.0 if self.visit_count == 0 else (1 - b) * self.mean_value + b * self.max_value
|
|
169
165
|
|
|
170
|
-
|
|
171
166
|
if GymctsNeuralNode.score_variate == "PUCT_v0":
|
|
172
167
|
return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa)
|
|
173
168
|
elif GymctsNeuralNode.score_variate == "PUCT_v1":
|
|
174
169
|
return exploitation_term + c * p_sa * math.sqrt(2 * math.log(n_s) / (n_sa))
|
|
175
170
|
elif GymctsNeuralNode.score_variate == "PUCT_v2":
|
|
171
|
+
if n_sa == 0:
|
|
172
|
+
return float("inf") # Avoid division by zero
|
|
176
173
|
return exploitation_term + c * p_sa * math.sqrt(n_s) / n_sa
|
|
177
174
|
elif GymctsNeuralNode.score_variate == "PUCT_v3":
|
|
178
175
|
return exploitation_term + c * (p_sa ** GymctsNeuralNode.PUCT_v3_mu) * math.sqrt(n_s / (1 + n_sa))
|
|
@@ -200,11 +197,10 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
200
197
|
c2 = GymctsNeuralNode.MuZero_c2
|
|
201
198
|
return exploitation_term + c * p_sa * math.sqrt(n_s) / (1 + n_sa) * (c1 + math.log((n_s + c2 + 1) / c2))
|
|
202
199
|
|
|
203
|
-
|
|
204
|
-
|
|
200
|
+
exploration_term = self._selection_score_prior * c * math.sqrt(
|
|
201
|
+
math.log(self.parent.visit_count) / (self.visit_count)) if self.visit_count > 0 else float("inf")
|
|
205
202
|
return self.mean_value + exploration_term
|
|
206
203
|
|
|
207
|
-
|
|
208
204
|
def get_best_action(self) -> int:
|
|
209
205
|
"""
|
|
210
206
|
Returns the best action of the node. The best action is the action with the highest score.
|
|
@@ -214,7 +210,6 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
214
210
|
"""
|
|
215
211
|
return max(self.children.values(), key=lambda child: child.max_value).action
|
|
216
212
|
|
|
217
|
-
|
|
218
213
|
def __str__(self, colored=False, action_space_n=None) -> str:
|
|
219
214
|
"""
|
|
220
215
|
Returns a string representation of the node. The string representation is used for visualisation purposes.
|
|
@@ -263,14 +258,13 @@ class GymctsNeuralNode(GymctsNode):
|
|
|
263
258
|
root_node = self.get_root()
|
|
264
259
|
mean_val = f"{self.mean_value:.2f}"
|
|
265
260
|
|
|
266
|
-
|
|
267
261
|
return ((f"("
|
|
268
262
|
f"{p}a{e}={ccu.wrap_evenly_spaced_color(s=self.action, n_of_item=self.action, n_classes=action_space_n)}, "
|
|
269
263
|
f"{p}N{e}={colorful_value(self.visit_count)}, "
|
|
270
264
|
f"{p}Q_v{e}={ccu.wrap_with_color_scale(s=mean_val, value=self.mean_value, min_val=root_node.min_value, max_val=root_node.max_value)}, "
|
|
271
265
|
f"{p}best{e}={colorful_value(self.max_value)}") +
|
|
272
|
-
(
|
|
273
|
-
|
|
266
|
+
(
|
|
267
|
+
f", {p}{GymctsNeuralNode.score_variate}{e}={colorful_value(self.tree_policy_score())})" if not self.is_root() else ")"))
|
|
274
268
|
|
|
275
269
|
|
|
276
270
|
class GymctsNeuralAgent(GymctsAgent):
|
|
@@ -282,15 +276,15 @@ class GymctsNeuralAgent(GymctsAgent):
|
|
|
282
276
|
score_variate: Literal[
|
|
283
277
|
"PUCT_v0",
|
|
284
278
|
"PUCT_v1",
|
|
285
|
-
"
|
|
286
|
-
"
|
|
287
|
-
"
|
|
288
|
-
"
|
|
289
|
-
"
|
|
290
|
-
"
|
|
291
|
-
"
|
|
292
|
-
"
|
|
293
|
-
"
|
|
279
|
+
"PUCT_v2",
|
|
280
|
+
"PUCT_v3",
|
|
281
|
+
"PUCT_v4",
|
|
282
|
+
"PUCT_v5",
|
|
283
|
+
"PUCT_v6",
|
|
284
|
+
"PUCT_v7",
|
|
285
|
+
"PUCT_v8",
|
|
286
|
+
"PUCT_v9",
|
|
287
|
+
"PUCT_v10",
|
|
294
288
|
"MuZero_v0",
|
|
295
289
|
"MuZero_v1",
|
|
296
290
|
] = "PUCT_v0",
|
|
@@ -304,15 +298,23 @@ class GymctsNeuralAgent(GymctsAgent):
|
|
|
304
298
|
**kwargs
|
|
305
299
|
)
|
|
306
300
|
if score_variate not in [
|
|
307
|
-
"PUCT_v0",
|
|
308
|
-
"
|
|
309
|
-
"
|
|
310
|
-
"
|
|
311
|
-
"
|
|
301
|
+
"PUCT_v0",
|
|
302
|
+
"PUCT_v1",
|
|
303
|
+
"PUCT_v2",
|
|
304
|
+
"PUCT_v3",
|
|
305
|
+
"PUCT_v4",
|
|
306
|
+
"PUCT_v5",
|
|
307
|
+
"PUCT_v6",
|
|
308
|
+
"PUCT_v7",
|
|
309
|
+
"PUCT_v8",
|
|
310
|
+
"PUCT_v9",
|
|
311
|
+
"PUCT_v10",
|
|
312
|
+
"MuZero_v0",
|
|
313
|
+
"MuZero_v1",
|
|
312
314
|
]:
|
|
313
315
|
raise ValueError(f"Invalid score_variate: {score_variate}. Must be one of: "
|
|
314
|
-
f"PUCT_v0, PUCT_v1,
|
|
315
|
-
f"
|
|
316
|
+
f"['PUCT_v0', 'PUCT_v1', 'PUCT_v2', 'PUCT_v3', 'PUCT_v4', 'PUCT_v5', "
|
|
317
|
+
f"'PUCT_v6', 'PUCT_v7', 'PUCT_v8', 'PUCT_v9', 'PUCT_v10', 'MuZero_v0', 'MuZero_v1']")
|
|
316
318
|
GymctsNeuralNode.score_variate = score_variate
|
|
317
319
|
|
|
318
320
|
if model_kwargs is None:
|
|
@@ -336,22 +338,17 @@ class GymctsNeuralAgent(GymctsAgent):
|
|
|
336
338
|
env = ActionMasker(env, action_mask_fn=mask_fn)
|
|
337
339
|
|
|
338
340
|
model_kwargs = {
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
341
|
+
"policy": MaskableActorCriticPolicy,
|
|
342
|
+
"env": env,
|
|
343
|
+
"verbose": 1,
|
|
344
|
+
} | model_kwargs
|
|
343
345
|
|
|
344
346
|
self._model = sb3_contrib.MaskablePPO(**model_kwargs)
|
|
345
347
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
def learn(self, total_timesteps:int, **kwargs) -> None:
|
|
348
|
+
def learn(self, total_timesteps: int, **kwargs) -> None:
|
|
351
349
|
"""Learn from the environment using the MaskablePPO model."""
|
|
352
350
|
self._model.learn(total_timesteps=total_timesteps, **kwargs)
|
|
353
351
|
|
|
354
|
-
|
|
355
352
|
def expand_node(self, node: GymctsNeuralNode) -> None:
|
|
356
353
|
log.debug(f"expanding node: {node}")
|
|
357
354
|
# EXPANSION STRATEGY
|
|
@@ -395,7 +392,6 @@ class GymctsNeuralAgent(GymctsAgent):
|
|
|
395
392
|
if prob == 0.0:
|
|
396
393
|
continue
|
|
397
394
|
|
|
398
|
-
|
|
399
395
|
assert action in node.valid_actions, f"Action {action} is not in valid actions: {node.valid_actions}"
|
|
400
396
|
|
|
401
397
|
obs, reward, terminal, truncated, _ = self.env.step(action)
|
|
@@ -411,9 +407,6 @@ class GymctsNeuralAgent(GymctsAgent):
|
|
|
411
407
|
# print(f"Expanded node {node} with {len(node.children)} children.")
|
|
412
408
|
|
|
413
409
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
410
|
if __name__ == '__main__':
|
|
418
411
|
log.setLevel(20)
|
|
419
412
|
|
|
@@ -426,14 +419,13 @@ if __name__ == '__main__':
|
|
|
426
419
|
"reward_function": "nasuta",
|
|
427
420
|
}
|
|
428
421
|
|
|
429
|
-
|
|
430
|
-
|
|
431
422
|
env = DisjunctiveGraphJspEnv(**env_kwargs)
|
|
432
423
|
env.reset()
|
|
433
424
|
|
|
434
425
|
env = GraphJspNeuralGYMCTSWrapper(env)
|
|
435
426
|
|
|
436
427
|
import torch
|
|
428
|
+
|
|
437
429
|
model_kwargs = {
|
|
438
430
|
"gamma": 0.99013,
|
|
439
431
|
"gae_lambda": 0.9,
|
|
@@ -467,7 +459,6 @@ if __name__ == '__main__':
|
|
|
467
459
|
|
|
468
460
|
agent.learn(total_timesteps=10_000)
|
|
469
461
|
|
|
470
|
-
|
|
471
462
|
agent.solve()
|
|
472
463
|
|
|
473
464
|
actions = agent.solve(render_tree_after_step=True)
|
|
@@ -477,9 +468,3 @@ if __name__ == '__main__':
|
|
|
477
468
|
env.render()
|
|
478
469
|
makespan = env.unwrapped.get_makespan()
|
|
479
470
|
print(f"makespan: {makespan}")
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gymcts
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.4
|
|
4
4
|
Summary: A minimalistic implementation of the Monte Carlo Tree Search algorithm for planning problems fomulated as gymnaisum reinforcement learning environments.
|
|
5
5
|
Author: Alexander Nasuta
|
|
6
6
|
Author-email: Alexander Nasuta <alexander.nasuta@wzl-iqs.rwth-aachen.de>
|
|
@@ -5,12 +5,12 @@ gymcts/gymcts_agent.py,sha256=FzMPjHXyKN6enNJubmYEouvb0wBbE1-bpxuLuW4J1gU,10960
|
|
|
5
5
|
gymcts/gymcts_deepcopy_wrapper.py,sha256=lCCT5-6JVCwUCP__4uPMMkT5HnO2JWm2ebzJ69zXp9c,6792
|
|
6
6
|
gymcts/gymcts_distributed_agent.py,sha256=Ha9UBQvFjoErfMWvPyN0JcTYz-JaiJ4eWjLMikp9Yhs,11569
|
|
7
7
|
gymcts/gymcts_env_abc.py,sha256=iqrFNNSa-kZyAGk1UN2BjkdkV6NufAkYJT8d7PlQ07E,2525
|
|
8
|
-
gymcts/gymcts_neural_agent.py,sha256=
|
|
8
|
+
gymcts/gymcts_neural_agent.py,sha256=_PV_lNYVyZDjrPBRYK-DWiQRwUGnleAt3SKbwCZKCWU,16326
|
|
9
9
|
gymcts/gymcts_node.py,sha256=KAR5y1MrT8c_7ZXwTuCj77B7DiERDfHplF8avs76JHU,13410
|
|
10
10
|
gymcts/gymcts_tree_plotter.py,sha256=PR6C7q9Q4kuz1aLGyD7-aZsxk3RqlHZpOqmOiRpCyK0,3547
|
|
11
11
|
gymcts/logger.py,sha256=RI7B9cvbBGrj0_QIAI77wihzuu2tPG_-z9GM2Mw5aHE,926
|
|
12
|
-
gymcts-1.4.
|
|
13
|
-
gymcts-1.4.
|
|
14
|
-
gymcts-1.4.
|
|
15
|
-
gymcts-1.4.
|
|
16
|
-
gymcts-1.4.
|
|
12
|
+
gymcts-1.4.4.dist-info/licenses/LICENSE,sha256=UGe75WojDiw_77SEnK2aysEDlElRlkWie7U7NaAFx00,1072
|
|
13
|
+
gymcts-1.4.4.dist-info/METADATA,sha256=y_-_ktxyZpaLdB0i81ggKepZNycG-P1jiqqadBMwSzI,23864
|
|
14
|
+
gymcts-1.4.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
gymcts-1.4.4.dist-info/top_level.txt,sha256=E8MoLsPimUPD0H1Y6lum4TVe-lhSDAyBAXGrkYIT52w,7
|
|
16
|
+
gymcts-1.4.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|