gr-libs 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evaluation/analyze_results_cross_alg_cross_domain.py +277 -0
- evaluation/create_minigrid_map_image.py +34 -0
- evaluation/file_system.py +42 -0
- evaluation/generate_experiments_results.py +92 -0
- evaluation/generate_experiments_results_new_ver1.py +254 -0
- evaluation/generate_experiments_results_new_ver2.py +331 -0
- evaluation/generate_task_specific_statistics_plots.py +272 -0
- evaluation/get_plans_images.py +47 -0
- evaluation/increasing_and_decreasing_.py +63 -0
- gr_libs/__init__.py +2 -0
- gr_libs/environment/__init__.py +0 -0
- gr_libs/environment/environment.py +227 -0
- gr_libs/environment/utils/__init__.py +0 -0
- gr_libs/environment/utils/utils.py +17 -0
- gr_libs/metrics/__init__.py +0 -0
- gr_libs/metrics/metrics.py +224 -0
- gr_libs/ml/__init__.py +6 -0
- gr_libs/ml/agent.py +56 -0
- gr_libs/ml/base/__init__.py +1 -0
- gr_libs/ml/base/rl_agent.py +54 -0
- gr_libs/ml/consts.py +22 -0
- gr_libs/ml/neural/__init__.py +3 -0
- gr_libs/ml/neural/deep_rl_learner.py +395 -0
- gr_libs/ml/neural/utils/__init__.py +2 -0
- gr_libs/ml/neural/utils/dictlist.py +33 -0
- gr_libs/ml/neural/utils/penv.py +57 -0
- gr_libs/ml/planner/__init__.py +0 -0
- gr_libs/ml/planner/mcts/__init__.py +0 -0
- gr_libs/ml/planner/mcts/mcts_model.py +330 -0
- gr_libs/ml/planner/mcts/utils/__init__.py +2 -0
- gr_libs/ml/planner/mcts/utils/node.py +33 -0
- gr_libs/ml/planner/mcts/utils/tree.py +102 -0
- gr_libs/ml/sequential/__init__.py +1 -0
- gr_libs/ml/sequential/lstm_model.py +192 -0
- gr_libs/ml/tabular/__init__.py +3 -0
- gr_libs/ml/tabular/state.py +21 -0
- gr_libs/ml/tabular/tabular_q_learner.py +453 -0
- gr_libs/ml/tabular/tabular_rl_agent.py +126 -0
- gr_libs/ml/utils/__init__.py +6 -0
- gr_libs/ml/utils/env.py +7 -0
- gr_libs/ml/utils/format.py +100 -0
- gr_libs/ml/utils/math.py +13 -0
- gr_libs/ml/utils/other.py +24 -0
- gr_libs/ml/utils/storage.py +127 -0
- gr_libs/recognizer/__init__.py +0 -0
- gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +102 -0
- gr_libs/recognizer/graml/__init__.py +0 -0
- gr_libs/recognizer/graml/gr_dataset.py +134 -0
- gr_libs/recognizer/graml/graml_recognizer.py +266 -0
- gr_libs/recognizer/recognizer.py +46 -0
- gr_libs/recognizer/utils/__init__.py +1 -0
- gr_libs/recognizer/utils/format.py +13 -0
- gr_libs-0.1.3.dist-info/METADATA +197 -0
- gr_libs-0.1.3.dist-info/RECORD +62 -0
- gr_libs-0.1.3.dist-info/WHEEL +5 -0
- gr_libs-0.1.3.dist-info/top_level.txt +3 -0
- tutorials/graml_minigrid_tutorial.py +30 -0
- tutorials/graml_panda_tutorial.py +32 -0
- tutorials/graml_parking_tutorial.py +38 -0
- tutorials/graml_point_maze_tutorial.py +43 -0
- tutorials/graql_minigrid_tutorial.py +29 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
|
3
|
+
|
4
|
+
class TabularState(ABC):
|
5
|
+
def __init__(self,
|
6
|
+
agent_x_position: int,
|
7
|
+
agent_y_position: int,
|
8
|
+
agent_direction: int
|
9
|
+
):
|
10
|
+
self._agent_x_position = agent_x_position
|
11
|
+
self._agent_y_position = agent_y_position
|
12
|
+
self._agent_direction = agent_direction
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def gen_tabular_state(environment, observation):
|
16
|
+
x, y = environment.unwrapped.agent_pos
|
17
|
+
direction = observation['direction']
|
18
|
+
return TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
|
19
|
+
|
20
|
+
def __str__(self):
|
21
|
+
return f"({self._agent_x_position},{self._agent_y_position}):{self._agent_direction}"
|
@@ -0,0 +1,453 @@
|
|
1
|
+
# Don't import stuff from metrics! it's a higher level module.
|
2
|
+
import os.path
|
3
|
+
import pickle
|
4
|
+
import random
|
5
|
+
from types import MethodType
|
6
|
+
|
7
|
+
import dill
|
8
|
+
from gymnasium import register
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
from tqdm import tqdm
|
12
|
+
from typing import Any
|
13
|
+
from random import Random
|
14
|
+
from typing import List, Iterable
|
15
|
+
from gymnasium.error import InvalidAction
|
16
|
+
from gr_libs.environment.environment import QLEARNING, MinigridProperty
|
17
|
+
from gr_libs.ml.tabular import TabularState
|
18
|
+
from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
|
19
|
+
from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
|
20
|
+
|
21
|
+
|
22
|
+
class TabularQLearner(TabularRLAgent):
|
23
|
+
"""
|
24
|
+
A simple Tabular Q-Learning agent.
|
25
|
+
"""
|
26
|
+
|
27
|
+
MODEL_FILE_NAME = r"tabular_model.txt"
|
28
|
+
CONF_FILE = r"conf.pkl"
|
29
|
+
|
30
|
+
def __init__(self,
|
31
|
+
domain_name: str,
|
32
|
+
problem_name: str,
|
33
|
+
algorithm: str,
|
34
|
+
num_timesteps: int,
|
35
|
+
decaying_eps: bool = True,
|
36
|
+
eps: float = 1.0,
|
37
|
+
alpha: float = 0.5,
|
38
|
+
decay: float = 0.000002,
|
39
|
+
gamma: float = 0.9,
|
40
|
+
rand: Random = Random(),
|
41
|
+
learning_rate: float = 0.001,
|
42
|
+
check_partial_goals: bool = True,
|
43
|
+
valid_only: bool = False
|
44
|
+
):
|
45
|
+
super().__init__(
|
46
|
+
domain_name=domain_name,
|
47
|
+
problem_name=problem_name,
|
48
|
+
episodes=num_timesteps,
|
49
|
+
decaying_eps=decaying_eps,
|
50
|
+
eps=eps,
|
51
|
+
alpha=alpha,
|
52
|
+
decay=decay,
|
53
|
+
gamma=gamma,
|
54
|
+
rand=rand,
|
55
|
+
learning_rate=learning_rate
|
56
|
+
)
|
57
|
+
assert algorithm == QLEARNING, f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
|
58
|
+
self.valid_only = valid_only
|
59
|
+
self.check_partial_goals = check_partial_goals
|
60
|
+
self.goal_literals_achieved = set()
|
61
|
+
self.model_directory = get_agent_model_dir(domain_name=domain_name, model_name=problem_name, class_name=self.class_name())
|
62
|
+
self.model_file_path = os.path.join(self.model_directory, TabularQLearner.MODEL_FILE_NAME)
|
63
|
+
self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
|
64
|
+
|
65
|
+
self._learned_episodes = 0
|
66
|
+
|
67
|
+
if os.path.exists(self.model_file_path):
|
68
|
+
print(f"Loading pre-existing model in {self.model_file_path}")
|
69
|
+
self.load_q_table(path=self.model_file_path)
|
70
|
+
else:
|
71
|
+
print(f"Creating new model in {self.model_file_path}")
|
72
|
+
if os.path.exists(self._conf_file):
|
73
|
+
print(f"Loading pre-existing conf file in {self._conf_file}")
|
74
|
+
with open(self._conf_file, "rb") as f:
|
75
|
+
conf = dill.load(file=f)
|
76
|
+
self._learned_episodes = conf['learned_episodes']
|
77
|
+
|
78
|
+
# hyperparameters
|
79
|
+
self.base_eps = eps
|
80
|
+
self.patience = 400000
|
81
|
+
if self.decaying_eps:
|
82
|
+
def epsilon():
|
83
|
+
self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
|
84
|
+
return self._c_eps
|
85
|
+
|
86
|
+
self.eps = epsilon
|
87
|
+
else:
|
88
|
+
self.eps = lambda: eps
|
89
|
+
self.decaying_eps = decaying_eps
|
90
|
+
self.alpha = alpha
|
91
|
+
self.last_state = None
|
92
|
+
self.last_action = None
|
93
|
+
|
94
|
+
def states_in_q(self) -> Iterable:
|
95
|
+
"""Returns the states stored in the q_values table
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
List: The states for which we have a mapping in the q-table
|
99
|
+
"""
|
100
|
+
return self.q_table.keys()
|
101
|
+
|
102
|
+
def policy(self, state: TabularState) -> Any:
|
103
|
+
"""Returns the greedy deterministic policy for the specified state
|
104
|
+
|
105
|
+
Args:
|
106
|
+
state (State): the state for which we want the action
|
107
|
+
|
108
|
+
Raises:
|
109
|
+
InvalidAction: Not sure about this one
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Any: The greedy action learned for state
|
113
|
+
"""
|
114
|
+
return self.best_action(state)
|
115
|
+
|
116
|
+
def epsilon_greedy_policy(self, state: TabularState) -> Any:
|
117
|
+
eps = self.eps()
|
118
|
+
if self._random.random() <= eps:
|
119
|
+
action = self._random.randint(0, self.number_of_actions - 1)
|
120
|
+
else:
|
121
|
+
action = self.policy(state)
|
122
|
+
return action
|
123
|
+
|
124
|
+
def softmax_policy(self, state: TabularState) -> np.array:
|
125
|
+
"""Returns a softmax policy over the q-value returns stored in the q-table
|
126
|
+
|
127
|
+
Args:
|
128
|
+
state (State): the state for which we want a softmax policy
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
np.array: probability of taking each action in self.actions given a state
|
132
|
+
"""
|
133
|
+
if str(state) not in self.q_table:
|
134
|
+
self.add_new_state(state)
|
135
|
+
# If we query a state we have not visited, return a uniform distribution
|
136
|
+
# return softmax([0]*self.actions)
|
137
|
+
return softmax(self.q_table[str(state)])
|
138
|
+
|
139
|
+
def save_q_table(self, path: str):
|
140
|
+
# sadly, this does not work, because the state we are using
|
141
|
+
# is a frozenset of literals, which are not serializable.
|
142
|
+
# a way to fix this is to use array states built using
|
143
|
+
# common_functions.build_state
|
144
|
+
|
145
|
+
directory = os.path.dirname(path)
|
146
|
+
if not os.path.exists(directory):
|
147
|
+
os.makedirs(directory)
|
148
|
+
|
149
|
+
with open(path, 'wb') as f:
|
150
|
+
pickle.dump(self.q_table, f)
|
151
|
+
|
152
|
+
def load_q_table(self, path: str):
|
153
|
+
with open(path, 'rb') as f:
|
154
|
+
table = pickle.load(f)
|
155
|
+
self.q_table = table
|
156
|
+
|
157
|
+
def add_new_state(self, state: TabularState):
|
158
|
+
self.q_table[str(state)] = [0.] * self.number_of_actions
|
159
|
+
|
160
|
+
def get_all_q_values(self, state: TabularState) -> List[float]:
|
161
|
+
if str(state) in self.q_table:
|
162
|
+
return self.q_table[str(state)]
|
163
|
+
else:
|
164
|
+
return [0.] * self.number_of_actions
|
165
|
+
|
166
|
+
def best_action(self, state: TabularState) -> float:
|
167
|
+
if str(state) not in self.q_table:
|
168
|
+
self.add_new_state(state)
|
169
|
+
return np.argmax(self.q_table[str(state)])
|
170
|
+
|
171
|
+
def get_max_q(self, state) -> float:
|
172
|
+
if str(state) not in self.q_table:
|
173
|
+
self.add_new_state(state)
|
174
|
+
return np.max(self.q_table[str(state)])
|
175
|
+
|
176
|
+
def set_q_value(self, state: TabularState, action: Any, q_value: float):
|
177
|
+
if str(state) not in self.q_table:
|
178
|
+
self.add_new_state(state)
|
179
|
+
self.q_table[str(state)][action] = q_value
|
180
|
+
|
181
|
+
def get_q_value(self, state: TabularState, action: Any) -> float:
|
182
|
+
if str(state) not in self.q_table:
|
183
|
+
self.add_new_state(state)
|
184
|
+
return self.q_table[str(state)][action]
|
185
|
+
|
186
|
+
def agent_start(self, state: TabularState) -> int:
|
187
|
+
"""The first method called when the experiment starts,
|
188
|
+
called after the environment starts.
|
189
|
+
Args:
|
190
|
+
state (Numpy array): the state from the
|
191
|
+
environment's env_start function.
|
192
|
+
Returns:
|
193
|
+
(int) the first action the agent takes.
|
194
|
+
"""
|
195
|
+
self.last_state = state
|
196
|
+
self.last_action = self.policy(state)
|
197
|
+
return self.last_action
|
198
|
+
|
199
|
+
def agent_step(self, reward: float, state: TabularState) -> int:
|
200
|
+
"""A step taken by the agent.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
reward (float): the reward received for taking the last action taken
|
204
|
+
state (Any): the state from the
|
205
|
+
environment's step based on where the agent ended up after the
|
206
|
+
last step
|
207
|
+
Returns:
|
208
|
+
(int) The action the agent takes given this state.
|
209
|
+
"""
|
210
|
+
max_q = self.get_max_q(state)
|
211
|
+
old_q = self.get_q_value(self.last_state, self.last_action)
|
212
|
+
|
213
|
+
td_error = self.gamma * max_q - old_q
|
214
|
+
new_q = old_q + self.alpha * (reward + td_error)
|
215
|
+
|
216
|
+
self.set_q_value(self.last_state, self.last_action, new_q)
|
217
|
+
# action = self.best_action(state)
|
218
|
+
action = self.epsilon_greedy_policy(state)
|
219
|
+
self.last_state = state
|
220
|
+
self.last_action = action
|
221
|
+
return action
|
222
|
+
|
223
|
+
def agent_end(self, reward: float) -> Any:
|
224
|
+
"""Called when the agent terminates.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
reward (float): the reward the agent received for entering the
|
228
|
+
terminal state.
|
229
|
+
"""
|
230
|
+
old_q = self.get_q_value(self.last_state, self.last_action)
|
231
|
+
|
232
|
+
td_error = - old_q
|
233
|
+
|
234
|
+
new_q = old_q + self.alpha * (reward + td_error)
|
235
|
+
self.set_q_value(self.last_state, self.last_action, new_q)
|
236
|
+
|
237
|
+
def learn(self, init_threshold: int = 20):
|
238
|
+
tsteps = 2000
|
239
|
+
done_times = 0
|
240
|
+
patience = 0
|
241
|
+
converged_at = None
|
242
|
+
max_r = float("-inf")
|
243
|
+
print(f"{self._learned_episodes}->{self.episodes}")
|
244
|
+
if self._learned_episodes >= self.episodes:
|
245
|
+
print("learned episodes is above the requsted episodes")
|
246
|
+
return
|
247
|
+
print(f'Using {self.__class__.__name__}')
|
248
|
+
tq = tqdm(range(self.episodes - self._learned_episodes),
|
249
|
+
postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
|
250
|
+
for n in tq:
|
251
|
+
self.step = n
|
252
|
+
episode_r = 0
|
253
|
+
observation, info = self.env.reset()
|
254
|
+
tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
|
255
|
+
action = self.agent_start(state=tabular_state)
|
256
|
+
|
257
|
+
self.update_states_counter(observation_str=str(tabular_state))
|
258
|
+
done = False
|
259
|
+
tstep = 0
|
260
|
+
while tstep < tsteps and not done:
|
261
|
+
observation, reward, terminated, truncated, _ = self.env.step(action)
|
262
|
+
done = terminated | truncated
|
263
|
+
if done:
|
264
|
+
done_times += 1
|
265
|
+
|
266
|
+
# standard q-learning algorithm
|
267
|
+
next_tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
|
268
|
+
self.update_states_counter(observation_str=str(next_tabular_state))
|
269
|
+
action = self.agent_step(reward, next_tabular_state)
|
270
|
+
tstep += 1
|
271
|
+
episode_r += reward
|
272
|
+
self._learned_episodes = self._learned_episodes + 1
|
273
|
+
if done: # One last update at the terminal state
|
274
|
+
self.agent_end(reward)
|
275
|
+
|
276
|
+
if episode_r > max_r:
|
277
|
+
max_r = episode_r
|
278
|
+
# print("New all time high reward:", episode_r)
|
279
|
+
tq.set_postfix_str(
|
280
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
|
281
|
+
if (n + 1) % 100 == 0:
|
282
|
+
tq.set_postfix_str(
|
283
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
|
284
|
+
if (n + 1) % 1000 == 0:
|
285
|
+
tq.set_postfix_str(
|
286
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
|
287
|
+
if done_times <= 10:
|
288
|
+
patience += 1
|
289
|
+
if patience >= self.patience:
|
290
|
+
print(f"Did not find goal after {n} episodes. Retrying.")
|
291
|
+
raise InvalidAction("Did not learn")
|
292
|
+
else:
|
293
|
+
patience = 0
|
294
|
+
if done_times == 1000 and converged_at is not None:
|
295
|
+
converged_at = n
|
296
|
+
print(f"***Policy converged to goal at {converged_at}***")
|
297
|
+
done_times = 0
|
298
|
+
self.goal_literals_achieved.clear()
|
299
|
+
|
300
|
+
print(f"number of unique states found during training:{self.get_number_of_unique_states()}")
|
301
|
+
print("finish learning and saving status")
|
302
|
+
self.save_models_to_files()
|
303
|
+
|
304
|
+
def exploit(self, number_of_steps=20):
|
305
|
+
observation, info = self.env.reset()
|
306
|
+
for step_number in range(number_of_steps):
|
307
|
+
tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
|
308
|
+
action = self.policy(state=tabular_state)
|
309
|
+
observation, reward, terminated, truncated, _ = self.env.step(action)
|
310
|
+
done = terminated | truncated
|
311
|
+
if done:
|
312
|
+
print(f"reached goal after {step_number + 1} steps!")
|
313
|
+
break
|
314
|
+
|
315
|
+
def get_actions_probabilities(self, observation):
|
316
|
+
obs, agent_pos = observation
|
317
|
+
direction = obs['direction']
|
318
|
+
|
319
|
+
x, y = agent_pos
|
320
|
+
tabular_state = TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
|
321
|
+
return softmax(self.get_all_q_values(tabular_state))
|
322
|
+
|
323
|
+
def get_q_of_specific_cell(self, cell_key):
|
324
|
+
cell_q_table = {}
|
325
|
+
for i in range(4):
|
326
|
+
key = cell_key + ':' + str(i)
|
327
|
+
if key in self.q_table:
|
328
|
+
cell_q_table[key] = self.q_table[key]
|
329
|
+
return cell_q_table
|
330
|
+
|
331
|
+
def get_all_cells(self):
|
332
|
+
cells = set()
|
333
|
+
for key in self.q_table.keys():
|
334
|
+
cell = key.split(':')[0]
|
335
|
+
cells.add(cell)
|
336
|
+
return list(cells)
|
337
|
+
|
338
|
+
|
339
|
+
def _save_conf_file(self):
|
340
|
+
conf = {
|
341
|
+
'learned_episodes': self._learned_episodes,
|
342
|
+
'states_counter': self.states_counter
|
343
|
+
}
|
344
|
+
with open(self._conf_file, "wb") as f:
|
345
|
+
dill.dump(conf, f)
|
346
|
+
|
347
|
+
def save_models_to_files(self):
|
348
|
+
self.save_q_table(path=self.model_file_path)
|
349
|
+
self._save_conf_file()
|
350
|
+
|
351
|
+
def simplify_observation(self, observation):
|
352
|
+
return [(obs['direction'], agent_pos_x, agent_pos_y, action) for ((obs, (agent_pos_x, agent_pos_y)), action) in observation] # list of tuples, each tuple the sample
|
353
|
+
|
354
|
+
def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig = False, fig_path: str=None, env_prop=None):
|
355
|
+
"""
|
356
|
+
Generate a single observation given a list of agents
|
357
|
+
|
358
|
+
Args:
|
359
|
+
agents (list): A list of agents from which to select one randomly.
|
360
|
+
action_selection_method : a MethodType, to generate the observation stochastically, greedily, or softmax.
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
list: A list of state-action pairs representing the generated observation.
|
364
|
+
|
365
|
+
Notes:
|
366
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
367
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
368
|
+
selected based on the probability distribution defined by the Q-values in the Q-table.
|
369
|
+
|
370
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
371
|
+
episode terminates.
|
372
|
+
"""
|
373
|
+
if save_fig == False:
|
374
|
+
assert fig_path == None, "You can't specify a vid path when you don't even save the figure."
|
375
|
+
else:
|
376
|
+
assert fig_path != None, "You must specify a vid path when you save the figure."
|
377
|
+
obs, _ = self.env.reset()
|
378
|
+
MAX_STEPS = 32
|
379
|
+
done = False
|
380
|
+
steps = []
|
381
|
+
for step_index in range(MAX_STEPS):
|
382
|
+
x, y = self.env.unwrapped.agent_pos
|
383
|
+
str_state = "({},{}):{}".format(x, y, obs['direction'])
|
384
|
+
relevant_actions_idx = 3
|
385
|
+
action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(self.q_table[str_state][:relevant_actions_idx]) # Normalize probabilities
|
386
|
+
if step_index == 0 and random_optimalism:
|
387
|
+
# print("in 1st step in generating plan and got random optimalism.")
|
388
|
+
std_dev = np.std(action_probs)
|
389
|
+
# uniques_sorted = np.unique(action_probs)
|
390
|
+
num_of_stds = abs(action_probs[0] - action_probs[2]) / std_dev
|
391
|
+
if num_of_stds < 2.1:
|
392
|
+
# sorted_indices = np.argsort(action_probs)
|
393
|
+
# action = np.random.choice([sorted_indices[-1], sorted_indices[-2]])
|
394
|
+
action = np.random.choice([0, 2])
|
395
|
+
if action == 0:
|
396
|
+
steps.append(((obs, self.env.unwrapped.agent_pos), action))
|
397
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
398
|
+
assert reward >= 0
|
399
|
+
action = 2
|
400
|
+
step_index += 1
|
401
|
+
else: action = action_selection_method(action_probs)
|
402
|
+
else:
|
403
|
+
action = action_selection_method(action_probs)
|
404
|
+
steps.append(((obs, self.env.unwrapped.agent_pos), action))
|
405
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
406
|
+
assert reward >= 0
|
407
|
+
done = terminated | truncated
|
408
|
+
if done:
|
409
|
+
break
|
410
|
+
|
411
|
+
#assert len(steps) >= 2
|
412
|
+
if save_fig:
|
413
|
+
sequence = [pos for ((state, pos), action) in steps]
|
414
|
+
#print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
|
415
|
+
print(f"generating sequence image at {fig_path}.")
|
416
|
+
env_prop.create_sequence_image(sequence, fig_path, self.problem_name) # TODO change that assumption, cannot assume this is minigrid env
|
417
|
+
|
418
|
+
return steps
|
419
|
+
|
420
|
+
def generate_partial_observation(self, action_selection_method: MethodType, percentage: float, save_fig = False, is_consecutive = True, random_optimalism=True, fig_path=None):
|
421
|
+
"""
|
422
|
+
Generate a single observation given a list of agents
|
423
|
+
|
424
|
+
Args:
|
425
|
+
agents (list): A list of agents from which to select one randomly.
|
426
|
+
action_selection_method : a MethodType, to generate the observation stochastically, greedily, or softmax.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
list: A list of state-action pairs representing the generated observation.
|
430
|
+
|
431
|
+
Notes:
|
432
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
433
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
434
|
+
selected based on the probability distribution defined by the Q-values in the Q-table.
|
435
|
+
|
436
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
437
|
+
episode terminates.
|
438
|
+
"""
|
439
|
+
|
440
|
+
steps = self.generate_observation(action_selection_method=action_selection_method, random_optimalism=random_optimalism, save_fig=save_fig, fig_path=fig_path) # steps are a full observation
|
441
|
+
result = random_subset_with_order(steps, (int)(percentage * len(steps)), is_consecutive)
|
442
|
+
if percentage >= 0.8:
|
443
|
+
assert len(result) > 2
|
444
|
+
return result
|
445
|
+
|
446
|
+
if __name__ == "__main__":
|
447
|
+
from gr_libs.metrics.metrics import greedy_selection
|
448
|
+
import gr_envs # to register everything
|
449
|
+
agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
|
450
|
+
agent.generate_observation(greedy_selection, True, True)
|
451
|
+
|
452
|
+
# python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
|
453
|
+
|
@@ -0,0 +1,126 @@
|
|
1
|
+
import gymnasium as gym
|
2
|
+
from abc import abstractmethod
|
3
|
+
from typing import Collection, Literal, Any
|
4
|
+
from random import Random
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from gr_libs.ml.base import RLAgent
|
8
|
+
from gr_libs.ml.base import State
|
9
|
+
|
10
|
+
|
11
|
+
class TabularRLAgent(RLAgent):
|
12
|
+
"""
|
13
|
+
This is a base class used as parent class for any
|
14
|
+
RL agent. This is currently not much in use, but is
|
15
|
+
recommended as development goes on.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self,
|
19
|
+
domain_name: str,
|
20
|
+
problem_name: str,
|
21
|
+
episodes: int,
|
22
|
+
decaying_eps: bool,
|
23
|
+
eps: float,
|
24
|
+
alpha: float,
|
25
|
+
decay: float,
|
26
|
+
gamma: float,
|
27
|
+
rand: Random,
|
28
|
+
learning_rate
|
29
|
+
):
|
30
|
+
super().__init__(
|
31
|
+
episodes=episodes,
|
32
|
+
decaying_eps=decaying_eps,
|
33
|
+
epsilon=eps,
|
34
|
+
learning_rate=learning_rate,
|
35
|
+
gamma=gamma,
|
36
|
+
domain_name=domain_name,
|
37
|
+
problem_name=problem_name
|
38
|
+
)
|
39
|
+
self.env = gym.make(id=problem_name)
|
40
|
+
self.actions = self.env.unwrapped.actions
|
41
|
+
self.number_of_actions = len(self.actions)
|
42
|
+
self._actions_space = self.env.action_space
|
43
|
+
self._random = rand
|
44
|
+
self._alpha = alpha
|
45
|
+
self._decay = decay
|
46
|
+
self._c_eps = eps
|
47
|
+
self.q_table = {}
|
48
|
+
|
49
|
+
# TODO:: maybe need to save env.reset output
|
50
|
+
self.env.reset()
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
def agent_start(self, state) -> Any:
|
54
|
+
"""The first method called when the experiment starts,
|
55
|
+
called after the environment starts.
|
56
|
+
Args:
|
57
|
+
state (Numpy array): the state from the
|
58
|
+
environment's env_start function.
|
59
|
+
Returns:
|
60
|
+
(int) the first action the agent takes.
|
61
|
+
"""
|
62
|
+
pass
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def agent_step(self, reward: float, state: State) -> Any:
|
66
|
+
"""A step taken by the agent.
|
67
|
+
Args:
|
68
|
+
reward (float): the reward received for taking the last action taken
|
69
|
+
state (Any): the state observation from the
|
70
|
+
environment's step based, where the agent ended up after the
|
71
|
+
last step
|
72
|
+
Returns:
|
73
|
+
The action the agent is taking.
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
def agent_end(self, reward: float) -> Any:
|
79
|
+
"""Called when the agent terminates.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
reward (float): the reward the agent received for entering the
|
83
|
+
terminal state.
|
84
|
+
"""
|
85
|
+
pass
|
86
|
+
|
87
|
+
@abstractmethod
|
88
|
+
def policy(self, state: State) -> Any:
|
89
|
+
"""The action for the specified state under the currently learned policy
|
90
|
+
(unlike agent_step, this does not update the policy using state as a sample.
|
91
|
+
Args:
|
92
|
+
state (Any): the state observation from the environment
|
93
|
+
Returns:
|
94
|
+
The action prescribed for that state
|
95
|
+
"""
|
96
|
+
pass
|
97
|
+
|
98
|
+
@abstractmethod
|
99
|
+
def softmax_policy(self, state: State) -> np.array:
|
100
|
+
"""Returns a softmax policy over the q-value returns stored in the q-table
|
101
|
+
|
102
|
+
Args:
|
103
|
+
state (State): the state for which we want a softmax policy
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
np.array: probability of taking each action in self.actions given a state
|
107
|
+
"""
|
108
|
+
pass
|
109
|
+
|
110
|
+
@abstractmethod
|
111
|
+
def learn(self, init_threshold: int = 20):
|
112
|
+
pass
|
113
|
+
|
114
|
+
def __getitem__(self, state: State) -> Any:
|
115
|
+
"""[summary]
|
116
|
+
|
117
|
+
Args:
|
118
|
+
state (Any): The state for which we want to get the policy
|
119
|
+
|
120
|
+
Raises:
|
121
|
+
InvalidAction: [description]
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
Any: [description]
|
125
|
+
"""""
|
126
|
+
return self.softmax_policy(state)
|