gr-libs 0.1.6.post1__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +6 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +17 -9
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +95 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -231
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. {gr_libs-0.1.6.post1.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -1
  51. tests/test_gcdraco.py +10 -0
  52. tests/test_graml.py +8 -4
  53. tests/test_graql.py +2 -1
  54. tutorials/gcdraco_panda_tutorial.py +66 -0
  55. tutorials/gcdraco_parking_tutorial.py +61 -0
  56. tutorials/graml_minigrid_tutorial.py +42 -12
  57. tutorials/graml_panda_tutorial.py +35 -14
  58. tutorials/graml_parking_tutorial.py +37 -19
  59. tutorials/graml_point_maze_tutorial.py +33 -13
  60. tutorials/graql_minigrid_tutorial.py +31 -15
  61. CI/README.md +0 -12
  62. CI/docker_build_context/Dockerfile +0 -15
  63. gr_libs/recognizer/recognizer_doc.md +0 -61
  64. gr_libs-0.1.6.post1.dist-info/RECORD +0 -70
@@ -13,7 +13,7 @@ from typing import Any
13
13
  from random import Random
14
14
  from typing import List, Iterable
15
15
  from gymnasium.error import InvalidAction
16
- from gr_libs.environment.environment import QLEARNING, MinigridProperty
16
+ from gr_libs.environment.environment import QLEARNING, EnvProperty
17
17
  from gr_libs.ml.tabular import TabularState
18
18
  from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
19
19
  from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
@@ -27,21 +27,23 @@ class TabularQLearner(TabularRLAgent):
27
27
  MODEL_FILE_NAME = r"tabular_model.txt"
28
28
  CONF_FILE = r"conf.pkl"
29
29
 
30
- def __init__(self,
31
- domain_name: str,
32
- problem_name: str,
33
- algorithm: str,
34
- num_timesteps: int,
35
- decaying_eps: bool = True,
36
- eps: float = 1.0,
37
- alpha: float = 0.5,
38
- decay: float = 0.000002,
39
- gamma: float = 0.9,
40
- rand: Random = Random(),
41
- learning_rate: float = 0.001,
42
- check_partial_goals: bool = True,
43
- valid_only: bool = False
44
- ):
30
+ def __init__(
31
+ self,
32
+ domain_name: str,
33
+ problem_name: str,
34
+ env_prop: EnvProperty,
35
+ algorithm: str,
36
+ num_timesteps: int,
37
+ decaying_eps: bool = True,
38
+ eps: float = 1.0,
39
+ alpha: float = 0.5,
40
+ decay: float = 0.000002,
41
+ gamma: float = 0.9,
42
+ rand: Random = Random(),
43
+ learning_rate: float = 0.001,
44
+ check_partial_goals: bool = True,
45
+ valid_only: bool = False,
46
+ ):
45
47
  super().__init__(
46
48
  domain_name=domain_name,
47
49
  problem_name=problem_name,
@@ -52,14 +54,23 @@ class TabularQLearner(TabularRLAgent):
52
54
  decay=decay,
53
55
  gamma=gamma,
54
56
  rand=rand,
55
- learning_rate=learning_rate
57
+ learning_rate=learning_rate,
56
58
  )
57
- assert algorithm == QLEARNING, f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
59
+ assert (
60
+ algorithm == QLEARNING
61
+ ), f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
62
+ self.env_prop = env_prop
58
63
  self.valid_only = valid_only
59
64
  self.check_partial_goals = check_partial_goals
60
65
  self.goal_literals_achieved = set()
61
- self.model_directory = get_agent_model_dir(domain_name=domain_name, model_name=problem_name, class_name=self.class_name())
62
- self.model_file_path = os.path.join(self.model_directory, TabularQLearner.MODEL_FILE_NAME)
66
+ self.model_directory = get_agent_model_dir(
67
+ domain_name=domain_name,
68
+ model_name=problem_name,
69
+ class_name=self.class_name(),
70
+ )
71
+ self.model_file_path = os.path.join(
72
+ self.model_directory, TabularQLearner.MODEL_FILE_NAME
73
+ )
63
74
  self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
64
75
 
65
76
  self._learned_episodes = 0
@@ -73,12 +84,13 @@ class TabularQLearner(TabularRLAgent):
73
84
  print(f"Loading pre-existing conf file in {self._conf_file}")
74
85
  with open(self._conf_file, "rb") as f:
75
86
  conf = dill.load(file=f)
76
- self._learned_episodes = conf['learned_episodes']
87
+ self._learned_episodes = conf["learned_episodes"]
77
88
 
78
89
  # hyperparameters
79
90
  self.base_eps = eps
80
91
  self.patience = 400000
81
92
  if self.decaying_eps:
93
+
82
94
  def epsilon():
83
95
  self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
84
96
  return self._c_eps
@@ -146,22 +158,22 @@ class TabularQLearner(TabularRLAgent):
146
158
  if not os.path.exists(directory):
147
159
  os.makedirs(directory)
148
160
 
149
- with open(path, 'wb') as f:
161
+ with open(path, "wb") as f:
150
162
  pickle.dump(self.q_table, f)
151
163
 
152
164
  def load_q_table(self, path: str):
153
- with open(path, 'rb') as f:
165
+ with open(path, "rb") as f:
154
166
  table = pickle.load(f)
155
167
  self.q_table = table
156
168
 
157
169
  def add_new_state(self, state: TabularState):
158
- self.q_table[str(state)] = [0.] * self.number_of_actions
170
+ self.q_table[str(state)] = [0.0] * self.number_of_actions
159
171
 
160
172
  def get_all_q_values(self, state: TabularState) -> List[float]:
161
173
  if str(state) in self.q_table:
162
174
  return self.q_table[str(state)]
163
175
  else:
164
- return [0.] * self.number_of_actions
176
+ return [0.0] * self.number_of_actions
165
177
 
166
178
  def best_action(self, state: TabularState) -> float:
167
179
  if str(state) not in self.q_table:
@@ -229,7 +241,7 @@ class TabularQLearner(TabularRLAgent):
229
241
  """
230
242
  old_q = self.get_q_value(self.last_state, self.last_action)
231
243
 
232
- td_error = - old_q
244
+ td_error = -old_q
233
245
 
234
246
  new_q = old_q + self.alpha * (reward + td_error)
235
247
  self.set_q_value(self.last_state, self.last_action, new_q)
@@ -244,14 +256,18 @@ class TabularQLearner(TabularRLAgent):
244
256
  if self._learned_episodes >= self.episodes:
245
257
  print("learned episodes is above the requsted episodes")
246
258
  return
247
- print(f'Using {self.__class__.__name__}')
248
- tq = tqdm(range(self.episodes - self._learned_episodes),
249
- postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
259
+ print(f"Using {self.__class__.__name__}")
260
+ tq = tqdm(
261
+ range(self.episodes - self._learned_episodes),
262
+ postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}",
263
+ )
250
264
  for n in tq:
251
265
  self.step = n
252
266
  episode_r = 0
253
267
  observation, info = self.env.reset()
254
- tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
268
+ tabular_state = TabularState.gen_tabular_state(
269
+ environment=self.env, observation=observation
270
+ )
255
271
  action = self.agent_start(state=tabular_state)
256
272
 
257
273
  self.update_states_counter(observation_str=str(tabular_state))
@@ -264,7 +280,9 @@ class TabularQLearner(TabularRLAgent):
264
280
  done_times += 1
265
281
 
266
282
  # standard q-learning algorithm
267
- next_tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
283
+ next_tabular_state = TabularState.gen_tabular_state(
284
+ environment=self.env, observation=observation
285
+ )
268
286
  self.update_states_counter(observation_str=str(next_tabular_state))
269
287
  action = self.agent_step(reward, next_tabular_state)
270
288
  tstep += 1
@@ -277,13 +295,16 @@ class TabularQLearner(TabularRLAgent):
277
295
  max_r = episode_r
278
296
  # print("New all time high reward:", episode_r)
279
297
  tq.set_postfix_str(
280
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
298
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
299
+ )
281
300
  if (n + 1) % 100 == 0:
282
301
  tq.set_postfix_str(
283
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
302
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
303
+ )
284
304
  if (n + 1) % 1000 == 0:
285
305
  tq.set_postfix_str(
286
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
306
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
307
+ )
287
308
  if done_times <= 10:
288
309
  patience += 1
289
310
  if patience >= self.patience:
@@ -297,14 +318,18 @@ class TabularQLearner(TabularRLAgent):
297
318
  done_times = 0
298
319
  self.goal_literals_achieved.clear()
299
320
 
300
- print(f"number of unique states found during training:{self.get_number_of_unique_states()}")
321
+ print(
322
+ f"number of unique states found during training:{self.get_number_of_unique_states()}"
323
+ )
301
324
  print("finish learning and saving status")
302
325
  self.save_models_to_files()
303
326
 
304
327
  def exploit(self, number_of_steps=20):
305
328
  observation, info = self.env.reset()
306
329
  for step_number in range(number_of_steps):
307
- tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
330
+ tabular_state = TabularState.gen_tabular_state(
331
+ environment=self.env, observation=observation
332
+ )
308
333
  action = self.policy(state=tabular_state)
309
334
  observation, reward, terminated, truncated, _ = self.env.step(action)
310
335
  done = terminated | truncated
@@ -314,16 +339,18 @@ class TabularQLearner(TabularRLAgent):
314
339
 
315
340
  def get_actions_probabilities(self, observation):
316
341
  obs, agent_pos = observation
317
- direction = obs['direction']
342
+ direction = obs["direction"]
318
343
 
319
344
  x, y = agent_pos
320
- tabular_state = TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
345
+ tabular_state = TabularState(
346
+ agent_x_position=x, agent_y_position=y, agent_direction=direction
347
+ )
321
348
  return softmax(self.get_all_q_values(tabular_state))
322
349
 
323
350
  def get_q_of_specific_cell(self, cell_key):
324
351
  cell_q_table = {}
325
352
  for i in range(4):
326
- key = cell_key + ':' + str(i)
353
+ key = cell_key + ":" + str(i)
327
354
  if key in self.q_table:
328
355
  cell_q_table[key] = self.q_table[key]
329
356
  return cell_q_table
@@ -331,15 +358,14 @@ class TabularQLearner(TabularRLAgent):
331
358
  def get_all_cells(self):
332
359
  cells = set()
333
360
  for key in self.q_table.keys():
334
- cell = key.split(':')[0]
361
+ cell = key.split(":")[0]
335
362
  cells.add(cell)
336
363
  return list(cells)
337
364
 
338
-
339
365
  def _save_conf_file(self):
340
366
  conf = {
341
- 'learned_episodes': self._learned_episodes,
342
- 'states_counter': self.states_counter
367
+ "learned_episodes": self._learned_episodes,
368
+ "states_counter": self.states_counter,
343
369
  }
344
370
  with open(self._conf_file, "wb") as f:
345
371
  dill.dump(conf, f)
@@ -347,11 +373,20 @@ class TabularQLearner(TabularRLAgent):
347
373
  def save_models_to_files(self):
348
374
  self.save_q_table(path=self.model_file_path)
349
375
  self._save_conf_file()
350
-
376
+
351
377
  def simplify_observation(self, observation):
352
- return [(obs['direction'], agent_pos_x, agent_pos_y, action) for ((obs, (agent_pos_x, agent_pos_y)), action) in observation] # list of tuples, each tuple the sample
353
-
354
- def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig = False, fig_path: str=None, env_prop=None):
378
+ return [
379
+ (obs["direction"], agent_pos_x, agent_pos_y, action)
380
+ for ((obs, (agent_pos_x, agent_pos_y)), action) in observation
381
+ ] # list of tuples, each tuple the sample
382
+
383
+ def generate_observation(
384
+ self,
385
+ action_selection_method: MethodType,
386
+ random_optimalism,
387
+ save_fig=False,
388
+ fig_path: str = None,
389
+ ):
355
390
  """
356
391
  Generate a single observation given a list of agents
357
392
 
@@ -363,26 +398,32 @@ class TabularQLearner(TabularRLAgent):
363
398
  list: A list of state-action pairs representing the generated observation.
364
399
 
365
400
  Notes:
366
- The function randomly selects an agent from the given list and generates a sequence of state-action pairs
367
- based on the Q-table of the selected agent. The action selection is stochastic, where each action is
401
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
402
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
368
403
  selected based on the probability distribution defined by the Q-values in the Q-table.
369
404
 
370
- The generated sequence terminates when a maximum number of steps is reached or when the environment
405
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
371
406
  episode terminates.
372
407
  """
373
408
  if save_fig == False:
374
- assert fig_path == None, "You can't specify a vid path when you don't even save the figure."
409
+ assert (
410
+ fig_path == None
411
+ ), "You can't specify a vid path when you don't even save the figure."
375
412
  else:
376
- assert fig_path != None, "You must specify a vid path when you save the figure."
413
+ assert (
414
+ fig_path != None
415
+ ), "You must specify a vid path when you save the figure."
377
416
  obs, _ = self.env.reset()
378
417
  MAX_STEPS = 32
379
418
  done = False
380
419
  steps = []
381
420
  for step_index in range(MAX_STEPS):
382
421
  x, y = self.env.unwrapped.agent_pos
383
- str_state = "({},{}):{}".format(x, y, obs['direction'])
422
+ str_state = "({},{}):{}".format(x, y, obs["direction"])
384
423
  relevant_actions_idx = 3
385
- action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(self.q_table[str_state][:relevant_actions_idx]) # Normalize probabilities
424
+ action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
425
+ self.q_table[str_state][:relevant_actions_idx]
426
+ ) # Normalize probabilities
386
427
  if step_index == 0 and random_optimalism:
387
428
  # print("in 1st step in generating plan and got random optimalism.")
388
429
  std_dev = np.std(action_probs)
@@ -398,7 +439,8 @@ class TabularQLearner(TabularRLAgent):
398
439
  assert reward >= 0
399
440
  action = 2
400
441
  step_index += 1
401
- else: action = action_selection_method(action_probs)
442
+ else:
443
+ action = action_selection_method(action_probs)
402
444
  else:
403
445
  action = action_selection_method(action_probs)
404
446
  steps.append(((obs, self.env.unwrapped.agent_pos), action))
@@ -408,16 +450,26 @@ class TabularQLearner(TabularRLAgent):
408
450
  if done:
409
451
  break
410
452
 
411
- #assert len(steps) >= 2
453
+ # assert len(steps) >= 2
412
454
  if save_fig:
413
455
  sequence = [pos for ((state, pos), action) in steps]
414
- #print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
456
+ # print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
415
457
  print(f"generating sequence image at {fig_path}.")
416
- env_prop.create_sequence_image(sequence, fig_path, self.problem_name) # TODO change that assumption, cannot assume this is minigrid env
458
+ self.env_prop.create_sequence_image(
459
+ sequence, fig_path, self.problem_name
460
+ ) # TODO change that assumption, cannot assume this is minigrid env
417
461
 
418
462
  return steps
419
463
 
420
- def generate_partial_observation(self, action_selection_method: MethodType, percentage: float, save_fig = False, is_consecutive = True, random_optimalism=True, fig_path=None):
464
+ def generate_partial_observation(
465
+ self,
466
+ action_selection_method: MethodType,
467
+ percentage: float,
468
+ save_fig=False,
469
+ is_consecutive=True,
470
+ random_optimalism=True,
471
+ fig_path=None,
472
+ ):
421
473
  """
422
474
  Generate a single observation given a list of agents
423
475
 
@@ -429,25 +481,23 @@ class TabularQLearner(TabularRLAgent):
429
481
  list: A list of state-action pairs representing the generated observation.
430
482
 
431
483
  Notes:
432
- The function randomly selects an agent from the given list and generates a sequence of state-action pairs
433
- based on the Q-table of the selected agent. The action selection is stochastic, where each action is
484
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
485
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
434
486
  selected based on the probability distribution defined by the Q-values in the Q-table.
435
487
 
436
- The generated sequence terminates when a maximum number of steps is reached or when the environment
488
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
437
489
  episode terminates.
438
490
  """
439
491
 
440
- steps = self.generate_observation(action_selection_method=action_selection_method, random_optimalism=random_optimalism, save_fig=save_fig, fig_path=fig_path) # steps are a full observation
441
- result = random_subset_with_order(steps, (int)(percentage * len(steps)), is_consecutive)
492
+ steps = self.generate_observation(
493
+ action_selection_method=action_selection_method,
494
+ random_optimalism=random_optimalism,
495
+ save_fig=save_fig,
496
+ fig_path=fig_path,
497
+ ) # steps are a full observation
498
+ result = random_subset_with_order(
499
+ steps, (int)(percentage * len(steps)), is_consecutive
500
+ )
442
501
  if percentage >= 0.8:
443
502
  assert len(result) > 2
444
503
  return result
445
-
446
- if __name__ == "__main__":
447
- from gr_libs.metrics.metrics import greedy_selection
448
- import gr_envs # to register everything
449
- agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
450
- agent.generate_observation(greedy_selection, True, True)
451
-
452
- # python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
453
-
@@ -15,18 +15,19 @@ class TabularRLAgent(RLAgent):
15
15
  recommended as development goes on.
16
16
  """
17
17
 
18
- def __init__(self,
19
- domain_name: str,
20
- problem_name: str,
21
- episodes: int,
22
- decaying_eps: bool,
23
- eps: float,
24
- alpha: float,
25
- decay: float,
26
- gamma: float,
27
- rand: Random,
28
- learning_rate
29
- ):
18
+ def __init__(
19
+ self,
20
+ domain_name: str,
21
+ problem_name: str,
22
+ episodes: int,
23
+ decaying_eps: bool,
24
+ eps: float,
25
+ alpha: float,
26
+ decay: float,
27
+ gamma: float,
28
+ rand: Random,
29
+ learning_rate,
30
+ ):
30
31
  super().__init__(
31
32
  episodes=episodes,
32
33
  decaying_eps=decaying_eps,
@@ -34,7 +35,7 @@ class TabularRLAgent(RLAgent):
34
35
  learning_rate=learning_rate,
35
36
  gamma=gamma,
36
37
  domain_name=domain_name,
37
- problem_name=problem_name
38
+ problem_name=problem_name,
38
39
  )
39
40
  self.env = gym.make(id=problem_name)
40
41
  self.actions = self.env.unwrapped.actions
@@ -87,11 +88,11 @@ class TabularRLAgent(RLAgent):
87
88
  @abstractmethod
88
89
  def policy(self, state: State) -> Any:
89
90
  """The action for the specified state under the currently learned policy
90
- (unlike agent_step, this does not update the policy using state as a sample.
91
- Args:
92
- state (Any): the state observation from the environment
93
- Returns:
94
- The action prescribed for that state
91
+ (unlike agent_step, this does not update the policy using state as a sample.
92
+ Args:
93
+ state (Any): the state observation from the environment
94
+ Returns:
95
+ The action prescribed for that state
95
96
  """
96
97
  pass
97
98
 
@@ -122,5 +123,5 @@ class TabularRLAgent(RLAgent):
122
123
 
123
124
  Returns:
124
125
  Any: [description]
125
- """""
126
+ """ ""
126
127
  return self.softmax_policy(state)
@@ -1,6 +1,12 @@
1
- #from .agent import *
1
+ # from .agent import *
2
2
  from .env import make_env
3
- from .format import Vocabulary, preprocess_images, preprocess_texts, get_obss_preprocessor, random_subset_with_order
3
+ from .format import (
4
+ Vocabulary,
5
+ preprocess_images,
6
+ preprocess_texts,
7
+ get_obss_preprocessor,
8
+ random_subset_with_order,
9
+ )
4
10
  from .other import device, seed, synthesize
5
11
  from .storage import *
6
12
  from .math import softmax
@@ -5,96 +5,104 @@ import gr_libs.ml
5
5
  import gymnasium as gym
6
6
  import random
7
7
 
8
- def get_obss_preprocessor(obs_space):
9
- # Check if obs_space is an image space
10
- if isinstance(obs_space, gym.spaces.Box):
11
- obs_space = {"image": obs_space.shape}
12
8
 
13
- def preprocess_obss(obss, device=None):
14
- return ml.DictList({
15
- "image": preprocess_images(obss, device=device)
16
- })
9
+ def get_obss_preprocessor(obs_space):
10
+ # Check if obs_space is an image space
11
+ if isinstance(obs_space, gym.spaces.Box):
12
+ obs_space = {"image": obs_space.shape}
17
13
 
18
- # Check if it is a MiniGrid observation space
19
- elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
20
- obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
14
+ def preprocess_obss(obss, device=None):
15
+ return ml.DictList({"image": preprocess_images(obss, device=device)})
21
16
 
22
- vocab = Vocabulary(obs_space["text"])
17
+ # Check if it is a MiniGrid observation space
18
+ elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
19
+ obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
23
20
 
24
- def preprocess_obss(obss, device=None):
25
- return ml.DictList({
26
- "image": preprocess_images([obs["image"] for obs in obss], device=device),
27
- "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
28
- })
21
+ vocab = Vocabulary(obs_space["text"])
29
22
 
30
- preprocess_obss.vocab = vocab
23
+ def preprocess_obss(obss, device=None):
24
+ return ml.DictList(
25
+ {
26
+ "image": preprocess_images(
27
+ [obs["image"] for obs in obss], device=device
28
+ ),
29
+ "text": preprocess_texts(
30
+ [obs["mission"] for obs in obss], vocab, device=device
31
+ ),
32
+ }
33
+ )
31
34
 
32
- # Check if it is a MiniGrid observation space
33
- elif isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces.keys():
34
- obs_space = {"observation": obs_space.spaces["observation"].shape}
35
+ preprocess_obss.vocab = vocab
35
36
 
36
- def preprocess_obss(obss, device=None):
37
- return ml.DictList({
38
- "observation": preprocess_images(obss, device=device)
39
- })
37
+ # Check if it is a MiniGrid observation space
38
+ elif (
39
+ isinstance(obs_space, gym.spaces.Dict)
40
+ and "observation" in obs_space.spaces.keys()
41
+ ):
42
+ obs_space = {"observation": obs_space.spaces["observation"].shape}
40
43
 
44
+ def preprocess_obss(obss, device=None):
45
+ return ml.DictList({"observation": preprocess_images(obss, device=device)})
41
46
 
42
- else:
43
- raise ValueError("Unknown observation space: " + str(obs_space))
47
+ else:
48
+ raise ValueError("Unknown observation space: " + str(obs_space))
44
49
 
45
- return obs_space, preprocess_obss
50
+ return obs_space, preprocess_obss
46
51
 
47
52
 
48
53
  def preprocess_images(images, device=None):
49
- # Bug of Pytorch: very slow if not first converted to numpy array
50
- images = numpy.array(images)
51
- return torch.tensor(images, device=device, dtype=torch.float)
52
-
53
-
54
- def random_subset_with_order(sequence, subset_size, is_consecutive = True):
55
- if subset_size >= len(sequence):
56
- return sequence
57
- else:
58
- if is_consecutive:
59
- indices_to_select = [i for i in range(subset_size)]
60
- else:
61
- indices_to_select = sorted(random.sample(range(len(sequence)), subset_size)) # Randomly select indices to keep
62
- return [sequence[i] for i in indices_to_select] # Return the elements corresponding to the selected indices
63
-
54
+ # Bug of Pytorch: very slow if not first converted to numpy array
55
+ images = numpy.array(images)
56
+ return torch.tensor(images, device=device, dtype=torch.float)
57
+
58
+
59
+ def random_subset_with_order(sequence, subset_size, is_consecutive=True):
60
+ if subset_size >= len(sequence):
61
+ return sequence
62
+ else:
63
+ if is_consecutive:
64
+ indices_to_select = [i for i in range(subset_size)]
65
+ else:
66
+ indices_to_select = sorted(
67
+ random.sample(range(len(sequence)), subset_size)
68
+ ) # Randomly select indices to keep
69
+ return [
70
+ sequence[i] for i in indices_to_select
71
+ ] # Return the elements corresponding to the selected indices
64
72
 
65
73
 
66
74
  def preprocess_texts(texts, vocab, device=None):
67
- var_indexed_texts = []
68
- max_text_len = 0
75
+ var_indexed_texts = []
76
+ max_text_len = 0
69
77
 
70
- for text in texts:
71
- tokens = re.findall("([a-z]+)", text.lower())
72
- var_indexed_text = numpy.array([vocab[token] for token in tokens])
73
- var_indexed_texts.append(var_indexed_text)
74
- max_text_len = max(len(var_indexed_text), max_text_len)
78
+ for text in texts:
79
+ tokens = re.findall("([a-z]+)", text.lower())
80
+ var_indexed_text = numpy.array([vocab[token] for token in tokens])
81
+ var_indexed_texts.append(var_indexed_text)
82
+ max_text_len = max(len(var_indexed_text), max_text_len)
75
83
 
76
- indexed_texts = numpy.zeros((len(texts), max_text_len))
84
+ indexed_texts = numpy.zeros((len(texts), max_text_len))
77
85
 
78
- for i, indexed_text in enumerate(var_indexed_texts):
79
- indexed_texts[i, :len(indexed_text)] = indexed_text
86
+ for i, indexed_text in enumerate(var_indexed_texts):
87
+ indexed_texts[i, : len(indexed_text)] = indexed_text
80
88
 
81
- return torch.tensor(indexed_texts, device=device, dtype=torch.long)
89
+ return torch.tensor(indexed_texts, device=device, dtype=torch.long)
82
90
 
83
91
 
84
92
  class Vocabulary:
85
- """A mapping from tokens to ids with a capacity of `max_size` words.
86
- It can be saved in a `vocab.json` file."""
87
-
88
- def __init__(self, max_size):
89
- self.max_size = max_size
90
- self.vocab = {}
91
-
92
- def load_vocab(self, vocab):
93
- self.vocab = vocab
94
-
95
- def __getitem__(self, token):
96
- if not token in self.vocab.keys():
97
- if len(self.vocab) >= self.max_size:
98
- raise ValueError("Maximum vocabulary capacity reached")
99
- self.vocab[token] = len(self.vocab) + 1
100
- return self.vocab[token]
93
+ """A mapping from tokens to ids with a capacity of `max_size` words.
94
+ It can be saved in a `vocab.json` file."""
95
+
96
+ def __init__(self, max_size):
97
+ self.max_size = max_size
98
+ self.vocab = {}
99
+
100
+ def load_vocab(self, vocab):
101
+ self.vocab = vocab
102
+
103
+ def __getitem__(self, token):
104
+ if not token in self.vocab.keys():
105
+ if len(self.vocab) >= self.max_size:
106
+ raise ValueError("Maximum vocabulary capacity reached")
107
+ self.vocab[token] = len(self.vocab) + 1
108
+ return self.vocab[token]
gr_libs/ml/utils/math.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import math
2
2
  from typing import Callable, Generator, List
3
3
 
4
+
4
5
  def softmax(values: List[float]) -> List[float]:
5
6
  """Computes softmax probabilities for an array of values
6
7
  TODO We should probably use numpy arrays here
@@ -10,4 +11,4 @@ def softmax(values: List[float]) -> List[float]:
10
11
  Returns:
11
12
  np.array: softmax probabilities
12
13
  """
13
- return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
14
+ return [(math.exp(q)) / sum([math.exp(_q) for _q in values]) for q in values]
gr_libs/ml/utils/other.py CHANGED
@@ -21,4 +21,4 @@ def synthesize(array):
21
21
  d["std"] = numpy.std(array)
22
22
  d["min"] = numpy.amin(array)
23
23
  d["max"] = numpy.amax(array)
24
- return d
24
+ return d