gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. gr_libs/__init__.py +4 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
  4. gr_libs/_evaluation/_generate_experiments_results.py +141 -0
  5. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
  6. gr_libs/_evaluation/_get_plans_images.py +61 -0
  7. gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +30 -9
  11. gr_libs/environment/_utils/utils.py +27 -0
  12. gr_libs/environment/environment.py +417 -54
  13. gr_libs/metrics/__init__.py +7 -0
  14. gr_libs/metrics/metrics.py +231 -54
  15. gr_libs/ml/__init__.py +2 -5
  16. gr_libs/ml/agent.py +21 -6
  17. gr_libs/ml/base/__init__.py +3 -1
  18. gr_libs/ml/base/rl_agent.py +81 -13
  19. gr_libs/ml/consts.py +1 -1
  20. gr_libs/ml/neural/__init__.py +1 -3
  21. gr_libs/ml/neural/deep_rl_learner.py +619 -378
  22. gr_libs/ml/neural/utils/__init__.py +1 -2
  23. gr_libs/ml/neural/utils/dictlist.py +3 -3
  24. gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
  25. gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
  26. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
  27. gr_libs/ml/planner/mcts/mcts_model.py +571 -312
  28. gr_libs/ml/sequential/__init__.py +0 -1
  29. gr_libs/ml/sequential/_lstm_model.py +270 -0
  30. gr_libs/ml/tabular/__init__.py +1 -3
  31. gr_libs/ml/tabular/state.py +7 -7
  32. gr_libs/ml/tabular/tabular_q_learner.py +150 -82
  33. gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
  34. gr_libs/ml/utils/__init__.py +2 -3
  35. gr_libs/ml/utils/format.py +28 -97
  36. gr_libs/ml/utils/math.py +5 -3
  37. gr_libs/ml/utils/other.py +3 -3
  38. gr_libs/ml/utils/storage.py +88 -81
  39. gr_libs/odgr_executor.py +268 -0
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/_utils/__init__.py +0 -0
  42. gr_libs/recognizer/_utils/format.py +18 -0
  43. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
  44. gr_libs/recognizer/graml/_gr_dataset.py +233 -0
  45. gr_libs/recognizer/graml/graml_recognizer.py +586 -252
  46. gr_libs/recognizer/recognizer.py +90 -30
  47. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  48. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  49. gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
  50. gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
  51. gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
  52. gr_libs/tutorials/graml_panda_tutorial.py +57 -0
  53. gr_libs/tutorials/graml_parking_tutorial.py +52 -0
  54. gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
  55. gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
  56. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  57. gr_libs-0.2.2.dist-info/RECORD +71 -0
  58. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  59. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  60. tests/test_draco.py +14 -0
  61. tests/test_gcdraco.py +10 -0
  62. tests/test_graml.py +12 -8
  63. tests/test_graql.py +3 -2
  64. evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
  65. evaluation/create_minigrid_map_image.py +0 -34
  66. evaluation/file_system.py +0 -42
  67. evaluation/generate_experiments_results.py +0 -92
  68. evaluation/generate_experiments_results_new_ver1.py +0 -254
  69. evaluation/generate_experiments_results_new_ver2.py +0 -331
  70. evaluation/generate_task_specific_statistics_plots.py +0 -272
  71. evaluation/get_plans_images.py +0 -47
  72. evaluation/increasing_and_decreasing_.py +0 -63
  73. gr_libs/environment/utils/utils.py +0 -17
  74. gr_libs/ml/neural/utils/penv.py +0 -57
  75. gr_libs/ml/sequential/lstm_model.py +0 -192
  76. gr_libs/recognizer/graml/gr_dataset.py +0 -134
  77. gr_libs/recognizer/utils/__init__.py +0 -1
  78. gr_libs/recognizer/utils/format.py +0 -13
  79. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  80. gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
  81. tutorials/graml_minigrid_tutorial.py +0 -34
  82. tutorials/graml_panda_tutorial.py +0 -41
  83. tutorials/graml_parking_tutorial.py +0 -39
  84. tutorials/graml_point_maze_tutorial.py +0 -39
  85. tutorials/graql_minigrid_tutorial.py +0 -34
  86. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,330 +1,589 @@
1
+ """ model that performs mcts to find a plan in discrete state/action environments. """
2
+
1
3
  import os
4
+ import pickle
2
5
  import random
3
- from math import sqrt, log
6
+ from math import log, sqrt
4
7
 
8
+ import gymnasium as gym
5
9
  from tqdm import tqdm
6
- import pickle
7
10
 
8
11
  from gr_libs.ml.utils.storage import get_agent_model_dir
9
- from .utils import Node
10
- from .utils import Tree
11
- import gymnasium as gym
12
+
13
+ from ._utils import Node, Tree
12
14
 
13
15
  PROB = 0.8
14
16
  UNIFORM_PROB = 0.1
15
17
  newely_expanded = 0
16
- dict_dir_id_to_str = {0:'right', 1:'down', 2:'left', 3:'up'}
17
- dict_action_id_to_str = {0:'turn left', 1:'turn right', 2:'go straight'}
18
+ dict_dir_id_to_str = {0: "right", 1: "down", 2: "left", 3: "up"}
19
+ dict_action_id_to_str = {0: "turn left", 1: "turn right", 2: "go straight"}
20
+
18
21
 
19
22
  def save_figure(steps, env_name, problem_name, img_path, env_prop):
20
- sequence = [pos for ((state, pos), action) in steps]
21
- #print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
22
- print(f"generating sequence image at {img_path}.")
23
- env_prop.create_sequence_image(sequence, img_path, problem_name)
23
+ """
24
+ Save a figure representing the sequence of steps taken in a problem.
25
+
26
+ Args:
27
+ steps (list): List of tuples representing the state, position, and action taken at each step.
28
+ env_name (str): Name of the environment.
29
+ problem_name (str): Name of the problem.
30
+ img_path (str): Path to save the generated image.
31
+ env_prop: Object with methods to create the sequence image.
32
+
33
+ Returns:
34
+ None
35
+ """
36
+ sequence = [pos for ((state, pos), action) in steps]
37
+ print(f"generating sequence image at {img_path}.")
38
+ env_prop.create_sequence_image(sequence, img_path, problem_name)
39
+
24
40
 
25
41
  # TODO add number of expanded nodes and debug by putting breakpoint on the creation of nodes representing (8,4) and checking if they're invalid or something
