gr-libs 0.1.7.post0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. gr_libs/__init__.py +4 -1
  2. gr_libs/_evaluation/__init__.py +1 -0
  3. gr_libs/_evaluation/_analyze_results_cross_alg_cross_domain.py +260 -0
  4. gr_libs/_evaluation/_generate_experiments_results.py +141 -0
  5. gr_libs/_evaluation/_generate_task_specific_statistics_plots.py +497 -0
  6. gr_libs/_evaluation/_get_plans_images.py +61 -0
  7. gr_libs/_evaluation/_increasing_and_decreasing_.py +106 -0
  8. gr_libs/_version.py +2 -2
  9. gr_libs/all_experiments.py +294 -0
  10. gr_libs/environment/__init__.py +30 -9
  11. gr_libs/environment/_utils/utils.py +27 -0
  12. gr_libs/environment/environment.py +417 -54
  13. gr_libs/metrics/__init__.py +7 -0
  14. gr_libs/metrics/metrics.py +231 -54
  15. gr_libs/ml/__init__.py +2 -5
  16. gr_libs/ml/agent.py +21 -6
  17. gr_libs/ml/base/__init__.py +3 -1
  18. gr_libs/ml/base/rl_agent.py +81 -13
  19. gr_libs/ml/consts.py +1 -1
  20. gr_libs/ml/neural/__init__.py +1 -3
  21. gr_libs/ml/neural/deep_rl_learner.py +619 -378
  22. gr_libs/ml/neural/utils/__init__.py +1 -2
  23. gr_libs/ml/neural/utils/dictlist.py +3 -3
  24. gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +1 -1
  25. gr_libs/ml/planner/mcts/{utils → _utils}/node.py +11 -7
  26. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +15 -11
  27. gr_libs/ml/planner/mcts/mcts_model.py +571 -312
  28. gr_libs/ml/sequential/__init__.py +0 -1
  29. gr_libs/ml/sequential/_lstm_model.py +270 -0
  30. gr_libs/ml/tabular/__init__.py +1 -3
  31. gr_libs/ml/tabular/state.py +7 -7
  32. gr_libs/ml/tabular/tabular_q_learner.py +150 -82
  33. gr_libs/ml/tabular/tabular_rl_agent.py +42 -28
  34. gr_libs/ml/utils/__init__.py +2 -3
  35. gr_libs/ml/utils/format.py +28 -97
  36. gr_libs/ml/utils/math.py +5 -3
  37. gr_libs/ml/utils/other.py +3 -3
  38. gr_libs/ml/utils/storage.py +88 -81
  39. gr_libs/odgr_executor.py +268 -0
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/_utils/__init__.py +0 -0
  42. gr_libs/recognizer/_utils/format.py +18 -0
  43. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +233 -88
  44. gr_libs/recognizer/graml/_gr_dataset.py +233 -0
  45. gr_libs/recognizer/graml/graml_recognizer.py +586 -252
  46. gr_libs/recognizer/recognizer.py +90 -30
  47. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  48. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  49. gr_libs/tutorials/gcdraco_panda_tutorial.py +62 -0
  50. gr_libs/tutorials/gcdraco_parking_tutorial.py +57 -0
  51. gr_libs/tutorials/graml_minigrid_tutorial.py +64 -0
  52. gr_libs/tutorials/graml_panda_tutorial.py +57 -0
  53. gr_libs/tutorials/graml_parking_tutorial.py +52 -0
  54. gr_libs/tutorials/graml_point_maze_tutorial.py +60 -0
  55. gr_libs/tutorials/graql_minigrid_tutorial.py +50 -0
  56. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/METADATA +84 -29
  57. gr_libs-0.2.2.dist-info/RECORD +71 -0
  58. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.2.2.dist-info}/WHEEL +1 -1
  59. gr_libs-0.2.2.dist-info/top_level.txt +2 -0
  60. tests/test_draco.py +14 -0
  61. tests/test_gcdraco.py +10 -0
  62. tests/test_graml.py +12 -8
  63. tests/test_graql.py +3 -2
  64. evaluation/analyze_results_cross_alg_cross_domain.py +0 -277
  65. evaluation/create_minigrid_map_image.py +0 -34
  66. evaluation/file_system.py +0 -42
  67. evaluation/generate_experiments_results.py +0 -92
  68. evaluation/generate_experiments_results_new_ver1.py +0 -254
  69. evaluation/generate_experiments_results_new_ver2.py +0 -331
  70. evaluation/generate_task_specific_statistics_plots.py +0 -272
  71. evaluation/get_plans_images.py +0 -47
  72. evaluation/increasing_and_decreasing_.py +0 -63
  73. gr_libs/environment/utils/utils.py +0 -17
  74. gr_libs/ml/neural/utils/penv.py +0 -57
  75. gr_libs/ml/sequential/lstm_model.py +0 -192
  76. gr_libs/recognizer/graml/gr_dataset.py +0 -134
  77. gr_libs/recognizer/utils/__init__.py +0 -1
  78. gr_libs/recognizer/utils/format.py +0 -13
  79. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  80. gr_libs-0.1.7.post0.dist-info/top_level.txt +0 -4
  81. tutorials/graml_minigrid_tutorial.py +0 -34
  82. tutorials/graml_panda_tutorial.py +0 -41
  83. tutorials/graml_parking_tutorial.py +0 -39
  84. tutorials/graml_point_maze_tutorial.py +0 -39
  85. tutorials/graql_minigrid_tutorial.py +0 -34
  86. /gr_libs/environment/{utils → _utils}/__init__.py +0 -0
