gr-libs 0.1.7.post0__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
- evaluation/create_minigrid_map_image.py +10 -6
- evaluation/file_system.py +16 -5
- evaluation/generate_experiments_results.py +123 -74
- evaluation/generate_experiments_results_new_ver1.py +227 -243
- evaluation/generate_experiments_results_new_ver2.py +317 -317
- evaluation/generate_task_specific_statistics_plots.py +481 -253
- evaluation/get_plans_images.py +41 -26
- evaluation/increasing_and_decreasing_.py +97 -56
- gr_libs/__init__.py +2 -1
- gr_libs/_version.py +2 -2
- gr_libs/environment/__init__.py +16 -8
- gr_libs/environment/environment.py +167 -39
- gr_libs/environment/utils/utils.py +22 -12
- gr_libs/metrics/__init__.py +5 -0
- gr_libs/metrics/metrics.py +76 -34
- gr_libs/ml/__init__.py +2 -0
- gr_libs/ml/agent.py +21 -6
- gr_libs/ml/base/__init__.py +1 -1
- gr_libs/ml/base/rl_agent.py +13 -10
- gr_libs/ml/consts.py +1 -1
- gr_libs/ml/neural/deep_rl_learner.py +433 -352
- gr_libs/ml/neural/utils/__init__.py +1 -1
- gr_libs/ml/neural/utils/dictlist.py +3 -3
- gr_libs/ml/neural/utils/penv.py +5 -2
- gr_libs/ml/planner/mcts/mcts_model.py +524 -302
- gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
- gr_libs/ml/planner/mcts/utils/node.py +11 -7
- gr_libs/ml/planner/mcts/utils/tree.py +14 -10
- gr_libs/ml/sequential/__init__.py +1 -1
- gr_libs/ml/sequential/lstm_model.py +256 -175
- gr_libs/ml/tabular/state.py +7 -7
- gr_libs/ml/tabular/tabular_q_learner.py +123 -73
- gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
- gr_libs/ml/utils/__init__.py +8 -2
- gr_libs/ml/utils/format.py +78 -70
- gr_libs/ml/utils/math.py +2 -1
- gr_libs/ml/utils/other.py +1 -1
- gr_libs/ml/utils/storage.py +88 -28
- gr_libs/problems/consts.py +1549 -1227
- gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
- gr_libs/recognizer/graml/gr_dataset.py +209 -110
- gr_libs/recognizer/graml/graml_recognizer.py +431 -240
- gr_libs/recognizer/recognizer.py +38 -27
- gr_libs/recognizer/utils/__init__.py +1 -1
- gr_libs/recognizer/utils/format.py +8 -3
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
- gr_libs-0.1.8.dist-info/RECORD +70 -0
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
- tests/test_gcdraco.py +10 -0
- tests/test_graml.py +8 -4
- tests/test_graql.py +2 -1
- tutorials/gcdraco_panda_tutorial.py +66 -0
- tutorials/gcdraco_parking_tutorial.py +61 -0
- tutorials/graml_minigrid_tutorial.py +42 -12
- tutorials/graml_panda_tutorial.py +35 -14
- tutorials/graml_parking_tutorial.py +37 -20
- tutorials/graml_point_maze_tutorial.py +33 -13
- tutorials/graql_minigrid_tutorial.py +31 -15
- gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
- {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from typing import Any
|
|
13
13
|
from random import Random
|
14
14
|
from typing import List, Iterable
|
15
15
|
from gymnasium.error import InvalidAction
|
16
|
-
from gr_libs.environment.environment import QLEARNING,
|
16
|
+
from gr_libs.environment.environment import QLEARNING, EnvProperty
|
17
17
|
from gr_libs.ml.tabular import TabularState
|
18
18
|
from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
|
19
19
|
from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
|
@@ -27,21 +27,23 @@ class TabularQLearner(TabularRLAgent):
|
|
27
27
|
MODEL_FILE_NAME = r"tabular_model.txt"
|
28
28
|
CONF_FILE = r"conf.pkl"
|
29
29
|
|
30
|
-
def __init__(
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
domain_name: str,
|
33
|
+
problem_name: str,
|
34
|
+
env_prop: EnvProperty,
|
35
|
+
algorithm: str,
|
36
|
+
num_timesteps: int,
|
37
|
+
decaying_eps: bool = True,
|
38
|
+
eps: float = 1.0,
|
39
|
+
alpha: float = 0.5,
|
40
|
+
decay: float = 0.000002,
|
41
|
+
gamma: float = 0.9,
|
42
|
+
rand: Random = Random(),
|
43
|
+
learning_rate: float = 0.001,
|
44
|
+
check_partial_goals: bool = True,
|
45
|
+
valid_only: bool = False,
|
46
|
+
):
|
45
47
|
super().__init__(
|
46
48
|
domain_name=domain_name,
|
47
49
|
problem_name=problem_name,
|
@@ -52,14 +54,23 @@ class TabularQLearner(TabularRLAgent):
|
|
52
54
|
decay=decay,
|
53
55
|
gamma=gamma,
|
54
56
|
rand=rand,
|
55
|
-
learning_rate=learning_rate
|
57
|
+
learning_rate=learning_rate,
|
56
58
|
)
|
57
|
-
assert
|
59
|
+
assert (
|
60
|
+
algorithm == QLEARNING
|
61
|
+
), f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
|
62
|
+
self.env_prop = env_prop
|
58
63
|
self.valid_only = valid_only
|
59
64
|
self.check_partial_goals = check_partial_goals
|
60
65
|
self.goal_literals_achieved = set()
|
61
|
-
self.model_directory = get_agent_model_dir(
|
62
|
-
|
66
|
+
self.model_directory = get_agent_model_dir(
|
67
|
+
domain_name=domain_name,
|
68
|
+
model_name=problem_name,
|
69
|
+
class_name=self.class_name(),
|
70
|
+
)
|
71
|
+
self.model_file_path = os.path.join(
|
72
|
+
self.model_directory, TabularQLearner.MODEL_FILE_NAME
|
73
|
+
)
|
63
74
|
self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
|
64
75
|
|
65
76
|
self._learned_episodes = 0
|
@@ -73,12 +84,13 @@ class TabularQLearner(TabularRLAgent):
|
|
73
84
|
print(f"Loading pre-existing conf file in {self._conf_file}")
|
74
85
|
with open(self._conf_file, "rb") as f:
|
75
86
|
conf = dill.load(file=f)
|
76
|
-
self._learned_episodes = conf[
|
87
|
+
self._learned_episodes = conf["learned_episodes"]
|
77
88
|
|
78
89
|
# hyperparameters
|
79
90
|
self.base_eps = eps
|
80
91
|
self.patience = 400000
|
81
92
|
if self.decaying_eps:
|
93
|
+
|
82
94
|
def epsilon():
|
83
95
|
self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
|
84
96
|
return self._c_eps
|
@@ -146,22 +158,22 @@ class TabularQLearner(TabularRLAgent):
|
|
146
158
|
if not os.path.exists(directory):
|
147
159
|
os.makedirs(directory)
|
148
160
|
|
149
|
-
with open(path,
|
161
|
+
with open(path, "wb") as f:
|
150
162
|
pickle.dump(self.q_table, f)
|
151
163
|
|
152
164
|
def load_q_table(self, path: str):
|
153
|
-
with open(path,
|
165
|
+
with open(path, "rb") as f:
|
154
166
|
table = pickle.load(f)
|
155
167
|
self.q_table = table
|
156
168
|
|
157
169
|
def add_new_state(self, state: TabularState):
|
158
|
-
self.q_table[str(state)] = [0.] * self.number_of_actions
|
170
|
+
self.q_table[str(state)] = [0.0] * self.number_of_actions
|
159
171
|
|
160
172
|
def get_all_q_values(self, state: TabularState) -> List[float]:
|
161
173
|
if str(state) in self.q_table:
|
162
174
|
return self.q_table[str(state)]
|
163
175
|
else:
|
164
|
-
return [0.] * self.number_of_actions
|
176
|
+
return [0.0] * self.number_of_actions
|
165
177
|
|
166
178
|
def best_action(self, state: TabularState) -> float:
|
167
179
|
if str(state) not in self.q_table:
|
@@ -229,7 +241,7 @@ class TabularQLearner(TabularRLAgent):
|
|
229
241
|
"""
|
230
242
|
old_q = self.get_q_value(self.last_state, self.last_action)
|
231
243
|
|
232
|
-
td_error = -
|
244
|
+
td_error = -old_q
|
233
245
|
|
234
246
|
new_q = old_q + self.alpha * (reward + td_error)
|
235
247
|
self.set_q_value(self.last_state, self.last_action, new_q)
|
@@ -244,14 +256,18 @@ class TabularQLearner(TabularRLAgent):
|
|
244
256
|
if self._learned_episodes >= self.episodes:
|
245
257
|
print("learned episodes is above the requsted episodes")
|
246
258
|
return
|
247
|
-
print(f
|
248
|
-
tq = tqdm(
|
249
|
-
|
259
|
+
print(f"Using {self.__class__.__name__}")
|
260
|
+
tq = tqdm(
|
261
|
+
range(self.episodes - self._learned_episodes),
|
262
|
+
postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}",
|
263
|
+
)
|
250
264
|
for n in tq:
|
251
265
|
self.step = n
|
252
266
|
episode_r = 0
|
253
267
|
observation, info = self.env.reset()
|
254
|
-
tabular_state = TabularState.gen_tabular_state(
|
268
|
+
tabular_state = TabularState.gen_tabular_state(
|
269
|
+
environment=self.env, observation=observation
|
270
|
+
)
|
255
271
|
action = self.agent_start(state=tabular_state)
|
256
272
|
|
257
273
|
self.update_states_counter(observation_str=str(tabular_state))
|
@@ -264,7 +280,9 @@ class TabularQLearner(TabularRLAgent):
|
|
264
280
|
done_times += 1
|
265
281
|
|
266
282
|
# standard q-learning algorithm
|
267
|
-
next_tabular_state = TabularState.gen_tabular_state(
|
283
|
+
next_tabular_state = TabularState.gen_tabular_state(
|
284
|
+
environment=self.env, observation=observation
|
285
|
+
)
|
268
286
|
self.update_states_counter(observation_str=str(next_tabular_state))
|
269
287
|
action = self.agent_step(reward, next_tabular_state)
|
270
288
|
tstep += 1
|
@@ -277,13 +295,16 @@ class TabularQLearner(TabularRLAgent):
|
|
277
295
|
max_r = episode_r
|
278
296
|
# print("New all time high reward:", episode_r)
|
279
297
|
tq.set_postfix_str(
|
280
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
298
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
299
|
+
)
|
281
300
|
if (n + 1) % 100 == 0:
|
282
301
|
tq.set_postfix_str(
|
283
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
302
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
303
|
+
)
|
284
304
|
if (n + 1) % 1000 == 0:
|
285
305
|
tq.set_postfix_str(
|
286
|
-
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
306
|
+
f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
|
307
|
+
)
|
287
308
|
if done_times <= 10:
|
288
309
|
patience += 1
|
289
310
|
if patience >= self.patience:
|
@@ -297,14 +318,18 @@ class TabularQLearner(TabularRLAgent):
|
|
297
318
|
done_times = 0
|
298
319
|
self.goal_literals_achieved.clear()
|
299
320
|
|
300
|
-
print(
|
321
|
+
print(
|
322
|
+
f"number of unique states found during training:{self.get_number_of_unique_states()}"
|
323
|
+
)
|
301
324
|
print("finish learning and saving status")
|
302
325
|
self.save_models_to_files()
|
303
326
|
|
304
327
|
def exploit(self, number_of_steps=20):
|
305
328
|
observation, info = self.env.reset()
|
306
329
|
for step_number in range(number_of_steps):
|
307
|
-
tabular_state = TabularState.gen_tabular_state(
|
330
|
+
tabular_state = TabularState.gen_tabular_state(
|
331
|
+
environment=self.env, observation=observation
|
332
|
+
)
|
308
333
|
action = self.policy(state=tabular_state)
|
309
334
|
observation, reward, terminated, truncated, _ = self.env.step(action)
|
310
335
|
done = terminated | truncated
|
@@ -314,16 +339,18 @@ class TabularQLearner(TabularRLAgent):
|
|
314
339
|
|
315
340
|
def get_actions_probabilities(self, observation):
|
316
341
|
obs, agent_pos = observation
|
317
|
-
direction = obs[
|
342
|
+
direction = obs["direction"]
|
318
343
|
|
319
344
|
x, y = agent_pos
|
320
|
-
tabular_state = TabularState(
|
345
|
+
tabular_state = TabularState(
|
346
|
+
agent_x_position=x, agent_y_position=y, agent_direction=direction
|
347
|
+
)
|
321
348
|
return softmax(self.get_all_q_values(tabular_state))
|
322
349
|
|
323
350
|
def get_q_of_specific_cell(self, cell_key):
|
324
351
|
cell_q_table = {}
|
325
352
|
for i in range(4):
|
326
|
-
key = cell_key +
|
353
|
+
key = cell_key + ":" + str(i)
|
327
354
|
if key in self.q_table:
|
328
355
|
cell_q_table[key] = self.q_table[key]
|
329
356
|
return cell_q_table
|
@@ -331,15 +358,14 @@ class TabularQLearner(TabularRLAgent):
|
|
331
358
|
def get_all_cells(self):
|
332
359
|
cells = set()
|
333
360
|
for key in self.q_table.keys():
|
334
|
-
cell = key.split(
|
361
|
+
cell = key.split(":")[0]
|
335
362
|
cells.add(cell)
|
336
363
|
return list(cells)
|
337
364
|
|
338
|
-
|
339
365
|
def _save_conf_file(self):
|
340
366
|
conf = {
|
341
|
-
|
342
|
-
|
367
|
+
"learned_episodes": self._learned_episodes,
|
368
|
+
"states_counter": self.states_counter,
|
343
369
|
}
|
344
370
|
with open(self._conf_file, "wb") as f:
|
345
371
|
dill.dump(conf, f)
|
@@ -347,11 +373,20 @@ class TabularQLearner(TabularRLAgent):
|
|
347
373
|
def save_models_to_files(self):
|
348
374
|
self.save_q_table(path=self.model_file_path)
|
349
375
|
self._save_conf_file()
|
350
|
-
|
376
|
+
|
351
377
|
def simplify_observation(self, observation):
|
352
|
-
return [
|
353
|
-
|
354
|
-
|
378
|
+
return [
|
379
|
+
(obs["direction"], agent_pos_x, agent_pos_y, action)
|
380
|
+
for ((obs, (agent_pos_x, agent_pos_y)), action) in observation
|
381
|
+
] # list of tuples, each tuple the sample
|
382
|
+
|
383
|
+
def generate_observation(
|
384
|
+
self,
|
385
|
+
action_selection_method: MethodType,
|
386
|
+
random_optimalism,
|
387
|
+
save_fig=False,
|
388
|
+
fig_path: str = None,
|
389
|
+
):
|
355
390
|
"""
|
356
391
|
Generate a single observation given a list of agents
|
357
392
|
|
@@ -363,26 +398,32 @@ class TabularQLearner(TabularRLAgent):
|
|
363
398
|
list: A list of state-action pairs representing the generated observation.
|
364
399
|
|
365
400
|
Notes:
|
366
|
-
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
367
|
-
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
401
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
402
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
368
403
|
selected based on the probability distribution defined by the Q-values in the Q-table.
|
369
404
|
|
370
|
-
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
405
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
371
406
|
episode terminates.
|
372
407
|
"""
|
373
408
|
if save_fig == False:
|
374
|
-
assert
|
409
|
+
assert (
|
410
|
+
fig_path == None
|
411
|
+
), "You can't specify a vid path when you don't even save the figure."
|
375
412
|
else:
|
376
|
-
assert
|
413
|
+
assert (
|
414
|
+
fig_path != None
|
415
|
+
), "You must specify a vid path when you save the figure."
|
377
416
|
obs, _ = self.env.reset()
|
378
417
|
MAX_STEPS = 32
|
379
418
|
done = False
|
380
419
|
steps = []
|
381
420
|
for step_index in range(MAX_STEPS):
|
382
421
|
x, y = self.env.unwrapped.agent_pos
|
383
|
-
str_state = "({},{}):{}".format(x, y, obs[
|
422
|
+
str_state = "({},{}):{}".format(x, y, obs["direction"])
|
384
423
|
relevant_actions_idx = 3
|
385
|
-
action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
|
424
|
+
action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
|
425
|
+
self.q_table[str_state][:relevant_actions_idx]
|
426
|
+
) # Normalize probabilities
|
386
427
|
if step_index == 0 and random_optimalism:
|
387
428
|
# print("in 1st step in generating plan and got random optimalism.")
|
388
429
|
std_dev = np.std(action_probs)
|
@@ -398,7 +439,8 @@ class TabularQLearner(TabularRLAgent):
|
|
398
439
|
assert reward >= 0
|
399
440
|
action = 2
|
400
441
|
step_index += 1
|
401
|
-
else:
|
442
|
+
else:
|
443
|
+
action = action_selection_method(action_probs)
|
402
444
|
else:
|
403
445
|
action = action_selection_method(action_probs)
|
404
446
|
steps.append(((obs, self.env.unwrapped.agent_pos), action))
|
@@ -408,16 +450,26 @@ class TabularQLearner(TabularRLAgent):
|
|
408
450
|
if done:
|
409
451
|
break
|
410
452
|
|
411
|
-
#assert len(steps) >= 2
|
453
|
+
# assert len(steps) >= 2
|
412
454
|
if save_fig:
|
413
455
|
sequence = [pos for ((state, pos), action) in steps]
|
414
|
-
#print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
|
456
|
+
# print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
|
415
457
|
print(f"generating sequence image at {fig_path}.")
|
416
|
-
env_prop.create_sequence_image(
|
458
|
+
self.env_prop.create_sequence_image(
|
459
|
+
sequence, fig_path, self.problem_name
|
460
|
+
) # TODO change that assumption, cannot assume this is minigrid env
|
417
461
|
|
418
462
|
return steps
|
419
463
|
|
420
|
-
def generate_partial_observation(
|
464
|
+
def generate_partial_observation(
|
465
|
+
self,
|
466
|
+
action_selection_method: MethodType,
|
467
|
+
percentage: float,
|
468
|
+
save_fig=False,
|
469
|
+
is_consecutive=True,
|
470
|
+
random_optimalism=True,
|
471
|
+
fig_path=None,
|
472
|
+
):
|
421
473
|
"""
|
422
474
|
Generate a single observation given a list of agents
|
423
475
|
|
@@ -429,25 +481,23 @@ class TabularQLearner(TabularRLAgent):
|
|
429
481
|
list: A list of state-action pairs representing the generated observation.
|
430
482
|
|
431
483
|
Notes:
|
432
|
-
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
433
|
-
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
484
|
+
The function randomly selects an agent from the given list and generates a sequence of state-action pairs
|
485
|
+
based on the Q-table of the selected agent. The action selection is stochastic, where each action is
|
434
486
|
selected based on the probability distribution defined by the Q-values in the Q-table.
|
435
487
|
|
436
|
-
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
488
|
+
The generated sequence terminates when a maximum number of steps is reached or when the environment
|
437
489
|
episode terminates.
|
438
490
|
"""
|
439
491
|
|
440
|
-
steps = self.generate_observation(
|
441
|
-
|
492
|
+
steps = self.generate_observation(
|
493
|
+
action_selection_method=action_selection_method,
|
494
|
+
random_optimalism=random_optimalism,
|
495
|
+
save_fig=save_fig,
|
496
|
+
fig_path=fig_path,
|
497
|
+
) # steps are a full observation
|
498
|
+
result = random_subset_with_order(
|
499
|
+
steps, (int)(percentage * len(steps)), is_consecutive
|
500
|
+
)
|
442
501
|
if percentage >= 0.8:
|
443
502
|
assert len(result) > 2
|
444
503
|
return result
|
445
|
-
|
446
|
-
if __name__ == "__main__":
|
447
|
-
from gr_libs.metrics.metrics import greedy_selection
|
448
|
-
import gr_envs # to register everything
|
449
|
-
agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
|
450
|
-
agent.generate_observation(greedy_selection, True, True)
|
451
|
-
|
452
|
-
# python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
|
453
|
-
|
@@ -15,18 +15,19 @@ class TabularRLAgent(RLAgent):
|
|
15
15
|
recommended as development goes on.
|
16
16
|
"""
|
17
17
|
|
18
|
-
def __init__(
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
domain_name: str,
|
21
|
+
problem_name: str,
|
22
|
+
episodes: int,
|
23
|
+
decaying_eps: bool,
|
24
|
+
eps: float,
|
25
|
+
alpha: float,
|
26
|
+
decay: float,
|
27
|
+
gamma: float,
|
28
|
+
rand: Random,
|
29
|
+
learning_rate,
|
30
|
+
):
|
30
31
|
super().__init__(
|
31
32
|
episodes=episodes,
|
32
33
|
decaying_eps=decaying_eps,
|
@@ -34,7 +35,7 @@ class TabularRLAgent(RLAgent):
|
|
34
35
|
learning_rate=learning_rate,
|
35
36
|
gamma=gamma,
|
36
37
|
domain_name=domain_name,
|
37
|
-
problem_name=problem_name
|
38
|
+
problem_name=problem_name,
|
38
39
|
)
|
39
40
|
self.env = gym.make(id=problem_name)
|
40
41
|
self.actions = self.env.unwrapped.actions
|
@@ -87,11 +88,11 @@ class TabularRLAgent(RLAgent):
|
|
87
88
|
@abstractmethod
|
88
89
|
def policy(self, state: State) -> Any:
|
89
90
|
"""The action for the specified state under the currently learned policy
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
91
|
+
(unlike agent_step, this does not update the policy using state as a sample.
|
92
|
+
Args:
|
93
|
+
state (Any): the state observation from the environment
|
94
|
+
Returns:
|
95
|
+
The action prescribed for that state
|
95
96
|
"""
|
96
97
|
pass
|
97
98
|
|
@@ -122,5 +123,5 @@ class TabularRLAgent(RLAgent):
|
|
122
123
|
|
123
124
|
Returns:
|
124
125
|
Any: [description]
|
125
|
-
"""""
|
126
|
+
""" ""
|
126
127
|
return self.softmax_policy(state)
|
gr_libs/ml/utils/__init__.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1
|
-
#from .agent import *
|
1
|
+
# from .agent import *
|
2
2
|
from .env import make_env
|
3
|
-
from .format import
|
3
|
+
from .format import (
|
4
|
+
Vocabulary,
|
5
|
+
preprocess_images,
|
6
|
+
preprocess_texts,
|
7
|
+
get_obss_preprocessor,
|
8
|
+
random_subset_with_order,
|
9
|
+
)
|
4
10
|
from .other import device, seed, synthesize
|
5
11
|
from .storage import *
|
6
12
|
from .math import softmax
|
gr_libs/ml/utils/format.py
CHANGED
@@ -5,96 +5,104 @@ import gr_libs.ml
|
|
5
5
|
import gymnasium as gym
|
6
6
|
import random
|
7
7
|
|
8
|
-
def get_obss_preprocessor(obs_space):
|
9
|
-
# Check if obs_space is an image space
|
10
|
-
if isinstance(obs_space, gym.spaces.Box):
|
11
|
-
obs_space = {"image": obs_space.shape}
|
12
8
|
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
9
|
+
def get_obss_preprocessor(obs_space):
|
10
|
+
# Check if obs_space is an image space
|
11
|
+
if isinstance(obs_space, gym.spaces.Box):
|
12
|
+
obs_space = {"image": obs_space.shape}
|
17
13
|
|
18
|
-
|
19
|
-
|
20
|
-
obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
|
14
|
+
def preprocess_obss(obss, device=None):
|
15
|
+
return ml.DictList({"image": preprocess_images(obss, device=device)})
|
21
16
|
|
22
|
-
|
17
|
+
# Check if it is a MiniGrid observation space
|
18
|
+
elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
|
19
|
+
obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
|
23
20
|
|
24
|
-
|
25
|
-
return ml.DictList({
|
26
|
-
"image": preprocess_images([obs["image"] for obs in obss], device=device),
|
27
|
-
"text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
|
28
|
-
})
|
21
|
+
vocab = Vocabulary(obs_space["text"])
|
29
22
|
|
30
|
-
|
23
|
+
def preprocess_obss(obss, device=None):
|
24
|
+
return ml.DictList(
|
25
|
+
{
|
26
|
+
"image": preprocess_images(
|
27
|
+
[obs["image"] for obs in obss], device=device
|
28
|
+
),
|
29
|
+
"text": preprocess_texts(
|
30
|
+
[obs["mission"] for obs in obss], vocab, device=device
|
31
|
+
),
|
32
|
+
}
|
33
|
+
)
|
31
34
|
|
32
|
-
|
33
|
-
elif isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces.keys():
|
34
|
-
obs_space = {"observation": obs_space.spaces["observation"].shape}
|
35
|
+
preprocess_obss.vocab = vocab
|
35
36
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
37
|
+
# Check if it is a MiniGrid observation space
|
38
|
+
elif (
|
39
|
+
isinstance(obs_space, gym.spaces.Dict)
|
40
|
+
and "observation" in obs_space.spaces.keys()
|
41
|
+
):
|
42
|
+
obs_space = {"observation": obs_space.spaces["observation"].shape}
|
40
43
|
|
44
|
+
def preprocess_obss(obss, device=None):
|
45
|
+
return ml.DictList({"observation": preprocess_images(obss, device=device)})
|
41
46
|
|
42
|
-
|
43
|
-
|
47
|
+
else:
|
48
|
+
raise ValueError("Unknown observation space: " + str(obs_space))
|
44
49
|
|
45
|
-
|
50
|
+
return obs_space, preprocess_obss
|
46
51
|
|
47
52
|
|
48
53
|
def preprocess_images(images, device=None):
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
def random_subset_with_order(sequence, subset_size, is_consecutive
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
54
|
+
# Bug of Pytorch: very slow if not first converted to numpy array
|
55
|
+
images = numpy.array(images)
|
56
|
+
return torch.tensor(images, device=device, dtype=torch.float)
|
57
|
+
|
58
|
+
|
59
|
+
def random_subset_with_order(sequence, subset_size, is_consecutive=True):
|
60
|
+
if subset_size >= len(sequence):
|
61
|
+
return sequence
|
62
|
+
else:
|
63
|
+
if is_consecutive:
|
64
|
+
indices_to_select = [i for i in range(subset_size)]
|
65
|
+
else:
|
66
|
+
indices_to_select = sorted(
|
67
|
+
random.sample(range(len(sequence)), subset_size)
|
68
|
+
) # Randomly select indices to keep
|
69
|
+
return [
|
70
|
+
sequence[i] for i in indices_to_select
|
71
|
+
] # Return the elements corresponding to the selected indices
|
64
72
|
|
65
73
|
|
66
74
|
def preprocess_texts(texts, vocab, device=None):
|
67
|
-
|
68
|
-
|
75
|
+
var_indexed_texts = []
|
76
|
+
max_text_len = 0
|
69
77
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
78
|
+
for text in texts:
|
79
|
+
tokens = re.findall("([a-z]+)", text.lower())
|
80
|
+
var_indexed_text = numpy.array([vocab[token] for token in tokens])
|
81
|
+
var_indexed_texts.append(var_indexed_text)
|
82
|
+
max_text_len = max(len(var_indexed_text), max_text_len)
|
75
83
|
|
76
|
-
|
84
|
+
indexed_texts = numpy.zeros((len(texts), max_text_len))
|
77
85
|
|
78
|
-
|
79
|
-
|
86
|
+
for i, indexed_text in enumerate(var_indexed_texts):
|
87
|
+
indexed_texts[i, : len(indexed_text)] = indexed_text
|
80
88
|
|
81
|
-
|
89
|
+
return torch.tensor(indexed_texts, device=device, dtype=torch.long)
|
82
90
|
|
83
91
|
|
84
92
|
class Vocabulary:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
93
|
+
"""A mapping from tokens to ids with a capacity of `max_size` words.
|
94
|
+
It can be saved in a `vocab.json` file."""
|
95
|
+
|
96
|
+
def __init__(self, max_size):
|
97
|
+
self.max_size = max_size
|
98
|
+
self.vocab = {}
|
99
|
+
|
100
|
+
def load_vocab(self, vocab):
|
101
|
+
self.vocab = vocab
|
102
|
+
|
103
|
+
def __getitem__(self, token):
|
104
|
+
if not token in self.vocab.keys():
|
105
|
+
if len(self.vocab) >= self.max_size:
|
106
|
+
raise ValueError("Maximum vocabulary capacity reached")
|
107
|
+
self.vocab[token] = len(self.vocab) + 1
|
108
|
+
return self.vocab[token]
|
gr_libs/ml/utils/math.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import math
|
2
2
|
from typing import Callable, Generator, List
|
3
3
|
|
4
|
+
|
4
5
|
def softmax(values: List[float]) -> List[float]:
|
5
6
|
"""Computes softmax probabilities for an array of values
|
6
7
|
TODO We should probably use numpy arrays here
|
@@ -10,4 +11,4 @@ def softmax(values: List[float]) -> List[float]:
|
|
10
11
|
Returns:
|
11
12
|
np.array: softmax probabilities
|
12
13
|
"""
|
13
|
-
return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
|
14
|
+
return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
|
gr_libs/ml/utils/other.py
CHANGED