26
42
 
27
- # Explanation on hashing and uncertainty in the acto outcome:
28
- # We want to detect circles, while not preventing expected behavior. To achieve it, hasing must include previous state, action, and resulting state.
29
- # Hashing the direction means coming to the same position from different positions gets different id's.
30
- # Example: the agent might have stood at (2,2), picked action 2 (forward), and accidently turned right, resulting at state ((2,2), right).
31
- # later, when the agent stood at (2,1), looked right and walked forward, it got to the same state. We would want to enable that, because
32
- # this is the expected behavior, so these nodes must have unique id's.
33
- # The situations where circles will indeed be detected, are only if the outcome was the same for the previous state, consistent with the action - whether it was or wasn't expected.
34
- class MonteCarloTreeSearch():
35
-
36
- def __init__(self, env, tree, goal, use_heuristic=True):
37
- self.env = env
38
- self.tree = tree
39
- self.action_space = self.env.action_space.n
40
- self.action_space = 3 # currently
41
- state, _ = self.env.reset()
42
- self.use_heuristic = use_heuristic
43
- self.goal = goal
44
- self.tree.add_node(Node(identifier=hash((None, None, tuple(self.env.unwrapped.agent_pos), state['direction'])), state=state, action=None, action_space=self.action_space, reward=0, terminal=False, pos=env.unwrapped.agent_pos, depth=0))
45
- self.plan = []
46
-
47
- # def mark_invalid_children(self, children_identifiers, action):
48
- # for child_id in children_identifiers:
49
- # child = self.tree.nodes[child_id]
50
- # if child.action == action:
51
- # child.invalid = True
52
-
53
- def decide_invalid_path(self, new_node_father, old_node, new_node): # new_node created the circle, old_node got to the configuration first.
54
- new_visits, old_visits = [1,1], [0,0] # stochasticity couldn't result a cycle directly, because it involves a different action. we can get it only by making the same stochastic action mistake or just an actual cycle.
55
- new_node_ptr = new_node_father
56
- old_node_ptr = old_node
57
-
58
- while new_node_ptr != None:
59
- new_visits[0] += new_node_ptr.num_visits
60
- new_visits[1] += 1
61
- new_node_ptr = self.tree.parent(new_node_ptr)
62
-
63
- while old_node_ptr != None: # getting to the old node wasn't necessarily through the current root. check all the way until None, the original root's parent.
64
- old_visits[0] += old_node_ptr.num_visits
65
- old_visits[1] += 1
66
- old_node_ptr = self.tree.parent(old_node_ptr)
67
-
68
- if new_visits[0] / new_visits[1] > old_visits[0] / old_visits[1]: # newer node is the more probable one. make the 1st path the invalid one: its the one that created the circle!
69
- old_node.invalid = True
70
- # self.tree.update_id(old_id=old_node.identifier, new_id=new_node.identifier)
71
- else:
72
- new_node.invalid = True
73
-
74
- def is_parent_child_same(self, new_node, node):
75
- return new_node.pos[0] == node.pos[0] and new_node.pos[1] == node.pos[1] and new_node.state['direction'] == node.state['direction']
76
-
77
- def expand(self, node, depth):
78
- global newely_expanded
79
- action = node.untried_action()
80
- state, reward, terminated, truncated, _ = self.env.step(self.stochastic_action(action))
81
- done = terminated | truncated
82
- new_identifier = hash((tuple(node.pos), node.state['direction'], action, tuple(self.env.unwrapped.agent_pos), state['direction']))
83
- valid_id = new_identifier
84
- while new_identifier in self.tree.nodes.keys(): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
85
- if self.tree.nodes[new_identifier].invalid == False:
86
- valid_id = new_identifier
87
- new_identifier = hash((666, new_identifier))
88
- # after this while, the id is for sure unused.
89
- new_node = Node(identifier=new_identifier, state=state, action=action, action_space=self.action_space, reward=reward, terminal=done, pos=self.env.unwrapped.agent_pos, depth=depth)
90
- if self.is_parent_child_same(new_node, node): # this is not a regular circle but it indicates that the action - regardless if happened with or without intention, led to staying put. note this could happen even if the first if is true - twice in history someone tried to go against the wall from 2 different paths. both should be tagged invalid.
91
- new_node.invalid = True
92
- new_node.got_invalid = True
93
- # if this is a legit (s,a,s'), find the valid one and check whether this one might be more valid.
94
- elif valid_id in self.tree.nodes.keys(): # who can tell which node is invalid? might be that this is the more probable way to get here, it just happened later. maybe think of summing back up the num of visits to decide which one to make invalid.
95
- # print("CIRCLE DETECTED!") # circle can be detected by 2 nodes making the wrong stochastic action one after another, in different times!
96
-
97
- self.decide_invalid_path(new_node_father=node, old_node=self.tree.nodes[valid_id], new_node=new_node)
98
- # self.mark_invalid_children(node.children_identifiers, action)
99
-
100
- self.tree.add_node(new_node, node)
101
- # if action == 2 and tuple(self.env.unwrapped.agent_pos) == tuple(node.pos): # if the new node is actually invalid, mark it along with the other nodes of the same action as invalid, meaning reward will be 0 for them.
102
- # self.mark_invalid_children(node.children_identifiers)
103
- newely_expanded += 1
104
- return new_node
105
-
106
- def stochastic_action(self, choice):
107
- prob_distribution = []
108
- actions = range(self.action_space)
109
- for action in actions:
110
- if action == choice: prob_distribution.append(PROB)
111
- else: prob_distribution.append(UNIFORM_PROB)
112
- return random.choices(actions, weights=prob_distribution, k=1)[0]
113
-
114
- def expand_selection_stochastic_node(self, node, resulting_identifier, terminated, truncated, reward, action, state, depth):
115
- global newely_expanded
116
- # the new node could result in a terminating state.
117
- done = terminated | truncated
118
- valid_id = resulting_identifier
119
- while resulting_identifier in self.tree.nodes.keys(): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
120
- if self.tree.nodes[resulting_identifier].invalid == False:
121
- valid_id = resulting_identifier
122
- resulting_identifier = hash((666, resulting_identifier))
123
- # after this while, the id is for sure unused.
124
- new_node = Node(identifier=resulting_identifier, state=state, action=action, action_space=self.action_space, reward=reward, terminal=done, pos=self.env.unwrapped.agent_pos, depth=depth)
125
- if self.is_parent_child_same(new_node, node): # this is not a regular circle but it indicates that the action - regardless if happened with or without intention, led to staying put. note this could happen even if the first if is true - twice in history someone tried to go against the wall from 2 different paths. both should be tagged invalid.
126
- new_node.invalid = True
127
- new_node.got_invalid = True
128
- # if this is a legit (s,a,s'), find the valid one and check whether this one might be more valid.
129
- elif valid_id in self.tree.nodes.keys(): # who can tell which node is invalid? might be that this is the more probable way to get here, it just happened later. maybe think of summing back up the num of visits to decide which one to make invalid.
130
- # print("CIRCLE DETECTED!") # circle can be detected by 2 nodes making the wrong stochastic action one after another, in different times!
131
- self.decide_invalid_path(new_node_father=node, old_node=self.tree.nodes[valid_id], new_node=new_node)
132
- # self.mark_invalid_children(node.children_identifiers, action)
133
- self.tree.add_node(new_node, node)
134
- newely_expanded += 1
135
- return new_node
136
-
137
- def simulation(self, node):
138
- if node.terminal:
139
- return node.reward
140
- if self.use_heuristic:
141
- # taken from Monte-Carlo Planning for Pathfinding in Real-Time Strategy Games , 2010.
142
- # need to handle the case of walking into a wall here: the resulting node will be considered invalid and it's reward and performance needs to be 0, but must handle stochasticity
143
- # suggestion to handle stochasticity - consider *all* the children associated with taking action 2 towards a wall as performance 0, even if they accidently led in walking to another direction.
144
- # which suggests the invalidity needs to be checked not according to the resulting state, rather according to the intended action itself and the environment! remember, you cannot access the "stochastic_action", it is meant to be hidden from you.
145
- if node.pos[0] == self.goal[0] and node.pos[1] == self.goal[1] : return 2
146
- if node.invalid: return -0.5
147
- else: return 0.8*(1 / (abs(node.pos[0] - self.goal[0]) + abs(node.pos[1] - self.goal[1]))) + 0.2*(1/node.depth) # large depth = less probability of obstacles -> larger nominator higher performance. further from goal -> larger denominator, lower performance.
148
- while True:
149
- action = random.randint(0, self.action_space-1)
150
- state, reward, terminated, truncated, _ = self.env.step(self.stochastic_action(action))
151
- done = terminated | truncated # this time there could be truncation unlike in the tree policy.
152
- if done:
153
- return reward
154
-
155
- def compute_value(self, parent, child, exploration_constant):
156
- exploration_term = exploration_constant * sqrt(2*log(parent.num_visits) / child.num_visits)
157
- return child.performance + exploration_term
158
-
159
- # return the best action from a node. the value of an action is the weighted sum of performance of all children that are associated with this action.
160
- def best_action(self, node, exploration_constant):
161
- tried_actions_values = {} # dictionary mapping actions to tuples of (cumulative number of visits of children, sum of (child performance * num of visits for child)) to compute the mean later
162
- if tuple(node.pos) == (1,2) and node.depth == 3 and node.action == 0:
163
- pass
164
- children = [child for child in self.tree.children(node) if not child.invalid]
165
- if not children: # all children are invalid. this node is invalid aswell.
166
- return 2
167
- for child in children:
168
- value = self.compute_value(node, child, exploration_constant)
169
- tried_actions_values.setdefault(child.action, [0, 0]) # create if it doesn't exist
170
- tried_actions_values[child.action][0] += child.num_visits # add the number of child visits
171
- tried_actions_values[child.action][1] += value * child.num_visits # add the relative performance of this child
172
- return max(tried_actions_values, key=lambda k: tried_actions_values[k][1] / tried_actions_values[k][0]) # return the key (action) with the highest average value
173
-
174
- # only changes the environment to make sure the actions which are already a part of the plan have been executed.
175
- def execute_partial_plan(self, plan):
176
- node = self.tree.root
177
- depth = 0
178
- for action in plan:
179
- depth += 1
180
- # important to simulate the env to get to some state, as the nodes don't hold this information.
181
- state, reward, terminated, truncated, _ = self.env.step(action)
182
- done = terminated
183
- if done: return None, False
184
- resulting_identifier = hash((tuple(node.pos), node.state['direction'], action, tuple(self.env.unwrapped.agent_pos), state['direction']))
185
- node = self.tree.nodes[resulting_identifier]
186
- return node, True
187
-
188
- # finds the ultimate path from the root node to a terminal state (the one that maximized rewards)
189
- def tree_policy(self, root_depth):
190
- node = self.tree.root
191
- depth = root_depth
192
- while not (node.terminal or node.invalid):
193
- depth += 1
194
- if self.tree.is_expandable(node):
195
- # expansion - in case there's an action that never been tried, its value is infinity to encourage exploration of all children of a node.
196
- return self.expand(node, depth), depth
197
- else:
198
- # selection - balance exploration and exploitation, coming down the tree - but note the selection might lead to new nodes because of stochaticity.
199
- best_action = self.best_action(node, exploration_constant=1/sqrt(2.0))
200
- if best_action == -1: break
201
- # important to simulate the env to get to some state, as the nodes don't hold this information.
202
- state, reward, terminated, truncated, _ = self.env.step(self.stochastic_action(best_action))
203
- # due to stochasticity, nodes could sometimes be terminal and sometimes they aren't. important to update it. also, the resulting state
204
- # could be a state we've never been at due to uncertainty of actions' outcomes.
205
- # if the resulting state creates a parent-action-child triplet that hasn't been seen before, add to the tree and return it, similar result to 'expand'.
206
- # the hashing must include the action, because we want to enable getting to the same state stochastically from 2 different states: walking forward from (1,2) looking right and getting to (2,2) - the expected behavior, should be allowed even if the agent once stood at (2,1), looked down, turned right and accidently proceeded forward.
207
- resulting_identifier = [child_id for child_id in node.children_identifiers if all(a == b for a, b in zip(self.tree.nodes[child_id].pos, self.env.unwrapped.agent_pos)) and self.tree.nodes[child_id].action == best_action]
208
- if len(resulting_identifier) == 0: # took an action done before, but it lead to a new state.
209
- resulting_identifier = hash((tuple(node.pos), node.state['direction'], best_action, tuple(self.env.unwrapped.agent_pos), state['direction']))
210
- return self.expand_selection_stochastic_node(node, resulting_identifier, terminated, truncated, reward, best_action, state, depth), depth
211
- assert len(resulting_identifier) == 1
212
- node = self.tree.nodes[resulting_identifier[0]]
213
- return node, depth
214
-
215
- # receives a final state node and updates the rewards of all the nodes on the path to the root
216
- def backpropagation(self, node, value):
217
- while node != self.tree.parent(self.tree.root):
218
- assert node != None # if we got to None it means we got to the actual root with the backpropogation instead of to the current root, which means in this path, someone had a differrent parent than it should, probably a double id.
219
- node.num_visits += 1
220
- node.total_simulation_reward += value
221
- node.performance = node.total_simulation_reward/node.num_visits
222
- node = self.tree.parent(node)
223
-
224
-
225
- def generate_full_policy_sequence(self, env_name, problem_name, save_fig=False, fig_path=None, env_prop=None):
226
- trace = []
227
- node, prev_node = self.tree.root, self.tree.root
228
- print("generating policy sequence.")
229
- for action in self.plan:
230
- print(f"position {tuple(node.pos)} direction {dict_dir_id_to_str[node.state['direction']]}, action {dict_action_id_to_str[action]}")
231
- candidate_children = [child for child in self.tree.children(node) if child.action == action] # there could be some children associated with the best action, representing different outcomes.
232
- assert len(candidate_children) > 0
233
- node = max(candidate_children, key=lambda node: node.num_visits) # pick the child that was visited most, meaning it represents the desired action and not the undesired outcomes.
234
- trace.append(((prev_node.state, tuple(prev_node.pos)), node.action)) # need to add the previous node with the action leading to the next node which is a property of the next node
235
- prev_node = node
236
- if save_fig:
237
- assert fig_path!=None
238
- save_figure(trace, env_name, problem_name, fig_path, env_prop)
239
- else:
240
- assert fig_path==None
241
- return trace
242
-
243
-
244
- def save_model_and_generate_policy(tree, original_root, model_file_path, monteCarloTreeSearch):
245
- tree.root = original_root
246
- with open(model_file_path, 'wb') as file: # Serialize the model
247
- monteCarloTreeSearch.env = None # pickle cannot serialize lambdas which exist in the env
248
- pickle.dump(monteCarloTreeSearch, file)
43
+
44
+ class MonteCarloTreeSearch:
45
+ """
46
+ Monte Carlo Tree Search class for performing search on an environment using a tree data structure.
47
+
48
+ Explanation on hashing and uncertainty in the acto outcome:
49
+ We want to detect circles, while not preventing expected behavior.
50
+ To achieve it, hasing must include previous state, action, and resulting state.
51
+ Hashing the direction means coming to the same position from different positions gets different id's.
52
+ Example: the agent might have stood at (2,2), picked action 2 (forward), and accidently turned right,
53
+ resulting at state ((2,2), right).
54
+ later, when the agent stood at (2,1), looked right and walked forward,
55
+ it got to the same state. We would want to enable that, because
56
+ this is the expected behavior, so these nodes must have unique id's.
57
+ The situations where circles will indeed be detected, are only if the outcome was the same for the previous state,
58
+ consistent with the action - whether it was or wasn't expected.
59
+
60
+ Args:
61
+ env (gym.Env): The environment to perform the search on.
62
+ tree (Tree): The tree data structure to store the search tree.
63
+ goal (object): The goal state of the search.
64
+ use_heuristic (bool, optional): Whether to use a heuristic function during the search. Defaults to True.
65
+ """
66
+
67
+ def __init__(self, env, tree, goal, use_heuristic=True):
68
+ """
69
+ Initializes the Monte Carlo Tree Search.
70
+
71
+ Args:
72
+ env (gym.Env): The environment to perform the search on.
73
+ tree (Tree): The tree data structure to store the search tree.
74
+ goal (object): The goal state of the search.
75
+ use_heuristic (bool, optional): Whether to use a heuristic function during the search. Defaults to True.
76
+ """
77
+ self.env = env
78
+ self.tree = tree
79
+ self.action_space = self.env.action_space.n
80
+ self.action_space = 3 # currently
81
+ state, _ = self.env.reset()
82
+ self.use_heuristic = use_heuristic
83
+ self.goal = goal
84
+ self.tree.add_node(
85
+ Node(
86
+ identifier=hash(
87
+ (
88
+ None,
89
+ None,
90
+ tuple(self.env.unwrapped.agent_pos),
91
+ state["direction"],
92
+ )
93
+ ),
94
+ state=state,
95
+ action=None,
96
+ action_space=self.action_space,
97
+ reward=0,
98
+ terminal=False,
99
+ pos=env.unwrapped.agent_pos,
100
+ depth=0,
101
+ )
102
+ )
103
+ self.plan = []
104
+
105
+ # def mark_invalid_children(self, children_identifiers, action):
106
+ # for child_id in children_identifiers:
107
+ # child = self.tree.nodes[child_id]
108
+ # if child.action == action:
109
+ # child.invalid = True
110
+
111
+ def decide_invalid_path(
112
+ self, new_node_father, old_node, new_node
113
+ ): # new_node created the circle, old_node got to the configuration first.
114
+ new_visits, old_visits = [1, 1], [
115
+ 0,
116
+ 0,
117
+ ] # stochasticity couldn't result a cycle directly, because it involves a different action. we can get it only by making the same stochastic action mistake or just an actual cycle.
118
+ new_node_ptr = new_node_father
119
+ old_node_ptr = old_node
120
+
121
+ while new_node_ptr is not None:
122
+ new_visits[0] += new_node_ptr.num_visits
123
+ new_visits[1] += 1
124
+ new_node_ptr = self.tree.parent(new_node_ptr)
125
+
126
+ while (
127
+ old_node_ptr is not None
128
+ ): # getting to the old node wasn't necessarily through the current root. check all the way until None, the original root's parent.
129
+ old_visits[0] += old_node_ptr.num_visits
130
+ old_visits[1] += 1
131
+ old_node_ptr = self.tree.parent(old_node_ptr)
132
+
133
+ if (
134
+ new_visits[0] / new_visits[1] > old_visits[0] / old_visits[1]
135
+ ): # newer node is the more probable one. make the 1st path the invalid one: its the one that created the circle!
136
+ old_node.invalid = True
137
+ # self.tree.update_id(old_id=old_node.identifier, new_id=new_node.identifier)
138
+ else:
139
+ new_node.invalid = True
140
+
141
+ def is_parent_child_same(self, new_node, node):
142
+ return (
143
+ new_node.pos[0] == node.pos[0]
144
+ and new_node.pos[1] == node.pos[1]
145
+ and new_node.state["direction"] == node.state["direction"]
146
+ )
147
+
148
+ def expand(self, node, depth):
149
+ global newely_expanded
150
+ action = node.untried_action()
151
+ state, reward, terminated, truncated, _ = self.env.step(
152
+ self.stochastic_action(action)
153
+ )
154
+ done = terminated | truncated
155
+ new_identifier = hash(
156
+ (
157
+ tuple(node.pos),
158
+ node.state["direction"],
159
+ action,
160
+ tuple(self.env.unwrapped.agent_pos),
161
+ state["direction"],
162
+ )
163
+ )
164
+ valid_id = new_identifier
165
+ while (
166
+ new_identifier in self.tree.nodes.keys()
167
+ ): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
168
+ if self.tree.nodes[new_identifier].invalid is False:
169
+ valid_id = new_identifier
170
+ new_identifier = hash((666, new_identifier))
171
+ # after this while, the id is for sure unused.
172
+ new_node = Node(
173
+ identifier=new_identifier,
174
+ state=state,
175
+ action=action,
176
+ action_space=self.action_space,
177
+ reward=reward,
178
+ terminal=done,
179
+ pos=self.env.unwrapped.agent_pos,
180
+ depth=depth,
181
+ )
182
+ if self.is_parent_child_same(
183
+ new_node, node
184
+ ): # this is not a regular circle but it indicates that the action - regardless if happened with or without intention, led to staying put. note this could happen even if the first if is true - twice in history someone tried to go against the wall from 2 different paths. both should be tagged invalid.
185
+ new_node.invalid = True
186
+ new_node.got_invalid = True
187
+ # if this is a legit (s,a,s'), find the valid one and check whether this one might be more valid.
188
+ elif (
189
+ valid_id in self.tree.nodes.keys()
190
+ ): # who can tell which node is invalid? might be that this is the more probable way to get here, it just happened later. maybe think of summing back up the num of visits to decide which one to make invalid.
191
+ # print("CIRCLE DETECTED!") # circle can be detected by 2 nodes making the wrong stochastic action one after another, in different times!
192
+
193
+ self.decide_invalid_path(
194
+ new_node_father=node,
195
+ old_node=self.tree.nodes[valid_id],
196
+ new_node=new_node,
197
+ )
198
+ # self.mark_invalid_children(node.children_identifiers, action)
199
+
200
+ self.tree.add_node(new_node, node)
201
+ # if action == 2 and tuple(self.env.unwrapped.agent_pos) == tuple(node.pos): # if the new node is actually invalid, mark it along with the other nodes of the same action as invalid, meaning reward will be 0 for them.
202
+ # self.mark_invalid_children(node.children_identifiers)
203
+ newely_expanded += 1
204
+ return new_node
205
+
206
+ def stochastic_action(self, choice):
207
+ prob_distribution = []
208
+ actions = range(self.action_space)
209
+ for action in actions:
210
+ if action == choice:
211
+ prob_distribution.append(PROB)
212
+ else:
213
+ prob_distribution.append(UNIFORM_PROB)
214
+ return random.choices(actions, weights=prob_distribution, k=1)[0]
215
+
216
+ def expand_selection_stochastic_node(
217
+ self,
218
+ node,
219
+ resulting_identifier,
220
+ terminated,
221
+ truncated,
222
+ reward,
223
+ action,
224
+ state,
225
+ depth,
226
+ ):
227
+ global newely_expanded
228
+ # the new node could result in a terminating state.
229
+ done = terminated | truncated
230
+ valid_id = resulting_identifier
231
+ while (
232
+ resulting_identifier in self.tree.nodes.keys()
233
+ ): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
234
+ if self.tree.nodes[resulting_identifier].invalid is False:
235
+ valid_id = resulting_identifier
236
+ resulting_identifier = hash((666, resulting_identifier))
237
+ # after this while, the id is for sure unused.
238
+ new_node = Node(
239
+ identifier=resulting_identifier,
240
+ state=state,
241
+ action=action,
242
+ action_space=self.action_space,
243
+ reward=reward,
244
+ terminal=done,
245
+ pos=self.env.unwrapped.agent_pos,
246
+ depth=depth,
247
+ )
248
+ if self.is_parent_child_same(
249
+ new_node, node
250
+ ): # this is not a regular circle but it indicates that the action - regardless if happened with or without intention, led to staying put. note this could happen even if the first if is true - twice in history someone tried to go against the wall from 2 different paths. both should be tagged invalid.
251
+ new_node.invalid = True
252
+ new_node.got_invalid = True
253
+ # if this is a legit (s,a,s'), find the valid one and check whether this one might be more valid.
254
+ elif (
255
+ valid_id in self.tree.nodes.keys()
256
+ ): # who can tell which node is invalid? might be that this is the more probable way to get here, it just happened later. maybe think of summing back up the num of visits to decide which one to make invalid.
257
+ # print("CIRCLE DETECTED!") # circle can be detected by 2 nodes making the wrong stochastic action one after another, in different times!
258
+ self.decide_invalid_path(
259
+ new_node_father=node,
260
+ old_node=self.tree.nodes[valid_id],
261
+ new_node=new_node,
262
+ )
263
+ # self.mark_invalid_children(node.children_identifiers, action)
264
+ self.tree.add_node(new_node, node)
265
+ newely_expanded += 1
266
+ return new_node
267
+
268
+ def simulation(self, node):
269
+ if node.terminal:
270
+ return node.reward
271
+ if self.use_heuristic:
272
+ # taken from Monte-Carlo Planning for Pathfinding in Real-Time Strategy Games , 2010.
273
+ # need to handle the case of walking into a wall here: the resulting node will be considered invalid and it's reward and performance needs to be 0, but must handle stochasticity
274
+ pass
275
+ # suggestion to handle stochasticity - consider *all* the children associated with taking action 2 towards a wall as performance 0, even if they accidently led in walking to another direction.
276
+ # which suggests the invalidity needs to be checked not according to the resulting state, rather according to the intended action itself and the environment! remember, you cannot access the "stochastic_action", it is meant to be hidden from you.
277
+ if node.pos[0] == self.goal[0] and node.pos[1] == self.goal[1]:
278
+ return 2
279
+ if node.invalid:
280
+ return -0.5
281
+ else:
282
+ return 0.8 * (
283
+ 1
284
+ / (
285
+ abs(node.pos[0] - self.goal[0])
286
+ + abs(node.pos[1] - self.goal[1])
287
+ )
288
+ ) + 0.2 * (
289
+ 1 / node.depth
290
+ ) # large depth = less probability of obstacles -> larger nominator higher performance. further from goal -> larger denominator, lower performance.
291
+ while True:
292
+ action = random.randint(0, self.action_space - 1)
293
+ state, reward, terminated, truncated, _ = self.env.step(
294
+ self.stochastic_action(action)
295
+ )
296
+ done = (
297
+ terminated | truncated
298
+ ) # this time there could be truncation unlike in the tree policy.
299
+ if done:
300
+ return reward
301
+
302
+ def compute_value(self, parent, child, exploration_constant):
303
+ exploration_term = exploration_constant * sqrt(
304
+ 2 * log(parent.num_visits) / child.num_visits
305
+ )
306
+ return child.performance + exploration_term
307
+
308
+ # return the best action from a node. the value of an action is the weighted sum of performance of all children that are associated with this action.
309
+ def best_action(self, node, exploration_constant):
310
+ tried_actions_values = (
311
+ {}
312
+ ) # dictionary mapping actions to tuples of (cumulative number of visits of children, sum of (child performance * num of visits for child)) to compute the mean later
313
+ if tuple(node.pos) == (1, 2) and node.depth == 3 and node.action == 0:
314
+ pass
315
+ children = [child for child in self.tree.children(node) if not child.invalid]
316
+ if not children: # all children are invalid. this node is invalid aswell.
317
+ return 2
318
+ for child in children:
319
+ value = self.compute_value(node, child, exploration_constant)
320
+ tried_actions_values.setdefault(
321
+ child.action, [0, 0]
322
+ ) # create if it doesn't exist
323
+ tried_actions_values[child.action][
324
+ 0
325
+ ] += child.num_visits # add the number of child visits
326
+ tried_actions_values[child.action][1] += (
327
+ value * child.num_visits
328
+ ) # add the relative performance of this child
329
+ return max(
330
+ tried_actions_values,
331
+ key=lambda k: tried_actions_values[k][1] / tried_actions_values[k][0],
332
+ ) # return the key (action) with the highest average value
333
+
334
+ # only changes the environment to make sure the actions which are already a part of the plan have been executed.
335
+ def execute_partial_plan(self, plan):
336
+ node = self.tree.root
337
+ depth = 0
338
+ for action in plan:
339
+ depth += 1
340
+ # important to simulate the env to get to some state, as the nodes don't hold this information.
341
+ state, reward, terminated, truncated, _ = self.env.step(action)
342
+ done = terminated
343
+ if done:
344
+ return None, False
345
+ resulting_identifier = hash(
346
+ (
347
+ tuple(node.pos),
348
+ node.state["direction"],
349
+ action,
350
+ tuple(self.env.unwrapped.agent_pos),
351
+ state["direction"],
352
+ )
353
+ )
354
+ node = self.tree.nodes[resulting_identifier]
355
+ return node, True
356
+
357
+ # finds the ultimate path from the root node to a terminal state (the one that maximized rewards)
358
+ def tree_policy(self, root_depth):
359
+ node = self.tree.root
360
+ depth = root_depth
361
+ while not (node.terminal or node.invalid):
362
+ depth += 1
363
+ if self.tree.is_expandable(node):
364
+ # expansion - in case there's an action that never been tried, its value is infinity to encourage exploration of all children of a node.
365
+ return self.expand(node, depth), depth
366
+ else:
367
+ # selection - balance exploration and exploitation, coming down the tree - but note the selection might lead to new nodes because of stochaticity.
368
+ best_action = self.best_action(node, exploration_constant=1 / sqrt(2.0))
369
+ if best_action == -1:
370
+ break
371
+ # important to simulate the env to get to some state, as the nodes don't hold this information.
372
+ state, reward, terminated, truncated, _ = self.env.step(
373
+ self.stochastic_action(best_action)
374
+ )
375
+ # due to stochasticity, nodes could sometimes be terminal and sometimes they aren't. important to update it. also, the resulting state
376
+ # could be a state we've never been at due to uncertainty of actions' outcomes.
377
+ # if the resulting state creates a parent-action-child triplet that hasn't been seen before, add to the tree and return it, similar result to 'expand'.
378
+ # the hashing must include the action, because we want to enable getting to the same state stochastically from 2 different states: walking forward from (1,2) looking right and getting to (2,2) - the expected behavior, should be allowed even if the agent once stood at (2,1), looked down, turned right and accidently proceeded forward.
379
+ resulting_identifier = [
380
+ child_id
381
+ for child_id in node.children_identifiers
382
+ if all(
383
+ a == b
384
+ for a, b in zip(
385
+ self.tree.nodes[child_id].pos, self.env.unwrapped.agent_pos
386
+ )
387
+ )
388
+ and self.tree.nodes[child_id].action == best_action
389
+ ]
390
+ if (
391
+ len(resulting_identifier) == 0
392
+ ): # took an action done before, but it lead to a new state.
393
+ resulting_identifier = hash(
394
+ (
395
+ tuple(node.pos),
396
+ node.state["direction"],
397
+ best_action,
398
+ tuple(self.env.unwrapped.agent_pos),
399
+ state["direction"],
400
+ )
401
+ )
402
+ return (
403
+ self.expand_selection_stochastic_node(
404
+ node,
405
+ resulting_identifier,
406
+ terminated,
407
+ truncated,
408
+ reward,
409
+ best_action,
410
+ state,
411
+ depth,
412
+ ),
413
+ depth,
414
+ )
415
+ assert len(resulting_identifier) == 1
416
+ node = self.tree.nodes[resulting_identifier[0]]
417
+ return node, depth
418
+
419
+ # receives a final state node and updates the rewards of all the nodes on the path to the root
420
+ def backpropagation(self, node, value):
421
+ while node != self.tree.parent(self.tree.root):
422
+ assert (
423
+ node is not None
424
+ ) # if we got to None it means we got to the actual root with the backpropogation instead of to the current root, which means in this path, someone had a differrent parent than it should, probably a double id.
425
+ node.num_visits += 1
426
+ node.total_simulation_reward += value
427
+ node.performance = node.total_simulation_reward / node.num_visits
428
+ node = self.tree.parent(node)
429
+
430
+ def generate_full_policy_sequence(
431
+ self, env_name, problem_name, save_fig=False, fig_path=None, env_prop=None
432
+ ):
433
+ trace = []
434
+ node, prev_node = self.tree.root, self.tree.root
435
+ print("generating policy sequence.")
436
+ for action in self.plan:
437
+ print(
438
+ f"position {tuple(node.pos)} direction {dict_dir_id_to_str[node.state['direction']]}, action {dict_action_id_to_str[action]}"
439
+ )
440
+ candidate_children = [
441
+ child for child in self.tree.children(node) if child.action == action
442
+ ] # there could be some children associated with the best action, representing different outcomes.
443
+ assert len(candidate_children) > 0
444
+ node = max(
445
+ candidate_children, key=lambda node: node.num_visits
446
+ ) # pick the child that was visited most, meaning it represents the desired action and not the undesired outcomes.
447
+ trace.append(
448
+ ((prev_node.state, tuple(prev_node.pos)), node.action)
449
+ ) # need to add the previous node with the action leading to the next node which is a property of the next node
450
+ prev_node = node
451
+ if save_fig:
452
+ assert fig_path is not None
453
+ save_figure(trace, env_name, problem_name, fig_path, env_prop)
454
+ else:
455
+ assert fig_path is None
456
+ return trace
457
+
458
+
459
+ def save_model_and_generate_policy(
460
+ tree, original_root, model_file_path, monteCarloTreeSearch
461
+ ):
462
+ tree.root = original_root
463
+ with open(model_file_path, "wb") as file: # Serialize the model
464
+ monteCarloTreeSearch.env = (
465
+ None # pickle cannot serialize lambdas which exist in the env
466
+ )
467
+ pickle.dump(monteCarloTreeSearch, file)
249
468
 