@@ -1,19 +1,18 @@
1
- # Don't import stuff from metrics! it's a higher level module.
1
+ """ implementation of q-learning """
2
+
2
3
  import os.path
3
4
  import pickle
4
- import random
5
+ from collections.abc import Iterable
6
+ from random import Random
5
7
  from types import MethodType
8
+ from typing import Any
6
9
 
7
10
  import dill
8
- from gymnasium import register
9
11
  import numpy as np
10
-
11
- from tqdm import tqdm
12
- from typing import Any
13
- from random import Random
14
- from typing import List, Iterable
15
12
  from gymnasium.error import InvalidAction
16
- from gr_libs.environment.environment import QLEARNING, MinigridProperty
13
+ from tqdm import tqdm
14
+
15
+ from gr_libs.environment.environment import QLEARNING, EnvProperty
17
16
  from gr_libs.ml.tabular import TabularState
18
17
  from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
19
18
  from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
@@ -27,21 +26,42 @@ class TabularQLearner(TabularRLAgent):
27
26
  MODEL_FILE_NAME = r"tabular_model.txt"
28
27
  CONF_FILE = r"conf.pkl"
29
28
 
30
- def __init__(self,
31
- domain_name: str,
32
- problem_name: str,
33
- algorithm: str,
34
- num_timesteps: int,
35
- decaying_eps: bool = True,
36
- eps: float = 1.0,
37
- alpha: float = 0.5,
38
- decay: float = 0.000002,
39
- gamma: float = 0.9,
40
- rand: Random = Random(),
41
- learning_rate: float = 0.001,
42
- check_partial_goals: bool = True,
43
- valid_only: bool = False
44
- ):
29
+ def __init__(
30
+ self,
31
+ domain_name: str,
32
+ problem_name: str,
33
+ env_prop: EnvProperty,
34
+ algorithm: str,
35
+ num_timesteps: int,
36
+ decaying_eps: bool = True,
37
+ eps: float = 1.0,
38
+ alpha: float = 0.5,
39
+ decay: float = 0.000002,
40
+ gamma: float = 0.9,
41
+ rand: Random = Random(),
42
+ learning_rate: float = 0.001,
43
+ check_partial_goals: bool = True,
44
+ valid_only: bool = False,
45
+ ):
46
+ """
47
+ Initialize a TabularQLearner object.
48
+
49
+ Args:
50
+ domain_name (str): The name of the domain.
51
+ problem_name (str): The name of the problem.
52
+ env_prop (EnvProperty): The environment properties.
53
+ algorithm (str): The algorithm to use.
54
+ num_timesteps (int): The number of timesteps.
55
+ decaying_eps (bool, optional): Whether to use decaying epsilon. Defaults to True.
56
+ eps (float, optional): The initial epsilon value. Defaults to 1.0.
57
+ alpha (float, optional): The learning rate. Defaults to 0.5.
58
+ decay (float, optional): The decay rate. Defaults to 0.000002.
59
+ gamma (float, optional): The discount factor. Defaults to 0.9.
60
+ rand (Random, optional): The random number generator. Defaults to Random().
61
+ learning_rate (float, optional): The learning rate. Defaults to 0.001.
62
+ check_partial_goals (bool, optional): Whether to check partial goals. Defaults to True.
63
+ valid_only (bool, optional): Whether to use valid goals only. Defaults to False.
64
+ """
45
65
  super().__init__(
46
66
  domain_name=domain_name,
47
67
  problem_name=problem_name,
@@ -52,14 +72,23 @@ class TabularQLearner(TabularRLAgent):
52
72
  decay=decay,
53
73
  gamma=gamma,
54
74
  rand=rand,
55
- learning_rate=learning_rate
75
+ learning_rate=learning_rate,
56
76
  )
