gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gr_libs/__init__.py +4 -1
- gr_libs/_evaluation/__init__.py +1 -0
- gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
- gr_libs/_evaluation/_generate_experiments_results.py +141 -0
- gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
- gr_libs/_evaluation/_get_plans_images.py +61 -0
- gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
- gr_libs/_version.py +2 -2
- gr_libs/all_experiments.py +294 -0
- gr_libs/environment/__init__.py +30 -9
- gr_libs/environment/_utils/utils.py +27 -0
- gr_libs/environment/environment.py +417 -54
- gr_libs/metrics/__init__.py +7 -0
- gr_libs/metrics/metrics.py +231 -54
- gr_libs/ml/__init__.py +2 -5
- gr_libs/ml/agent.py +21 -6
- gr_libs/ml/base/__init__.py +3 -1
- gr_libs/ml/base/rl_agent.py +81 -13
- gr_libs/ml/consts.py +1 -1
- gr_libs/ml/neural/__init__.py +1 -3
- gr_libs/ml/neural/deep_rl_learner.py +619 -378
- gr_libs/ml/neural/utils/__init__.py +1 -2
- gr_libs/ml/neural/utils/dictlist.py +3 -3
- gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
- gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
- gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
- gr_libs/ml/planner/mcts/mcts_model.py +571 -312
- gr_libs/ml/sequential/__init__.py +0 -1
- gr_libs/ml/sequential/_lstm_model.py +270 -0
- gr_libs/ml/tabular/__init__.py +1 -3
- gr_libs/ml/tabular/state.py +7 -7
- gr_libs/ml/tabular/tabular_q_learner.py +150 -82
- gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
- gr_libs/ml/utils/__init__.py +2 -3
- gr_libs/ml/utils/format.py +28 -97
- gr_libs/ml/utils/math.py +5 -3
- gr_libs/ml/utils/other.py +3 -3
- gr_libs/ml/utils/storage.py +88 -81
- gr_libs/odgr_executor.py +268 -0
- gr_libs/problems/consts.py +1549 -1227
- gr_libs/recognizer/_utils/__init__.py +0 -0
- gr_libs/recognizer/_utils/format.py +18 -0
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
- gr_libs/recognizer/graml/_gr_dataset.py +233 -0
- gr_libs/recognizer/graml/graml_recognizer.py +586 -252
- gr_libs/recognizer/recognizer.py +90 -30
- gr_libs/tutorials/draco_panda_tutorial.py +58 -0
- gr_libs/tutorials/draco_parking_tutorial.py +56 -0
- gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
- gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
- gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
- gr_libs/tutorials/graml_panda_tutorial.py +57 -0
- gr_libs/tutorials/graml_parking_tutorial.py +52 -0
- gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
- gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
- gr_libs-0.2.2.dist-info/RECORD +71 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
- gr_libs-0.2.2.dist-info/top_level.txt +2 -0
- tests/test_draco.py +14 -0
- tests/test_gcdraco.py +10 -0
- tests/test_graml.py +12 -8
- tests/test_graql.py +3 -2
- evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
- evaluation/create_minigrid_map_image.py +0 -34
- evaluation/file_system.py +0 -42
- evaluation/generate_experiments_results.py +0 -92
- evaluation/generate_experiments_results_new_ver1.py +0 -254
- evaluation/generate_experiments_results_new_ver2.py +0 -331
- evaluation/generate_task_specific_statistics_plots.py +0 -272
- evaluation/get_plans_images.py +0 -47
- evaluation/increasing_and_decreasing_.py +0 -63
- gr_libs/environment/utils/utils.py +0 -17
- gr_libs/ml/neural/utils/penv.py +0 -57
- gr_libs/ml/sequential/lstm_model.py +0 -192
- gr_libs/recognizer/graml/gr_dataset.py +0 -134
- gr_libs/recognizer/utils/__init__.py +0 -1
- gr_libs/recognizer/utils/format.py +0 -13
- gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
- gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
- tutorials/graml_minigrid_tutorial.py +0 -34
- tutorials/graml_panda_tutorial.py +0 -41
- tutorials/graml_parking_tutorial.py +0 -39
- tutorials/graml_point_maze_tutorial.py +0 -39
- tutorials/graql_minigrid_tutorial.py +0 -34
- /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,19 +1,18 @@
|
|
1
|
-
|
1
|
+
""" implementation of q-learning """
|
2
|
+
|
2
3
|
import os.path
|
3
4
|
import pickle
|
4
|
-
import
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from random import Random
|
5
7
|
from types import MethodType
|
8
|
+
from typing import Any
|
6
9
|
|
7
10
|
import dill
|
8
|
-
from gymnasium import register
|
9
11
|
import numpy as np
|
10
|
-
|
11
|
-
from tqdm import tqdm
|
12
|
-
from typing import Any
|
13
|
-
from random import Random
|
14
|
-
from typing import List, Iterable
|
15
12
|
from gymnasium.error import InvalidAction
|
16
|
-
from
|
13
|
+
from tqdm import tqdm
|
14
|
+
|
15
|
+
from gr_libs.environment.environment import QLEARNING, EnvProperty
|
17
16
|
from gr_libs.ml.tabular import TabularState
|
18
17
|
from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
|
19
18
|
from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
|
@@ -27,21 +26,42 @@ class TabularQLearner(TabularRLAgent):
|
|
27
26
|
MODEL_FILE_NAME = r"tabular_model.txt"
|
28
27
|
CONF_FILE = r"conf.pkl"
|
29
28
|
|
30
|
-
def __init__(
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
domain_name: str,
|
32
|
+
problem_name: str,
|
33
|
+
env_prop: EnvProperty,
|
34
|
+
algorithm: str,
|
35
|
+
num_timesteps: int,
|
36
|
+
decaying_eps: bool = True,
|
37
|
+
eps: float = 1.0,
|
38
|
+
alpha: float = 0.5,
|
39
|
+
decay: float = 0.000002,
|
40
|
+
gamma: float = 0.9,
|
41
|
+
rand: Random = Random(),
|
42
|
+
learning_rate: float = 0.001,
|
43
|
+
check_partial_goals: bool = True,
|
44
|
+
valid_only: bool = False,
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Initialize a TabularQLearner object.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
domain_name (str): The name of the domain.
|
51
|
+
problem_name (str): The name of the problem.
|
52
|
+
env_prop (EnvProperty): The environment properties.
|
53
|
+
algorithm (str): The algorithm to use.
|
54
|
+
num_timesteps (int): The number of timesteps.
|
55
|
+
decaying_eps (bool, optional): Whether to use decaying epsilon. Defaults to True.
|
56
|
+
eps (float, optional): The initial epsilon value. Defaults to 1.0.
|
57
|
+
alpha (float, optional): The learning rate. Defaults to 0.5.
|
58
|
+
decay (float, optional): The decay rate. Defaults to 0.000002.
|
59
|
+
gamma (float, optional): The discount factor. Defaults to 0.9.
|
60
|
+
rand (Random, optional): The random number generator. Defaults to Random().
|
61
|
+
learning_rate (float, optional): The learning rate. Defaults to 0.001.
|
62
|
+
check_partial_goals (bool, optional): Whether to check partial goals. Defaults to True.
|
63
|
+
valid_only (bool, optional): Whether to use valid goals only. Defaults to False.
|
64
|
+
"""
|
45
65
|
super().__init__(
|
46
66
|
domain_name=domain_name,
|
47
67
|
problem_name=problem_name,
|
@@ -52,14 +72,23 @@ class TabularQLearner(TabularRLAgent):
|
|
52
72
|
decay=decay,
|
53
73
|
gamma=gamma,
|
54
74
|
rand=rand,
|
55
|
-
learning_rate=learning_rate
|
75
|
+
learning_rate=learning_rate,
|
56
76
|
)
|
57
|
-
assert
|
77
|
+
assert (
|
78
|
+
algorithm == QLEARNING
|
79
|
+
), f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
|
80
|
+
self.env_prop = env_prop
|
58
81
|
self.valid_only = valid_only
|
59
82
|
self.check_partial_goals = check_partial_goals
|
60
83
|
self.goal_literals_achieved = set()
|
61
|
-
self.model_directory = get_agent_model_dir(
|
62
|
-
|
84
|
+
self.model_directory = get_agent_model_dir(
|
85
|
+
domain_name=domain_name,
|
86
|
+
model_name=problem_name,
|
87
|
+
class_name=self.class_name(),
|
88
|
+
)
|
89
|
+
self.model_file_path = os.path.join(
|
90
|
+
self.model_directory, TabularQLearner.MODEL_FILE_NAME
|
91
|
+
)
|
63
92
|
self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
|
64
93
|
|
65
94
|
self._learned_episodes = 0
|
@@ -73,12 +102,13 @@ class TabularQLearner(TabularRLAgent):
|
|
73
102
|
print(f"Loading pre-existing conf file in {self._conf_file}")
|
74
103
|
with open(self._conf_file, "rb") as f:
|
75
104
|
conf = dill.load(file=f)
|
76
|
-
self._learned_episodes = conf[
|
105
|
+
self._learned_episodes = conf["learned_episodes"]
|
77
106
|
|
78
107
|
# hyperparameters
|
79
108
|
self.base_eps = eps
|
80
109
|
self.patience = 400000
|
81
110
|
if self.decaying_eps:
|
111
|
+
|
82
112
|
def epsilon():
|
83
113
|
self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
|
84
114
|
return self._c_eps
|
@@ -146,22 +176,22 @@ class TabularQLearner(TabularRLAgent):
|
|
146
176
|
if not os.path.exists(directory):
|
147
177
|
os.makedirs(directory)
|
148
178
|
|
149
|
-
with open(path,
|
179
|
+
with open(path, "wb") as f:
|
150
180
|
pickle.dump(self.q_table, f)
|
151
181
|
|
152
182
|
def load_q_table(self, path: str):
|
153
|
-
with open(path,
|
183
|
+
with open(path, "rb") as f:
|
154
184
|
table = pickle.load(f)
|
155
185
|
self.q_table = table
|
156
186
|
|
157
187
|
def add_new_state(self, state: TabularState):
|
158
|
-
self.q_table[str(state)] = [0.] * self.number_of_actions
|
188
|
+
self.q_table[str(state)] = [0.0] * self.number_of_actions
|
159
189
|
|
160
|
-
def get_all_q_values(self, state: TabularState) ->
|
190
|
+
def get_all_q_values(self, state: TabularState) -> list[float]:
|
161
191
|
if str(state) in self.q_table:
|
162
192
|
return self.q_table[str(state)]
|
163
193
|
else:
|
164
|
-
return [0.] * self.number_of_actions
|
194
|
+
return [0.0] * self.number_of_actions
|
165
195
|
|
166
196
|
def best_action(self, state: TabularState) -> float:
|
167
197
|
if str(state) not in self.q_table:
|
@@ -229,7 +259,7 @@ class TabularQLearner(TabularRLAgent):
|
|
229
259
|
"""
|
230
260
|
old_q = self.get_q_value(self.last_state, self.last_action)
|
231
261
|
|
232
|
-
td_error = -
|
262
|
+
td_error = -old_q
|
233
263
|
|
234
264
|
new_q = old_q + self.alpha * (reward + td_error)
|
235
265
|
self.set_q_value(self.last_state, self.last_action, new_q)
|
@@ -244,14 +274,18 @@ class TabularQLearner(TabularRLAgent):
|
|
244
274
|
if self._learned_episodes >= self.episodes:
|
245
275
|
print("learned episodes is above the requsted episodes")
|
246
276
|
return
|
247
|
-
print(f
|
248
|
-
tq = tqdm(
|
249
|
-
|
277
|
+
print(f"Using {self.__class__.__name__}")
|
278
|
+
tq = tqdm(
|
279
|
+
range(self.episodes - self._learned_episodes),
|
280
|
+
postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}",
|
281
|
+
)
|
250
282
|
for n in tq:
|
251
283
|
self.step = n
|
252
284
|
episode_r = 0
|
253
285
|
observation, info = self.env.reset()
|
254
|
-
tabular_state = TabularState.gen_tabular_state(
|
286
|
+
tabular_state = TabularState.gen_tabular_state(
|
287
|
+
environment=self.env, observation=observation
|
288
|
+
)
|
255
289
|
action = self.agent_start(state=tabular_state)
|
256
290
|
|
257
291
|
self.update_states_counter(observation_str=str(tabular_state))
|
@@ -264,7 +298,9 @@ class TabularQLearner(TabularRLAgent):
|
|
264
298
|
done_times += 1
|
265
299
|
|
266
300
|
# standard q-learning algorithm
|
267
|
-
next_tabular_state = TabularState.gen_tabular_state(
|
301
|
+
next_tabular_state = TabularState.gen_tabular_state(
|
302
|
+
environment=self.env, observation=observation
|
303
|
+
)
|
268
304
|
self.update_states_counter(observation_str=str(next_tabular_state))
|
269
305
|
action = self.agent_step(reward, next_tabular_state)
|
270
306
|
tstep += 1
|
@@ -277,13 +313,16 @@ class TabularQLearner(TabularRLAgent):
|
|
277
313
|
max_r = episode_r
|
278
314
|
# print("New all time high reward:", episode_r)
|
279
315
|
tq.set_postfix_str(
|
280
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
316
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
317
|
+
)
|
281
318
|
if (n + 1) % 100 == 0:
|
282
319
|
tq.set_postfix_str(
|
283
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
320
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
321
|
+
)
|
284
322
|
if (n + 1) % 1000 == 0:
|
285
323
|
tq.set_postfix_str(
|
286
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
324
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
325
|
+
)
|
287
326
|
if done_times <= 10:
|
288
327
|
patience += 1
|
289
328
|
if patience >= self.patience:
|
@@ -297,14 +336,18 @@ class TabularQLearner(TabularRLAgent):
|
|
297
336
|
done_times = 0
|
298
337
|
self.goal_literals_achieved.clear()
|
299
338
|
|
300
|
-
print(
|
339
|
+
print(
|
340
|
+
f"number of unique states found during training:{self.get_number_of_unique_states()}"
|
341
|
+
)
|
301
342
|
print("finish learning and saving status")
|
302
343
|
self.save_models_to_files()
|
303
344
|
|
304
345
|
def exploit(self, number_of_steps=20):
|
305
346
|
observation, info = self.env.reset()
|
306
347
|
for step_number in range(number_of_steps):
|
307
|
-
tabular_state = TabularState.gen_tabular_state(
|
348
|
+
tabular_state = TabularState.gen_tabular_state(
|
349
|
+
environment=self.env, observation=observation
|
350
|
+
)
|
308
351
|
action = self.policy(state=tabular_state)
|
309
352
|
observation, reward, terminated, truncated, _ = self.env.step(action)
|
310
353
|
done = terminated | truncated
|
@@ -314,16 +357,18 @@ class TabularQLearner(TabularRLAgent):
|
|
314
357
|
|
315
358
|
def get_actions_probabilities(self, observation):
|
316
359
|
obs, agent_pos = observation
|
317
|
-
direction = obs[
|
360
|
+
direction = obs["direction"]
|
318
361
|
|
319
362
|
x, y = agent_pos
|
320
|
-
tabular_state = TabularState(
|
363
|
+
tabular_state = TabularState(
|
364
|
+
agent_x_position=x, agent_y_position=y, agent_direction=direction
|
365
|
+
)
|
321
366
|
return softmax(self.get_all_q_values(tabular_state))
|
322
367
|
|
323
368
|
def get_q_of_specific_cell(self, cell_key):
|
324
369
|
cell_q_table = {}
|
325
370
|
for i in range(4):
|
326
|
-
key = cell_key +
|
371
|
+
key = cell_key + ":" + str(i)
|
327
372
|
if key in self.q_table:
|
328
373
|
cell_q_table[key] = self.q_table[key]
|
329
374
|
return cell_q_table
|
@@ -331,15 +376,14 @@ class TabularQLearner(TabularRLAgent):
|
|
331
376
|
def get_all_cells(self):
|
332
377
|
cells = set()
|
333
378
|
for key in self.q_table.keys():
|
334
|
-
cell = key.split(
|
379
|
+
cell = key.split(":")[0]
|
335
380
|
cells.add(cell)
|
336
381
|
return list(cells)
|
337
382
|
|
338
|
-
|
339
383
|
def _save_conf_file(self):
|
340
384
|
conf = {
|
341
|
-
|
342
|
-
|
385
|
+
"learned_episodes": self._learned_episodes,
|
386
|
+
"states_counter": self.states_counter,
|
343
387
|
}
|
344
388
|
with open(self._conf_file, "wb") as f:
|
345
389
|
dill.dump(conf, f)
|
@@ -347,11 +391,20 @@ class TabularQLearner(TabularRLAgent):
|
|
347
391
|
def save_models_to_files(self):
|
348
392
|
self.save_q_table(path=self.model_file_path)
|
349
393
|
self._save_conf_file()
|
350
|
-
|
394
|
+
|
351
395
|
def simplify_observation(self, observation):
|
352
|
-
return [
|
353
|
-
|
354
|
-
|
396
|
+
return [
|
397
|
+
(obs["direction"], agent_pos_x, agent_pos_y, action)
|
398
|
+
for ((obs, (agent_pos_x, agent_pos_y)), action) in observation
|
399
|
+
] # list of tuples, each tuple the sample
|
400
|
+
|
401
|
+
def generate_observation(
|
402
|
+
self,
|
403
|
+
action_selection_method: MethodType,
|
404
|
+
random_optimalism,
|
405
|
+
save_fig=False,
|
406
|
+
fig_path: str = None,
|
407
|
+
):
|
355
408
|
"""
|
356
409
|
Generate a single observation given a list of agents
|
357
410
|
|
@@ -363,26 +416,32 @@ class TabularQLearner(TabularRLAgent):
|
|
363
416
|
list: A list of state-action pairs representing the generated observation.
|
364
417
|
|
365
418
|
Notes:
|
366
|
-
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
367
|
-
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
419
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
420
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
368
421
|
selected based on the probability distribution defined by the Q-values in the Q-table.
|
369
422
|
|
370
|
-
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
423
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
371
424
|
episode terminates.
|
372
425
|
"""
|
373
426
|
if save_fig == False:
|
374
|
-
assert
|
427
|
+
assert (
|
428
|
+
fig_path == None
|
429
|
+
), "You can't specify a vid path when you don't even save the figure."
|
375
430
|
else:
|
376
|
-
assert
|
431
|
+
assert (
|
432
|
+
fig_path != None
|
433
|
+
), "You must specify a vid path when you save the figure."
|
377
434
|
obs, _ = self.env.reset()
|
378
435
|
MAX_STEPS = 32
|
379
436
|
done = False
|
380
437
|
steps = []
|
381
438
|
for step_index in range(MAX_STEPS):
|
382
439
|
x, y = self.env.unwrapped.agent_pos
|
383
|
-
str_state = "({},{}):{}".format(x, y, obs[
|
440
|
+
str_state = "({},{}):{}".format(x, y, obs["direction"])
|
384
441
|
relevant_actions_idx = 3
|
385
|
-
action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
|
442
|
+
action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
|
443
|
+
self.q_table[str_state][:relevant_actions_idx]
|
444
|
+
) # Normalize probabilities
|
386
445
|
if step_index == 0 and random_optimalism:
|
387
446
|
# print("in 1st step in generating plan and got random optimalism.")
|
388
447
|
std_dev = np.std(action_probs)
|
@@ -398,7 +457,8 @@ class TabularQLearner(TabularRLAgent):
|
|
398
457
|
assert reward >= 0
|
399
458
|
action = 2
|
400
459
|
step_index += 1
|
401
|
-
else:
|
460
|
+
else:
|
461
|
+
action = action_selection_method(action_probs)
|
402
462
|
else:
|
403
463
|
action = action_selection_method(action_probs)
|
404
464
|
steps.append(((obs, self.env.unwrapped.agent_pos), action))
|
@@ -408,16 +468,26 @@ class TabularQLearner(TabularRLAgent):
|
|
408
468
|
if done:
|
409
469
|
break
|
410
470
|
|
411
|
-
#assert len(steps) >= 2
|
471
|
+
# assert len(steps) >= 2
|
412
472
|
if save_fig:
|
413
473
|
sequence = [pos for ((state, pos), action) in steps]
|
414
|
-
#print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
|
474
|
+
# print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
|
415
475
|
print(f"generating sequence image at {fig_path}.")
|
416
|
-
env_prop.create_sequence_image(
|
476
|
+
self.env_prop.create_sequence_image(
|
477
|
+
sequence, fig_path, self.problem_name
|
478
|
+
) # TODO change that assumption, cannot assume this is minigrid env
|
417
479
|
|
418
480
|
return steps
|
419
481
|
|
420
|
-
def generate_partial_observation(
|
482
|
+
def generate_partial_observation(
|
483
|
+
self,
|
484
|
+
action_selection_method: MethodType,
|
485
|
+
percentage: float,
|
486
|
+
save_fig=False,
|
487
|
+
is_consecutive=True,
|
488
|
+
random_optimalism=True,
|
489
|
+
fig_path=None,
|
490
|
+
):
|
421
491
|
"""
|
422
492
|
Generate a single observation given a list of agents
|
423
493
|
|
@@ -429,25 +499,23 @@ class TabularQLearner(TabularRLAgent):
|
|
429
499
|
list: A list of state-action pairs representing the generated observation.
|
430
500
|
|
431
501
|
Notes:
|
432
|
-
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
433
|
-
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
502
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
503
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
434
504
|
selected based on the probability distribution defined by the Q-values in the Q-table.
|
435
505
|
|
436
|
-
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
506
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
437
507
|
episode terminates.
|
438
508
|
"""
|
439
509
|
|
440
|
-
steps = self.generate_observation(
|
441
|
-
|
510
|
+
steps = self.generate_observation(
|
511
|
+
action_selection_method=action_selection_method,
|
512
|
+
random_optimalism=random_optimalism,
|
513
|
+
save_fig=save_fig,
|
514
|
+
fig_path=fig_path,
|
515
|
+
) # steps are a full observation
|
516
|
+
result = random_subset_with_order(
|
517
|
+
steps, (int)(percentage * len(steps)), is_consecutive
|
518
|
+
)
|
442
519
|
if percentage >= 0.8:
|
443
520
|
assert len(result) > 2
|
444
521
|
return result
|
445
|
-
|
446
|
-
if __name__ == "__main__":
|
447
|
-
from gr_libs.metrics.metrics import greedy_selection
|
448
|
-
import gr_envs # to register everything
|
449
|
-
agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
|
450
|
-
agent.generate_observation(greedy_selection, True, True)
|
451
|
-
|
452
|
-
# python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
|
453
|
-
|
@@ -1,11 +1,11 @@
|
|
1
|
-
import gymnasium as gym
|
2
1
|
from abc import abstractmethod
|
3
|
-
from typing import Collection, Literal, Any
|
4
2
|
from random import Random
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import gymnasium as gym
|
5
6
|
import numpy as np
|
6
7
|
|
7
|
-
from gr_libs.ml.base import RLAgent
|
8
|
-
from gr_libs.ml.base import State
|
8
|
+
from gr_libs.ml.base import RLAgent, State
|
9
9
|
|
10
10
|
|
11
11
|
class TabularRLAgent(RLAgent):
|
@@ -15,18 +15,37 @@ class TabularRLAgent(RLAgent):
|
|
15
15
|
recommended as development goes on.
|
16
16
|
"""
|
17
17
|
|
18
|
-
def __init__(
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
domain_name: str,
|
21
|
+
problem_name: str,
|
22
|
+
episodes: int,
|
23
|
+
decaying_eps: bool,
|
24
|
+
eps: float,
|
25
|
+
alpha: float,
|
26
|
+
decay: float,
|
27
|
+
gamma: float,
|
28
|
+
rand: Random,
|
29
|
+
learning_rate,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Initializes a TabularRLAgent object.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
domain_name (str): The name of the domain.
|
36
|
+
problem_name (str): The name of the problem.
|
37
|
+
episodes (int): The number of episodes to run.
|
38
|
+
decaying_eps (bool): Whether to use decaying epsilon.
|
39
|
+
eps (float): The initial epsilon value.
|
40
|
+
alpha (float): The learning rate.
|
41
|
+
decay (float): The decay rate for epsilon.
|
42
|
+
gamma (float): The discount factor.
|
43
|
+
rand (Random): The random number generator.
|
44
|
+
learning_rate: The learning rate.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
None
|
48
|
+
"""
|
30
49
|
super().__init__(
|
31
50
|
episodes=episodes,
|
32
51
|
decaying_eps=decaying_eps,
|
@@ -34,7 +53,7 @@ class TabularRLAgent(RLAgent):
|
|
34
53
|
learning_rate=learning_rate,
|
35
54
|
gamma=gamma,
|
36
55
|
domain_name=domain_name,
|
37
|
-
problem_name=problem_name
|
56
|
+
problem_name=problem_name,
|
38
57
|
)
|
39
58
|
self.env = gym.make(id=problem_name)
|
40
59
|
self.actions = self.env.unwrapped.actions
|
@@ -59,7 +78,6 @@ class TabularRLAgent(RLAgent):
|
|
59
78
|
Returns:
|
60
79
|
(int) the first action the agent takes.
|
61
80
|
"""
|
62
|
-
pass
|
63
81
|
|
64
82
|
@abstractmethod
|
65
83
|
def agent_step(self, reward: float, state: State) -> Any:
|
@@ -72,7 +90,6 @@ class TabularRLAgent(RLAgent):
|
|
72
90
|
Returns:
|
73
91
|
The action the agent is taking.
|
74
92
|
"""
|
75
|
-
pass
|
76
93
|
|
77
94
|
@abstractmethod
|
78
95
|
def agent_end(self, reward: float) -> Any:
|
@@ -82,18 +99,16 @@ class TabularRLAgent(RLAgent):
|
|
82
99
|
reward (float): the reward the agent received for entering the
|
83
100
|
terminal state.
|
84
101
|
"""
|
85
|
-
pass
|
86
102
|
|
87
103
|
@abstractmethod
|
88
104
|
def policy(self, state: State) -> Any:
|
89
105
|
"""The action for the specified state under the currently learned policy
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
106
|
+
(unlike agent_step, this does not update the policy using state as a sample.
|
107
|
+
Args:
|
108
|
+
state (Any): the state observation from the environment
|
109
|
+
Returns:
|
110
|
+
The action prescribed for that state
|
95
111
|
"""
|
96
|
-
pass
|
97
112
|
|
98
113
|
@abstractmethod
|
99
114
|
def softmax_policy(self, state: State) -> np.array:
|
@@ -105,7 +120,6 @@ class TabularRLAgent(RLAgent):
|
|
105
120
|
Returns:
|
106
121
|
np.array: probability of taking each action in self.actions given a state
|
107
122
|
"""
|
108
|
-
pass
|
109
123
|
|
110
124
|
@abstractmethod
|
111
125
|
def learn(self, init_threshold: int = 20):
|
@@ -122,5 +136,5 @@ class TabularRLAgent(RLAgent):
|
|
122
136
|
|
123
137
|
Returns:
|
124
138
|
Any: [description]
|
125
|
-
"""""
|
139
|
+
""" ""
|
126
140
|
return self.softmax_policy(state)
|
gr_libs/ml/utils/__init__.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
|
-
#from .agent import *
|
2
1
|
from .env import make_env
|
3
|
-
from .format import
|
2
|
+
from .format import random_subset_with_order
|
3
|
+
from .math import softmax
|
4
4
|
from .other import device, seed, synthesize
|
5
5
|
from .storage import *
|
6
|
-
from .math import softmax
|