250
469
 
251
470
  def plan(env_name, problem_name, goal, save_fig=False, fig_path=None, env_prop=None):
252
- global newely_expanded
253
- model_dir = get_agent_model_dir(env_name=env_name, model_name=problem_name, class_name="MCTS")
254
- model_file_path = os.path.join(model_dir, "mcts_model.pth")
255
- if os.path.exists(model_file_path):
256
- print(f"Loading pre-existing mcts planner in {model_file_path}")
257
- with open(model_file_path, 'rb') as file: # Load the pre-existing model
258
- try:
259
- monteCarloTreeSearch = pickle.load(file)
260
- except Exception as e:
261
- class RenameUnpickler(pickle.Unpickler):
262
- def find_class(self, module, name):
263
- renamed_module = module
264
- if module.startswith("ml"):
265
- renamed_module = "gr_libs." + renamed_module
266
- return super(RenameUnpickler, self).find_class(renamed_module, name)
267
- def renamed_load(file_obj):
268
- return RenameUnpickler(file_obj).load()
269
- file.seek(0)
270
- monteCarloTreeSearch = renamed_load(file)
271
-
272
- with open(model_file_path, 'wb') as file:
273
- pickle.dump(monteCarloTreeSearch, file)
274
-
275
- return monteCarloTreeSearch.generate_full_policy_sequence(env_name, problem_name, save_fig, fig_path)
276
- if not os.path.exists(model_dir): # if we reached here, the model doesn't exist. make sure its folder exists.
277
- os.makedirs(model_dir)
278
- steps = 10000
279
- print(f"No tree found. Executing MCTS, starting with {steps} rollouts for each action.")
280
- env = gym.make(id=problem_name)
281
- random.seed(2)
282
- tree = Tree()
283
- mcts = MonteCarloTreeSearch(env=env, tree=tree, goal=goal)
284
- original_root = tree.root
285
- depth = 0
286
- while not tree.root.terminal: # we iterate until the root is a terminal state, meaning the game is over.
287
- max_reward = 0
288
- iteration = 0
289
- steps = max(2000,int(steps*0.9))
290
- print(f"Executing {steps} rollouts for each action now.")
291
- tq = tqdm(range(steps), postfix=f"Iteration: {iteration}, Num of steps: {len(mcts.plan)}. depth: {depth}. Max reward: {max_reward}. plan to {tuple(env.unwrapped.agent_pos)}, newely expanded: {0}")
292
- for n in tq:
293
- iteration = n
294
- mcts.env.reset()
295
- # when executing the partial plan, it's possible the environment finished due to the stochasticity. the execution would return false if that happend.
296
- depth = len(mcts.plan)
297
- mcts.tree.root = original_root # need to return it to the original root before executing the partial plan as it can lead to a different path and the root can change between iterations.
298
- node, result = mcts.execute_partial_plan(mcts.plan)
299
- if not result:
300
- # false return value from partial plan execution means the plan is finished. we can mark our root as terminal and exit, happy with our plan.
301
- tree.root.terminal = True
302
- save_model_and_generate_policy(tree=tree, original_root=original_root, model_file_path=model_file_path, monteCarloTreeSearch=mcts)
303
- return mcts.generate_full_policy_sequence(env_name, problem_name, save_fig, fig_path, env_prop)
304
- plan_pos, plan_dir = node.pos, dict_dir_id_to_str[node.state['direction']]
305
- tree.root = node # determine the root to be the node executed after the plan for this iteration.
306
- node, depth = mcts.tree_policy(root_depth=depth) # find a path to a new unvisited node (unique sequence of actions) by utilizing explorative policy or choosing unvisited children recursively
307
- # if the node that returned from tree policy is terminal, the reward will be returned from "simulation" function immediately.
308
- reward = mcts.simulation(node) # proceed from that node randomly and collect the final reward expected from it (heuristic)
309
- if reward > max_reward:
310
- max_reward = reward
311
- mcts.backpropagation(node, reward) # update the performances of nodes along the way until the root
312
- tq.set_postfix_str(f"Iteration: {iteration}, Num of steps: {len(mcts.plan)}. depth: {depth}. Max reward: {max_reward}. plan to {tuple(plan_pos)}, looking {plan_dir}. newely expanded: {newely_expanded}")
313
- # update the root and start from it next time.
314
- newely_expanded = 0
315
- action = mcts.best_action(node=tree.root, exploration_constant=0)
316
- if action == -1:
317
- pass
318
- mcts.plan.append(action)
319
- print(f"Executed action {action}")
320
- save_model_and_generate_policy(tree=tree, original_root=original_root, model_file_path=model_file_path, monteCarloTreeSearch=monteCarloTreeSearch)
321
- return mcts.generate_full_policy_sequence(env_name, problem_name, save_fig, fig_path)
322
-
323
- if __name__ == "__main__":
324
- # register(
325
- # id="MiniGrid-DynamicGoalEmpty-8x8-3x6-v0",
326
- # entry_point="minigrid.envs:DynamicGoalEmpty",
327
- # kwargs={"size": 8, "agent_start_pos" : (1, 1), "goal_pos": (3,6) },
328
- # )
329
- # plan("MiniGrid-DynamicGoalEmpty-8x8-3x6-v0")
330
- pass
471
+ """
472
+ Plan a path using Monte Carlo Tree Search (MCTS) algorithm.
473
+
474
+ Args:
475
+ env_name (str): Name of the environment.
476
+ problem_name (str): Name of the problem.
477
+ goal (tuple): Goal state to reach.
478
+ save_fig (bool): Flag to save the figure of the plan.
479
+ fig_path (str): Path to save the figure.
480
+ env_prop: Object with methods to create the sequence image.
481
+ """
482
+ global newely_expanded
483
+ model_dir = get_agent_model_dir(
484
+ env_name=env_name, model_name=problem_name, class_name="MCTS"
485
+ )
486
+ model_file_path = os.path.join(model_dir, "mcts_model.pth")
487
+ if os.path.exists(model_file_path):
488
+ print(f"Loading pre-existing mcts planner in {model_file_path}")
489
+ with open(model_file_path, "rb") as file: # Load the pre-existing model
490
+ try:
491
+ monteCarloTreeSearch = pickle.load(file)
492
+ except Exception:
493
+
494
+ class RenameUnpickler(pickle.Unpickler):
495
+ def find_class(self, module, name):
496
+ renamed_module = module
497
+ if module.startswith("ml"):
498
+ renamed_module = "gr_libs." + renamed_module
499
+ return super().find_class(renamed_module, name)
500
+
501
+ def renamed_load(file_obj):
502
+ return RenameUnpickler(file_obj).load()
503
+
504
+ file.seek(0)
505
+ monteCarloTreeSearch = renamed_load(file)
506
+
507
+ with open(model_file_path, "wb") as file:
508
+ pickle.dump(monteCarloTreeSearch, file)
509
+
510
+ return monteCarloTreeSearch.generate_full_policy_sequence(
511
+ env_name, problem_name, save_fig, fig_path
512
+ )
513
+ if not os.path.exists(
514
+ model_dir
515
+ ): # if we reached here, the model doesn't exist. make sure its folder exists.
516
+ os.makedirs(model_dir)
517
+ steps = 10000
518
+ print(
519
+ f"No tree found. Executing MCTS, starting with {steps} rollouts for each action."
520
+ )
521
+ env = gym.make(id=problem_name)
522
+ random.seed(2)
523
+ tree = Tree()
524
+ mcts = MonteCarloTreeSearch(env=env, tree=tree, goal=goal)
525
+ original_root = tree.root
526
+ depth = 0
527
+ while (
528
+ not tree.root.terminal
529
+ ): # we iterate until the root is a terminal state, meaning the game is over.
530
+ max_reward = 0
531
+ iteration = 0
532
+ steps = max(2000, int(steps * 0.9))
533
+ print(f"Executing {steps} rollouts for each action now.")
534
+ tq = tqdm(
535
+ range(steps),
536
+ postfix=f"Iteration: {iteration}, Num of steps: {len(mcts.plan)}. depth: {depth}. Max reward: {max_reward}. plan to {tuple(env.unwrapped.agent_pos)}, newely expanded: {0}",
537
+ )
538
+ for n in tq:
539
+ iteration = n
540
+ mcts.env.reset()
541
+ # when executing the partial plan, it's possible the environment finished due to the stochasticity. the execution would return false if that happend.
542
+ depth = len(mcts.plan)
543
+ mcts.tree.root = original_root # need to return it to the original root before executing the partial plan as it can lead to a different path and the root can change between iterations.
544
+ node, result = mcts.execute_partial_plan(mcts.plan)
545
+ if not result:
546
+ # false return value from partial plan execution means the plan is finished. we can mark our root as terminal and exit, happy with our plan.
547
+ tree.root.terminal = True
548
+ save_model_and_generate_policy(
549
+ tree=tree,
550
+ original_root=original_root,
551
+ model_file_path=model_file_path,
552
+ monteCarloTreeSearch=mcts,
553
+ )
554
+ return mcts.generate_full_policy_sequence(
555
+ env_name, problem_name, save_fig, fig_path, env_prop
556
+ )
557
+ plan_pos, plan_dir = node.pos, dict_dir_id_to_str[node.state["direction"]]
558
+ tree.root = node # determine the root to be the node executed after the plan for this iteration.
559
+ node, depth = mcts.tree_policy(
560
+ root_depth=depth
561
+ ) # find a path to a new unvisited node (unique sequence of actions) by utilizing explorative policy or choosing unvisited children recursively
562
+ # if the node that returned from tree policy is terminal, the reward will be returned from "simulation" function immediately.
563
+ reward = mcts.simulation(
564
+ node
565
+ ) # proceed from that node randomly and collect the final reward expected from it (heuristic)
566
+ if reward > max_reward:
567
+ max_reward = reward
568
+ mcts.backpropagation(
569
+ node, reward
570
+ ) # update the performances of nodes along the way until the root
571
+ tq.set_postfix_str(
572
+ f"Iteration: {iteration}, Num of steps: {len(mcts.plan)}. depth: {depth}. Max reward: {max_reward}. plan to {tuple(plan_pos)}, looking {plan_dir}. newely expanded: {newely_expanded}"
573
+ )
574
+ # update the root and start from it next time.
575
+ newely_expanded = 0
576
+ action = mcts.best_action(node=tree.root, exploration_constant=0)
577
+ if action == -1:
578
+ pass
579
+ mcts.plan.append(action)
580
+ print(f"Executed action {action}")
581
+ save_model_and_generate_policy(
582
+ tree=tree,
583
+ original_root=original_root,
584
+ model_file_path=model_file_path,
585
+ monteCarloTreeSearch=monteCarloTreeSearch,
586
+ )
587
+ return mcts.generate_full_policy_sequence(
588
+ env_name, problem_name, save_fig, fig_path
589
+ )