57
- assert algorithm == QLEARNING, f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
77
+ assert (
78
+ algorithm == QLEARNING
79
+ ), f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
80
+ self.env_prop = env_prop
58
81
  self.valid_only = valid_only
59
82
  self.check_partial_goals = check_partial_goals
60
83
  self.goal_literals_achieved = set()
61
- self.model_directory = get_agent_model_dir(domain_name=domain_name, model_name=problem_name, class_name=self.class_name())
62
- self.model_file_path = os.path.join(self.model_directory, TabularQLearner.MODEL_FILE_NAME)
84
+ self.model_directory = get_agent_model_dir(
85
+ domain_name=domain_name,
86
+ model_name=problem_name,
87
+ class_name=self.class_name(),
88
+ )
89
+ self.model_file_path = os.path.join(
90
+ self.model_directory, TabularQLearner.MODEL_FILE_NAME
91
+ )
63
92
  self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
64
93
 
65
94
  self._learned_episodes = 0
@@ -73,12 +102,13 @@ class TabularQLearner(TabularRLAgent):
73
102
  print(f"Loading pre-existing conf file in {self._conf_file}")
74
103
  with open(self._conf_file, "rb") as f:
75
104
  conf = dill.load(file=f)
76
- self._learned_episodes = conf['learned_episodes']
105
+ self._learned_episodes = conf["learned_episodes"]
77
106
 
78
107
  # hyperparameters
79
108
  self.base_eps = eps
80
109
  self.patience = 400000
81
110
  if self.decaying_eps:
111
+
82
112
  def epsilon():
83
113
  self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
84
114
  return self._c_eps
@@ -146,22 +176,22 @@ class TabularQLearner(TabularRLAgent):
146
176
  if not os.path.exists(directory):
147
177
  os.makedirs(directory)
148
178
 
149
- with open(path, 'wb') as f:
179
+ with open(path, "wb") as f:
150
180
  pickle.dump(self.q_table, f)
151
181
 
152
182
  def load_q_table(self, path: str):
153
- with open(path, 'rb') as f:
183
+ with open(path, "rb") as f:
154
184
  table = pickle.load(f)
155
185
  self.q_table = table
156
186
 
157
187
  def add_new_state(self, state: TabularState):
158
- self.q_table[str(state)] = [0.] * self.number_of_actions
188
+ self.q_table[str(state)] = [0.0] * self.number_of_actions
159
189
 
160
- def get_all_q_values(self, state: TabularState) -> List[float]:
190
+ def get_all_q_values(self, state: TabularState) -> list[float]:
161
191
  if str(state) in self.q_table:
162
192
  return self.q_table[str(state)]
163
193
  else:
164
- return [0.] * self.number_of_actions
194
+ return [0.0] * self.number_of_actions
165
195
 
166
196
  def best_action(self, state: TabularState) -> float:
167
197
  if str(state) not in self.q_table:
@@ -229,7 +259,7 @@ class TabularQLearner(TabularRLAgent):
229
259
  """
230
260
  old_q = self.get_q_value(self.last_state, self.last_action)
231
261
 
232
- td_error = - old_q
262
+ td_error = -old_q
233
263
 
234
264
  new_q = old_q + self.alpha * (reward + td_error)
235
265
  self.set_q_value(self.last_state, self.last_action, new_q)
@@ -244,14 +274,18 @@ class TabularQLearner(TabularRLAgent):
244
274
  if self._learned_episodes >= self.episodes:
245
275
  print("learned episodes is above the requsted episodes")
246
276
  return
247
- print(f'Using {self.__class__.__name__}')
248
- tq = tqdm(range(self.episodes - self._learned_episodes),
249
- postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
277
+ print(f"Using {self.__class__.__name__}")
278
+ tq = tqdm(
279
+ range(self.episodes - self._learned_episodes),
280
+ postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}",
281
+ )
250
282
  for n in tq:
251
283
  self.step = n
252
284
  episode_r = 0
253
285
  observation, info = self.env.reset()
254
- tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
286
+ tabular_state = TabularState.gen_tabular_state(
287
+ environment=self.env, observation=observation
288
+ )
255
289
  action = self.agent_start(state=tabular_state)
256
290
 
257
291
  self.update_states_counter(observation_str=str(tabular_state))
@@ -264,7 +298,9 @@ class TabularQLearner(TabularRLAgent):
264
298
  done_times += 1
265
299
 
266
300
  # standard q-learning algorithm
267
- next_tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
301
+ next_tabular_state = TabularState.gen_tabular_state(
302
+ environment=self.env, observation=observation
303
+ )
268
304
  self.update_states_counter(observation_str=str(next_tabular_state))
269
305
  action = self.agent_step(reward, next_tabular_state)
270
306
  tstep += 1
@@ -277,13 +313,16 @@ class TabularQLearner(TabularRLAgent):
277
313
  max_r = episode_r
278
314
  # print("New all time high reward:", episode_r)
279
315
  tq.set_postfix_str(
280
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
316
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
317
+ )
281
318
  if (n + 1) % 100 == 0:
282
319
  tq.set_postfix_str(
283
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
320
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
321
+ )
284
322
  if (n + 1) % 1000 == 0:
285
323
  tq.set_postfix_str(
286
- f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
324
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}"
325
+ )
287
326
  if done_times <= 10:
288
327
  patience += 1
289
328
  if patience >= self.patience:
@@ -297,14 +336,18 @@ class TabularQLearner(TabularRLAgent):
297
336
  done_times = 0
298
337
  self.goal_literals_achieved.clear()
299
338
 
300
- print(f"number of unique states found during training:{self.get_number_of_unique_states()}")
339
+ print(
340
+ f"number of unique states found during training:{self.get_number_of_unique_states()}"
341
+ )
301
342
  print("finish learning and saving status")
302
343
  self.save_models_to_files()
303
344
 
304
345
  def exploit(self, number_of_steps=20):
305
346
  observation, info = self.env.reset()
306
347
  for step_number in range(number_of_steps):
307
- tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
348
+ tabular_state = TabularState.gen_tabular_state(
349
+ environment=self.env, observation=observation
350
+ )
308
351
  action = self.policy(state=tabular_state)
309
352
  observation, reward, terminated, truncated, _ = self.env.step(action)
310
353
  done = terminated | truncated
@@ -314,16 +357,18 @@ class TabularQLearner(TabularRLAgent):
314
357
 
315
358
  def get_actions_probabilities(self, observation):
316
359
  obs, agent_pos = observation
317
- direction = obs['direction']
360
+ direction = obs["direction"]
318
361
 
319
362
  x, y = agent_pos
320
- tabular_state = TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
363
+ tabular_state = TabularState(
364
+ agent_x_position=x, agent_y_position=y, agent_direction=direction
365
+ )
321
366
  return softmax(self.get_all_q_values(tabular_state))
322
367
 
323
368
  def get_q_of_specific_cell(self, cell_key):
324
369
  cell_q_table = {}
325
370
  for i in range(4):
326
- key = cell_key + ':' + str(i)
371
+ key = cell_key + ":" + str(i)
327
372
  if key in self.q_table:
328
373
  cell_q_table[key] = self.q_table[key]
329
374
  return cell_q_table
@@ -331,15 +376,14 @@ class TabularQLearner(TabularRLAgent):
331
376
  def get_all_cells(self):
332
377
  cells = set()
333
378
  for key in self.q_table.keys():
334
- cell = key.split(':')[0]
379
+ cell = key.split(":")[0]
335
380
  cells.add(cell)
336
381
  return list(cells)
337
382
 
338
-
339
383
  def _save_conf_file(self):
340
384
  conf = {
341
- 'learned_episodes': self._learned_episodes,
342
- 'states_counter': self.states_counter
385
+ "learned_episodes": self._learned_episodes,
386
+ "states_counter": self.states_counter,
343
387
  }
344
388
  with open(self._conf_file, "wb") as f:
345
389
  dill.dump(conf, f)
@@ -347,11 +391,20 @@ class TabularQLearner(TabularRLAgent):
347
391
  def save_models_to_files(self):
348
392
  self.save_q_table(path=self.model_file_path)
349
393
  self._save_conf_file()
350
-
394
+
351
395
  def simplify_observation(self, observation):
352
- return [(obs['direction'], agent_pos_x, agent_pos_y, action) for ((obs, (agent_pos_x, agent_pos_y)), action) in observation] # list of tuples, each tuple the sample
353
-
354
- def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig=False, fig_path: str=None, env_prop=None):
396
+ return [
397
+ (obs["direction"], agent_pos_x, agent_pos_y, action)
398
+ for ((obs, (agent_pos_x, agent_pos_y)), action) in observation
399
+ ] # list of tuples, each tuple the sample
400
+
401
+ def generate_observation(
402
+ self,
403
+ action_selection_method: MethodType,
404
+ random_optimalism,
405
+ save_fig=False,
406
+ fig_path: str = None,
407
+ ):
355
408
  """
356
409
  Generate a single observation given a list of agents
357
410
 
@@ -363,26 +416,32 @@ class TabularQLearner(TabularRLAgent):
363
416
  list: A list of state-action pairs representing the generated observation.
364
417
 
365
418
  Notes:
366
- The function randomly selects an agent from the given list and generates a sequence of state-action pairs
367
- based on the Q-table of the selected agent. The action selection is stochastic, where each action is
419
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
420
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
368
421
  selected based on the probability distribution defined by the Q-values in the Q-table.
369
422
 
370
- The generated sequence terminates when a maximum number of steps is reached or when the environment
423
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
371
424
  episode terminates.
372
425
  """
373
426
  if save_fig == False:
374
- assert fig_path == None, "You can't specify a vid path when you don't even save the figure."
427
+ assert (
428
+ fig_path == None
429
+ ), "You can't specify a vid path when you don't even save the figure."
375
430
  else:
376
- assert fig_path != None, "You must specify a vid path when you save the figure."
431
+ assert (
432
+ fig_path != None
433
+ ), "You must specify a vid path when you save the figure."
377
434
  obs, _ = self.env.reset()
378
435
  MAX_STEPS = 32
379
436
  done = False
380
437
  steps = []
381
438
  for step_index in range(MAX_STEPS):
382
439
  x, y = self.env.unwrapped.agent_pos
383
- str_state = "({},{}):{}".format(x, y, obs['direction'])
440
+ str_state = "({},{}):{}".format(x, y, obs["direction"])
384
441
  relevant_actions_idx = 3
385
- action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(self.q_table[str_state][:relevant_actions_idx]) # Normalize probabilities
442
+ action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(
443
+ self.q_table[str_state][:relevant_actions_idx]
444
+ ) # Normalize probabilities
386
445
  if step_index == 0 and random_optimalism:
387
446
  # print("in 1st step in generating plan and got random optimalism.")
388
447
  std_dev = np.std(action_probs)
@@ -398,7 +457,8 @@ class TabularQLearner(TabularRLAgent):
398
457
  assert reward >= 0
399
458
  action = 2
400
459
  step_index += 1
401
- else: action = action_selection_method(action_probs)
460
+ else:
461
+ action = action_selection_method(action_probs)
402
462
  else:
403
463
  action = action_selection_method(action_probs)
404
464
  steps.append(((obs, self.env.unwrapped.agent_pos), action))
@@ -408,16 +468,26 @@ class TabularQLearner(TabularRLAgent):
408
468
  if done:
409
469
  break
410
470
 
411
- #assert len(steps) >= 2
471
+ # assert len(steps) >= 2
412
472
  if save_fig:
413
473
  sequence = [pos for ((state, pos), action) in steps]
414
- #print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
474
+ # print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
415
475
  print(f"generating sequence image at {fig_path}.")
416
- env_prop.create_sequence_image(sequence, fig_path, self.problem_name) # TODO change that assumption, cannot assume this is minigrid env
476
+ self.env_prop.create_sequence_image(
477
+ sequence, fig_path, self.problem_name
478
+ ) # TODO change that assumption, cannot assume this is minigrid env
417
479
 
418
480
  return steps
419
481
 
420
- def generate_partial_observation(self, action_selection_method: MethodType, percentage: float, save_fig = False, is_consecutive = True, random_optimalism=True, fig_path=None):
482
+ def generate_partial_observation(
483
+ self,
484
+ action_selection_method: MethodType,
485
+ percentage: float,
486
+ save_fig=False,
487
+ is_consecutive=True,
488
+ random_optimalism=True,
489
+ fig_path=None,
490
+ ):
421
491
  """
422
492
  Generate a single observation given a list of agents
423
493
 
@@ -429,25 +499,23 @@ class TabularQLearner(TabularRLAgent):
429
499
  list: A list of state-action pairs representing the generated observation.
430
500
 
431
501
  Notes:
432
- The function randomly selects an agent from the given list and generates a sequence of state-action pairs
433
- based on the Q-table of the selected agent. The action selection is stochastic, where each action is
502
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
503
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
434
504
  selected based on the probability distribution defined by the Q-values in the Q-table.
435
505
 
436
- The generated sequence terminates when a maximum number of steps is reached or when the environment
506
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
437
507
  episode terminates.
438
508
  """
439
509
 
440
- steps = self.generate_observation(action_selection_method=action_selection_method, random_optimalism=random_optimalism, save_fig=save_fig, fig_path=fig_path) # steps are a full observation
441
- result = random_subset_with_order(steps, (int)(percentage * len(steps)), is_consecutive)
510
+ steps = self.generate_observation(
511
+ action_selection_method=action_selection_method,
512
+ random_optimalism=random_optimalism,
513
+ save_fig=save_fig,
514
+ fig_path=fig_path,
515
+ ) # steps are a full observation
516
+ result = random_subset_with_order(
517
+ steps, (int)(percentage * len(steps)), is_consecutive
518
+ )
442
519
  if percentage >= 0.8:
443
520
  assert len(result) > 2
444
521
  return result
445
-
446
- if __name__ == "__main__":
447
- from gr_libs.metrics.metrics import greedy_selection
448
- import gr_envs # to register everything
449
- agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
450
- agent.generate_observation(greedy_selection, True, True)
451
-
452
- # python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
453
-
@@ -1,11 +1,11 @@
1
- import gymnasium as gym
2
1
  from abc import abstractmethod
3
- from typing import Collection, Literal, Any
4
2
  from random import Random
3
+ from typing import Any
4
+
5
+ import gymnasium as gym
5
6
  import numpy as np
6
7
 
7
- from gr_libs.ml.base import RLAgent
8
- from gr_libs.ml.base import State
8
+ from gr_libs.ml.base import RLAgent, State
9
9
 
10
10
 
11
11
  class TabularRLAgent(RLAgent):
@@ -15,18 +15,37 @@ class TabularRLAgent(RLAgent):
15
15
  recommended as development goes on.
16
16
  """
17
17
 
18
- def __init__(self,
19
- domain_name: str,
20
- problem_name: str,
21
- episodes: int,
22
- decaying_eps: bool,
23
- eps: float,
24
- alpha: float,
25
- decay: float,
26
- gamma: float,
27
- rand: Random,
28
- learning_rate
29
- ):
18
+ def __init__(
19
+ self,
20
+ domain_name: str,
21
+ problem_name: str,
22
+ episodes: int,
23
+ decaying_eps: bool,
24
+ eps: float,
25
+ alpha: float,
26
+ decay: float,
27
+ gamma: float,
28
+ rand: Random,
29
+ learning_rate,
30
+ ):
31
+ """
32
+ Initializes a TabularRLAgent object.
33
+
34
+ Args:
35
+ domain_name (str): The name of the domain.
36
+ problem_name (str): The name of the problem.
37
+ episodes (int): The number of episodes to run.
38
+ decaying_eps (bool): Whether to use decaying epsilon.
39
+ eps (float): The initial epsilon value.
40
+ alpha (float): The learning rate.
41
+ decay (float): The decay rate for epsilon.
42
+ gamma (float): The discount factor.
43
+ rand (Random): The random number generator.
44
+ learning_rate: The learning rate.
45
+
46
+ Returns:
47
+ None
48
+ """
30
49
  super().__init__(
31
50
  episodes=episodes,
32
51
  decaying_eps=decaying_eps,
@@ -34,7 +53,7 @@ class TabularRLAgent(RLAgent):
34
53
  learning_rate=learning_rate,
35
54
  gamma=gamma,
36
55
  domain_name=domain_name,
37
- problem_name=problem_name
56
+ problem_name=problem_name,
38
57
  )
39
58
  self.env = gym.make(id=problem_name)
40
59
  self.actions = self.env.unwrapped.actions
@@ -59,7 +78,6 @@ class TabularRLAgent(RLAgent):
59
78
  Returns:
60
79
  (int) the first action the agent takes.
61
80
  """
62
- pass
63
81
 
64
82
  @abstractmethod
65
83
  def agent_step(self, reward: float, state: State) -> Any:
@@ -72,7 +90,6 @@ class TabularRLAgent(RLAgent):
72
90
  Returns:
73
91
  The action the agent is taking.
74
92
  """
75
- pass
76
93
 
77
94
  @abstractmethod
78
95
  def agent_end(self, reward: float) -> Any:
@@ -82,18 +99,16 @@ class TabularRLAgent(RLAgent):
82
99
  reward (float): the reward the agent received for entering the
83
100
  terminal state.
84
101
  """
85
- pass
86
102
 
87
103
  @abstractmethod
88
104
  def policy(self, state: State) -> Any:
89
105
  """The action for the specified state under the currently learned policy
90
- (unlike agent_step, this does not update the policy using state as a sample.
91
- Args:
92
- state (Any): the state observation from the environment
93
- Returns:
94
- The action prescribed for that state
106
+ (unlike agent_step, this does not update the policy using state as a sample.
107
+ Args:
108
+ state (Any): the state observation from the environment
109
+ Returns:
110
+ The action prescribed for that state
95
111
  """
96
- pass
97
112
 
98
113
  @abstractmethod
99
114
  def softmax_policy(self, state: State) -> np.array:
@@ -105,7 +120,6 @@ class TabularRLAgent(RLAgent):
105
120
  Returns:
106
121
  np.array: probability of taking each action in self.actions given a state
107
122
  """
108
- pass
109
123
 
110
124
  @abstractmethod
111
125
  def learn(self, init_threshold: int = 20):
@@ -122,5 +136,5 @@ class TabularRLAgent(RLAgent):
122
136
 
123
137
  Returns:
124
138
  Any: [description]
125
- """""
139
+ """ ""
126
140
  return self.softmax_policy(state)
@@ -1,6 +1,5 @@
1
- #from .agent import *
2
1
  from .env import make_env
3
- from .format import Vocabulary, preprocess_images, preprocess_texts, get_obss_preprocessor, random_subset_with_order
2
+ from .format import random_subset_with_order
3
+ from .math import softmax
4
4
  from .other import device, seed, synthesize
5
5
  from .storage import *
6
- from .math import softmax